fix torch gc

Former-commit-id: e173799d057598e5692a407601c30d8ce1513461
This commit is contained in:
hiyouga
2024-06-06 20:30:25 +08:00
parent ca95e98ca0
commit 80f716bc10
3 changed files with 14 additions and 8 deletions

View File

@@ -212,12 +212,17 @@ def has_tokenized_data(path: os.PathLike) -> bool:
def torch_gc() -> None:
r"""
Collects GPU memory.
Collects GPU or NPU memory.
"""
gc.collect()
if torch.cuda.is_available():
if is_torch_xpu_available():
torch.xpu.empty_cache()
elif is_torch_npu_available():
torch.npu.empty_cache()
elif is_torch_mps_available():
torch.mps.empty_cache()
elif is_torch_cuda_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def try_download_model_from_ms(model_args: "ModelArguments") -> str: