lazy image load

Former-commit-id: cdd733b575411e003bc5ffd6560dd8eff8aa09cf
This commit is contained in:
hiyouga
2024-09-04 02:27:08 +08:00
parent fed7ae5661
commit 7056087e92
19 changed files with 353 additions and 366 deletions

View File

@@ -16,12 +16,18 @@
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Dict, Literal, Sequence
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence
import torch
from transformers import DataCollatorForSeq2Seq
if TYPE_CHECKING:
from transformers import ProcessorMixin
from .template import Template
def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
r"""
Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
@@ -65,41 +71,29 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
r"""
Data collator that supports VLMs.
Features should contain input_ids, attention_mask, labels and images.
"""
template: Optional["Template"] = None
processor: Optional["ProcessorMixin"] = None
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
if "token_type_ids" in features[0].keys():
for feature in features:
feature["token_type_ids"] = feature["token_type_ids"][0]
batch_images, batch_imglens, batch_seqlens = [], [], []
for feature in features:
images = feature.pop("images") or [] # avoid NoneType
batch_images.extend(images)
batch_imglens.append(len(images))
batch_seqlens.append(len(feature["input_ids"]))
extra_features = {}
if "pixel_values" in features[0].keys():
pixel_values = []
for feature in features:
if feature["pixel_values"] is None:
pixel_values.append(torch.zeros(0, dtype=torch.float))
else:
pixel_values.append(torch.tensor(feature["pixel_values"], dtype=torch.float))
mm_inputs = self.template.mm_plugin.get_mm_inputs(batch_images, batch_imglens, batch_seqlens, self.processor)
if "token_type_ids" in mm_inputs:
token_type_ids = mm_inputs.pop("token_type_ids")
for i, feature in enumerate(features):
feature["token_type_ids"] = token_type_ids[i]
extra_features["pixel_values"] = torch.cat(pixel_values, dim=0)
if extra_features["pixel_values"].numel() == 0:
extra_features["pixel_values"] = None
if "image_grid_thw" in features[0].keys():
image_grid_thw = []
for feature in features:
if feature["image_grid_thw"] is None:
image_grid_thw.append(torch.zeros(0, dtype=torch.long))
else:
image_grid_thw.append(torch.tensor(feature["image_grid_thw"], dtype=torch.long))
extra_features["image_grid_thw"] = torch.cat(image_grid_thw, dim=0)
if extra_features["image_grid_thw"].numel() == 0:
extra_features["image_grid_thw"] = None
features = [{key: feature[key] for key in feature if key not in extra_features.keys()} for feature in features]
features: Dict[str, "torch.Tensor"] = super().__call__(features)
features.update({key: value for key, value in extra_features.items() if value is not None})
features.update(mm_inputs)
return features
@@ -141,16 +135,8 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
"input_ids": feature["{}_input_ids".format(key)],
"attention_mask": feature["{}_attention_mask".format(key)],
"labels": feature["{}_labels".format(key)],
"images": feature["images"],
}
if "{}_token_type_ids".format(key) in feature:
target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)]
if "pixel_values" in feature: # image data are same for chosen and rejected
target_feature["pixel_values"] = feature["pixel_values"]
if "image_grid_thw" in feature:
target_feature["image_grid_thw"] = feature["image_grid_thw"]
concatenated_features.append(target_feature)
return super().__call__(concatenated_features)
@@ -171,22 +157,14 @@ class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
"input_ids": feature["input_ids"],
"attention_mask": feature["attention_mask"],
"labels": feature["labels"],
"images": feature["images"],
}
kl_feature = {
"input_ids": feature["kl_input_ids"],
"attention_mask": feature["kl_attention_mask"],
"labels": feature["kl_labels"],
"images": feature["images"],
}
if "token_type_ids" in feature:
target_feature["token_type_ids"] = feature["token_type_ids"]
kl_feature["token_type_ids"] = feature["kl_token_type_ids"]
if "pixel_values" in feature:
target_feature["pixel_values"] = feature["pixel_values"]
if "image_grid_thw" in feature:
target_feature["image_grid_thw"] = feature["image_grid_thw"]
target_features.append(target_feature)
kl_features.append(kl_feature)
kto_tags.append(feature["kto_tags"])
@@ -196,7 +174,7 @@ class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
batch["kl_input_ids"] = kl_batch["input_ids"]
batch["kl_attention_mask"] = kl_batch["attention_mask"]
batch["kl_labels"] = kl_batch["labels"]
if "token_type_ids" in batch:
if "token_type_ids" in kl_batch:
batch["kl_token_type_ids"] = kl_batch["token_type_ids"]
batch["kto_tags"] = torch.tensor(kto_tags)