@@ -3,8 +3,8 @@ import inspect
|
||||
from typing import TYPE_CHECKING, Any, Dict, List
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.utils import cached_file
|
||||
from transformers.trainer import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
||||
|
||||
from llmtuner.extras.constants import V_HEAD_WEIGHTS_NAME, V_HEAD_SAFE_WEIGHTS_NAME
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.misc import get_current_device
|
||||
|
||||
@@ -103,22 +103,20 @@ def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") ->
|
||||
|
||||
try:
|
||||
from safetensors import safe_open
|
||||
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
|
||||
vhead_file = cached_file(filename=V_HEAD_SAFE_WEIGHTS_NAME, **kwargs)
|
||||
with safe_open(vhead_file, framework="pt", device="cpu") as f:
|
||||
return {
|
||||
"v_head.summary.weight": f.get_tensor("v_head.summary.weight"),
|
||||
"v_head.summary.bias": f.get_tensor("v_head.summary.bias")
|
||||
}
|
||||
return {key: f.get_tensor(key) for key in f.keys()}
|
||||
except Exception as err:
|
||||
logger.info("Failed to load {}: {}".format(SAFE_WEIGHTS_NAME, str(err)))
|
||||
logger.info("Failed to load {}: {}".format(V_HEAD_SAFE_WEIGHTS_NAME, str(err)))
|
||||
|
||||
try:
|
||||
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
|
||||
vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs)
|
||||
return torch.load(vhead_file, map_location="cpu")
|
||||
except Exception as err:
|
||||
logger.info("Failed to load {}: {}".format(WEIGHTS_NAME, str(err)))
|
||||
logger.info("Failed to load {}: {}".format(V_HEAD_WEIGHTS_NAME, str(err)))
|
||||
|
||||
logger.warning("Provided path ({}) does not contain valuehead weights.".format(path_or_repo_id))
|
||||
logger.info("Provided path ({}) does not contain value head weights.".format(path_or_repo_id))
|
||||
logger.info("Ignore these messages if you are not resuming the training of a value head model.")
|
||||
return None
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user