[infer] fix vllm args (#7235)

Former-commit-id: 999be5b4512890b8cf4f45874a77e35cf35626f5
This commit is contained in:
hoshi-hiyouga
2025-03-11 01:15:35 +08:00
committed by GitHub
parent 18968405d0
commit 522a3e8493
4 changed files with 32 additions and 26 deletions

View File

@@ -38,7 +38,7 @@ def vllm_infer(
dataset_dir: str = "data",
template: str = "default",
cutoff_len: int = 2048,
max_samples: int = None,
max_samples: Optional[int] = None,
vllm_config: str = "{}",
save_name: str = "generated_predictions.jsonl",
temperature: float = 0.95,
@@ -46,6 +46,7 @@ def vllm_infer(
top_k: int = 50,
max_new_tokens: int = 1024,
repetition_penalty: float = 1.0,
skip_special_tokens: bool = True,
seed: Optional[int] = None,
pipeline_parallel_size: int = 1,
image_max_pixels: int = 768 * 768,
@@ -97,19 +98,21 @@ def vllm_infer(
multi_modal_data = None
inputs.append({"prompt_token_ids": sample["input_ids"], "multi_modal_data": multi_modal_data})
prompts.append(tokenizer.decode(sample["input_ids"], skip_special_tokens=False))
prompts.append(tokenizer.decode(sample["input_ids"], skip_special_tokens=skip_special_tokens))
labels.append(
tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, sample["labels"])), skip_special_tokens=False)
tokenizer.decode(
list(filter(lambda x: x != IGNORE_INDEX, sample["labels"])), skip_special_tokens=skip_special_tokens
)
)
sampling_params = SamplingParams(
repetition_penalty=generating_args.repetition_penalty or 1.0, # repetition_penalty must > 0
temperature=generating_args.temperature,
top_p=generating_args.top_p or 1.0, # top_p must > 0
top_k=generating_args.top_k,
top_k=generating_args.top_k or -1, # top_k must > 0
stop_token_ids=template_obj.get_stop_token_ids(tokenizer),
max_tokens=generating_args.max_new_tokens,
skip_special_tokens=False,
skip_special_tokens=skip_special_tokens,
seed=seed,
)
if model_args.adapter_name_or_path is not None:
@@ -121,6 +124,7 @@ def vllm_infer(
"model": model_args.model_name_or_path,
"trust_remote_code": True,
"dtype": model_args.infer_dtype,
"max_model_len": cutoff_len + max_new_tokens,
"tensor_parallel_size": (get_device_count() // pipeline_parallel_size) or 1,
"pipeline_parallel_size": pipeline_parallel_size,
"disable_log_stats": True,