[optim] add support to APOLLO (#6617)
Former-commit-id: 5a252e5a458457adbd19da3b68a3897ad2962824
This commit is contained in:
@@ -139,6 +139,9 @@ def _check_extra_dependencies(
|
||||
if finetuning_args.use_galore:
|
||||
check_version("galore_torch", mandatory=True)
|
||||
|
||||
if finetuning_args.use_apollo:
|
||||
check_version("apollo_torch", mandatory=True)
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
check_version("badam>=1.2.1", mandatory=True)
|
||||
|
||||
@@ -262,6 +265,13 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
|
||||
):
|
||||
raise ValueError("Distributed training does not support layer-wise GaLore.")
|
||||
|
||||
if (
|
||||
finetuning_args.use_apollo
|
||||
and finetuning_args.apollo_layerwise
|
||||
and training_args.parallel_mode == ParallelMode.DISTRIBUTED
|
||||
):
|
||||
raise ValueError("Distributed training does not support layer-wise APOLLO.")
|
||||
|
||||
if finetuning_args.use_badam and training_args.parallel_mode == ParallelMode.DISTRIBUTED:
|
||||
if finetuning_args.badam_mode == "ratio":
|
||||
raise ValueError("Radio-based BAdam does not yet support distributed training, use layer-wise BAdam.")
|
||||
@@ -271,6 +281,9 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
|
||||
if finetuning_args.use_galore and training_args.deepspeed is not None:
|
||||
raise ValueError("GaLore is incompatible with DeepSpeed yet.")
|
||||
|
||||
if finetuning_args.use_apollo and training_args.deepspeed is not None:
|
||||
raise ValueError("APOLLO is incompatible with DeepSpeed yet.")
|
||||
|
||||
if model_args.infer_backend == "vllm":
|
||||
raise ValueError("vLLM backend is only available for API, CLI and Web.")
|
||||
|
||||
@@ -306,6 +319,11 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
|
||||
"Using GaLore with mixed precision training may significantly increases GPU memory usage."
|
||||
)
|
||||
|
||||
if training_args.do_train and finetuning_args.use_apollo and not finetuning_args.pure_bf16:
|
||||
logger.warning_rank0(
|
||||
"Using APOLLO with mixed precision training may significantly increases GPU memory usage."
|
||||
)
|
||||
|
||||
if (not training_args.do_train) and model_args.quantization_bit is not None:
|
||||
logger.warning_rank0("Evaluating model in 4/8-bit mode may cause lower scores.")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user