remove visual_inputs, fix qlora
Former-commit-id: be30c01c4f1482520ece770bd54c6a4837c26f0a
This commit is contained in:
@@ -24,6 +24,7 @@ from ..extras.logging import get_logger
|
||||
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
|
||||
from .model_utils.quantization import QuantizationMethod
|
||||
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
|
||||
from .model_utils.visual import get_forbidden_modules, patch_target_modules
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -37,7 +38,6 @@ logger = get_logger(__name__)
|
||||
|
||||
def _setup_full_tuning(
|
||||
model: "PreTrainedModel",
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
is_trainable: bool,
|
||||
cast_trainable_params_to_fp32: bool,
|
||||
@@ -46,13 +46,7 @@ def _setup_full_tuning(
|
||||
return
|
||||
|
||||
logger.info("Fine-tuning method: Full")
|
||||
forbidden_modules = set()
|
||||
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
|
||||
forbidden_modules.add("vision_tower")
|
||||
|
||||
if model_args.visual_inputs and finetuning_args.train_mm_proj_only:
|
||||
forbidden_modules.add("language_model")
|
||||
|
||||
forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
|
||||
for name, param in model.named_parameters():
|
||||
if not any(forbidden_module in name for forbidden_module in forbidden_modules):
|
||||
if cast_trainable_params_to_fp32:
|
||||
@@ -63,7 +57,6 @@ def _setup_full_tuning(
|
||||
|
||||
def _setup_freeze_tuning(
|
||||
model: "PreTrainedModel",
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
is_trainable: bool,
|
||||
cast_trainable_params_to_fp32: bool,
|
||||
@@ -72,8 +65,8 @@ def _setup_freeze_tuning(
|
||||
return
|
||||
|
||||
logger.info("Fine-tuning method: Freeze")
|
||||
if model_args.visual_inputs:
|
||||
config = model.config.text_config
|
||||
if hasattr(model.config, "text_config"): # composite models
|
||||
config = getattr(model.config, "text_config")
|
||||
else:
|
||||
config = model.config
|
||||
|
||||
@@ -130,10 +123,7 @@ def _setup_freeze_tuning(
|
||||
|
||||
trainable_layers.append(module_name)
|
||||
|
||||
forbidden_modules = set()
|
||||
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
|
||||
forbidden_modules.add("vision_tower")
|
||||
|
||||
forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
|
||||
for name, param in model.named_parameters():
|
||||
if any(trainable_layer in name for trainable_layer in trainable_layers) and not any(
|
||||
forbidden_module in name for forbidden_module in forbidden_modules
|
||||
@@ -211,8 +201,7 @@ def _setup_lora_tuning(
|
||||
if finetuning_args.use_llama_pro:
|
||||
target_modules = find_expanded_modules(model, target_modules, finetuning_args.freeze_trainable_layers)
|
||||
|
||||
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
|
||||
target_modules = "^(?!.*(?:vision_tower|visual)).*(?:{}).*".format("|".join(target_modules))
|
||||
target_modules = patch_target_modules(model.config, finetuning_args, target_modules)
|
||||
|
||||
if (
|
||||
finetuning_args.use_dora
|
||||
@@ -303,9 +292,9 @@ def init_adapter(
|
||||
cast_trainable_params_to_fp32 = True
|
||||
|
||||
if finetuning_args.finetuning_type == "full":
|
||||
_setup_full_tuning(model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
|
||||
_setup_full_tuning(model, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
|
||||
elif finetuning_args.finetuning_type == "freeze":
|
||||
_setup_freeze_tuning(model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
|
||||
_setup_freeze_tuning(model, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
|
||||
elif finetuning_args.finetuning_type == "lora":
|
||||
model = _setup_lora_tuning(
|
||||
config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32
|
||||
|
||||
@@ -93,17 +93,10 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
||||
|
||||
patch_tokenizer(tokenizer)
|
||||
|
||||
if model_args.visual_inputs:
|
||||
try:
|
||||
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
||||
setattr(processor, "tokenizer", tokenizer)
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
"This multimodal LLM is not supported.\n"
|
||||
"Download LLaVA-1.5 models from: https://huggingface.co/llava-hf\n"
|
||||
"Download Yi-VL models from: https://huggingface.co/BUAADreamer"
|
||||
)
|
||||
else:
|
||||
try:
|
||||
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
||||
setattr(processor, "tokenizer", tokenizer)
|
||||
except Exception:
|
||||
processor = None
|
||||
|
||||
return {"tokenizer": tokenizer, "processor": processor}
|
||||
@@ -145,12 +138,16 @@ def load_model(
|
||||
|
||||
if model_args.mixture_of_depths == "load":
|
||||
model = load_mod_pretrained_model(**init_kwargs)
|
||||
elif model_args.visual_inputs:
|
||||
model = AutoModelForVision2Seq.from_pretrained(**init_kwargs)
|
||||
elif model_args.train_from_scratch:
|
||||
model = AutoModelForCausalLM.from_config(config)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(**init_kwargs)
|
||||
if type(config) in AutoModelForVision2Seq._model_mapping.keys(): # assume built-in models
|
||||
load_class = AutoModelForVision2Seq
|
||||
else:
|
||||
load_class = AutoModelForCausalLM
|
||||
|
||||
if model_args.train_from_scratch:
|
||||
model = load_class.from_config(config)
|
||||
else:
|
||||
model = load_class.from_pretrained(**init_kwargs)
|
||||
|
||||
if model_args.mixture_of_depths == "convert":
|
||||
model = convert_pretrained_model_to_mod(model, config, model_args)
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
from typing import TYPE_CHECKING, List, Sequence, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers.models
|
||||
@@ -28,7 +28,7 @@ from ...extras.logging import get_logger
|
||||
if TYPE_CHECKING:
|
||||
from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel
|
||||
|
||||
from ...hparams import ModelArguments
|
||||
from ...hparams import FinetuningArguments, ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -80,24 +80,74 @@ class LlavaMultiModalProjectorForYiVLForVLLM(LlavaMultiModalProjectorForYiVL):
|
||||
self.act = ACT2FN[projector_hidden_act]
|
||||
|
||||
|
||||
def autocast_projector_dtype(
|
||||
model: "PreTrainedModel", model_args: "ModelArguments", mm_projector_name: str = "multi_modal_projector"
|
||||
) -> None:
|
||||
def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArguments") -> None:
|
||||
r"""
|
||||
Casts projector output to half precision for quantized VLMs.
|
||||
"""
|
||||
|
||||
def _mm_projector_forward_post_hook(
|
||||
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
|
||||
) -> "torch.Tensor":
|
||||
return output.to(model_args.compute_dtype)
|
||||
|
||||
if hasattr(model, mm_projector_name) and getattr(model, "quantization_method", None):
|
||||
if getattr(model, "quantization_method", None):
|
||||
if getattr(model.config, "model_type", None) in ["llava", "paligemma"]:
|
||||
mm_projector: "torch.nn.Module" = getattr(model, "multi_modal_projector")
|
||||
elif getattr(model.config, "model_type", None) == "qwen2_vl":
|
||||
mm_projector: "torch.nn.Module" = getattr(getattr(model, "visual"), "merger")
|
||||
else:
|
||||
return
|
||||
|
||||
logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype))
|
||||
mm_projector: "torch.nn.Module" = getattr(model, mm_projector_name)
|
||||
mm_projector.register_forward_hook(_mm_projector_forward_post_hook)
|
||||
|
||||
|
||||
def configure_visual_model(config: "PretrainedConfig") -> None:
|
||||
r"""
|
||||
Patches VLMs before loading them.
|
||||
"""
|
||||
if getattr(config, "model_type", None) == "llava": # required for ds zero3 and valuehead models
|
||||
setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))
|
||||
|
||||
if getattr(config, "is_yi_vl_derived_model", None):
|
||||
logger.info("Detected Yi-VL model, applying projector patch.")
|
||||
transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL
|
||||
|
||||
|
||||
def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "FinetuningArguments") -> Set[str]:
|
||||
r"""
|
||||
Freezes vision tower and language model for VLM full/freeze tuning.
|
||||
"""
|
||||
forbidden_modules = set()
|
||||
if getattr(config, "model_type", None) in ["llava", "paligemma"]:
|
||||
if finetuning_args.freeze_vision_tower:
|
||||
forbidden_modules.add("vision_tower")
|
||||
|
||||
if finetuning_args.train_mm_proj_only:
|
||||
forbidden_modules.add("language_model")
|
||||
|
||||
elif getattr(config, "model_type", None) == "qwen2_vl":
|
||||
if finetuning_args.freeze_vision_tower:
|
||||
forbidden_modules.add("visual")
|
||||
|
||||
if finetuning_args.train_mm_proj_only:
|
||||
raise ValueError("Qwen2-VL models do not support `train_mm_proj_only`.")
|
||||
|
||||
return forbidden_modules
|
||||
|
||||
|
||||
def patch_target_modules(
|
||||
config: "PretrainedConfig", finetuning_args: "FinetuningArguments", target_modules: Sequence[str]
|
||||
) -> Union[str, List[str]]:
|
||||
r"""
|
||||
Freezes vision tower for VLM LoRA tuning.
|
||||
"""
|
||||
if not finetuning_args.freeze_vision_tower:
|
||||
return target_modules
|
||||
|
||||
if getattr(config, "model_type", None) in ["llava", "paligemma"]:
|
||||
return "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules))
|
||||
elif getattr(config, "model_type", None) == "qwen2_vl":
|
||||
return "^(?!.*visual).*(?:{}).*".format("|".join(target_modules))
|
||||
else:
|
||||
return target_modules
|
||||
|
||||
@@ -131,11 +131,9 @@ def patch_model(
|
||||
if model_args.resize_vocab:
|
||||
resize_embedding_layer(model, tokenizer)
|
||||
|
||||
if model_args.visual_inputs:
|
||||
autocast_projector_dtype(model, model_args)
|
||||
|
||||
if is_trainable:
|
||||
prepare_model_for_training(model, model_args)
|
||||
autocast_projector_dtype(model, model_args)
|
||||
add_z3_leaf_module(model)
|
||||
|
||||
if not model_args.use_unsloth:
|
||||
|
||||
Reference in New Issue
Block a user