fix rlhf callback
Former-commit-id: f5485452d660caef56474cb7dc37abbe4f34599e
This commit is contained in:
@@ -203,12 +203,13 @@ def load_model_and_tokenizer(
|
||||
if stage in ["rm", "ppo"]:
|
||||
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||
reset_logging()
|
||||
if model_args.checkpoint_dir is not None: # load valuehead weights if exists
|
||||
logger.warning("Only the last checkpoint containing valuehead will be loaded.")
|
||||
vhead_params = load_valuehead_params(model_args.checkpoint_dir[-1], model_args)
|
||||
if vhead_params is not None:
|
||||
model.load_state_dict(vhead_params, strict=False)
|
||||
logger.info("Loaded valuehead from checkpoint: {}".format(model_args.checkpoint_dir[-1]))
|
||||
vhead_path = (
|
||||
model_args.checkpoint_dir[-1] if model_args.checkpoint_dir is not None else model_args.model_name_or_path
|
||||
)
|
||||
vhead_params = load_valuehead_params(vhead_path, model_args)
|
||||
if vhead_params is not None:
|
||||
model.load_state_dict(vhead_params, strict=False)
|
||||
logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path))
|
||||
|
||||
# Prepare model for inference
|
||||
if not is_trainable:
|
||||
|
||||
Reference in New Issue
Block a user