fix galore

Former-commit-id: 62a3ceeef8f60caef43ccc7f971a0c9184e21296
This commit is contained in:
hiyouga
2024-03-08 00:44:51 +08:00
parent 81fcb80466
commit e416cecf62
11 changed files with 129 additions and 25 deletions

View File

@@ -154,14 +154,28 @@ def create_custom_optimzer(
},
]
if training_args.optim == "adamw_torch":
optimizer = GaLoreAdamW(param_groups, lr=training_args.learning_rate)
elif training_args.optim == "adamw_8bit":
optimizer = GaLoreAdamW8bit(param_groups, lr=training_args.learning_rate)
optimizer = GaLoreAdamW(
param_groups,
lr=training_args.learning_rate,
eps=training_args.adam_epsilon,
betas=(training_args.adam_beta1, training_args.adam_beta2),
)
elif training_args.optim in ["adamw_bnb_8bit", "adamw_8bit", "paged_adamw_8bit"]:
optimizer = GaLoreAdamW8bit(
param_groups,
lr=training_args.learning_rate,
eps=training_args.adam_epsilon,
betas=(training_args.adam_beta1, training_args.adam_beta2),
optim_bits=8,
is_paged="paged" in training_args.optim,
)
elif training_args.optim == "adafactor":
optimizer = GaLoreAdafactor(param_groups, lr=training_args.learning_rate)
optimizer = GaLoreAdafactor(
param_groups,
lr=training_args.learning_rate,
)
else:
raise NotImplementedError("Unknow optim: {}".format(training_args.optim))
logger.info("Used the GaLore optimizer, may cause hanging at the start of training, wait patiently.")
logger.info("Using GaLore optimizer, may cause hanging at the start of training, wait patiently.")
return optimizer