support ORPO
Former-commit-id: f44a4c27e2461cdaa1b16865f597a31033c0e6d9
This commit is contained in:
@@ -74,7 +74,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
||||
return super().create_scheduler(num_training_steps, optimizer)
|
||||
|
||||
def sft_loss(self, chosen_logits: torch.FloatTensor, chosen_labels: torch.LongTensor) -> torch.Tensor:
|
||||
def sft_loss(self, chosen_logits: "torch.FloatTensor", chosen_labels: "torch.LongTensor") -> "torch.Tensor":
|
||||
r"""
|
||||
Computes supervised cross-entropy loss of given labels under the given logits.
|
||||
|
||||
@@ -85,8 +85,8 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
return -all_logps
|
||||
|
||||
def concatenated_forward(
|
||||
self, model: "PreTrainedModel", batch: Dict[str, torch.Tensor]
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
|
||||
|
||||
all_logits = model(
|
||||
@@ -107,9 +107,9 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
def get_batch_loss_metrics(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
batch: Dict[str, torch.Tensor],
|
||||
batch: Dict[str, "torch.Tensor"],
|
||||
train_eval: Literal["train", "eval"] = "train",
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
) -> Tuple["torch.Tensor", Dict[str, "torch.Tensor"]]:
|
||||
r"""
|
||||
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
|
||||
"""
|
||||
@@ -142,21 +142,22 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
reference_chosen_logps,
|
||||
reference_rejected_logps,
|
||||
)
|
||||
batch_loss = losses.mean()
|
||||
if self.ftx_gamma > 1e-6:
|
||||
batch_size = batch["input_ids"].size(0) // 2
|
||||
chosen_labels, _ = batch["labels"].split(batch_size, dim=0)
|
||||
losses += self.ftx_gamma * self.sft_loss(policy_chosen_logits, chosen_labels)
|
||||
batch_loss += self.ftx_gamma * self.sft_loss(policy_chosen_logits, chosen_labels).mean()
|
||||
|
||||
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
||||
|
||||
prefix = "eval_" if train_eval == "eval" else ""
|
||||
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.cpu().mean()
|
||||
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.cpu().mean()
|
||||
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.cpu().mean()
|
||||
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).cpu().mean()
|
||||
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().cpu().mean()
|
||||
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().cpu().mean()
|
||||
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().cpu().mean()
|
||||
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().cpu().mean()
|
||||
metrics["{}rewards/chosen".format(prefix)] = chosen_rewards.cpu().mean()
|
||||
metrics["{}rewards/rejected".format(prefix)] = rejected_rewards.cpu().mean()
|
||||
metrics["{}rewards/accuracies".format(prefix)] = reward_accuracies.cpu().mean()
|
||||
metrics["{}rewards/margins".format(prefix)] = (chosen_rewards - rejected_rewards).cpu().mean()
|
||||
metrics["{}logps/rejected".format(prefix)] = policy_rejected_logps.detach().cpu().mean()
|
||||
metrics["{}logps/chosen".format(prefix)] = policy_chosen_logps.detach().cpu().mean()
|
||||
metrics["{}logits/rejected".format(prefix)] = policy_rejected_logits.detach().cpu().mean()
|
||||
metrics["{}logits/chosen".format(prefix)] = policy_chosen_logits.detach().cpu().mean()
|
||||
|
||||
return losses.mean(), metrics
|
||||
return batch_loss, metrics
|
||||
|
||||
Reference in New Issue
Block a user