support badam for all stages

Former-commit-id: 7a1380646119bfe6855f73dd90570defcea05281
This commit is contained in:
hiyouga
2024-04-16 17:44:48 +08:00
parent 42084e08ae
commit a4167fd925
9 changed files with 61 additions and 28 deletions

View File

@@ -1,4 +1,5 @@
from collections import defaultdict
from types import MethodType
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
import torch
@@ -44,6 +45,10 @@ class CustomORPOTrainer(DPOTrainer):
self._stored_metrics = defaultdict(lambda: defaultdict(list))
Trainer.__init__(self, model=model, **kwargs)
if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None: