resolve gradient checkpointing issue.

Former-commit-id: 6df9135d063bb6102f0cbcdf0d702076f5febbae
This commit is contained in:
Jonery
2024-04-16 12:05:27 +08:00
parent d4d471450f
commit 6dd6b3e396
4 changed files with 8 additions and 14 deletions

View File

@@ -29,7 +29,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs) -> None:
super().__init__(**kwargs)
self.finetuning_args = finetuning_args
if version.parse(torch.__version__) >= version.parse("1.13"):
if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)