mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-04-03 16:23:07 +00:00
[data] fix qwen3vl timestamp (#10338)
This commit is contained in:
@@ -1489,10 +1489,11 @@ class Qwen2VLPlugin(BasePlugin):
|
|||||||
|
|
||||||
@override
|
@override
|
||||||
def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> "RegularizedVideoOutput":
|
def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> "RegularizedVideoOutput":
|
||||||
results, fps_per_video, durations = [], [], []
|
results, fps_per_video, durations, frames_indices = [], [], [], []
|
||||||
for video in videos:
|
for video in videos:
|
||||||
frames: list[ImageObject] = []
|
frames: list[ImageObject] = []
|
||||||
if _check_video_is_nested_images(video):
|
if _check_video_is_nested_images(video):
|
||||||
|
# we assume already sample frames from videos
|
||||||
for frame in video:
|
for frame in video:
|
||||||
if not is_valid_image(frame) and not isinstance(frame, dict) and not os.path.exists(frame):
|
if not is_valid_image(frame) and not isinstance(frame, dict) and not os.path.exists(frame):
|
||||||
raise ValueError("Invalid image found in video frames.")
|
raise ValueError("Invalid image found in video frames.")
|
||||||
@@ -1500,10 +1501,14 @@ class Qwen2VLPlugin(BasePlugin):
|
|||||||
frames = video
|
frames = video
|
||||||
fps_per_video.append(kwargs.get("video_fps", 2.0))
|
fps_per_video.append(kwargs.get("video_fps", 2.0))
|
||||||
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
|
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
|
||||||
|
frames_indices.append(list(range(len(frames))))
|
||||||
else:
|
else:
|
||||||
container = av.open(video, "r")
|
container = av.open(video, "r")
|
||||||
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||||
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
|
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
|
||||||
|
original_fps = float(video_stream.average_rate)
|
||||||
|
# for qwen3vl video timestamp calculation
|
||||||
|
frames_indices.append([idx / original_fps * kwargs.get("video_fps", 2.0) for idx in sample_indices]) # hack usage when do_sample_frames=False
|
||||||
container.seek(0)
|
container.seek(0)
|
||||||
for frame_idx, frame in enumerate(container.decode(video_stream)):
|
for frame_idx, frame in enumerate(container.decode(video_stream)):
|
||||||
if frame_idx in sample_indices:
|
if frame_idx in sample_indices:
|
||||||
@@ -1522,7 +1527,7 @@ class Qwen2VLPlugin(BasePlugin):
|
|||||||
frames = self._regularize_images(frames, **kwargs)["images"]
|
frames = self._regularize_images(frames, **kwargs)["images"]
|
||||||
results.append(frames)
|
results.append(frames)
|
||||||
|
|
||||||
return {"videos": results, "fps_per_video": fps_per_video, "durations": durations}
|
return {"videos": results, "fps_per_video": fps_per_video, "durations": durations, "frames_indices": frames_indices}
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def _get_mm_inputs(
|
def _get_mm_inputs(
|
||||||
@@ -1637,8 +1642,8 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
|
|||||||
video_maxlen=getattr(processor, "video_maxlen", 128),
|
video_maxlen=getattr(processor, "video_maxlen", 128),
|
||||||
)
|
)
|
||||||
video_metadata = [
|
video_metadata = [
|
||||||
{"fps": getattr(processor, "video_fps", 24.0), "duration": duration, "total_num_frames": len(video)}
|
{"fps": getattr(processor, "video_fps", 24.0), "duration": duration, "total_num_frames": len(video), "frames_indices": sample_indices}
|
||||||
for video, duration in zip(videos["videos"], videos["durations"])
|
for video, duration, sample_indices in zip(videos["videos"], videos["durations"], videos["frames_indices"])
|
||||||
]
|
]
|
||||||
mm_inputs.update(
|
mm_inputs.update(
|
||||||
video_processor(
|
video_processor(
|
||||||
@@ -1646,6 +1651,7 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
|
|||||||
video_metadata=video_metadata,
|
video_metadata=video_metadata,
|
||||||
fps=getattr(processor, "video_fps", 2.0),
|
fps=getattr(processor, "video_fps", 2.0),
|
||||||
return_metadata=True,
|
return_metadata=True,
|
||||||
|
do_sample_frames=False, # avoid changing frames_indices
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
|
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
|
||||||
|
|||||||
Reference in New Issue
Block a user