[misc] upgrade format to py39 (#7256)
This commit is contained in:
@@ -13,8 +13,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncGenerator, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Union
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -36,8 +37,7 @@ class Response:
|
||||
|
||||
|
||||
class BaseEngine(ABC):
|
||||
r"""
|
||||
Base class for inference engine of chat models.
|
||||
r"""Base class for inference engine of chat models.
|
||||
|
||||
Must implements async methods: chat(), stream_chat() and get_scores().
|
||||
"""
|
||||
@@ -47,7 +47,7 @@ class BaseEngine(ABC):
|
||||
tokenizer: "PreTrainedTokenizer"
|
||||
can_generate: bool
|
||||
template: "Template"
|
||||
generating_args: Dict[str, Any]
|
||||
generating_args: dict[str, Any]
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
@@ -57,31 +57,27 @@ class BaseEngine(ABC):
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
) -> None:
|
||||
r"""
|
||||
Initializes an inference engine.
|
||||
"""
|
||||
r"""Initialize an inference engine."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
audios: Optional[Sequence["AudioInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
r"""
|
||||
Gets a list of responses of the chat model.
|
||||
"""
|
||||
) -> list["Response"]:
|
||||
r"""Get a list of responses of the chat model."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def stream_chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
@@ -89,18 +85,14 @@ class BaseEngine(ABC):
|
||||
audios: Optional[Sequence["AudioInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
r"""
|
||||
Gets the response token-by-token of the chat model.
|
||||
"""
|
||||
r"""Get the response token-by-token of the chat model."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_scores(
|
||||
self,
|
||||
batch_input: List[str],
|
||||
batch_input: list[str],
|
||||
**input_kwargs,
|
||||
) -> List[float]:
|
||||
r"""
|
||||
Gets a list of scores of the reward model.
|
||||
"""
|
||||
) -> list[float]:
|
||||
r"""Get a list of scores of the reward model."""
|
||||
...
|
||||
|
||||
@@ -17,8 +17,9 @@
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from collections.abc import AsyncGenerator, Generator, Sequence
|
||||
from threading import Thread
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from ..extras.constants import EngineName
|
||||
from ..extras.misc import torch_gc
|
||||
@@ -38,20 +39,19 @@ def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None:
|
||||
|
||||
|
||||
class ChatModel:
|
||||
r"""
|
||||
General class for chat models. Backed by huggingface or vllm engines.
|
||||
r"""General class for chat models. Backed by huggingface or vllm engines.
|
||||
|
||||
Supports both sync and async methods.
|
||||
Sync methods: chat(), stream_chat() and get_scores().
|
||||
Async methods: achat(), astream_chat() and aget_scores().
|
||||
"""
|
||||
|
||||
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
||||
def __init__(self, args: Optional[dict[str, Any]] = None) -> None:
|
||||
model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
|
||||
if model_args.infer_backend == EngineName.HF:
|
||||
self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
|
||||
self.engine: BaseEngine = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
|
||||
elif model_args.infer_backend == EngineName.VLLM:
|
||||
self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args)
|
||||
self.engine: BaseEngine = VllmEngine(model_args, data_args, finetuning_args, generating_args)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")
|
||||
|
||||
@@ -61,17 +61,15 @@ class ChatModel:
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
audios: Optional[Sequence["AudioInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
r"""
|
||||
Gets a list of responses of the chat model.
|
||||
"""
|
||||
) -> list["Response"]:
|
||||
r"""Get a list of responses of the chat model."""
|
||||
task = asyncio.run_coroutine_threadsafe(
|
||||
self.achat(messages, system, tools, images, videos, audios, **input_kwargs), self._loop
|
||||
)
|
||||
@@ -79,22 +77,20 @@ class ChatModel:
|
||||
|
||||
async def achat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
audios: Optional[Sequence["AudioInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
r"""
|
||||
Asynchronously gets a list of responses of the chat model.
|
||||
"""
|
||||
) -> list["Response"]:
|
||||
r"""Asynchronously get a list of responses of the chat model."""
|
||||
return await self.engine.chat(messages, system, tools, images, videos, audios, **input_kwargs)
|
||||
|
||||
def stream_chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
@@ -102,9 +98,7 @@ class ChatModel:
|
||||
audios: Optional[Sequence["AudioInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> Generator[str, None, None]:
|
||||
r"""
|
||||
Gets the response token-by-token of the chat model.
|
||||
"""
|
||||
r"""Get the response token-by-token of the chat model."""
|
||||
generator = self.astream_chat(messages, system, tools, images, videos, audios, **input_kwargs)
|
||||
while True:
|
||||
try:
|
||||
@@ -115,7 +109,7 @@ class ChatModel:
|
||||
|
||||
async def astream_chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
@@ -123,9 +117,7 @@ class ChatModel:
|
||||
audios: Optional[Sequence["AudioInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
r"""
|
||||
Asynchronously gets the response token-by-token of the chat model.
|
||||
"""
|
||||
r"""Asynchronously get the response token-by-token of the chat model."""
|
||||
async for new_token in self.engine.stream_chat(
|
||||
messages, system, tools, images, videos, audios, **input_kwargs
|
||||
):
|
||||
@@ -133,23 +125,19 @@ class ChatModel:
|
||||
|
||||
def get_scores(
|
||||
self,
|
||||
batch_input: List[str],
|
||||
batch_input: list[str],
|
||||
**input_kwargs,
|
||||
) -> List[float]:
|
||||
r"""
|
||||
Gets a list of scores of the reward model.
|
||||
"""
|
||||
) -> list[float]:
|
||||
r"""Get a list of scores of the reward model."""
|
||||
task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop)
|
||||
return task.result()
|
||||
|
||||
async def aget_scores(
|
||||
self,
|
||||
batch_input: List[str],
|
||||
batch_input: list[str],
|
||||
**input_kwargs,
|
||||
) -> List[float]:
|
||||
r"""
|
||||
Asynchronously gets a list of scores of the reward model.
|
||||
"""
|
||||
) -> list[float]:
|
||||
r"""Asynchronously get a list of scores of the reward model."""
|
||||
return await self.engine.get_scores(batch_input, **input_kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -15,8 +15,9 @@
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import os
|
||||
from collections.abc import AsyncGenerator, Sequence
|
||||
from threading import Thread
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import GenerationConfig, TextIteratorStreamer
|
||||
@@ -76,15 +77,15 @@ class HuggingfaceEngine(BaseEngine):
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
template: "Template",
|
||||
generating_args: Dict[str, Any],
|
||||
messages: Sequence[Dict[str, str]],
|
||||
generating_args: dict[str, Any],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
audios: Optional[Sequence["AudioInput"]] = None,
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> Tuple[Dict[str, Any], int]:
|
||||
input_kwargs: Optional[dict[str, Any]] = {},
|
||||
) -> tuple[dict[str, Any], int]:
|
||||
mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]}
|
||||
if images is not None:
|
||||
mm_input_dict.update({"images": images, "imglens": [len(images)]})
|
||||
@@ -130,7 +131,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
|
||||
max_length: Optional[int] = input_kwargs.pop("max_length", None)
|
||||
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
|
||||
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
|
||||
stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None)
|
||||
|
||||
if stop is not None:
|
||||
logger.warning_rank0("Stop parameter is not supported by the huggingface engine yet.")
|
||||
@@ -217,15 +218,15 @@ class HuggingfaceEngine(BaseEngine):
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
template: "Template",
|
||||
generating_args: Dict[str, Any],
|
||||
messages: Sequence[Dict[str, str]],
|
||||
generating_args: dict[str, Any],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
audios: Optional[Sequence["AudioInput"]] = None,
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> List["Response"]:
|
||||
input_kwargs: Optional[dict[str, Any]] = {},
|
||||
) -> list["Response"]:
|
||||
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
|
||||
model,
|
||||
tokenizer,
|
||||
@@ -272,14 +273,14 @@ class HuggingfaceEngine(BaseEngine):
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
template: "Template",
|
||||
generating_args: Dict[str, Any],
|
||||
messages: Sequence[Dict[str, str]],
|
||||
generating_args: dict[str, Any],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
audios: Optional[Sequence["AudioInput"]] = None,
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
input_kwargs: Optional[dict[str, Any]] = {},
|
||||
) -> Callable[[], str]:
|
||||
gen_kwargs, _ = HuggingfaceEngine._process_args(
|
||||
model,
|
||||
@@ -317,12 +318,12 @@ class HuggingfaceEngine(BaseEngine):
|
||||
def _get_scores(
|
||||
model: "PreTrainedModelWrapper",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
batch_input: List[str],
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> List[float]:
|
||||
batch_input: list[str],
|
||||
input_kwargs: Optional[dict[str, Any]] = {},
|
||||
) -> list[float]:
|
||||
max_length: Optional[int] = input_kwargs.pop("max_length", None)
|
||||
device = getattr(model.pretrained_model, "device", "cuda")
|
||||
inputs: Dict[str, "torch.Tensor"] = tokenizer(
|
||||
inputs: dict[str, torch.Tensor] = tokenizer(
|
||||
batch_input,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
@@ -330,21 +331,21 @@ class HuggingfaceEngine(BaseEngine):
|
||||
return_tensors="pt",
|
||||
add_special_tokens=False,
|
||||
).to(device)
|
||||
values: "torch.Tensor" = model(**inputs, return_dict=True, use_cache=False)[-1]
|
||||
values: torch.Tensor = model(**inputs, return_dict=True, use_cache=False)[-1]
|
||||
scores = values.gather(dim=-1, index=(inputs["attention_mask"].sum(dim=-1, keepdim=True) - 1))
|
||||
return scores
|
||||
|
||||
@override
|
||||
async def chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
audios: Optional[Sequence["AudioInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
) -> list["Response"]:
|
||||
if not self.can_generate:
|
||||
raise ValueError("The current model does not support `chat`.")
|
||||
|
||||
@@ -370,7 +371,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
@override
|
||||
async def stream_chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
@@ -408,9 +409,9 @@ class HuggingfaceEngine(BaseEngine):
|
||||
@override
|
||||
async def get_scores(
|
||||
self,
|
||||
batch_input: List[str],
|
||||
batch_input: list[str],
|
||||
**input_kwargs,
|
||||
) -> List[float]:
|
||||
) -> list[float]:
|
||||
if self.can_generate:
|
||||
raise ValueError("Cannot get scores using an auto-regressive model.")
|
||||
|
||||
|
||||
@@ -13,7 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
|
||||
from collections.abc import AsyncGenerator, AsyncIterator, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
@@ -53,7 +54,7 @@ class VllmEngine(BaseEngine):
|
||||
self.model_args = model_args
|
||||
config = load_config(model_args) # may download model from ms hub
|
||||
if getattr(config, "quantization_config", None): # gptq models should use float16
|
||||
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
|
||||
quantization_config: dict[str, Any] = getattr(config, "quantization_config", None)
|
||||
quant_method = quantization_config.get("quant_method", "")
|
||||
if quant_method == QuantizationMethod.GPTQ and model_args.infer_dtype == "auto":
|
||||
model_args.infer_dtype = "float16"
|
||||
@@ -101,7 +102,7 @@ class VllmEngine(BaseEngine):
|
||||
|
||||
async def _generate(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
@@ -143,7 +144,7 @@ class VllmEngine(BaseEngine):
|
||||
skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
|
||||
max_length: Optional[int] = input_kwargs.pop("max_length", None)
|
||||
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
|
||||
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
|
||||
stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None)
|
||||
|
||||
if length_penalty is not None:
|
||||
logger.warning_rank0("Length penalty is not supported by the vllm engine yet.")
|
||||
@@ -201,14 +202,14 @@ class VllmEngine(BaseEngine):
|
||||
@override
|
||||
async def chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
audios: Optional[Sequence["AudioInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
) -> list["Response"]:
|
||||
final_output = None
|
||||
generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs)
|
||||
async for request_output in generator:
|
||||
@@ -230,7 +231,7 @@ class VllmEngine(BaseEngine):
|
||||
@override
|
||||
async def stream_chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
@@ -248,7 +249,7 @@ class VllmEngine(BaseEngine):
|
||||
@override
|
||||
async def get_scores(
|
||||
self,
|
||||
batch_input: List[str],
|
||||
batch_input: list[str],
|
||||
**input_kwargs,
|
||||
) -> List[float]:
|
||||
) -> list[float]:
|
||||
raise NotImplementedError("vLLM engine does not support get_scores.")
|
||||
|
||||
Reference in New Issue
Block a user