Former-commit-id: efbb32afdcf0d6aa4ca26f54c95f76dbb84f77dc
This commit is contained in:
hiyouga
2023-12-16 20:50:45 +08:00
parent f927601702
commit 790a31404a
3 changed files with 22 additions and 17 deletions

View File

@@ -81,9 +81,16 @@ def load_model_and_tokenizer(
if add_valuehead:
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
patcher.patch_valuehead_model(model)
vhead_params = load_valuehead_params(model_args)
if model_args.adapter_name_or_path is not None:
vhead_path = model_args.adapter_name_or_path[-1]
else:
vhead_path = model_args.model_name_or_path
vhead_params = load_valuehead_params(vhead_path, model_args)
if vhead_params is not None:
model.load_state_dict(vhead_params, strict=False)
logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path))
if not is_trainable:
model.requires_grad_(False) # fix all model params