fix shift short attention
Former-commit-id: 9a49cce8e6f6b222f74a07bdab40efee6a77b0f1
This commit is contained in:
@@ -55,46 +55,32 @@ class LlamaShiftShortAttention(LlamaAttention):
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
if getattr(self, "shift_ratio", None) and self.training: # shift
|
||||
group_size = int(q_len * getattr(self, "shift_ratio"))
|
||||
if q_len % group_size > 0:
|
||||
raise ValueError("q_len {} should be divisible by group size {}.".format(q_len, group_size))
|
||||
num_group = q_len // group_size
|
||||
for state in (query_states, key_states, value_states):
|
||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
|
||||
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
|
||||
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
||||
num_groups = q_len // groupsz
|
||||
def shift(state: torch.Tensor) -> torch.Tensor:
|
||||
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
|
||||
state[:, :, self.num_heads//2:] = state[:, :, self.num_heads//2:].roll(-group_size//2, dims=1)
|
||||
state = state.reshape(bsz * num_group, group_size, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
state[:, :, self.num_heads//2:] = state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1)
|
||||
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1)
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz*n_group, :, groupsz, :)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
if getattr(self, "shift_ratio", None) and self.training: # shift back
|
||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
||||
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
||||
attn_output[:, :, self.num_heads//2:] = attn_output[:, :, self.num_heads//2:].roll(group_size//2, dims=1)
|
||||
attn_output[:, :, self.num_heads//2:] = attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
@@ -160,19 +146,21 @@ class LlamaFlashAttention2(LlamaAttention):
|
||||
key_states = key_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
||||
value_states = value_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
||||
|
||||
if getattr(self, "shift_ratio", None) and self.training: # shift
|
||||
group_size = int(q_len * getattr(self, "shift_ratio"))
|
||||
if q_len % group_size > 0:
|
||||
raise ValueError("q_len {} should be divisible by group size {}.".format(q_len, group_size))
|
||||
num_group = q_len // group_size
|
||||
for state in (query_states, key_states, value_states):
|
||||
state[:, :, self.num_heads//2:] = state[:, :, self.num_heads//2:].roll(-group_size//2, dims=1)
|
||||
state = state.reshape(bsz * num_group, group_size, self.num_heads, self.head_dim)
|
||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
|
||||
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
|
||||
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
||||
num_groups = q_len // groupsz
|
||||
def shift(state: torch.Tensor) -> torch.Tensor:
|
||||
state[:, :, self.num_heads//2:] = state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1)
|
||||
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim)
|
||||
|
||||
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.reshape(bsz * num_groups, groupsz)
|
||||
|
||||
if attention_mask is not None:
|
||||
logger.warning_once("Padded sequences are less efficient in FlashAttention.")
|
||||
batch_size = query_states.shape[0]
|
||||
# -q_len: assumes left padding
|
||||
# -q_len: assumes left padding when q_len != kv_len
|
||||
unpadded_q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(query_states, attention_mask[:, -q_len:])
|
||||
unpadded_k, _, cu_seqlens_k, max_seqlen_k = unpad_input(key_states, attention_mask)
|
||||
unpadded_v, _, _, _ = unpad_input(value_states, attention_mask)
|
||||
@@ -188,15 +176,15 @@ class LlamaFlashAttention2(LlamaAttention):
|
||||
softmax_scale=None,
|
||||
causal=True,
|
||||
)
|
||||
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, q_len)
|
||||
attn_output = pad_input(attn_output_unpad, indices_q, bsz, q_len)
|
||||
else:
|
||||
attn_output = flash_attn_func(
|
||||
query_states, key_states, value_states, 0.0, softmax_scale=None, causal=True
|
||||
)
|
||||
|
||||
if getattr(self, "shift_ratio", None) and self.training: # shift back
|
||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
||||
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
||||
attn_output[:, :, self.num_heads//2:] = attn_output[:, :, self.num_heads//2:].roll(group_size//2, dims=1)
|
||||
attn_output[:, :, self.num_heads//2:] = attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
Reference in New Issue
Block a user