add unittest

Former-commit-id: 8a1f0c5f922989e08a19c65de0b2c4afd2a5771f
This commit is contained in:
hiyouga
2024-07-19 01:06:27 +08:00
parent 4c1513a845
commit 994b9089e9
16 changed files with 436 additions and 260 deletions

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import os
from typing import TYPE_CHECKING, Sequence
from typing import TYPE_CHECKING, List, Sequence
import pytest
from transformers import AutoTokenizer
@@ -39,7 +39,7 @@ MESSAGES = [
def _check_tokenization(
tokenizer: "PreTrainedTokenizer", batch_input_ids: Sequence[Sequence[int]], batch_text: Sequence[str]
):
) -> None:
for input_ids, text in zip(batch_input_ids, batch_text):
assert input_ids == tokenizer.encode(text, add_special_tokens=False)
assert tokenizer.decode(input_ids) == text
@@ -47,7 +47,7 @@ def _check_tokenization(
def _check_single_template(
model_id: str, template_name: str, prompt_str: str, answer_str: str, extra_str: str, use_fast: bool
):
) -> List[str]:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, token=HF_TOKEN)
content_str = tokenizer.apply_chat_template(MESSAGES, tokenize=False)
content_ids = tokenizer.apply_chat_template(MESSAGES, tokenize=True)
@@ -59,7 +59,7 @@ def _check_single_template(
return content_ids
def _check_template(model_id: str, template_name: str, prompt_str: str, answer_str: str, extra_str: str = ""):
def _check_template(model_id: str, template_name: str, prompt_str: str, answer_str: str, extra_str: str = "") -> None:
"""
Checks template for both the slow tokenizer and the fast tokenizer.