[data] fix qwen2vl pos ids (#8387)

This commit is contained in:
Yaowei Zheng
2025-06-17 00:48:54 +08:00
committed by GitHub
parent 31874e4f62
commit 3a3bae1cfe
7 changed files with 85 additions and 35 deletions

View File

@@ -16,6 +16,7 @@ import os
import torch
from PIL import Image
from transformers import AutoConfig, AutoModelForVision2Seq
from llamafactory.data import get_template_and_fix_tokenizer
from llamafactory.data.collator import MultiModalDataCollatorForSeq2Seq, prepare_4d_attention_mask
@@ -72,12 +73,17 @@ def test_base_collator():
def test_multimodal_collator():
model_args, data_args, *_ = get_infer_args(
{"model_name_or_path": "Qwen/Qwen2-VL-7B-Instruct", "template": "qwen2_vl"}
{"model_name_or_path": "Qwen/Qwen2-VL-2B-Instruct", "template": "qwen2_vl"}
)
tokenizer_module = load_tokenizer(model_args)
template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args)
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
with torch.device("meta"):
model = AutoModelForVision2Seq.from_config(config)
data_collator = MultiModalDataCollatorForSeq2Seq(
template=template,
model=model,
pad_to_multiple_of=4,
label_pad_token_id=IGNORE_INDEX,
**tokenizer_module,
@@ -107,8 +113,15 @@ def test_multimodal_collator():
"labels": [
[0, 1, 2, 3, q, q, q, q, q, q, q, q],
],
"position_ids": [
[[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]],
[[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]],
[[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]],
],
"rope_deltas": [[-8]],
**tokenizer_module["processor"].image_processor(fake_image),
}
assert batch_input.keys() == expected_input.keys()
for k in batch_input.keys():
assert batch_input[k].eq(torch.tensor(expected_input[k])).all()
@@ -150,3 +163,7 @@ def test_4d_attention_mask():
)
assert list(attention_mask_computed.size()) == [2, 1, 6, 6]
assert torch.all(attention_mask_computed == attention_mask_expected)
if __name__ == "__main__":
test_multimodal_collator()