fix full/freeze tuning for mllm
Former-commit-id: df5860ddb593d5b82163a585d12160b41dbce0f3
This commit is contained in:
@@ -10,7 +10,6 @@ from ..extras.logging import get_logger
|
||||
from .utils.misc import find_all_linear_modules, find_expanded_modules
|
||||
from .utils.quantization import QuantizationMethod
|
||||
from .utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
|
||||
from .utils.visual import filter_vision_tower_linear
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -53,21 +52,33 @@ def init_adapter(
|
||||
|
||||
if finetuning_args.finetuning_type == "full" and is_trainable:
|
||||
logger.info("Fine-tuning method: Full")
|
||||
if cast_trainable_params_to_fp32:
|
||||
model = model.float()
|
||||
|
||||
if model_args.visual_inputs and hasattr(model, "vision_tower"): # freeze vision model
|
||||
model.vision_tower.requires_grad_(False)
|
||||
forbidden_modules = set()
|
||||
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
|
||||
forbidden_modules.add("vision_tower")
|
||||
|
||||
if model_args.visual_inputs and hasattr(model, "language_model") and model_args.tune_mm_proj: # freeze language model if only tune mm_proj
|
||||
model.language_model.requires_grad_(False)
|
||||
if model_args.visual_inputs and finetuning_args.train_mm_proj_only:
|
||||
forbidden_modules.add("language_model")
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if not any(forbidden_module in name for forbidden_module in forbidden_modules):
|
||||
if cast_trainable_params_to_fp32:
|
||||
param.data = param.data.to(torch.float32)
|
||||
else:
|
||||
param.requires_grad_(False)
|
||||
|
||||
if finetuning_args.finetuning_type == "freeze" and is_trainable:
|
||||
logger.info("Fine-tuning method: Freeze")
|
||||
|
||||
if model_args.visual_inputs:
|
||||
config = model.config.text_config
|
||||
else:
|
||||
config = model.config
|
||||
|
||||
num_layers = (
|
||||
getattr(model.config, "num_hidden_layers", None)
|
||||
or getattr(model.config, "num_layers", None)
|
||||
or getattr(model.config, "n_layer", None)
|
||||
getattr(config, "num_hidden_layers", None)
|
||||
or getattr(config, "num_layers", None)
|
||||
or getattr(config, "n_layer", None)
|
||||
)
|
||||
if not num_layers:
|
||||
raise ValueError("Current model does not support freeze tuning.")
|
||||
@@ -119,16 +130,19 @@ def init_adapter(
|
||||
|
||||
trainable_layers.append(module_name)
|
||||
|
||||
forbidden_modules = set()
|
||||
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
|
||||
forbidden_modules.add("vision_tower")
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if any(trainable_layer in name for trainable_layer in trainable_layers):
|
||||
if any(trainable_layer in name for trainable_layer in trainable_layers) and not any(
|
||||
forbidden_module in name for forbidden_module in forbidden_modules
|
||||
):
|
||||
if cast_trainable_params_to_fp32:
|
||||
param.data = param.data.to(torch.float32)
|
||||
else:
|
||||
param.requires_grad_(False)
|
||||
|
||||
if model_args.visual_inputs and hasattr(model, "vision_tower"): # freeze vision model
|
||||
model.vision_tower.requires_grad_(False)
|
||||
|
||||
logger.info("Set trainable layers: {}".format(",".join(map(str, trainable_layer_ids))))
|
||||
|
||||
if finetuning_args.finetuning_type == "lora":
|
||||
@@ -177,15 +191,15 @@ def init_adapter(
|
||||
|
||||
if is_trainable and adapter_to_resume is None: # create new lora weights while training
|
||||
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
|
||||
target_modules = find_all_linear_modules(model)
|
||||
target_modules = find_all_linear_modules(model, finetuning_args.freeze_vision_tower)
|
||||
else:
|
||||
target_modules = finetuning_args.lora_target
|
||||
|
||||
if finetuning_args.use_llama_pro:
|
||||
target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable)
|
||||
target_modules = find_expanded_modules(model, target_modules, finetuning_args.freeze_trainable_layers)
|
||||
|
||||
if model_args.visual_inputs:
|
||||
target_modules = filter_vision_tower_linear(target_modules)
|
||||
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
|
||||
target_modules = "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules))
|
||||
|
||||
if (
|
||||
finetuning_args.use_dora
|
||||
|
||||
Reference in New Issue
Block a user