add tool test

Former-commit-id: 639a355a9ceb2e4585b81aea71fc810f4b510776
This commit is contained in:
hiyouga
2024-01-18 10:26:26 +08:00
parent a423274fd9
commit d8affd3967
9 changed files with 63 additions and 37 deletions

View File

@@ -2,6 +2,7 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Tuple
from ..extras.constants import CHOICES
from ..data import Role
if TYPE_CHECKING:
from datasets import Dataset
@@ -28,20 +29,23 @@ class EvalTemplate:
support_set: "Dataset",
subject_name: str,
use_history: bool
) -> Tuple[str, str, List[Tuple[str, str]]]:
query, resp = self.parse_example(target_data)
history = [self.parse_example(support_set[k]) for k in range(len(support_set))]
) -> List[Dict[str, str]]:
messages = []
for k in range(len(support_set)):
prompt, response = self.parse_example(support_set[k])
messages.append({"role": Role.USER, "content": prompt})
messages.append({"role": Role.ASSISTANT, "content": response})
if len(history):
temp = history.pop(0)
history.insert(0, (self.system.format(subject=subject_name) + temp[0], temp[1]))
else:
query = self.system.format(subject=subject_name) + query
prompt, response = self.parse_example(target_data)
messages.append({"role": Role.USER, "content": prompt})
messages.append({"role": Role.ASSISTANT, "content": response})
messages[0]["content"] = self.system.format(subject=subject_name) + messages[0]["content"]
if not use_history:
query = "\n\n".join(["".join(item) for item in history] + [query])
history = []
return query.strip(), resp, history
messages = [{"role": Role.USER, "content": "\n\n".join([message["content"] for message in messages[:-1]])}]
return messages
eval_templates: Dict[str, "EvalTemplate"] = {}