add rlhf-v dataset
Former-commit-id: 3fd18fc34a0c994a738504746abfd5548e002437
This commit is contained in:
@@ -14,7 +14,7 @@
|
||||
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Union
|
||||
|
||||
from datasets import Features
|
||||
|
||||
@@ -33,19 +33,17 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _convert_images(images: List[Any], dataset_attr: "DatasetAttr", data_args: "DataArguments") -> List[Any]:
|
||||
def _convert_images(images: Sequence[Any], dataset_attr: "DatasetAttr", data_args: "DataArguments") -> List[Any]:
|
||||
r"""
|
||||
Optionally concatenates image path to dataset dir when loading from local disk.
|
||||
"""
|
||||
outputs = []
|
||||
images = images[:]
|
||||
if dataset_attr.load_from in ["script", "file"]:
|
||||
for image in images:
|
||||
if isinstance(image, str) and os.path.isfile(os.path.join(data_args.dataset_dir, image)):
|
||||
outputs.append(os.path.join(data_args.dataset_dir, image))
|
||||
else:
|
||||
outputs.append(image)
|
||||
for i in range(len(images)):
|
||||
if isinstance(images[i], str) and os.path.isfile(os.path.join(data_args.dataset_dir, images[i])):
|
||||
images[i] = os.path.join(data_args.dataset_dir, images[i])
|
||||
|
||||
return outputs
|
||||
return images
|
||||
|
||||
|
||||
def convert_alpaca(
|
||||
|
||||
@@ -142,15 +142,15 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
||||
"attention_mask": feature["{}_attention_mask".format(key)],
|
||||
"labels": feature["{}_labels".format(key)],
|
||||
}
|
||||
if "{}_token_type_ids".format(key) in feature:
|
||||
target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)]
|
||||
|
||||
if "pixel_values" in feature: # image data are same for chosen and rejected
|
||||
target_feature["pixel_values"] = feature["pixel_values"]
|
||||
|
||||
if "image_grid_thw" in feature:
|
||||
target_feature["image_grid_thw"] = feature["image_grid_thw"]
|
||||
|
||||
if "{}_token_type_ids".format(key) in feature:
|
||||
target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)]
|
||||
|
||||
concatenated_features.append(target_feature)
|
||||
|
||||
return super().__call__(concatenated_features)
|
||||
@@ -177,16 +177,16 @@ class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
||||
"attention_mask": feature["kl_attention_mask"],
|
||||
"labels": feature["kl_labels"],
|
||||
}
|
||||
if "token_type_ids" in feature:
|
||||
target_feature["token_type_ids"] = feature["token_type_ids"]
|
||||
kl_feature["token_type_ids"] = feature["kl_token_type_ids"]
|
||||
|
||||
if "pixel_values" in feature:
|
||||
target_feature["pixel_values"] = feature["pixel_values"]
|
||||
|
||||
if "image_grid_thw" in feature:
|
||||
target_feature["image_grid_thw"] = feature["image_grid_thw"]
|
||||
|
||||
if "token_type_ids" in feature:
|
||||
target_feature["token_type_ids"] = feature["token_type_ids"]
|
||||
kl_feature["token_type_ids"] = feature["kl_token_type_ids"]
|
||||
|
||||
target_features.append(target_feature)
|
||||
kl_features.append(kl_feature)
|
||||
kto_tags.append(feature["kto_tags"])
|
||||
|
||||
@@ -19,6 +19,23 @@ if TYPE_CHECKING:
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
|
||||
|
||||
def _regularize_images(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> List["ImageObject"]:
|
||||
r"""
|
||||
Regularizes images to avoid error. Including resizing and mode convert.
|
||||
"""
|
||||
images = images[:]
|
||||
image_resolution = getattr(processor, "image_resolution", 512)
|
||||
for i in range(len(images)):
|
||||
if max(images[i].width, images[i].height) > image_resolution:
|
||||
factor = image_resolution / max(images[i].width, images[i].height)
|
||||
images[i] = images[i].resize((int(images[i].width * factor), int(images[i].height * factor)))
|
||||
|
||||
if images[i].mode != "RGB":
|
||||
images[i] = images[i].convert("RGB")
|
||||
|
||||
return images
|
||||
|
||||
|
||||
def _get_mm_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]:
|
||||
r"""
|
||||
Processes visual inputs.
|
||||
@@ -34,6 +51,7 @@ def _get_mm_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin")
|
||||
"""
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
if len(images) != 0:
|
||||
images = _regularize_images(images, processor)
|
||||
image_inputs = image_processor(images=images, return_tensors="pt")
|
||||
else: # add NoneType for fake images
|
||||
image = Image.new("RGB", (64, 64), (255, 255, 255))
|
||||
|
||||
Reference in New Issue
Block a user