lazy image load
Former-commit-id: cdd733b575411e003bc5ffd6560dd8eff8aa09cf
This commit is contained in:
@@ -47,11 +47,15 @@ IMAGES = [Image.new("RGB", (32, 32), (255, 255, 255))]
|
||||
|
||||
NO_IMAGES = []
|
||||
|
||||
IMGLENS = [1]
|
||||
|
||||
NO_IMGLENS = [0]
|
||||
|
||||
INPUT_IDS = [0, 1, 2, 3, 4]
|
||||
|
||||
LABELS = [0, 1, 2, 3, 4]
|
||||
|
||||
FEATURE_SEQLENS = {"token_type_ids": 1024}
|
||||
SEQLENS = [1024]
|
||||
|
||||
|
||||
def _get_mm_inputs(processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]:
|
||||
@@ -80,11 +84,11 @@ def test_base_plugin():
|
||||
# test mm_messages
|
||||
assert base_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == MM_MESSAGES
|
||||
assert base_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
|
||||
_is_close(base_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), {})
|
||||
_is_close(base_plugin.get_mm_inputs(IMAGES, IMGLENS, SEQLENS, processor), {})
|
||||
# test text_messages
|
||||
assert base_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES
|
||||
assert base_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
|
||||
_is_close(base_plugin.get_mm_inputs(NO_IMAGES, FEATURE_SEQLENS, processor), {})
|
||||
_is_close(base_plugin.get_mm_inputs(NO_IMAGES, NO_IMGLENS, SEQLENS, processor), {})
|
||||
|
||||
|
||||
def test_llava_plugin():
|
||||
@@ -101,11 +105,11 @@ def test_llava_plugin():
|
||||
# test mm_messages
|
||||
assert llava_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == expected_mm_messages
|
||||
assert llava_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
|
||||
_is_close(llava_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), mm_inputs)
|
||||
_is_close(llava_plugin.get_mm_inputs(IMAGES, IMGLENS, SEQLENS, processor), mm_inputs)
|
||||
# test text_messages
|
||||
assert llava_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES
|
||||
assert llava_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
|
||||
_is_close(llava_plugin.get_mm_inputs(NO_IMAGES, FEATURE_SEQLENS, processor), {"pixel_values": None})
|
||||
_is_close(llava_plugin.get_mm_inputs(NO_IMAGES, NO_IMGLENS, SEQLENS, processor), {})
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
|
||||
@@ -128,7 +132,7 @@ def test_paligemma_plugin():
|
||||
expected_input_ids,
|
||||
expected_labels,
|
||||
)
|
||||
_is_close(paligemma_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), mm_inputs)
|
||||
_is_close(paligemma_plugin.get_mm_inputs(IMAGES, IMGLENS, SEQLENS, processor), mm_inputs)
|
||||
# test text_messages
|
||||
assert paligemma_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES
|
||||
assert paligemma_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (
|
||||
@@ -136,8 +140,8 @@ def test_paligemma_plugin():
|
||||
LABELS,
|
||||
)
|
||||
_is_close(
|
||||
paligemma_plugin.get_mm_inputs(NO_IMAGES, FEATURE_SEQLENS, processor),
|
||||
{"pixel_values": None, "token_type_ids": [[1] * 1024]},
|
||||
paligemma_plugin.get_mm_inputs(NO_IMAGES, NO_IMGLENS, SEQLENS, processor),
|
||||
{"token_type_ids": [[1] * 1024]},
|
||||
)
|
||||
|
||||
|
||||
@@ -158,11 +162,8 @@ def test_qwen2_vl_plugin():
|
||||
# test mm_messages
|
||||
assert qwen2_vl_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == expected_mm_messages
|
||||
assert qwen2_vl_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
|
||||
_is_close(qwen2_vl_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), mm_inputs)
|
||||
_is_close(qwen2_vl_plugin.get_mm_inputs(IMAGES, IMGLENS, SEQLENS, processor), mm_inputs)
|
||||
# test text_messages
|
||||
assert qwen2_vl_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES
|
||||
assert qwen2_vl_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
|
||||
_is_close(
|
||||
qwen2_vl_plugin.get_mm_inputs(NO_IMAGES, FEATURE_SEQLENS, processor),
|
||||
{"pixel_values": None, "image_grid_thw": None},
|
||||
)
|
||||
_is_close(qwen2_vl_plugin.get_mm_inputs(NO_IMAGES, NO_IMGLENS, SEQLENS, processor), {})
|
||||
|
||||
Reference in New Issue
Block a user