initial-commit

Former-commit-id: b6a39847a10b417b09db4b5512dd835e9e4ce928
This commit is contained in:
simonJJJ
2024-08-28 16:51:35 +08:00
parent 7272792f65
commit 0f3d54d8a0
8 changed files with 183 additions and 5 deletions

View File

@@ -72,9 +72,37 @@ class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq):
compute_dtype: "torch.dtype" = torch.float32
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
image_grid_thw = None
if "image_grid_thw" in features[0]:
image_grid_thw_list = [
torch.Tensor(feature["image_grid_thw"]).long()
for feature in features
if feature["image_grid_thw"][0][0] > 0
]
pixel_values_list = [
torch.Tensor(feature["pixel_values"]) for feature in features if feature["image_grid_thw"][0][0] > 0
]
if image_grid_thw_list:
image_grid_thw = torch.cat(image_grid_thw_list, 0)
else:
# Handle the case where the list is empty, for example:
image_grid_thw = None
if pixel_values_list:
pixel_values = torch.cat(pixel_values_list, 0)
else:
# Handle the case where the list is empty, for example:
pixel_values = None
features = [
{key: feature[key] for key in feature if key not in ["image_grid_thw", "pixel_values"]}
for feature in features
]
features = super().__call__(features)
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
if image_grid_thw is not None:
features["image_grid_thw"] = image_grid_thw
features["pixel_values"] = pixel_values
return features