diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 430fa885b..06aec7c8b 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -1987,6 +1987,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin): f"Each {VIDEO_PLACEHOLDER} must be followed by an {AUDIO_PLACEHOLDER} when using audio in video." ) + position_id_per_seconds: int = getattr(processor, "position_id_per_seconds", 25) audio_t_index = torch.arange(audio_lengths[num_audio_tokens]) video_t_index = ( torch.arange(video_grid_thw[num_video_tokens][0]) @@ -1998,9 +1999,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin): ) .flatten() * mm_inputs["video_second_per_grid"][num_video_tokens] - * 25 # FIXME hardcode of position_id_per_seconds=25 + * position_id_per_seconds ).long() - t_ntoken_per_chunk = 50 # FIXME hardcode: [25 * 2] + t_ntoken_per_chunk = position_id_per_seconds * 2 video_chunk_indices = processor.get_chunked_index(video_t_index, t_ntoken_per_chunk) audio_chunk_indices = processor.get_chunked_index(audio_t_index, t_ntoken_per_chunk) placeholder_string = ""