Former-commit-id: fede563aeb716ba5d1e368fd3e1182e4e580d248
This commit is contained in:
fzc8578
2025-01-10 20:01:22 +08:00
parent 8c2a712247
commit 9e972bc9ec
5 changed files with 45 additions and 13 deletions

View File

@@ -6,6 +6,7 @@ import re
import numpy as np
import torch
from torch.nn.utils.rnn import pad_sequence
from transformers.image_utils import get_image_size, to_numpy_array
from typing_extensions import override
@@ -297,7 +298,6 @@ class CpmOPlugin(BasePlugin):
image_index += 1
final_text += text_chunks[-1]
messages[index]['content'] = final_text
# print(messages)
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
@@ -310,6 +310,7 @@ class CpmOPlugin(BasePlugin):
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: "ProcessorMixin",
**kwargs,
) -> Dict[str, "torch.Tensor"]:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
@@ -321,6 +322,14 @@ class CpmOPlugin(BasePlugin):
images,
image_resolution=getattr(processor, "image_resolution", 512 * 512),
)
if "valid_image_nums_ls" in kwargs:
valid_image_nums_ls = kwargs['valid_image_nums_ls']
new_images = []
idx = 0
for valid_image_nums in valid_image_nums_ls:
new_images.append(images[idx:idx+valid_image_nums])
idx += valid_image_nums
images = new_images
image_inputs = image_processor(images, do_pad=True, max_slice_nums=image_processor.max_slice_nums, return_tensors="pt")
mm_inputs.update(image_inputs)
@@ -333,6 +342,26 @@ class CpmOPlugin(BasePlugin):
)
return mm_inputs
def trim_and_pad(self, seq, padding_value=0):
return pad_sequence([s for s in seq], batch_first=True, padding_value=padding_value)
def pad_data(self, features):
features['position_ids'] = [torch.arange(input_ids.size(0)).long() for input_ids in features['input_ids']]
features['input_ids'] = self.trim_and_pad(
[input_ids for input_ids in features['input_ids']],
)
features['position_ids'] = self.trim_and_pad(
[position_ids for position_ids in features['position_ids']],
)
features['labels'] = self.trim_and_pad(
[labels for labels in features['labels']],
padding_value=-100,
)
features['attention_mask'] = self.trim_and_pad(
[attention_mask for attention_mask in features['attention_mask']],
)
return features
@override
def get_mm_inputs(
@@ -345,9 +374,9 @@ class CpmOPlugin(BasePlugin):
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
mm_inputs = self._get_mm_inputs(images, videos, processor)
image_bounds_list = []
position_ids = []
valid_image_nums_ls = []
for input_ids in batch_ids:
input_ids_ = torch.tensor(input_ids)
start_cond = (input_ids_ == processor.tokenizer.im_start_id) | (input_ids_ == processor.tokenizer.slice_start_id)
@@ -356,6 +385,7 @@ class CpmOPlugin(BasePlugin):
image_start_tokens += 1
image_end_tokens = torch.where(end_cond)[0]
valid_image_nums = max(len(image_start_tokens), len(image_end_tokens))
valid_image_nums_ls.append(valid_image_nums)
image_bounds = torch.hstack(
[
image_start_tokens[:valid_image_nums].unsqueeze(-1),
@@ -363,14 +393,9 @@ class CpmOPlugin(BasePlugin):
]
)
image_bounds_list.append(image_bounds)
position_ids_ = list(range(input_ids_.size(0)))
# print(input_ids_.shape, len(position_ids_)
position_ids.append(position_ids_)
#TODO add pad
position_ids = torch.tensor(position_ids, dtype=torch.int64)
mm_inputs = self._get_mm_inputs(images, videos, processor, valid_image_nums_ls=valid_image_nums_ls)
mm_inputs.update({
"image_bound": image_bounds_list,
"position_ids": position_ids,
})
return mm_inputs