fix system prompt

Former-commit-id: 411e775aa939bdd154a3f1e92921ede90d989f18
This commit is contained in:
hiyouga
2023-08-16 01:35:52 +08:00
parent ca9a494d0c
commit baa709674f
15 changed files with 170 additions and 152 deletions

View File

@@ -26,7 +26,7 @@ class WebChatModel(ChatModel):
finetuning_type: str,
quantization_bit: str,
template: str,
source_prefix: str
system_prompt: str
):
if self.model is not None:
yield ALERTS["err_exists"][lang]
@@ -55,7 +55,7 @@ class WebChatModel(ChatModel):
finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit != "None" else None,
template=template,
source_prefix=source_prefix
system_prompt=system_prompt
)
super().__init__(args)
@@ -73,7 +73,7 @@ class WebChatModel(ChatModel):
chatbot: List[Tuple[str, str]],
query: str,
history: List[Tuple[str, str]],
prefix: str,
system: str,
max_new_tokens: int,
top_p: float,
temperature: float
@@ -81,7 +81,7 @@ class WebChatModel(ChatModel):
chatbot.append([query, ""])
response = ""
for new_text in self.stream_chat(
query, history, prefix, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
query, history, system, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
):
response += new_text
response = self.postprocess(response)

View File

@@ -17,7 +17,7 @@ def create_chat_box(
with gr.Row():
with gr.Column(scale=4):
prefix = gr.Textbox(show_label=False)
system = gr.Textbox(show_label=False)
query = gr.Textbox(show_label=False, lines=8)
submit_btn = gr.Button(variant="primary")
@@ -31,7 +31,7 @@ def create_chat_box(
submit_btn.click(
chat_model.predict,
[chatbot, query, history, prefix, max_new_tokens, top_p, temperature],
[chatbot, query, history, system, max_new_tokens, top_p, temperature],
[chatbot, history],
show_progress=True
).then(
@@ -41,7 +41,7 @@ def create_chat_box(
clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True)
return chat_box, chatbot, history, dict(
prefix=prefix,
system=system,
query=query,
submit_btn=submit_btn,
clear_btn=clear_btn,

View File

@@ -52,7 +52,7 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict
top_elems["finetuning_type"],
top_elems["quantization_bit"],
top_elems["template"],
top_elems["source_prefix"],
top_elems["system_prompt"],
dataset_dir,
dataset,
max_source_length,

View File

@@ -28,7 +28,7 @@ def create_infer_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component"
top_elems["finetuning_type"],
top_elems["quantization_bit"],
top_elems["template"],
top_elems["source_prefix"]
top_elems["system_prompt"]
],
[info_box]
).then(

View File

@@ -28,7 +28,7 @@ def create_top() -> Dict[str, "Component"]:
with gr.Row():
quantization_bit = gr.Dropdown(choices=["None", "8", "4"], value="None", scale=1)
template = gr.Dropdown(choices=list(templates.keys()), value="default", scale=1)
source_prefix = gr.Textbox(scale=2)
system_prompt = gr.Textbox(scale=2)
lang.change(save_config, [lang, model_name, model_path])
@@ -62,5 +62,5 @@ def create_top() -> Dict[str, "Component"]:
advanced_tab=advanced_tab,
quantization_bit=quantization_bit,
template=template,
source_prefix=source_prefix
system_prompt=system_prompt
)

View File

@@ -101,7 +101,7 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
top_elems["finetuning_type"],
top_elems["quantization_bit"],
top_elems["template"],
top_elems["source_prefix"],
top_elems["system_prompt"],
training_stage,
dataset_dir,
dataset,

View File

@@ -77,7 +77,7 @@ LOCALES = {
"info": "构建提示词时使用的模板"
}
},
"source_prefix": {
"system_prompt": {
"en": {
"label": "System prompt (optional)",
"info": "A sequence used as the default system prompt."
@@ -455,7 +455,7 @@ LOCALES = {
"value": "模型未加载,请先加载模型。"
}
},
"prefix": {
"system": {
"en": {
"placeholder": "System prompt (optional)"
},

View File

@@ -69,7 +69,7 @@ class Runner:
finetuning_type: str,
quantization_bit: str,
template: str,
source_prefix: str,
system_prompt: str,
training_stage: str,
dataset_dir: str,
dataset: List[str],
@@ -114,7 +114,7 @@ class Runner:
finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit != "None" else None,
template=template,
source_prefix=source_prefix,
system_prompt=system_prompt,
dataset_dir=dataset_dir,
dataset=",".join(dataset),
max_source_length=max_source_length,
@@ -170,7 +170,7 @@ class Runner:
finetuning_type: str,
quantization_bit: str,
template: str,
source_prefix: str,
system_prompt: str,
dataset_dir: str,
dataset: List[str],
max_source_length: int,
@@ -198,7 +198,7 @@ class Runner:
finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit != "None" else None,
template=template,
source_prefix=source_prefix,
system_prompt=system_prompt,
dataset_dir=dataset_dir,
dataset=",".join(dataset),
max_source_length=max_source_length,