optimize predict vram

Former-commit-id: a577e44eee351b3ed8011a33ae01cd713354ff97
This commit is contained in:
hiyouga
2024-08-30 23:08:45 +08:00
parent 66a1abac6a
commit d789b667d7
5 changed files with 10 additions and 10 deletions

View File

@@ -75,8 +75,8 @@ class PairwiseTrainer(Trainer):
return super().create_scheduler(num_training_steps, optimizer)
def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, torch.Tensor], return_outputs: bool = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
r"""
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.