fix ci
Former-commit-id: 95aceebd61d195be5c980a919c12c59b56722898
This commit is contained in:
@@ -23,13 +23,15 @@ def test_attention():
|
||||
"fa2": "LlamaFlashAttention2",
|
||||
}
|
||||
for requested_attention in attention_available:
|
||||
model_args, _, finetuning_args, _ = get_infer_args({
|
||||
"model_name_or_path": TINY_LLAMA,
|
||||
"template": "llama2",
|
||||
"flash_attn": requested_attention,
|
||||
})
|
||||
model_args, _, finetuning_args, _ = get_infer_args(
|
||||
{
|
||||
"model_name_or_path": TINY_LLAMA,
|
||||
"template": "llama2",
|
||||
"flash_attn": requested_attention,
|
||||
}
|
||||
)
|
||||
tokenizer = load_tokenizer(model_args)
|
||||
model = load_model(tokenizer["tokenizer"], model_args, finetuning_args)
|
||||
for module in model.modules():
|
||||
if "Attention" in module.__class__.__name__:
|
||||
assert module.__class__.__name__ == llama_attention_classes[requested_attention]
|
||||
assert module.__class__.__name__ == llama_attention_classes[requested_attention]
|
||||
Reference in New Issue
Block a user