add eval acc

Former-commit-id: 7ffde76fbfb6192e3aac31ccc098f31ce89181ae
This commit is contained in:
hiyouga
2024-07-01 03:51:20 +08:00
parent 38c94d2e9c
commit 884b49e662
3 changed files with 31 additions and 17 deletions

View File

@@ -17,9 +17,11 @@
# limitations under the License.
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
from typing import TYPE_CHECKING, Dict
import numpy as np
import torch
from transformers import EvalPrediction
from transformers.utils import is_jieba_available, is_nltk_available
from ...extras.constants import IGNORE_INDEX
@@ -42,6 +44,22 @@ if is_rouge_available():
from rouge_chinese import Rouge
def compute_accuracy(eval_preds: "EvalPrediction") -> Dict[str, float]:
preds, labels = eval_preds.predictions, eval_preds.label_ids
accuracies = []
for i in range(len(preds)):
pred, label = preds[i, 1:], labels[i, :-1]
label_mask = label != IGNORE_INDEX
accuracies.append(np.mean(pred[label_mask] == label[label_mask]))
return {"accuracy": float(np.mean(accuracies))}
def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "torch.Tensor":
logits = logits[0] if isinstance(logits, (list, tuple)) else logits
return torch.argmax(logits, dim=-1)
@dataclass
class ComputeMetrics:
r"""
@@ -50,11 +68,11 @@ class ComputeMetrics:
tokenizer: "PreTrainedTokenizer"
def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
def __call__(self, eval_preds: "EvalPrediction") -> Dict[str, float]:
r"""
Uses the model predictions to compute metrics.
"""
preds, labels = eval_preds
preds, labels = eval_preds.predictions, eval_preds.label_ids
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)