[data] gemma3 plugin pan and scan (#7294)
* gemma3 pan and scan * add test case * fix test
This commit is contained in:
@@ -20,6 +20,7 @@ import torch
|
||||
from PIL import Image
|
||||
|
||||
from llamafactory.data.mm_plugin import get_mm_plugin
|
||||
from llamafactory.extras.packages import is_transformers_version_greater_than
|
||||
from llamafactory.hparams import get_infer_args
|
||||
from llamafactory.model import load_tokenizer
|
||||
|
||||
@@ -135,6 +136,27 @@ def test_base_plugin():
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HF_TOKEN or not is_transformers_version_greater_than("4.50.0"), reason="Gated model.")
|
||||
def test_gemma3_plugin():
|
||||
image_seqlen = 256
|
||||
tokenizer_module = _load_tokenizer_module(model_name_or_path="google/gemma-3-4b-it")
|
||||
gemma3_plugin = get_mm_plugin(name="gemma3", image_token="<image_soft_token>")
|
||||
image_tokens_expanded = "<image_soft_token>" * image_seqlen
|
||||
check_inputs = {"plugin": gemma3_plugin, **tokenizer_module}
|
||||
check_inputs["expected_mm_messages"] = [
|
||||
{
|
||||
key: value.replace("<image>", f"\n\n<start_of_image>{image_tokens_expanded}<end_of_image>\n\n")
|
||||
for key, value in message.items()
|
||||
}
|
||||
for message in MM_MESSAGES
|
||||
]
|
||||
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
|
||||
check_inputs["expected_mm_inputs"].pop("num_crops")
|
||||
check_inputs["expected_mm_inputs"]["token_type_ids"] = [[0] * 1024]
|
||||
check_inputs["expected_no_mm_inputs"] = {"token_type_ids": [[0] * 1024]}
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
def test_llava_plugin():
|
||||
image_seqlen = 576
|
||||
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf")
|
||||
|
||||
Reference in New Issue
Block a user