fix ppo args

Former-commit-id: 0f12899951808f53a482082eb116bda309775930
This commit is contained in:
hiyouga
2023-10-11 23:40:50 +08:00
parent 3198a7e5f4
commit 97b74d328b
4 changed files with 18 additions and 9 deletions

View File

@@ -42,15 +42,14 @@ def run_ppo(
ppo_epochs=1,
max_grad_norm=training_args.max_grad_norm,
seed=training_args.seed,
log_with=training_args.report_to,
optimize_cuda_cache=True,
target=finetuning_args.ppo_target,
log_with=finetuning_args.ppo_logger,
use_score_scaling=finetuning_args.ppo_score_norm,
use_score_norm=finetuning_args.ppo_score_norm,
accelerator_kwargs={"step_scheduler_with_optimizer": False}
)
if finetuning_args.ppo_score_norm:
ppo_config.use_score_scaling = True
ppo_config.use_score_norm = True
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
total_train_batch_size = (
training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size