diff --git a/examples/v1/train_full/train_full_fsdp2.yaml b/examples/v1/train_full/train_full_fsdp2.yaml new file mode 100644 index 000000000..3bc5e70cc --- /dev/null +++ b/examples/v1/train_full/train_full_fsdp2.yaml @@ -0,0 +1,34 @@ +model: Qwen/Qwen3-0.6B +trust_remote_code: true +model_class: llm + +template: qwen3_nothink + +kernel_config: + name: auto + include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null + +quant_config: null + +dist_config: + name: fsdp2 + dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp + +init_config: + name: init_on_meta + +### data +train_dataset: data/v1_sft_demo.yaml + +### training +output_dir: outputs/test_fsdp2 +micro_batch_size: 1 +global_batch_size: 1 +cutoff_len: 2048 +learning_rate: 1.0e-4 +bf16: false +max_steps: 10 + +### sample +sample_backend: hf +max_new_tokens: 128 diff --git a/scripts/hf2dcp.py b/scripts/hf2dcp.py new file mode 100644 index 000000000..9e6fbf8c5 --- /dev/null +++ b/scripts/hf2dcp.py @@ -0,0 +1,55 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Convert a HuggingFace model to DCP checkpoint format. + +Usage: + python scripts/hf2dcp.py convert --hf_path=/path/to/hf --dcp_path=/path/to/dcp + +Arguments: + hf_path: Path to the HuggingFace model directory. + dcp_path: Output path (directory) for DCP checkpoint. +""" + +import fire +import torch +import torch.distributed.checkpoint as dcp +from transformers import AutoModelForCausalLM + + +def convert(hf_path: str, dcp_path: str) -> None: + """Convert HF model weights to DCP. + + Args: + hf_path: HuggingFace model directory. + dcp_path: Output path (directory) for DCP checkpoint. + """ + if not hf_path or not dcp_path: + raise ValueError("Both 'hf_path' and 'dcp_path' are required.") + + print(f"Loading HF model from {hf_path}...") + model = AutoModelForCausalLM.from_pretrained(hf_path, device_map="cpu", torch_dtype=torch.bfloat16) + + print(f"Saving to DCP format at {dcp_path}...") + dcp.save(model.state_dict(), checkpoint_id=dcp_path) + print("Done!") + + +def help() -> None: + """Show help message.""" + print(__doc__) + + +if __name__ == "__main__": + fire.Fire({"convert": convert, "help": help, "--convert": convert}) diff --git a/src/llamafactory/v1/accelerator/helper.py b/src/llamafactory/v1/accelerator/helper.py index c9cee520c..803ed54e3 100644 --- a/src/llamafactory/v1/accelerator/helper.py +++ b/src/llamafactory/v1/accelerator/helper.py @@ -180,6 +180,16 @@ def operate_tensorlike(fn: Callable[[...], Tensor], data: TensorLike, **kwargs) return result.tolist() +def get_process_group_backend() -> str: + """Get backend for init process group.""" + if get_current_accelerator().type == DeviceType.NPU: + return "hccl" + elif get_current_accelerator().type == DeviceType.CUDA: + return "nccl" + else: + return "gloo" + + def all_gather(tensor: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: """Gathers the tensor from all ranks and stacks them at the first dim.""" world_size = get_world_size() diff --git a/src/llamafactory/v1/accelerator/interface.py b/src/llamafactory/v1/accelerator/interface.py index 6a4c51962..e31afdc79 100644 --- a/src/llamafactory/v1/accelerator/interface.py +++ b/src/llamafactory/v1/accelerator/interface.py @@ -145,7 +145,7 @@ class DistributedInterface: timeout = config.get("timeout", 18000) if self._is_distributed: - init_process_group(timeout=timedelta(seconds=timeout)) + init_process_group(timeout=timedelta(seconds=timeout), backend=helper.get_process_group_backend()) self.model_device_mesh = init_device_mesh( device_type=self.current_device.type, mesh_shape=self.strategy.model_mesh_shape, diff --git a/src/llamafactory/v1/config/arg_parser.py b/src/llamafactory/v1/config/arg_parser.py index 38eddd54a..2122a569f 100644 --- a/src/llamafactory/v1/config/arg_parser.py +++ b/src/llamafactory/v1/config/arg_parser.py @@ -20,7 +20,7 @@ from typing import Any from omegaconf import OmegaConf from transformers import HfArgumentParser -from ...extras.misc import is_env_enabled +from ..utils.env import is_env_enabled from .data_args import DataArguments from .model_args import ModelArguments from .sample_args import SampleArguments diff --git a/src/llamafactory/v1/config/training_args.py b/src/llamafactory/v1/config/training_args.py index 67d09653c..8fe0c1cf1 100644 --- a/src/llamafactory/v1/config/training_args.py +++ b/src/llamafactory/v1/config/training_args.py @@ -45,6 +45,10 @@ class TrainingArguments: default=3, metadata={"help": "Number of training epochs."}, ) + max_steps: int | None = field( + default=None, + metadata={"help": "Maximum number of training steps. If set, overrides num_train_epochs."}, + ) max_grad_norm: float = field( default=1.0, metadata={"help": "Maximum gradient norm for training."}, diff --git a/src/llamafactory/v1/core/base_trainer.py b/src/llamafactory/v1/core/base_trainer.py index e474b37a9..2d97c8073 100644 --- a/src/llamafactory/v1/core/base_trainer.py +++ b/src/llamafactory/v1/core/base_trainer.py @@ -67,7 +67,11 @@ class BaseTrainer: self.model_input_names = self.renderer.processor.model_input_names self._create_batch_generator() - self.num_training_steps = self.args.num_train_epochs * len(self.train_batch_generator) + # Calculate num_training_steps: max_steps takes priority if set + if self.args.max_steps is not None and self.args.max_steps > 0: + self.num_training_steps = self.args.max_steps + else: + self.num_training_steps = self.args.num_train_epochs * len(self.train_batch_generator) if self.args.enable_activation_checkpointing: self.model.gradient_checkpointing_enable({"use_reentrant": False}) @@ -98,7 +102,22 @@ class BaseTrainer: ) def _shard_model(self) -> None: - pass + if self.args.dist_config is None: + if DistributedInterface().get_world_size(Dim.DP) > 1: + from torch.nn.parallel import DistributedDataParallel as DDP + + logger.warning_rank0( + "dist_config is None but distributed training is enabled; falling back to DistributedDataParallel." + ) + device_ids = None if self.device.type == "cpu" else [self.device.index] + self.model = DDP(self.model, device_ids=device_ids) + else: + from ..plugins.trainer_plugins.distributed.hub import DistributedPlugin + + self.model = DistributedPlugin(self.args.dist_config.name)( + self.model, + self.args.dist_config, + ) def _init_optimizer(self) -> None: """Init optimizer.""" @@ -162,7 +181,9 @@ class BaseTrainer: step_loss += loss.item() grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm).item() - if not torch.isfinite(grad_norm): + + # isfinite(): argument 'input' (position 1) must be Tensor, not float + if not torch.isfinite(torch.tensor(grad_norm)): # type: ignore # pyright: ignore [reportUnknownReturnType] logger.warning_rank0(f"Gradient norm is not finite: {grad_norm}") else: self.optimizer.step() @@ -172,10 +193,17 @@ class BaseTrainer: step_loss, grad_norm = DistributedInterface().all_reduce([step_loss, grad_norm]) DistributedInterface().sync() - print(f"Epoch {epoch}, Step {self.global_step}, Loss: {step_loss:.4f}, Grad Norm: {grad_norm:.4f}") + if DistributedInterface().get_rank() == 0: + print(f"Epoch {epoch}, Step {self.global_step}, Loss: {step_loss:.4f}, Grad Norm: {grad_norm:.4f}") + + # Check if max_steps is reached + if self.global_step >= self.num_training_steps: + logger.info_rank0(f"Reached max_steps ({self.num_training_steps}), stopping training.") + return def save_model(self) -> None: """Save the model.""" - self.model.save_pretrained(self.args.output_dir) + model_to_save = self.model.module if hasattr(self.model, "module") else self.model + model_to_save.save_pretrained(self.args.output_dir) self.renderer.processor.save_pretrained(self.args.output_dir) logger.info_rank0(f"Model saved to {self.args.output_dir}") diff --git a/src/llamafactory/v1/core/utils/batching.py b/src/llamafactory/v1/core/utils/batching.py index 100bae4e5..e87a95974 100644 --- a/src/llamafactory/v1/core/utils/batching.py +++ b/src/llamafactory/v1/core/utils/batching.py @@ -30,7 +30,7 @@ from torch.utils.data import default_collate from torchdata.stateful_dataloader import StatefulDataLoader from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler -from ...accelerator.interface import DistributedInterface +from ...accelerator.interface import Dim, DistributedInterface from ...config import BatchingStrategy from ...utils import logging from ...utils.helper import pad_and_truncate @@ -83,8 +83,7 @@ class BatchGenerator(Iterator): self.pin_memory = pin_memory self.drop_last = drop_last # TODO: support length and infinity - - dp_size = DistributedInterface().get_world_size("dp") + dp_size = DistributedInterface().get_world_size(Dim.DP) if self.global_batch_size is None: self.global_batch_size = dp_size * micro_batch_size @@ -126,8 +125,8 @@ class BatchGenerator(Iterator): if len(self.dataset) != -1: sampler = StatefulDistributedSampler( self.dataset, - num_replicas=DistributedInterface().get_world_size("dp"), - rank=DistributedInterface().get_rank("dp"), + num_replicas=DistributedInterface().get_world_size(Dim.DP), + rank=DistributedInterface().get_rank(Dim.DP), shuffle=True, seed=0, drop_last=self.drop_last, @@ -142,6 +141,7 @@ class BatchGenerator(Iterator): num_workers=self.batching_workers, collate_fn=self.renderer.process_samples, pin_memory=self.pin_memory, + pin_memory_device=DistributedInterface().current_device.type, drop_last=self.drop_last, ) if self.batching_strategy == BatchingStrategy.NORMAL: diff --git a/src/llamafactory/v1/launcher.py b/src/llamafactory/v1/launcher.py index 9f04ab398..b9b077b1c 100644 --- a/src/llamafactory/v1/launcher.py +++ b/src/llamafactory/v1/launcher.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import subprocess import sys - -from ..extras.env import VERSION, print_env +from copy import deepcopy USAGE = ( @@ -27,27 +28,97 @@ USAGE = ( + "-" * 70 ) - -WELCOME = ( - "-" * 58 - + "\n" - + f"| Welcome to LLaMA Factory, version {VERSION}" - + " " * (21 - len(VERSION)) - + "|\n|" - + " " * 56 - + "|\n" - + "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n" - + "-" * 58 -) +_DIST_TRAIN_COMMANDS = ("train", "sft", "dpo", "rm") def launch(): + from .accelerator.helper import get_device_count + from .utils.env import find_available_port, is_env_enabled, use_kt, use_ray + from .utils.logging import get_logger + + logger = get_logger(__name__) + + # NOTE: + # `llamafactory-cli ...` enters here first. + # We may re-launch via `torchrun` for distributed training. In that case we must + # forward `` as argv[1] to the re-executed script, otherwise the script + # will misinterpret the first user argument (e.g. yaml config) as the command. command = sys.argv.pop(1) if len(sys.argv) > 1 else "help" - if command == "sft": # train command will fallback to sft command - from .trainers.sft_trainer import run_sft + if command in _DIST_TRAIN_COMMANDS and ( + is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray() and not use_kt()) + ): + nnodes = os.getenv("NNODES", "1") + node_rank = os.getenv("NODE_RANK", "0") + nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count())) + master_addr = os.getenv("MASTER_ADDR", "127.0.0.1") + master_port = os.getenv("MASTER_PORT", str(find_available_port())) + logger.info_rank0(f"Initializing {nproc_per_node} distributed tasks at: {master_addr}:{master_port}") + if int(nnodes) > 1: + logger.info_rank0(f"Multi-node training enabled: num nodes: {nnodes}, node rank: {node_rank}") - run_sft() + # elastic launch support + max_restarts = os.getenv("MAX_RESTARTS", "0") + rdzv_id = os.getenv("RDZV_ID") + min_nnodes = os.getenv("MIN_NNODES") + max_nnodes = os.getenv("MAX_NNODES") + + env = deepcopy(os.environ) + if is_env_enabled("OPTIM_TORCH", "1"): + # optimize DDP, see https://zhuanlan.zhihu.com/p/671834539 + env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" + env["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + + torchrun_args = [ + "torchrun", + "--nproc-per-node", + nproc_per_node, + ] + if rdzv_id is not None: + # launch elastic job with fault tolerant support when possible + # see also https://docs.pytorch.org/docs/stable/elastic/train_script.html + rdzv_nnodes = nnodes + # elastic number of nodes if MIN_NNODES and MAX_NNODES are set + if min_nnodes is not None and max_nnodes is not None: + rdzv_nnodes = f"{min_nnodes}:{max_nnodes}" + + torchrun_args.extend( + [ + "--nnodes", + rdzv_nnodes, + "--rdzv-id", + rdzv_id, + "--rdzv-backend", + "c10d", + "--rdzv-endpoint", + f"{master_addr}:{master_port}", + "--max-restarts", + max_restarts, + ] + ) + else: + # NOTE: DO NOT USE shell=True to avoid security risk + torchrun_args.extend( + [ + "--nnodes", + nnodes, + "--node_rank", + node_rank, + "--master_addr", + master_addr, + "--master_port", + master_port, + ] + ) + + script_args = [__file__, command] + sys.argv[1:] + process = subprocess.run( + torchrun_args + script_args, + env=env, + check=True, + ) + + sys.exit(process.returncode) elif command == "chat": from .samplers.cli_sampler import run_chat @@ -55,17 +126,54 @@ def launch(): run_chat() elif command == "env": - print_env() + raise NotImplementedError("Environment information is not implemented yet.") elif command == "version": - print(WELCOME) + raise NotImplementedError("Version information is not implemented yet.") elif command == "help": print(USAGE) + elif command in _DIST_TRAIN_COMMANDS: + # Single GPU training without torchrun + if command in ("train", "sft"): + from llamafactory.v1.trainers.sft_trainer import run_sft + + run_sft() + elif command == "dpo": + raise NotImplementedError("DPO trainer is not implemented yet.") + elif command == "rm": + raise NotImplementedError("RM trainer is not implemented yet.") + else: print(f"Unknown command: {command}.\n{USAGE}") +def main(): + # sys.argv[1] contains the command (sft/dpo/rm/train), sys.argv[2:] contains the rest args + command = sys.argv[1] if len(sys.argv) > 1 else "sft" + + # Routing needs the sub-command, but downstream trainers usually expect argv without it. + if command in _DIST_TRAIN_COMMANDS: + sys.argv.pop(1) + else: + # Backward-compat: if someone runs `torchrun launcher.py config.yaml`, + # treat it as sft by default. + if len(sys.argv) > 1 and sys.argv[1].endswith((".yaml", ".yml")): + command = "sft" + if command in ("train", "sft"): + from llamafactory.v1.trainers.sft_trainer import run_sft + + run_sft() + elif command == "dpo": + # from llamafactory.v1.trainers.dpo_trainer import run_dpo + # run_dpo() + raise NotImplementedError("DPO trainer is not implemented yet.") + elif command == "rm": + # from llamafactory.v1.trainers.rm_trainer import run_rm + # run_rm() + raise NotImplementedError("RM trainer is not implemented yet.") + + if __name__ == "__main__": - pass + main() diff --git a/src/llamafactory/v1/plugins/trainer_plugins/distributed/accelerate.py b/src/llamafactory/v1/plugins/trainer_plugins/distributed/accelerate.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py b/src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py new file mode 100644 index 000000000..b40265ce3 --- /dev/null +++ b/src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py @@ -0,0 +1,399 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import os + +import torch +import torch.nn as nn +from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict, set_model_state_dict +from torch.distributed.fsdp import ( + CPUOffloadPolicy, + MixedPrecisionPolicy, + fully_shard, +) +from transformers import PreTrainedModel + +from ....accelerator.helper import get_current_accelerator +from ....accelerator.interface import DistributedInterface +from ....utils.logging import get_logger + + +logger = get_logger(__name__) + + +def get_transformer_layer_cls(model: PreTrainedModel) -> type[nn.Module] | None: + no_split_modules = getattr(model, "_no_split_modules", None) + if no_split_modules: + if isinstance(no_split_modules, (list, tuple)): + for name, module in model.named_modules(): + for cls_name in no_split_modules: + if module.__class__.__name__ == cls_name: + return module.__class__ + if hasattr(model, "model") and hasattr(model.model, "layers"): + return type(model.model.layers[0]) + if hasattr(model, "layers"): + return type(model.layers[0]) + + return None + + +class FSDP2Engine: + def __init__(self, dist_config: dict): + self.dist_interface = DistributedInterface() + self.rank = self.dist_interface.get_rank() + self.local_rank = self.dist_interface.get_local_rank() + self.world_size = self.dist_interface.get_world_size() + self.mixed_precision = dist_config.get("mixed_precision", "bf16") + self.reshard_after_forward = dist_config.get("reshard_after_forward", True) + self.offload_params = dist_config.get("offload_params", False) + self.pin_memory = dist_config.get("pin_memory", True) + self.dcp_path = dist_config.get("dcp_path", None) + self.device_mesh = self.dist_interface.data_device_mesh + + if self.device_mesh is None: + logger.warning( + "Device Mesh not found in DistributedInterface. FSDP2 might fail if not running in distributed mode." + ) + + if self.device_mesh is not None: + try: + self.fsdp_mesh = self.device_mesh["dp"] + except Exception: + self.fsdp_mesh = self.device_mesh + + logger.info(f"Using Device Mesh: {self.fsdp_mesh}") + else: + self.fsdp_mesh = None + + def get_mp_policy(self) -> MixedPrecisionPolicy: + if self.mixed_precision == "bf16": + param_dtype = torch.bfloat16 + reduce_dtype = torch.float32 + elif self.mixed_precision == "fp16": + param_dtype = torch.float16 + reduce_dtype = torch.float32 + else: + param_dtype = torch.float32 + reduce_dtype = torch.float32 + + return MixedPrecisionPolicy( + param_dtype=param_dtype, + reduce_dtype=reduce_dtype, + cast_forward_inputs=True, + ) + + def prepare_model(self, model: PreTrainedModel) -> PreTrainedModel: + if self.fsdp_mesh is None: + logger.warning("No FSDP Mesh available, skipping FSDP wrapping.") + return model + + mp_policy = self.get_mp_policy() + layer_cls = get_transformer_layer_cls(model) + + if layer_cls is None: + logger.warning( + "Could not identify Transformer Layer class, applying FSDP to the whole model structure only." + ) + transformer_layer_cls_to_wrap = set() + else: + logger.info(f"Applying per-layer FSDP to {layer_cls.__name__}") + transformer_layer_cls_to_wrap = {layer_cls} + + for name, module in model.named_modules(): + should_wrap = False + + if type(module) in transformer_layer_cls_to_wrap: + should_wrap = True + elif isinstance(module, nn.Embedding): + if not getattr(model.config, "tie_word_embeddings", True): + should_wrap = True + + if should_wrap: + fully_shard( + module, + mesh=self.fsdp_mesh, + reshard_after_forward=self.reshard_after_forward, + mp_policy=mp_policy, + offload_policy=CPUOffloadPolicy(pin_memory=self.pin_memory) if self.offload_params else None, + ) + + use_gradient_checkpointing = True # Could be configurable + if use_gradient_checkpointing: + if self.rank == 0: + logger.info("Enabling gradient checkpointing (transformers native)...") + + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) + + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + fully_shard( + model, + mesh=self.fsdp_mesh, + reshard_after_forward=self.reshard_after_forward, + mp_policy=mp_policy, + offload_policy=CPUOffloadPolicy(pin_memory=self.pin_memory) if self.offload_params else None, + ) + + return model + + @torch.no_grad() + def materialize_and_load(self, model: PreTrainedModel, hf_model_path: str, dcp_path: str = None): + if self.rank == 0: + logger.info("Materializing sharded model params...") + + device = get_current_accelerator() + model.to_empty(device=device) + + if dcp_path and os.path.exists(dcp_path): + if self.rank == 0: + logger.info(f"DCP path found at {dcp_path}. Using efficient Sharded Loading (DCP Load).") + self._load_from_dcp(model, dcp_path) + else: + if self.rank == 0: + if dcp_path: + logger.warning(f"DCP path {dcp_path} not found.") + logger.info("Using HF Meta Loading (Chunk Load).") + self._load_weights_from_hf_checkpoint(model, hf_model_path) + + return model + + def shard_model(self, model: PreTrainedModel) -> PreTrainedModel: + if model.device.type == "meta": + model = self.prepare_model(model) + model = self.materialize_and_load(model, hf_model_path=model.config.name_or_path, dcp_path=self.dcp_path) + else: + model = self.prepare_model(model) + return model + + def _load_from_dcp(self, model: PreTrainedModel, dcp_path: str): + import torch.distributed.checkpoint as dcp + + try: + if self.rank == 0: + logger.info(f"Loading distributed checkpoint from {dcp_path} ...") + + options = StateDictOptions(full_state_dict=False, cpu_offload=True) + local_state_dict = get_model_state_dict(model, options=options) + dcp.load(state_dict=local_state_dict, checkpoint_id=dcp_path) + set_model_state_dict(model, local_state_dict, options=options) + + if self.rank == 0: + logger.info("DCP weights loaded successfully.") + + except Exception as e: + logger.error(f"Failed to load from DCP: {e}") + raise e + + def _load_weights_from_hf_checkpoint(self, model, hf_model_path): + import glob + import json + + hf_model_path = self._resolve_hf_checkpoint_dir(hf_model_path) + + if self.rank == 0: + logger.info(f"Loading weights from {hf_model_path} ...") + + index_file = os.path.join(hf_model_path, "model.safetensors.index.json") + is_safetensors = True + checkpoint_files = [] + + if os.path.exists(index_file): + with open(index_file) as f: + index = json.load(f) + checkpoint_files = sorted(set(index["weight_map"].values())) + checkpoint_files = [os.path.join(hf_model_path, f) for f in checkpoint_files] + elif os.path.exists(os.path.join(hf_model_path, "model.safetensors")): + checkpoint_files = [os.path.join(hf_model_path, "model.safetensors")] + else: + is_safetensors = False + index_file = os.path.join(hf_model_path, "pytorch_model.bin.index.json") + if os.path.exists(index_file): + with open(index_file) as f: + index = json.load(f) + checkpoint_files = sorted(set(index["weight_map"].values())) + checkpoint_files = [os.path.join(hf_model_path, f) for f in checkpoint_files] + elif os.path.exists(os.path.join(hf_model_path, "pytorch_model.bin")): + checkpoint_files = [os.path.join(hf_model_path, "pytorch_model.bin")] + else: + checkpoint_files = sorted(glob.glob(os.path.join(hf_model_path, "*.safetensors"))) + if checkpoint_files: + is_safetensors = True + else: + checkpoint_files = sorted(glob.glob(os.path.join(hf_model_path, "*.bin"))) + + if not checkpoint_files: + raise ValueError(f"No checkpoint files found in {hf_model_path}") + + param_map = dict(model.named_parameters()) + total_files = len(checkpoint_files) + + for i, ckpt_file in enumerate(checkpoint_files): + if self.rank == 0: + logger.info(f"[{i + 1}/{total_files}] Loading {os.path.basename(ckpt_file)} ...") + + if is_safetensors: + from safetensors import safe_open + + with safe_open(ckpt_file, framework="pt", device="cpu") as f: + for key in f.keys(): + if key in param_map: + tensor = f.get_tensor(key) + self._copy_weights(param_map[key], tensor) + else: + state_dict = torch.load(ckpt_file, map_location="cpu") + for key, tensor in state_dict.items(): + if key in param_map: + self._copy_weights(param_map[key], tensor) + del state_dict + gc.collect() + + def _resolve_hf_checkpoint_dir(self, hf_model_path: str) -> str: + """Resolve a HF model identifier or local path to a local directory containing checkpoint files. + + - If `hf_model_path` is an existing directory, return it. + - If it's a file path, return its parent directory. + - Otherwise treat it as a Hugging Face Hub repo id and download/resolve to the local cache dir. + """ + if not hf_model_path: + return hf_model_path + + # Local directory or file path. + if os.path.isdir(hf_model_path): + return hf_model_path + if os.path.isfile(hf_model_path): + return os.path.dirname(hf_model_path) + + # HuggingFace Hub repo id: snapshot to local cache so we can glob/index files. + try: + from huggingface_hub import snapshot_download + except ImportError as e: + raise ValueError( + f"hf_model_path='{hf_model_path}' does not exist locally and huggingface_hub is not available " + f"to download it. Please provide a local model directory or install huggingface_hub. Error: {e}" + ) from e + + revision = os.getenv("HF_REVISION") + offline = os.getenv("HF_HUB_OFFLINE") == "1" or os.getenv("TRANSFORMERS_OFFLINE") == "1" + + # In distributed runs, let rank0 download first to avoid N-way concurrent downloads. + if torch.distributed.is_available() and torch.distributed.is_initialized(): + if self.rank == 0: + local_dir = snapshot_download( + repo_id=hf_model_path, + revision=revision, + local_files_only=offline, + allow_patterns=[ + "*.safetensors", + "*.bin", + "*.index.json", + "model.safetensors", + "model.safetensors.index.json", + "pytorch_model.bin", + "pytorch_model.bin.index.json", + "config.json", + ], + ) + logger.info(f"Resolved HF repo id '{hf_model_path}' to local dir: {local_dir}") + torch.distributed.barrier() + if self.rank != 0: + local_dir = snapshot_download( + repo_id=hf_model_path, + revision=revision, + local_files_only=True, + allow_patterns=[ + "*.safetensors", + "*.bin", + "*.index.json", + "model.safetensors", + "model.safetensors.index.json", + "pytorch_model.bin", + "pytorch_model.bin.index.json", + "config.json", + ], + ) + return local_dir + + local_dir = snapshot_download( + repo_id=hf_model_path, + revision=revision, + local_files_only=offline, + allow_patterns=[ + "*.safetensors", + "*.bin", + "*.index.json", + "model.safetensors", + "model.safetensors.index.json", + "pytorch_model.bin", + "pytorch_model.bin.index.json", + "config.json", + ], + ) + if self.rank == 0: + logger.info(f"Resolved HF repo id '{hf_model_path}' to local dir: {local_dir}") + return local_dir + + def _copy_weights(self, param, loaded_tensor): + from torch.distributed._tensor import DTensor, Shard + + if loaded_tensor.dtype != param.dtype: + loaded_tensor = loaded_tensor.to(param.dtype) + + if isinstance(param, DTensor): + shard_placement = None + mesh_dim = -1 + + for i, placement in enumerate(param.placements): + if isinstance(placement, Shard): + shard_placement = placement + mesh_dim = i + break + + local_tensor = param.to_local() + + if shard_placement is None: + local_tensor.copy_(loaded_tensor) + else: + dim = shard_placement.dim + mesh = param.device_mesh + my_coordinate = mesh.get_coordinate() + if my_coordinate is None: + return + + rank_in_dim = my_coordinate[mesh_dim] + world_size_in_dim = mesh.size(mesh_dim) + + full_size = param.shape[dim] + chunk_size = (full_size + world_size_in_dim - 1) // world_size_in_dim + + start = rank_in_dim * chunk_size + end = min(start + chunk_size, full_size) + + if start >= full_size: + return + + sliced_tensor = loaded_tensor.narrow(dim, start, end - start) + + slices = [slice(None)] * local_tensor.ndim + slices[dim] = slice(0, sliced_tensor.shape[dim]) + local_tensor[tuple(slices)].copy_(sliced_tensor) + else: + param.data.copy_(loaded_tensor) diff --git a/src/llamafactory/v1/plugins/trainer_plugins/distributed/hub.py b/src/llamafactory/v1/plugins/trainer_plugins/distributed/hub.py new file mode 100644 index 000000000..096cae14e --- /dev/null +++ b/src/llamafactory/v1/plugins/trainer_plugins/distributed/hub.py @@ -0,0 +1,34 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ....config.arg_utils import PluginConfig +from ....utils.plugin import BasePlugin +from ....utils.types import HFModel + + +class DistributedPlugin(BasePlugin): + def __call__(self, model: HFModel, dist_config: PluginConfig, **kwargs) -> HFModel: + return super().__call__(model, dist_config, **kwargs) + + +@DistributedPlugin("fsdp2").register() +def shard_model_fsdp2(model: HFModel, dist_config: PluginConfig) -> HFModel: + from .fsdp2 import FSDP2Engine + + return FSDP2Engine(dist_config).shard_model(model) + + +@DistributedPlugin("deepspeed").register() +def shard_model_deepspeed(model: HFModel, dist_config: PluginConfig) -> HFModel: + return model diff --git a/src/llamafactory/v1/utils/env.py b/src/llamafactory/v1/utils/env.py index 683cf0357..fc17d7cc9 100644 --- a/src/llamafactory/v1/utils/env.py +++ b/src/llamafactory/v1/utils/env.py @@ -28,3 +28,11 @@ def find_available_port() -> int: def is_env_enabled(env_var: str, default: str = "0") -> bool: """Check if the environment variable is enabled.""" return os.getenv(env_var, default).lower() in ["true", "yes", "on", "t", "y", "1"] + + +def use_ray() -> bool: + return False + + +def use_kt() -> bool: + return False diff --git a/tests_v1/plugins/model_plugins/test_kernel_plugin.py b/tests_v1/plugins/model_plugins/test_kernel_plugin.py index 334c89502..a4207fed2 100644 --- a/tests_v1/plugins/model_plugins/test_kernel_plugin.py +++ b/tests_v1/plugins/model_plugins/test_kernel_plugin.py @@ -15,7 +15,6 @@ import sys from unittest.mock import MagicMock, patch -import pytest import torch.multiprocessing as mp from transformers import AutoModelForCausalLM diff --git a/tests_v1/trainers/test_fsdp2_sft_trainer.py b/tests_v1/trainers/test_fsdp2_sft_trainer.py new file mode 100644 index 000000000..875f55785 --- /dev/null +++ b/tests_v1/trainers/test_fsdp2_sft_trainer.py @@ -0,0 +1,89 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import subprocess +import sys +from pathlib import Path + +import pytest + + +@pytest.mark.xfail(reason="CI machines may OOM when heavily loaded.") +@pytest.mark.runs_on(["cuda", "npu"]) +def test_fsdp2_sft_trainer(tmp_path: Path): + """Test FSDP2 SFT trainer by simulating `llamafactory-cli sft config.yaml` behavior.""" + config_yaml = """\ +model: Qwen/Qwen3-0.6B +trust_remote_code: true +model_class: llm + +template: qwen3_nothink + +kernel_config: + name: auto + include_kernels: auto + +quant_config: null + +dist_config: + name: fsdp2 + dcp_path: null + +init_config: + name: init_on_meta + +### data +train_dataset: data/v1_sft_demo.yaml + +### training +output_dir: {output_dir} +micro_batch_size: 1 +global_batch_size: 1 +cutoff_len: 2048 +learning_rate: 1.0e-4 +bf16: false +max_steps: 1 + +### sample +sample_backend: hf +max_new_tokens: 128 +""" + # Create output directory + output_dir = tmp_path / "outputs" + output_dir.mkdir(parents=True, exist_ok=True) + config_file = tmp_path / "config.yaml" + config_file.write_text(config_yaml.format(output_dir=str(output_dir))) + + # Set up environment variables + env = os.environ.copy() + env["USE_V1"] = "1" # Use v1 launcher + env["FORCE_TORCHRUN"] = "1" # Force distributed training via torchrun + + # Run the CLI command via subprocess + # This simulates: llamafactory-cli sft config.yaml + result = subprocess.run( + [sys.executable, "-m", "llamafactory.cli", "sft", str(config_file)], + env=env, + capture_output=True, + cwd=str(Path(__file__).parent.parent.parent), # LLaMA-Factory root + ) + + # Decode output with error handling (progress bars may contain non-UTF-8 bytes) + stderr = result.stderr.decode("utf-8", errors="replace") + + # Check the result + assert result.returncode == 0, f"Training failed with return code {result.returncode}\nSTDERR: {stderr}" + + # Verify output files exist (optional - adjust based on what run_sft produces) + # assert (output_dir / "some_expected_file").exists()