[trainer] Add LD-DPO objective (#8362)

This commit is contained in:
Aman Gupta
2025-06-12 01:10:38 -07:00
committed by GitHub
parent 44f1b9b5ad
commit 8e4ac78607
3 changed files with 35 additions and 5 deletions

View File

@@ -585,7 +585,7 @@ def create_custom_scheduler(
def get_batch_logps(
logits: "torch.Tensor", labels: "torch.Tensor", label_pad_token_id: int = IGNORE_INDEX
logits: "torch.Tensor", labels: "torch.Tensor", label_pad_token_id: int = IGNORE_INDEX, ld_alpha: Optional[float] = None
) -> tuple["torch.Tensor", "torch.Tensor"]:
r"""Compute the log probabilities of the given labels under the given logits.
@@ -602,7 +602,30 @@ def get_batch_logps(
loss_mask = labels != label_pad_token_id
labels[labels == label_pad_token_id] = 0 # dummy token
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1)
valid_length = loss_mask.sum(-1)
if ld_alpha is not None:
num_examples = labels.shape[0] // 2
chosen_lengths = valid_length[:num_examples]
rejected_lengths = valid_length[num_examples:]
min_lengths = torch.min(chosen_lengths, rejected_lengths)
start_positions = torch.argmax(loss_mask.int(), dim=1)
public_lengths = start_positions + torch.cat([min_lengths, min_lengths], dim=0)
seq_len = labels.shape[-1]
position_ids = torch.arange(seq_len, device=per_token_logps.device).expand_as(per_token_logps)
ld_mask = position_ids < public_lengths.unsqueeze(1)
front_mask = (ld_mask * loss_mask).float()
rear_mask = (~ld_mask * loss_mask).float()
front_logps = (per_token_logps * front_mask).sum(-1)
rear_logps = (per_token_logps * rear_mask).sum(-1)
logps = front_logps + ld_alpha * rear_logps
else:
logps = (per_token_logps * loss_mask).sum(-1)
return logps, valid_length
def nested_detach(