[data] fix qwen3omni audio length calculation (#9467)
This commit is contained in:
@@ -1885,6 +1885,12 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
|||||||
image_grid_thw = mm_inputs.get("image_grid_thw", [])
|
image_grid_thw = mm_inputs.get("image_grid_thw", [])
|
||||||
video_grid_thw = mm_inputs.get("video_grid_thw", [])
|
video_grid_thw = mm_inputs.get("video_grid_thw", [])
|
||||||
if "feature_attention_mask" in mm_inputs:
|
if "feature_attention_mask" in mm_inputs:
|
||||||
|
if processor.__class__.__name__ == "Qwen3OmniMoeProcessor": # for qwen3omni
|
||||||
|
input_lengths = mm_inputs["feature_attention_mask"].sum(-1)
|
||||||
|
input_lengths_leave = input_lengths % 100
|
||||||
|
feature_lengths = (input_lengths_leave - 1) // 2 + 1
|
||||||
|
audio_lengths = ((feature_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
|
||||||
|
else:
|
||||||
input_lengths = (mm_inputs["feature_attention_mask"].sum(-1).numpy() - 1) // 2 + 1
|
input_lengths = (mm_inputs["feature_attention_mask"].sum(-1).numpy() - 1) // 2 + 1
|
||||||
audio_lengths = (input_lengths - 2) // 2 + 1
|
audio_lengths = (input_lengths - 2) // 2 + 1
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user