[data] gemma3 plugin pan and scan (#7294)

* gemma3 pan and scan

* add test case

* fix test
This commit is contained in:
hoshi-hiyouga
2025-03-13 23:29:23 +08:00
committed by GitHub
parent 0be0d7796a
commit 93e6184cbe
5 changed files with 65 additions and 4 deletions

View File

@@ -20,6 +20,7 @@ import torch
from PIL import Image
from llamafactory.data.mm_plugin import get_mm_plugin
from llamafactory.extras.packages import is_transformers_version_greater_than
from llamafactory.hparams import get_infer_args
from llamafactory.model import load_tokenizer
@@ -135,6 +136,27 @@ def test_base_plugin():
_check_plugin(**check_inputs)
@pytest.mark.skipif(not HF_TOKEN or not is_transformers_version_greater_than("4.50.0"), reason="Gated model.")
def test_gemma3_plugin():
image_seqlen = 256
tokenizer_module = _load_tokenizer_module(model_name_or_path="google/gemma-3-4b-it")
gemma3_plugin = get_mm_plugin(name="gemma3", image_token="<image_soft_token>")
image_tokens_expanded = "<image_soft_token>" * image_seqlen
check_inputs = {"plugin": gemma3_plugin, **tokenizer_module}
check_inputs["expected_mm_messages"] = [
{
key: value.replace("<image>", f"\n\n<start_of_image>{image_tokens_expanded}<end_of_image>\n\n")
for key, value in message.items()
}
for message in MM_MESSAGES
]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
check_inputs["expected_mm_inputs"].pop("num_crops")
check_inputs["expected_mm_inputs"]["token_type_ids"] = [[0] * 1024]
check_inputs["expected_no_mm_inputs"] = {"token_type_ids": [[0] * 1024]}
_check_plugin(**check_inputs)
def test_llava_plugin():
image_seqlen = 576
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf")