[misc] fix env vars (#7715)
This commit is contained in:
@@ -177,10 +177,10 @@ def get_peak_memory() -> tuple[int, int]:
|
||||
r"""Get the peak memory usage for the current device (in Bytes)."""
|
||||
if is_torch_npu_available():
|
||||
return torch.npu.max_memory_allocated(), torch.npu.max_memory_reserved()
|
||||
elif is_torch_cuda_available():
|
||||
return torch.cuda.max_memory_allocated(), torch.cuda.max_memory_reserved()
|
||||
elif is_torch_xpu_available():
|
||||
return torch.xpu.max_memory_allocated(), torch.xpu.max_memory_reserved()
|
||||
elif is_torch_cuda_available():
|
||||
return torch.cuda.max_memory_allocated(), torch.cuda.max_memory_reserved()
|
||||
else:
|
||||
return 0, 0
|
||||
|
||||
|
||||
Reference in New Issue
Block a user