fix ChatGLM lm_head #494
Former-commit-id: bf0048abdaeb2b9592d38ac991704ad014370b47
This commit is contained in:
@@ -32,11 +32,11 @@ class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
|
||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
Trainer.__init__(self, **kwargs)
|
||||
if not hasattr(self, "accelerator"):
|
||||
raise AttributeError("Please update `transformers`.")
|
||||
|
||||
if ref_model is not None:
|
||||
if hasattr(self, "accelerator"):
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
else:
|
||||
raise AttributeError("Please update `transformers`.")
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
|
||||
def concatenated_forward(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user