@@ -85,34 +85,21 @@ def get_modelcard_args(
|
||||
}
|
||||
|
||||
|
||||
def load_valuehead_params(model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
|
||||
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
|
||||
r"""
|
||||
Loads value head parameters from Hugging Face Hub or local disk.
|
||||
|
||||
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
|
||||
"""
|
||||
if model_args.adapter_name_or_path is not None:
|
||||
path_or_repo_id = model_args.adapter_name_or_path[-1]
|
||||
else:
|
||||
path_or_repo_id = model_args.model_name_or_path
|
||||
|
||||
kwargs = {
|
||||
"path_or_repo_id": path_or_repo_id,
|
||||
"cache_dir": model_args.cache_dir,
|
||||
"token": model_args.hf_hub_token
|
||||
}
|
||||
|
||||
try:
|
||||
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
|
||||
logger.info("Loaded valuehead from {}".format(path_or_repo_id))
|
||||
return torch.load(vhead_file, map_location="cpu")
|
||||
except Exception as err:
|
||||
logger.info("Failed to load {}: {}".format(WEIGHTS_NAME, str(err)))
|
||||
|
||||
try:
|
||||
from safetensors import safe_open
|
||||
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
|
||||
logger.info("Loaded valuehead from {}".format(path_or_repo_id))
|
||||
with safe_open(vhead_file, framework="pt", device="cpu") as f:
|
||||
return {
|
||||
"v_head.summary.weight": f.get_tensor("v_head.summary.weight"),
|
||||
@@ -121,6 +108,12 @@ def load_valuehead_params(model_args: "ModelArguments") -> Dict[str, torch.Tenso
|
||||
except Exception as err:
|
||||
logger.info("Failed to load {}: {}".format(SAFE_WEIGHTS_NAME, str(err)))
|
||||
|
||||
try:
|
||||
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
|
||||
return torch.load(vhead_file, map_location="cpu")
|
||||
except Exception as err:
|
||||
logger.info("Failed to load {}: {}".format(WEIGHTS_NAME, str(err)))
|
||||
|
||||
logger.warning("Provided path ({}) does not contain valuehead weights.".format(path_or_repo_id))
|
||||
return None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user