support pissa

Former-commit-id: ef8e45f2eaf466c54e9a671512a2974575677b08
This commit is contained in:
hiyouga
2024-06-16 01:08:12 +08:00
parent 05f3a3c944
commit 32f45c9e91
19 changed files with 406 additions and 76 deletions

View File

@@ -179,8 +179,16 @@ def _setup_lora_tuning(
else:
adapter_to_merge = model_args.adapter_name_or_path
init_kwargs = {
"subfolder": model_args.adapter_folder,
"offload_folder": model_args.offload_folder,
"cache_dir": model_args.cache_dir,
"revision": model_args.model_revision,
"token": model_args.hf_hub_token,
}
for adapter in adapter_to_merge:
model: "LoraModel" = PeftModel.from_pretrained(model, adapter, offload_folder=model_args.offload_folder)
model: "LoraModel" = PeftModel.from_pretrained(model, adapter, **init_kwargs)
model = model.merge_and_unload()
if len(adapter_to_merge) > 0:
@@ -190,12 +198,7 @@ def _setup_lora_tuning(
if model_args.use_unsloth:
model = load_unsloth_peft_model(config, model_args, is_trainable=is_trainable)
else:
model = PeftModel.from_pretrained(
model,
adapter_to_resume,
is_trainable=is_trainable,
offload_folder=model_args.offload_folder,
)
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
@@ -242,6 +245,14 @@ def _setup_lora_tuning(
if model_args.use_unsloth:
model = get_unsloth_peft_model(model, model_args, peft_kwargs)
else:
if finetuning_args.pissa_init:
if finetuning_args.pissa_iter == -1:
logger.info("Using PiSSA initialization.")
peft_kwargs["init_lora_weights"] = "pissa"
else:
logger.info("Using PiSSA initialization with FSVD steps {}.".format(finetuning_args.pissa_iter))
peft_kwargs["init_lora_weights"] = "pissa_niter_{}".format(finetuning_args.pissa_iter)
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,