[data] optimize qwen3 loss computation (#7923)

This commit is contained in:
hoshi-hiyouga
2025-04-30 16:18:00 +08:00
committed by GitHub
parent 73198a6645
commit 052ca871bd
11 changed files with 205 additions and 39 deletions

View File

@@ -125,6 +125,37 @@ def test_encode_multiturn(use_fast: bool):
)
@pytest.mark.parametrize("use_fast", [True, False])
def test_reasoning_encode_oneturn(use_fast: bool):
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", use_fast=use_fast)
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="qwen3"))
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
prompt_str = (
"<|im_start|>user\nHow are you<|im_end|>\n"
"<|im_start|>assistant\nI am fine!<|im_end|>\n"
"<|im_start|>user\n你好<|im_end|>\n"
"<|im_start|>assistant\n<think>\n\n</think>\n\n"
)
answer_str = "很高兴认识你!<|im_end|>\n"
_check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))
@pytest.mark.parametrize("use_fast", [True, False])
def test_reasoning_encode_multiturn(use_fast: bool):
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", use_fast=use_fast)
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="qwen3"))
encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES)
prompt_str_1 = "<|im_start|>user\nHow are you<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
answer_str_1 = "I am fine!<|im_end|>\n"
prompt_str_2 = "<|im_start|>user\n你好<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
answer_str_2 = "很高兴认识你!<|im_end|>\n"
_check_tokenization(
tokenizer,
(encoded_pairs[0][0], encoded_pairs[0][1], encoded_pairs[1][0], encoded_pairs[1][1]),
(prompt_str_1, answer_str_1, prompt_str_2, answer_str_2),
)
@pytest.mark.parametrize("use_fast", [True, False])
def test_jinja_template(use_fast: bool):
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
@@ -227,6 +258,15 @@ def test_qwen2_5_template(use_fast: bool):
@pytest.mark.parametrize("use_fast", [True, False])
def test_qwen3_template(use_fast: bool):
prompt_str = (
"<|im_start|>user\nHow are you<|im_end|>\n"
"<|im_start|>assistant\nI am fine!<|im_end|>\n"
"<|im_start|>user\n你好<|im_end|>\n"
"<|im_start|>assistant\n<think>\n\n</think>\n\n"
)
answer_str = "很高兴认识你!<|im_end|>\n"
_check_template("Qwen/Qwen3-8B", "qwen3", prompt_str, answer_str, use_fast)
prompt_str = (
"<|im_start|>user\nHow are you<|im_end|>\n"
"<|im_start|>assistant\nI am fine!<|im_end|>\n"