[model] gemma4 (#10346)

This commit is contained in:
Kingsley
2026-04-05 12:10:28 +08:00
committed by GitHub
parent acac63ef35
commit eae6f0b541
8 changed files with 576 additions and 7 deletions

View File

@@ -57,7 +57,7 @@ TEXT_MESSAGES = [
]
VIDEO_MESSAGES = [
{"role": "user", "content": "<video>What is in this viode?"},
{"role": "user", "content": "<video>What is in this video?"},
{"role": "assistant", "content": "A cat."},
]
@@ -210,6 +210,34 @@ def test_gemma3_plugin():
_check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not is_transformers_version_greater_than("5.6.0"), reason="Requires transformers>=5.6.0")
def test_gemma4_plugin():
tokenizer_module = _load_tokenizer_module(model_name_or_path="google/gemma-4-31B-it")
processor = tokenizer_module["processor"]
gemma4_plugin = get_mm_plugin(name="gemma4", image_token="<|image|>", video_token="<|video|>")
check_inputs = {"plugin": gemma4_plugin, **tokenizer_module}
# validate
mm_inputs = gemma4_plugin._get_mm_inputs(IMAGES, NO_VIDEOS, NO_AUDIOS, processor)
num_image_soft_tokens = 256 # when we use default max_soft_tokens=280
image_token = getattr(processor, "image_token")
boi_token = getattr(processor, "boi_token")
eoi_token = getattr(processor, "eoi_token")
expected_mm_type_ids = [[int(token_id == getattr(processor, "image_token_id")) for token_id in token_ids] for token_ids in BATCH_IDS]
check_inputs["expected_mm_messages"] = [
{"role": "user", "content": f"{boi_token}{image_token * num_image_soft_tokens}{eoi_token}What is in this image?"},
{"role": "assistant", "content": "A cat."},
]
for key in ("num_soft_tokens_per_image",):
mm_inputs.pop(key, None)
mm_inputs["mm_token_type_ids"] = expected_mm_type_ids
check_inputs["expected_mm_inputs"] = mm_inputs
check_inputs["expected_no_mm_inputs"] = {"mm_token_type_ids": expected_mm_type_ids}
_check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not is_transformers_version_greater_than("4.52.0"), reason="Requires transformers>=4.52.0")
def test_internvl_plugin():