run style check

Former-commit-id: 5ec33baf5f95df9fa2afe5523c825d3eda8a076b
This commit is contained in:
Eric Tang
2025-01-06 23:55:56 +00:00
committed by hiyouga
parent 8683582300
commit 4f31ad997c
7 changed files with 54 additions and 35 deletions

View File

@@ -1,21 +1,19 @@
from typing import Any, Callable, Dict
from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer
from .ray_train_args import RayTrainArguments
def get_ray_trainer(
training_function: Callable,
train_loop_config: Dict[str, Any],
ray_args: RayTrainArguments,
) -> TorchTrainer:
if not ray_args.use_ray:
raise ValueError("Ray is not enabled. Please set USE_RAY=1 in your environment.")
trainer = TorchTrainer(
training_function,
train_loop_config=train_loop_config,
@@ -25,4 +23,4 @@ def get_ray_trainer(
use_gpu=True,
),
)
return trainer
return trainer

View File

@@ -3,14 +3,23 @@ from typing import Any, Dict, Literal, Optional
from .ray_utils import should_use_ray
@dataclass
class RayTrainArguments:
r"""
Arguments pertaining to the Ray training.
"""
resources_per_worker: Optional[Dict[str, Any]] = field(default_factory=lambda: {"GPU": 1}, metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."})
num_workers: Optional[int] = field(default=1, metadata={"help": "The number of workers for Ray training. Default is 1 worker."})
placement_strategy: Optional[Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"]] = field(default="PACK", metadata={"help": "The placement strategy for Ray training. Default is PACK."})
resources_per_worker: Optional[Dict[str, Any]] = field(
default_factory=lambda: {"GPU": 1},
metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."},
)
num_workers: Optional[int] = field(
default=1, metadata={"help": "The number of workers for Ray training. Default is 1 worker."}
)
placement_strategy: Optional[Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"]] = field(
default="PACK", metadata={"help": "The placement strategy for Ray training. Default is PACK."}
)
@property
def use_ray(self) -> bool:
@@ -19,4 +28,3 @@ class RayTrainArguments:
This prevents manual setting of use_ray.
"""
return should_use_ray()

View File

@@ -1,9 +1,5 @@
import os
def should_use_ray():
return os.getenv("USE_RAY", "0").lower() in ["true", "1"]