[optim] add support to APOLLO (#6617)

Former-commit-id: 5a252e5a458457adbd19da3b68a3897ad2962824
This commit is contained in:
zhuHQ
2025-01-14 10:24:56 -06:00
committed by GitHub
parent 66184762e8
commit c2120432db
10 changed files with 351 additions and 5 deletions

View File

@@ -251,6 +251,59 @@ class GaloreArguments:
)
@dataclass
class ApolloArguments:
r"""
Arguments pertaining to the APOLLO algorithm.
"""
use_apollo: bool = field(
default=False,
metadata={"help": "Whether or not to use the APOLLO optimizer."},
)
apollo_target: str = field(
default="all",
metadata={
"help": (
"Name(s) of modules to apply APOLLO. Use commas to separate multiple modules. "
"Use `all` to specify all the linear modules."
)
},
)
apollo_rank: int = field(
default=16,
metadata={"help": "The rank of APOLLO gradients."},
)
apollo_update_interval: int = field(
default=200,
metadata={"help": "Number of steps to update the APOLLO projection."},
)
apollo_scale: float = field(
default=1.0,
metadata={"help": "APOLLO scaling coefficient."},
)
apollo_proj: Literal["svd", "random"] = field(
default="random",
metadata={"help": "Type of APOLLO low-rank projection algorithm (svd or random)."},
)
apollo_proj_type: Literal["std", "right", "left",] = field(
default="std",
metadata={"help": "Type of APOLLO projection."},
)
apollo_scale_type: Literal["channel", "tensor"] = field(
default="channel",
metadata={"help": "Type of APOLLO scaling (channel or tensor)."},
)
apollo_layerwise: bool = field(
default=False,
metadata={"help": "Whether or not to enable layer-wise update to further save memory."},
)
apollo_scale_front: bool = field(
default=False,
metadata={"help": "Whether or not to use the norm-growth limiter in front of gradient scaling."},
)
@dataclass
class BAdamArgument:
r"""
@@ -334,7 +387,7 @@ class SwanLabArguments:
@dataclass
class FinetuningArguments(
FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument, SwanLabArguments
FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, ApolloArguments, BAdamArgument, SwanLabArguments
):
r"""
Arguments pertaining to which techniques we are going to fine-tuning with.
@@ -401,6 +454,7 @@ class FinetuningArguments(
self.lora_target: List[str] = split_arg(self.lora_target)
self.additional_target: Optional[List[str]] = split_arg(self.additional_target)
self.galore_target: List[str] = split_arg(self.galore_target)
self.apollo_target: List[str] = split_arg(self.apollo_target)
self.freeze_vision_tower = self.freeze_vision_tower or self.train_mm_proj_only
self.freeze_multi_modal_projector = self.freeze_multi_modal_projector and not self.train_mm_proj_only
self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]
@@ -421,12 +475,18 @@ class FinetuningArguments(
if self.use_llama_pro and self.finetuning_type == "full":
raise ValueError("`use_llama_pro` is only valid for Freeze or LoRA training.")
if self.finetuning_type == "lora" and (self.use_galore or self.use_badam):
if self.finetuning_type == "lora" and (self.use_galore or self.use_badam or self.use_apollo):
raise ValueError("Cannot use LoRA with GaLore or BAdam together.")
if self.use_galore and self.use_badam:
raise ValueError("Cannot use GaLore with BAdam together.")
if self.use_galore and self.use_apollo:
raise ValueError("Cannot use GaLore with APOLLO together.")
if self.use_badam and self.use_apollo:
raise ValueError("Cannot use BAdam with APOLLO together.")
if self.pissa_init and (self.stage in ["ppo", "kto"] or self.use_ref_model):
raise ValueError("Cannot use PiSSA for current training stage.")