support llava-next(video)

Former-commit-id: 27e94593ac467e56e3a7f5c64f4ff6cee81f4b47
This commit is contained in:
BUAADreamer
2024-09-10 12:31:53 +08:00
parent dfff411e1a
commit 484128b641
11 changed files with 394 additions and 33 deletions

View File

@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import os
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Tuple
@@ -136,6 +136,47 @@ def test_llava_plugin():
_check_plugin(**check_inputs)
def test_idefics2_plugin():
tokenizer, processor = _load_tokenizer_module(model_name_or_path="HuggingFaceM4/idefics2-8b")
idefics2_plugin = get_mm_plugin(name="idefics2", image_token="<image>")
check_inputs = {"plugin": idefics2_plugin, "tokenizer": tokenizer, "processor": processor}
mm_messages = copy.deepcopy(MM_MESSAGES)
fake_image_token = processor.fake_image_token.content
image_str = f"{fake_image_token}{"<image>" * processor.image_seq_len}{fake_image_token}"
image_str = image_str * 5
for message in mm_messages:
content = message["content"]
content = content.replace("<image>", image_str)
content = content.replace(f"{fake_image_token}{fake_image_token}", f"{fake_image_token}")
message['content'] = content
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
_check_plugin(**check_inputs)
def test_llava_next_plugin():
tokenizer, processor = _load_tokenizer_module(model_name_or_path="llava-hf/llava-v1.6-vicuna-7b-hf")
llava_next_plugin = get_mm_plugin(name="llava_next", image_token="<image>")
check_inputs = {"plugin": llava_next_plugin, "tokenizer": tokenizer, "processor": processor}
check_inputs["expected_mm_messages"] = [
{key: value for key, value in message.items()}
for message in MM_MESSAGES
]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
_check_plugin(**check_inputs)
def test_llava_next_video_plugin():
tokenizer, processor = _load_tokenizer_module(model_name_or_path="llava-hf/LLaVA-NeXT-Video-7B-hf")
llava_next_video_plugin = get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>")
check_inputs = {"plugin": llava_next_video_plugin, "tokenizer": tokenizer, "processor": processor}
check_inputs["expected_mm_messages"] = [
{key: value for key, value in message.items()}
for message in MM_MESSAGES
]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
_check_plugin(**check_inputs)
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
def test_paligemma_plugin():
tokenizer, processor = _load_tokenizer_module(model_name_or_path="google/paligemma-3b-pt-224")
@@ -167,3 +208,15 @@ def test_qwen2_vl_plugin():
]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
_check_plugin(**check_inputs)
def test_video_llava_plugin():
tokenizer, processor = _load_tokenizer_module(model_name_or_path="LanguageBind/Video-LLaVA-7B-hf")
video_llava_plugin = get_mm_plugin(name="video_llava", image_token="<image>", video_token="<video>")
check_inputs = {"plugin": video_llava_plugin, "tokenizer": tokenizer, "processor": processor}
check_inputs["expected_mm_messages"] = [
{key: value for key, value in message.items()}
for message in MM_MESSAGES
]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
_check_plugin(**check_inputs)