[feature] adding orthogononal finetuning (OFT) to llama factory (#8623)

Co-authored-by: Zeju <zqiu@g003.internal.cluster.is.localnet>
Co-authored-by: Zeju <zqiu@login2.is.localnet>
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
This commit is contained in:
Zeju Qiu
2025-08-18 12:22:47 +02:00
committed by GitHub
parent 1ada15981a
commit 003a2acb1a
13 changed files with 375 additions and 47 deletions

View File

@@ -16,10 +16,11 @@ import re
from typing import TYPE_CHECKING
import torch
from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
from peft import LoraConfig, LoraModel, OFTConfig, OFTModel, PeftModel, TaskType, get_peft_model
from transformers.integrations import is_deepspeed_zero3_enabled
from ..extras import logging
from ..extras.misc import check_version
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
from .model_utils.quantization import QuantizationMethod
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
@@ -147,7 +148,10 @@ def _setup_lora_tuning(
cast_trainable_params_to_fp32: bool,
) -> "PeftModel":
if is_trainable:
logger.info_rank0("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
if finetuning_args.finetuning_type == "oft":
logger.info_rank0("Fine-tuning method: OFT")
else:
logger.info_rank0("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
adapter_to_resume = None
@@ -223,17 +227,29 @@ def _setup_lora_tuning(
finetuning_args.additional_target = module_names
logger.warning_rank0("Vocab has been resized, add {} to trainable params.".format(",".join(module_names)))
peft_kwargs = {
"r": finetuning_args.lora_rank,
"target_modules": target_modules,
"lora_alpha": finetuning_args.lora_alpha,
"lora_dropout": finetuning_args.lora_dropout,
"use_rslora": finetuning_args.use_rslora,
"use_dora": finetuning_args.use_dora,
"modules_to_save": finetuning_args.additional_target,
}
if finetuning_args.finetuning_type == "lora":
peft_kwargs = {
"r": finetuning_args.lora_rank,
"target_modules": target_modules,
"lora_alpha": finetuning_args.lora_alpha,
"lora_dropout": finetuning_args.lora_dropout,
"use_rslora": finetuning_args.use_rslora,
"use_dora": finetuning_args.use_dora,
"modules_to_save": finetuning_args.additional_target,
}
elif finetuning_args.finetuning_type == "oft":
peft_kwargs = {
"r": finetuning_args.oft_rank,
"oft_block_size": finetuning_args.oft_block_size,
"target_modules": target_modules,
"module_dropout": finetuning_args.module_dropout,
"modules_to_save": finetuning_args.additional_target,
}
if model_args.use_unsloth:
if finetuning_args.finetuning_type == "oft":
raise ValueError("Unsloth is currently not supported for OFT.")
model = get_unsloth_peft_model(model, model_args, peft_kwargs)
else:
if finetuning_args.pissa_init:
@@ -244,12 +260,19 @@ def _setup_lora_tuning(
logger.info_rank0(f"Using PiSSA initialization with FSVD steps {finetuning_args.pissa_iter}.")
peft_kwargs["init_lora_weights"] = f"pissa_niter_{finetuning_args.pissa_iter}"
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
**peft_kwargs,
)
model = get_peft_model(model, lora_config)
if finetuning_args.finetuning_type == "lora":
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
**peft_kwargs,
)
elif finetuning_args.finetuning_type == "oft":
peft_config = OFTConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
**peft_kwargs,
)
model = get_peft_model(model, peft_config)
if is_trainable and cast_trainable_params_to_fp32:
for param in filter(lambda p: p.requires_grad, model.parameters()):
@@ -272,8 +295,8 @@ def init_adapter(
Note that the trainable parameters must be cast to float32.
"""
if is_trainable and getattr(model, "quantization_method", None) is not None:
if finetuning_args.finetuning_type != "lora":
raise ValueError("Quantized models can only be used for the LoRA tuning.")
if finetuning_args.finetuning_type not in ["lora", "oft"]:
raise ValueError("Quantized models can only be used for the LoRA or OFT tuning.")
if finetuning_args.pissa_init:
raise ValueError("Cannot initialize PiSSA adapter on quantized models.")
@@ -296,7 +319,7 @@ def init_adapter(
_setup_full_tuning(model, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
elif finetuning_args.finetuning_type == "freeze":
_setup_freeze_tuning(model, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
elif finetuning_args.finetuning_type == "lora":
elif finetuning_args.finetuning_type in ["lora", "oft"]:
model = _setup_lora_tuning(
config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32
)