mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-01-30 06:12:04 +00:00
[v1] support training with fsdp2 (#9773)
Co-authored-by: frozenleaves <frozen@Mac.local> Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
This commit is contained in:
34
examples/v1/train_full/train_full_fsdp2.yaml
Normal file
34
examples/v1/train_full/train_full_fsdp2.yaml
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
model: Qwen/Qwen3-0.6B
|
||||||
|
trust_remote_code: true
|
||||||
|
model_class: llm
|
||||||
|
|
||||||
|
template: qwen3_nothink
|
||||||
|
|
||||||
|
kernel_config:
|
||||||
|
name: auto
|
||||||
|
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
|
||||||
|
|
||||||
|
quant_config: null
|
||||||
|
|
||||||
|
dist_config:
|
||||||
|
name: fsdp2
|
||||||
|
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
|
||||||
|
|
||||||
|
init_config:
|
||||||
|
name: init_on_meta
|
||||||
|
|
||||||
|
### data
|
||||||
|
train_dataset: data/v1_sft_demo.yaml
|
||||||
|
|
||||||
|
### training
|
||||||
|
output_dir: outputs/test_fsdp2
|
||||||
|
micro_batch_size: 1
|
||||||
|
global_batch_size: 1
|
||||||
|
cutoff_len: 2048
|
||||||
|
learning_rate: 1.0e-4
|
||||||
|
bf16: false
|
||||||
|
max_steps: 10
|
||||||
|
|
||||||
|
### sample
|
||||||
|
sample_backend: hf
|
||||||
|
max_new_tokens: 128
|
||||||
55
scripts/hf2dcp.py
Normal file
55
scripts/hf2dcp.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Convert a HuggingFace model to DCP checkpoint format.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python scripts/hf2dcp.py convert --hf_path=/path/to/hf --dcp_path=/path/to/dcp
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
hf_path: Path to the HuggingFace model directory.
|
||||||
|
dcp_path: Output path (directory) for DCP checkpoint.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import fire
|
||||||
|
import torch
|
||||||
|
import torch.distributed.checkpoint as dcp
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
|
|
||||||
|
def convert(hf_path: str, dcp_path: str) -> None:
|
||||||
|
"""Convert HF model weights to DCP.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hf_path: HuggingFace model directory.
|
||||||
|
dcp_path: Output path (directory) for DCP checkpoint.
|
||||||
|
"""
|
||||||
|
if not hf_path or not dcp_path:
|
||||||
|
raise ValueError("Both 'hf_path' and 'dcp_path' are required.")
|
||||||
|
|
||||||
|
print(f"Loading HF model from {hf_path}...")
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(hf_path, device_map="cpu", torch_dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
print(f"Saving to DCP format at {dcp_path}...")
|
||||||
|
dcp.save(model.state_dict(), checkpoint_id=dcp_path)
|
||||||
|
print("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
def help() -> None:
|
||||||
|
"""Show help message."""
|
||||||
|
print(__doc__)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
fire.Fire({"convert": convert, "help": help, "--convert": convert})
|
||||||
@@ -180,6 +180,16 @@ def operate_tensorlike(fn: Callable[[...], Tensor], data: TensorLike, **kwargs)
|
|||||||
return result.tolist()
|
return result.tolist()
|
||||||
|
|
||||||
|
|
||||||
|
def get_process_group_backend() -> str:
|
||||||
|
"""Get backend for init process group."""
|
||||||
|
if get_current_accelerator().type == DeviceType.NPU:
|
||||||
|
return "hccl"
|
||||||
|
elif get_current_accelerator().type == DeviceType.CUDA:
|
||||||
|
return "nccl"
|
||||||
|
else:
|
||||||
|
return "gloo"
|
||||||
|
|
||||||
|
|
||||||
def all_gather(tensor: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
|
def all_gather(tensor: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
|
||||||
"""Gathers the tensor from all ranks and stacks them at the first dim."""
|
"""Gathers the tensor from all ranks and stacks them at the first dim."""
|
||||||
world_size = get_world_size()
|
world_size = get_world_size()
|
||||||
|
|||||||
@@ -145,7 +145,7 @@ class DistributedInterface:
|
|||||||
timeout = config.get("timeout", 18000)
|
timeout = config.get("timeout", 18000)
|
||||||
|
|
||||||
if self._is_distributed:
|
if self._is_distributed:
|
||||||
init_process_group(timeout=timedelta(seconds=timeout))
|
init_process_group(timeout=timedelta(seconds=timeout), backend=helper.get_process_group_backend())
|
||||||
self.model_device_mesh = init_device_mesh(
|
self.model_device_mesh = init_device_mesh(
|
||||||
device_type=self.current_device.type,
|
device_type=self.current_device.type,
|
||||||
mesh_shape=self.strategy.model_mesh_shape,
|
mesh_shape=self.strategy.model_mesh_shape,
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from typing import Any
|
|||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from transformers import HfArgumentParser
|
from transformers import HfArgumentParser
|
||||||
|
|
||||||
from ...extras.misc import is_env_enabled
|
from ..utils.env import is_env_enabled
|
||||||
from .data_args import DataArguments
|
from .data_args import DataArguments
|
||||||
from .model_args import ModelArguments
|
from .model_args import ModelArguments
|
||||||
from .sample_args import SampleArguments
|
from .sample_args import SampleArguments
|
||||||
|
|||||||
@@ -45,6 +45,10 @@ class TrainingArguments:
|
|||||||
default=3,
|
default=3,
|
||||||
metadata={"help": "Number of training epochs."},
|
metadata={"help": "Number of training epochs."},
|
||||||
)
|
)
|
||||||
|
max_steps: int | None = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Maximum number of training steps. If set, overrides num_train_epochs."},
|
||||||
|
)
|
||||||
max_grad_norm: float = field(
|
max_grad_norm: float = field(
|
||||||
default=1.0,
|
default=1.0,
|
||||||
metadata={"help": "Maximum gradient norm for training."},
|
metadata={"help": "Maximum gradient norm for training."},
|
||||||
|
|||||||
@@ -67,6 +67,10 @@ class BaseTrainer:
|
|||||||
self.model_input_names = self.renderer.processor.model_input_names
|
self.model_input_names = self.renderer.processor.model_input_names
|
||||||
|
|
||||||
self._create_batch_generator()
|
self._create_batch_generator()
|
||||||
|
# Calculate num_training_steps: max_steps takes priority if set
|
||||||
|
if self.args.max_steps is not None and self.args.max_steps > 0:
|
||||||
|
self.num_training_steps = self.args.max_steps
|
||||||
|
else:
|
||||||
self.num_training_steps = self.args.num_train_epochs * len(self.train_batch_generator)
|
self.num_training_steps = self.args.num_train_epochs * len(self.train_batch_generator)
|
||||||
|
|
||||||
if self.args.enable_activation_checkpointing:
|
if self.args.enable_activation_checkpointing:
|
||||||
@@ -98,7 +102,22 @@ class BaseTrainer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _shard_model(self) -> None:
|
def _shard_model(self) -> None:
|
||||||
pass
|
if self.args.dist_config is None:
|
||||||
|
if DistributedInterface().get_world_size(Dim.DP) > 1:
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
|
logger.warning_rank0(
|
||||||
|
"dist_config is None but distributed training is enabled; falling back to DistributedDataParallel."
|
||||||
|
)
|
||||||
|
device_ids = None if self.device.type == "cpu" else [self.device.index]
|
||||||
|
self.model = DDP(self.model, device_ids=device_ids)
|
||||||
|
else:
|
||||||
|
from ..plugins.trainer_plugins.distributed.hub import DistributedPlugin
|
||||||
|
|
||||||
|
self.model = DistributedPlugin(self.args.dist_config.name)(
|
||||||
|
self.model,
|
||||||
|
self.args.dist_config,
|
||||||
|
)
|
||||||
|
|
||||||
def _init_optimizer(self) -> None:
|
def _init_optimizer(self) -> None:
|
||||||
"""Init optimizer."""
|
"""Init optimizer."""
|
||||||
@@ -162,7 +181,9 @@ class BaseTrainer:
|
|||||||
step_loss += loss.item()
|
step_loss += loss.item()
|
||||||
|
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm).item()
|
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm).item()
|
||||||
if not torch.isfinite(grad_norm):
|
|
||||||
|
# isfinite(): argument 'input' (position 1) must be Tensor, not float
|
||||||
|
if not torch.isfinite(torch.tensor(grad_norm)): # type: ignore # pyright: ignore [reportUnknownReturnType]
|
||||||
logger.warning_rank0(f"Gradient norm is not finite: {grad_norm}")
|
logger.warning_rank0(f"Gradient norm is not finite: {grad_norm}")
|
||||||
else:
|
else:
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
@@ -172,10 +193,17 @@ class BaseTrainer:
|
|||||||
|
|
||||||
step_loss, grad_norm = DistributedInterface().all_reduce([step_loss, grad_norm])
|
step_loss, grad_norm = DistributedInterface().all_reduce([step_loss, grad_norm])
|
||||||
DistributedInterface().sync()
|
DistributedInterface().sync()
|
||||||
|
if DistributedInterface().get_rank() == 0:
|
||||||
print(f"Epoch {epoch}, Step {self.global_step}, Loss: {step_loss:.4f}, Grad Norm: {grad_norm:.4f}")
|
print(f"Epoch {epoch}, Step {self.global_step}, Loss: {step_loss:.4f}, Grad Norm: {grad_norm:.4f}")
|
||||||
|
|
||||||
|
# Check if max_steps is reached
|
||||||
|
if self.global_step >= self.num_training_steps:
|
||||||
|
logger.info_rank0(f"Reached max_steps ({self.num_training_steps}), stopping training.")
|
||||||
|
return
|
||||||
|
|
||||||
def save_model(self) -> None:
|
def save_model(self) -> None:
|
||||||
"""Save the model."""
|
"""Save the model."""
|
||||||
self.model.save_pretrained(self.args.output_dir)
|
model_to_save = self.model.module if hasattr(self.model, "module") else self.model
|
||||||
|
model_to_save.save_pretrained(self.args.output_dir)
|
||||||
self.renderer.processor.save_pretrained(self.args.output_dir)
|
self.renderer.processor.save_pretrained(self.args.output_dir)
|
||||||
logger.info_rank0(f"Model saved to {self.args.output_dir}")
|
logger.info_rank0(f"Model saved to {self.args.output_dir}")
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ from torch.utils.data import default_collate
|
|||||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||||
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
|
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
|
||||||
|
|
||||||
from ...accelerator.interface import DistributedInterface
|
from ...accelerator.interface import Dim, DistributedInterface
|
||||||
from ...config import BatchingStrategy
|
from ...config import BatchingStrategy
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from ...utils.helper import pad_and_truncate
|
from ...utils.helper import pad_and_truncate
|
||||||
@@ -83,8 +83,7 @@ class BatchGenerator(Iterator):
|
|||||||
self.pin_memory = pin_memory
|
self.pin_memory = pin_memory
|
||||||
self.drop_last = drop_last
|
self.drop_last = drop_last
|
||||||
# TODO: support length and infinity
|
# TODO: support length and infinity
|
||||||
|
dp_size = DistributedInterface().get_world_size(Dim.DP)
|
||||||
dp_size = DistributedInterface().get_world_size("dp")
|
|
||||||
|
|
||||||
if self.global_batch_size is None:
|
if self.global_batch_size is None:
|
||||||
self.global_batch_size = dp_size * micro_batch_size
|
self.global_batch_size = dp_size * micro_batch_size
|
||||||
@@ -126,8 +125,8 @@ class BatchGenerator(Iterator):
|
|||||||
if len(self.dataset) != -1:
|
if len(self.dataset) != -1:
|
||||||
sampler = StatefulDistributedSampler(
|
sampler = StatefulDistributedSampler(
|
||||||
self.dataset,
|
self.dataset,
|
||||||
num_replicas=DistributedInterface().get_world_size("dp"),
|
num_replicas=DistributedInterface().get_world_size(Dim.DP),
|
||||||
rank=DistributedInterface().get_rank("dp"),
|
rank=DistributedInterface().get_rank(Dim.DP),
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
seed=0,
|
seed=0,
|
||||||
drop_last=self.drop_last,
|
drop_last=self.drop_last,
|
||||||
@@ -142,6 +141,7 @@ class BatchGenerator(Iterator):
|
|||||||
num_workers=self.batching_workers,
|
num_workers=self.batching_workers,
|
||||||
collate_fn=self.renderer.process_samples,
|
collate_fn=self.renderer.process_samples,
|
||||||
pin_memory=self.pin_memory,
|
pin_memory=self.pin_memory,
|
||||||
|
pin_memory_device=DistributedInterface().current_device.type,
|
||||||
drop_last=self.drop_last,
|
drop_last=self.drop_last,
|
||||||
)
|
)
|
||||||
if self.batching_strategy == BatchingStrategy.NORMAL:
|
if self.batching_strategy == BatchingStrategy.NORMAL:
|
||||||
|
|||||||
@@ -12,9 +12,10 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
from copy import deepcopy
|
||||||
from ..extras.env import VERSION, print_env
|
|
||||||
|
|
||||||
|
|
||||||
USAGE = (
|
USAGE = (
|
||||||
@@ -27,27 +28,97 @@ USAGE = (
|
|||||||
+ "-" * 70
|
+ "-" * 70
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_DIST_TRAIN_COMMANDS = ("train", "sft", "dpo", "rm")
|
||||||
WELCOME = (
|
|
||||||
"-" * 58
|
|
||||||
+ "\n"
|
|
||||||
+ f"| Welcome to LLaMA Factory, version {VERSION}"
|
|
||||||
+ " " * (21 - len(VERSION))
|
|
||||||
+ "|\n|"
|
|
||||||
+ " " * 56
|
|
||||||
+ "|\n"
|
|
||||||
+ "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n"
|
|
||||||
+ "-" * 58
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def launch():
|
def launch():
|
||||||
|
from .accelerator.helper import get_device_count
|
||||||
|
from .utils.env import find_available_port, is_env_enabled, use_kt, use_ray
|
||||||
|
from .utils.logging import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
# NOTE:
|
||||||
|
# `llamafactory-cli <command> ...` enters here first.
|
||||||
|
# We may re-launch via `torchrun` for distributed training. In that case we must
|
||||||
|
# forward `<command>` as argv[1] to the re-executed script, otherwise the script
|
||||||
|
# will misinterpret the first user argument (e.g. yaml config) as the command.
|
||||||
command = sys.argv.pop(1) if len(sys.argv) > 1 else "help"
|
command = sys.argv.pop(1) if len(sys.argv) > 1 else "help"
|
||||||
|
|
||||||
if command == "sft": # train command will fallback to sft command
|
if command in _DIST_TRAIN_COMMANDS and (
|
||||||
from .trainers.sft_trainer import run_sft
|
is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray() and not use_kt())
|
||||||
|
):
|
||||||
|
nnodes = os.getenv("NNODES", "1")
|
||||||
|
node_rank = os.getenv("NODE_RANK", "0")
|
||||||
|
nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count()))
|
||||||
|
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
|
||||||
|
master_port = os.getenv("MASTER_PORT", str(find_available_port()))
|
||||||
|
logger.info_rank0(f"Initializing {nproc_per_node} distributed tasks at: {master_addr}:{master_port}")
|
||||||
|
if int(nnodes) > 1:
|
||||||
|
logger.info_rank0(f"Multi-node training enabled: num nodes: {nnodes}, node rank: {node_rank}")
|
||||||
|
|
||||||
run_sft()
|
# elastic launch support
|
||||||
|
max_restarts = os.getenv("MAX_RESTARTS", "0")
|
||||||
|
rdzv_id = os.getenv("RDZV_ID")
|
||||||
|
min_nnodes = os.getenv("MIN_NNODES")
|
||||||
|
max_nnodes = os.getenv("MAX_NNODES")
|
||||||
|
|
||||||
|
env = deepcopy(os.environ)
|
||||||
|
if is_env_enabled("OPTIM_TORCH", "1"):
|
||||||
|
# optimize DDP, see https://zhuanlan.zhihu.com/p/671834539
|
||||||
|
env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||||
|
env["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
||||||
|
|
||||||
|
torchrun_args = [
|
||||||
|
"torchrun",
|
||||||
|
"--nproc-per-node",
|
||||||
|
nproc_per_node,
|
||||||
|
]
|
||||||
|
if rdzv_id is not None:
|
||||||
|
# launch elastic job with fault tolerant support when possible
|
||||||
|
# see also https://docs.pytorch.org/docs/stable/elastic/train_script.html
|
||||||
|
rdzv_nnodes = nnodes
|
||||||
|
# elastic number of nodes if MIN_NNODES and MAX_NNODES are set
|
||||||
|
if min_nnodes is not None and max_nnodes is not None:
|
||||||
|
rdzv_nnodes = f"{min_nnodes}:{max_nnodes}"
|
||||||
|
|
||||||
|
torchrun_args.extend(
|
||||||
|
[
|
||||||
|
"--nnodes",
|
||||||
|
rdzv_nnodes,
|
||||||
|
"--rdzv-id",
|
||||||
|
rdzv_id,
|
||||||
|
"--rdzv-backend",
|
||||||
|
"c10d",
|
||||||
|
"--rdzv-endpoint",
|
||||||
|
f"{master_addr}:{master_port}",
|
||||||
|
"--max-restarts",
|
||||||
|
max_restarts,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# NOTE: DO NOT USE shell=True to avoid security risk
|
||||||
|
torchrun_args.extend(
|
||||||
|
[
|
||||||
|
"--nnodes",
|
||||||
|
nnodes,
|
||||||
|
"--node_rank",
|
||||||
|
node_rank,
|
||||||
|
"--master_addr",
|
||||||
|
master_addr,
|
||||||
|
"--master_port",
|
||||||
|
master_port,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
script_args = [__file__, command] + sys.argv[1:]
|
||||||
|
process = subprocess.run(
|
||||||
|
torchrun_args + script_args,
|
||||||
|
env=env,
|
||||||
|
check=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
sys.exit(process.returncode)
|
||||||
|
|
||||||
elif command == "chat":
|
elif command == "chat":
|
||||||
from .samplers.cli_sampler import run_chat
|
from .samplers.cli_sampler import run_chat
|
||||||
@@ -55,17 +126,54 @@ def launch():
|
|||||||
run_chat()
|
run_chat()
|
||||||
|
|
||||||
elif command == "env":
|
elif command == "env":
|
||||||
print_env()
|
raise NotImplementedError("Environment information is not implemented yet.")
|
||||||
|
|
||||||
elif command == "version":
|
elif command == "version":
|
||||||
print(WELCOME)
|
raise NotImplementedError("Version information is not implemented yet.")
|
||||||
|
|
||||||
elif command == "help":
|
elif command == "help":
|
||||||
print(USAGE)
|
print(USAGE)
|
||||||
|
|
||||||
|
elif command in _DIST_TRAIN_COMMANDS:
|
||||||
|
# Single GPU training without torchrun
|
||||||
|
if command in ("train", "sft"):
|
||||||
|
from llamafactory.v1.trainers.sft_trainer import run_sft
|
||||||
|
|
||||||
|
run_sft()
|
||||||
|
elif command == "dpo":
|
||||||
|
raise NotImplementedError("DPO trainer is not implemented yet.")
|
||||||
|
elif command == "rm":
|
||||||
|
raise NotImplementedError("RM trainer is not implemented yet.")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
print(f"Unknown command: {command}.\n{USAGE}")
|
print(f"Unknown command: {command}.\n{USAGE}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# sys.argv[1] contains the command (sft/dpo/rm/train), sys.argv[2:] contains the rest args
|
||||||
|
command = sys.argv[1] if len(sys.argv) > 1 else "sft"
|
||||||
|
|
||||||
|
# Routing needs the sub-command, but downstream trainers usually expect argv without it.
|
||||||
|
if command in _DIST_TRAIN_COMMANDS:
|
||||||
|
sys.argv.pop(1)
|
||||||
|
else:
|
||||||
|
# Backward-compat: if someone runs `torchrun launcher.py config.yaml`,
|
||||||
|
# treat it as sft by default.
|
||||||
|
if len(sys.argv) > 1 and sys.argv[1].endswith((".yaml", ".yml")):
|
||||||
|
command = "sft"
|
||||||
|
if command in ("train", "sft"):
|
||||||
|
from llamafactory.v1.trainers.sft_trainer import run_sft
|
||||||
|
|
||||||
|
run_sft()
|
||||||
|
elif command == "dpo":
|
||||||
|
# from llamafactory.v1.trainers.dpo_trainer import run_dpo
|
||||||
|
# run_dpo()
|
||||||
|
raise NotImplementedError("DPO trainer is not implemented yet.")
|
||||||
|
elif command == "rm":
|
||||||
|
# from llamafactory.v1.trainers.rm_trainer import run_rm
|
||||||
|
# run_rm()
|
||||||
|
raise NotImplementedError("RM trainer is not implemented yet.")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pass
|
main()
|
||||||
|
|||||||
399
src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py
Normal file
399
src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py
Normal file
@@ -0,0 +1,399 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict, set_model_state_dict
|
||||||
|
from torch.distributed.fsdp import (
|
||||||
|
CPUOffloadPolicy,
|
||||||
|
MixedPrecisionPolicy,
|
||||||
|
fully_shard,
|
||||||
|
)
|
||||||
|
from transformers import PreTrainedModel
|
||||||
|
|
||||||
|
from ....accelerator.helper import get_current_accelerator
|
||||||
|
from ....accelerator.interface import DistributedInterface
|
||||||
|
from ....utils.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_transformer_layer_cls(model: PreTrainedModel) -> type[nn.Module] | None:
|
||||||
|
no_split_modules = getattr(model, "_no_split_modules", None)
|
||||||
|
if no_split_modules:
|
||||||
|
if isinstance(no_split_modules, (list, tuple)):
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
for cls_name in no_split_modules:
|
||||||
|
if module.__class__.__name__ == cls_name:
|
||||||
|
return module.__class__
|
||||||
|
if hasattr(model, "model") and hasattr(model.model, "layers"):
|
||||||
|
return type(model.model.layers[0])
|
||||||
|
if hasattr(model, "layers"):
|
||||||
|
return type(model.layers[0])
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class FSDP2Engine:
|
||||||
|
def __init__(self, dist_config: dict):
|
||||||
|
self.dist_interface = DistributedInterface()
|
||||||
|
self.rank = self.dist_interface.get_rank()
|
||||||
|
self.local_rank = self.dist_interface.get_local_rank()
|
||||||
|
self.world_size = self.dist_interface.get_world_size()
|
||||||
|
self.mixed_precision = dist_config.get("mixed_precision", "bf16")
|
||||||
|
self.reshard_after_forward = dist_config.get("reshard_after_forward", True)
|
||||||
|
self.offload_params = dist_config.get("offload_params", False)
|
||||||
|
self.pin_memory = dist_config.get("pin_memory", True)
|
||||||
|
self.dcp_path = dist_config.get("dcp_path", None)
|
||||||
|
self.device_mesh = self.dist_interface.data_device_mesh
|
||||||
|
|
||||||
|
if self.device_mesh is None:
|
||||||
|
logger.warning(
|
||||||
|
"Device Mesh not found in DistributedInterface. FSDP2 might fail if not running in distributed mode."
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.device_mesh is not None:
|
||||||
|
try:
|
||||||
|
self.fsdp_mesh = self.device_mesh["dp"]
|
||||||
|
except Exception:
|
||||||
|
self.fsdp_mesh = self.device_mesh
|
||||||
|
|
||||||
|
logger.info(f"Using Device Mesh: {self.fsdp_mesh}")
|
||||||
|
else:
|
||||||
|
self.fsdp_mesh = None
|
||||||
|
|
||||||
|
def get_mp_policy(self) -> MixedPrecisionPolicy:
|
||||||
|
if self.mixed_precision == "bf16":
|
||||||
|
param_dtype = torch.bfloat16
|
||||||
|
reduce_dtype = torch.float32
|
||||||
|
elif self.mixed_precision == "fp16":
|
||||||
|
param_dtype = torch.float16
|
||||||
|
reduce_dtype = torch.float32
|
||||||
|
else:
|
||||||
|
param_dtype = torch.float32
|
||||||
|
reduce_dtype = torch.float32
|
||||||
|
|
||||||
|
return MixedPrecisionPolicy(
|
||||||
|
param_dtype=param_dtype,
|
||||||
|
reduce_dtype=reduce_dtype,
|
||||||
|
cast_forward_inputs=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_model(self, model: PreTrainedModel) -> PreTrainedModel:
|
||||||
|
if self.fsdp_mesh is None:
|
||||||
|
logger.warning("No FSDP Mesh available, skipping FSDP wrapping.")
|
||||||
|
return model
|
||||||
|
|
||||||
|
mp_policy = self.get_mp_policy()
|
||||||
|
layer_cls = get_transformer_layer_cls(model)
|
||||||
|
|
||||||
|
if layer_cls is None:
|
||||||
|
logger.warning(
|
||||||
|
"Could not identify Transformer Layer class, applying FSDP to the whole model structure only."
|
||||||
|
)
|
||||||
|
transformer_layer_cls_to_wrap = set()
|
||||||
|
else:
|
||||||
|
logger.info(f"Applying per-layer FSDP to {layer_cls.__name__}")
|
||||||
|
transformer_layer_cls_to_wrap = {layer_cls}
|
||||||
|
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
should_wrap = False
|
||||||
|
|
||||||
|
if type(module) in transformer_layer_cls_to_wrap:
|
||||||
|
should_wrap = True
|
||||||
|
elif isinstance(module, nn.Embedding):
|
||||||
|
if not getattr(model.config, "tie_word_embeddings", True):
|
||||||
|
should_wrap = True
|
||||||
|
|
||||||
|
if should_wrap:
|
||||||
|
fully_shard(
|
||||||
|
module,
|
||||||
|
mesh=self.fsdp_mesh,
|
||||||
|
reshard_after_forward=self.reshard_after_forward,
|
||||||
|
mp_policy=mp_policy,
|
||||||
|
offload_policy=CPUOffloadPolicy(pin_memory=self.pin_memory) if self.offload_params else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
use_gradient_checkpointing = True # Could be configurable
|
||||||
|
if use_gradient_checkpointing:
|
||||||
|
if self.rank == 0:
|
||||||
|
logger.info("Enabling gradient checkpointing (transformers native)...")
|
||||||
|
|
||||||
|
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||||
|
|
||||||
|
if hasattr(model, "enable_input_require_grads"):
|
||||||
|
model.enable_input_require_grads()
|
||||||
|
else:
|
||||||
|
|
||||||
|
def make_inputs_require_grad(module, input, output):
|
||||||
|
output.requires_grad_(True)
|
||||||
|
|
||||||
|
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
||||||
|
|
||||||
|
fully_shard(
|
||||||
|
model,
|
||||||
|
mesh=self.fsdp_mesh,
|
||||||
|
reshard_after_forward=self.reshard_after_forward,
|
||||||
|
mp_policy=mp_policy,
|
||||||
|
offload_policy=CPUOffloadPolicy(pin_memory=self.pin_memory) if self.offload_params else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def materialize_and_load(self, model: PreTrainedModel, hf_model_path: str, dcp_path: str = None):
|
||||||
|
if self.rank == 0:
|
||||||
|
logger.info("Materializing sharded model params...")
|
||||||
|
|
||||||
|
device = get_current_accelerator()
|
||||||
|
model.to_empty(device=device)
|
||||||
|
|
||||||
|
if dcp_path and os.path.exists(dcp_path):
|
||||||
|
if self.rank == 0:
|
||||||
|
logger.info(f"DCP path found at {dcp_path}. Using efficient Sharded Loading (DCP Load).")
|
||||||
|
self._load_from_dcp(model, dcp_path)
|
||||||
|
else:
|
||||||
|
if self.rank == 0:
|
||||||
|
if dcp_path:
|
||||||
|
logger.warning(f"DCP path {dcp_path} not found.")
|
||||||
|
logger.info("Using HF Meta Loading (Chunk Load).")
|
||||||
|
self._load_weights_from_hf_checkpoint(model, hf_model_path)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def shard_model(self, model: PreTrainedModel) -> PreTrainedModel:
|
||||||
|
if model.device.type == "meta":
|
||||||
|
model = self.prepare_model(model)
|
||||||
|
model = self.materialize_and_load(model, hf_model_path=model.config.name_or_path, dcp_path=self.dcp_path)
|
||||||
|
else:
|
||||||
|
model = self.prepare_model(model)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _load_from_dcp(self, model: PreTrainedModel, dcp_path: str):
|
||||||
|
import torch.distributed.checkpoint as dcp
|
||||||
|
|
||||||
|
try:
|
||||||
|
if self.rank == 0:
|
||||||
|
logger.info(f"Loading distributed checkpoint from {dcp_path} ...")
|
||||||
|
|
||||||
|
options = StateDictOptions(full_state_dict=False, cpu_offload=True)
|
||||||
|
local_state_dict = get_model_state_dict(model, options=options)
|
||||||
|
dcp.load(state_dict=local_state_dict, checkpoint_id=dcp_path)
|
||||||
|
set_model_state_dict(model, local_state_dict, options=options)
|
||||||
|
|
||||||
|
if self.rank == 0:
|
||||||
|
logger.info("DCP weights loaded successfully.")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load from DCP: {e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def _load_weights_from_hf_checkpoint(self, model, hf_model_path):
|
||||||
|
import glob
|
||||||
|
import json
|
||||||
|
|
||||||
|
hf_model_path = self._resolve_hf_checkpoint_dir(hf_model_path)
|
||||||
|
|
||||||
|
if self.rank == 0:
|
||||||
|
logger.info(f"Loading weights from {hf_model_path} ...")
|
||||||
|
|
||||||
|
index_file = os.path.join(hf_model_path, "model.safetensors.index.json")
|
||||||
|
is_safetensors = True
|
||||||
|
checkpoint_files = []
|
||||||
|
|
||||||
|
if os.path.exists(index_file):
|
||||||
|
with open(index_file) as f:
|
||||||
|
index = json.load(f)
|
||||||
|
checkpoint_files = sorted(set(index["weight_map"].values()))
|
||||||
|
checkpoint_files = [os.path.join(hf_model_path, f) for f in checkpoint_files]
|
||||||
|
elif os.path.exists(os.path.join(hf_model_path, "model.safetensors")):
|
||||||
|
checkpoint_files = [os.path.join(hf_model_path, "model.safetensors")]
|
||||||
|
else:
|
||||||
|
is_safetensors = False
|
||||||
|
index_file = os.path.join(hf_model_path, "pytorch_model.bin.index.json")
|
||||||
|
if os.path.exists(index_file):
|
||||||
|
with open(index_file) as f:
|
||||||
|
index = json.load(f)
|
||||||
|
checkpoint_files = sorted(set(index["weight_map"].values()))
|
||||||
|
checkpoint_files = [os.path.join(hf_model_path, f) for f in checkpoint_files]
|
||||||
|
elif os.path.exists(os.path.join(hf_model_path, "pytorch_model.bin")):
|
||||||
|
checkpoint_files = [os.path.join(hf_model_path, "pytorch_model.bin")]
|
||||||
|
else:
|
||||||
|
checkpoint_files = sorted(glob.glob(os.path.join(hf_model_path, "*.safetensors")))
|
||||||
|
if checkpoint_files:
|
||||||
|
is_safetensors = True
|
||||||
|
else:
|
||||||
|
checkpoint_files = sorted(glob.glob(os.path.join(hf_model_path, "*.bin")))
|
||||||
|
|
||||||
|
if not checkpoint_files:
|
||||||
|
raise ValueError(f"No checkpoint files found in {hf_model_path}")
|
||||||
|
|
||||||
|
param_map = dict(model.named_parameters())
|
||||||
|
total_files = len(checkpoint_files)
|
||||||
|
|
||||||
|
for i, ckpt_file in enumerate(checkpoint_files):
|
||||||
|
if self.rank == 0:
|
||||||
|
logger.info(f"[{i + 1}/{total_files}] Loading {os.path.basename(ckpt_file)} ...")
|
||||||
|
|
||||||
|
if is_safetensors:
|
||||||
|
from safetensors import safe_open
|
||||||
|
|
||||||
|
with safe_open(ckpt_file, framework="pt", device="cpu") as f:
|
||||||
|
for key in f.keys():
|
||||||
|
if key in param_map:
|
||||||
|
tensor = f.get_tensor(key)
|
||||||
|
self._copy_weights(param_map[key], tensor)
|
||||||
|
else:
|
||||||
|
state_dict = torch.load(ckpt_file, map_location="cpu")
|
||||||
|
for key, tensor in state_dict.items():
|
||||||
|
if key in param_map:
|
||||||
|
self._copy_weights(param_map[key], tensor)
|
||||||
|
del state_dict
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
def _resolve_hf_checkpoint_dir(self, hf_model_path: str) -> str:
|
||||||
|
"""Resolve a HF model identifier or local path to a local directory containing checkpoint files.
|
||||||
|
|
||||||
|
- If `hf_model_path` is an existing directory, return it.
|
||||||
|
- If it's a file path, return its parent directory.
|
||||||
|
- Otherwise treat it as a Hugging Face Hub repo id and download/resolve to the local cache dir.
|
||||||
|
"""
|
||||||
|
if not hf_model_path:
|
||||||
|
return hf_model_path
|
||||||
|
|
||||||
|
# Local directory or file path.
|
||||||
|
if os.path.isdir(hf_model_path):
|
||||||
|
return hf_model_path
|
||||||
|
if os.path.isfile(hf_model_path):
|
||||||
|
return os.path.dirname(hf_model_path)
|
||||||
|
|
||||||
|
# HuggingFace Hub repo id: snapshot to local cache so we can glob/index files.
|
||||||
|
try:
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
except ImportError as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"hf_model_path='{hf_model_path}' does not exist locally and huggingface_hub is not available "
|
||||||
|
f"to download it. Please provide a local model directory or install huggingface_hub. Error: {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
revision = os.getenv("HF_REVISION")
|
||||||
|
offline = os.getenv("HF_HUB_OFFLINE") == "1" or os.getenv("TRANSFORMERS_OFFLINE") == "1"
|
||||||
|
|
||||||
|
# In distributed runs, let rank0 download first to avoid N-way concurrent downloads.
|
||||||
|
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
||||||
|
if self.rank == 0:
|
||||||
|
local_dir = snapshot_download(
|
||||||
|
repo_id=hf_model_path,
|
||||||
|
revision=revision,
|
||||||
|
local_files_only=offline,
|
||||||
|
allow_patterns=[
|
||||||
|
"*.safetensors",
|
||||||
|
"*.bin",
|
||||||
|
"*.index.json",
|
||||||
|
"model.safetensors",
|
||||||
|
"model.safetensors.index.json",
|
||||||
|
"pytorch_model.bin",
|
||||||
|
"pytorch_model.bin.index.json",
|
||||||
|
"config.json",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
logger.info(f"Resolved HF repo id '{hf_model_path}' to local dir: {local_dir}")
|
||||||
|
torch.distributed.barrier()
|
||||||
|
if self.rank != 0:
|
||||||
|
local_dir = snapshot_download(
|
||||||
|
repo_id=hf_model_path,
|
||||||
|
revision=revision,
|
||||||
|
local_files_only=True,
|
||||||
|
allow_patterns=[
|
||||||
|
"*.safetensors",
|
||||||
|
"*.bin",
|
||||||
|
"*.index.json",
|
||||||
|
"model.safetensors",
|
||||||
|
"model.safetensors.index.json",
|
||||||
|
"pytorch_model.bin",
|
||||||
|
"pytorch_model.bin.index.json",
|
||||||
|
"config.json",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
return local_dir
|
||||||
|
|
||||||
|
local_dir = snapshot_download(
|
||||||
|
repo_id=hf_model_path,
|
||||||
|
revision=revision,
|
||||||
|
local_files_only=offline,
|
||||||
|
allow_patterns=[
|
||||||
|
"*.safetensors",
|
||||||
|
"*.bin",
|
||||||
|
"*.index.json",
|
||||||
|
"model.safetensors",
|
||||||
|
"model.safetensors.index.json",
|
||||||
|
"pytorch_model.bin",
|
||||||
|
"pytorch_model.bin.index.json",
|
||||||
|
"config.json",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
if self.rank == 0:
|
||||||
|
logger.info(f"Resolved HF repo id '{hf_model_path}' to local dir: {local_dir}")
|
||||||
|
return local_dir
|
||||||
|
|
||||||
|
def _copy_weights(self, param, loaded_tensor):
|
||||||
|
from torch.distributed._tensor import DTensor, Shard
|
||||||
|
|
||||||
|
if loaded_tensor.dtype != param.dtype:
|
||||||
|
loaded_tensor = loaded_tensor.to(param.dtype)
|
||||||
|
|
||||||
|
if isinstance(param, DTensor):
|
||||||
|
shard_placement = None
|
||||||
|
mesh_dim = -1
|
||||||
|
|
||||||
|
for i, placement in enumerate(param.placements):
|
||||||
|
if isinstance(placement, Shard):
|
||||||
|
shard_placement = placement
|
||||||
|
mesh_dim = i
|
||||||
|
break
|
||||||
|
|
||||||
|
local_tensor = param.to_local()
|
||||||
|
|
||||||
|
if shard_placement is None:
|
||||||
|
local_tensor.copy_(loaded_tensor)
|
||||||
|
else:
|
||||||
|
dim = shard_placement.dim
|
||||||
|
mesh = param.device_mesh
|
||||||
|
my_coordinate = mesh.get_coordinate()
|
||||||
|
if my_coordinate is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
rank_in_dim = my_coordinate[mesh_dim]
|
||||||
|
world_size_in_dim = mesh.size(mesh_dim)
|
||||||
|
|
||||||
|
full_size = param.shape[dim]
|
||||||
|
chunk_size = (full_size + world_size_in_dim - 1) // world_size_in_dim
|
||||||
|
|
||||||
|
start = rank_in_dim * chunk_size
|
||||||
|
end = min(start + chunk_size, full_size)
|
||||||
|
|
||||||
|
if start >= full_size:
|
||||||
|
return
|
||||||
|
|
||||||
|
sliced_tensor = loaded_tensor.narrow(dim, start, end - start)
|
||||||
|
|
||||||
|
slices = [slice(None)] * local_tensor.ndim
|
||||||
|
slices[dim] = slice(0, sliced_tensor.shape[dim])
|
||||||
|
local_tensor[tuple(slices)].copy_(sliced_tensor)
|
||||||
|
else:
|
||||||
|
param.data.copy_(loaded_tensor)
|
||||||
@@ -0,0 +1,34 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from ....config.arg_utils import PluginConfig
|
||||||
|
from ....utils.plugin import BasePlugin
|
||||||
|
from ....utils.types import HFModel
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedPlugin(BasePlugin):
|
||||||
|
def __call__(self, model: HFModel, dist_config: PluginConfig, **kwargs) -> HFModel:
|
||||||
|
return super().__call__(model, dist_config, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@DistributedPlugin("fsdp2").register()
|
||||||
|
def shard_model_fsdp2(model: HFModel, dist_config: PluginConfig) -> HFModel:
|
||||||
|
from .fsdp2 import FSDP2Engine
|
||||||
|
|
||||||
|
return FSDP2Engine(dist_config).shard_model(model)
|
||||||
|
|
||||||
|
|
||||||
|
@DistributedPlugin("deepspeed").register()
|
||||||
|
def shard_model_deepspeed(model: HFModel, dist_config: PluginConfig) -> HFModel:
|
||||||
|
return model
|
||||||
@@ -28,3 +28,11 @@ def find_available_port() -> int:
|
|||||||
def is_env_enabled(env_var: str, default: str = "0") -> bool:
|
def is_env_enabled(env_var: str, default: str = "0") -> bool:
|
||||||
"""Check if the environment variable is enabled."""
|
"""Check if the environment variable is enabled."""
|
||||||
return os.getenv(env_var, default).lower() in ["true", "yes", "on", "t", "y", "1"]
|
return os.getenv(env_var, default).lower() in ["true", "yes", "on", "t", "y", "1"]
|
||||||
|
|
||||||
|
|
||||||
|
def use_ray() -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def use_kt() -> bool:
|
||||||
|
return False
|
||||||
|
|||||||
@@ -15,7 +15,6 @@
|
|||||||
import sys
|
import sys
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
|
|||||||
89
tests_v1/trainers/test_fsdp2_sft_trainer.py
Normal file
89
tests_v1/trainers/test_fsdp2_sft_trainer.py
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.xfail(reason="CI machines may OOM when heavily loaded.")
|
||||||
|
@pytest.mark.runs_on(["cuda", "npu"])
|
||||||
|
def test_fsdp2_sft_trainer(tmp_path: Path):
|
||||||
|
"""Test FSDP2 SFT trainer by simulating `llamafactory-cli sft config.yaml` behavior."""
|
||||||
|
config_yaml = """\
|
||||||
|
model: Qwen/Qwen3-0.6B
|
||||||
|
trust_remote_code: true
|
||||||
|
model_class: llm
|
||||||
|
|
||||||
|
template: qwen3_nothink
|
||||||
|
|
||||||
|
kernel_config:
|
||||||
|
name: auto
|
||||||
|
include_kernels: auto
|
||||||
|
|
||||||
|
quant_config: null
|
||||||
|
|
||||||
|
dist_config:
|
||||||
|
name: fsdp2
|
||||||
|
dcp_path: null
|
||||||
|
|
||||||
|
init_config:
|
||||||
|
name: init_on_meta
|
||||||
|
|
||||||
|
### data
|
||||||
|
train_dataset: data/v1_sft_demo.yaml
|
||||||
|
|
||||||
|
### training
|
||||||
|
output_dir: {output_dir}
|
||||||
|
micro_batch_size: 1
|
||||||
|
global_batch_size: 1
|
||||||
|
cutoff_len: 2048
|
||||||
|
learning_rate: 1.0e-4
|
||||||
|
bf16: false
|
||||||
|
max_steps: 1
|
||||||
|
|
||||||
|
### sample
|
||||||
|
sample_backend: hf
|
||||||
|
max_new_tokens: 128
|
||||||
|
"""
|
||||||
|
# Create output directory
|
||||||
|
output_dir = tmp_path / "outputs"
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
config_file = tmp_path / "config.yaml"
|
||||||
|
config_file.write_text(config_yaml.format(output_dir=str(output_dir)))
|
||||||
|
|
||||||
|
# Set up environment variables
|
||||||
|
env = os.environ.copy()
|
||||||
|
env["USE_V1"] = "1" # Use v1 launcher
|
||||||
|
env["FORCE_TORCHRUN"] = "1" # Force distributed training via torchrun
|
||||||
|
|
||||||
|
# Run the CLI command via subprocess
|
||||||
|
# This simulates: llamafactory-cli sft config.yaml
|
||||||
|
result = subprocess.run(
|
||||||
|
[sys.executable, "-m", "llamafactory.cli", "sft", str(config_file)],
|
||||||
|
env=env,
|
||||||
|
capture_output=True,
|
||||||
|
cwd=str(Path(__file__).parent.parent.parent), # LLaMA-Factory root
|
||||||
|
)
|
||||||
|
|
||||||
|
# Decode output with error handling (progress bars may contain non-UTF-8 bytes)
|
||||||
|
stderr = result.stderr.decode("utf-8", errors="replace")
|
||||||
|
|
||||||
|
# Check the result
|
||||||
|
assert result.returncode == 0, f"Training failed with return code {result.returncode}\nSTDERR: {stderr}"
|
||||||
|
|
||||||
|
# Verify output files exist (optional - adjust based on what run_sft produces)
|
||||||
|
# assert (output_dir / "some_expected_file").exists()
|
||||||
Reference in New Issue
Block a user