mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-01-31 06:42:05 +00:00
[v1] add cli sampler (#9721)
This commit is contained in:
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
|
||||
from . import logging
|
||||
@@ -27,7 +28,7 @@ class BasePlugin:
|
||||
A plugin is a callable object that can be registered and called by name.
|
||||
"""
|
||||
|
||||
_registry: dict[str, Callable] = {}
|
||||
_registry: dict[str, dict[str, Callable]] = defaultdict(dict)
|
||||
|
||||
def __init__(self, name: str | None = None):
|
||||
"""Initialize the plugin with a name.
|
||||
@@ -37,8 +38,7 @@ class BasePlugin:
|
||||
"""
|
||||
self.name = name
|
||||
|
||||
@property
|
||||
def register(self):
|
||||
def register(self, method_name: str = "__call__"):
|
||||
"""Decorator to register a function as a plugin.
|
||||
|
||||
Example usage:
|
||||
@@ -46,16 +46,21 @@ class BasePlugin:
|
||||
@PrintPlugin("hello").register()
|
||||
def print_hello():
|
||||
print("Hello world!")
|
||||
|
||||
|
||||
@PrintPlugin("hello").register("again")
|
||||
def print_hello_again():
|
||||
print("Hello world! Again.")
|
||||
```
|
||||
"""
|
||||
if self.name is None:
|
||||
raise ValueError("Plugin name is not specified.")
|
||||
raise ValueError("Plugin name should be specified.")
|
||||
|
||||
if self.name in self._registry:
|
||||
logger.warning_rank0_once(f"Plugin {self.name} is already registered.")
|
||||
if method_name in self._registry[self.name]:
|
||||
logger.warning_rank0_once(f"Method {method_name} of plugin {self.name} is already registered.")
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
self._registry[self.name] = func
|
||||
self._registry[self.name][method_name] = func
|
||||
return func
|
||||
|
||||
return decorator
|
||||
@@ -68,10 +73,23 @@ class BasePlugin:
|
||||
PrintPlugin("hello")()
|
||||
```
|
||||
"""
|
||||
if self.name not in self._registry:
|
||||
raise ValueError(f"Plugin {self.name} is not registered.")
|
||||
if "__call__" not in self._registry[self.name]:
|
||||
raise ValueError(f"Method __call__ of plugin {self.name} is not registered.")
|
||||
|
||||
return self._registry[self.name](*args, **kwargs)
|
||||
return self._registry[self.name]["__call__"](*args, **kwargs)
|
||||
|
||||
def __getattr__(self, method_name: str):
|
||||
"""Get the registered function with the given name.
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
PrintPlugin("hello").again()
|
||||
```
|
||||
"""
|
||||
if method_name not in self._registry[self.name]:
|
||||
raise ValueError(f"Method {method_name} of plugin {self.name} is not registered.")
|
||||
|
||||
return self._registry[self.name][method_name]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -82,8 +100,13 @@ if __name__ == "__main__":
|
||||
class PrintPlugin(BasePlugin):
|
||||
pass
|
||||
|
||||
@PrintPlugin("hello").register
|
||||
@PrintPlugin("hello").register()
|
||||
def print_hello():
|
||||
print("Hello world!")
|
||||
|
||||
@PrintPlugin("hello").register("again")
|
||||
def print_hello_again():
|
||||
print("Hello world! Again.")
|
||||
|
||||
PrintPlugin("hello")()
|
||||
PrintPlugin("hello").again()
|
||||
|
||||
Reference in New Issue
Block a user