[misc] fix script (#6977)

Former-commit-id: 775efa1d8cbdb1b7d122be2a986d47f85214e0a1
This commit is contained in:
hoshi-hiyouga
2025-02-18 17:00:46 +08:00
committed by GitHub
parent f5cd17881e
commit be33ef67fb
3 changed files with 22 additions and 21 deletions

View File

@@ -21,18 +21,13 @@ from ..data import get_template_and_fix_tokenizer
from ..extras import logging
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.misc import get_device_count
from ..extras.packages import is_pillow_available, is_vllm_available
from ..extras.packages import is_vllm_available
from ..model import load_config, load_tokenizer
from ..model.model_utils.quantization import QuantizationMethod
from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
from .base_engine import BaseEngine, Response
if is_pillow_available():
from PIL import Image
from PIL.Image import Image as ImageObject
if is_vllm_available():
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
from vllm.lora.request import LoRARequest
@@ -54,6 +49,7 @@ class VllmEngine(BaseEngine):
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
) -> None:
self.model_args = model_args
config = load_config(model_args) # may download model from ms hub
if getattr(config, "quantization_config", None): # gptq models should use float16
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
@@ -180,15 +176,13 @@ class VllmEngine(BaseEngine):
)
if images is not None: # add image features
multi_modal_data = {"image": []}
for image in images:
if not isinstance(image, (str, ImageObject)):
raise ValueError(f"Expected image input is a path or PIL.Image, but got {type(image)}.")
if isinstance(image, str):
image = Image.open(image).convert("RGB")
multi_modal_data["image"].append(image)
multi_modal_data = {
"image": self.template.mm_plugin._regularize_images(
images,
image_max_pixels=self.model_args.image_max_pixels,
image_min_pixels=self.model_args.image_min_pixels,
)
}
else:
multi_modal_data = None

View File

@@ -1112,9 +1112,13 @@ class Qwen2vlPlugin(BasePlugin):
self._validate_input(images, videos, audios)
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
merge_length: int = getattr(image_processor, "merge_size") ** 2
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
image_grid_thw = mm_inputs.get("image_grid_thw", [])
video_grid_thw = mm_inputs.get("video_grid_thw", [])
if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
image_grid_thw = mm_inputs.get("image_grid_thw", [])
video_grid_thw = mm_inputs.get("video_grid_thw", [])
else:
image_grid_thw = [None] * len(images)
video_grid_thw = [None] * len(videos)
num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages)