Former-commit-id: f7aa06c9c0b18c28419ea5792410915d3f322cbf
This commit is contained in:
hiyouga
2024-09-02 23:56:21 +08:00
parent 7367c6ec21
commit b5e9df5df8
5 changed files with 57 additions and 38 deletions

View File

@@ -45,6 +45,9 @@ if is_rouge_available():
def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "torch.Tensor":
r"""
Computes the token with the largest likelihood to reduce memory footprint.
"""
if isinstance(logits, (list, tuple)):
if logits[0].dim() == 3: # (batch_size, seq_len, vocab_size)
logits = logits[0]
@@ -59,6 +62,9 @@ def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "tor
@dataclass
class ComputeAccuracy:
r"""
Computes accuracy and supports `batch_eval_metrics`.
"""
def _dump(self) -> Optional[Dict[str, float]]:
result = None
if hasattr(self, "score_dict"):
@@ -84,6 +90,8 @@ class ComputeAccuracy:
@dataclass
class ComputeSimilarity:
r"""
Computes text similarity scores and supports `batch_eval_metrics`.
Wraps the tokenizer into metric functions, used in CustomSeq2SeqTrainer.
"""