support badam for all stages

Former-commit-id: 7a1380646119bfe6855f73dd90570defcea05281
This commit is contained in:
hiyouga
2024-04-16 17:44:48 +08:00
parent 42084e08ae
commit a4167fd925
9 changed files with 61 additions and 28 deletions

View File

@@ -1,6 +1,7 @@
import math
import os
import sys
from types import MethodType
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
import torch
@@ -124,6 +125,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
else:
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
r"""
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.