remove .cpu()
Former-commit-id: 35c57cc9dcba305d40282a9757ddc23968c210ac
This commit is contained in:
@@ -54,7 +54,7 @@ def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "tor
|
||||
if logits.dim() != 3:
|
||||
raise ValueError("Cannot process the logits.")
|
||||
|
||||
return torch.argmax(logits, dim=-1).cpu()
|
||||
return torch.argmax(logits, dim=-1)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
Reference in New Issue
Block a user