Merge pull request #4352 from Ledzy/main

[Enhancement] Support ZeRO-3 when using BAdam

Former-commit-id: 0dc75275efa7d7540b472783a52ea6aeaa503c0b
This commit is contained in:
hoshi-hiyouga
2024-06-25 01:49:13 +08:00
committed by GitHub
12 changed files with 149 additions and 24 deletions

View File

@@ -372,6 +372,9 @@ def _create_badam_optimizer(
dict(params=decay_params, weight_decay=training_args.weight_decay),
]
from transformers.integrations import is_deepspeed_zero3_enabled
ds_zero3_enabled = is_deepspeed_zero3_enabled()
if finetuning_args.badam_mode == "layer":
from badam import BlockOptimizer
@@ -384,6 +387,7 @@ def _create_badam_optimizer(
start_block=finetuning_args.badam_start_block,
switch_mode=finetuning_args.badam_switch_mode,
verbose=finetuning_args.badam_verbose,
ds_zero3_enabled=ds_zero3_enabled
)
logger.info(
f"Using BAdam optimizer with layer-wise update, switch mode is {finetuning_args.badam_switch_mode}, "
@@ -394,6 +398,7 @@ def _create_badam_optimizer(
elif finetuning_args.badam_mode == "ratio":
from badam import BlockOptimizerRatio
assert not ds_zero3_enabled, "BAdam with ratio-based update does not support Deepspeed ZeRO-3 yet, use layer-wise update instead: --badam_mode layer."
assert finetuning_args.badam_update_ratio > 1e-6
optimizer = BlockOptimizerRatio(
param_groups=param_groups,
@@ -405,7 +410,7 @@ def _create_badam_optimizer(
**optim_kwargs,
)
logger.info(
f"Using BAdam optimizer with ratio-wise update, update ratio is {finetuning_args.badam_update_ratio}, "
f"Using BAdam optimizer with ratio-based update, update ratio is {finetuning_args.badam_update_ratio}, "
f"mask mode is {finetuning_args.badam_mask_mode}"
)