[model] support LiquidAI's LFM2.5-VL vision-language model (#9729)

This commit is contained in:
Vo Van Phuc
2026-01-07 16:20:29 +07:00
committed by GitHub
parent b4e051bea4
commit 958fb523a2
5 changed files with 118 additions and 0 deletions

View File

@@ -2092,6 +2092,73 @@ class VideoLlavaPlugin(BasePlugin):
return messages
@dataclass
class LFMVLPlugin(BasePlugin):
r"""Plugin for LFM2.5-VL vision-language models.
LFM2.5-VL uses dynamic image token counts based on image resolution.
The image processor returns spatial_shapes tensor with [height, width] grid dimensions.
Token count per image = (spatial_h * spatial_w) / (downsample_factor^2)
"""
@override
def _get_mm_inputs(
self,
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: "MMProcessor",
) -> dict[str, "torch.Tensor"]:
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
mm_inputs = {}
if len(images) != 0:
images = self._regularize_images(
images,
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
)["images"]
mm_inputs.update(image_processor(images, return_tensors="pt"))
return mm_inputs
@override
def process_messages(
self,
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
self._validate_messages(messages, images, videos, audios)
num_image_tokens = 0
messages = deepcopy(messages)
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
downsample_factor: int = getattr(image_processor, "downsample_factor", 2)
if self.expand_mm_tokens and len(images) > 0:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
spatial_shapes = mm_inputs.get("spatial_shapes", [])
else:
spatial_shapes = []
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
if self.expand_mm_tokens and len(spatial_shapes) > num_image_tokens:
h, w = spatial_shapes[num_image_tokens].tolist()
image_seqlen = (h * w) // (downsample_factor * downsample_factor)
else:
image_seqlen = 1
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
num_image_tokens += 1
message["content"] = content.replace("{{image}}", self.image_token)
return messages
PLUGINS = {
"base": BasePlugin,
"ernie_vl": ErnieVLPlugin,
@@ -2104,6 +2171,7 @@ PLUGINS = {
"llava": LlavaPlugin,
"llava_next": LlavaNextPlugin,
"llava_next_video": LlavaNextVideoPlugin,
"lfm2_vl": LFMVLPlugin,
"minicpm_v": MiniCPMVPlugin,
"mllama": MllamaPlugin,
"paligemma": PaliGemmaPlugin,