add docstrings, refactor logger

Former-commit-id: c34e489d71f8f539028543ccf8ee92cecedd6276
This commit is contained in:
hiyouga
2024-09-08 00:56:56 +08:00
parent 93d4570a59
commit 7f71276ad8
30 changed files with 334 additions and 57 deletions

View File

@@ -35,6 +35,12 @@ class Response:
class BaseEngine(ABC):
r"""
Base class for inference engine of chat models.
Must implements async methods: chat(), stream_chat() and get_scores().
"""
model: Union["PreTrainedModel", "AsyncLLMEngine"]
tokenizer: "PreTrainedTokenizer"
can_generate: bool
@@ -48,7 +54,11 @@ class BaseEngine(ABC):
data_args: "DataArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
) -> None: ...
) -> None:
r"""
Initializes an inference engine.
"""
...
@abstractmethod
async def chat(
@@ -59,7 +69,11 @@ class BaseEngine(ABC):
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs,
) -> List["Response"]: ...
) -> List["Response"]:
r"""
Gets a list of responses of the chat model.
"""
...
@abstractmethod
async def stream_chat(
@@ -70,11 +84,19 @@ class BaseEngine(ABC):
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]: ...
) -> AsyncGenerator[str, None]:
r"""
Gets the response token-by-token of the chat model.
"""
...
@abstractmethod
async def get_scores(
self,
batch_input: List[str],
**input_kwargs,
) -> List[float]: ...
) -> List[float]:
r"""
Gets a list of scores of the reward model.
"""
...

View File

@@ -37,8 +37,17 @@ def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None:
class ChatModel:
r"""
General class for chat models. Backed by huggingface or vllm engines.
Supports both sync and async methods.
Sync methods: chat(), stream_chat() and get_scores().
Async methods: achat(), astream_chat() and aget_scores().
"""
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
self.engine_type = model_args.infer_backend
if model_args.infer_backend == "huggingface":
self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
elif model_args.infer_backend == "vllm":
@@ -59,6 +68,9 @@ class ChatModel:
video: Optional["VideoInput"] = None,
**input_kwargs,
) -> List["Response"]:
r"""
Gets a list of responses of the chat model.
"""
task = asyncio.run_coroutine_threadsafe(
self.achat(messages, system, tools, image, video, **input_kwargs), self._loop
)
@@ -73,6 +85,9 @@ class ChatModel:
video: Optional["VideoInput"] = None,
**input_kwargs,
) -> List["Response"]:
r"""
Asynchronously gets a list of responses of the chat model.
"""
return await self.engine.chat(messages, system, tools, image, video, **input_kwargs)
def stream_chat(
@@ -84,6 +99,9 @@ class ChatModel:
video: Optional["VideoInput"] = None,
**input_kwargs,
) -> Generator[str, None, None]:
r"""
Gets the response token-by-token of the chat model.
"""
generator = self.astream_chat(messages, system, tools, image, video, **input_kwargs)
while True:
try:
@@ -101,6 +119,9 @@ class ChatModel:
video: Optional["VideoInput"] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
r"""
Asynchronously gets the response token-by-token of the chat model.
"""
async for new_token in self.engine.stream_chat(messages, system, tools, image, video, **input_kwargs):
yield new_token
@@ -109,6 +130,9 @@ class ChatModel:
batch_input: List[str],
**input_kwargs,
) -> List[float]:
r"""
Gets a list of scores of the reward model.
"""
task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop)
return task.result()
@@ -117,6 +141,9 @@ class ChatModel:
batch_input: List[str],
**input_kwargs,
) -> List[float]:
r"""
Asynchronously gets a list of scores of the reward model.
"""
return await self.engine.get_scores(batch_input, **input_kwargs)

View File

@@ -20,6 +20,7 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Opt
import torch
from transformers import GenerationConfig, TextIteratorStreamer
from typing_extensions import override
from ..data import get_template_and_fix_tokenizer
from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
@@ -271,6 +272,7 @@ class HuggingfaceEngine(BaseEngine):
return scores
@override
async def chat(
self,
messages: Sequence[Dict[str, str]],
@@ -301,6 +303,7 @@ class HuggingfaceEngine(BaseEngine):
with concurrent.futures.ThreadPoolExecutor() as pool:
return await loop.run_in_executor(pool, self._chat, *input_args)
@override
async def stream_chat(
self,
messages: Sequence[Dict[str, str]],
@@ -336,6 +339,7 @@ class HuggingfaceEngine(BaseEngine):
except StopAsyncIteration:
break
@override
async def get_scores(
self,
batch_input: List[str],

View File

@@ -15,6 +15,8 @@
import uuid
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
from typing_extensions import override
from ..data import get_template_and_fix_tokenizer
from ..extras.constants import IMAGE_PLACEHOLDER
from ..extras.logging import get_logger
@@ -191,6 +193,7 @@ class VllmEngine(BaseEngine):
)
return result_generator
@override
async def chat(
self,
messages: Sequence[Dict[str, str]],
@@ -218,6 +221,7 @@ class VllmEngine(BaseEngine):
return results
@override
async def stream_chat(
self,
messages: Sequence[Dict[str, str]],
@@ -234,6 +238,7 @@ class VllmEngine(BaseEngine):
generated_text = result.outputs[0].text
yield delta_text
@override
async def get_scores(
self,
batch_input: List[str],