support lora for llama pro

Former-commit-id: f74c78ba95f0545aae89e603e466f494705ad024
This commit is contained in:
hiyouga
2024-02-21 02:17:22 +08:00
parent a3f30038a0
commit bc16c9a54a
7 changed files with 119 additions and 28 deletions

View File

@@ -5,7 +5,7 @@ from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
from transformers.integrations import is_deepspeed_zero3_enabled
from ..extras.logging import get_logger
from .utils import find_all_linear_modules
from .utils import find_all_linear_modules, find_expanded_modules
if TYPE_CHECKING:
@@ -82,6 +82,8 @@ def init_adapter(
else:
param.requires_grad_(False)
logger.info("Set trainable layers: {}".format(",".join(map(str, trainable_layer_ids))))
if finetuning_args.finetuning_type == "lora":
logger.info("Fine-tuning method: LoRA")
adapter_to_resume = None
@@ -118,6 +120,9 @@ def init_adapter(
else:
target_modules = finetuning_args.lora_target
if finetuning_args.use_llama_pro:
target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable)
peft_kwargs = {
"r": finetuning_args.lora_rank,
"target_modules": target_modules,