refactor model_dtype, fix PPO trainer
Former-commit-id: 3e17ee5afbcb823a7c9a2f91864b3750cd79edb4
This commit is contained in:
@@ -67,9 +67,9 @@ class ModelArguments:
|
||||
default=None,
|
||||
metadata={"help": "Auth token to log in with Hugging Face Hub."}
|
||||
)
|
||||
layernorm_dtype: Optional[Literal["auto", "fp16", "bf16", "fp32"]] = field(
|
||||
default="auto",
|
||||
metadata={"help": "Data type of the layer norm weights."}
|
||||
upcast_layernorm: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to upcast the layernorm weights in fp32."}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
|
||||
Reference in New Issue
Block a user