4 Commits

Author SHA1 Message Date
ゆり
e7cb145f5d [logging] Fix race condition in LoggerHandler during multi-GPU training (#10156)
Co-authored-by: yurekami <yurekami@users.noreply.github.com>
2026-02-03 11:14:07 +08:00
Hertz
b53d7037c2 [model] support youtu-vl model (#10152) 2026-02-02 21:42:43 +08:00
浮梦
bf04ca6af8 [deps] adapt to transformers v5 (#10147)
Co-authored-by: frozenleaves <frozen@Mac.local>
Co-authored-by: hiyouga <hiyouga@buaa.edu.cn>
2026-02-02 12:07:19 +08:00
xvxuopop
762b480131 [feature] support using ray.remote to start distributed training. (#10109) 2026-01-28 16:05:29 +08:00
32 changed files with 469 additions and 203 deletions

View File

@@ -18,7 +18,7 @@ init_config:
name: init_on_meta name: init_on_meta
### data ### data
train_dataset: data/v1_sft_demo.yaml train_dataset: data/v1_sft_demo.yaml
### training ### training
output_dir: outputs/test_fsdp2 output_dir: outputs/test_fsdp2

View File

@@ -40,10 +40,10 @@ dependencies = [
"torch>=2.4.0", "torch>=2.4.0",
"torchvision>=0.19.0", "torchvision>=0.19.0",
"torchaudio>=2.4.0", "torchaudio>=2.4.0",
"transformers>=4.51.0,<=4.57.1,!=4.52.0,!=4.57.0", "transformers>=4.51.0,<=5.0.0,!=4.52.0,!=4.57.0",
"datasets>=2.16.0,<=4.0.0", "datasets>=2.16.0,<=4.0.0",
"accelerate>=1.3.0,<=1.11.0", "accelerate>=1.3.0,<=1.11.0",
"peft>=0.14.0,<=0.17.1", "peft>=0.18.0,<=0.18.1",
"trl>=0.18.0,<=0.24.0", "trl>=0.18.0,<=0.24.0",
"torchdata>=0.10.0,<=0.11.0", "torchdata>=0.10.0,<=0.11.0",
# gui # gui

View File

@@ -1 +1 @@
deepspeed>=0.10.0,<=0.16.9 deepspeed>=0.10.0,<=0.18.4

View File

@@ -2159,6 +2159,40 @@ class LFMVLPlugin(BasePlugin):
return messages return messages
@dataclass
class YoutuVLPlugin(BasePlugin):
r"""Plugin for Youtu-VL vision-language models."""
vision_bos_token: str = "<|vision_start|>"
vision_eos_token: str = "<|vision_end|>"
@override
def process_messages(
self,
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
self._validate_messages(messages, images, videos, audios)
messages = deepcopy(messages)
for message in messages:
content = message["content"]
content = content.replace(
IMAGE_PLACEHOLDER, f"{self.vision_bos_token}{self.image_token}{self.vision_eos_token}"
)
content = content.replace(
VIDEO_PLACEHOLDER, f"{self.vision_bos_token}{self.video_token}{self.vision_eos_token}"
)
message["content"] = content
return messages
PLUGINS = { PLUGINS = {
"base": BasePlugin, "base": BasePlugin,
"ernie_vl": ErnieVLPlugin, "ernie_vl": ErnieVLPlugin,
@@ -2181,6 +2215,7 @@ PLUGINS = {
"qwen2_vl": Qwen2VLPlugin, "qwen2_vl": Qwen2VLPlugin,
"qwen3_vl": Qwen3VLPlugin, "qwen3_vl": Qwen3VLPlugin,
"video_llava": VideoLlavaPlugin, "video_llava": VideoLlavaPlugin,
"youtu_vl": YoutuVLPlugin,
} }

View File

@@ -2146,6 +2146,19 @@ register_template(
) )
register_template(
name="youtu_vl",
format_user=StringFormatter(
slots=["<|begin_of_text|>user\n{{content}}<|end_of_text|>\n<|begin_of_text|>assistant\n"]
),
format_assistant=StringFormatter(slots=["{{content}}<|end_of_text|>\n"]),
format_system=StringFormatter(slots=["<|begin_of_text|>system\n{{content}}<|end_of_text|>\n"]),
default_system="You are a helpful assistant.",
stop_words=["<|end_of_text|>"],
mm_plugin=get_mm_plugin(name="youtu_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
)
register_template( register_template(
name="yuan", name="yuan",
format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]), format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]),

View File

@@ -3375,6 +3375,18 @@ register_model_group(
) )
register_model_group(
models={
"Youtu-VL-4B-Instruct": {
DownloadSource.DEFAULT: "tencent/Youtu-VL-4B-Instruct",
DownloadSource.MODELSCOPE: "Tencent-YouTu-Research/Youtu-VL-4B-Instruct",
},
},
template="youtu_vl",
multimodal=True,
)
register_model_group( register_model_group(
models={ models={
"Yuan2-2B-Chat": { "Yuan2-2B-Chat": {

View File

@@ -41,12 +41,13 @@ class LoggerHandler(logging.Handler):
datefmt="%Y-%m-%d %H:%M:%S", datefmt="%Y-%m-%d %H:%M:%S",
) )
self.setLevel(logging.INFO) self.setLevel(logging.INFO)
self.thread_pool = ThreadPoolExecutor(max_workers=1)
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
self.running_log = os.path.join(output_dir, RUNNING_LOG) self.running_log = os.path.join(output_dir, RUNNING_LOG)
if os.path.exists(self.running_log): try:
os.remove(self.running_log) os.remove(self.running_log)
except OSError:
self.thread_pool = ThreadPoolExecutor(max_workers=1) pass
def _write_log(self, log_entry: str) -> None: def _write_log(self, log_entry: str) -> None:
with open(self.running_log, "a", encoding="utf-8") as f: with open(self.running_log, "a", encoding="utf-8") as f:

View File

@@ -94,10 +94,10 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
def check_dependencies() -> None: def check_dependencies() -> None:
r"""Check the version of the required packages.""" r"""Check the version of the required packages."""
check_version("transformers>=4.51.0,<=4.57.1") check_version("transformers>=4.51.0,<=5.0.0")
check_version("datasets>=2.16.0,<=4.0.0") check_version("datasets>=2.16.0,<=4.0.0")
check_version("accelerate>=1.3.0,<=1.11.0") check_version("accelerate>=1.3.0,<=1.11.0")
check_version("peft>=0.14.0,<=0.17.1") check_version("peft>=0.18.0,<=0.18.1")
check_version("trl>=0.18.0,<=0.24.0") check_version("trl>=0.18.0,<=0.24.0")
@@ -157,6 +157,33 @@ def get_current_device() -> "torch.device":
return torch.device(device) return torch.device(device)
def get_device_name() -> str:
r"""Get the name of available devices."""
if is_torch_xpu_available():
device = "xpu"
elif is_torch_npu_available():
device = "npu"
elif is_torch_mps_available():
device = "mps"
elif is_torch_cuda_available():
device = "gpu"
else:
device = "cpu"
return device
def get_torch_device():
r"""Get the torch device namespace for the available devices."""
device_name = get_device_name()
device_name = "cuda" if device_name == "gpu" else device_name
try:
return getattr(torch, device_name)
except AttributeError:
logger.warning_rank0(f"Device namespace '{device_name}' not found in torch, try to load torch.cuda.")
return torch.cuda
def get_device_count() -> int: def get_device_count() -> int:
r"""Get the number of available devices.""" r"""Get the number of available devices."""
if is_torch_xpu_available(): if is_torch_xpu_available():

View File

@@ -65,7 +65,9 @@ class DataArguments:
) )
mix_strategy: Literal["concat", "interleave_under", "interleave_over", "interleave_once"] = field( mix_strategy: Literal["concat", "interleave_under", "interleave_over", "interleave_once"] = field(
default="concat", default="concat",
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling/sampling w.o. replacement)."}, metadata={
"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling/sampling w.o. replacement)."
},
) )
interleave_probs: str | None = field( interleave_probs: str | None = field(
default=None, default=None,

View File

@@ -206,9 +206,6 @@ class BaseModelArguments:
if self.model_name_or_path is None: if self.model_name_or_path is None:
raise ValueError("Please provide `model_name_or_path`.") raise ValueError("Please provide `model_name_or_path`.")
if self.split_special_tokens and self.use_fast_tokenizer:
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
if self.adapter_name_or_path is not None: # support merging multiple lora weights if self.adapter_name_or_path is not None: # support merging multiple lora weights
self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")] self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]

View File

@@ -139,10 +139,6 @@ def _verify_model_args(
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1: if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
raise ValueError("Quantized model only accepts a single adapter. Merge them first.") raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
if data_args.template == "yi" and model_args.use_fast_tokenizer:
logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
model_args.use_fast_tokenizer = False
def _check_extra_dependencies( def _check_extra_dependencies(
model_args: "ModelArguments", model_args: "ModelArguments",
@@ -188,9 +184,7 @@ def _check_extra_dependencies(
if training_args is not None: if training_args is not None:
if training_args.deepspeed: if training_args.deepspeed:
# pin deepspeed version < 0.17 because of https://github.com/deepspeedai/DeepSpeed/issues/7347
check_version("deepspeed", mandatory=True) check_version("deepspeed", mandatory=True)
check_version("deepspeed>=0.10.0,<=0.16.9")
if training_args.predict_with_generate: if training_args.predict_with_generate:
check_version("jieba", mandatory=True) check_version("jieba", mandatory=True)

View File

@@ -14,7 +14,6 @@
import json import json
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Literal
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
from transformers.training_args import _convert_str_dict from transformers.training_args import _convert_str_dict
@@ -40,56 +39,29 @@ else:
class RayArguments: class RayArguments:
r"""Arguments pertaining to the Ray training.""" r"""Arguments pertaining to the Ray training."""
ray_run_name: str | None = field(
default=None,
metadata={"help": "The training results will be saved at `<ray_storage_path>/ray_run_name`."},
)
ray_storage_path: str = field(
default="./saves",
metadata={"help": "The storage path to save training results to"},
)
ray_storage_filesystem: Literal["s3", "gs", "gcs"] | None = field(
default=None,
metadata={"help": "The storage filesystem to use. If None specified, local filesystem will be used."},
)
ray_num_workers: int = field( ray_num_workers: int = field(
default=1, default=1,
metadata={"help": "The number of workers for Ray training. Default is 1 worker."}, metadata={"help": "The number of workers for Ray training. Default is 1 worker."},
) )
resources_per_worker: dict | str = field(
default_factory=lambda: {"GPU": 1},
metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."},
)
placement_strategy: Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"] = field(
default="PACK",
metadata={"help": "The placement strategy for Ray training. Default is PACK."},
)
ray_init_kwargs: dict | str | None = field( ray_init_kwargs: dict | str | None = field(
default=None, default=None,
metadata={"help": "The arguments to pass to ray.init for Ray training. Default is None."}, metadata={"help": "The arguments to pass to ray.init for Ray training. Default is None."},
) )
master_addr: str | None = field(
default=None,
metadata={"help": "The master address for init_process_group"},
)
master_port: str | None = field(
default=None,
metadata={"help": "The master port for init_process_group"},
)
def __post_init__(self): def __post_init__(self):
self.use_ray = use_ray() self.use_ray = use_ray()
if isinstance(self.resources_per_worker, str) and self.resources_per_worker.startswith("{"):
self.resources_per_worker = _convert_str_dict(json.loads(self.resources_per_worker))
if isinstance(self.ray_init_kwargs, str) and self.ray_init_kwargs.startswith("{"): if isinstance(self.ray_init_kwargs, str) and self.ray_init_kwargs.startswith("{"):
self.ray_init_kwargs = _convert_str_dict(json.loads(self.ray_init_kwargs)) self.ray_init_kwargs = _convert_str_dict(json.loads(self.ray_init_kwargs))
if self.ray_storage_filesystem is not None:
if self.ray_storage_filesystem not in ["s3", "gs", "gcs"]:
raise ValueError(
f"ray_storage_filesystem must be one of ['s3', 'gs', 'gcs'], got {self.ray_storage_filesystem}."
)
import pyarrow.fs as fs
if self.ray_storage_filesystem == "s3":
self.ray_storage_filesystem = fs.S3FileSystem()
elif self.ray_storage_filesystem == "gs" or self.ray_storage_filesystem == "gcs":
self.ray_storage_filesystem = fs.GcsFileSystem()
@dataclass @dataclass
class Fp8Arguments: class Fp8Arguments:

View File

@@ -22,7 +22,6 @@ from transformers import (
AutoModelForImageTextToText, AutoModelForImageTextToText,
AutoModelForSeq2SeqLM, AutoModelForSeq2SeqLM,
AutoModelForTextToWaveform, AutoModelForTextToWaveform,
AutoModelForVision2Seq,
AutoProcessor, AutoProcessor,
AutoTokenizer, AutoTokenizer,
) )
@@ -166,11 +165,9 @@ def load_model(
else: else:
if type(config) in AutoModelForImageTextToText._model_mapping.keys(): # image-text if type(config) in AutoModelForImageTextToText._model_mapping.keys(): # image-text
load_class = AutoModelForImageTextToText load_class = AutoModelForImageTextToText
elif type(config) in AutoModelForVision2Seq._model_mapping.keys(): # image-text
load_class = AutoModelForVision2Seq
elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys(): # audio-text elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys(): # audio-text
load_class = AutoModelForSeq2SeqLM load_class = AutoModelForSeq2SeqLM
elif type(config) in AutoModelForTextToWaveform._model_mapping.keys(): # audio hack for qwen omni elif type(config) in AutoModelForTextToWaveform._model_mapping.keys(): # audio-text for qwen omni
load_class = AutoModelForTextToWaveform load_class = AutoModelForTextToWaveform
else: else:
load_class = AutoModelForCausalLM load_class = AutoModelForCausalLM

View File

@@ -57,6 +57,11 @@ def configure_attn_implementation(config: "PretrainedConfig", model_args: "Model
"Gemma-2 should use soft-capping attention, while the SDPA attention does not support it." "Gemma-2 should use soft-capping attention, while the SDPA attention does not support it."
) )
if getattr(config, "model_type", None) in ["youtu", "youtu_vl"]:
if model_args.flash_attn in (AttentionFunction.AUTO, AttentionFunction.SDPA):
logger.warning_rank0("Youtu-VL does not support SDPA, forcing eager attention.")
model_args.flash_attn = AttentionFunction.DISABLED
if model_args.flash_attn == AttentionFunction.AUTO: if model_args.flash_attn == AttentionFunction.AUTO:
return return
@@ -85,6 +90,13 @@ def configure_attn_implementation(config: "PretrainedConfig", model_args: "Model
elif getattr(config, "model_type", None) == "kimi_vl": elif getattr(config, "model_type", None) == "kimi_vl":
setattr(config.vision_config, "_attn_implementation", requested_attn_implementation) setattr(config.vision_config, "_attn_implementation", requested_attn_implementation)
setattr(config.text_config, "_attn_implementation", requested_attn_implementation) setattr(config.text_config, "_attn_implementation", requested_attn_implementation)
elif getattr(config, "model_type", None) == "youtu_vl":
setattr(config, "attn_implementation", requested_attn_implementation)
setattr(config, "_attn_implementation", requested_attn_implementation)
if hasattr(config, "vision_config"):
setattr(config.vision_config, "_attn_implementation", requested_attn_implementation)
if hasattr(config, "text_config"):
setattr(config.text_config, "_attn_implementation", requested_attn_implementation)
else: else:
setattr(config, "_attn_implementation", requested_attn_implementation) setattr(config, "_attn_implementation", requested_attn_implementation)

View File

@@ -374,7 +374,13 @@ _register_composite_model(
_register_composite_model( _register_composite_model(
model_type="qwen3_omni_moe_thinker", model_type="qwen3_omni_moe_thinker",
projector_key="visual.merger", projector_key="visual.merger",
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list", "audio_tower"], vision_model_keys=[
"visual.pos_embed",
"visual.patch_embed",
"visual.blocks",
"visual.deepstack_merger_list",
"audio_tower",
],
language_model_keys=["model", "lm_head"], language_model_keys=["model", "lm_head"],
lora_conflict_keys=["patch_embed"], lora_conflict_keys=["patch_embed"],
) )

View File

@@ -61,6 +61,26 @@ def patch_qwen3_omni_moe_thinker_text_sparse_moe_block():
modeling_qwen3_omni_moe.Qwen3OmniMoeThinkerTextSparseMoeBlock = Qwen3OmniMoeThinkerTextSparseMoeBlock modeling_qwen3_omni_moe.Qwen3OmniMoeThinkerTextSparseMoeBlock = Qwen3OmniMoeThinkerTextSparseMoeBlock
def patch_youtu_vl_model(model: "PreTrainedModel") -> None:
original_forward = model.forward
def forward(self, *args, **kwargs):
outputs = original_forward(*args, **kwargs)
if "loss" not in outputs and "labels" in kwargs:
logits = outputs.get("logits")
labels = kwargs.get("labels")
if logits is not None and labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = torch.nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
outputs["loss"] = loss
return outputs
model.forward = MethodType(forward, model)
def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> None: def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> None:
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__): if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer) tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
@@ -207,6 +227,9 @@ def patch_model(
if getattr(model.config, "model_type", None) == "gemma3n": if getattr(model.config, "model_type", None) == "gemma3n":
setattr(model_args, "disable_gradient_checkpointing", True) setattr(model_args, "disable_gradient_checkpointing", True)
if getattr(model.config, "model_type", None) == "youtu_vl":
patch_youtu_vl_model(model)
prepare_model_for_training(model, model_args) prepare_model_for_training(model, model_args)
autocast_projector_dtype(model, model_args) autocast_projector_dtype(model, model_args)
add_z3_leaf_module(model) add_z3_leaf_module(model)

View File

@@ -103,7 +103,9 @@ class FixValueHeadModelCallback(TrainerCallback):
if args.should_save: if args.should_save:
output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}") output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
fix_valuehead_checkpoint( fix_valuehead_checkpoint(
model=kwargs.pop("model"), output_dir=output_dir, safe_serialization=args.save_safetensors model=kwargs.pop("model"),
output_dir=output_dir,
safe_serialization=getattr(args, "save_safetensors", True),
) )
@@ -137,7 +139,7 @@ class PissaConvertCallback(TrainerCallback):
if isinstance(model, PeftModel): if isinstance(model, PeftModel):
init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights") init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights")
setattr(model.peft_config["default"], "init_lora_weights", True) setattr(model.peft_config["default"], "init_lora_weights", True)
model.save_pretrained(pissa_init_dir, safe_serialization=args.save_safetensors) model.save_pretrained(pissa_init_dir, safe_serialization=getattr(args, "save_safetensors", True))
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights) setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
@override @override
@@ -155,11 +157,11 @@ class PissaConvertCallback(TrainerCallback):
if isinstance(model, PeftModel): if isinstance(model, PeftModel):
init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights") init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights")
setattr(model.peft_config["default"], "init_lora_weights", True) setattr(model.peft_config["default"], "init_lora_weights", True)
model.save_pretrained(pissa_backup_dir, safe_serialization=args.save_safetensors) model.save_pretrained(pissa_backup_dir, safe_serialization=getattr(args, "save_safetensors", True))
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights) setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
model.save_pretrained( model.save_pretrained(
pissa_convert_dir, pissa_convert_dir,
safe_serialization=args.save_safetensors, safe_serialization=getattr(args, "save_safetensors", True),
path_initial_model_for_weight_conversion=pissa_init_dir, path_initial_model_for_weight_conversion=pissa_init_dir,
) )
model.load_adapter(pissa_backup_dir, "default", is_trainable=True) model.load_adapter(pissa_backup_dir, "default", is_trainable=True)

View File

@@ -72,7 +72,7 @@ def run_ppo(
ppo_trainer.ppo_train(resume_from_checkpoint=training_args.resume_from_checkpoint) ppo_trainer.ppo_train(resume_from_checkpoint=training_args.resume_from_checkpoint)
ppo_trainer.save_model() ppo_trainer.save_model()
if training_args.should_save: if training_args.should_save:
fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors) fix_valuehead_checkpoint(model, training_args.output_dir, getattr(training_args, "save_safetensors", True))
ppo_trainer.save_state() # must be called after save_model to have a folder ppo_trainer.save_state() # must be called after save_model to have a folder
if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss: if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss:

View File

@@ -114,7 +114,7 @@ class PairwiseTrainer(Trainer):
if state_dict is None: if state_dict is None:
state_dict = self.model.state_dict() state_dict = self.model.state_dict()
if self.args.save_safetensors: if getattr(self.args, "save_safetensors", True):
from collections import defaultdict from collections import defaultdict
ptrs = defaultdict(list) ptrs = defaultdict(list)

View File

@@ -65,7 +65,7 @@ def run_rm(
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.save_model() trainer.save_model()
if training_args.should_save: if training_args.should_save:
fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors) fix_valuehead_checkpoint(model, training_args.output_dir, getattr(training_args, "save_safetensors", True))
trainer.log_metrics("train", train_result.metrics) trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics)

View File

@@ -20,7 +20,6 @@
import json import json
import os import os
from collections.abc import Callable, Mapping from collections.abc import Callable, Mapping
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Union
import torch import torch
@@ -34,6 +33,7 @@ from typing_extensions import override
from ..extras import logging from ..extras import logging
from ..extras.constants import IGNORE_INDEX, SWANLAB_CONFIG from ..extras.constants import IGNORE_INDEX, SWANLAB_CONFIG
from ..extras.misc import get_device_name
from ..extras.packages import is_apollo_available, is_galore_available, is_ray_available from ..extras.packages import is_apollo_available, is_galore_available, is_ray_available
from ..hparams import FinetuningArguments, ModelArguments from ..hparams import FinetuningArguments, ModelArguments
from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params
@@ -49,15 +49,15 @@ if is_apollo_available():
if is_ray_available(): if is_ray_available():
import ray import ray
from ray.train import RunConfig, ScalingConfig from ray.util.placement_group import PlacementGroup, placement_group
from ray.train.torch import TorchTrainer from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedModel, TrainerCallback, TrainerState from transformers import PreTrainedModel, TrainerCallback, TrainerState
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from ..hparams import DataArguments, RayArguments, TrainingArguments from ..hparams import DataArguments, TrainingArguments
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@@ -807,36 +807,88 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
return swanlab_callback return swanlab_callback
def get_ray_trainer( def get_placement_group(num_workers: int) -> tuple["PlacementGroup", dict[str, int]]:
training_function: Callable, r"""Get the Ray placement group for distributed training."""
train_loop_config: dict[str, Any], bundle = {"CPU": 10}
ray_args: "RayArguments", device_name = get_device_name().upper()
) -> "TorchTrainer": if device_name != "CPU":
if not ray_args.use_ray: bundle[device_name] = 1
raise ValueError("Ray was not enabled. Please set `USE_RAY=1` to enable ray.") bundles = [bundle for _ in range(num_workers)]
pg = placement_group(bundles, strategy="PACK")
if ray_args.ray_init_kwargs is not None: return pg, bundle
ray.init(**ray_args.ray_init_kwargs)
if ray_args.ray_storage_filesystem is not None:
# this means we are using s3/gcs
storage_path = ray_args.ray_storage_path
else:
storage_path = Path(ray_args.ray_storage_path).absolute().as_posix()
trainer = TorchTrainer( def get_ray_remote_config_for_worker(
training_function, placement_group: "PlacementGroup",
train_loop_config=train_loop_config, bundle_idx: int,
scaling_config=ScalingConfig( rank: int,
num_workers=ray_args.ray_num_workers, world_size: int,
resources_per_worker=ray_args.resources_per_worker, master_addr: str,
placement_strategy=ray_args.placement_strategy, master_port: str,
use_gpu=True, env: dict[str, str] = None,
) -> dict[str, Any]:
r"""Get the remote config for a Ray worker."""
env_vars = {
"RANK": str(rank),
"WORLD_SIZE": str(world_size),
"MASTER_ADDR": master_addr,
"MASTER_PORT": master_port,
"TORCHELASTIC_USE_AGENT_STORE": "False",
}
env.update(env_vars)
remote_config = {
"scheduling_strategy": PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_bundle_index=bundle_idx,
), ),
run_config=RunConfig( "runtime_env": {"env_vars": env},
name=ray_args.ray_run_name, "num_cpus": 10,
storage_filesystem=ray_args.ray_storage_filesystem, }
storage_path=storage_path,
), device_name = get_device_name()
) if device_name == "gpu":
return trainer remote_config["num_gpus"] = 1
elif device_name == "npu":
remote_config["resources"] = {"NPU": 1}
return remote_config
def get_ray_head_node_ip() -> str:
r"""Get the IP address of the Ray head node."""
head_ip = next(node["NodeManagerAddress"] for node in ray.nodes() if node.get("IsHead", False))
return head_ip
def sort_placement_group_by_node_ip(placement_group: "PlacementGroup", master_addr: str = None) -> list[int]:
r"""Sort the placement group bundles by their node IP addresses."""
@ray.remote
def _get_node_ip():
return ray.util.get_node_ip_address().strip("[]")
tasks = []
for bundle_idx in range(placement_group.bundle_count):
task = _get_node_ip.options(
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_bundle_index=bundle_idx,
),
).remote()
tasks.append(task)
bundle_ips = ray.get(tasks)
bundle_node_ip_list = list(enumerate(bundle_ips))
sorted_bundle_node_ip_list = sorted(bundle_node_ip_list, key=lambda x: x[1])
sorted_bundle_indices = [item[0] for item in sorted_bundle_node_ip_list]
if master_addr is not None:
preferred_indices = [idx for idx, ip in bundle_node_ip_list if ip == master_addr]
if preferred_indices:
remaining = [i for i in sorted_bundle_indices if i not in preferred_indices]
sorted_bundle_indices = preferred_indices + remaining
return sorted_bundle_indices

View File

@@ -23,9 +23,9 @@ from transformers import EarlyStoppingCallback, PreTrainedModel
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from ..extras import logging from ..extras import logging
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.misc import infer_optim_dtype from ..extras.misc import find_available_port, get_device_name, get_torch_device, infer_optim_dtype
from ..extras.packages import is_mcore_adapter_available, is_ray_available from ..extras.packages import is_mcore_adapter_available, is_ray_available
from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args from ..hparams import RayArguments, get_infer_args, get_ray_args, get_train_args, read_args
from ..model import load_model, load_tokenizer from ..model import load_model, load_tokenizer
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
from .dpo import run_dpo from .dpo import run_dpo
@@ -34,12 +34,17 @@ from .ppo import run_ppo
from .pt import run_pt from .pt import run_pt
from .rm import run_rm from .rm import run_rm
from .sft import run_sft from .sft import run_sft
from .trainer_utils import get_ray_trainer, get_swanlab_callback from .trainer_utils import (
get_placement_group,
get_ray_head_node_ip,
get_ray_remote_config_for_worker,
get_swanlab_callback,
sort_placement_group_by_node_ip,
)
if is_ray_available(): if is_ray_available():
import ray import ray
from ray.train.huggingface.transformers import RayTrainReportCallback
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -115,13 +120,7 @@ def run_exp(args: Optional[dict[str, Any]] = None, callbacks: Optional[list["Tra
ray_args = get_ray_args(args) ray_args = get_ray_args(args)
callbacks = callbacks or [] callbacks = callbacks or []
if ray_args.use_ray: if ray_args.use_ray:
callbacks.append(RayTrainReportCallback()) _ray_training_function(ray_args, config={"args": args, "callbacks": callbacks})
trainer = get_ray_trainer(
training_function=_training_function,
train_loop_config={"args": args, "callbacks": callbacks},
ray_args=ray_args,
)
trainer.fit()
else: else:
_training_function(config={"args": args, "callbacks": callbacks}) _training_function(config={"args": args, "callbacks": callbacks})
@@ -212,3 +211,94 @@ def export_model(args: Optional[dict[str, Any]] = None) -> None:
with open(ollama_modelfile, "w", encoding="utf-8") as f: with open(ollama_modelfile, "w", encoding="utf-8") as f:
f.write(template.get_ollama_modelfile(tokenizer)) f.write(template.get_ollama_modelfile(tokenizer))
logger.info_rank0(f"Ollama modelfile saved in {ollama_modelfile}") logger.info_rank0(f"Ollama modelfile saved in {ollama_modelfile}")
class Worker:
def __init__(self):
self._setup_env_visible_devices()
local_rank = os.environ.get("LOCAL_RANK", "0")
get_torch_device().set_device(int(local_rank))
def _setup_env_visible_devices(self) -> None:
RAY_NOSET_VISIBLE_DEVICES_LIST = [
"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES",
"RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES",
]
is_ray_noset_visible_devices = any(os.environ.get(env_var, None) for env_var in RAY_NOSET_VISIBLE_DEVICES_LIST)
if is_ray_noset_visible_devices:
device_name = get_device_name().upper()
local_rank = ray.get_runtime_context().get_accelerator_ids()[device_name][0]
os.environ["LOCAL_RANK"] = local_rank
else:
os.environ["LOCAL_RANK"] = "0"
def _training_function(self, config: dict[str, Any]) -> None:
_training_function(config)
def _ray_training_function(ray_args: "RayArguments", config: dict[str, Any]) -> None:
num_workers = ray_args.ray_num_workers
master_addr = ray_args.master_addr
master_port = ray_args.master_port
logger.info(f"Using ray.remote mode with {num_workers} workers for distributed training.")
# initialize ray
if not ray.is_initialized():
if ray_args.ray_init_kwargs is not None:
ray.init(**ray_args.ray_init_kwargs)
else:
ray.init()
# verify resources
device_name = get_device_name().upper()
total_devices = int(ray.cluster_resources().get(device_name, 0))
if num_workers > total_devices:
raise ValueError(
f"The number of devices in the Ray cluster ({total_devices}) should be greater than num_workers ({num_workers})."
)
# verify master_addr
if master_addr is None:
master_addr = get_ray_head_node_ip()
logger.info(f"`master_addr` is not specified, using head node ip: {master_addr}.")
else:
nodes = [node["NodeManagerAddress"] for node in ray.nodes() if node["Alive"]]
if master_addr not in nodes:
raise ValueError(f"The `master_addr` ({master_addr}) is not in Ray cluster or not alive ")
# create placementgroup for resource management
pg, bundle = get_placement_group(total_devices)
ray.get(pg.ready())
logger.info(f"Create placement group with {num_workers} bundles: {bundle}")
# get sorted_bundle_indices
sorted_bundle_indices = sort_placement_group_by_node_ip(pg, master_addr)
# get master port
if master_port is None:
master_port = find_available_port()
logger.info(f"`master_port` is not specified, using available port: {master_port}.")
master_port = str(master_port)
# backing up environment variables
current_env = dict(os.environ.items())
# launch workers
RayWorker = ray.remote(Worker)
workers = []
for rank in range(num_workers):
remote_config = get_ray_remote_config_for_worker(
placement_group=pg,
bundle_idx=sorted_bundle_indices[rank],
rank=rank,
world_size=num_workers,
master_addr=master_addr,
master_port=master_port,
env=current_env,
)
worker = RayWorker.options(**remote_config).remote()
workers.append(worker)
ray.get([worker._training_function.remote(config=config) for worker in workers])
ray.shutdown()

View File

@@ -18,7 +18,6 @@ Contains shared fixtures, pytest configuration, and custom markers.
""" """
import os import os
import sys
import pytest import pytest
import torch import torch
@@ -149,14 +148,7 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -
devices_str = ",".join(str(i) for i in range(required)) devices_str = ",".join(str(i) for i in range(required))
monkeypatch.setenv(env_key, devices_str) monkeypatch.setenv(env_key, devices_str)
monkeypatch.syspath_prepend(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
# add project root dir to path for mp run
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if project_root not in sys.path:
sys.path.insert(0, project_root)
os.environ["PYTHONPATH"] = project_root + os.pathsep + os.environ.get("PYTHONPATH", "")
else: # non-distributed test else: # non-distributed test
if old_value: if old_value:
visible_devices = [v for v in old_value.split(",") if v != ""] visible_devices = [v for v in old_value.split(",") if v != ""]

View File

@@ -20,6 +20,7 @@ from datasets import load_dataset
from transformers import AutoTokenizer from transformers import AutoTokenizer
from llamafactory.extras.constants import IGNORE_INDEX from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.extras.packages import is_transformers_version_greater_than
from llamafactory.train.test_utils import load_dataset_module from llamafactory.train.test_utils import load_dataset_module
@@ -52,7 +53,12 @@ def test_feedback_data(num_samples: int):
for index in indexes: for index in indexes:
messages = original_data["messages"][index] messages = original_data["messages"][index]
ref_input_ids = ref_tokenizer.apply_chat_template(messages) ref_input_ids = ref_tokenizer.apply_chat_template(messages)
prompt_len = len(ref_tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True)) ref_prompt_ids = ref_tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True)
if is_transformers_version_greater_than("5.0.0"):
ref_input_ids = ref_input_ids["input_ids"]
ref_prompt_ids = ref_prompt_ids["input_ids"]
prompt_len = len(ref_prompt_ids)
ref_labels = [IGNORE_INDEX] * prompt_len + ref_input_ids[prompt_len:] ref_labels = [IGNORE_INDEX] * prompt_len + ref_input_ids[prompt_len:]
assert train_dataset["input_ids"][index] == ref_input_ids assert train_dataset["input_ids"][index] == ref_input_ids
assert train_dataset["labels"][index] == ref_labels assert train_dataset["labels"][index] == ref_labels

View File

@@ -20,6 +20,7 @@ from datasets import load_dataset
from transformers import AutoTokenizer from transformers import AutoTokenizer
from llamafactory.extras.constants import IGNORE_INDEX from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.extras.packages import is_transformers_version_greater_than
from llamafactory.train.test_utils import load_dataset_module from llamafactory.train.test_utils import load_dataset_module
@@ -63,13 +64,21 @@ def test_pairwise_data(num_samples: int):
rejected_messages = original_data["conversations"][index] + [original_data["rejected"][index]] rejected_messages = original_data["conversations"][index] + [original_data["rejected"][index]]
chosen_messages = _convert_sharegpt_to_openai(chosen_messages) chosen_messages = _convert_sharegpt_to_openai(chosen_messages)
rejected_messages = _convert_sharegpt_to_openai(rejected_messages) rejected_messages = _convert_sharegpt_to_openai(rejected_messages)
ref_chosen_input_ids = ref_tokenizer.apply_chat_template(chosen_messages) ref_chosen_input_ids = ref_tokenizer.apply_chat_template(chosen_messages)
chosen_prompt_len = len(ref_tokenizer.apply_chat_template(chosen_messages[:-1], add_generation_prompt=True)) ref_chosen_prompt_ids = ref_tokenizer.apply_chat_template(chosen_messages[:-1], add_generation_prompt=True)
ref_chosen_labels = [IGNORE_INDEX] * chosen_prompt_len + ref_chosen_input_ids[chosen_prompt_len:]
ref_rejected_input_ids = ref_tokenizer.apply_chat_template(rejected_messages) ref_rejected_input_ids = ref_tokenizer.apply_chat_template(rejected_messages)
rejected_prompt_len = len( ref_rejected_prompt_ids = ref_tokenizer.apply_chat_template(rejected_messages[:-1], add_generation_prompt=True)
ref_tokenizer.apply_chat_template(rejected_messages[:-1], add_generation_prompt=True)
) if is_transformers_version_greater_than("5.0.0"):
ref_chosen_input_ids = ref_chosen_input_ids["input_ids"]
ref_rejected_input_ids = ref_rejected_input_ids["input_ids"]
ref_chosen_prompt_ids = ref_chosen_prompt_ids["input_ids"]
ref_rejected_prompt_ids = ref_rejected_prompt_ids["input_ids"]
chosen_prompt_len = len(ref_chosen_prompt_ids)
rejected_prompt_len = len(ref_rejected_prompt_ids)
ref_chosen_labels = [IGNORE_INDEX] * chosen_prompt_len + ref_chosen_input_ids[chosen_prompt_len:]
ref_rejected_labels = [IGNORE_INDEX] * rejected_prompt_len + ref_rejected_input_ids[rejected_prompt_len:] ref_rejected_labels = [IGNORE_INDEX] * rejected_prompt_len + ref_rejected_input_ids[rejected_prompt_len:]
assert train_dataset["chosen_input_ids"][index] == ref_chosen_input_ids assert train_dataset["chosen_input_ids"][index] == ref_chosen_input_ids
assert train_dataset["chosen_labels"][index] == ref_chosen_labels assert train_dataset["chosen_labels"][index] == ref_chosen_labels

View File

@@ -20,6 +20,7 @@ from datasets import load_dataset
from transformers import AutoTokenizer from transformers import AutoTokenizer
from llamafactory.extras.constants import IGNORE_INDEX from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.extras.packages import is_transformers_version_greater_than
from llamafactory.train.test_utils import load_dataset_module from llamafactory.train.test_utils import load_dataset_module
@@ -59,7 +60,16 @@ def test_supervised_single_turn(num_samples: int):
{"role": "assistant", "content": original_data["output"][index]}, {"role": "assistant", "content": original_data["output"][index]},
] ]
ref_input_ids = ref_tokenizer.apply_chat_template(messages) ref_input_ids = ref_tokenizer.apply_chat_template(messages)
ref_prompt_ids = ref_tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True)
if is_transformers_version_greater_than("5.0.0"):
ref_input_ids = ref_input_ids["input_ids"]
ref_prompt_ids = ref_prompt_ids["input_ids"]
prompt_len = len(ref_prompt_ids)
ref_label_ids = [IGNORE_INDEX] * prompt_len + ref_input_ids[prompt_len:]
assert train_dataset["input_ids"][index] == ref_input_ids assert train_dataset["input_ids"][index] == ref_input_ids
assert train_dataset["labels"][index] == ref_label_ids
@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.runs_on(["cpu", "mps"])
@@ -73,6 +83,10 @@ def test_supervised_multi_turn(num_samples: int):
indexes = random.choices(range(len(original_data)), k=num_samples) indexes = random.choices(range(len(original_data)), k=num_samples)
for index in indexes: for index in indexes:
ref_input_ids = ref_tokenizer.apply_chat_template(original_data["messages"][index]) ref_input_ids = ref_tokenizer.apply_chat_template(original_data["messages"][index])
if is_transformers_version_greater_than("5.0.0"):
ref_input_ids = ref_input_ids["input_ids"]
# cannot test the label ids in multi-turn case
assert train_dataset["input_ids"][index] == ref_input_ids assert train_dataset["input_ids"][index] == ref_input_ids
@@ -86,9 +100,12 @@ def test_supervised_train_on_prompt(num_samples: int):
original_data = load_dataset(DEMO_DATA, name="system_chat", split="train") original_data = load_dataset(DEMO_DATA, name="system_chat", split="train")
indexes = random.choices(range(len(original_data)), k=num_samples) indexes = random.choices(range(len(original_data)), k=num_samples)
for index in indexes: for index in indexes:
ref_ids = ref_tokenizer.apply_chat_template(original_data["messages"][index]) ref_input_ids = ref_tokenizer.apply_chat_template(original_data["messages"][index])
assert train_dataset["input_ids"][index] == ref_ids if is_transformers_version_greater_than("5.0.0"):
assert train_dataset["labels"][index] == ref_ids ref_input_ids = ref_input_ids["input_ids"]
assert train_dataset["input_ids"][index] == ref_input_ids
assert train_dataset["labels"][index] == ref_input_ids
@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.runs_on(["cpu", "mps"])
@@ -103,7 +120,13 @@ def test_supervised_mask_history(num_samples: int):
for index in indexes: for index in indexes:
messages = original_data["messages"][index] messages = original_data["messages"][index]
ref_input_ids = ref_tokenizer.apply_chat_template(messages) ref_input_ids = ref_tokenizer.apply_chat_template(messages)
prompt_len = len(ref_tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True)) ref_prompt_ids = ref_tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True)
if is_transformers_version_greater_than("5.0.0"):
ref_input_ids = ref_input_ids["input_ids"]
ref_prompt_ids = ref_prompt_ids["input_ids"]
prompt_len = len(ref_prompt_ids)
ref_label_ids = [IGNORE_INDEX] * prompt_len + ref_input_ids[prompt_len:] ref_label_ids = [IGNORE_INDEX] * prompt_len + ref_input_ids[prompt_len:]
assert train_dataset["input_ids"][index] == ref_input_ids assert train_dataset["input_ids"][index] == ref_input_ids
assert train_dataset["labels"][index] == ref_label_ids assert train_dataset["labels"][index] == ref_label_ids

View File

@@ -19,6 +19,7 @@ import pytest
from datasets import load_dataset from datasets import load_dataset
from transformers import AutoTokenizer from transformers import AutoTokenizer
from llamafactory.extras.packages import is_transformers_version_greater_than
from llamafactory.train.test_utils import load_dataset_module from llamafactory.train.test_utils import load_dataset_module
@@ -55,8 +56,13 @@ def test_unsupervised_data(num_samples: int):
indexes = random.choices(range(len(original_data)), k=num_samples) indexes = random.choices(range(len(original_data)), k=num_samples)
for index in indexes: for index in indexes:
messages = original_data["messages"][index] messages = original_data["messages"][index]
ref_ids = ref_tokenizer.apply_chat_template(messages) ref_input_ids = ref_tokenizer.apply_chat_template(messages)
ref_input_ids = ref_tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True) ref_prompt_ids = ref_tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True)
ref_labels = ref_ids[len(ref_input_ids) :]
assert train_dataset["input_ids"][index] == ref_input_ids if is_transformers_version_greater_than("5.0.0"):
ref_input_ids = ref_input_ids["input_ids"]
ref_prompt_ids = ref_prompt_ids["input_ids"]
ref_labels = ref_input_ids[len(ref_prompt_ids) :]
assert train_dataset["input_ids"][index] == ref_prompt_ids
assert train_dataset["labels"][index] == ref_labels assert train_dataset["labels"][index] == ref_labels

View File

@@ -17,7 +17,7 @@ import os
import pytest import pytest
import torch import torch
from PIL import Image from PIL import Image
from transformers import AutoConfig, AutoModelForVision2Seq from transformers import AutoConfig, AutoModelForImageTextToText
from llamafactory.data import get_template_and_fix_tokenizer from llamafactory.data import get_template_and_fix_tokenizer
from llamafactory.data.collator import MultiModalDataCollatorForSeq2Seq, prepare_4d_attention_mask from llamafactory.data.collator import MultiModalDataCollatorForSeq2Seq, prepare_4d_attention_mask
@@ -82,7 +82,7 @@ def test_multimodal_collator():
template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args) template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args)
config = AutoConfig.from_pretrained(model_args.model_name_or_path) config = AutoConfig.from_pretrained(model_args.model_name_or_path)
with torch.device("meta"): with torch.device("meta"):
model = AutoModelForVision2Seq.from_config(config) model = AutoModelForImageTextToText.from_config(config)
data_collator = MultiModalDataCollatorForSeq2Seq( data_collator = MultiModalDataCollatorForSeq2Seq(
template=template, template=template,

View File

@@ -20,6 +20,7 @@ from transformers import AutoTokenizer
from llamafactory.data import get_template_and_fix_tokenizer from llamafactory.data import get_template_and_fix_tokenizer
from llamafactory.data.template import parse_template from llamafactory.data.template import parse_template
from llamafactory.extras.packages import is_transformers_version_greater_than
from llamafactory.hparams import DataArguments from llamafactory.hparams import DataArguments
@@ -65,7 +66,6 @@ def _check_template(
template_name: str, template_name: str,
prompt_str: str, prompt_str: str,
answer_str: str, answer_str: str,
use_fast: bool,
messages: list[dict[str, str]] = MESSAGES, messages: list[dict[str, str]] = MESSAGES,
) -> None: ) -> None:
r"""Check template. r"""Check template.
@@ -75,13 +75,15 @@ def _check_template(
template_name: the template name. template_name: the template name.
prompt_str: the string corresponding to the prompt part. prompt_str: the string corresponding to the prompt part.
answer_str: the string corresponding to the answer part. answer_str: the string corresponding to the answer part.
use_fast: whether to use fast tokenizer.
messages: the list of messages. messages: the list of messages.
""" """
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, token=HF_TOKEN) tokenizer = AutoTokenizer.from_pretrained(model_id)
content_str = tokenizer.apply_chat_template(messages, tokenize=False) content_str = tokenizer.apply_chat_template(messages, tokenize=False)
content_ids = tokenizer.apply_chat_template(messages, tokenize=True) content_ids = tokenizer.apply_chat_template(messages, tokenize=True)
if is_transformers_version_greater_than("5.0.0"):
content_ids = content_ids["input_ids"]
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template=template_name)) template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template=template_name))
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, messages) prompt_ids, answer_ids = template.encode_oneturn(tokenizer, messages)
assert content_str == prompt_str + answer_str assert content_str == prompt_str + answer_str
@@ -90,9 +92,8 @@ def _check_template(
@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("use_fast", [True, False]) def test_encode_oneturn():
def test_encode_oneturn(use_fast: bool): tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3")) template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES) prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
prompt_str = ( prompt_str = (
@@ -106,9 +107,8 @@ def test_encode_oneturn(use_fast: bool):
@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("use_fast", [True, False]) def test_encode_multiturn():
def test_encode_multiturn(use_fast: bool): tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3")) template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES) encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES)
prompt_str_1 = ( prompt_str_1 = (
@@ -128,11 +128,10 @@ def test_encode_multiturn(use_fast: bool):
@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("use_fast", [True, False])
@pytest.mark.parametrize("cot_messages", [True, False]) @pytest.mark.parametrize("cot_messages", [True, False])
@pytest.mark.parametrize("enable_thinking", [True, False, None]) @pytest.mark.parametrize("enable_thinking", [True, False, None])
def test_reasoning_encode_oneturn(use_fast: bool, cot_messages: bool, enable_thinking: bool): def test_reasoning_encode_oneturn(cot_messages: bool, enable_thinking: bool):
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", use_fast=use_fast) tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
data_args = DataArguments(template="qwen3", enable_thinking=enable_thinking) data_args = DataArguments(template="qwen3", enable_thinking=enable_thinking)
template = get_template_and_fix_tokenizer(tokenizer, data_args) template = get_template_and_fix_tokenizer(tokenizer, data_args)
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES) prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES)
@@ -155,11 +154,10 @@ def test_reasoning_encode_oneturn(use_fast: bool, cot_messages: bool, enable_thi
@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("use_fast", [True, False])
@pytest.mark.parametrize("cot_messages", [True, False]) @pytest.mark.parametrize("cot_messages", [True, False])
@pytest.mark.parametrize("enable_thinking", [True, False, None]) @pytest.mark.parametrize("enable_thinking", [True, False, None])
def test_reasoning_encode_multiturn(use_fast: bool, cot_messages: bool, enable_thinking: bool): def test_reasoning_encode_multiturn(cot_messages: bool, enable_thinking: bool):
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", use_fast=use_fast) tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
data_args = DataArguments(template="qwen3", enable_thinking=enable_thinking) data_args = DataArguments(template="qwen3", enable_thinking=enable_thinking)
template = get_template_and_fix_tokenizer(tokenizer, data_args) template = get_template_and_fix_tokenizer(tokenizer, data_args)
encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES) encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES)
@@ -185,10 +183,9 @@ def test_reasoning_encode_multiturn(use_fast: bool, cot_messages: bool, enable_t
@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("use_fast", [True, False]) def test_jinja_template():
def test_jinja_template(use_fast: bool): tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast) ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3")) template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
tokenizer.chat_template = template._get_jinja_template(tokenizer) # llama3 template no replace tokenizer.chat_template = template._get_jinja_template(tokenizer) # llama3 template no replace
assert tokenizer.chat_template != ref_tokenizer.chat_template assert tokenizer.chat_template != ref_tokenizer.chat_template
@@ -222,8 +219,7 @@ def test_get_stop_token_ids():
@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.") @pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
@pytest.mark.parametrize("use_fast", [True, False]) def test_gemma_template():
def test_gemma_template(use_fast: bool):
prompt_str = ( prompt_str = (
f"<bos><start_of_turn>user\n{MESSAGES[0]['content']}<end_of_turn>\n" f"<bos><start_of_turn>user\n{MESSAGES[0]['content']}<end_of_turn>\n"
f"<start_of_turn>model\n{MESSAGES[1]['content']}<end_of_turn>\n" f"<start_of_turn>model\n{MESSAGES[1]['content']}<end_of_turn>\n"
@@ -231,13 +227,12 @@ def test_gemma_template(use_fast: bool):
"<start_of_turn>model\n" "<start_of_turn>model\n"
) )
answer_str = f"{MESSAGES[3]['content']}<end_of_turn>\n" answer_str = f"{MESSAGES[3]['content']}<end_of_turn>\n"
_check_template("google/gemma-3-4b-it", "gemma", prompt_str, answer_str, use_fast) _check_template("google/gemma-3-4b-it", "gemma", prompt_str, answer_str)
@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.") @pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
@pytest.mark.parametrize("use_fast", [True, False]) def test_gemma2_template():
def test_gemma2_template(use_fast: bool):
prompt_str = ( prompt_str = (
f"<bos><start_of_turn>user\n{MESSAGES[0]['content']}<end_of_turn>\n" f"<bos><start_of_turn>user\n{MESSAGES[0]['content']}<end_of_turn>\n"
f"<start_of_turn>model\n{MESSAGES[1]['content']}<end_of_turn>\n" f"<start_of_turn>model\n{MESSAGES[1]['content']}<end_of_turn>\n"
@@ -245,13 +240,12 @@ def test_gemma2_template(use_fast: bool):
"<start_of_turn>model\n" "<start_of_turn>model\n"
) )
answer_str = f"{MESSAGES[3]['content']}<end_of_turn>\n" answer_str = f"{MESSAGES[3]['content']}<end_of_turn>\n"
_check_template("google/gemma-2-2b-it", "gemma2", prompt_str, answer_str, use_fast) _check_template("google/gemma-2-2b-it", "gemma2", prompt_str, answer_str)
@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.") @pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
@pytest.mark.parametrize("use_fast", [True, False]) def test_llama3_template():
def test_llama3_template(use_fast: bool):
prompt_str = ( prompt_str = (
f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{MESSAGES[0]['content']}<|eot_id|>" f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{MESSAGES[0]['content']}<|eot_id|>"
f"<|start_header_id|>assistant<|end_header_id|>\n\n{MESSAGES[1]['content']}<|eot_id|>" f"<|start_header_id|>assistant<|end_header_id|>\n\n{MESSAGES[1]['content']}<|eot_id|>"
@@ -259,14 +253,11 @@ def test_llama3_template(use_fast: bool):
"<|start_header_id|>assistant<|end_header_id|>\n\n" "<|start_header_id|>assistant<|end_header_id|>\n\n"
) )
answer_str = f"{MESSAGES[3]['content']}<|eot_id|>" answer_str = f"{MESSAGES[3]['content']}<|eot_id|>"
_check_template("meta-llama/Meta-Llama-3-8B-Instruct", "llama3", prompt_str, answer_str, use_fast) _check_template("meta-llama/Meta-Llama-3-8B-Instruct", "llama3", prompt_str, answer_str)
@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize( def test_llama4_template():
"use_fast", [True, pytest.param(False, marks=pytest.mark.xfail(reason="Llama 4 has no slow tokenizer."))]
)
def test_llama4_template(use_fast: bool):
prompt_str = ( prompt_str = (
f"<|begin_of_text|><|header_start|>user<|header_end|>\n\n{MESSAGES[0]['content']}<|eot|>" f"<|begin_of_text|><|header_start|>user<|header_end|>\n\n{MESSAGES[0]['content']}<|eot|>"
f"<|header_start|>assistant<|header_end|>\n\n{MESSAGES[1]['content']}<|eot|>" f"<|header_start|>assistant<|header_end|>\n\n{MESSAGES[1]['content']}<|eot|>"
@@ -274,18 +265,11 @@ def test_llama4_template(use_fast: bool):
"<|header_start|>assistant<|header_end|>\n\n" "<|header_start|>assistant<|header_end|>\n\n"
) )
answer_str = f"{MESSAGES[3]['content']}<|eot|>" answer_str = f"{MESSAGES[3]['content']}<|eot|>"
_check_template(TINY_LLAMA4, "llama4", prompt_str, answer_str, use_fast) _check_template(TINY_LLAMA4, "llama4", prompt_str, answer_str)
@pytest.mark.parametrize(
"use_fast",
[
pytest.param(True, marks=pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")),
pytest.param(False, marks=pytest.mark.xfail(reason="Phi-4 slow tokenizer is broken.")),
],
)
@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_phi4_template(use_fast: bool): def test_phi4_template():
prompt_str = ( prompt_str = (
f"<|im_start|>user<|im_sep|>{MESSAGES[0]['content']}<|im_end|>" f"<|im_start|>user<|im_sep|>{MESSAGES[0]['content']}<|im_end|>"
f"<|im_start|>assistant<|im_sep|>{MESSAGES[1]['content']}<|im_end|>" f"<|im_start|>assistant<|im_sep|>{MESSAGES[1]['content']}<|im_end|>"
@@ -293,13 +277,12 @@ def test_phi4_template(use_fast: bool):
"<|im_start|>assistant<|im_sep|>" "<|im_start|>assistant<|im_sep|>"
) )
answer_str = f"{MESSAGES[3]['content']}<|im_end|>" answer_str = f"{MESSAGES[3]['content']}<|im_end|>"
_check_template("microsoft/phi-4", "phi4", prompt_str, answer_str, use_fast) _check_template("microsoft/phi-4", "phi4", prompt_str, answer_str)
@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.xfail(not HF_TOKEN, reason="Authorization.") @pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")
@pytest.mark.parametrize("use_fast", [True, False]) def test_qwen2_5_template():
def test_qwen2_5_template(use_fast: bool):
prompt_str = ( prompt_str = (
"<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n" "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n"
f"<|im_start|>user\n{MESSAGES[0]['content']}<|im_end|>\n" f"<|im_start|>user\n{MESSAGES[0]['content']}<|im_end|>\n"
@@ -308,13 +291,12 @@ def test_qwen2_5_template(use_fast: bool):
"<|im_start|>assistant\n" "<|im_start|>assistant\n"
) )
answer_str = f"{MESSAGES[3]['content']}<|im_end|>\n" answer_str = f"{MESSAGES[3]['content']}<|im_end|>\n"
_check_template("Qwen/Qwen2.5-7B-Instruct", "qwen", prompt_str, answer_str, use_fast) _check_template("Qwen/Qwen2.5-7B-Instruct", "qwen", prompt_str, answer_str)
@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("use_fast", [True, False])
@pytest.mark.parametrize("cot_messages", [True, False]) @pytest.mark.parametrize("cot_messages", [True, False])
def test_qwen3_template(use_fast: bool, cot_messages: bool): def test_qwen3_template(cot_messages: bool):
prompt_str = ( prompt_str = (
f"<|im_start|>user\n{MESSAGES[0]['content']}<|im_end|>\n" f"<|im_start|>user\n{MESSAGES[0]['content']}<|im_end|>\n"
f"<|im_start|>assistant\n{MESSAGES[1]['content']}<|im_end|>\n" f"<|im_start|>assistant\n{MESSAGES[1]['content']}<|im_end|>\n"
@@ -328,12 +310,12 @@ def test_qwen3_template(use_fast: bool, cot_messages: bool):
answer_str = f"{MESSAGES_WITH_THOUGHT[3]['content']}<|im_end|>\n" answer_str = f"{MESSAGES_WITH_THOUGHT[3]['content']}<|im_end|>\n"
messages = MESSAGES_WITH_THOUGHT messages = MESSAGES_WITH_THOUGHT
_check_template("Qwen/Qwen3-8B", "qwen3", prompt_str, answer_str, use_fast, messages=messages) _check_template("Qwen/Qwen3-8B", "qwen3", prompt_str, answer_str, messages=messages)
@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_parse_llama3_template(): def test_parse_llama3_template():
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, token=HF_TOKEN) tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
template = parse_template(tokenizer) template = parse_template(tokenizer)
assert template.format_user.slots == [ assert template.format_user.slots == [
"<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>" "<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
@@ -348,7 +330,7 @@ def test_parse_llama3_template():
@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.xfail(not HF_TOKEN, reason="Authorization.") @pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")
def test_parse_qwen_template(): def test_parse_qwen_template():
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct", token=HF_TOKEN) tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
template = parse_template(tokenizer) template = parse_template(tokenizer)
assert template.__class__.__name__ == "Template" assert template.__class__.__name__ == "Template"
assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"] assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
@@ -361,7 +343,7 @@ def test_parse_qwen_template():
@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.xfail(not HF_TOKEN, reason="Authorization.") @pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")
def test_parse_qwen3_template(): def test_parse_qwen3_template():
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", token=HF_TOKEN) tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
template = parse_template(tokenizer) template = parse_template(tokenizer)
assert template.__class__.__name__ == "ReasoningTemplate" assert template.__class__.__name__ == "ReasoningTemplate"
assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"] assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]

View File

@@ -16,7 +16,8 @@ import os
import pytest import pytest
import torch import torch
from transformers import AutoConfig, AutoModelForVision2Seq from safetensors.torch import load_file
from transformers import AutoConfig, AutoModelForImageTextToText
from llamafactory.extras.packages import is_transformers_version_greater_than from llamafactory.extras.packages import is_transformers_version_greater_than
from llamafactory.hparams import FinetuningArguments, ModelArguments from llamafactory.hparams import FinetuningArguments, ModelArguments
@@ -36,7 +37,7 @@ def test_visual_full(freeze_vision_tower: bool, freeze_multi_modal_projector: bo
) )
config = AutoConfig.from_pretrained(model_args.model_name_or_path) config = AutoConfig.from_pretrained(model_args.model_name_or_path)
with torch.device("meta"): with torch.device("meta"):
model = AutoModelForVision2Seq.from_config(config) model = AutoModelForImageTextToText.from_config(config)
model = init_adapter(config, model, model_args, finetuning_args, is_trainable=True) model = init_adapter(config, model, model_args, finetuning_args, is_trainable=True)
for name, param in model.named_parameters(): for name, param in model.named_parameters():
@@ -56,7 +57,7 @@ def test_visual_lora(freeze_vision_tower: bool, freeze_language_model: bool):
) )
config = AutoConfig.from_pretrained(model_args.model_name_or_path) config = AutoConfig.from_pretrained(model_args.model_name_or_path)
with torch.device("meta"): with torch.device("meta"):
model = AutoModelForVision2Seq.from_config(config) model = AutoModelForImageTextToText.from_config(config)
model = init_adapter(config, model, model_args, finetuning_args, is_trainable=True) model = init_adapter(config, model, model_args, finetuning_args, is_trainable=True)
trainable_params, frozen_params = set(), set() trainable_params, frozen_params = set(), set()
@@ -86,13 +87,14 @@ def test_visual_model_save_load():
finetuning_args = FinetuningArguments(finetuning_type="full") finetuning_args = FinetuningArguments(finetuning_type="full")
config = AutoConfig.from_pretrained(model_args.model_name_or_path) config = AutoConfig.from_pretrained(model_args.model_name_or_path)
with torch.device("meta"): with torch.device("meta"):
model = AutoModelForVision2Seq.from_config(config) model = AutoModelForImageTextToText.from_config(config)
model = init_adapter(config, model, model_args, finetuning_args, is_trainable=False) model = init_adapter(config, model, model_args, finetuning_args, is_trainable=False)
model.to_empty(device="cpu")
loaded_model_weight = dict(model.named_parameters()) loaded_model_weight = dict(model.named_parameters())
model.save_pretrained(os.path.join("output", "qwen2_vl"), max_shard_size="10GB", safe_serialization=False) model.save_pretrained(os.path.join("output", "qwen2_vl"), max_shard_size="10GB", safe_serialization=True)
saved_model_weight = torch.load(os.path.join("output", "qwen2_vl", "pytorch_model.bin"), weights_only=False) saved_model_weight = load_file(os.path.join("output", "qwen2_vl", "model.safetensors"))
if is_transformers_version_greater_than("4.52.0"): if is_transformers_version_greater_than("4.52.0"):
assert "model.language_model.layers.0.self_attn.q_proj.weight" in loaded_model_weight assert "model.language_model.layers.0.self_attn.q_proj.weight" in loaded_model_weight

View File

@@ -1,2 +1,2 @@
# change if test fails or cache is outdated # change if test fails or cache is outdated
0.9.5.105 0.9.5.106

View File

@@ -23,6 +23,13 @@ from llamafactory.v1.core.utils.rendering import Renderer
from llamafactory.v1.utils.types import Processor from llamafactory.v1.utils.types import Processor
def _get_input_ids(inputs: list | dict) -> list:
if not isinstance(inputs, list):
return inputs["input_ids"]
else:
return inputs
HF_MESSAGES = [ HF_MESSAGES = [
{"role": "system", "content": "You are a helpful assistant."}, {"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is LLM?"}, {"role": "user", "content": "What is LLM?"},
@@ -81,15 +88,15 @@ def test_chatml_rendering():
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3") tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
renderer = Renderer(template="chatml", processor=tokenizer) renderer = Renderer(template="chatml", processor=tokenizer)
hf_inputs = tokenizer.apply_chat_template(HF_MESSAGES[:-1], add_generation_prompt=True) hf_inputs = _get_input_ids(tokenizer.apply_chat_template(HF_MESSAGES[:-1], add_generation_prompt=True))
v1_inputs = renderer.render_messages(V1_MESSAGES[:-1], is_generate=True) v1_inputs = renderer.render_messages(V1_MESSAGES[:-1], is_generate=True)
assert v1_inputs["input_ids"] == hf_inputs assert v1_inputs["input_ids"] == hf_inputs
assert v1_inputs["attention_mask"] == [1] * len(hf_inputs) assert v1_inputs["attention_mask"] == [1] * len(hf_inputs)
assert v1_inputs["labels"] == [-100] * len(hf_inputs) assert v1_inputs["labels"] == [-100] * len(hf_inputs)
assert v1_inputs["loss_weights"] == [0.0] * len(hf_inputs) assert v1_inputs["loss_weights"] == [0.0] * len(hf_inputs)
hf_inputs_part = tokenizer.apply_chat_template(HF_MESSAGES[:-1], add_generation_prompt=False) hf_inputs_part = _get_input_ids(tokenizer.apply_chat_template(HF_MESSAGES[:-1], add_generation_prompt=False))
hf_inputs_full = tokenizer.apply_chat_template(HF_MESSAGES, add_generation_prompt=False) hf_inputs_full = _get_input_ids(tokenizer.apply_chat_template(HF_MESSAGES, add_generation_prompt=False))
v1_inputs_full = renderer.render_messages(V1_MESSAGES, is_generate=False) v1_inputs_full = renderer.render_messages(V1_MESSAGES, is_generate=False)
assert v1_inputs_full["input_ids"] == hf_inputs_full assert v1_inputs_full["input_ids"] == hf_inputs_full
assert v1_inputs_full["attention_mask"] == [1] * len(hf_inputs_full) assert v1_inputs_full["attention_mask"] == [1] * len(hf_inputs_full)
@@ -124,17 +131,21 @@ def test_qwen3_nothink_rendering():
tokenizer: Processor = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B-Instruct-2507") tokenizer: Processor = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B-Instruct-2507")
renderer = Renderer(template="qwen3_nothink", processor=tokenizer) renderer = Renderer(template="qwen3_nothink", processor=tokenizer)
hf_inputs = tokenizer.apply_chat_template(HF_MESSAGES_WITH_TOOLS[:-1], tools=V1_TOOLS, add_generation_prompt=True) hf_inputs = _get_input_ids(
tokenizer.apply_chat_template(HF_MESSAGES_WITH_TOOLS[:-1], tools=V1_TOOLS, add_generation_prompt=True)
)
v1_inputs = renderer.render_messages(V1_MESSAGES_WITH_TOOLS[:-1], tools=json.dumps(V1_TOOLS), is_generate=True) v1_inputs = renderer.render_messages(V1_MESSAGES_WITH_TOOLS[:-1], tools=json.dumps(V1_TOOLS), is_generate=True)
assert v1_inputs["input_ids"] == hf_inputs assert v1_inputs["input_ids"] == hf_inputs
assert v1_inputs["attention_mask"] == [1] * len(hf_inputs) assert v1_inputs["attention_mask"] == [1] * len(hf_inputs)
assert v1_inputs["labels"] == [-100] * len(hf_inputs) assert v1_inputs["labels"] == [-100] * len(hf_inputs)
assert v1_inputs["loss_weights"] == [0.0] * len(hf_inputs) assert v1_inputs["loss_weights"] == [0.0] * len(hf_inputs)
hf_inputs_part = tokenizer.apply_chat_template( hf_inputs_part = _get_input_ids(
HF_MESSAGES_WITH_TOOLS[:-1], tools=V1_TOOLS, add_generation_prompt=False tokenizer.apply_chat_template(HF_MESSAGES_WITH_TOOLS[:-1], tools=V1_TOOLS, add_generation_prompt=False)
)
hf_inputs_full = _get_input_ids(
tokenizer.apply_chat_template(HF_MESSAGES_WITH_TOOLS, tools=V1_TOOLS, add_generation_prompt=False)
) )
hf_inputs_full = tokenizer.apply_chat_template(HF_MESSAGES_WITH_TOOLS, tools=V1_TOOLS, add_generation_prompt=False)
v1_inputs_full = renderer.render_messages(V1_MESSAGES_WITH_TOOLS, tools=json.dumps(V1_TOOLS), is_generate=False) v1_inputs_full = renderer.render_messages(V1_MESSAGES_WITH_TOOLS, tools=json.dumps(V1_TOOLS), is_generate=False)
assert v1_inputs_full["input_ids"] == hf_inputs_full assert v1_inputs_full["input_ids"] == hf_inputs_full
assert v1_inputs_full["attention_mask"] == [1] * len(hf_inputs_full) assert v1_inputs_full["attention_mask"] == [1] * len(hf_inputs_full)
@@ -187,7 +198,7 @@ def test_qwen3_nothink_rendering_remote(num_samples: int):
def test_process_sft_samples(): def test_process_sft_samples():
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3") tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
renderer = Renderer(template="chatml", processor=tokenizer) renderer = Renderer(template="chatml", processor=tokenizer)
hf_inputs = tokenizer.apply_chat_template(HF_MESSAGES) hf_inputs = _get_input_ids(tokenizer.apply_chat_template(HF_MESSAGES))
samples = [{"messages": V1_MESSAGES, "extra_info": "test", "_dataset_name": "default"}] samples = [{"messages": V1_MESSAGES, "extra_info": "test", "_dataset_name": "default"}]
model_inputs = renderer.process_samples(samples) model_inputs = renderer.process_samples(samples)
@@ -200,7 +211,7 @@ def test_process_sft_samples():
def test_process_dpo_samples(): def test_process_dpo_samples():
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3") tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
renderer = Renderer(template="chatml", processor=tokenizer) renderer = Renderer(template="chatml", processor=tokenizer)
hf_inputs = tokenizer.apply_chat_template(HF_MESSAGES) hf_inputs = _get_input_ids(tokenizer.apply_chat_template(HF_MESSAGES))
samples = [ samples = [
{ {