update get template

Former-commit-id: 21ea0d0786f91c0bce79630963e66b815a6792a0
This commit is contained in:
hiyouga
2024-09-04 22:36:20 +08:00
parent 5d85be31ca
commit af178cbcd1
17 changed files with 57 additions and 56 deletions

View File

@@ -19,6 +19,7 @@ import pytest
from transformers import AutoTokenizer
from llamafactory.data import get_template_and_fix_tokenizer
from llamafactory.hparams import DataArguments
if TYPE_CHECKING:
@@ -51,7 +52,7 @@ def _check_single_template(
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, token=HF_TOKEN)
content_str = tokenizer.apply_chat_template(MESSAGES, tokenize=False)
content_ids = tokenizer.apply_chat_template(MESSAGES, tokenize=True)
template = get_template_and_fix_tokenizer(tokenizer, name=template_name)
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template=template_name))
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
assert content_str == prompt_str + answer_str + extra_str
assert content_ids == prompt_ids + answer_ids + tokenizer.encode(extra_str, add_special_tokens=False)
@@ -78,7 +79,7 @@ def _check_template(model_id: str, template_name: str, prompt_str: str, answer_s
@pytest.mark.parametrize("use_fast", [True, False])
def test_encode_oneturn(use_fast: bool):
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast)
template = get_template_and_fix_tokenizer(tokenizer, name="llama3")
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
prompt_str = (
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"
@@ -93,7 +94,7 @@ def test_encode_oneturn(use_fast: bool):
@pytest.mark.parametrize("use_fast", [True, False])
def test_encode_multiturn(use_fast: bool):
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast)
template = get_template_and_fix_tokenizer(tokenizer, name="llama3")
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES)
prompt_str_1 = (
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"