support badam for all stages

Former-commit-id: 7a1380646119bfe6855f73dd90570defcea05281
This commit is contained in:
hiyouga
2024-04-16 17:44:48 +08:00
parent 42084e08ae
commit a4167fd925
9 changed files with 61 additions and 28 deletions

View File

@@ -1,5 +1,6 @@
import json
import os
from types import MethodType
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
@@ -28,6 +29,10 @@ class PairwiseTrainer(Trainer):
super().__init__(**kwargs)
self.finetuning_args = finetuning_args
self.can_return_loss = True # override property to return eval_loss
if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None: