[misc] upgrade format to py39 (#7256)

This commit is contained in:
hoshi-hiyouga
2025-03-12 00:08:41 +08:00
committed by GitHub
parent 5995800bce
commit 264538cb26
113 changed files with 984 additions and 1407 deletions

View File

@@ -17,8 +17,9 @@
import asyncio
import os
from collections.abc import AsyncGenerator, Generator, Sequence
from threading import Thread
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
from typing import TYPE_CHECKING, Any, Optional
from ..extras.constants import EngineName
from ..extras.misc import torch_gc
@@ -38,20 +39,19 @@ def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None:
class ChatModel:
r"""
General class for chat models. Backed by huggingface or vllm engines.
r"""General class for chat models. Backed by huggingface or vllm engines.
Supports both sync and async methods.
Sync methods: chat(), stream_chat() and get_scores().
Async methods: achat(), astream_chat() and aget_scores().
"""
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
def __init__(self, args: Optional[dict[str, Any]] = None) -> None:
model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
if model_args.infer_backend == EngineName.HF:
self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
self.engine: BaseEngine = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
elif model_args.infer_backend == EngineName.VLLM:
self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args)
self.engine: BaseEngine = VllmEngine(model_args, data_args, finetuning_args, generating_args)
else:
raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")
@@ -61,17 +61,15 @@ class ChatModel:
def chat(
self,
messages: Sequence[Dict[str, str]],
messages: Sequence[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs,
) -> List["Response"]:
r"""
Gets a list of responses of the chat model.
"""
) -> list["Response"]:
r"""Get a list of responses of the chat model."""
task = asyncio.run_coroutine_threadsafe(
self.achat(messages, system, tools, images, videos, audios, **input_kwargs), self._loop
)
@@ -79,22 +77,20 @@ class ChatModel:
async def achat(
self,
messages: Sequence[Dict[str, str]],
messages: Sequence[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs,
) -> List["Response"]:
r"""
Asynchronously gets a list of responses of the chat model.
"""
) -> list["Response"]:
r"""Asynchronously get a list of responses of the chat model."""
return await self.engine.chat(messages, system, tools, images, videos, audios, **input_kwargs)
def stream_chat(
self,
messages: Sequence[Dict[str, str]],
messages: Sequence[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
@@ -102,9 +98,7 @@ class ChatModel:
audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs,
) -> Generator[str, None, None]:
r"""
Gets the response token-by-token of the chat model.
"""
r"""Get the response token-by-token of the chat model."""
generator = self.astream_chat(messages, system, tools, images, videos, audios, **input_kwargs)
while True:
try:
@@ -115,7 +109,7 @@ class ChatModel:
async def astream_chat(
self,
messages: Sequence[Dict[str, str]],
messages: Sequence[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
@@ -123,9 +117,7 @@ class ChatModel:
audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
r"""
Asynchronously gets the response token-by-token of the chat model.
"""
r"""Asynchronously get the response token-by-token of the chat model."""
async for new_token in self.engine.stream_chat(
messages, system, tools, images, videos, audios, **input_kwargs
):
@@ -133,23 +125,19 @@ class ChatModel:
def get_scores(
self,
batch_input: List[str],
batch_input: list[str],
**input_kwargs,
) -> List[float]:
r"""
Gets a list of scores of the reward model.
"""
) -> list[float]:
r"""Get a list of scores of the reward model."""
task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop)
return task.result()
async def aget_scores(
self,
batch_input: List[str],
batch_input: list[str],
**input_kwargs,
) -> List[float]:
r"""
Asynchronously gets a list of scores of the reward model.
"""
) -> list[float]:
r"""Asynchronously get a list of scores of the reward model."""
return await self.engine.get_scores(batch_input, **input_kwargs)