fix bug in PPO training

Former-commit-id: 2e99f0e53ce6de0acbcab85dd50aef874e8c6336
This commit is contained in:
hiyouga
2023-11-16 02:32:54 +08:00
parent f81a8a5e5c
commit e017266b98
3 changed files with 7 additions and 4 deletions

View File

@@ -57,7 +57,7 @@ def create_reward_model(
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
if "default" in name:
param.data = param.data.to(torch.float32) # trainable params should in fp32
vhead_params = load_valuehead_params(model_args.checkpoint_dir[-1], model_args)
vhead_params = load_valuehead_params(finetuning_args.reward_model, model_args)
assert vhead_params is not None, "Reward model is not correctly loaded."
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)