Former-commit-id: ff9a3f73961a362d0ddc22079f80a85465fffda8
This commit is contained in:
hiyouga
2024-04-01 22:53:52 +08:00
parent 85726c91ce
commit 1dc963caa6
4 changed files with 23 additions and 15 deletions

View File

@@ -73,7 +73,7 @@ class CustomORPOTrainer(DPOTrainer):
Computes the average log probabilities of the labels under the given logits.
"""
all_logits: "torch.Tensor" = model(
input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], return_dict=True
input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], return_dict=True, use_cache=False
).logits.to(torch.float32)
all_logps = self.get_batch_logps(