add llamafy_qwen.py
Former-commit-id: 6cdc91543c022edcc98076488f06e809fde9bad7
This commit is contained in:
@@ -128,7 +128,7 @@ def load_model_and_tokenizer(
|
||||
else:
|
||||
logger.warning("Current model does not support RoPE scaling.")
|
||||
|
||||
# Set FlashAttention-2
|
||||
# Set FlashAttention-2 and S^2-Attn
|
||||
if model_args.flash_attn:
|
||||
if getattr(config, "model_type", None) == "llama":
|
||||
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
|
||||
|
||||
Reference in New Issue
Block a user