add llava and instructblip

Former-commit-id: 142fb6f4541a1acfefe66ff2574dabde53b00c06
This commit is contained in:
BUAADreamer
2024-04-25 00:22:43 +08:00
parent 1451297c78
commit 12c51655ce
16 changed files with 273 additions and 214 deletions

View File

@@ -182,7 +182,8 @@ def init_adapter(
def init_mm_adapter(
model: "AutoModelForVision2Seq", model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: bool
is_trainable: bool,
use_clm=True,
) -> "AutoModelForVision2Seq":
if finetuning_args.finetuning_type == "lora":
logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
@@ -253,12 +254,19 @@ def init_mm_adapter(
}
model = FastLanguageModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs)
else:
lora_config = LoraConfig(
# task_type=TaskType.CAUSAL_LM,
inference_mode=False,
use_dora=finetuning_args.use_dora,
**peft_kwargs,
)
if use_clm:
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
use_dora=finetuning_args.use_dora,
**peft_kwargs,
)
else:
lora_config = LoraConfig(
inference_mode=False,
use_dora=finetuning_args.use_dora,
**peft_kwargs,
)
model = get_peft_model(model, lora_config)
if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):

View File

@@ -191,6 +191,7 @@ def load_mm_model(
finetuning_args: "FinetuningArguments",
is_trainable: bool = False,
add_valuehead: bool = False,
use_clm=True,
) -> "AutoModelForVision2Seq":
r"""
Loads pretrained model. Must after load_tokenizer.
@@ -231,7 +232,7 @@ def load_mm_model(
patch_model(model, tokenizer, model_args, is_trainable)
register_autoclass(config, model, tokenizer)
model = init_mm_adapter(model, model_args, finetuning_args, is_trainable)
model = init_mm_adapter(model, model_args, finetuning_args, is_trainable, use_clm)
if not is_trainable:
model.requires_grad_(False)