support layerwise galore

Former-commit-id: d43a4da0947897d0be3f62fad3107754d4c89f2b
This commit is contained in:
hiyouga
2024-03-10 00:24:11 +08:00
parent c635bbe465
commit 7ff8a064f3
14 changed files with 109 additions and 51 deletions

View File

@@ -1,6 +1,8 @@
from typing import TYPE_CHECKING, Optional, Union
import math
from typing import TYPE_CHECKING, Callable, Dict, Optional, Union
import torch
from transformers.optimization import get_scheduler
from transformers.utils.versions import require_version
from ..extras.logging import get_logger
@@ -14,6 +16,7 @@ if is_galore_available():
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from transformers import Seq2SeqTrainingArguments, Trainer
from transformers.modeling_utils import PreTrainedModel
from trl import AutoModelForCausalLMWithValueHead
@@ -24,6 +27,18 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
class DummyOptimizer(torch.optim.Optimizer):
def __init__(self, *args, **kwargs):
dummy_tensor = torch.randn(1, 1)
super().__init__([dummy_tensor], {"lr": 1e-3})
def zero_grad(self, set_to_none: bool = True) -> None:
pass
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
pass
def create_modelcard_and_push(
trainer: "Trainer",
model_args: "ModelArguments",
@@ -127,7 +142,10 @@ def create_reward_model(
def create_custom_optimzer(
model: "PreTrainedModel", training_args: "Seq2SeqTrainingArguments", finetuning_args: "FinetuningArguments"
model: "PreTrainedModel",
dataset: Union["Dataset", "IterableDataset"],
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
) -> Optional["torch.optim.Optimizer"]:
if not finetuning_args.use_galore:
return None
@@ -144,40 +162,80 @@ def create_custom_optimzer(
trainable_params = filter(lambda p: p.requires_grad, model.parameters())
non_galore_params = [p for p in trainable_params if id(p) not in id_galore_params]
# define param groups as galore_params and non_galore_params
param_groups = [
{"params": non_galore_params},
{
"params": galore_params,
"rank": finetuning_args.galore_rank,
"update_proj_gap": finetuning_args.galore_update_interval,
"scale": finetuning_args.galore_scale,
"proj_type": finetuning_args.galore_proj_type,
},
]
if training_args.optim == "adamw_torch":
optimizer = GaLoreAdamW(
param_groups,
lr=training_args.learning_rate,
eps=training_args.adam_epsilon,
betas=(training_args.adam_beta1, training_args.adam_beta2),
)
optim_class = GaLoreAdamW
optim_kwargs = {
"lr": training_args.learning_rate,
"eps": training_args.adam_epsilon,
"betas": (training_args.adam_beta1, training_args.adam_beta2),
}
elif training_args.optim in ["adamw_bnb_8bit", "adamw_8bit", "paged_adamw_8bit"]:
optimizer = GaLoreAdamW8bit(
param_groups,
lr=training_args.learning_rate,
eps=training_args.adam_epsilon,
betas=(training_args.adam_beta1, training_args.adam_beta2),
optim_bits=8,
is_paged="paged" in training_args.optim,
)
optim_class = GaLoreAdamW8bit
optim_kwargs = {
"lr": training_args.learning_rate,
"eps": training_args.adam_epsilon,
"betas": (training_args.adam_beta1, training_args.adam_beta2),
"optim_bits": 8,
"is_paged": "paged" in training_args.optim,
}
elif training_args.optim == "adafactor":
optimizer = GaLoreAdafactor(
param_groups,
lr=training_args.learning_rate,
)
optim_class = GaLoreAdafactor
optim_kwargs = {
"lr": training_args.learning_rate,
}
else:
raise NotImplementedError("Unknow optim: {}".format(training_args.optim))
galore_kwargs = {
"rank": finetuning_args.galore_rank,
"update_proj_gap": finetuning_args.galore_update_interval,
"scale": finetuning_args.galore_scale,
"proj_type": finetuning_args.galore_proj_type,
}
if finetuning_args.galore_layerwise:
if training_args.gradient_accumulation_steps != 1:
raise ValueError("Per-layer GaLore does not support gradient accumulation.")
if training_args.max_steps > 0:
num_training_steps = training_args.max_steps
else:
total_train_batch_size = training_args.per_device_train_batch_size * training_args.world_size
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {}
for param in non_galore_params:
param_groups = [dict(params=[param])]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
for param in galore_params:
param_groups = [dict(params=[param], **galore_kwargs)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
scheduler_dict: Dict["torch.Tensor", "torch.optim.lr_scheduler.LRScheduler"] = {}
for param in non_galore_params + galore_params:
scheduler_dict[param] = get_scheduler(
training_args.lr_scheduler_type,
optimizer=optimizer_dict[param],
num_warmup_steps=training_args.get_warmup_steps(num_training_steps) * 2,
num_training_steps=num_training_steps * 2,
)
def optimizer_hook(param: "torch.Tensor"):
if param.grad is not None:
optimizer_dict[param].step()
optimizer_dict[param].zero_grad()
scheduler_dict[param].step()
for param in non_galore_params + galore_params:
param.register_post_accumulate_grad_hook(optimizer_hook)
optimizer = DummyOptimizer()
else:
param_groups = [dict(params=non_galore_params), dict(params=galore_params, **galore_kwargs)]
optimizer = optim_class(param_groups, **optim_kwargs)
logger.info("Using GaLore optimizer, may cause hanging at the start of training, wait patiently.")
return optimizer