add tests

Former-commit-id: 484634ee9c982e82e919ff67d507e0210345182d
This commit is contained in:
hiyouga
2024-06-15 19:51:20 +08:00
parent 308abfec6c
commit 7f90b0cd20
8 changed files with 166 additions and 14 deletions

View File

@@ -22,6 +22,7 @@ from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList, PreTr
from transformers.utils import (
SAFE_WEIGHTS_NAME,
WEIGHTS_NAME,
is_safetensors_available,
is_torch_bf16_gpu_available,
is_torch_cuda_available,
is_torch_mps_available,
@@ -34,6 +35,11 @@ from .constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from .logging import get_logger
if is_safetensors_available():
from safetensors import safe_open
from safetensors.torch import save_file
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
try:
_is_bf16_available = is_torch_bf16_gpu_available()
@@ -128,9 +134,6 @@ def fix_valuehead_checkpoint(
return
if safe_serialization:
from safetensors import safe_open
from safetensors.torch import save_file
path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}