support batch infer in vllm

Former-commit-id: 3ef5ed3b9a44eed2f7e3ff221dfc343d0a97c0b5
This commit is contained in:
hiyouga
2024-12-04 13:50:00 +00:00
parent 53edd62f8b
commit c1768cfb14
29 changed files with 148 additions and 407 deletions

View File

@@ -12,127 +12,106 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import sys
from typing import List
import fire
from transformers import Seq2SeqTrainingArguments
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.hparams import get_train_args
from llamafactory.extras.misc import get_device_count
from llamafactory.hparams import get_infer_args
from llamafactory.model import load_tokenizer
max_tokens = 2048
def vllm_infer(
model_name_or_path: str = None,
model_name_or_path: str,
adapter_name_or_path: str = None,
dataset: str = "alpaca_en_demo",
dataset_dir: str = "data",
eval_dataset: str = None,
template: str = "default",
max_sample: int = None,
preprocessing_num_workers: int = 16,
predict_with_generate: bool = True,
do_predict: bool = True,
temperature: float = 0.7,
cutoff_len: int = 2048,
max_samples: int = None,
vllm_config: str = "{}",
save_name: str = "generated_predictions.jsonl",
temperature: float = 0.95,
top_p: float = 0.7,
top_k: float = 50,
output_dir: str = "output",
top_k: int = 50,
max_new_tokens: int = 1024,
repetition_penalty: float = 1.0,
):
if len(sys.argv) == 1:
model_args, data_args, training_args, finetuning_args, generating_args = (
get_train_args(
dict(
model_name_or_path=model_name_or_path,
adapter_name_or_path=adapter_name_or_path,
dataset_dir=dataset_dir,
eval_dataset=eval_dataset,
template=template,
max_sample=max_sample,
preprocessing_num_workers=preprocessing_num_workers,
predict_with_generate=predict_with_generate,
do_predict=do_predict,
temperature=temperature,
top_p=top_p,
top_k=top_k,
output_dir=output_dir,
)
)
)
else:
model_args, data_args, training_args, finetuning_args, generating_args = (
get_train_args()
r"""
Performs batch generation using vLLM engine, which supports tensor parallelism.
Usage: python vllm_infer.py --model_name_or_path meta-llama/Llama-2-7b-hf --template llama --dataset alpaca_en_demo
"""
model_args, data_args, _, generating_args = get_infer_args(
dict(
model_name_or_path=model_name_or_path,
adapter_name_or_path=adapter_name_or_path,
dataset=dataset,
dataset_dir=dataset_dir,
template=template,
cutoff_len=cutoff_len,
max_samples=max_samples,
vllm_config=vllm_config,
temperature=temperature,
top_p=top_p,
top_k=top_k,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
)
)
tokenizer = load_tokenizer(model_args)["tokenizer"]
training_args = Seq2SeqTrainingArguments(output_dir="dummy_dir", predict_with_generate=True)
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
template = get_template_and_fix_tokenizer(tokenizer, data_args)
dataset = get_dataset(template, model_args, data_args, training_args, "ppo", **tokenizer_module)["train_dataset"]
eval_dataset = get_dataset(
template, model_args, data_args, training_args, finetuning_args.stage, tokenizer
)["eval_dataset"]
prompts = [item["input_ids"] for item in eval_dataset]
prompts = tokenizer.batch_decode(prompts, skip_special_tokens=False)
labels = [
list(filter(lambda x: x != IGNORE_INDEX, item["labels"]))
for item in eval_dataset
]
labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
inputs, prompts, labels = [], [], []
for sample in dataset:
inputs.append({"prompt_token_ids": sample["input_ids"]})
prompts.append(tokenizer.decode(sample["input_ids"], skip_special_tokens=False))
labels.append(
tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, sample["labels"])), skip_special_tokens=False)
)
sampling_params = SamplingParams(
repetition_penalty=generating_args.repetition_penalty or 1.0, # repetition_penalty must > 0
temperature=generating_args.temperature,
top_p=generating_args.top_p or 1.0, # top_p must > 0
top_k=generating_args.top_k,
top_p=generating_args.top_p,
max_tokens=max_tokens,
stop_token_ids=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids,
max_tokens=generating_args.max_new_tokens,
skip_special_tokens=False,
)
if model_args.adapter_name_or_path:
if isinstance(model_args.adapter_name_or_path, list):
lora_path = model_args.adapter_name_or_path[0]
else:
lora_path = model_args.adapter_name_or_path
lora_requests = LoRARequest("lora_adapter_0", 0, lora_path=lora_path)
enable_lora = True
if model_args.adapter_name_or_path is not None:
lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0])
else:
lora_requests = None
enable_lora = False
lora_request = None
llm = LLM(
model=model_args.model_name_or_path,
trust_remote_code=True,
tokenizer=model_args.model_name_or_path,
enable_lora=enable_lora,
)
engine_args = {
"model": model_args.model_name_or_path,
"trust_remote_code": True,
"dtype": model_args.infer_dtype,
"tensor_parallel_size": get_device_count() or 1,
"disable_log_stats": True,
"enable_lora": model_args.adapter_name_or_path is not None,
}
if isinstance(model_args.vllm_config, dict):
engine_args.update(model_args.vllm_config)
outputs = llm.generate(prompts, sampling_params, lora_request=lora_requests)
results = LLM(**engine_args).generate(inputs, sampling_params, lora_request=lora_request)
preds = [result.outputs[0].text for result in results]
with open(save_name, "w", encoding="utf-8") as f:
for text, pred, label in zip(prompts, preds, labels):
f.write(json.dumps({"prompt": text, "predict": pred, "label": label}, ensure_ascii=False) + "\n")
if not os.path.exists(training_args.output_dir):
os.makedirs(training_args.output_dir, exist_ok=True)
output_prediction_file = os.path.join(
training_args.output_dir, "generated_predictions.jsonl"
)
with open(output_prediction_file, "w", encoding="utf-8") as writer:
res: List[str] = []
for text, pred, label in zip(prompts, outputs, labels):
res.append(
json.dumps(
{"prompt": text, "predict": pred.outputs[0].text, "label": label},
ensure_ascii=False,
)
)
writer.write("\n".join(res))
print("*" * 70)
print(f"{len(prompts)} generated results have been saved at {save_name}.")
print("*" * 70)
if __name__ == "__main__":