update patcher

Former-commit-id: afb365e515d615dd62f791622450debab60ce5cc
This commit is contained in:
hiyouga
2024-06-19 21:27:00 +08:00
parent a7d7f79855
commit 5f5d4c1923
3 changed files with 10 additions and 7 deletions

View File

@@ -152,6 +152,10 @@ def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None:
if isinstance(self.pretrained_model, PreTrainedModel):
return self.pretrained_model.get_input_embeddings()
def get_output_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module:
if isinstance(self.pretrained_model, PreTrainedModel):
return self.pretrained_model.get_output_embeddings()
def create_or_update_model_card(self: "AutoModelForCausalLMWithValueHead", output_dir: str) -> None:
if isinstance(self.pretrained_model, PeftModel):
self.pretrained_model.create_or_update_model_card(output_dir)
@@ -160,4 +164,5 @@ def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None:
setattr(model, "_keys_to_ignore_on_save", ignore_modules)
setattr(model, "tie_weights", MethodType(tie_weights, model))
setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model))
setattr(model, "get_output_embeddings", MethodType(get_output_embeddings, model))
setattr(model, "create_or_update_model_card", MethodType(create_or_update_model_card, model))