match api with OpenAI format

Former-commit-id: 9cbe2b98b024393817e86ff8e3ff1636776fa263
This commit is contained in:
hiyouga
2023-06-22 20:27:00 +08:00
parent 84b66010a3
commit 391bf1c699
4 changed files with 144 additions and 135 deletions

View File

@@ -77,7 +77,7 @@ def parse_text(text): # copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT
return text
def predict(query, chatbot, max_length, top_p, temperature, history):
def predict(query, chatbot, max_new_tokens, top_p, temperature, history):
chatbot.append((parse_text(query), ""))
input_ids = tokenizer([prompt_template.get_prompt(query, history, source_prefix)], return_tensors="pt")["input_ids"]
@@ -85,17 +85,15 @@ def predict(query, chatbot, max_length, top_p, temperature, history):
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs = {
gen_kwargs = generating_args.to_dict()
gen_kwargs.update({
"input_ids": input_ids,
"do_sample": generating_args.do_sample,
"top_p": top_p,
"temperature": temperature,
"num_beams": generating_args.num_beams,
"max_length": max_length,
"repetition_penalty": generating_args.repetition_penalty,
"max_new_tokens": max_new_tokens,
"logits_processor": get_logits_processor(),
"streamer": streamer
}
})
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
@@ -137,13 +135,16 @@ with gr.Blocks() as demo:
with gr.Column(scale=1):
emptyBtn = gr.Button("Clear History")
max_length = gr.Slider(0, 2048, value=1024, step=1.0, label="Maximum length", interactive=True)
top_p = gr.Slider(0, 1, value=generating_args.top_p, step=0.01, label="Top P", interactive=True)
temperature = gr.Slider(0, 1.5, value=generating_args.temperature, step=0.01, label="Temperature", interactive=True)
max_new_tokens = gr.Slider(10, 2048, value=generating_args.max_new_tokens, step=1.0,
label="Maximum new tokens", interactive=True)
top_p = gr.Slider(0.01, 1, value=generating_args.top_p, step=0.01,
label="Top P", interactive=True)
temperature = gr.Slider(0.01, 1.5, value=generating_args.temperature, step=0.01,
label="Temperature", interactive=True)
history = gr.State([])
submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history], show_progress=True)
submitBtn.click(predict, [user_input, chatbot, max_new_tokens, top_p, temperature, history], [chatbot, history], show_progress=True)
submitBtn.click(reset_user_input, [], [user_input])
emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)