Merge pull request #3287 from Ledzy/badam

[Feature] Add BAdam algorithm

Former-commit-id: 10a5e1e65b34b03e5ca2a41bf6ded09a3fb25f0c
This commit is contained in:
hoshi-hiyouga
2024-04-16 17:32:16 +08:00
committed by GitHub
9 changed files with 215 additions and 11 deletions

View File

@@ -172,7 +172,7 @@ class GaloreArguments:
use_galore: bool = field(
default=False,
metadata={"help": "Whether or not to use gradient low-Rank projection."},
metadata={"help": "Whether or not to use the gradient low-Rank projection (GaLore)."},
)
galore_target: str = field(
default="all",
@@ -204,7 +204,54 @@ class GaloreArguments:
@dataclass
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments):
class BAdamArgument:
r"""
Arguments pertaining to the BAdam optimizer.
"""
use_badam: bool = field(
default=False,
metadata={"help": "Whether or not to use the BAdam optimizer."},
)
badam_mode: Literal["layer", "ratio"] = field(
default="layer",
metadata={"help": "Whether to use layer-wise or ratio-wise BAdam optimizer."},
)
badam_start_block: Optional[int] = field(
default=None,
metadata={"help": "The starting block index for layer-wise BAdam."},
)
badam_switch_block_every: Optional[int] = field(
default=50,
metadata={"help": "How often to switch model's block update. Set to -1 to disable the block update."},
)
badam_switch_mode: Optional[Literal["ascending", "descending", "random", "fixed"]] = field(
default="ascending",
metadata={"help": "the strategy of picking block to update for layer-wise BAdam."},
)
badam_update_ratio: float = field(
default=0.0,
metadata={"help": "The ratio of the update for ratio-wise BAdam."},
)
badam_mask_mode: Literal["adjacent", "scatter"] = field(
default="adjacent",
metadata={
"help": """The mode of the mask for BAdam optimizer. \
`adjacent` means that the trainable parameters are adjacent to each other, \
`scatter` means that trainable parameters are randomly choosed from the weight."""
},
)
badam_verbose: int = field(
default=0,
metadata={
"help": """The verbosity level of BAdam optimizer. \
0 for no print, 1 for print the block prefix, 2 for print trainable parameters"""
},
)
@dataclass
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument):
r"""
Arguments pertaining to which techniques we are going to fine-tuning with.
"""

View File

@@ -88,6 +88,9 @@ def _check_extra_dependencies(
if finetuning_args.use_galore:
require_version("galore_torch", "To fix: pip install galore_torch")
if finetuning_args.use_badam:
require_version("badam", "To fix: pip install badam")
if training_args is not None and training_args.predict_with_generate:
require_version("jieba", "To fix: pip install jieba")
require_version("nltk", "To fix: pip install nltk")
@@ -172,7 +175,14 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
raise ValueError("Distributed training does not support layer-wise GaLore.")
if finetuning_args.use_galore and training_args.deepspeed is not None:
raise ValueError("GaLore is incompatible with DeepSpeed.")
raise ValueError("GaLore is incompatible with DeepSpeed yet.")
if (
finetuning_args.use_badam
and finetuning_args.badam_mode == "layer"
and training_args.parallel_mode.value == "distributed"
):
raise ValueError("Layer-wise BAdam does not yet support distributed training, use ratio-wise BAdam.")
if model_args.infer_backend == "vllm":
raise ValueError("vLLM backend is only available for API, CLI and Web.")