lazy image load

Former-commit-id: cdd733b575411e003bc5ffd6560dd8eff8aa09cf
This commit is contained in:
hiyouga
2024-09-04 02:27:08 +08:00
parent fed7ae5661
commit 7056087e92
19 changed files with 353 additions and 366 deletions

View File

@@ -47,11 +47,15 @@ IMAGES = [Image.new("RGB", (32, 32), (255, 255, 255))]
NO_IMAGES = []
IMGLENS = [1]
NO_IMGLENS = [0]
INPUT_IDS = [0, 1, 2, 3, 4]
LABELS = [0, 1, 2, 3, 4]
FEATURE_SEQLENS = {"token_type_ids": 1024}
SEQLENS = [1024]
def _get_mm_inputs(processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]:
@@ -80,11 +84,11 @@ def test_base_plugin():
# test mm_messages
assert base_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == MM_MESSAGES
assert base_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
_is_close(base_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), {})
_is_close(base_plugin.get_mm_inputs(IMAGES, IMGLENS, SEQLENS, processor), {})
# test text_messages
assert base_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES
assert base_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
_is_close(base_plugin.get_mm_inputs(NO_IMAGES, FEATURE_SEQLENS, processor), {})
_is_close(base_plugin.get_mm_inputs(NO_IMAGES, NO_IMGLENS, SEQLENS, processor), {})
def test_llava_plugin():
@@ -101,11 +105,11 @@ def test_llava_plugin():
# test mm_messages
assert llava_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == expected_mm_messages
assert llava_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
_is_close(llava_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), mm_inputs)
_is_close(llava_plugin.get_mm_inputs(IMAGES, IMGLENS, SEQLENS, processor), mm_inputs)
# test text_messages
assert llava_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES
assert llava_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
_is_close(llava_plugin.get_mm_inputs(NO_IMAGES, FEATURE_SEQLENS, processor), {"pixel_values": None})
_is_close(llava_plugin.get_mm_inputs(NO_IMAGES, NO_IMGLENS, SEQLENS, processor), {})
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
@@ -128,7 +132,7 @@ def test_paligemma_plugin():
expected_input_ids,
expected_labels,
)
_is_close(paligemma_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), mm_inputs)
_is_close(paligemma_plugin.get_mm_inputs(IMAGES, IMGLENS, SEQLENS, processor), mm_inputs)
# test text_messages
assert paligemma_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES
assert paligemma_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (
@@ -136,8 +140,8 @@ def test_paligemma_plugin():
LABELS,
)
_is_close(
paligemma_plugin.get_mm_inputs(NO_IMAGES, FEATURE_SEQLENS, processor),
{"pixel_values": None, "token_type_ids": [[1] * 1024]},
paligemma_plugin.get_mm_inputs(NO_IMAGES, NO_IMGLENS, SEQLENS, processor),
{"token_type_ids": [[1] * 1024]},
)
@@ -158,11 +162,8 @@ def test_qwen2_vl_plugin():
# test mm_messages
assert qwen2_vl_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == expected_mm_messages
assert qwen2_vl_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
_is_close(qwen2_vl_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), mm_inputs)
_is_close(qwen2_vl_plugin.get_mm_inputs(IMAGES, IMGLENS, SEQLENS, processor), mm_inputs)
# test text_messages
assert qwen2_vl_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES
assert qwen2_vl_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
_is_close(
qwen2_vl_plugin.get_mm_inputs(NO_IMAGES, FEATURE_SEQLENS, processor),
{"pixel_values": None, "image_grid_thw": None},
)
_is_close(qwen2_vl_plugin.get_mm_inputs(NO_IMAGES, NO_IMGLENS, SEQLENS, processor), {})