Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add zlib support to Decompressor #1189

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions test/test_local_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import unittest
import warnings
import zipfile
import zlib
from functools import partial

from json.decoder import JSONDecodeError
Expand Down Expand Up @@ -51,6 +52,7 @@
XzFileLoader,
ZipArchiveLoader,
)
from torchdata.datapipes.iter.util.decompressor import _ZlibFile

try:
import iopath
Expand Down Expand Up @@ -557,13 +559,21 @@ def _write_single_gz_file(self):
with open(self.temp_files[0], "rb") as f:
k.write(f.read())

def _write_single_zlib_file(self):
import zlib

with open(f"{self.temp_dir.name}/temp.zlib", "wb") as k:
with open(self.temp_files[0], "rb") as f:
k.write(zlib.compress(f.read()))

def test_decompressor_iterdatapipe(self):
self._write_test_tar_files()
self._write_test_tar_gz_files()
self._write_single_gz_file()
self._write_test_zip_files()
self._write_test_xz_files()
self._write_test_bz2_files()
self._write_single_zlib_file()

# Functional Test: work with .tar files
tar_file_dp = FileLister(self.temp_dir.name, "*.tar")
Expand Down Expand Up @@ -606,6 +616,14 @@ def test_decompressor_iterdatapipe(self):
bz2_decompress_dp = Decompressor(bz2_load_dp, file_type="bz2")
self._decompressor_bz2_test_helper(bz2_decompress_dp)

# Functional Test: work with .zlib files
zlib_file_dp = IterableWrapper([f"{self.temp_dir.name}/temp.zlib"])
zlib_load_dp = FileOpener(zlib_file_dp, mode="b")
zlib_decompress_dp = Decompressor(zlib_load_dp, file_type="zlib")
for _, zlib_stream in zlib_decompress_dp:
with open(self.temp_files[0], "rb") as f:
self.assertEqual(f.read(), zlib_stream.read())

# Functional Test: work without file type as input for .tar files
tar_decompress_dp = Decompressor(tar_load_dp, file_type=None)
self._decompressor_tar_test_helper(self.temp_files, tar_decompress_dp)
Expand All @@ -622,6 +640,12 @@ def test_decompressor_iterdatapipe(self):
bz2_decompress_dp = Decompressor(bz2_load_dp, file_type=None)
self._decompressor_bz2_test_helper(bz2_decompress_dp)

# Functional Test: work without file type as input for .zlib files
zlib_decompress_dp = Decompressor(zlib_load_dp, file_type=None)
for _, zlib_stream in zlib_decompress_dp:
with open(self.temp_files[0], "rb") as f:
self.assertEqual(f.read(), zlib_stream.read())

# Functional Test: Compression Type is works for both upper and lower case strings
tar_decompress_dp = Decompressor(tar_load_dp, file_type="TAr")
self._decompressor_tar_test_helper(self.temp_files, tar_decompress_dp)
Expand All @@ -640,6 +664,70 @@ def test_decompressor_iterdatapipe(self):
with self.assertRaisesRegex(TypeError, "has no len"):
len(tar_decompress_dp)

def test_zlibfile_readall(self):
uncompressed_data_test_cases = [b"", b"some data", 10_000 * b"some data"]
for uncompressed_data in uncompressed_data_test_cases:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't find an obvious way of using pytest parameterized tests that was compatible with expecttest, so I resorted to just manually iterating over a list of inputs.

compressed_file = _ZlibFile(io.BytesIO(zlib.compress(uncompressed_data)))
self.assertEqual(compressed_file.readall(), uncompressed_data)

def test_zlibfile_read(self):
uncompressed_data_test_cases = [b"", b"some data", 10_000 * b"some data"]
num_bytes_to_read_test_cases = [-1, 0, 1, 2, 64_000, 128_000]
for uncompressed_data in uncompressed_data_test_cases:
for num_bytes_to_read in num_bytes_to_read_test_cases:
compressed_file = _ZlibFile(io.BytesIO(zlib.compress(uncompressed_data)))
result = bytearray()
chunk = compressed_file.read(num_bytes_to_read)
while chunk:
result.extend(chunk)
chunk = compressed_file.read(num_bytes_to_read)
if num_bytes_to_read == 0:
self.assertEqual(result, b"")
else:
self.assertEqual(result, uncompressed_data)

def test_zlibfile_stream_ends_prematurely(self):
compressed_bytes = zlib.compress(b"some data")
# slice compressed bytes so that the stream ends prematurely
compressed_file = _ZlibFile(io.BytesIO(compressed_bytes[:-2]))
with self.assertRaises(EOFError):
compressed_file.read()

def test_zlibfile_iteration(self):
# Ensure there are at least io.DEFAULT_BUFFER_SIZE bytes so that multiple read calls are
# performed under-the-hood
uncompressed_bytes = b"1234\n56\n\n78\n" + b"9" * io.DEFAULT_BUFFER_SIZE
compressed_bytes = zlib.compress(uncompressed_bytes)

# Test _ZlibFile.__next__
compressed_file = _ZlibFile(io.BytesIO(compressed_bytes))
self.assertEqual(next(compressed_file), b"1234\n")
self.assertEqual(next(compressed_file), b"56\n")
self.assertEqual(next(compressed_file), b"\n")
self.assertEqual(next(compressed_file), b"78\n")
self.assertEqual(next(compressed_file), b"9" * io.DEFAULT_BUFFER_SIZE)
with self.assertRaises(StopIteration):
next(compressed_file)

# Test _ZlibFile iterator creation as performed in StreamWrapper
def create_iterator():
yield from _ZlibFile(io.BytesIO(compressed_bytes))

self.assertEqual(list(create_iterator()), [b"1234\n", b"56\n", b"\n", b"78\n", b"9" * io.DEFAULT_BUFFER_SIZE])

# Test that interleaving calls to `read` with calls to `next` works as expected
compressed_file = _ZlibFile(io.BytesIO(compressed_bytes))
compressed_file.read(2)
self.assertEqual(next(compressed_file), b"34\n")
self.assertEqual(compressed_file.read(5), b"56\n\n7")
self.assertEqual(next(compressed_file), b"8\n")
self.assertEqual(compressed_file.read(3), b"999")
self.assertEqual(next(compressed_file), b"9" * (io.DEFAULT_BUFFER_SIZE - 3))
with self.assertRaises(StopIteration):
next(compressed_file)
self.assertEqual(compressed_file.read(1), b"")
self.assertEqual(compressed_file.read(), b"")

def _write_text_files(self):
name_to_data = {"1.text": b"DATA", "2.text": b"DATA", "3.text": b"DATA"}
source_dp = IterableWrapper(sorted(name_to_data.items()))
Expand Down
112 changes: 111 additions & 1 deletion torchdata/datapipes/iter/util/decompressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@

import bz2
import gzip
import io
import lzma
import os
import pathlib
import sys
import tarfile
import zipfile
import zlib

from enum import Enum
from io import IOBase
Expand All @@ -27,6 +30,7 @@ class CompressionType(Enum):
TAR = "tar"
ZIP = "zip"
BZIP2 = "bz2"
ZLIB = "zlib"


@functional_datapipe("decompress")
Expand All @@ -38,7 +42,7 @@ class DecompressorIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):

Args:
source_datapipe: IterDataPipe containing tuples of path and compressed stream of data
file_type: Optional `string` or ``CompressionType`` that represents what compression format of the inputs
file_type: Optional `string` or ``CompressionType`` that represents the compression format of the inputs

Example:
>>> from torchdata.datapipes.iter import FileLister, FileOpener
Expand All @@ -58,6 +62,7 @@ class DecompressorIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
types.TAR: lambda file: tarfile.open(fileobj=file, mode="r:*"),
types.ZIP: lambda file: zipfile.ZipFile(file=file),
types.BZIP2: lambda file: bz2.BZ2File(filename=file),
types.ZLIB: lambda file: _ZlibFile(file),
}

def __init__(
Expand Down Expand Up @@ -87,6 +92,8 @@ def _detect_compression_type(self, path: str) -> CompressionType:
return self.types.ZIP
elif ext == ".bz2":
return self.types.BZIP2
elif ext == ".zlib":
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There doesn't seem to be a strong standard for what extension zlib files should have. I've also seen .z and .zz.

return self.types.ZLIB
else:
raise RuntimeError(
f"File at {path} has file extension {ext}, which does not match what are supported by"
Expand Down Expand Up @@ -114,3 +121,106 @@ def __new__(
cls, source_datapipe: IterDataPipe[Tuple[str, IOBase]], file_type: Optional[Union[str, CompressionType]] = None
):
return DecompressorIterDataPipe(source_datapipe, file_type)


class _ZlibFile:
"""
A minimal read-only file object for decompressing zlib data. It's only intended to be wrapped by
StreamWrapper and isn't intended to be used outside decompressor.py. It only supports the
specific operations expected by StreamWrapper.
"""

def __init__(self, file) -> None:
self._decompressor = zlib.decompressobj()
self._file = file

# Stores decompressed bytes leftover from a call to __next__: since __next__ only returns
# up until the next newline, we need to store the bytes from after the newline for future
# calls to read or __next__.
self._buffer = bytearray()

# Whether or not self._file still has bytes left to read
self._file_exhausted = False

def read(self, size: int = -1) -> bytearray:
if size < 0:
return self.readall()

if not size:
return bytearray()

result = self._buffer[:size]
self._buffer = self._buffer[size:]
while len(result) < size and self._compressed_bytes_remain():
# If decompress was called previously, there might be some compressed bytes from a previous chunk
# that haven't been decompressed yet (because decompress() was passed a max_length value that
# didn't exhaust the compressed bytes in the chunk). We can retrieve these leftover bytes from
# unconsumed_tail:
chunk = self._decompressor.unconsumed_tail
# Let's read compressed bytes in chunks of io.DEFAULT_BUFFER_SIZE because this is what python's gzip
# library does:
# https://github.com/python/cpython/blob/a6326972253bf5282c5bf422f4a16d93ace77b57/Lib/gzip.py#L505
if len(chunk) < io.DEFAULT_BUFFER_SIZE:
compressed_bytes = self._file.read(io.DEFAULT_BUFFER_SIZE - len(chunk))
if compressed_bytes:
chunk += compressed_bytes
else:
self._file_exhausted = True
decompressed_chunk = self._decompressor.decompress(chunk, max_length=size - len(result))
result.extend(decompressed_chunk)
if not self._compressed_bytes_remain() and not self._decompressor.eof:
# There are no more compressed bytes available to decompress, but we haven't reached the
# zlib EOF, so something is wrong.
raise EOFError("Compressed file ended before the end-of-stream marker was reached")
return result

def readall(self):
"""
This is just mimicking python's internal DecompressReader.readall:
https://github.com/python/cpython/blob/a6326972253bf5282c5bf422f4a16d93ace77b57/Lib/_compression.py#L113
"""
chunks = []
# sys.maxsize means the max length of output buffer is unlimited,
# so that the whole input buffer can be decompressed within one
# .decompress() call.
data = self.read(sys.maxsize)
while data:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't sure if python 3.7 support was desired, so I didn't use the walrus operator here.

chunks.append(data)
data = self.read(sys.maxsize)

return b"".join(chunks)

def __iter__(self):
return self

def __next__(self):
if not self._buffer and not self._compressed_bytes_remain():
raise StopIteration

# Check if the buffer already has a newline, in which case we don't need to do any decompression of
# remaining bytes
newline_index = self._buffer.find(b"\n")
if newline_index != -1:
line = self._buffer[: newline_index + 1]
self._buffer = self._buffer[newline_index + 1 :]
return line

# Keep decompressing bytes until we find a newline or run out of bytes
line = self._buffer
self._buffer = bytearray()
while self._compressed_bytes_remain():
decompressed_chunk = self.read(io.DEFAULT_BUFFER_SIZE)
newline_index = decompressed_chunk.find(b"\n")
if newline_index == -1:
line.extend(decompressed_chunk)
else:
line.extend(decompressed_chunk[: newline_index + 1])
self._buffer.extend(decompressed_chunk[newline_index + 1 :])
return line
return line

def _compressed_bytes_remain(self) -> bool:
"""
True if there are compressed bytes still left to decompress. False otherwise.
"""
return not self._file_exhausted or self._decompressor.unconsumed_tail != b""