add docstrings, refactor logger
Former-commit-id: c34e489d71f8f539028543ccf8ee92cecedd6276
This commit is contained in:
@@ -20,6 +20,7 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Opt
|
||||
|
||||
import torch
|
||||
from transformers import GenerationConfig, TextIteratorStreamer
|
||||
from typing_extensions import override
|
||||
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||
@@ -271,6 +272,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
|
||||
return scores
|
||||
|
||||
@override
|
||||
async def chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
@@ -301,6 +303,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||
return await loop.run_in_executor(pool, self._chat, *input_args)
|
||||
|
||||
@override
|
||||
async def stream_chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
@@ -336,6 +339,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
@override
|
||||
async def get_scores(
|
||||
self,
|
||||
batch_input: List[str],
|
||||
|
||||
Reference in New Issue
Block a user