[trainer] Add LD-DPO objective (#8362)
This commit is contained in:
@@ -80,6 +80,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
self.ftx_gamma = finetuning_args.pref_ftx
|
||||
self.label_smoothing = finetuning_args.dpo_label_smoothing
|
||||
self.simpo_gamma = finetuning_args.simpo_gamma
|
||||
self.ld_alpha = finetuning_args.ld_alpha
|
||||
|
||||
Trainer.__init__(self, model=model, **kwargs)
|
||||
self.model_accepts_loss_kwargs = False # overwrite trainer's default behavior
|
||||
@@ -177,7 +178,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
|
||||
@override
|
||||
def concatenated_forward(
|
||||
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"]
|
||||
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"], is_ref_model: bool = False
|
||||
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
r"""Compute the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO.
|
||||
|
||||
@@ -187,7 +188,8 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
batch = nested_detach(batch, clone=True) # avoid error
|
||||
|
||||
all_logits: torch.Tensor = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
|
||||
all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"])
|
||||
all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"],
|
||||
ld_alpha=(self.ld_alpha if not is_ref_model else None))
|
||||
if self.loss_type in ["ipo", "orpo", "simpo"]:
|
||||
all_logps = all_logps / valid_length
|
||||
|
||||
@@ -217,7 +219,8 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
ref_context = nullcontext()
|
||||
|
||||
with torch.no_grad(), ref_context:
|
||||
reference_chosen_logps, reference_rejected_logps, *_ = self.concatenated_forward(ref_model, batch)
|
||||
reference_chosen_logps, reference_rejected_logps, *_ = self.concatenated_forward(ref_model, batch,
|
||||
is_ref_model=True)
|
||||
|
||||
return reference_chosen_logps, reference_rejected_logps
|
||||
|
||||
|
||||
Reference in New Issue
Block a user