[trainer] Add LD-DPO objective (#8362)

This commit is contained in:
Aman Gupta
2025-06-12 01:10:38 -07:00
committed by GitHub
parent 44f1b9b5ad
commit 8e4ac78607
3 changed files with 35 additions and 5 deletions

View File

@@ -80,6 +80,7 @@ class CustomDPOTrainer(DPOTrainer):
self.ftx_gamma = finetuning_args.pref_ftx
self.label_smoothing = finetuning_args.dpo_label_smoothing
self.simpo_gamma = finetuning_args.simpo_gamma
self.ld_alpha = finetuning_args.ld_alpha
Trainer.__init__(self, model=model, **kwargs)
self.model_accepts_loss_kwargs = False # overwrite trainer's default behavior
@@ -177,7 +178,7 @@ class CustomDPOTrainer(DPOTrainer):
@override
def concatenated_forward(
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"]
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"], is_ref_model: bool = False
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r"""Compute the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO.
@@ -187,7 +188,8 @@ class CustomDPOTrainer(DPOTrainer):
batch = nested_detach(batch, clone=True) # avoid error
all_logits: torch.Tensor = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"])
all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"],
ld_alpha=(self.ld_alpha if not is_ref_model else None))
if self.loss_type in ["ipo", "orpo", "simpo"]:
all_logps = all_logps / valid_length
@@ -217,7 +219,8 @@ class CustomDPOTrainer(DPOTrainer):
ref_context = nullcontext()
with torch.no_grad(), ref_context:
reference_chosen_logps, reference_rejected_logps, *_ = self.concatenated_forward(ref_model, batch)
reference_chosen_logps, reference_rejected_logps, *_ = self.concatenated_forward(ref_model, batch,
is_ref_model=True)
return reference_chosen_logps, reference_rejected_logps