[model] support gemma3 (#7273)
This commit is contained in:
@@ -19,6 +19,7 @@ import torch
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForImageTextToText,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoModelForVision2Seq,
|
||||
AutoProcessor,
|
||||
@@ -72,7 +73,6 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
||||
Note: including inplace operation of model_args.
|
||||
"""
|
||||
init_kwargs = _get_init_kwargs(model_args)
|
||||
config = load_config(model_args)
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
@@ -94,7 +94,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
||||
patch_tokenizer(tokenizer, model_args)
|
||||
try:
|
||||
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
||||
patch_processor(processor, config, tokenizer, model_args)
|
||||
patch_processor(processor, tokenizer, model_args)
|
||||
except Exception as e:
|
||||
logger.debug(f"Processor was not found: {e}.")
|
||||
processor = None
|
||||
@@ -141,9 +141,11 @@ def load_model(
|
||||
if model_args.mixture_of_depths == "load":
|
||||
model = load_mod_pretrained_model(**init_kwargs)
|
||||
else:
|
||||
if type(config) in AutoModelForVision2Seq._model_mapping.keys(): # assume built-in models
|
||||
if type(config) in AutoModelForVision2Seq._model_mapping.keys(): # image-text
|
||||
load_class = AutoModelForVision2Seq
|
||||
elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys():
|
||||
elif type(config) in AutoModelForImageTextToText._model_mapping.keys(): # image-text
|
||||
load_class = AutoModelForImageTextToText
|
||||
elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys(): # audio-text
|
||||
load_class = AutoModelForSeq2SeqLM
|
||||
else:
|
||||
load_class = AutoModelForCausalLM
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the HuggingFace's Transformers library.
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava/modeling_llava.py
|
||||
@@ -28,7 +28,7 @@ from ...extras import logging
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel, ProcessorMixin
|
||||
from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel
|
||||
|
||||
from ...hparams import FinetuningArguments, ModelArguments
|
||||
|
||||
@@ -62,6 +62,16 @@ def _register_composite_model(
|
||||
language_model_keys: Optional[list[str]] = None,
|
||||
lora_conflict_keys: Optional[list[str]] = None,
|
||||
):
|
||||
r"""Register a new composite model.
|
||||
|
||||
Args:
|
||||
model_type: model type
|
||||
projector_key: multi_modal_projector
|
||||
vision_model_keys: vision_tower
|
||||
language_model_keys: language_model
|
||||
lora_conflict_keys: None
|
||||
|
||||
"""
|
||||
COMPOSITE_MODELS[model_type] = CompositeModel(
|
||||
model_type=model_type,
|
||||
projector_key=projector_key or "multi_modal_projector",
|
||||
@@ -169,39 +179,10 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
|
||||
return forbidden_modules
|
||||
|
||||
|
||||
def get_image_seqlen(config: "PretrainedConfig") -> int:
|
||||
r"""Compute the number of special tokens per image."""
|
||||
model_type = getattr(config, "model_type", None)
|
||||
if model_type == "llava":
|
||||
image_seqlen = (config.vision_config.image_size // config.vision_config.patch_size) ** 2
|
||||
if getattr(config, "vision_feature_select_strategy", "default") == "full": # add [CLS] token
|
||||
image_seqlen += 1
|
||||
elif model_type == "paligemma":
|
||||
image_seqlen = config.vision_config.num_image_tokens
|
||||
else:
|
||||
image_seqlen = -1
|
||||
|
||||
return image_seqlen
|
||||
|
||||
|
||||
def get_patch_size(config: "PretrainedConfig", processor: "ProcessorMixin") -> int:
|
||||
r"""Compute the patch size of the vit."""
|
||||
patch_size = getattr(config.vision_config, "patch_size", getattr(processor, "patch_size", -1))
|
||||
return patch_size
|
||||
|
||||
|
||||
def get_vision_feature_select_strategy(config: "PretrainedConfig", processor: "ProcessorMixin") -> int:
|
||||
r"""Get the vision_feature_select_strategy."""
|
||||
vision_feature_select_strategy = getattr(
|
||||
config, "vision_feature_select_strategy", getattr(processor, "vision_feature_select_strategy", "default")
|
||||
)
|
||||
return vision_feature_select_strategy
|
||||
|
||||
|
||||
def patch_target_modules(
|
||||
model: "PreTrainedModel", finetuning_args: "FinetuningArguments", target_modules: Sequence[str]
|
||||
) -> list[str]:
|
||||
r"""Freezes vision tower for VLM LoRA tuning."""
|
||||
r"""Freeze vision tower for VLM LoRA tuning."""
|
||||
model_type = getattr(model.config, "model_type", None)
|
||||
if model_type in COMPOSITE_MODELS:
|
||||
forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
|
||||
@@ -218,6 +199,11 @@ def patch_target_modules(
|
||||
return target_modules
|
||||
|
||||
|
||||
_register_composite_model(
|
||||
model_type="gemma3",
|
||||
)
|
||||
|
||||
|
||||
_register_composite_model(
|
||||
model_type="llava",
|
||||
)
|
||||
|
||||
@@ -33,13 +33,7 @@ from .model_utils.packing import configure_packing
|
||||
from .model_utils.quantization import configure_quantization
|
||||
from .model_utils.rope import configure_rope
|
||||
from .model_utils.valuehead import prepare_valuehead_model
|
||||
from .model_utils.visual import (
|
||||
autocast_projector_dtype,
|
||||
configure_visual_model,
|
||||
get_image_seqlen,
|
||||
get_patch_size,
|
||||
get_vision_feature_select_strategy,
|
||||
)
|
||||
from .model_utils.visual import autocast_projector_dtype, configure_visual_model
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -72,21 +66,16 @@ def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArgument
|
||||
|
||||
def patch_processor(
|
||||
processor: "ProcessorMixin",
|
||||
config: "PretrainedConfig",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
model_args: "ModelArguments",
|
||||
) -> None:
|
||||
setattr(processor, "tokenizer", tokenizer)
|
||||
if getattr(config, "vision_config", None) is not None: # visual models
|
||||
setattr(processor, "image_seqlen", get_image_seqlen(config))
|
||||
setattr(processor, "patch_size", get_patch_size(config, processor))
|
||||
setattr(processor, "image_max_pixels", model_args.image_max_pixels)
|
||||
setattr(processor, "image_min_pixels", model_args.image_min_pixels)
|
||||
setattr(processor, "video_max_pixels", model_args.video_max_pixels)
|
||||
setattr(processor, "video_min_pixels", model_args.video_min_pixels)
|
||||
setattr(processor, "video_fps", model_args.video_fps)
|
||||
setattr(processor, "video_maxlen", model_args.video_maxlen)
|
||||
setattr(processor, "vision_feature_select_strategy", get_vision_feature_select_strategy(config, processor))
|
||||
setattr(processor, "image_max_pixels", model_args.image_max_pixels)
|
||||
setattr(processor, "image_min_pixels", model_args.image_min_pixels)
|
||||
setattr(processor, "video_max_pixels", model_args.video_max_pixels)
|
||||
setattr(processor, "video_min_pixels", model_args.video_min_pixels)
|
||||
setattr(processor, "video_fps", model_args.video_fps)
|
||||
setattr(processor, "video_maxlen", model_args.video_maxlen)
|
||||
|
||||
|
||||
def patch_config(
|
||||
|
||||
Reference in New Issue
Block a user