mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-01 20:23:37 +00:00
[v1] use async streamer (#9741)
This commit is contained in:
@@ -24,12 +24,13 @@ Init Phase:
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
|
||||
from ....utils.logging import get_logger
|
||||
from ....utils import logging
|
||||
from ....utils.plugin import BasePlugin
|
||||
from ....utils.types import HFModel
|
||||
from .registry import Registry
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def scan_all_kernels():
|
||||
@@ -110,27 +111,30 @@ class KernelPlugin(BasePlugin):
|
||||
|
||||
|
||||
@KernelPlugin("auto").register()
|
||||
def apply_default_kernels(**kwargs):
|
||||
def apply_default_kernels(model: HFModel, include_kernels: str = None) -> HFModel:
|
||||
"""Applies all default registered kernels to the model.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments passed to the kernel application function.
|
||||
Typically includes the model instance and the include_kernels configuration.
|
||||
model (HFModel): The model instance to apply kernels to.
|
||||
include_kernels (str, optional): Comma-separated list of kernel IDs to apply.
|
||||
If "auto" or True, applies all default kernels.
|
||||
If None or False, no kernels are applied.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
HFModel: The model with applied kernels.
|
||||
"""
|
||||
if not kwargs.get("include_kernels"): # None/False/empty string
|
||||
return kwargs.get("model")
|
||||
elif kwargs.get("include_kernels") == "auto" or kwargs.get("include_kernels") is True: # True/auto
|
||||
if not include_kernels:
|
||||
return model
|
||||
elif include_kernels == "auto" or include_kernels is True:
|
||||
use_kernels = default_kernels.keys()
|
||||
else:
|
||||
use_kernels = kwargs.get("include_kernels").split(",") # "kernel_id1,kernel_id2,kernel_id3"
|
||||
use_kernels = include_kernels.split(",") # "kernel_id1,kernel_id2,kernel_id3"
|
||||
|
||||
for kernel in use_kernels:
|
||||
if kernel not in default_kernels:
|
||||
raise ValueError(f"Kernel {kernel} not found")
|
||||
|
||||
apply_kernel(kernel, **kwargs)
|
||||
apply_kernel(kernel, model=model)
|
||||
|
||||
return kwargs.get("model")
|
||||
return model
|
||||
|
||||
@@ -20,8 +20,6 @@ Init Phase:
|
||||
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from ....accelerator.helper import get_current_accelerator
|
||||
from .base import BaseKernel
|
||||
|
||||
@@ -73,14 +71,14 @@ class Registry:
|
||||
return kernel_cls
|
||||
|
||||
@classmethod
|
||||
def get(cls, kernel_id: str) -> Optional[type[BaseKernel]]:
|
||||
def get(cls, kernel_id: str) -> type[BaseKernel] | None:
|
||||
"""Retrieves a registered kernel implementation by its ID.
|
||||
|
||||
Args:
|
||||
kernel_id (str): The ID of the kernel to retrieve.
|
||||
|
||||
Returns:
|
||||
Optional[type[BaseKernel]]: The kernel class if found, else ``None``.
|
||||
type[BaseKernel] | None: The kernel class if found, else ``None``.
|
||||
"""
|
||||
return cls._kernels.get(kernel_id)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user