allow non-packing pretraining
Former-commit-id: 3fee5cc5a3db9ce874ad90f2500ec092d904bd4e
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import Trainer
|
||||
@@ -26,7 +26,7 @@ class PairwiseTrainer(Trainer):
|
||||
self.can_return_loss = True # override property to return eval_loss
|
||||
|
||||
def compute_loss(
|
||||
self, model: "PreTrainedModel", inputs: Dict[str, torch.Tensor], return_outputs: Optional[bool] = False
|
||||
self, model: "PreTrainedModel", inputs: Dict[str, torch.Tensor], return_outputs: bool = False
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
r"""
|
||||
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
|
||||
|
||||
Reference in New Issue
Block a user