[inference] fix stop token for object detection (#6624)
* fix stop token * update minicpm data pipeline * fix npu qlora examples Former-commit-id: 844919fadaa8a61dfae47020971ea80730b2346f
This commit is contained in:
@@ -50,11 +50,15 @@ def vllm_infer(
|
||||
top_k: int = 50,
|
||||
max_new_tokens: int = 1024,
|
||||
repetition_penalty: float = 1.0,
|
||||
pipeline_parallel_size: int = 1,
|
||||
):
|
||||
r"""
|
||||
Performs batch generation using vLLM engine, which supports tensor parallelism.
|
||||
Usage: python vllm_infer.py --model_name_or_path meta-llama/Llama-2-7b-hf --template llama --dataset alpaca_en_demo
|
||||
"""
|
||||
if pipeline_parallel_size > get_device_count():
|
||||
raise ValueError("Pipeline parallel size should be smaller than the number of gpus.")
|
||||
|
||||
model_args, data_args, _, generating_args = get_infer_args(
|
||||
dict(
|
||||
model_name_or_path=model_name_or_path,
|
||||
@@ -107,7 +111,7 @@ def vllm_infer(
|
||||
temperature=generating_args.temperature,
|
||||
top_p=generating_args.top_p or 1.0, # top_p must > 0
|
||||
top_k=generating_args.top_k,
|
||||
stop_token_ids=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids,
|
||||
stop_token_ids=template_obj.get_stop_token_ids(tokenizer),
|
||||
max_tokens=generating_args.max_new_tokens,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
@@ -120,7 +124,8 @@ def vllm_infer(
|
||||
"model": model_args.model_name_or_path,
|
||||
"trust_remote_code": True,
|
||||
"dtype": model_args.infer_dtype,
|
||||
"tensor_parallel_size": get_device_count() or 1,
|
||||
"tensor_parallel_size": (get_device_count() // pipeline_parallel_size) or 1,
|
||||
"pipeline_parallel_size": pipeline_parallel_size,
|
||||
"disable_log_stats": True,
|
||||
"enable_lora": model_args.adapter_name_or_path is not None,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user