fix rlhf callback

Former-commit-id: f5485452d660caef56474cb7dc37abbe4f34599e
This commit is contained in:
hiyouga
2023-11-16 03:26:19 +08:00
parent e017266b98
commit de3a84ac59
4 changed files with 19 additions and 12 deletions

View File

@@ -203,12 +203,13 @@ def load_model_and_tokenizer(
if stage in ["rm", "ppo"]:
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
reset_logging()
if model_args.checkpoint_dir is not None: # load valuehead weights if exists
logger.warning("Only the last checkpoint containing valuehead will be loaded.")
vhead_params = load_valuehead_params(model_args.checkpoint_dir[-1], model_args)
if vhead_params is not None:
model.load_state_dict(vhead_params, strict=False)
logger.info("Loaded valuehead from checkpoint: {}".format(model_args.checkpoint_dir[-1]))
vhead_path = (
model_args.checkpoint_dir[-1] if model_args.checkpoint_dir is not None else model_args.model_name_or_path
)
vhead_params = load_valuehead_params(vhead_path, model_args)
if vhead_params is not None:
model.load_state_dict(vhead_params, strict=False)
logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path))
# Prepare model for inference
if not is_trainable: