fix plots
Former-commit-id: 81355671296b84d438967463bb2a92934ff31aae
This commit is contained in:
@@ -142,11 +142,10 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
reference_chosen_logps,
|
||||
reference_rejected_logps,
|
||||
)
|
||||
batch_loss = losses.mean()
|
||||
if self.ftx_gamma > 1e-6:
|
||||
batch_size = batch["input_ids"].size(0) // 2
|
||||
chosen_labels, _ = batch["labels"].split(batch_size, dim=0)
|
||||
batch_loss += self.ftx_gamma * self.sft_loss(policy_chosen_logits, chosen_labels).mean()
|
||||
losses += self.ftx_gamma * self.sft_loss(policy_chosen_logits, chosen_labels)
|
||||
|
||||
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
||||
|
||||
@@ -160,4 +159,4 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
metrics["{}logits/rejected".format(prefix)] = policy_rejected_logits.detach().cpu().mean()
|
||||
metrics["{}logits/chosen".format(prefix)] = policy_chosen_logits.detach().cpu().mean()
|
||||
|
||||
return batch_loss, metrics
|
||||
return losses.mean(), metrics
|
||||
|
||||
Reference in New Issue
Block a user