fix dpo trainer
Former-commit-id: c160dd7cd86e296e32775ace2e4258a473449c41
This commit is contained in:
@@ -69,7 +69,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
Returns:
|
||||
A tensor of shape (batch_size,) containing the cross-entropy loss of each samples.
|
||||
"""
|
||||
all_logps = self._get_batch_logps(
|
||||
all_logps = self.get_batch_logps(
|
||||
chosen_logits,
|
||||
chosen_labels,
|
||||
average_log_prob=True
|
||||
@@ -89,7 +89,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
return_dict=True
|
||||
).logits.to(torch.float32)
|
||||
|
||||
all_logps = self._get_batch_logps(
|
||||
all_logps = self.get_batch_logps(
|
||||
all_logits,
|
||||
batch["labels"],
|
||||
average_log_prob=False
|
||||
|
||||
Reference in New Issue
Block a user