Former-commit-id: 8577f52b4152efe6cc7a8b5f6d37b4f9ba6684e7
This commit is contained in:
hiyouga
2024-12-30 05:55:15 +00:00
parent 5f473e2696
commit f8f05a883b
7 changed files with 26 additions and 11 deletions

View File

@@ -168,6 +168,10 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
for key, value in features.items(): # cast data dtype for paligemma
if torch.is_tensor(value) and torch.is_floating_point(value):
features[key] = value.to(self.compute_dtype)
return features