support batch infer in vllm
Former-commit-id: 3ef5ed3b9a44eed2f7e3ff221dfc343d0a97c0b5
This commit is contained in:
@@ -12,127 +12,106 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
import fire
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
|
||||
from llamafactory.extras.constants import IGNORE_INDEX
|
||||
from llamafactory.hparams import get_train_args
|
||||
from llamafactory.extras.misc import get_device_count
|
||||
from llamafactory.hparams import get_infer_args
|
||||
from llamafactory.model import load_tokenizer
|
||||
|
||||
|
||||
max_tokens = 2048
|
||||
|
||||
|
||||
def vllm_infer(
|
||||
model_name_or_path: str = None,
|
||||
model_name_or_path: str,
|
||||
adapter_name_or_path: str = None,
|
||||
dataset: str = "alpaca_en_demo",
|
||||
dataset_dir: str = "data",
|
||||
eval_dataset: str = None,
|
||||
template: str = "default",
|
||||
max_sample: int = None,
|
||||
preprocessing_num_workers: int = 16,
|
||||
predict_with_generate: bool = True,
|
||||
do_predict: bool = True,
|
||||
temperature: float = 0.7,
|
||||
cutoff_len: int = 2048,
|
||||
max_samples: int = None,
|
||||
vllm_config: str = "{}",
|
||||
save_name: str = "generated_predictions.jsonl",
|
||||
temperature: float = 0.95,
|
||||
top_p: float = 0.7,
|
||||
top_k: float = 50,
|
||||
output_dir: str = "output",
|
||||
top_k: int = 50,
|
||||
max_new_tokens: int = 1024,
|
||||
repetition_penalty: float = 1.0,
|
||||
):
|
||||
|
||||
if len(sys.argv) == 1:
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = (
|
||||
get_train_args(
|
||||
dict(
|
||||
model_name_or_path=model_name_or_path,
|
||||
adapter_name_or_path=adapter_name_or_path,
|
||||
dataset_dir=dataset_dir,
|
||||
eval_dataset=eval_dataset,
|
||||
template=template,
|
||||
max_sample=max_sample,
|
||||
preprocessing_num_workers=preprocessing_num_workers,
|
||||
predict_with_generate=predict_with_generate,
|
||||
do_predict=do_predict,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
output_dir=output_dir,
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = (
|
||||
get_train_args()
|
||||
r"""
|
||||
Performs batch generation using vLLM engine, which supports tensor parallelism.
|
||||
Usage: python vllm_infer.py --model_name_or_path meta-llama/Llama-2-7b-hf --template llama --dataset alpaca_en_demo
|
||||
"""
|
||||
model_args, data_args, _, generating_args = get_infer_args(
|
||||
dict(
|
||||
model_name_or_path=model_name_or_path,
|
||||
adapter_name_or_path=adapter_name_or_path,
|
||||
dataset=dataset,
|
||||
dataset_dir=dataset_dir,
|
||||
template=template,
|
||||
cutoff_len=cutoff_len,
|
||||
max_samples=max_samples,
|
||||
vllm_config=vllm_config,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
max_new_tokens=max_new_tokens,
|
||||
repetition_penalty=repetition_penalty,
|
||||
)
|
||||
)
|
||||
|
||||
tokenizer = load_tokenizer(model_args)["tokenizer"]
|
||||
training_args = Seq2SeqTrainingArguments(output_dir="dummy_dir", predict_with_generate=True)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
template = get_template_and_fix_tokenizer(tokenizer, data_args)
|
||||
dataset = get_dataset(template, model_args, data_args, training_args, "ppo", **tokenizer_module)["train_dataset"]
|
||||
|
||||
eval_dataset = get_dataset(
|
||||
template, model_args, data_args, training_args, finetuning_args.stage, tokenizer
|
||||
)["eval_dataset"]
|
||||
|
||||
prompts = [item["input_ids"] for item in eval_dataset]
|
||||
prompts = tokenizer.batch_decode(prompts, skip_special_tokens=False)
|
||||
|
||||
labels = [
|
||||
list(filter(lambda x: x != IGNORE_INDEX, item["labels"]))
|
||||
for item in eval_dataset
|
||||
]
|
||||
labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
||||
inputs, prompts, labels = [], [], []
|
||||
for sample in dataset:
|
||||
inputs.append({"prompt_token_ids": sample["input_ids"]})
|
||||
prompts.append(tokenizer.decode(sample["input_ids"], skip_special_tokens=False))
|
||||
labels.append(
|
||||
tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, sample["labels"])), skip_special_tokens=False)
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
repetition_penalty=generating_args.repetition_penalty or 1.0, # repetition_penalty must > 0
|
||||
temperature=generating_args.temperature,
|
||||
top_p=generating_args.top_p or 1.0, # top_p must > 0
|
||||
top_k=generating_args.top_k,
|
||||
top_p=generating_args.top_p,
|
||||
max_tokens=max_tokens,
|
||||
stop_token_ids=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids,
|
||||
max_tokens=generating_args.max_new_tokens,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
|
||||
if model_args.adapter_name_or_path:
|
||||
if isinstance(model_args.adapter_name_or_path, list):
|
||||
lora_path = model_args.adapter_name_or_path[0]
|
||||
else:
|
||||
lora_path = model_args.adapter_name_or_path
|
||||
|
||||
lora_requests = LoRARequest("lora_adapter_0", 0, lora_path=lora_path)
|
||||
enable_lora = True
|
||||
if model_args.adapter_name_or_path is not None:
|
||||
lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0])
|
||||
else:
|
||||
lora_requests = None
|
||||
enable_lora = False
|
||||
lora_request = None
|
||||
|
||||
llm = LLM(
|
||||
model=model_args.model_name_or_path,
|
||||
trust_remote_code=True,
|
||||
tokenizer=model_args.model_name_or_path,
|
||||
enable_lora=enable_lora,
|
||||
)
|
||||
engine_args = {
|
||||
"model": model_args.model_name_or_path,
|
||||
"trust_remote_code": True,
|
||||
"dtype": model_args.infer_dtype,
|
||||
"tensor_parallel_size": get_device_count() or 1,
|
||||
"disable_log_stats": True,
|
||||
"enable_lora": model_args.adapter_name_or_path is not None,
|
||||
}
|
||||
if isinstance(model_args.vllm_config, dict):
|
||||
engine_args.update(model_args.vllm_config)
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params, lora_request=lora_requests)
|
||||
results = LLM(**engine_args).generate(inputs, sampling_params, lora_request=lora_request)
|
||||
preds = [result.outputs[0].text for result in results]
|
||||
with open(save_name, "w", encoding="utf-8") as f:
|
||||
for text, pred, label in zip(prompts, preds, labels):
|
||||
f.write(json.dumps({"prompt": text, "predict": pred, "label": label}, ensure_ascii=False) + "\n")
|
||||
|
||||
if not os.path.exists(training_args.output_dir):
|
||||
os.makedirs(training_args.output_dir, exist_ok=True)
|
||||
|
||||
output_prediction_file = os.path.join(
|
||||
training_args.output_dir, "generated_predictions.jsonl"
|
||||
)
|
||||
|
||||
with open(output_prediction_file, "w", encoding="utf-8") as writer:
|
||||
res: List[str] = []
|
||||
for text, pred, label in zip(prompts, outputs, labels):
|
||||
res.append(
|
||||
json.dumps(
|
||||
{"prompt": text, "predict": pred.outputs[0].text, "label": label},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
)
|
||||
writer.write("\n".join(res))
|
||||
print("*" * 70)
|
||||
print(f"{len(prompts)} generated results have been saved at {save_name}.")
|
||||
print("*" * 70)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user