Cleaner integration.

Former-commit-id: 26d4b05d424bd71f570195dd433258caf6465d92
This commit is contained in:
Jonery
2024-06-19 12:29:40 +08:00
parent c7479751e8
commit fa3150548e
8 changed files with 24 additions and 64 deletions

View File

@@ -371,11 +371,8 @@ def _create_badam_optimizer(
dict(params=decay_params, weight_decay=training_args.weight_decay),
]
ds_zero3_enabled = False
if hasattr(training_args, "deepspeed_plugin") and training_args.deepspeed_plugin is not None:
assert training_args.deepspeed_plugin.zero_stage == 3, f"BAdam only supports deepspeed ZeRO-3 stage, got {training_args.deepspeed_plugin.zero_stage}"
assert finetuning_args.badam_mode == "layer", "BAdam only supports layer-wise update in ZeRO-3 stage"
ds_zero3_enabled = True
from transformers.integrations import is_deepspeed_zero3_enabled
ds_zero3_enabled = is_deepspeed_zero3_enabled()
if finetuning_args.badam_mode == "layer":
from badam import BlockOptimizer
@@ -400,6 +397,7 @@ def _create_badam_optimizer(
elif finetuning_args.badam_mode == "ratio":
from badam import BlockOptimizerRatio
assert not ds_zero3_enabled, "BAdam with ratio-based update does not support Deepspeed ZeRO-3 yet, use layer-wise update instead: --badam_mode layer."
assert finetuning_args.badam_update_ratio > 1e-6
optimizer = BlockOptimizerRatio(
param_groups=param_groups,
@@ -411,7 +409,7 @@ def _create_badam_optimizer(
**optim_kwargs,
)
logger.info(
f"Using BAdam optimizer with ratio-wise update, update ratio is {finetuning_args.badam_update_ratio}, "
f"Using BAdam optimizer with ratio-based update, update ratio is {finetuning_args.badam_update_ratio}, "
f"mask mode is {finetuning_args.badam_mask_mode}"
)