improve lora+ impl.

Former-commit-id: 332bad25455a70ad9204e7dd384bb086d789aa39
This commit is contained in:
hiyouga
2024-03-13 23:32:51 +08:00
parent 73f4513c84
commit 46f99ff277
12 changed files with 165 additions and 169 deletions

View File

@@ -57,7 +57,7 @@ class LoraArguments:
metadata={
"help": """Name(s) of target modules to apply LoRA. \
Use commas to separate multiple modules. \
Use "all" to specify all the available modules. \
Use "all" to specify all the linear modules. \
LLaMA choices: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], \
BLOOM & Falcon & ChatGLM choices: ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"], \
Baichuan choices: ["W_pack", "o_proj", "gate_proj", "up_proj", "down_proj"], \
@@ -66,6 +66,14 @@ class LoraArguments:
Others choices: the same as LLaMA."""
},
)
loraplus_lr_ratio: Optional[float] = field(
default=None,
metadata={"help": "LoRA plus learning rate ratio (lr_B / lr_A)."},
)
loraplus_lr_embedding: float = field(
default=1e-6,
metadata={"help": "LoRA plus learning rate for lora embedding layers."},
)
use_rslora: bool = field(
default=False,
metadata={"help": "Whether or not to use the rank stabilization scaling factor for LoRA layer."},
@@ -163,8 +171,11 @@ class GaloreArguments:
metadata={"help": "Whether or not to use gradient low-Rank projection."},
)
galore_target: str = field(
default="mlp,attn",
metadata={"help": "Name(s) of modules to apply GaLore. Use commas to separate multiple modules."},
default="all",
metadata={
"help": """Name(s) of modules to apply GaLore. Use commas to separate multiple modules. \
Use "all" to specify all the linear modules."""
},
)
galore_rank: int = field(
default=16,
@@ -210,11 +221,6 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
default=False,
metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."},
)
# for lora+,[LoRA+: Efficient Low Rank Adaptation of Large Models](https://arxiv.org/pdf/2402.12354.pdf)
lora_lr_ratio: Optional[float] = field(
default=None,
metadata={'help': 'The lora learning_rate ratio of lora_A to lora_B, option:16.0.'},
)
plot_loss: bool = field(
default=False,
metadata={"help": "Whether or not to save the training loss curves."},
@@ -230,6 +236,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
self.lora_alpha = self.lora_alpha or self.lora_rank * 2
self.lora_target = split_arg(self.lora_target)
self.additional_target = split_arg(self.additional_target)
self.galore_target = split_arg(self.galore_target)
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."