video datasets
Former-commit-id: 33f28ce82d9e44d2615909250dc56d6a4a03cd99
This commit is contained in:
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -28,6 +28,8 @@ if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
|
||||
from llamafactory.data.mm_plugin import BasePlugin
|
||||
|
||||
|
||||
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
||||
|
||||
@@ -47,10 +49,14 @@ IMAGES = [Image.new("RGB", (32, 32), (255, 255, 255))]
|
||||
|
||||
NO_IMAGES = []
|
||||
|
||||
NO_VIDEOS = []
|
||||
|
||||
IMGLENS = [1]
|
||||
|
||||
NO_IMGLENS = [0]
|
||||
|
||||
NO_VIDLENS = [0]
|
||||
|
||||
INPUT_IDS = [0, 1, 2, 3, 4]
|
||||
|
||||
LABELS = [0, 1, 2, 3, 4]
|
||||
@@ -78,92 +84,97 @@ def _load_tokenizer_module(model_name_or_path: str) -> Tuple["PreTrainedTokenize
|
||||
return tokenizer_module["tokenizer"], tokenizer_module["processor"]
|
||||
|
||||
|
||||
def _check_plugin(
|
||||
plugin: "BasePlugin",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: "ProcessorMixin",
|
||||
expected_mm_messages: Sequence[Dict[str, str]],
|
||||
expected_input_ids: List[int],
|
||||
expected_labels: List[int],
|
||||
expected_mm_inputs: Dict[str, Any],
|
||||
expected_no_mm_inputs: Dict[str, Any],
|
||||
) -> None:
|
||||
# test mm_messages
|
||||
assert plugin.process_messages(MM_MESSAGES, IMAGES, NO_VIDEOS, processor) == expected_mm_messages
|
||||
assert plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, NO_VIDEOS, tokenizer, processor) == (
|
||||
expected_input_ids,
|
||||
expected_labels,
|
||||
)
|
||||
_is_close(
|
||||
plugin.get_mm_inputs(IMAGES, NO_VIDEOS, IMGLENS, NO_VIDLENS, SEQLENS, processor),
|
||||
expected_mm_inputs,
|
||||
)
|
||||
# test text_messages
|
||||
assert plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, NO_VIDEOS, processor) == TEXT_MESSAGES
|
||||
assert plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, NO_VIDEOS, tokenizer, processor) == (
|
||||
INPUT_IDS,
|
||||
LABELS,
|
||||
)
|
||||
_is_close(
|
||||
plugin.get_mm_inputs(NO_IMAGES, NO_VIDEOS, NO_IMGLENS, NO_VIDLENS, SEQLENS, processor),
|
||||
expected_no_mm_inputs,
|
||||
)
|
||||
|
||||
|
||||
def test_base_plugin():
|
||||
tokenizer, processor = _load_tokenizer_module(model_name_or_path=TINY_LLAMA)
|
||||
base_plugin = get_mm_plugin(name="base", image_token="<image>")
|
||||
# test mm_messages
|
||||
assert base_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == MM_MESSAGES
|
||||
assert base_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
|
||||
_is_close(base_plugin.get_mm_inputs(IMAGES, IMGLENS, SEQLENS, processor), {})
|
||||
# test text_messages
|
||||
assert base_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES
|
||||
assert base_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
|
||||
_is_close(base_plugin.get_mm_inputs(NO_IMAGES, NO_IMGLENS, SEQLENS, processor), {})
|
||||
check_inputs = {"plugin": base_plugin, "tokenizer": tokenizer, "processor": processor}
|
||||
check_inputs["expected_mm_messages"] = MM_MESSAGES
|
||||
check_inputs["expected_input_ids"] = INPUT_IDS
|
||||
check_inputs["expected_labels"] = LABELS
|
||||
check_inputs["expected_mm_inputs"] = {}
|
||||
check_inputs["expected_no_mm_inputs"] = {}
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
def test_llava_plugin():
|
||||
tokenizer, processor = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf")
|
||||
llava_plugin = get_mm_plugin(name="llava", image_token="<image>")
|
||||
image_seqlen = 576
|
||||
|
||||
mm_inputs = _get_mm_inputs(processor)
|
||||
expected_mm_messages = [
|
||||
check_inputs = {"plugin": llava_plugin, "tokenizer": tokenizer, "processor": processor}
|
||||
check_inputs["expected_mm_messages"] = [
|
||||
{key: value.replace("<image>", "<image>" * image_seqlen) for key, value in message.items()}
|
||||
for message in MM_MESSAGES
|
||||
]
|
||||
|
||||
llava_plugin = get_mm_plugin(name="llava", image_token="<image>")
|
||||
# test mm_messages
|
||||
assert llava_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == expected_mm_messages
|
||||
assert llava_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
|
||||
_is_close(llava_plugin.get_mm_inputs(IMAGES, IMGLENS, SEQLENS, processor), mm_inputs)
|
||||
# test text_messages
|
||||
assert llava_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES
|
||||
assert llava_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
|
||||
_is_close(llava_plugin.get_mm_inputs(NO_IMAGES, NO_IMGLENS, SEQLENS, processor), {})
|
||||
check_inputs["expected_input_ids"] = INPUT_IDS
|
||||
check_inputs["expected_labels"] = LABELS
|
||||
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
|
||||
check_inputs["expected_no_mm_inputs"] = {}
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
|
||||
def test_paligemma_plugin():
|
||||
tokenizer, processor = _load_tokenizer_module(model_name_or_path="google/paligemma-3b-pt-224")
|
||||
paligemma_plugin = get_mm_plugin(name="paligemma", image_token="<image>")
|
||||
image_seqlen = 256
|
||||
|
||||
mm_inputs = _get_mm_inputs(processor)
|
||||
mm_inputs["token_type_ids"] = [[0] * image_seqlen + [1] * (1024 - image_seqlen)]
|
||||
expected_mm_messages = [
|
||||
check_inputs = {"plugin": paligemma_plugin, "tokenizer": tokenizer, "processor": processor}
|
||||
check_inputs["expected_mm_messages"] = [
|
||||
{key: value.replace("<image>", "") for key, value in message.items()} for message in MM_MESSAGES
|
||||
]
|
||||
expected_input_ids = [tokenizer.convert_tokens_to_ids("<image>")] * image_seqlen + INPUT_IDS
|
||||
expected_labels = [-100] * image_seqlen + LABELS
|
||||
|
||||
paligemma_plugin = get_mm_plugin(name="paligemma", image_token="<image>")
|
||||
# test mm_messages
|
||||
assert paligemma_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == expected_mm_messages
|
||||
assert paligemma_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (
|
||||
expected_input_ids,
|
||||
expected_labels,
|
||||
)
|
||||
_is_close(paligemma_plugin.get_mm_inputs(IMAGES, IMGLENS, SEQLENS, processor), mm_inputs)
|
||||
# test text_messages
|
||||
assert paligemma_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES
|
||||
assert paligemma_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (
|
||||
INPUT_IDS,
|
||||
LABELS,
|
||||
)
|
||||
_is_close(
|
||||
paligemma_plugin.get_mm_inputs(NO_IMAGES, NO_IMGLENS, SEQLENS, processor),
|
||||
{"token_type_ids": [[1] * 1024]},
|
||||
)
|
||||
check_inputs["expected_input_ids"] = [tokenizer.convert_tokens_to_ids("<image>")] * image_seqlen + INPUT_IDS
|
||||
check_inputs["expected_labels"] = [-100] * image_seqlen + LABELS
|
||||
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
|
||||
check_inputs["expected_mm_inputs"]["token_type_ids"] = [[0] * image_seqlen + [1] * (1024 - image_seqlen)]
|
||||
check_inputs["expected_no_mm_inputs"] = {"token_type_ids": [[1] * 1024]}
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
def test_qwen2_vl_plugin():
|
||||
tokenizer, processor = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct")
|
||||
qwen2_vl_plugin = get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>")
|
||||
image_seqlen = 4
|
||||
|
||||
mm_inputs = _get_mm_inputs(processor)
|
||||
expected_mm_messages = [
|
||||
check_inputs = {"plugin": qwen2_vl_plugin, "tokenizer": tokenizer, "processor": processor}
|
||||
check_inputs["expected_mm_messages"] = [
|
||||
{
|
||||
key: value.replace("<image>", "<|vision_start|>{}<|vision_end|>".format("<|image_pad|>" * image_seqlen))
|
||||
for key, value in message.items()
|
||||
}
|
||||
for message in MM_MESSAGES
|
||||
]
|
||||
|
||||
qwen2_vl_plugin = get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>")
|
||||
# test mm_messages
|
||||
assert qwen2_vl_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == expected_mm_messages
|
||||
assert qwen2_vl_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
|
||||
_is_close(qwen2_vl_plugin.get_mm_inputs(IMAGES, IMGLENS, SEQLENS, processor), mm_inputs)
|
||||
# test text_messages
|
||||
assert qwen2_vl_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES
|
||||
assert qwen2_vl_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
|
||||
_is_close(qwen2_vl_plugin.get_mm_inputs(NO_IMAGES, NO_IMGLENS, SEQLENS, processor), {})
|
||||
check_inputs["expected_input_ids"] = INPUT_IDS
|
||||
check_inputs["expected_labels"] = LABELS
|
||||
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
|
||||
check_inputs["expected_no_mm_inputs"] = {}
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
Reference in New Issue
Block a user