allow non-packing pretraining

Former-commit-id: 3fee5cc5a3db9ce874ad90f2500ec092d904bd4e
This commit is contained in:
hiyouga
2024-03-09 22:21:46 +08:00
parent c631799f5d
commit 4881f4e631
22 changed files with 64 additions and 67 deletions

View File

@@ -1,6 +1,6 @@
import json
import os
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
import torch
from transformers import Trainer
@@ -26,7 +26,7 @@ class PairwiseTrainer(Trainer):
self.can_return_loss = True # override property to return eval_loss
def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, torch.Tensor], return_outputs: Optional[bool] = False
self, model: "PreTrainedModel", inputs: Dict[str, torch.Tensor], return_outputs: bool = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
r"""
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.