fix layer norm dtype

Former-commit-id: 67af21961b68d9b54d07b09e444c7140869f26da
This commit is contained in:
hiyouga
2023-09-28 00:25:55 +08:00
parent 6c5d8f089e
commit 1c150995ae
6 changed files with 28 additions and 22 deletions

View File

@@ -128,10 +128,6 @@ def load_model_and_tokenizer(
else:
logger.warning("Current model does not support RoPE scaling.")
# Fix RMSNorm in fp32 weight (https://github.com/huggingface/transformers/pull/23535)
if getattr(config, "model_type", None) == "llama":
LlamaModule.LlamaRMSNorm = LlamaPatches.LlamaRMSNorm
# Set FlashAttention-2
if model_args.flash_attn:
if getattr(config, "model_type", None) == "llama":
@@ -205,7 +201,8 @@ def load_model_and_tokenizer(
tokenizer.__class__.register_for_auto_class()
# Initialize adapters
model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model
if is_trainable:
model = prepare_model_for_training(model, model_args.layernorm_dtype, finetuning_args.finetuning_type)
model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
model = model.train() if is_trainable else model.eval()

View File

@@ -226,6 +226,17 @@ def get_train_args(
else:
model_args.compute_dtype = _infer_dtype()
if model_args.layernorm_dtype == "bf16":
if not is_bf16_available:
raise ValueError("Current device does not support bf16 type.")
model_args.layernorm_dtype = torch.bfloat16
elif model_args.layernorm_dtype == "fp16":
model_args.layernorm_dtype = torch.float16
elif model_args.layernorm_dtype == "fp32":
model_args.layernorm_dtype = torch.float32
else:
model_args.layernorm_dtype = model_args.compute_dtype
model_args.model_max_length = data_args.cutoff_len
# Log on each process the small summary:

View File

@@ -31,6 +31,7 @@ def find_all_linear_modules(
def prepare_model_for_training(
model: "PreTrainedModel",
layernorm_dtype: torch.dtype,
finetuning_type: str,
output_layer_name: Optional[str] = "lm_head",
use_gradient_checkpointing: Optional[bool] = True,
@@ -45,7 +46,7 @@ def prepare_model_for_training(
"""
for name, param in model.named_parameters():
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
param.data = param.data.to(torch.float32)
param.data = param.data.to(layernorm_dtype)
if use_gradient_checkpointing:
if hasattr(model, "enable_input_require_grads"):