clean code

Former-commit-id: f54cafd5c7f0383370d1a2f357834a61a97397ce
This commit is contained in:
hiyouga
2024-06-13 01:58:16 +08:00
parent 04d7629abf
commit 0a75224f62
4 changed files with 17 additions and 27 deletions

View File

@@ -1,12 +1,10 @@
import uuid
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
from packaging import version
from ..data import get_template_and_fix_tokenizer
from ..extras.logging import get_logger
from ..extras.misc import get_device_count
from ..extras.packages import is_vllm_available, _get_package_version
from ..extras.packages import is_vllm_available, is_vllm_version_greater_than_0_5
from ..model import load_config, load_tokenizer
from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
from .base_engine import BaseEngine, Response
@@ -16,7 +14,7 @@ if is_vllm_available():
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
from vllm.lora.request import LoRARequest
if _get_package_version("vllm") >= version.parse("0.5.0"):
if is_vllm_version_greater_than_0_5():
from vllm.multimodal.image import ImagePixelData
else:
from vllm.sequence import MultiModalData
@@ -112,9 +110,9 @@ class VllmEngine(BaseEngine):
if self.processor is not None and image is not None: # add image features
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
pixel_values = image_processor(image, return_tensors="pt")["pixel_values"]
if _get_package_version("vllm") >= version.parse("0.5.0"):
multi_modal_data = ImagePixelData(pixel_values)
else:
if is_vllm_version_greater_than_0_5():
multi_modal_data = ImagePixelData(image=pixel_values)
else: # TODO: remove vllm 0.4.3 support
multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
else:
multi_modal_data = None