support multiturn training like FastChat
Former-commit-id: 629cafb1a09924e82d7ea1f9fba318d3f5593196
This commit is contained in:
@@ -25,7 +25,6 @@ model_args, data_args, finetuning_args, generating_args = prepare_infer_args()
|
||||
model, tokenizer = load_pretrained(model_args, finetuning_args)
|
||||
|
||||
prompt_template = Template(data_args.prompt_template)
|
||||
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
||||
|
||||
|
||||
def postprocess(self, y):
|
||||
@@ -82,9 +81,12 @@ def predict(query, chatbot, max_length, top_p, temperature, history):
|
||||
|
||||
input_ids = tokenizer([prompt_template.get_prompt(query, history)], return_tensors="pt")["input_ids"]
|
||||
input_ids = input_ids.to(model.device)
|
||||
|
||||
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
||||
|
||||
gen_kwargs = {
|
||||
"input_ids": input_ids,
|
||||
"do_sample": True,
|
||||
"do_sample": generating_args.do_sample,
|
||||
"top_p": top_p,
|
||||
"temperature": temperature,
|
||||
"num_beams": generating_args.num_beams,
|
||||
@@ -93,8 +95,10 @@ def predict(query, chatbot, max_length, top_p, temperature, history):
|
||||
"logits_processor": get_logits_processor(),
|
||||
"streamer": streamer
|
||||
}
|
||||
|
||||
thread = Thread(target=model.generate, kwargs=gen_kwargs)
|
||||
thread.start()
|
||||
|
||||
response = ""
|
||||
for new_text in streamer:
|
||||
response += new_text
|
||||
|
||||
Reference in New Issue
Block a user