mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-04 21:23:09 +00:00
[feature] support using ray.remote to start distributed training. (#10109)
This commit is contained in:
@@ -157,6 +157,33 @@ def get_current_device() -> "torch.device":
|
||||
return torch.device(device)
|
||||
|
||||
|
||||
def get_device_name() -> str:
|
||||
r"""Get the name of available devices."""
|
||||
if is_torch_xpu_available():
|
||||
device = "xpu"
|
||||
elif is_torch_npu_available():
|
||||
device = "npu"
|
||||
elif is_torch_mps_available():
|
||||
device = "mps"
|
||||
elif is_torch_cuda_available():
|
||||
device = "gpu"
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
return device
|
||||
|
||||
|
||||
def get_torch_device():
|
||||
r"""Get the torch device namespace for the available devices."""
|
||||
device_name = get_device_name()
|
||||
device_name = "cuda" if device_name == "gpu" else device_name
|
||||
try:
|
||||
return getattr(torch, device_name)
|
||||
except AttributeError:
|
||||
logger.warning_rank0(f"Device namespace '{device_name}' not found in torch, try to load torch.cuda.")
|
||||
return torch.cuda
|
||||
|
||||
|
||||
def get_device_count() -> int:
|
||||
r"""Get the number of available devices."""
|
||||
if is_torch_xpu_available():
|
||||
|
||||
Reference in New Issue
Block a user