mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-01 08:13:38 +00:00
format style
Former-commit-id: 53b683531b83cd1d19de97c6565f16c1eca6f5e1
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Tuple
|
||||
|
||||
from ..extras.constants import CHOICES
|
||||
from ..data import Role
|
||||
from ..extras.constants import CHOICES
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datasets import Dataset
|
||||
@@ -10,24 +11,17 @@ if TYPE_CHECKING:
|
||||
|
||||
@dataclass
|
||||
class EvalTemplate:
|
||||
|
||||
system: str
|
||||
choice: str
|
||||
answer: str
|
||||
prefix: str
|
||||
|
||||
def parse_example(
|
||||
self,
|
||||
example: Dict[str, str]
|
||||
) -> Tuple[str, str]:
|
||||
def parse_example(self, example: Dict[str, str]) -> Tuple[str, str]:
|
||||
candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example]
|
||||
return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
|
||||
|
||||
def format_example(
|
||||
self,
|
||||
target_data: Dict[str, str],
|
||||
support_set: "Dataset",
|
||||
subject_name: str
|
||||
self, target_data: Dict[str, str], support_set: "Dataset", subject_name: str
|
||||
) -> List[Dict[str, str]]:
|
||||
messages = []
|
||||
for k in range(len(support_set)):
|
||||
@@ -45,19 +39,8 @@ class EvalTemplate:
|
||||
eval_templates: Dict[str, "EvalTemplate"] = {}
|
||||
|
||||
|
||||
def register_eval_template(
|
||||
name: str,
|
||||
system: str,
|
||||
choice: str,
|
||||
answer: str,
|
||||
prefix: str
|
||||
) -> None:
|
||||
eval_templates[name] = EvalTemplate(
|
||||
system=system,
|
||||
choice=choice,
|
||||
answer=answer,
|
||||
prefix=prefix
|
||||
)
|
||||
def register_eval_template(name: str, system: str, choice: str, answer: str, prefix: str) -> None:
|
||||
eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer, prefix=prefix)
|
||||
|
||||
|
||||
def get_eval_template(name: str) -> "EvalTemplate":
|
||||
@@ -71,7 +54,7 @@ register_eval_template(
|
||||
system="The following are multiple choice questions (with answers) about {subject}.\n\n",
|
||||
choice="\n{choice}. {content}",
|
||||
answer="\nAnswer: ",
|
||||
prefix=" "
|
||||
prefix=" ",
|
||||
)
|
||||
|
||||
|
||||
@@ -80,5 +63,5 @@ register_eval_template(
|
||||
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
|
||||
choice="\n{choice}. {content}",
|
||||
answer="\n答案:",
|
||||
prefix="\n"
|
||||
prefix="\n",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user