[model] support gemma3 (#7273)

This commit is contained in:
hoshi-hiyouga
2025-03-13 01:35:23 +08:00
committed by GitHub
parent e6159ad730
commit 4b9d8da5a4
9 changed files with 356 additions and 274 deletions

View File

@@ -1,4 +1,4 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's Transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava/modeling_llava.py
@@ -28,7 +28,7 @@ from ...extras import logging
if TYPE_CHECKING:
from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel, ProcessorMixin
from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel
from ...hparams import FinetuningArguments, ModelArguments
@@ -62,6 +62,16 @@ def _register_composite_model(
language_model_keys: Optional[list[str]] = None,
lora_conflict_keys: Optional[list[str]] = None,
):
r"""Register a new composite model.
Args:
model_type: model type
projector_key: multi_modal_projector
vision_model_keys: vision_tower
language_model_keys: language_model
lora_conflict_keys: None
"""
COMPOSITE_MODELS[model_type] = CompositeModel(
model_type=model_type,
projector_key=projector_key or "multi_modal_projector",
@@ -169,39 +179,10 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
return forbidden_modules
def get_image_seqlen(config: "PretrainedConfig") -> int:
r"""Compute the number of special tokens per image."""
model_type = getattr(config, "model_type", None)
if model_type == "llava":
image_seqlen = (config.vision_config.image_size // config.vision_config.patch_size) ** 2
if getattr(config, "vision_feature_select_strategy", "default") == "full": # add [CLS] token
image_seqlen += 1
elif model_type == "paligemma":
image_seqlen = config.vision_config.num_image_tokens
else:
image_seqlen = -1
return image_seqlen
def get_patch_size(config: "PretrainedConfig", processor: "ProcessorMixin") -> int:
r"""Compute the patch size of the vit."""
patch_size = getattr(config.vision_config, "patch_size", getattr(processor, "patch_size", -1))
return patch_size
def get_vision_feature_select_strategy(config: "PretrainedConfig", processor: "ProcessorMixin") -> int:
r"""Get the vision_feature_select_strategy."""
vision_feature_select_strategy = getattr(
config, "vision_feature_select_strategy", getattr(processor, "vision_feature_select_strategy", "default")
)
return vision_feature_select_strategy
def patch_target_modules(
model: "PreTrainedModel", finetuning_args: "FinetuningArguments", target_modules: Sequence[str]
) -> list[str]:
r"""Freezes vision tower for VLM LoRA tuning."""
r"""Freeze vision tower for VLM LoRA tuning."""
model_type = getattr(model.config, "model_type", None)
if model_type in COMPOSITE_MODELS:
forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
@@ -218,6 +199,11 @@ def patch_target_modules(
return target_modules
_register_composite_model(
model_type="gemma3",
)
_register_composite_model(
model_type="llava",
)