fix flashattn warning
Former-commit-id: 6eb095d39bd82fdbdb729a0ea57fc7246e3a60d6
This commit is contained in:
@@ -5,11 +5,14 @@ from typing import Optional, Tuple
|
||||
from transformers.utils import logging
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv
|
||||
|
||||
is_flash_attn_2_available = False
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func # type: ignore
|
||||
from flash_attn.bert_padding import pad_input, unpad_input # type: ignore
|
||||
is_flash_attn_2_available = True
|
||||
except ImportError:
|
||||
print("FlashAttention-2 is not installed, ignore this if you are not using FlashAttention.")
|
||||
is_flash_attn_2_available = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
Reference in New Issue
Block a user