improve lora+ impl.
Former-commit-id: 332bad25455a70ad9204e7dd384bb086d789aa39
This commit is contained in:
@@ -57,7 +57,7 @@ class LoraArguments:
|
||||
metadata={
|
||||
"help": """Name(s) of target modules to apply LoRA. \
|
||||
Use commas to separate multiple modules. \
|
||||
Use "all" to specify all the available modules. \
|
||||
Use "all" to specify all the linear modules. \
|
||||
LLaMA choices: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], \
|
||||
BLOOM & Falcon & ChatGLM choices: ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"], \
|
||||
Baichuan choices: ["W_pack", "o_proj", "gate_proj", "up_proj", "down_proj"], \
|
||||
@@ -66,6 +66,14 @@ class LoraArguments:
|
||||
Others choices: the same as LLaMA."""
|
||||
},
|
||||
)
|
||||
loraplus_lr_ratio: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "LoRA plus learning rate ratio (lr_B / lr_A)."},
|
||||
)
|
||||
loraplus_lr_embedding: float = field(
|
||||
default=1e-6,
|
||||
metadata={"help": "LoRA plus learning rate for lora embedding layers."},
|
||||
)
|
||||
use_rslora: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use the rank stabilization scaling factor for LoRA layer."},
|
||||
@@ -163,8 +171,11 @@ class GaloreArguments:
|
||||
metadata={"help": "Whether or not to use gradient low-Rank projection."},
|
||||
)
|
||||
galore_target: str = field(
|
||||
default="mlp,attn",
|
||||
metadata={"help": "Name(s) of modules to apply GaLore. Use commas to separate multiple modules."},
|
||||
default="all",
|
||||
metadata={
|
||||
"help": """Name(s) of modules to apply GaLore. Use commas to separate multiple modules. \
|
||||
Use "all" to specify all the linear modules."""
|
||||
},
|
||||
)
|
||||
galore_rank: int = field(
|
||||
default=16,
|
||||
@@ -210,11 +221,6 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."},
|
||||
)
|
||||
# for lora+,[LoRA+: Efficient Low Rank Adaptation of Large Models](https://arxiv.org/pdf/2402.12354.pdf)
|
||||
lora_lr_ratio: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={'help': 'The lora learning_rate ratio of lora_A to lora_B, option:16.0.'},
|
||||
)
|
||||
plot_loss: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to save the training loss curves."},
|
||||
@@ -230,6 +236,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
|
||||
self.lora_alpha = self.lora_alpha or self.lora_rank * 2
|
||||
self.lora_target = split_arg(self.lora_target)
|
||||
self.additional_target = split_arg(self.additional_target)
|
||||
self.galore_target = split_arg(self.galore_target)
|
||||
|
||||
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
|
||||
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||
|
||||
Reference in New Issue
Block a user