optimize predict vram

Former-commit-id: a577e44eee351b3ed8011a33ae01cd713354ff97
This commit is contained in:
hiyouga
2024-08-30 23:08:45 +08:00
parent 66a1abac6a
commit d789b667d7
5 changed files with 10 additions and 10 deletions

View File

@@ -54,7 +54,7 @@ def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "tor
if logits.dim() != 3:
raise ValueError("Cannot process the logits.")
return torch.argmax(logits, dim=-1)
return torch.argmax(logits, dim=-1).cpu()
@dataclass