@@ -79,7 +79,7 @@ def run_ppo(
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
ppo_trainer.ppo_train(max_target_length=data_args.max_target_length)
|
||||
ppo_trainer.ppo_train()
|
||||
ppo_trainer.save_model()
|
||||
ppo_trainer.save_state() # must be called after save_model to have a folder
|
||||
if ppo_trainer.is_world_process_zero() and model_args.plot_loss:
|
||||
|
||||
Reference in New Issue
Block a user