[v1] add cli sampler (#9721)

This commit is contained in:
Yaowei Zheng
2026-01-06 23:31:27 +08:00
committed by GitHub
parent e944dc442c
commit ea0b4e2466
45 changed files with 1091 additions and 505 deletions

View File

@@ -13,6 +13,7 @@
# limitations under the License.
from collections import defaultdict
from collections.abc import Callable
from . import logging
@@ -27,7 +28,7 @@ class BasePlugin:
A plugin is a callable object that can be registered and called by name.
"""
_registry: dict[str, Callable] = {}
_registry: dict[str, dict[str, Callable]] = defaultdict(dict)
def __init__(self, name: str | None = None):
"""Initialize the plugin with a name.
@@ -37,8 +38,7 @@ class BasePlugin:
"""
self.name = name
@property
def register(self):
def register(self, method_name: str = "__call__"):
"""Decorator to register a function as a plugin.
Example usage:
@@ -46,16 +46,21 @@ class BasePlugin:
@PrintPlugin("hello").register()
def print_hello():
print("Hello world!")
@PrintPlugin("hello").register("again")
def print_hello_again():
print("Hello world! Again.")
```
"""
if self.name is None:
raise ValueError("Plugin name is not specified.")
raise ValueError("Plugin name should be specified.")
if self.name in self._registry:
logger.warning_rank0_once(f"Plugin {self.name} is already registered.")
if method_name in self._registry[self.name]:
logger.warning_rank0_once(f"Method {method_name} of plugin {self.name} is already registered.")
def decorator(func: Callable) -> Callable:
self._registry[self.name] = func
self._registry[self.name][method_name] = func
return func
return decorator
@@ -68,10 +73,23 @@ class BasePlugin:
PrintPlugin("hello")()
```
"""
if self.name not in self._registry:
raise ValueError(f"Plugin {self.name} is not registered.")
if "__call__" not in self._registry[self.name]:
raise ValueError(f"Method __call__ of plugin {self.name} is not registered.")
return self._registry[self.name](*args, **kwargs)
return self._registry[self.name]["__call__"](*args, **kwargs)
def __getattr__(self, method_name: str):
"""Get the registered function with the given name.
Example usage:
```python
PrintPlugin("hello").again()
```
"""
if method_name not in self._registry[self.name]:
raise ValueError(f"Method {method_name} of plugin {self.name} is not registered.")
return self._registry[self.name][method_name]
if __name__ == "__main__":
@@ -82,8 +100,13 @@ if __name__ == "__main__":
class PrintPlugin(BasePlugin):
pass
@PrintPlugin("hello").register
@PrintPlugin("hello").register()
def print_hello():
print("Hello world!")
@PrintPlugin("hello").register("again")
def print_hello_again():
print("Hello world! Again.")
PrintPlugin("hello")()
PrintPlugin("hello").again()