mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-01-30 06:12:04 +00:00
Compare commits
1 Commits
9640f79ae5
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
762b480131 |
@@ -157,6 +157,33 @@ def get_current_device() -> "torch.device":
|
||||
return torch.device(device)
|
||||
|
||||
|
||||
def get_device_name() -> str:
|
||||
r"""Get the name of available devices."""
|
||||
if is_torch_xpu_available():
|
||||
device = "xpu"
|
||||
elif is_torch_npu_available():
|
||||
device = "npu"
|
||||
elif is_torch_mps_available():
|
||||
device = "mps"
|
||||
elif is_torch_cuda_available():
|
||||
device = "gpu"
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
return device
|
||||
|
||||
|
||||
def get_torch_device():
|
||||
r"""Get the torch device namespace for the available devices."""
|
||||
device_name = get_device_name()
|
||||
device_name = "cuda" if device_name == "gpu" else device_name
|
||||
try:
|
||||
return getattr(torch, device_name)
|
||||
except AttributeError:
|
||||
logger.warning_rank0(f"Device namespace '{device_name}' not found in torch, try to load torch.cuda.")
|
||||
return torch.cuda
|
||||
|
||||
|
||||
def get_device_count() -> int:
|
||||
r"""Get the number of available devices."""
|
||||
if is_torch_xpu_available():
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal
|
||||
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from transformers.training_args import _convert_str_dict
|
||||
@@ -40,56 +39,29 @@ else:
|
||||
class RayArguments:
|
||||
r"""Arguments pertaining to the Ray training."""
|
||||
|
||||
ray_run_name: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The training results will be saved at `<ray_storage_path>/ray_run_name`."},
|
||||
)
|
||||
ray_storage_path: str = field(
|
||||
default="./saves",
|
||||
metadata={"help": "The storage path to save training results to"},
|
||||
)
|
||||
ray_storage_filesystem: Literal["s3", "gs", "gcs"] | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The storage filesystem to use. If None specified, local filesystem will be used."},
|
||||
)
|
||||
ray_num_workers: int = field(
|
||||
default=1,
|
||||
metadata={"help": "The number of workers for Ray training. Default is 1 worker."},
|
||||
)
|
||||
resources_per_worker: dict | str = field(
|
||||
default_factory=lambda: {"GPU": 1},
|
||||
metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."},
|
||||
)
|
||||
placement_strategy: Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"] = field(
|
||||
default="PACK",
|
||||
metadata={"help": "The placement strategy for Ray training. Default is PACK."},
|
||||
)
|
||||
ray_init_kwargs: dict | str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The arguments to pass to ray.init for Ray training. Default is None."},
|
||||
)
|
||||
master_addr: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The master address for init_process_group"},
|
||||
)
|
||||
master_port: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The master port for init_process_group"},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
self.use_ray = use_ray()
|
||||
if isinstance(self.resources_per_worker, str) and self.resources_per_worker.startswith("{"):
|
||||
self.resources_per_worker = _convert_str_dict(json.loads(self.resources_per_worker))
|
||||
|
||||
if isinstance(self.ray_init_kwargs, str) and self.ray_init_kwargs.startswith("{"):
|
||||
self.ray_init_kwargs = _convert_str_dict(json.loads(self.ray_init_kwargs))
|
||||
|
||||
if self.ray_storage_filesystem is not None:
|
||||
if self.ray_storage_filesystem not in ["s3", "gs", "gcs"]:
|
||||
raise ValueError(
|
||||
f"ray_storage_filesystem must be one of ['s3', 'gs', 'gcs'], got {self.ray_storage_filesystem}."
|
||||
)
|
||||
|
||||
import pyarrow.fs as fs
|
||||
|
||||
if self.ray_storage_filesystem == "s3":
|
||||
self.ray_storage_filesystem = fs.S3FileSystem()
|
||||
elif self.ray_storage_filesystem == "gs" or self.ray_storage_filesystem == "gcs":
|
||||
self.ray_storage_filesystem = fs.GcsFileSystem()
|
||||
|
||||
|
||||
@dataclass
|
||||
class Fp8Arguments:
|
||||
|
||||
@@ -20,7 +20,6 @@
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Callable, Mapping
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import torch
|
||||
@@ -34,6 +33,7 @@ from typing_extensions import override
|
||||
|
||||
from ..extras import logging
|
||||
from ..extras.constants import IGNORE_INDEX, SWANLAB_CONFIG
|
||||
from ..extras.misc import get_device_name
|
||||
from ..extras.packages import is_apollo_available, is_galore_available, is_ray_available
|
||||
from ..hparams import FinetuningArguments, ModelArguments
|
||||
from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params
|
||||
@@ -49,15 +49,15 @@ if is_apollo_available():
|
||||
|
||||
if is_ray_available():
|
||||
import ray
|
||||
from ray.train import RunConfig, ScalingConfig
|
||||
from ray.train.torch import TorchTrainer
|
||||
from ray.util.placement_group import PlacementGroup, placement_group
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel, TrainerCallback, TrainerState
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from ..hparams import DataArguments, RayArguments, TrainingArguments
|
||||
from ..hparams import DataArguments, TrainingArguments
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -807,36 +807,88 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
|
||||
return swanlab_callback
|
||||
|
||||
|
||||
def get_ray_trainer(
|
||||
training_function: Callable,
|
||||
train_loop_config: dict[str, Any],
|
||||
ray_args: "RayArguments",
|
||||
) -> "TorchTrainer":
|
||||
if not ray_args.use_ray:
|
||||
raise ValueError("Ray was not enabled. Please set `USE_RAY=1` to enable ray.")
|
||||
def get_placement_group(num_workers: int) -> tuple["PlacementGroup", dict[str, int]]:
|
||||
r"""Get the Ray placement group for distributed training."""
|
||||
bundle = {"CPU": 10}
|
||||
device_name = get_device_name().upper()
|
||||
if device_name != "CPU":
|
||||
bundle[device_name] = 1
|
||||
bundles = [bundle for _ in range(num_workers)]
|
||||
pg = placement_group(bundles, strategy="PACK")
|
||||
|
||||
if ray_args.ray_init_kwargs is not None:
|
||||
ray.init(**ray_args.ray_init_kwargs)
|
||||
return pg, bundle
|
||||
|
||||
if ray_args.ray_storage_filesystem is not None:
|
||||
# this means we are using s3/gcs
|
||||
storage_path = ray_args.ray_storage_path
|
||||
else:
|
||||
storage_path = Path(ray_args.ray_storage_path).absolute().as_posix()
|
||||
|
||||
trainer = TorchTrainer(
|
||||
training_function,
|
||||
train_loop_config=train_loop_config,
|
||||
scaling_config=ScalingConfig(
|
||||
num_workers=ray_args.ray_num_workers,
|
||||
resources_per_worker=ray_args.resources_per_worker,
|
||||
placement_strategy=ray_args.placement_strategy,
|
||||
use_gpu=True,
|
||||
def get_ray_remote_config_for_worker(
|
||||
placement_group: "PlacementGroup",
|
||||
bundle_idx: int,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
master_addr: str,
|
||||
master_port: str,
|
||||
env: dict[str, str] = None,
|
||||
) -> dict[str, Any]:
|
||||
r"""Get the remote config for a Ray worker."""
|
||||
env_vars = {
|
||||
"RANK": str(rank),
|
||||
"WORLD_SIZE": str(world_size),
|
||||
"MASTER_ADDR": master_addr,
|
||||
"MASTER_PORT": master_port,
|
||||
"TORCHELASTIC_USE_AGENT_STORE": "False",
|
||||
}
|
||||
env.update(env_vars)
|
||||
|
||||
remote_config = {
|
||||
"scheduling_strategy": PlacementGroupSchedulingStrategy(
|
||||
placement_group=placement_group,
|
||||
placement_group_bundle_index=bundle_idx,
|
||||
),
|
||||
run_config=RunConfig(
|
||||
name=ray_args.ray_run_name,
|
||||
storage_filesystem=ray_args.ray_storage_filesystem,
|
||||
storage_path=storage_path,
|
||||
),
|
||||
)
|
||||
return trainer
|
||||
"runtime_env": {"env_vars": env},
|
||||
"num_cpus": 10,
|
||||
}
|
||||
|
||||
device_name = get_device_name()
|
||||
if device_name == "gpu":
|
||||
remote_config["num_gpus"] = 1
|
||||
elif device_name == "npu":
|
||||
remote_config["resources"] = {"NPU": 1}
|
||||
|
||||
return remote_config
|
||||
|
||||
|
||||
def get_ray_head_node_ip() -> str:
|
||||
r"""Get the IP address of the Ray head node."""
|
||||
head_ip = next(node["NodeManagerAddress"] for node in ray.nodes() if node.get("IsHead", False))
|
||||
return head_ip
|
||||
|
||||
|
||||
def sort_placement_group_by_node_ip(placement_group: "PlacementGroup", master_addr: str = None) -> list[int]:
|
||||
r"""Sort the placement group bundles by their node IP addresses."""
|
||||
|
||||
@ray.remote
|
||||
def _get_node_ip():
|
||||
return ray.util.get_node_ip_address().strip("[]")
|
||||
|
||||
tasks = []
|
||||
for bundle_idx in range(placement_group.bundle_count):
|
||||
task = _get_node_ip.options(
|
||||
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||
placement_group=placement_group,
|
||||
placement_group_bundle_index=bundle_idx,
|
||||
),
|
||||
).remote()
|
||||
tasks.append(task)
|
||||
|
||||
bundle_ips = ray.get(tasks)
|
||||
bundle_node_ip_list = list(enumerate(bundle_ips))
|
||||
|
||||
sorted_bundle_node_ip_list = sorted(bundle_node_ip_list, key=lambda x: x[1])
|
||||
sorted_bundle_indices = [item[0] for item in sorted_bundle_node_ip_list]
|
||||
|
||||
if master_addr is not None:
|
||||
preferred_indices = [idx for idx, ip in bundle_node_ip_list if ip == master_addr]
|
||||
if preferred_indices:
|
||||
remaining = [i for i in sorted_bundle_indices if i not in preferred_indices]
|
||||
sorted_bundle_indices = preferred_indices + remaining
|
||||
|
||||
return sorted_bundle_indices
|
||||
|
||||
@@ -23,9 +23,9 @@ from transformers import EarlyStoppingCallback, PreTrainedModel
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras import logging
|
||||
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||
from ..extras.misc import infer_optim_dtype
|
||||
from ..extras.misc import find_available_port, get_device_name, get_torch_device, infer_optim_dtype
|
||||
from ..extras.packages import is_mcore_adapter_available, is_ray_available
|
||||
from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args
|
||||
from ..hparams import RayArguments, get_infer_args, get_ray_args, get_train_args, read_args
|
||||
from ..model import load_model, load_tokenizer
|
||||
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
|
||||
from .dpo import run_dpo
|
||||
@@ -34,12 +34,17 @@ from .ppo import run_ppo
|
||||
from .pt import run_pt
|
||||
from .rm import run_rm
|
||||
from .sft import run_sft
|
||||
from .trainer_utils import get_ray_trainer, get_swanlab_callback
|
||||
from .trainer_utils import (
|
||||
get_placement_group,
|
||||
get_ray_head_node_ip,
|
||||
get_ray_remote_config_for_worker,
|
||||
get_swanlab_callback,
|
||||
sort_placement_group_by_node_ip,
|
||||
)
|
||||
|
||||
|
||||
if is_ray_available():
|
||||
import ray
|
||||
from ray.train.huggingface.transformers import RayTrainReportCallback
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -115,13 +120,7 @@ def run_exp(args: Optional[dict[str, Any]] = None, callbacks: Optional[list["Tra
|
||||
ray_args = get_ray_args(args)
|
||||
callbacks = callbacks or []
|
||||
if ray_args.use_ray:
|
||||
callbacks.append(RayTrainReportCallback())
|
||||
trainer = get_ray_trainer(
|
||||
training_function=_training_function,
|
||||
train_loop_config={"args": args, "callbacks": callbacks},
|
||||
ray_args=ray_args,
|
||||
)
|
||||
trainer.fit()
|
||||
_ray_training_function(ray_args, config={"args": args, "callbacks": callbacks})
|
||||
else:
|
||||
_training_function(config={"args": args, "callbacks": callbacks})
|
||||
|
||||
@@ -212,3 +211,94 @@ def export_model(args: Optional[dict[str, Any]] = None) -> None:
|
||||
with open(ollama_modelfile, "w", encoding="utf-8") as f:
|
||||
f.write(template.get_ollama_modelfile(tokenizer))
|
||||
logger.info_rank0(f"Ollama modelfile saved in {ollama_modelfile}")
|
||||
|
||||
|
||||
class Worker:
|
||||
def __init__(self):
|
||||
self._setup_env_visible_devices()
|
||||
|
||||
local_rank = os.environ.get("LOCAL_RANK", "0")
|
||||
get_torch_device().set_device(int(local_rank))
|
||||
|
||||
def _setup_env_visible_devices(self) -> None:
|
||||
RAY_NOSET_VISIBLE_DEVICES_LIST = [
|
||||
"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES",
|
||||
"RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES",
|
||||
]
|
||||
is_ray_noset_visible_devices = any(os.environ.get(env_var, None) for env_var in RAY_NOSET_VISIBLE_DEVICES_LIST)
|
||||
if is_ray_noset_visible_devices:
|
||||
device_name = get_device_name().upper()
|
||||
local_rank = ray.get_runtime_context().get_accelerator_ids()[device_name][0]
|
||||
os.environ["LOCAL_RANK"] = local_rank
|
||||
else:
|
||||
os.environ["LOCAL_RANK"] = "0"
|
||||
|
||||
def _training_function(self, config: dict[str, Any]) -> None:
|
||||
_training_function(config)
|
||||
|
||||
|
||||
def _ray_training_function(ray_args: "RayArguments", config: dict[str, Any]) -> None:
|
||||
num_workers = ray_args.ray_num_workers
|
||||
master_addr = ray_args.master_addr
|
||||
master_port = ray_args.master_port
|
||||
logger.info(f"Using ray.remote mode with {num_workers} workers for distributed training.")
|
||||
|
||||
# initialize ray
|
||||
if not ray.is_initialized():
|
||||
if ray_args.ray_init_kwargs is not None:
|
||||
ray.init(**ray_args.ray_init_kwargs)
|
||||
else:
|
||||
ray.init()
|
||||
|
||||
# verify resources
|
||||
device_name = get_device_name().upper()
|
||||
total_devices = int(ray.cluster_resources().get(device_name, 0))
|
||||
if num_workers > total_devices:
|
||||
raise ValueError(
|
||||
f"The number of devices in the Ray cluster ({total_devices}) should be greater than num_workers ({num_workers})."
|
||||
)
|
||||
|
||||
# verify master_addr
|
||||
if master_addr is None:
|
||||
master_addr = get_ray_head_node_ip()
|
||||
logger.info(f"`master_addr` is not specified, using head node ip: {master_addr}.")
|
||||
else:
|
||||
nodes = [node["NodeManagerAddress"] for node in ray.nodes() if node["Alive"]]
|
||||
if master_addr not in nodes:
|
||||
raise ValueError(f"The `master_addr` ({master_addr}) is not in Ray cluster or not alive ")
|
||||
|
||||
# create placementgroup for resource management
|
||||
pg, bundle = get_placement_group(total_devices)
|
||||
ray.get(pg.ready())
|
||||
logger.info(f"Create placement group with {num_workers} bundles: {bundle}")
|
||||
|
||||
# get sorted_bundle_indices
|
||||
sorted_bundle_indices = sort_placement_group_by_node_ip(pg, master_addr)
|
||||
|
||||
# get master port
|
||||
if master_port is None:
|
||||
master_port = find_available_port()
|
||||
logger.info(f"`master_port` is not specified, using available port: {master_port}.")
|
||||
master_port = str(master_port)
|
||||
|
||||
# backing up environment variables
|
||||
current_env = dict(os.environ.items())
|
||||
|
||||
# launch workers
|
||||
RayWorker = ray.remote(Worker)
|
||||
workers = []
|
||||
for rank in range(num_workers):
|
||||
remote_config = get_ray_remote_config_for_worker(
|
||||
placement_group=pg,
|
||||
bundle_idx=sorted_bundle_indices[rank],
|
||||
rank=rank,
|
||||
world_size=num_workers,
|
||||
master_addr=master_addr,
|
||||
master_port=master_port,
|
||||
env=current_env,
|
||||
)
|
||||
worker = RayWorker.options(**remote_config).remote()
|
||||
workers.append(worker)
|
||||
|
||||
ray.get([worker._training_function.remote(config=config) for worker in workers])
|
||||
ray.shutdown()
|
||||
|
||||
Reference in New Issue
Block a user