support batch_eval_metrics, fix #4826

Former-commit-id: 3fe1df17188825f8a32fbe6a1294b4b532ce0c85
This commit is contained in:
hiyouga
2024-07-17 00:33:00 +08:00
parent 45367105fc
commit 8c93921952
7 changed files with 85 additions and 36 deletions

View File

@@ -17,13 +17,14 @@
# limitations under the License.
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING, Dict, Optional
import numpy as np
import torch
from transformers.utils import is_jieba_available, is_nltk_available
from ...extras.constants import IGNORE_INDEX
from ...extras.misc import numpify
from ...extras.packages import is_rouge_available
@@ -43,17 +44,6 @@ if is_rouge_available():
from rouge_chinese import Rouge
def compute_accuracy(eval_preds: "EvalPrediction") -> Dict[str, float]:
preds, labels = eval_preds.predictions, eval_preds.label_ids
accuracies = []
for i in range(len(preds)):
pred, label = preds[i, :-1], labels[i, 1:]
label_mask = label != IGNORE_INDEX
accuracies.append(np.mean(pred[label_mask] == label[label_mask]))
return {"accuracy": float(np.mean(accuracies))}
def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "torch.Tensor":
if isinstance(logits, (list, tuple)):
if logits[0].dim() == 3: # (batch_size, seq_len, vocab_size)
@@ -68,19 +58,34 @@ def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "tor
@dataclass
class ComputeMetrics:
class ComputeAccuracy:
def __post_init__(self):
self.score_dict = {"accuracy": []}
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]:
preds, labels = numpify(eval_preds.predictions), numpify(eval_preds.label_ids)
for i in range(len(preds)):
pred, label = preds[i, :-1], labels[i, 1:]
label_mask = label != IGNORE_INDEX
self.score_dict["accuracy"].append(np.mean(pred[label_mask] == label[label_mask]))
if compute_result:
return {"accuracy": float(np.mean(self.score_dict["accuracy"]))}
@dataclass
class ComputeSimilarity:
r"""
Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer.
Wraps the tokenizer into metric functions, used in CustomSeq2SeqTrainer.
"""
tokenizer: "PreTrainedTokenizer"
def __call__(self, eval_preds: "EvalPrediction") -> Dict[str, float]:
r"""
Uses the model predictions to compute metrics.
"""
preds, labels = eval_preds.predictions, eval_preds.label_ids
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
def __post_init__(self):
self.score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]:
preds, labels = numpify(eval_preds.predictions), numpify(eval_preds.label_ids)
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
@@ -100,9 +105,10 @@ class ComputeMetrics:
result = scores[0]
for k, v in result.items():
score_dict[k].append(round(v["f"] * 100, 4))
self.score_dict[k].append(round(v["f"] * 100, 4))
bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
score_dict["bleu-4"].append(round(bleu_score * 100, 4))
self.score_dict["bleu-4"].append(round(bleu_score * 100, 4))
return {k: float(np.mean(v)) for k, v in score_dict.items()}
if compute_result:
return {k: float(np.mean(v)) for k, v in self.score_dict.items()}