support loftq
Former-commit-id: e7ac2eb7f7daae17525a278ffbe2f82c0fbd8093
This commit is contained in:
@@ -91,6 +91,16 @@ def init_adapter(
|
||||
else:
|
||||
target_modules = finetuning_args.lora_target
|
||||
|
||||
config_kwargs = {}
|
||||
if model_args.quantization_bit is not None and finetuning_args.loftq_init:
|
||||
if model_args.quantization_bit != 4:
|
||||
raise ValueError("LoftQ initialization only support 4-bit quantized training.")
|
||||
|
||||
from peft import LoftQConfig # type: ignore
|
||||
loftq_config = LoftQConfig(loftq_bits=4)
|
||||
config_kwargs["init_lora_weights"] = "loftq"
|
||||
config_kwargs["loftq_config"] = loftq_config
|
||||
|
||||
lora_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
inference_mode=False,
|
||||
@@ -98,7 +108,8 @@ def init_adapter(
|
||||
lora_alpha=finetuning_args.lora_alpha,
|
||||
lora_dropout=finetuning_args.lora_dropout,
|
||||
target_modules=target_modules,
|
||||
modules_to_save=finetuning_args.additional_target
|
||||
modules_to_save=finetuning_args.additional_target,
|
||||
**config_kwargs
|
||||
)
|
||||
model = get_peft_model(model, lora_config)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user