simplify code

Former-commit-id: d3731754ab7c28ae81f60784e0e4213f279d93fe
This commit is contained in:
hiyouga
2023-07-20 15:08:57 +08:00
parent 7a43ff3d89
commit 5ba0b80e5c
18 changed files with 52 additions and 136 deletions

View File

@@ -3,30 +3,13 @@ from dataclasses import dataclass
@dataclass
class Format:
class Template:
prefix: str
prompt: str
sep: str
use_history: bool
templates: Dict[str, Format] = {}
@dataclass
class Template:
name: str
def __post_init__(self):
if self.name in templates:
self.prefix = templates[self.name].prefix
self.prompt = templates[self.name].prompt
self.sep = templates[self.name].sep
self.use_history = templates[self.name].use_history
else:
raise ValueError("Template {} does not exist.".format(self.name))
def get_prompt(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = ""
) -> str:
@@ -61,8 +44,11 @@ class Template:
return convs[:-1] # drop last
templates: Dict[str, Template] = {}
def register_template(name: str, prefix: str, prompt: str, sep: str, use_history: bool) -> None:
templates[name] = Format(
templates[name] = Template(
prefix=prefix,
prompt=prompt,
sep=sep,
@@ -70,6 +56,12 @@ def register_template(name: str, prefix: str, prompt: str, sep: str, use_history
)
def get_template(name: str) -> Template:
template = templates.get(name, None)
assert template is not None, "Template {} does not exist.".format(name)
return template
r"""
Supports language model inference without histories.
"""