@@ -168,6 +168,10 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
|
||||
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
|
||||
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
|
||||
|
||||
for key, value in features.items(): # cast data dtype for paligemma
|
||||
if torch.is_tensor(value) and torch.is_floating_point(value):
|
||||
features[key] = value.to(self.compute_dtype)
|
||||
|
||||
return features
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user