[optim] clean apollo (#6645)
* clean apollo code * update readme Former-commit-id: 38b8ec4a99189483124b54df9d6bc6b0d318855a
This commit is contained in:
@@ -258,31 +258,21 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
|
||||
if is_deepspeed_zero3_enabled():
|
||||
raise ValueError("`pure_bf16` is incompatible with DeepSpeed ZeRO-3.")
|
||||
|
||||
if (
|
||||
finetuning_args.use_galore
|
||||
and finetuning_args.galore_layerwise
|
||||
and training_args.parallel_mode == ParallelMode.DISTRIBUTED
|
||||
):
|
||||
raise ValueError("Distributed training does not support layer-wise GaLore.")
|
||||
if training_args.parallel_mode == ParallelMode.DISTRIBUTED:
|
||||
if finetuning_args.use_galore and finetuning_args.galore_layerwise:
|
||||
raise ValueError("Distributed training does not support layer-wise GaLore.")
|
||||
|
||||
if (
|
||||
finetuning_args.use_apollo
|
||||
and finetuning_args.apollo_layerwise
|
||||
and training_args.parallel_mode == ParallelMode.DISTRIBUTED
|
||||
):
|
||||
raise ValueError("Distributed training does not support layer-wise APOLLO.")
|
||||
if finetuning_args.use_apollo and finetuning_args.apollo_layerwise:
|
||||
raise ValueError("Distributed training does not support layer-wise APOLLO.")
|
||||
|
||||
if finetuning_args.use_badam and training_args.parallel_mode == ParallelMode.DISTRIBUTED:
|
||||
if finetuning_args.badam_mode == "ratio":
|
||||
raise ValueError("Radio-based BAdam does not yet support distributed training, use layer-wise BAdam.")
|
||||
elif not is_deepspeed_zero3_enabled():
|
||||
raise ValueError("Layer-wise BAdam only supports DeepSpeed ZeRO-3 training.")
|
||||
if finetuning_args.use_badam:
|
||||
if finetuning_args.badam_mode == "ratio":
|
||||
raise ValueError("Radio-based BAdam does not yet support distributed training, use layer-wise BAdam.")
|
||||
elif not is_deepspeed_zero3_enabled():
|
||||
raise ValueError("Layer-wise BAdam only supports DeepSpeed ZeRO-3 training.")
|
||||
|
||||
if finetuning_args.use_galore and training_args.deepspeed is not None:
|
||||
raise ValueError("GaLore is incompatible with DeepSpeed yet.")
|
||||
|
||||
if finetuning_args.use_apollo and training_args.deepspeed is not None:
|
||||
raise ValueError("APOLLO is incompatible with DeepSpeed yet.")
|
||||
if training_args.deepspeed is not None and (finetuning_args.use_galore or finetuning_args.use_apollo):
|
||||
raise ValueError("GaLore and APOLLO are incompatible with DeepSpeed yet.")
|
||||
|
||||
if model_args.infer_backend == "vllm":
|
||||
raise ValueError("vLLM backend is only available for API, CLI and Web.")
|
||||
@@ -314,14 +304,13 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
|
||||
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
|
||||
logger.warning_rank0("We recommend enable mixed precision training.")
|
||||
|
||||
if training_args.do_train and finetuning_args.use_galore and not finetuning_args.pure_bf16:
|
||||
if (
|
||||
training_args.do_train
|
||||
and (finetuning_args.use_galore or finetuning_args.use_apollo)
|
||||
and not finetuning_args.pure_bf16
|
||||
):
|
||||
logger.warning_rank0(
|
||||
"Using GaLore with mixed precision training may significantly increases GPU memory usage."
|
||||
)
|
||||
|
||||
if training_args.do_train and finetuning_args.use_apollo and not finetuning_args.pure_bf16:
|
||||
logger.warning_rank0(
|
||||
"Using APOLLO with mixed precision training may significantly increases GPU memory usage."
|
||||
"Using GaLore or APOLLO with mixed precision training may significantly increases GPU memory usage."
|
||||
)
|
||||
|
||||
if (not training_args.do_train) and model_args.quantization_bit is not None:
|
||||
@@ -397,7 +386,6 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
|
||||
str(model_args.compute_dtype),
|
||||
)
|
||||
)
|
||||
|
||||
transformers.set_seed(training_args.seed)
|
||||
|
||||
return model_args, data_args, training_args, finetuning_args, generating_args
|
||||
|
||||
Reference in New Issue
Block a user