fix IPO and ORPO loss

Former-commit-id: fc27955732aedbb12003faf19b760e2768b228f2
This commit is contained in:
hiyouga
2024-04-01 14:37:53 +08:00
parent c5a46f9113
commit 52d402e2a9
2 changed files with 27 additions and 49 deletions

View File

@@ -87,16 +87,22 @@ class CustomDPOTrainer(DPOTrainer):
def concatenated_forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r"""
Computes the sum log probabilities of the labels under the given logits if loss_type != IPO.
Otherwise the average log probabilities.
"""
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
all_logits = model(
all_logits: "torch.Tensor" = model(
input_ids=batch_copied["input_ids"], attention_mask=batch_copied["attention_mask"], return_dict=True
).logits.to(torch.float32)
all_logps = self.get_batch_logps(
all_logits,
batch["labels"],
average_log_prob=False,
logits=all_logits,
labels=batch_copied["labels"],
average_log_prob=(self.loss_type == "ipo"),
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
batch_size = batch["input_ids"].size(0) // 2