[data] optimize qwen3 loss computation (#7923)
This commit is contained in:
@@ -125,6 +125,37 @@ def test_encode_multiturn(use_fast: bool):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_fast", [True, False])
|
||||
def test_reasoning_encode_oneturn(use_fast: bool):
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", use_fast=use_fast)
|
||||
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="qwen3"))
|
||||
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
|
||||
prompt_str = (
|
||||
"<|im_start|>user\nHow are you<|im_end|>\n"
|
||||
"<|im_start|>assistant\nI am fine!<|im_end|>\n"
|
||||
"<|im_start|>user\n你好<|im_end|>\n"
|
||||
"<|im_start|>assistant\n<think>\n\n</think>\n\n"
|
||||
)
|
||||
answer_str = "很高兴认识你!<|im_end|>\n"
|
||||
_check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_fast", [True, False])
|
||||
def test_reasoning_encode_multiturn(use_fast: bool):
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", use_fast=use_fast)
|
||||
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="qwen3"))
|
||||
encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES)
|
||||
prompt_str_1 = "<|im_start|>user\nHow are you<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
|
||||
answer_str_1 = "I am fine!<|im_end|>\n"
|
||||
prompt_str_2 = "<|im_start|>user\n你好<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
|
||||
answer_str_2 = "很高兴认识你!<|im_end|>\n"
|
||||
_check_tokenization(
|
||||
tokenizer,
|
||||
(encoded_pairs[0][0], encoded_pairs[0][1], encoded_pairs[1][0], encoded_pairs[1][1]),
|
||||
(prompt_str_1, answer_str_1, prompt_str_2, answer_str_2),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_fast", [True, False])
|
||||
def test_jinja_template(use_fast: bool):
|
||||
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
|
||||
@@ -227,6 +258,15 @@ def test_qwen2_5_template(use_fast: bool):
|
||||
|
||||
@pytest.mark.parametrize("use_fast", [True, False])
|
||||
def test_qwen3_template(use_fast: bool):
|
||||
prompt_str = (
|
||||
"<|im_start|>user\nHow are you<|im_end|>\n"
|
||||
"<|im_start|>assistant\nI am fine!<|im_end|>\n"
|
||||
"<|im_start|>user\n你好<|im_end|>\n"
|
||||
"<|im_start|>assistant\n<think>\n\n</think>\n\n"
|
||||
)
|
||||
answer_str = "很高兴认识你!<|im_end|>\n"
|
||||
_check_template("Qwen/Qwen3-8B", "qwen3", prompt_str, answer_str, use_fast)
|
||||
|
||||
prompt_str = (
|
||||
"<|im_start|>user\nHow are you<|im_end|>\n"
|
||||
"<|im_start|>assistant\nI am fine!<|im_end|>\n"
|
||||
|
||||
Reference in New Issue
Block a user