lazy image load
Former-commit-id: cdd733b575411e003bc5ffd6560dd8eff8aa09cf
This commit is contained in:
@@ -16,12 +16,18 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Literal, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence
|
||||
|
||||
import torch
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import ProcessorMixin
|
||||
|
||||
from .template import Template
|
||||
|
||||
|
||||
def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
|
||||
r"""
|
||||
Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
|
||||
@@ -65,41 +71,29 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
|
||||
class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator that supports VLMs.
|
||||
|
||||
Features should contain input_ids, attention_mask, labels and images.
|
||||
"""
|
||||
|
||||
template: Optional["Template"] = None
|
||||
processor: Optional["ProcessorMixin"] = None
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
||||
if "token_type_ids" in features[0].keys():
|
||||
for feature in features:
|
||||
feature["token_type_ids"] = feature["token_type_ids"][0]
|
||||
batch_images, batch_imglens, batch_seqlens = [], [], []
|
||||
for feature in features:
|
||||
images = feature.pop("images") or [] # avoid NoneType
|
||||
batch_images.extend(images)
|
||||
batch_imglens.append(len(images))
|
||||
batch_seqlens.append(len(feature["input_ids"]))
|
||||
|
||||
extra_features = {}
|
||||
if "pixel_values" in features[0].keys():
|
||||
pixel_values = []
|
||||
for feature in features:
|
||||
if feature["pixel_values"] is None:
|
||||
pixel_values.append(torch.zeros(0, dtype=torch.float))
|
||||
else:
|
||||
pixel_values.append(torch.tensor(feature["pixel_values"], dtype=torch.float))
|
||||
mm_inputs = self.template.mm_plugin.get_mm_inputs(batch_images, batch_imglens, batch_seqlens, self.processor)
|
||||
if "token_type_ids" in mm_inputs:
|
||||
token_type_ids = mm_inputs.pop("token_type_ids")
|
||||
for i, feature in enumerate(features):
|
||||
feature["token_type_ids"] = token_type_ids[i]
|
||||
|
||||
extra_features["pixel_values"] = torch.cat(pixel_values, dim=0)
|
||||
if extra_features["pixel_values"].numel() == 0:
|
||||
extra_features["pixel_values"] = None
|
||||
|
||||
if "image_grid_thw" in features[0].keys():
|
||||
image_grid_thw = []
|
||||
for feature in features:
|
||||
if feature["image_grid_thw"] is None:
|
||||
image_grid_thw.append(torch.zeros(0, dtype=torch.long))
|
||||
else:
|
||||
image_grid_thw.append(torch.tensor(feature["image_grid_thw"], dtype=torch.long))
|
||||
|
||||
extra_features["image_grid_thw"] = torch.cat(image_grid_thw, dim=0)
|
||||
if extra_features["image_grid_thw"].numel() == 0:
|
||||
extra_features["image_grid_thw"] = None
|
||||
|
||||
features = [{key: feature[key] for key in feature if key not in extra_features.keys()} for feature in features]
|
||||
features: Dict[str, "torch.Tensor"] = super().__call__(features)
|
||||
features.update({key: value for key, value in extra_features.items() if value is not None})
|
||||
features.update(mm_inputs)
|
||||
return features
|
||||
|
||||
|
||||
@@ -141,16 +135,8 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
||||
"input_ids": feature["{}_input_ids".format(key)],
|
||||
"attention_mask": feature["{}_attention_mask".format(key)],
|
||||
"labels": feature["{}_labels".format(key)],
|
||||
"images": feature["images"],
|
||||
}
|
||||
if "{}_token_type_ids".format(key) in feature:
|
||||
target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)]
|
||||
|
||||
if "pixel_values" in feature: # image data are same for chosen and rejected
|
||||
target_feature["pixel_values"] = feature["pixel_values"]
|
||||
|
||||
if "image_grid_thw" in feature:
|
||||
target_feature["image_grid_thw"] = feature["image_grid_thw"]
|
||||
|
||||
concatenated_features.append(target_feature)
|
||||
|
||||
return super().__call__(concatenated_features)
|
||||
@@ -171,22 +157,14 @@ class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
||||
"input_ids": feature["input_ids"],
|
||||
"attention_mask": feature["attention_mask"],
|
||||
"labels": feature["labels"],
|
||||
"images": feature["images"],
|
||||
}
|
||||
kl_feature = {
|
||||
"input_ids": feature["kl_input_ids"],
|
||||
"attention_mask": feature["kl_attention_mask"],
|
||||
"labels": feature["kl_labels"],
|
||||
"images": feature["images"],
|
||||
}
|
||||
if "token_type_ids" in feature:
|
||||
target_feature["token_type_ids"] = feature["token_type_ids"]
|
||||
kl_feature["token_type_ids"] = feature["kl_token_type_ids"]
|
||||
|
||||
if "pixel_values" in feature:
|
||||
target_feature["pixel_values"] = feature["pixel_values"]
|
||||
|
||||
if "image_grid_thw" in feature:
|
||||
target_feature["image_grid_thw"] = feature["image_grid_thw"]
|
||||
|
||||
target_features.append(target_feature)
|
||||
kl_features.append(kl_feature)
|
||||
kto_tags.append(feature["kto_tags"])
|
||||
@@ -196,7 +174,7 @@ class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
||||
batch["kl_input_ids"] = kl_batch["input_ids"]
|
||||
batch["kl_attention_mask"] = kl_batch["attention_mask"]
|
||||
batch["kl_labels"] = kl_batch["labels"]
|
||||
if "token_type_ids" in batch:
|
||||
if "token_type_ids" in kl_batch:
|
||||
batch["kl_token_type_ids"] = kl_batch["token_type_ids"]
|
||||
|
||||
batch["kto_tags"] = torch.tensor(kto_tags)
|
||||
|
||||
Reference in New Issue
Block a user