support llama pro #2338 , add rslora

Former-commit-id: 40d659b7f30dd5a004703c176ec1f22dc864e505
This commit is contained in:
hiyouga
2024-02-15 02:27:36 +08:00
parent b403f8d8a8
commit 596b6828cb
24 changed files with 438 additions and 203 deletions

View File

@@ -10,20 +10,25 @@ class FreezeArguments:
"""
name_module_trainable: Optional[str] = field(
default="mlp",
default=None,
metadata={
"help": 'Name of trainable modules for partial-parameter (freeze) fine-tuning. \
Use commas to separate multiple modules. \
LLaMA choices: ["mlp", "self_attn"], \
BLOOM & Falcon & ChatGLM choices: ["mlp", "self_attention"], \
Qwen choices: ["mlp", "attn"], \
Phi choices: ["mlp", "mixer"], \
InternLM2 choices: ["feed_forward", "attention"], \
Others choices: the same as LLaMA.'
"help": """Name of trainable modules for partial-parameter (freeze) fine-tuning. \
Use commas to separate multiple modules. \
Use "all" to specify all the available modules. \
LLaMA choices: ["mlp", "self_attn"], \
BLOOM & Falcon & ChatGLM choices: ["mlp", "self_attention"], \
Qwen choices: ["mlp", "attn"], \
InternLM2 choices: ["feed_forward", "attention"], \
Others choices: the same as LLaMA."""
},
)
num_layer_trainable: Optional[int] = field(
default=3, metadata={"help": "The number of trainable layers for partial-parameter (freeze) fine-tuning."}
default=3,
metadata={"help": "The number of trainable layers for partial-parameter (freeze) fine-tuning."},
)
use_llama_pro: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to use llama pro for partial-parameter (freeze) fine-tuning."},
)
@@ -40,27 +45,42 @@ class LoraArguments:
},
)
lora_alpha: Optional[int] = field(
default=None, metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."}
default=None,
metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."},
)
lora_dropout: Optional[float] = field(
default=0.0,
metadata={"help": "Dropout rate for the LoRA fine-tuning."},
)
lora_rank: Optional[int] = field(
default=8,
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."},
)
lora_dropout: Optional[float] = field(default=0.0, metadata={"help": "Dropout rate for the LoRA fine-tuning."})
lora_rank: Optional[int] = field(default=8, metadata={"help": "The intrinsic dimension for LoRA fine-tuning."})
lora_target: Optional[str] = field(
default=None,
metadata={
"help": 'Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
LLaMA choices: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], \
BLOOM & Falcon & ChatGLM choices: ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"], \
Baichuan choices: ["W_pack", "o_proj", "gate_proj", "up_proj", "down_proj"], \
Qwen choices: ["c_attn", "attn.c_proj", "w1", "w2", "mlp.c_proj"], \
Phi choices: ["Wqkv", "out_proj", "fc1", "fc2"], \
Others choices: the same as LLaMA.'
"help": """Name(s) of target modules to apply LoRA. \
Use commas to separate multiple modules. \
Use "all" to specify all the available modules. \
LLaMA choices: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], \
BLOOM & Falcon & ChatGLM choices: ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"], \
Baichuan choices: ["W_pack", "o_proj", "gate_proj", "up_proj", "down_proj"], \
Qwen choices: ["c_attn", "attn.c_proj", "w1", "w2", "mlp.c_proj"], \
InternLM2 choices: ["wqkv", "wo", "w1", "w2", "w3"], \
Others choices: the same as LLaMA."""
},
)
lora_bf16_mode: Optional[bool] = field(
default=False, metadata={"help": "Whether or not to train lora adapters in bf16 precision."}
default=False,
metadata={"help": "Whether or not to train lora adapters in bf16 precision."},
)
use_rslora: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to use the rank stabilization scaling factor for LoRA layer."},
)
create_new_adapter: Optional[bool] = field(
default=False, metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."}
default=False,
metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."},
)
@@ -70,49 +90,65 @@ class RLHFArguments:
Arguments pertaining to the PPO and DPO training.
"""
dpo_beta: Optional[float] = field(default=0.1, metadata={"help": "The beta parameter for the DPO loss."})
dpo_beta: Optional[float] = field(
default=0.1,
metadata={"help": "The beta parameter for the DPO loss."},
)
dpo_loss: Optional[Literal["sigmoid", "hinge", "ipo", "kto"]] = field(
default="sigmoid", metadata={"help": "The type of DPO loss to use."}
default="sigmoid",
metadata={"help": "The type of DPO loss to use."},
)
dpo_ftx: Optional[float] = field(
default=0, metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."}
default=0,
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."},
)
ppo_buffer_size: Optional[int] = field(
default=1,
metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."},
)
ppo_epochs: Optional[int] = field(
default=4, metadata={"help": "The number of epochs to perform in a PPO optimization step."}
default=4,
metadata={"help": "The number of epochs to perform in a PPO optimization step."},
)
ppo_logger: Optional[str] = field(
default=None, metadata={"help": 'Log with either "wandb" or "tensorboard" in PPO training.'}
default=None,
metadata={"help": 'Log with either "wandb" or "tensorboard" in PPO training.'},
)
ppo_score_norm: Optional[bool] = field(
default=False, metadata={"help": "Use score normalization in PPO training."}
default=False,
metadata={"help": "Use score normalization in PPO training."},
)
ppo_target: Optional[float] = field(
default=6.0, metadata={"help": "Target KL value for adaptive KL control in PPO training."}
default=6.0,
metadata={"help": "Target KL value for adaptive KL control in PPO training."},
)
ppo_whiten_rewards: Optional[bool] = field(
default=False, metadata={"help": "Whiten the rewards before compute advantages in PPO training."}
default=False,
metadata={"help": "Whiten the rewards before compute advantages in PPO training."},
)
ref_model: Optional[str] = field(
default=None, metadata={"help": "Path to the reference model used for the PPO or DPO training."}
default=None,
metadata={"help": "Path to the reference model used for the PPO or DPO training."},
)
ref_model_adapters: Optional[str] = field(
default=None, metadata={"help": "Path to the adapters of the reference model."}
default=None,
metadata={"help": "Path to the adapters of the reference model."},
)
ref_model_quantization_bit: Optional[int] = field(
default=None, metadata={"help": "The number of bits to quantize the reference model."}
default=None,
metadata={"help": "The number of bits to quantize the reference model."},
)
reward_model: Optional[str] = field(
default=None, metadata={"help": "Path to the reward model used for the PPO training."}
default=None,
metadata={"help": "Path to the reward model used for the PPO training."},
)
reward_model_adapters: Optional[str] = field(
default=None, metadata={"help": "Path to the adapters of the reward model."}
default=None,
metadata={"help": "Path to the adapters of the reward model."},
)
reward_model_quantization_bit: Optional[int] = field(
default=None, metadata={"help": "The number of bits to quantize the reward model."}
default=None,
metadata={"help": "The number of bits to quantize the reward model."},
)
reward_model_type: Optional[Literal["lora", "full", "api"]] = field(
default="lora",
@@ -127,16 +163,20 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
"""
stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
default="sft", metadata={"help": "Which stage will be performed in training."}
default="sft",
metadata={"help": "Which stage will be performed in training."},
)
finetuning_type: Optional[Literal["lora", "freeze", "full"]] = field(
default="lora", metadata={"help": "Which fine-tuning method to use."}
default="lora",
metadata={"help": "Which fine-tuning method to use."},
)
disable_version_checking: Optional[bool] = field(
default=False, metadata={"help": "Whether or not to disable version checking."}
default=False,
metadata={"help": "Whether or not to disable version checking."},
)
plot_loss: Optional[bool] = field(
default=False, metadata={"help": "Whether or not to save the training loss curves."}
default=False,
metadata={"help": "Whether or not to save the training loss curves."},
)
def __post_init__(self):