improve template, add phi4 model

Former-commit-id: a785b6796e445a3adba45c5b6947166a2ff99871
This commit is contained in:
hiyouga
2025-01-09 18:27:20 +00:00
parent 4e25d037c8
commit 867980196e
5 changed files with 147 additions and 125 deletions

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import os
from typing import TYPE_CHECKING, List, Sequence
from typing import TYPE_CHECKING, Sequence
import pytest
from transformers import AutoTokenizer
@@ -42,39 +42,36 @@ MESSAGES = [
def _check_tokenization(
tokenizer: "PreTrainedTokenizer", batch_input_ids: Sequence[Sequence[int]], batch_text: Sequence[str]
) -> None:
r"""
Checks token ids and texts.
encode(text) == token_ids
decode(token_ids) == text
"""
for input_ids, text in zip(batch_input_ids, batch_text):
assert input_ids == tokenizer.encode(text, add_special_tokens=False)
assert tokenizer.encode(text, add_special_tokens=False) == input_ids
assert tokenizer.decode(input_ids) == text
def _check_single_template(
model_id: str, template_name: str, prompt_str: str, answer_str: str, extra_str: str, use_fast: bool
) -> List[str]:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, token=HF_TOKEN)
content_str = tokenizer.apply_chat_template(MESSAGES, tokenize=False)
content_ids = tokenizer.apply_chat_template(MESSAGES, tokenize=True)
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template=template_name))
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
assert content_str == prompt_str + answer_str + extra_str
assert content_ids == prompt_ids + answer_ids + tokenizer.encode(extra_str, add_special_tokens=False)
_check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))
return content_ids
def _check_template(model_id: str, template_name: str, prompt_str: str, answer_str: str, extra_str: str = "") -> None:
"""
Checks template for both the slow tokenizer and the fast tokenizer.
def _check_template(model_id: str, template_name: str, prompt_str: str, answer_str: str, use_fast: bool) -> None:
r"""
Checks template.
Args:
model_id: the model id on hugging face hub.
template_name: the template name.
prompt_str: the string corresponding to the prompt part.
answer_str: the string corresponding to the answer part.
extra_str: the extra string in the jinja template of the original tokenizer.
use_fast: whether to use fast tokenizer.
"""
slow_ids = _check_single_template(model_id, template_name, prompt_str, answer_str, extra_str, use_fast=False)
fast_ids = _check_single_template(model_id, template_name, prompt_str, answer_str, extra_str, use_fast=True)
assert slow_ids == fast_ids
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, token=HF_TOKEN)
content_str = tokenizer.apply_chat_template(MESSAGES, tokenize=False)
content_ids = tokenizer.apply_chat_template(MESSAGES, tokenize=True)
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template=template_name))
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
assert content_str == prompt_str + answer_str
assert content_ids == prompt_ids + answer_ids
_check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))
@pytest.mark.parametrize("use_fast", [True, False])
@@ -125,19 +122,21 @@ def test_jinja_template(use_fast: bool):
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
def test_gemma_template():
@pytest.mark.parametrize("use_fast", [True, False])
def test_gemma_template(use_fast: bool):
prompt_str = (
"<bos><start_of_turn>user\nHow are you<end_of_turn>\n"
"<start_of_turn>model\nI am fine!<end_of_turn>\n"
"<start_of_turn>user\n你好<end_of_turn>\n"
"<start_of_turn>model\n"
)
answer_str = "很高兴认识你!"
_check_template("google/gemma-2-9b-it", "gemma", prompt_str, answer_str, extra_str="<end_of_turn>\n")
answer_str = "很高兴认识你!<end_of_turn>\n"
_check_template("google/gemma-2-9b-it", "gemma", prompt_str, answer_str, use_fast)
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
def test_llama3_template():
@pytest.mark.parametrize("use_fast", [True, False])
def test_llama3_template(use_fast: bool):
prompt_str = (
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\nI am fine!<|eot_id|>"
@@ -145,10 +144,25 @@ def test_llama3_template():
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
answer_str = "很高兴认识你!<|eot_id|>"
_check_template("meta-llama/Meta-Llama-3-8B-Instruct", "llama3", prompt_str, answer_str)
_check_template("meta-llama/Meta-Llama-3-8B-Instruct", "llama3", prompt_str, answer_str, use_fast)
def test_qwen_template():
@pytest.mark.parametrize(
"use_fast", [True, pytest.param(False, marks=pytest.mark.xfail(reason="Phi-4 slow tokenizer is broken."))]
)
def test_phi4_template(use_fast: bool):
prompt_str = (
"<|im_start|>user<|im_sep|>How are you<|im_end|>"
"<|im_start|>assistant<|im_sep|>I am fine!<|im_end|>"
"<|im_start|>user<|im_sep|>你好<|im_end|>"
"<|im_start|>assistant<|im_sep|>"
)
answer_str = "很高兴认识你!<|im_end|>"
_check_template("microsoft/phi-4", "phi4", prompt_str, answer_str, use_fast)
@pytest.mark.parametrize("use_fast", [True, False])
def test_qwen_template(use_fast: bool):
prompt_str = (
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
"<|im_start|>user\nHow are you<|im_end|>\n"
@@ -156,17 +170,18 @@ def test_qwen_template():
"<|im_start|>user\n你好<|im_end|>\n"
"<|im_start|>assistant\n"
)
answer_str = "很高兴认识你!<|im_end|>"
_check_template("Qwen/Qwen2-7B-Instruct", "qwen", prompt_str, answer_str, extra_str="\n")
answer_str = "很高兴认识你!<|im_end|>\n"
_check_template("Qwen/Qwen2-7B-Instruct", "qwen", prompt_str, answer_str, use_fast)
@pytest.mark.xfail(reason="The fast tokenizer of Yi model is corrupted.")
def test_yi_template():
@pytest.mark.parametrize("use_fast", [True, False])
@pytest.mark.xfail(reason="Yi tokenizer is broken.")
def test_yi_template(use_fast: bool):
prompt_str = (
"<|im_start|>user\nHow are you<|im_end|>\n"
"<|im_start|>assistant\nI am fine!<|im_end|>\n"
"<|im_start|>user\n你好<|im_end|>\n"
"<|im_start|>assistant\n"
)
answer_str = "很高兴认识你!<|im_end|>"
_check_template("01-ai/Yi-1.5-6B-Chat", "yi", prompt_str, answer_str)
answer_str = "很高兴认识你!<|im_end|>\n"
_check_template("01-ai/Yi-1.5-6B-Chat", "yi", prompt_str, answer_str, use_fast)