fix vlm zero3 training

Former-commit-id: 86fe7fe71b51077310357b7b1895522258f9bc7a
This commit is contained in:
hiyouga
2024-12-04 09:40:39 +00:00
parent c07ba8ccc0
commit 2f09c34980
3 changed files with 157 additions and 41 deletions

View File

@@ -13,14 +13,14 @@
# limitations under the License.
import os
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Sequence
import pytest
import torch
from PIL import Image
from llamafactory.data.mm_plugin import get_mm_plugin
from llamafactory.hparams import ModelArguments
from llamafactory.hparams import get_infer_args
from llamafactory.model import load_tokenizer
@@ -29,6 +29,7 @@ if TYPE_CHECKING:
from transformers.image_processing_utils import BaseImageProcessor
from llamafactory.data.mm_plugin import BasePlugin
from llamafactory.model.loader import TokenizerModule
HF_TOKEN = os.getenv("HF_TOKEN")
@@ -82,10 +83,9 @@ def _is_close(batch_a: Dict[str, Any], batch_b: Dict[str, Any]) -> None:
assert batch_a[key] == batch_b[key]
def _load_tokenizer_module(model_name_or_path: str) -> Tuple["PreTrainedTokenizer", "ProcessorMixin"]:
model_args = ModelArguments(model_name_or_path=model_name_or_path)
tokenizer_module = load_tokenizer(model_args)
return tokenizer_module["tokenizer"], tokenizer_module["processor"]
def _load_tokenizer_module(model_name_or_path: str) -> "TokenizerModule":
model_args, *_ = get_infer_args({"model_name_or_path": model_name_or_path, "template": "default"})
return load_tokenizer(model_args)
def _check_plugin(
@@ -121,73 +121,75 @@ def _check_plugin(
def test_base_plugin():
tokenizer, processor = _load_tokenizer_module(model_name_or_path=TINY_LLAMA)
tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA)
base_plugin = get_mm_plugin(name="base", image_token="<image>")
check_inputs = {"plugin": base_plugin, "tokenizer": tokenizer, "processor": processor}
check_inputs = {"plugin": base_plugin, **tokenizer_module}
_check_plugin(**check_inputs)
def test_llava_plugin():
tokenizer, processor = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf")
llava_plugin = get_mm_plugin(name="llava", image_token="<image>")
image_seqlen = 576
check_inputs = {"plugin": llava_plugin, "tokenizer": tokenizer, "processor": processor}
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf")
llava_plugin = get_mm_plugin(name="llava", image_token="<image>")
check_inputs = {"plugin": llava_plugin, **tokenizer_module}
check_inputs["expected_mm_messages"] = [
{key: value.replace("<image>", "<image>" * image_seqlen) for key, value in message.items()}
for message in MM_MESSAGES
]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
_check_plugin(**check_inputs)
def test_llava_next_plugin():
tokenizer, processor = _load_tokenizer_module(model_name_or_path="llava-hf/llava-v1.6-vicuna-7b-hf")
llava_next_plugin = get_mm_plugin(name="llava_next", image_token="<image>")
check_inputs = {"plugin": llava_next_plugin, "tokenizer": tokenizer, "processor": processor}
image_seqlen = 1176
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-v1.6-vicuna-7b-hf")
llava_next_plugin = get_mm_plugin(name="llava_next", image_token="<image>")
check_inputs = {"plugin": llava_next_plugin, **tokenizer_module}
check_inputs["expected_mm_messages"] = [
{key: value.replace("<image>", "<image>" * image_seqlen) for key, value in message.items()}
for message in MM_MESSAGES
]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
_check_plugin(**check_inputs)
def test_llava_next_video_plugin():
tokenizer, processor = _load_tokenizer_module(model_name_or_path="llava-hf/LLaVA-NeXT-Video-7B-hf")
llava_next_video_plugin = get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>")
check_inputs = {"plugin": llava_next_video_plugin, "tokenizer": tokenizer, "processor": processor}
image_seqlen = 1176
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/LLaVA-NeXT-Video-7B-hf")
llava_next_video_plugin = get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>")
check_inputs = {"plugin": llava_next_video_plugin, **tokenizer_module}
check_inputs["expected_mm_messages"] = [
{key: value.replace("<image>", "<image>" * image_seqlen) for key, value in message.items()}
for message in MM_MESSAGES
]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
_check_plugin(**check_inputs)
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
def test_paligemma_plugin():
tokenizer, processor = _load_tokenizer_module(model_name_or_path="google/paligemma-3b-pt-224")
paligemma_plugin = get_mm_plugin(name="paligemma", image_token="<image>")
image_seqlen = 256
check_inputs = {"plugin": paligemma_plugin, "tokenizer": tokenizer, "processor": processor}
tokenizer_module = _load_tokenizer_module(model_name_or_path="google/paligemma-3b-pt-224")
paligemma_plugin = get_mm_plugin(name="paligemma", image_token="<image>")
check_inputs = {"plugin": paligemma_plugin, **tokenizer_module}
check_inputs["expected_mm_messages"] = [
{key: value.replace("<image>", "") for key, value in message.items()} for message in MM_MESSAGES
]
check_inputs["expected_input_ids"] = [tokenizer.convert_tokens_to_ids("<image>")] * image_seqlen + INPUT_IDS
check_inputs["expected_input_ids"] = [
tokenizer_module["tokenizer"].convert_tokens_to_ids(paligemma_plugin.image_token)
] * image_seqlen + INPUT_IDS
check_inputs["expected_labels"] = [-100] * image_seqlen + LABELS
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
check_inputs["expected_mm_inputs"]["token_type_ids"] = [[0] * image_seqlen + [1] * (1024 - image_seqlen)]
check_inputs["expected_no_mm_inputs"] = {"token_type_ids": [[1] * 1024]}
_check_plugin(**check_inputs)
def test_pixtral_plugin():
tokenizer, processor = _load_tokenizer_module(model_name_or_path="mistral-community/pixtral-12b")
pixtral_plugin = get_mm_plugin(name="pixtral", image_token="[IMG]")
image_slice_height, image_slice_width = 2, 2
check_inputs = {"plugin": pixtral_plugin, "tokenizer": tokenizer, "processor": processor}
tokenizer_module = _load_tokenizer_module(model_name_or_path="mistral-community/pixtral-12b")
pixtral_plugin = get_mm_plugin(name="pixtral", image_token="[IMG]")
check_inputs = {"plugin": pixtral_plugin, **tokenizer_module}
check_inputs["expected_mm_messages"] = [
{
key: value.replace(
@@ -199,17 +201,17 @@ def test_pixtral_plugin():
}
for message in MM_MESSAGES
]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
check_inputs["expected_mm_inputs"].pop("image_sizes")
check_inputs["expected_mm_inputs"]["pixel_values"] = check_inputs["expected_mm_inputs"]["pixel_values"][0]
_check_plugin(**check_inputs)
def test_qwen2_vl_plugin():
tokenizer, processor = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct")
qwen2_vl_plugin = get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>")
image_seqlen = 4
check_inputs = {"plugin": qwen2_vl_plugin, "tokenizer": tokenizer, "processor": processor}
tokenizer_module = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct")
qwen2_vl_plugin = get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>")
check_inputs = {"plugin": qwen2_vl_plugin, **tokenizer_module}
check_inputs["expected_mm_messages"] = [
{
key: value.replace("<image>", "<|vision_start|>{}<|vision_end|>".format("<|image_pad|>" * image_seqlen))
@@ -217,18 +219,18 @@ def test_qwen2_vl_plugin():
}
for message in MM_MESSAGES
]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
_check_plugin(**check_inputs)
def test_video_llava_plugin():
tokenizer, processor = _load_tokenizer_module(model_name_or_path="LanguageBind/Video-LLaVA-7B-hf")
video_llava_plugin = get_mm_plugin(name="video_llava", image_token="<image>", video_token="<video>")
check_inputs = {"plugin": video_llava_plugin, "tokenizer": tokenizer, "processor": processor}
image_seqlen = 256
tokenizer_module = _load_tokenizer_module(model_name_or_path="LanguageBind/Video-LLaVA-7B-hf")
video_llava_plugin = get_mm_plugin(name="video_llava", image_token="<image>", video_token="<video>")
check_inputs = {"plugin": video_llava_plugin, **tokenizer_module}
check_inputs["expected_mm_messages"] = [
{key: value.replace("<image>", "<image>" * image_seqlen) for key, value in message.items()}
for message in MM_MESSAGES
]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
_check_plugin(**check_inputs)