-
Notifications
You must be signed in to change notification settings - Fork 160
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add zlib support to Decompressor #1189
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,11 +6,14 @@ | |
|
||
import bz2 | ||
import gzip | ||
import io | ||
import lzma | ||
import os | ||
import pathlib | ||
import sys | ||
import tarfile | ||
import zipfile | ||
import zlib | ||
|
||
from enum import Enum | ||
from io import IOBase | ||
|
@@ -27,6 +30,7 @@ class CompressionType(Enum): | |
TAR = "tar" | ||
ZIP = "zip" | ||
BZIP2 = "bz2" | ||
ZLIB = "zlib" | ||
|
||
|
||
@functional_datapipe("decompress") | ||
|
@@ -38,7 +42,7 @@ class DecompressorIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]): | |
|
||
Args: | ||
source_datapipe: IterDataPipe containing tuples of path and compressed stream of data | ||
file_type: Optional `string` or ``CompressionType`` that represents what compression format of the inputs | ||
file_type: Optional `string` or ``CompressionType`` that represents the compression format of the inputs | ||
|
||
Example: | ||
>>> from torchdata.datapipes.iter import FileLister, FileOpener | ||
|
@@ -58,6 +62,7 @@ class DecompressorIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]): | |
types.TAR: lambda file: tarfile.open(fileobj=file, mode="r:*"), | ||
types.ZIP: lambda file: zipfile.ZipFile(file=file), | ||
types.BZIP2: lambda file: bz2.BZ2File(filename=file), | ||
types.ZLIB: lambda file: _ZlibFile(file), | ||
} | ||
|
||
def __init__( | ||
|
@@ -87,6 +92,8 @@ def _detect_compression_type(self, path: str) -> CompressionType: | |
return self.types.ZIP | ||
elif ext == ".bz2": | ||
return self.types.BZIP2 | ||
elif ext == ".zlib": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There doesn't seem to be a strong standard for what extension zlib files should have. I've also seen .z and .zz. |
||
return self.types.ZLIB | ||
else: | ||
raise RuntimeError( | ||
f"File at {path} has file extension {ext}, which does not match what are supported by" | ||
|
@@ -114,3 +121,106 @@ def __new__( | |
cls, source_datapipe: IterDataPipe[Tuple[str, IOBase]], file_type: Optional[Union[str, CompressionType]] = None | ||
): | ||
return DecompressorIterDataPipe(source_datapipe, file_type) | ||
|
||
|
||
class _ZlibFile: | ||
""" | ||
A minimal read-only file object for decompressing zlib data. It's only intended to be wrapped by | ||
StreamWrapper and isn't intended to be used outside decompressor.py. It only supports the | ||
specific operations expected by StreamWrapper. | ||
""" | ||
|
||
def __init__(self, file) -> None: | ||
self._decompressor = zlib.decompressobj() | ||
self._file = file | ||
|
||
# Stores decompressed bytes leftover from a call to __next__: since __next__ only returns | ||
# up until the next newline, we need to store the bytes from after the newline for future | ||
# calls to read or __next__. | ||
self._buffer = bytearray() | ||
|
||
# Whether or not self._file still has bytes left to read | ||
self._file_exhausted = False | ||
|
||
def read(self, size: int = -1) -> bytearray: | ||
if size < 0: | ||
return self.readall() | ||
|
||
if not size: | ||
return bytearray() | ||
|
||
result = self._buffer[:size] | ||
self._buffer = self._buffer[size:] | ||
while len(result) < size and self._compressed_bytes_remain(): | ||
# If decompress was called previously, there might be some compressed bytes from a previous chunk | ||
# that haven't been decompressed yet (because decompress() was passed a max_length value that | ||
# didn't exhaust the compressed bytes in the chunk). We can retrieve these leftover bytes from | ||
# unconsumed_tail: | ||
chunk = self._decompressor.unconsumed_tail | ||
# Let's read compressed bytes in chunks of io.DEFAULT_BUFFER_SIZE because this is what python's gzip | ||
# library does: | ||
# https://github.com/python/cpython/blob/a6326972253bf5282c5bf422f4a16d93ace77b57/Lib/gzip.py#L505 | ||
if len(chunk) < io.DEFAULT_BUFFER_SIZE: | ||
compressed_bytes = self._file.read(io.DEFAULT_BUFFER_SIZE - len(chunk)) | ||
if compressed_bytes: | ||
chunk += compressed_bytes | ||
else: | ||
self._file_exhausted = True | ||
decompressed_chunk = self._decompressor.decompress(chunk, max_length=size - len(result)) | ||
result.extend(decompressed_chunk) | ||
if not self._compressed_bytes_remain() and not self._decompressor.eof: | ||
# There are no more compressed bytes available to decompress, but we haven't reached the | ||
# zlib EOF, so something is wrong. | ||
raise EOFError("Compressed file ended before the end-of-stream marker was reached") | ||
return result | ||
|
||
def readall(self): | ||
""" | ||
This is just mimicking python's internal DecompressReader.readall: | ||
https://github.com/python/cpython/blob/a6326972253bf5282c5bf422f4a16d93ace77b57/Lib/_compression.py#L113 | ||
""" | ||
chunks = [] | ||
# sys.maxsize means the max length of output buffer is unlimited, | ||
# so that the whole input buffer can be decompressed within one | ||
# .decompress() call. | ||
data = self.read(sys.maxsize) | ||
while data: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wasn't sure if python 3.7 support was desired, so I didn't use the walrus operator here. |
||
chunks.append(data) | ||
data = self.read(sys.maxsize) | ||
|
||
return b"".join(chunks) | ||
|
||
def __iter__(self): | ||
return self | ||
|
||
def __next__(self): | ||
if not self._buffer and not self._compressed_bytes_remain(): | ||
raise StopIteration | ||
|
||
# Check if the buffer already has a newline, in which case we don't need to do any decompression of | ||
# remaining bytes | ||
newline_index = self._buffer.find(b"\n") | ||
if newline_index != -1: | ||
line = self._buffer[: newline_index + 1] | ||
self._buffer = self._buffer[newline_index + 1 :] | ||
return line | ||
|
||
# Keep decompressing bytes until we find a newline or run out of bytes | ||
line = self._buffer | ||
self._buffer = bytearray() | ||
while self._compressed_bytes_remain(): | ||
decompressed_chunk = self.read(io.DEFAULT_BUFFER_SIZE) | ||
newline_index = decompressed_chunk.find(b"\n") | ||
if newline_index == -1: | ||
line.extend(decompressed_chunk) | ||
else: | ||
line.extend(decompressed_chunk[: newline_index + 1]) | ||
self._buffer.extend(decompressed_chunk[newline_index + 1 :]) | ||
return line | ||
return line | ||
|
||
def _compressed_bytes_remain(self) -> bool: | ||
""" | ||
True if there are compressed bytes still left to decompress. False otherwise. | ||
""" | ||
return not self._file_exhausted or self._decompressor.unconsumed_tail != b"" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't find an obvious way of using pytest parameterized tests that was compatible with
expecttest
, so I resorted to just manually iterating over a list of inputs.