fix loading valuehead

Former-commit-id: 7872375d7a0c1d8826206631f6717a91ec49f1b3
This commit is contained in:
hiyouga
2023-06-13 11:13:06 +08:00
parent 6828f07d54
commit 6f655e3916
2 changed files with 18 additions and 11 deletions

View File

@@ -94,7 +94,7 @@ def _init_adapter(
if model_args.checkpoint_dir is not None:
if finetuning_args.finetuning_type != "lora":
assert is_mergeable and len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
load_trainable_params(model, model_args.checkpoint_dir[0]) # load model checkpoints for non-peft methods
assert load_trainable_params(model, model_args.checkpoint_dir[0]), "Model checkpoint is not correctly loaded."
else:
assert is_mergeable or len(model_args.checkpoint_dir) == 1, "Quantized model only accepts a single checkpoint."
@@ -217,18 +217,19 @@ def load_pretrained(
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
load_valuehead_params(model, model_args.checkpoint_dir[0])
model.v_head.load_state_dict({
"summary.weight": getattr(model, "reward_head_weight"),
"summary.bias": getattr(model, "reward_head_bias")
})
logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.")
if load_valuehead_params(model, model_args.checkpoint_dir[-1]):
model.v_head.load_state_dict({
"summary.weight": getattr(model, "reward_head_weight"),
"summary.bias": getattr(model, "reward_head_bias")
})
if stage == "ppo": # load reward model
assert is_trainable, "PPO stage cannot be performed at evaluation."
assert model_args.reward_model is not None, "Reward model is necessary for PPO training."
logger.info("Load reward model from {}".format(model_args.reward_model))
model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False)
load_valuehead_params(model, model_args.reward_model)
assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded."
if not is_trainable:
model.requires_grad_(False) # fix all model params