fix flashattn warning

Former-commit-id: 6eb095d39bd82fdbdb729a0ea57fc7246e3a60d6
This commit is contained in:
hiyouga
2023-11-10 18:34:54 +08:00
parent f0766a2ab0
commit 989eccd286
2 changed files with 10 additions and 4 deletions

View File

@@ -123,9 +123,12 @@ def load_model_and_tokenizer(
# Set FlashAttention-2
if model_args.flash_attn:
if getattr(config, "model_type", None) == "llama":
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask
logger.info("Using FlashAttention-2 for faster training and inference.")
if LlamaPatches.is_flash_attn_2_available:
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask
logger.info("Using FlashAttention-2 for faster training and inference.")
else:
logger.warning("FlashAttention-2 is not installed.")
elif getattr(config, "model_type", None) in ["qwen", "Yi"]:
logger.info("Current model automatically enables FlashAttention if installed.")
else: