[misc] fix packing and eval plot (#7623)
This commit is contained in:
@@ -20,7 +20,6 @@ import torch
|
||||
from PIL import Image
|
||||
|
||||
from llamafactory.data.mm_plugin import get_mm_plugin
|
||||
from llamafactory.extras.packages import is_transformers_version_greater_than
|
||||
from llamafactory.hparams import get_infer_args
|
||||
from llamafactory.model import load_tokenizer
|
||||
|
||||
@@ -35,7 +34,8 @@ if TYPE_CHECKING:
|
||||
|
||||
HF_TOKEN = os.getenv("HF_TOKEN")
|
||||
|
||||
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||
TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
|
||||
TINY_LLAMA4 = os.getenv("TINY_LLAMA4", "llamafactory/tiny-random-Llama-4")
|
||||
|
||||
MM_MESSAGES = [
|
||||
{"role": "user", "content": "<image>What is in this image?"},
|
||||
@@ -130,13 +130,13 @@ def _check_plugin(
|
||||
|
||||
|
||||
def test_base_plugin():
|
||||
tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA)
|
||||
tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA3)
|
||||
base_plugin = get_mm_plugin(name="base")
|
||||
check_inputs = {"plugin": base_plugin, **tokenizer_module}
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HF_TOKEN or not is_transformers_version_greater_than("4.50.0"), reason="Gated model.")
|
||||
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
|
||||
def test_gemma3_plugin():
|
||||
image_seqlen = 256
|
||||
tokenizer_module = _load_tokenizer_module(model_name_or_path="google/gemma-3-4b-it")
|
||||
@@ -157,6 +157,27 @@ def test_gemma3_plugin():
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="Unknown error.")
|
||||
def test_llama4_plugin():
|
||||
tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA4)
|
||||
processor = tokenizer_module["processor"]
|
||||
llama4_plugin = get_mm_plugin(name="llama4", image_token="<|image|>")
|
||||
check_inputs = {"plugin": llama4_plugin, **tokenizer_module}
|
||||
mm_inputs = _get_mm_inputs(tokenizer_module["processor"])
|
||||
image_height, image_width = mm_inputs["pixel_values"][0].shape[-2:]
|
||||
num_patches_per_chunk = int(
|
||||
(image_height // processor.patch_size) * (image_width // processor.patch_size) // processor.downsample_ratio
|
||||
)
|
||||
aspect_ratios = mm_inputs.pop("aspect_ratios")
|
||||
tokens_for_this_image = processor._prompt_split_image(aspect_ratios[0], num_patches_per_chunk)
|
||||
check_inputs["expected_mm_messages"] = [
|
||||
{key: value.replace("<image>", tokens_for_this_image) for key, value in message.items()}
|
||||
for message in MM_MESSAGES
|
||||
]
|
||||
check_inputs["expected_mm_inputs"] = mm_inputs
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
def test_llava_plugin():
|
||||
image_seqlen = 576
|
||||
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf")
|
||||
|
||||
Reference in New Issue
Block a user