support multiturn training like FastChat

Former-commit-id: 629cafb1a09924e82d7ea1f9fba318d3f5593196
This commit is contained in:
hiyouga
2023-06-14 22:27:39 +08:00
parent 6f655e3916
commit aa1bb8a9a2
5 changed files with 166 additions and 108 deletions

View File

@@ -25,7 +25,6 @@ model_args, data_args, finetuning_args, generating_args = prepare_infer_args()
model, tokenizer = load_pretrained(model_args, finetuning_args)
prompt_template = Template(data_args.prompt_template)
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
def postprocess(self, y):
@@ -82,9 +81,12 @@ def predict(query, chatbot, max_length, top_p, temperature, history):
input_ids = tokenizer([prompt_template.get_prompt(query, history)], return_tensors="pt")["input_ids"]
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs = {
"input_ids": input_ids,
"do_sample": True,
"do_sample": generating_args.do_sample,
"top_p": top_p,
"temperature": temperature,
"num_beams": generating_args.num_beams,
@@ -93,8 +95,10 @@ def predict(query, chatbot, max_length, top_p, temperature, history):
"logits_processor": get_logits_processor(),
"streamer": streamer
}
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
response = ""
for new_text in streamer:
response += new_text