add e2e tests
Former-commit-id: 0156a37450604641c4f5f9756ad84324698fc88c
This commit is contained in:
@@ -19,7 +19,6 @@ if is_pyav_available():
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
from numpy.typing import NDArray
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
|
||||
@@ -31,11 +30,17 @@ if TYPE_CHECKING:
|
||||
VideoInput = str
|
||||
|
||||
|
||||
def _regularize_images(images: Sequence["ImageInput"], processor: "ProcessorMixin") -> List["ImageObject"]:
|
||||
def _regularize_images(
|
||||
images: Sequence["ImageInput"],
|
||||
processor: "ProcessorMixin",
|
||||
max_resolution: Optional[int] = None,
|
||||
) -> List["ImageObject"]:
|
||||
r"""
|
||||
Regularizes images to avoid error. Including reading, resizing and converting.
|
||||
"""
|
||||
image_resolution: int = getattr(processor, "image_resolution", 512)
|
||||
if max_resolution is None:
|
||||
max_resolution: int = getattr(processor, "image_resolution", 512)
|
||||
|
||||
results = []
|
||||
for image in images:
|
||||
if isinstance(image, str):
|
||||
@@ -49,9 +54,9 @@ def _regularize_images(images: Sequence["ImageInput"], processor: "ProcessorMixi
|
||||
if not isinstance(image, ImageObject):
|
||||
raise ValueError("Expect input is a list of Images, but got {}.".format(type(image)))
|
||||
|
||||
if max(image.width, image.height) > image_resolution:
|
||||
factor = image_resolution / max(image.width, image.height)
|
||||
image = image.resize((int(image.width * factor), int(image.height * factor)))
|
||||
if max(image.width, image.height) > max_resolution:
|
||||
factor = max_resolution / max(image.width, image.height)
|
||||
image = image.resize((int(image.width * factor), int(image.height * factor)), resample=Image.NEAREST)
|
||||
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
@@ -61,11 +66,16 @@ def _regularize_images(images: Sequence["ImageInput"], processor: "ProcessorMixi
|
||||
return results
|
||||
|
||||
|
||||
def _regularize_videos(videos: Sequence["VideoInput"], processor: "ProcessorMixin") -> List["NDArray"]:
|
||||
def _regularize_videos(
|
||||
videos: Sequence["VideoInput"],
|
||||
processor: "ProcessorMixin",
|
||||
) -> List[List["ImageObject"]]:
|
||||
r"""
|
||||
Regularizes videos to avoid error. Including reading, resizing and converting.
|
||||
"""
|
||||
video_resolution: int = getattr(processor, "video_resolution", 128)
|
||||
video_fps: float = getattr(processor, "video_fps", 1.0)
|
||||
video_maxlen: int = getattr(processor, "video_maxlen", 64)
|
||||
video_factor: int = getattr(processor, "video_factor", 1)
|
||||
results = []
|
||||
for video in videos:
|
||||
@@ -73,6 +83,7 @@ def _regularize_videos(videos: Sequence["VideoInput"], processor: "ProcessorMixi
|
||||
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||
total_frames = video_stream.frames
|
||||
sample_frames = float(video_stream.duration * video_stream.time_base) * video_fps
|
||||
sample_frames = min(video_maxlen, sample_frames) # reduce length <= maxlen
|
||||
sample_frames = round(sample_frames / video_factor) * video_factor # for qwen2_vl
|
||||
sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
|
||||
frames: List["ImageObject"] = []
|
||||
@@ -81,7 +92,7 @@ def _regularize_videos(videos: Sequence["VideoInput"], processor: "ProcessorMixi
|
||||
if frame_idx in sample_indices:
|
||||
frames.append(frame.to_image())
|
||||
|
||||
frames = _regularize_images(frames, processor)
|
||||
frames = _regularize_images(frames, processor, video_resolution)
|
||||
results.append(frames)
|
||||
|
||||
return results
|
||||
|
||||
@@ -562,8 +562,8 @@ _register_template(
|
||||
_register_template(
|
||||
name="cpm3",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
stop_words=["<|im_end|>"],
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user