add test mm plugin

Former-commit-id: ddea5cca5a3174de1dcc7fdee8ec69e77700b6bf
This commit is contained in:
hiyouga
2024-08-31 01:53:38 +08:00
parent 2f6fc27c8b
commit 43654028eb
5 changed files with 192 additions and 45 deletions

View File

@@ -18,36 +18,14 @@ if TYPE_CHECKING:
from transformers.image_processing_utils import BaseImageProcessor
def get_pixel_values(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "torch.Tensor":
def _get_mm_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]:
r"""
Processes visual inputs. (currently only supports a single image)
Processes visual inputs.
Returns:
Returns: (llava and paligemma)
pixel_values: tensor with shape (B, C, H, W)
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255))
return image_processor([image], return_tensors="pt")["pixel_values"]
def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> List[List[int]]:
r"""
Gets paligemma token type ids for computing loss.
Returns:
token_type_ids: shape (1, seq_len)
"""
image_seq_length = getattr(processor, "image_seq_length")
return [[0] * image_seq_length + [1] * (input_len - image_seq_length)]
def get_qwen2vl_image_inputs(
images: Sequence["ImageObject"], processor: "ProcessorMixin"
) -> Dict[str, "torch.Tensor"]:
r"""
Processes qwen2-vl visual inputs. Supports multiple images.
Returns:
Returns: (qwen2-vl)
pixel_values: tensor with shape (num_patches, patch_dim)
image_grid_thw: tensot with shape (num_images, 3), where the three numbers are time, width, height
@@ -59,9 +37,22 @@ def get_qwen2vl_image_inputs(
else:
image = Image.new("RGB", (56, 56), (255, 255, 255))
image_inputs = image_processor(images=[image], return_tensors="pt")
image_inputs["image_grid_thw"][0][0] = 0 # fake image
if "image_grid_thw" in image_inputs: # fake image for qwen2-vl
image_inputs["image_grid_thw"][0][0] = 0
return {"pixel_values": image_inputs["pixel_values"], "image_grid_thw": image_inputs["image_grid_thw"]}
return image_inputs
def _get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> List[List[int]]:
r"""
Gets paligemma token type ids for computing loss.
Returns:
token_type_ids: shape (1, seq_len)
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
image_seq_length: int = getattr(image_processor, "image_seq_length")
return [[0] * image_seq_length + [1] * (input_len - image_seq_length)]
class BasePlugin:
@@ -131,8 +122,9 @@ class LlavaPlugin(BasePlugin):
if image_count > 1:
raise ValueError("Llava model only accepts one image per sample.")
content = content.replace(IMAGE_PLACEHOLDER, self.image_token, 1)
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
content = content.replace("{{image}}", self.image_token)
new_messages.append({"role": message["role"], "content": content})
return new_messages
@@ -143,7 +135,7 @@ class LlavaPlugin(BasePlugin):
feature_seqlens: Dict[str, int],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Any]:
return {"pixel_values": get_pixel_values(images, processor)}
return _get_mm_inputs(images, processor)
def process_model_inputs(
self,
@@ -153,7 +145,8 @@ class LlavaPlugin(BasePlugin):
processor: Optional["ProcessorMixin"],
) -> None:
mm_inputs = self.get_mm_inputs(images, feature_seqlens, processor)
model_inputs["pixel_values"].append(mm_inputs["pixel_values"][0])
for key, value in mm_inputs.items():
model_inputs[key].append(value[0])
class PaliGemmaPlugin(BasePlugin):
@@ -200,9 +193,9 @@ class PaliGemmaPlugin(BasePlugin):
feature_seqlens: Dict[str, int],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Any]:
mm_inputs = {"pixel_values": get_pixel_values(images, processor)}
mm_inputs = _get_mm_inputs(images, processor)
for feature_name, feature_length in feature_seqlens.items():
mm_inputs[feature_name] = get_paligemma_token_type_ids(feature_length, processor)
mm_inputs[feature_name] = _get_paligemma_token_type_ids(feature_length, processor)
return mm_inputs
@@ -214,9 +207,8 @@ class PaliGemmaPlugin(BasePlugin):
processor: Optional["ProcessorMixin"],
) -> None:
mm_inputs = self.get_mm_inputs(images, feature_seqlens, processor)
model_inputs["pixel_values"].append(mm_inputs["pixel_values"][0])
for feature_name in feature_seqlens.keys():
model_inputs[feature_name].append(mm_inputs[feature_name][0])
for key, value in mm_inputs.items():
model_inputs[key].append(value[0])
class Qwen2vlPlugin(BasePlugin):
@@ -229,7 +221,7 @@ class Qwen2vlPlugin(BasePlugin):
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
merge_length: int = getattr(image_processor, "merge_size") ** 2
if len(images) > 0:
image_grid_thw = get_qwen2vl_image_inputs(images, processor)["image_grid_thw"]
image_grid_thw = _get_mm_inputs(images, processor)["image_grid_thw"]
index = 0
new_messages = []
@@ -255,7 +247,7 @@ class Qwen2vlPlugin(BasePlugin):
feature_seqlens: Dict[str, int],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Any]:
return get_qwen2vl_image_inputs(images, processor)
return _get_mm_inputs(images, processor)
def process_model_inputs(
self,
@@ -265,11 +257,12 @@ class Qwen2vlPlugin(BasePlugin):
processor: Optional["ProcessorMixin"],
) -> None:
mm_inputs = self.get_mm_inputs(images, feature_seqlens, processor)
model_inputs["pixel_values"].append(mm_inputs["pixel_values"])
model_inputs["image_grid_thw"].append(mm_inputs["image_grid_thw"])
for key, value in mm_inputs.items():
model_inputs[key].append(value) # support multi-image
PLUGINS = {
"base": BasePlugin,
"llava": LlavaPlugin,
"paligemma": PaliGemmaPlugin,
"qwen2_vl": Qwen2vlPlugin,

View File

@@ -19,13 +19,14 @@ from ..extras.constants import IMAGE_PLACEHOLDER
from ..extras.logging import get_logger
from .data_utils import Role
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
from .mm_plugin import BasePlugin, get_mm_plugin
from .mm_plugin import get_mm_plugin
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
from .formatter import SLOTS, Formatter
from .mm_plugin import BasePlugin
logger = get_logger(__name__)
@@ -209,7 +210,7 @@ def _register_template(
stop_words: Sequence[str] = [],
efficient_eos: bool = False,
replace_eos: bool = False,
mm_plugin: "BasePlugin" = BasePlugin(IMAGE_PLACEHOLDER),
mm_plugin: "BasePlugin" = get_mm_plugin(name="base", image_token=IMAGE_PLACEHOLDER),
) -> None:
r"""
Registers a chat template.

View File

@@ -99,6 +99,11 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
except Exception:
processor = None
# Avoid load tokenizer, see:
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/auto/processing_auto.py#L324
if "Processor" not in processor.__class__.__name__:
processor = None
return {"tokenizer": tokenizer, "processor": processor}

View File

@@ -46,7 +46,6 @@ def save_model(
finetuning_type: str,
checkpoint_path: Union[str, List[str]],
template: str,
visual_inputs: bool,
export_size: int,
export_quantization_bit: str,
export_quantization_dataset: str,
@@ -78,7 +77,6 @@ def save_model(
model_name_or_path=model_path,
finetuning_type=finetuning_type,
template=template,
visual_inputs=visual_inputs,
export_dir=export_dir,
export_hub_model_id=export_hub_model_id or None,
export_size=export_size,
@@ -129,7 +127,6 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
engine.manager.get_elem_by_id("top.finetuning_type"),
engine.manager.get_elem_by_id("top.checkpoint_path"),
engine.manager.get_elem_by_id("top.template"),
engine.manager.get_elem_by_id("top.visual_inputs"),
export_size,
export_quantization_bit,
export_quantization_dataset,