refactor model_dtype, fix PPO trainer
Former-commit-id: 3e17ee5afbcb823a7c9a2f91864b3750cd79edb4
This commit is contained in:
@@ -145,6 +145,9 @@ class Runner:
|
||||
)
|
||||
args[compute_type] = True
|
||||
|
||||
if args["quantization_bit"] is not None:
|
||||
args["upcast_layernorm"] = True
|
||||
|
||||
if args["stage"] == "ppo":
|
||||
args["reward_model"] = reward_model
|
||||
val_size = 0
|
||||
|
||||
Reference in New Issue
Block a user