add tests
Former-commit-id: 484634ee9c982e82e919ff67d507e0210345182d
This commit is contained in:
@@ -22,6 +22,7 @@ from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList, PreTr
|
||||
from transformers.utils import (
|
||||
SAFE_WEIGHTS_NAME,
|
||||
WEIGHTS_NAME,
|
||||
is_safetensors_available,
|
||||
is_torch_bf16_gpu_available,
|
||||
is_torch_cuda_available,
|
||||
is_torch_mps_available,
|
||||
@@ -34,6 +35,11 @@ from .constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||
from .logging import get_logger
|
||||
|
||||
|
||||
if is_safetensors_available():
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import save_file
|
||||
|
||||
|
||||
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
||||
try:
|
||||
_is_bf16_available = is_torch_bf16_gpu_available()
|
||||
@@ -128,9 +134,6 @@ def fix_valuehead_checkpoint(
|
||||
return
|
||||
|
||||
if safe_serialization:
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import save_file
|
||||
|
||||
path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
|
||||
with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
|
||||
state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
|
||||
|
||||
Reference in New Issue
Block a user