[version] support transformers 449 (#6982)
* support transformers 449 * fix mm plugin Former-commit-id: e9118a9df0839d24f6ddff5a0b55ef101a1d3d22
This commit is contained in:
@@ -380,10 +380,8 @@ class LlavaNextPlugin(BasePlugin):
|
||||
num_image_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
if "image_sizes" in mm_inputs:
|
||||
image_sizes = iter(mm_inputs["image_sizes"])
|
||||
|
||||
if "pixel_values" in mm_inputs:
|
||||
image_sizes = iter(mm_inputs["image_sizes"].tolist())
|
||||
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
|
||||
|
||||
for message in messages:
|
||||
@@ -439,7 +437,7 @@ class LlavaNextVideoPlugin(BasePlugin):
|
||||
messages = deepcopy(messages)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
if "pixel_values" in mm_inputs:
|
||||
image_sizes = iter(mm_inputs["image_sizes"])
|
||||
image_sizes = iter(mm_inputs["image_sizes"].tolist())
|
||||
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
@@ -916,16 +914,14 @@ class PixtralPlugin(BasePlugin):
|
||||
num_image_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
image_input_sizes = mm_inputs.get("image_sizes", None)
|
||||
if "pixel_values" in mm_inputs:
|
||||
image_sizes = iter(mm_inputs["image_sizes"].tolist())
|
||||
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
if image_input_sizes is None:
|
||||
raise ValueError("Cannot get image input sizes.")
|
||||
|
||||
if self.expand_mm_tokens:
|
||||
image_size = image_input_sizes[0][num_image_tokens]
|
||||
height, width = image_size
|
||||
height, width = next(image_sizes)
|
||||
num_height_tokens = height // patch_size
|
||||
num_width_tokens = width // patch_size
|
||||
replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
|
||||
@@ -959,9 +955,6 @@ class PixtralPlugin(BasePlugin):
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos, audios)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
if mm_inputs.get("pixel_values"):
|
||||
mm_inputs["pixel_values"] = mm_inputs["pixel_values"][0]
|
||||
|
||||
mm_inputs.pop("image_sizes", None)
|
||||
return mm_inputs
|
||||
|
||||
|
||||
Reference in New Issue
Block a user