[v1] use async streamer (#9741)

This commit is contained in:
Yaowei Zheng
2026-01-09 16:07:40 +08:00
committed by hiyouga
parent 766d5ae6ad
commit 8abb8fb533
6 changed files with 47 additions and 57 deletions

View File

@@ -24,12 +24,13 @@ Init Phase:
import importlib
from pathlib import Path
from ....utils.logging import get_logger
from ....utils import logging
from ....utils.plugin import BasePlugin
from ....utils.types import HFModel
from .registry import Registry
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
def scan_all_kernels():
@@ -110,27 +111,30 @@ class KernelPlugin(BasePlugin):
@KernelPlugin("auto").register()
def apply_default_kernels(**kwargs):
def apply_default_kernels(model: HFModel, include_kernels: str = None) -> HFModel:
"""Applies all default registered kernels to the model.
Args:
**kwargs: Keyword arguments passed to the kernel application function.
Typically includes the model instance and the include_kernels configuration.
model (HFModel): The model instance to apply kernels to.
include_kernels (str, optional): Comma-separated list of kernel IDs to apply.
If "auto" or True, applies all default kernels.
If None or False, no kernels are applied.
Defaults to None.
Returns:
HFModel: The model with applied kernels.
"""
if not kwargs.get("include_kernels"): # None/False/empty string
return kwargs.get("model")
elif kwargs.get("include_kernels") == "auto" or kwargs.get("include_kernels") is True: # True/auto
if not include_kernels:
return model
elif include_kernels == "auto" or include_kernels is True:
use_kernels = default_kernels.keys()
else:
use_kernels = kwargs.get("include_kernels").split(",") # "kernel_id1,kernel_id2,kernel_id3"
use_kernels = include_kernels.split(",") # "kernel_id1,kernel_id2,kernel_id3"
for kernel in use_kernels:
if kernel not in default_kernels:
raise ValueError(f"Kernel {kernel} not found")
apply_kernel(kernel, **kwargs)
apply_kernel(kernel, model=model)
return kwargs.get("model")
return model

View File

@@ -20,8 +20,6 @@ Init Phase:
"""
from typing import Optional
from ....accelerator.helper import get_current_accelerator
from .base import BaseKernel
@@ -73,14 +71,14 @@ class Registry:
return kernel_cls
@classmethod
def get(cls, kernel_id: str) -> Optional[type[BaseKernel]]:
def get(cls, kernel_id: str) -> type[BaseKernel] | None:
"""Retrieves a registered kernel implementation by its ID.
Args:
kernel_id (str): The ID of the kernel to retrieve.
Returns:
Optional[type[BaseKernel]]: The kernel class if found, else ``None``.
type[BaseKernel] | None: The kernel class if found, else ``None``.
"""
return cls._kernels.get(kernel_id)