[optim] add support to APOLLO (#6617)

Former-commit-id: 5a252e5a458457adbd19da3b68a3897ad2962824
This commit is contained in:
zhuHQ
2025-01-14 10:24:56 -06:00
committed by GitHub
parent 66184762e8
commit c2120432db
10 changed files with 351 additions and 5 deletions

View File

@@ -32,7 +32,7 @@ from typing_extensions import override
from ..extras import logging
from ..extras.constants import IGNORE_INDEX
from ..extras.packages import is_galore_available, is_ray_available
from ..extras.packages import is_galore_available, is_ray_available, is_apollo_available
from ..hparams import FinetuningArguments, ModelArguments
from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params
@@ -40,6 +40,8 @@ from ..model import find_all_linear_modules, load_model, load_tokenizer, load_va
if is_galore_available():
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit # type: ignore
if is_apollo_available():
from apollo_torch import APOLLOAdamW # type: ignore
if is_ray_available():
from ray.train import RunConfig, ScalingConfig
@@ -58,7 +60,7 @@ logger = logging.get_logger(__name__)
class DummyOptimizer(torch.optim.Optimizer):
r"""
A dummy optimizer used for the GaLore algorithm.
A dummy optimizer used for the GaLore or APOLLO algorithm.
"""
def __init__(
@@ -275,6 +277,90 @@ def _create_galore_optimizer(
logger.info_rank0("Using GaLore optimizer, may cause hanging at the start of training, wait patiently.")
return optimizer
def _create_apollo_optimizer(
model: "PreTrainedModel",
training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer":
if len(finetuning_args.apollo_target) == 1 and finetuning_args.apollo_target[0] == "all":
apollo_targets = find_all_linear_modules(model, finetuning_args.freeze_vision_tower)
else:
apollo_targets = finetuning_args.apollo_target
apollo_params: List["torch.nn.Parameter"] = []
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear) and any(target in name for target in apollo_targets):
for param in module.parameters():
if param.requires_grad and len(param.shape) > 1:
apollo_params.append(param)
apollo_kwargs = {
"rank": finetuning_args.apollo_rank,
"proj": finetuning_args.apollo_proj,
"proj_type": finetuning_args.apollo_proj_type,
"update_proj_gap": finetuning_args.apollo_update_interval,
"scale": finetuning_args.apollo_scale,
"scale_type": finetuning_args.apollo_scale_type,
"scale_front": finetuning_args.apollo_scale_front,
}
print(apollo_kwargs)
id_apollo_params = {id(param) for param in apollo_params}
decay_params, nodecay_params = [], [] # they are non-galore parameters
trainable_params: List["torch.nn.Parameter"] = [] # galore_params + decay_params + nodecay_params
decay_param_names = _get_decay_parameter_names(model)
for name, param in model.named_parameters():
if param.requires_grad:
trainable_params.append(param)
if id(param) not in id_apollo_params:
if name in decay_param_names:
decay_params.append(param)
else:
nodecay_params.append(param)
_, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
if training_args.optim == "adamw_torch":
optim_class = APOLLOAdamW
else:
raise NotImplementedError(f"Unknow optim: {training_args.optim}")
if finetuning_args.apollo_layerwise:
if training_args.gradient_accumulation_steps != 1:
raise ValueError("Per-layer APOLLO does not support gradient accumulation.")
optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {}
for param in nodecay_params:
param_groups = [dict(params=[param], weight_decay=0.0)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
for param in decay_params:
param_groups = [dict(params=[param], weight_decay=training_args.weight_decay)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
for param in apollo_params: # galore params have weight decay
param_groups = [dict(params=[param], weight_decay=training_args.weight_decay, **apollo_kwargs)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
def optimizer_hook(param: "torch.nn.Parameter"):
if param.grad is not None:
optimizer_dict[param].step()
optimizer_dict[param].zero_grad()
for param in trainable_params:
param.register_post_accumulate_grad_hook(optimizer_hook)
optimizer = DummyOptimizer(lr=training_args.learning_rate, optimizer_dict=optimizer_dict)
else:
param_groups = [
dict(params=nodecay_params, weight_decay=0.0),
dict(params=decay_params, weight_decay=training_args.weight_decay),
dict(params=apollo_params, weight_decay=training_args.weight_decay, **apollo_kwargs),
]
optimizer = optim_class(param_groups, **optim_kwargs)
logger.info_rank0("Using APOLLO optimizer.")
return optimizer
def _create_loraplus_optimizer(
model: "PreTrainedModel",
@@ -410,6 +496,9 @@ def create_custom_optimizer(
if finetuning_args.use_galore:
return _create_galore_optimizer(model, training_args, finetuning_args)
if finetuning_args.use_apollo:
return _create_apollo_optimizer(model, training_args, finetuning_args)
if finetuning_args.loraplus_lr_ratio is not None:
return _create_loraplus_optimizer(model, training_args, finetuning_args)