[feature] support using ray.remote to start distributed training. (#10109)

This commit is contained in:
xvxuopop
2026-01-28 16:05:29 +08:00
committed by GitHub
parent 9640f79ae5
commit 762b480131
4 changed files with 221 additions and 80 deletions

View File

@@ -157,6 +157,33 @@ def get_current_device() -> "torch.device":
return torch.device(device) return torch.device(device)
def get_device_name() -> str:
r"""Get the name of available devices."""
if is_torch_xpu_available():
device = "xpu"
elif is_torch_npu_available():
device = "npu"
elif is_torch_mps_available():
device = "mps"
elif is_torch_cuda_available():
device = "gpu"
else:
device = "cpu"
return device
def get_torch_device():
r"""Get the torch device namespace for the available devices."""
device_name = get_device_name()
device_name = "cuda" if device_name == "gpu" else device_name
try:
return getattr(torch, device_name)
except AttributeError:
logger.warning_rank0(f"Device namespace '{device_name}' not found in torch, try to load torch.cuda.")
return torch.cuda
def get_device_count() -> int: def get_device_count() -> int:
r"""Get the number of available devices.""" r"""Get the number of available devices."""
if is_torch_xpu_available(): if is_torch_xpu_available():

View File

@@ -14,7 +14,6 @@
import json import json
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Literal
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
from transformers.training_args import _convert_str_dict from transformers.training_args import _convert_str_dict
@@ -40,56 +39,29 @@ else:
class RayArguments: class RayArguments:
r"""Arguments pertaining to the Ray training.""" r"""Arguments pertaining to the Ray training."""
ray_run_name: str | None = field(
default=None,
metadata={"help": "The training results will be saved at `<ray_storage_path>/ray_run_name`."},
)
ray_storage_path: str = field(
default="./saves",
metadata={"help": "The storage path to save training results to"},
)
ray_storage_filesystem: Literal["s3", "gs", "gcs"] | None = field(
default=None,
metadata={"help": "The storage filesystem to use. If None specified, local filesystem will be used."},
)
ray_num_workers: int = field( ray_num_workers: int = field(
default=1, default=1,
metadata={"help": "The number of workers for Ray training. Default is 1 worker."}, metadata={"help": "The number of workers for Ray training. Default is 1 worker."},
) )
resources_per_worker: dict | str = field(
default_factory=lambda: {"GPU": 1},
metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."},
)
placement_strategy: Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"] = field(
default="PACK",
metadata={"help": "The placement strategy for Ray training. Default is PACK."},
)
ray_init_kwargs: dict | str | None = field( ray_init_kwargs: dict | str | None = field(
default=None, default=None,
metadata={"help": "The arguments to pass to ray.init for Ray training. Default is None."}, metadata={"help": "The arguments to pass to ray.init for Ray training. Default is None."},
) )
master_addr: str | None = field(
default=None,
metadata={"help": "The master address for init_process_group"},
)
master_port: str | None = field(
default=None,
metadata={"help": "The master port for init_process_group"},
)
def __post_init__(self): def __post_init__(self):
self.use_ray = use_ray() self.use_ray = use_ray()
if isinstance(self.resources_per_worker, str) and self.resources_per_worker.startswith("{"):
self.resources_per_worker = _convert_str_dict(json.loads(self.resources_per_worker))
if isinstance(self.ray_init_kwargs, str) and self.ray_init_kwargs.startswith("{"): if isinstance(self.ray_init_kwargs, str) and self.ray_init_kwargs.startswith("{"):
self.ray_init_kwargs = _convert_str_dict(json.loads(self.ray_init_kwargs)) self.ray_init_kwargs = _convert_str_dict(json.loads(self.ray_init_kwargs))
if self.ray_storage_filesystem is not None:
if self.ray_storage_filesystem not in ["s3", "gs", "gcs"]:
raise ValueError(
f"ray_storage_filesystem must be one of ['s3', 'gs', 'gcs'], got {self.ray_storage_filesystem}."
)
import pyarrow.fs as fs
if self.ray_storage_filesystem == "s3":
self.ray_storage_filesystem = fs.S3FileSystem()
elif self.ray_storage_filesystem == "gs" or self.ray_storage_filesystem == "gcs":
self.ray_storage_filesystem = fs.GcsFileSystem()
@dataclass @dataclass
class Fp8Arguments: class Fp8Arguments:

View File

@@ -20,7 +20,6 @@
import json import json
import os import os
from collections.abc import Callable, Mapping from collections.abc import Callable, Mapping
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Union
import torch import torch
@@ -34,6 +33,7 @@ from typing_extensions import override
from ..extras import logging from ..extras import logging
from ..extras.constants import IGNORE_INDEX, SWANLAB_CONFIG from ..extras.constants import IGNORE_INDEX, SWANLAB_CONFIG
from ..extras.misc import get_device_name
from ..extras.packages import is_apollo_available, is_galore_available, is_ray_available from ..extras.packages import is_apollo_available, is_galore_available, is_ray_available
from ..hparams import FinetuningArguments, ModelArguments from ..hparams import FinetuningArguments, ModelArguments
from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params
@@ -49,15 +49,15 @@ if is_apollo_available():
if is_ray_available(): if is_ray_available():
import ray import ray
from ray.train import RunConfig, ScalingConfig from ray.util.placement_group import PlacementGroup, placement_group
from ray.train.torch import TorchTrainer from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedModel, TrainerCallback, TrainerState from transformers import PreTrainedModel, TrainerCallback, TrainerState
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from ..hparams import DataArguments, RayArguments, TrainingArguments from ..hparams import DataArguments, TrainingArguments
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@@ -807,36 +807,88 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
return swanlab_callback return swanlab_callback
def get_ray_trainer( def get_placement_group(num_workers: int) -> tuple["PlacementGroup", dict[str, int]]:
training_function: Callable, r"""Get the Ray placement group for distributed training."""
train_loop_config: dict[str, Any], bundle = {"CPU": 10}
ray_args: "RayArguments", device_name = get_device_name().upper()
) -> "TorchTrainer": if device_name != "CPU":
if not ray_args.use_ray: bundle[device_name] = 1
raise ValueError("Ray was not enabled. Please set `USE_RAY=1` to enable ray.") bundles = [bundle for _ in range(num_workers)]
pg = placement_group(bundles, strategy="PACK")
if ray_args.ray_init_kwargs is not None: return pg, bundle
ray.init(**ray_args.ray_init_kwargs)
if ray_args.ray_storage_filesystem is not None:
# this means we are using s3/gcs
storage_path = ray_args.ray_storage_path
else:
storage_path = Path(ray_args.ray_storage_path).absolute().as_posix()
trainer = TorchTrainer( def get_ray_remote_config_for_worker(
training_function, placement_group: "PlacementGroup",
train_loop_config=train_loop_config, bundle_idx: int,
scaling_config=ScalingConfig( rank: int,
num_workers=ray_args.ray_num_workers, world_size: int,
resources_per_worker=ray_args.resources_per_worker, master_addr: str,
placement_strategy=ray_args.placement_strategy, master_port: str,
use_gpu=True, env: dict[str, str] = None,
) -> dict[str, Any]:
r"""Get the remote config for a Ray worker."""
env_vars = {
"RANK": str(rank),
"WORLD_SIZE": str(world_size),
"MASTER_ADDR": master_addr,
"MASTER_PORT": master_port,
"TORCHELASTIC_USE_AGENT_STORE": "False",
}
env.update(env_vars)
remote_config = {
"scheduling_strategy": PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_bundle_index=bundle_idx,
), ),
run_config=RunConfig( "runtime_env": {"env_vars": env},
name=ray_args.ray_run_name, "num_cpus": 10,
storage_filesystem=ray_args.ray_storage_filesystem, }
storage_path=storage_path,
device_name = get_device_name()
if device_name == "gpu":
remote_config["num_gpus"] = 1
elif device_name == "npu":
remote_config["resources"] = {"NPU": 1}
return remote_config
def get_ray_head_node_ip() -> str:
r"""Get the IP address of the Ray head node."""
head_ip = next(node["NodeManagerAddress"] for node in ray.nodes() if node.get("IsHead", False))
return head_ip
def sort_placement_group_by_node_ip(placement_group: "PlacementGroup", master_addr: str = None) -> list[int]:
r"""Sort the placement group bundles by their node IP addresses."""
@ray.remote
def _get_node_ip():
return ray.util.get_node_ip_address().strip("[]")
tasks = []
for bundle_idx in range(placement_group.bundle_count):
task = _get_node_ip.options(
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_bundle_index=bundle_idx,
), ),
) ).remote()
return trainer tasks.append(task)
bundle_ips = ray.get(tasks)
bundle_node_ip_list = list(enumerate(bundle_ips))
sorted_bundle_node_ip_list = sorted(bundle_node_ip_list, key=lambda x: x[1])
sorted_bundle_indices = [item[0] for item in sorted_bundle_node_ip_list]
if master_addr is not None:
preferred_indices = [idx for idx, ip in bundle_node_ip_list if ip == master_addr]
if preferred_indices:
remaining = [i for i in sorted_bundle_indices if i not in preferred_indices]
sorted_bundle_indices = preferred_indices + remaining
return sorted_bundle_indices

View File

@@ -23,9 +23,9 @@ from transformers import EarlyStoppingCallback, PreTrainedModel
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from ..extras import logging from ..extras import logging
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.misc import infer_optim_dtype from ..extras.misc import find_available_port, get_device_name, get_torch_device, infer_optim_dtype
from ..extras.packages import is_mcore_adapter_available, is_ray_available from ..extras.packages import is_mcore_adapter_available, is_ray_available
from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args from ..hparams import RayArguments, get_infer_args, get_ray_args, get_train_args, read_args
from ..model import load_model, load_tokenizer from ..model import load_model, load_tokenizer
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
from .dpo import run_dpo from .dpo import run_dpo
@@ -34,12 +34,17 @@ from .ppo import run_ppo
from .pt import run_pt from .pt import run_pt
from .rm import run_rm from .rm import run_rm
from .sft import run_sft from .sft import run_sft
from .trainer_utils import get_ray_trainer, get_swanlab_callback from .trainer_utils import (
get_placement_group,
get_ray_head_node_ip,
get_ray_remote_config_for_worker,
get_swanlab_callback,
sort_placement_group_by_node_ip,
)
if is_ray_available(): if is_ray_available():
import ray import ray
from ray.train.huggingface.transformers import RayTrainReportCallback
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -115,13 +120,7 @@ def run_exp(args: Optional[dict[str, Any]] = None, callbacks: Optional[list["Tra
ray_args = get_ray_args(args) ray_args = get_ray_args(args)
callbacks = callbacks or [] callbacks = callbacks or []
if ray_args.use_ray: if ray_args.use_ray:
callbacks.append(RayTrainReportCallback()) _ray_training_function(ray_args, config={"args": args, "callbacks": callbacks})
trainer = get_ray_trainer(
training_function=_training_function,
train_loop_config={"args": args, "callbacks": callbacks},
ray_args=ray_args,
)
trainer.fit()
else: else:
_training_function(config={"args": args, "callbacks": callbacks}) _training_function(config={"args": args, "callbacks": callbacks})
@@ -212,3 +211,94 @@ def export_model(args: Optional[dict[str, Any]] = None) -> None:
with open(ollama_modelfile, "w", encoding="utf-8") as f: with open(ollama_modelfile, "w", encoding="utf-8") as f:
f.write(template.get_ollama_modelfile(tokenizer)) f.write(template.get_ollama_modelfile(tokenizer))
logger.info_rank0(f"Ollama modelfile saved in {ollama_modelfile}") logger.info_rank0(f"Ollama modelfile saved in {ollama_modelfile}")
class Worker:
def __init__(self):
self._setup_env_visible_devices()
local_rank = os.environ.get("LOCAL_RANK", "0")
get_torch_device().set_device(int(local_rank))
def _setup_env_visible_devices(self) -> None:
RAY_NOSET_VISIBLE_DEVICES_LIST = [
"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES",
"RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES",
]
is_ray_noset_visible_devices = any(os.environ.get(env_var, None) for env_var in RAY_NOSET_VISIBLE_DEVICES_LIST)
if is_ray_noset_visible_devices:
device_name = get_device_name().upper()
local_rank = ray.get_runtime_context().get_accelerator_ids()[device_name][0]
os.environ["LOCAL_RANK"] = local_rank
else:
os.environ["LOCAL_RANK"] = "0"
def _training_function(self, config: dict[str, Any]) -> None:
_training_function(config)
def _ray_training_function(ray_args: "RayArguments", config: dict[str, Any]) -> None:
num_workers = ray_args.ray_num_workers
master_addr = ray_args.master_addr
master_port = ray_args.master_port
logger.info(f"Using ray.remote mode with {num_workers} workers for distributed training.")
# initialize ray
if not ray.is_initialized():
if ray_args.ray_init_kwargs is not None:
ray.init(**ray_args.ray_init_kwargs)
else:
ray.init()
# verify resources
device_name = get_device_name().upper()
total_devices = int(ray.cluster_resources().get(device_name, 0))
if num_workers > total_devices:
raise ValueError(
f"The number of devices in the Ray cluster ({total_devices}) should be greater than num_workers ({num_workers})."
)
# verify master_addr
if master_addr is None:
master_addr = get_ray_head_node_ip()
logger.info(f"`master_addr` is not specified, using head node ip: {master_addr}.")
else:
nodes = [node["NodeManagerAddress"] for node in ray.nodes() if node["Alive"]]
if master_addr not in nodes:
raise ValueError(f"The `master_addr` ({master_addr}) is not in Ray cluster or not alive ")
# create placementgroup for resource management
pg, bundle = get_placement_group(total_devices)
ray.get(pg.ready())
logger.info(f"Create placement group with {num_workers} bundles: {bundle}")
# get sorted_bundle_indices
sorted_bundle_indices = sort_placement_group_by_node_ip(pg, master_addr)
# get master port
if master_port is None:
master_port = find_available_port()
logger.info(f"`master_port` is not specified, using available port: {master_port}.")
master_port = str(master_port)
# backing up environment variables
current_env = dict(os.environ.items())
# launch workers
RayWorker = ray.remote(Worker)
workers = []
for rank in range(num_workers):
remote_config = get_ray_remote_config_for_worker(
placement_group=pg,
bundle_idx=sorted_bundle_indices[rank],
rank=rank,
world_size=num_workers,
master_addr=master_addr,
master_port=master_port,
env=current_env,
)
worker = RayWorker.options(**remote_config).remote()
workers.append(worker)
ray.get([worker._training_function.remote(config=config) for worker in workers])
ray.shutdown()