Former-commit-id: 7ec64588c541422875adfdaf5692a27d05b96cb9
This commit is contained in:
hiyouga
2024-01-19 21:44:32 +08:00
parent 384f0e7678
commit 0868d5c550
4 changed files with 18 additions and 11 deletions

View File

@@ -4,6 +4,7 @@ import os
import json
import torch
import numpy as np
import inspect
from tqdm import tqdm, trange
from typing import Any, Dict, List, Optional
@@ -53,13 +54,18 @@ class Evaluator:
pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
results = {}
for subject in pbar:
if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0
kwargs = {"trust_remote_code": True}
else:
kwargs = {}
dataset = load_dataset(
path=os.path.join(self.eval_args.task_dir, self.eval_args.task),
name=subject,
cache_dir=self.model_args.cache_dir,
download_mode=self.eval_args.download_mode,
token=self.model_args.hf_hub_token,
trust_remote_code=True
**kwargs
)
pbar.set_postfix_str(categorys[subject]["name"])
inputs, outputs, labels = [], [], []