mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-04-06 17:23:08 +00:00
[model] gemma4 (#10346)
This commit is contained in:
@@ -57,7 +57,7 @@ TEXT_MESSAGES = [
|
||||
]
|
||||
|
||||
VIDEO_MESSAGES = [
|
||||
{"role": "user", "content": "<video>What is in this viode?"},
|
||||
{"role": "user", "content": "<video>What is in this video?"},
|
||||
{"role": "assistant", "content": "A cat."},
|
||||
]
|
||||
|
||||
@@ -210,6 +210,34 @@ def test_gemma3_plugin():
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
@pytest.mark.skipif(not is_transformers_version_greater_than("5.6.0"), reason="Requires transformers>=5.6.0")
|
||||
def test_gemma4_plugin():
|
||||
tokenizer_module = _load_tokenizer_module(model_name_or_path="google/gemma-4-31B-it")
|
||||
processor = tokenizer_module["processor"]
|
||||
gemma4_plugin = get_mm_plugin(name="gemma4", image_token="<|image|>", video_token="<|video|>")
|
||||
check_inputs = {"plugin": gemma4_plugin, **tokenizer_module}
|
||||
# validate
|
||||
mm_inputs = gemma4_plugin._get_mm_inputs(IMAGES, NO_VIDEOS, NO_AUDIOS, processor)
|
||||
num_image_soft_tokens = 256 # when we use default max_soft_tokens=280
|
||||
image_token = getattr(processor, "image_token")
|
||||
boi_token = getattr(processor, "boi_token")
|
||||
eoi_token = getattr(processor, "eoi_token")
|
||||
|
||||
expected_mm_type_ids = [[int(token_id == getattr(processor, "image_token_id")) for token_id in token_ids] for token_ids in BATCH_IDS]
|
||||
check_inputs["expected_mm_messages"] = [
|
||||
{"role": "user", "content": f"{boi_token}{image_token * num_image_soft_tokens}{eoi_token}What is in this image?"},
|
||||
{"role": "assistant", "content": "A cat."},
|
||||
]
|
||||
for key in ("num_soft_tokens_per_image",):
|
||||
mm_inputs.pop(key, None)
|
||||
|
||||
mm_inputs["mm_token_type_ids"] = expected_mm_type_ids
|
||||
check_inputs["expected_mm_inputs"] = mm_inputs
|
||||
check_inputs["expected_no_mm_inputs"] = {"mm_token_type_ids": expected_mm_type_ids}
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
@pytest.mark.skipif(not is_transformers_version_greater_than("4.52.0"), reason="Requires transformers>=4.52.0")
|
||||
def test_internvl_plugin():
|
||||
|
||||
Reference in New Issue
Block a user