fix dpo trainer

Former-commit-id: c160dd7cd86e296e32775ace2e4258a473449c41
This commit is contained in:
hiyouga
2023-12-23 01:51:55 +08:00
parent 0fdd6074c3
commit d358d955e5
2 changed files with 5 additions and 2 deletions

View File

@@ -69,7 +69,7 @@ class CustomDPOTrainer(DPOTrainer):
Returns:
A tensor of shape (batch_size,) containing the cross-entropy loss of each samples.
"""
all_logps = self._get_batch_logps(
all_logps = self.get_batch_logps(
chosen_logits,
chosen_labels,
average_log_prob=True
@@ -89,7 +89,7 @@ class CustomDPOTrainer(DPOTrainer):
return_dict=True
).logits.to(torch.float32)
all_logps = self._get_batch_logps(
all_logps = self.get_batch_logps(
all_logits,
batch["labels"],
average_log_prob=False