fix layer norm dtype

Former-commit-id: 67af21961b68d9b54d07b09e444c7140869f26da
This commit is contained in:
hiyouga
2023-09-28 00:25:55 +08:00
parent 6c5d8f089e
commit 1c150995ae
6 changed files with 28 additions and 22 deletions

View File

@@ -67,6 +67,10 @@ class ModelArguments:
default=None,
metadata={"help": "Auth token to log in with Hugging Face Hub."}
)
layernorm_dtype: Optional[Literal["auto", "fp16", "bf16", "fp32"]] = field(
default="auto",
metadata={"help": "Data type of the layer norm weights."}
)
def __post_init__(self):
self.compute_dtype = None