diff --git a/src/llamafactory/extras/misc.py b/src/llamafactory/extras/misc.py index 36c140287..47de14bb1 100644 --- a/src/llamafactory/extras/misc.py +++ b/src/llamafactory/extras/misc.py @@ -157,6 +157,33 @@ def get_current_device() -> "torch.device": return torch.device(device) +def get_device_name() -> str: + r"""Get the name of available devices.""" + if is_torch_xpu_available(): + device = "xpu" + elif is_torch_npu_available(): + device = "npu" + elif is_torch_mps_available(): + device = "mps" + elif is_torch_cuda_available(): + device = "gpu" + else: + device = "cpu" + + return device + + +def get_torch_device(): + r"""Get the torch device namespace for the available devices.""" + device_name = get_device_name() + device_name = "cuda" if device_name == "gpu" else device_name + try: + return getattr(torch, device_name) + except AttributeError: + logger.warning_rank0(f"Device namespace '{device_name}' not found in torch, try to load torch.cuda.") + return torch.cuda + + def get_device_count() -> int: r"""Get the number of available devices.""" if is_torch_xpu_available(): diff --git a/src/llamafactory/hparams/training_args.py b/src/llamafactory/hparams/training_args.py index cb34154fa..bf975f0d2 100644 --- a/src/llamafactory/hparams/training_args.py +++ b/src/llamafactory/hparams/training_args.py @@ -14,7 +14,6 @@ import json from dataclasses import dataclass, field -from typing import Literal from transformers import Seq2SeqTrainingArguments from transformers.training_args import _convert_str_dict @@ -40,56 +39,29 @@ else: class RayArguments: r"""Arguments pertaining to the Ray training.""" - ray_run_name: str | None = field( - default=None, - metadata={"help": "The training results will be saved at `/ray_run_name`."}, - ) - ray_storage_path: str = field( - default="./saves", - metadata={"help": "The storage path to save training results to"}, - ) - ray_storage_filesystem: Literal["s3", "gs", "gcs"] | None = field( - default=None, - metadata={"help": "The storage filesystem to use. If None specified, local filesystem will be used."}, - ) ray_num_workers: int = field( default=1, metadata={"help": "The number of workers for Ray training. Default is 1 worker."}, ) - resources_per_worker: dict | str = field( - default_factory=lambda: {"GPU": 1}, - metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."}, - ) - placement_strategy: Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"] = field( - default="PACK", - metadata={"help": "The placement strategy for Ray training. Default is PACK."}, - ) ray_init_kwargs: dict | str | None = field( default=None, metadata={"help": "The arguments to pass to ray.init for Ray training. Default is None."}, ) + master_addr: str | None = field( + default=None, + metadata={"help": "The master address for init_process_group"}, + ) + master_port: str | None = field( + default=None, + metadata={"help": "The master port for init_process_group"}, + ) def __post_init__(self): self.use_ray = use_ray() - if isinstance(self.resources_per_worker, str) and self.resources_per_worker.startswith("{"): - self.resources_per_worker = _convert_str_dict(json.loads(self.resources_per_worker)) if isinstance(self.ray_init_kwargs, str) and self.ray_init_kwargs.startswith("{"): self.ray_init_kwargs = _convert_str_dict(json.loads(self.ray_init_kwargs)) - if self.ray_storage_filesystem is not None: - if self.ray_storage_filesystem not in ["s3", "gs", "gcs"]: - raise ValueError( - f"ray_storage_filesystem must be one of ['s3', 'gs', 'gcs'], got {self.ray_storage_filesystem}." - ) - - import pyarrow.fs as fs - - if self.ray_storage_filesystem == "s3": - self.ray_storage_filesystem = fs.S3FileSystem() - elif self.ray_storage_filesystem == "gs" or self.ray_storage_filesystem == "gcs": - self.ray_storage_filesystem = fs.GcsFileSystem() - @dataclass class Fp8Arguments: diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index 9d0c8789d..e58316092 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -20,7 +20,6 @@ import json import os from collections.abc import Callable, Mapping -from pathlib import Path from typing import TYPE_CHECKING, Any, Optional, Union import torch @@ -34,6 +33,7 @@ from typing_extensions import override from ..extras import logging from ..extras.constants import IGNORE_INDEX, SWANLAB_CONFIG +from ..extras.misc import get_device_name from ..extras.packages import is_apollo_available, is_galore_available, is_ray_available from ..hparams import FinetuningArguments, ModelArguments from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params @@ -49,15 +49,15 @@ if is_apollo_available(): if is_ray_available(): import ray - from ray.train import RunConfig, ScalingConfig - from ray.train.torch import TorchTrainer + from ray.util.placement_group import PlacementGroup, placement_group + from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy if TYPE_CHECKING: from transformers import PreTrainedModel, TrainerCallback, TrainerState from trl import AutoModelForCausalLMWithValueHead - from ..hparams import DataArguments, RayArguments, TrainingArguments + from ..hparams import DataArguments, TrainingArguments logger = logging.get_logger(__name__) @@ -807,36 +807,88 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall return swanlab_callback -def get_ray_trainer( - training_function: Callable, - train_loop_config: dict[str, Any], - ray_args: "RayArguments", -) -> "TorchTrainer": - if not ray_args.use_ray: - raise ValueError("Ray was not enabled. Please set `USE_RAY=1` to enable ray.") +def get_placement_group(num_workers: int) -> tuple["PlacementGroup", dict[str, int]]: + r"""Get the Ray placement group for distributed training.""" + bundle = {"CPU": 10} + device_name = get_device_name().upper() + if device_name != "CPU": + bundle[device_name] = 1 + bundles = [bundle for _ in range(num_workers)] + pg = placement_group(bundles, strategy="PACK") - if ray_args.ray_init_kwargs is not None: - ray.init(**ray_args.ray_init_kwargs) + return pg, bundle - if ray_args.ray_storage_filesystem is not None: - # this means we are using s3/gcs - storage_path = ray_args.ray_storage_path - else: - storage_path = Path(ray_args.ray_storage_path).absolute().as_posix() - trainer = TorchTrainer( - training_function, - train_loop_config=train_loop_config, - scaling_config=ScalingConfig( - num_workers=ray_args.ray_num_workers, - resources_per_worker=ray_args.resources_per_worker, - placement_strategy=ray_args.placement_strategy, - use_gpu=True, +def get_ray_remote_config_for_worker( + placement_group: "PlacementGroup", + bundle_idx: int, + rank: int, + world_size: int, + master_addr: str, + master_port: str, + env: dict[str, str] = None, +) -> dict[str, Any]: + r"""Get the remote config for a Ray worker.""" + env_vars = { + "RANK": str(rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": master_addr, + "MASTER_PORT": master_port, + "TORCHELASTIC_USE_AGENT_STORE": "False", + } + env.update(env_vars) + + remote_config = { + "scheduling_strategy": PlacementGroupSchedulingStrategy( + placement_group=placement_group, + placement_group_bundle_index=bundle_idx, ), - run_config=RunConfig( - name=ray_args.ray_run_name, - storage_filesystem=ray_args.ray_storage_filesystem, - storage_path=storage_path, - ), - ) - return trainer + "runtime_env": {"env_vars": env}, + "num_cpus": 10, + } + + device_name = get_device_name() + if device_name == "gpu": + remote_config["num_gpus"] = 1 + elif device_name == "npu": + remote_config["resources"] = {"NPU": 1} + + return remote_config + + +def get_ray_head_node_ip() -> str: + r"""Get the IP address of the Ray head node.""" + head_ip = next(node["NodeManagerAddress"] for node in ray.nodes() if node.get("IsHead", False)) + return head_ip + + +def sort_placement_group_by_node_ip(placement_group: "PlacementGroup", master_addr: str = None) -> list[int]: + r"""Sort the placement group bundles by their node IP addresses.""" + + @ray.remote + def _get_node_ip(): + return ray.util.get_node_ip_address().strip("[]") + + tasks = [] + for bundle_idx in range(placement_group.bundle_count): + task = _get_node_ip.options( + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=placement_group, + placement_group_bundle_index=bundle_idx, + ), + ).remote() + tasks.append(task) + + bundle_ips = ray.get(tasks) + bundle_node_ip_list = list(enumerate(bundle_ips)) + + sorted_bundle_node_ip_list = sorted(bundle_node_ip_list, key=lambda x: x[1]) + sorted_bundle_indices = [item[0] for item in sorted_bundle_node_ip_list] + + if master_addr is not None: + preferred_indices = [idx for idx, ip in bundle_node_ip_list if ip == master_addr] + if preferred_indices: + remaining = [i for i in sorted_bundle_indices if i not in preferred_indices] + sorted_bundle_indices = preferred_indices + remaining + + return sorted_bundle_indices diff --git a/src/llamafactory/train/tuner.py b/src/llamafactory/train/tuner.py index 90e284110..5d9a85f87 100644 --- a/src/llamafactory/train/tuner.py +++ b/src/llamafactory/train/tuner.py @@ -23,9 +23,9 @@ from transformers import EarlyStoppingCallback, PreTrainedModel from ..data import get_template_and_fix_tokenizer from ..extras import logging from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME -from ..extras.misc import infer_optim_dtype +from ..extras.misc import find_available_port, get_device_name, get_torch_device, infer_optim_dtype from ..extras.packages import is_mcore_adapter_available, is_ray_available -from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args +from ..hparams import RayArguments, get_infer_args, get_ray_args, get_train_args, read_args from ..model import load_model, load_tokenizer from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback from .dpo import run_dpo @@ -34,12 +34,17 @@ from .ppo import run_ppo from .pt import run_pt from .rm import run_rm from .sft import run_sft -from .trainer_utils import get_ray_trainer, get_swanlab_callback +from .trainer_utils import ( + get_placement_group, + get_ray_head_node_ip, + get_ray_remote_config_for_worker, + get_swanlab_callback, + sort_placement_group_by_node_ip, +) if is_ray_available(): import ray - from ray.train.huggingface.transformers import RayTrainReportCallback if TYPE_CHECKING: @@ -115,13 +120,7 @@ def run_exp(args: Optional[dict[str, Any]] = None, callbacks: Optional[list["Tra ray_args = get_ray_args(args) callbacks = callbacks or [] if ray_args.use_ray: - callbacks.append(RayTrainReportCallback()) - trainer = get_ray_trainer( - training_function=_training_function, - train_loop_config={"args": args, "callbacks": callbacks}, - ray_args=ray_args, - ) - trainer.fit() + _ray_training_function(ray_args, config={"args": args, "callbacks": callbacks}) else: _training_function(config={"args": args, "callbacks": callbacks}) @@ -212,3 +211,94 @@ def export_model(args: Optional[dict[str, Any]] = None) -> None: with open(ollama_modelfile, "w", encoding="utf-8") as f: f.write(template.get_ollama_modelfile(tokenizer)) logger.info_rank0(f"Ollama modelfile saved in {ollama_modelfile}") + + +class Worker: + def __init__(self): + self._setup_env_visible_devices() + + local_rank = os.environ.get("LOCAL_RANK", "0") + get_torch_device().set_device(int(local_rank)) + + def _setup_env_visible_devices(self) -> None: + RAY_NOSET_VISIBLE_DEVICES_LIST = [ + "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", + "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES", + ] + is_ray_noset_visible_devices = any(os.environ.get(env_var, None) for env_var in RAY_NOSET_VISIBLE_DEVICES_LIST) + if is_ray_noset_visible_devices: + device_name = get_device_name().upper() + local_rank = ray.get_runtime_context().get_accelerator_ids()[device_name][0] + os.environ["LOCAL_RANK"] = local_rank + else: + os.environ["LOCAL_RANK"] = "0" + + def _training_function(self, config: dict[str, Any]) -> None: + _training_function(config) + + +def _ray_training_function(ray_args: "RayArguments", config: dict[str, Any]) -> None: + num_workers = ray_args.ray_num_workers + master_addr = ray_args.master_addr + master_port = ray_args.master_port + logger.info(f"Using ray.remote mode with {num_workers} workers for distributed training.") + + # initialize ray + if not ray.is_initialized(): + if ray_args.ray_init_kwargs is not None: + ray.init(**ray_args.ray_init_kwargs) + else: + ray.init() + + # verify resources + device_name = get_device_name().upper() + total_devices = int(ray.cluster_resources().get(device_name, 0)) + if num_workers > total_devices: + raise ValueError( + f"The number of devices in the Ray cluster ({total_devices}) should be greater than num_workers ({num_workers})." + ) + + # verify master_addr + if master_addr is None: + master_addr = get_ray_head_node_ip() + logger.info(f"`master_addr` is not specified, using head node ip: {master_addr}.") + else: + nodes = [node["NodeManagerAddress"] for node in ray.nodes() if node["Alive"]] + if master_addr not in nodes: + raise ValueError(f"The `master_addr` ({master_addr}) is not in Ray cluster or not alive ") + + # create placementgroup for resource management + pg, bundle = get_placement_group(total_devices) + ray.get(pg.ready()) + logger.info(f"Create placement group with {num_workers} bundles: {bundle}") + + # get sorted_bundle_indices + sorted_bundle_indices = sort_placement_group_by_node_ip(pg, master_addr) + + # get master port + if master_port is None: + master_port = find_available_port() + logger.info(f"`master_port` is not specified, using available port: {master_port}.") + master_port = str(master_port) + + # backing up environment variables + current_env = dict(os.environ.items()) + + # launch workers + RayWorker = ray.remote(Worker) + workers = [] + for rank in range(num_workers): + remote_config = get_ray_remote_config_for_worker( + placement_group=pg, + bundle_idx=sorted_bundle_indices[rank], + rank=rank, + world_size=num_workers, + master_addr=master_addr, + master_port=master_port, + env=current_env, + ) + worker = RayWorker.options(**remote_config).remote() + workers.append(worker) + + ray.get([worker._training_function.remote(config=config) for worker in workers]) + ray.shutdown()