add rlhf-v dataset

Former-commit-id: 3fd18fc34a0c994a738504746abfd5548e002437
This commit is contained in:
hiyouga
2024-09-01 22:57:41 +08:00
parent 7621526d22
commit 60cf12727b
12 changed files with 107 additions and 33 deletions

View File

@@ -19,6 +19,23 @@ if TYPE_CHECKING:
from transformers.image_processing_utils import BaseImageProcessor
def _regularize_images(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> List["ImageObject"]:
r"""
Regularizes images to avoid error. Including resizing and mode convert.
"""
images = images[:]
image_resolution = getattr(processor, "image_resolution", 512)
for i in range(len(images)):
if max(images[i].width, images[i].height) > image_resolution:
factor = image_resolution / max(images[i].width, images[i].height)
images[i] = images[i].resize((int(images[i].width * factor), int(images[i].height * factor)))
if images[i].mode != "RGB":
images[i] = images[i].convert("RGB")
return images
def _get_mm_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]:
r"""
Processes visual inputs.
@@ -34,6 +51,7 @@ def _get_mm_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin")
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
if len(images) != 0:
images = _regularize_images(images, processor)
image_inputs = image_processor(images=images, return_tensors="pt")
else: # add NoneType for fake images
image = Image.new("RGB", (64, 64), (255, 255, 255))