add tool test

Former-commit-id: 639a355a9ceb2e4585b81aea71fc810f4b510776
This commit is contained in:
hiyouga
2024-01-18 10:26:26 +08:00
parent a423274fd9
commit d8affd3967
9 changed files with 63 additions and 37 deletions

View File

@@ -65,17 +65,17 @@ class Evaluator:
inputs, outputs, labels = [], [], []
for i in trange(len(dataset[self.data_args.split]), desc="Formatting batches", position=1, leave=False):
support_set = dataset["train"].shuffle().select(range(min(self.eval_args.n_shot, len(dataset["train"]))))
query, resp, history = self.eval_template.format_example(
messages = self.eval_template.format_example(
target_data=dataset[self.data_args.split][i],
support_set=support_set,
subject_name=categorys[subject]["name"],
use_history=self.template.use_history
subject_name=categorys[subject]["name"]
)
input_ids, _ = self.template.encode_oneturn(
tokenizer=self.tokenizer, query=query, resp=resp, history=history
tokenizer=self.tokenizer, messages=messages
)
inputs.append({"input_ids": input_ids, "attention_mask": [1] * len(input_ids)})
labels.append(resp)
labels.append(messages[-1]["content"])
for i in trange(0, len(inputs), self.eval_args.batch_size, desc="Predicting batches", position=1, leave=False):
batch_input = self.tokenizer.pad(