mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-03 08:53:38 +00:00
[model] support youtu-vl model (#10152)
This commit is contained in:
@@ -2159,6 +2159,40 @@ class LFMVLPlugin(BasePlugin):
|
|||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class YoutuVLPlugin(BasePlugin):
|
||||||
|
r"""Plugin for Youtu-VL vision-language models."""
|
||||||
|
|
||||||
|
vision_bos_token: str = "<|vision_start|>"
|
||||||
|
vision_eos_token: str = "<|vision_end|>"
|
||||||
|
|
||||||
|
@override
|
||||||
|
def process_messages(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
images: list["ImageInput"],
|
||||||
|
videos: list["VideoInput"],
|
||||||
|
audios: list["AudioInput"],
|
||||||
|
processor: Optional["MMProcessor"],
|
||||||
|
) -> list[dict[str, str]]:
|
||||||
|
self._validate_input(processor, images, videos, audios)
|
||||||
|
self._validate_messages(messages, images, videos, audios)
|
||||||
|
messages = deepcopy(messages)
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
content = message["content"]
|
||||||
|
content = content.replace(
|
||||||
|
IMAGE_PLACEHOLDER, f"{self.vision_bos_token}{self.image_token}{self.vision_eos_token}"
|
||||||
|
)
|
||||||
|
content = content.replace(
|
||||||
|
VIDEO_PLACEHOLDER, f"{self.vision_bos_token}{self.video_token}{self.vision_eos_token}"
|
||||||
|
)
|
||||||
|
|
||||||
|
message["content"] = content
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
PLUGINS = {
|
PLUGINS = {
|
||||||
"base": BasePlugin,
|
"base": BasePlugin,
|
||||||
"ernie_vl": ErnieVLPlugin,
|
"ernie_vl": ErnieVLPlugin,
|
||||||
@@ -2181,6 +2215,7 @@ PLUGINS = {
|
|||||||
"qwen2_vl": Qwen2VLPlugin,
|
"qwen2_vl": Qwen2VLPlugin,
|
||||||
"qwen3_vl": Qwen3VLPlugin,
|
"qwen3_vl": Qwen3VLPlugin,
|
||||||
"video_llava": VideoLlavaPlugin,
|
"video_llava": VideoLlavaPlugin,
|
||||||
|
"youtu_vl": YoutuVLPlugin,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2146,6 +2146,19 @@ register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_template(
|
||||||
|
name="youtu_vl",
|
||||||
|
format_user=StringFormatter(
|
||||||
|
slots=["<|begin_of_text|>user\n{{content}}<|end_of_text|>\n<|begin_of_text|>assistant\n"]
|
||||||
|
),
|
||||||
|
format_assistant=StringFormatter(slots=["{{content}}<|end_of_text|>\n"]),
|
||||||
|
format_system=StringFormatter(slots=["<|begin_of_text|>system\n{{content}}<|end_of_text|>\n"]),
|
||||||
|
default_system="You are a helpful assistant.",
|
||||||
|
stop_words=["<|end_of_text|>"],
|
||||||
|
mm_plugin=get_mm_plugin(name="youtu_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="yuan",
|
name="yuan",
|
||||||
format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]),
|
format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]),
|
||||||
|
|||||||
@@ -3375,6 +3375,18 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"Youtu-VL-4B-Instruct": {
|
||||||
|
DownloadSource.DEFAULT: "tencent/Youtu-VL-4B-Instruct",
|
||||||
|
DownloadSource.MODELSCOPE: "Tencent-YouTu-Research/Youtu-VL-4B-Instruct",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="youtu_vl",
|
||||||
|
multimodal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Yuan2-2B-Chat": {
|
"Yuan2-2B-Chat": {
|
||||||
|
|||||||
@@ -57,6 +57,11 @@ def configure_attn_implementation(config: "PretrainedConfig", model_args: "Model
|
|||||||
"Gemma-2 should use soft-capping attention, while the SDPA attention does not support it."
|
"Gemma-2 should use soft-capping attention, while the SDPA attention does not support it."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if getattr(config, "model_type", None) in ["youtu", "youtu_vl"]:
|
||||||
|
if model_args.flash_attn in (AttentionFunction.AUTO, AttentionFunction.SDPA):
|
||||||
|
logger.warning_rank0("Youtu-VL does not support SDPA, forcing eager attention.")
|
||||||
|
model_args.flash_attn = AttentionFunction.DISABLED
|
||||||
|
|
||||||
if model_args.flash_attn == AttentionFunction.AUTO:
|
if model_args.flash_attn == AttentionFunction.AUTO:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -85,6 +90,13 @@ def configure_attn_implementation(config: "PretrainedConfig", model_args: "Model
|
|||||||
elif getattr(config, "model_type", None) == "kimi_vl":
|
elif getattr(config, "model_type", None) == "kimi_vl":
|
||||||
setattr(config.vision_config, "_attn_implementation", requested_attn_implementation)
|
setattr(config.vision_config, "_attn_implementation", requested_attn_implementation)
|
||||||
setattr(config.text_config, "_attn_implementation", requested_attn_implementation)
|
setattr(config.text_config, "_attn_implementation", requested_attn_implementation)
|
||||||
|
elif getattr(config, "model_type", None) == "youtu_vl":
|
||||||
|
setattr(config, "attn_implementation", requested_attn_implementation)
|
||||||
|
setattr(config, "_attn_implementation", requested_attn_implementation)
|
||||||
|
if hasattr(config, "vision_config"):
|
||||||
|
setattr(config.vision_config, "_attn_implementation", requested_attn_implementation)
|
||||||
|
if hasattr(config, "text_config"):
|
||||||
|
setattr(config.text_config, "_attn_implementation", requested_attn_implementation)
|
||||||
else:
|
else:
|
||||||
setattr(config, "_attn_implementation", requested_attn_implementation)
|
setattr(config, "_attn_implementation", requested_attn_implementation)
|
||||||
|
|
||||||
|
|||||||
@@ -61,6 +61,26 @@ def patch_qwen3_omni_moe_thinker_text_sparse_moe_block():
|
|||||||
modeling_qwen3_omni_moe.Qwen3OmniMoeThinkerTextSparseMoeBlock = Qwen3OmniMoeThinkerTextSparseMoeBlock
|
modeling_qwen3_omni_moe.Qwen3OmniMoeThinkerTextSparseMoeBlock = Qwen3OmniMoeThinkerTextSparseMoeBlock
|
||||||
|
|
||||||
|
|
||||||
|
def patch_youtu_vl_model(model: "PreTrainedModel") -> None:
|
||||||
|
original_forward = model.forward
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
outputs = original_forward(*args, **kwargs)
|
||||||
|
if "loss" not in outputs and "labels" in kwargs:
|
||||||
|
logits = outputs.get("logits")
|
||||||
|
labels = kwargs.get("labels")
|
||||||
|
if logits is not None and labels is not None:
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
loss_fct = torch.nn.CrossEntropyLoss()
|
||||||
|
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
|
||||||
|
outputs["loss"] = loss
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
model.forward = MethodType(forward, model)
|
||||||
|
|
||||||
|
|
||||||
def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> None:
|
def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> None:
|
||||||
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
|
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
|
||||||
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
|
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
|
||||||
@@ -207,6 +227,9 @@ def patch_model(
|
|||||||
if getattr(model.config, "model_type", None) == "gemma3n":
|
if getattr(model.config, "model_type", None) == "gemma3n":
|
||||||
setattr(model_args, "disable_gradient_checkpointing", True)
|
setattr(model_args, "disable_gradient_checkpointing", True)
|
||||||
|
|
||||||
|
if getattr(model.config, "model_type", None) == "youtu_vl":
|
||||||
|
patch_youtu_vl_model(model)
|
||||||
|
|
||||||
prepare_model_for_training(model, model_args)
|
prepare_model_for_training(model, model_args)
|
||||||
autocast_projector_dtype(model, model_args)
|
autocast_projector_dtype(model, model_args)
|
||||||
add_z3_leaf_module(model)
|
add_z3_leaf_module(model)
|
||||||
|
|||||||
Reference in New Issue
Block a user