refactor mm training
Former-commit-id: 179c0558699e287cbf38a2d73bff47e86d589c5a
This commit is contained in:
@@ -62,15 +62,11 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
|
||||
|
||||
|
||||
@dataclass
|
||||
class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq):
|
||||
class CustomDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator for 4d attention mask.
|
||||
Data collator for custom models (like Qwen2-VL).
|
||||
"""
|
||||
|
||||
block_diag_attn: bool = False
|
||||
attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
|
||||
compute_dtype: "torch.dtype" = torch.float32
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
||||
image_grid_thw = None
|
||||
if "image_grid_thw" in features[0]:
|
||||
@@ -83,23 +79,18 @@ class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq):
|
||||
torch.Tensor(feature["pixel_values"]) for feature in features if feature["image_grid_thw"][0][0] > 0
|
||||
]
|
||||
if image_grid_thw_list:
|
||||
image_grid_thw = torch.cat(image_grid_thw_list, 0)
|
||||
image_grid_thw = torch.cat(image_grid_thw_list, dim=0)
|
||||
pixel_values = torch.cat(pixel_values_list, dim=0)
|
||||
else:
|
||||
# Handle the case where the list is empty, for example:
|
||||
image_grid_thw = None
|
||||
if pixel_values_list:
|
||||
pixel_values = torch.cat(pixel_values_list, 0)
|
||||
else:
|
||||
# Handle the case where the list is empty, for example:
|
||||
pixel_values = None
|
||||
|
||||
features = [
|
||||
{key: feature[key] for key in feature if key not in ["image_grid_thw", "pixel_values"]}
|
||||
for feature in features
|
||||
]
|
||||
|
||||
features = super().__call__(features)
|
||||
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
|
||||
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
|
||||
if image_grid_thw is not None:
|
||||
features["image_grid_thw"] = image_grid_thw
|
||||
features["pixel_values"] = pixel_values
|
||||
@@ -108,7 +99,25 @@ class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq):
|
||||
|
||||
|
||||
@dataclass
|
||||
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
|
||||
class SFTDataCollatorWith4DAttentionMask(CustomDataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator for 4d attention mask.
|
||||
"""
|
||||
|
||||
block_diag_attn: bool = False
|
||||
attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
|
||||
compute_dtype: "torch.dtype" = torch.float32
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
||||
features = super().__call__(features)
|
||||
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
|
||||
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
|
||||
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
class PairwiseDataCollatorWithPadding(CustomDataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator for pairwise data.
|
||||
"""
|
||||
@@ -128,9 +137,12 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
|
||||
"attention_mask": feature["{}_attention_mask".format(key)],
|
||||
"labels": feature["{}_labels".format(key)],
|
||||
}
|
||||
if "pixel_values" in feature:
|
||||
if "pixel_values" in feature: # image data are same for chosen and rejected
|
||||
target_feature["pixel_values"] = feature["pixel_values"]
|
||||
|
||||
if "image_grid_thw" in feature:
|
||||
target_feature["image_grid_thw"] = feature["image_grid_thw"]
|
||||
|
||||
if "{}_token_type_ids".format(key) in feature:
|
||||
target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)]
|
||||
|
||||
@@ -140,7 +152,7 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
|
||||
|
||||
|
||||
@dataclass
|
||||
class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
|
||||
class KTODataCollatorWithPadding(CustomDataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator for KTO data.
|
||||
"""
|
||||
@@ -163,6 +175,9 @@ class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
|
||||
if "pixel_values" in feature:
|
||||
target_feature["pixel_values"] = feature["pixel_values"]
|
||||
|
||||
if "image_grid_thw" in feature:
|
||||
target_feature["image_grid_thw"] = feature["image_grid_thw"]
|
||||
|
||||
if "token_type_ids" in feature:
|
||||
target_feature["token_type_ids"] = feature["token_type_ids"]
|
||||
kl_feature["token_type_ids"] = feature["kl_token_type_ids"]
|
||||
|
||||
Reference in New Issue
Block a user