Feature BAdam
Former-commit-id: d8d2807fbcf587c37f7fd34a23e9397d2775ceed
This commit is contained in:
@@ -9,7 +9,8 @@ from transformers import Seq2SeqTrainer
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.logging import get_logger
|
||||
from ..utils import create_custom_optimzer, create_custom_scheduler
|
||||
|
||||
from types import MethodType
|
||||
from packaging import version
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.trainer import PredictionOutput
|
||||
@@ -28,6 +29,9 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.finetuning_args = finetuning_args
|
||||
if version.parse(torch.__version__) >= version.parse("1.13"):
|
||||
from badam import clip_grad_norm_for_sparse_tensor
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
|
||||
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
|
||||
@@ -287,12 +287,69 @@ def _create_loraplus_optimizer(
|
||||
logger.info("Using LoRA+ optimizer with loraplus lr ratio {:.2f}.".format(finetuning_args.loraplus_lr_ratio))
|
||||
return optimizer
|
||||
|
||||
def _create_badam_optimizer(
|
||||
model: "PreTrainedModel",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
) -> "torch.optim.Optimizer":
|
||||
|
||||
from transformers.trainer_pt_utils import get_parameter_names
|
||||
decay_parameters = list(filter(lambda n: "bias" not in n, get_parameter_names(model, ALL_LAYERNORM_LAYERS)))
|
||||
# filter out the embedding layers when using badam ratio mode
|
||||
if finetuning_args.badam_mode == "ratio":
|
||||
decay_parameters = list(filter(lambda n: "embed" not in n, decay_parameters)) # TODO: make it more general
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if n in decay_parameters],
|
||||
"weight_decay": training_args.weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if n not in decay_parameters],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
|
||||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
|
||||
|
||||
# create BlockOptimizer
|
||||
if finetuning_args.badam_mode == "layer":
|
||||
from badam import BlockOptimizer
|
||||
base_optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||
optimizer = BlockOptimizer(base_optimizer=base_optimizer,
|
||||
named_parameters_list=list(model.named_parameters()),
|
||||
block_prefix_list=None,
|
||||
switch_block_every=finetuning_args.switch_block_every,
|
||||
start_block=finetuning_args.start_block,
|
||||
switch_mode=finetuning_args.switch_mode,
|
||||
verbose=finetuning_args.badam_verbose)
|
||||
|
||||
logger.info(f"Using BAdam optimizer with layer-wise update, switch mode is {finetuning_args.switch_mode}, "
|
||||
f"switch block every {finetuning_args.switch_block_every} steps, "
|
||||
f"default start block is {finetuning_args.start_block}")
|
||||
|
||||
elif finetuning_args.badam_mode == "ratio":
|
||||
assert finetuning_args.badam_update_ratio > 0.
|
||||
from badam import BlockOptimizerRatio
|
||||
optimizer = BlockOptimizerRatio(param_groups=optimizer_grouped_parameters,
|
||||
named_parameters_list=list(model.named_parameters()),
|
||||
update_ratio=finetuning_args.badam_update_ratio,
|
||||
mask_mode=finetuning_args.badam_mask_mode,
|
||||
verbose=finetuning_args.badam_verbose,
|
||||
**optimizer_kwargs)
|
||||
|
||||
logger.info(f"Using BAdam optimizer with ratio update, update ratio is {finetuning_args.badam_update_ratio}, "
|
||||
f"mask mode is {finetuning_args.badam_mask_mode}")
|
||||
|
||||
return optimizer
|
||||
|
||||
def create_custom_optimzer(
|
||||
model: "PreTrainedModel",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
) -> Optional["torch.optim.Optimizer"]:
|
||||
if finetuning_args.use_badam:
|
||||
return _create_badam_optimizer(model, training_args, finetuning_args)
|
||||
|
||||
if finetuning_args.use_galore:
|
||||
return _create_galore_optimizer(model, training_args, finetuning_args)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user