mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-03-29 02:33:09 +00:00
[v1] add init on rank0 for fsdp2 (#10264)
This commit is contained in:
@@ -140,6 +140,9 @@ class ModelEngine:
|
||||
**init_kwargs,
|
||||
)
|
||||
|
||||
init_mode = self.args.init_config.name if self.args.init_config is not None else "init_on_default"
|
||||
model._init_mode = init_mode
|
||||
|
||||
if self.args.peft_config is None:
|
||||
if self.is_train:
|
||||
logger.info_rank0("Fine-tuning mode: full tuning")
|
||||
@@ -147,6 +150,9 @@ class ModelEngine:
|
||||
else:
|
||||
logger.info_rank0("Inference the original model")
|
||||
else:
|
||||
if self.args.peft_config.name == "lora" and init_mode == "init_on_meta":
|
||||
raise ValueError("Currently lora stage does not support loading model by meta.")
|
||||
|
||||
from ..plugins.model_plugins.peft import PeftPlugin
|
||||
|
||||
model = PeftPlugin(self.args.peft_config.name)(model, self.args.peft_config, self.is_train)
|
||||
|
||||
@@ -150,9 +150,6 @@ def load_adapter(model: HFModel, adapter_name_or_path: Union[list[str], str], is
|
||||
|
||||
@PeftPlugin("lora").register()
|
||||
def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool = False) -> HFModel:
|
||||
if model.device.type == "meta":
|
||||
raise ValueError("Currently lora stage does not support loading model by meta.")
|
||||
|
||||
adapter_name_or_path = config.get("adapter_name_or_path")
|
||||
|
||||
if adapter_name_or_path:
|
||||
|
||||
@@ -17,6 +17,7 @@ import gc
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from peft.tuners.lora import LoraLayer
|
||||
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict, set_model_state_dict
|
||||
@@ -244,23 +245,57 @@ class FSDP2Engine:
|
||||
logger.info(f"Restored {len(saved_buffers)} non-persistent buffers")
|
||||
|
||||
def shard_model(self, model: HFModel) -> HFModel:
|
||||
if model.device.type == "meta":
|
||||
init_mode = getattr(model, "_init_mode", "init_on_default")
|
||||
|
||||
if init_mode == "init_on_rank0":
|
||||
if getattr(model.config, "tie_word_embeddings", False):
|
||||
model.tie_weights()
|
||||
|
||||
if self.rank == 0:
|
||||
logger.info("init_on_rank0 detected: sharding then scattering Rank 0 CPU weights.")
|
||||
full_sd = {k: v.clone() for k, v in model.state_dict().items()}
|
||||
else:
|
||||
full_sd = {}
|
||||
|
||||
# Reuse existing helper to save persistent=False buffers (e.g. inv_freq) before shard
|
||||
saved_buffers = self._save_non_persistent_buffers(model) if self.rank == 0 else {}
|
||||
|
||||
model = self.prepare_model(model)
|
||||
|
||||
device = get_current_accelerator()
|
||||
model.to_empty(device=device)
|
||||
|
||||
# Scatter params from Rank 0 into all DTensor shards
|
||||
# Broadcast the full state dict from the global rank-0 process to all ranks in this group.
|
||||
options = StateDictOptions(full_state_dict=True, cpu_offload=True, broadcast_from_rank0=True)
|
||||
set_model_state_dict(model, full_sd, options=options)
|
||||
|
||||
# Broadcast and restore non-persistent buffers
|
||||
buffers_to_sync = [saved_buffers]
|
||||
dist.broadcast_object_list(buffers_to_sync, src=0, group=self.fsdp_mesh.get_group())
|
||||
self._restore_non_persistent_buffers(model, buffers_to_sync[0])
|
||||
|
||||
if self.rank == 0:
|
||||
logger.info("init_on_rank0 sync complete.")
|
||||
|
||||
elif init_mode == "init_on_meta":
|
||||
non_persistent_buffers = self._save_non_persistent_buffers(model)
|
||||
|
||||
if getattr(model.config, "tie_word_embeddings", None):
|
||||
if getattr(model.config, "tie_word_embeddings", False):
|
||||
model.tie_weights()
|
||||
|
||||
model = self.prepare_model(model)
|
||||
model = self.materialize_and_load(model, hf_model_path=model.config.name_or_path, dcp_path=self.dcp_path)
|
||||
|
||||
# fix tied broken for no-fsdp-wrap case
|
||||
if getattr(model.config, "tie_word_embeddings", None):
|
||||
if getattr(model.config, "tie_word_embeddings", False):
|
||||
model.tie_weights()
|
||||
|
||||
self._restore_non_persistent_buffers(model, non_persistent_buffers)
|
||||
|
||||
else:
|
||||
model = self.prepare_model(model)
|
||||
|
||||
return model
|
||||
|
||||
def _load_from_dcp(self, model: HFModel, dcp_path: str):
|
||||
|
||||
Reference in New Issue
Block a user