Former-commit-id: 59f2cbbd52d4646fbd1ba83032bf522ecc49a50f
This commit is contained in:
hiyouga
2023-11-01 23:38:49 +08:00
parent dab8f45033
commit 8d52fb46ca
5 changed files with 33 additions and 24 deletions

View File

@@ -51,10 +51,14 @@ def run_ppo(
)
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
total_train_batch_size = (
training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
)
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
if training_args.max_steps > 0:
num_training_steps = training_args.max_steps
else:
total_train_batch_size = (
training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
)
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
lr_scheduler = get_scheduler(
training_args.lr_scheduler_type,
optimizer=optimizer,