support badam for all stages

Former-commit-id: 7a1380646119bfe6855f73dd90570defcea05281
This commit is contained in:
hiyouga
2024-04-16 17:44:48 +08:00
parent 42084e08ae
commit a4167fd925
9 changed files with 61 additions and 28 deletions

View File

@@ -1,5 +1,6 @@
from collections import defaultdict
from contextlib import nullcontext
from types import MethodType
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
import torch
@@ -63,6 +64,11 @@ class CustomDPOTrainer(DPOTrainer):
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args)