[data] qwen3 fixes (#8109)
This commit is contained in:
@@ -126,29 +126,50 @@ def test_encode_multiturn(use_fast: bool):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_fast", [True, False])
|
||||
def test_reasoning_encode_oneturn(use_fast: bool):
|
||||
@pytest.mark.parametrize("cot_messages", [True, False])
|
||||
@pytest.mark.parametrize("enable_thinking", [True, False])
|
||||
def test_reasoning_encode_oneturn(use_fast: bool, cot_messages: bool, enable_thinking: bool):
|
||||
messages = MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", use_fast=use_fast)
|
||||
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="qwen3"))
|
||||
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
|
||||
data_args = DataArguments(template="qwen3", enable_thinking=enable_thinking)
|
||||
template = get_template_and_fix_tokenizer(tokenizer, data_args)
|
||||
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, messages)
|
||||
prompt_str = (
|
||||
"<|im_start|>user\nHow are you<|im_end|>\n"
|
||||
"<|im_start|>assistant\nI am fine!<|im_end|>\n"
|
||||
"<|im_start|>user\n你好<|im_end|>\n"
|
||||
"<|im_start|>assistant\n<think>\n\n</think>\n\n"
|
||||
f"<|im_start|>user\n{messages[0]['content']}<|im_end|>\n<|im_start|>assistant\n"
|
||||
f"{MESSAGES[1]['content']}<|im_end|>\n"
|
||||
f"<|im_start|>user\n{messages[2]['content']}<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
answer_str = "很高兴认识你!<|im_end|>\n"
|
||||
answer_str = f"{messages[3]['content']}<|im_end|>\n"
|
||||
if not cot_messages:
|
||||
if enable_thinking:
|
||||
answer_str = "<think>\n\n</think>\n\n" + answer_str
|
||||
else:
|
||||
prompt_str = prompt_str + "<think>\n\n</think>\n\n"
|
||||
|
||||
_check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_fast", [True, False])
|
||||
def test_reasoning_encode_multiturn(use_fast: bool):
|
||||
@pytest.mark.parametrize("cot_messages", [True, False])
|
||||
@pytest.mark.parametrize("enable_thinking", [True, False])
|
||||
def test_reasoning_encode_multiturn(use_fast: bool, cot_messages: bool, enable_thinking: bool):
|
||||
messages = MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", use_fast=use_fast)
|
||||
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="qwen3"))
|
||||
encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES)
|
||||
prompt_str_1 = "<|im_start|>user\nHow are you<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
|
||||
answer_str_1 = "I am fine!<|im_end|>\n"
|
||||
prompt_str_2 = "<|im_start|>user\n你好<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
|
||||
answer_str_2 = "很高兴认识你!<|im_end|>\n"
|
||||
data_args = DataArguments(template="qwen3", enable_thinking=enable_thinking)
|
||||
template = get_template_and_fix_tokenizer(tokenizer, data_args)
|
||||
encoded_pairs = template.encode_multiturn(tokenizer, messages)
|
||||
prompt_str_1 = f"<|im_start|>user\n{messages[0]['content']}<|im_end|>\n<|im_start|>assistant\n"
|
||||
answer_str_1 = f"{messages[1]['content']}<|im_end|>\n"
|
||||
prompt_str_2 = f"<|im_start|>user\n{messages[2]['content']}<|im_end|>\n<|im_start|>assistant\n"
|
||||
answer_str_2 = f"{messages[3]['content']}<|im_end|>\n"
|
||||
if not cot_messages:
|
||||
if enable_thinking:
|
||||
answer_str_1 = "<think>\n\n</think>\n\n" + answer_str_1
|
||||
answer_str_2 = "<think>\n\n</think>\n\n" + answer_str_2
|
||||
else:
|
||||
prompt_str_1 = prompt_str_1 + "<think>\n\n</think>\n\n"
|
||||
prompt_str_2 = prompt_str_2 + "<think>\n\n</think>\n\n"
|
||||
|
||||
_check_tokenization(
|
||||
tokenizer,
|
||||
(encoded_pairs[0][0], encoded_pairs[0][1], encoded_pairs[1][0], encoded_pairs[1][1]),
|
||||
@@ -193,12 +214,12 @@ def test_get_stop_token_ids():
|
||||
@pytest.mark.parametrize("use_fast", [True, False])
|
||||
def test_gemma_template(use_fast: bool):
|
||||
prompt_str = (
|
||||
"<bos><start_of_turn>user\nHow are you<end_of_turn>\n"
|
||||
"<start_of_turn>model\nI am fine!<end_of_turn>\n"
|
||||
"<start_of_turn>user\n你好<end_of_turn>\n"
|
||||
f"<bos><start_of_turn>user\n{MESSAGES[0]['content']}<end_of_turn>\n"
|
||||
f"<start_of_turn>model\n{MESSAGES[1]['content']}<end_of_turn>\n"
|
||||
f"<start_of_turn>user\n{MESSAGES[2]['content']}<end_of_turn>\n"
|
||||
"<start_of_turn>model\n"
|
||||
)
|
||||
answer_str = "很高兴认识你!<end_of_turn>\n"
|
||||
answer_str = f"{MESSAGES[3]['content']}<end_of_turn>\n"
|
||||
_check_template("google/gemma-3-4b-it", "gemma", prompt_str, answer_str, use_fast)
|
||||
|
||||
|
||||
@@ -206,12 +227,12 @@ def test_gemma_template(use_fast: bool):
|
||||
@pytest.mark.parametrize("use_fast", [True, False])
|
||||
def test_llama3_template(use_fast: bool):
|
||||
prompt_str = (
|
||||
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\nI am fine!<|eot_id|>"
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n你好<|eot_id|>"
|
||||
f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{MESSAGES[0]['content']}<|eot_id|>"
|
||||
f"<|start_header_id|>assistant<|end_header_id|>\n\n{MESSAGES[1]['content']}<|eot_id|>"
|
||||
f"<|start_header_id|>user<|end_header_id|>\n\n{MESSAGES[2]['content']}<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
answer_str = "很高兴认识你!<|eot_id|>"
|
||||
answer_str = f"{MESSAGES[3]['content']}<|eot_id|>"
|
||||
_check_template("meta-llama/Meta-Llama-3-8B-Instruct", "llama3", prompt_str, answer_str, use_fast)
|
||||
|
||||
|
||||
@@ -220,12 +241,12 @@ def test_llama3_template(use_fast: bool):
|
||||
)
|
||||
def test_llama4_template(use_fast: bool):
|
||||
prompt_str = (
|
||||
"<|begin_of_text|><|header_start|>user<|header_end|>\n\nHow are you<|eot|>"
|
||||
"<|header_start|>assistant<|header_end|>\n\nI am fine!<|eot|>"
|
||||
"<|header_start|>user<|header_end|>\n\n你好<|eot|>"
|
||||
f"<|begin_of_text|><|header_start|>user<|header_end|>\n\n{MESSAGES[0]['content']}<|eot|>"
|
||||
f"<|header_start|>assistant<|header_end|>\n\n{MESSAGES[1]['content']}<|eot|>"
|
||||
f"<|header_start|>user<|header_end|>\n\n{MESSAGES[2]['content']}<|eot|>"
|
||||
"<|header_start|>assistant<|header_end|>\n\n"
|
||||
)
|
||||
answer_str = "很高兴认识你!<|eot|>"
|
||||
answer_str = f"{MESSAGES[3]['content']}<|eot|>"
|
||||
_check_template(TINY_LLAMA4, "llama4", prompt_str, answer_str, use_fast)
|
||||
|
||||
|
||||
@@ -234,12 +255,12 @@ def test_llama4_template(use_fast: bool):
|
||||
)
|
||||
def test_phi4_template(use_fast: bool):
|
||||
prompt_str = (
|
||||
"<|im_start|>user<|im_sep|>How are you<|im_end|>"
|
||||
"<|im_start|>assistant<|im_sep|>I am fine!<|im_end|>"
|
||||
"<|im_start|>user<|im_sep|>你好<|im_end|>"
|
||||
f"<|im_start|>user<|im_sep|>{MESSAGES[0]['content']}<|im_end|>"
|
||||
f"<|im_start|>assistant<|im_sep|>{MESSAGES[1]['content']}<|im_end|>"
|
||||
f"<|im_start|>user<|im_sep|>{MESSAGES[2]['content']}<|im_end|>"
|
||||
"<|im_start|>assistant<|im_sep|>"
|
||||
)
|
||||
answer_str = "很高兴认识你!<|im_end|>"
|
||||
answer_str = f"{MESSAGES[3]['content']}<|im_end|>"
|
||||
_check_template("microsoft/phi-4", "phi4", prompt_str, answer_str, use_fast)
|
||||
|
||||
|
||||
@@ -247,34 +268,30 @@ def test_phi4_template(use_fast: bool):
|
||||
def test_qwen2_5_template(use_fast: bool):
|
||||
prompt_str = (
|
||||
"<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n"
|
||||
"<|im_start|>user\nHow are you<|im_end|>\n"
|
||||
"<|im_start|>assistant\nI am fine!<|im_end|>\n"
|
||||
"<|im_start|>user\n你好<|im_end|>\n"
|
||||
f"<|im_start|>user\n{MESSAGES[0]['content']}<|im_end|>\n"
|
||||
f"<|im_start|>assistant\n{MESSAGES[1]['content']}<|im_end|>\n"
|
||||
f"<|im_start|>user\n{MESSAGES[2]['content']}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
)
|
||||
answer_str = "很高兴认识你!<|im_end|>\n"
|
||||
answer_str = f"{MESSAGES[3]['content']}<|im_end|>\n"
|
||||
_check_template("Qwen/Qwen2.5-7B-Instruct", "qwen", prompt_str, answer_str, use_fast)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_fast", [True, False])
|
||||
def test_qwen3_template(use_fast: bool):
|
||||
@pytest.mark.parametrize("cot_messages", [True, False])
|
||||
def test_qwen3_template(use_fast: bool, cot_messages: bool):
|
||||
messages = MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES
|
||||
prompt_str = (
|
||||
"<|im_start|>user\nHow are you<|im_end|>\n"
|
||||
"<|im_start|>assistant\nI am fine!<|im_end|>\n"
|
||||
"<|im_start|>user\n你好<|im_end|>\n"
|
||||
"<|im_start|>assistant\n<think>\n\n</think>\n\n"
|
||||
)
|
||||
answer_str = "很高兴认识你!<|im_end|>\n"
|
||||
_check_template("Qwen/Qwen3-8B", "qwen3", prompt_str, answer_str, use_fast)
|
||||
|
||||
prompt_str = (
|
||||
"<|im_start|>user\nHow are you<|im_end|>\n"
|
||||
"<|im_start|>assistant\nI am fine!<|im_end|>\n"
|
||||
"<|im_start|>user\n你好<|im_end|>\n"
|
||||
f"<|im_start|>user\n{messages[0]['content']}<|im_end|>\n"
|
||||
f"<|im_start|>assistant\n{MESSAGES[1]['content']}<|im_end|>\n"
|
||||
f"<|im_start|>user\n{messages[2]['content']}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
)
|
||||
answer_str = "<think>\n模型思考内容\n</think>\n\n很高兴认识你!<|im_end|>\n"
|
||||
_check_template("Qwen/Qwen3-8B", "qwen3", prompt_str, answer_str, use_fast, messages=MESSAGES_WITH_THOUGHT)
|
||||
answer_str = f"{messages[3]['content']}<|im_end|>\n"
|
||||
if not cot_messages:
|
||||
answer_str = "<think>\n\n</think>\n\n" + answer_str
|
||||
|
||||
_check_template("Qwen/Qwen3-8B", "qwen3", prompt_str, answer_str, use_fast, messages=messages)
|
||||
|
||||
|
||||
def test_parse_llama3_template():
|
||||
@@ -293,6 +310,7 @@ def test_parse_llama3_template():
|
||||
def test_parse_qwen_template():
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct", token=HF_TOKEN)
|
||||
template = parse_template(tokenizer)
|
||||
assert template.__class__.__name__ == "Template"
|
||||
assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
|
||||
assert template.format_assistant.slots == ["{{content}}<|im_end|>\n"]
|
||||
assert template.format_system.slots == ["<|im_start|>system\n{{content}}<|im_end|>\n"]
|
||||
@@ -303,6 +321,7 @@ def test_parse_qwen_template():
|
||||
def test_parse_qwen3_template():
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", token=HF_TOKEN)
|
||||
template = parse_template(tokenizer)
|
||||
assert template.__class__.__name__ == "ReasoningTemplate"
|
||||
assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
|
||||
assert template.format_assistant.slots == ["{{content}}<|im_end|>\n"]
|
||||
assert template.format_system.slots == ["<|im_start|>system\n{{content}}<|im_end|>\n"]
|
||||
|
||||
Reference in New Issue
Block a user