Feature BAdam

Former-commit-id: d8d2807fbcf587c37f7fd34a23e9397d2775ceed
This commit is contained in:
Jonery
2024-04-15 23:15:27 +08:00
parent 276f2cb24e
commit d4d471450f
9 changed files with 195 additions and 7 deletions

View File

@@ -163,6 +163,47 @@ class RLHFArguments:
metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."},
)
@dataclass
class BAdamArgument:
r"""
Arguments for BAdam optimizer.
"""
use_badam: bool = field(
default=False,
metadata={"help": "Whether or not to use BAdam optimizer."},
)
badam_mode: Literal["layer", "ratio"] = field(
default="layer",
metadata={"help": "The mode of BAdam optimizer. 'layer' for layer-wise, 'ratio' for ratio-wise."},
)
# ======== Arguments for layer-wise update ========
start_block: Optional[int] = field(
default=None,
metadata={"help": "The starting block index for block-wise fine-tuning."}
)
switch_block_every: Optional[int] = field(
default=50,
metadata={"help": "how often to switch model's block update. Set to -1 to disable the block update."}
)
switch_mode: Optional[Literal["ascending", "descending", "random", "fixed"]] = field(
default="ascending",
metadata={"help": "the strategy of picking block to update."}
)
# ======== Arguments for ratio-wise update ========
badam_update_ratio: float = field(
default=0.,
metadata={"help": "The ratio of the update for the BAdam optimizer."}
)
badam_mask_mode: Literal["adjacent", "scatter"] = field(
default="adjacent",
metadata={"help": "The mode of the mask for BAdam optimizer. `adjacent` means that the trainable parameters are adjacent to each other; `scatter` means that trainable parameters are randomly choosed from the weight."}
)
badam_verbose: int = field(
default=0,
metadata={"help": "The verbosity level of BAdam optimizer. 0 for no print, 1 for print the block prefix, 2 for print trainable parameters"}
)
@dataclass
class GaloreArguments:
@@ -204,7 +245,7 @@ class GaloreArguments:
@dataclass
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments):
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument):
r"""
Arguments pertaining to which techniques we are going to fine-tuning with.
"""