add rlhf-v dataset
Former-commit-id: 3fd18fc34a0c994a738504746abfd5548e002437
This commit is contained in:
@@ -19,6 +19,23 @@ if TYPE_CHECKING:
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
|
||||
|
||||
def _regularize_images(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> List["ImageObject"]:
|
||||
r"""
|
||||
Regularizes images to avoid error. Including resizing and mode convert.
|
||||
"""
|
||||
images = images[:]
|
||||
image_resolution = getattr(processor, "image_resolution", 512)
|
||||
for i in range(len(images)):
|
||||
if max(images[i].width, images[i].height) > image_resolution:
|
||||
factor = image_resolution / max(images[i].width, images[i].height)
|
||||
images[i] = images[i].resize((int(images[i].width * factor), int(images[i].height * factor)))
|
||||
|
||||
if images[i].mode != "RGB":
|
||||
images[i] = images[i].convert("RGB")
|
||||
|
||||
return images
|
||||
|
||||
|
||||
def _get_mm_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]:
|
||||
r"""
|
||||
Processes visual inputs.
|
||||
@@ -34,6 +51,7 @@ def _get_mm_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin")
|
||||
"""
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
if len(images) != 0:
|
||||
images = _regularize_images(images, processor)
|
||||
image_inputs = image_processor(images=images, return_tensors="pt")
|
||||
else: # add NoneType for fake images
|
||||
image = Image.new("RGB", (64, 64), (255, 255, 255))
|
||||
|
||||
Reference in New Issue
Block a user