[v1] support training with fsdp2 (#9773)

Co-authored-by: frozenleaves <frozen@Mac.local>
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
This commit is contained in:
浮梦
2026-01-25 19:41:58 +08:00
committed by GitHub
parent 641bfdd482
commit f9f11dcb97
15 changed files with 801 additions and 33 deletions

View File

@@ -0,0 +1,34 @@
model: Qwen/Qwen3-0.6B
trust_remote_code: true
model_class: llm
template: qwen3_nothink
kernel_config:
name: auto
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
quant_config: null
dist_config:
name: fsdp2
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
init_config:
name: init_on_meta
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/test_fsdp2
micro_batch_size: 1
global_batch_size: 1
cutoff_len: 2048
learning_rate: 1.0e-4
bf16: false
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

55
scripts/hf2dcp.py Normal file
View File

@@ -0,0 +1,55 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert a HuggingFace model to DCP checkpoint format.
Usage:
python scripts/hf2dcp.py convert --hf_path=/path/to/hf --dcp_path=/path/to/dcp
Arguments:
hf_path: Path to the HuggingFace model directory.
dcp_path: Output path (directory) for DCP checkpoint.
"""
import fire
import torch
import torch.distributed.checkpoint as dcp
from transformers import AutoModelForCausalLM
def convert(hf_path: str, dcp_path: str) -> None:
"""Convert HF model weights to DCP.
Args:
hf_path: HuggingFace model directory.
dcp_path: Output path (directory) for DCP checkpoint.
"""
if not hf_path or not dcp_path:
raise ValueError("Both 'hf_path' and 'dcp_path' are required.")
print(f"Loading HF model from {hf_path}...")
model = AutoModelForCausalLM.from_pretrained(hf_path, device_map="cpu", torch_dtype=torch.bfloat16)
print(f"Saving to DCP format at {dcp_path}...")
dcp.save(model.state_dict(), checkpoint_id=dcp_path)
print("Done!")
def help() -> None:
"""Show help message."""
print(__doc__)
if __name__ == "__main__":
fire.Fire({"convert": convert, "help": help, "--convert": convert})

View File

@@ -180,6 +180,16 @@ def operate_tensorlike(fn: Callable[[...], Tensor], data: TensorLike, **kwargs)
return result.tolist()
def get_process_group_backend() -> str:
"""Get backend for init process group."""
if get_current_accelerator().type == DeviceType.NPU:
return "hccl"
elif get_current_accelerator().type == DeviceType.CUDA:
return "nccl"
else:
return "gloo"
def all_gather(tensor: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
"""Gathers the tensor from all ranks and stacks them at the first dim."""
world_size = get_world_size()

View File

@@ -145,7 +145,7 @@ class DistributedInterface:
timeout = config.get("timeout", 18000)
if self._is_distributed:
init_process_group(timeout=timedelta(seconds=timeout))
init_process_group(timeout=timedelta(seconds=timeout), backend=helper.get_process_group_backend())
self.model_device_mesh = init_device_mesh(
device_type=self.current_device.type,
mesh_shape=self.strategy.model_mesh_shape,

View File

@@ -20,7 +20,7 @@ from typing import Any
from omegaconf import OmegaConf
from transformers import HfArgumentParser
from ...extras.misc import is_env_enabled
from ..utils.env import is_env_enabled
from .data_args import DataArguments
from .model_args import ModelArguments
from .sample_args import SampleArguments

View File

@@ -45,6 +45,10 @@ class TrainingArguments:
default=3,
metadata={"help": "Number of training epochs."},
)
max_steps: int | None = field(
default=None,
metadata={"help": "Maximum number of training steps. If set, overrides num_train_epochs."},
)
max_grad_norm: float = field(
default=1.0,
metadata={"help": "Maximum gradient norm for training."},

View File

@@ -67,6 +67,10 @@ class BaseTrainer:
self.model_input_names = self.renderer.processor.model_input_names
self._create_batch_generator()
# Calculate num_training_steps: max_steps takes priority if set
if self.args.max_steps is not None and self.args.max_steps > 0:
self.num_training_steps = self.args.max_steps
else:
self.num_training_steps = self.args.num_train_epochs * len(self.train_batch_generator)
if self.args.enable_activation_checkpointing:
@@ -98,7 +102,22 @@ class BaseTrainer:
)
def _shard_model(self) -> None:
pass
if self.args.dist_config is None:
if DistributedInterface().get_world_size(Dim.DP) > 1:
from torch.nn.parallel import DistributedDataParallel as DDP
logger.warning_rank0(
"dist_config is None but distributed training is enabled; falling back to DistributedDataParallel."
)
device_ids = None if self.device.type == "cpu" else [self.device.index]
self.model = DDP(self.model, device_ids=device_ids)
else:
from ..plugins.trainer_plugins.distributed.hub import DistributedPlugin
self.model = DistributedPlugin(self.args.dist_config.name)(
self.model,
self.args.dist_config,
)
def _init_optimizer(self) -> None:
"""Init optimizer."""
@@ -162,7 +181,9 @@ class BaseTrainer:
step_loss += loss.item()
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm).item()
if not torch.isfinite(grad_norm):
# isfinite(): argument 'input' (position 1) must be Tensor, not float
if not torch.isfinite(torch.tensor(grad_norm)): # type: ignore # pyright: ignore [reportUnknownReturnType]
logger.warning_rank0(f"Gradient norm is not finite: {grad_norm}")
else:
self.optimizer.step()
@@ -172,10 +193,17 @@ class BaseTrainer:
step_loss, grad_norm = DistributedInterface().all_reduce([step_loss, grad_norm])
DistributedInterface().sync()
if DistributedInterface().get_rank() == 0:
print(f"Epoch {epoch}, Step {self.global_step}, Loss: {step_loss:.4f}, Grad Norm: {grad_norm:.4f}")
# Check if max_steps is reached
if self.global_step >= self.num_training_steps:
logger.info_rank0(f"Reached max_steps ({self.num_training_steps}), stopping training.")
return
def save_model(self) -> None:
"""Save the model."""
self.model.save_pretrained(self.args.output_dir)
model_to_save = self.model.module if hasattr(self.model, "module") else self.model
model_to_save.save_pretrained(self.args.output_dir)
self.renderer.processor.save_pretrained(self.args.output_dir)
logger.info_rank0(f"Model saved to {self.args.output_dir}")

View File

@@ -30,7 +30,7 @@ from torch.utils.data import default_collate
from torchdata.stateful_dataloader import StatefulDataLoader
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
from ...accelerator.interface import DistributedInterface
from ...accelerator.interface import Dim, DistributedInterface
from ...config import BatchingStrategy
from ...utils import logging
from ...utils.helper import pad_and_truncate
@@ -83,8 +83,7 @@ class BatchGenerator(Iterator):
self.pin_memory = pin_memory
self.drop_last = drop_last
# TODO: support length and infinity
dp_size = DistributedInterface().get_world_size("dp")
dp_size = DistributedInterface().get_world_size(Dim.DP)
if self.global_batch_size is None:
self.global_batch_size = dp_size * micro_batch_size
@@ -126,8 +125,8 @@ class BatchGenerator(Iterator):
if len(self.dataset) != -1:
sampler = StatefulDistributedSampler(
self.dataset,
num_replicas=DistributedInterface().get_world_size("dp"),
rank=DistributedInterface().get_rank("dp"),
num_replicas=DistributedInterface().get_world_size(Dim.DP),
rank=DistributedInterface().get_rank(Dim.DP),
shuffle=True,
seed=0,
drop_last=self.drop_last,
@@ -142,6 +141,7 @@ class BatchGenerator(Iterator):
num_workers=self.batching_workers,
collate_fn=self.renderer.process_samples,
pin_memory=self.pin_memory,
pin_memory_device=DistributedInterface().current_device.type,
drop_last=self.drop_last,
)
if self.batching_strategy == BatchingStrategy.NORMAL:

View File

@@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
import sys
from ..extras.env import VERSION, print_env
from copy import deepcopy
USAGE = (
@@ -27,27 +28,97 @@ USAGE = (
+ "-" * 70
)
WELCOME = (
"-" * 58
+ "\n"
+ f"| Welcome to LLaMA Factory, version {VERSION}"
+ " " * (21 - len(VERSION))
+ "|\n|"
+ " " * 56
+ "|\n"
+ "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n"
+ "-" * 58
)
_DIST_TRAIN_COMMANDS = ("train", "sft", "dpo", "rm")
def launch():
from .accelerator.helper import get_device_count
from .utils.env import find_available_port, is_env_enabled, use_kt, use_ray
from .utils.logging import get_logger
logger = get_logger(__name__)
# NOTE:
# `llamafactory-cli <command> ...` enters here first.
# We may re-launch via `torchrun` for distributed training. In that case we must
# forward `<command>` as argv[1] to the re-executed script, otherwise the script
# will misinterpret the first user argument (e.g. yaml config) as the command.
command = sys.argv.pop(1) if len(sys.argv) > 1 else "help"
if command == "sft": # train command will fallback to sft command
from .trainers.sft_trainer import run_sft
if command in _DIST_TRAIN_COMMANDS and (
is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray() and not use_kt())
):
nnodes = os.getenv("NNODES", "1")
node_rank = os.getenv("NODE_RANK", "0")
nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count()))
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
master_port = os.getenv("MASTER_PORT", str(find_available_port()))
logger.info_rank0(f"Initializing {nproc_per_node} distributed tasks at: {master_addr}:{master_port}")
if int(nnodes) > 1:
logger.info_rank0(f"Multi-node training enabled: num nodes: {nnodes}, node rank: {node_rank}")
run_sft()
# elastic launch support
max_restarts = os.getenv("MAX_RESTARTS", "0")
rdzv_id = os.getenv("RDZV_ID")
min_nnodes = os.getenv("MIN_NNODES")
max_nnodes = os.getenv("MAX_NNODES")
env = deepcopy(os.environ)
if is_env_enabled("OPTIM_TORCH", "1"):
# optimize DDP, see https://zhuanlan.zhihu.com/p/671834539
env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
env["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
torchrun_args = [
"torchrun",
"--nproc-per-node",
nproc_per_node,
]
if rdzv_id is not None:
# launch elastic job with fault tolerant support when possible
# see also https://docs.pytorch.org/docs/stable/elastic/train_script.html
rdzv_nnodes = nnodes
# elastic number of nodes if MIN_NNODES and MAX_NNODES are set
if min_nnodes is not None and max_nnodes is not None:
rdzv_nnodes = f"{min_nnodes}:{max_nnodes}"
torchrun_args.extend(
[
"--nnodes",
rdzv_nnodes,
"--rdzv-id",
rdzv_id,
"--rdzv-backend",
"c10d",
"--rdzv-endpoint",
f"{master_addr}:{master_port}",
"--max-restarts",
max_restarts,
]
)
else:
# NOTE: DO NOT USE shell=True to avoid security risk
torchrun_args.extend(
[
"--nnodes",
nnodes,
"--node_rank",
node_rank,
"--master_addr",
master_addr,
"--master_port",
master_port,
]
)
script_args = [__file__, command] + sys.argv[1:]
process = subprocess.run(
torchrun_args + script_args,
env=env,
check=True,
)
sys.exit(process.returncode)
elif command == "chat":
from .samplers.cli_sampler import run_chat
@@ -55,17 +126,54 @@ def launch():
run_chat()
elif command == "env":
print_env()
raise NotImplementedError("Environment information is not implemented yet.")
elif command == "version":
print(WELCOME)
raise NotImplementedError("Version information is not implemented yet.")
elif command == "help":
print(USAGE)
elif command in _DIST_TRAIN_COMMANDS:
# Single GPU training without torchrun
if command in ("train", "sft"):
from llamafactory.v1.trainers.sft_trainer import run_sft
run_sft()
elif command == "dpo":
raise NotImplementedError("DPO trainer is not implemented yet.")
elif command == "rm":
raise NotImplementedError("RM trainer is not implemented yet.")
else:
print(f"Unknown command: {command}.\n{USAGE}")
def main():
# sys.argv[1] contains the command (sft/dpo/rm/train), sys.argv[2:] contains the rest args
command = sys.argv[1] if len(sys.argv) > 1 else "sft"
# Routing needs the sub-command, but downstream trainers usually expect argv without it.
if command in _DIST_TRAIN_COMMANDS:
sys.argv.pop(1)
else:
# Backward-compat: if someone runs `torchrun launcher.py config.yaml`,
# treat it as sft by default.
if len(sys.argv) > 1 and sys.argv[1].endswith((".yaml", ".yml")):
command = "sft"
if command in ("train", "sft"):
from llamafactory.v1.trainers.sft_trainer import run_sft
run_sft()
elif command == "dpo":
# from llamafactory.v1.trainers.dpo_trainer import run_dpo
# run_dpo()
raise NotImplementedError("DPO trainer is not implemented yet.")
elif command == "rm":
# from llamafactory.v1.trainers.rm_trainer import run_rm
# run_rm()
raise NotImplementedError("RM trainer is not implemented yet.")
if __name__ == "__main__":
pass
main()

View File

@@ -0,0 +1,399 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import os
import torch
import torch.nn as nn
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict, set_model_state_dict
from torch.distributed.fsdp import (
CPUOffloadPolicy,
MixedPrecisionPolicy,
fully_shard,
)
from transformers import PreTrainedModel
from ....accelerator.helper import get_current_accelerator
from ....accelerator.interface import DistributedInterface
from ....utils.logging import get_logger
logger = get_logger(__name__)
def get_transformer_layer_cls(model: PreTrainedModel) -> type[nn.Module] | None:
no_split_modules = getattr(model, "_no_split_modules", None)
if no_split_modules:
if isinstance(no_split_modules, (list, tuple)):
for name, module in model.named_modules():
for cls_name in no_split_modules:
if module.__class__.__name__ == cls_name:
return module.__class__
if hasattr(model, "model") and hasattr(model.model, "layers"):
return type(model.model.layers[0])
if hasattr(model, "layers"):
return type(model.layers[0])
return None
class FSDP2Engine:
def __init__(self, dist_config: dict):
self.dist_interface = DistributedInterface()
self.rank = self.dist_interface.get_rank()
self.local_rank = self.dist_interface.get_local_rank()
self.world_size = self.dist_interface.get_world_size()
self.mixed_precision = dist_config.get("mixed_precision", "bf16")
self.reshard_after_forward = dist_config.get("reshard_after_forward", True)
self.offload_params = dist_config.get("offload_params", False)
self.pin_memory = dist_config.get("pin_memory", True)
self.dcp_path = dist_config.get("dcp_path", None)
self.device_mesh = self.dist_interface.data_device_mesh
if self.device_mesh is None:
logger.warning(
"Device Mesh not found in DistributedInterface. FSDP2 might fail if not running in distributed mode."
)
if self.device_mesh is not None:
try:
self.fsdp_mesh = self.device_mesh["dp"]
except Exception:
self.fsdp_mesh = self.device_mesh
logger.info(f"Using Device Mesh: {self.fsdp_mesh}")
else:
self.fsdp_mesh = None
def get_mp_policy(self) -> MixedPrecisionPolicy:
if self.mixed_precision == "bf16":
param_dtype = torch.bfloat16
reduce_dtype = torch.float32
elif self.mixed_precision == "fp16":
param_dtype = torch.float16
reduce_dtype = torch.float32
else:
param_dtype = torch.float32
reduce_dtype = torch.float32
return MixedPrecisionPolicy(
param_dtype=param_dtype,
reduce_dtype=reduce_dtype,
cast_forward_inputs=True,
)
def prepare_model(self, model: PreTrainedModel) -> PreTrainedModel:
if self.fsdp_mesh is None:
logger.warning("No FSDP Mesh available, skipping FSDP wrapping.")
return model
mp_policy = self.get_mp_policy()
layer_cls = get_transformer_layer_cls(model)
if layer_cls is None:
logger.warning(
"Could not identify Transformer Layer class, applying FSDP to the whole model structure only."
)
transformer_layer_cls_to_wrap = set()
else:
logger.info(f"Applying per-layer FSDP to {layer_cls.__name__}")
transformer_layer_cls_to_wrap = {layer_cls}
for name, module in model.named_modules():
should_wrap = False
if type(module) in transformer_layer_cls_to_wrap:
should_wrap = True
elif isinstance(module, nn.Embedding):
if not getattr(model.config, "tie_word_embeddings", True):
should_wrap = True
if should_wrap:
fully_shard(
module,
mesh=self.fsdp_mesh,
reshard_after_forward=self.reshard_after_forward,
mp_policy=mp_policy,
offload_policy=CPUOffloadPolicy(pin_memory=self.pin_memory) if self.offload_params else None,
)
use_gradient_checkpointing = True # Could be configurable
if use_gradient_checkpointing:
if self.rank == 0:
logger.info("Enabling gradient checkpointing (transformers native)...")
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
fully_shard(
model,
mesh=self.fsdp_mesh,
reshard_after_forward=self.reshard_after_forward,
mp_policy=mp_policy,
offload_policy=CPUOffloadPolicy(pin_memory=self.pin_memory) if self.offload_params else None,
)
return model
@torch.no_grad()
def materialize_and_load(self, model: PreTrainedModel, hf_model_path: str, dcp_path: str = None):
if self.rank == 0:
logger.info("Materializing sharded model params...")
device = get_current_accelerator()
model.to_empty(device=device)
if dcp_path and os.path.exists(dcp_path):
if self.rank == 0:
logger.info(f"DCP path found at {dcp_path}. Using efficient Sharded Loading (DCP Load).")
self._load_from_dcp(model, dcp_path)
else:
if self.rank == 0:
if dcp_path:
logger.warning(f"DCP path {dcp_path} not found.")
logger.info("Using HF Meta Loading (Chunk Load).")
self._load_weights_from_hf_checkpoint(model, hf_model_path)
return model
def shard_model(self, model: PreTrainedModel) -> PreTrainedModel:
if model.device.type == "meta":
model = self.prepare_model(model)
model = self.materialize_and_load(model, hf_model_path=model.config.name_or_path, dcp_path=self.dcp_path)
else:
model = self.prepare_model(model)
return model
def _load_from_dcp(self, model: PreTrainedModel, dcp_path: str):
import torch.distributed.checkpoint as dcp
try:
if self.rank == 0:
logger.info(f"Loading distributed checkpoint from {dcp_path} ...")
options = StateDictOptions(full_state_dict=False, cpu_offload=True)
local_state_dict = get_model_state_dict(model, options=options)
dcp.load(state_dict=local_state_dict, checkpoint_id=dcp_path)
set_model_state_dict(model, local_state_dict, options=options)
if self.rank == 0:
logger.info("DCP weights loaded successfully.")
except Exception as e:
logger.error(f"Failed to load from DCP: {e}")
raise e
def _load_weights_from_hf_checkpoint(self, model, hf_model_path):
import glob
import json
hf_model_path = self._resolve_hf_checkpoint_dir(hf_model_path)
if self.rank == 0:
logger.info(f"Loading weights from {hf_model_path} ...")
index_file = os.path.join(hf_model_path, "model.safetensors.index.json")
is_safetensors = True
checkpoint_files = []
if os.path.exists(index_file):
with open(index_file) as f:
index = json.load(f)
checkpoint_files = sorted(set(index["weight_map"].values()))
checkpoint_files = [os.path.join(hf_model_path, f) for f in checkpoint_files]
elif os.path.exists(os.path.join(hf_model_path, "model.safetensors")):
checkpoint_files = [os.path.join(hf_model_path, "model.safetensors")]
else:
is_safetensors = False
index_file = os.path.join(hf_model_path, "pytorch_model.bin.index.json")
if os.path.exists(index_file):
with open(index_file) as f:
index = json.load(f)
checkpoint_files = sorted(set(index["weight_map"].values()))
checkpoint_files = [os.path.join(hf_model_path, f) for f in checkpoint_files]
elif os.path.exists(os.path.join(hf_model_path, "pytorch_model.bin")):
checkpoint_files = [os.path.join(hf_model_path, "pytorch_model.bin")]
else:
checkpoint_files = sorted(glob.glob(os.path.join(hf_model_path, "*.safetensors")))
if checkpoint_files:
is_safetensors = True
else:
checkpoint_files = sorted(glob.glob(os.path.join(hf_model_path, "*.bin")))
if not checkpoint_files:
raise ValueError(f"No checkpoint files found in {hf_model_path}")
param_map = dict(model.named_parameters())
total_files = len(checkpoint_files)
for i, ckpt_file in enumerate(checkpoint_files):
if self.rank == 0:
logger.info(f"[{i + 1}/{total_files}] Loading {os.path.basename(ckpt_file)} ...")
if is_safetensors:
from safetensors import safe_open
with safe_open(ckpt_file, framework="pt", device="cpu") as f:
for key in f.keys():
if key in param_map:
tensor = f.get_tensor(key)
self._copy_weights(param_map[key], tensor)
else:
state_dict = torch.load(ckpt_file, map_location="cpu")
for key, tensor in state_dict.items():
if key in param_map:
self._copy_weights(param_map[key], tensor)
del state_dict
gc.collect()
def _resolve_hf_checkpoint_dir(self, hf_model_path: str) -> str:
"""Resolve a HF model identifier or local path to a local directory containing checkpoint files.
- If `hf_model_path` is an existing directory, return it.
- If it's a file path, return its parent directory.
- Otherwise treat it as a Hugging Face Hub repo id and download/resolve to the local cache dir.
"""
if not hf_model_path:
return hf_model_path
# Local directory or file path.
if os.path.isdir(hf_model_path):
return hf_model_path
if os.path.isfile(hf_model_path):
return os.path.dirname(hf_model_path)
# HuggingFace Hub repo id: snapshot to local cache so we can glob/index files.
try:
from huggingface_hub import snapshot_download
except ImportError as e:
raise ValueError(
f"hf_model_path='{hf_model_path}' does not exist locally and huggingface_hub is not available "
f"to download it. Please provide a local model directory or install huggingface_hub. Error: {e}"
) from e
revision = os.getenv("HF_REVISION")
offline = os.getenv("HF_HUB_OFFLINE") == "1" or os.getenv("TRANSFORMERS_OFFLINE") == "1"
# In distributed runs, let rank0 download first to avoid N-way concurrent downloads.
if torch.distributed.is_available() and torch.distributed.is_initialized():
if self.rank == 0:
local_dir = snapshot_download(
repo_id=hf_model_path,
revision=revision,
local_files_only=offline,
allow_patterns=[
"*.safetensors",
"*.bin",
"*.index.json",
"model.safetensors",
"model.safetensors.index.json",
"pytorch_model.bin",
"pytorch_model.bin.index.json",
"config.json",
],
)
logger.info(f"Resolved HF repo id '{hf_model_path}' to local dir: {local_dir}")
torch.distributed.barrier()
if self.rank != 0:
local_dir = snapshot_download(
repo_id=hf_model_path,
revision=revision,
local_files_only=True,
allow_patterns=[
"*.safetensors",
"*.bin",
"*.index.json",
"model.safetensors",
"model.safetensors.index.json",
"pytorch_model.bin",
"pytorch_model.bin.index.json",
"config.json",
],
)
return local_dir
local_dir = snapshot_download(
repo_id=hf_model_path,
revision=revision,
local_files_only=offline,
allow_patterns=[
"*.safetensors",
"*.bin",
"*.index.json",
"model.safetensors",
"model.safetensors.index.json",
"pytorch_model.bin",
"pytorch_model.bin.index.json",
"config.json",
],
)
if self.rank == 0:
logger.info(f"Resolved HF repo id '{hf_model_path}' to local dir: {local_dir}")
return local_dir
def _copy_weights(self, param, loaded_tensor):
from torch.distributed._tensor import DTensor, Shard
if loaded_tensor.dtype != param.dtype:
loaded_tensor = loaded_tensor.to(param.dtype)
if isinstance(param, DTensor):
shard_placement = None
mesh_dim = -1
for i, placement in enumerate(param.placements):
if isinstance(placement, Shard):
shard_placement = placement
mesh_dim = i
break
local_tensor = param.to_local()
if shard_placement is None:
local_tensor.copy_(loaded_tensor)
else:
dim = shard_placement.dim
mesh = param.device_mesh
my_coordinate = mesh.get_coordinate()
if my_coordinate is None:
return
rank_in_dim = my_coordinate[mesh_dim]
world_size_in_dim = mesh.size(mesh_dim)
full_size = param.shape[dim]
chunk_size = (full_size + world_size_in_dim - 1) // world_size_in_dim
start = rank_in_dim * chunk_size
end = min(start + chunk_size, full_size)
if start >= full_size:
return
sliced_tensor = loaded_tensor.narrow(dim, start, end - start)
slices = [slice(None)] * local_tensor.ndim
slices[dim] = slice(0, sliced_tensor.shape[dim])
local_tensor[tuple(slices)].copy_(sliced_tensor)
else:
param.data.copy_(loaded_tensor)

View File

@@ -0,0 +1,34 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ....config.arg_utils import PluginConfig
from ....utils.plugin import BasePlugin
from ....utils.types import HFModel
class DistributedPlugin(BasePlugin):
def __call__(self, model: HFModel, dist_config: PluginConfig, **kwargs) -> HFModel:
return super().__call__(model, dist_config, **kwargs)
@DistributedPlugin("fsdp2").register()
def shard_model_fsdp2(model: HFModel, dist_config: PluginConfig) -> HFModel:
from .fsdp2 import FSDP2Engine
return FSDP2Engine(dist_config).shard_model(model)
@DistributedPlugin("deepspeed").register()
def shard_model_deepspeed(model: HFModel, dist_config: PluginConfig) -> HFModel:
return model

View File

@@ -28,3 +28,11 @@ def find_available_port() -> int:
def is_env_enabled(env_var: str, default: str = "0") -> bool:
"""Check if the environment variable is enabled."""
return os.getenv(env_var, default).lower() in ["true", "yes", "on", "t", "y", "1"]
def use_ray() -> bool:
return False
def use_kt() -> bool:
return False

View File

@@ -15,7 +15,6 @@
import sys
from unittest.mock import MagicMock, patch
import pytest
import torch.multiprocessing as mp
from transformers import AutoModelForCausalLM

View File

@@ -0,0 +1,89 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
import sys
from pathlib import Path
import pytest
@pytest.mark.xfail(reason="CI machines may OOM when heavily loaded.")
@pytest.mark.runs_on(["cuda", "npu"])
def test_fsdp2_sft_trainer(tmp_path: Path):
"""Test FSDP2 SFT trainer by simulating `llamafactory-cli sft config.yaml` behavior."""
config_yaml = """\
model: Qwen/Qwen3-0.6B
trust_remote_code: true
model_class: llm
template: qwen3_nothink
kernel_config:
name: auto
include_kernels: auto
quant_config: null
dist_config:
name: fsdp2
dcp_path: null
init_config:
name: init_on_meta
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: {output_dir}
micro_batch_size: 1
global_batch_size: 1
cutoff_len: 2048
learning_rate: 1.0e-4
bf16: false
max_steps: 1
### sample
sample_backend: hf
max_new_tokens: 128
"""
# Create output directory
output_dir = tmp_path / "outputs"
output_dir.mkdir(parents=True, exist_ok=True)
config_file = tmp_path / "config.yaml"
config_file.write_text(config_yaml.format(output_dir=str(output_dir)))
# Set up environment variables
env = os.environ.copy()
env["USE_V1"] = "1" # Use v1 launcher
env["FORCE_TORCHRUN"] = "1" # Force distributed training via torchrun
# Run the CLI command via subprocess
# This simulates: llamafactory-cli sft config.yaml
result = subprocess.run(
[sys.executable, "-m", "llamafactory.cli", "sft", str(config_file)],
env=env,
capture_output=True,
cwd=str(Path(__file__).parent.parent.parent), # LLaMA-Factory root
)
# Decode output with error handling (progress bars may contain non-UTF-8 bytes)
stderr = result.stderr.decode("utf-8", errors="replace")
# Check the result
assert result.returncode == 0, f"Training failed with return code {result.returncode}\nSTDERR: {stderr}"
# Verify output files exist (optional - adjust based on what run_sft produces)
# assert (output_dir / "some_expected_file").exists()