simplify code
Former-commit-id: d3731754ab7c28ae81f60784e0e4213f279d93fe
This commit is contained in:
@@ -3,30 +3,13 @@ from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Format:
|
||||
class Template:
|
||||
|
||||
prefix: str
|
||||
prompt: str
|
||||
sep: str
|
||||
use_history: bool
|
||||
|
||||
|
||||
templates: Dict[str, Format] = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Template:
|
||||
|
||||
name: str
|
||||
|
||||
def __post_init__(self):
|
||||
if self.name in templates:
|
||||
self.prefix = templates[self.name].prefix
|
||||
self.prompt = templates[self.name].prompt
|
||||
self.sep = templates[self.name].sep
|
||||
self.use_history = templates[self.name].use_history
|
||||
else:
|
||||
raise ValueError("Template {} does not exist.".format(self.name))
|
||||
|
||||
def get_prompt(
|
||||
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = ""
|
||||
) -> str:
|
||||
@@ -61,8 +44,11 @@ class Template:
|
||||
return convs[:-1] # drop last
|
||||
|
||||
|
||||
templates: Dict[str, Template] = {}
|
||||
|
||||
|
||||
def register_template(name: str, prefix: str, prompt: str, sep: str, use_history: bool) -> None:
|
||||
templates[name] = Format(
|
||||
templates[name] = Template(
|
||||
prefix=prefix,
|
||||
prompt=prompt,
|
||||
sep=sep,
|
||||
@@ -70,6 +56,12 @@ def register_template(name: str, prefix: str, prompt: str, sep: str, use_history
|
||||
)
|
||||
|
||||
|
||||
def get_template(name: str) -> Template:
|
||||
template = templates.get(name, None)
|
||||
assert template is not None, "Template {} does not exist.".format(name)
|
||||
return template
|
||||
|
||||
|
||||
r"""
|
||||
Supports language model inference without histories.
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user