fix torch gc
Former-commit-id: e173799d057598e5692a407601c30d8ce1513461
This commit is contained in:
@@ -212,12 +212,17 @@ def has_tokenized_data(path: os.PathLike) -> bool:
|
||||
|
||||
def torch_gc() -> None:
|
||||
r"""
|
||||
Collects GPU memory.
|
||||
Collects GPU or NPU memory.
|
||||
"""
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
if is_torch_xpu_available():
|
||||
torch.xpu.empty_cache()
|
||||
elif is_torch_npu_available():
|
||||
torch.npu.empty_cache()
|
||||
elif is_torch_mps_available():
|
||||
torch.mps.empty_cache()
|
||||
elif is_torch_cuda_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
|
||||
def try_download_model_from_ms(model_args: "ModelArguments") -> str:
|
||||
|
||||
Reference in New Issue
Block a user