fix galore
Former-commit-id: 62a3ceeef8f60caef43ccc7f971a0c9184e21296
This commit is contained in:
@@ -154,14 +154,28 @@ def create_custom_optimzer(
|
||||
},
|
||||
]
|
||||
if training_args.optim == "adamw_torch":
|
||||
optimizer = GaLoreAdamW(param_groups, lr=training_args.learning_rate)
|
||||
elif training_args.optim == "adamw_8bit":
|
||||
optimizer = GaLoreAdamW8bit(param_groups, lr=training_args.learning_rate)
|
||||
optimizer = GaLoreAdamW(
|
||||
param_groups,
|
||||
lr=training_args.learning_rate,
|
||||
eps=training_args.adam_epsilon,
|
||||
betas=(training_args.adam_beta1, training_args.adam_beta2),
|
||||
)
|
||||
elif training_args.optim in ["adamw_bnb_8bit", "adamw_8bit", "paged_adamw_8bit"]:
|
||||
optimizer = GaLoreAdamW8bit(
|
||||
param_groups,
|
||||
lr=training_args.learning_rate,
|
||||
eps=training_args.adam_epsilon,
|
||||
betas=(training_args.adam_beta1, training_args.adam_beta2),
|
||||
optim_bits=8,
|
||||
is_paged="paged" in training_args.optim,
|
||||
)
|
||||
elif training_args.optim == "adafactor":
|
||||
optimizer = GaLoreAdafactor(param_groups, lr=training_args.learning_rate)
|
||||
optimizer = GaLoreAdafactor(
|
||||
param_groups,
|
||||
lr=training_args.learning_rate,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Unknow optim: {}".format(training_args.optim))
|
||||
|
||||
logger.info("Used the GaLore optimizer, may cause hanging at the start of training, wait patiently.")
|
||||
|
||||
logger.info("Using GaLore optimizer, may cause hanging at the start of training, wait patiently.")
|
||||
return optimizer
|
||||
|
||||
Reference in New Issue
Block a user