fix layer norm dtype
Former-commit-id: 67af21961b68d9b54d07b09e444c7140869f26da
This commit is contained in:
@@ -226,6 +226,17 @@ def get_train_args(
|
||||
else:
|
||||
model_args.compute_dtype = _infer_dtype()
|
||||
|
||||
if model_args.layernorm_dtype == "bf16":
|
||||
if not is_bf16_available:
|
||||
raise ValueError("Current device does not support bf16 type.")
|
||||
model_args.layernorm_dtype = torch.bfloat16
|
||||
elif model_args.layernorm_dtype == "fp16":
|
||||
model_args.layernorm_dtype = torch.float16
|
||||
elif model_args.layernorm_dtype == "fp32":
|
||||
model_args.layernorm_dtype = torch.float32
|
||||
else:
|
||||
model_args.layernorm_dtype = model_args.compute_dtype
|
||||
|
||||
model_args.model_max_length = data_args.cutoff_len
|
||||
|
||||
# Log on each process the small summary:
|
||||
|
||||
Reference in New Issue
Block a user