[optim] clean apollo (#6645)

* clean apollo code

* update readme

Former-commit-id: 38b8ec4a99189483124b54df9d6bc6b0d318855a
This commit is contained in:
hoshi-hiyouga
2025-01-15 01:42:50 +08:00
committed by GitHub
parent c2120432db
commit 7638f1070e
14 changed files with 110 additions and 103 deletions

View File

@@ -32,7 +32,7 @@ from typing_extensions import override
from ..extras import logging
from ..extras.constants import IGNORE_INDEX
from ..extras.packages import is_galore_available, is_ray_available, is_apollo_available
from ..extras.packages import is_apollo_available, is_galore_available, is_ray_available
from ..hparams import FinetuningArguments, ModelArguments
from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params
@@ -40,9 +40,11 @@ from ..model import find_all_linear_modules, load_model, load_tokenizer, load_va
if is_galore_available():
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit # type: ignore
if is_apollo_available():
from apollo_torch import APOLLOAdamW # type: ignore
if is_ray_available():
from ray.train import RunConfig, ScalingConfig
from ray.train.torch import TorchTrainer
@@ -240,9 +242,10 @@ def _create_galore_optimizer(
elif training_args.optim == "adafactor":
optim_class = GaLoreAdafactor
else:
raise NotImplementedError(f"Unknow optim: {training_args.optim}")
raise NotImplementedError(f"Unknown optim: {training_args.optim}.")
if finetuning_args.galore_layerwise:
logger.warning_rank0("The displayed gradient norm will be all zeros in layerwise GaLore.")
if training_args.gradient_accumulation_steps != 1:
raise ValueError("Per-layer GaLore does not support gradient accumulation.")
@@ -274,9 +277,13 @@ def _create_galore_optimizer(
]
optimizer = optim_class(param_groups, **optim_kwargs)
logger.info_rank0("Using GaLore optimizer, may cause hanging at the start of training, wait patiently.")
logger.info_rank0(
f"Using GaLore optimizer with args: {galore_kwargs}. "
"It may cause hanging at the start of training, wait patiently."
)
return optimizer
def _create_apollo_optimizer(
model: "PreTrainedModel",
training_args: "TrainingArguments",
@@ -304,11 +311,9 @@ def _create_apollo_optimizer(
"scale_front": finetuning_args.apollo_scale_front,
}
print(apollo_kwargs)
id_apollo_params = {id(param) for param in apollo_params}
decay_params, nodecay_params = [], [] # they are non-galore parameters
trainable_params: List["torch.nn.Parameter"] = [] # galore_params + decay_params + nodecay_params
decay_params, nodecay_params = [], [] # they are non-apollo parameters
trainable_params: List["torch.nn.Parameter"] = [] # apollo_params + decay_params + nodecay_params
decay_param_names = _get_decay_parameter_names(model)
for name, param in model.named_parameters():
if param.requires_grad:
@@ -324,9 +329,10 @@ def _create_apollo_optimizer(
if training_args.optim == "adamw_torch":
optim_class = APOLLOAdamW
else:
raise NotImplementedError(f"Unknow optim: {training_args.optim}")
raise NotImplementedError(f"Unknown optim: {training_args.optim}.")
if finetuning_args.apollo_layerwise:
logger.warning_rank0("The displayed gradient norm will be all zeros in layerwise APOLLO.")
if training_args.gradient_accumulation_steps != 1:
raise ValueError("Per-layer APOLLO does not support gradient accumulation.")
@@ -337,7 +343,7 @@ def _create_apollo_optimizer(
for param in decay_params:
param_groups = [dict(params=[param], weight_decay=training_args.weight_decay)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
for param in apollo_params: # galore params have weight decay
for param in apollo_params: # apollo params have weight decay
param_groups = [dict(params=[param], weight_decay=training_args.weight_decay, **apollo_kwargs)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
@@ -358,7 +364,7 @@ def _create_apollo_optimizer(
]
optimizer = optim_class(param_groups, **optim_kwargs)
logger.info_rank0("Using APOLLO optimizer.")
logger.info_rank0(f"Using APOLLO optimizer with args: {apollo_kwargs}.")
return optimizer