fix full/freeze tuning for mllm

Former-commit-id: df5860ddb593d5b82163a585d12160b41dbce0f3
This commit is contained in:
hiyouga
2024-05-27 20:37:57 +08:00
parent 48ff9fb150
commit 3b28c003dd
6 changed files with 62 additions and 47 deletions

View File

@@ -10,7 +10,6 @@ from ..extras.logging import get_logger
from .utils.misc import find_all_linear_modules, find_expanded_modules
from .utils.quantization import QuantizationMethod
from .utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
from .utils.visual import filter_vision_tower_linear
if TYPE_CHECKING:
@@ -53,21 +52,33 @@ def init_adapter(
if finetuning_args.finetuning_type == "full" and is_trainable:
logger.info("Fine-tuning method: Full")
if cast_trainable_params_to_fp32:
model = model.float()
if model_args.visual_inputs and hasattr(model, "vision_tower"): # freeze vision model
model.vision_tower.requires_grad_(False)
forbidden_modules = set()
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
forbidden_modules.add("vision_tower")
if model_args.visual_inputs and hasattr(model, "language_model") and model_args.tune_mm_proj: # freeze language model if only tune mm_proj
model.language_model.requires_grad_(False)
if model_args.visual_inputs and finetuning_args.train_mm_proj_only:
forbidden_modules.add("language_model")
for name, param in model.named_parameters():
if not any(forbidden_module in name for forbidden_module in forbidden_modules):
if cast_trainable_params_to_fp32:
param.data = param.data.to(torch.float32)
else:
param.requires_grad_(False)
if finetuning_args.finetuning_type == "freeze" and is_trainable:
logger.info("Fine-tuning method: Freeze")
if model_args.visual_inputs:
config = model.config.text_config
else:
config = model.config
num_layers = (
getattr(model.config, "num_hidden_layers", None)
or getattr(model.config, "num_layers", None)
or getattr(model.config, "n_layer", None)
getattr(config, "num_hidden_layers", None)
or getattr(config, "num_layers", None)
or getattr(config, "n_layer", None)
)
if not num_layers:
raise ValueError("Current model does not support freeze tuning.")
@@ -119,16 +130,19 @@ def init_adapter(
trainable_layers.append(module_name)
forbidden_modules = set()
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
forbidden_modules.add("vision_tower")
for name, param in model.named_parameters():
if any(trainable_layer in name for trainable_layer in trainable_layers):
if any(trainable_layer in name for trainable_layer in trainable_layers) and not any(
forbidden_module in name for forbidden_module in forbidden_modules
):
if cast_trainable_params_to_fp32:
param.data = param.data.to(torch.float32)
else:
param.requires_grad_(False)
if model_args.visual_inputs and hasattr(model, "vision_tower"): # freeze vision model
model.vision_tower.requires_grad_(False)
logger.info("Set trainable layers: {}".format(",".join(map(str, trainable_layer_ids))))
if finetuning_args.finetuning_type == "lora":
@@ -177,15 +191,15 @@ def init_adapter(
if is_trainable and adapter_to_resume is None: # create new lora weights while training
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
target_modules = find_all_linear_modules(model)
target_modules = find_all_linear_modules(model, finetuning_args.freeze_vision_tower)
else:
target_modules = finetuning_args.lora_target
if finetuning_args.use_llama_pro:
target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable)
target_modules = find_expanded_modules(model, target_modules, finetuning_args.freeze_trainable_layers)
if model_args.visual_inputs:
target_modules = filter_vision_tower_linear(target_modules)
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
target_modules = "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules))
if (
finetuning_args.use_dora