add some
Former-commit-id: 58f50b8729083e9ea0fdcf07042b06261670ad57
This commit is contained in:
@@ -265,8 +265,19 @@ class CpmOPlugin(BasePlugin):
|
||||
) -> List[Dict[str, str]]:
|
||||
self._validate_input(images, videos)
|
||||
num_image_tokens = 0
|
||||
num_video_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
mm_inputs = {}
|
||||
|
||||
if len(videos) != 0:
|
||||
assert len(images) == 0, "Only support video and image sft seperately"
|
||||
max_slice_nums = 2
|
||||
use_image_id = False
|
||||
mm_inputs = self._get_mm_inputs([], videos, processor)
|
||||
else:
|
||||
max_slice_nums = image_processor.max_slice_nums
|
||||
use_image_id = image_processor.use_image_id
|
||||
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
@@ -274,15 +285,21 @@ class CpmOPlugin(BasePlugin):
|
||||
num_image_tokens += 1
|
||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
|
||||
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
num_video_tokens += 1
|
||||
content = content.replace(
|
||||
VIDEO_PLACEHOLDER, "{{image}}" * len(mm_inputs["pixel_values"][num_video_tokens - 1]), 1
|
||||
)
|
||||
|
||||
message["content"] = content.replace("{{image}}", "(<image>./</image>)")
|
||||
|
||||
if num_image_tokens > 0:
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||
mm_inputs = self._get_mm_inputs(images, [], processor)
|
||||
|
||||
if mm_inputs:
|
||||
pattern = "(<image>./</image>)"
|
||||
images, image_sizes = mm_inputs["pixel_values"], mm_inputs["image_sizes"]
|
||||
image_sizes = mm_inputs["image_sizes"]
|
||||
|
||||
image_index = 0
|
||||
for index, message in enumerate(messages):
|
||||
text = message["content"]
|
||||
image_tags = re.findall(pattern, text)
|
||||
@@ -293,19 +310,21 @@ class CpmOPlugin(BasePlugin):
|
||||
final_text
|
||||
+ text_chunks[i]
|
||||
+ image_processor.get_slice_image_placeholder(
|
||||
image_sizes[image_index][i],
|
||||
image_sizes[0][i],
|
||||
i,
|
||||
image_processor.max_slice_nums,
|
||||
image_processor.use_image_id,
|
||||
max_slice_nums,
|
||||
use_image_id,
|
||||
)
|
||||
)
|
||||
image_index += 1
|
||||
final_text += text_chunks[-1]
|
||||
messages[index]["content"] = final_text
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
if len(videos) != num_video_tokens:
|
||||
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||
|
||||
return messages
|
||||
|
||||
@override
|
||||
@@ -333,7 +352,7 @@ class CpmOPlugin(BasePlugin):
|
||||
new_images.append(images[idx : idx + valid_image_nums])
|
||||
idx += valid_image_nums
|
||||
images = new_images
|
||||
|
||||
|
||||
image_inputs = image_processor(
|
||||
images, do_pad=True, max_slice_nums=image_processor.max_slice_nums, return_tensors="pt"
|
||||
)
|
||||
@@ -346,6 +365,8 @@ class CpmOPlugin(BasePlugin):
|
||||
video_fps=getattr(processor, "video_fps", 2.0),
|
||||
video_maxlen=getattr(processor, "video_maxlen", 64),
|
||||
)
|
||||
video_inputs = image_processor(videos, do_pad=True, max_slice_nums=2, return_tensors="pt")
|
||||
mm_inputs.update(video_inputs)
|
||||
|
||||
return mm_inputs
|
||||
|
||||
@@ -380,12 +401,9 @@ class CpmOPlugin(BasePlugin):
|
||||
]
|
||||
)
|
||||
image_bounds_list.append(image_bounds)
|
||||
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor, valid_image_nums_ls=valid_image_nums_ls)
|
||||
mm_inputs.update(
|
||||
{
|
||||
"image_bound": image_bounds_list,
|
||||
}
|
||||
)
|
||||
mm_inputs.update({"image_bound": image_bounds_list})
|
||||
return mm_inputs
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user