From 45afeb41614e293e2a7c12df9a1bc10125f3e94d Mon Sep 17 00:00:00 2001 From: John Brock Date: Sun, 18 Jun 2023 14:32:10 -0700 Subject: [PATCH] add zlib support to Decompressor --- test/test_local_io.py | 88 ++++++++++++++ torchdata/datapipes/iter/util/decompressor.py | 112 +++++++++++++++++- 2 files changed, 199 insertions(+), 1 deletion(-) diff --git a/test/test_local_io.py b/test/test_local_io.py index 313b5e361..8704302e0 100644 --- a/test/test_local_io.py +++ b/test/test_local_io.py @@ -18,6 +18,7 @@ import unittest import warnings import zipfile +import zlib from functools import partial from json.decoder import JSONDecodeError @@ -51,6 +52,7 @@ XzFileLoader, ZipArchiveLoader, ) +from torchdata.datapipes.iter.util.decompressor import _ZlibFile try: import iopath @@ -557,6 +559,13 @@ def _write_single_gz_file(self): with open(self.temp_files[0], "rb") as f: k.write(f.read()) + def _write_single_zlib_file(self): + import zlib + + with open(f"{self.temp_dir.name}/temp.zlib", "wb") as k: + with open(self.temp_files[0], "rb") as f: + k.write(zlib.compress(f.read())) + def test_decompressor_iterdatapipe(self): self._write_test_tar_files() self._write_test_tar_gz_files() @@ -564,6 +573,7 @@ def test_decompressor_iterdatapipe(self): self._write_test_zip_files() self._write_test_xz_files() self._write_test_bz2_files() + self._write_single_zlib_file() # Functional Test: work with .tar files tar_file_dp = FileLister(self.temp_dir.name, "*.tar") @@ -606,6 +616,14 @@ def test_decompressor_iterdatapipe(self): bz2_decompress_dp = Decompressor(bz2_load_dp, file_type="bz2") self._decompressor_bz2_test_helper(bz2_decompress_dp) + # Functional Test: work with .zlib files + zlib_file_dp = IterableWrapper([f"{self.temp_dir.name}/temp.zlib"]) + zlib_load_dp = FileOpener(zlib_file_dp, mode="b") + zlib_decompress_dp = Decompressor(zlib_load_dp, file_type="zlib") + for _, zlib_stream in zlib_decompress_dp: + with open(self.temp_files[0], "rb") as f: + self.assertEqual(f.read(), zlib_stream.read()) + # Functional Test: work without file type as input for .tar files tar_decompress_dp = Decompressor(tar_load_dp, file_type=None) self._decompressor_tar_test_helper(self.temp_files, tar_decompress_dp) @@ -622,6 +640,12 @@ def test_decompressor_iterdatapipe(self): bz2_decompress_dp = Decompressor(bz2_load_dp, file_type=None) self._decompressor_bz2_test_helper(bz2_decompress_dp) + # Functional Test: work without file type as input for .zlib files + zlib_decompress_dp = Decompressor(zlib_load_dp, file_type=None) + for _, zlib_stream in zlib_decompress_dp: + with open(self.temp_files[0], "rb") as f: + self.assertEqual(f.read(), zlib_stream.read()) + # Functional Test: Compression Type is works for both upper and lower case strings tar_decompress_dp = Decompressor(tar_load_dp, file_type="TAr") self._decompressor_tar_test_helper(self.temp_files, tar_decompress_dp) @@ -640,6 +664,70 @@ def test_decompressor_iterdatapipe(self): with self.assertRaisesRegex(TypeError, "has no len"): len(tar_decompress_dp) + def test_zlibfile_readall(self): + uncompressed_data_test_cases = [b"", b"some data", 10_000 * b"some data"] + for uncompressed_data in uncompressed_data_test_cases: + compressed_file = _ZlibFile(io.BytesIO(zlib.compress(uncompressed_data))) + self.assertEqual(compressed_file.readall(), uncompressed_data) + + def test_zlibfile_read(self): + uncompressed_data_test_cases = [b"", b"some data", 10_000 * b"some data"] + num_bytes_to_read_test_cases = [-1, 0, 1, 2, 64_000, 128_000] + for uncompressed_data in uncompressed_data_test_cases: + for num_bytes_to_read in num_bytes_to_read_test_cases: + compressed_file = _ZlibFile(io.BytesIO(zlib.compress(uncompressed_data))) + result = bytearray() + chunk = compressed_file.read(num_bytes_to_read) + while chunk: + result.extend(chunk) + chunk = compressed_file.read(num_bytes_to_read) + if num_bytes_to_read == 0: + self.assertEqual(result, b"") + else: + self.assertEqual(result, uncompressed_data) + + def test_zlibfile_stream_ends_prematurely(self): + compressed_bytes = zlib.compress(b"some data") + # slice compressed bytes so that the stream ends prematurely + compressed_file = _ZlibFile(io.BytesIO(compressed_bytes[:-2])) + with self.assertRaises(EOFError): + compressed_file.read() + + def test_zlibfile_iteration(self): + # Ensure there are at least io.DEFAULT_BUFFER_SIZE bytes so that multiple read calls are + # performed under-the-hood + uncompressed_bytes = b"1234\n56\n\n78\n" + b"9" * io.DEFAULT_BUFFER_SIZE + compressed_bytes = zlib.compress(uncompressed_bytes) + + # Test _ZlibFile.__next__ + compressed_file = _ZlibFile(io.BytesIO(compressed_bytes)) + self.assertEqual(next(compressed_file), b"1234\n") + self.assertEqual(next(compressed_file), b"56\n") + self.assertEqual(next(compressed_file), b"\n") + self.assertEqual(next(compressed_file), b"78\n") + self.assertEqual(next(compressed_file), b"9" * io.DEFAULT_BUFFER_SIZE) + with self.assertRaises(StopIteration): + next(compressed_file) + + # Test _ZlibFile iterator creation as performed in StreamWrapper + def create_iterator(): + yield from _ZlibFile(io.BytesIO(compressed_bytes)) + + self.assertEqual(list(create_iterator()), [b"1234\n", b"56\n", b"\n", b"78\n", b"9" * io.DEFAULT_BUFFER_SIZE]) + + # Test that interleaving calls to `read` with calls to `next` works as expected + compressed_file = _ZlibFile(io.BytesIO(compressed_bytes)) + compressed_file.read(2) + self.assertEqual(next(compressed_file), b"34\n") + self.assertEqual(compressed_file.read(5), b"56\n\n7") + self.assertEqual(next(compressed_file), b"8\n") + self.assertEqual(compressed_file.read(3), b"999") + self.assertEqual(next(compressed_file), b"9" * (io.DEFAULT_BUFFER_SIZE - 3)) + with self.assertRaises(StopIteration): + next(compressed_file) + self.assertEqual(compressed_file.read(1), b"") + self.assertEqual(compressed_file.read(), b"") + def _write_text_files(self): name_to_data = {"1.text": b"DATA", "2.text": b"DATA", "3.text": b"DATA"} source_dp = IterableWrapper(sorted(name_to_data.items())) diff --git a/torchdata/datapipes/iter/util/decompressor.py b/torchdata/datapipes/iter/util/decompressor.py index aafcb7144..849af6441 100644 --- a/torchdata/datapipes/iter/util/decompressor.py +++ b/torchdata/datapipes/iter/util/decompressor.py @@ -6,11 +6,14 @@ import bz2 import gzip +import io import lzma import os import pathlib +import sys import tarfile import zipfile +import zlib from enum import Enum from io import IOBase @@ -27,6 +30,7 @@ class CompressionType(Enum): TAR = "tar" ZIP = "zip" BZIP2 = "bz2" + ZLIB = "zlib" @functional_datapipe("decompress") @@ -38,7 +42,7 @@ class DecompressorIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]): Args: source_datapipe: IterDataPipe containing tuples of path and compressed stream of data - file_type: Optional `string` or ``CompressionType`` that represents what compression format of the inputs + file_type: Optional `string` or ``CompressionType`` that represents the compression format of the inputs Example: >>> from torchdata.datapipes.iter import FileLister, FileOpener @@ -58,6 +62,7 @@ class DecompressorIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]): types.TAR: lambda file: tarfile.open(fileobj=file, mode="r:*"), types.ZIP: lambda file: zipfile.ZipFile(file=file), types.BZIP2: lambda file: bz2.BZ2File(filename=file), + types.ZLIB: lambda file: _ZlibFile(file), } def __init__( @@ -87,6 +92,8 @@ def _detect_compression_type(self, path: str) -> CompressionType: return self.types.ZIP elif ext == ".bz2": return self.types.BZIP2 + elif ext == ".zlib": + return self.types.ZLIB else: raise RuntimeError( f"File at {path} has file extension {ext}, which does not match what are supported by" @@ -114,3 +121,106 @@ def __new__( cls, source_datapipe: IterDataPipe[Tuple[str, IOBase]], file_type: Optional[Union[str, CompressionType]] = None ): return DecompressorIterDataPipe(source_datapipe, file_type) + + +class _ZlibFile: + """ + A minimal read-only file object for decompressing zlib data. It's only intended to be wrapped by + StreamWrapper and isn't intended to be used outside decompressor.py. It only supports the + specific operations expected by StreamWrapper. + """ + + def __init__(self, file) -> None: + self._decompressor = zlib.decompressobj() + self._file = file + + # Stores decompressed bytes leftover from a call to __next__: since __next__ only returns + # up until the next newline, we need to store the bytes from after the newline for future + # calls to read or __next__. + self._buffer = bytearray() + + # Whether or not self._file still has bytes left to read + self._file_exhausted = False + + def read(self, size: int = -1) -> bytearray: + if size < 0: + return self.readall() + + if not size: + return bytearray() + + result = self._buffer[:size] + self._buffer = self._buffer[size:] + while len(result) < size and self._compressed_bytes_remain(): + # If decompress was called previously, there might be some compressed bytes from a previous chunk + # that haven't been decompressed yet (because decompress() was passed a max_length value that + # didn't exhaust the compressed bytes in the chunk). We can retrieve these leftover bytes from + # unconsumed_tail: + chunk = self._decompressor.unconsumed_tail + # Let's read compressed bytes in chunks of io.DEFAULT_BUFFER_SIZE because this is what python's gzip + # library does: + # https://github.com/python/cpython/blob/a6326972253bf5282c5bf422f4a16d93ace77b57/Lib/gzip.py#L505 + if len(chunk) < io.DEFAULT_BUFFER_SIZE: + compressed_bytes = self._file.read(io.DEFAULT_BUFFER_SIZE - len(chunk)) + if compressed_bytes: + chunk += compressed_bytes + else: + self._file_exhausted = True + decompressed_chunk = self._decompressor.decompress(chunk, max_length=size - len(result)) + result.extend(decompressed_chunk) + if not self._compressed_bytes_remain() and not self._decompressor.eof: + # There are no more compressed bytes available to decompress, but we haven't reached the + # zlib EOF, so something is wrong. + raise EOFError("Compressed file ended before the end-of-stream marker was reached") + return result + + def readall(self): + """ + This is just mimicking python's internal DecompressReader.readall: + https://github.com/python/cpython/blob/a6326972253bf5282c5bf422f4a16d93ace77b57/Lib/_compression.py#L113 + """ + chunks = [] + # sys.maxsize means the max length of output buffer is unlimited, + # so that the whole input buffer can be decompressed within one + # .decompress() call. + data = self.read(sys.maxsize) + while data: + chunks.append(data) + data = self.read(sys.maxsize) + + return b"".join(chunks) + + def __iter__(self): + return self + + def __next__(self): + if not self._buffer and not self._compressed_bytes_remain(): + raise StopIteration + + # Check if the buffer already has a newline, in which case we don't need to do any decompression of + # remaining bytes + newline_index = self._buffer.find(b"\n") + if newline_index != -1: + line = self._buffer[: newline_index + 1] + self._buffer = self._buffer[newline_index + 1 :] + return line + + # Keep decompressing bytes until we find a newline or run out of bytes + line = self._buffer + self._buffer = bytearray() + while self._compressed_bytes_remain(): + decompressed_chunk = self.read(io.DEFAULT_BUFFER_SIZE) + newline_index = decompressed_chunk.find(b"\n") + if newline_index == -1: + line.extend(decompressed_chunk) + else: + line.extend(decompressed_chunk[: newline_index + 1]) + self._buffer.extend(decompressed_chunk[newline_index + 1 :]) + return line + return line + + def _compressed_bytes_remain(self) -> bool: + """ + True if there are compressed bytes still left to decompress. False otherwise. + """ + return not self._file_exhausted or self._decompressor.unconsumed_tail != b""