Former-commit-id: 43a56cb331fae899ca35b0c312730d4ab79d0c42
This commit is contained in:
hiyouga
2024-07-15 01:04:56 +08:00
parent 68365045b4
commit e4d11a117b
18 changed files with 46 additions and 41 deletions

View File

@@ -77,9 +77,13 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
ref_model: Optional["AutoModelForCausalLMWithValueHead"],
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
dataset: "Dataset",
data_collator: "DataCollatorWithPadding",
train_dataset: Optional["Dataset"] = None,
eval_dataset: Optional["Dataset"] = None,
) -> None:
if eval_dataset is not None:
raise NotImplementedError("PPOTrainer does not support eval dataset yet.")
backward_batch_size = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
ppo_config = PPOConfig(
model_name=model_args.model_name_or_path,
@@ -115,7 +119,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
num_training_steps = training_args.max_steps
else:
total_train_batch_size = backward_batch_size * finetuning_args.ppo_buffer_size * training_args.world_size
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
num_training_steps = training_args.num_train_epochs * math.ceil(len(train_dataset) / total_train_batch_size)
optimizer = self.create_optimizer(model, training_args, finetuning_args)
scheduler = self.create_scheduler(training_args, num_training_steps, optimizer)
@@ -126,7 +130,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
model=model,
ref_model=ref_model,
tokenizer=tokenizer,
dataset=dataset,
dataset=train_dataset,
data_collator=data_collator,
lr_scheduler=scheduler,
)