Former-commit-id: 412d856eeada2abcea598fac0a8d35ae90cc9c01
This commit is contained in:
hiyouga
2024-02-06 15:23:08 +08:00
parent 0dd68d1e06
commit b564b97b7e
2 changed files with 15 additions and 1 deletions

View File

@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import torch
from datasets import load_dataset
from peft import PeftModel
from transformers import BitsAndBytesConfig, GPTQConfig, PreTrainedModel, PreTrainedTokenizerBase
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version
@@ -307,7 +308,12 @@ def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None:
if isinstance(self.pretrained_model, PreTrainedModel):
return self.pretrained_model.get_input_embeddings()
def create_or_update_model_card(self: "AutoModelForCausalLMWithValueHead", output_dir: str) -> None:
if isinstance(self.pretrained_model, PeftModel):
self.pretrained_model.create_or_update_model_card(output_dir)
ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name]
setattr(model, "_keys_to_ignore_on_save", ignore_modules)
setattr(model, "tie_weights", MethodType(tie_weights, model))
setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model))
setattr(model, "create_or_update_model_card", MethodType(create_or_update_model_card, model))