fix mixed mm inputs and rlhf-v

Former-commit-id: 7c248fac20bf85d57a91132ce7a793c7f84e9218
This commit is contained in:
hiyouga
2024-09-01 20:52:47 +08:00
parent 1d8e9c7897
commit 7e4c5d4bb3
20 changed files with 306 additions and 277 deletions

View File

@@ -62,44 +62,49 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
@dataclass
class CustomDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
r"""
Data collator for custom models (like Qwen2-VL).
Data collator that supports VLMs.
"""
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
image_grid_thw = None # TODO: better handle various VLMs
if "image_grid_thw" in features[0]:
image_grid_thw_list = [
torch.Tensor(feature["image_grid_thw"]).long()
for feature in features
if feature["image_grid_thw"][0][0] > 0
]
pixel_values_list = [
torch.Tensor(feature["pixel_values"]) for feature in features if feature["image_grid_thw"][0][0] > 0
]
if image_grid_thw_list:
image_grid_thw = torch.cat(image_grid_thw_list, dim=0)
pixel_values = torch.cat(pixel_values_list, dim=0)
else:
image_grid_thw = None
pixel_values = None
if "token_type_ids" in features[0].keys():
for feature in features:
feature["token_type_ids"] = feature["token_type_ids"][0]
features = [
{key: feature[key] for key in feature if key not in ["image_grid_thw", "pixel_values"]}
for feature in features
]
extra_features = {}
if "pixel_values" in features[0].keys():
pixel_values = []
for feature in features:
if feature["pixel_values"] is None:
pixel_values.append(torch.zeros(0, dtype=torch.float))
else:
pixel_values.append(torch.tensor(feature["pixel_values"], dtype=torch.float))
features = super().__call__(features)
if image_grid_thw is not None:
features["image_grid_thw"] = image_grid_thw
features["pixel_values"] = pixel_values
extra_features["pixel_values"] = torch.cat(pixel_values, dim=0)
if extra_features["pixel_values"].numel() == 0:
extra_features["pixel_values"] = None
if "image_grid_thw" in features[0].keys():
image_grid_thw = []
for feature in features:
if feature["image_grid_thw"] is None:
image_grid_thw.append(torch.zeros(0, dtype=torch.long))
else:
image_grid_thw.append(torch.tensor(feature["image_grid_thw"], dtype=torch.long))
extra_features["image_grid_thw"] = torch.cat(pixel_values, dim=0)
if extra_features["image_grid_thw"].numel() == 0:
extra_features["image_grid_thw"] = None
features = [{key: feature[key] for key in feature if key not in extra_features.keys()} for feature in features]
features: Dict[str, "torch.Tensor"] = super().__call__(features)
features.update({key: value for key, value in extra_features.items() if value is not None})
return features
@dataclass
class SFTDataCollatorWith4DAttentionMask(CustomDataCollatorForSeq2Seq):
class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
r"""
Data collator for 4d attention mask.
"""
@@ -117,7 +122,7 @@ class SFTDataCollatorWith4DAttentionMask(CustomDataCollatorForSeq2Seq):
@dataclass
class PairwiseDataCollatorWithPadding(CustomDataCollatorForSeq2Seq):
class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
r"""
Data collator for pairwise data.
"""
@@ -152,7 +157,7 @@ class PairwiseDataCollatorWithPadding(CustomDataCollatorForSeq2Seq):
@dataclass
class KTODataCollatorWithPadding(CustomDataCollatorForSeq2Seq):
class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
r"""
Data collator for KTO data.
"""