add CMMLU, update eval script
Former-commit-id: 47f31f06a946eefa5a972e4a566cf3ce05e1e111
This commit is contained in:
@@ -1,16 +1,15 @@
|
||||
# coding=utf-8
|
||||
# Evaluates the performance of pre-trained models.
|
||||
# Usage: python evaluate.py --model_name_or_path path_to_model --checkpoint_dir path_to_ckpt --template vanilla
|
||||
# --task ceval --split validation --lang zh --n_shot 5 --batch_size 4
|
||||
# --task ceval --split validation --lang zh --n_shot 5 --batch_size 4 --save_name result
|
||||
# Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py
|
||||
|
||||
import os
|
||||
import fire
|
||||
import json
|
||||
import torch
|
||||
import random
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from tqdm import tqdm, trange
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple
|
||||
from datasets import load_dataset
|
||||
from dataclasses import dataclass
|
||||
@@ -30,6 +29,7 @@ class EvalTemplate:
|
||||
system: str
|
||||
choice: str
|
||||
answer: str
|
||||
prefix: str
|
||||
|
||||
def parse_example(
|
||||
self,
|
||||
@@ -49,7 +49,6 @@ class EvalTemplate:
|
||||
history = [self.parse_example(support_set[k]) for k in range(len(support_set))]
|
||||
|
||||
if len(history):
|
||||
random.shuffle(history)
|
||||
temp = history.pop(0)
|
||||
history.insert(0, (self.system.format(subject=subject_name) + temp[0], temp[1]))
|
||||
else:
|
||||
@@ -65,12 +64,14 @@ eval_templates = {
|
||||
"en": EvalTemplate(
|
||||
system="The following are multiple choice questions (with answers) about {subject}.\n\n",
|
||||
choice="\n{choice}. {content}",
|
||||
answer="\nAnswer: "
|
||||
answer="\nAnswer: ",
|
||||
prefix=" "
|
||||
),
|
||||
"zh": EvalTemplate(
|
||||
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
|
||||
choice="\n{choice}. {content}",
|
||||
answer="\n答案:"
|
||||
answer="\n答案:",
|
||||
prefix="\n"
|
||||
)
|
||||
}
|
||||
|
||||
@@ -79,9 +80,8 @@ eval_templates = {
|
||||
def batch_inference(
|
||||
chat_model: ChatModel,
|
||||
batch_input: Dict[str, torch.Tensor],
|
||||
lang: Literal["zh", "en"]
|
||||
prefix_char: str
|
||||
) -> List[str]:
|
||||
prefix_char = "\n" if lang == "zh" else " "
|
||||
logits = chat_model.model(**batch_input).logits
|
||||
probs = torch.nn.functional.softmax(
|
||||
torch.stack(
|
||||
@@ -108,7 +108,8 @@ def evaluate(
|
||||
split: Optional[Literal["validation", "test"]] = "validation",
|
||||
lang: Optional[Literal["zh", "en"]] = "zh",
|
||||
n_shot: Optional[int] = 5,
|
||||
batch_size: Optional[int] = 4
|
||||
batch_size: Optional[int] = 4,
|
||||
save_name: Optional[str] = None
|
||||
):
|
||||
with open(os.path.join(dataset_dir, task, "mapping.json"), "r", encoding="utf-8") as f:
|
||||
categorys = json.load(f)
|
||||
@@ -119,25 +120,25 @@ def evaluate(
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
template=template
|
||||
))
|
||||
chat_model.tokenizer.padding_side = "left"
|
||||
eval_template = eval_templates[lang]
|
||||
assert chat_model.tokenizer.padding_side == "left", "only left-padded tensor can be accepted."
|
||||
|
||||
category_corrects: Dict[str, np.ndarray] = {
|
||||
"STEM": np.array([], dtype="bool"),
|
||||
"Social Sciences": np.array([], dtype="bool"),
|
||||
"Humanities": np.array([], dtype="bool"),
|
||||
"Other": np.array([], dtype="bool")
|
||||
subj: np.array([], dtype="bool") for subj in ["STEM", "Social Sciences", "Humanities", "Other"]
|
||||
}
|
||||
overall_corrects = np.array([], dtype="bool")
|
||||
|
||||
pbar = tqdm(categorys.keys())
|
||||
pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
|
||||
results = {}
|
||||
for subject in pbar:
|
||||
pbar.set_postfix_str(categorys[subject]["name"])
|
||||
inputs, labels = [], []
|
||||
dataset = load_dataset(os.path.join(dataset_dir, task), subject)
|
||||
for i in range(len(dataset[split])):
|
||||
support_set = dataset["train"].shuffle().select(range(min(n_shot, len(dataset["train"]))))
|
||||
query, resp, history = eval_template.format_example(
|
||||
target_data=dataset[split][i],
|
||||
support_set=dataset["train"].select(range(min(n_shot, len(dataset["train"])))),
|
||||
support_set=support_set,
|
||||
subject_name=categorys[subject]["name"],
|
||||
use_history=chat_model.template.use_history
|
||||
)
|
||||
@@ -154,23 +155,33 @@ def evaluate(
|
||||
labels.append(resp)
|
||||
|
||||
outputs = []
|
||||
for i in range(0, len(inputs), batch_size):
|
||||
for i in trange(0, len(inputs), batch_size, desc="Processing batches", position=1, leave=False):
|
||||
batch_input = chat_model.tokenizer.pad(
|
||||
inputs[i : i + batch_size],
|
||||
return_attention_mask=True,
|
||||
return_tensors="pt"
|
||||
).to(chat_model.model.device)
|
||||
preds = batch_inference(chat_model, batch_input, lang)
|
||||
preds = batch_inference(chat_model, batch_input, eval_template.prefix)
|
||||
outputs += preds
|
||||
|
||||
corrects = (np.array(outputs) == np.array(labels))
|
||||
category_name = categorys[subject]["category"]
|
||||
category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0)
|
||||
overall_corrects = np.concatenate([overall_corrects, corrects], axis=0)
|
||||
results[subject] = {str(i): outputs[i] for i in range(len(outputs))}
|
||||
|
||||
print("Average accuracy: {:.2f}".format(100 * np.mean(overall_corrects)))
|
||||
score_info = "Average accuracy: {:.2f}".format(100 * np.mean(overall_corrects))
|
||||
for category_name, category_correct in category_corrects.items():
|
||||
print(" {} - {:.2f}".format(category_name, 100 * np.mean(category_correct)))
|
||||
if len(category_correct):
|
||||
score_info += "\n{:>16}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
|
||||
|
||||
print(score_info)
|
||||
if save_name is not None:
|
||||
with open(save_name + ".json", "w", encoding="utf-8", newline="\n") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
|
||||
with open(save_name + ".log", "w", encoding="utf-8", newline="\n") as f:
|
||||
f.write(score_info)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user