fix bug in galore optimizer

Former-commit-id: c05ac23261a5a8ba893c2918a43dc7777307407b
This commit is contained in:
hiyouga
2024-04-21 18:53:22 +08:00
parent f8e219dc81
commit d16561e7a4
2 changed files with 7 additions and 13 deletions

View File

@@ -234,14 +234,6 @@ def _create_galore_optimizer(
param_groups = [dict(params=[param], weight_decay=training_args.weight_decay, **galore_kwargs)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
def optimizer_hook(param: "torch.nn.Parameter"):
if param.grad is not None:
optimizer_dict[param].step()
optimizer_dict[param].zero_grad()
for param in trainable_params:
param.register_post_accumulate_grad_hook(optimizer_hook)
optimizer = DummyOptimizer(lr=training_args.learning_rate, optimizer_dict=optimizer_dict)
else:
param_groups = [
@@ -391,9 +383,11 @@ def create_custom_scheduler(
num_training_steps=num_training_steps * 2,
)
def scheduler_hook(param: "torch.nn.Parameter"):
def optimizer_hook(param: "torch.nn.Parameter"):
if param.grad is not None:
optimizer_dict[param].step()
optimizer_dict[param].zero_grad()
scheduler_dict[param].step()
for param in optimizer_dict.keys():
param.register_post_accumulate_grad_hook(scheduler_hook)
param.register_post_accumulate_grad_hook(optimizer_hook)