[misc] fix packing and eval plot (#7623)
This commit is contained in:
@@ -29,7 +29,8 @@ if TYPE_CHECKING:
|
||||
|
||||
HF_TOKEN = os.getenv("HF_TOKEN")
|
||||
|
||||
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||
TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
|
||||
TINY_LLAMA4 = os.getenv("TINY_LLAMA4", "llamafactory/tiny-random-Llama-4")
|
||||
|
||||
MESSAGES = [
|
||||
{"role": "user", "content": "How are you"},
|
||||
@@ -75,7 +76,7 @@ def _check_template(model_id: str, template_name: str, prompt_str: str, answer_s
|
||||
|
||||
@pytest.mark.parametrize("use_fast", [True, False])
|
||||
def test_encode_oneturn(use_fast: bool):
|
||||
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast)
|
||||
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
|
||||
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
|
||||
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
|
||||
prompt_str = (
|
||||
@@ -90,7 +91,7 @@ def test_encode_oneturn(use_fast: bool):
|
||||
|
||||
@pytest.mark.parametrize("use_fast", [True, False])
|
||||
def test_encode_multiturn(use_fast: bool):
|
||||
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast)
|
||||
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
|
||||
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
|
||||
encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES)
|
||||
prompt_str_1 = (
|
||||
@@ -111,8 +112,8 @@ def test_encode_multiturn(use_fast: bool):
|
||||
|
||||
@pytest.mark.parametrize("use_fast", [True, False])
|
||||
def test_jinja_template(use_fast: bool):
|
||||
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast)
|
||||
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast)
|
||||
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
|
||||
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
|
||||
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
|
||||
tokenizer.chat_template = template._get_jinja_template(tokenizer) # llama3 template no replace
|
||||
assert tokenizer.chat_template != ref_tokenizer.chat_template
|
||||
@@ -120,7 +121,7 @@ def test_jinja_template(use_fast: bool):
|
||||
|
||||
|
||||
def test_ollama_modelfile():
|
||||
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
|
||||
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
|
||||
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
|
||||
assert template.get_ollama_modelfile(tokenizer) == (
|
||||
"# ollama modelfile auto-generated by llamafactory\n\n"
|
||||
@@ -137,7 +138,7 @@ def test_ollama_modelfile():
|
||||
|
||||
|
||||
def test_get_stop_token_ids():
|
||||
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
|
||||
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
|
||||
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
|
||||
assert set(template.get_stop_token_ids(tokenizer)) == {128008, 128009}
|
||||
|
||||
@@ -152,7 +153,7 @@ def test_gemma_template(use_fast: bool):
|
||||
"<start_of_turn>model\n"
|
||||
)
|
||||
answer_str = "很高兴认识你!<end_of_turn>\n"
|
||||
_check_template("google/gemma-2-9b-it", "gemma", prompt_str, answer_str, use_fast)
|
||||
_check_template("google/gemma-3-4b-it", "gemma", prompt_str, answer_str, use_fast)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
|
||||
@@ -168,7 +169,20 @@ def test_llama3_template(use_fast: bool):
|
||||
_check_template("meta-llama/Meta-Llama-3-8B-Instruct", "llama3", prompt_str, answer_str, use_fast)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
|
||||
@pytest.mark.parametrize(
|
||||
"use_fast", [True, pytest.param(False, marks=pytest.mark.xfail(reason="Llama 4 has no slow tokenizer."))]
|
||||
)
|
||||
def test_llama4_template(use_fast: bool):
|
||||
prompt_str = (
|
||||
"<|begin_of_text|><|header_start|>user<|header_end|>\n\nHow are you<|eot|>"
|
||||
"<|header_start|>assistant<|header_end|>\n\nI am fine!<|eot|>"
|
||||
"<|header_start|>user<|header_end|>\n\n你好<|eot|>"
|
||||
"<|header_start|>assistant<|header_end|>\n\n"
|
||||
)
|
||||
answer_str = "很高兴认识你!<|eot|>"
|
||||
_check_template(TINY_LLAMA4, "llama4", prompt_str, answer_str, use_fast)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"use_fast", [True, pytest.param(False, marks=pytest.mark.xfail(reason="Phi-4 slow tokenizer is broken."))]
|
||||
)
|
||||
@@ -183,35 +197,21 @@ def test_phi4_template(use_fast: bool):
|
||||
_check_template("microsoft/phi-4", "phi4", prompt_str, answer_str, use_fast)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.") # TODO: why it is gated?
|
||||
@pytest.mark.parametrize("use_fast", [True, False])
|
||||
def test_qwen_template(use_fast: bool):
|
||||
prompt_str = (
|
||||
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
||||
"<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n"
|
||||
"<|im_start|>user\nHow are you<|im_end|>\n"
|
||||
"<|im_start|>assistant\nI am fine!<|im_end|>\n"
|
||||
"<|im_start|>user\n你好<|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
)
|
||||
answer_str = "很高兴认识你!<|im_end|>\n"
|
||||
_check_template("Qwen/Qwen2-7B-Instruct", "qwen", prompt_str, answer_str, use_fast)
|
||||
_check_template("Qwen/Qwen2.5-7B-Instruct", "qwen", prompt_str, answer_str, use_fast)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_fast", [True, False])
|
||||
@pytest.mark.xfail(reason="Yi tokenizer is broken.")
|
||||
def test_yi_template(use_fast: bool):
|
||||
prompt_str = (
|
||||
"<|im_start|>user\nHow are you<|im_end|>\n"
|
||||
"<|im_start|>assistant\nI am fine!<|im_end|>\n"
|
||||
"<|im_start|>user\n你好<|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
)
|
||||
answer_str = "很高兴认识你!<|im_end|>\n"
|
||||
_check_template("01-ai/Yi-1.5-6B-Chat", "yi", prompt_str, answer_str, use_fast)
|
||||
|
||||
|
||||
def test_parse_template():
|
||||
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, token=HF_TOKEN)
|
||||
def test_parse_llama3_template():
|
||||
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, token=HF_TOKEN)
|
||||
template = parse_template(tokenizer)
|
||||
assert template.format_user.slots == [
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
||||
@@ -223,12 +223,11 @@ def test_parse_template():
|
||||
assert template.default_system == ""
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
|
||||
def test_parse_qwen_template():
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B-Instruct", token=HF_TOKEN)
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct", token=HF_TOKEN)
|
||||
template = parse_template(tokenizer)
|
||||
assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
|
||||
assert template.format_assistant.slots == ["{{content}}<|im_end|>\n"]
|
||||
assert template.format_system.slots == ["<|im_start|>system\n{{content}}<|im_end|>\n"]
|
||||
assert template.format_prefix.slots == []
|
||||
assert template.default_system == "You are a helpful assistant."
|
||||
assert template.default_system == "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
|
||||
|
||||
Reference in New Issue
Block a user