From b53d7037c2a477efb01d57b3a88b3dcfa2b46515 Mon Sep 17 00:00:00 2001 From: Hertz <2267379130@qq.com> Date: Mon, 2 Feb 2026 21:42:43 +0800 Subject: [PATCH] [model] support youtu-vl model (#10152) --- src/llamafactory/data/mm_plugin.py | 35 +++++++++++++++++++ src/llamafactory/data/template.py | 13 +++++++ src/llamafactory/extras/constants.py | 12 +++++++ .../model/model_utils/attention.py | 12 +++++++ src/llamafactory/model/patcher.py | 23 ++++++++++++ 5 files changed, 95 insertions(+) diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index b0a345e1c..fca092cb8 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -2159,6 +2159,40 @@ class LFMVLPlugin(BasePlugin): return messages +@dataclass +class YoutuVLPlugin(BasePlugin): + r"""Plugin for Youtu-VL vision-language models.""" + + vision_bos_token: str = "<|vision_start|>" + vision_eos_token: str = "<|vision_end|>" + + @override + def process_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: Optional["MMProcessor"], + ) -> list[dict[str, str]]: + self._validate_input(processor, images, videos, audios) + self._validate_messages(messages, images, videos, audios) + messages = deepcopy(messages) + + for message in messages: + content = message["content"] + content = content.replace( + IMAGE_PLACEHOLDER, f"{self.vision_bos_token}{self.image_token}{self.vision_eos_token}" + ) + content = content.replace( + VIDEO_PLACEHOLDER, f"{self.vision_bos_token}{self.video_token}{self.vision_eos_token}" + ) + + message["content"] = content + + return messages + + PLUGINS = { "base": BasePlugin, "ernie_vl": ErnieVLPlugin, @@ -2181,6 +2215,7 @@ PLUGINS = { "qwen2_vl": Qwen2VLPlugin, "qwen3_vl": Qwen3VLPlugin, "video_llava": VideoLlavaPlugin, + "youtu_vl": YoutuVLPlugin, } diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 562f3375c..8a29e32ee 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -2146,6 +2146,19 @@ register_template( ) +register_template( + name="youtu_vl", + format_user=StringFormatter( + slots=["<|begin_of_text|>user\n{{content}}<|end_of_text|>\n<|begin_of_text|>assistant\n"] + ), + format_assistant=StringFormatter(slots=["{{content}}<|end_of_text|>\n"]), + format_system=StringFormatter(slots=["<|begin_of_text|>system\n{{content}}<|end_of_text|>\n"]), + default_system="You are a helpful assistant.", + stop_words=["<|end_of_text|>"], + mm_plugin=get_mm_plugin(name="youtu_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"), +) + + register_template( name="yuan", format_user=StringFormatter(slots=["{{content}}", {"token": ""}]), diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index 6baf51435..15f618f35 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -3375,6 +3375,18 @@ register_model_group( ) +register_model_group( + models={ + "Youtu-VL-4B-Instruct": { + DownloadSource.DEFAULT: "tencent/Youtu-VL-4B-Instruct", + DownloadSource.MODELSCOPE: "Tencent-YouTu-Research/Youtu-VL-4B-Instruct", + }, + }, + template="youtu_vl", + multimodal=True, +) + + register_model_group( models={ "Yuan2-2B-Chat": { diff --git a/src/llamafactory/model/model_utils/attention.py b/src/llamafactory/model/model_utils/attention.py index 086ab5cf8..290df74a4 100644 --- a/src/llamafactory/model/model_utils/attention.py +++ b/src/llamafactory/model/model_utils/attention.py @@ -57,6 +57,11 @@ def configure_attn_implementation(config: "PretrainedConfig", model_args: "Model "Gemma-2 should use soft-capping attention, while the SDPA attention does not support it." ) + if getattr(config, "model_type", None) in ["youtu", "youtu_vl"]: + if model_args.flash_attn in (AttentionFunction.AUTO, AttentionFunction.SDPA): + logger.warning_rank0("Youtu-VL does not support SDPA, forcing eager attention.") + model_args.flash_attn = AttentionFunction.DISABLED + if model_args.flash_attn == AttentionFunction.AUTO: return @@ -85,6 +90,13 @@ def configure_attn_implementation(config: "PretrainedConfig", model_args: "Model elif getattr(config, "model_type", None) == "kimi_vl": setattr(config.vision_config, "_attn_implementation", requested_attn_implementation) setattr(config.text_config, "_attn_implementation", requested_attn_implementation) + elif getattr(config, "model_type", None) == "youtu_vl": + setattr(config, "attn_implementation", requested_attn_implementation) + setattr(config, "_attn_implementation", requested_attn_implementation) + if hasattr(config, "vision_config"): + setattr(config.vision_config, "_attn_implementation", requested_attn_implementation) + if hasattr(config, "text_config"): + setattr(config.text_config, "_attn_implementation", requested_attn_implementation) else: setattr(config, "_attn_implementation", requested_attn_implementation) diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index 19e174bae..52e8ace21 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -61,6 +61,26 @@ def patch_qwen3_omni_moe_thinker_text_sparse_moe_block(): modeling_qwen3_omni_moe.Qwen3OmniMoeThinkerTextSparseMoeBlock = Qwen3OmniMoeThinkerTextSparseMoeBlock +def patch_youtu_vl_model(model: "PreTrainedModel") -> None: + original_forward = model.forward + + def forward(self, *args, **kwargs): + outputs = original_forward(*args, **kwargs) + if "loss" not in outputs and "labels" in kwargs: + logits = outputs.get("logits") + labels = kwargs.get("labels") + if logits is not None and labels is not None: + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss_fct = torch.nn.CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) + outputs["loss"] = loss + + return outputs + + model.forward = MethodType(forward, model) + + def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> None: if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__): tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer) @@ -207,6 +227,9 @@ def patch_model( if getattr(model.config, "model_type", None) == "gemma3n": setattr(model_args, "disable_gradient_checkpointing", True) + if getattr(model.config, "model_type", None) == "youtu_vl": + patch_youtu_vl_model(model) + prepare_model_for_training(model, model_args) autocast_projector_dtype(model, model_args) add_z3_leaf_module(model)