fix inference, add prompt template
Former-commit-id: 3940e50c71472b210bbc1b01248bf85a191c4065
This commit is contained in:
@@ -7,7 +7,14 @@ import mdtex2html
|
||||
import gradio as gr
|
||||
|
||||
from threading import Thread
|
||||
from utils import load_pretrained, prepare_infer_args, get_logits_processor
|
||||
from utils import (
|
||||
load_pretrained,
|
||||
prepare_infer_args,
|
||||
get_logits_processor,
|
||||
prompt_template_alpaca,
|
||||
prompt_template_ziya
|
||||
)
|
||||
|
||||
from transformers import TextIteratorStreamer
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
@@ -18,26 +25,7 @@ require_version("gradio>=3.30.0", "To fix: pip install gradio>=3.30.0")
|
||||
model_args, data_args, finetuning_args = prepare_infer_args()
|
||||
model, tokenizer = load_pretrained(model_args, finetuning_args)
|
||||
|
||||
|
||||
def format_example_alpaca(query, history):
|
||||
prompt = "Below is an instruction that describes a task. "
|
||||
prompt += "Write a response that appropriately completes the request.\n"
|
||||
prompt += "Instruction:\n"
|
||||
for old_query, response in history:
|
||||
prompt += "Human: {}\nAssistant: {}\n".format(old_query, response)
|
||||
prompt += "Human: {}\nAssistant:".format(query)
|
||||
return prompt
|
||||
|
||||
|
||||
def format_example_ziya(query, history):
|
||||
prompt = ""
|
||||
for old_query, response in history:
|
||||
prompt += "<human>: {}\n<bot>: {}\n".format(old_query, response)
|
||||
prompt += "<human>: {}\n<bot>:".format(query)
|
||||
return prompt
|
||||
|
||||
|
||||
format_example = format_example_alpaca if data_args.prompt_template == "alpaca" else format_example_ziya
|
||||
format_example = prompt_template_alpaca if data_args.prompt_template == "alpaca" else prompt_template_ziya
|
||||
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user