fix layer norm dtype

Former-commit-id: 67af21961b68d9b54d07b09e444c7140869f26da
This commit is contained in:
hiyouga
2023-09-28 00:25:55 +08:00
parent 6c5d8f089e
commit 1c150995ae
6 changed files with 28 additions and 22 deletions

View File

@@ -2,7 +2,7 @@ IGNORE_INDEX = -100
LOG_FILE_NAME = "trainer_log.jsonl"
LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"]
LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp", "ln_1", "ln_2"]
METHODS = ["full", "freeze", "lora"]

View File

@@ -19,21 +19,6 @@ except ImportError:
logger = logging.get_logger(__name__)
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return (self.weight * hidden_states).to(input_dtype)
class LlamaShiftShortAttention(LlamaAttention):
def forward(
@@ -162,6 +147,14 @@ class LlamaFlashAttention2(LlamaAttention):
past_key_value = (key_states, value_states) if use_cache else None
# cast to half precision
input_dtype = query_states.dtype
if input_dtype == torch.float32:
logger.warning_once("The input hidden states seems to be silently casted in float32.")
query_states = query_states.to(torch.float16)
key_states = key_states.to(torch.float16)
value_states = value_states.to(torch.float16)
if getattr(self, "num_key_value_groups"):
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)