[optim] clean apollo (#6645)

* clean apollo code

* update readme

Former-commit-id: 38b8ec4a99189483124b54df9d6bc6b0d318855a
This commit is contained in:
hoshi-hiyouga
2025-01-15 01:42:50 +08:00
committed by GitHub
parent c2120432db
commit 7638f1070e
14 changed files with 110 additions and 103 deletions

View File

@@ -49,9 +49,11 @@ def is_fastapi_available():
def is_galore_available():
return _is_package_available("galore_torch")
def is_apollo_available():
return _is_package_available("apollo_torch")
def is_gradio_available():
return _is_package_available("gradio")

View File

@@ -286,7 +286,7 @@ class ApolloArguments:
default="random",
metadata={"help": "Type of APOLLO low-rank projection algorithm (svd or random)."},
)
apollo_proj_type: Literal["std", "right", "left",] = field(
apollo_proj_type: Literal["std", "right", "left"] = field(
default="std",
metadata={"help": "Type of APOLLO projection."},
)
@@ -475,17 +475,11 @@ class FinetuningArguments(
if self.use_llama_pro and self.finetuning_type == "full":
raise ValueError("`use_llama_pro` is only valid for Freeze or LoRA training.")
if self.finetuning_type == "lora" and (self.use_galore or self.use_badam or self.use_apollo):
raise ValueError("Cannot use LoRA with GaLore or BAdam together.")
if self.finetuning_type == "lora" and (self.use_galore or self.use_apollo or self.use_badam):
raise ValueError("Cannot use LoRA with GaLore, APOLLO or BAdam together.")
if self.use_galore and self.use_badam:
raise ValueError("Cannot use GaLore with BAdam together.")
if self.use_galore and self.use_apollo:
raise ValueError("Cannot use GaLore with APOLLO together.")
if self.use_badam and self.use_apollo:
raise ValueError("Cannot use BAdam with APOLLO together.")
if int(self.use_galore) + int(self.use_apollo) + (self.use_badam) > 1:
raise ValueError("Cannot use GaLore, APOLLO or BAdam together.")
if self.pissa_init and (self.stage in ["ppo", "kto"] or self.use_ref_model):
raise ValueError("Cannot use PiSSA for current training stage.")

View File

@@ -258,31 +258,21 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
if is_deepspeed_zero3_enabled():
raise ValueError("`pure_bf16` is incompatible with DeepSpeed ZeRO-3.")
if (
finetuning_args.use_galore
and finetuning_args.galore_layerwise
and training_args.parallel_mode == ParallelMode.DISTRIBUTED
):
raise ValueError("Distributed training does not support layer-wise GaLore.")
if training_args.parallel_mode == ParallelMode.DISTRIBUTED:
if finetuning_args.use_galore and finetuning_args.galore_layerwise:
raise ValueError("Distributed training does not support layer-wise GaLore.")
if (
finetuning_args.use_apollo
and finetuning_args.apollo_layerwise
and training_args.parallel_mode == ParallelMode.DISTRIBUTED
):
raise ValueError("Distributed training does not support layer-wise APOLLO.")
if finetuning_args.use_apollo and finetuning_args.apollo_layerwise:
raise ValueError("Distributed training does not support layer-wise APOLLO.")
if finetuning_args.use_badam and training_args.parallel_mode == ParallelMode.DISTRIBUTED:
if finetuning_args.badam_mode == "ratio":
raise ValueError("Radio-based BAdam does not yet support distributed training, use layer-wise BAdam.")
elif not is_deepspeed_zero3_enabled():
raise ValueError("Layer-wise BAdam only supports DeepSpeed ZeRO-3 training.")
if finetuning_args.use_badam:
if finetuning_args.badam_mode == "ratio":
raise ValueError("Radio-based BAdam does not yet support distributed training, use layer-wise BAdam.")
elif not is_deepspeed_zero3_enabled():
raise ValueError("Layer-wise BAdam only supports DeepSpeed ZeRO-3 training.")
if finetuning_args.use_galore and training_args.deepspeed is not None:
raise ValueError("GaLore is incompatible with DeepSpeed yet.")
if finetuning_args.use_apollo and training_args.deepspeed is not None:
raise ValueError("APOLLO is incompatible with DeepSpeed yet.")
if training_args.deepspeed is not None and (finetuning_args.use_galore or finetuning_args.use_apollo):
raise ValueError("GaLore and APOLLO are incompatible with DeepSpeed yet.")
if model_args.infer_backend == "vllm":
raise ValueError("vLLM backend is only available for API, CLI and Web.")
@@ -314,14 +304,13 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
logger.warning_rank0("We recommend enable mixed precision training.")
if training_args.do_train and finetuning_args.use_galore and not finetuning_args.pure_bf16:
if (
training_args.do_train
and (finetuning_args.use_galore or finetuning_args.use_apollo)
and not finetuning_args.pure_bf16
):
logger.warning_rank0(
"Using GaLore with mixed precision training may significantly increases GPU memory usage."
)
if training_args.do_train and finetuning_args.use_apollo and not finetuning_args.pure_bf16:
logger.warning_rank0(
"Using APOLLO with mixed precision training may significantly increases GPU memory usage."
"Using GaLore or APOLLO with mixed precision training may significantly increases GPU memory usage."
)
if (not training_args.do_train) and model_args.quantization_bit is not None:
@@ -397,7 +386,6 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
str(model_args.compute_dtype),
)
)
transformers.set_seed(training_args.seed)
return model_args, data_args, training_args, finetuning_args, generating_args

View File

@@ -27,7 +27,7 @@ logger = logging.get_logger(__name__)
def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]:
r"""
Finds all available modules to apply lora or galore or apollo.
Finds all available modules to apply LoRA, GaLore or APOLLO.
"""
model_type = getattr(model.config, "model_type", None)
forbidden_modules = {"lm_head"}

View File

@@ -32,7 +32,7 @@ from typing_extensions import override
from ..extras import logging
from ..extras.constants import IGNORE_INDEX
from ..extras.packages import is_galore_available, is_ray_available, is_apollo_available
from ..extras.packages import is_apollo_available, is_galore_available, is_ray_available
from ..hparams import FinetuningArguments, ModelArguments
from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params
@@ -40,9 +40,11 @@ from ..model import find_all_linear_modules, load_model, load_tokenizer, load_va
if is_galore_available():
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit # type: ignore
if is_apollo_available():
from apollo_torch import APOLLOAdamW # type: ignore
if is_ray_available():
from ray.train import RunConfig, ScalingConfig
from ray.train.torch import TorchTrainer
@@ -240,9 +242,10 @@ def _create_galore_optimizer(
elif training_args.optim == "adafactor":
optim_class = GaLoreAdafactor
else:
raise NotImplementedError(f"Unknow optim: {training_args.optim}")
raise NotImplementedError(f"Unknown optim: {training_args.optim}.")
if finetuning_args.galore_layerwise:
logger.warning_rank0("The displayed gradient norm will be all zeros in layerwise GaLore.")
if training_args.gradient_accumulation_steps != 1:
raise ValueError("Per-layer GaLore does not support gradient accumulation.")
@@ -274,9 +277,13 @@ def _create_galore_optimizer(
]
optimizer = optim_class(param_groups, **optim_kwargs)
logger.info_rank0("Using GaLore optimizer, may cause hanging at the start of training, wait patiently.")
logger.info_rank0(
f"Using GaLore optimizer with args: {galore_kwargs}. "
"It may cause hanging at the start of training, wait patiently."
)
return optimizer
def _create_apollo_optimizer(
model: "PreTrainedModel",
training_args: "TrainingArguments",
@@ -304,11 +311,9 @@ def _create_apollo_optimizer(
"scale_front": finetuning_args.apollo_scale_front,
}
print(apollo_kwargs)
id_apollo_params = {id(param) for param in apollo_params}
decay_params, nodecay_params = [], [] # they are non-galore parameters
trainable_params: List["torch.nn.Parameter"] = [] # galore_params + decay_params + nodecay_params
decay_params, nodecay_params = [], [] # they are non-apollo parameters
trainable_params: List["torch.nn.Parameter"] = [] # apollo_params + decay_params + nodecay_params
decay_param_names = _get_decay_parameter_names(model)
for name, param in model.named_parameters():
if param.requires_grad:
@@ -324,9 +329,10 @@ def _create_apollo_optimizer(
if training_args.optim == "adamw_torch":
optim_class = APOLLOAdamW
else:
raise NotImplementedError(f"Unknow optim: {training_args.optim}")
raise NotImplementedError(f"Unknown optim: {training_args.optim}.")
if finetuning_args.apollo_layerwise:
logger.warning_rank0("The displayed gradient norm will be all zeros in layerwise APOLLO.")
if training_args.gradient_accumulation_steps != 1:
raise ValueError("Per-layer APOLLO does not support gradient accumulation.")
@@ -337,7 +343,7 @@ def _create_apollo_optimizer(
for param in decay_params:
param_groups = [dict(params=[param], weight_decay=training_args.weight_decay)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
for param in apollo_params: # galore params have weight decay
for param in apollo_params: # apollo params have weight decay
param_groups = [dict(params=[param], weight_decay=training_args.weight_decay, **apollo_kwargs)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
@@ -358,7 +364,7 @@ def _create_apollo_optimizer(
]
optimizer = optim_class(param_groups, **optim_kwargs)
logger.info_rank0("Using APOLLO optimizer.")
logger.info_rank0(f"Using APOLLO optimizer with args: {apollo_kwargs}.")
return optimizer

View File

@@ -234,8 +234,8 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Row():
use_galore = gr.Checkbox()
galore_rank = gr.Slider(minimum=1, maximum=1024, value=16, step=1)
galore_update_interval = gr.Slider(minimum=1, maximum=1024, value=200, step=1)
galore_scale = gr.Slider(minimum=0, maximum=1, value=0.25, step=0.01)
galore_update_interval = gr.Slider(minimum=1, maximum=2048, value=200, step=1)
galore_scale = gr.Slider(minimum=0, maximum=100, value=2.0, step=0.1)
galore_target = gr.Textbox(value="all")
input_elems.update({use_galore, galore_rank, galore_update_interval, galore_scale, galore_target})
@@ -254,9 +254,10 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Row():
use_apollo = gr.Checkbox()
apollo_rank = gr.Slider(minimum=1, maximum=1024, value=16, step=1)
apollo_update_interval = gr.Slider(minimum=1, maximum=1024, value=200, step=1)
apollo_scale = gr.Slider(minimum=0, maximum=1, value=0.25, step=0.01)
apollo_update_interval = gr.Slider(minimum=1, maximum=2048, value=200, step=1)
apollo_scale = gr.Slider(minimum=0, maximum=100, value=32.0, step=0.1)
apollo_target = gr.Textbox(value="all")
input_elems.update({use_apollo, apollo_rank, apollo_update_interval, apollo_scale, apollo_target})
elem_dict.update(
dict(

View File

@@ -1162,19 +1162,19 @@ LOCALES = {
"use_galore": {
"en": {
"label": "Use GaLore",
"info": "Enable gradient low-Rank projection.",
"info": "Use GaLore optimizer.",
},
"ru": {
"label": "Использовать GaLore",
"info": "Включить проекцию градиента на низкоранговое пространство.",
"info": "Используйте оптимизатор GaLore.",
},
"zh": {
"label": "使用 GaLore",
"info": "使用梯度低秩投影",
"info": "使用 GaLore 优化器",
},
"ko": {
"label": "GaLore 사용",
"info": "그레디언트 로우 랭크 프로젝션을 활성화합니다.",
"info": "GaLore 최적화를 사용하세요.",
},
},
"galore_rank": {
@@ -1266,19 +1266,19 @@ LOCALES = {
"use_apollo": {
"en": {
"label": "Use APOLLO",
"info": "Enable gradient low-Rank projection.",
"info": "Use APOLLO optimizer.",
},
"ru": {
"label": "Использовать APOLLO",
"info": "Включить проекцию градиента на низкоранговое пространство.",
"info": "Используйте оптимизатор APOLLO.",
},
"zh": {
"label": "使用 APOLLO",
"info": "使用梯度低秩投影",
"info": "使用 APOLLO 优化器",
},
"ko": {
"label": "APOLLO 사용",
"info": "그레디언트 로우 랭크 프로젝션을 활성화합니다.",
"info": "APOLLO 최적화를 사용하세요.",
},
},
"apollo_rank": {

View File

@@ -224,7 +224,7 @@ class Runner:
args["galore_update_interval"] = get("train.galore_update_interval")
args["galore_scale"] = get("train.galore_scale")
args["galore_target"] = get("train.galore_target")
# apollo config
if args["use_apollo"]:
args["apollo_rank"] = get("train.apollo_rank")