fix layer norm dtype
Former-commit-id: 67af21961b68d9b54d07b09e444c7140869f26da
This commit is contained in:
@@ -2,7 +2,7 @@ IGNORE_INDEX = -100
|
||||
|
||||
LOG_FILE_NAME = "trainer_log.jsonl"
|
||||
|
||||
LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"]
|
||||
LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp", "ln_1", "ln_2"]
|
||||
|
||||
METHODS = ["full", "freeze", "lora"]
|
||||
|
||||
|
||||
@@ -19,21 +19,6 @@ except ImportError:
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class LlamaRMSNorm(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return (self.weight * hidden_states).to(input_dtype)
|
||||
|
||||
|
||||
class LlamaShiftShortAttention(LlamaAttention):
|
||||
|
||||
def forward(
|
||||
@@ -162,6 +147,14 @@ class LlamaFlashAttention2(LlamaAttention):
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
# cast to half precision
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
logger.warning_once("The input hidden states seems to be silently casted in float32.")
|
||||
query_states = query_states.to(torch.float16)
|
||||
key_states = key_states.to(torch.float16)
|
||||
value_states = value_states.to(torch.float16)
|
||||
|
||||
if getattr(self, "num_key_value_groups"):
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
Reference in New Issue
Block a user