drafting ray integration
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com> Former-commit-id: 19c12ddae9350f6e25a270fe3372f5b9094cf960
This commit is contained in:
committed by
hiyouga
parent
5ccc607222
commit
8683582300
0
src/llamafactory/integrations/__init__.py
Normal file
0
src/llamafactory/integrations/__init__.py
Normal file
0
src/llamafactory/integrations/ray/__init__.py
Normal file
0
src/llamafactory/integrations/ray/__init__.py
Normal file
28
src/llamafactory/integrations/ray/ray_train.py
Normal file
28
src/llamafactory/integrations/ray/ray_train.py
Normal file
@@ -0,0 +1,28 @@
|
||||
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
from ray.train.torch import TorchTrainer
|
||||
from ray.train import ScalingConfig
|
||||
|
||||
from .ray_train_args import RayTrainArguments
|
||||
|
||||
|
||||
def get_ray_trainer(
|
||||
training_function: Callable,
|
||||
train_loop_config: Dict[str, Any],
|
||||
ray_args: RayTrainArguments,
|
||||
) -> TorchTrainer:
|
||||
|
||||
if not ray_args.use_ray:
|
||||
raise ValueError("Ray is not enabled. Please set USE_RAY=1 in your environment.")
|
||||
|
||||
trainer = TorchTrainer(
|
||||
training_function,
|
||||
train_loop_config=train_loop_config,
|
||||
scaling_config=ScalingConfig(
|
||||
num_workers=ray_args.num_workers,
|
||||
resources_per_worker=ray_args.resources_per_worker,
|
||||
use_gpu=True,
|
||||
),
|
||||
)
|
||||
return trainer
|
||||
22
src/llamafactory/integrations/ray/ray_train_args.py
Normal file
22
src/llamafactory/integrations/ray/ray_train_args.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
|
||||
from .ray_utils import should_use_ray
|
||||
|
||||
@dataclass
|
||||
class RayTrainArguments:
|
||||
r"""
|
||||
Arguments pertaining to the Ray training.
|
||||
"""
|
||||
resources_per_worker: Optional[Dict[str, Any]] = field(default_factory=lambda: {"GPU": 1}, metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."})
|
||||
num_workers: Optional[int] = field(default=1, metadata={"help": "The number of workers for Ray training. Default is 1 worker."})
|
||||
placement_strategy: Optional[Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"]] = field(default="PACK", metadata={"help": "The placement strategy for Ray training. Default is PACK."})
|
||||
|
||||
@property
|
||||
def use_ray(self) -> bool:
|
||||
"""
|
||||
Always returns the value from the environment variable check.
|
||||
This prevents manual setting of use_ray.
|
||||
"""
|
||||
return should_use_ray()
|
||||
|
||||
9
src/llamafactory/integrations/ray/ray_utils.py
Normal file
9
src/llamafactory/integrations/ray/ray_utils.py
Normal file
@@ -0,0 +1,9 @@
|
||||
|
||||
import os
|
||||
|
||||
|
||||
def should_use_ray():
|
||||
return os.getenv("USE_RAY", "0").lower() in ["true", "1"]
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user