Former-commit-id: 8c01ffe8d277d49a413571e0669f460c8d0802bf
This commit is contained in:
hiyouga
2023-11-20 18:46:36 +08:00
parent ba2be6371d
commit adf2730d1d
5 changed files with 34 additions and 36 deletions

View File

@@ -74,10 +74,13 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
else:
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
def ppo_train(self) -> None:
def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
r"""
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
"""
if resume_from_checkpoint is not None:
raise ValueError("`resume_from_checkpoint` will be supported in the future version.")
total_train_batch_size = (
self.args.per_device_train_batch_size * self.args.gradient_accumulation_steps * self.args.world_size
)

View File

@@ -94,7 +94,7 @@ def run_ppo(
# Training
if training_args.do_train:
ppo_trainer.ppo_train()
ppo_trainer.ppo_train(resume_from_checkpoint=training_args.resume_from_checkpoint)
ppo_trainer.save_model()
ppo_trainer.save_state() # must be called after save_model to have a folder
if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss:

View File

@@ -47,7 +47,7 @@ def run_rm(
# Training
if training_args.do_train:
train_result = trainer.train()
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.save_model()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)