diff --git a/src/groundlight/images.py b/src/groundlight/images.py index 41b2c704..50e60295 100644 --- a/src/groundlight/images.py +++ b/src/groundlight/images.py @@ -4,6 +4,8 @@ from groundlight.optional_imports import Image, np +DEFAULT_JPEG_QUALITY = 95 + class ByteStreamWrapper(IOBase): """This class acts as a thin wrapper around bytes in order to @@ -28,11 +30,11 @@ def close(self) -> None: pass -def bytestream_from_filename(image_filename: str, jpeg_quality) -> ByteStreamWrapper: +def bytestream_from_filename(image_filename: str, jpeg_quality: int = DEFAULT_JPEG_QUALITY) -> ByteStreamWrapper: """Determines what to do with an arbitrary filename Only supports JPEG and PNG files for now. - For PNG files, we convert to a JPEG. + For PNG files, we convert to RGB format used in JPEGs. """ if imghdr.what(image_filename) == "jpeg": buffer = buffer_from_jpeg_file(image_filename) @@ -57,7 +59,7 @@ def buffer_from_jpeg_file(image_filename: str) -> BufferedReader: raise ValueError("We only support JPEG files, for now.") -def jpeg_from_numpy(img: np.ndarray, jpeg_quality: int = 95) -> bytes: +def jpeg_from_numpy(img: np.ndarray, jpeg_quality: int = DEFAULT_JPEG_QUALITY) -> bytes: """Converts a numpy array to BytesIO.""" pilim = Image.fromarray(img.astype("uint8"), "RGB") with BytesIO() as buf: @@ -66,7 +68,7 @@ def jpeg_from_numpy(img: np.ndarray, jpeg_quality: int = 95) -> bytes: return out -def bytestream_from_pil(pil_image: Image.Image, jpeg_quality: int = 95) -> ByteStreamWrapper: +def bytestream_from_pil(pil_image: Image.Image, jpeg_quality: int = DEFAULT_JPEG_QUALITY) -> ByteStreamWrapper: """Converts a PIL image to a BytesIO.""" bytesio = BytesIO() pil_image.save(bytesio, "jpeg", quality=jpeg_quality) diff --git a/test/unit/test_imagefuncs.py b/test/unit/test_imagefuncs.py index a159639d..961e0945 100644 --- a/test/unit/test_imagefuncs.py +++ b/test/unit/test_imagefuncs.py @@ -5,6 +5,7 @@ from io import BytesIO import pytest +import PIL from groundlight.images import * from groundlight.optional_imports import * @@ -26,6 +27,22 @@ def test_jpeg_from_numpy(): assert len(jpeg2) > len(jpeg3) +def test_bytestream_from_filename(): + images_streams = [] + images_streams.append(bytestream_from_filename("test/assets/cat.jpeg")) + images_streams.append(bytestream_from_filename("test/assets/cat.png")) + images_streams.append(bytestream_from_filename("test/assets/cat.png", jpeg_quality=95)) + for i in images_streams: + assert isinstance(i, ByteStreamWrapper) + image = Image.open(i) + assert image.mode == "RGB" + + # pixel based test, verified the image is correct by eye, then got a pixel whose value to check against + png_bytestream = bytestream_from_filename("test/assets/cat.png", jpeg_quality=95) + png_image = Image.open(png_bytestream) + assert png_image.getpixel((200, 200)) == (215, 209, 197) + + def test_unsupported_image_type(): with pytest.raises(TypeError): parse_supported_image_types(1) # type: ignore