mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-01 20:23:37 +00:00
[v1] add cli sampler (#9721)
This commit is contained in:
@@ -11,3 +11,5 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
|
||||
29
src/llamafactory/v1/utils/helper.py
Normal file
29
src/llamafactory/v1/utils/helper.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from .types import Processor
|
||||
|
||||
|
||||
def get_tokenizer(processor: Processor) -> PreTrainedTokenizer:
|
||||
"""Get tokenizer from processor.
|
||||
|
||||
Args:
|
||||
processor: Processor.
|
||||
|
||||
Returns:
|
||||
Tokenizer.
|
||||
"""
|
||||
return processor.tokenizer if hasattr(processor, "tokenizer") else processor
|
||||
@@ -54,7 +54,7 @@ def _get_default_logging_level() -> "logging._Level":
|
||||
|
||||
|
||||
def _get_library_name() -> str:
|
||||
return __name__.split(".")[0]
|
||||
return ".".join(__name__.split(".")[:2]) # llamafactory.v1
|
||||
|
||||
|
||||
def _get_library_root_logger() -> "_Logger":
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
|
||||
from . import logging
|
||||
@@ -27,7 +28,7 @@ class BasePlugin:
|
||||
A plugin is a callable object that can be registered and called by name.
|
||||
"""
|
||||
|
||||
_registry: dict[str, Callable] = {}
|
||||
_registry: dict[str, dict[str, Callable]] = defaultdict(dict)
|
||||
|
||||
def __init__(self, name: str | None = None):
|
||||
"""Initialize the plugin with a name.
|
||||
@@ -37,8 +38,7 @@ class BasePlugin:
|
||||
"""
|
||||
self.name = name
|
||||
|
||||
@property
|
||||
def register(self):
|
||||
def register(self, method_name: str = "__call__"):
|
||||
"""Decorator to register a function as a plugin.
|
||||
|
||||
Example usage:
|
||||
@@ -46,16 +46,21 @@ class BasePlugin:
|
||||
@PrintPlugin("hello").register()
|
||||
def print_hello():
|
||||
print("Hello world!")
|
||||
|
||||
|
||||
@PrintPlugin("hello").register("again")
|
||||
def print_hello_again():
|
||||
print("Hello world! Again.")
|
||||
```
|
||||
"""
|
||||
if self.name is None:
|
||||
raise ValueError("Plugin name is not specified.")
|
||||
raise ValueError("Plugin name should be specified.")
|
||||
|
||||
if self.name in self._registry:
|
||||
logger.warning_rank0_once(f"Plugin {self.name} is already registered.")
|
||||
if method_name in self._registry[self.name]:
|
||||
logger.warning_rank0_once(f"Method {method_name} of plugin {self.name} is already registered.")
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
self._registry[self.name] = func
|
||||
self._registry[self.name][method_name] = func
|
||||
return func
|
||||
|
||||
return decorator
|
||||
@@ -68,10 +73,23 @@ class BasePlugin:
|
||||
PrintPlugin("hello")()
|
||||
```
|
||||
"""
|
||||
if self.name not in self._registry:
|
||||
raise ValueError(f"Plugin {self.name} is not registered.")
|
||||
if "__call__" not in self._registry[self.name]:
|
||||
raise ValueError(f"Method __call__ of plugin {self.name} is not registered.")
|
||||
|
||||
return self._registry[self.name](*args, **kwargs)
|
||||
return self._registry[self.name]["__call__"](*args, **kwargs)
|
||||
|
||||
def __getattr__(self, method_name: str):
|
||||
"""Get the registered function with the given name.
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
PrintPlugin("hello").again()
|
||||
```
|
||||
"""
|
||||
if method_name not in self._registry[self.name]:
|
||||
raise ValueError(f"Method {method_name} of plugin {self.name} is not registered.")
|
||||
|
||||
return self._registry[self.name][method_name]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -82,8 +100,13 @@ if __name__ == "__main__":
|
||||
class PrintPlugin(BasePlugin):
|
||||
pass
|
||||
|
||||
@PrintPlugin("hello").register
|
||||
@PrintPlugin("hello").register()
|
||||
def print_hello():
|
||||
print("Hello world!")
|
||||
|
||||
@PrintPlugin("hello").register("again")
|
||||
def print_hello_again():
|
||||
print("Hello world! Again.")
|
||||
|
||||
PrintPlugin("hello")()
|
||||
PrintPlugin("hello").again()
|
||||
|
||||
@@ -84,27 +84,59 @@ class DistributedConfig(TypedDict, total=False):
|
||||
|
||||
|
||||
class Content(TypedDict):
|
||||
type: Literal["text", "reasoning", "tools", "tool_calls", "image_url"]
|
||||
type: Literal["text", "reasoning", "tool_call", "image_url"]
|
||||
"""Type of the content."""
|
||||
value: str
|
||||
"""Value of the content."""
|
||||
|
||||
|
||||
class Message(TypedDict):
|
||||
role: Literal["system", "user", "assistant", "tool"]
|
||||
"""Role of the message."""
|
||||
content: list[Content]
|
||||
loss_weight: float
|
||||
"""Content of the message."""
|
||||
loss_weight: NotRequired[float]
|
||||
"""Loss weight for this message, default to 1.0. Required in training."""
|
||||
|
||||
|
||||
class SFTSample(TypedDict):
|
||||
messages: list[Message]
|
||||
"""Messages in the sample."""
|
||||
extra_info: NotRequired[str]
|
||||
"""Extra information for the sample, including tools, kto_labels."""
|
||||
_dataset_name: NotRequired[str]
|
||||
"""Dataset name for the sample."""
|
||||
|
||||
|
||||
class DPOSample(TypedDict):
|
||||
chosen_messages: list[Message]
|
||||
"""Chosen messages in the sample."""
|
||||
rejected_messages: list[Message]
|
||||
"""Rejected messages in the sample."""
|
||||
extra_info: NotRequired[str]
|
||||
"""Extra information for the sample, including tools, kto_labels."""
|
||||
_dataset_name: NotRequired[str]
|
||||
"""Dataset name for the sample."""
|
||||
|
||||
|
||||
Sample = Union[SFTSample, DPOSample]
|
||||
|
||||
|
||||
class ToolCall(TypedDict):
|
||||
name: str
|
||||
"""Function name."""
|
||||
arguments: str
|
||||
"""Function arguments."""
|
||||
|
||||
|
||||
class ModelInput(TypedDict, total=False):
|
||||
input_ids: list[int]
|
||||
"""Input ids for the model."""
|
||||
attention_mask: list[int]
|
||||
"""Attention mask for the model."""
|
||||
labels: list[int]
|
||||
"""Labels for the model."""
|
||||
loss_weights: list[float]
|
||||
"""Loss weight for each token, default to 1.0."""
|
||||
position_ids: NotRequired[list[int] | list[list[int]]]
|
||||
"""Position ids for the model (optional)."""
|
||||
|
||||
Reference in New Issue
Block a user