[data] qwen3 fixes (#8109)

This commit is contained in:
hoshi-hiyouga
2025-05-20 02:00:30 +08:00
committed by GitHub
parent 45030ff803
commit 9b5baa97f0
13 changed files with 197 additions and 160 deletions

View File

@@ -49,6 +49,8 @@ def vllm_infer(
max_new_tokens: int = 1024,
repetition_penalty: float = 1.0,
skip_special_tokens: bool = True,
default_system: Optional[str] = None,
enable_thinking: bool = True,
seed: Optional[int] = None,
pipeline_parallel_size: int = 1,
image_max_pixels: int = 768 * 768,
@@ -74,6 +76,8 @@ def vllm_infer(
cutoff_len=cutoff_len,
max_samples=max_samples,
preprocessing_num_workers=16,
default_system=default_system,
enable_thinking=enable_thinking,
vllm_config=vllm_config,
temperature=temperature,
top_p=top_p,
@@ -127,14 +131,11 @@ def vllm_infer(
lora_request = None
# Store all results in these lists
all_prompts = []
all_preds = []
all_labels = []
all_prompts, all_preds, all_labels = [], [], []
# Add batch process to avoid the issue of too many files opened
for i in tqdm(range(0, len(train_dataset), batch_size), desc="Processing batched inference"):
vllm_inputs, prompts, labels = [], [], []
batch = train_dataset[i : min(i + batch_size, len(train_dataset))]
for j in range(len(batch["input_ids"])):
@@ -176,15 +177,14 @@ def vllm_infer(
)
results = llm.generate(vllm_inputs, sampling_params, lora_request=lora_request)
preds = [result.outputs[0].text for result in results]
# Accumulate results
all_prompts.extend(prompts)
all_preds.extend(preds)
all_labels.extend(labels)
gc.collect()
# Write all results at once outside the loop
with open(save_name, "w", encoding="utf-8") as f:
for text, pred, label in zip(all_prompts, all_preds, all_labels):