Former-commit-id: f121d5c4f94af9f165132c4309cb9bdc8217d985
This commit is contained in:
hiyouga
2024-06-10 21:24:15 +08:00
parent 0ecf0d51e3
commit 784088db3f
6 changed files with 41 additions and 54 deletions

View File

@@ -6,7 +6,12 @@ from llamafactory.hparams import get_infer_args
from llamafactory.model import load_model, load_tokenizer
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-LlamaForCausalLM")
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
INFER_ARGS = {
"model_name_or_path": TINY_LLAMA,
"template": "llama3",
}
def test_attention():
@@ -23,13 +28,7 @@ def test_attention():
"fa2": "LlamaFlashAttention2",
}
for requested_attention in attention_available:
model_args, _, finetuning_args, _ = get_infer_args(
{
"model_name_or_path": TINY_LLAMA,
"template": "llama2",
"flash_attn": requested_attention,
}
)
model_args, _, finetuning_args, _ = get_infer_args({"flash_attn": requested_attention, **INFER_ARGS})
tokenizer_module = load_tokenizer(model_args)
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args)
for module in model.modules():