fix mixed mm inputs and rlhf-v
Former-commit-id: 7c248fac20bf85d57a91132ce7a793c7f84e9218
This commit is contained in:
@@ -62,44 +62,49 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
|
||||
|
||||
|
||||
@dataclass
|
||||
class CustomDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator for custom models (like Qwen2-VL).
|
||||
Data collator that supports VLMs.
|
||||
"""
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
||||
image_grid_thw = None # TODO: better handle various VLMs
|
||||
if "image_grid_thw" in features[0]:
|
||||
image_grid_thw_list = [
|
||||
torch.Tensor(feature["image_grid_thw"]).long()
|
||||
for feature in features
|
||||
if feature["image_grid_thw"][0][0] > 0
|
||||
]
|
||||
pixel_values_list = [
|
||||
torch.Tensor(feature["pixel_values"]) for feature in features if feature["image_grid_thw"][0][0] > 0
|
||||
]
|
||||
if image_grid_thw_list:
|
||||
image_grid_thw = torch.cat(image_grid_thw_list, dim=0)
|
||||
pixel_values = torch.cat(pixel_values_list, dim=0)
|
||||
else:
|
||||
image_grid_thw = None
|
||||
pixel_values = None
|
||||
if "token_type_ids" in features[0].keys():
|
||||
for feature in features:
|
||||
feature["token_type_ids"] = feature["token_type_ids"][0]
|
||||
|
||||
features = [
|
||||
{key: feature[key] for key in feature if key not in ["image_grid_thw", "pixel_values"]}
|
||||
for feature in features
|
||||
]
|
||||
extra_features = {}
|
||||
if "pixel_values" in features[0].keys():
|
||||
pixel_values = []
|
||||
for feature in features:
|
||||
if feature["pixel_values"] is None:
|
||||
pixel_values.append(torch.zeros(0, dtype=torch.float))
|
||||
else:
|
||||
pixel_values.append(torch.tensor(feature["pixel_values"], dtype=torch.float))
|
||||
|
||||
features = super().__call__(features)
|
||||
if image_grid_thw is not None:
|
||||
features["image_grid_thw"] = image_grid_thw
|
||||
features["pixel_values"] = pixel_values
|
||||
extra_features["pixel_values"] = torch.cat(pixel_values, dim=0)
|
||||
if extra_features["pixel_values"].numel() == 0:
|
||||
extra_features["pixel_values"] = None
|
||||
|
||||
if "image_grid_thw" in features[0].keys():
|
||||
image_grid_thw = []
|
||||
for feature in features:
|
||||
if feature["image_grid_thw"] is None:
|
||||
image_grid_thw.append(torch.zeros(0, dtype=torch.long))
|
||||
else:
|
||||
image_grid_thw.append(torch.tensor(feature["image_grid_thw"], dtype=torch.long))
|
||||
|
||||
extra_features["image_grid_thw"] = torch.cat(pixel_values, dim=0)
|
||||
if extra_features["image_grid_thw"].numel() == 0:
|
||||
extra_features["image_grid_thw"] = None
|
||||
|
||||
features = [{key: feature[key] for key in feature if key not in extra_features.keys()} for feature in features]
|
||||
features: Dict[str, "torch.Tensor"] = super().__call__(features)
|
||||
features.update({key: value for key, value in extra_features.items() if value is not None})
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
class SFTDataCollatorWith4DAttentionMask(CustomDataCollatorForSeq2Seq):
|
||||
class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator for 4d attention mask.
|
||||
"""
|
||||
@@ -117,7 +122,7 @@ class SFTDataCollatorWith4DAttentionMask(CustomDataCollatorForSeq2Seq):
|
||||
|
||||
|
||||
@dataclass
|
||||
class PairwiseDataCollatorWithPadding(CustomDataCollatorForSeq2Seq):
|
||||
class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator for pairwise data.
|
||||
"""
|
||||
@@ -152,7 +157,7 @@ class PairwiseDataCollatorWithPadding(CustomDataCollatorForSeq2Seq):
|
||||
|
||||
|
||||
@dataclass
|
||||
class KTODataCollatorWithPadding(CustomDataCollatorForSeq2Seq):
|
||||
class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator for KTO data.
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user