update ppo trainer

Former-commit-id: caa525a5c6f228b9ad71387d1fe4f1c2ffa2479e
This commit is contained in:
hiyouga
2023-11-20 21:39:15 +08:00
parent e585950c54
commit 28258aecd2
7 changed files with 68 additions and 41 deletions

View File

@@ -82,13 +82,16 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
raise ValueError("`resume_from_checkpoint` will be supported in the future version.")
total_train_batch_size = (
self.args.per_device_train_batch_size * self.args.gradient_accumulation_steps * self.args.world_size
self.args.per_device_train_batch_size
* self.args.gradient_accumulation_steps
* self.finetuning_args.ppo_buffer_size
* self.args.world_size
)
if self.args.max_steps > 0:
num_examples = total_train_batch_size * self.args.max_steps
num_train_epochs = sys.maxsize
max_steps = self.args.max_steps
steps_in_epoch = self.args.max_steps * self.args.gradient_accumulation_steps
steps_in_epoch = self.args.max_steps
else:
len_dataloader = len(self.dataloader)
num_examples = len(self.dataset)
@@ -103,13 +106,16 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
if self.is_world_process_zero():
logger.info("***** Running training *****")
logger.info(f" Num examples = {num_examples}")
logger.info(f" Num Epochs = {num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_steps}")
logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]}")
logger.info(" Num examples = {}".format(num_examples))
logger.info(" Num Epochs = {}".format(num_train_epochs))
logger.info(" Instantaneous batch size per device = {}".format(self.args.per_device_train_batch_size))
logger.info(" Total train batch size (w. parallel, buffer, distributed & accumulation) = {}".format(
total_train_batch_size
))
logger.info(" Gradient Accumulation steps = {}".format(self.args.gradient_accumulation_steps))
logger.info(" Num optimization epochs per batch = {}".format(self.finetuning_args.ppo_epochs))
logger.info(" Total training steps = {}".format(max_steps))
logger.info(" Number of trainable parameters = {}".format(count_parameters(self.model)[0]))
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
dataiter = iter(self.dataloader)

View File

@@ -39,11 +39,12 @@ def run_ppo(
reward_model = create_reward_model(model, model_args, finetuning_args)
# Create ppo config
backward_batch_size = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
ppo_config = PPOConfig(
model_name=model_args.model_name_or_path,
learning_rate=training_args.learning_rate,
mini_batch_size=training_args.per_device_train_batch_size,
batch_size=training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps,
batch_size=backward_batch_size * finetuning_args.ppo_buffer_size,
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
ppo_epochs=finetuning_args.ppo_epochs,
max_grad_norm=training_args.max_grad_norm,
@@ -62,9 +63,7 @@ def run_ppo(
if training_args.max_steps > 0:
num_training_steps = training_args.max_steps
else:
total_train_batch_size = (
training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
)
total_train_batch_size = backward_batch_size * finetuning_args.ppo_buffer_size * training_args.world_size
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
lr_scheduler = get_scheduler(