[model] support audio (#6701)
* support qwen2_audio * improve code * lint * fix * fix * fix --------- Co-authored-by: hiyouga <hiyouga@buaa.edu.cn> Former-commit-id: 5eacb5629e4d7733cd992a63747a1335f2c6a929
This commit is contained in:
@@ -18,11 +18,12 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
|
||||
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER
|
||||
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER
|
||||
from ..extras.packages import is_pillow_available
|
||||
|
||||
|
||||
@@ -80,7 +81,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator that supports VLMs.
|
||||
|
||||
Features should contain input_ids, attention_mask, labels, and optionally contain images and videos.
|
||||
Features should contain input_ids, attention_mask, labels, and optionally contain images, videos and audios.
|
||||
"""
|
||||
|
||||
template: Optional["Template"] = None
|
||||
@@ -91,26 +92,54 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
raise ValueError("Template is required for MultiModalDataCollator.")
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
||||
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids = [], [], [], [], []
|
||||
batch_images, batch_videos, batch_audios = [], [], []
|
||||
batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], []
|
||||
for feature in features:
|
||||
images = feature.pop("images", None) or []
|
||||
videos = feature.pop("videos", None) or []
|
||||
audios = feature.pop("audios", None) or []
|
||||
batch_images.extend(images)
|
||||
batch_videos.extend(videos)
|
||||
batch_audios.extend(audios)
|
||||
batch_imglens.append(len(images))
|
||||
batch_vidlens.append(len(videos))
|
||||
batch_audlens.append(len(audios))
|
||||
batch_input_ids.append(feature["input_ids"])
|
||||
|
||||
fake_input_ids = None
|
||||
if (
|
||||
self.processor is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0
|
||||
self.template.mm_plugin.image_token is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0
|
||||
): # avoid process hanging in zero3/fsdp case
|
||||
fake_messages = [{"role": "user", "content": IMAGE_PLACEHOLDER}]
|
||||
fake_images = [Image.new("RGB", (64, 64), (255, 255, 255))]
|
||||
fake_messages = self.template.mm_plugin.process_messages(fake_messages, fake_images, [], self.processor)
|
||||
fake_messages = self.template.mm_plugin.process_messages(
|
||||
fake_messages, fake_images, [], [], self.processor
|
||||
)
|
||||
fake_input_ids = self.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False)
|
||||
fake_input_ids, _ = self.template.mm_plugin.process_token_ids(
|
||||
fake_input_ids, None, fake_images, [], self.tokenizer, self.processor
|
||||
fake_input_ids, None, fake_images, [], [], self.tokenizer, self.processor
|
||||
)
|
||||
batch_images = fake_images
|
||||
batch_imglens[0] = 1
|
||||
batch_input_ids[0] = features[0]["input_ids"]
|
||||
|
||||
if (
|
||||
self.template.mm_plugin.audio_token is not None and sum(batch_audlens) == 0
|
||||
): # avoid process hanging in zero3/fsdp case
|
||||
fake_messages = [{"role": "user", "content": AUDIO_PLACEHOLDER}]
|
||||
fake_audios = [np.zeros(1600)]
|
||||
fake_messages = self.template.mm_plugin.process_messages(
|
||||
fake_messages, [], [], fake_audios, self.processor
|
||||
)
|
||||
fake_input_ids = self.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False)
|
||||
fake_input_ids, _ = self.template.mm_plugin.process_token_ids(
|
||||
fake_input_ids, None, [], [], fake_audios, self.tokenizer, self.processor
|
||||
)
|
||||
batch_audios = fake_audios
|
||||
batch_audlens[0] = 1
|
||||
batch_input_ids[0] = features[0]["input_ids"]
|
||||
|
||||
if fake_input_ids is not None:
|
||||
if self.tokenizer.padding_side == "right":
|
||||
features[0]["input_ids"] = features[0]["input_ids"] + fake_input_ids
|
||||
features[0]["attention_mask"] = features[0]["attention_mask"] + [0] * len(fake_input_ids)
|
||||
@@ -120,12 +149,15 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
features[0]["attention_mask"] = [0] * len(fake_input_ids) + features[0]["attention_mask"]
|
||||
features[0]["labels"] = [IGNORE_INDEX] * len(fake_input_ids) + features[0]["labels"]
|
||||
|
||||
batch_images = fake_images
|
||||
batch_imglens[0] = 1
|
||||
batch_input_ids[0] = features[0]["input_ids"]
|
||||
|
||||
mm_inputs = self.template.mm_plugin.get_mm_inputs(
|
||||
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids, self.processor
|
||||
batch_images,
|
||||
batch_videos,
|
||||
batch_audios,
|
||||
batch_imglens,
|
||||
batch_vidlens,
|
||||
batch_audlens,
|
||||
batch_input_ids,
|
||||
self.processor,
|
||||
)
|
||||
if "token_type_ids" in mm_inputs:
|
||||
token_type_ids = mm_inputs.pop("token_type_ids")
|
||||
@@ -208,6 +240,7 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
||||
"labels": feature[f"{key}_labels"],
|
||||
"images": feature["images"],
|
||||
"videos": feature["videos"],
|
||||
"audios": feature["audios"],
|
||||
}
|
||||
concatenated_features.append(target_feature)
|
||||
|
||||
@@ -231,6 +264,7 @@ class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
||||
"labels": feature["labels"],
|
||||
"images": feature["images"],
|
||||
"videos": feature["videos"],
|
||||
"audios": feature["audios"],
|
||||
}
|
||||
kl_feature = {
|
||||
"input_ids": feature["kl_input_ids"],
|
||||
@@ -238,6 +272,7 @@ class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
||||
"labels": feature["kl_labels"],
|
||||
"images": feature["images"],
|
||||
"videos": feature["videos"],
|
||||
"audios": feature["audios"],
|
||||
}
|
||||
target_features.append(target_feature)
|
||||
kl_features.append(kl_feature)
|
||||
|
||||
Reference in New Issue
Block a user