update readme

Former-commit-id: b8d0170fe0d094acce85dcb5f91775e4685ee055
This commit is contained in:
hiyouga
2024-05-27 18:14:02 +08:00
parent b0d9966663
commit 97a23e1cbe
10 changed files with 71 additions and 62 deletions

View File

@@ -8,7 +8,6 @@ import torch
from transformers import GenerationConfig, TextIteratorStreamer
from ..data import get_template_and_fix_tokenizer
from ..extras.constants import IMAGE_TOKEN
from ..extras.misc import get_logits_processor
from ..model import load_model, load_tokenizer
from .base_engine import BaseEngine, Response
@@ -60,9 +59,9 @@ class HuggingfaceEngine(BaseEngine):
processor is not None
and image is not None
and not hasattr(processor, "image_seq_length")
and IMAGE_TOKEN not in messages[0]["content"]
and template.image_token not in messages[0]["content"]
): # llava-like models
messages[0]["content"] = IMAGE_TOKEN + messages[0]["content"]
messages[0]["content"] = template.image_token + messages[0]["content"]
paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or generating_args["default_system"]
@@ -75,7 +74,7 @@ class HuggingfaceEngine(BaseEngine):
batch_feature = image_processor(image, return_tensors="pt")
pixel_values = batch_feature.to(model.device)["pixel_values"] # shape (B, C, H, W)
if hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
prompt_length = len(prompt_ids)

View File

@@ -2,7 +2,6 @@ import uuid
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
from ..data import get_template_and_fix_tokenizer
from ..extras.constants import IMAGE_TOKEN
from ..extras.logging import get_logger
from ..extras.misc import get_device_count, infer_optim_dtype
from ..extras.packages import is_vllm_available
@@ -67,7 +66,7 @@ class VllmEngine(BaseEngine):
patch_size = config.vision_config.patch_size
self.image_feature_size = (image_size // patch_size) ** 2
engine_args["image_input_type"] = "pixel_values"
engine_args["image_token_id"] = self.tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
engine_args["image_token_id"] = self.tokenizer.convert_tokens_to_ids(self.template.image_token)
engine_args["image_input_shape"] = "1,3,{},{}".format(image_size, image_size)
engine_args["image_feature_size"] = self.image_feature_size
if getattr(config, "is_yi_vl_derived_model", None):
@@ -97,9 +96,9 @@ class VllmEngine(BaseEngine):
self.processor is not None
and image is not None
and not hasattr(self.processor, "image_seq_length")
and IMAGE_TOKEN not in messages[0]["content"]
): # llava-like models
messages[0]["content"] = IMAGE_TOKEN * self.image_feature_size + messages[0]["content"]
and self.template.image_token not in messages[0]["content"]
): # llava-like models (TODO: paligemma models)
messages[0]["content"] = self.template.image_token * self.image_feature_size + messages[0]["content"]
paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"]