fix streaming in pt stage #548 #549

Former-commit-id: 050e992bee2a9293cc7399b578de807b5bf9bddc
This commit is contained in:
hiyouga
2023-08-17 17:59:26 +08:00
parent c80e56423a
commit e993e717a5
4 changed files with 43 additions and 40 deletions

View File

@@ -25,7 +25,7 @@ class ComputeMetrics:
Uses the model predictions to compute metrics.
"""
preds, labels = eval_preds
score_dict = {"accuracy": [], "rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
@@ -49,6 +49,5 @@ class ComputeMetrics:
bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
score_dict["bleu-4"].append(round(bleu_score * 100, 4))
score_dict["accuracy"].append(float(len(label) != 0 and pred[:len(label)] == label))
return {k: float(np.mean(v)) for k, v in score_dict.items()}