fix mllama cross_mask
Former-commit-id: c33967308bebd99489d28bd5a879525cf304c1f9
This commit is contained in:
@@ -19,6 +19,7 @@ from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
|
||||
|
||||
@@ -98,6 +99,12 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
feature["token_type_ids"] = token_type_ids[i]
|
||||
|
||||
features: Dict[str, "torch.Tensor"] = super().__call__(features)
|
||||
if "cross_attention_mask" in mm_inputs: # for mllama inputs
|
||||
cross_attention_mask = mm_inputs.pop("cross_attention_mask")
|
||||
seq_len = features["input_ids"].size(1)
|
||||
orig_len = cross_attention_mask.size(1)
|
||||
mm_inputs["cross_attention_mask"] = F.pad(cross_attention_mask, (0, 0, 0, 0, 0, seq_len - orig_len))
|
||||
|
||||
features.update(mm_inputs)
|
||||
if isinstance(features.get("pixel_values"), list): # for pixtral inputs
|
||||
features = features.data # use default_collate() instead of BatchEncoding.to()
|
||||
|
||||
Reference in New Issue
Block a user