diff --git a/examples/v1/train_freeze/train_freeze_sft.yaml b/examples/v1/train_freeze/train_freeze_sft.yaml index 84bb66427..29233d8e1 100644 --- a/examples/v1/train_freeze/train_freeze_sft.yaml +++ b/examples/v1/train_freeze/train_freeze_sft.yaml @@ -1,5 +1,4 @@ model: Qwen/Qwen3-4B -trust_remote_code: true model_class: llm template: qwen3_nothink diff --git a/examples/v1/train_full/train_full_deepspeed.yaml b/examples/v1/train_full/train_full_deepspeed.yaml index 2b9a6642e..0a9851147 100644 --- a/examples/v1/train_full/train_full_deepspeed.yaml +++ b/examples/v1/train_full/train_full_deepspeed.yaml @@ -1,5 +1,4 @@ model: Qwen/Qwen3-0.6B - model_class: llm template: qwen3_nothink diff --git a/examples/v1/train_full/train_full_fsdp2.yaml b/examples/v1/train_full/train_full_fsdp2.yaml index 57ac6a1f3..1378ec30b 100644 --- a/examples/v1/train_full/train_full_fsdp2.yaml +++ b/examples/v1/train_full/train_full_fsdp2.yaml @@ -1,5 +1,4 @@ model: Qwen/Qwen3-0.6B -trust_remote_code: true model_class: llm template: qwen3_nothink diff --git a/examples/v1/train_lora/train_lora_sft.yaml b/examples/v1/train_lora/train_lora_sft.yaml index e1f160c51..653b1df7f 100644 --- a/examples/v1/train_lora/train_lora_sft.yaml +++ b/examples/v1/train_lora/train_lora_sft.yaml @@ -1,5 +1,4 @@ model: Qwen/Qwen3-4B -trust_remote_code: true model_class: llm template: qwen3_nothink @@ -28,7 +27,6 @@ train_dataset: data/v1_sft_demo.yaml ### training output_dir: ./outputs/test_lora micro_batch_size: 1 -global_batch_size: 4 cutoff_len: 2048 learning_rate: 1.0e-4 bf16: true diff --git a/examples/v1/train_lora/train_lora_sft_rank0.yaml b/examples/v1/train_lora/train_lora_sft_rank0.yaml new file mode 100644 index 000000000..363d6eb60 --- /dev/null +++ b/examples/v1/train_lora/train_lora_sft_rank0.yaml @@ -0,0 +1,40 @@ +model: Qwen/Qwen3-4B +model_class: llm + +template: qwen3_nothink + +# PEFT Configuration +peft_config: + name: lora + r: 16 + lora_alpha: 32 + lora_dropout: 0.05 + target_modules: all + +# Kernel Config +kernel_config: + name: auto + include_kernels: auto + +# FSDP Config +dist_config: + name: fsdp2 + dcp_path: null + +init_config: + name: init_on_rank0 + +### data +train_dataset: data/v1_sft_demo.yaml + +### training +output_dir: ./outputs/test_lora +micro_batch_size: 1 +cutoff_len: 2048 +learning_rate: 1.0e-4 +bf16: true +max_steps: 10 + +### sample +sample_backend: hf +max_new_tokens: 128 diff --git a/examples/v1/train_qlora/quantization.yaml b/examples/v1/train_qlora/quantization.yaml index a063b207c..6edc9745f 100644 --- a/examples/v1/train_qlora/quantization.yaml +++ b/examples/v1/train_qlora/quantization.yaml @@ -1,5 +1,4 @@ model: Qwen/Qwen3-0.6B -trust_remote_code: true model_class: llm template: qwen3_nothink diff --git a/src/llamafactory/v1/core/model_engine.py b/src/llamafactory/v1/core/model_engine.py index c08522402..fbdbd6b0e 100644 --- a/src/llamafactory/v1/core/model_engine.py +++ b/src/llamafactory/v1/core/model_engine.py @@ -140,6 +140,9 @@ class ModelEngine: **init_kwargs, ) + init_mode = self.args.init_config.name if self.args.init_config is not None else "init_on_default" + model._init_mode = init_mode + if self.args.peft_config is None: if self.is_train: logger.info_rank0("Fine-tuning mode: full tuning") @@ -147,6 +150,9 @@ class ModelEngine: else: logger.info_rank0("Inference the original model") else: + if self.args.peft_config.name == "lora" and init_mode == "init_on_meta": + raise ValueError("Currently lora stage does not support loading model by meta.") + from ..plugins.model_plugins.peft import PeftPlugin model = PeftPlugin(self.args.peft_config.name)(model, self.args.peft_config, self.is_train) diff --git a/src/llamafactory/v1/plugins/model_plugins/peft.py b/src/llamafactory/v1/plugins/model_plugins/peft.py index 17ff3779e..2ef2035e1 100644 --- a/src/llamafactory/v1/plugins/model_plugins/peft.py +++ b/src/llamafactory/v1/plugins/model_plugins/peft.py @@ -150,9 +150,6 @@ def load_adapter(model: HFModel, adapter_name_or_path: Union[list[str], str], is @PeftPlugin("lora").register() def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool = False) -> HFModel: - if model.device.type == "meta": - raise ValueError("Currently lora stage does not support loading model by meta.") - adapter_name_or_path = config.get("adapter_name_or_path") if adapter_name_or_path: diff --git a/src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py b/src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py index 7d4fac3cc..bf6b09b87 100644 --- a/src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py +++ b/src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py @@ -17,6 +17,7 @@ import gc import os import torch +import torch.distributed as dist import torch.nn as nn from peft.tuners.lora import LoraLayer from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict, set_model_state_dict @@ -244,23 +245,57 @@ class FSDP2Engine: logger.info(f"Restored {len(saved_buffers)} non-persistent buffers") def shard_model(self, model: HFModel) -> HFModel: - if model.device.type == "meta": + init_mode = getattr(model, "_init_mode", "init_on_default") + + if init_mode == "init_on_rank0": + if getattr(model.config, "tie_word_embeddings", False): + model.tie_weights() + + if self.rank == 0: + logger.info("init_on_rank0 detected: sharding then scattering Rank 0 CPU weights.") + full_sd = {k: v.clone() for k, v in model.state_dict().items()} + else: + full_sd = {} + + # Reuse existing helper to save persistent=False buffers (e.g. inv_freq) before shard + saved_buffers = self._save_non_persistent_buffers(model) if self.rank == 0 else {} + + model = self.prepare_model(model) + + device = get_current_accelerator() + model.to_empty(device=device) + + # Scatter params from Rank 0 into all DTensor shards + # Broadcast the full state dict from the global rank-0 process to all ranks in this group. + options = StateDictOptions(full_state_dict=True, cpu_offload=True, broadcast_from_rank0=True) + set_model_state_dict(model, full_sd, options=options) + + # Broadcast and restore non-persistent buffers + buffers_to_sync = [saved_buffers] + dist.broadcast_object_list(buffers_to_sync, src=0, group=self.fsdp_mesh.get_group()) + self._restore_non_persistent_buffers(model, buffers_to_sync[0]) + + if self.rank == 0: + logger.info("init_on_rank0 sync complete.") + + elif init_mode == "init_on_meta": non_persistent_buffers = self._save_non_persistent_buffers(model) - if getattr(model.config, "tie_word_embeddings", None): + if getattr(model.config, "tie_word_embeddings", False): model.tie_weights() model = self.prepare_model(model) model = self.materialize_and_load(model, hf_model_path=model.config.name_or_path, dcp_path=self.dcp_path) # fix tied broken for no-fsdp-wrap case - if getattr(model.config, "tie_word_embeddings", None): + if getattr(model.config, "tie_word_embeddings", False): model.tie_weights() self._restore_non_persistent_buffers(model, non_persistent_buffers) else: model = self.prepare_model(model) + return model def _load_from_dcp(self, model: HFModel, dcp_path: str):