support layerwise galore
Former-commit-id: d43a4da0947897d0be3f62fad3107754d4c89f2b
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Callable, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers.optimization import get_scheduler
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
@@ -14,6 +16,7 @@ if is_galore_available():
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datasets import Dataset, IterableDataset
|
||||
from transformers import Seq2SeqTrainingArguments, Trainer
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
@@ -24,6 +27,18 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DummyOptimizer(torch.optim.Optimizer):
|
||||
def __init__(self, *args, **kwargs):
|
||||
dummy_tensor = torch.randn(1, 1)
|
||||
super().__init__([dummy_tensor], {"lr": 1e-3})
|
||||
|
||||
def zero_grad(self, set_to_none: bool = True) -> None:
|
||||
pass
|
||||
|
||||
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
|
||||
pass
|
||||
|
||||
|
||||
def create_modelcard_and_push(
|
||||
trainer: "Trainer",
|
||||
model_args: "ModelArguments",
|
||||
@@ -127,7 +142,10 @@ def create_reward_model(
|
||||
|
||||
|
||||
def create_custom_optimzer(
|
||||
model: "PreTrainedModel", training_args: "Seq2SeqTrainingArguments", finetuning_args: "FinetuningArguments"
|
||||
model: "PreTrainedModel",
|
||||
dataset: Union["Dataset", "IterableDataset"],
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
) -> Optional["torch.optim.Optimizer"]:
|
||||
if not finetuning_args.use_galore:
|
||||
return None
|
||||
@@ -144,40 +162,80 @@ def create_custom_optimzer(
|
||||
trainable_params = filter(lambda p: p.requires_grad, model.parameters())
|
||||
non_galore_params = [p for p in trainable_params if id(p) not in id_galore_params]
|
||||
|
||||
# define param groups as galore_params and non_galore_params
|
||||
param_groups = [
|
||||
{"params": non_galore_params},
|
||||
{
|
||||
"params": galore_params,
|
||||
"rank": finetuning_args.galore_rank,
|
||||
"update_proj_gap": finetuning_args.galore_update_interval,
|
||||
"scale": finetuning_args.galore_scale,
|
||||
"proj_type": finetuning_args.galore_proj_type,
|
||||
},
|
||||
]
|
||||
if training_args.optim == "adamw_torch":
|
||||
optimizer = GaLoreAdamW(
|
||||
param_groups,
|
||||
lr=training_args.learning_rate,
|
||||
eps=training_args.adam_epsilon,
|
||||
betas=(training_args.adam_beta1, training_args.adam_beta2),
|
||||
)
|
||||
optim_class = GaLoreAdamW
|
||||
optim_kwargs = {
|
||||
"lr": training_args.learning_rate,
|
||||
"eps": training_args.adam_epsilon,
|
||||
"betas": (training_args.adam_beta1, training_args.adam_beta2),
|
||||
}
|
||||
|
||||
elif training_args.optim in ["adamw_bnb_8bit", "adamw_8bit", "paged_adamw_8bit"]:
|
||||
optimizer = GaLoreAdamW8bit(
|
||||
param_groups,
|
||||
lr=training_args.learning_rate,
|
||||
eps=training_args.adam_epsilon,
|
||||
betas=(training_args.adam_beta1, training_args.adam_beta2),
|
||||
optim_bits=8,
|
||||
is_paged="paged" in training_args.optim,
|
||||
)
|
||||
optim_class = GaLoreAdamW8bit
|
||||
optim_kwargs = {
|
||||
"lr": training_args.learning_rate,
|
||||
"eps": training_args.adam_epsilon,
|
||||
"betas": (training_args.adam_beta1, training_args.adam_beta2),
|
||||
"optim_bits": 8,
|
||||
"is_paged": "paged" in training_args.optim,
|
||||
}
|
||||
|
||||
elif training_args.optim == "adafactor":
|
||||
optimizer = GaLoreAdafactor(
|
||||
param_groups,
|
||||
lr=training_args.learning_rate,
|
||||
)
|
||||
optim_class = GaLoreAdafactor
|
||||
optim_kwargs = {
|
||||
"lr": training_args.learning_rate,
|
||||
}
|
||||
|
||||
else:
|
||||
raise NotImplementedError("Unknow optim: {}".format(training_args.optim))
|
||||
|
||||
galore_kwargs = {
|
||||
"rank": finetuning_args.galore_rank,
|
||||
"update_proj_gap": finetuning_args.galore_update_interval,
|
||||
"scale": finetuning_args.galore_scale,
|
||||
"proj_type": finetuning_args.galore_proj_type,
|
||||
}
|
||||
|
||||
if finetuning_args.galore_layerwise:
|
||||
if training_args.gradient_accumulation_steps != 1:
|
||||
raise ValueError("Per-layer GaLore does not support gradient accumulation.")
|
||||
|
||||
if training_args.max_steps > 0:
|
||||
num_training_steps = training_args.max_steps
|
||||
else:
|
||||
total_train_batch_size = training_args.per_device_train_batch_size * training_args.world_size
|
||||
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
|
||||
|
||||
optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {}
|
||||
for param in non_galore_params:
|
||||
param_groups = [dict(params=[param])]
|
||||
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
|
||||
for param in galore_params:
|
||||
param_groups = [dict(params=[param], **galore_kwargs)]
|
||||
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
|
||||
|
||||
scheduler_dict: Dict["torch.Tensor", "torch.optim.lr_scheduler.LRScheduler"] = {}
|
||||
for param in non_galore_params + galore_params:
|
||||
scheduler_dict[param] = get_scheduler(
|
||||
training_args.lr_scheduler_type,
|
||||
optimizer=optimizer_dict[param],
|
||||
num_warmup_steps=training_args.get_warmup_steps(num_training_steps) * 2,
|
||||
num_training_steps=num_training_steps * 2,
|
||||
)
|
||||
|
||||
def optimizer_hook(param: "torch.Tensor"):
|
||||
if param.grad is not None:
|
||||
optimizer_dict[param].step()
|
||||
optimizer_dict[param].zero_grad()
|
||||
scheduler_dict[param].step()
|
||||
|
||||
for param in non_galore_params + galore_params:
|
||||
param.register_post_accumulate_grad_hook(optimizer_hook)
|
||||
|
||||
optimizer = DummyOptimizer()
|
||||
else:
|
||||
param_groups = [dict(params=non_galore_params), dict(params=galore_params, **galore_kwargs)]
|
||||
optimizer = optim_class(param_groups, **optim_kwargs)
|
||||
|
||||
logger.info("Using GaLore optimizer, may cause hanging at the start of training, wait patiently.")
|
||||
return optimizer
|
||||
|
||||
Reference in New Issue
Block a user