Feature BAdam
Former-commit-id: d8d2807fbcf587c37f7fd34a23e9397d2775ceed
This commit is contained in:
@@ -163,6 +163,47 @@ class RLHFArguments:
|
||||
metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."},
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class BAdamArgument:
|
||||
r"""
|
||||
Arguments for BAdam optimizer.
|
||||
"""
|
||||
use_badam: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use BAdam optimizer."},
|
||||
)
|
||||
badam_mode: Literal["layer", "ratio"] = field(
|
||||
default="layer",
|
||||
metadata={"help": "The mode of BAdam optimizer. 'layer' for layer-wise, 'ratio' for ratio-wise."},
|
||||
)
|
||||
|
||||
# ======== Arguments for layer-wise update ========
|
||||
start_block: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The starting block index for block-wise fine-tuning."}
|
||||
)
|
||||
switch_block_every: Optional[int] = field(
|
||||
default=50,
|
||||
metadata={"help": "how often to switch model's block update. Set to -1 to disable the block update."}
|
||||
)
|
||||
switch_mode: Optional[Literal["ascending", "descending", "random", "fixed"]] = field(
|
||||
default="ascending",
|
||||
metadata={"help": "the strategy of picking block to update."}
|
||||
)
|
||||
|
||||
# ======== Arguments for ratio-wise update ========
|
||||
badam_update_ratio: float = field(
|
||||
default=0.,
|
||||
metadata={"help": "The ratio of the update for the BAdam optimizer."}
|
||||
)
|
||||
badam_mask_mode: Literal["adjacent", "scatter"] = field(
|
||||
default="adjacent",
|
||||
metadata={"help": "The mode of the mask for BAdam optimizer. `adjacent` means that the trainable parameters are adjacent to each other; `scatter` means that trainable parameters are randomly choosed from the weight."}
|
||||
)
|
||||
badam_verbose: int = field(
|
||||
default=0,
|
||||
metadata={"help": "The verbosity level of BAdam optimizer. 0 for no print, 1 for print the block prefix, 2 for print trainable parameters"}
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class GaloreArguments:
|
||||
@@ -204,7 +245,7 @@ class GaloreArguments:
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments):
|
||||
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument):
|
||||
r"""
|
||||
Arguments pertaining to which techniques we are going to fine-tuning with.
|
||||
"""
|
||||
|
||||
@@ -171,6 +171,12 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
if finetuning_args.use_galore and training_args.deepspeed is not None:
|
||||
raise ValueError("GaLore is incompatible with DeepSpeed.")
|
||||
|
||||
if (finetuning_args.use_badam
|
||||
and finetuning_args.badam_mode == "layer"
|
||||
and training_args.parallel_mode.value == "distributed"
|
||||
):
|
||||
raise ValueError("BAdam with layer-wise mode is not supported in distributed training by now, use ratio mode instead.")
|
||||
|
||||
if model_args.infer_backend == "vllm":
|
||||
raise ValueError("vLLM backend is only available for API, CLI and Web.")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user