update trainers
Former-commit-id: d0dd6eefed0b86895ed00a7cafb331e5193db645
This commit is contained in:
@@ -5,6 +5,7 @@ from transformers import Trainer
|
||||
from transformers.optimization import get_scheduler
|
||||
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||
from transformers.trainer_pt_utils import get_parameter_names
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.packages import is_galore_available
|
||||
@@ -28,7 +29,13 @@ logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DummyOptimizer(torch.optim.Optimizer):
|
||||
def __init__(self, lr: float = 1e-3, optimizer_dict: Optional[dict] = None, *args, **kwargs) -> None:
|
||||
r"""
|
||||
A dummy optimizer used for the GaLore algorithm.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, lr: float = 1e-3, optimizer_dict: Optional[Dict["torch.nn.Parameter", "torch.optim.Optimizer"]] = None
|
||||
) -> None:
|
||||
dummy_tensor = torch.randn(1, 1)
|
||||
self.optimizer_dict = optimizer_dict
|
||||
super().__init__([dummy_tensor], {"lr": lr})
|
||||
@@ -155,8 +162,9 @@ def _create_galore_optimizer(
|
||||
model: "PreTrainedModel",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
max_steps: int,
|
||||
) -> "torch.optim.Optimizer":
|
||||
require_version("galore_torch", "To fix: pip install galore_torch")
|
||||
|
||||
if len(finetuning_args.galore_target) == 1 and finetuning_args.galore_target[0] == "all":
|
||||
galore_targets = find_all_linear_modules(model)
|
||||
else:
|
||||
@@ -211,29 +219,19 @@ def _create_galore_optimizer(
|
||||
for param in decay_params:
|
||||
param_groups = [dict(params=[param], weight_decay=training_args.weight_decay)]
|
||||
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
|
||||
for param in galore_params:
|
||||
for param in galore_params: # galore params have weight decay
|
||||
param_groups = [dict(params=[param], weight_decay=training_args.weight_decay, **galore_kwargs)]
|
||||
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
|
||||
|
||||
scheduler_dict: Dict["torch.Tensor", "torch.optim.lr_scheduler.LRScheduler"] = {}
|
||||
for param in trainable_params:
|
||||
scheduler_dict[param] = get_scheduler(
|
||||
training_args.lr_scheduler_type,
|
||||
optimizer=optimizer_dict[param],
|
||||
num_warmup_steps=training_args.get_warmup_steps(max_steps) * 2,
|
||||
num_training_steps=max_steps * 2,
|
||||
)
|
||||
|
||||
def optimizer_hook(param: "torch.Tensor"):
|
||||
def optimizer_hook(param: "torch.nn.Parameter"):
|
||||
if param.grad is not None:
|
||||
optimizer_dict[param].step()
|
||||
optimizer_dict[param].zero_grad()
|
||||
scheduler_dict[param].step()
|
||||
|
||||
for param in trainable_params:
|
||||
param.register_post_accumulate_grad_hook(optimizer_hook)
|
||||
|
||||
optimizer = DummyOptimizer(lr=training_args.learning_rate) # display scheduler result
|
||||
optimizer = DummyOptimizer(lr=training_args.learning_rate, optimizer_dict=optimizer_dict)
|
||||
else:
|
||||
param_groups = [
|
||||
dict(params=nodecay_params),
|
||||
@@ -292,10 +290,34 @@ def create_custom_optimzer(
|
||||
model: "PreTrainedModel",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
max_steps: int,
|
||||
) -> Optional["torch.optim.Optimizer"]:
|
||||
if finetuning_args.use_galore:
|
||||
return _create_galore_optimizer(model, training_args, finetuning_args, max_steps)
|
||||
return _create_galore_optimizer(model, training_args, finetuning_args)
|
||||
|
||||
if finetuning_args.loraplus_lr_ratio is not None:
|
||||
return _create_loraplus_optimizer(model, training_args, finetuning_args)
|
||||
|
||||
|
||||
def create_custom_scheduler(
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
num_training_steps: int,
|
||||
optimizer: Optional["torch.optim.Optimizer"] = None,
|
||||
) -> None:
|
||||
if optimizer is not None and isinstance(optimizer, DummyOptimizer):
|
||||
optimizer_dict = optimizer.optimizer_dict
|
||||
scheduler_dict: Dict["torch.nn.Parameter", "torch.optim.lr_scheduler.LRScheduler"] = {}
|
||||
|
||||
for param in optimizer_dict.keys():
|
||||
scheduler_dict[param] = get_scheduler(
|
||||
training_args.lr_scheduler_type,
|
||||
optimizer=optimizer_dict[param],
|
||||
num_warmup_steps=training_args.get_warmup_steps(num_training_steps) * 2,
|
||||
num_training_steps=num_training_steps * 2,
|
||||
)
|
||||
|
||||
def scheduler_hook(param: "torch.nn.Parameter"):
|
||||
if param.grad is not None:
|
||||
scheduler_dict[param].step()
|
||||
|
||||
for param in optimizer_dict.keys():
|
||||
param.register_post_accumulate_grad_hook(scheduler_hook)
|
||||
|
||||
Reference in New Issue
Block a user