initial-commit
Former-commit-id: b6a39847a10b417b09db4b5512dd835e9e4ce928
This commit is contained in:
@@ -72,9 +72,37 @@ class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq):
|
||||
compute_dtype: "torch.dtype" = torch.float32
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
||||
image_grid_thw = None
|
||||
if "image_grid_thw" in features[0]:
|
||||
image_grid_thw_list = [
|
||||
torch.Tensor(feature["image_grid_thw"]).long()
|
||||
for feature in features
|
||||
if feature["image_grid_thw"][0][0] > 0
|
||||
]
|
||||
pixel_values_list = [
|
||||
torch.Tensor(feature["pixel_values"]) for feature in features if feature["image_grid_thw"][0][0] > 0
|
||||
]
|
||||
if image_grid_thw_list:
|
||||
image_grid_thw = torch.cat(image_grid_thw_list, 0)
|
||||
else:
|
||||
# Handle the case where the list is empty, for example:
|
||||
image_grid_thw = None
|
||||
if pixel_values_list:
|
||||
pixel_values = torch.cat(pixel_values_list, 0)
|
||||
else:
|
||||
# Handle the case where the list is empty, for example:
|
||||
pixel_values = None
|
||||
features = [
|
||||
{key: feature[key] for key in feature if key not in ["image_grid_thw", "pixel_values"]}
|
||||
for feature in features
|
||||
]
|
||||
|
||||
features = super().__call__(features)
|
||||
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
|
||||
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
|
||||
if image_grid_thw is not None:
|
||||
features["image_grid_thw"] = image_grid_thw
|
||||
features["pixel_values"] = pixel_values
|
||||
|
||||
return features
|
||||
|
||||
|
||||
Reference in New Issue
Block a user