add prompt template class

Former-commit-id: 3d7e3a38d00aa5d9664824093043951af8c3f707
This commit is contained in:
hiyouga
2023-06-07 11:55:25 +08:00
parent 3da427a665
commit b9feb82e4e
8 changed files with 67 additions and 40 deletions

View File

@@ -1,16 +1,45 @@
def prompt_template_alpaca(query, history=None):
prompt = ""
if history:
for old_query, response in history:
prompt += "Human:{}\nAssistant:{}\n".format(old_query, response)
prompt += "Human:{}\nAssistant:".format(query)
return prompt
from typing import Optional
from dataclasses import dataclass
def prompt_template_ziya(query, history=None):
prompt = ""
if history:
for old_query, response in history:
prompt += "<human>:{}\n<bot>:{}\n".format(old_query, response)
prompt += "<human>:{}\n<bot>:".format(query)
return prompt
@dataclass
class Template:
name: str
def get_prompt(self, query: str, history: Optional[list] = None, prefix: Optional[str] = "") -> str:
return getattr(self, "_format_{}".format(self.name))(query, history, prefix)
def _format_alpaca(self, query: str, history: Optional[list], prefix: Optional[str] = "") -> str:
if prefix:
prompt = prefix
else:
prompt = "Below is an instruction that describes a task. "
prompt += "Write a response that appropriately completes the request.\n"
prompt += "Instruction:\n"
if history:
for old_query, response in history:
prompt += "Human:{}\nAssistant:{}\n".format(old_query, response)
prompt += "Human:{}\nAssistant:".format(query)
return prompt
def _format_vicuna(self, query: str, history: Optional[list], prefix: Optional[str] = "") -> str:
if prefix:
prompt = prefix
else:
prompt = "A chat between a curious user and an artificial intelligence assistant. "
prompt += "The assistant gives helpful, detailed, and polite answers to the user's questions. "
if history:
for old_query, response in history:
prompt += "USER: {} ASSISTANT: {}</s>".format(old_query, response)
prompt += "USER: {} ASSISTANT:".format(query)
return prompt
def _format_ziya(self, query: str, history: Optional[list], prefix: Optional[str] = "") -> str:
prompt = prefix
if history:
for old_query, response in history:
prompt += "<human>:{}\n<bot>:{}\n".format(old_query, response)
prompt += "<human>:{}\n<bot>:".format(query)
return prompt