imporve log
Former-commit-id: a6abf375975ffea3d51e1b944c9855b5f62ffac8
This commit is contained in:
@@ -15,9 +15,9 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.misc import check_version
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -35,8 +35,8 @@ def configure_attn_implementation(
|
||||
if getattr(config, "model_type", None) == "gemma2" and is_trainable:
|
||||
if model_args.flash_attn == "auto" or model_args.flash_attn == "fa2":
|
||||
if is_flash_attn_2_available():
|
||||
require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4")
|
||||
require_version("flash_attn>=2.6.3", "To fix: pip install flash_attn>=2.6.3")
|
||||
check_version("transformers>=4.42.4")
|
||||
check_version("flash_attn>=2.6.3")
|
||||
if model_args.flash_attn != "fa2":
|
||||
logger.warning_rank0("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.")
|
||||
model_args.flash_attn = "fa2"
|
||||
|
||||
Reference in New Issue
Block a user