refactor finetuning Args
Former-commit-id: be425a70a4c8f051717cf1e4464dbd79dae4c0b5
This commit is contained in:
@@ -12,18 +12,6 @@ class FinetuningArguments:
|
||||
default="lora",
|
||||
metadata={"help": "Which fine-tuning method to use."}
|
||||
)
|
||||
num_hidden_layers: Optional[int] = field(
|
||||
default=32,
|
||||
metadata={"help": "Number of decoder blocks in the model for partial-parameter (freeze) fine-tuning. \
|
||||
LLaMA choices: [\"32\", \"40\", \"60\", \"80\"], \
|
||||
LLaMA-2 choices: [\"32\", \"40\", \"80\"], \
|
||||
BLOOM choices: [\"24\", \"30\", \"70\"], \
|
||||
Falcon choices: [\"32\", \"60\"], \
|
||||
Baichuan choices: [\"32\", \"40\"] \
|
||||
Qwen choices: [\"32\"], \
|
||||
XVERSE choices: [\"40\"], \
|
||||
ChatGLM2 choices: [\"28\"]"}
|
||||
)
|
||||
num_layer_trainable: Optional[int] = field(
|
||||
default=3,
|
||||
metadata={"help": "Number of trainable layers for partial-parameter (freeze) fine-tuning."}
|
||||
@@ -33,9 +21,9 @@ class FinetuningArguments:
|
||||
metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \
|
||||
LLaMA choices: [\"mlp\", \"self_attn\"], \
|
||||
BLOOM & Falcon & ChatGLM2 choices: [\"mlp\", \"self_attention\"], \
|
||||
Baichuan choices: [\"mlp\", \"self_attn\"], \
|
||||
Qwen choices: [\"mlp\", \"attn\"], \
|
||||
LLaMA-2, InternLM, XVERSE choices: the same as LLaMA."}
|
||||
Phi-1.5 choices: [\"mlp\", \"mixer\"], \
|
||||
LLaMA-2, Baichuan, InternLM, XVERSE choices: the same as LLaMA."}
|
||||
)
|
||||
lora_rank: Optional[int] = field(
|
||||
default=8,
|
||||
@@ -56,8 +44,13 @@ class FinetuningArguments:
|
||||
BLOOM & Falcon & ChatGLM2 choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \
|
||||
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
||||
Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \
|
||||
Phi-1.5 choices: [\"Wqkv\", \"out_proj\", \"fc1\", \"fc2\"], \
|
||||
LLaMA-2, InternLM, XVERSE choices: the same as LLaMA."}
|
||||
)
|
||||
additional_target: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Name(s) of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint."}
|
||||
)
|
||||
resume_lora_training: Optional[bool] = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
|
||||
@@ -75,12 +68,8 @@ class FinetuningArguments:
|
||||
if isinstance(self.lora_target, str): # support custom target modules/layers of LoRA
|
||||
self.lora_target = [target.strip() for target in self.lora_target.split(",")]
|
||||
|
||||
if self.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
|
||||
trainable_layer_ids = [self.num_hidden_layers - k - 1 for k in range(self.num_layer_trainable)]
|
||||
else: # fine-tuning the first n layers if num_layer_trainable < 0
|
||||
trainable_layer_ids = [k for k in range(-self.num_layer_trainable)]
|
||||
|
||||
self.trainable_layers = ["{:d}.{}".format(idx, self.name_module_trainable) for idx in trainable_layer_ids]
|
||||
if isinstance(self.additional_target, str):
|
||||
self.additional_target = [target.strip() for target in self.additional_target.split(",")]
|
||||
|
||||
assert self.finetuning_type in ["lora", "freeze", "full", "none"], "Invalid fine-tuning method."
|
||||
|
||||
|
||||
Reference in New Issue
Block a user