optimize predict vram
Former-commit-id: a577e44eee351b3ed8011a33ae01cd713354ff97
This commit is contained in:
@@ -75,8 +75,8 @@ class PairwiseTrainer(Trainer):
|
||||
return super().create_scheduler(num_training_steps, optimizer)
|
||||
|
||||
def compute_loss(
|
||||
self, model: "PreTrainedModel", inputs: Dict[str, torch.Tensor], return_outputs: bool = False
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False
|
||||
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
|
||||
r"""
|
||||
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user