refactor model_dtype, fix PPO trainer

Former-commit-id: 3e17ee5afbcb823a7c9a2f91864b3750cd79edb4
This commit is contained in:
hiyouga
2023-10-11 23:16:01 +08:00
parent a2d08ce961
commit 3198a7e5f4
10 changed files with 104 additions and 119 deletions

View File

@@ -67,9 +67,9 @@ class ModelArguments:
default=None,
metadata={"help": "Auth token to log in with Hugging Face Hub."}
)
layernorm_dtype: Optional[Literal["auto", "fp16", "bf16", "fp32"]] = field(
default="auto",
metadata={"help": "Data type of the layer norm weights."}
upcast_layernorm: Optional[bool] = field(
default=False,
metadata={"help": "Whether to upcast the layernorm weights in fp32."}
)
def __post_init__(self):