mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-02 08:33:38 +00:00
[misc] upgrade format to py39 (#7256)
This commit is contained in:
@@ -13,7 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
|
||||
from collections.abc import AsyncGenerator, AsyncIterator, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
@@ -53,7 +54,7 @@ class VllmEngine(BaseEngine):
|
||||
self.model_args = model_args
|
||||
config = load_config(model_args) # may download model from ms hub
|
||||
if getattr(config, "quantization_config", None): # gptq models should use float16
|
||||
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
|
||||
quantization_config: dict[str, Any] = getattr(config, "quantization_config", None)
|
||||
quant_method = quantization_config.get("quant_method", "")
|
||||
if quant_method == QuantizationMethod.GPTQ and model_args.infer_dtype == "auto":
|
||||
model_args.infer_dtype = "float16"
|
||||
@@ -101,7 +102,7 @@ class VllmEngine(BaseEngine):
|
||||
|
||||
async def _generate(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
@@ -143,7 +144,7 @@ class VllmEngine(BaseEngine):
|
||||
skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
|
||||
max_length: Optional[int] = input_kwargs.pop("max_length", None)
|
||||
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
|
||||
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
|
||||
stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None)
|
||||
|
||||
if length_penalty is not None:
|
||||
logger.warning_rank0("Length penalty is not supported by the vllm engine yet.")
|
||||
@@ -201,14 +202,14 @@ class VllmEngine(BaseEngine):
|
||||
@override
|
||||
async def chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
audios: Optional[Sequence["AudioInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
) -> list["Response"]:
|
||||
final_output = None
|
||||
generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs)
|
||||
async for request_output in generator:
|
||||
@@ -230,7 +231,7 @@ class VllmEngine(BaseEngine):
|
||||
@override
|
||||
async def stream_chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
@@ -248,7 +249,7 @@ class VllmEngine(BaseEngine):
|
||||
@override
|
||||
async def get_scores(
|
||||
self,
|
||||
batch_input: List[str],
|
||||
batch_input: list[str],
|
||||
**input_kwargs,
|
||||
) -> List[float]:
|
||||
) -> list[float]:
|
||||
raise NotImplementedError("vLLM engine does not support get_scores.")
|
||||
|
||||
Reference in New Issue
Block a user