initial-commit

Former-commit-id: b6a39847a10b417b09db4b5512dd835e9e4ce928
This commit is contained in:
simonJJJ
2024-08-28 16:51:35 +08:00
parent 7272792f65
commit 0f3d54d8a0
8 changed files with 183 additions and 5 deletions

View File

@@ -78,6 +78,20 @@ def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") ->
return [0] * image_seq_length + [1] * (input_len - image_seq_length)
def get_qwen2vl_image_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray":
r"""
Processes visual inputs. support multi images
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
if len(images) != 0:
image_inputs = image_processor(images=images, return_tensors="pt")
else:
image = Image.new("RGB", (56, 56), (255, 255, 255))
image_inputs = image_processor(images=[image], return_tensors="pt")
image_inputs["image_grid_thw"][0][0] = 0
return {"pixel_values": image_inputs["pixel_values"], "image_grid_thw": image_inputs["image_grid_thw"]}
def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]:
r"""
Computes the real sequence length after truncation by the cutoff_len.

View File

@@ -17,10 +17,17 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, greedy_knapsack, infer_seqlen
from .processor_utils import (
get_paligemma_token_type_ids,
get_pixel_values,
get_qwen2vl_image_inputs,
greedy_knapsack,
infer_seqlen,
)
if TYPE_CHECKING:
from PIL.Image import Image as ImageObject
from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments
@@ -36,12 +43,31 @@ def _encode_supervised_example(
system: Optional[str],
tools: Optional[str],
template: "Template",
images: Sequence["ImageObject"],
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
cutoff_len: int,
train_on_prompt: bool,
mask_history: bool,
) -> Tuple[List[int], List[int]]:
if "image_grid_thw" in processor.model_input_names: # qwen2_vl models
image_processor = getattr(processor, "image_processor")
merge_length = image_processor.merge_size**2
if len(images) > 0:
image_grid_thw = get_qwen2vl_image_inputs(images, processor)["image_grid_thw"]
index = 0
for message in prompt:
content = message["content"]
while "<|image_pad|>" in content:
content = content.replace(
"<|image_pad|>",
template.vision_start_token
+ "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length)
+ template.vision_end_token,
1,
)
index += 1
message["content"] = content.replace("<|placeholder|>", "<|image_pad|>")
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
prompt[0]["content"] = template.image_token + prompt[0]["content"]
@@ -107,6 +133,8 @@ def preprocess_supervised_dataset(
model_inputs["pixel_values"] = []
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"] = []
if "image_grid_thw" in processor.model_input_names: # qwen2_vl models
model_inputs["image_grid_thw"] = []
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
@@ -129,9 +157,14 @@ def preprocess_supervised_dataset(
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
if processor is not None:
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
if "image_grid_thw" in processor.model_input_names: # qwen2_vl models
image_inputs = get_qwen2vl_image_inputs(examples["images"][i], processor)
model_inputs["pixel_values"].append(image_inputs["pixel_values"])
model_inputs["image_grid_thw"].append(image_inputs["image_grid_thw"])
else:
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
return model_inputs