fix bug at inference
Former-commit-id: df9b41af4401006b8040eb53c44dd290b604e0eb
This commit is contained in:
@@ -29,8 +29,8 @@ def main():
|
||||
return prompt
|
||||
|
||||
def predict(query, history: list):
|
||||
inputs = tokenizer([format_example(query)], return_tensors="pt")
|
||||
inputs = inputs.to(model.device)
|
||||
input_ids = tokenizer([format_example(query)], return_tensors="pt")["input_ids"]
|
||||
input_ids = input_ids.to(model.device)
|
||||
gen_kwargs = {
|
||||
"do_sample": True,
|
||||
"top_p": 0.9,
|
||||
@@ -41,8 +41,8 @@ def main():
|
||||
"repetition_penalty": 1.5
|
||||
}
|
||||
with torch.no_grad():
|
||||
generation_output = model.generate(**inputs, **gen_kwargs)
|
||||
outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):]
|
||||
generation_output = model.generate(input_ids=input_ids, **gen_kwargs)
|
||||
outputs = generation_output.tolist()[0][len(input_ids[0]):]
|
||||
response = tokenizer.decode(outputs, skip_special_tokens=True)
|
||||
history = history + [(query, response)]
|
||||
return response, history
|
||||
|
||||
Reference in New Issue
Block a user