fix inference, add prompt template

Former-commit-id: 3940e50c71472b210bbc1b01248bf85a191c4065
This commit is contained in:
hiyouga
2023-06-07 10:52:35 +08:00
parent 12094c1db5
commit 3da427a665
6 changed files with 44 additions and 81 deletions

View File

@@ -23,7 +23,9 @@ from fastapi import FastAPI, Request
from utils import (
load_pretrained,
prepare_infer_args,
get_logits_processor
get_logits_processor,
prompt_template_alpaca,
prompt_template_ziya
)
@@ -96,23 +98,6 @@ async def create_item(request: Request):
if __name__ == "__main__":
model_args, data_args, finetuning_args = prepare_infer_args()
model, tokenizer = load_pretrained(model_args, finetuning_args)
def format_example_alpaca(query, history):
prompt = "Below is an instruction that describes a task. "
prompt += "Write a response that appropriately completes the request.\n"
prompt += "Instruction:\n"
for old_query, response in history:
prompt += "Human: {}\nAssistant: {}\n".format(old_query, response)
prompt += "Human: {}\nAssistant:".format(query)
return prompt
def format_example_ziya(query, history):
prompt = ""
for old_query, response in history:
prompt += "<human>: {}\n<bot>: {}\n".format(old_query, response)
prompt += "<human>: {}\n<bot>:".format(query)
return prompt
format_example = format_example_alpaca if data_args.prompt_template == "alpaca" else format_example_ziya
format_example = prompt_template_alpaca if data_args.prompt_template == "alpaca" else prompt_template_ziya
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)