try to past test

Former-commit-id: 3b6bfae0e5fe795a70d530b2765f27d95c5862f8
This commit is contained in:
BUAADreamer
2024-09-10 13:12:51 +08:00
parent 66b870fd08
commit 514f976cc1
4 changed files with 54 additions and 43 deletions

View File

@@ -296,11 +296,11 @@ class LlavaPlugin(BasePlugin):
class LlavaNextPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
num_image_tokens = 0
@@ -318,13 +318,13 @@ class LlavaNextPlugin(BasePlugin):
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
seqlens: Sequence[int],
processor: Optional["ProcessorMixin"],
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
seqlens: Sequence[int],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
return _get_mm_inputs(images, videos, processor)
@@ -379,6 +379,7 @@ class LlavaNextVideoPlugin(BasePlugin):
res.update(video_res)
return res
class PaliGemmaPlugin(BasePlugin):
@override
def process_messages(