run style check
Former-commit-id: 5ec33baf5f95df9fa2afe5523c825d3eda8a076b
This commit is contained in:
@@ -1,21 +1,19 @@
|
||||
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
from ray.train.torch import TorchTrainer
|
||||
from ray.train import ScalingConfig
|
||||
from ray.train.torch import TorchTrainer
|
||||
|
||||
from .ray_train_args import RayTrainArguments
|
||||
|
||||
|
||||
|
||||
def get_ray_trainer(
|
||||
training_function: Callable,
|
||||
train_loop_config: Dict[str, Any],
|
||||
ray_args: RayTrainArguments,
|
||||
) -> TorchTrainer:
|
||||
|
||||
if not ray_args.use_ray:
|
||||
raise ValueError("Ray is not enabled. Please set USE_RAY=1 in your environment.")
|
||||
|
||||
|
||||
trainer = TorchTrainer(
|
||||
training_function,
|
||||
train_loop_config=train_loop_config,
|
||||
@@ -25,4 +23,4 @@ def get_ray_trainer(
|
||||
use_gpu=True,
|
||||
),
|
||||
)
|
||||
return trainer
|
||||
return trainer
|
||||
|
||||
Reference in New Issue
Block a user