support batch infer in vllm
Former-commit-id: 3ef5ed3b9a44eed2f7e3ff221dfc343d0a97c0b5
This commit is contained in:
@@ -161,12 +161,9 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1)
|
||||
|
||||
decoded_inputs = self.tokenizer.batch_decode(dataset["input_ids"], skip_special_tokens=True)
|
||||
decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
|
||||
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||
decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
|
||||
|
||||
with open(output_prediction_file, "w", encoding="utf-8") as writer:
|
||||
res: List[str] = []
|
||||
for text, label, pred in zip(decoded_inputs, decoded_labels, decoded_preds):
|
||||
res.append(json.dumps({"prompt": text, "label": label, "predict": pred}, ensure_ascii=False))
|
||||
|
||||
writer.write("\n".join(res))
|
||||
with open(output_prediction_file, "w", encoding="utf-8") as f:
|
||||
for text, pred, label in zip(decoded_inputs, decoded_preds, decoded_labels):
|
||||
f.write(json.dumps({"prompt": text, "predict": pred, "label": label}, ensure_ascii=False) + "\n")
|
||||
|
||||
Reference in New Issue
Block a user