support llava-next(video)

Former-commit-id: 27e94593ac467e56e3a7f5c64f4ff6cee81f4b47
This commit is contained in:
BUAADreamer
2024-09-10 12:31:53 +08:00
parent dfff411e1a
commit 484128b641
11 changed files with 394 additions and 33 deletions

View File

@@ -209,6 +209,50 @@ class BasePlugin:
return {}
class Idefics2Plugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
num_image_tokens = 0
messages = deepcopy(messages)
fake_image_token = processor.fake_image_token.content
image_str = f"{fake_image_token}{self.image_token * processor.image_seq_len}{fake_image_token}"
image_str = image_str * 5
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
num_image_tokens += 1
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
content = content.replace("{{image}}", image_str)
content = content.replace(f"{fake_image_token}{fake_image_token}", f"{fake_image_token}")
message["content"] = content
if len(images) != num_image_tokens:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
return messages
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
seqlens: Sequence[int],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
return _get_mm_inputs(images, videos, processor)
class LlavaPlugin(BasePlugin):
@override
def process_messages(
@@ -249,6 +293,92 @@ class LlavaPlugin(BasePlugin):
return _get_mm_inputs(images, videos, processor)
class LlavaNextPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
num_image_tokens = 0
messages = deepcopy(messages)
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
num_image_tokens += 1
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
if len(images) != num_image_tokens:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
return messages
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
seqlens: Sequence[int],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
return _get_mm_inputs(images, videos, processor)
class LlavaNextVideoPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
num_image_tokens = 0
num_video_tokens = 0
messages = deepcopy(messages)
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
num_image_tokens += 1
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
while VIDEO_PLACEHOLDER in content:
num_video_tokens += 1
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}", 1)
if len(images) != num_image_tokens:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
if len(videos) != num_video_tokens:
raise ValueError("The number of videos does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
return messages
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
seqlens: Sequence[int],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
video_processor = getattr(processor, "video_processor")
res = _get_mm_inputs(images, [], processor)
if len(videos) != 0:
videos = _regularize_videos(videos, processor)
video_res = video_processor(videos, return_tensors="pt")
res.update(video_res)
return res
class PaliGemmaPlugin(BasePlugin):
@override
def process_messages(
@@ -380,11 +510,59 @@ class Qwen2vlPlugin(BasePlugin):
return _get_mm_inputs(images, videos, processor)
class VideoLlavaPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
num_image_tokens = 0
num_video_tokens = 0
messages = deepcopy(messages)
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
num_image_tokens += 1
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
while VIDEO_PLACEHOLDER in content:
num_video_tokens += 1
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}", 1)
if len(images) != num_image_tokens:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
if len(videos) != num_video_tokens:
raise ValueError("The number of videos does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
return messages
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
seqlens: Sequence[int],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
return _get_mm_inputs(images, videos, processor)
PLUGINS = {
"base": BasePlugin,
"idefics2": Idefics2Plugin,
"llava": LlavaPlugin,
"llava_next": LlavaNextPlugin,
"llava_next_video": LlavaNextVideoPlugin,
"paligemma": PaliGemmaPlugin,
"qwen2_vl": Qwen2vlPlugin,
"video_llava": VideoLlavaPlugin,
}