[feature] support using ray.remote to start distributed training. (#10109)

This commit is contained in:
xvxuopop
2026-01-28 16:05:29 +08:00
committed by GitHub
parent 9640f79ae5
commit 762b480131
4 changed files with 221 additions and 80 deletions

View File

@@ -157,6 +157,33 @@ def get_current_device() -> "torch.device":
return torch.device(device)
def get_device_name() -> str:
r"""Get the name of available devices."""
if is_torch_xpu_available():
device = "xpu"
elif is_torch_npu_available():
device = "npu"
elif is_torch_mps_available():
device = "mps"
elif is_torch_cuda_available():
device = "gpu"
else:
device = "cpu"
return device
def get_torch_device():
r"""Get the torch device namespace for the available devices."""
device_name = get_device_name()
device_name = "cuda" if device_name == "gpu" else device_name
try:
return getattr(torch, device_name)
except AttributeError:
logger.warning_rank0(f"Device namespace '{device_name}' not found in torch, try to load torch.cuda.")
return torch.cuda
def get_device_count() -> int:
r"""Get the number of available devices."""
if is_torch_xpu_available():