[v1] add cli sampler (#9721)

This commit is contained in:
Yaowei Zheng
2026-01-06 23:31:27 +08:00
committed by GitHub
parent e944dc442c
commit ea0b4e2466
45 changed files with 1091 additions and 505 deletions

View File

@@ -13,11 +13,12 @@
# limitations under the License.
import json
from typing import Any, Literal, NotRequired, TypedDict
from ...utils import logging
from ...utils.plugin import BasePlugin
from ...utils.types import DPOSample, Sample, SFTSample
from ...utils.types import DPOSample, Sample, SFTSample, ToolCall
logger = logging.get_logger(__name__)
@@ -61,7 +62,7 @@ class DataConverterPlugin(BasePlugin):
return super().__call__(raw_sample)
@DataConverterPlugin("alpaca").register
@DataConverterPlugin("alpaca").register()
def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
"""Convert Alpaca sample to SFT sample.
@@ -98,7 +99,7 @@ def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
return {"messages": messages}
@DataConverterPlugin("sharegpt").register
@DataConverterPlugin("sharegpt").register()
def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
"""Convert ShareGPT sample to SFT sample.
@@ -118,17 +119,32 @@ def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
"function_call": "assistant",
}
messages = []
tools = raw_sample.get("tools", "")
tools = raw_sample.get("tools")
if tools:
try:
tools: list[dict[str, Any]] = json.loads(tools)
except json.JSONDecodeError:
logger.warning_rank0(f"Invalid tools format: {str(tools)}")
tools = []
for message in raw_sample.get("conversations", []):
tag = message["from"]
if tag not in tag_mapping:
logger.warning_rank0(f"Unsupported role tag {tag} in message: {message}")
elif tag == "function_call":
try:
tool_calls: ToolCall | list[ToolCall] = json.loads(message["value"])
except json.JSONDecodeError:
logger.warning_rank0(f"Invalid tool call format: {str(message['value'])}")
continue
if not isinstance(tool_calls, list):
tool_calls = [tool_calls]
messages.append(
{
"role": "assistant",
"content": [{"type": "tool_calls", "value": message["value"]}],
"content": [{"type": "tool_call", "value": json.dumps(tool_call)} for tool_call in tool_calls],
"loss_weight": 1.0,
}
)
@@ -142,15 +158,12 @@ def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
)
if tools:
if messages and messages[0]["role"] == "system":
messages[0]["content"].append({"type": "tools", "value": tools})
else:
messages.insert(0, {"role": "system", "content": [{"type": "tools", "value": tools}], "loss_weight": 0.0})
return {"messages": messages}
return {"messages": messages, "extra_info": json.dumps({"tools": tools})}
else:
return {"messages": messages}
@DataConverterPlugin("pair").register
@DataConverterPlugin("pair").register()
def pair_converter(raw_sample: PairSample) -> DPOSample:
"""Convert Pair sample to DPO sample.

View File

@@ -49,7 +49,7 @@ def _get_builder_name(path: str) -> Literal["arrow", "csv", "json", "parquet", "
raise ValueError(f"Unknown dataset filetype: {filetype}.")
@DataLoaderPlugin("local").register
@DataLoaderPlugin("local").register()
def load_data_from_file(filepath: str, split: str, streaming: bool) -> HFDataset:
if os.path.isdir(filepath):
filetype = _get_builder_name(os.listdir(filepath)[0])
@@ -66,49 +66,43 @@ def load_data_from_file(filepath: str, split: str, streaming: bool) -> HFDataset
return dataset
class DataIndexPlugin(BasePlugin):
"""Plugin for adjusting dataset index."""
def adjust_data_index(
data_index: list[tuple[str, int]], size: int | None, weight: float | None
) -> list[tuple[str, int]]:
"""Adjust dataset index by size and weight.
def adjust_data_index(
self, data_index: list[tuple[str, int]], size: int | None, weight: float | None
) -> list[tuple[str, int]]:
"""Adjust dataset index by size and weight.
Args:
data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
size (Optional[int]): Desired dataset size.
weight (Optional[float]): Desired dataset weight.
Args:
data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
size (Optional[int]): Desired dataset size.
weight (Optional[float]): Desired dataset weight.
Returns:
list[tuple[str, int]]: Adjusted dataset index.
"""
if size is not None:
data_index = random.choices(data_index, k=size)
Returns:
list[tuple[str, int]]: Adjusted dataset index.
"""
if size is not None:
data_index = random.choices(data_index, k=size)
if weight is not None:
data_index = random.choices(data_index, k=int(len(data_index) * weight))
if weight is not None:
data_index = random.choices(data_index, k=int(len(data_index) * weight))
return data_index
return data_index
class DataSelectorPlugin(BasePlugin):
"""Plugin for selecting dataset samples."""
def select_data_sample(
data_index: list[tuple[str, int]], index: slice | list[int] | Any
) -> tuple[str, int] | list[tuple[str, int]]:
"""Select dataset samples.
def select(
self, data_index: list[tuple[str, int]], index: slice | list[int] | Any
) -> tuple[str, int] | list[tuple[str, int]]:
"""Select dataset samples.
Args:
data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
index (Union[slice, list[int], Any]): Index of dataset samples.
Args:
data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
index (Union[slice, list[int], Any]): Index of dataset samples.
Returns:
Union[tuple[str, int], list[tuple[str, int]]]: Selected dataset samples.
"""
if isinstance(index, slice):
return [data_index[i] for i in range(*index.indices(len(data_index)))]
elif isinstance(index, list):
return [data_index[i] for i in index]
else:
raise ValueError(f"Invalid index type {type(index)}.")
Returns:
Union[tuple[str, int], list[tuple[str, int]]]: Selected dataset samples.
"""
if isinstance(index, slice):
return [data_index[i] for i in range(*index.indices(len(data_index)))]
elif isinstance(index, list):
return [data_index[i] for i in index]
else:
raise ValueError(f"Invalid index type {type(index)}.")

View File

@@ -1,133 +0,0 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
@dataclass
class Template:
user_template: str
assistant_template: str
system_template: str
def render_message(self, message: dict[str, str]) -> str:
return self.user_template.format(**message)
@dataclass
class QwenTemplate:
message_template: str = "<|im_start|>{role}\n{content}<|im_end|>\n" # FIXME if role: tool
thinking_template: str = "<think>\n{content}\n</think>\n\n"
def _extract_content(self, content_data: str | list[dict[str, str]]) -> str:
if isinstance(content_data, str):
return content_data.strip()
if isinstance(content_data, list):
parts = []
for item in content_data:
if item.get("type") == "text":
parts.append(item.get("value", ""))
elif item.get("type") == "image_url":
pass
return "\n".join(parts).strip()
return ""
def render_message(self, message: dict[str, str | list[dict[str, str]]]) -> str:
role = message["role"]
content = self._extract_content(message.get("content", ""))
if role == "assistant":
reasoning_content = message.get("reasoning_content", "")
if reasoning_content:
reasoning_content = self.thinking_template.format(content=str(reasoning_content).strip())
return self.message_template.format(role="assistant", content=reasoning_content + content)
else:
return self.message_template.format(role=role, content=content)
def encode_messages(self, tokenizer, messages: list[dict[str, str]], max_seq_len: int = 8192) -> any:
"""Encode one message."""
input_ids, attention_mask, labels = [], [], []
for message in messages:
content_str = self.render_message(message)
content_ids = tokenizer.encode(content_str, add_special_tokens=False)
input_ids += content_ids
attention_mask += [1] * len(content_ids)
if hasattr(message, "loss_weight"):
loss_weight = message["loss_weight"]
else:
loss_weight = 1 if message["role"] == "assistant" else 0
if loss_weight == 1:
labels += content_ids
else:
labels += [-100] * len(content_ids)
model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
model_inputs.update({"position_ids": list(range(len(input_ids)))})
model_inputs = {k: v[-max_seq_len:] for k, v in model_inputs.items()}
return model_inputs
if __name__ == "__main__":
def to_qwen3_messages(template: QwenTemplate, messages: list[dict]):
out = []
for m in messages:
role = m["role"]
content = template._extract_content(m.get("content", ""))
if role == "assistant":
reasoning = (m.get("reasoning_content") or "").strip()
if reasoning:
content = template.thinking_template.format(content=reasoning) + content
out.append({"role": role, "content": content})
return out
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained(
"Qwen/Qwen3-30B-A3B-Thinking-2507",
trust_remote_code=True,
)
test_messages = [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": [{"type": "text", "text": "1+1等于几"}, {"type": "text", "text": "2+2等于几"}],
},
{
"role": "assistant",
"reasoning_content": "这是一个简单的数学问题。1加1的结果是2。",
"content": [{"type": "text", "text": "1+1=2"}, {"type": "text", "text": "2+2=4"}],
},
]
template = QwenTemplate()
rendered_custom = "".join([template.render_message(m) for m in test_messages])
qwen3_messages = to_qwen3_messages(template, test_messages)
rendered_hf = tok.apply_chat_template(qwen3_messages, tokenize=False, add_generation_prompt=False)
print("==== custom ====")
print(rendered_custom)
print("==== hf ====")
print(rendered_hf)
assert rendered_custom.strip() == rendered_hf.strip(), "Rendered text mismatch"
ids_custom = tok.encode(rendered_custom, add_special_tokens=False)
ids_hf = tok.apply_chat_template(qwen3_messages, tokenize=True, add_generation_prompt=False)
assert ids_custom == ids_hf, f"Token ids mismatch: custom={len(ids_custom)} hf={len(ids_hf)}"

View File

@@ -25,12 +25,12 @@ class InitPlugin(BasePlugin):
return super().__call__()
@InitPlugin("init_on_meta").register
@InitPlugin("init_on_meta").register()
def init_on_meta() -> torch.device:
return torch.device(DeviceType.META.value)
@InitPlugin("init_on_rank0").register
@InitPlugin("init_on_rank0").register()
def init_on_rank0() -> torch.device:
if DistributedInterface().get_rank() == 0:
return torch.device(DeviceType.CPU.value)
@@ -38,6 +38,6 @@ def init_on_rank0() -> torch.device:
return torch.device(DeviceType.META.value)
@InitPlugin("init_on_default").register
@InitPlugin("init_on_default").register()
def init_on_default() -> torch.device:
return DistributedInterface().current_accelerator
return DistributedInterface().current_device

View File

@@ -38,17 +38,17 @@ class BaseKernel(ABC):
@classmethod
def get_kernel_id(cls) -> str:
r"""Returns the unique identifier for the kernel."""
"""Returns the unique identifier for the kernel."""
return cls._kernel_id
@classmethod
def get_device(cls) -> str:
r"""Returns the device type associated with the kernel (e.g., "cuda", "npu", "cpu")."""
"""Returns the device type associated with the kernel (e.g., "cuda", "npu", "cpu")."""
return cls._device
@classmethod
def check_deps(cls) -> bool:
r"""Checks if the required dependencies for the kernel are available.
"""Checks if the required dependencies for the kernel are available.
Returns:
bool: ``True`` if dependencies are met, ``False`` otherwise.
@@ -65,7 +65,7 @@ class BaseKernel(ABC):
@classmethod
@abstractmethod
def apply(cls, **kwargs) -> HFModel:
r"""Applies the kernel optimization to the model.
"""Applies the kernel optimization to the model.
Args:
**kwargs: Arbitrary keyword arguments, usually containing the model instance and the kernel configuration.

View File

@@ -33,7 +33,7 @@ logger = get_logger(__name__)
def scan_all_kernels():
r"""Scan all kernels in the ``ops`` directory.
"""Scan all kernels in the ``ops`` directory.
Scans the ``ops`` directory for all ``.py`` files and attempts to import them.
Importing triggers the :func:`~registry.register_kernel` decorator, which automatically registers the kernels.
@@ -77,7 +77,7 @@ default_kernels = scan_all_kernels()
def get_default_kernels():
r"""Get a list of default registered kernel IDs.
"""Get a list of default registered kernel IDs.
Returns:
list[str]: List of kernel IDs.
@@ -86,7 +86,7 @@ def get_default_kernels():
def apply_kernel(kernel_id: str, **kwargs):
r"""Applies a specific kernel to the model.
"""Applies a specific kernel to the model.
Args:
kernel_id (str): The ID of the kernel to apply.
@@ -99,18 +99,19 @@ def apply_kernel(kernel_id: str, **kwargs):
kernel = default_kernels.get(kernel_id)
if kernel is None:
raise ValueError(f"Kernel {kernel_id} not found")
kernel.apply(**kwargs)
class KernelPlugin(BasePlugin):
r"""Plugin for managing kernel optimizations."""
"""Plugin for managing kernel optimizations."""
pass
@KernelPlugin("auto").register
@KernelPlugin("auto").register()
def apply_default_kernels(**kwargs):
r"""Applies all default registered kernels to the model.
"""Applies all default registered kernels to the model.
Args:
**kwargs: Keyword arguments passed to the kernel application function.
@@ -125,8 +126,11 @@ def apply_default_kernels(**kwargs):
use_kernels = default_kernels.keys()
else:
use_kernels = kwargs.get("include_kernels").split(",") # "kernel_id1,kernel_id2,kernel_id3"
for kernel in use_kernels:
if kernel not in default_kernels:
raise ValueError(f"Kernel {kernel} not found")
apply_kernel(kernel, **kwargs)
return kwargs.get("model")

View File

@@ -40,11 +40,11 @@ from ...registry import register_kernel
class GmmFunction(torch.autograd.Function):
r"""Custom autograd function for NPU Grouped Matrix Multiplication (GMM)."""
"""Custom autograd function for NPU Grouped Matrix Multiplication (GMM)."""
@staticmethod
def forward(ctx, x, weight, group_list):
r"""Performs the forward pass of Grouped Matrix Multiplication.
"""Performs the forward pass of Grouped Matrix Multiplication.
Args:
ctx: Context object to save tensors for backward pass.
@@ -65,7 +65,7 @@ class GmmFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output):
r"""Performs the backward pass of Grouped Matrix Multiplication.
"""Performs the backward pass of Grouped Matrix Multiplication.
Args:
ctx: Context object containing saved tensors.
@@ -94,11 +94,11 @@ class GmmFunction(torch.autograd.Function):
class HybridGmmFunction(torch.autograd.Function):
r"""Custom autograd function for Hybrid Grouped Matrix Multiplication on NPU."""
"""Custom autograd function for Hybrid Grouped Matrix Multiplication on NPU."""
@staticmethod
def forward(ctx, num_experts, *args):
r"""Performs the forward pass of Hybrid GMM.
"""Performs the forward pass of Hybrid GMM.
Args:
ctx: Context object to save tensors.
@@ -124,7 +124,7 @@ class HybridGmmFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, *grad_outputs):
r"""Performs the backward pass of Hybrid GMM.
"""Performs the backward pass of Hybrid GMM.
Args:
ctx: Context object containing saved tensors.
@@ -176,13 +176,13 @@ class HybridGmmFunction(torch.autograd.Function):
class NpuMoeFused:
r"""Container for NPU fused MoE forward functions."""
"""Container for NPU fused MoE forward functions."""
@staticmethod
def npu_moe_experts_forward(
self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor
) -> torch.Tensor:
r"""Forward pass for MoE experts using NPU fused operations.
"""Forward pass for MoE experts using NPU fused operations.
Args:
self: The MoE layer instance.
@@ -230,11 +230,11 @@ class NpuMoeFused:
class Qwen3NpuMoeFused:
r"""Container for Qwen3 NPU fused MoE forward functions."""
"""Container for Qwen3 NPU fused MoE forward functions."""
@staticmethod
def qwen3moe_sparse_moe_block_forward(self, hidden_states: torch.Tensor):
r"""Forward pass for Qwen3 sparse MoE block using NPU fused operations.
"""Forward pass for Qwen3 sparse MoE block using NPU fused operations.
Args:
self: The Qwen3 MoE block instance.
@@ -298,14 +298,14 @@ if not is_transformers_version_greater_than("5.0.0"):
@register_kernel
class NpuFusedMoEKernel(BaseKernel):
r"""NPU Fused MoE Kernel implementation."""
"""NPU Fused MoE Kernel implementation."""
_kernel_id = "npu_fused_moe"
_device = DeviceType.NPU
@classmethod
def apply(cls, **kwargs) -> HFModel:
r"""Applies the NPU fused MoE kernel to the model.
"""Applies the NPU fused MoE kernel to the model.
Args:
**kwargs: Keyword arguments containing the model.
@@ -333,6 +333,7 @@ class NpuFusedMoEKernel(BaseKernel):
if target_moe_mapping is None:
return model
for module in model.modules():
class_name = module.__class__.__name__
if class_name in target_moe_mapping:

View File

@@ -38,7 +38,7 @@ except ImportError:
def npu_swiglu_forward(self, hidden_state):
r"""SwiGLU forward pass for NPU.
"""SwiGLU forward pass for NPU.
Args:
self: The MLP layer instance.
@@ -53,7 +53,7 @@ def npu_swiglu_forward(self, hidden_state):
def _npu_swiglu_glm4_forward(self, hidden_states):
r"""SwiGLU forward pass for GLM4 on NPU.
"""SwiGLU forward pass for GLM4 on NPU.
Args:
self: The GLM4 MLP layer instance.
@@ -68,7 +68,7 @@ def _npu_swiglu_glm4_forward(self, hidden_states):
def _npu_swiglu_gemma3ntext_forward(self, hidden_states):
r"""SwiGLU forward pass for Gemma3nText on NPU.
"""SwiGLU forward pass for Gemma3nText on NPU.
Args:
self: The Gemma3nText MLP layer instance.
@@ -88,7 +88,7 @@ def _npu_swiglu_gemma3ntext_forward(self, hidden_states):
@register_kernel
class NpuSwiGluKernel(BaseKernel):
r"""NPU Kernel for fused SwiGLU activation."""
"""NPU Kernel for fused SwiGLU activation."""
# just support apply to the following module layers
expect_modules = frozenset(
@@ -126,7 +126,7 @@ class NpuSwiGluKernel(BaseKernel):
@classmethod
def apply(cls, **kwargs) -> "HFModel":
r"""Applies the NPU fused SwiGLU kernel to the model.
"""Applies the NPU fused SwiGLU kernel to the model.
Args:
**kwargs: Keyword arguments containing the model.

View File

@@ -30,7 +30,7 @@ from ...registry import register_kernel
def npu_rms_norm_forward(self, hidden_states):
r"""NPU forward implementation for RMSNorm.
"""NPU forward implementation for RMSNorm.
Args:
self: RMSNorm module instance with `weight` and `variance_epsilon`.
@@ -46,14 +46,14 @@ def npu_rms_norm_forward(self, hidden_states):
@register_kernel
class NpuRMSNormKernel(BaseKernel):
r"""NPU kernel wrapper for RMSNorm that applies the replacement within a model."""
"""NPU kernel wrapper for RMSNorm that applies the replacement within a model."""
_kernel_id = "npu_fused_rmsnorm"
_device = DeviceType.NPU
@classmethod
def apply(cls, **kwargs) -> "HFModel":
r"""Iterate the model and apply NPU-optimized forward to matched RMSNorm modules.
"""Iterate the model and apply NPU-optimized forward to matched RMSNorm modules.
Key points:
- Match modules whose class name contains "RMSNorm" (case-insensitive).
@@ -78,6 +78,7 @@ class NpuRMSNormKernel(BaseKernel):
if not cls.check_deps():
raise RuntimeError(f"torch_npu is not available but {cls.__name__} was called.")
rms_norm_pattern = re.compile("RMSNorm", re.IGNORECASE)
for name, module in model.named_modules():

View File

@@ -40,7 +40,7 @@ except ImportError:
def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
r"""Applies Rotary Position Embedding to the query and key tensors using NPU optimization.
"""Applies Rotary Position Embedding to the query and key tensors using NPU optimization.
Args:
q (Tensor): Query tensor.
@@ -61,7 +61,7 @@ def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
r"""Applies Rotary Position Embedding with multimodal sections (Qwen2-VL) on NPU.
"""Applies Rotary Position Embedding with multimodal sections (Qwen2-VL) on NPU.
Args:
q (Tensor): Query tensor.
@@ -89,14 +89,14 @@ def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, un
@register_kernel
class NpuRoPEKernel(BaseKernel):
r"""NPU Kernel for Rotary Position Embedding."""
"""NPU Kernel for Rotary Position Embedding."""
_kernel_id = "npu_fused_rope"
_device = DeviceType.NPU
@classmethod
def apply(cls, **kwargs) -> "HFModel":
r"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
This function iterates through the model's modules to find attention layers,
identifies the module where they are defined, and replaces the original
@@ -115,9 +115,11 @@ class NpuRoPEKernel(BaseKernel):
"""
if not cls.check_deps():
raise RuntimeError(f"torch_npu is not available but {cls.__name__} was called.")
model = kwargs.get("model", None)
if model is None:
raise ValueError(f"HFModel instance is required for {cls.__name__}.")
_modules = set()
for module in model.modules():
if "Attention" in module.__class__.__name__:
@@ -143,4 +145,5 @@ class NpuRoPEKernel(BaseKernel):
_modules.add(module_name)
except Exception as e:
logger.warning_rank0_once(f"Failed to apply RoPE kernel to module {module_name}: {e}")
return model

View File

@@ -30,7 +30,7 @@ __all__ = ["Registry", "register_kernel"]
class Registry:
r"""Registry for managing kernel implementations.
"""Registry for managing kernel implementations.
Storage structure: ``{ "kernel_id": Class }``
"""
@@ -38,8 +38,8 @@ class Registry:
_kernels: dict[str, type[BaseKernel]] = {}
@classmethod
def register(cls, kernel_cls: type[BaseKernel]):
r"""Decorator to register a kernel class.
def register(cls, kernel_cls: type[BaseKernel]) -> type[BaseKernel] | None:
"""Decorator to register a kernel class.
The class must inherit from :class:`BaseKernel` and specify ``_kernel_id`` and ``_device`` attributes.
@@ -47,7 +47,7 @@ class Registry:
kernel_cls (type[BaseKernel]): The kernel class to register.
Returns:
type[BaseKernel]: The registered kernel class.
type[BaseKernel] | None: The registered kernel class if the device type matches the current accelerator
Raises:
TypeError: If the class does not inherit from :class:`BaseKernel`.
@@ -55,6 +55,7 @@ class Registry:
"""
if not issubclass(kernel_cls, BaseKernel):
raise TypeError(f"Class {kernel_cls} must inherit from BaseKernel")
kernel_id = kernel_cls.get_kernel_id()
device = kernel_cls.get_device()
@@ -73,7 +74,7 @@ class Registry:
@classmethod
def get(cls, kernel_id: str) -> Optional[type[BaseKernel]]:
r"""Retrieves a registered kernel implementation by its ID.
"""Retrieves a registered kernel implementation by its ID.
Args:
kernel_id (str): The ID of the kernel to retrieve.
@@ -85,7 +86,7 @@ class Registry:
@classmethod
def get_registered_kernels(cls) -> dict[str, type[BaseKernel]]:
r"""Returns a dictionary of all registered kernels.
"""Returns a dictionary of all registered kernels.
Returns:
dict[str, type[BaseKernel]]: Dictionary mapping kernel IDs to kernel classes.

View File

@@ -45,13 +45,13 @@ class PeftPlugin(BasePlugin):
return super().__call__(model, config)
@PeftPlugin("lora").register
@PeftPlugin("lora").register()
def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool) -> PeftModel:
peft_config = LoraConfig(**config)
model = get_peft_model(model, peft_config)
return model
@PeftPlugin("freeze").register
@PeftPlugin("freeze").register()
def get_freeze_model(model: HFModel, config: FreezeConfigDict, is_train: bool) -> HFModel:
raise NotImplementedError()

View File

@@ -0,0 +1,36 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils.plugin import BasePlugin
from ...utils.types import Message, ModelInput, Processor
class RenderingPlugin(BasePlugin):
pass
@RenderingPlugin("qwen").register("render_messages")
def render_qwen_messages(
processor: Processor,
messages: list[Message],
tools: str | None = None,
is_generate: bool = False,
) -> ModelInput:
raise NotImplementedError()
@RenderingPlugin("qwen").register("parse_message")
def parse_qwen_message(generated_text: str) -> Message:
raise NotImplementedError()