update ppo trainer

Former-commit-id: caa525a5c6f228b9ad71387d1fe4f1c2ffa2479e
This commit is contained in:
hiyouga
2023-11-20 21:39:15 +08:00
parent e585950c54
commit 28258aecd2
7 changed files with 68 additions and 41 deletions

View File

@@ -82,13 +82,16 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
raise ValueError("`resume_from_checkpoint` will be supported in the future version.")
total_train_batch_size = (
self.args.per_device_train_batch_size * self.args.gradient_accumulation_steps * self.args.world_size
self.args.per_device_train_batch_size
* self.args.gradient_accumulation_steps
* self.finetuning_args.ppo_buffer_size
* self.args.world_size
)
if self.args.max_steps > 0:
num_examples = total_train_batch_size * self.args.max_steps
num_train_epochs = sys.maxsize
max_steps = self.args.max_steps
steps_in_epoch = self.args.max_steps * self.args.gradient_accumulation_steps
steps_in_epoch = self.args.max_steps
else:
len_dataloader = len(self.dataloader)
num_examples = len(self.dataset)
@@ -103,13 +106,16 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
if self.is_world_process_zero():
logger.info("***** Running training *****")
logger.info(f" Num examples = {num_examples}")
logger.info(f" Num Epochs = {num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_steps}")
logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]}")
logger.info(" Num examples = {}".format(num_examples))
logger.info(" Num Epochs = {}".format(num_train_epochs))
logger.info(" Instantaneous batch size per device = {}".format(self.args.per_device_train_batch_size))
logger.info(" Total train batch size (w. parallel, buffer, distributed & accumulation) = {}".format(
total_train_batch_size
))
logger.info(" Gradient Accumulation steps = {}".format(self.args.gradient_accumulation_steps))
logger.info(" Num optimization epochs per batch = {}".format(self.finetuning_args.ppo_epochs))
logger.info(" Total training steps = {}".format(max_steps))
logger.info(" Number of trainable parameters = {}".format(count_parameters(self.model)[0]))
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
dataiter = iter(self.dataloader)