[trainer] Add LD-DPO objective (#8362)
This commit is contained in:
@@ -585,7 +585,7 @@ def create_custom_scheduler(
|
||||
|
||||
|
||||
def get_batch_logps(
|
||||
logits: "torch.Tensor", labels: "torch.Tensor", label_pad_token_id: int = IGNORE_INDEX
|
||||
logits: "torch.Tensor", labels: "torch.Tensor", label_pad_token_id: int = IGNORE_INDEX, ld_alpha: Optional[float] = None
|
||||
) -> tuple["torch.Tensor", "torch.Tensor"]:
|
||||
r"""Compute the log probabilities of the given labels under the given logits.
|
||||
|
||||
@@ -602,7 +602,30 @@ def get_batch_logps(
|
||||
loss_mask = labels != label_pad_token_id
|
||||
labels[labels == label_pad_token_id] = 0 # dummy token
|
||||
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
|
||||
return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1)
|
||||
|
||||
valid_length = loss_mask.sum(-1)
|
||||
if ld_alpha is not None:
|
||||
num_examples = labels.shape[0] // 2
|
||||
chosen_lengths = valid_length[:num_examples]
|
||||
rejected_lengths = valid_length[num_examples:]
|
||||
min_lengths = torch.min(chosen_lengths, rejected_lengths)
|
||||
start_positions = torch.argmax(loss_mask.int(), dim=1)
|
||||
public_lengths = start_positions + torch.cat([min_lengths, min_lengths], dim=0)
|
||||
|
||||
seq_len = labels.shape[-1]
|
||||
position_ids = torch.arange(seq_len, device=per_token_logps.device).expand_as(per_token_logps)
|
||||
|
||||
ld_mask = position_ids < public_lengths.unsqueeze(1)
|
||||
front_mask = (ld_mask * loss_mask).float()
|
||||
rear_mask = (~ld_mask * loss_mask).float()
|
||||
|
||||
front_logps = (per_token_logps * front_mask).sum(-1)
|
||||
rear_logps = (per_token_logps * rear_mask).sum(-1)
|
||||
logps = front_logps + ld_alpha * rear_logps
|
||||
else:
|
||||
logps = (per_token_logps * loss_mask).sum(-1)
|
||||
|
||||
return logps, valid_length
|
||||
|
||||
|
||||
def nested_detach(
|
||||
|
||||
Reference in New Issue
Block a user