fix ppo callbacks

Former-commit-id: 54f1c67c2a802b1d8368a6d1837d4c9a729f2695
This commit is contained in:
hiyouga
2024-07-02 17:34:56 +08:00
parent 973cf8e980
commit 96a81ce89d
2 changed files with 5 additions and 5 deletions

View File

@@ -70,7 +70,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: List["TrainerCallback"],
callbacks: Optional[List["TrainerCallback"]],
model: "AutoModelForCausalLMWithValueHead",
reward_model: Optional["AutoModelForCausalLMWithValueHead"],
ref_model: Optional["AutoModelForCausalLMWithValueHead"],
@@ -78,7 +78,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
processor: Optional["ProcessorMixin"],
dataset: "Dataset",
data_collator: "DataCollatorWithPadding",
):
) -> None:
backward_batch_size = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
ppo_config = PPOConfig(
model_name=model_args.model_name_or_path,
@@ -144,7 +144,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
self.callback_handler = CallbackHandler(
[callbacks], self.accelerator.unwrap_model(self.model), self.tokenizer, self.optimizer, self.lr_scheduler
callbacks, self.accelerator.unwrap_model(self.model), self.tokenizer, self.optimizer, self.lr_scheduler
)
if self.args.max_steps > 0: