mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-01-29 22:02:03 +00:00
[feature] support using ray.remote to start distributed training. (#10109)
This commit is contained in:
@@ -157,6 +157,33 @@ def get_current_device() -> "torch.device":
|
|||||||
return torch.device(device)
|
return torch.device(device)
|
||||||
|
|
||||||
|
|
||||||
|
def get_device_name() -> str:
|
||||||
|
r"""Get the name of available devices."""
|
||||||
|
if is_torch_xpu_available():
|
||||||
|
device = "xpu"
|
||||||
|
elif is_torch_npu_available():
|
||||||
|
device = "npu"
|
||||||
|
elif is_torch_mps_available():
|
||||||
|
device = "mps"
|
||||||
|
elif is_torch_cuda_available():
|
||||||
|
device = "gpu"
|
||||||
|
else:
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
return device
|
||||||
|
|
||||||
|
|
||||||
|
def get_torch_device():
|
||||||
|
r"""Get the torch device namespace for the available devices."""
|
||||||
|
device_name = get_device_name()
|
||||||
|
device_name = "cuda" if device_name == "gpu" else device_name
|
||||||
|
try:
|
||||||
|
return getattr(torch, device_name)
|
||||||
|
except AttributeError:
|
||||||
|
logger.warning_rank0(f"Device namespace '{device_name}' not found in torch, try to load torch.cuda.")
|
||||||
|
return torch.cuda
|
||||||
|
|
||||||
|
|
||||||
def get_device_count() -> int:
|
def get_device_count() -> int:
|
||||||
r"""Get the number of available devices."""
|
r"""Get the number of available devices."""
|
||||||
if is_torch_xpu_available():
|
if is_torch_xpu_available():
|
||||||
|
|||||||
@@ -14,7 +14,6 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
from transformers import Seq2SeqTrainingArguments
|
from transformers import Seq2SeqTrainingArguments
|
||||||
from transformers.training_args import _convert_str_dict
|
from transformers.training_args import _convert_str_dict
|
||||||
@@ -40,56 +39,29 @@ else:
|
|||||||
class RayArguments:
|
class RayArguments:
|
||||||
r"""Arguments pertaining to the Ray training."""
|
r"""Arguments pertaining to the Ray training."""
|
||||||
|
|
||||||
ray_run_name: str | None = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "The training results will be saved at `<ray_storage_path>/ray_run_name`."},
|
|
||||||
)
|
|
||||||
ray_storage_path: str = field(
|
|
||||||
default="./saves",
|
|
||||||
metadata={"help": "The storage path to save training results to"},
|
|
||||||
)
|
|
||||||
ray_storage_filesystem: Literal["s3", "gs", "gcs"] | None = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "The storage filesystem to use. If None specified, local filesystem will be used."},
|
|
||||||
)
|
|
||||||
ray_num_workers: int = field(
|
ray_num_workers: int = field(
|
||||||
default=1,
|
default=1,
|
||||||
metadata={"help": "The number of workers for Ray training. Default is 1 worker."},
|
metadata={"help": "The number of workers for Ray training. Default is 1 worker."},
|
||||||
)
|
)
|
||||||
resources_per_worker: dict | str = field(
|
|
||||||
default_factory=lambda: {"GPU": 1},
|
|
||||||
metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."},
|
|
||||||
)
|
|
||||||
placement_strategy: Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"] = field(
|
|
||||||
default="PACK",
|
|
||||||
metadata={"help": "The placement strategy for Ray training. Default is PACK."},
|
|
||||||
)
|
|
||||||
ray_init_kwargs: dict | str | None = field(
|
ray_init_kwargs: dict | str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The arguments to pass to ray.init for Ray training. Default is None."},
|
metadata={"help": "The arguments to pass to ray.init for Ray training. Default is None."},
|
||||||
)
|
)
|
||||||
|
master_addr: str | None = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "The master address for init_process_group"},
|
||||||
|
)
|
||||||
|
master_port: str | None = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "The master port for init_process_group"},
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.use_ray = use_ray()
|
self.use_ray = use_ray()
|
||||||
if isinstance(self.resources_per_worker, str) and self.resources_per_worker.startswith("{"):
|
|
||||||
self.resources_per_worker = _convert_str_dict(json.loads(self.resources_per_worker))
|
|
||||||
|
|
||||||
if isinstance(self.ray_init_kwargs, str) and self.ray_init_kwargs.startswith("{"):
|
if isinstance(self.ray_init_kwargs, str) and self.ray_init_kwargs.startswith("{"):
|
||||||
self.ray_init_kwargs = _convert_str_dict(json.loads(self.ray_init_kwargs))
|
self.ray_init_kwargs = _convert_str_dict(json.loads(self.ray_init_kwargs))
|
||||||
|
|
||||||
if self.ray_storage_filesystem is not None:
|
|
||||||
if self.ray_storage_filesystem not in ["s3", "gs", "gcs"]:
|
|
||||||
raise ValueError(
|
|
||||||
f"ray_storage_filesystem must be one of ['s3', 'gs', 'gcs'], got {self.ray_storage_filesystem}."
|
|
||||||
)
|
|
||||||
|
|
||||||
import pyarrow.fs as fs
|
|
||||||
|
|
||||||
if self.ray_storage_filesystem == "s3":
|
|
||||||
self.ray_storage_filesystem = fs.S3FileSystem()
|
|
||||||
elif self.ray_storage_filesystem == "gs" or self.ray_storage_filesystem == "gcs":
|
|
||||||
self.ray_storage_filesystem = fs.GcsFileSystem()
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Fp8Arguments:
|
class Fp8Arguments:
|
||||||
|
|||||||
@@ -20,7 +20,6 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from collections.abc import Callable, Mapping
|
from collections.abc import Callable, Mapping
|
||||||
from pathlib import Path
|
|
||||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -34,6 +33,7 @@ from typing_extensions import override
|
|||||||
|
|
||||||
from ..extras import logging
|
from ..extras import logging
|
||||||
from ..extras.constants import IGNORE_INDEX, SWANLAB_CONFIG
|
from ..extras.constants import IGNORE_INDEX, SWANLAB_CONFIG
|
||||||
|
from ..extras.misc import get_device_name
|
||||||
from ..extras.packages import is_apollo_available, is_galore_available, is_ray_available
|
from ..extras.packages import is_apollo_available, is_galore_available, is_ray_available
|
||||||
from ..hparams import FinetuningArguments, ModelArguments
|
from ..hparams import FinetuningArguments, ModelArguments
|
||||||
from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params
|
from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params
|
||||||
@@ -49,15 +49,15 @@ if is_apollo_available():
|
|||||||
|
|
||||||
if is_ray_available():
|
if is_ray_available():
|
||||||
import ray
|
import ray
|
||||||
from ray.train import RunConfig, ScalingConfig
|
from ray.util.placement_group import PlacementGroup, placement_group
|
||||||
from ray.train.torch import TorchTrainer
|
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PreTrainedModel, TrainerCallback, TrainerState
|
from transformers import PreTrainedModel, TrainerCallback, TrainerState
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
from ..hparams import DataArguments, RayArguments, TrainingArguments
|
from ..hparams import DataArguments, TrainingArguments
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -807,36 +807,88 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
|
|||||||
return swanlab_callback
|
return swanlab_callback
|
||||||
|
|
||||||
|
|
||||||
def get_ray_trainer(
|
def get_placement_group(num_workers: int) -> tuple["PlacementGroup", dict[str, int]]:
|
||||||
training_function: Callable,
|
r"""Get the Ray placement group for distributed training."""
|
||||||
train_loop_config: dict[str, Any],
|
bundle = {"CPU": 10}
|
||||||
ray_args: "RayArguments",
|
device_name = get_device_name().upper()
|
||||||
) -> "TorchTrainer":
|
if device_name != "CPU":
|
||||||
if not ray_args.use_ray:
|
bundle[device_name] = 1
|
||||||
raise ValueError("Ray was not enabled. Please set `USE_RAY=1` to enable ray.")
|
bundles = [bundle for _ in range(num_workers)]
|
||||||
|
pg = placement_group(bundles, strategy="PACK")
|
||||||
|
|
||||||
if ray_args.ray_init_kwargs is not None:
|
return pg, bundle
|
||||||
ray.init(**ray_args.ray_init_kwargs)
|
|
||||||
|
|
||||||
if ray_args.ray_storage_filesystem is not None:
|
|
||||||
# this means we are using s3/gcs
|
|
||||||
storage_path = ray_args.ray_storage_path
|
|
||||||
else:
|
|
||||||
storage_path = Path(ray_args.ray_storage_path).absolute().as_posix()
|
|
||||||
|
|
||||||
trainer = TorchTrainer(
|
def get_ray_remote_config_for_worker(
|
||||||
training_function,
|
placement_group: "PlacementGroup",
|
||||||
train_loop_config=train_loop_config,
|
bundle_idx: int,
|
||||||
scaling_config=ScalingConfig(
|
rank: int,
|
||||||
num_workers=ray_args.ray_num_workers,
|
world_size: int,
|
||||||
resources_per_worker=ray_args.resources_per_worker,
|
master_addr: str,
|
||||||
placement_strategy=ray_args.placement_strategy,
|
master_port: str,
|
||||||
use_gpu=True,
|
env: dict[str, str] = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
r"""Get the remote config for a Ray worker."""
|
||||||
|
env_vars = {
|
||||||
|
"RANK": str(rank),
|
||||||
|
"WORLD_SIZE": str(world_size),
|
||||||
|
"MASTER_ADDR": master_addr,
|
||||||
|
"MASTER_PORT": master_port,
|
||||||
|
"TORCHELASTIC_USE_AGENT_STORE": "False",
|
||||||
|
}
|
||||||
|
env.update(env_vars)
|
||||||
|
|
||||||
|
remote_config = {
|
||||||
|
"scheduling_strategy": PlacementGroupSchedulingStrategy(
|
||||||
|
placement_group=placement_group,
|
||||||
|
placement_group_bundle_index=bundle_idx,
|
||||||
),
|
),
|
||||||
run_config=RunConfig(
|
"runtime_env": {"env_vars": env},
|
||||||
name=ray_args.ray_run_name,
|
"num_cpus": 10,
|
||||||
storage_filesystem=ray_args.ray_storage_filesystem,
|
}
|
||||||
storage_path=storage_path,
|
|
||||||
),
|
device_name = get_device_name()
|
||||||
)
|
if device_name == "gpu":
|
||||||
return trainer
|
remote_config["num_gpus"] = 1
|
||||||
|
elif device_name == "npu":
|
||||||
|
remote_config["resources"] = {"NPU": 1}
|
||||||
|
|
||||||
|
return remote_config
|
||||||
|
|
||||||
|
|
||||||
|
def get_ray_head_node_ip() -> str:
|
||||||
|
r"""Get the IP address of the Ray head node."""
|
||||||
|
head_ip = next(node["NodeManagerAddress"] for node in ray.nodes() if node.get("IsHead", False))
|
||||||
|
return head_ip
|
||||||
|
|
||||||
|
|
||||||
|
def sort_placement_group_by_node_ip(placement_group: "PlacementGroup", master_addr: str = None) -> list[int]:
|
||||||
|
r"""Sort the placement group bundles by their node IP addresses."""
|
||||||
|
|
||||||
|
@ray.remote
|
||||||
|
def _get_node_ip():
|
||||||
|
return ray.util.get_node_ip_address().strip("[]")
|
||||||
|
|
||||||
|
tasks = []
|
||||||
|
for bundle_idx in range(placement_group.bundle_count):
|
||||||
|
task = _get_node_ip.options(
|
||||||
|
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||||
|
placement_group=placement_group,
|
||||||
|
placement_group_bundle_index=bundle_idx,
|
||||||
|
),
|
||||||
|
).remote()
|
||||||
|
tasks.append(task)
|
||||||
|
|
||||||
|
bundle_ips = ray.get(tasks)
|
||||||
|
bundle_node_ip_list = list(enumerate(bundle_ips))
|
||||||
|
|
||||||
|
sorted_bundle_node_ip_list = sorted(bundle_node_ip_list, key=lambda x: x[1])
|
||||||
|
sorted_bundle_indices = [item[0] for item in sorted_bundle_node_ip_list]
|
||||||
|
|
||||||
|
if master_addr is not None:
|
||||||
|
preferred_indices = [idx for idx, ip in bundle_node_ip_list if ip == master_addr]
|
||||||
|
if preferred_indices:
|
||||||
|
remaining = [i for i in sorted_bundle_indices if i not in preferred_indices]
|
||||||
|
sorted_bundle_indices = preferred_indices + remaining
|
||||||
|
|
||||||
|
return sorted_bundle_indices
|
||||||
|
|||||||
@@ -23,9 +23,9 @@ from transformers import EarlyStoppingCallback, PreTrainedModel
|
|||||||
from ..data import get_template_and_fix_tokenizer
|
from ..data import get_template_and_fix_tokenizer
|
||||||
from ..extras import logging
|
from ..extras import logging
|
||||||
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||||
from ..extras.misc import infer_optim_dtype
|
from ..extras.misc import find_available_port, get_device_name, get_torch_device, infer_optim_dtype
|
||||||
from ..extras.packages import is_mcore_adapter_available, is_ray_available
|
from ..extras.packages import is_mcore_adapter_available, is_ray_available
|
||||||
from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args
|
from ..hparams import RayArguments, get_infer_args, get_ray_args, get_train_args, read_args
|
||||||
from ..model import load_model, load_tokenizer
|
from ..model import load_model, load_tokenizer
|
||||||
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
|
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
|
||||||
from .dpo import run_dpo
|
from .dpo import run_dpo
|
||||||
@@ -34,12 +34,17 @@ from .ppo import run_ppo
|
|||||||
from .pt import run_pt
|
from .pt import run_pt
|
||||||
from .rm import run_rm
|
from .rm import run_rm
|
||||||
from .sft import run_sft
|
from .sft import run_sft
|
||||||
from .trainer_utils import get_ray_trainer, get_swanlab_callback
|
from .trainer_utils import (
|
||||||
|
get_placement_group,
|
||||||
|
get_ray_head_node_ip,
|
||||||
|
get_ray_remote_config_for_worker,
|
||||||
|
get_swanlab_callback,
|
||||||
|
sort_placement_group_by_node_ip,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if is_ray_available():
|
if is_ray_available():
|
||||||
import ray
|
import ray
|
||||||
from ray.train.huggingface.transformers import RayTrainReportCallback
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -115,13 +120,7 @@ def run_exp(args: Optional[dict[str, Any]] = None, callbacks: Optional[list["Tra
|
|||||||
ray_args = get_ray_args(args)
|
ray_args = get_ray_args(args)
|
||||||
callbacks = callbacks or []
|
callbacks = callbacks or []
|
||||||
if ray_args.use_ray:
|
if ray_args.use_ray:
|
||||||
callbacks.append(RayTrainReportCallback())
|
_ray_training_function(ray_args, config={"args": args, "callbacks": callbacks})
|
||||||
trainer = get_ray_trainer(
|
|
||||||
training_function=_training_function,
|
|
||||||
train_loop_config={"args": args, "callbacks": callbacks},
|
|
||||||
ray_args=ray_args,
|
|
||||||
)
|
|
||||||
trainer.fit()
|
|
||||||
else:
|
else:
|
||||||
_training_function(config={"args": args, "callbacks": callbacks})
|
_training_function(config={"args": args, "callbacks": callbacks})
|
||||||
|
|
||||||
@@ -212,3 +211,94 @@ def export_model(args: Optional[dict[str, Any]] = None) -> None:
|
|||||||
with open(ollama_modelfile, "w", encoding="utf-8") as f:
|
with open(ollama_modelfile, "w", encoding="utf-8") as f:
|
||||||
f.write(template.get_ollama_modelfile(tokenizer))
|
f.write(template.get_ollama_modelfile(tokenizer))
|
||||||
logger.info_rank0(f"Ollama modelfile saved in {ollama_modelfile}")
|
logger.info_rank0(f"Ollama modelfile saved in {ollama_modelfile}")
|
||||||
|
|
||||||
|
|
||||||
|
class Worker:
|
||||||
|
def __init__(self):
|
||||||
|
self._setup_env_visible_devices()
|
||||||
|
|
||||||
|
local_rank = os.environ.get("LOCAL_RANK", "0")
|
||||||
|
get_torch_device().set_device(int(local_rank))
|
||||||
|
|
||||||
|
def _setup_env_visible_devices(self) -> None:
|
||||||
|
RAY_NOSET_VISIBLE_DEVICES_LIST = [
|
||||||
|
"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES",
|
||||||
|
"RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES",
|
||||||
|
]
|
||||||
|
is_ray_noset_visible_devices = any(os.environ.get(env_var, None) for env_var in RAY_NOSET_VISIBLE_DEVICES_LIST)
|
||||||
|
if is_ray_noset_visible_devices:
|
||||||
|
device_name = get_device_name().upper()
|
||||||
|
local_rank = ray.get_runtime_context().get_accelerator_ids()[device_name][0]
|
||||||
|
os.environ["LOCAL_RANK"] = local_rank
|
||||||
|
else:
|
||||||
|
os.environ["LOCAL_RANK"] = "0"
|
||||||
|
|
||||||
|
def _training_function(self, config: dict[str, Any]) -> None:
|
||||||
|
_training_function(config)
|
||||||
|
|
||||||
|
|
||||||
|
def _ray_training_function(ray_args: "RayArguments", config: dict[str, Any]) -> None:
|
||||||
|
num_workers = ray_args.ray_num_workers
|
||||||
|
master_addr = ray_args.master_addr
|
||||||
|
master_port = ray_args.master_port
|
||||||
|
logger.info(f"Using ray.remote mode with {num_workers} workers for distributed training.")
|
||||||
|
|
||||||
|
# initialize ray
|
||||||
|
if not ray.is_initialized():
|
||||||
|
if ray_args.ray_init_kwargs is not None:
|
||||||
|
ray.init(**ray_args.ray_init_kwargs)
|
||||||
|
else:
|
||||||
|
ray.init()
|
||||||
|
|
||||||
|
# verify resources
|
||||||
|
device_name = get_device_name().upper()
|
||||||
|
total_devices = int(ray.cluster_resources().get(device_name, 0))
|
||||||
|
if num_workers > total_devices:
|
||||||
|
raise ValueError(
|
||||||
|
f"The number of devices in the Ray cluster ({total_devices}) should be greater than num_workers ({num_workers})."
|
||||||
|
)
|
||||||
|
|
||||||
|
# verify master_addr
|
||||||
|
if master_addr is None:
|
||||||
|
master_addr = get_ray_head_node_ip()
|
||||||
|
logger.info(f"`master_addr` is not specified, using head node ip: {master_addr}.")
|
||||||
|
else:
|
||||||
|
nodes = [node["NodeManagerAddress"] for node in ray.nodes() if node["Alive"]]
|
||||||
|
if master_addr not in nodes:
|
||||||
|
raise ValueError(f"The `master_addr` ({master_addr}) is not in Ray cluster or not alive ")
|
||||||
|
|
||||||
|
# create placementgroup for resource management
|
||||||
|
pg, bundle = get_placement_group(total_devices)
|
||||||
|
ray.get(pg.ready())
|
||||||
|
logger.info(f"Create placement group with {num_workers} bundles: {bundle}")
|
||||||
|
|
||||||
|
# get sorted_bundle_indices
|
||||||
|
sorted_bundle_indices = sort_placement_group_by_node_ip(pg, master_addr)
|
||||||
|
|
||||||
|
# get master port
|
||||||
|
if master_port is None:
|
||||||
|
master_port = find_available_port()
|
||||||
|
logger.info(f"`master_port` is not specified, using available port: {master_port}.")
|
||||||
|
master_port = str(master_port)
|
||||||
|
|
||||||
|
# backing up environment variables
|
||||||
|
current_env = dict(os.environ.items())
|
||||||
|
|
||||||
|
# launch workers
|
||||||
|
RayWorker = ray.remote(Worker)
|
||||||
|
workers = []
|
||||||
|
for rank in range(num_workers):
|
||||||
|
remote_config = get_ray_remote_config_for_worker(
|
||||||
|
placement_group=pg,
|
||||||
|
bundle_idx=sorted_bundle_indices[rank],
|
||||||
|
rank=rank,
|
||||||
|
world_size=num_workers,
|
||||||
|
master_addr=master_addr,
|
||||||
|
master_port=master_port,
|
||||||
|
env=current_env,
|
||||||
|
)
|
||||||
|
worker = RayWorker.options(**remote_config).remote()
|
||||||
|
workers.append(worker)
|
||||||
|
|
||||||
|
ray.get([worker._training_function.remote(config=config) for worker in workers])
|
||||||
|
ray.shutdown()
|
||||||
|
|||||||
Reference in New Issue
Block a user