update ppo trainer

Former-commit-id: caa525a5c6f228b9ad71387d1fe4f1c2ffa2479e
This commit is contained in:
hiyouga
2023-11-20 21:39:15 +08:00
parent e585950c54
commit 28258aecd2
7 changed files with 68 additions and 41 deletions

View File

@@ -8,10 +8,6 @@ class FreezeArguments:
r"""
Arguments pertaining to the freeze (partial-parameter) training.
"""
num_layer_trainable: Optional[int] = field(
default=3,
metadata={"help": "Number of trainable layers for partial-parameter (freeze) fine-tuning."}
)
name_module_trainable: Optional[str] = field(
default="mlp",
metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \
@@ -22,6 +18,10 @@ class FreezeArguments:
Phi-1.5 choices: [\"mlp\", \"mixer\"], \
Others choices: the same as LLaMA."}
)
num_layer_trainable: Optional[int] = field(
default=3,
metadata={"help": "The number of trainable layers for partial-parameter (freeze) fine-tuning."}
)
@dataclass
@@ -29,9 +29,9 @@ class LoraArguments:
r"""
Arguments pertaining to the LoRA training.
"""
lora_rank: Optional[int] = field(
default=8,
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}
additional_target: Optional[str] = field(
default=None,
metadata={"help": "Name(s) of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint."}
)
lora_alpha: Optional[float] = field(
default=None,
@@ -41,6 +41,10 @@ class LoraArguments:
default=0.1,
metadata={"help": "Dropout rate for the LoRA fine-tuning."}
)
lora_rank: Optional[int] = field(
default=8,
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}
)
lora_target: Optional[str] = field(
default=None,
metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
@@ -51,10 +55,6 @@ class LoraArguments:
Phi-1.5 choices: [\"Wqkv\", \"out_proj\", \"fc1\", \"fc2\"], \
Others choices: the same as LLaMA."}
)
additional_target: Optional[str] = field(
default=None,
metadata={"help": "Name(s) of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint."}
)
resume_lora_training: Optional[bool] = field(
default=True,
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
@@ -70,13 +70,17 @@ class RLHFArguments:
default=0.1,
metadata={"help": "The beta parameter for the DPO loss."}
)
ppo_logger: Optional[str] = field(
default=None,
metadata={"help": "Log with either 'wandb' or 'tensorboard' in PPO training."}
ppo_buffer_size: Optional[int] = field(
default=1,
metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."}
)
ppo_epochs: Optional[int] = field(
default=4,
metadata={"help": "Number of optimisation epochs per batch of samples"},
metadata={"help": "The number of epochs to perform in a PPO optimization step."}
)
ppo_logger: Optional[str] = field(
default=None,
metadata={"help": "Log with either \"wandb\" or \"tensorboard\" in PPO training."}
)
ppo_score_norm: Optional[bool] = field(
default=False,