[data] fix qwen2vl pos ids (#8387)
This commit is contained in:
@@ -16,6 +16,7 @@ import os
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoConfig, AutoModelForVision2Seq
|
||||
|
||||
from llamafactory.data import get_template_and_fix_tokenizer
|
||||
from llamafactory.data.collator import MultiModalDataCollatorForSeq2Seq, prepare_4d_attention_mask
|
||||
@@ -72,12 +73,17 @@ def test_base_collator():
|
||||
|
||||
def test_multimodal_collator():
|
||||
model_args, data_args, *_ = get_infer_args(
|
||||
{"model_name_or_path": "Qwen/Qwen2-VL-7B-Instruct", "template": "qwen2_vl"}
|
||||
{"model_name_or_path": "Qwen/Qwen2-VL-2B-Instruct", "template": "qwen2_vl"}
|
||||
)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args)
|
||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
|
||||
with torch.device("meta"):
|
||||
model = AutoModelForVision2Seq.from_config(config)
|
||||
|
||||
data_collator = MultiModalDataCollatorForSeq2Seq(
|
||||
template=template,
|
||||
model=model,
|
||||
pad_to_multiple_of=4,
|
||||
label_pad_token_id=IGNORE_INDEX,
|
||||
**tokenizer_module,
|
||||
@@ -107,8 +113,15 @@ def test_multimodal_collator():
|
||||
"labels": [
|
||||
[0, 1, 2, 3, q, q, q, q, q, q, q, q],
|
||||
],
|
||||
"position_ids": [
|
||||
[[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]],
|
||||
[[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]],
|
||||
[[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]],
|
||||
],
|
||||
"rope_deltas": [[-8]],
|
||||
**tokenizer_module["processor"].image_processor(fake_image),
|
||||
}
|
||||
assert batch_input.keys() == expected_input.keys()
|
||||
for k in batch_input.keys():
|
||||
assert batch_input[k].eq(torch.tensor(expected_input[k])).all()
|
||||
|
||||
@@ -150,3 +163,7 @@ def test_4d_attention_mask():
|
||||
)
|
||||
assert list(attention_mask_computed.size()) == [2, 1, 6, 6]
|
||||
assert torch.all(attention_mask_computed == attention_mask_expected)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_multimodal_collator()
|
||||
|
||||
Reference in New Issue
Block a user