[misc] fix env vars (#7715)

This commit is contained in:
hoshi-hiyouga
2025-04-14 16:04:04 +08:00
committed by GitHub
parent 7c61b35106
commit 3f91a95250
3 changed files with 5 additions and 3 deletions

View File

@@ -177,10 +177,10 @@ def get_peak_memory() -> tuple[int, int]:
r"""Get the peak memory usage for the current device (in Bytes)."""
if is_torch_npu_available():
return torch.npu.max_memory_allocated(), torch.npu.max_memory_reserved()
elif is_torch_cuda_available():
return torch.cuda.max_memory_allocated(), torch.cuda.max_memory_reserved()
elif is_torch_xpu_available():
return torch.xpu.max_memory_allocated(), torch.xpu.max_memory_reserved()
elif is_torch_cuda_available():
return torch.cuda.max_memory_allocated(), torch.cuda.max_memory_reserved()
else:
return 0, 0