Tiny fix.

Former-commit-id: 4c1cef12d812832eed58b5da562ba083104756d3
This commit is contained in:
marko1616
2024-09-26 11:06:21 -04:00
committed by hiyouga
parent 3d35aeca72
commit 918a367378
2 changed files with 4 additions and 9 deletions

View File

@@ -686,19 +686,12 @@ class MllamaPlugin(BasePlugin):
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
num_image_tokens = 0
messages = deepcopy(messages)
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
num_image_tokens += 1
content = content.replace(IMAGE_PLACEHOLDER, "<|image|>", 1)
content = content.replace(IMAGE_PLACEHOLDER, "<|image|>", 1)
message["content"] = content
if len(images) != num_image_tokens:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
return messages
def get_mm_inputs(
@@ -710,7 +703,7 @@ class MllamaPlugin(BasePlugin):
seqlens: Sequence[int],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
super().get_mm_inputs(images, videos, imglens, vidlens, seqlens, processor)
self._get_mm_inputs(images, videos, processor)
if images is not None:
images = [Image.open(image) if isinstance(image, str) else image for image in images]
image_features = processor.image_processor(images)