add llamafy_qwen.py

Former-commit-id: 6cdc91543c022edcc98076488f06e809fde9bad7
This commit is contained in:
hiyouga
2023-10-08 22:05:36 +08:00
parent 728dfb1be7
commit 33af3cbf37
3 changed files with 187 additions and 31 deletions

View File

@@ -128,7 +128,7 @@ def load_model_and_tokenizer(
else:
logger.warning("Current model does not support RoPE scaling.")
# Set FlashAttention-2
# Set FlashAttention-2 and S^2-Attn
if model_args.flash_attn:
if getattr(config, "model_type", None) == "llama":
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2