support lora target auto find
Former-commit-id: bce9984733d88bf013847eed523d1c75fdf0995e
This commit is contained in:
@@ -42,8 +42,8 @@ def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -
|
||||
logger.warning("Provided path ({}) does not contain valuehead weights.".format(checkpoint_dir))
|
||||
return False
|
||||
valuehead_state_dict = torch.load(valuehead_file, map_location="cpu")
|
||||
model.register_buffer("reward_head_weight", valuehead_state_dict["summary.weight"])
|
||||
model.register_buffer("reward_head_bias", valuehead_state_dict["summary.bias"])
|
||||
model.register_buffer("default_head_weight", torch.zeros_like(valuehead_state_dict["summary.weight"]))
|
||||
model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"]))
|
||||
model.register_buffer("reward_head_weight", valuehead_state_dict["summary.weight"], persistent=False)
|
||||
model.register_buffer("reward_head_bias", valuehead_state_dict["summary.bias"], persistent=False)
|
||||
model.register_buffer("default_head_weight", torch.zeros_like(valuehead_state_dict["summary.weight"]), persistent=False)
|
||||
model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"]), persistent=False)
|
||||
return True
|
||||
|
||||
Reference in New Issue
Block a user