Former-commit-id: 95aceebd61d195be5c980a919c12c59b56722898
This commit is contained in:
hiyouga
2024-06-08 01:57:36 +08:00
parent 6d17c59090
commit 1364190a66
3 changed files with 12 additions and 10 deletions

View File

@@ -23,13 +23,15 @@ def test_attention():
"fa2": "LlamaFlashAttention2",
}
for requested_attention in attention_available:
model_args, _, finetuning_args, _ = get_infer_args({
"model_name_or_path": TINY_LLAMA,
"template": "llama2",
"flash_attn": requested_attention,
})
model_args, _, finetuning_args, _ = get_infer_args(
{
"model_name_or_path": TINY_LLAMA,
"template": "llama2",
"flash_attn": requested_attention,
}
)
tokenizer = load_tokenizer(model_args)
model = load_model(tokenizer["tokenizer"], model_args, finetuning_args)
for module in model.modules():
if "Attention" in module.__class__.__name__:
assert module.__class__.__name__ == llama_attention_classes[requested_attention]
assert module.__class__.__name__ == llama_attention_classes[requested_attention]