clean code

Former-commit-id: f54cafd5c7f0383370d1a2f357834a61a97397ce
This commit is contained in:
hiyouga
2024-06-13 01:58:16 +08:00
parent 04d7629abf
commit 0a75224f62
4 changed files with 17 additions and 27 deletions

View File

@@ -1,7 +1,8 @@
from typing import TYPE_CHECKING
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
from ...extras.logging import get_logger
from ...extras.packages import is_flash_attn2_available, is_sdpa_available
if TYPE_CHECKING:
@@ -21,13 +22,13 @@ def configure_attn_implementation(config: "PretrainedConfig", model_args: "Model
requested_attn_implementation = "eager"
elif model_args.flash_attn == "sdpa":
if not is_sdpa_available():
if not is_torch_sdpa_available():
logger.warning("torch>=2.1.1 is required for SDPA attention.")
return
requested_attn_implementation = "sdpa"
elif model_args.flash_attn == "fa2":
if not is_flash_attn2_available():
if not is_flash_attn_2_available():
logger.warning("FlashAttention-2 is not installed.")
return