fix incorrect loss value for vlms
Former-commit-id: 0aa29a71ce958343a2086090d647eb63b8f5f5be
This commit is contained in:
@@ -181,7 +181,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
self.add_callback(SaveProcessorCallback(processor))
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
from badam import BAdamCallback, clip_grad_norm_old_version
|
||||
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.add_callback(BAdamCallback)
|
||||
|
||||
Reference in New Issue
Block a user