[optim] clean apollo (#6645)
* clean apollo code * update readme Former-commit-id: 38b8ec4a99189483124b54df9d6bc6b0d318855a
This commit is contained in:
@@ -49,9 +49,11 @@ def is_fastapi_available():
|
||||
def is_galore_available():
|
||||
return _is_package_available("galore_torch")
|
||||
|
||||
|
||||
def is_apollo_available():
|
||||
return _is_package_available("apollo_torch")
|
||||
|
||||
|
||||
def is_gradio_available():
|
||||
return _is_package_available("gradio")
|
||||
|
||||
|
||||
@@ -286,7 +286,7 @@ class ApolloArguments:
|
||||
default="random",
|
||||
metadata={"help": "Type of APOLLO low-rank projection algorithm (svd or random)."},
|
||||
)
|
||||
apollo_proj_type: Literal["std", "right", "left",] = field(
|
||||
apollo_proj_type: Literal["std", "right", "left"] = field(
|
||||
default="std",
|
||||
metadata={"help": "Type of APOLLO projection."},
|
||||
)
|
||||
@@ -475,17 +475,11 @@ class FinetuningArguments(
|
||||
if self.use_llama_pro and self.finetuning_type == "full":
|
||||
raise ValueError("`use_llama_pro` is only valid for Freeze or LoRA training.")
|
||||
|
||||
if self.finetuning_type == "lora" and (self.use_galore or self.use_badam or self.use_apollo):
|
||||
raise ValueError("Cannot use LoRA with GaLore or BAdam together.")
|
||||
if self.finetuning_type == "lora" and (self.use_galore or self.use_apollo or self.use_badam):
|
||||
raise ValueError("Cannot use LoRA with GaLore, APOLLO or BAdam together.")
|
||||
|
||||
if self.use_galore and self.use_badam:
|
||||
raise ValueError("Cannot use GaLore with BAdam together.")
|
||||
|
||||
if self.use_galore and self.use_apollo:
|
||||
raise ValueError("Cannot use GaLore with APOLLO together.")
|
||||
|
||||
if self.use_badam and self.use_apollo:
|
||||
raise ValueError("Cannot use BAdam with APOLLO together.")
|
||||
if int(self.use_galore) + int(self.use_apollo) + (self.use_badam) > 1:
|
||||
raise ValueError("Cannot use GaLore, APOLLO or BAdam together.")
|
||||
|
||||
if self.pissa_init and (self.stage in ["ppo", "kto"] or self.use_ref_model):
|
||||
raise ValueError("Cannot use PiSSA for current training stage.")
|
||||
|
||||
@@ -258,31 +258,21 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
|
||||
if is_deepspeed_zero3_enabled():
|
||||
raise ValueError("`pure_bf16` is incompatible with DeepSpeed ZeRO-3.")
|
||||
|
||||
if (
|
||||
finetuning_args.use_galore
|
||||
and finetuning_args.galore_layerwise
|
||||
and training_args.parallel_mode == ParallelMode.DISTRIBUTED
|
||||
):
|
||||
raise ValueError("Distributed training does not support layer-wise GaLore.")
|
||||
if training_args.parallel_mode == ParallelMode.DISTRIBUTED:
|
||||
if finetuning_args.use_galore and finetuning_args.galore_layerwise:
|
||||
raise ValueError("Distributed training does not support layer-wise GaLore.")
|
||||
|
||||
if (
|
||||
finetuning_args.use_apollo
|
||||
and finetuning_args.apollo_layerwise
|
||||
and training_args.parallel_mode == ParallelMode.DISTRIBUTED
|
||||
):
|
||||
raise ValueError("Distributed training does not support layer-wise APOLLO.")
|
||||
if finetuning_args.use_apollo and finetuning_args.apollo_layerwise:
|
||||
raise ValueError("Distributed training does not support layer-wise APOLLO.")
|
||||
|
||||
if finetuning_args.use_badam and training_args.parallel_mode == ParallelMode.DISTRIBUTED:
|
||||
if finetuning_args.badam_mode == "ratio":
|
||||
raise ValueError("Radio-based BAdam does not yet support distributed training, use layer-wise BAdam.")
|
||||
elif not is_deepspeed_zero3_enabled():
|
||||
raise ValueError("Layer-wise BAdam only supports DeepSpeed ZeRO-3 training.")
|
||||
if finetuning_args.use_badam:
|
||||
if finetuning_args.badam_mode == "ratio":
|
||||
raise ValueError("Radio-based BAdam does not yet support distributed training, use layer-wise BAdam.")
|
||||
elif not is_deepspeed_zero3_enabled():
|
||||
raise ValueError("Layer-wise BAdam only supports DeepSpeed ZeRO-3 training.")
|
||||
|
||||
if finetuning_args.use_galore and training_args.deepspeed is not None:
|
||||
raise ValueError("GaLore is incompatible with DeepSpeed yet.")
|
||||
|
||||
if finetuning_args.use_apollo and training_args.deepspeed is not None:
|
||||
raise ValueError("APOLLO is incompatible with DeepSpeed yet.")
|
||||
if training_args.deepspeed is not None and (finetuning_args.use_galore or finetuning_args.use_apollo):
|
||||
raise ValueError("GaLore and APOLLO are incompatible with DeepSpeed yet.")
|
||||
|
||||
if model_args.infer_backend == "vllm":
|
||||
raise ValueError("vLLM backend is only available for API, CLI and Web.")
|
||||
@@ -314,14 +304,13 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
|
||||
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
|
||||
logger.warning_rank0("We recommend enable mixed precision training.")
|
||||
|
||||
if training_args.do_train and finetuning_args.use_galore and not finetuning_args.pure_bf16:
|
||||
if (
|
||||
training_args.do_train
|
||||
and (finetuning_args.use_galore or finetuning_args.use_apollo)
|
||||
and not finetuning_args.pure_bf16
|
||||
):
|
||||
logger.warning_rank0(
|
||||
"Using GaLore with mixed precision training may significantly increases GPU memory usage."
|
||||
)
|
||||
|
||||
if training_args.do_train and finetuning_args.use_apollo and not finetuning_args.pure_bf16:
|
||||
logger.warning_rank0(
|
||||
"Using APOLLO with mixed precision training may significantly increases GPU memory usage."
|
||||
"Using GaLore or APOLLO with mixed precision training may significantly increases GPU memory usage."
|
||||
)
|
||||
|
||||
if (not training_args.do_train) and model_args.quantization_bit is not None:
|
||||
@@ -397,7 +386,6 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
|
||||
str(model_args.compute_dtype),
|
||||
)
|
||||
)
|
||||
|
||||
transformers.set_seed(training_args.seed)
|
||||
|
||||
return model_args, data_args, training_args, finetuning_args, generating_args
|
||||
|
||||
@@ -27,7 +27,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]:
|
||||
r"""
|
||||
Finds all available modules to apply lora or galore or apollo.
|
||||
Finds all available modules to apply LoRA, GaLore or APOLLO.
|
||||
"""
|
||||
model_type = getattr(model.config, "model_type", None)
|
||||
forbidden_modules = {"lm_head"}
|
||||
|
||||
@@ -32,7 +32,7 @@ from typing_extensions import override
|
||||
|
||||
from ..extras import logging
|
||||
from ..extras.constants import IGNORE_INDEX
|
||||
from ..extras.packages import is_galore_available, is_ray_available, is_apollo_available
|
||||
from ..extras.packages import is_apollo_available, is_galore_available, is_ray_available
|
||||
from ..hparams import FinetuningArguments, ModelArguments
|
||||
from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params
|
||||
|
||||
@@ -40,9 +40,11 @@ from ..model import find_all_linear_modules, load_model, load_tokenizer, load_va
|
||||
if is_galore_available():
|
||||
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit # type: ignore
|
||||
|
||||
|
||||
if is_apollo_available():
|
||||
from apollo_torch import APOLLOAdamW # type: ignore
|
||||
|
||||
|
||||
if is_ray_available():
|
||||
from ray.train import RunConfig, ScalingConfig
|
||||
from ray.train.torch import TorchTrainer
|
||||
@@ -240,9 +242,10 @@ def _create_galore_optimizer(
|
||||
elif training_args.optim == "adafactor":
|
||||
optim_class = GaLoreAdafactor
|
||||
else:
|
||||
raise NotImplementedError(f"Unknow optim: {training_args.optim}")
|
||||
raise NotImplementedError(f"Unknown optim: {training_args.optim}.")
|
||||
|
||||
if finetuning_args.galore_layerwise:
|
||||
logger.warning_rank0("The displayed gradient norm will be all zeros in layerwise GaLore.")
|
||||
if training_args.gradient_accumulation_steps != 1:
|
||||
raise ValueError("Per-layer GaLore does not support gradient accumulation.")
|
||||
|
||||
@@ -274,9 +277,13 @@ def _create_galore_optimizer(
|
||||
]
|
||||
optimizer = optim_class(param_groups, **optim_kwargs)
|
||||
|
||||
logger.info_rank0("Using GaLore optimizer, may cause hanging at the start of training, wait patiently.")
|
||||
logger.info_rank0(
|
||||
f"Using GaLore optimizer with args: {galore_kwargs}. "
|
||||
"It may cause hanging at the start of training, wait patiently."
|
||||
)
|
||||
return optimizer
|
||||
|
||||
|
||||
def _create_apollo_optimizer(
|
||||
model: "PreTrainedModel",
|
||||
training_args: "TrainingArguments",
|
||||
@@ -304,11 +311,9 @@ def _create_apollo_optimizer(
|
||||
"scale_front": finetuning_args.apollo_scale_front,
|
||||
}
|
||||
|
||||
print(apollo_kwargs)
|
||||
|
||||
id_apollo_params = {id(param) for param in apollo_params}
|
||||
decay_params, nodecay_params = [], [] # they are non-galore parameters
|
||||
trainable_params: List["torch.nn.Parameter"] = [] # galore_params + decay_params + nodecay_params
|
||||
decay_params, nodecay_params = [], [] # they are non-apollo parameters
|
||||
trainable_params: List["torch.nn.Parameter"] = [] # apollo_params + decay_params + nodecay_params
|
||||
decay_param_names = _get_decay_parameter_names(model)
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
@@ -324,9 +329,10 @@ def _create_apollo_optimizer(
|
||||
if training_args.optim == "adamw_torch":
|
||||
optim_class = APOLLOAdamW
|
||||
else:
|
||||
raise NotImplementedError(f"Unknow optim: {training_args.optim}")
|
||||
raise NotImplementedError(f"Unknown optim: {training_args.optim}.")
|
||||
|
||||
if finetuning_args.apollo_layerwise:
|
||||
logger.warning_rank0("The displayed gradient norm will be all zeros in layerwise APOLLO.")
|
||||
if training_args.gradient_accumulation_steps != 1:
|
||||
raise ValueError("Per-layer APOLLO does not support gradient accumulation.")
|
||||
|
||||
@@ -337,7 +343,7 @@ def _create_apollo_optimizer(
|
||||
for param in decay_params:
|
||||
param_groups = [dict(params=[param], weight_decay=training_args.weight_decay)]
|
||||
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
|
||||
for param in apollo_params: # galore params have weight decay
|
||||
for param in apollo_params: # apollo params have weight decay
|
||||
param_groups = [dict(params=[param], weight_decay=training_args.weight_decay, **apollo_kwargs)]
|
||||
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
|
||||
|
||||
@@ -358,7 +364,7 @@ def _create_apollo_optimizer(
|
||||
]
|
||||
optimizer = optim_class(param_groups, **optim_kwargs)
|
||||
|
||||
logger.info_rank0("Using APOLLO optimizer.")
|
||||
logger.info_rank0(f"Using APOLLO optimizer with args: {apollo_kwargs}.")
|
||||
return optimizer
|
||||
|
||||
|
||||
|
||||
@@ -234,8 +234,8 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
with gr.Row():
|
||||
use_galore = gr.Checkbox()
|
||||
galore_rank = gr.Slider(minimum=1, maximum=1024, value=16, step=1)
|
||||
galore_update_interval = gr.Slider(minimum=1, maximum=1024, value=200, step=1)
|
||||
galore_scale = gr.Slider(minimum=0, maximum=1, value=0.25, step=0.01)
|
||||
galore_update_interval = gr.Slider(minimum=1, maximum=2048, value=200, step=1)
|
||||
galore_scale = gr.Slider(minimum=0, maximum=100, value=2.0, step=0.1)
|
||||
galore_target = gr.Textbox(value="all")
|
||||
|
||||
input_elems.update({use_galore, galore_rank, galore_update_interval, galore_scale, galore_target})
|
||||
@@ -254,9 +254,10 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
with gr.Row():
|
||||
use_apollo = gr.Checkbox()
|
||||
apollo_rank = gr.Slider(minimum=1, maximum=1024, value=16, step=1)
|
||||
apollo_update_interval = gr.Slider(minimum=1, maximum=1024, value=200, step=1)
|
||||
apollo_scale = gr.Slider(minimum=0, maximum=1, value=0.25, step=0.01)
|
||||
apollo_update_interval = gr.Slider(minimum=1, maximum=2048, value=200, step=1)
|
||||
apollo_scale = gr.Slider(minimum=0, maximum=100, value=32.0, step=0.1)
|
||||
apollo_target = gr.Textbox(value="all")
|
||||
|
||||
input_elems.update({use_apollo, apollo_rank, apollo_update_interval, apollo_scale, apollo_target})
|
||||
elem_dict.update(
|
||||
dict(
|
||||
|
||||
@@ -1162,19 +1162,19 @@ LOCALES = {
|
||||
"use_galore": {
|
||||
"en": {
|
||||
"label": "Use GaLore",
|
||||
"info": "Enable gradient low-Rank projection.",
|
||||
"info": "Use GaLore optimizer.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Использовать GaLore",
|
||||
"info": "Включить проекцию градиента на низкоранговое пространство.",
|
||||
"info": "Используйте оптимизатор GaLore.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "使用 GaLore",
|
||||
"info": "使用梯度低秩投影。",
|
||||
"info": "使用 GaLore 优化器。",
|
||||
},
|
||||
"ko": {
|
||||
"label": "GaLore 사용",
|
||||
"info": "그레디언트 로우 랭크 프로젝션을 활성화합니다.",
|
||||
"info": "GaLore 최적화를 사용하세요.",
|
||||
},
|
||||
},
|
||||
"galore_rank": {
|
||||
@@ -1266,19 +1266,19 @@ LOCALES = {
|
||||
"use_apollo": {
|
||||
"en": {
|
||||
"label": "Use APOLLO",
|
||||
"info": "Enable gradient low-Rank projection.",
|
||||
"info": "Use APOLLO optimizer.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Использовать APOLLO",
|
||||
"info": "Включить проекцию градиента на низкоранговое пространство.",
|
||||
"info": "Используйте оптимизатор APOLLO.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "使用 APOLLO",
|
||||
"info": "使用梯度低秩投影。",
|
||||
"info": "使用 APOLLO 优化器。",
|
||||
},
|
||||
"ko": {
|
||||
"label": "APOLLO 사용",
|
||||
"info": "그레디언트 로우 랭크 프로젝션을 활성화합니다.",
|
||||
"info": "APOLLO 최적화를 사용하세요.",
|
||||
},
|
||||
},
|
||||
"apollo_rank": {
|
||||
|
||||
@@ -224,7 +224,7 @@ class Runner:
|
||||
args["galore_update_interval"] = get("train.galore_update_interval")
|
||||
args["galore_scale"] = get("train.galore_scale")
|
||||
args["galore_target"] = get("train.galore_target")
|
||||
|
||||
|
||||
# apollo config
|
||||
if args["use_apollo"]:
|
||||
args["apollo_rank"] = get("train.apollo_rank")
|
||||
|
||||
Reference in New Issue
Block a user