From bf03325c967e6102e26734dcde813902e32b977a Mon Sep 17 00:00:00 2001 From: WyattBlue Date: Wed, 4 Dec 2024 01:01:29 -0500 Subject: [PATCH] Address #1663 --- av/video/reformatter.pyi | 3 ++- av/video/reformatter.pyx | 24 +++++++++++++++++++----- tests/test_videoframe.py | 29 +++++++++++++++++++++++++++-- 3 files changed, 48 insertions(+), 8 deletions(-) diff --git a/av/video/reformatter.pyi b/av/video/reformatter.pyi index a601dd335..fd5dbd053 100644 --- a/av/video/reformatter.pyi +++ b/av/video/reformatter.pyi @@ -27,7 +27,8 @@ class Colorspace(IntEnum): fcc: int itu601: int itu624: int - smpte240: int + smpte170m: int + smpte240m: int default: int class ColorRange(IntEnum): diff --git a/av/video/reformatter.pyx b/av/video/reformatter.pyx index 4511d08af..538cc4953 100644 --- a/av/video/reformatter.pyx +++ b/av/video/reformatter.pyx @@ -45,6 +45,19 @@ class ColorRange(IntEnum): NB: "Not part of ABI" = lib.AVCOL_RANGE_NB +def _resolve_enum_value(value, enum_class, default): + # Helper function to resolve enum values from different input types. + if value is None: + return default + if isinstance(value, enum_class): + return value.value + if isinstance(value, int): + return value + if isinstance(value, str): + return enum_class[value].value + raise ValueError(f"Cannot convert {value} to {enum_class.__name__}") + + cdef class VideoReformatter: """An object for reformatting size and pixel format of :class:`.VideoFrame`. @@ -83,11 +96,12 @@ cdef class VideoReformatter: """ cdef VideoFormat video_format = VideoFormat(format if format is not None else frame.format) - cdef int c_src_colorspace = (Colorspace[src_colorspace].value if src_colorspace is not None else frame.colorspace) - cdef int c_dst_colorspace = (Colorspace[dst_colorspace].value if dst_colorspace is not None else frame.colorspace) - cdef int c_interpolation = (Interpolation[interpolation] if interpolation is not None else Interpolation.BILINEAR).value - cdef int c_src_color_range = (ColorRange[src_color_range].value if src_color_range is not None else 0) - cdef int c_dst_color_range = (ColorRange[dst_color_range].value if dst_color_range is not None else 0) + + cdef int c_src_colorspace = _resolve_enum_value(src_colorspace, Colorspace, frame.colorspace) + cdef int c_dst_colorspace = _resolve_enum_value(dst_colorspace, Colorspace, frame.colorspace) + cdef int c_interpolation = _resolve_enum_value(interpolation, Interpolation, int(Interpolation.BILINEAR)) + cdef int c_src_color_range = _resolve_enum_value(src_color_range, ColorRange, 0) + cdef int c_dst_color_range = _resolve_enum_value(dst_color_range, ColorRange, 0) return self._reformat( frame, diff --git a/tests/test_videoframe.py b/tests/test_videoframe.py index c93a12e32..f044be949 100644 --- a/tests/test_videoframe.py +++ b/tests/test_videoframe.py @@ -7,6 +7,7 @@ import av from av import VideoFrame +from av.video.reformatter import ColorRange, Colorspace, Interpolation from .common import ( TestCase, @@ -145,6 +146,24 @@ def test_roundtrip(self) -> None: img.save(self.sandboxed("roundtrip-high.jpg")) assertImagesAlmostEqual(image, img) + def test_interpolation(self) -> None: + import PIL.Image as Image + + image = Image.open(fate_png()) + frame = VideoFrame.from_image(image) + assert frame.width == 330 and frame.height == 330 + + img = frame.to_image(width=200, height=100, interpolation=Interpolation.BICUBIC) + assert img.width == 200 and img.height == 100 + + img = frame.to_image(width=200, height=100, interpolation="BICUBIC") + assert img.width == 200 and img.height == 100 + + img = frame.to_image( + width=200, height=100, interpolation=int(Interpolation.BICUBIC) + ) + assert img.width == 200 and img.height == 100 + def test_to_image_rgb24(self) -> None: sizes = [(318, 238), (320, 240), (500, 500)] for width, height in sizes: @@ -838,14 +857,20 @@ def test_reformat_identity() -> None: def test_reformat_colorspace() -> None: - # This is allowed. frame = VideoFrame(640, 480, "rgb24") frame.reformat(src_colorspace=None, dst_colorspace="smpte240m") - # I thought this was not allowed, but it seems to be. + frame = VideoFrame(640, 480, "rgb24") + frame.reformat(src_colorspace=None, dst_colorspace=Colorspace.smpte240m) + frame = VideoFrame(640, 480, "yuv420p") frame.reformat(src_colorspace=None, dst_colorspace="smpte240m") + frame = VideoFrame(640, 480, "rgb24") + frame.colorspace = Colorspace.smpte240m + assert frame.colorspace == int(Colorspace.smpte240m) + assert frame.colorspace == Colorspace.smpte240m + def test_reformat_pixel_format_align() -> None: height = 480