mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-02 20:43:38 +00:00
[deps] adapt to transformers v5 (#10147)
Co-authored-by: frozenleaves <frozen@Mac.local> Co-authored-by: hiyouga <hiyouga@buaa.edu.cn>
This commit is contained in:
@@ -18,7 +18,6 @@ Contains shared fixtures, pytest configuration, and custom markers.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -149,14 +148,7 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -
|
||||
devices_str = ",".join(str(i) for i in range(required))
|
||||
|
||||
monkeypatch.setenv(env_key, devices_str)
|
||||
|
||||
# add project root dir to path for mp run
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
os.environ["PYTHONPATH"] = project_root + os.pathsep + os.environ.get("PYTHONPATH", "")
|
||||
|
||||
monkeypatch.syspath_prepend(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
else: # non-distributed test
|
||||
if old_value:
|
||||
visible_devices = [v for v in old_value.split(",") if v != ""]
|
||||
|
||||
@@ -20,6 +20,7 @@ from datasets import load_dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from llamafactory.extras.constants import IGNORE_INDEX
|
||||
from llamafactory.extras.packages import is_transformers_version_greater_than
|
||||
from llamafactory.train.test_utils import load_dataset_module
|
||||
|
||||
|
||||
@@ -52,7 +53,12 @@ def test_feedback_data(num_samples: int):
|
||||
for index in indexes:
|
||||
messages = original_data["messages"][index]
|
||||
ref_input_ids = ref_tokenizer.apply_chat_template(messages)
|
||||
prompt_len = len(ref_tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True))
|
||||
ref_prompt_ids = ref_tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True)
|
||||
if is_transformers_version_greater_than("5.0.0"):
|
||||
ref_input_ids = ref_input_ids["input_ids"]
|
||||
ref_prompt_ids = ref_prompt_ids["input_ids"]
|
||||
|
||||
prompt_len = len(ref_prompt_ids)
|
||||
ref_labels = [IGNORE_INDEX] * prompt_len + ref_input_ids[prompt_len:]
|
||||
assert train_dataset["input_ids"][index] == ref_input_ids
|
||||
assert train_dataset["labels"][index] == ref_labels
|
||||
|
||||
@@ -20,6 +20,7 @@ from datasets import load_dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from llamafactory.extras.constants import IGNORE_INDEX
|
||||
from llamafactory.extras.packages import is_transformers_version_greater_than
|
||||
from llamafactory.train.test_utils import load_dataset_module
|
||||
|
||||
|
||||
@@ -63,13 +64,21 @@ def test_pairwise_data(num_samples: int):
|
||||
rejected_messages = original_data["conversations"][index] + [original_data["rejected"][index]]
|
||||
chosen_messages = _convert_sharegpt_to_openai(chosen_messages)
|
||||
rejected_messages = _convert_sharegpt_to_openai(rejected_messages)
|
||||
|
||||
ref_chosen_input_ids = ref_tokenizer.apply_chat_template(chosen_messages)
|
||||
chosen_prompt_len = len(ref_tokenizer.apply_chat_template(chosen_messages[:-1], add_generation_prompt=True))
|
||||
ref_chosen_labels = [IGNORE_INDEX] * chosen_prompt_len + ref_chosen_input_ids[chosen_prompt_len:]
|
||||
ref_chosen_prompt_ids = ref_tokenizer.apply_chat_template(chosen_messages[:-1], add_generation_prompt=True)
|
||||
ref_rejected_input_ids = ref_tokenizer.apply_chat_template(rejected_messages)
|
||||
rejected_prompt_len = len(
|
||||
ref_tokenizer.apply_chat_template(rejected_messages[:-1], add_generation_prompt=True)
|
||||
)
|
||||
ref_rejected_prompt_ids = ref_tokenizer.apply_chat_template(rejected_messages[:-1], add_generation_prompt=True)
|
||||
|
||||
if is_transformers_version_greater_than("5.0.0"):
|
||||
ref_chosen_input_ids = ref_chosen_input_ids["input_ids"]
|
||||
ref_rejected_input_ids = ref_rejected_input_ids["input_ids"]
|
||||
ref_chosen_prompt_ids = ref_chosen_prompt_ids["input_ids"]
|
||||
ref_rejected_prompt_ids = ref_rejected_prompt_ids["input_ids"]
|
||||
|
||||
chosen_prompt_len = len(ref_chosen_prompt_ids)
|
||||
rejected_prompt_len = len(ref_rejected_prompt_ids)
|
||||
ref_chosen_labels = [IGNORE_INDEX] * chosen_prompt_len + ref_chosen_input_ids[chosen_prompt_len:]
|
||||
ref_rejected_labels = [IGNORE_INDEX] * rejected_prompt_len + ref_rejected_input_ids[rejected_prompt_len:]
|
||||
assert train_dataset["chosen_input_ids"][index] == ref_chosen_input_ids
|
||||
assert train_dataset["chosen_labels"][index] == ref_chosen_labels
|
||||
|
||||
@@ -20,6 +20,7 @@ from datasets import load_dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from llamafactory.extras.constants import IGNORE_INDEX
|
||||
from llamafactory.extras.packages import is_transformers_version_greater_than
|
||||
from llamafactory.train.test_utils import load_dataset_module
|
||||
|
||||
|
||||
@@ -59,7 +60,16 @@ def test_supervised_single_turn(num_samples: int):
|
||||
{"role": "assistant", "content": original_data["output"][index]},
|
||||
]
|
||||
ref_input_ids = ref_tokenizer.apply_chat_template(messages)
|
||||
ref_prompt_ids = ref_tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True)
|
||||
|
||||
if is_transformers_version_greater_than("5.0.0"):
|
||||
ref_input_ids = ref_input_ids["input_ids"]
|
||||
ref_prompt_ids = ref_prompt_ids["input_ids"]
|
||||
|
||||
prompt_len = len(ref_prompt_ids)
|
||||
ref_label_ids = [IGNORE_INDEX] * prompt_len + ref_input_ids[prompt_len:]
|
||||
assert train_dataset["input_ids"][index] == ref_input_ids
|
||||
assert train_dataset["labels"][index] == ref_label_ids
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
@@ -73,6 +83,10 @@ def test_supervised_multi_turn(num_samples: int):
|
||||
indexes = random.choices(range(len(original_data)), k=num_samples)
|
||||
for index in indexes:
|
||||
ref_input_ids = ref_tokenizer.apply_chat_template(original_data["messages"][index])
|
||||
if is_transformers_version_greater_than("5.0.0"):
|
||||
ref_input_ids = ref_input_ids["input_ids"]
|
||||
|
||||
# cannot test the label ids in multi-turn case
|
||||
assert train_dataset["input_ids"][index] == ref_input_ids
|
||||
|
||||
|
||||
@@ -86,9 +100,12 @@ def test_supervised_train_on_prompt(num_samples: int):
|
||||
original_data = load_dataset(DEMO_DATA, name="system_chat", split="train")
|
||||
indexes = random.choices(range(len(original_data)), k=num_samples)
|
||||
for index in indexes:
|
||||
ref_ids = ref_tokenizer.apply_chat_template(original_data["messages"][index])
|
||||
assert train_dataset["input_ids"][index] == ref_ids
|
||||
assert train_dataset["labels"][index] == ref_ids
|
||||
ref_input_ids = ref_tokenizer.apply_chat_template(original_data["messages"][index])
|
||||
if is_transformers_version_greater_than("5.0.0"):
|
||||
ref_input_ids = ref_input_ids["input_ids"]
|
||||
|
||||
assert train_dataset["input_ids"][index] == ref_input_ids
|
||||
assert train_dataset["labels"][index] == ref_input_ids
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
@@ -103,7 +120,13 @@ def test_supervised_mask_history(num_samples: int):
|
||||
for index in indexes:
|
||||
messages = original_data["messages"][index]
|
||||
ref_input_ids = ref_tokenizer.apply_chat_template(messages)
|
||||
prompt_len = len(ref_tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True))
|
||||
ref_prompt_ids = ref_tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True)
|
||||
|
||||
if is_transformers_version_greater_than("5.0.0"):
|
||||
ref_input_ids = ref_input_ids["input_ids"]
|
||||
ref_prompt_ids = ref_prompt_ids["input_ids"]
|
||||
|
||||
prompt_len = len(ref_prompt_ids)
|
||||
ref_label_ids = [IGNORE_INDEX] * prompt_len + ref_input_ids[prompt_len:]
|
||||
assert train_dataset["input_ids"][index] == ref_input_ids
|
||||
assert train_dataset["labels"][index] == ref_label_ids
|
||||
|
||||
@@ -19,6 +19,7 @@ import pytest
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from llamafactory.extras.packages import is_transformers_version_greater_than
|
||||
from llamafactory.train.test_utils import load_dataset_module
|
||||
|
||||
|
||||
@@ -55,8 +56,13 @@ def test_unsupervised_data(num_samples: int):
|
||||
indexes = random.choices(range(len(original_data)), k=num_samples)
|
||||
for index in indexes:
|
||||
messages = original_data["messages"][index]
|
||||
ref_ids = ref_tokenizer.apply_chat_template(messages)
|
||||
ref_input_ids = ref_tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True)
|
||||
ref_labels = ref_ids[len(ref_input_ids) :]
|
||||
assert train_dataset["input_ids"][index] == ref_input_ids
|
||||
ref_input_ids = ref_tokenizer.apply_chat_template(messages)
|
||||
ref_prompt_ids = ref_tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True)
|
||||
|
||||
if is_transformers_version_greater_than("5.0.0"):
|
||||
ref_input_ids = ref_input_ids["input_ids"]
|
||||
ref_prompt_ids = ref_prompt_ids["input_ids"]
|
||||
|
||||
ref_labels = ref_input_ids[len(ref_prompt_ids) :]
|
||||
assert train_dataset["input_ids"][index] == ref_prompt_ids
|
||||
assert train_dataset["labels"][index] == ref_labels
|
||||
|
||||
@@ -17,7 +17,7 @@ import os
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoConfig, AutoModelForVision2Seq
|
||||
from transformers import AutoConfig, AutoModelForImageTextToText
|
||||
|
||||
from llamafactory.data import get_template_and_fix_tokenizer
|
||||
from llamafactory.data.collator import MultiModalDataCollatorForSeq2Seq, prepare_4d_attention_mask
|
||||
@@ -82,7 +82,7 @@ def test_multimodal_collator():
|
||||
template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args)
|
||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
|
||||
with torch.device("meta"):
|
||||
model = AutoModelForVision2Seq.from_config(config)
|
||||
model = AutoModelForImageTextToText.from_config(config)
|
||||
|
||||
data_collator = MultiModalDataCollatorForSeq2Seq(
|
||||
template=template,
|
||||
|
||||
@@ -20,6 +20,7 @@ from transformers import AutoTokenizer
|
||||
|
||||
from llamafactory.data import get_template_and_fix_tokenizer
|
||||
from llamafactory.data.template import parse_template
|
||||
from llamafactory.extras.packages import is_transformers_version_greater_than
|
||||
from llamafactory.hparams import DataArguments
|
||||
|
||||
|
||||
@@ -65,7 +66,6 @@ def _check_template(
|
||||
template_name: str,
|
||||
prompt_str: str,
|
||||
answer_str: str,
|
||||
use_fast: bool,
|
||||
messages: list[dict[str, str]] = MESSAGES,
|
||||
) -> None:
|
||||
r"""Check template.
|
||||
@@ -75,13 +75,15 @@ def _check_template(
|
||||
template_name: the template name.
|
||||
prompt_str: the string corresponding to the prompt part.
|
||||
answer_str: the string corresponding to the answer part.
|
||||
use_fast: whether to use fast tokenizer.
|
||||
messages: the list of messages.
|
||||
|
||||
"""
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, token=HF_TOKEN)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
content_str = tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
content_ids = tokenizer.apply_chat_template(messages, tokenize=True)
|
||||
if is_transformers_version_greater_than("5.0.0"):
|
||||
content_ids = content_ids["input_ids"]
|
||||
|
||||
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template=template_name))
|
||||
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, messages)
|
||||
assert content_str == prompt_str + answer_str
|
||||
@@ -90,9 +92,8 @@ def _check_template(
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
@pytest.mark.parametrize("use_fast", [True, False])
|
||||
def test_encode_oneturn(use_fast: bool):
|
||||
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
|
||||
def test_encode_oneturn():
|
||||
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
|
||||
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
|
||||
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
|
||||
prompt_str = (
|
||||
@@ -106,9 +107,8 @@ def test_encode_oneturn(use_fast: bool):
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
@pytest.mark.parametrize("use_fast", [True, False])
|
||||
def test_encode_multiturn(use_fast: bool):
|
||||
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
|
||||
def test_encode_multiturn():
|
||||
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
|
||||
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
|
||||
encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES)
|
||||
prompt_str_1 = (
|
||||
@@ -128,11 +128,10 @@ def test_encode_multiturn(use_fast: bool):
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
@pytest.mark.parametrize("use_fast", [True, False])
|
||||
@pytest.mark.parametrize("cot_messages", [True, False])
|
||||
@pytest.mark.parametrize("enable_thinking", [True, False, None])
|
||||
def test_reasoning_encode_oneturn(use_fast: bool, cot_messages: bool, enable_thinking: bool):
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", use_fast=use_fast)
|
||||
def test_reasoning_encode_oneturn(cot_messages: bool, enable_thinking: bool):
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
|
||||
data_args = DataArguments(template="qwen3", enable_thinking=enable_thinking)
|
||||
template = get_template_and_fix_tokenizer(tokenizer, data_args)
|
||||
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES)
|
||||
@@ -155,11 +154,10 @@ def test_reasoning_encode_oneturn(use_fast: bool, cot_messages: bool, enable_thi
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
@pytest.mark.parametrize("use_fast", [True, False])
|
||||
@pytest.mark.parametrize("cot_messages", [True, False])
|
||||
@pytest.mark.parametrize("enable_thinking", [True, False, None])
|
||||
def test_reasoning_encode_multiturn(use_fast: bool, cot_messages: bool, enable_thinking: bool):
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", use_fast=use_fast)
|
||||
def test_reasoning_encode_multiturn(cot_messages: bool, enable_thinking: bool):
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
|
||||
data_args = DataArguments(template="qwen3", enable_thinking=enable_thinking)
|
||||
template = get_template_and_fix_tokenizer(tokenizer, data_args)
|
||||
encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES)
|
||||
@@ -185,10 +183,9 @@ def test_reasoning_encode_multiturn(use_fast: bool, cot_messages: bool, enable_t
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
@pytest.mark.parametrize("use_fast", [True, False])
|
||||
def test_jinja_template(use_fast: bool):
|
||||
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
|
||||
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
|
||||
def test_jinja_template():
|
||||
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
|
||||
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
|
||||
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
|
||||
tokenizer.chat_template = template._get_jinja_template(tokenizer) # llama3 template no replace
|
||||
assert tokenizer.chat_template != ref_tokenizer.chat_template
|
||||
@@ -222,8 +219,7 @@ def test_get_stop_token_ids():
|
||||
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
|
||||
@pytest.mark.parametrize("use_fast", [True, False])
|
||||
def test_gemma_template(use_fast: bool):
|
||||
def test_gemma_template():
|
||||
prompt_str = (
|
||||
f"<bos><start_of_turn>user\n{MESSAGES[0]['content']}<end_of_turn>\n"
|
||||
f"<start_of_turn>model\n{MESSAGES[1]['content']}<end_of_turn>\n"
|
||||
@@ -231,13 +227,12 @@ def test_gemma_template(use_fast: bool):
|
||||
"<start_of_turn>model\n"
|
||||
)
|
||||
answer_str = f"{MESSAGES[3]['content']}<end_of_turn>\n"
|
||||
_check_template("google/gemma-3-4b-it", "gemma", prompt_str, answer_str, use_fast)
|
||||
_check_template("google/gemma-3-4b-it", "gemma", prompt_str, answer_str)
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
|
||||
@pytest.mark.parametrize("use_fast", [True, False])
|
||||
def test_gemma2_template(use_fast: bool):
|
||||
def test_gemma2_template():
|
||||
prompt_str = (
|
||||
f"<bos><start_of_turn>user\n{MESSAGES[0]['content']}<end_of_turn>\n"
|
||||
f"<start_of_turn>model\n{MESSAGES[1]['content']}<end_of_turn>\n"
|
||||
@@ -245,13 +240,12 @@ def test_gemma2_template(use_fast: bool):
|
||||
"<start_of_turn>model\n"
|
||||
)
|
||||
answer_str = f"{MESSAGES[3]['content']}<end_of_turn>\n"
|
||||
_check_template("google/gemma-2-2b-it", "gemma2", prompt_str, answer_str, use_fast)
|
||||
_check_template("google/gemma-2-2b-it", "gemma2", prompt_str, answer_str)
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
|
||||
@pytest.mark.parametrize("use_fast", [True, False])
|
||||
def test_llama3_template(use_fast: bool):
|
||||
def test_llama3_template():
|
||||
prompt_str = (
|
||||
f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{MESSAGES[0]['content']}<|eot_id|>"
|
||||
f"<|start_header_id|>assistant<|end_header_id|>\n\n{MESSAGES[1]['content']}<|eot_id|>"
|
||||
@@ -259,14 +253,11 @@ def test_llama3_template(use_fast: bool):
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
answer_str = f"{MESSAGES[3]['content']}<|eot_id|>"
|
||||
_check_template("meta-llama/Meta-Llama-3-8B-Instruct", "llama3", prompt_str, answer_str, use_fast)
|
||||
_check_template("meta-llama/Meta-Llama-3-8B-Instruct", "llama3", prompt_str, answer_str)
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
@pytest.mark.parametrize(
|
||||
"use_fast", [True, pytest.param(False, marks=pytest.mark.xfail(reason="Llama 4 has no slow tokenizer."))]
|
||||
)
|
||||
def test_llama4_template(use_fast: bool):
|
||||
def test_llama4_template():
|
||||
prompt_str = (
|
||||
f"<|begin_of_text|><|header_start|>user<|header_end|>\n\n{MESSAGES[0]['content']}<|eot|>"
|
||||
f"<|header_start|>assistant<|header_end|>\n\n{MESSAGES[1]['content']}<|eot|>"
|
||||
@@ -274,18 +265,11 @@ def test_llama4_template(use_fast: bool):
|
||||
"<|header_start|>assistant<|header_end|>\n\n"
|
||||
)
|
||||
answer_str = f"{MESSAGES[3]['content']}<|eot|>"
|
||||
_check_template(TINY_LLAMA4, "llama4", prompt_str, answer_str, use_fast)
|
||||
_check_template(TINY_LLAMA4, "llama4", prompt_str, answer_str)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"use_fast",
|
||||
[
|
||||
pytest.param(True, marks=pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")),
|
||||
pytest.param(False, marks=pytest.mark.xfail(reason="Phi-4 slow tokenizer is broken.")),
|
||||
],
|
||||
)
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
def test_phi4_template(use_fast: bool):
|
||||
def test_phi4_template():
|
||||
prompt_str = (
|
||||
f"<|im_start|>user<|im_sep|>{MESSAGES[0]['content']}<|im_end|>"
|
||||
f"<|im_start|>assistant<|im_sep|>{MESSAGES[1]['content']}<|im_end|>"
|
||||
@@ -293,13 +277,12 @@ def test_phi4_template(use_fast: bool):
|
||||
"<|im_start|>assistant<|im_sep|>"
|
||||
)
|
||||
answer_str = f"{MESSAGES[3]['content']}<|im_end|>"
|
||||
_check_template("microsoft/phi-4", "phi4", prompt_str, answer_str, use_fast)
|
||||
_check_template("microsoft/phi-4", "phi4", prompt_str, answer_str)
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
@pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")
|
||||
@pytest.mark.parametrize("use_fast", [True, False])
|
||||
def test_qwen2_5_template(use_fast: bool):
|
||||
def test_qwen2_5_template():
|
||||
prompt_str = (
|
||||
"<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n"
|
||||
f"<|im_start|>user\n{MESSAGES[0]['content']}<|im_end|>\n"
|
||||
@@ -308,13 +291,12 @@ def test_qwen2_5_template(use_fast: bool):
|
||||
"<|im_start|>assistant\n"
|
||||
)
|
||||
answer_str = f"{MESSAGES[3]['content']}<|im_end|>\n"
|
||||
_check_template("Qwen/Qwen2.5-7B-Instruct", "qwen", prompt_str, answer_str, use_fast)
|
||||
_check_template("Qwen/Qwen2.5-7B-Instruct", "qwen", prompt_str, answer_str)
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
@pytest.mark.parametrize("use_fast", [True, False])
|
||||
@pytest.mark.parametrize("cot_messages", [True, False])
|
||||
def test_qwen3_template(use_fast: bool, cot_messages: bool):
|
||||
def test_qwen3_template(cot_messages: bool):
|
||||
prompt_str = (
|
||||
f"<|im_start|>user\n{MESSAGES[0]['content']}<|im_end|>\n"
|
||||
f"<|im_start|>assistant\n{MESSAGES[1]['content']}<|im_end|>\n"
|
||||
@@ -328,12 +310,12 @@ def test_qwen3_template(use_fast: bool, cot_messages: bool):
|
||||
answer_str = f"{MESSAGES_WITH_THOUGHT[3]['content']}<|im_end|>\n"
|
||||
messages = MESSAGES_WITH_THOUGHT
|
||||
|
||||
_check_template("Qwen/Qwen3-8B", "qwen3", prompt_str, answer_str, use_fast, messages=messages)
|
||||
_check_template("Qwen/Qwen3-8B", "qwen3", prompt_str, answer_str, messages=messages)
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
def test_parse_llama3_template():
|
||||
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, token=HF_TOKEN)
|
||||
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
|
||||
template = parse_template(tokenizer)
|
||||
assert template.format_user.slots == [
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
||||
@@ -348,7 +330,7 @@ def test_parse_llama3_template():
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
@pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")
|
||||
def test_parse_qwen_template():
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct", token=HF_TOKEN)
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
|
||||
template = parse_template(tokenizer)
|
||||
assert template.__class__.__name__ == "Template"
|
||||
assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
|
||||
@@ -361,7 +343,7 @@ def test_parse_qwen_template():
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
@pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")
|
||||
def test_parse_qwen3_template():
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", token=HF_TOKEN)
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
|
||||
template = parse_template(tokenizer)
|
||||
assert template.__class__.__name__ == "ReasoningTemplate"
|
||||
assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
|
||||
|
||||
@@ -16,7 +16,8 @@ import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoModelForVision2Seq
|
||||
from safetensors.torch import load_file
|
||||
from transformers import AutoConfig, AutoModelForImageTextToText
|
||||
|
||||
from llamafactory.extras.packages import is_transformers_version_greater_than
|
||||
from llamafactory.hparams import FinetuningArguments, ModelArguments
|
||||
@@ -36,7 +37,7 @@ def test_visual_full(freeze_vision_tower: bool, freeze_multi_modal_projector: bo
|
||||
)
|
||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
|
||||
with torch.device("meta"):
|
||||
model = AutoModelForVision2Seq.from_config(config)
|
||||
model = AutoModelForImageTextToText.from_config(config)
|
||||
|
||||
model = init_adapter(config, model, model_args, finetuning_args, is_trainable=True)
|
||||
for name, param in model.named_parameters():
|
||||
@@ -56,7 +57,7 @@ def test_visual_lora(freeze_vision_tower: bool, freeze_language_model: bool):
|
||||
)
|
||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
|
||||
with torch.device("meta"):
|
||||
model = AutoModelForVision2Seq.from_config(config)
|
||||
model = AutoModelForImageTextToText.from_config(config)
|
||||
|
||||
model = init_adapter(config, model, model_args, finetuning_args, is_trainable=True)
|
||||
trainable_params, frozen_params = set(), set()
|
||||
@@ -86,13 +87,14 @@ def test_visual_model_save_load():
|
||||
finetuning_args = FinetuningArguments(finetuning_type="full")
|
||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
|
||||
with torch.device("meta"):
|
||||
model = AutoModelForVision2Seq.from_config(config)
|
||||
model = AutoModelForImageTextToText.from_config(config)
|
||||
|
||||
model = init_adapter(config, model, model_args, finetuning_args, is_trainable=False)
|
||||
model.to_empty(device="cpu")
|
||||
loaded_model_weight = dict(model.named_parameters())
|
||||
|
||||
model.save_pretrained(os.path.join("output", "qwen2_vl"), max_shard_size="10GB", safe_serialization=False)
|
||||
saved_model_weight = torch.load(os.path.join("output", "qwen2_vl", "pytorch_model.bin"), weights_only=False)
|
||||
model.save_pretrained(os.path.join("output", "qwen2_vl"), max_shard_size="10GB", safe_serialization=True)
|
||||
saved_model_weight = load_file(os.path.join("output", "qwen2_vl", "model.safetensors"))
|
||||
|
||||
if is_transformers_version_greater_than("4.52.0"):
|
||||
assert "model.language_model.layers.0.self_attn.q_proj.weight" in loaded_model_weight
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
# change if test fails or cache is outdated
|
||||
0.9.5.105
|
||||
0.9.5.106
|
||||
|
||||
Reference in New Issue
Block a user