adapt for badam with ds zero3

Former-commit-id: fff2a020ec8713022bd8145f4a7168168ea07ca4
This commit is contained in:
Jonery
2024-06-17 18:18:10 +08:00
parent 4bd276f58f
commit ba303fd1aa
3 changed files with 28 additions and 6 deletions

View File

@@ -55,6 +55,21 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
output_dir = output_dir if output_dir is not None else self.args.output_dir
getattr(self.processor, "image_processor").save_pretrained(output_dir)
def training_step(self, *args, **kwargs):
r"""
Update the reference to deepspeed optimizer
"""
if self.finetuning_args.use_badam and \
self.args.deepspeed_plugin is not None and \
self.args.deepspeed_plugin.zero_stage == 3:
ds_optim = self.optimizer.optimizer
badam_optim = ds_optim.optimizer
badam_optim.ds_optimizer = ds_optim
return super().training_step(*args, **kwargs)
def prediction_step(
self,
model: "torch.nn.Module",