[optim] add support to APOLLO (#6617)
Former-commit-id: 5a252e5a458457adbd19da3b68a3897ad2962824
This commit is contained in:
@@ -251,6 +251,59 @@ class GaloreArguments:
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ApolloArguments:
|
||||
r"""
|
||||
Arguments pertaining to the APOLLO algorithm.
|
||||
"""
|
||||
|
||||
use_apollo: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use the APOLLO optimizer."},
|
||||
)
|
||||
apollo_target: str = field(
|
||||
default="all",
|
||||
metadata={
|
||||
"help": (
|
||||
"Name(s) of modules to apply APOLLO. Use commas to separate multiple modules. "
|
||||
"Use `all` to specify all the linear modules."
|
||||
)
|
||||
},
|
||||
)
|
||||
apollo_rank: int = field(
|
||||
default=16,
|
||||
metadata={"help": "The rank of APOLLO gradients."},
|
||||
)
|
||||
apollo_update_interval: int = field(
|
||||
default=200,
|
||||
metadata={"help": "Number of steps to update the APOLLO projection."},
|
||||
)
|
||||
apollo_scale: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "APOLLO scaling coefficient."},
|
||||
)
|
||||
apollo_proj: Literal["svd", "random"] = field(
|
||||
default="random",
|
||||
metadata={"help": "Type of APOLLO low-rank projection algorithm (svd or random)."},
|
||||
)
|
||||
apollo_proj_type: Literal["std", "right", "left",] = field(
|
||||
default="std",
|
||||
metadata={"help": "Type of APOLLO projection."},
|
||||
)
|
||||
apollo_scale_type: Literal["channel", "tensor"] = field(
|
||||
default="channel",
|
||||
metadata={"help": "Type of APOLLO scaling (channel or tensor)."},
|
||||
)
|
||||
apollo_layerwise: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to enable layer-wise update to further save memory."},
|
||||
)
|
||||
apollo_scale_front: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use the norm-growth limiter in front of gradient scaling."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BAdamArgument:
|
||||
r"""
|
||||
@@ -334,7 +387,7 @@ class SwanLabArguments:
|
||||
|
||||
@dataclass
|
||||
class FinetuningArguments(
|
||||
FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument, SwanLabArguments
|
||||
FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, ApolloArguments, BAdamArgument, SwanLabArguments
|
||||
):
|
||||
r"""
|
||||
Arguments pertaining to which techniques we are going to fine-tuning with.
|
||||
@@ -401,6 +454,7 @@ class FinetuningArguments(
|
||||
self.lora_target: List[str] = split_arg(self.lora_target)
|
||||
self.additional_target: Optional[List[str]] = split_arg(self.additional_target)
|
||||
self.galore_target: List[str] = split_arg(self.galore_target)
|
||||
self.apollo_target: List[str] = split_arg(self.apollo_target)
|
||||
self.freeze_vision_tower = self.freeze_vision_tower or self.train_mm_proj_only
|
||||
self.freeze_multi_modal_projector = self.freeze_multi_modal_projector and not self.train_mm_proj_only
|
||||
self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]
|
||||
@@ -421,12 +475,18 @@ class FinetuningArguments(
|
||||
if self.use_llama_pro and self.finetuning_type == "full":
|
||||
raise ValueError("`use_llama_pro` is only valid for Freeze or LoRA training.")
|
||||
|
||||
if self.finetuning_type == "lora" and (self.use_galore or self.use_badam):
|
||||
if self.finetuning_type == "lora" and (self.use_galore or self.use_badam or self.use_apollo):
|
||||
raise ValueError("Cannot use LoRA with GaLore or BAdam together.")
|
||||
|
||||
if self.use_galore and self.use_badam:
|
||||
raise ValueError("Cannot use GaLore with BAdam together.")
|
||||
|
||||
if self.use_galore and self.use_apollo:
|
||||
raise ValueError("Cannot use GaLore with APOLLO together.")
|
||||
|
||||
if self.use_badam and self.use_apollo:
|
||||
raise ValueError("Cannot use BAdam with APOLLO together.")
|
||||
|
||||
if self.pissa_init and (self.stage in ["ppo", "kto"] or self.use_ref_model):
|
||||
raise ValueError("Cannot use PiSSA for current training stage.")
|
||||
|
||||
|
||||
@@ -139,6 +139,9 @@ def _check_extra_dependencies(
|
||||
if finetuning_args.use_galore:
|
||||
check_version("galore_torch", mandatory=True)
|
||||
|
||||
if finetuning_args.use_apollo:
|
||||
check_version("apollo_torch", mandatory=True)
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
check_version("badam>=1.2.1", mandatory=True)
|
||||
|
||||
@@ -262,6 +265,13 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
|
||||
):
|
||||
raise ValueError("Distributed training does not support layer-wise GaLore.")
|
||||
|
||||
if (
|
||||
finetuning_args.use_apollo
|
||||
and finetuning_args.apollo_layerwise
|
||||
and training_args.parallel_mode == ParallelMode.DISTRIBUTED
|
||||
):
|
||||
raise ValueError("Distributed training does not support layer-wise APOLLO.")
|
||||
|
||||
if finetuning_args.use_badam and training_args.parallel_mode == ParallelMode.DISTRIBUTED:
|
||||
if finetuning_args.badam_mode == "ratio":
|
||||
raise ValueError("Radio-based BAdam does not yet support distributed training, use layer-wise BAdam.")
|
||||
@@ -271,6 +281,9 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
|
||||
if finetuning_args.use_galore and training_args.deepspeed is not None:
|
||||
raise ValueError("GaLore is incompatible with DeepSpeed yet.")
|
||||
|
||||
if finetuning_args.use_apollo and training_args.deepspeed is not None:
|
||||
raise ValueError("APOLLO is incompatible with DeepSpeed yet.")
|
||||
|
||||
if model_args.infer_backend == "vllm":
|
||||
raise ValueError("vLLM backend is only available for API, CLI and Web.")
|
||||
|
||||
@@ -306,6 +319,11 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
|
||||
"Using GaLore with mixed precision training may significantly increases GPU memory usage."
|
||||
)
|
||||
|
||||
if training_args.do_train and finetuning_args.use_apollo and not finetuning_args.pure_bf16:
|
||||
logger.warning_rank0(
|
||||
"Using APOLLO with mixed precision training may significantly increases GPU memory usage."
|
||||
)
|
||||
|
||||
if (not training_args.do_train) and model_args.quantization_bit is not None:
|
||||
logger.warning_rank0("Evaluating model in 4/8-bit mode may cause lower scores.")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user