[v1] support training with fsdp2 (#9773)

Co-authored-by: frozenleaves <frozen@Mac.local>
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
This commit is contained in:
浮梦
2026-01-25 19:41:58 +08:00
committed by GitHub
parent 641bfdd482
commit f9f11dcb97
15 changed files with 801 additions and 33 deletions

View File

@@ -0,0 +1,34 @@
model: Qwen/Qwen3-0.6B
trust_remote_code: true
model_class: llm
template: qwen3_nothink
kernel_config:
name: auto
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
quant_config: null
dist_config:
name: fsdp2
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
init_config:
name: init_on_meta
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/test_fsdp2
micro_batch_size: 1
global_batch_size: 1
cutoff_len: 2048
learning_rate: 1.0e-4
bf16: false
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

55
scripts/hf2dcp.py Normal file
View File

@@ -0,0 +1,55 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert a HuggingFace model to DCP checkpoint format.
Usage:
python scripts/hf2dcp.py convert --hf_path=/path/to/hf --dcp_path=/path/to/dcp
Arguments:
hf_path: Path to the HuggingFace model directory.
dcp_path: Output path (directory) for DCP checkpoint.
"""
import fire
import torch
import torch.distributed.checkpoint as dcp
from transformers import AutoModelForCausalLM
def convert(hf_path: str, dcp_path: str) -> None:
"""Convert HF model weights to DCP.
Args:
hf_path: HuggingFace model directory.
dcp_path: Output path (directory) for DCP checkpoint.
"""
if not hf_path or not dcp_path:
raise ValueError("Both 'hf_path' and 'dcp_path' are required.")
print(f"Loading HF model from {hf_path}...")
model = AutoModelForCausalLM.from_pretrained(hf_path, device_map="cpu", torch_dtype=torch.bfloat16)
print(f"Saving to DCP format at {dcp_path}...")
dcp.save(model.state_dict(), checkpoint_id=dcp_path)
print("Done!")
def help() -> None:
"""Show help message."""
print(__doc__)
if __name__ == "__main__":
fire.Fire({"convert": convert, "help": help, "--convert": convert})

View File

@@ -180,6 +180,16 @@ def operate_tensorlike(fn: Callable[[...], Tensor], data: TensorLike, **kwargs)
return result.tolist() return result.tolist()
def get_process_group_backend() -> str:
"""Get backend for init process group."""
if get_current_accelerator().type == DeviceType.NPU:
return "hccl"
elif get_current_accelerator().type == DeviceType.CUDA:
return "nccl"
else:
return "gloo"
def all_gather(tensor: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: def all_gather(tensor: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
"""Gathers the tensor from all ranks and stacks them at the first dim.""" """Gathers the tensor from all ranks and stacks them at the first dim."""
world_size = get_world_size() world_size = get_world_size()

View File

@@ -145,7 +145,7 @@ class DistributedInterface:
timeout = config.get("timeout", 18000) timeout = config.get("timeout", 18000)
if self._is_distributed: if self._is_distributed:
init_process_group(timeout=timedelta(seconds=timeout)) init_process_group(timeout=timedelta(seconds=timeout), backend=helper.get_process_group_backend())
self.model_device_mesh = init_device_mesh( self.model_device_mesh = init_device_mesh(
device_type=self.current_device.type, device_type=self.current_device.type,
mesh_shape=self.strategy.model_mesh_shape, mesh_shape=self.strategy.model_mesh_shape,

View File

@@ -20,7 +20,7 @@ from typing import Any
from omegaconf import OmegaConf from omegaconf import OmegaConf
from transformers import HfArgumentParser from transformers import HfArgumentParser
from ...extras.misc import is_env_enabled from ..utils.env import is_env_enabled
from .data_args import DataArguments from .data_args import DataArguments
from .model_args import ModelArguments from .model_args import ModelArguments
from .sample_args import SampleArguments from .sample_args import SampleArguments

View File

@@ -45,6 +45,10 @@ class TrainingArguments:
default=3, default=3,
metadata={"help": "Number of training epochs."}, metadata={"help": "Number of training epochs."},
) )
max_steps: int | None = field(
default=None,
metadata={"help": "Maximum number of training steps. If set, overrides num_train_epochs."},
)
max_grad_norm: float = field( max_grad_norm: float = field(
default=1.0, default=1.0,
metadata={"help": "Maximum gradient norm for training."}, metadata={"help": "Maximum gradient norm for training."},

View File

@@ -67,6 +67,10 @@ class BaseTrainer:
self.model_input_names = self.renderer.processor.model_input_names self.model_input_names = self.renderer.processor.model_input_names
self._create_batch_generator() self._create_batch_generator()
# Calculate num_training_steps: max_steps takes priority if set
if self.args.max_steps is not None and self.args.max_steps > 0:
self.num_training_steps = self.args.max_steps
else:
self.num_training_steps = self.args.num_train_epochs * len(self.train_batch_generator) self.num_training_steps = self.args.num_train_epochs * len(self.train_batch_generator)
if self.args.enable_activation_checkpointing: if self.args.enable_activation_checkpointing:
@@ -98,7 +102,22 @@ class BaseTrainer:
) )
def _shard_model(self) -> None: def _shard_model(self) -> None:
pass if self.args.dist_config is None:
if DistributedInterface().get_world_size(Dim.DP) > 1:
from torch.nn.parallel import DistributedDataParallel as DDP
logger.warning_rank0(
"dist_config is None but distributed training is enabled; falling back to DistributedDataParallel."
)
device_ids = None if self.device.type == "cpu" else [self.device.index]
self.model = DDP(self.model, device_ids=device_ids)
else:
from ..plugins.trainer_plugins.distributed.hub import DistributedPlugin
self.model = DistributedPlugin(self.args.dist_config.name)(
self.model,
self.args.dist_config,
)
def _init_optimizer(self) -> None: def _init_optimizer(self) -> None:
"""Init optimizer.""" """Init optimizer."""
@@ -162,7 +181,9 @@ class BaseTrainer:
step_loss += loss.item() step_loss += loss.item()
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm).item() grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm).item()
if not torch.isfinite(grad_norm):
# isfinite(): argument 'input' (position 1) must be Tensor, not float
if not torch.isfinite(torch.tensor(grad_norm)): # type: ignore # pyright: ignore [reportUnknownReturnType]
logger.warning_rank0(f"Gradient norm is not finite: {grad_norm}") logger.warning_rank0(f"Gradient norm is not finite: {grad_norm}")
else: else:
self.optimizer.step() self.optimizer.step()
@@ -172,10 +193,17 @@ class BaseTrainer:
step_loss, grad_norm = DistributedInterface().all_reduce([step_loss, grad_norm]) step_loss, grad_norm = DistributedInterface().all_reduce([step_loss, grad_norm])
DistributedInterface().sync() DistributedInterface().sync()
if DistributedInterface().get_rank() == 0:
print(f"Epoch {epoch}, Step {self.global_step}, Loss: {step_loss:.4f}, Grad Norm: {grad_norm:.4f}") print(f"Epoch {epoch}, Step {self.global_step}, Loss: {step_loss:.4f}, Grad Norm: {grad_norm:.4f}")
# Check if max_steps is reached
if self.global_step >= self.num_training_steps:
logger.info_rank0(f"Reached max_steps ({self.num_training_steps}), stopping training.")
return
def save_model(self) -> None: def save_model(self) -> None:
"""Save the model.""" """Save the model."""
self.model.save_pretrained(self.args.output_dir) model_to_save = self.model.module if hasattr(self.model, "module") else self.model
model_to_save.save_pretrained(self.args.output_dir)
self.renderer.processor.save_pretrained(self.args.output_dir) self.renderer.processor.save_pretrained(self.args.output_dir)
logger.info_rank0(f"Model saved to {self.args.output_dir}") logger.info_rank0(f"Model saved to {self.args.output_dir}")

View File

@@ -30,7 +30,7 @@ from torch.utils.data import default_collate
from torchdata.stateful_dataloader import StatefulDataLoader from torchdata.stateful_dataloader import StatefulDataLoader
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
from ...accelerator.interface import DistributedInterface from ...accelerator.interface import Dim, DistributedInterface
from ...config import BatchingStrategy from ...config import BatchingStrategy
from ...utils import logging from ...utils import logging
from ...utils.helper import pad_and_truncate from ...utils.helper import pad_and_truncate
@@ -83,8 +83,7 @@ class BatchGenerator(Iterator):
self.pin_memory = pin_memory self.pin_memory = pin_memory
self.drop_last = drop_last self.drop_last = drop_last
# TODO: support length and infinity # TODO: support length and infinity
dp_size = DistributedInterface().get_world_size(Dim.DP)
dp_size = DistributedInterface().get_world_size("dp")
if self.global_batch_size is None: if self.global_batch_size is None:
self.global_batch_size = dp_size * micro_batch_size self.global_batch_size = dp_size * micro_batch_size
@@ -126,8 +125,8 @@ class BatchGenerator(Iterator):
if len(self.dataset) != -1: if len(self.dataset) != -1:
sampler = StatefulDistributedSampler( sampler = StatefulDistributedSampler(
self.dataset, self.dataset,
num_replicas=DistributedInterface().get_world_size("dp"), num_replicas=DistributedInterface().get_world_size(Dim.DP),
rank=DistributedInterface().get_rank("dp"), rank=DistributedInterface().get_rank(Dim.DP),
shuffle=True, shuffle=True,
seed=0, seed=0,
drop_last=self.drop_last, drop_last=self.drop_last,
@@ -142,6 +141,7 @@ class BatchGenerator(Iterator):
num_workers=self.batching_workers, num_workers=self.batching_workers,
collate_fn=self.renderer.process_samples, collate_fn=self.renderer.process_samples,
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
pin_memory_device=DistributedInterface().current_device.type,
drop_last=self.drop_last, drop_last=self.drop_last,
) )
if self.batching_strategy == BatchingStrategy.NORMAL: if self.batching_strategy == BatchingStrategy.NORMAL:

View File

@@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import subprocess
import sys import sys
from copy import deepcopy
from ..extras.env import VERSION, print_env
USAGE = ( USAGE = (
@@ -27,27 +28,97 @@ USAGE = (
+ "-" * 70 + "-" * 70
) )
_DIST_TRAIN_COMMANDS = ("train", "sft", "dpo", "rm")
WELCOME = (
"-" * 58
+ "\n"
+ f"| Welcome to LLaMA Factory, version {VERSION}"
+ " " * (21 - len(VERSION))
+ "|\n|"
+ " " * 56
+ "|\n"
+ "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n"
+ "-" * 58
)
def launch(): def launch():
from .accelerator.helper import get_device_count
from .utils.env import find_available_port, is_env_enabled, use_kt, use_ray
from .utils.logging import get_logger
logger = get_logger(__name__)
# NOTE:
# `llamafactory-cli <command> ...` enters here first.
# We may re-launch via `torchrun` for distributed training. In that case we must
# forward `<command>` as argv[1] to the re-executed script, otherwise the script
# will misinterpret the first user argument (e.g. yaml config) as the command.
command = sys.argv.pop(1) if len(sys.argv) > 1 else "help" command = sys.argv.pop(1) if len(sys.argv) > 1 else "help"
if command == "sft": # train command will fallback to sft command if command in _DIST_TRAIN_COMMANDS and (
from .trainers.sft_trainer import run_sft is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray() and not use_kt())
):
nnodes = os.getenv("NNODES", "1")
node_rank = os.getenv("NODE_RANK", "0")
nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count()))
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
master_port = os.getenv("MASTER_PORT", str(find_available_port()))
logger.info_rank0(f"Initializing {nproc_per_node} distributed tasks at: {master_addr}:{master_port}")
if int(nnodes) > 1:
logger.info_rank0(f"Multi-node training enabled: num nodes: {nnodes}, node rank: {node_rank}")
run_sft() # elastic launch support
max_restarts = os.getenv("MAX_RESTARTS", "0")
rdzv_id = os.getenv("RDZV_ID")
min_nnodes = os.getenv("MIN_NNODES")
max_nnodes = os.getenv("MAX_NNODES")
env = deepcopy(os.environ)
if is_env_enabled("OPTIM_TORCH", "1"):
# optimize DDP, see https://zhuanlan.zhihu.com/p/671834539
env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
env["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
torchrun_args = [
"torchrun",
"--nproc-per-node",
nproc_per_node,
]
if rdzv_id is not None:
# launch elastic job with fault tolerant support when possible
# see also https://docs.pytorch.org/docs/stable/elastic/train_script.html
rdzv_nnodes = nnodes
# elastic number of nodes if MIN_NNODES and MAX_NNODES are set
if min_nnodes is not None and max_nnodes is not None:
rdzv_nnodes = f"{min_nnodes}:{max_nnodes}"
torchrun_args.extend(
[
"--nnodes",
rdzv_nnodes,
"--rdzv-id",
rdzv_id,
"--rdzv-backend",
"c10d",
"--rdzv-endpoint",
f"{master_addr}:{master_port}",
"--max-restarts",
max_restarts,
]
)
else:
# NOTE: DO NOT USE shell=True to avoid security risk
torchrun_args.extend(
[
"--nnodes",
nnodes,
"--node_rank",
node_rank,
"--master_addr",
master_addr,
"--master_port",
master_port,
]
)
script_args = [__file__, command] + sys.argv[1:]
process = subprocess.run(
torchrun_args + script_args,
env=env,
check=True,
)
sys.exit(process.returncode)
elif command == "chat": elif command == "chat":
from .samplers.cli_sampler import run_chat from .samplers.cli_sampler import run_chat
@@ -55,17 +126,54 @@ def launch():
run_chat() run_chat()
elif command == "env": elif command == "env":
print_env() raise NotImplementedError("Environment information is not implemented yet.")
elif command == "version": elif command == "version":
print(WELCOME) raise NotImplementedError("Version information is not implemented yet.")
elif command == "help": elif command == "help":
print(USAGE) print(USAGE)
elif command in _DIST_TRAIN_COMMANDS:
# Single GPU training without torchrun
if command in ("train", "sft"):
from llamafactory.v1.trainers.sft_trainer import run_sft
run_sft()
elif command == "dpo":
raise NotImplementedError("DPO trainer is not implemented yet.")
elif command == "rm":
raise NotImplementedError("RM trainer is not implemented yet.")
else: else:
print(f"Unknown command: {command}.\n{USAGE}") print(f"Unknown command: {command}.\n{USAGE}")
def main():
# sys.argv[1] contains the command (sft/dpo/rm/train), sys.argv[2:] contains the rest args
command = sys.argv[1] if len(sys.argv) > 1 else "sft"
# Routing needs the sub-command, but downstream trainers usually expect argv without it.
if command in _DIST_TRAIN_COMMANDS:
sys.argv.pop(1)
else:
# Backward-compat: if someone runs `torchrun launcher.py config.yaml`,
# treat it as sft by default.
if len(sys.argv) > 1 and sys.argv[1].endswith((".yaml", ".yml")):
command = "sft"
if command in ("train", "sft"):
from llamafactory.v1.trainers.sft_trainer import run_sft
run_sft()
elif command == "dpo":
# from llamafactory.v1.trainers.dpo_trainer import run_dpo
# run_dpo()
raise NotImplementedError("DPO trainer is not implemented yet.")
elif command == "rm":
# from llamafactory.v1.trainers.rm_trainer import run_rm
# run_rm()
raise NotImplementedError("RM trainer is not implemented yet.")
if __name__ == "__main__": if __name__ == "__main__":
pass main()

View File

@@ -0,0 +1,399 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import os
import torch
import torch.nn as nn
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict, set_model_state_dict
from torch.distributed.fsdp import (
CPUOffloadPolicy,
MixedPrecisionPolicy,
fully_shard,
)
from transformers import PreTrainedModel
from ....accelerator.helper import get_current_accelerator
from ....accelerator.interface import DistributedInterface
from ....utils.logging import get_logger
logger = get_logger(__name__)
def get_transformer_layer_cls(model: PreTrainedModel) -> type[nn.Module] | None:
no_split_modules = getattr(model, "_no_split_modules", None)
if no_split_modules:
if isinstance(no_split_modules, (list, tuple)):
for name, module in model.named_modules():
for cls_name in no_split_modules:
if module.__class__.__name__ == cls_name:
return module.__class__
if hasattr(model, "model") and hasattr(model.model, "layers"):
return type(model.model.layers[0])
if hasattr(model, "layers"):
return type(model.layers[0])
return None
class FSDP2Engine:
def __init__(self, dist_config: dict):
self.dist_interface = DistributedInterface()
self.rank = self.dist_interface.get_rank()
self.local_rank = self.dist_interface.get_local_rank()
self.world_size = self.dist_interface.get_world_size()
self.mixed_precision = dist_config.get("mixed_precision", "bf16")
self.reshard_after_forward = dist_config.get("reshard_after_forward", True)
self.offload_params = dist_config.get("offload_params", False)
self.pin_memory = dist_config.get("pin_memory", True)
self.dcp_path = dist_config.get("dcp_path", None)
self.device_mesh = self.dist_interface.data_device_mesh
if self.device_mesh is None:
logger.warning(
"Device Mesh not found in DistributedInterface. FSDP2 might fail if not running in distributed mode."
)
if self.device_mesh is not None:
try:
self.fsdp_mesh = self.device_mesh["dp"]
except Exception:
self.fsdp_mesh = self.device_mesh
logger.info(f"Using Device Mesh: {self.fsdp_mesh}")
else:
self.fsdp_mesh = None
def get_mp_policy(self) -> MixedPrecisionPolicy:
if self.mixed_precision == "bf16":
param_dtype = torch.bfloat16
reduce_dtype = torch.float32
elif self.mixed_precision == "fp16":
param_dtype = torch.float16
reduce_dtype = torch.float32
else:
param_dtype = torch.float32
reduce_dtype = torch.float32
return MixedPrecisionPolicy(
param_dtype=param_dtype,
reduce_dtype=reduce_dtype,
cast_forward_inputs=True,
)
def prepare_model(self, model: PreTrainedModel) -> PreTrainedModel:
if self.fsdp_mesh is None:
logger.warning("No FSDP Mesh available, skipping FSDP wrapping.")
return model
mp_policy = self.get_mp_policy()
layer_cls = get_transformer_layer_cls(model)
if layer_cls is None:
logger.warning(
"Could not identify Transformer Layer class, applying FSDP to the whole model structure only."
)
transformer_layer_cls_to_wrap = set()
else:
logger.info(f"Applying per-layer FSDP to {layer_cls.__name__}")
transformer_layer_cls_to_wrap = {layer_cls}
for name, module in model.named_modules():
should_wrap = False
if type(module) in transformer_layer_cls_to_wrap:
should_wrap = True
elif isinstance(module, nn.Embedding):
if not getattr(model.config, "tie_word_embeddings", True):
should_wrap = True
if should_wrap:
fully_shard(
module,
mesh=self.fsdp_mesh,
reshard_after_forward=self.reshard_after_forward,
mp_policy=mp_policy,
offload_policy=CPUOffloadPolicy(pin_memory=self.pin_memory) if self.offload_params else None,
)
use_gradient_checkpointing = True # Could be configurable
if use_gradient_checkpointing:
if self.rank == 0:
logger.info("Enabling gradient checkpointing (transformers native)...")
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
fully_shard(
model,
mesh=self.fsdp_mesh,
reshard_after_forward=self.reshard_after_forward,
mp_policy=mp_policy,
offload_policy=CPUOffloadPolicy(pin_memory=self.pin_memory) if self.offload_params else None,
)
return model
@torch.no_grad()
def materialize_and_load(self, model: PreTrainedModel, hf_model_path: str, dcp_path: str = None):
if self.rank == 0:
logger.info("Materializing sharded model params...")
device = get_current_accelerator()
model.to_empty(device=device)
if dcp_path and os.path.exists(dcp_path):
if self.rank == 0:
logger.info(f"DCP path found at {dcp_path}. Using efficient Sharded Loading (DCP Load).")
self._load_from_dcp(model, dcp_path)
else:
if self.rank == 0:
if dcp_path:
logger.warning(f"DCP path {dcp_path} not found.")
logger.info("Using HF Meta Loading (Chunk Load).")
self._load_weights_from_hf_checkpoint(model, hf_model_path)
return model
def shard_model(self, model: PreTrainedModel) -> PreTrainedModel:
if model.device.type == "meta":
model = self.prepare_model(model)
model = self.materialize_and_load(model, hf_model_path=model.config.name_or_path, dcp_path=self.dcp_path)
else:
model = self.prepare_model(model)
return model
def _load_from_dcp(self, model: PreTrainedModel, dcp_path: str):
import torch.distributed.checkpoint as dcp
try:
if self.rank == 0:
logger.info(f"Loading distributed checkpoint from {dcp_path} ...")
options = StateDictOptions(full_state_dict=False, cpu_offload=True)
local_state_dict = get_model_state_dict(model, options=options)
dcp.load(state_dict=local_state_dict, checkpoint_id=dcp_path)
set_model_state_dict(model, local_state_dict, options=options)
if self.rank == 0:
logger.info("DCP weights loaded successfully.")
except Exception as e:
logger.error(f"Failed to load from DCP: {e}")
raise e
def _load_weights_from_hf_checkpoint(self, model, hf_model_path):
import glob
import json
hf_model_path = self._resolve_hf_checkpoint_dir(hf_model_path)
if self.rank == 0:
logger.info(f"Loading weights from {hf_model_path} ...")
index_file = os.path.join(hf_model_path, "model.safetensors.index.json")
is_safetensors = True
checkpoint_files = []
if os.path.exists(index_file):
with open(index_file) as f:
index = json.load(f)
checkpoint_files = sorted(set(index["weight_map"].values()))
checkpoint_files = [os.path.join(hf_model_path, f) for f in checkpoint_files]
elif os.path.exists(os.path.join(hf_model_path, "model.safetensors")):
checkpoint_files = [os.path.join(hf_model_path, "model.safetensors")]
else:
is_safetensors = False
index_file = os.path.join(hf_model_path, "pytorch_model.bin.index.json")
if os.path.exists(index_file):
with open(index_file) as f:
index = json.load(f)
checkpoint_files = sorted(set(index["weight_map"].values()))
checkpoint_files = [os.path.join(hf_model_path, f) for f in checkpoint_files]
elif os.path.exists(os.path.join(hf_model_path, "pytorch_model.bin")):
checkpoint_files = [os.path.join(hf_model_path, "pytorch_model.bin")]
else:
checkpoint_files = sorted(glob.glob(os.path.join(hf_model_path, "*.safetensors")))
if checkpoint_files:
is_safetensors = True
else:
checkpoint_files = sorted(glob.glob(os.path.join(hf_model_path, "*.bin")))
if not checkpoint_files:
raise ValueError(f"No checkpoint files found in {hf_model_path}")
param_map = dict(model.named_parameters())
total_files = len(checkpoint_files)
for i, ckpt_file in enumerate(checkpoint_files):
if self.rank == 0:
logger.info(f"[{i + 1}/{total_files}] Loading {os.path.basename(ckpt_file)} ...")
if is_safetensors:
from safetensors import safe_open
with safe_open(ckpt_file, framework="pt", device="cpu") as f:
for key in f.keys():
if key in param_map:
tensor = f.get_tensor(key)
self._copy_weights(param_map[key], tensor)
else:
state_dict = torch.load(ckpt_file, map_location="cpu")
for key, tensor in state_dict.items():
if key in param_map:
self._copy_weights(param_map[key], tensor)
del state_dict
gc.collect()
def _resolve_hf_checkpoint_dir(self, hf_model_path: str) -> str:
"""Resolve a HF model identifier or local path to a local directory containing checkpoint files.
- If `hf_model_path` is an existing directory, return it.
- If it's a file path, return its parent directory.
- Otherwise treat it as a Hugging Face Hub repo id and download/resolve to the local cache dir.
"""
if not hf_model_path:
return hf_model_path
# Local directory or file path.
if os.path.isdir(hf_model_path):
return hf_model_path
if os.path.isfile(hf_model_path):
return os.path.dirname(hf_model_path)
# HuggingFace Hub repo id: snapshot to local cache so we can glob/index files.
try:
from huggingface_hub import snapshot_download
except ImportError as e:
raise ValueError(
f"hf_model_path='{hf_model_path}' does not exist locally and huggingface_hub is not available "
f"to download it. Please provide a local model directory or install huggingface_hub. Error: {e}"
) from e
revision = os.getenv("HF_REVISION")
offline = os.getenv("HF_HUB_OFFLINE") == "1" or os.getenv("TRANSFORMERS_OFFLINE") == "1"
# In distributed runs, let rank0 download first to avoid N-way concurrent downloads.
if torch.distributed.is_available() and torch.distributed.is_initialized():
if self.rank == 0:
local_dir = snapshot_download(
repo_id=hf_model_path,
revision=revision,
local_files_only=offline,
allow_patterns=[
"*.safetensors",
"*.bin",
"*.index.json",
"model.safetensors",
"model.safetensors.index.json",
"pytorch_model.bin",
"pytorch_model.bin.index.json",
"config.json",
],
)
logger.info(f"Resolved HF repo id '{hf_model_path}' to local dir: {local_dir}")
torch.distributed.barrier()
if self.rank != 0:
local_dir = snapshot_download(
repo_id=hf_model_path,
revision=revision,
local_files_only=True,
allow_patterns=[
"*.safetensors",
"*.bin",
"*.index.json",
"model.safetensors",
"model.safetensors.index.json",
"pytorch_model.bin",
"pytorch_model.bin.index.json",
"config.json",
],
)
return local_dir
local_dir = snapshot_download(
repo_id=hf_model_path,
revision=revision,
local_files_only=offline,
allow_patterns=[
"*.safetensors",
"*.bin",
"*.index.json",
"model.safetensors",
"model.safetensors.index.json",
"pytorch_model.bin",
"pytorch_model.bin.index.json",
"config.json",
],
)
if self.rank == 0:
logger.info(f"Resolved HF repo id '{hf_model_path}' to local dir: {local_dir}")
return local_dir
def _copy_weights(self, param, loaded_tensor):
from torch.distributed._tensor import DTensor, Shard
if loaded_tensor.dtype != param.dtype:
loaded_tensor = loaded_tensor.to(param.dtype)
if isinstance(param, DTensor):
shard_placement = None
mesh_dim = -1
for i, placement in enumerate(param.placements):
if isinstance(placement, Shard):
shard_placement = placement
mesh_dim = i
break
local_tensor = param.to_local()
if shard_placement is None:
local_tensor.copy_(loaded_tensor)
else:
dim = shard_placement.dim
mesh = param.device_mesh
my_coordinate = mesh.get_coordinate()
if my_coordinate is None:
return
rank_in_dim = my_coordinate[mesh_dim]
world_size_in_dim = mesh.size(mesh_dim)
full_size = param.shape[dim]
chunk_size = (full_size + world_size_in_dim - 1) // world_size_in_dim
start = rank_in_dim * chunk_size
end = min(start + chunk_size, full_size)
if start >= full_size:
return
sliced_tensor = loaded_tensor.narrow(dim, start, end - start)
slices = [slice(None)] * local_tensor.ndim
slices[dim] = slice(0, sliced_tensor.shape[dim])
local_tensor[tuple(slices)].copy_(sliced_tensor)
else:
param.data.copy_(loaded_tensor)

View File

@@ -0,0 +1,34 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ....config.arg_utils import PluginConfig
from ....utils.plugin import BasePlugin
from ....utils.types import HFModel
class DistributedPlugin(BasePlugin):
def __call__(self, model: HFModel, dist_config: PluginConfig, **kwargs) -> HFModel:
return super().__call__(model, dist_config, **kwargs)
@DistributedPlugin("fsdp2").register()
def shard_model_fsdp2(model: HFModel, dist_config: PluginConfig) -> HFModel:
from .fsdp2 import FSDP2Engine
return FSDP2Engine(dist_config).shard_model(model)
@DistributedPlugin("deepspeed").register()
def shard_model_deepspeed(model: HFModel, dist_config: PluginConfig) -> HFModel:
return model

View File

@@ -28,3 +28,11 @@ def find_available_port() -> int:
def is_env_enabled(env_var: str, default: str = "0") -> bool: def is_env_enabled(env_var: str, default: str = "0") -> bool:
"""Check if the environment variable is enabled.""" """Check if the environment variable is enabled."""
return os.getenv(env_var, default).lower() in ["true", "yes", "on", "t", "y", "1"] return os.getenv(env_var, default).lower() in ["true", "yes", "on", "t", "y", "1"]
def use_ray() -> bool:
return False
def use_kt() -> bool:
return False

View File

@@ -15,7 +15,6 @@
import sys import sys
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest
import torch.multiprocessing as mp import torch.multiprocessing as mp
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM

View File

@@ -0,0 +1,89 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
import sys
from pathlib import Path
import pytest
@pytest.mark.xfail(reason="CI machines may OOM when heavily loaded.")
@pytest.mark.runs_on(["cuda", "npu"])
def test_fsdp2_sft_trainer(tmp_path: Path):
"""Test FSDP2 SFT trainer by simulating `llamafactory-cli sft config.yaml` behavior."""
config_yaml = """\
model: Qwen/Qwen3-0.6B
trust_remote_code: true
model_class: llm
template: qwen3_nothink
kernel_config:
name: auto
include_kernels: auto
quant_config: null
dist_config:
name: fsdp2
dcp_path: null
init_config:
name: init_on_meta
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: {output_dir}
micro_batch_size: 1
global_batch_size: 1
cutoff_len: 2048
learning_rate: 1.0e-4
bf16: false
max_steps: 1
### sample
sample_backend: hf
max_new_tokens: 128
"""
# Create output directory
output_dir = tmp_path / "outputs"
output_dir.mkdir(parents=True, exist_ok=True)
config_file = tmp_path / "config.yaml"
config_file.write_text(config_yaml.format(output_dir=str(output_dir)))
# Set up environment variables
env = os.environ.copy()
env["USE_V1"] = "1" # Use v1 launcher
env["FORCE_TORCHRUN"] = "1" # Force distributed training via torchrun
# Run the CLI command via subprocess
# This simulates: llamafactory-cli sft config.yaml
result = subprocess.run(
[sys.executable, "-m", "llamafactory.cli", "sft", str(config_file)],
env=env,
capture_output=True,
cwd=str(Path(__file__).parent.parent.parent), # LLaMA-Factory root
)
# Decode output with error handling (progress bars may contain non-UTF-8 bytes)
stderr = result.stderr.decode("utf-8", errors="replace")
# Check the result
assert result.returncode == 0, f"Training failed with return code {result.returncode}\nSTDERR: {stderr}"
# Verify output files exist (optional - adjust based on what run_sft produces)
# assert (output_dir / "some_expected_file").exists()