[optim] clean apollo (#6645)
* clean apollo code * update readme Former-commit-id: 38b8ec4a99189483124b54df9d6bc6b0d318855a
This commit is contained in:
@@ -32,7 +32,7 @@ from typing_extensions import override
|
||||
|
||||
from ..extras import logging
|
||||
from ..extras.constants import IGNORE_INDEX
|
||||
from ..extras.packages import is_galore_available, is_ray_available, is_apollo_available
|
||||
from ..extras.packages import is_apollo_available, is_galore_available, is_ray_available
|
||||
from ..hparams import FinetuningArguments, ModelArguments
|
||||
from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params
|
||||
|
||||
@@ -40,9 +40,11 @@ from ..model import find_all_linear_modules, load_model, load_tokenizer, load_va
|
||||
if is_galore_available():
|
||||
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit # type: ignore
|
||||
|
||||
|
||||
if is_apollo_available():
|
||||
from apollo_torch import APOLLOAdamW # type: ignore
|
||||
|
||||
|
||||
if is_ray_available():
|
||||
from ray.train import RunConfig, ScalingConfig
|
||||
from ray.train.torch import TorchTrainer
|
||||
@@ -240,9 +242,10 @@ def _create_galore_optimizer(
|
||||
elif training_args.optim == "adafactor":
|
||||
optim_class = GaLoreAdafactor
|
||||
else:
|
||||
raise NotImplementedError(f"Unknow optim: {training_args.optim}")
|
||||
raise NotImplementedError(f"Unknown optim: {training_args.optim}.")
|
||||
|
||||
if finetuning_args.galore_layerwise:
|
||||
logger.warning_rank0("The displayed gradient norm will be all zeros in layerwise GaLore.")
|
||||
if training_args.gradient_accumulation_steps != 1:
|
||||
raise ValueError("Per-layer GaLore does not support gradient accumulation.")
|
||||
|
||||
@@ -274,9 +277,13 @@ def _create_galore_optimizer(
|
||||
]
|
||||
optimizer = optim_class(param_groups, **optim_kwargs)
|
||||
|
||||
logger.info_rank0("Using GaLore optimizer, may cause hanging at the start of training, wait patiently.")
|
||||
logger.info_rank0(
|
||||
f"Using GaLore optimizer with args: {galore_kwargs}. "
|
||||
"It may cause hanging at the start of training, wait patiently."
|
||||
)
|
||||
return optimizer
|
||||
|
||||
|
||||
def _create_apollo_optimizer(
|
||||
model: "PreTrainedModel",
|
||||
training_args: "TrainingArguments",
|
||||
@@ -304,11 +311,9 @@ def _create_apollo_optimizer(
|
||||
"scale_front": finetuning_args.apollo_scale_front,
|
||||
}
|
||||
|
||||
print(apollo_kwargs)
|
||||
|
||||
id_apollo_params = {id(param) for param in apollo_params}
|
||||
decay_params, nodecay_params = [], [] # they are non-galore parameters
|
||||
trainable_params: List["torch.nn.Parameter"] = [] # galore_params + decay_params + nodecay_params
|
||||
decay_params, nodecay_params = [], [] # they are non-apollo parameters
|
||||
trainable_params: List["torch.nn.Parameter"] = [] # apollo_params + decay_params + nodecay_params
|
||||
decay_param_names = _get_decay_parameter_names(model)
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
@@ -324,9 +329,10 @@ def _create_apollo_optimizer(
|
||||
if training_args.optim == "adamw_torch":
|
||||
optim_class = APOLLOAdamW
|
||||
else:
|
||||
raise NotImplementedError(f"Unknow optim: {training_args.optim}")
|
||||
raise NotImplementedError(f"Unknown optim: {training_args.optim}.")
|
||||
|
||||
if finetuning_args.apollo_layerwise:
|
||||
logger.warning_rank0("The displayed gradient norm will be all zeros in layerwise APOLLO.")
|
||||
if training_args.gradient_accumulation_steps != 1:
|
||||
raise ValueError("Per-layer APOLLO does not support gradient accumulation.")
|
||||
|
||||
@@ -337,7 +343,7 @@ def _create_apollo_optimizer(
|
||||
for param in decay_params:
|
||||
param_groups = [dict(params=[param], weight_decay=training_args.weight_decay)]
|
||||
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
|
||||
for param in apollo_params: # galore params have weight decay
|
||||
for param in apollo_params: # apollo params have weight decay
|
||||
param_groups = [dict(params=[param], weight_decay=training_args.weight_decay, **apollo_kwargs)]
|
||||
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
|
||||
|
||||
@@ -358,7 +364,7 @@ def _create_apollo_optimizer(
|
||||
]
|
||||
optimizer = optim_class(param_groups, **optim_kwargs)
|
||||
|
||||
logger.info_rank0("Using APOLLO optimizer.")
|
||||
logger.info_rank0(f"Using APOLLO optimizer with args: {apollo_kwargs}.")
|
||||
return optimizer
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user