refactor mm training

Former-commit-id: 179c0558699e287cbf38a2d73bff47e86d589c5a
This commit is contained in:
hiyouga
2024-08-30 02:14:31 +08:00
parent 77c2c7076b
commit c62a6ca59d
29 changed files with 499 additions and 312 deletions

View File

@@ -62,15 +62,11 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
@dataclass
class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq):
class CustomDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
r"""
Data collator for 4d attention mask.
Data collator for custom models (like Qwen2-VL).
"""
block_diag_attn: bool = False
attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
compute_dtype: "torch.dtype" = torch.float32
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
image_grid_thw = None
if "image_grid_thw" in features[0]:
@@ -83,23 +79,18 @@ class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq):
torch.Tensor(feature["pixel_values"]) for feature in features if feature["image_grid_thw"][0][0] > 0
]
if image_grid_thw_list:
image_grid_thw = torch.cat(image_grid_thw_list, 0)
image_grid_thw = torch.cat(image_grid_thw_list, dim=0)
pixel_values = torch.cat(pixel_values_list, dim=0)
else:
# Handle the case where the list is empty, for example:
image_grid_thw = None
if pixel_values_list:
pixel_values = torch.cat(pixel_values_list, 0)
else:
# Handle the case where the list is empty, for example:
pixel_values = None
features = [
{key: feature[key] for key in feature if key not in ["image_grid_thw", "pixel_values"]}
for feature in features
]
features = super().__call__(features)
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
if image_grid_thw is not None:
features["image_grid_thw"] = image_grid_thw
features["pixel_values"] = pixel_values
@@ -108,7 +99,25 @@ class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq):
@dataclass
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
class SFTDataCollatorWith4DAttentionMask(CustomDataCollatorForSeq2Seq):
r"""
Data collator for 4d attention mask.
"""
block_diag_attn: bool = False
attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
compute_dtype: "torch.dtype" = torch.float32
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
features = super().__call__(features)
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
return features
@dataclass
class PairwiseDataCollatorWithPadding(CustomDataCollatorForSeq2Seq):
r"""
Data collator for pairwise data.
"""
@@ -128,9 +137,12 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
"attention_mask": feature["{}_attention_mask".format(key)],
"labels": feature["{}_labels".format(key)],
}
if "pixel_values" in feature:
if "pixel_values" in feature: # image data are same for chosen and rejected
target_feature["pixel_values"] = feature["pixel_values"]
if "image_grid_thw" in feature:
target_feature["image_grid_thw"] = feature["image_grid_thw"]
if "{}_token_type_ids".format(key) in feature:
target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)]
@@ -140,7 +152,7 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
@dataclass
class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
class KTODataCollatorWithPadding(CustomDataCollatorForSeq2Seq):
r"""
Data collator for KTO data.
"""
@@ -163,6 +175,9 @@ class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
if "pixel_values" in feature:
target_feature["pixel_values"] = feature["pixel_values"]
if "image_grid_thw" in feature:
target_feature["image_grid_thw"] = feature["image_grid_thw"]
if "token_type_ids" in feature:
target_feature["token_type_ids"] = feature["token_type_ids"]
kl_feature["token_type_ids"] = feature["kl_token_type_ids"]