add tool test

Former-commit-id: 639a355a9ceb2e4585b81aea71fc810f4b510776
This commit is contained in:
hiyouga
2024-01-18 10:26:26 +08:00
parent a423274fd9
commit d8affd3967
9 changed files with 63 additions and 37 deletions

View File

@@ -65,17 +65,17 @@ class Evaluator:
inputs, outputs, labels = [], [], []
for i in trange(len(dataset[self.data_args.split]), desc="Formatting batches", position=1, leave=False):
support_set = dataset["train"].shuffle().select(range(min(self.eval_args.n_shot, len(dataset["train"]))))
query, resp, history = self.eval_template.format_example(
messages = self.eval_template.format_example(
target_data=dataset[self.data_args.split][i],
support_set=support_set,
subject_name=categorys[subject]["name"],
use_history=self.template.use_history
subject_name=categorys[subject]["name"]
)
input_ids, _ = self.template.encode_oneturn(
tokenizer=self.tokenizer, query=query, resp=resp, history=history
tokenizer=self.tokenizer, messages=messages
)
inputs.append({"input_ids": input_ids, "attention_mask": [1] * len(input_ids)})
labels.append(resp)
labels.append(messages[-1]["content"])
for i in trange(0, len(inputs), self.eval_args.batch_size, desc="Predicting batches", position=1, leave=False):
batch_input = self.tokenizer.pad(

View File

@@ -2,6 +2,7 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Tuple
from ..extras.constants import CHOICES
from ..data import Role
if TYPE_CHECKING:
from datasets import Dataset
@@ -28,20 +29,23 @@ class EvalTemplate:
support_set: "Dataset",
subject_name: str,
use_history: bool
) -> Tuple[str, str, List[Tuple[str, str]]]:
query, resp = self.parse_example(target_data)
history = [self.parse_example(support_set[k]) for k in range(len(support_set))]
) -> List[Dict[str, str]]:
messages = []
for k in range(len(support_set)):
prompt, response = self.parse_example(support_set[k])
messages.append({"role": Role.USER, "content": prompt})
messages.append({"role": Role.ASSISTANT, "content": response})
if len(history):
temp = history.pop(0)
history.insert(0, (self.system.format(subject=subject_name) + temp[0], temp[1]))
else:
query = self.system.format(subject=subject_name) + query
prompt, response = self.parse_example(target_data)
messages.append({"role": Role.USER, "content": prompt})
messages.append({"role": Role.ASSISTANT, "content": response})
messages[0]["content"] = self.system.format(subject=subject_name) + messages[0]["content"]
if not use_history:
query = "\n\n".join(["".join(item) for item in history] + [query])
history = []
return query.strip(), resp, history
messages = [{"role": Role.USER, "content": "\n\n".join([message["content"] for message in messages[:-1]])}]
return messages
eval_templates: Dict[str, "EvalTemplate"] = {}