[v1] add cli sampler (#9721)

This commit is contained in:
Yaowei Zheng
2026-01-06 23:31:27 +08:00
committed by GitHub
parent e944dc442c
commit ea0b4e2466
45 changed files with 1091 additions and 505 deletions

View File

@@ -55,12 +55,12 @@ jobs:
uv pip install -e . uv pip install -e .
uv pip install -r requirements/dev.txt uv pip install -r requirements/dev.txt
- name: Cache HuggingFace models - name: Cache files
id: hf-hub-cache id: hf-hub-cache
uses: actions/cache@v4 uses: actions/cache@v4
with: with:
path: ${{ runner.temp }}/huggingface path: ${{ runner.temp }}/huggingface
key: hf-cache-${{ runner.os }}-${{ hashFiles('tests/version.txt') }} key: huggingface-${{ matrix.os }}-${{ matrix.python }}-${{ hashFiles('tests/version.txt') }}
- name: Check quality - name: Check quality
run: | run: |

View File

@@ -73,7 +73,7 @@ dependencies = [
# api # api
"uvicorn", "uvicorn",
"fastapi", "fastapi",
"sse-starlette" "sse-starlette",
] ]
[project.scripts] [project.scripts]

View File

@@ -119,9 +119,19 @@ def synchronize() -> None:
@requires_accelerator @requires_accelerator
def set_device() -> None: def set_device_index() -> None:
"""Set current accelerator.""" """Set current accelerator index to local rank."""
torch.accelerator.set_device_index(get_local_rank()) if get_current_accelerator().type != DeviceType.CPU:
torch.accelerator.set_device_index(get_local_rank())
@requires_accelerator
def get_current_device() -> torch.device:
"""Get current accelerator device."""
if get_current_accelerator().type == DeviceType.CPU:
return torch.device(DeviceType.CPU.value)
else:
return torch.device(type=get_current_accelerator().type, index=torch.accelerator.current_device_index())
def is_torch_cuda_available(): def is_torch_cuda_available():

View File

@@ -123,12 +123,13 @@ class DistributedInterface:
if self._initialized: if self._initialized:
return return
helper.set_device_index()
self._is_distributed = helper.is_distributed() self._is_distributed = helper.is_distributed()
self._rank = helper.get_rank() self._rank = helper.get_rank()
self._world_size = helper.get_world_size() self._world_size = helper.get_world_size()
self._local_rank = helper.get_local_rank() self._local_rank = helper.get_local_rank()
self._local_world_size = helper.get_local_world_size() self._local_world_size = helper.get_local_world_size()
self.current_accelerator = helper.get_current_accelerator() self.current_device = helper.get_current_device()
self.device_count = helper.get_device_count() self.device_count = helper.get_device_count()
if config is None: if config is None:
@@ -144,15 +145,14 @@ class DistributedInterface:
timeout = config.get("timeout", 18000) timeout = config.get("timeout", 18000)
if self._is_distributed: if self._is_distributed:
helper.set_device()
init_process_group(timeout=timedelta(seconds=timeout)) init_process_group(timeout=timedelta(seconds=timeout))
self.model_device_mesh = init_device_mesh( self.model_device_mesh = init_device_mesh(
device_type=self.current_accelerator.type, device_type=self.current_device.type,
mesh_shape=self.strategy.model_mesh_shape, mesh_shape=self.strategy.model_mesh_shape,
mesh_dim_names=self.strategy.model_mesh_dim_names, mesh_dim_names=self.strategy.model_mesh_dim_names,
) )
self.data_device_mesh = init_device_mesh( self.data_device_mesh = init_device_mesh(
device_type=self.current_accelerator.type, device_type=self.current_device.type,
mesh_shape=self.strategy.data_mesh_shape, mesh_shape=self.strategy.data_mesh_shape,
mesh_dim_names=self.strategy.data_mesh_dim_names, mesh_dim_names=self.strategy.data_mesh_dim_names,
) )
@@ -161,12 +161,12 @@ class DistributedInterface:
self.data_device_mesh = None self.data_device_mesh = None
self._initialized = True self._initialized = True
logger.info_rank0(f"DistributedInterface initialized with strategy={self.strategy}.") logger.info_rank0(f"DistributedInterface initialized: {self}.")
def __str__(self) -> str: def __str__(self) -> str:
return ( return (
f"DistributedInterface(strategy={self.strategy}), is_distributed={self._is_distributed}, " f"DistributedInterface(strategy={self.strategy}), is_distributed={self._is_distributed}, "
f"current_accelerator={self.current_accelerator}, rank={self._rank}, world_size={self._world_size}, " f"current_device={self.current_device}, rank={self._rank}, world_size={self._world_size}, "
f"model_device_mesh={self.model_device_mesh}, data_device_mesh={self.data_device_mesh}" f"model_device_mesh={self.model_device_mesh}, data_device_mesh={self.data_device_mesh}"
) )
@@ -251,4 +251,7 @@ class DistributedInterface:
if __name__ == "__main__": if __name__ == "__main__":
print(DistributedInterface(DistributedStrategy())) """
python -m llamafactory.v1.accelerator.interface
"""
print(DistributedInterface())

View File

@@ -17,7 +17,7 @@
import json import json
from enum import Enum, unique from enum import StrEnum, unique
class PluginConfig(dict): class PluginConfig(dict):
@@ -36,7 +36,7 @@ PluginArgument = PluginConfig | dict | str | None
@unique @unique
class ModelClass(str, Enum): class ModelClass(StrEnum):
"""Auto class for model config.""" """Auto class for model config."""
LLM = "llm" LLM = "llm"
@@ -45,7 +45,7 @@ class ModelClass(str, Enum):
@unique @unique
class SampleBackend(str, Enum): class SampleBackend(StrEnum):
HF = "hf" HF = "hf"
VLLM = "vllm" VLLM = "vllm"

View File

@@ -21,8 +21,13 @@ from .arg_utils import ModelClass, PluginConfig, get_plugin_config
@dataclass @dataclass
class ModelArguments: class ModelArguments:
model: str = field( model: str = field(
default="Qwen/Qwen3-4B-Instruct-2507",
metadata={"help": "Path to the model or model identifier from Hugging Face."}, metadata={"help": "Path to the model or model identifier from Hugging Face."},
) )
template: str = field(
default="chatml",
metadata={"help": "Template for the model."},
)
trust_remote_code: bool = field( trust_remote_code: bool = field(
default=False, default=False,
metadata={"help": "Trust remote code from Hugging Face."}, metadata={"help": "Trust remote code from Hugging Face."},

View File

@@ -12,10 +12,20 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import asyncio
import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Generator
from threading import Thread
import torch
from transformers import TextIteratorStreamer
from ..accelerator.interface import DistributedInterface
from ..config import ModelArguments, SampleArguments, SampleBackend from ..config import ModelArguments, SampleArguments, SampleBackend
from ..utils.types import HFModel, Processor, TorchDataset from ..utils.helper import get_tokenizer
from ..utils.types import HFModel, Message, Sample, TorchDataset
from .utils.rendering import Renderer
class BaseEngine(ABC): class BaseEngine(ABC):
@@ -24,8 +34,8 @@ class BaseEngine(ABC):
self, self,
args: SampleArguments, args: SampleArguments,
model_args: ModelArguments, model_args: ModelArguments,
model: HFModel = None, model: HFModel,
processor: Processor = None, renderer: Renderer,
) -> None: ) -> None:
"""Initialize the engine. """Initialize the engine.
@@ -33,17 +43,34 @@ class BaseEngine(ABC):
args: Sample arguments. args: Sample arguments.
model_args: Model arguments. model_args: Model arguments.
model: Model. model: Model.
processor: Processor. renderer: Renderer.
""" """
... ...
@abstractmethod @abstractmethod
async def generate(self, messages): async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
pass """Generate tokens asynchronously.
Args:
messages: List of messages.
tools: Tools string.
Yields:
Generated tokens.
"""
...
@abstractmethod @abstractmethod
async def batch_infer(self, data: TorchDataset) -> None: async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
pass """Batch infer samples.
Args:
dataset: Torch dataset.
Returns:
List of samples.
"""
...
class HuggingFaceEngine(BaseEngine): class HuggingFaceEngine(BaseEngine):
@@ -52,26 +79,103 @@ class HuggingFaceEngine(BaseEngine):
args: SampleArguments, args: SampleArguments,
model_args: ModelArguments, model_args: ModelArguments,
model: HFModel, model: HFModel,
processor: Processor, renderer: Renderer,
) -> None: ) -> None:
self.args = args self.args = args
self.model_args = model_args
self.model = model
self.renderer = renderer
self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
@torch.inference_mode()
def get_response(self, messages: list[Message], tools: str | None = None) -> Generator[str, None, None]:
model_inputs = self.renderer.render_messages(messages, tools, is_generate=True)
streamer = TextIteratorStreamer(
tokenizer=get_tokenizer(self.renderer.processor),
skip_prompt=True,
skip_special_tokens=True, # TODO: configurable
)
device = DistributedInterface().current_device
kwargs = {
"input_ids": torch.tensor([model_inputs["input_ids"]]).to(device),
"attention_mask": torch.tensor([model_inputs["attention_mask"]]).to(device),
"max_new_tokens": self.args.max_new_tokens,
"streamer": streamer,
}
thread = Thread(target=self.model.generate, kwargs=kwargs, daemon=True)
thread.start()
def stream():
try:
return streamer.__next__()
except StopIteration:
raise StopAsyncIteration()
return stream
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
async with self.semaphore:
response = self.get_response(messages, tools)
while True:
try:
yield await asyncio.to_thread(response)
except StopAsyncIteration:
break
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
"""Batch infer samples.
Args:
dataset: Torch dataset.
Returns:
List of samples.
"""
raise NotImplementedError("Batch infer is not implemented.")
class BaseSampler: class BaseSampler:
"""Base sampler.
Args:
args: Sample arguments.
model_args: Model arguments.
model: Model.
renderer: Renderer.
"""
def __init__( def __init__(
self, self,
args: SampleArguments, args: SampleArguments,
model_args: ModelArguments, model_args: ModelArguments,
model: HFModel, model: HFModel,
processor: Processor, renderer: Renderer,
) -> None: ) -> None:
if args.sample_backend == SampleBackend.HF: if args.sample_backend == SampleBackend.HF:
self.engine = HuggingFaceEngine(args, model_args, model, processor) self.engine = HuggingFaceEngine(args, model_args, model, renderer)
else: else:
raise ValueError(f"Unknown sample backend: {args.sample_backend}") raise ValueError(f"Unknown sample backend: {args.sample_backend}")
async def generate(self, messages): async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
return await self.engine.generate(messages) """Generate tokens asynchronously.
async def batch_infer(self, data: TorchDataset) -> None: Args:
return await self.engine.batch_infer(data) messages: List of messages.
tools: Tools string.
Yields:
Generated tokens.
"""
async for token in self.engine.generate(messages, tools):
yield token
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
"""Batch infer samples.
Args:
dataset: Torch dataset.
Returns:
List of samples.
"""
return await self.engine.batch_infer(dataset)

View File

@@ -14,15 +14,23 @@
"""The definition of data engine. """The definition of data engine.
Init Data engine: How to use:
data_engine = DataEngine(data_args)
data_engine[i]: Get the sample via index.
Init workflow:
1. Parse dataset info from arguments. 1. Parse dataset info from arguments.
2. Load datasets according to dataset info. 2. Load datasets according to dataset info.
3. Build data index (and reweight samples if necessary). 3. Build data index (and reweight samples if necessary).
Get Data Sample: Get data sample:
1. Get sample from data index. 1. Get sample from data index.
2. Convert sample to standard format. 2. Convert sample to standard format.
3. Return sample. 3. Return sample.
Note:
1. The data engine is equivalent to the torch dataset.
2. The data engine is agnostic to the model used.
""" """
import os import os
@@ -98,10 +106,10 @@ class DataEngine(Dataset):
size = self.dataset_infos[dataset_name].get("size") size = self.dataset_infos[dataset_name].get("size")
weight = self.dataset_infos[dataset_name].get("weight") weight = self.dataset_infos[dataset_name].get("weight")
if size or weight: # data index plugin if size or weight:
from ..plugins.data_plugins.loader import DataIndexPlugin from ..plugins.data_plugins.loader import adjust_data_index
data_index = DataIndexPlugin().adjust_data_index(data_index, size, weight) data_index = adjust_data_index(data_index, size, weight)
self.data_index.extend(data_index) self.data_index.extend(data_index)
@@ -150,9 +158,9 @@ class DataEngine(Dataset):
dataset_name, sample_index = self.data_index[index] dataset_name, sample_index = self.data_index[index]
return self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name) return self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
else: # data selector plugin else: # data selector plugin
from ..plugins.data_plugins.loader import DataSelectorPlugin from ..plugins.data_plugins.loader import select_data_sample
selected_index = DataSelectorPlugin().select(self.data_index, index) selected_index = select_data_sample(self.data_index, index)
if isinstance(selected_index, list): if isinstance(selected_index, list):
return [ return [
self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name) self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)

View File

@@ -12,16 +12,18 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""The definition of model loader. """The definition of model engine.
How to use: How to use:
model_loader = ModelLoader(model_args, is_trainable=True) model_engine = ModelEngine(model_args, is_train=True)
model_loader.processor: Get the tokenizer or multi-modal processor. model_engine.processor: Get the tokenizer or multi-modal processor.
model_loader.model_config: Get the model configuration. model_engine.renderer: Get the renderer.
model_loader.model: Get the HF model. model_engine.model_config: Get the model configuration.
model_engine.model: Get the HF model.
Init Workflow: Init workflow:
1. Init processor. 1. Init processor.
2. Init render.
2. Init model config. 2. Init model config.
3. Init model. 3. Init model.
4. Init adapter. 4. Init adapter.
@@ -36,17 +38,18 @@ from ..accelerator.interface import DistributedInterface
from ..config.model_args import ModelArguments, ModelClass from ..config.model_args import ModelArguments, ModelClass
from ..utils import logging from ..utils import logging
from ..utils.types import HFConfig, HFModel, Processor from ..utils.types import HFConfig, HFModel, Processor
from .utils.rendering import Renderer
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
class ModelLoader: class ModelEngine:
"""Model loader. """Model engine.
Args: Args:
model_args: Model arguments. model_args: Model arguments.
is_trainable: Whether to train the model. is_train: Whether to train the model.
""" """
def __init__(self, model_args: ModelArguments, is_train: bool = False) -> None: def __init__(self, model_args: ModelArguments, is_train: bool = False) -> None:
@@ -56,6 +59,8 @@ class ModelLoader:
"""Whether to train the model.""" """Whether to train the model."""
self.processor = self._init_processor() self.processor = self._init_processor()
"""Tokenizer or multi-modal processor.""" """Tokenizer or multi-modal processor."""
self.renderer = Renderer(self.args.template, self.processor)
"""Renderer."""
self.model_config = self._init_model_config() self.model_config = self._init_model_config()
"""Model configuration.""" """Model configuration."""
self.model = self._init_model() self.model = self._init_model()
@@ -107,7 +112,7 @@ class ModelLoader:
init_device = InitPlugin(self.args.init_config.name)() init_device = InitPlugin(self.args.init_config.name)()
else: else:
init_device = DistributedInterface().current_accelerator init_device = DistributedInterface().current_device
if init_device.type == DeviceType.META: if init_device.type == DeviceType.META:
with init_empty_weights(): with init_empty_weights():
@@ -144,12 +149,12 @@ class ModelLoader:
if __name__ == "__main__": if __name__ == "__main__":
""" """
python -m llamafactory.v1.core.model_loader --model llamafactory/tiny-random-qwen2.5 python -m llamafactory.v1.core.model_engine --model llamafactory/tiny-random-qwen2.5
""" """
from ..config.arg_parser import get_args from ..config.arg_parser import get_args
_, model_args, *_ = get_args() _, model_args, *_ = get_args()
model_loader = ModelLoader(model_args=model_args) model_engine = ModelEngine(model_args=model_args)
print(model_loader.processor) print(model_engine.processor)
print(model_loader.model_config) print(model_engine.model_config)
print(model_loader.model) print(model_engine.model)

View File

@@ -0,0 +1,239 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import re
from ...utils.constants import IGNORE_INDEX
from ...utils.helper import get_tokenizer
from ...utils.types import Message, ModelInput, Processor
def _update_model_input(
processor: Processor,
input_ids: list[int],
labels: list[int],
loss_weights: list[int],
temp_str: str,
temp_weight: float,
) -> str:
"""Update model input with temporary string."""
if not temp_str:
return ""
tokenizer = get_tokenizer(processor)
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
input_ids.extend(temp_ids)
loss_weights.extend([temp_weight] * len(temp_ids))
if temp_weight > 1e-6:
labels.extend(temp_ids)
else:
labels.extend([IGNORE_INDEX] * len(temp_ids))
return ""
def render_chatml_messages(
processor: Processor,
messages: list[Message],
tools: str | None = None,
is_generate: bool = False,
) -> ModelInput:
"""Apply chatml template to messages and convert them to model input.
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground
"""
input_ids, labels, loss_weights = [], [], []
temp_str, temp_weight = "", 0.0
if tools:
temp_str += "<|im_start|>system\n"
if messages[0]["role"] == "system":
for content in messages[0]["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "\n\n"
temp_weight = messages[0].get("loss_weight", 0.0)
temp_str += (
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
)
try:
tools = json.loads(tools)
except json.JSONDecodeError:
raise ValueError(f"Invalid tools format: {str(tools)}.")
if not isinstance(tools, list):
tools = [tools]
for tool in tools:
temp_str += "\n" + json.dumps(tool, ensure_ascii=False)
temp_str += (
"\n</tools>\n\nFor each function call, return a json object with function name "
'and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": '
'<function-name>, "arguments": <args-json-object>}\n</tool_call><|im_end|>\n'
)
elif messages[0]["role"] == "system":
temp_str += "<|im_start|>system\n"
for content in messages[0]["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "<|im_end|>\n"
temp_weight = messages[0].get("loss_weight", 0.0)
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
for turn_idx, message in enumerate(messages):
if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0):
temp_str += "<|im_start|>" + message["role"] + "\n"
for content in message["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 0.0)
elif message["role"] == "assistant":
temp_str += "<|im_start|>" + message["role"] + "\n"
for val_idx, content in enumerate(message["content"]):
if content["type"] == "text":
temp_str += content["value"]
elif content["type"] == "reasoning":
temp_str += "<thinking>\n" + content["value"] + "\n</thinking>\n\n" # avoid using special tokens
elif content["type"] == "tool_call":
if val_idx != 0 and message["content"][val_idx - 1]["type"] in ["text", "tool_call"]:
temp_str += "\n"
try:
tool_call = json.loads(content["value"])
except json.JSONDecodeError:
raise ValueError(f"Invalid tool call format: {content['value']}.")
temp_str += (
'<tool_call>\n{"name": "'
+ tool_call["name"]
+ '", "arguments": '
+ json.dumps(tool_call["arguments"], ensure_ascii=False)
+ "}\n</tool_call>"
)
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 1.0)
elif message["role"] == "tool":
if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool":
temp_str += "<|im_start|>user"
temp_str += "\n<tool_response>\n"
for content in message["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "\n</tool_response>"
if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool":
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 0.0)
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
if is_generate:
temp_str += "<|im_start|>assistant\n"
temp_weight = 0.0
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
attention_mask = [1] * len(input_ids)
return ModelInput(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
loss_weights=loss_weights,
)
def parse_chatml_message(generated_text: str) -> Message:
"""Parse a message in ChatML format. Supports interleaved reasoning and tool calls.
Args:
generated_text (str): The generated text in ChatML format.
Returns:
Message: The parsed message.
"""
pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
content = []
last_end = 0
for match in pattern.finditer(generated_text):
start, end = match.span()
if start > last_end:
text = generated_text[last_end:start].strip()
if text:
content.append({"type": "text", "value": text})
tag_type = match.group(1)
tag_value = match.group(2).strip()
if tag_type == "thinking":
content.append({"type": "reasoning", "value": tag_value.strip()})
elif tag_type == "tool_call":
try:
json.loads(tag_value.strip())
except json.JSONDecodeError:
raise ValueError(f"Invalid tool call format: {tag_value.strip()}.")
content.append({"type": "tool_call", "value": tag_value.strip()})
last_end = end
if last_end < len(generated_text):
text = generated_text[last_end:].strip()
if text:
content.append({"type": "text", "value": text})
return Message(role="assistant", content=content)
class Renderer:
def __init__(self, template: str, processor: Processor):
self.template = template
self.processor = processor
def render_messages(
self, messages: list[Message], tools: str | None = None, is_generate: bool = False
) -> ModelInput:
if self.template == "chatml":
return render_chatml_messages(self.processor, messages, tools, is_generate)
else:
from ...plugins.model_plugins.rendering import RenderingPlugin
return RenderingPlugin(self.template).render_messages(self.processor, messages, tools, is_generate)
def parse_message(self, generated_text: str) -> Message:
if self.template == "chatml":
return parse_chatml_message(generated_text)
else:
from ...plugins.model_plugins.rendering import RenderingPlugin
return RenderingPlugin(self.template).parse_message(generated_text)

View File

@@ -49,6 +49,11 @@ def launch():
run_sft() run_sft()
elif command == "chat":
from .samplers.cli_sampler import run_chat
run_chat()
elif command == "env": elif command == "env":
print_env() print_env()

View File

@@ -13,11 +13,12 @@
# limitations under the License. # limitations under the License.
import json
from typing import Any, Literal, NotRequired, TypedDict from typing import Any, Literal, NotRequired, TypedDict
from ...utils import logging from ...utils import logging
from ...utils.plugin import BasePlugin from ...utils.plugin import BasePlugin
from ...utils.types import DPOSample, Sample, SFTSample from ...utils.types import DPOSample, Sample, SFTSample, ToolCall
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@@ -61,7 +62,7 @@ class DataConverterPlugin(BasePlugin):
return super().__call__(raw_sample) return super().__call__(raw_sample)
@DataConverterPlugin("alpaca").register @DataConverterPlugin("alpaca").register()
def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample: def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
"""Convert Alpaca sample to SFT sample. """Convert Alpaca sample to SFT sample.
@@ -98,7 +99,7 @@ def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
return {"messages": messages} return {"messages": messages}
@DataConverterPlugin("sharegpt").register @DataConverterPlugin("sharegpt").register()
def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample: def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
"""Convert ShareGPT sample to SFT sample. """Convert ShareGPT sample to SFT sample.
@@ -118,17 +119,32 @@ def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
"function_call": "assistant", "function_call": "assistant",
} }
messages = [] messages = []
tools = raw_sample.get("tools", "") tools = raw_sample.get("tools")
if tools:
try:
tools: list[dict[str, Any]] = json.loads(tools)
except json.JSONDecodeError:
logger.warning_rank0(f"Invalid tools format: {str(tools)}")
tools = []
for message in raw_sample.get("conversations", []): for message in raw_sample.get("conversations", []):
tag = message["from"] tag = message["from"]
if tag not in tag_mapping: if tag not in tag_mapping:
logger.warning_rank0(f"Unsupported role tag {tag} in message: {message}") logger.warning_rank0(f"Unsupported role tag {tag} in message: {message}")
elif tag == "function_call": elif tag == "function_call":
try:
tool_calls: ToolCall | list[ToolCall] = json.loads(message["value"])
except json.JSONDecodeError:
logger.warning_rank0(f"Invalid tool call format: {str(message['value'])}")
continue
if not isinstance(tool_calls, list):
tool_calls = [tool_calls]
messages.append( messages.append(
{ {
"role": "assistant", "role": "assistant",
"content": [{"type": "tool_calls", "value": message["value"]}], "content": [{"type": "tool_call", "value": json.dumps(tool_call)} for tool_call in tool_calls],
"loss_weight": 1.0, "loss_weight": 1.0,
} }
) )
@@ -142,15 +158,12 @@ def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
) )
if tools: if tools:
if messages and messages[0]["role"] == "system": return {"messages": messages, "extra_info": json.dumps({"tools": tools})}
messages[0]["content"].append({"type": "tools", "value": tools}) else:
else: return {"messages": messages}
messages.insert(0, {"role": "system", "content": [{"type": "tools", "value": tools}], "loss_weight": 0.0})
return {"messages": messages}
@DataConverterPlugin("pair").register @DataConverterPlugin("pair").register()
def pair_converter(raw_sample: PairSample) -> DPOSample: def pair_converter(raw_sample: PairSample) -> DPOSample:
"""Convert Pair sample to DPO sample. """Convert Pair sample to DPO sample.

View File

@@ -49,7 +49,7 @@ def _get_builder_name(path: str) -> Literal["arrow", "csv", "json", "parquet", "
raise ValueError(f"Unknown dataset filetype: {filetype}.") raise ValueError(f"Unknown dataset filetype: {filetype}.")
@DataLoaderPlugin("local").register @DataLoaderPlugin("local").register()
def load_data_from_file(filepath: str, split: str, streaming: bool) -> HFDataset: def load_data_from_file(filepath: str, split: str, streaming: bool) -> HFDataset:
if os.path.isdir(filepath): if os.path.isdir(filepath):
filetype = _get_builder_name(os.listdir(filepath)[0]) filetype = _get_builder_name(os.listdir(filepath)[0])
@@ -66,49 +66,43 @@ def load_data_from_file(filepath: str, split: str, streaming: bool) -> HFDataset
return dataset return dataset
class DataIndexPlugin(BasePlugin): def adjust_data_index(
"""Plugin for adjusting dataset index.""" data_index: list[tuple[str, int]], size: int | None, weight: float | None
) -> list[tuple[str, int]]:
"""Adjust dataset index by size and weight.
def adjust_data_index( Args:
self, data_index: list[tuple[str, int]], size: int | None, weight: float | None data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
) -> list[tuple[str, int]]: size (Optional[int]): Desired dataset size.
"""Adjust dataset index by size and weight. weight (Optional[float]): Desired dataset weight.
Args: Returns:
data_index (list[tuple[str, int]]): List of (dataset_name, sample_index). list[tuple[str, int]]: Adjusted dataset index.
size (Optional[int]): Desired dataset size. """
weight (Optional[float]): Desired dataset weight. if size is not None:
data_index = random.choices(data_index, k=size)
Returns: if weight is not None:
list[tuple[str, int]]: Adjusted dataset index. data_index = random.choices(data_index, k=int(len(data_index) * weight))
"""
if size is not None:
data_index = random.choices(data_index, k=size)
if weight is not None: return data_index
data_index = random.choices(data_index, k=int(len(data_index) * weight))
return data_index
class DataSelectorPlugin(BasePlugin): def select_data_sample(
"""Plugin for selecting dataset samples.""" data_index: list[tuple[str, int]], index: slice | list[int] | Any
) -> tuple[str, int] | list[tuple[str, int]]:
"""Select dataset samples.
def select( Args:
self, data_index: list[tuple[str, int]], index: slice | list[int] | Any data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
) -> tuple[str, int] | list[tuple[str, int]]: index (Union[slice, list[int], Any]): Index of dataset samples.
"""Select dataset samples.
Args: Returns:
data_index (list[tuple[str, int]]): List of (dataset_name, sample_index). Union[tuple[str, int], list[tuple[str, int]]]: Selected dataset samples.
index (Union[slice, list[int], Any]): Index of dataset samples. """
if isinstance(index, slice):
Returns: return [data_index[i] for i in range(*index.indices(len(data_index)))]
Union[tuple[str, int], list[tuple[str, int]]]: Selected dataset samples. elif isinstance(index, list):
""" return [data_index[i] for i in index]
if isinstance(index, slice): else:
return [data_index[i] for i in range(*index.indices(len(data_index)))] raise ValueError(f"Invalid index type {type(index)}.")
elif isinstance(index, list):
return [data_index[i] for i in index]
else:
raise ValueError(f"Invalid index type {type(index)}.")

View File

@@ -1,133 +0,0 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
@dataclass
class Template:
user_template: str
assistant_template: str
system_template: str
def render_message(self, message: dict[str, str]) -> str:
return self.user_template.format(**message)
@dataclass
class QwenTemplate:
message_template: str = "<|im_start|>{role}\n{content}<|im_end|>\n" # FIXME if role: tool
thinking_template: str = "<think>\n{content}\n</think>\n\n"
def _extract_content(self, content_data: str | list[dict[str, str]]) -> str:
if isinstance(content_data, str):
return content_data.strip()
if isinstance(content_data, list):
parts = []
for item in content_data:
if item.get("type") == "text":
parts.append(item.get("value", ""))
elif item.get("type") == "image_url":
pass
return "\n".join(parts).strip()
return ""
def render_message(self, message: dict[str, str | list[dict[str, str]]]) -> str:
role = message["role"]
content = self._extract_content(message.get("content", ""))
if role == "assistant":
reasoning_content = message.get("reasoning_content", "")
if reasoning_content:
reasoning_content = self.thinking_template.format(content=str(reasoning_content).strip())
return self.message_template.format(role="assistant", content=reasoning_content + content)
else:
return self.message_template.format(role=role, content=content)
def encode_messages(self, tokenizer, messages: list[dict[str, str]], max_seq_len: int = 8192) -> any:
"""Encode one message."""
input_ids, attention_mask, labels = [], [], []
for message in messages:
content_str = self.render_message(message)
content_ids = tokenizer.encode(content_str, add_special_tokens=False)
input_ids += content_ids
attention_mask += [1] * len(content_ids)
if hasattr(message, "loss_weight"):
loss_weight = message["loss_weight"]
else:
loss_weight = 1 if message["role"] == "assistant" else 0
if loss_weight == 1:
labels += content_ids
else:
labels += [-100] * len(content_ids)
model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
model_inputs.update({"position_ids": list(range(len(input_ids)))})
model_inputs = {k: v[-max_seq_len:] for k, v in model_inputs.items()}
return model_inputs
if __name__ == "__main__":
def to_qwen3_messages(template: QwenTemplate, messages: list[dict]):
out = []
for m in messages:
role = m["role"]
content = template._extract_content(m.get("content", ""))
if role == "assistant":
reasoning = (m.get("reasoning_content") or "").strip()
if reasoning:
content = template.thinking_template.format(content=reasoning) + content
out.append({"role": role, "content": content})
return out
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained(
"Qwen/Qwen3-30B-A3B-Thinking-2507",
trust_remote_code=True,
)
test_messages = [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": [{"type": "text", "text": "1+1等于几"}, {"type": "text", "text": "2+2等于几"}],
},
{
"role": "assistant",
"reasoning_content": "这是一个简单的数学问题。1加1的结果是2。",
"content": [{"type": "text", "text": "1+1=2"}, {"type": "text", "text": "2+2=4"}],
},
]
template = QwenTemplate()
rendered_custom = "".join([template.render_message(m) for m in test_messages])
qwen3_messages = to_qwen3_messages(template, test_messages)
rendered_hf = tok.apply_chat_template(qwen3_messages, tokenize=False, add_generation_prompt=False)
print("==== custom ====")
print(rendered_custom)
print("==== hf ====")
print(rendered_hf)
assert rendered_custom.strip() == rendered_hf.strip(), "Rendered text mismatch"
ids_custom = tok.encode(rendered_custom, add_special_tokens=False)
ids_hf = tok.apply_chat_template(qwen3_messages, tokenize=True, add_generation_prompt=False)
assert ids_custom == ids_hf, f"Token ids mismatch: custom={len(ids_custom)} hf={len(ids_hf)}"

View File

@@ -25,12 +25,12 @@ class InitPlugin(BasePlugin):
return super().__call__() return super().__call__()
@InitPlugin("init_on_meta").register @InitPlugin("init_on_meta").register()
def init_on_meta() -> torch.device: def init_on_meta() -> torch.device:
return torch.device(DeviceType.META.value) return torch.device(DeviceType.META.value)
@InitPlugin("init_on_rank0").register @InitPlugin("init_on_rank0").register()
def init_on_rank0() -> torch.device: def init_on_rank0() -> torch.device:
if DistributedInterface().get_rank() == 0: if DistributedInterface().get_rank() == 0:
return torch.device(DeviceType.CPU.value) return torch.device(DeviceType.CPU.value)
@@ -38,6 +38,6 @@ def init_on_rank0() -> torch.device:
return torch.device(DeviceType.META.value) return torch.device(DeviceType.META.value)
@InitPlugin("init_on_default").register @InitPlugin("init_on_default").register()
def init_on_default() -> torch.device: def init_on_default() -> torch.device:
return DistributedInterface().current_accelerator return DistributedInterface().current_device

View File

@@ -38,17 +38,17 @@ class BaseKernel(ABC):
@classmethod @classmethod
def get_kernel_id(cls) -> str: def get_kernel_id(cls) -> str:
r"""Returns the unique identifier for the kernel.""" """Returns the unique identifier for the kernel."""
return cls._kernel_id return cls._kernel_id
@classmethod @classmethod
def get_device(cls) -> str: def get_device(cls) -> str:
r"""Returns the device type associated with the kernel (e.g., "cuda", "npu", "cpu").""" """Returns the device type associated with the kernel (e.g., "cuda", "npu", "cpu")."""
return cls._device return cls._device
@classmethod @classmethod
def check_deps(cls) -> bool: def check_deps(cls) -> bool:
r"""Checks if the required dependencies for the kernel are available. """Checks if the required dependencies for the kernel are available.
Returns: Returns:
bool: ``True`` if dependencies are met, ``False`` otherwise. bool: ``True`` if dependencies are met, ``False`` otherwise.
@@ -65,7 +65,7 @@ class BaseKernel(ABC):
@classmethod @classmethod
@abstractmethod @abstractmethod
def apply(cls, **kwargs) -> HFModel: def apply(cls, **kwargs) -> HFModel:
r"""Applies the kernel optimization to the model. """Applies the kernel optimization to the model.
Args: Args:
**kwargs: Arbitrary keyword arguments, usually containing the model instance and the kernel configuration. **kwargs: Arbitrary keyword arguments, usually containing the model instance and the kernel configuration.

View File

@@ -33,7 +33,7 @@ logger = get_logger(__name__)
def scan_all_kernels(): def scan_all_kernels():
r"""Scan all kernels in the ``ops`` directory. """Scan all kernels in the ``ops`` directory.
Scans the ``ops`` directory for all ``.py`` files and attempts to import them. Scans the ``ops`` directory for all ``.py`` files and attempts to import them.
Importing triggers the :func:`~registry.register_kernel` decorator, which automatically registers the kernels. Importing triggers the :func:`~registry.register_kernel` decorator, which automatically registers the kernels.
@@ -77,7 +77,7 @@ default_kernels = scan_all_kernels()
def get_default_kernels(): def get_default_kernels():
r"""Get a list of default registered kernel IDs. """Get a list of default registered kernel IDs.
Returns: Returns:
list[str]: List of kernel IDs. list[str]: List of kernel IDs.
@@ -86,7 +86,7 @@ def get_default_kernels():
def apply_kernel(kernel_id: str, **kwargs): def apply_kernel(kernel_id: str, **kwargs):
r"""Applies a specific kernel to the model. """Applies a specific kernel to the model.
Args: Args:
kernel_id (str): The ID of the kernel to apply. kernel_id (str): The ID of the kernel to apply.
@@ -99,18 +99,19 @@ def apply_kernel(kernel_id: str, **kwargs):
kernel = default_kernels.get(kernel_id) kernel = default_kernels.get(kernel_id)
if kernel is None: if kernel is None:
raise ValueError(f"Kernel {kernel_id} not found") raise ValueError(f"Kernel {kernel_id} not found")
kernel.apply(**kwargs) kernel.apply(**kwargs)
class KernelPlugin(BasePlugin): class KernelPlugin(BasePlugin):
r"""Plugin for managing kernel optimizations.""" """Plugin for managing kernel optimizations."""
pass pass
@KernelPlugin("auto").register @KernelPlugin("auto").register()
def apply_default_kernels(**kwargs): def apply_default_kernels(**kwargs):
r"""Applies all default registered kernels to the model. """Applies all default registered kernels to the model.
Args: Args:
**kwargs: Keyword arguments passed to the kernel application function. **kwargs: Keyword arguments passed to the kernel application function.
@@ -125,8 +126,11 @@ def apply_default_kernels(**kwargs):
use_kernels = default_kernels.keys() use_kernels = default_kernels.keys()
else: else:
use_kernels = kwargs.get("include_kernels").split(",") # "kernel_id1,kernel_id2,kernel_id3" use_kernels = kwargs.get("include_kernels").split(",") # "kernel_id1,kernel_id2,kernel_id3"
for kernel in use_kernels: for kernel in use_kernels:
if kernel not in default_kernels: if kernel not in default_kernels:
raise ValueError(f"Kernel {kernel} not found") raise ValueError(f"Kernel {kernel} not found")
apply_kernel(kernel, **kwargs) apply_kernel(kernel, **kwargs)
return kwargs.get("model") return kwargs.get("model")

View File

@@ -40,11 +40,11 @@ from ...registry import register_kernel
class GmmFunction(torch.autograd.Function): class GmmFunction(torch.autograd.Function):
r"""Custom autograd function for NPU Grouped Matrix Multiplication (GMM).""" """Custom autograd function for NPU Grouped Matrix Multiplication (GMM)."""
@staticmethod @staticmethod
def forward(ctx, x, weight, group_list): def forward(ctx, x, weight, group_list):
r"""Performs the forward pass of Grouped Matrix Multiplication. """Performs the forward pass of Grouped Matrix Multiplication.
Args: Args:
ctx: Context object to save tensors for backward pass. ctx: Context object to save tensors for backward pass.
@@ -65,7 +65,7 @@ class GmmFunction(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
r"""Performs the backward pass of Grouped Matrix Multiplication. """Performs the backward pass of Grouped Matrix Multiplication.
Args: Args:
ctx: Context object containing saved tensors. ctx: Context object containing saved tensors.
@@ -94,11 +94,11 @@ class GmmFunction(torch.autograd.Function):
class HybridGmmFunction(torch.autograd.Function): class HybridGmmFunction(torch.autograd.Function):
r"""Custom autograd function for Hybrid Grouped Matrix Multiplication on NPU.""" """Custom autograd function for Hybrid Grouped Matrix Multiplication on NPU."""
@staticmethod @staticmethod
def forward(ctx, num_experts, *args): def forward(ctx, num_experts, *args):
r"""Performs the forward pass of Hybrid GMM. """Performs the forward pass of Hybrid GMM.
Args: Args:
ctx: Context object to save tensors. ctx: Context object to save tensors.
@@ -124,7 +124,7 @@ class HybridGmmFunction(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, *grad_outputs): def backward(ctx, *grad_outputs):
r"""Performs the backward pass of Hybrid GMM. """Performs the backward pass of Hybrid GMM.
Args: Args:
ctx: Context object containing saved tensors. ctx: Context object containing saved tensors.
@@ -176,13 +176,13 @@ class HybridGmmFunction(torch.autograd.Function):
class NpuMoeFused: class NpuMoeFused:
r"""Container for NPU fused MoE forward functions.""" """Container for NPU fused MoE forward functions."""
@staticmethod @staticmethod
def npu_moe_experts_forward( def npu_moe_experts_forward(
self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor
) -> torch.Tensor: ) -> torch.Tensor:
r"""Forward pass for MoE experts using NPU fused operations. """Forward pass for MoE experts using NPU fused operations.
Args: Args:
self: The MoE layer instance. self: The MoE layer instance.
@@ -230,11 +230,11 @@ class NpuMoeFused:
class Qwen3NpuMoeFused: class Qwen3NpuMoeFused:
r"""Container for Qwen3 NPU fused MoE forward functions.""" """Container for Qwen3 NPU fused MoE forward functions."""
@staticmethod @staticmethod
def qwen3moe_sparse_moe_block_forward(self, hidden_states: torch.Tensor): def qwen3moe_sparse_moe_block_forward(self, hidden_states: torch.Tensor):
r"""Forward pass for Qwen3 sparse MoE block using NPU fused operations. """Forward pass for Qwen3 sparse MoE block using NPU fused operations.
Args: Args:
self: The Qwen3 MoE block instance. self: The Qwen3 MoE block instance.
@@ -298,14 +298,14 @@ if not is_transformers_version_greater_than("5.0.0"):
@register_kernel @register_kernel
class NpuFusedMoEKernel(BaseKernel): class NpuFusedMoEKernel(BaseKernel):
r"""NPU Fused MoE Kernel implementation.""" """NPU Fused MoE Kernel implementation."""
_kernel_id = "npu_fused_moe" _kernel_id = "npu_fused_moe"
_device = DeviceType.NPU _device = DeviceType.NPU
@classmethod @classmethod
def apply(cls, **kwargs) -> HFModel: def apply(cls, **kwargs) -> HFModel:
r"""Applies the NPU fused MoE kernel to the model. """Applies the NPU fused MoE kernel to the model.
Args: Args:
**kwargs: Keyword arguments containing the model. **kwargs: Keyword arguments containing the model.
@@ -333,6 +333,7 @@ class NpuFusedMoEKernel(BaseKernel):
if target_moe_mapping is None: if target_moe_mapping is None:
return model return model
for module in model.modules(): for module in model.modules():
class_name = module.__class__.__name__ class_name = module.__class__.__name__
if class_name in target_moe_mapping: if class_name in target_moe_mapping:

View File

@@ -38,7 +38,7 @@ except ImportError:
def npu_swiglu_forward(self, hidden_state): def npu_swiglu_forward(self, hidden_state):
r"""SwiGLU forward pass for NPU. """SwiGLU forward pass for NPU.
Args: Args:
self: The MLP layer instance. self: The MLP layer instance.
@@ -53,7 +53,7 @@ def npu_swiglu_forward(self, hidden_state):
def _npu_swiglu_glm4_forward(self, hidden_states): def _npu_swiglu_glm4_forward(self, hidden_states):
r"""SwiGLU forward pass for GLM4 on NPU. """SwiGLU forward pass for GLM4 on NPU.
Args: Args:
self: The GLM4 MLP layer instance. self: The GLM4 MLP layer instance.
@@ -68,7 +68,7 @@ def _npu_swiglu_glm4_forward(self, hidden_states):
def _npu_swiglu_gemma3ntext_forward(self, hidden_states): def _npu_swiglu_gemma3ntext_forward(self, hidden_states):
r"""SwiGLU forward pass for Gemma3nText on NPU. """SwiGLU forward pass for Gemma3nText on NPU.
Args: Args:
self: The Gemma3nText MLP layer instance. self: The Gemma3nText MLP layer instance.
@@ -88,7 +88,7 @@ def _npu_swiglu_gemma3ntext_forward(self, hidden_states):
@register_kernel @register_kernel
class NpuSwiGluKernel(BaseKernel): class NpuSwiGluKernel(BaseKernel):
r"""NPU Kernel for fused SwiGLU activation.""" """NPU Kernel for fused SwiGLU activation."""
# just support apply to the following module layers # just support apply to the following module layers
expect_modules = frozenset( expect_modules = frozenset(
@@ -126,7 +126,7 @@ class NpuSwiGluKernel(BaseKernel):
@classmethod @classmethod
def apply(cls, **kwargs) -> "HFModel": def apply(cls, **kwargs) -> "HFModel":
r"""Applies the NPU fused SwiGLU kernel to the model. """Applies the NPU fused SwiGLU kernel to the model.
Args: Args:
**kwargs: Keyword arguments containing the model. **kwargs: Keyword arguments containing the model.

View File

@@ -30,7 +30,7 @@ from ...registry import register_kernel
def npu_rms_norm_forward(self, hidden_states): def npu_rms_norm_forward(self, hidden_states):
r"""NPU forward implementation for RMSNorm. """NPU forward implementation for RMSNorm.
Args: Args:
self: RMSNorm module instance with `weight` and `variance_epsilon`. self: RMSNorm module instance with `weight` and `variance_epsilon`.
@@ -46,14 +46,14 @@ def npu_rms_norm_forward(self, hidden_states):
@register_kernel @register_kernel
class NpuRMSNormKernel(BaseKernel): class NpuRMSNormKernel(BaseKernel):
r"""NPU kernel wrapper for RMSNorm that applies the replacement within a model.""" """NPU kernel wrapper for RMSNorm that applies the replacement within a model."""
_kernel_id = "npu_fused_rmsnorm" _kernel_id = "npu_fused_rmsnorm"
_device = DeviceType.NPU _device = DeviceType.NPU
@classmethod @classmethod
def apply(cls, **kwargs) -> "HFModel": def apply(cls, **kwargs) -> "HFModel":
r"""Iterate the model and apply NPU-optimized forward to matched RMSNorm modules. """Iterate the model and apply NPU-optimized forward to matched RMSNorm modules.
Key points: Key points:
- Match modules whose class name contains "RMSNorm" (case-insensitive). - Match modules whose class name contains "RMSNorm" (case-insensitive).
@@ -78,6 +78,7 @@ class NpuRMSNormKernel(BaseKernel):
if not cls.check_deps(): if not cls.check_deps():
raise RuntimeError(f"torch_npu is not available but {cls.__name__} was called.") raise RuntimeError(f"torch_npu is not available but {cls.__name__} was called.")
rms_norm_pattern = re.compile("RMSNorm", re.IGNORECASE) rms_norm_pattern = re.compile("RMSNorm", re.IGNORECASE)
for name, module in model.named_modules(): for name, module in model.named_modules():

View File

@@ -40,7 +40,7 @@ except ImportError:
def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
r"""Applies Rotary Position Embedding to the query and key tensors using NPU optimization. """Applies Rotary Position Embedding to the query and key tensors using NPU optimization.
Args: Args:
q (Tensor): Query tensor. q (Tensor): Query tensor.
@@ -61,7 +61,7 @@ def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, unsqueeze_dim=1): def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
r"""Applies Rotary Position Embedding with multimodal sections (Qwen2-VL) on NPU. """Applies Rotary Position Embedding with multimodal sections (Qwen2-VL) on NPU.
Args: Args:
q (Tensor): Query tensor. q (Tensor): Query tensor.
@@ -89,14 +89,14 @@ def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, un
@register_kernel @register_kernel
class NpuRoPEKernel(BaseKernel): class NpuRoPEKernel(BaseKernel):
r"""NPU Kernel for Rotary Position Embedding.""" """NPU Kernel for Rotary Position Embedding."""
_kernel_id = "npu_fused_rope" _kernel_id = "npu_fused_rope"
_device = DeviceType.NPU _device = DeviceType.NPU
@classmethod @classmethod
def apply(cls, **kwargs) -> "HFModel": def apply(cls, **kwargs) -> "HFModel":
r"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`. """Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
This function iterates through the model's modules to find attention layers, This function iterates through the model's modules to find attention layers,
identifies the module where they are defined, and replaces the original identifies the module where they are defined, and replaces the original
@@ -115,9 +115,11 @@ class NpuRoPEKernel(BaseKernel):
""" """
if not cls.check_deps(): if not cls.check_deps():
raise RuntimeError(f"torch_npu is not available but {cls.__name__} was called.") raise RuntimeError(f"torch_npu is not available but {cls.__name__} was called.")
model = kwargs.get("model", None) model = kwargs.get("model", None)
if model is None: if model is None:
raise ValueError(f"HFModel instance is required for {cls.__name__}.") raise ValueError(f"HFModel instance is required for {cls.__name__}.")
_modules = set() _modules = set()
for module in model.modules(): for module in model.modules():
if "Attention" in module.__class__.__name__: if "Attention" in module.__class__.__name__:
@@ -143,4 +145,5 @@ class NpuRoPEKernel(BaseKernel):
_modules.add(module_name) _modules.add(module_name)
except Exception as e: except Exception as e:
logger.warning_rank0_once(f"Failed to apply RoPE kernel to module {module_name}: {e}") logger.warning_rank0_once(f"Failed to apply RoPE kernel to module {module_name}: {e}")
return model return model

View File

@@ -30,7 +30,7 @@ __all__ = ["Registry", "register_kernel"]
class Registry: class Registry:
r"""Registry for managing kernel implementations. """Registry for managing kernel implementations.
Storage structure: ``{ "kernel_id": Class }`` Storage structure: ``{ "kernel_id": Class }``
""" """
@@ -38,8 +38,8 @@ class Registry:
_kernels: dict[str, type[BaseKernel]] = {} _kernels: dict[str, type[BaseKernel]] = {}
@classmethod @classmethod
def register(cls, kernel_cls: type[BaseKernel]): def register(cls, kernel_cls: type[BaseKernel]) -> type[BaseKernel] | None:
r"""Decorator to register a kernel class. """Decorator to register a kernel class.
The class must inherit from :class:`BaseKernel` and specify ``_kernel_id`` and ``_device`` attributes. The class must inherit from :class:`BaseKernel` and specify ``_kernel_id`` and ``_device`` attributes.
@@ -47,7 +47,7 @@ class Registry:
kernel_cls (type[BaseKernel]): The kernel class to register. kernel_cls (type[BaseKernel]): The kernel class to register.
Returns: Returns:
type[BaseKernel]: The registered kernel class. type[BaseKernel] | None: The registered kernel class if the device type matches the current accelerator
Raises: Raises:
TypeError: If the class does not inherit from :class:`BaseKernel`. TypeError: If the class does not inherit from :class:`BaseKernel`.
@@ -55,6 +55,7 @@ class Registry:
""" """
if not issubclass(kernel_cls, BaseKernel): if not issubclass(kernel_cls, BaseKernel):
raise TypeError(f"Class {kernel_cls} must inherit from BaseKernel") raise TypeError(f"Class {kernel_cls} must inherit from BaseKernel")
kernel_id = kernel_cls.get_kernel_id() kernel_id = kernel_cls.get_kernel_id()
device = kernel_cls.get_device() device = kernel_cls.get_device()
@@ -73,7 +74,7 @@ class Registry:
@classmethod @classmethod
def get(cls, kernel_id: str) -> Optional[type[BaseKernel]]: def get(cls, kernel_id: str) -> Optional[type[BaseKernel]]:
r"""Retrieves a registered kernel implementation by its ID. """Retrieves a registered kernel implementation by its ID.
Args: Args:
kernel_id (str): The ID of the kernel to retrieve. kernel_id (str): The ID of the kernel to retrieve.
@@ -85,7 +86,7 @@ class Registry:
@classmethod @classmethod
def get_registered_kernels(cls) -> dict[str, type[BaseKernel]]: def get_registered_kernels(cls) -> dict[str, type[BaseKernel]]:
r"""Returns a dictionary of all registered kernels. """Returns a dictionary of all registered kernels.
Returns: Returns:
dict[str, type[BaseKernel]]: Dictionary mapping kernel IDs to kernel classes. dict[str, type[BaseKernel]]: Dictionary mapping kernel IDs to kernel classes.

View File

@@ -45,13 +45,13 @@ class PeftPlugin(BasePlugin):
return super().__call__(model, config) return super().__call__(model, config)
@PeftPlugin("lora").register @PeftPlugin("lora").register()
def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool) -> PeftModel: def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool) -> PeftModel:
peft_config = LoraConfig(**config) peft_config = LoraConfig(**config)
model = get_peft_model(model, peft_config) model = get_peft_model(model, peft_config)
return model return model
@PeftPlugin("freeze").register @PeftPlugin("freeze").register()
def get_freeze_model(model: HFModel, config: FreezeConfigDict, is_train: bool) -> HFModel: def get_freeze_model(model: HFModel, config: FreezeConfigDict, is_train: bool) -> HFModel:
raise NotImplementedError() raise NotImplementedError()

View File

@@ -0,0 +1,36 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils.plugin import BasePlugin
from ...utils.types import Message, ModelInput, Processor
class RenderingPlugin(BasePlugin):
pass
@RenderingPlugin("qwen").register("render_messages")
def render_qwen_messages(
processor: Processor,
messages: list[Message],
tools: str | None = None,
is_generate: bool = False,
) -> ModelInput:
raise NotImplementedError()
@RenderingPlugin("qwen").register("parse_message")
def parse_qwen_message(generated_text: str) -> Message:
raise NotImplementedError()

View File

@@ -12,10 +12,64 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import asyncio
import os
from collections.abc import Generator
from threading import Thread
from ..config import InputArgument, SampleBackend, get_args from ..config import InputArgument, ModelArguments, SampleArguments, SampleBackend, get_args
from ..core.base_sampler import BaseSampler from ..core.base_sampler import BaseSampler
from ..core.model_loader import ModelLoader from ..core.data_engine import DataEngine
from ..core.model_engine import ModelEngine
from ..core.utils.rendering import Renderer
from ..utils.types import HFModel, Message, Sample, TorchDataset
class SyncSampler(BaseSampler):
def __init__(
self,
args: SampleArguments,
model_args: ModelArguments,
model: HFModel,
renderer: Renderer,
) -> None:
def _start_background_loop(loop: asyncio.AbstractEventLoop) -> None:
asyncio.set_event_loop(loop)
loop.run_forever()
super().__init__(args, model_args, model, renderer)
self._loop = asyncio.new_event_loop()
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
self._thread.start()
def generate(self, messages: list[Message], tools: str | None = None) -> Generator[str, None, None]:
"""Generate tokens synchronously.
Args:
messages: List of messages.
tools: Tools string.
Yields:
Generated tokens.
"""
generator = super().generate(messages, tools)
while True:
try:
token = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop).result()
yield token
except StopAsyncIteration:
break
def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
"""Batch infer samples synchronously.
Args:
dataset: Torch dataset.
Returns:
List of samples.
"""
return asyncio.run_coroutine_threadsafe(super().batch_infer(dataset), self._loop).result()
def run_chat(args: InputArgument = None): def run_chat(args: InputArgument = None):
@@ -23,12 +77,48 @@ def run_chat(args: InputArgument = None):
if sample_args.sample_backend != SampleBackend.HF: if sample_args.sample_backend != SampleBackend.HF:
model_args.init_plugin = {"name": "init_on_meta"} model_args.init_plugin = {"name": "init_on_meta"}
model_loader = ModelLoader(model_args) model_engine = ModelEngine(model_args)
sampler = BaseSampler(sample_args, model_args, model_loader.model, model_loader.processor) sampler = SyncSampler(sample_args, model_args, model_engine.model, model_engine.renderer)
if data_args.dataset is not None: if data_args.dataset is not None:
sampler.batch_infer() dataset = DataEngine(data_args)
sampler.batch_infer(dataset)
else: else:
sampler.generate() if os.name != "nt":
try:
import readline # noqa: F401
except ImportError:
print("Install `readline` for a better experience.")
messages = []
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
while True:
try:
query = input("\nUser: ")
except UnicodeDecodeError:
print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
continue
except Exception:
raise
if query.strip() == "exit":
break
if query.strip() == "clear":
messages = []
print("History has been removed.")
continue
messages.append({"role": "user", "content": [{"type": "text", "value": query}]})
print("Assistant: ", end="", flush=True)
response = ""
for new_text in sampler.generate(messages):
print(new_text, end="", flush=True)
response += new_text
print()
messages.append(model_engine.renderer.parse_message(response))
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -17,7 +17,7 @@ from ..accelerator.interface import DistributedInterface
from ..config.arg_parser import get_args from ..config.arg_parser import get_args
from ..core.base_trainer import BaseTrainer from ..core.base_trainer import BaseTrainer
from ..core.data_engine import DataEngine from ..core.data_engine import DataEngine
from ..core.model_loader import ModelLoader from ..core.model_engine import ModelEngine
class SFTTrainer(BaseTrainer): class SFTTrainer(BaseTrainer):
@@ -28,11 +28,11 @@ def run_sft(user_args):
model_args, data_args, training_args, _ = get_args(user_args) model_args, data_args, training_args, _ = get_args(user_args)
DistributedInterface(training_args.dist_config) DistributedInterface(training_args.dist_config)
data_engine = DataEngine(data_args) data_engine = DataEngine(data_args)
model_loader = ModelLoader(model_args) model_engine = ModelEngine(model_args)
trainer = SFTTrainer( trainer = SFTTrainer(
args=training_args, args=training_args,
model=model_loader.model, model=model_engine.model,
processor=model_loader.processor, processor=model_engine.processor,
dataset=data_engine, dataset=data_engine,
) )
trainer.fit() trainer.fit()

View File

@@ -11,3 +11,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
IGNORE_INDEX = -100

View File

@@ -0,0 +1,29 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers import PreTrainedTokenizer
from .types import Processor
def get_tokenizer(processor: Processor) -> PreTrainedTokenizer:
"""Get tokenizer from processor.
Args:
processor: Processor.
Returns:
Tokenizer.
"""
return processor.tokenizer if hasattr(processor, "tokenizer") else processor

View File

@@ -54,7 +54,7 @@ def _get_default_logging_level() -> "logging._Level":
def _get_library_name() -> str: def _get_library_name() -> str:
return __name__.split(".")[0] return ".".join(__name__.split(".")[:2]) # llamafactory.v1
def _get_library_root_logger() -> "_Logger": def _get_library_root_logger() -> "_Logger":

View File

@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
from collections import defaultdict
from collections.abc import Callable from collections.abc import Callable
from . import logging from . import logging
@@ -27,7 +28,7 @@ class BasePlugin:
A plugin is a callable object that can be registered and called by name. A plugin is a callable object that can be registered and called by name.
""" """
_registry: dict[str, Callable] = {} _registry: dict[str, dict[str, Callable]] = defaultdict(dict)
def __init__(self, name: str | None = None): def __init__(self, name: str | None = None):
"""Initialize the plugin with a name. """Initialize the plugin with a name.
@@ -37,8 +38,7 @@ class BasePlugin:
""" """
self.name = name self.name = name
@property def register(self, method_name: str = "__call__"):
def register(self):
"""Decorator to register a function as a plugin. """Decorator to register a function as a plugin.
Example usage: Example usage:
@@ -46,16 +46,21 @@ class BasePlugin:
@PrintPlugin("hello").register() @PrintPlugin("hello").register()
def print_hello(): def print_hello():
print("Hello world!") print("Hello world!")
@PrintPlugin("hello").register("again")
def print_hello_again():
print("Hello world! Again.")
``` ```
""" """
if self.name is None: if self.name is None:
raise ValueError("Plugin name is not specified.") raise ValueError("Plugin name should be specified.")
if self.name in self._registry: if method_name in self._registry[self.name]:
logger.warning_rank0_once(f"Plugin {self.name} is already registered.") logger.warning_rank0_once(f"Method {method_name} of plugin {self.name} is already registered.")
def decorator(func: Callable) -> Callable: def decorator(func: Callable) -> Callable:
self._registry[self.name] = func self._registry[self.name][method_name] = func
return func return func
return decorator return decorator
@@ -68,10 +73,23 @@ class BasePlugin:
PrintPlugin("hello")() PrintPlugin("hello")()
``` ```
""" """
if self.name not in self._registry: if "__call__" not in self._registry[self.name]:
raise ValueError(f"Plugin {self.name} is not registered.") raise ValueError(f"Method __call__ of plugin {self.name} is not registered.")
return self._registry[self.name](*args, **kwargs) return self._registry[self.name]["__call__"](*args, **kwargs)
def __getattr__(self, method_name: str):
"""Get the registered function with the given name.
Example usage:
```python
PrintPlugin("hello").again()
```
"""
if method_name not in self._registry[self.name]:
raise ValueError(f"Method {method_name} of plugin {self.name} is not registered.")
return self._registry[self.name][method_name]
if __name__ == "__main__": if __name__ == "__main__":
@@ -82,8 +100,13 @@ if __name__ == "__main__":
class PrintPlugin(BasePlugin): class PrintPlugin(BasePlugin):
pass pass
@PrintPlugin("hello").register @PrintPlugin("hello").register()
def print_hello(): def print_hello():
print("Hello world!") print("Hello world!")
@PrintPlugin("hello").register("again")
def print_hello_again():
print("Hello world! Again.")
PrintPlugin("hello")() PrintPlugin("hello")()
PrintPlugin("hello").again()

View File

@@ -84,27 +84,59 @@ class DistributedConfig(TypedDict, total=False):
class Content(TypedDict): class Content(TypedDict):
type: Literal["text", "reasoning", "tools", "tool_calls", "image_url"] type: Literal["text", "reasoning", "tool_call", "image_url"]
"""Type of the content."""
value: str value: str
"""Value of the content."""
class Message(TypedDict): class Message(TypedDict):
role: Literal["system", "user", "assistant", "tool"] role: Literal["system", "user", "assistant", "tool"]
"""Role of the message."""
content: list[Content] content: list[Content]
loss_weight: float """Content of the message."""
loss_weight: NotRequired[float]
"""Loss weight for this message, default to 1.0. Required in training."""
class SFTSample(TypedDict): class SFTSample(TypedDict):
messages: list[Message] messages: list[Message]
"""Messages in the sample."""
extra_info: NotRequired[str] extra_info: NotRequired[str]
"""Extra information for the sample, including tools, kto_labels."""
_dataset_name: NotRequired[str] _dataset_name: NotRequired[str]
"""Dataset name for the sample."""
class DPOSample(TypedDict): class DPOSample(TypedDict):
chosen_messages: list[Message] chosen_messages: list[Message]
"""Chosen messages in the sample."""
rejected_messages: list[Message] rejected_messages: list[Message]
"""Rejected messages in the sample."""
extra_info: NotRequired[str] extra_info: NotRequired[str]
"""Extra information for the sample, including tools, kto_labels."""
_dataset_name: NotRequired[str] _dataset_name: NotRequired[str]
"""Dataset name for the sample."""
Sample = Union[SFTSample, DPOSample] Sample = Union[SFTSample, DPOSample]
class ToolCall(TypedDict):
name: str
"""Function name."""
arguments: str
"""Function arguments."""
class ModelInput(TypedDict, total=False):
input_ids: list[int]
"""Input ids for the model."""
attention_mask: list[int]
"""Attention mask for the model."""
labels: list[int]
"""Labels for the model."""
loss_weights: list[float]
"""Loss weight for each token, default to 1.0."""
position_ids: NotRequired[list[int] | list[list[int]]]
"""Position ids for the model (optional)."""

View File

@@ -18,7 +18,7 @@ Contains shared fixtures, pytest configuration, and custom markers.
""" """
import os import os
from typing import Optional import sys
import pytest import pytest
import torch import torch
@@ -73,7 +73,7 @@ def _handle_slow_tests(items: list[Item]):
item.add_marker(skip_slow) item.add_marker(skip_slow)
def _get_visible_devices_env() -> Optional[str]: def _get_visible_devices_env() -> str | None:
"""Return device visibility env var name.""" """Return device visibility env var name."""
if CURRENT_DEVICE == "cuda": if CURRENT_DEVICE == "cuda":
return "CUDA_VISIBLE_DEVICES" return "CUDA_VISIBLE_DEVICES"
@@ -149,6 +149,14 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -
devices_str = ",".join(str(i) for i in range(required)) devices_str = ",".join(str(i) for i in range(required))
monkeypatch.setenv(env_key, devices_str) monkeypatch.setenv(env_key, devices_str)
# add project root dir to path for mp run
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if project_root not in sys.path:
sys.path.insert(0, project_root)
os.environ["PYTHONPATH"] = project_root + os.pathsep + os.environ.get("PYTHONPATH", "")
else: # non-distributed test else: # non-distributed test
if old_value: if old_value:
visible_devices = [v for v in old_value.split(",") if v != ""] visible_devices = [v for v in old_value.split(",") if v != ""]

View File

@@ -1,2 +1,2 @@
# change if test fails or cache is outdated # change if test fails or cache is outdated
0.9.4.105 0.9.5.101

View File

@@ -1,173 +0,0 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Integration tests for DataLoader with different combinations of packing and dynamic batching.
Tests the 4 scenarios:
a) non pack + non dynamic.
b) non pack + dynamic.
c) pack + non dynamic.
d) pack + dynamic.
"""
import torch
from torch.utils.data import DataLoader as TorchDataLoader
from torch.utils.data import Dataset
from transformers import AutoTokenizer
from llamafactory.v1.config.data_args import DataArguments
from llamafactory.v1.core.data_engine import DataEngine
from llamafactory.v1.core.trainer_utils.data_collator import (
DefaultCollator,
)
from llamafactory.v1.core.trainer_utils.data_loader import DataLoader
from llamafactory.v1.plugins.data_plugins.template import QwenTemplate
from llamafactory.v1.utils.batching_queue import TextBatchingQueue
class TensorDataset(Dataset):
"""Wrapper dataset that converts DataEngine samples to tensor format."""
def __init__(self, data_engine: DataEngine, processor, template, max_samples: int = None):
self.data_engine = data_engine
self.processor = processor
self.template = template
self.max_samples = max_samples or len(data_engine)
self.tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
def __len__(self):
return min(self.max_samples, len(self.data_engine))
def __getitem__(self, idx):
# Get sample from DataEngine
sample = self.data_engine[idx]
# Extract messages from sample
# DataEngine returns samples with format like {"messages": [...], ...}
# For llamafactory/v1-sft-demo, the format should have "messages" field
messages = None
if "messages" in sample:
messages = sample["messages"]
elif "conversations" in sample:
messages = sample["conversations"]
elif "conversation" in sample:
messages = sample["conversation"]
else:
# Try to find message-like fields (skip _dataset_name)
for key, value in sample.items():
if key.startswith("_"):
continue
if isinstance(value, list) and len(value) > 0:
# Check if it looks like a message list
if isinstance(value[0], dict) and "role" in value[0]:
messages = value
break
if messages is None:
raise ValueError(f"Could not find messages in sample: {list(sample.keys())}")
# Encode messages using template
encoded = self.template.encode_messages(self.tokenizer, messages)
# Convert to tensors
return {
"input_ids": torch.tensor(encoded["input_ids"], dtype=torch.long),
"attention_mask": torch.tensor(encoded["attention_mask"], dtype=torch.long),
"labels": torch.tensor(encoded["labels"], dtype=torch.long),
}
def create_real_dataset(max_samples: int = 20, batch_size: int = 4):
"""Create a real dataset using DataEngine."""
data_args = DataArguments(dataset="llamafactory/v1-sft-demo")
data_engine = DataEngine(data_args)
# Create processor and template
processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen2.5")
template = QwenTemplate()
# Create tensor dataset
raw_data_dataset = TensorDataset(data_engine, processor, template, max_samples=max_samples)
# Create torch DataLoader
torch_dataloader = TorchDataLoader(
raw_data_dataset,
batch_size=batch_size,
shuffle=False,
collate_fn=lambda x: x,
)
return torch_dataloader, processor, template
class TestDataLoaderNonPackNonDynamic:
"""Test case a) non pack + non dynamic."""
def test_basic_functionality(self):
"""Test DataLoader without packing and without dynamic batching."""
# Create real dataset
torch_dataloader, processor, template = create_real_dataset(max_samples=80, batch_size=8)
# Create collator (non-packing)
collator = DefaultCollator(processor=processor, template=template)
# Create DataLoader without batching_queue (non-dynamic)
data_loader = DataLoader(
dataloader=torch_dataloader,
collate_fn=collator,
num_micro_batch=1,
batching_queue=None,
)
# Iterate and check results
batches = list(iter(data_loader))
assert len(batches) > 0
# Check first batch
one_batch = batches[0]
micro_batches = one_batch[0]
assert "input_ids" in micro_batches
assert "attention_mask" in micro_batches
assert "labels" in micro_batches
assert micro_batches["input_ids"].shape[0] == 1 # batch_size=1
assert micro_batches["input_ids"].ndim == 2 # [batch_size, seq_len]
class TestDataLoaderNonPackDynamic:
"""Test case b) non pack + dynamic."""
def test_basic_functionality(self):
"""Test DataLoader without packing but with dynamic batching."""
# Create real dataset
torch_dataloader, processor, template = create_real_dataset(max_samples=80, batch_size=8)
collator = DefaultCollator(processor=processor, template=template)
# Create batching queue for dynamic batching
batching_queue = TextBatchingQueue(
token_micro_bsz=120,
buffer_size=8,
)
data_loader = DataLoader(
dataloader=torch_dataloader,
collate_fn=collator,
num_micro_batch=4,
batching_queue=batching_queue,
)
# Iterate and check
batches = list(iter(data_loader))
micro_batch_tokens_first = [micro_batch["attention_mask"].sum() for micro_batch in batches[0]]
assert all(num_tokens <= 120 for num_tokens in micro_batch_tokens_first)
assert len(batches) > 0

View File

@@ -15,18 +15,18 @@
import torch import torch
from llamafactory.v1.config.model_args import ModelArguments, PluginConfig from llamafactory.v1.config.model_args import ModelArguments, PluginConfig
from llamafactory.v1.core.model_loader import ModelLoader from llamafactory.v1.core.model_engine import ModelEngine
def test_tiny_qwen(): def test_tiny_qwen():
from transformers import Qwen2Config, Qwen2ForCausalLM, Qwen2TokenizerFast from transformers import Qwen2Config, Qwen2ForCausalLM, Qwen2TokenizerFast
model_args = ModelArguments(model="llamafactory/tiny-random-qwen2.5") model_args = ModelArguments(model="llamafactory/tiny-random-qwen2.5")
model_loader = ModelLoader(model_args) model_engine = ModelEngine(model_args)
assert isinstance(model_loader.processor, Qwen2TokenizerFast) assert isinstance(model_engine.processor, Qwen2TokenizerFast)
assert isinstance(model_loader.model.config, Qwen2Config) assert isinstance(model_engine.model_config, Qwen2Config)
assert isinstance(model_loader.model, Qwen2ForCausalLM) assert isinstance(model_engine.model, Qwen2ForCausalLM)
assert model_loader.model.dtype == torch.bfloat16 assert model_engine.model.dtype == torch.bfloat16
def test_tiny_qwen_with_kernel_plugin(): def test_tiny_qwen_with_kernel_plugin():
@@ -37,13 +37,14 @@ def test_tiny_qwen_with_kernel_plugin():
model_args = ModelArguments( model_args = ModelArguments(
model="llamafactory/tiny-random-qwen2.5", kernel_config=PluginConfig(name="auto", include_kernels="auto") model="llamafactory/tiny-random-qwen2.5", kernel_config=PluginConfig(name="auto", include_kernels="auto")
) )
model_loader = ModelLoader(model_args) model_engine = ModelEngine(model_args)
# test enable apply kernel plugin # test enable apply kernel plugin
if hasattr(torch, "npu"): if hasattr(torch, "npu"):
assert model_loader.model.model.layers[0].input_layernorm.forward.__code__ == npu_rms_norm_forward.__code__ assert model_engine.model.model.layers[0].input_layernorm.forward.__code__ == npu_rms_norm_forward.__code__
else: else:
assert model_loader.model.model.layers[0].input_layernorm.forward.__code__ != npu_rms_norm_forward.__code__ assert model_engine.model.model.layers[0].input_layernorm.forward.__code__ != npu_rms_norm_forward.__code__
assert isinstance(model_loader.model, Qwen2ForCausalLM)
assert isinstance(model_engine.model, Qwen2ForCausalLM)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -0,0 +1,171 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Integration tests for DataLoader with different combinations of packing and dynamic batching.
Tests the 4 scenarios:
a) non pack + non dynamic.
b) non pack + dynamic.
c) pack + non dynamic.
d) pack + dynamic.
"""
# import torch
# from torch.utils.data import DataLoader as TorchDataLoader
# from torch.utils.data import Dataset
# from transformers import AutoTokenizer
# from llamafactory.v1.config.data_args import DataArguments
# from llamafactory.v1.core.data_engine import DataEngine
# from llamafactory.v1.core.utils.data_collator import DefaultCollator
# from llamafactory.v1.core.utils.data_loader import DataLoader
# from llamafactory.v1.plugins.data_plugins.rendering import QwenTemplate
# from llamafactory.v1.utils.batching_queue import TextBatchingQueue
# class TensorDataset(Dataset):
# """Wrapper dataset that converts DataEngine samples to tensor format."""
# def __init__(self, data_engine: DataEngine, processor, template, max_samples: int = None):
# self.data_engine = data_engine
# self.processor = processor
# self.template = template
# self.max_samples = max_samples or len(data_engine)
# self.tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
# def __len__(self):
# return min(self.max_samples, len(self.data_engine))
# def __getitem__(self, idx):
# # Get sample from DataEngine
# sample = self.data_engine[idx]
# # Extract messages from sample
# # DataEngine returns samples with format like {"messages": [...], ...}
# # For llamafactory/v1-sft-demo, the format should have "messages" field
# messages = None
# if "messages" in sample:
# messages = sample["messages"]
# elif "conversations" in sample:
# messages = sample["conversations"]
# elif "conversation" in sample:
# messages = sample["conversation"]
# else:
# # Try to find message-like fields (skip _dataset_name)
# for key, value in sample.items():
# if key.startswith("_"):
# continue
# if isinstance(value, list) and len(value) > 0:
# # Check if it looks like a message list
# if isinstance(value[0], dict) and "role" in value[0]:
# messages = value
# break
# if messages is None:
# raise ValueError(f"Could not find messages in sample: {list(sample.keys())}")
# # Encode messages using template
# encoded = self.template.encode_messages(self.tokenizer, messages)
# # Convert to tensors
# return {
# "input_ids": torch.tensor(encoded["input_ids"], dtype=torch.long),
# "attention_mask": torch.tensor(encoded["attention_mask"], dtype=torch.long),
# "labels": torch.tensor(encoded["labels"], dtype=torch.long),
# }
# def create_real_dataset(max_samples: int = 20, batch_size: int = 4):
# """Create a real dataset using DataEngine."""
# data_args = DataArguments(dataset="llamafactory/v1-sft-demo")
# data_engine = DataEngine(data_args)
# # Create processor and template
# processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen2.5")
# template = QwenTemplate()
# # Create tensor dataset
# raw_data_dataset = TensorDataset(data_engine, processor, template, max_samples=max_samples)
# # Create torch DataLoader
# torch_dataloader = TorchDataLoader(
# raw_data_dataset,
# batch_size=batch_size,
# shuffle=False,
# collate_fn=lambda x: x,
# )
# return torch_dataloader, processor, template
# class TestDataLoaderNonPackNonDynamic:
# """Test case a) non pack + non dynamic."""
# def test_basic_functionality(self):
# """Test DataLoader without packing and without dynamic batching."""
# # Create real dataset
# torch_dataloader, processor, template = create_real_dataset(max_samples=80, batch_size=8)
# # Create collator (non-packing)
# collator = DefaultCollator(processor=processor, template=template)
# # Create DataLoader without batching_queue (non-dynamic)
# data_loader = DataLoader(
# dataloader=torch_dataloader,
# collate_fn=collator,
# num_micro_batch=1,
# batching_queue=None,
# )
# # Iterate and check results
# batches = list(iter(data_loader))
# assert len(batches) > 0
# # Check first batch
# one_batch = batches[0]
# micro_batches = one_batch[0]
# assert "input_ids" in micro_batches
# assert "attention_mask" in micro_batches
# assert "labels" in micro_batches
# assert micro_batches["input_ids"].shape[0] == 1 # batch_size=1
# assert micro_batches["input_ids"].ndim == 2 # [batch_size, seq_len]
# class TestDataLoaderNonPackDynamic:
# """Test case b) non pack + dynamic."""
# def test_basic_functionality(self):
# """Test DataLoader without packing but with dynamic batching."""
# # Create real dataset
# torch_dataloader, processor, template = create_real_dataset(max_samples=80, batch_size=8)
# collator = DefaultCollator(processor=processor, template=template)
# # Create batching queue for dynamic batching
# batching_queue = TextBatchingQueue(
# token_micro_bsz=120,
# buffer_size=8,
# )
# data_loader = DataLoader(
# dataloader=torch_dataloader,
# collate_fn=collator,
# num_micro_batch=4,
# batching_queue=batching_queue,
# )
# # Iterate and check
# batches = list(iter(data_loader))
# micro_batch_tokens_first = [micro_batch["attention_mask"].sum() for micro_batch in batches[0]]
# assert all(num_tokens <= 120 for num_tokens in micro_batch_tokens_first)
# assert len(batches) > 0

View File

@@ -0,0 +1,65 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers import AutoTokenizer
from llamafactory.v1.core.utils.rendering import Renderer
from llamafactory.v1.utils.types import Processor
HF_MESSAGES = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is LLM?"},
{"role": "assistant", "content": "LLM stands for Large Language Model."},
]
V1_MESSAGES = [
{"role": "system", "content": [{"type": "text", "value": "You are a helpful assistant."}]},
{"role": "user", "content": [{"type": "text", "value": "What is LLM?"}]},
{"role": "assistant", "content": [{"type": "text", "value": "LLM stands for Large Language Model."}]},
]
def test_chatml_rendering():
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
renderer = Renderer(template="chatml", processor=tokenizer)
hf_inputs = tokenizer.apply_chat_template(HF_MESSAGES[:-1], add_generation_prompt=True)
v1_inputs = renderer.render_messages(V1_MESSAGES[:-1], is_generate=True)
assert v1_inputs["input_ids"] == hf_inputs
assert v1_inputs["attention_mask"] == [1] * len(hf_inputs)
assert v1_inputs["labels"] == [-100] * len(hf_inputs)
assert v1_inputs["loss_weights"] == [0.0] * len(hf_inputs)
hf_inputs_part = tokenizer.apply_chat_template(HF_MESSAGES[:-1], add_generation_prompt=False)
hf_inputs_full = tokenizer.apply_chat_template(HF_MESSAGES, add_generation_prompt=False)
v1_inputs_full = renderer.render_messages(V1_MESSAGES, is_generate=False)
assert v1_inputs_full["input_ids"] == hf_inputs_full
assert v1_inputs_full["attention_mask"] == [1] * len(hf_inputs_full)
assert v1_inputs_full["labels"] == [-100] * len(hf_inputs_part) + hf_inputs_full[len(hf_inputs_part) :]
assert v1_inputs_full["loss_weights"] == [0.0] * len(hf_inputs_part) + [1.0] * (
len(hf_inputs_full) - len(hf_inputs_part)
)
def test_chatml_parse():
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
renderer = Renderer(template="chatml", processor=tokenizer)
generated_text = "LLM stands for Large Language Model."
parsed_message = renderer.parse_message(generated_text)
assert parsed_message == V1_MESSAGES[-1]
if __name__ == "__main__":
test_chatml_rendering()
test_chatml_parse()

View File

@@ -54,7 +54,7 @@ def test_sharegpt_converter():
"conversations": [ "conversations": [
{"from": "system", "value": "System"}, {"from": "system", "value": "System"},
{"from": "human", "value": "User"}, {"from": "human", "value": "User"},
{"from": "function_call", "value": "Tool"}, {"from": "function_call", "value": "1"},
{"from": "observation", "value": "Observation"}, {"from": "observation", "value": "Observation"},
{"from": "gpt", "value": "Assistant"}, {"from": "gpt", "value": "Assistant"},
] ]
@@ -63,7 +63,7 @@ def test_sharegpt_converter():
"messages": [ "messages": [
{"content": [{"type": "text", "value": "System"}], "loss_weight": 0.0, "role": "system"}, {"content": [{"type": "text", "value": "System"}], "loss_weight": 0.0, "role": "system"},
{"content": [{"type": "text", "value": "User"}], "loss_weight": 0.0, "role": "user"}, {"content": [{"type": "text", "value": "User"}], "loss_weight": 0.0, "role": "user"},
{"content": [{"type": "tool_calls", "value": "Tool"}], "loss_weight": 1.0, "role": "assistant"}, {"content": [{"type": "tool_call", "value": "1"}], "loss_weight": 1.0, "role": "assistant"},
{"content": [{"type": "text", "value": "Observation"}], "loss_weight": 0.0, "role": "tool"}, {"content": [{"type": "text", "value": "Observation"}], "loss_weight": 0.0, "role": "tool"},
{"content": [{"type": "text", "value": "Assistant"}], "loss_weight": 1.0, "role": "assistant"}, {"content": [{"type": "text", "value": "Assistant"}], "loss_weight": 1.0, "role": "assistant"},
] ]

View File

@@ -12,11 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import pytest
from llamafactory.v1.accelerator.interface import DistributedInterface from llamafactory.v1.accelerator.interface import DistributedInterface
from llamafactory.v1.config.arg_parser import get_args from llamafactory.v1.config.arg_parser import get_args
from llamafactory.v1.core.model_loader import ModelLoader from llamafactory.v1.core.model_engine import ModelEngine
def test_init_on_meta(): def test_init_on_meta():
@@ -26,11 +25,10 @@ def test_init_on_meta():
init_config={"name": "init_on_meta"}, init_config={"name": "init_on_meta"},
) )
) )
model_loader = ModelLoader(model_args=model_args) model_engine = ModelEngine(model_args=model_args)
assert model_loader.model.device.type == "meta" assert model_engine.model.device.type == "meta"
@pytest.mark.runs_on(["cuda", "npu"])
def test_init_on_rank0(): def test_init_on_rank0():
_, model_args, *_ = get_args( _, model_args, *_ = get_args(
dict( dict(
@@ -38,11 +36,11 @@ def test_init_on_rank0():
init_config={"name": "init_on_rank0"}, init_config={"name": "init_on_rank0"},
) )
) )
model_loader = ModelLoader(model_args=model_args) model_engine = ModelEngine(model_args=model_args)
if DistributedInterface().get_rank() == 0: if DistributedInterface().get_rank() == 0:
assert model_loader.model.device.type == "cpu" assert model_engine.model.device.type == "cpu"
else: else:
assert model_loader.model.device.type == "meta" assert model_engine.model.device.type == "meta"
def test_init_on_default(): def test_init_on_default():
@@ -52,5 +50,5 @@ def test_init_on_default():
init_config={"name": "init_on_default"}, init_config={"name": "init_on_default"},
) )
) )
model_loader = ModelLoader(model_args=model_args) model_engine = ModelEngine(model_args=model_args)
assert model_loader.model.device.type == DistributedInterface().current_accelerator.type assert model_engine.model.device == DistributedInterface().current_device

View File

@@ -0,0 +1,41 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from llamafactory.v1.config import ModelArguments, SampleArguments
from llamafactory.v1.core.model_engine import ModelEngine
from llamafactory.v1.samplers.cli_sampler import SyncSampler
@pytest.mark.runs_on(["cuda", "npu"])
def test_sync_sampler():
model_args = ModelArguments(model="Qwen/Qwen3-4B-Instruct-2507")
sample_args = SampleArguments()
model_engine = ModelEngine(model_args)
sampler = SyncSampler(sample_args, model_args, model_engine.model, model_engine.renderer)
messages = [{"role": "user", "content": [{"type": "text", "value": "Say 'This is a test.'"}]}]
response = ""
for new_text in sampler.generate(messages):
response += new_text
print(response)
assert model_engine.renderer.parse_message(response) == {
"role": "assistant",
"content": [{"type": "text", "value": "This is a test."}],
}
if __name__ == "__main__":
test_sync_sampler()