fix mm inference

Former-commit-id: fa782c15a07ed40f8a6381acdf2da395377efd04
This commit is contained in:
hiyouga
2024-09-02 01:47:40 +08:00
parent f203a9d78e
commit a7fbae47d5
6 changed files with 19 additions and 23 deletions

View File

@@ -27,7 +27,7 @@ from .vllm_engine import VllmEngine
if TYPE_CHECKING:
from numpy.typing import NDArray
from PIL.Image import Image
from .base_engine import BaseEngine, Response
@@ -56,7 +56,7 @@ class ChatModel:
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
image: Optional["Image"] = None,
**input_kwargs,
) -> List["Response"]:
task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, image, **input_kwargs), self._loop)
@@ -67,7 +67,7 @@ class ChatModel:
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
image: Optional["Image"] = None,
**input_kwargs,
) -> List["Response"]:
return await self.engine.chat(messages, system, tools, image, **input_kwargs)
@@ -77,7 +77,7 @@ class ChatModel:
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
image: Optional["Image"] = None,
**input_kwargs,
) -> Generator[str, None, None]:
generator = self.astream_chat(messages, system, tools, image, **input_kwargs)
@@ -93,7 +93,7 @@ class ChatModel:
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
image: Optional["Image"] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
async for new_token in self.engine.stream_chat(messages, system, tools, image, **input_kwargs):

View File

@@ -30,7 +30,7 @@ from .base_engine import BaseEngine, Response
if TYPE_CHECKING:
from numpy.typing import NDArray
from PIL.Image import Image
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from trl import PreTrainedModelWrapper
@@ -78,7 +78,7 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
image: Optional["Image"] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> Tuple[Dict[str, Any], int]:
if image is not None:
@@ -175,7 +175,7 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
image: Optional["Image"] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> List["Response"]:
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
@@ -210,7 +210,7 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
image: Optional["Image"] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> Callable[[], str]:
gen_kwargs, _ = HuggingfaceEngine._process_args(
@@ -267,7 +267,7 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
image: Optional["Image"] = None,
**input_kwargs,
) -> List["Response"]:
if not self.can_generate:
@@ -295,7 +295,7 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
image: Optional["Image"] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
if not self.can_generate:

View File

@@ -39,7 +39,7 @@ if is_vllm_available():
if TYPE_CHECKING:
from numpy.typing import NDArray
from PIL.Image import Image
from transformers.image_processing_utils import BaseImageProcessor
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
@@ -111,7 +111,7 @@ class VllmEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
image: Optional["Image"] = None,
**input_kwargs,
) -> AsyncIterator["RequestOutput"]:
request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
@@ -195,7 +195,7 @@ class VllmEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
image: Optional["Image"] = None,
**input_kwargs,
) -> List["Response"]:
final_output = None
@@ -221,7 +221,7 @@ class VllmEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
image: Optional["Image"] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
generated_text = ""