refactor model_dtype, fix PPO trainer

Former-commit-id: 3e17ee5afbcb823a7c9a2f91864b3750cd79edb4
This commit is contained in:
hiyouga
2023-10-11 23:16:01 +08:00
parent a2d08ce961
commit 3198a7e5f4
10 changed files with 104 additions and 119 deletions

View File

@@ -145,6 +145,9 @@ class Runner:
)
args[compute_type] = True
if args["quantization_bit"] is not None:
args["upcast_layernorm"] = True
if args["stage"] == "ppo":
args["reward_model"] = reward_model
val_size = 0