[infer] fix vllm args (#7235)
Former-commit-id: 999be5b4512890b8cf4f45874a77e35cf35626f5
This commit is contained in:
@@ -38,7 +38,7 @@ def vllm_infer(
|
||||
dataset_dir: str = "data",
|
||||
template: str = "default",
|
||||
cutoff_len: int = 2048,
|
||||
max_samples: int = None,
|
||||
max_samples: Optional[int] = None,
|
||||
vllm_config: str = "{}",
|
||||
save_name: str = "generated_predictions.jsonl",
|
||||
temperature: float = 0.95,
|
||||
@@ -46,6 +46,7 @@ def vllm_infer(
|
||||
top_k: int = 50,
|
||||
max_new_tokens: int = 1024,
|
||||
repetition_penalty: float = 1.0,
|
||||
skip_special_tokens: bool = True,
|
||||
seed: Optional[int] = None,
|
||||
pipeline_parallel_size: int = 1,
|
||||
image_max_pixels: int = 768 * 768,
|
||||
@@ -97,19 +98,21 @@ def vllm_infer(
|
||||
multi_modal_data = None
|
||||
|
||||
inputs.append({"prompt_token_ids": sample["input_ids"], "multi_modal_data": multi_modal_data})
|
||||
prompts.append(tokenizer.decode(sample["input_ids"], skip_special_tokens=False))
|
||||
prompts.append(tokenizer.decode(sample["input_ids"], skip_special_tokens=skip_special_tokens))
|
||||
labels.append(
|
||||
tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, sample["labels"])), skip_special_tokens=False)
|
||||
tokenizer.decode(
|
||||
list(filter(lambda x: x != IGNORE_INDEX, sample["labels"])), skip_special_tokens=skip_special_tokens
|
||||
)
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
repetition_penalty=generating_args.repetition_penalty or 1.0, # repetition_penalty must > 0
|
||||
temperature=generating_args.temperature,
|
||||
top_p=generating_args.top_p or 1.0, # top_p must > 0
|
||||
top_k=generating_args.top_k,
|
||||
top_k=generating_args.top_k or -1, # top_k must > 0
|
||||
stop_token_ids=template_obj.get_stop_token_ids(tokenizer),
|
||||
max_tokens=generating_args.max_new_tokens,
|
||||
skip_special_tokens=False,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
seed=seed,
|
||||
)
|
||||
if model_args.adapter_name_or_path is not None:
|
||||
@@ -121,6 +124,7 @@ def vllm_infer(
|
||||
"model": model_args.model_name_or_path,
|
||||
"trust_remote_code": True,
|
||||
"dtype": model_args.infer_dtype,
|
||||
"max_model_len": cutoff_len + max_new_tokens,
|
||||
"tensor_parallel_size": (get_device_count() // pipeline_parallel_size) or 1,
|
||||
"pipeline_parallel_size": pipeline_parallel_size,
|
||||
"disable_log_stats": True,
|
||||
|
||||
Reference in New Issue
Block a user