refactor model_dtype, fix PPO trainer

Former-commit-id: 3e17ee5afbcb823a7c9a2f91864b3750cd79edb4
This commit is contained in:
hiyouga
2023-10-11 23:16:01 +08:00
parent a2d08ce961
commit 3198a7e5f4
10 changed files with 104 additions and 119 deletions

View File

@@ -65,10 +65,10 @@ def run_ppo(
# Initialize our Trainer
ppo_trainer = CustomPPOTrainer(
model_args=model_args,
training_args=training_args,
generating_args=generating_args,
callbacks=callbacks + [SavePeftModelCallback()],
compute_dtype=model_args.compute_dtype,
config=ppo_config,
model=model,
ref_model=None,