fix: by hiyouga suggestion
Former-commit-id: 41195f1bc69e4b5da7a265369d368b06754362cf
This commit is contained in:
@@ -40,7 +40,7 @@ from typing_extensions import override
|
||||
from ...extras import logging
|
||||
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
|
||||
from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback
|
||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_swanlab_callback
|
||||
from .ppo_utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm
|
||||
|
||||
|
||||
@@ -186,6 +186,9 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.add_callback(BAdamCallback)
|
||||
|
||||
if finetuning_args.use_swanlab:
|
||||
self.add_callback(get_swanlab_callback(finetuning_args))
|
||||
|
||||
def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
|
||||
r"""
|
||||
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
|
||||
|
||||
Reference in New Issue
Block a user