[data] fix mllama (#7053)

* fix mllama

* fix test

Former-commit-id: f5af20a63f3d59a6a68d323a7c6f68e551edb3a3
This commit is contained in:
hoshi-hiyouga
2025-02-24 22:05:38 +08:00
committed by GitHub
parent c1d5073bd3
commit 065f7fb5da
2 changed files with 85 additions and 66 deletions

View File

@@ -103,15 +103,17 @@ def _check_plugin(
expected_no_mm_inputs: Dict[str, Any] = {},
) -> None:
# test mm_messages
assert plugin.process_messages(MM_MESSAGES, IMAGES, NO_VIDEOS, NO_AUDIOS, processor) == expected_mm_messages
assert plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, NO_VIDEOS, NO_AUDIOS, tokenizer, processor) == (
expected_input_ids,
expected_labels,
)
_is_close(
plugin.get_mm_inputs(IMAGES, NO_VIDEOS, NO_AUDIOS, IMGLENS, NO_VIDLENS, NO_AUDLENS, BATCH_IDS, processor),
expected_mm_inputs,
)
if plugin.__class__.__name__ != "BasePlugin":
assert plugin.process_messages(MM_MESSAGES, IMAGES, NO_VIDEOS, NO_AUDIOS, processor) == expected_mm_messages
assert plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, NO_VIDEOS, NO_AUDIOS, tokenizer, processor) == (
expected_input_ids,
expected_labels,
)
_is_close(
plugin.get_mm_inputs(IMAGES, NO_VIDEOS, NO_AUDIOS, IMGLENS, NO_VIDLENS, NO_AUDLENS, BATCH_IDS, processor),
expected_mm_inputs,
)
# test text_messages
assert plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, NO_VIDEOS, NO_AUDIOS, processor) == TEXT_MESSAGES
assert plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, NO_VIDEOS, NO_AUDIOS, tokenizer, processor) == (
@@ -128,7 +130,7 @@ def _check_plugin(
def test_base_plugin():
tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA)
base_plugin = get_mm_plugin(name="base", image_token="<image>")
base_plugin = get_mm_plugin(name="base")
check_inputs = {"plugin": base_plugin, **tokenizer_module}
_check_plugin(**check_inputs)