@@ -23,7 +23,6 @@ from typing import TYPE_CHECKING, Optional
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
from transformers.models.llama.modeling_llama import Cache, apply_rotary_pos_emb, repeat_kv
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
|
||||
@@ -32,7 +31,15 @@ from ...extras.packages import is_transformers_version_greater_than
|
||||
|
||||
|
||||
if not is_transformers_version_greater_than("4.48.0"):
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaFlashAttention2, LlamaSdpaAttention
|
||||
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
Cache,
|
||||
LlamaAttention,
|
||||
LlamaFlashAttention2,
|
||||
LlamaSdpaAttention,
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -206,9 +213,6 @@ def llama_flash_attention_2_forward(
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1)
|
||||
|
||||
if is_transformers_version_greater_than("4.43.0"):
|
||||
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
||||
|
||||
attn_output: torch.Tensor = _flash_attention_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
@@ -220,10 +224,6 @@ def llama_flash_attention_2_forward(
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
is_causal=self.is_causal,
|
||||
)
|
||||
else:
|
||||
attn_output: torch.Tensor = self._flash_attention_forward(
|
||||
query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate
|
||||
)
|
||||
|
||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
||||
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
||||
@@ -350,7 +350,7 @@ def llama_sdpa_attention_forward(
|
||||
|
||||
|
||||
def _apply_llama_patch() -> None:
|
||||
check_version("transformers>=4.43.0,<4.48.0", mandatory=True)
|
||||
check_version("transformers>=4.45.0,<4.48.0", mandatory=True)
|
||||
LlamaAttention.forward = llama_attention_forward
|
||||
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
|
||||
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
|
||||
|
||||
@@ -43,11 +43,6 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.packages import is_transformers_version_greater_than
|
||||
|
||||
|
||||
if is_transformers_version_greater_than("4.43.0"):
|
||||
import transformers.modeling_flash_attention_utils
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -116,5 +111,7 @@ def configure_packing(model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||
if not is_trainable or not model_args.block_diag_attn:
|
||||
return
|
||||
|
||||
import transformers.modeling_flash_attention_utils
|
||||
|
||||
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
|
||||
logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.")
|
||||
|
||||
Reference in New Issue
Block a user