[v1] add cli sampler (#9721)

This commit is contained in:
Yaowei Zheng
2026-01-06 23:31:27 +08:00
committed by GitHub
parent e944dc442c
commit ea0b4e2466
45 changed files with 1091 additions and 505 deletions

View File

@@ -11,3 +11,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
IGNORE_INDEX = -100

View File

@@ -0,0 +1,29 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers import PreTrainedTokenizer
from .types import Processor
def get_tokenizer(processor: Processor) -> PreTrainedTokenizer:
"""Get tokenizer from processor.
Args:
processor: Processor.
Returns:
Tokenizer.
"""
return processor.tokenizer if hasattr(processor, "tokenizer") else processor

View File

@@ -54,7 +54,7 @@ def _get_default_logging_level() -> "logging._Level":
def _get_library_name() -> str:
return __name__.split(".")[0]
return ".".join(__name__.split(".")[:2]) # llamafactory.v1
def _get_library_root_logger() -> "_Logger":

View File

@@ -13,6 +13,7 @@
# limitations under the License.
from collections import defaultdict
from collections.abc import Callable
from . import logging
@@ -27,7 +28,7 @@ class BasePlugin:
A plugin is a callable object that can be registered and called by name.
"""
_registry: dict[str, Callable] = {}
_registry: dict[str, dict[str, Callable]] = defaultdict(dict)
def __init__(self, name: str | None = None):
"""Initialize the plugin with a name.
@@ -37,8 +38,7 @@ class BasePlugin:
"""
self.name = name
@property
def register(self):
def register(self, method_name: str = "__call__"):
"""Decorator to register a function as a plugin.
Example usage:
@@ -46,16 +46,21 @@ class BasePlugin:
@PrintPlugin("hello").register()
def print_hello():
print("Hello world!")
@PrintPlugin("hello").register("again")
def print_hello_again():
print("Hello world! Again.")
```
"""
if self.name is None:
raise ValueError("Plugin name is not specified.")
raise ValueError("Plugin name should be specified.")
if self.name in self._registry:
logger.warning_rank0_once(f"Plugin {self.name} is already registered.")
if method_name in self._registry[self.name]:
logger.warning_rank0_once(f"Method {method_name} of plugin {self.name} is already registered.")
def decorator(func: Callable) -> Callable:
self._registry[self.name] = func
self._registry[self.name][method_name] = func
return func
return decorator
@@ -68,10 +73,23 @@ class BasePlugin:
PrintPlugin("hello")()
```
"""
if self.name not in self._registry:
raise ValueError(f"Plugin {self.name} is not registered.")
if "__call__" not in self._registry[self.name]:
raise ValueError(f"Method __call__ of plugin {self.name} is not registered.")
return self._registry[self.name](*args, **kwargs)
return self._registry[self.name]["__call__"](*args, **kwargs)
def __getattr__(self, method_name: str):
"""Get the registered function with the given name.
Example usage:
```python
PrintPlugin("hello").again()
```
"""
if method_name not in self._registry[self.name]:
raise ValueError(f"Method {method_name} of plugin {self.name} is not registered.")
return self._registry[self.name][method_name]
if __name__ == "__main__":
@@ -82,8 +100,13 @@ if __name__ == "__main__":
class PrintPlugin(BasePlugin):
pass
@PrintPlugin("hello").register
@PrintPlugin("hello").register()
def print_hello():
print("Hello world!")
@PrintPlugin("hello").register("again")
def print_hello_again():
print("Hello world! Again.")
PrintPlugin("hello")()
PrintPlugin("hello").again()

View File

@@ -84,27 +84,59 @@ class DistributedConfig(TypedDict, total=False):
class Content(TypedDict):
type: Literal["text", "reasoning", "tools", "tool_calls", "image_url"]
type: Literal["text", "reasoning", "tool_call", "image_url"]
"""Type of the content."""
value: str
"""Value of the content."""
class Message(TypedDict):
role: Literal["system", "user", "assistant", "tool"]
"""Role of the message."""
content: list[Content]
loss_weight: float
"""Content of the message."""
loss_weight: NotRequired[float]
"""Loss weight for this message, default to 1.0. Required in training."""
class SFTSample(TypedDict):
messages: list[Message]
"""Messages in the sample."""
extra_info: NotRequired[str]
"""Extra information for the sample, including tools, kto_labels."""
_dataset_name: NotRequired[str]
"""Dataset name for the sample."""
class DPOSample(TypedDict):
chosen_messages: list[Message]
"""Chosen messages in the sample."""
rejected_messages: list[Message]
"""Rejected messages in the sample."""
extra_info: NotRequired[str]
"""Extra information for the sample, including tools, kto_labels."""
_dataset_name: NotRequired[str]
"""Dataset name for the sample."""
Sample = Union[SFTSample, DPOSample]
class ToolCall(TypedDict):
name: str
"""Function name."""
arguments: str
"""Function arguments."""
class ModelInput(TypedDict, total=False):
input_ids: list[int]
"""Input ids for the model."""
attention_mask: list[int]
"""Attention mask for the model."""
labels: list[int]
"""Labels for the model."""
loss_weights: list[float]
"""Loss weight for each token, default to 1.0."""
position_ids: NotRequired[list[int] | list[list[int]]]
"""Position ids for the model (optional)."""