refactor adapter hparam

Former-commit-id: f82aece9ebd6df83a7a005cc7cbbcec07fa6e14d
This commit is contained in:
hiyouga
2023-12-15 20:53:11 +08:00
parent 27ef5b1aa7
commit f902b0d420
21 changed files with 302 additions and 311 deletions

View File

@@ -1,5 +1,4 @@
import torch
import inspect
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
from transformers.utils import cached_file
@@ -86,29 +85,26 @@ def get_modelcard_args(
}
def load_valuehead_params(
path_or_repo_id: str,
model_args: "ModelArguments"
) -> Dict[str, torch.Tensor]:
def load_valuehead_params(model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
r"""
Loads value head parameters from Hugging Face Hub or local disk.
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
"""
if model_args.adapter_name_or_path is not None:
path_or_repo_id = model_args.adapter_name_or_path[-1]
else:
path_or_repo_id = model_args.model_name_or_path
kwargs = {
"path_or_repo_id": path_or_repo_id,
"cache_dir": model_args.cache_dir
"cache_dir": model_args.cache_dir,
"token": model_args.hf_hub_token
}
if "token" in inspect.signature(cached_file).parameters:
kwargs["token"] = model_args.hf_hub_token
elif "use_auth_token" in inspect.signature(cached_file).parameters: # for transformers==4.31.0
kwargs["use_auth_token"] = model_args.hf_hub_token
else:
logger.warning("Ignore `hf_hub_token` since matched parameter is not found.")
try:
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
logger.info("Loaded valuehead from {}".format(path_or_repo_id))
return torch.load(vhead_file, map_location="cpu")
except Exception as err:
logger.info("Failed to load {}: {}".format(WEIGHTS_NAME, str(err)))
@@ -116,6 +112,7 @@ def load_valuehead_params(
try:
from safetensors import safe_open
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
logger.info("Loaded valuehead from {}".format(path_or_repo_id))
with safe_open(vhead_file, framework="pt", device="cpu") as f:
return {
"v_head.summary.weight": f.get_tensor("v_head.summary.weight"),