|
|
|
|
@@ -16,7 +16,7 @@
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass
|
|
|
|
|
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple, Union
|
|
|
|
|
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import transformers
|
|
|
|
|
@@ -42,6 +42,7 @@ class CompositeModel:
|
|
|
|
|
projector_key: str
|
|
|
|
|
vision_model_keys: List[str]
|
|
|
|
|
language_model_keys: List[str]
|
|
|
|
|
lora_conflict_keys: List[str]
|
|
|
|
|
|
|
|
|
|
def get_projector(self, module: "torch.nn.Module") -> "torch.nn.Module":
|
|
|
|
|
for key in self.projector_key.split("."):
|
|
|
|
|
@@ -58,15 +59,14 @@ def _register_composite_model(
|
|
|
|
|
projector_key: Optional[str] = None,
|
|
|
|
|
vision_model_keys: Optional[List[str]] = None,
|
|
|
|
|
language_model_keys: Optional[List[str]] = None,
|
|
|
|
|
lora_conflict_keys: Optional[List[str]] = None,
|
|
|
|
|
):
|
|
|
|
|
projector_key = projector_key or "multi_modal_projector"
|
|
|
|
|
vision_model_keys = vision_model_keys or ["vision_tower"]
|
|
|
|
|
language_model_keys = language_model_keys or ["language_model"]
|
|
|
|
|
COMPOSITE_MODELS[model_type] = CompositeModel(
|
|
|
|
|
model_type=model_type,
|
|
|
|
|
projector_key=projector_key,
|
|
|
|
|
vision_model_keys=vision_model_keys,
|
|
|
|
|
language_model_keys=language_model_keys,
|
|
|
|
|
projector_key=projector_key or "multi_modal_projector",
|
|
|
|
|
vision_model_keys=vision_model_keys or ["vision_tower"],
|
|
|
|
|
language_model_keys=language_model_keys or ["language_model"],
|
|
|
|
|
lora_conflict_keys=lora_conflict_keys or [],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -210,29 +210,25 @@ def get_vision_feature_select_strategy(config: "PretrainedConfig", processor: "P
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def patch_target_modules(
|
|
|
|
|
config: "PretrainedConfig", finetuning_args: "FinetuningArguments", target_modules: Sequence[str]
|
|
|
|
|
) -> Union[str, List[str]]:
|
|
|
|
|
model: "PreTrainedModel", finetuning_args: "FinetuningArguments", target_modules: Sequence[str]
|
|
|
|
|
) -> List[str]:
|
|
|
|
|
r"""
|
|
|
|
|
Freezes vision tower for VLM LoRA tuning.
|
|
|
|
|
"""
|
|
|
|
|
model_type = getattr(config, "model_type", None)
|
|
|
|
|
vit_model_type = getattr(getattr(config, "vision_config", None), "model_type", None)
|
|
|
|
|
if finetuning_args.freeze_vision_tower:
|
|
|
|
|
if model_type in COMPOSITE_MODELS:
|
|
|
|
|
vision_model_keys = COMPOSITE_MODELS[model_type].vision_model_keys
|
|
|
|
|
logger.info_rank0(f"Set vision model not trainable: {vision_model_keys}.")
|
|
|
|
|
vision_model_keys = "|".join(vision_model_keys)
|
|
|
|
|
target_modules = "|".join(target_modules)
|
|
|
|
|
return f"^(?!.*{vision_model_keys}).*(?:{target_modules}).*"
|
|
|
|
|
else:
|
|
|
|
|
return target_modules
|
|
|
|
|
model_type = getattr(model.config, "model_type", None)
|
|
|
|
|
if model_type in COMPOSITE_MODELS:
|
|
|
|
|
forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
|
|
|
|
|
forbidden_modules.update(COMPOSITE_MODELS[model_type].lora_conflict_keys)
|
|
|
|
|
module_names = []
|
|
|
|
|
for name, _ in model.named_modules():
|
|
|
|
|
if any(target_module in name for target_module in target_modules) and not any(
|
|
|
|
|
forbidden_module in name for forbidden_module in forbidden_modules
|
|
|
|
|
):
|
|
|
|
|
module_names.append(name)
|
|
|
|
|
|
|
|
|
|
return module_names
|
|
|
|
|
else:
|
|
|
|
|
if model_type == "qwen2_vl": # avoid attaching lora to Conv3D layer
|
|
|
|
|
return "^(?!.*patch_embed).*(?:{}).*".format("|".join(target_modules))
|
|
|
|
|
elif vit_model_type == "pixtral":
|
|
|
|
|
return "^(?!.*patch_conv).*(?:{}).*".format("|".join(target_modules))
|
|
|
|
|
else:
|
|
|
|
|
return target_modules
|
|
|
|
|
return target_modules
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_register_composite_model(
|
|
|
|
|
@@ -252,6 +248,7 @@ _register_composite_model(
|
|
|
|
|
|
|
|
|
|
_register_composite_model(
|
|
|
|
|
model_type="minicpmv",
|
|
|
|
|
projector_key="resampler",
|
|
|
|
|
vision_model_keys=["vpm"],
|
|
|
|
|
language_model_keys=["llm"],
|
|
|
|
|
)
|
|
|
|
|
@@ -259,8 +256,10 @@ _register_composite_model(
|
|
|
|
|
|
|
|
|
|
_register_composite_model(
|
|
|
|
|
model_type="minicpmo",
|
|
|
|
|
vision_model_keys=["vpm", "apm", "resampler", "tts"],
|
|
|
|
|
projector_key="resampler",
|
|
|
|
|
vision_model_keys=["vpm", "apm", "audio_avg_pooler", "audio_projection_layer", "tts"],
|
|
|
|
|
language_model_keys=["llm"],
|
|
|
|
|
lora_conflict_keys=["audio_projection_layer"],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -291,6 +290,7 @@ _register_composite_model(
|
|
|
|
|
projector_key="visual.merger",
|
|
|
|
|
vision_model_keys=["visual.patch_embed", "visual.blocks"],
|
|
|
|
|
language_model_keys=["model", "lm_head"],
|
|
|
|
|
lora_conflict_keys=["patch_embed"],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -299,4 +299,5 @@ _register_composite_model(
|
|
|
|
|
projector_key="visual.merger",
|
|
|
|
|
vision_model_keys=["visual.patch_embed", "visual.blocks"],
|
|
|
|
|
language_model_keys=["model", "lm_head"],
|
|
|
|
|
lora_conflict_keys=["patch_embed"],
|
|
|
|
|
)
|
|
|
|
|
|