[deps] adapt to transformers v5 (#10147)

Co-authored-by: frozenleaves <frozen@Mac.local>
Co-authored-by: hiyouga <hiyouga@buaa.edu.cn>
This commit is contained in:
浮梦
2026-02-02 12:07:19 +08:00
committed by GitHub
parent 762b480131
commit bf04ca6af8
23 changed files with 149 additions and 120 deletions

View File

@@ -40,10 +40,10 @@ dependencies = [
"torch>=2.4.0", "torch>=2.4.0",
"torchvision>=0.19.0", "torchvision>=0.19.0",
"torchaudio>=2.4.0", "torchaudio>=2.4.0",
"transformers>=4.51.0,<=4.57.1,!=4.52.0,!=4.57.0", "transformers>=4.51.0,<=5.0.0,!=4.52.0,!=4.57.0",
"datasets>=2.16.0,<=4.0.0", "datasets>=2.16.0,<=4.0.0",
"accelerate>=1.3.0,<=1.11.0", "accelerate>=1.3.0,<=1.11.0",
"peft>=0.14.0,<=0.17.1", "peft>=0.18.0,<=0.18.1",
"trl>=0.18.0,<=0.24.0", "trl>=0.18.0,<=0.24.0",
"torchdata>=0.10.0,<=0.11.0", "torchdata>=0.10.0,<=0.11.0",
# gui # gui

View File

@@ -1 +1 @@
deepspeed>=0.10.0,<=0.16.9 deepspeed>=0.10.0,<=0.18.4

View File

@@ -94,10 +94,10 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
def check_dependencies() -> None: def check_dependencies() -> None:
r"""Check the version of the required packages.""" r"""Check the version of the required packages."""
check_version("transformers>=4.51.0,<=4.57.1") check_version("transformers>=4.51.0,<=5.0.0")
check_version("datasets>=2.16.0,<=4.0.0") check_version("datasets>=2.16.0,<=4.0.0")
check_version("accelerate>=1.3.0,<=1.11.0") check_version("accelerate>=1.3.0,<=1.11.0")
check_version("peft>=0.14.0,<=0.17.1") check_version("peft>=0.18.0,<=0.18.1")
check_version("trl>=0.18.0,<=0.24.0") check_version("trl>=0.18.0,<=0.24.0")

View File

@@ -65,7 +65,9 @@ class DataArguments:
) )
mix_strategy: Literal["concat", "interleave_under", "interleave_over", "interleave_once"] = field( mix_strategy: Literal["concat", "interleave_under", "interleave_over", "interleave_once"] = field(
default="concat", default="concat",
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling/sampling w.o. replacement)."}, metadata={
"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling/sampling w.o. replacement)."
},
) )
interleave_probs: str | None = field( interleave_probs: str | None = field(
default=None, default=None,

View File

@@ -206,9 +206,6 @@ class BaseModelArguments:
if self.model_name_or_path is None: if self.model_name_or_path is None:
raise ValueError("Please provide `model_name_or_path`.") raise ValueError("Please provide `model_name_or_path`.")
if self.split_special_tokens and self.use_fast_tokenizer:
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
if self.adapter_name_or_path is not None: # support merging multiple lora weights if self.adapter_name_or_path is not None: # support merging multiple lora weights
self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")] self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]

View File

@@ -139,10 +139,6 @@ def _verify_model_args(
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1: if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
raise ValueError("Quantized model only accepts a single adapter. Merge them first.") raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
if data_args.template == "yi" and model_args.use_fast_tokenizer:
logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
model_args.use_fast_tokenizer = False
def _check_extra_dependencies( def _check_extra_dependencies(
model_args: "ModelArguments", model_args: "ModelArguments",
@@ -188,9 +184,7 @@ def _check_extra_dependencies(
if training_args is not None: if training_args is not None:
if training_args.deepspeed: if training_args.deepspeed:
# pin deepspeed version < 0.17 because of https://github.com/deepspeedai/DeepSpeed/issues/7347
check_version("deepspeed", mandatory=True) check_version("deepspeed", mandatory=True)
check_version("deepspeed>=0.10.0,<=0.16.9")
if training_args.predict_with_generate: if training_args.predict_with_generate:
check_version("jieba", mandatory=True) check_version("jieba", mandatory=True)

View File

@@ -22,7 +22,6 @@ from transformers import (
AutoModelForImageTextToText, AutoModelForImageTextToText,
AutoModelForSeq2SeqLM, AutoModelForSeq2SeqLM,
AutoModelForTextToWaveform, AutoModelForTextToWaveform,
AutoModelForVision2Seq,
AutoProcessor, AutoProcessor,
AutoTokenizer, AutoTokenizer,
) )
@@ -166,11 +165,9 @@ def load_model(
else: else:
if type(config) in AutoModelForImageTextToText._model_mapping.keys(): # image-text if type(config) in AutoModelForImageTextToText._model_mapping.keys(): # image-text
load_class = AutoModelForImageTextToText load_class = AutoModelForImageTextToText
elif type(config) in AutoModelForVision2Seq._model_mapping.keys(): # image-text
load_class = AutoModelForVision2Seq
elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys(): # audio-text elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys(): # audio-text
load_class = AutoModelForSeq2SeqLM load_class = AutoModelForSeq2SeqLM
elif type(config) in AutoModelForTextToWaveform._model_mapping.keys(): # audio hack for qwen omni elif type(config) in AutoModelForTextToWaveform._model_mapping.keys(): # audio-text for qwen omni
load_class = AutoModelForTextToWaveform load_class = AutoModelForTextToWaveform
else: else:
load_class = AutoModelForCausalLM load_class = AutoModelForCausalLM

View File

@@ -374,7 +374,13 @@ _register_composite_model(
_register_composite_model( _register_composite_model(
model_type="qwen3_omni_moe_thinker", model_type="qwen3_omni_moe_thinker",
projector_key="visual.merger", projector_key="visual.merger",
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list", "audio_tower"], vision_model_keys=[
"visual.pos_embed",
"visual.patch_embed",
"visual.blocks",
"visual.deepstack_merger_list",
"audio_tower",
],
language_model_keys=["model", "lm_head"], language_model_keys=["model", "lm_head"],
lora_conflict_keys=["patch_embed"], lora_conflict_keys=["patch_embed"],
) )

View File

@@ -103,7 +103,9 @@ class FixValueHeadModelCallback(TrainerCallback):
if args.should_save: if args.should_save:
output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}") output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
fix_valuehead_checkpoint( fix_valuehead_checkpoint(
model=kwargs.pop("model"), output_dir=output_dir, safe_serialization=args.save_safetensors model=kwargs.pop("model"),
output_dir=output_dir,
safe_serialization=getattr(args, "save_safetensors", True),
) )
@@ -137,7 +139,7 @@ class PissaConvertCallback(TrainerCallback):
if isinstance(model, PeftModel): if isinstance(model, PeftModel):
init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights") init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights")
setattr(model.peft_config["default"], "init_lora_weights", True) setattr(model.peft_config["default"], "init_lora_weights", True)
model.save_pretrained(pissa_init_dir, safe_serialization=args.save_safetensors) model.save_pretrained(pissa_init_dir, safe_serialization=getattr(args, "save_safetensors", True))
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights) setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
@override @override
@@ -155,11 +157,11 @@ class PissaConvertCallback(TrainerCallback):
if isinstance(model, PeftModel): if isinstance(model, PeftModel):
init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights") init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights")
setattr(model.peft_config["default"], "init_lora_weights", True) setattr(model.peft_config["default"], "init_lora_weights", True)
model.save_pretrained(pissa_backup_dir, safe_serialization=args.save_safetensors) model.save_pretrained(pissa_backup_dir, safe_serialization=getattr(args, "save_safetensors", True))
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights) setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
model.save_pretrained( model.save_pretrained(
pissa_convert_dir, pissa_convert_dir,
safe_serialization=args.save_safetensors, safe_serialization=getattr(args, "save_safetensors", True),
path_initial_model_for_weight_conversion=pissa_init_dir, path_initial_model_for_weight_conversion=pissa_init_dir,
) )
model.load_adapter(pissa_backup_dir, "default", is_trainable=True) model.load_adapter(pissa_backup_dir, "default", is_trainable=True)

View File

@@ -72,7 +72,7 @@ def run_ppo(
ppo_trainer.ppo_train(resume_from_checkpoint=training_args.resume_from_checkpoint) ppo_trainer.ppo_train(resume_from_checkpoint=training_args.resume_from_checkpoint)
ppo_trainer.save_model() ppo_trainer.save_model()
if training_args.should_save: if training_args.should_save:
fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors) fix_valuehead_checkpoint(model, training_args.output_dir, getattr(training_args, "save_safetensors", True))
ppo_trainer.save_state() # must be called after save_model to have a folder ppo_trainer.save_state() # must be called after save_model to have a folder
if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss: if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss:

View File

@@ -114,7 +114,7 @@ class PairwiseTrainer(Trainer):
if state_dict is None: if state_dict is None:
state_dict = self.model.state_dict() state_dict = self.model.state_dict()
if self.args.save_safetensors: if getattr(self.args, "save_safetensors", True):
from collections import defaultdict from collections import defaultdict
ptrs = defaultdict(list) ptrs = defaultdict(list)

View File

@@ -65,7 +65,7 @@ def run_rm(
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.save_model() trainer.save_model()
if training_args.should_save: if training_args.should_save:
fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors) fix_valuehead_checkpoint(model, training_args.output_dir, getattr(training_args, "save_safetensors", True))
trainer.log_metrics("train", train_result.metrics) trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics)

View File

@@ -18,7 +18,6 @@ Contains shared fixtures, pytest configuration, and custom markers.
""" """
import os import os
import sys
import pytest import pytest
import torch import torch
@@ -149,14 +148,7 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -
devices_str = ",".join(str(i) for i in range(required)) devices_str = ",".join(str(i) for i in range(required))
monkeypatch.setenv(env_key, devices_str) monkeypatch.setenv(env_key, devices_str)
monkeypatch.syspath_prepend(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
# add project root dir to path for mp run
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if project_root not in sys.path:
sys.path.insert(0, project_root)
os.environ["PYTHONPATH"] = project_root + os.pathsep + os.environ.get("PYTHONPATH", "")
else: # non-distributed test else: # non-distributed test
if old_value: if old_value:
visible_devices = [v for v in old_value.split(",") if v != ""] visible_devices = [v for v in old_value.split(",") if v != ""]

View File

@@ -20,6 +20,7 @@ from datasets import load_dataset
from transformers import AutoTokenizer from transformers import AutoTokenizer
from llamafactory.extras.constants import IGNORE_INDEX from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.extras.packages import is_transformers_version_greater_than
from llamafactory.train.test_utils import load_dataset_module from llamafactory.train.test_utils import load_dataset_module
@@ -52,7 +53,12 @@ def test_feedback_data(num_samples: int):
for index in indexes: for index in indexes:
messages = original_data["messages"][index] messages = original_data["messages"][index]
ref_input_ids = ref_tokenizer.apply_chat_template(messages) ref_input_ids = ref_tokenizer.apply_chat_template(messages)
prompt_len = len(ref_tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True)) ref_prompt_ids = ref_tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True)
if is_transformers_version_greater_than("5.0.0"):
ref_input_ids = ref_input_ids["input_ids"]
ref_prompt_ids = ref_prompt_ids["input_ids"]
prompt_len = len(ref_prompt_ids)
ref_labels = [IGNORE_INDEX] * prompt_len + ref_input_ids[prompt_len:] ref_labels = [IGNORE_INDEX] * prompt_len + ref_input_ids[prompt_len:]
assert train_dataset["input_ids"][index] == ref_input_ids assert train_dataset["input_ids"][index] == ref_input_ids
assert train_dataset["labels"][index] == ref_labels assert train_dataset["labels"][index] == ref_labels

View File

@@ -20,6 +20,7 @@ from datasets import load_dataset
from transformers import AutoTokenizer from transformers import AutoTokenizer
from llamafactory.extras.constants import IGNORE_INDEX from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.extras.packages import is_transformers_version_greater_than
from llamafactory.train.test_utils import load_dataset_module from llamafactory.train.test_utils import load_dataset_module
@@ -63,13 +64,21 @@ def test_pairwise_data(num_samples: int):
rejected_messages = original_data["conversations"][index] + [original_data["rejected"][index]] rejected_messages = original_data["conversations"][index] + [original_data["rejected"][index]]
chosen_messages = _convert_sharegpt_to_openai(chosen_messages) chosen_messages = _convert_sharegpt_to_openai(chosen_messages)
rejected_messages = _convert_sharegpt_to_openai(rejected_messages) rejected_messages = _convert_sharegpt_to_openai(rejected_messages)
ref_chosen_input_ids = ref_tokenizer.apply_chat_template(chosen_messages) ref_chosen_input_ids = ref_tokenizer.apply_chat_template(chosen_messages)
chosen_prompt_len = len(ref_tokenizer.apply_chat_template(chosen_messages[:-1], add_generation_prompt=True)) ref_chosen_prompt_ids = ref_tokenizer.apply_chat_template(chosen_messages[:-1], add_generation_prompt=True)
ref_chosen_labels = [IGNORE_INDEX] * chosen_prompt_len + ref_chosen_input_ids[chosen_prompt_len:]
ref_rejected_input_ids = ref_tokenizer.apply_chat_template(rejected_messages) ref_rejected_input_ids = ref_tokenizer.apply_chat_template(rejected_messages)
rejected_prompt_len = len( ref_rejected_prompt_ids = ref_tokenizer.apply_chat_template(rejected_messages[:-1], add_generation_prompt=True)
ref_tokenizer.apply_chat_template(rejected_messages[:-1], add_generation_prompt=True)
) if is_transformers_version_greater_than("5.0.0"):
ref_chosen_input_ids = ref_chosen_input_ids["input_ids"]
ref_rejected_input_ids = ref_rejected_input_ids["input_ids"]
ref_chosen_prompt_ids = ref_chosen_prompt_ids["input_ids"]
ref_rejected_prompt_ids = ref_rejected_prompt_ids["input_ids"]
chosen_prompt_len = len(ref_chosen_prompt_ids)
rejected_prompt_len = len(ref_rejected_prompt_ids)
ref_chosen_labels = [IGNORE_INDEX] * chosen_prompt_len + ref_chosen_input_ids[chosen_prompt_len:]
ref_rejected_labels = [IGNORE_INDEX] * rejected_prompt_len + ref_rejected_input_ids[rejected_prompt_len:] ref_rejected_labels = [IGNORE_INDEX] * rejected_prompt_len + ref_rejected_input_ids[rejected_prompt_len:]
assert train_dataset["chosen_input_ids"][index] == ref_chosen_input_ids assert train_dataset["chosen_input_ids"][index] == ref_chosen_input_ids
assert train_dataset["chosen_labels"][index] == ref_chosen_labels assert train_dataset["chosen_labels"][index] == ref_chosen_labels

View File

@@ -20,6 +20,7 @@ from datasets import load_dataset
from transformers import AutoTokenizer from transformers import AutoTokenizer
from llamafactory.extras.constants import IGNORE_INDEX from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.extras.packages import is_transformers_version_greater_than
from llamafactory.train.test_utils import load_dataset_module from llamafactory.train.test_utils import load_dataset_module
@@ -59,7 +60,16 @@ def test_supervised_single_turn(num_samples: int):
{"role": "assistant", "content": original_data["output"][index]}, {"role": "assistant", "content": original_data["output"][index]},
] ]
ref_input_ids = ref_tokenizer.apply_chat_template(messages) ref_input_ids = ref_tokenizer.apply_chat_template(messages)
ref_prompt_ids = ref_tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True)
if is_transformers_version_greater_than("5.0.0"):
ref_input_ids = ref_input_ids["input_ids"]
ref_prompt_ids = ref_prompt_ids["input_ids"]
prompt_len = len(ref_prompt_ids)
ref_label_ids = [IGNORE_INDEX] * prompt_len + ref_input_ids[prompt_len:]
assert train_dataset["input_ids"][index] == ref_input_ids assert train_dataset["input_ids"][index] == ref_input_ids
assert train_dataset["labels"][index] == ref_label_ids
@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.runs_on(["cpu", "mps"])
@@ -73,6 +83,10 @@ def test_supervised_multi_turn(num_samples: int):
indexes = random.choices(range(len(original_data)), k=num_samples) indexes = random.choices(range(len(original_data)), k=num_samples)
for index in indexes: for index in indexes:
ref_input_ids = ref_tokenizer.apply_chat_template(original_data["messages"][index]) ref_input_ids = ref_tokenizer.apply_chat_template(original_data["messages"][index])
if is_transformers_version_greater_than("5.0.0"):
ref_input_ids = ref_input_ids["input_ids"]
# cannot test the label ids in multi-turn case
assert train_dataset["input_ids"][index] == ref_input_ids assert train_dataset["input_ids"][index] == ref_input_ids
@@ -86,9 +100,12 @@ def test_supervised_train_on_prompt(num_samples: int):
original_data = load_dataset(DEMO_DATA, name="system_chat", split="train") original_data = load_dataset(DEMO_DATA, name="system_chat", split="train")
indexes = random.choices(range(len(original_data)), k=num_samples) indexes = random.choices(range(len(original_data)), k=num_samples)
for index in indexes: for index in indexes:
ref_ids = ref_tokenizer.apply_chat_template(original_data["messages"][index]) ref_input_ids = ref_tokenizer.apply_chat_template(original_data["messages"][index])
assert train_dataset["input_ids"][index] == ref_ids if is_transformers_version_greater_than("5.0.0"):
assert train_dataset["labels"][index] == ref_ids ref_input_ids = ref_input_ids["input_ids"]
assert train_dataset["input_ids"][index] == ref_input_ids
assert train_dataset["labels"][index] == ref_input_ids
@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.runs_on(["cpu", "mps"])
@@ -103,7 +120,13 @@ def test_supervised_mask_history(num_samples: int):
for index in indexes: for index in indexes:
messages = original_data["messages"][index] messages = original_data["messages"][index]
ref_input_ids = ref_tokenizer.apply_chat_template(messages) ref_input_ids = ref_tokenizer.apply_chat_template(messages)
prompt_len = len(ref_tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True)) ref_prompt_ids = ref_tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True)
if is_transformers_version_greater_than("5.0.0"):
ref_input_ids = ref_input_ids["input_ids"]
ref_prompt_ids = ref_prompt_ids["input_ids"]
prompt_len = len(ref_prompt_ids)
ref_label_ids = [IGNORE_INDEX] * prompt_len + ref_input_ids[prompt_len:] ref_label_ids = [IGNORE_INDEX] * prompt_len + ref_input_ids[prompt_len:]
assert train_dataset["input_ids"][index] == ref_input_ids assert train_dataset["input_ids"][index] == ref_input_ids
assert train_dataset["labels"][index] == ref_label_ids assert train_dataset["labels"][index] == ref_label_ids

View File

@@ -19,6 +19,7 @@ import pytest
from datasets import load_dataset from datasets import load_dataset
from transformers import AutoTokenizer from transformers import AutoTokenizer
from llamafactory.extras.packages import is_transformers_version_greater_than
from llamafactory.train.test_utils import load_dataset_module from llamafactory.train.test_utils import load_dataset_module
@@ -55,8 +56,13 @@ def test_unsupervised_data(num_samples: int):
indexes = random.choices(range(len(original_data)), k=num_samples) indexes = random.choices(range(len(original_data)), k=num_samples)
for index in indexes: for index in indexes:
messages = original_data["messages"][index] messages = original_data["messages"][index]
ref_ids = ref_tokenizer.apply_chat_template(messages) ref_input_ids = ref_tokenizer.apply_chat_template(messages)
ref_input_ids = ref_tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True) ref_prompt_ids = ref_tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True)
ref_labels = ref_ids[len(ref_input_ids) :]
assert train_dataset["input_ids"][index] == ref_input_ids if is_transformers_version_greater_than("5.0.0"):
ref_input_ids = ref_input_ids["input_ids"]
ref_prompt_ids = ref_prompt_ids["input_ids"]
ref_labels = ref_input_ids[len(ref_prompt_ids) :]
assert train_dataset["input_ids"][index] == ref_prompt_ids
assert train_dataset["labels"][index] == ref_labels assert train_dataset["labels"][index] == ref_labels

View File

@@ -17,7 +17,7 @@ import os
import pytest import pytest
import torch import torch
from PIL import Image from PIL import Image
from transformers import AutoConfig, AutoModelForVision2Seq from transformers import AutoConfig, AutoModelForImageTextToText
from llamafactory.data import get_template_and_fix_tokenizer from llamafactory.data import get_template_and_fix_tokenizer
from llamafactory.data.collator import MultiModalDataCollatorForSeq2Seq, prepare_4d_attention_mask from llamafactory.data.collator import MultiModalDataCollatorForSeq2Seq, prepare_4d_attention_mask
@@ -82,7 +82,7 @@ def test_multimodal_collator():
template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args) template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args)
config = AutoConfig.from_pretrained(model_args.model_name_or_path) config = AutoConfig.from_pretrained(model_args.model_name_or_path)
with torch.device("meta"): with torch.device("meta"):
model = AutoModelForVision2Seq.from_config(config) model = AutoModelForImageTextToText.from_config(config)
data_collator = MultiModalDataCollatorForSeq2Seq( data_collator = MultiModalDataCollatorForSeq2Seq(
template=template, template=template,

View File

@@ -20,6 +20,7 @@ from transformers import AutoTokenizer
from llamafactory.data import get_template_and_fix_tokenizer from llamafactory.data import get_template_and_fix_tokenizer
from llamafactory.data.template import parse_template from llamafactory.data.template import parse_template
from llamafactory.extras.packages import is_transformers_version_greater_than
from llamafactory.hparams import DataArguments from llamafactory.hparams import DataArguments
@@ -65,7 +66,6 @@ def _check_template(
template_name: str, template_name: str,
prompt_str: str, prompt_str: str,
answer_str: str, answer_str: str,
use_fast: bool,
messages: list[dict[str, str]] = MESSAGES, messages: list[dict[str, str]] = MESSAGES,
) -> None: ) -> None:
r"""Check template. r"""Check template.
@@ -75,13 +75,15 @@ def _check_template(
template_name: the template name. template_name: the template name.
prompt_str: the string corresponding to the prompt part. prompt_str: the string corresponding to the prompt part.
answer_str: the string corresponding to the answer part. answer_str: the string corresponding to the answer part.
use_fast: whether to use fast tokenizer.
messages: the list of messages. messages: the list of messages.
""" """
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, token=HF_TOKEN) tokenizer = AutoTokenizer.from_pretrained(model_id)
content_str = tokenizer.apply_chat_template(messages, tokenize=False) content_str = tokenizer.apply_chat_template(messages, tokenize=False)
content_ids = tokenizer.apply_chat_template(messages, tokenize=True) content_ids = tokenizer.apply_chat_template(messages, tokenize=True)
if is_transformers_version_greater_than("5.0.0"):
content_ids = content_ids["input_ids"]
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template=template_name)) template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template=template_name))
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, messages) prompt_ids, answer_ids = template.encode_oneturn(tokenizer, messages)
assert content_str == prompt_str + answer_str assert content_str == prompt_str + answer_str
@@ -90,9 +92,8 @@ def _check_template(
@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("use_fast", [True, False]) def test_encode_oneturn():
def test_encode_oneturn(use_fast: bool): tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3")) template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES) prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
prompt_str = ( prompt_str = (
@@ -106,9 +107,8 @@ def test_encode_oneturn(use_fast: bool):
@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("use_fast", [True, False]) def test_encode_multiturn():
def test_encode_multiturn(use_fast: bool): tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3")) template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES) encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES)
prompt_str_1 = ( prompt_str_1 = (
@@ -128,11 +128,10 @@ def test_encode_multiturn(use_fast: bool):
@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("use_fast", [True, False])
@pytest.mark.parametrize("cot_messages", [True, False]) @pytest.mark.parametrize("cot_messages", [True, False])
@pytest.mark.parametrize("enable_thinking", [True, False, None]) @pytest.mark.parametrize("enable_thinking", [True, False, None])
def test_reasoning_encode_oneturn(use_fast: bool, cot_messages: bool, enable_thinking: bool): def test_reasoning_encode_oneturn(cot_messages: bool, enable_thinking: bool):
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", use_fast=use_fast) tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
data_args = DataArguments(template="qwen3", enable_thinking=enable_thinking) data_args = DataArguments(template="qwen3", enable_thinking=enable_thinking)
template = get_template_and_fix_tokenizer(tokenizer, data_args) template = get_template_and_fix_tokenizer(tokenizer, data_args)
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES) prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES)
@@ -155,11 +154,10 @@ def test_reasoning_encode_oneturn(use_fast: bool, cot_messages: bool, enable_thi
@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("use_fast", [True, False])
@pytest.mark.parametrize("cot_messages", [True, False]) @pytest.mark.parametrize("cot_messages", [True, False])
@pytest.mark.parametrize("enable_thinking", [True, False, None]) @pytest.mark.parametrize("enable_thinking", [True, False, None])
def test_reasoning_encode_multiturn(use_fast: bool, cot_messages: bool, enable_thinking: bool): def test_reasoning_encode_multiturn(cot_messages: bool, enable_thinking: bool):
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", use_fast=use_fast) tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
data_args = DataArguments(template="qwen3", enable_thinking=enable_thinking) data_args = DataArguments(template="qwen3", enable_thinking=enable_thinking)
template = get_template_and_fix_tokenizer(tokenizer, data_args) template = get_template_and_fix_tokenizer(tokenizer, data_args)
encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES) encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES)
@@ -185,10 +183,9 @@ def test_reasoning_encode_multiturn(use_fast: bool, cot_messages: bool, enable_t
@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("use_fast", [True, False]) def test_jinja_template():
def test_jinja_template(use_fast: bool): tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast) ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3")) template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
tokenizer.chat_template = template._get_jinja_template(tokenizer) # llama3 template no replace tokenizer.chat_template = template._get_jinja_template(tokenizer) # llama3 template no replace
assert tokenizer.chat_template != ref_tokenizer.chat_template assert tokenizer.chat_template != ref_tokenizer.chat_template
@@ -222,8 +219,7 @@ def test_get_stop_token_ids():
@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.") @pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
@pytest.mark.parametrize("use_fast", [True, False]) def test_gemma_template():
def test_gemma_template(use_fast: bool):
prompt_str = ( prompt_str = (
f"<bos><start_of_turn>user\n{MESSAGES[0]['content']}<end_of_turn>\n" f"<bos><start_of_turn>user\n{MESSAGES[0]['content']}<end_of_turn>\n"
f"<start_of_turn>model\n{MESSAGES[1]['content']}<end_of_turn>\n" f"<start_of_turn>model\n{MESSAGES[1]['content']}<end_of_turn>\n"
@@ -231,13 +227,12 @@ def test_gemma_template(use_fast: bool):
"<start_of_turn>model\n" "<start_of_turn>model\n"
) )
answer_str = f"{MESSAGES[3]['content']}<end_of_turn>\n" answer_str = f"{MESSAGES[3]['content']}<end_of_turn>\n"
_check_template("google/gemma-3-4b-it", "gemma", prompt_str, answer_str, use_fast) _check_template("google/gemma-3-4b-it", "gemma", prompt_str, answer_str)
@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.") @pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
@pytest.mark.parametrize("use_fast", [True, False]) def test_gemma2_template():
def test_gemma2_template(use_fast: bool):
prompt_str = ( prompt_str = (
f"<bos><start_of_turn>user\n{MESSAGES[0]['content']}<end_of_turn>\n" f"<bos><start_of_turn>user\n{MESSAGES[0]['content']}<end_of_turn>\n"
f"<start_of_turn>model\n{MESSAGES[1]['content']}<end_of_turn>\n" f"<start_of_turn>model\n{MESSAGES[1]['content']}<end_of_turn>\n"
@@ -245,13 +240,12 @@ def test_gemma2_template(use_fast: bool):
"<start_of_turn>model\n" "<start_of_turn>model\n"
) )
answer_str = f"{MESSAGES[3]['content']}<end_of_turn>\n" answer_str = f"{MESSAGES[3]['content']}<end_of_turn>\n"
_check_template("google/gemma-2-2b-it", "gemma2", prompt_str, answer_str, use_fast) _check_template("google/gemma-2-2b-it", "gemma2", prompt_str, answer_str)
@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.") @pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
@pytest.mark.parametrize("use_fast", [True, False]) def test_llama3_template():
def test_llama3_template(use_fast: bool):
prompt_str = ( prompt_str = (
f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{MESSAGES[0]['content']}<|eot_id|>" f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{MESSAGES[0]['content']}<|eot_id|>"
f"<|start_header_id|>assistant<|end_header_id|>\n\n{MESSAGES[1]['content']}<|eot_id|>" f"<|start_header_id|>assistant<|end_header_id|>\n\n{MESSAGES[1]['content']}<|eot_id|>"
@@ -259,14 +253,11 @@ def test_llama3_template(use_fast: bool):
"<|start_header_id|>assistant<|end_header_id|>\n\n" "<|start_header_id|>assistant<|end_header_id|>\n\n"
) )
answer_str = f"{MESSAGES[3]['content']}<|eot_id|>" answer_str = f"{MESSAGES[3]['content']}<|eot_id|>"
_check_template("meta-llama/Meta-Llama-3-8B-Instruct", "llama3", prompt_str, answer_str, use_fast) _check_template("meta-llama/Meta-Llama-3-8B-Instruct", "llama3", prompt_str, answer_str)
@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize( def test_llama4_template():
"use_fast", [True, pytest.param(False, marks=pytest.mark.xfail(reason="Llama 4 has no slow tokenizer."))]
)
def test_llama4_template(use_fast: bool):
prompt_str = ( prompt_str = (
f"<|begin_of_text|><|header_start|>user<|header_end|>\n\n{MESSAGES[0]['content']}<|eot|>" f"<|begin_of_text|><|header_start|>user<|header_end|>\n\n{MESSAGES[0]['content']}<|eot|>"
f"<|header_start|>assistant<|header_end|>\n\n{MESSAGES[1]['content']}<|eot|>" f"<|header_start|>assistant<|header_end|>\n\n{MESSAGES[1]['content']}<|eot|>"
@@ -274,18 +265,11 @@ def test_llama4_template(use_fast: bool):
"<|header_start|>assistant<|header_end|>\n\n" "<|header_start|>assistant<|header_end|>\n\n"
) )
answer_str = f"{MESSAGES[3]['content']}<|eot|>" answer_str = f"{MESSAGES[3]['content']}<|eot|>"
_check_template(TINY_LLAMA4, "llama4", prompt_str, answer_str, use_fast) _check_template(TINY_LLAMA4, "llama4", prompt_str, answer_str)
@pytest.mark.parametrize(
"use_fast",
[
pytest.param(True, marks=pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")),
pytest.param(False, marks=pytest.mark.xfail(reason="Phi-4 slow tokenizer is broken.")),
],
)
@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_phi4_template(use_fast: bool): def test_phi4_template():
prompt_str = ( prompt_str = (
f"<|im_start|>user<|im_sep|>{MESSAGES[0]['content']}<|im_end|>" f"<|im_start|>user<|im_sep|>{MESSAGES[0]['content']}<|im_end|>"
f"<|im_start|>assistant<|im_sep|>{MESSAGES[1]['content']}<|im_end|>" f"<|im_start|>assistant<|im_sep|>{MESSAGES[1]['content']}<|im_end|>"
@@ -293,13 +277,12 @@ def test_phi4_template(use_fast: bool):
"<|im_start|>assistant<|im_sep|>" "<|im_start|>assistant<|im_sep|>"
) )
answer_str = f"{MESSAGES[3]['content']}<|im_end|>" answer_str = f"{MESSAGES[3]['content']}<|im_end|>"
_check_template("microsoft/phi-4", "phi4", prompt_str, answer_str, use_fast) _check_template("microsoft/phi-4", "phi4", prompt_str, answer_str)
@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.xfail(not HF_TOKEN, reason="Authorization.") @pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")
@pytest.mark.parametrize("use_fast", [True, False]) def test_qwen2_5_template():
def test_qwen2_5_template(use_fast: bool):
prompt_str = ( prompt_str = (
"<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n" "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n"
f"<|im_start|>user\n{MESSAGES[0]['content']}<|im_end|>\n" f"<|im_start|>user\n{MESSAGES[0]['content']}<|im_end|>\n"
@@ -308,13 +291,12 @@ def test_qwen2_5_template(use_fast: bool):
"<|im_start|>assistant\n" "<|im_start|>assistant\n"
) )
answer_str = f"{MESSAGES[3]['content']}<|im_end|>\n" answer_str = f"{MESSAGES[3]['content']}<|im_end|>\n"
_check_template("Qwen/Qwen2.5-7B-Instruct", "qwen", prompt_str, answer_str, use_fast) _check_template("Qwen/Qwen2.5-7B-Instruct", "qwen", prompt_str, answer_str)
@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("use_fast", [True, False])
@pytest.mark.parametrize("cot_messages", [True, False]) @pytest.mark.parametrize("cot_messages", [True, False])
def test_qwen3_template(use_fast: bool, cot_messages: bool): def test_qwen3_template(cot_messages: bool):
prompt_str = ( prompt_str = (
f"<|im_start|>user\n{MESSAGES[0]['content']}<|im_end|>\n" f"<|im_start|>user\n{MESSAGES[0]['content']}<|im_end|>\n"
f"<|im_start|>assistant\n{MESSAGES[1]['content']}<|im_end|>\n" f"<|im_start|>assistant\n{MESSAGES[1]['content']}<|im_end|>\n"
@@ -328,12 +310,12 @@ def test_qwen3_template(use_fast: bool, cot_messages: bool):
answer_str = f"{MESSAGES_WITH_THOUGHT[3]['content']}<|im_end|>\n" answer_str = f"{MESSAGES_WITH_THOUGHT[3]['content']}<|im_end|>\n"
messages = MESSAGES_WITH_THOUGHT messages = MESSAGES_WITH_THOUGHT
_check_template("Qwen/Qwen3-8B", "qwen3", prompt_str, answer_str, use_fast, messages=messages) _check_template("Qwen/Qwen3-8B", "qwen3", prompt_str, answer_str, messages=messages)
@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_parse_llama3_template(): def test_parse_llama3_template():
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, token=HF_TOKEN) tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
template = parse_template(tokenizer) template = parse_template(tokenizer)
assert template.format_user.slots == [ assert template.format_user.slots == [
"<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>" "<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
@@ -348,7 +330,7 @@ def test_parse_llama3_template():
@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.xfail(not HF_TOKEN, reason="Authorization.") @pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")
def test_parse_qwen_template(): def test_parse_qwen_template():
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct", token=HF_TOKEN) tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
template = parse_template(tokenizer) template = parse_template(tokenizer)
assert template.__class__.__name__ == "Template" assert template.__class__.__name__ == "Template"
assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"] assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
@@ -361,7 +343,7 @@ def test_parse_qwen_template():
@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.xfail(not HF_TOKEN, reason="Authorization.") @pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")
def test_parse_qwen3_template(): def test_parse_qwen3_template():
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", token=HF_TOKEN) tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
template = parse_template(tokenizer) template = parse_template(tokenizer)
assert template.__class__.__name__ == "ReasoningTemplate" assert template.__class__.__name__ == "ReasoningTemplate"
assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"] assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]

View File

@@ -16,7 +16,8 @@ import os
import pytest import pytest
import torch import torch
from transformers import AutoConfig, AutoModelForVision2Seq from safetensors.torch import load_file
from transformers import AutoConfig, AutoModelForImageTextToText
from llamafactory.extras.packages import is_transformers_version_greater_than from llamafactory.extras.packages import is_transformers_version_greater_than
from llamafactory.hparams import FinetuningArguments, ModelArguments from llamafactory.hparams import FinetuningArguments, ModelArguments
@@ -36,7 +37,7 @@ def test_visual_full(freeze_vision_tower: bool, freeze_multi_modal_projector: bo
) )
config = AutoConfig.from_pretrained(model_args.model_name_or_path) config = AutoConfig.from_pretrained(model_args.model_name_or_path)
with torch.device("meta"): with torch.device("meta"):
model = AutoModelForVision2Seq.from_config(config) model = AutoModelForImageTextToText.from_config(config)
model = init_adapter(config, model, model_args, finetuning_args, is_trainable=True) model = init_adapter(config, model, model_args, finetuning_args, is_trainable=True)
for name, param in model.named_parameters(): for name, param in model.named_parameters():
@@ -56,7 +57,7 @@ def test_visual_lora(freeze_vision_tower: bool, freeze_language_model: bool):
) )
config = AutoConfig.from_pretrained(model_args.model_name_or_path) config = AutoConfig.from_pretrained(model_args.model_name_or_path)
with torch.device("meta"): with torch.device("meta"):
model = AutoModelForVision2Seq.from_config(config) model = AutoModelForImageTextToText.from_config(config)
model = init_adapter(config, model, model_args, finetuning_args, is_trainable=True) model = init_adapter(config, model, model_args, finetuning_args, is_trainable=True)
trainable_params, frozen_params = set(), set() trainable_params, frozen_params = set(), set()
@@ -86,13 +87,14 @@ def test_visual_model_save_load():
finetuning_args = FinetuningArguments(finetuning_type="full") finetuning_args = FinetuningArguments(finetuning_type="full")
config = AutoConfig.from_pretrained(model_args.model_name_or_path) config = AutoConfig.from_pretrained(model_args.model_name_or_path)
with torch.device("meta"): with torch.device("meta"):
model = AutoModelForVision2Seq.from_config(config) model = AutoModelForImageTextToText.from_config(config)
model = init_adapter(config, model, model_args, finetuning_args, is_trainable=False) model = init_adapter(config, model, model_args, finetuning_args, is_trainable=False)
model.to_empty(device="cpu")
loaded_model_weight = dict(model.named_parameters()) loaded_model_weight = dict(model.named_parameters())
model.save_pretrained(os.path.join("output", "qwen2_vl"), max_shard_size="10GB", safe_serialization=False) model.save_pretrained(os.path.join("output", "qwen2_vl"), max_shard_size="10GB", safe_serialization=True)
saved_model_weight = torch.load(os.path.join("output", "qwen2_vl", "pytorch_model.bin"), weights_only=False) saved_model_weight = load_file(os.path.join("output", "qwen2_vl", "model.safetensors"))
if is_transformers_version_greater_than("4.52.0"): if is_transformers_version_greater_than("4.52.0"):
assert "model.language_model.layers.0.self_attn.q_proj.weight" in loaded_model_weight assert "model.language_model.layers.0.self_attn.q_proj.weight" in loaded_model_weight

View File

@@ -1,2 +1,2 @@
# change if test fails or cache is outdated # change if test fails or cache is outdated
0.9.5.105 0.9.5.106

View File

@@ -23,6 +23,13 @@ from llamafactory.v1.core.utils.rendering import Renderer
from llamafactory.v1.utils.types import Processor from llamafactory.v1.utils.types import Processor
def _get_input_ids(inputs: list | dict) -> list:
if not isinstance(inputs, list):
return inputs["input_ids"]
else:
return inputs
HF_MESSAGES = [ HF_MESSAGES = [
{"role": "system", "content": "You are a helpful assistant."}, {"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is LLM?"}, {"role": "user", "content": "What is LLM?"},
@@ -81,15 +88,15 @@ def test_chatml_rendering():
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3") tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
renderer = Renderer(template="chatml", processor=tokenizer) renderer = Renderer(template="chatml", processor=tokenizer)
hf_inputs = tokenizer.apply_chat_template(HF_MESSAGES[:-1], add_generation_prompt=True) hf_inputs = _get_input_ids(tokenizer.apply_chat_template(HF_MESSAGES[:-1], add_generation_prompt=True))
v1_inputs = renderer.render_messages(V1_MESSAGES[:-1], is_generate=True) v1_inputs = renderer.render_messages(V1_MESSAGES[:-1], is_generate=True)
assert v1_inputs["input_ids"] == hf_inputs assert v1_inputs["input_ids"] == hf_inputs
assert v1_inputs["attention_mask"] == [1] * len(hf_inputs) assert v1_inputs["attention_mask"] == [1] * len(hf_inputs)
assert v1_inputs["labels"] == [-100] * len(hf_inputs) assert v1_inputs["labels"] == [-100] * len(hf_inputs)
assert v1_inputs["loss_weights"] == [0.0] * len(hf_inputs) assert v1_inputs["loss_weights"] == [0.0] * len(hf_inputs)
hf_inputs_part = tokenizer.apply_chat_template(HF_MESSAGES[:-1], add_generation_prompt=False) hf_inputs_part = _get_input_ids(tokenizer.apply_chat_template(HF_MESSAGES[:-1], add_generation_prompt=False))
hf_inputs_full = tokenizer.apply_chat_template(HF_MESSAGES, add_generation_prompt=False) hf_inputs_full = _get_input_ids(tokenizer.apply_chat_template(HF_MESSAGES, add_generation_prompt=False))
v1_inputs_full = renderer.render_messages(V1_MESSAGES, is_generate=False) v1_inputs_full = renderer.render_messages(V1_MESSAGES, is_generate=False)
assert v1_inputs_full["input_ids"] == hf_inputs_full assert v1_inputs_full["input_ids"] == hf_inputs_full
assert v1_inputs_full["attention_mask"] == [1] * len(hf_inputs_full) assert v1_inputs_full["attention_mask"] == [1] * len(hf_inputs_full)
@@ -124,17 +131,21 @@ def test_qwen3_nothink_rendering():
tokenizer: Processor = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B-Instruct-2507") tokenizer: Processor = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B-Instruct-2507")
renderer = Renderer(template="qwen3_nothink", processor=tokenizer) renderer = Renderer(template="qwen3_nothink", processor=tokenizer)
hf_inputs = tokenizer.apply_chat_template(HF_MESSAGES_WITH_TOOLS[:-1], tools=V1_TOOLS, add_generation_prompt=True) hf_inputs = _get_input_ids(
tokenizer.apply_chat_template(HF_MESSAGES_WITH_TOOLS[:-1], tools=V1_TOOLS, add_generation_prompt=True)
)
v1_inputs = renderer.render_messages(V1_MESSAGES_WITH_TOOLS[:-1], tools=json.dumps(V1_TOOLS), is_generate=True) v1_inputs = renderer.render_messages(V1_MESSAGES_WITH_TOOLS[:-1], tools=json.dumps(V1_TOOLS), is_generate=True)
assert v1_inputs["input_ids"] == hf_inputs assert v1_inputs["input_ids"] == hf_inputs
assert v1_inputs["attention_mask"] == [1] * len(hf_inputs) assert v1_inputs["attention_mask"] == [1] * len(hf_inputs)
assert v1_inputs["labels"] == [-100] * len(hf_inputs) assert v1_inputs["labels"] == [-100] * len(hf_inputs)
assert v1_inputs["loss_weights"] == [0.0] * len(hf_inputs) assert v1_inputs["loss_weights"] == [0.0] * len(hf_inputs)
hf_inputs_part = tokenizer.apply_chat_template( hf_inputs_part = _get_input_ids(
HF_MESSAGES_WITH_TOOLS[:-1], tools=V1_TOOLS, add_generation_prompt=False tokenizer.apply_chat_template(HF_MESSAGES_WITH_TOOLS[:-1], tools=V1_TOOLS, add_generation_prompt=False)
)
hf_inputs_full = _get_input_ids(
tokenizer.apply_chat_template(HF_MESSAGES_WITH_TOOLS, tools=V1_TOOLS, add_generation_prompt=False)
) )
hf_inputs_full = tokenizer.apply_chat_template(HF_MESSAGES_WITH_TOOLS, tools=V1_TOOLS, add_generation_prompt=False)
v1_inputs_full = renderer.render_messages(V1_MESSAGES_WITH_TOOLS, tools=json.dumps(V1_TOOLS), is_generate=False) v1_inputs_full = renderer.render_messages(V1_MESSAGES_WITH_TOOLS, tools=json.dumps(V1_TOOLS), is_generate=False)
assert v1_inputs_full["input_ids"] == hf_inputs_full assert v1_inputs_full["input_ids"] == hf_inputs_full
assert v1_inputs_full["attention_mask"] == [1] * len(hf_inputs_full) assert v1_inputs_full["attention_mask"] == [1] * len(hf_inputs_full)
@@ -187,7 +198,7 @@ def test_qwen3_nothink_rendering_remote(num_samples: int):
def test_process_sft_samples(): def test_process_sft_samples():
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3") tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
renderer = Renderer(template="chatml", processor=tokenizer) renderer = Renderer(template="chatml", processor=tokenizer)
hf_inputs = tokenizer.apply_chat_template(HF_MESSAGES) hf_inputs = _get_input_ids(tokenizer.apply_chat_template(HF_MESSAGES))
samples = [{"messages": V1_MESSAGES, "extra_info": "test", "_dataset_name": "default"}] samples = [{"messages": V1_MESSAGES, "extra_info": "test", "_dataset_name": "default"}]
model_inputs = renderer.process_samples(samples) model_inputs = renderer.process_samples(samples)
@@ -200,7 +211,7 @@ def test_process_sft_samples():
def test_process_dpo_samples(): def test_process_dpo_samples():
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3") tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
renderer = Renderer(template="chatml", processor=tokenizer) renderer = Renderer(template="chatml", processor=tokenizer)
hf_inputs = tokenizer.apply_chat_template(HF_MESSAGES) hf_inputs = _get_input_ids(tokenizer.apply_chat_template(HF_MESSAGES))
samples = [ samples = [
{ {