mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-03 21:03:10 +00:00
[deps] adapt to transformers v5 (#10147)
Co-authored-by: frozenleaves <frozen@Mac.local> Co-authored-by: hiyouga <hiyouga@buaa.edu.cn>
This commit is contained in:
@@ -23,6 +23,13 @@ from llamafactory.v1.core.utils.rendering import Renderer
|
||||
from llamafactory.v1.utils.types import Processor
|
||||
|
||||
|
||||
def _get_input_ids(inputs: list | dict) -> list:
|
||||
if not isinstance(inputs, list):
|
||||
return inputs["input_ids"]
|
||||
else:
|
||||
return inputs
|
||||
|
||||
|
||||
HF_MESSAGES = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What is LLM?"},
|
||||
@@ -81,15 +88,15 @@ def test_chatml_rendering():
|
||||
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
|
||||
renderer = Renderer(template="chatml", processor=tokenizer)
|
||||
|
||||
hf_inputs = tokenizer.apply_chat_template(HF_MESSAGES[:-1], add_generation_prompt=True)
|
||||
hf_inputs = _get_input_ids(tokenizer.apply_chat_template(HF_MESSAGES[:-1], add_generation_prompt=True))
|
||||
v1_inputs = renderer.render_messages(V1_MESSAGES[:-1], is_generate=True)
|
||||
assert v1_inputs["input_ids"] == hf_inputs
|
||||
assert v1_inputs["attention_mask"] == [1] * len(hf_inputs)
|
||||
assert v1_inputs["labels"] == [-100] * len(hf_inputs)
|
||||
assert v1_inputs["loss_weights"] == [0.0] * len(hf_inputs)
|
||||
|
||||
hf_inputs_part = tokenizer.apply_chat_template(HF_MESSAGES[:-1], add_generation_prompt=False)
|
||||
hf_inputs_full = tokenizer.apply_chat_template(HF_MESSAGES, add_generation_prompt=False)
|
||||
hf_inputs_part = _get_input_ids(tokenizer.apply_chat_template(HF_MESSAGES[:-1], add_generation_prompt=False))
|
||||
hf_inputs_full = _get_input_ids(tokenizer.apply_chat_template(HF_MESSAGES, add_generation_prompt=False))
|
||||
v1_inputs_full = renderer.render_messages(V1_MESSAGES, is_generate=False)
|
||||
assert v1_inputs_full["input_ids"] == hf_inputs_full
|
||||
assert v1_inputs_full["attention_mask"] == [1] * len(hf_inputs_full)
|
||||
@@ -124,17 +131,21 @@ def test_qwen3_nothink_rendering():
|
||||
tokenizer: Processor = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B-Instruct-2507")
|
||||
renderer = Renderer(template="qwen3_nothink", processor=tokenizer)
|
||||
|
||||
hf_inputs = tokenizer.apply_chat_template(HF_MESSAGES_WITH_TOOLS[:-1], tools=V1_TOOLS, add_generation_prompt=True)
|
||||
hf_inputs = _get_input_ids(
|
||||
tokenizer.apply_chat_template(HF_MESSAGES_WITH_TOOLS[:-1], tools=V1_TOOLS, add_generation_prompt=True)
|
||||
)
|
||||
v1_inputs = renderer.render_messages(V1_MESSAGES_WITH_TOOLS[:-1], tools=json.dumps(V1_TOOLS), is_generate=True)
|
||||
assert v1_inputs["input_ids"] == hf_inputs
|
||||
assert v1_inputs["attention_mask"] == [1] * len(hf_inputs)
|
||||
assert v1_inputs["labels"] == [-100] * len(hf_inputs)
|
||||
assert v1_inputs["loss_weights"] == [0.0] * len(hf_inputs)
|
||||
|
||||
hf_inputs_part = tokenizer.apply_chat_template(
|
||||
HF_MESSAGES_WITH_TOOLS[:-1], tools=V1_TOOLS, add_generation_prompt=False
|
||||
hf_inputs_part = _get_input_ids(
|
||||
tokenizer.apply_chat_template(HF_MESSAGES_WITH_TOOLS[:-1], tools=V1_TOOLS, add_generation_prompt=False)
|
||||
)
|
||||
hf_inputs_full = _get_input_ids(
|
||||
tokenizer.apply_chat_template(HF_MESSAGES_WITH_TOOLS, tools=V1_TOOLS, add_generation_prompt=False)
|
||||
)
|
||||
hf_inputs_full = tokenizer.apply_chat_template(HF_MESSAGES_WITH_TOOLS, tools=V1_TOOLS, add_generation_prompt=False)
|
||||
v1_inputs_full = renderer.render_messages(V1_MESSAGES_WITH_TOOLS, tools=json.dumps(V1_TOOLS), is_generate=False)
|
||||
assert v1_inputs_full["input_ids"] == hf_inputs_full
|
||||
assert v1_inputs_full["attention_mask"] == [1] * len(hf_inputs_full)
|
||||
@@ -187,7 +198,7 @@ def test_qwen3_nothink_rendering_remote(num_samples: int):
|
||||
def test_process_sft_samples():
|
||||
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
|
||||
renderer = Renderer(template="chatml", processor=tokenizer)
|
||||
hf_inputs = tokenizer.apply_chat_template(HF_MESSAGES)
|
||||
hf_inputs = _get_input_ids(tokenizer.apply_chat_template(HF_MESSAGES))
|
||||
|
||||
samples = [{"messages": V1_MESSAGES, "extra_info": "test", "_dataset_name": "default"}]
|
||||
model_inputs = renderer.process_samples(samples)
|
||||
@@ -200,7 +211,7 @@ def test_process_sft_samples():
|
||||
def test_process_dpo_samples():
|
||||
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
|
||||
renderer = Renderer(template="chatml", processor=tokenizer)
|
||||
hf_inputs = tokenizer.apply_chat_template(HF_MESSAGES)
|
||||
hf_inputs = _get_input_ids(tokenizer.apply_chat_template(HF_MESSAGES))
|
||||
|
||||
samples = [
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user