Merge pull request #4352 from Ledzy/main
[Enhancement] Support ZeRO-3 when using BAdam Former-commit-id: 0dc75275efa7d7540b472783a52ea6aeaa503c0b
This commit is contained in:
@@ -214,13 +214,15 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
|
||||
if (
|
||||
finetuning_args.use_badam
|
||||
and finetuning_args.badam_mode == "layer"
|
||||
and training_args.parallel_mode == ParallelMode.DISTRIBUTED
|
||||
and training_args.parallel_mode.value == "distributed"
|
||||
):
|
||||
raise ValueError("Layer-wise BAdam does not yet support distributed training, use ratio-wise BAdam.")
|
||||
if finetuning_args.badam_mode == "ratio":
|
||||
raise ValueError("Ratio-wise BAdam does not yet support distributed training, use layer-wise BAdam: --badam_mode layer")
|
||||
if finetuning_args.badam_mode == "layer" and (not is_deepspeed_zero3_enabled()):
|
||||
raise ValueError(f"Layer-wise BAdam only supports DeepSpeed ZeRO 3 stage.")
|
||||
|
||||
if (finetuning_args.use_galore or finetuning_args.use_badam) and training_args.deepspeed is not None:
|
||||
raise ValueError("GaLore and BAdam are incompatible with DeepSpeed yet.")
|
||||
if (finetuning_args.use_galore) and training_args.deepspeed is not None:
|
||||
raise ValueError("GaLore are incompatible with DeepSpeed yet.")
|
||||
|
||||
if model_args.infer_backend == "vllm":
|
||||
raise ValueError("vLLM backend is only available for API, CLI and Web.")
|
||||
|
||||
@@ -96,9 +96,9 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
self.save_model(os.path.join(self.args.output_dir, "pissa_init"))
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
from badam import clip_grad_norm_for_sparse_tensor
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
|
||||
from badam import clip_grad_norm_old_version, BAdamCallback
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.callback_handler.add_callback(BAdamCallback)
|
||||
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
|
||||
@@ -91,9 +91,9 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
self.ref_model.eval()
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
from badam import clip_grad_norm_for_sparse_tensor
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
|
||||
from badam import clip_grad_norm_old_version, BAdamCallback
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.callback_handler.add_callback(BAdamCallback)
|
||||
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
|
||||
@@ -166,9 +166,9 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
from badam import clip_grad_norm_for_sparse_tensor
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
|
||||
from badam import clip_grad_norm_old_version, BAdamCallback
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.callback_handler.add_callback(BAdamCallback)
|
||||
|
||||
def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
|
||||
r"""
|
||||
|
||||
@@ -48,9 +48,9 @@ class CustomTrainer(Trainer):
|
||||
self.save_model(os.path.join(self.args.output_dir, "pissa_init"))
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
from badam import clip_grad_norm_for_sparse_tensor
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
|
||||
from badam import clip_grad_norm_old_version, BAdamCallback
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.callback_handler.add_callback(BAdamCallback)
|
||||
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
|
||||
@@ -72,9 +72,9 @@ class PairwiseTrainer(Trainer):
|
||||
self.processor = processor
|
||||
self.can_return_loss = True # override property to return eval_loss
|
||||
if finetuning_args.use_badam:
|
||||
from badam import clip_grad_norm_for_sparse_tensor
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
|
||||
from badam import clip_grad_norm_old_version, BAdamCallback
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.callback_handler.add_callback(BAdamCallback)
|
||||
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
|
||||
@@ -56,9 +56,9 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
self.save_model(os.path.join(self.args.output_dir, "pissa_init"))
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
from badam import clip_grad_norm_for_sparse_tensor
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
|
||||
from badam import clip_grad_norm_old_version, BAdamCallback
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.callback_handler.add_callback(BAdamCallback)
|
||||
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
|
||||
@@ -372,6 +372,9 @@ def _create_badam_optimizer(
|
||||
dict(params=decay_params, weight_decay=training_args.weight_decay),
|
||||
]
|
||||
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
ds_zero3_enabled = is_deepspeed_zero3_enabled()
|
||||
|
||||
if finetuning_args.badam_mode == "layer":
|
||||
from badam import BlockOptimizer
|
||||
|
||||
@@ -384,6 +387,7 @@ def _create_badam_optimizer(
|
||||
start_block=finetuning_args.badam_start_block,
|
||||
switch_mode=finetuning_args.badam_switch_mode,
|
||||
verbose=finetuning_args.badam_verbose,
|
||||
ds_zero3_enabled=ds_zero3_enabled
|
||||
)
|
||||
logger.info(
|
||||
f"Using BAdam optimizer with layer-wise update, switch mode is {finetuning_args.badam_switch_mode}, "
|
||||
@@ -394,6 +398,7 @@ def _create_badam_optimizer(
|
||||
elif finetuning_args.badam_mode == "ratio":
|
||||
from badam import BlockOptimizerRatio
|
||||
|
||||
assert not ds_zero3_enabled, "BAdam with ratio-based update does not support Deepspeed ZeRO-3 yet, use layer-wise update instead: --badam_mode layer."
|
||||
assert finetuning_args.badam_update_ratio > 1e-6
|
||||
optimizer = BlockOptimizerRatio(
|
||||
param_groups=param_groups,
|
||||
@@ -405,7 +410,7 @@ def _create_badam_optimizer(
|
||||
**optim_kwargs,
|
||||
)
|
||||
logger.info(
|
||||
f"Using BAdam optimizer with ratio-wise update, update ratio is {finetuning_args.badam_update_ratio}, "
|
||||
f"Using BAdam optimizer with ratio-based update, update ratio is {finetuning_args.badam_update_ratio}, "
|
||||
f"mask mode is {finetuning_args.badam_mask_mode}"
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user