fix vlm zero3 training
Former-commit-id: 86fe7fe71b51077310357b7b1895522258f9bc7a
This commit is contained in:
@@ -22,6 +22,13 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
|
||||
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER
|
||||
from ..extras.packages import is_pillow_available
|
||||
|
||||
|
||||
if is_pillow_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import ProcessorMixin
|
||||
@@ -73,7 +80,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator that supports VLMs.
|
||||
|
||||
Features should contain input_ids, attention_mask, labels and images.
|
||||
Features should contain input_ids, attention_mask, labels, and optionally contain images and videos.
|
||||
"""
|
||||
|
||||
template: Optional["Template"] = None
|
||||
@@ -90,6 +97,17 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
batch_vidlens.append(len(videos))
|
||||
batch_input_ids.append(feature["input_ids"])
|
||||
|
||||
if self.processor is not None and sum(batch_imglens) == 0: # avoid process hanging in zero3/fsdp case
|
||||
fake_messages = [{"role": "user", "content": IMAGE_PLACEHOLDER}]
|
||||
fake_images = [Image.new("RGB", (64, 64), (255, 255, 255))]
|
||||
fake_messages = self.template.mm_plugin.process_messages(fake_messages, fake_images, [], self.processor)
|
||||
fake_input_ids = self.processor.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False)
|
||||
features[0]["input_ids"] = features[0]["input_ids"] + fake_input_ids
|
||||
features[0]["attention_mask"] = features[0]["attention_mask"] + [0] * len(fake_input_ids)
|
||||
features[0]["labels"] = features[0]["labels"] + [IGNORE_INDEX] * len(fake_input_ids)
|
||||
batch_images = fake_images
|
||||
batch_input_ids[0] = features[0]["input_ids"]
|
||||
|
||||
mm_inputs = self.template.mm_plugin.get_mm_inputs(
|
||||
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids, self.processor
|
||||
)
|
||||
@@ -99,7 +117,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
feature["token_type_ids"] = token_type_ids[i]
|
||||
|
||||
features: Dict[str, "torch.Tensor"] = super().__call__(features)
|
||||
if "cross_attention_mask" in mm_inputs: # for mllama inputs
|
||||
if "cross_attention_mask" in mm_inputs: # for mllama inputs when pad_to_multiple_of is enabled
|
||||
cross_attention_mask = mm_inputs.pop("cross_attention_mask")
|
||||
seq_len = features["input_ids"].size(1)
|
||||
orig_len = cross_attention_mask.size(1)
|
||||
|
||||
Reference in New Issue
Block a user