@@ -6,7 +6,12 @@ from llamafactory.hparams import get_infer_args
|
||||
from llamafactory.model import load_model, load_tokenizer
|
||||
|
||||
|
||||
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-LlamaForCausalLM")
|
||||
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||
|
||||
INFER_ARGS = {
|
||||
"model_name_or_path": TINY_LLAMA,
|
||||
"template": "llama3",
|
||||
}
|
||||
|
||||
|
||||
def test_attention():
|
||||
@@ -23,13 +28,7 @@ def test_attention():
|
||||
"fa2": "LlamaFlashAttention2",
|
||||
}
|
||||
for requested_attention in attention_available:
|
||||
model_args, _, finetuning_args, _ = get_infer_args(
|
||||
{
|
||||
"model_name_or_path": TINY_LLAMA,
|
||||
"template": "llama2",
|
||||
"flash_attn": requested_attention,
|
||||
}
|
||||
)
|
||||
model_args, _, finetuning_args, _ = get_infer_args({"flash_attn": requested_attention, **INFER_ARGS})
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args)
|
||||
for module in model.modules():
|
||||
|
||||
Reference in New Issue
Block a user