Support distributed BAdam.

Former-commit-id: bdcb986e37975911c190a74d3e60bb77aa2033bd
This commit is contained in:
Jonery
2024-06-18 12:27:47 +08:00
parent 95ae30f678
commit 12fcfc2b72
7 changed files with 46 additions and 30 deletions

View File

@@ -170,6 +170,12 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
if (self.args.deepspeed_plugin is not None
and self.args.deepspeed_plugin.zero_stage == 3
):
from badam.utils import BAdamZeRO3Callback
self.callback_handler.add_callback(BAdamZeRO3Callback)
def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
r"""
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.