Former-commit-id: 8c01ffe8d277d49a413571e0669f460c8d0802bf
This commit is contained in:
hiyouga
2023-11-20 18:46:36 +08:00
parent ba2be6371d
commit adf2730d1d
5 changed files with 34 additions and 36 deletions

View File

@@ -74,10 +74,13 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
else:
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
def ppo_train(self) -> None:
def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
r"""
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
"""
if resume_from_checkpoint is not None:
raise ValueError("`resume_from_checkpoint` will be supported in the future version.")
total_train_batch_size = (
self.args.per_device_train_batch_size * self.args.gradient_accumulation_steps * self.args.world_size
)