support lora target auto find

Former-commit-id: bce9984733d88bf013847eed523d1c75fdf0995e
This commit is contained in:
hiyouga
2023-09-09 15:38:37 +08:00
parent 50e93392dd
commit 7143c551ab
11 changed files with 117 additions and 72 deletions

View File

@@ -42,8 +42,8 @@ def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -
logger.warning("Provided path ({}) does not contain valuehead weights.".format(checkpoint_dir))
return False
valuehead_state_dict = torch.load(valuehead_file, map_location="cpu")
model.register_buffer("reward_head_weight", valuehead_state_dict["summary.weight"])
model.register_buffer("reward_head_bias", valuehead_state_dict["summary.bias"])
model.register_buffer("default_head_weight", torch.zeros_like(valuehead_state_dict["summary.weight"]))
model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"]))
model.register_buffer("reward_head_weight", valuehead_state_dict["summary.weight"], persistent=False)
model.register_buffer("reward_head_bias", valuehead_state_dict["summary.bias"], persistent=False)
model.register_buffer("default_head_weight", torch.zeros_like(valuehead_state_dict["summary.weight"]), persistent=False)
model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"]), persistent=False)
return True