[optim] add support to APOLLO (#6617)
Former-commit-id: 5a252e5a458457adbd19da3b68a3897ad2962824
This commit is contained in:
@@ -251,6 +251,59 @@ class GaloreArguments:
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ApolloArguments:
|
||||
r"""
|
||||
Arguments pertaining to the APOLLO algorithm.
|
||||
"""
|
||||
|
||||
use_apollo: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use the APOLLO optimizer."},
|
||||
)
|
||||
apollo_target: str = field(
|
||||
default="all",
|
||||
metadata={
|
||||
"help": (
|
||||
"Name(s) of modules to apply APOLLO. Use commas to separate multiple modules. "
|
||||
"Use `all` to specify all the linear modules."
|
||||
)
|
||||
},
|
||||
)
|
||||
apollo_rank: int = field(
|
||||
default=16,
|
||||
metadata={"help": "The rank of APOLLO gradients."},
|
||||
)
|
||||
apollo_update_interval: int = field(
|
||||
default=200,
|
||||
metadata={"help": "Number of steps to update the APOLLO projection."},
|
||||
)
|
||||
apollo_scale: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "APOLLO scaling coefficient."},
|
||||
)
|
||||
apollo_proj: Literal["svd", "random"] = field(
|
||||
default="random",
|
||||
metadata={"help": "Type of APOLLO low-rank projection algorithm (svd or random)."},
|
||||
)
|
||||
apollo_proj_type: Literal["std", "right", "left",] = field(
|
||||
default="std",
|
||||
metadata={"help": "Type of APOLLO projection."},
|
||||
)
|
||||
apollo_scale_type: Literal["channel", "tensor"] = field(
|
||||
default="channel",
|
||||
metadata={"help": "Type of APOLLO scaling (channel or tensor)."},
|
||||
)
|
||||
apollo_layerwise: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to enable layer-wise update to further save memory."},
|
||||
)
|
||||
apollo_scale_front: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use the norm-growth limiter in front of gradient scaling."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BAdamArgument:
|
||||
r"""
|
||||
@@ -334,7 +387,7 @@ class SwanLabArguments:
|
||||
|
||||
@dataclass
|
||||
class FinetuningArguments(
|
||||
FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument, SwanLabArguments
|
||||
FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, ApolloArguments, BAdamArgument, SwanLabArguments
|
||||
):
|
||||
r"""
|
||||
Arguments pertaining to which techniques we are going to fine-tuning with.
|
||||
@@ -401,6 +454,7 @@ class FinetuningArguments(
|
||||
self.lora_target: List[str] = split_arg(self.lora_target)
|
||||
self.additional_target: Optional[List[str]] = split_arg(self.additional_target)
|
||||
self.galore_target: List[str] = split_arg(self.galore_target)
|
||||
self.apollo_target: List[str] = split_arg(self.apollo_target)
|
||||
self.freeze_vision_tower = self.freeze_vision_tower or self.train_mm_proj_only
|
||||
self.freeze_multi_modal_projector = self.freeze_multi_modal_projector and not self.train_mm_proj_only
|
||||
self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]
|
||||
@@ -421,12 +475,18 @@ class FinetuningArguments(
|
||||
if self.use_llama_pro and self.finetuning_type == "full":
|
||||
raise ValueError("`use_llama_pro` is only valid for Freeze or LoRA training.")
|
||||
|
||||
if self.finetuning_type == "lora" and (self.use_galore or self.use_badam):
|
||||
if self.finetuning_type == "lora" and (self.use_galore or self.use_badam or self.use_apollo):
|
||||
raise ValueError("Cannot use LoRA with GaLore or BAdam together.")
|
||||
|
||||
if self.use_galore and self.use_badam:
|
||||
raise ValueError("Cannot use GaLore with BAdam together.")
|
||||
|
||||
if self.use_galore and self.use_apollo:
|
||||
raise ValueError("Cannot use GaLore with APOLLO together.")
|
||||
|
||||
if self.use_badam and self.use_apollo:
|
||||
raise ValueError("Cannot use BAdam with APOLLO together.")
|
||||
|
||||
if self.pissa_init and (self.stage in ["ppo", "kto"] or self.use_ref_model):
|
||||
raise ValueError("Cannot use PiSSA for current training stage.")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user