add Baichuan2 models

Former-commit-id: 36960025e9274b574f57e7a7bf453cd96956e922
This commit is contained in:
hiyouga
2023-09-06 18:36:04 +08:00
parent b91fc1f5b3
commit 218f36bca5
3 changed files with 11 additions and 2 deletions

View File

@@ -1,3 +1,4 @@
import gc
import torch
from typing import TYPE_CHECKING, List, Optional, Tuple
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
@@ -98,6 +99,7 @@ def torch_gc() -> None:
r"""
Collects GPU memory.
"""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()