add ziya prompt template

Former-commit-id: 321e44ac54a91260cf00a4caa1991708814473fc
This commit is contained in:
hiyouga
2023-06-03 19:05:51 +08:00
parent 5eef8d5d98
commit fa850ae6e5
6 changed files with 79 additions and 24 deletions

View File

@@ -7,14 +7,12 @@ import torch
import mdtex2html
import gradio as gr
from utils import ModelArguments, FinetuningArguments, load_pretrained, get_logits_processor
from transformers import HfArgumentParser
from utils import load_pretrained, prepare_infer_args, get_logits_processor
from transformers.utils.versions import require_version
require_version("gradio==3.27.0", "To fix: pip install gradio==3.27.0") # higher version may cause problems
parser = HfArgumentParser((ModelArguments, FinetuningArguments))
model_args, finetuning_args = parser.parse_args_into_dataclasses()
model_args, data_args, finetuning_args = prepare_infer_args()
model, tokenizer = load_pretrained(model_args, finetuning_args)
if torch.cuda.device_count() > 1:
@@ -75,17 +73,31 @@ def parse_text(text): # copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT
return text
def format_example(query):
def format_example_alpaca(query, history):
prompt = "Below is an instruction that describes a task. "
prompt += "Write a response that appropriately completes the request.\n"
prompt += "Instruction:\nHuman: {}\nAssistant: ".format(query)
prompt += "Instruction:\n"
for old_query, response in history:
prompt += "Human: {}\nAssistant: {}\n".format(old_query, response)
prompt += "Human: {}\nAssistant:".format(query)
return prompt
def format_example_ziya(query, history):
prompt = ""
for old_query, response in history:
prompt += "<human>: {}\n<bot>: {}\n".format(old_query, response)
prompt += "<human>: {}\n<bot>:".format(query)
return prompt
format_example = format_example_alpaca if data_args.prompt_template == "alpaca" else format_example_ziya
def predict(input, chatbot, max_length, top_p, temperature, history):
chatbot.append((parse_text(input), ""))
input_ids = tokenizer([format_example(input)], return_tensors="pt")["input_ids"]
input_ids = tokenizer([format_example(input, history)], return_tensors="pt")["input_ids"]
input_ids = input_ids.to(model.device)
gen_kwargs = {
"do_sample": True,