refactor model_dtype, fix PPO trainer
Former-commit-id: 3e17ee5afbcb823a7c9a2f91864b3750cd79edb4
This commit is contained in:
@@ -65,10 +65,10 @@ def run_ppo(
|
||||
|
||||
# Initialize our Trainer
|
||||
ppo_trainer = CustomPPOTrainer(
|
||||
model_args=model_args,
|
||||
training_args=training_args,
|
||||
generating_args=generating_args,
|
||||
callbacks=callbacks + [SavePeftModelCallback()],
|
||||
compute_dtype=model_args.compute_dtype,
|
||||
config=ppo_config,
|
||||
model=model,
|
||||
ref_model=None,
|
||||
|
||||
Reference in New Issue
Block a user