mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-03-28 14:23:08 +00:00
[v1] add init on rank0 for fsdp2 (#10264)
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
model: Qwen/Qwen3-4B
|
||||
trust_remote_code: true
|
||||
model_class: llm
|
||||
|
||||
template: qwen3_nothink
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
model: Qwen/Qwen3-0.6B
|
||||
|
||||
model_class: llm
|
||||
|
||||
template: qwen3_nothink
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
model: Qwen/Qwen3-0.6B
|
||||
trust_remote_code: true
|
||||
model_class: llm
|
||||
|
||||
template: qwen3_nothink
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
model: Qwen/Qwen3-4B
|
||||
trust_remote_code: true
|
||||
model_class: llm
|
||||
|
||||
template: qwen3_nothink
|
||||
@@ -28,7 +27,6 @@ train_dataset: data/v1_sft_demo.yaml
|
||||
### training
|
||||
output_dir: ./outputs/test_lora
|
||||
micro_batch_size: 1
|
||||
global_batch_size: 4
|
||||
cutoff_len: 2048
|
||||
learning_rate: 1.0e-4
|
||||
bf16: true
|
||||
|
||||
40
examples/v1/train_lora/train_lora_sft_rank0.yaml
Normal file
40
examples/v1/train_lora/train_lora_sft_rank0.yaml
Normal file
@@ -0,0 +1,40 @@
|
||||
model: Qwen/Qwen3-4B
|
||||
model_class: llm
|
||||
|
||||
template: qwen3_nothink
|
||||
|
||||
# PEFT Configuration
|
||||
peft_config:
|
||||
name: lora
|
||||
r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
target_modules: all
|
||||
|
||||
# Kernel Config
|
||||
kernel_config:
|
||||
name: auto
|
||||
include_kernels: auto
|
||||
|
||||
# FSDP Config
|
||||
dist_config:
|
||||
name: fsdp2
|
||||
dcp_path: null
|
||||
|
||||
init_config:
|
||||
name: init_on_rank0
|
||||
|
||||
### data
|
||||
train_dataset: data/v1_sft_demo.yaml
|
||||
|
||||
### training
|
||||
output_dir: ./outputs/test_lora
|
||||
micro_batch_size: 1
|
||||
cutoff_len: 2048
|
||||
learning_rate: 1.0e-4
|
||||
bf16: true
|
||||
max_steps: 10
|
||||
|
||||
### sample
|
||||
sample_backend: hf
|
||||
max_new_tokens: 128
|
||||
@@ -1,5 +1,4 @@
|
||||
model: Qwen/Qwen3-0.6B
|
||||
trust_remote_code: true
|
||||
model_class: llm
|
||||
|
||||
template: qwen3_nothink
|
||||
|
||||
@@ -140,6 +140,9 @@ class ModelEngine:
|
||||
**init_kwargs,
|
||||
)
|
||||
|
||||
init_mode = self.args.init_config.name if self.args.init_config is not None else "init_on_default"
|
||||
model._init_mode = init_mode
|
||||
|
||||
if self.args.peft_config is None:
|
||||
if self.is_train:
|
||||
logger.info_rank0("Fine-tuning mode: full tuning")
|
||||
@@ -147,6 +150,9 @@ class ModelEngine:
|
||||
else:
|
||||
logger.info_rank0("Inference the original model")
|
||||
else:
|
||||
if self.args.peft_config.name == "lora" and init_mode == "init_on_meta":
|
||||
raise ValueError("Currently lora stage does not support loading model by meta.")
|
||||
|
||||
from ..plugins.model_plugins.peft import PeftPlugin
|
||||
|
||||
model = PeftPlugin(self.args.peft_config.name)(model, self.args.peft_config, self.is_train)
|
||||
|
||||
@@ -150,9 +150,6 @@ def load_adapter(model: HFModel, adapter_name_or_path: Union[list[str], str], is
|
||||
|
||||
@PeftPlugin("lora").register()
|
||||
def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool = False) -> HFModel:
|
||||
if model.device.type == "meta":
|
||||
raise ValueError("Currently lora stage does not support loading model by meta.")
|
||||
|
||||
adapter_name_or_path = config.get("adapter_name_or_path")
|
||||
|
||||
if adapter_name_or_path:
|
||||
|
||||
@@ -17,6 +17,7 @@ import gc
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from peft.tuners.lora import LoraLayer
|
||||
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict, set_model_state_dict
|
||||
@@ -244,23 +245,57 @@ class FSDP2Engine:
|
||||
logger.info(f"Restored {len(saved_buffers)} non-persistent buffers")
|
||||
|
||||
def shard_model(self, model: HFModel) -> HFModel:
|
||||
if model.device.type == "meta":
|
||||
init_mode = getattr(model, "_init_mode", "init_on_default")
|
||||
|
||||
if init_mode == "init_on_rank0":
|
||||
if getattr(model.config, "tie_word_embeddings", False):
|
||||
model.tie_weights()
|
||||
|
||||
if self.rank == 0:
|
||||
logger.info("init_on_rank0 detected: sharding then scattering Rank 0 CPU weights.")
|
||||
full_sd = {k: v.clone() for k, v in model.state_dict().items()}
|
||||
else:
|
||||
full_sd = {}
|
||||
|
||||
# Reuse existing helper to save persistent=False buffers (e.g. inv_freq) before shard
|
||||
saved_buffers = self._save_non_persistent_buffers(model) if self.rank == 0 else {}
|
||||
|
||||
model = self.prepare_model(model)
|
||||
|
||||
device = get_current_accelerator()
|
||||
model.to_empty(device=device)
|
||||
|
||||
# Scatter params from Rank 0 into all DTensor shards
|
||||
# Broadcast the full state dict from the global rank-0 process to all ranks in this group.
|
||||
options = StateDictOptions(full_state_dict=True, cpu_offload=True, broadcast_from_rank0=True)
|
||||
set_model_state_dict(model, full_sd, options=options)
|
||||
|
||||
# Broadcast and restore non-persistent buffers
|
||||
buffers_to_sync = [saved_buffers]
|
||||
dist.broadcast_object_list(buffers_to_sync, src=0, group=self.fsdp_mesh.get_group())
|
||||
self._restore_non_persistent_buffers(model, buffers_to_sync[0])
|
||||
|
||||
if self.rank == 0:
|
||||
logger.info("init_on_rank0 sync complete.")
|
||||
|
||||
elif init_mode == "init_on_meta":
|
||||
non_persistent_buffers = self._save_non_persistent_buffers(model)
|
||||
|
||||
if getattr(model.config, "tie_word_embeddings", None):
|
||||
if getattr(model.config, "tie_word_embeddings", False):
|
||||
model.tie_weights()
|
||||
|
||||
model = self.prepare_model(model)
|
||||
model = self.materialize_and_load(model, hf_model_path=model.config.name_or_path, dcp_path=self.dcp_path)
|
||||
|
||||
# fix tied broken for no-fsdp-wrap case
|
||||
if getattr(model.config, "tie_word_embeddings", None):
|
||||
if getattr(model.config, "tie_word_embeddings", False):
|
||||
model.tie_weights()
|
||||
|
||||
self._restore_non_persistent_buffers(model, non_persistent_buffers)
|
||||
|
||||
else:
|
||||
model = self.prepare_model(model)
|
||||
|
||||
return model
|
||||
|
||||
def _load_from_dcp(self, model: HFModel, dcp_path: str):
|
||||
|
||||
Reference in New Issue
Block a user