|
|
|
|
@@ -35,6 +35,7 @@ from transformers.utils.versions import require_version
|
|
|
|
|
|
|
|
|
|
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
|
|
|
|
|
from ...extras.logging import get_logger
|
|
|
|
|
from ...extras.packages import is_transformers_version_greater_than_4_43
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
|
@@ -50,14 +51,15 @@ transformers_logger = logging.get_logger(__name__)
|
|
|
|
|
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
|
|
|
|
|
def llama_attention_forward(
|
|
|
|
|
self: "LlamaAttention",
|
|
|
|
|
hidden_states: torch.Tensor,
|
|
|
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
|
|
|
hidden_states: "torch.Tensor",
|
|
|
|
|
attention_mask: Optional["torch.Tensor"] = None,
|
|
|
|
|
position_ids: Optional["torch.LongTensor"] = None,
|
|
|
|
|
past_key_value: Optional["Cache"] = None,
|
|
|
|
|
output_attentions: bool = False,
|
|
|
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
|
|
|
cache_position: Optional["torch.LongTensor"] = None,
|
|
|
|
|
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None,
|
|
|
|
|
**kwargs,
|
|
|
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
|
|
|
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]:
|
|
|
|
|
bsz, q_len, _ = hidden_states.size()
|
|
|
|
|
|
|
|
|
|
query_states: "torch.Tensor" = self.q_proj(hidden_states)
|
|
|
|
|
@@ -68,7 +70,11 @@ def llama_attention_forward(
|
|
|
|
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
|
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
|
|
|
|
|
|
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
|
|
|
if position_embeddings is None:
|
|
|
|
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
|
|
|
else:
|
|
|
|
|
cos, sin = position_embeddings
|
|
|
|
|
|
|
|
|
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
|
|
|
|
|
|
|
|
if past_key_value is not None:
|
|
|
|
|
@@ -130,14 +136,15 @@ def llama_attention_forward(
|
|
|
|
|
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
|
|
|
|
|
def llama_flash_attention_2_forward(
|
|
|
|
|
self: "LlamaFlashAttention2",
|
|
|
|
|
hidden_states: torch.Tensor,
|
|
|
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
|
|
|
hidden_states: "torch.Tensor",
|
|
|
|
|
attention_mask: Optional["torch.Tensor"] = None,
|
|
|
|
|
position_ids: Optional["torch.LongTensor"] = None,
|
|
|
|
|
past_key_value: Optional["Cache"] = None,
|
|
|
|
|
output_attentions: bool = False,
|
|
|
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
|
|
|
cache_position: Optional["torch.LongTensor"] = None,
|
|
|
|
|
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None,
|
|
|
|
|
**kwargs,
|
|
|
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
|
|
|
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]:
|
|
|
|
|
# LlamaFlashAttention2 attention does not support output_attentions
|
|
|
|
|
output_attentions = False
|
|
|
|
|
|
|
|
|
|
@@ -151,7 +158,11 @@ def llama_flash_attention_2_forward(
|
|
|
|
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
|
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
|
|
|
|
|
|
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
|
|
|
if position_embeddings is None:
|
|
|
|
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
|
|
|
else:
|
|
|
|
|
cos, sin = position_embeddings
|
|
|
|
|
|
|
|
|
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
|
|
|
|
|
|
|
|
if past_key_value is not None:
|
|
|
|
|
@@ -198,9 +209,24 @@ def llama_flash_attention_2_forward(
|
|
|
|
|
if attention_mask is not None:
|
|
|
|
|
attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1)
|
|
|
|
|
|
|
|
|
|
attn_output: "torch.Tensor" = self._flash_attention_forward(
|
|
|
|
|
query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate
|
|
|
|
|
)
|
|
|
|
|
if is_transformers_version_greater_than_4_43():
|
|
|
|
|
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
|
|
|
|
|
|
|
|
|
attn_output: "torch.Tensor" = _flash_attention_forward(
|
|
|
|
|
query_states,
|
|
|
|
|
key_states,
|
|
|
|
|
value_states,
|
|
|
|
|
attention_mask,
|
|
|
|
|
query_states.size(1),
|
|
|
|
|
dropout=dropout_rate,
|
|
|
|
|
sliding_window=getattr(self, "sliding_window", None),
|
|
|
|
|
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
|
|
|
|
is_causal=self.is_causal,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
attn_output: "torch.Tensor" = self._flash_attention_forward(
|
|
|
|
|
query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
|
|
|
|
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
|
|
|
|
@@ -225,14 +251,15 @@ def llama_flash_attention_2_forward(
|
|
|
|
|
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
|
|
|
|
|
def llama_sdpa_attention_forward(
|
|
|
|
|
self: "LlamaSdpaAttention",
|
|
|
|
|
hidden_states: torch.Tensor,
|
|
|
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
|
|
|
hidden_states: "torch.Tensor",
|
|
|
|
|
attention_mask: Optional["torch.Tensor"] = None,
|
|
|
|
|
position_ids: Optional["torch.LongTensor"] = None,
|
|
|
|
|
past_key_value: Optional["Cache"] = None,
|
|
|
|
|
output_attentions: bool = False,
|
|
|
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
|
|
|
cache_position: Optional["torch.LongTensor"] = None,
|
|
|
|
|
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None,
|
|
|
|
|
**kwargs,
|
|
|
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
|
|
|
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]:
|
|
|
|
|
if output_attentions:
|
|
|
|
|
transformers_logger.warning_once(
|
|
|
|
|
"SDPA does not support `output_attentions=True`. Falling back to the vanilla attention"
|
|
|
|
|
@@ -258,7 +285,11 @@ def llama_sdpa_attention_forward(
|
|
|
|
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
|
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
|
|
|
|
|
|
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
|
|
|
if position_embeddings is None:
|
|
|
|
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
|
|
|
else:
|
|
|
|
|
cos, sin = position_embeddings
|
|
|
|
|
|
|
|
|
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
|
|
|
|
|
|
|
|
if past_key_value is not None:
|
|
|
|
|
@@ -322,7 +353,7 @@ def llama_sdpa_attention_forward(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _apply_llama_patch() -> None:
|
|
|
|
|
require_version("transformers>=4.41.2,<=4.42.4", "To fix: pip install transformers>=4.41.2,<=4.42.4")
|
|
|
|
|
require_version("transformers>=4.41.2,<=4.43.4", "To fix: pip install transformers>=4.41.2,<=4.43.4")
|
|
|
|
|
LlamaAttention.forward = llama_attention_forward
|
|
|
|
|
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
|
|
|
|
|
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
|
|
|
|
|
|