add CMMLU, update eval script

Former-commit-id: 47f31f06a946eefa5a972e4a566cf3ce05e1e111
This commit is contained in:
hiyouga
2023-09-23 21:10:17 +08:00
parent f7cecd20e3
commit 73c48d0463
5 changed files with 237 additions and 61 deletions

View File

@@ -1,16 +1,15 @@
# coding=utf-8
# Evaluates the performance of pre-trained models.
# Usage: python evaluate.py --model_name_or_path path_to_model --checkpoint_dir path_to_ckpt --template vanilla
# --task ceval --split validation --lang zh --n_shot 5 --batch_size 4
# --task ceval --split validation --lang zh --n_shot 5 --batch_size 4 --save_name result
# Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py
import os
import fire
import json
import torch
import random
import numpy as np
from tqdm import tqdm
from tqdm import tqdm, trange
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple
from datasets import load_dataset
from dataclasses import dataclass
@@ -30,6 +29,7 @@ class EvalTemplate:
system: str
choice: str
answer: str
prefix: str
def parse_example(
self,
@@ -49,7 +49,6 @@ class EvalTemplate:
history = [self.parse_example(support_set[k]) for k in range(len(support_set))]
if len(history):
random.shuffle(history)
temp = history.pop(0)
history.insert(0, (self.system.format(subject=subject_name) + temp[0], temp[1]))
else:
@@ -65,12 +64,14 @@ eval_templates = {
"en": EvalTemplate(
system="The following are multiple choice questions (with answers) about {subject}.\n\n",
choice="\n{choice}. {content}",
answer="\nAnswer: "
answer="\nAnswer: ",
prefix=" "
),
"zh": EvalTemplate(
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
choice="\n{choice}. {content}",
answer="\n答案:"
answer="\n答案:",
prefix="\n"
)
}
@@ -79,9 +80,8 @@ eval_templates = {
def batch_inference(
chat_model: ChatModel,
batch_input: Dict[str, torch.Tensor],
lang: Literal["zh", "en"]
prefix_char: str
) -> List[str]:
prefix_char = "\n" if lang == "zh" else " "
logits = chat_model.model(**batch_input).logits
probs = torch.nn.functional.softmax(
torch.stack(
@@ -108,7 +108,8 @@ def evaluate(
split: Optional[Literal["validation", "test"]] = "validation",
lang: Optional[Literal["zh", "en"]] = "zh",
n_shot: Optional[int] = 5,
batch_size: Optional[int] = 4
batch_size: Optional[int] = 4,
save_name: Optional[str] = None
):
with open(os.path.join(dataset_dir, task, "mapping.json"), "r", encoding="utf-8") as f:
categorys = json.load(f)
@@ -119,25 +120,25 @@ def evaluate(
checkpoint_dir=checkpoint_dir,
template=template
))
chat_model.tokenizer.padding_side = "left"
eval_template = eval_templates[lang]
assert chat_model.tokenizer.padding_side == "left", "only left-padded tensor can be accepted."
category_corrects: Dict[str, np.ndarray] = {
"STEM": np.array([], dtype="bool"),
"Social Sciences": np.array([], dtype="bool"),
"Humanities": np.array([], dtype="bool"),
"Other": np.array([], dtype="bool")
subj: np.array([], dtype="bool") for subj in ["STEM", "Social Sciences", "Humanities", "Other"]
}
overall_corrects = np.array([], dtype="bool")
pbar = tqdm(categorys.keys())
pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
results = {}
for subject in pbar:
pbar.set_postfix_str(categorys[subject]["name"])
inputs, labels = [], []
dataset = load_dataset(os.path.join(dataset_dir, task), subject)
for i in range(len(dataset[split])):
support_set = dataset["train"].shuffle().select(range(min(n_shot, len(dataset["train"]))))
query, resp, history = eval_template.format_example(
target_data=dataset[split][i],
support_set=dataset["train"].select(range(min(n_shot, len(dataset["train"])))),
support_set=support_set,
subject_name=categorys[subject]["name"],
use_history=chat_model.template.use_history
)
@@ -154,23 +155,33 @@ def evaluate(
labels.append(resp)
outputs = []
for i in range(0, len(inputs), batch_size):
for i in trange(0, len(inputs), batch_size, desc="Processing batches", position=1, leave=False):
batch_input = chat_model.tokenizer.pad(
inputs[i : i + batch_size],
return_attention_mask=True,
return_tensors="pt"
).to(chat_model.model.device)
preds = batch_inference(chat_model, batch_input, lang)
preds = batch_inference(chat_model, batch_input, eval_template.prefix)
outputs += preds
corrects = (np.array(outputs) == np.array(labels))
category_name = categorys[subject]["category"]
category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0)
overall_corrects = np.concatenate([overall_corrects, corrects], axis=0)
results[subject] = {str(i): outputs[i] for i in range(len(outputs))}
print("Average accuracy: {:.2f}".format(100 * np.mean(overall_corrects)))
score_info = "Average accuracy: {:.2f}".format(100 * np.mean(overall_corrects))
for category_name, category_correct in category_corrects.items():
print(" {} - {:.2f}".format(category_name, 100 * np.mean(category_correct)))
if len(category_correct):
score_info += "\n{:>16}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
print(score_info)
if save_name is not None:
with open(save_name + ".json", "w", encoding="utf-8", newline="\n") as f:
json.dump(results, f, indent=2)
with open(save_name + ".log", "w", encoding="utf-8", newline="\n") as f:
f.write(score_info)
if __name__ == "__main__":