[model] support audio (#6701)

* support qwen2_audio

* improve code

* lint

* fix

* fix

* fix

---------

Co-authored-by: hiyouga <hiyouga@buaa.edu.cn>
Former-commit-id: 5eacb5629e4d7733cd992a63747a1335f2c6a929
This commit is contained in:
Zhangchi Feng
2025-02-05 04:59:09 +08:00
committed by GitHub
parent 9feb78e7b4
commit 8f401e37f8
35 changed files with 675 additions and 213 deletions

View File

@@ -18,11 +18,12 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence
import numpy as np
import torch
import torch.nn.functional as F
from transformers import DataCollatorForSeq2Seq
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER
from ..extras.packages import is_pillow_available
@@ -80,7 +81,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
r"""
Data collator that supports VLMs.
Features should contain input_ids, attention_mask, labels, and optionally contain images and videos.
Features should contain input_ids, attention_mask, labels, and optionally contain images, videos and audios.
"""
template: Optional["Template"] = None
@@ -91,26 +92,54 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
raise ValueError("Template is required for MultiModalDataCollator.")
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids = [], [], [], [], []
batch_images, batch_videos, batch_audios = [], [], []
batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], []
for feature in features:
images = feature.pop("images", None) or []
videos = feature.pop("videos", None) or []
audios = feature.pop("audios", None) or []
batch_images.extend(images)
batch_videos.extend(videos)
batch_audios.extend(audios)
batch_imglens.append(len(images))
batch_vidlens.append(len(videos))
batch_audlens.append(len(audios))
batch_input_ids.append(feature["input_ids"])
fake_input_ids = None
if (
self.processor is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0
self.template.mm_plugin.image_token is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0
): # avoid process hanging in zero3/fsdp case
fake_messages = [{"role": "user", "content": IMAGE_PLACEHOLDER}]
fake_images = [Image.new("RGB", (64, 64), (255, 255, 255))]
fake_messages = self.template.mm_plugin.process_messages(fake_messages, fake_images, [], self.processor)
fake_messages = self.template.mm_plugin.process_messages(
fake_messages, fake_images, [], [], self.processor
)
fake_input_ids = self.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False)
fake_input_ids, _ = self.template.mm_plugin.process_token_ids(
fake_input_ids, None, fake_images, [], self.tokenizer, self.processor
fake_input_ids, None, fake_images, [], [], self.tokenizer, self.processor
)
batch_images = fake_images
batch_imglens[0] = 1
batch_input_ids[0] = features[0]["input_ids"]
if (
self.template.mm_plugin.audio_token is not None and sum(batch_audlens) == 0
): # avoid process hanging in zero3/fsdp case
fake_messages = [{"role": "user", "content": AUDIO_PLACEHOLDER}]
fake_audios = [np.zeros(1600)]
fake_messages = self.template.mm_plugin.process_messages(
fake_messages, [], [], fake_audios, self.processor
)
fake_input_ids = self.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False)
fake_input_ids, _ = self.template.mm_plugin.process_token_ids(
fake_input_ids, None, [], [], fake_audios, self.tokenizer, self.processor
)
batch_audios = fake_audios
batch_audlens[0] = 1
batch_input_ids[0] = features[0]["input_ids"]
if fake_input_ids is not None:
if self.tokenizer.padding_side == "right":
features[0]["input_ids"] = features[0]["input_ids"] + fake_input_ids
features[0]["attention_mask"] = features[0]["attention_mask"] + [0] * len(fake_input_ids)
@@ -120,12 +149,15 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
features[0]["attention_mask"] = [0] * len(fake_input_ids) + features[0]["attention_mask"]
features[0]["labels"] = [IGNORE_INDEX] * len(fake_input_ids) + features[0]["labels"]
batch_images = fake_images
batch_imglens[0] = 1
batch_input_ids[0] = features[0]["input_ids"]
mm_inputs = self.template.mm_plugin.get_mm_inputs(
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids, self.processor
batch_images,
batch_videos,
batch_audios,
batch_imglens,
batch_vidlens,
batch_audlens,
batch_input_ids,
self.processor,
)
if "token_type_ids" in mm_inputs:
token_type_ids = mm_inputs.pop("token_type_ids")
@@ -208,6 +240,7 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
"labels": feature[f"{key}_labels"],
"images": feature["images"],
"videos": feature["videos"],
"audios": feature["audios"],
}
concatenated_features.append(target_feature)
@@ -231,6 +264,7 @@ class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
"labels": feature["labels"],
"images": feature["images"],
"videos": feature["videos"],
"audios": feature["audios"],
}
kl_feature = {
"input_ids": feature["kl_input_ids"],
@@ -238,6 +272,7 @@ class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
"labels": feature["kl_labels"],
"images": feature["images"],
"videos": feature["videos"],
"audios": feature["audios"],
}
target_features.append(target_feature)
kl_features.append(kl_feature)