mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-01 08:13:38 +00:00
[model] support LiquidAI's LFM2.5-VL vision-language model (#9729)
This commit is contained in:
@@ -2092,6 +2092,73 @@ class VideoLlavaPlugin(BasePlugin):
|
||||
return messages
|
||||
|
||||
|
||||
@dataclass
|
||||
class LFMVLPlugin(BasePlugin):
|
||||
r"""Plugin for LFM2.5-VL vision-language models.
|
||||
|
||||
LFM2.5-VL uses dynamic image token counts based on image resolution.
|
||||
The image processor returns spatial_shapes tensor with [height, width] grid dimensions.
|
||||
Token count per image = (spatial_h * spatial_w) / (downsample_factor^2)
|
||||
"""
|
||||
|
||||
@override
|
||||
def _get_mm_inputs(
|
||||
self,
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: "MMProcessor",
|
||||
) -> dict[str, "torch.Tensor"]:
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
||||
mm_inputs = {}
|
||||
if len(images) != 0:
|
||||
images = self._regularize_images(
|
||||
images,
|
||||
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
|
||||
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
|
||||
)["images"]
|
||||
mm_inputs.update(image_processor(images, return_tensors="pt"))
|
||||
return mm_inputs
|
||||
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
self._validate_messages(messages, images, videos, audios)
|
||||
num_image_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
||||
downsample_factor: int = getattr(image_processor, "downsample_factor", 2)
|
||||
|
||||
if self.expand_mm_tokens and len(images) > 0:
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
spatial_shapes = mm_inputs.get("spatial_shapes", [])
|
||||
else:
|
||||
spatial_shapes = []
|
||||
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
if self.expand_mm_tokens and len(spatial_shapes) > num_image_tokens:
|
||||
h, w = spatial_shapes[num_image_tokens].tolist()
|
||||
image_seqlen = (h * w) // (downsample_factor * downsample_factor)
|
||||
else:
|
||||
image_seqlen = 1
|
||||
|
||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
||||
num_image_tokens += 1
|
||||
|
||||
message["content"] = content.replace("{{image}}", self.image_token)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
PLUGINS = {
|
||||
"base": BasePlugin,
|
||||
"ernie_vl": ErnieVLPlugin,
|
||||
@@ -2104,6 +2171,7 @@ PLUGINS = {
|
||||
"llava": LlavaPlugin,
|
||||
"llava_next": LlavaNextPlugin,
|
||||
"llava_next_video": LlavaNextVideoPlugin,
|
||||
"lfm2_vl": LFMVLPlugin,
|
||||
"minicpm_v": MiniCPMVPlugin,
|
||||
"mllama": MllamaPlugin,
|
||||
"paligemma": PaliGemmaPlugin,
|
||||
|
||||
Reference in New Issue
Block a user