Skip to content

Commit

Permalink
Added default jpeg quality constant, added test
Browse files Browse the repository at this point in the history
  • Loading branch information
brandon-groundlight committed Sep 22, 2023
1 parent f395a95 commit dde012a
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 4 deletions.
10 changes: 6 additions & 4 deletions src/groundlight/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from groundlight.optional_imports import Image, np

DEFAULT_JPEG_QUALITY = 95


class ByteStreamWrapper(IOBase):
"""This class acts as a thin wrapper around bytes in order to
Expand All @@ -28,11 +30,11 @@ def close(self) -> None:
pass


def bytestream_from_filename(image_filename: str, jpeg_quality) -> ByteStreamWrapper:
def bytestream_from_filename(image_filename: str, jpeg_quality: int = DEFAULT_JPEG_QUALITY) -> ByteStreamWrapper:
"""Determines what to do with an arbitrary filename
Only supports JPEG and PNG files for now.
For PNG files, we convert to a JPEG.
For PNG files, we convert to RGB format used in JPEGs.
"""
if imghdr.what(image_filename) == "jpeg":
buffer = buffer_from_jpeg_file(image_filename)
Expand All @@ -57,7 +59,7 @@ def buffer_from_jpeg_file(image_filename: str) -> BufferedReader:
raise ValueError("We only support JPEG files, for now.")


def jpeg_from_numpy(img: np.ndarray, jpeg_quality: int = 95) -> bytes:
def jpeg_from_numpy(img: np.ndarray, jpeg_quality: int = DEFAULT_JPEG_QUALITY) -> bytes:
"""Converts a numpy array to BytesIO."""
pilim = Image.fromarray(img.astype("uint8"), "RGB")
with BytesIO() as buf:
Expand All @@ -66,7 +68,7 @@ def jpeg_from_numpy(img: np.ndarray, jpeg_quality: int = 95) -> bytes:
return out


def bytestream_from_pil(pil_image: Image.Image, jpeg_quality: int = 95) -> ByteStreamWrapper:
def bytestream_from_pil(pil_image: Image.Image, jpeg_quality: int = DEFAULT_JPEG_QUALITY) -> ByteStreamWrapper:
"""Converts a PIL image to a BytesIO."""
bytesio = BytesIO()
pil_image.save(bytesio, "jpeg", quality=jpeg_quality)
Expand Down
17 changes: 17 additions & 0 deletions test/unit/test_imagefuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from io import BytesIO

import pytest
import PIL
from groundlight.images import *
from groundlight.optional_imports import *

Expand All @@ -26,6 +27,22 @@ def test_jpeg_from_numpy():
assert len(jpeg2) > len(jpeg3)


def test_bytestream_from_filename():
images_streams = []
images_streams.append(bytestream_from_filename("test/assets/cat.jpeg"))
images_streams.append(bytestream_from_filename("test/assets/cat.png"))
images_streams.append(bytestream_from_filename("test/assets/cat.png", jpeg_quality=95))
for i in images_streams:
assert isinstance(i, ByteStreamWrapper)
image = Image.open(i)
assert image.mode == "RGB"

# pixel based test, verified the image is correct by eye, then got a pixel whose value to check against
png_bytestream = bytestream_from_filename("test/assets/cat.png", jpeg_quality=95)
png_image = Image.open(png_bytestream)
assert png_image.getpixel((200, 200)) == (215, 209, 197)


def test_unsupported_image_type():
with pytest.raises(TypeError):
parse_supported_image_types(1) # type: ignore
Expand Down

0 comments on commit dde012a

Please sign in to comment.