fix ChatGLM lm_head #494

Former-commit-id: bf0048abdaeb2b9592d38ac991704ad014370b47
This commit is contained in:
hiyouga
2023-08-14 14:14:48 +08:00
parent 0bfeed3a7e
commit bceaba551d
3 changed files with 12 additions and 8 deletions

View File

@@ -153,6 +153,10 @@ def load_model_and_tokenizer(
if "GenerationMixin" not in str(model.generate.__func__):
model.generate = MethodType(PreTrainedModel.generate, model)
# Fix LM head (for ChatGLM2)
if not hasattr(model, "lm_head"):
setattr(model, "lm_head", model.transformer.output_layer)
# Register auto class to save the custom code files.
if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
config.__class__.register_for_auto_class()

View File

@@ -32,11 +32,11 @@ class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
self._stored_metrics = defaultdict(lambda: defaultdict(list))
Trainer.__init__(self, **kwargs)
if not hasattr(self, "accelerator"):
raise AttributeError("Please update `transformers`.")
if ref_model is not None:
if hasattr(self, "accelerator"):
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
else:
raise AttributeError("Please update `transformers`.")
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
def concatenated_forward(
self,