[model] fix kv cache (#7564)
This commit is contained in:
@@ -1186,6 +1186,9 @@ class Qwen2OmniPlugin(BasePlugin):
|
||||
messages = deepcopy(messages)
|
||||
if self.expand_mm_tokens:
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
else:
|
||||
mm_inputs = {}
|
||||
|
||||
num_audio_tokens, num_image_tokens, num_video_tokens = 0, 0, 0
|
||||
use_audio_in_video = getattr(processor, "use_audio_in_video", False)
|
||||
|
||||
@@ -1193,18 +1196,22 @@ class Qwen2OmniPlugin(BasePlugin):
|
||||
if "feature_attention_mask" in mm_inputs:
|
||||
input_lengths = (mm_inputs["feature_attention_mask"].sum(-1).numpy() - 1) // 2 + 1
|
||||
audio_lengths = (input_lengths - 2) // 2 + 1
|
||||
|
||||
if mm_inputs.get("image_grid_thw", None) is not None:
|
||||
image_grid_thw = mm_inputs["image_grid_thw"]
|
||||
merge_length = processor.omni_processor.merge_size**2
|
||||
|
||||
if mm_inputs.get("video_grid_thw", None) is not None:
|
||||
video_grid_thw = mm_inputs["video_grid_thw"]
|
||||
merge_length = processor.omni_processor.merge_size**2
|
||||
|
||||
if use_audio_in_video:
|
||||
assert audio_lengths is not None, "audio_lengths should be exist when use_audio_in_video is `True`"
|
||||
assert mm_inputs.get("video_grid_thw", None) is not None, (
|
||||
"video_grid_thw should be exist when use_audio_in_video is `True`"
|
||||
)
|
||||
if audio_lengths is None:
|
||||
raise ValueError("audio_lengths should exist when use_audio_in_video is `True`.")
|
||||
|
||||
if not mm_inputs.get("video_grid_thw", None):
|
||||
raise ValueError("video_grid_thw should exist when use_audio_in_video is `True`.")
|
||||
|
||||
positions_list = []
|
||||
for i, message in enumerate(messages): # get multimodal index when use_audio
|
||||
positions = []
|
||||
@@ -1216,6 +1223,7 @@ class Qwen2OmniPlugin(BasePlugin):
|
||||
break
|
||||
positions.append((pos, special_token))
|
||||
start = pos + len(special_token)
|
||||
|
||||
positions_list.append(positions.sort(key=lambda x: x[0]))
|
||||
|
||||
for message in messages:
|
||||
@@ -1278,6 +1286,7 @@ class Qwen2OmniPlugin(BasePlugin):
|
||||
content = content.replace(AUDIO_PLACEHOLDER, "", 1)
|
||||
num_audio_tokens += 1
|
||||
num_video_tokens += 1
|
||||
|
||||
message["content"] = content
|
||||
|
||||
if len(audios) != num_audio_tokens:
|
||||
|
||||
Reference in New Issue
Block a user