diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index badd29ac6..fe87f53e8 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -1489,10 +1489,11 @@ class Qwen2VLPlugin(BasePlugin): @override def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> "RegularizedVideoOutput": - results, fps_per_video, durations = [], [], [] + results, fps_per_video, durations, frames_indices = [], [], [], [] for video in videos: frames: list[ImageObject] = [] if _check_video_is_nested_images(video): + # we assume already sample frames from videos for frame in video: if not is_valid_image(frame) and not isinstance(frame, dict) and not os.path.exists(frame): raise ValueError("Invalid image found in video frames.") @@ -1500,10 +1501,14 @@ class Qwen2VLPlugin(BasePlugin): frames = video fps_per_video.append(kwargs.get("video_fps", 2.0)) durations.append(len(frames) / kwargs.get("video_fps", 2.0)) + frames_indices.append(list(range(len(frames)))) else: container = av.open(video, "r") video_stream = next(stream for stream in container.streams if stream.type == "video") - sample_indices = self._get_video_sample_indices(video_stream, **kwargs) + sample_indices = self._get_video_sample_indices(video_stream, **kwargs) + original_fps = float(video_stream.average_rate) + # for qwen3vl video timestamp calculation + frames_indices.append([idx / original_fps * kwargs.get("video_fps", 2.0) for idx in sample_indices]) # hack usage when do_sample_frames=False container.seek(0) for frame_idx, frame in enumerate(container.decode(video_stream)): if frame_idx in sample_indices: @@ -1522,7 +1527,7 @@ class Qwen2VLPlugin(BasePlugin): frames = self._regularize_images(frames, **kwargs)["images"] results.append(frames) - return {"videos": results, "fps_per_video": fps_per_video, "durations": durations} + return {"videos": results, "fps_per_video": fps_per_video, "durations": durations, "frames_indices": frames_indices} @override def _get_mm_inputs( @@ -1637,8 +1642,8 @@ class Qwen3VLPlugin(Qwen2VLPlugin): video_maxlen=getattr(processor, "video_maxlen", 128), ) video_metadata = [ - {"fps": getattr(processor, "video_fps", 24.0), "duration": duration, "total_num_frames": len(video)} - for video, duration in zip(videos["videos"], videos["durations"]) + {"fps": getattr(processor, "video_fps", 24.0), "duration": duration, "total_num_frames": len(video), "frames_indices": sample_indices} + for video, duration, sample_indices in zip(videos["videos"], videos["durations"], videos["frames_indices"]) ] mm_inputs.update( video_processor( @@ -1646,6 +1651,7 @@ class Qwen3VLPlugin(Qwen2VLPlugin): video_metadata=video_metadata, fps=getattr(processor, "video_fps", 2.0), return_metadata=True, + do_sample_frames=False, # avoid changing frames_indices ) ) temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)