[optim] add support to APOLLO (#6617)

Former-commit-id: 5a252e5a458457adbd19da3b68a3897ad2962824
This commit is contained in:
zhuHQ
2025-01-14 10:24:56 -06:00
committed by GitHub
parent 66184762e8
commit c2120432db
10 changed files with 351 additions and 5 deletions

View File

@@ -139,6 +139,9 @@ def _check_extra_dependencies(
if finetuning_args.use_galore:
check_version("galore_torch", mandatory=True)
if finetuning_args.use_apollo:
check_version("apollo_torch", mandatory=True)
if finetuning_args.use_badam:
check_version("badam>=1.2.1", mandatory=True)
@@ -262,6 +265,13 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
):
raise ValueError("Distributed training does not support layer-wise GaLore.")
if (
finetuning_args.use_apollo
and finetuning_args.apollo_layerwise
and training_args.parallel_mode == ParallelMode.DISTRIBUTED
):
raise ValueError("Distributed training does not support layer-wise APOLLO.")
if finetuning_args.use_badam and training_args.parallel_mode == ParallelMode.DISTRIBUTED:
if finetuning_args.badam_mode == "ratio":
raise ValueError("Radio-based BAdam does not yet support distributed training, use layer-wise BAdam.")
@@ -271,6 +281,9 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
if finetuning_args.use_galore and training_args.deepspeed is not None:
raise ValueError("GaLore is incompatible with DeepSpeed yet.")
if finetuning_args.use_apollo and training_args.deepspeed is not None:
raise ValueError("APOLLO is incompatible with DeepSpeed yet.")
if model_args.infer_backend == "vllm":
raise ValueError("vLLM backend is only available for API, CLI and Web.")
@@ -306,6 +319,11 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
"Using GaLore with mixed precision training may significantly increases GPU memory usage."
)
if training_args.do_train and finetuning_args.use_apollo and not finetuning_args.pure_bf16:
logger.warning_rank0(
"Using APOLLO with mixed precision training may significantly increases GPU memory usage."
)
if (not training_args.do_train) and model_args.quantization_bit is not None:
logger.warning_rank0("Evaluating model in 4/8-bit mode may cause lower scores.")