fix plots

Former-commit-id: 81355671296b84d438967463bb2a92934ff31aae
This commit is contained in:
hiyouga
2024-03-31 19:43:48 +08:00
parent 00e17a377c
commit c5a46f9113
4 changed files with 5 additions and 6 deletions

View File

@@ -142,11 +142,10 @@ class CustomDPOTrainer(DPOTrainer):
reference_chosen_logps,
reference_rejected_logps,
)
batch_loss = losses.mean()
if self.ftx_gamma > 1e-6:
batch_size = batch["input_ids"].size(0) // 2
chosen_labels, _ = batch["labels"].split(batch_size, dim=0)
batch_loss += self.ftx_gamma * self.sft_loss(policy_chosen_logits, chosen_labels).mean()
losses += self.ftx_gamma * self.sft_loss(policy_chosen_logits, chosen_labels)
reward_accuracies = (chosen_rewards > rejected_rewards).float()
@@ -160,4 +159,4 @@ class CustomDPOTrainer(DPOTrainer):
metrics["{}logits/rejected".format(prefix)] = policy_rejected_logits.detach().cpu().mean()
metrics["{}logits/chosen".format(prefix)] = policy_chosen_logits.detach().cpu().mean()
return batch_loss, metrics
return losses.mean(), metrics