Former-commit-id: 35bc71b2a68fd303798c35fe22ad29ceea87cf9b
This commit is contained in:
Kingsley
2024-09-28 22:50:53 +08:00
parent e4c57f54f8
commit bddb2646bd
5 changed files with 21 additions and 19 deletions

View File

@@ -96,6 +96,9 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen
mm_projector: "torch.nn.Module" = getattr(model, "multi_modal_projector")
elif model_type == "qwen2_vl":
mm_projector: "torch.nn.Module" = getattr(getattr(model, "visual"), "merger")
# TODO check it
elif model_type == "pixtral":
mm_projector: "torch.nn.Module" = getattr(model, "vision_language_adapte")
else:
return
@@ -122,9 +125,11 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
"""
model_type = getattr(config, "model_type", None)
forbidden_modules = set()
if model_type in ["llava", "paligemma"]:
if model_type in ["llava", "paligemma", "pixtral"]:
if finetuning_args.freeze_vision_tower:
forbidden_modules.add("vision_tower")
#TODO check it
forbidden_modules.add("vision_encoder")
if finetuning_args.train_mm_proj_only:
forbidden_modules.add("language_model")
@@ -150,7 +155,7 @@ def get_image_seqlen(config: "PretrainedConfig") -> int:
image_seqlen += 1
elif model_type == "paligemma":
image_seqlen = config.vision_config.num_image_tokens
elif model_type == "qwen2_vl": # variable length
elif model_type in ["qwen2_vl", "pixtral"]: # variable length
image_seqlen = -1
return image_seqlen
@@ -168,10 +173,14 @@ def patch_target_modules(
return "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules))
elif model_type == "qwen2_vl":
return "^(?!.*visual).*(?:{}).*".format("|".join(target_modules))
elif model_type == "pixtral":
return "^(?!.*vision_encoder).*(?:{}).*".format("|".join(target_modules))
else:
return target_modules
else:
if model_type == "qwen2_vl":
return "^(?!.*patch_embed).*(?:{}).*".format("|".join(target_modules))
elif model_type == "pixtral":
return "^(?!.*patch_conv).*(?:{}).*".format("|".join(target_modules))
else:
return target_modules