fix dispatch

Former-commit-id: deda82638716506dc690902c51276bb1eb0ddd5e
This commit is contained in:
hiyouga
2024-01-03 16:33:16 +08:00
parent 7168392a51
commit 8c74851b70
2 changed files with 7 additions and 4 deletions

View File

@@ -276,7 +276,5 @@ def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None:
ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name]
setattr(model, "_keys_to_ignore_on_save", ignore_modules)
setattr(model, "_no_split_modules", getattr(model.pretrained_model, "_no_split_modules", None))
setattr(model, "dtype", getattr(model.pretrained_model, "dtype", None))
setattr(model, "tie_weights", MethodType(tie_weights, model))
setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model))