fix evaluator and cached_file in 4.31.0

Former-commit-id: 970897da402f604220d45084d492de4dab809ba4
This commit is contained in:
hiyouga
2023-11-18 19:39:23 +08:00
parent 775ce62950
commit 3d1ee27ccd
2 changed files with 37 additions and 12 deletions

View File

@@ -1,4 +1,5 @@
import torch
import inspect
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
from transformers.utils import cached_file
@@ -94,19 +95,35 @@ def load_valuehead_params(
"""
kwargs = {
"path_or_repo_id": path_or_repo_id,
"cache_dir": model_args.cache_dir,
"token": model_args.hf_hub_token
"cache_dir": model_args.cache_dir
}
if "token" in inspect.signature(cached_file).parameters:
kwargs["token"] = model_args.hf_hub_token
elif "use_auth_token" in inspect.signature(cached_file).parameters: # for transformers==4.31.0
kwargs["use_auth_token"] = model_args.hf_hub_token
else:
logger.warning("Ignore `hf_hub_token` since matched parameter is not found.")
try:
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
except:
try:
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
except:
logger.warning("Provided path ({}) does not contain valuehead weights.".format(path_or_repo_id))
return None
return torch.load(vhead_file, map_location="cpu")
except Exception as err:
logger.info("Failed to load {}: {}".format(WEIGHTS_NAME, str(err)))
return torch.load(vhead_file, map_location="cpu")
try:
from safetensors import safe_open
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
with safe_open(vhead_file, framework="pt", device="cpu") as f:
return {
"v_head.summary.weight": f.get_tensor("v_head.summary.weight"),
"v_head.summary.bias": f.get_tensor("v_head.summary.bias")
}
except Exception as err:
logger.info("Failed to load {}: {}".format(SAFE_WEIGHTS_NAME, str(err)))
logger.warning("Provided path ({}) does not contain valuehead weights.".format(path_or_repo_id))
return None
def prepare_model_for_training(