[deps] adapt to transformers v5 (#10147)

Co-authored-by: frozenleaves <frozen@Mac.local>
Co-authored-by: hiyouga <hiyouga@buaa.edu.cn>
This commit is contained in:
浮梦
2026-02-02 12:07:19 +08:00
committed by GitHub
parent 762b480131
commit bf04ca6af8
23 changed files with 149 additions and 120 deletions

View File

@@ -20,6 +20,7 @@ from transformers import AutoTokenizer
from llamafactory.data import get_template_and_fix_tokenizer
from llamafactory.data.template import parse_template
from llamafactory.extras.packages import is_transformers_version_greater_than
from llamafactory.hparams import DataArguments
@@ -65,7 +66,6 @@ def _check_template(
template_name: str,
prompt_str: str,
answer_str: str,
use_fast: bool,
messages: list[dict[str, str]] = MESSAGES,
) -> None:
r"""Check template.
@@ -75,13 +75,15 @@ def _check_template(
template_name: the template name.
prompt_str: the string corresponding to the prompt part.
answer_str: the string corresponding to the answer part.
use_fast: whether to use fast tokenizer.
messages: the list of messages.
"""
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, token=HF_TOKEN)
tokenizer = AutoTokenizer.from_pretrained(model_id)
content_str = tokenizer.apply_chat_template(messages, tokenize=False)
content_ids = tokenizer.apply_chat_template(messages, tokenize=True)
if is_transformers_version_greater_than("5.0.0"):
content_ids = content_ids["input_ids"]
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template=template_name))
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, messages)
assert content_str == prompt_str + answer_str
@@ -90,9 +92,8 @@ def _check_template(
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("use_fast", [True, False])
def test_encode_oneturn(use_fast: bool):
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
def test_encode_oneturn():
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
prompt_str = (
@@ -106,9 +107,8 @@ def test_encode_oneturn(use_fast: bool):
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("use_fast", [True, False])
def test_encode_multiturn(use_fast: bool):
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
def test_encode_multiturn():
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES)
prompt_str_1 = (
@@ -128,11 +128,10 @@ def test_encode_multiturn(use_fast: bool):
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("use_fast", [True, False])
@pytest.mark.parametrize("cot_messages", [True, False])
@pytest.mark.parametrize("enable_thinking", [True, False, None])
def test_reasoning_encode_oneturn(use_fast: bool, cot_messages: bool, enable_thinking: bool):
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", use_fast=use_fast)
def test_reasoning_encode_oneturn(cot_messages: bool, enable_thinking: bool):
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
data_args = DataArguments(template="qwen3", enable_thinking=enable_thinking)
template = get_template_and_fix_tokenizer(tokenizer, data_args)
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES)
@@ -155,11 +154,10 @@ def test_reasoning_encode_oneturn(use_fast: bool, cot_messages: bool, enable_thi
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("use_fast", [True, False])
@pytest.mark.parametrize("cot_messages", [True, False])
@pytest.mark.parametrize("enable_thinking", [True, False, None])
def test_reasoning_encode_multiturn(use_fast: bool, cot_messages: bool, enable_thinking: bool):
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", use_fast=use_fast)
def test_reasoning_encode_multiturn(cot_messages: bool, enable_thinking: bool):
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
data_args = DataArguments(template="qwen3", enable_thinking=enable_thinking)
template = get_template_and_fix_tokenizer(tokenizer, data_args)
encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES)
@@ -185,10 +183,9 @@ def test_reasoning_encode_multiturn(use_fast: bool, cot_messages: bool, enable_t
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("use_fast", [True, False])
def test_jinja_template(use_fast: bool):
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
def test_jinja_template():
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
tokenizer.chat_template = template._get_jinja_template(tokenizer) # llama3 template no replace
assert tokenizer.chat_template != ref_tokenizer.chat_template
@@ -222,8 +219,7 @@ def test_get_stop_token_ids():
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
@pytest.mark.parametrize("use_fast", [True, False])
def test_gemma_template(use_fast: bool):
def test_gemma_template():
prompt_str = (
f"<bos><start_of_turn>user\n{MESSAGES[0]['content']}<end_of_turn>\n"
f"<start_of_turn>model\n{MESSAGES[1]['content']}<end_of_turn>\n"
@@ -231,13 +227,12 @@ def test_gemma_template(use_fast: bool):
"<start_of_turn>model\n"
)
answer_str = f"{MESSAGES[3]['content']}<end_of_turn>\n"
_check_template("google/gemma-3-4b-it", "gemma", prompt_str, answer_str, use_fast)
_check_template("google/gemma-3-4b-it", "gemma", prompt_str, answer_str)
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
@pytest.mark.parametrize("use_fast", [True, False])
def test_gemma2_template(use_fast: bool):
def test_gemma2_template():
prompt_str = (
f"<bos><start_of_turn>user\n{MESSAGES[0]['content']}<end_of_turn>\n"
f"<start_of_turn>model\n{MESSAGES[1]['content']}<end_of_turn>\n"
@@ -245,13 +240,12 @@ def test_gemma2_template(use_fast: bool):
"<start_of_turn>model\n"
)
answer_str = f"{MESSAGES[3]['content']}<end_of_turn>\n"
_check_template("google/gemma-2-2b-it", "gemma2", prompt_str, answer_str, use_fast)
_check_template("google/gemma-2-2b-it", "gemma2", prompt_str, answer_str)
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
@pytest.mark.parametrize("use_fast", [True, False])
def test_llama3_template(use_fast: bool):
def test_llama3_template():
prompt_str = (
f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{MESSAGES[0]['content']}<|eot_id|>"
f"<|start_header_id|>assistant<|end_header_id|>\n\n{MESSAGES[1]['content']}<|eot_id|>"
@@ -259,14 +253,11 @@ def test_llama3_template(use_fast: bool):
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
answer_str = f"{MESSAGES[3]['content']}<|eot_id|>"
_check_template("meta-llama/Meta-Llama-3-8B-Instruct", "llama3", prompt_str, answer_str, use_fast)
_check_template("meta-llama/Meta-Llama-3-8B-Instruct", "llama3", prompt_str, answer_str)
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize(
"use_fast", [True, pytest.param(False, marks=pytest.mark.xfail(reason="Llama 4 has no slow tokenizer."))]
)
def test_llama4_template(use_fast: bool):
def test_llama4_template():
prompt_str = (
f"<|begin_of_text|><|header_start|>user<|header_end|>\n\n{MESSAGES[0]['content']}<|eot|>"
f"<|header_start|>assistant<|header_end|>\n\n{MESSAGES[1]['content']}<|eot|>"
@@ -274,18 +265,11 @@ def test_llama4_template(use_fast: bool):
"<|header_start|>assistant<|header_end|>\n\n"
)
answer_str = f"{MESSAGES[3]['content']}<|eot|>"
_check_template(TINY_LLAMA4, "llama4", prompt_str, answer_str, use_fast)
_check_template(TINY_LLAMA4, "llama4", prompt_str, answer_str)
@pytest.mark.parametrize(
"use_fast",
[
pytest.param(True, marks=pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")),
pytest.param(False, marks=pytest.mark.xfail(reason="Phi-4 slow tokenizer is broken.")),
],
)
@pytest.mark.runs_on(["cpu", "mps"])
def test_phi4_template(use_fast: bool):
def test_phi4_template():
prompt_str = (
f"<|im_start|>user<|im_sep|>{MESSAGES[0]['content']}<|im_end|>"
f"<|im_start|>assistant<|im_sep|>{MESSAGES[1]['content']}<|im_end|>"
@@ -293,13 +277,12 @@ def test_phi4_template(use_fast: bool):
"<|im_start|>assistant<|im_sep|>"
)
answer_str = f"{MESSAGES[3]['content']}<|im_end|>"
_check_template("microsoft/phi-4", "phi4", prompt_str, answer_str, use_fast)
_check_template("microsoft/phi-4", "phi4", prompt_str, answer_str)
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")
@pytest.mark.parametrize("use_fast", [True, False])
def test_qwen2_5_template(use_fast: bool):
def test_qwen2_5_template():
prompt_str = (
"<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n"
f"<|im_start|>user\n{MESSAGES[0]['content']}<|im_end|>\n"
@@ -308,13 +291,12 @@ def test_qwen2_5_template(use_fast: bool):
"<|im_start|>assistant\n"
)
answer_str = f"{MESSAGES[3]['content']}<|im_end|>\n"
_check_template("Qwen/Qwen2.5-7B-Instruct", "qwen", prompt_str, answer_str, use_fast)
_check_template("Qwen/Qwen2.5-7B-Instruct", "qwen", prompt_str, answer_str)
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("use_fast", [True, False])
@pytest.mark.parametrize("cot_messages", [True, False])
def test_qwen3_template(use_fast: bool, cot_messages: bool):
def test_qwen3_template(cot_messages: bool):
prompt_str = (
f"<|im_start|>user\n{MESSAGES[0]['content']}<|im_end|>\n"
f"<|im_start|>assistant\n{MESSAGES[1]['content']}<|im_end|>\n"
@@ -328,12 +310,12 @@ def test_qwen3_template(use_fast: bool, cot_messages: bool):
answer_str = f"{MESSAGES_WITH_THOUGHT[3]['content']}<|im_end|>\n"
messages = MESSAGES_WITH_THOUGHT
_check_template("Qwen/Qwen3-8B", "qwen3", prompt_str, answer_str, use_fast, messages=messages)
_check_template("Qwen/Qwen3-8B", "qwen3", prompt_str, answer_str, messages=messages)
@pytest.mark.runs_on(["cpu", "mps"])
def test_parse_llama3_template():
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, token=HF_TOKEN)
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
template = parse_template(tokenizer)
assert template.format_user.slots == [
"<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
@@ -348,7 +330,7 @@ def test_parse_llama3_template():
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")
def test_parse_qwen_template():
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct", token=HF_TOKEN)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
template = parse_template(tokenizer)
assert template.__class__.__name__ == "Template"
assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
@@ -361,7 +343,7 @@ def test_parse_qwen_template():
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")
def test_parse_qwen3_template():
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", token=HF_TOKEN)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
template = parse_template(tokenizer)
assert template.__class__.__name__ == "ReasoningTemplate"
assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]