[feature] add support for dft loss (#8917)

This commit is contained in:
XLXW
2025-08-15 23:29:57 +08:00
committed by GitHub
parent 936f4fd78e
commit 1ada15981a
4 changed files with 97 additions and 0 deletions

View File

@@ -631,6 +631,51 @@ def get_batch_logps(
return logps, valid_length
def dft_loss_func(outputs, labels, num_items_in_batch=None):
logits = outputs.get("logits")
if logits is None:
return outputs.get("loss", torch.tensor(0.0))
logits = logits.float()
vocab_size = logits.size(-1)
labels = torch.nn.functional.pad(labels, (0, 1), value=-100)
shift_labels = labels[..., 1:].contiguous()
logits = logits.view(-1, vocab_size)
shift_labels = shift_labels.view(-1)
shift_labels = shift_labels.to(logits.device)
loss = _dft_cross_entropy(logits, shift_labels, num_items_in_batch)
return loss
def _dft_cross_entropy(
source: torch.Tensor,
target: torch.Tensor,
num_items_in_batch: Optional[torch.Tensor] = None,
ignore_index: int = -100,
) -> torch.Tensor:
per_token_loss = torch.nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction="none")
valid_mask = target != ignore_index
if not valid_mask.any():
return torch.tensor(0.0, device=source.device, dtype=source.dtype)
valid_losses = per_token_loss[valid_mask]
with torch.no_grad():
target_probs = torch.exp(-valid_losses)
weighted_losses = valid_losses * target_probs
if num_items_in_batch is not None:
total_loss = weighted_losses.sum()
if torch.is_tensor(num_items_in_batch):
num_items_in_batch = num_items_in_batch.to(total_loss.device)
loss = total_loss / num_items_in_batch
else:
loss = weighted_losses.mean()
return loss
def nested_detach(
tensors: Union["torch.Tensor", list["torch.Tensor"], tuple["torch.Tensor"], dict[str, "torch.Tensor"]],
clone: bool = False,