@@ -45,6 +45,9 @@ if is_rouge_available():
|
||||
|
||||
|
||||
def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "torch.Tensor":
|
||||
r"""
|
||||
Computes the token with the largest likelihood to reduce memory footprint.
|
||||
"""
|
||||
if isinstance(logits, (list, tuple)):
|
||||
if logits[0].dim() == 3: # (batch_size, seq_len, vocab_size)
|
||||
logits = logits[0]
|
||||
@@ -59,6 +62,9 @@ def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "tor
|
||||
|
||||
@dataclass
|
||||
class ComputeAccuracy:
|
||||
r"""
|
||||
Computes accuracy and supports `batch_eval_metrics`.
|
||||
"""
|
||||
def _dump(self) -> Optional[Dict[str, float]]:
|
||||
result = None
|
||||
if hasattr(self, "score_dict"):
|
||||
@@ -84,6 +90,8 @@ class ComputeAccuracy:
|
||||
@dataclass
|
||||
class ComputeSimilarity:
|
||||
r"""
|
||||
Computes text similarity scores and supports `batch_eval_metrics`.
|
||||
|
||||
Wraps the tokenizer into metric functions, used in CustomSeq2SeqTrainer.
|
||||
"""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user