Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements to buffered reading for parquet #5611

Merged
merged 4 commits into from
Jun 15, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ public SeekableByteChannel getReadChannel(@NotNull final SeekableChannelContext
}

@Override
public InputStream getInputStream(SeekableByteChannel channel) throws IOException {
return wrappedProvider.getInputStream(channel);
public InputStream getInputStream(final SeekableByteChannel channel, final int sizeHint) throws IOException {
return wrappedProvider.getInputStream(channel, sizeHint);
}

@Override
Expand All @@ -115,7 +115,7 @@ public SeekableByteChannel getWriteChannel(@NotNull final Path path, final boole
return result == null
? new CachedChannel(wrappedProvider.getWriteChannel(path, append), channelType, pathKey)
: result.position(append ? result.size() : 0); // The seek isn't really necessary for append; will be at
// end no matter what.
// end no matter what.
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import java.util.stream.Stream;

public class LocalFSChannelProvider implements SeekableChannelsProvider {
private static final int MAX_READ_BUFFER_SIZE = 1 << 16; // 64 KiB

@Override
public SeekableChannelContext makeContext() {
// No additional context required for local FS
Expand All @@ -40,9 +42,10 @@ public SeekableByteChannel getReadChannel(@Nullable final SeekableChannelContext
}

@Override
public InputStream getInputStream(SeekableByteChannel channel) {
public InputStream getInputStream(final SeekableByteChannel channel, final int sizeHint) {
// FileChannel is not buffered, need to buffer
return new BufferedInputStream(Channels.newInputStreamNoClose(channel));
final int bufferSize = Math.min(sizeHint, MAX_READ_BUFFER_SIZE);
return new BufferedInputStream(Channels.newInputStreamNoClose(channel), bufferSize);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,23 @@
public interface SeekableChannelsProvider extends SafeCloseable {

/**
* Wraps {@link SeekableChannelsProvider#getInputStream(SeekableByteChannel)} to ensure the channel's position is
* incremented the exact amount that has been consumed from the resulting input stream. To remain valid, the caller
* must ensure that the resulting input stream isn't re-wrapped by any downstream code in a way that would adversely
* affect the position (such as re-wrapping the resulting input stream with buffering).
* Wraps {@link SeekableChannelsProvider#getInputStream(SeekableByteChannel, int)} to ensure the channel's position
* is incremented the exact amount that has been consumed from the resulting input stream. To remain valid, the
* caller must ensure that the resulting input stream isn't re-wrapped by any downstream code in a way that would
* adversely affect the position (such as re-wrapping the resulting input stream with buffering).
*
* <p>
* Equivalent to {@code ChannelPositionInputStream.of(ch, provider.getInputStream(ch))}.
* Equivalent to {@code ChannelPositionInputStream.of(ch, provider.getInputStream(ch, sizeHint))}.
*
* @param provider the provider
* @param ch the seekable channel
* @return the position-safe input stream
malhotrashivam marked this conversation as resolved.
Show resolved Hide resolved
* @throws IOException if an IO exception occurs
* @see ChannelPositionInputStream#of(SeekableByteChannel, InputStream)
*/
static InputStream channelPositionInputStream(SeekableChannelsProvider provider, SeekableByteChannel ch)
throws IOException {
return ChannelPositionInputStream.of(ch, provider.getInputStream(ch));
static InputStream channelPositionInputStream(SeekableChannelsProvider provider, SeekableByteChannel ch,
int sizeHint) throws IOException {
return ChannelPositionInputStream.of(ch, provider.getInputStream(ch, sizeHint));
}

/**
Expand Down Expand Up @@ -66,20 +66,21 @@ SeekableByteChannel getReadChannel(@NotNull SeekableChannelContext channelContex
throws IOException;

/**
* Creates an {@link InputStream} from the current position of {@code channel}; closing the resulting input stream
* does <i>not</i> close the {@code channel}. The {@link InputStream} will be buffered; either explicitly in the
* case where the implementation uses an unbuffered {@link #getReadChannel(SeekableChannelContext, URI)}, or
* implicitly when the implementation uses a buffered {@link #getReadChannel(SeekableChannelContext, URI)}.
* {@code channel} must have been created by {@code this} provider. The caller can't assume the position of
* {@code channel} after consuming the {@link InputStream}. For use-cases that require the channel's position to be
* incremented the exact amount the {@link InputStream} has been consumed, use
* {@link #channelPositionInputStream(SeekableChannelsProvider, SeekableByteChannel)}.
* Creates an {@link InputStream} from the current position of {@code channel} from which the caller expects to read
* {@code sizeHint} number of bytes. Closing the resulting input stream does <i>not</i> close the {@code channel}.
* The {@link InputStream} will be buffered; either explicitly in the case where the implementation uses an
* unbuffered {@link #getReadChannel(SeekableChannelContext, URI)}, or implicitly when the implementation uses a
* buffered {@link #getReadChannel(SeekableChannelContext, URI)}. {@code channel} must have been created by
* {@code this} provider. The caller can't assume the position of {@code channel} after consuming the
* {@link InputStream}. For use-cases that require the channel's position to be incremented the exact amount the
* {@link InputStream} has been consumed, use
* {@link #channelPositionInputStream(SeekableChannelsProvider, SeekableByteChannel, int)}.
*
* @param channel the channel
malhotrashivam marked this conversation as resolved.
Show resolved Hide resolved
* @return the input stream
* @throws IOException if an IO exception occurs
*/
InputStream getInputStream(SeekableByteChannel channel) throws IOException;
InputStream getInputStream(SeekableByteChannel channel, int sizeHint) throws IOException;

default SeekableByteChannel getWriteChannel(@NotNull final String path, final boolean append) throws IOException {
return getWriteChannel(Paths.get(path), append);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ public SeekableByteChannel getReadChannel(@NotNull SeekableChannelContext channe
}

@Override
public InputStream getInputStream(SeekableByteChannel channel) {
public InputStream getInputStream(SeekableByteChannel channel, int sizeHint) {
// TestMockChannel is always empty, so no need to buffer
return Channels.newInputStreamNoClose(channel);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,15 +191,7 @@ private Dictionary getDictionary(final SeekableChannelContext channelContext) {
} else {
return NULL_DICTIONARY;
}
// Use the context object provided by the caller, or create (and close) a new one
try (
final ContextHolder holder = SeekableChannelContext.ensureContext(channelsProvider, channelContext);
final SeekableByteChannel ch = channelsProvider.getReadChannel(holder.get(), getURI());
final InputStream in = channelsProvider.getInputStream(ch.position(dictionaryPageOffset))) {
return readDictionary(in, holder.get());
} catch (IOException e) {
throw new UncheckedIOException(e);
}
return readDictionary(dictionaryPageOffset, channelContext);
}

@Override
Expand All @@ -218,28 +210,38 @@ public SeekableChannelsProvider getChannelsProvider() {
}

@NotNull
private Dictionary readDictionary(InputStream in, SeekableChannelContext channelContext) throws IOException {
// explicitly not closing this, caller is responsible
final PageHeader pageHeader = Util.readPageHeader(in);
if (pageHeader.getType() != PageType.DICTIONARY_PAGE) {
// In case our fallback in getDictionary was too optimistic...
return NULL_DICTIONARY;
}
final DictionaryPageHeader dictHeader = pageHeader.getDictionary_page_header();
final int compressedPageSize = pageHeader.getCompressed_page_size();
final BytesInput payload;
if (compressedPageSize == 0) {
// Sometimes the size is explicitly empty, just use an empty payload
payload = BytesInput.empty();
} else {
payload = decompressor.decompress(in, compressedPageSize, pageHeader.getUncompressed_page_size(),
channelContext);
private Dictionary readDictionary(long dictionaryPageOffset, SeekableChannelContext channelContext) {
// Use the context object provided by the caller, or create (and close) a new one
try (
final ContextHolder holder = SeekableChannelContext.ensureContext(channelsProvider, channelContext);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the original code, we used to make the channel and stream in the calling method and this method would just use the same stream and not touch the underlying channel.
Now we make two streams, one for header and one for data. And we use the same channel.

Note that the channel's position gets updated after reading the header.
So I wanted to make the channel's lifecycle limited to this method so that no one else should depend on or use this channel. That is why I moved the logic for making the channel inside this method.

final SeekableByteChannel ch =
channelsProvider.getReadChannel(holder.get(), getURI()).position(dictionaryPageOffset)) {
final PageHeader pageHeader = readPageHeader(ch);
if (pageHeader.getType() != PageType.DICTIONARY_PAGE) {
// In case our fallback in getDictionary was too optimistic...
return NULL_DICTIONARY;
}
final DictionaryPageHeader dictHeader = pageHeader.getDictionary_page_header();
final int compressedPageSize = pageHeader.getCompressed_page_size();
final BytesInput payload;
try (final InputStream in = (compressedPageSize == 0) ? null
: channelsProvider.getInputStream(ch, compressedPageSize)) {
if (compressedPageSize == 0) {
// Sometimes the size is explicitly empty, just use an empty payload
payload = BytesInput.empty();
} else {
payload = decompressor.decompress(in, compressedPageSize, pageHeader.getUncompressed_page_size(),
holder.get());
}
final Encoding encoding = Encoding.valueOf(dictHeader.getEncoding().name());
final DictionaryPage dictionaryPage = new DictionaryPage(payload, dictHeader.getNum_values(), encoding);
// We are safe to not copy the payload because the Dictionary doesn't hold a reference to dictionaryPage
// or payload and thus doesn't hold a reference to the input stream.
return encoding.initDictionary(path, dictionaryPage);
}
} catch (IOException e) {
throw new UncheckedIOException(e);
}
final Encoding encoding = Encoding.valueOf(dictHeader.getEncoding().name());
final DictionaryPage dictionaryPage = new DictionaryPage(payload, dictHeader.getNum_values(), encoding);
// We are safe to not copy the payload because the Dictionary doesn't hold a reference to dictionaryPage or
// payload and thus doesn't hold a reference to the input stream.
return encoding.initDictionary(path, dictionaryPage);
}

private final class ColumnPageReaderIteratorImpl implements ColumnPageReaderIterator {
Expand Down Expand Up @@ -314,8 +316,12 @@ private org.apache.parquet.format.Encoding getEncoding(final PageHeader pageHead
}
}

/**
* Read the page header from the given channel and increment the channel position by the number of bytes read.
*/
private PageHeader readPageHeader(final SeekableByteChannel ch) throws IOException {
try (final InputStream in = SeekableChannelsProvider.channelPositionInputStream(channelsProvider, ch)) {
// We expect page headers to be smaller than 128 bytes
try (final InputStream in = SeekableChannelsProvider.channelPositionInputStream(channelsProvider, ch, 128)) {
return Util.readPageHeader(in);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ private int readRowCountFromDataPage(
@NotNull final SeekableChannelContext channelContext) throws IOException {
switch (pageHeader.type) {
case DATA_PAGE:
try (final InputStream in = channelsProvider.getInputStream(ch)) {
try (final InputStream in = channelsProvider.getInputStream(ch, pageHeader.getCompressed_page_size())) {
return readRowCountFromPageV1(readV1Unsafe(in, channelContext));
}
case DATA_PAGE_V2:
Expand All @@ -225,12 +225,12 @@ private IntBuffer readKeysFromDataPage(
@NotNull final SeekableChannelContext channelContext) throws IOException {
switch (pageHeader.type) {
case DATA_PAGE:
try (final InputStream in = channelsProvider.getInputStream(ch)) {
try (final InputStream in = channelsProvider.getInputStream(ch, pageHeader.getCompressed_page_size())) {
return readKeysFromPageV1(readV1Unsafe(in, channelContext), keyDest, nullPlaceholder,
channelContext);
}
case DATA_PAGE_V2:
try (final InputStream in = channelsProvider.getInputStream(ch)) {
try (final InputStream in = channelsProvider.getInputStream(ch, pageHeader.getCompressed_page_size())) {
return readKeysFromPageV2(readV2Unsafe(in, channelContext), keyDest, nullPlaceholder,
channelContext);
}
Expand All @@ -246,11 +246,11 @@ private Object readDataPage(
@NotNull final SeekableChannelContext channelContext) throws IOException {
switch (pageHeader.type) {
case DATA_PAGE:
try (final InputStream in = channelsProvider.getInputStream(ch)) {
try (final InputStream in = channelsProvider.getInputStream(ch, pageHeader.getCompressed_page_size())) {
return readPageV1(readV1Unsafe(in, channelContext), nullValue, channelContext);
}
case DATA_PAGE_V2:
try (final InputStream in = channelsProvider.getInputStream(ch)) {
try (final InputStream in = channelsProvider.getInputStream(ch, pageHeader.getCompressed_page_size())) {
return readPageV2(readV2Unsafe(in, channelContext), nullValue, channelContext);
}
default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ private OffsetIndex readOffsetIndex(@NotNull final SeekableChannelContext channe
SeekableChannelContext.ensureContext(channelsProvider, channelContext);
final SeekableByteChannel readChannel = channelsProvider.getReadChannel(holder.get(), columnChunkURI);
final InputStream in =
channelsProvider.getInputStream(readChannel.position(columnChunk.getOffset_index_offset()))) {
channelsProvider.getInputStream(readChannel.position(columnChunk.getOffset_index_offset()),
columnChunk.getOffset_index_length())) {
return (offsetIndex = ParquetMetadataConverter.fromParquetOffsetIndex(Util.readOffsetIndex(in)));
} catch (final IOException e) {
throw new UncheckedIOException(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,20 @@ private ParquetFileReader(
try (
final SeekableChannelContext context = channelsProvider.makeSingleUseContext();
final SeekableByteChannel ch = channelsProvider.getReadChannel(context, parquetFileURI)) {
positionToFileMetadata(parquetFileURI, ch);
try (final InputStream in = channelsProvider.getInputStream(ch)) {
final int footerLength = positionToFileMetadata(parquetFileURI, ch);
try (final InputStream in = channelsProvider.getInputStream(ch, footerLength)) {
fileMetaData = Util.readFileMetaData(in);
}
}
type = fromParquetSchema(fileMetaData.schema, fileMetaData.column_orders);
}

private static void positionToFileMetadata(URI parquetFileURI, SeekableByteChannel readChannel) throws IOException {
/**
* Read the footer length and position the channel to the start of the footer.
*
* @return The length of the footer
*/
private static int positionToFileMetadata(URI parquetFileURI, SeekableByteChannel readChannel) throws IOException {
final long fileLen = readChannel.size();
if (fileLen < MAGIC.length + FOOTER_LENGTH_SIZE + MAGIC.length) { // MAGIC + data + footer +
// footerIndex + MAGIC
Expand All @@ -128,6 +133,7 @@ private static void positionToFileMetadata(URI parquetFileURI, SeekableByteChann
"corrupted file: the footer index is not within the file: " + footerIndex);
}
readChannel.position(footerIndex);
return footerLength;
}

private static int makeLittleEndianInt(byte b0, byte b1, byte b2, byte b3) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -521,14 +521,10 @@ public void test_lz4_compressed() {

final Table fromDisk = checkSingleTable(table, dest).select();

try {
// The following file is tagged as LZ4 compressed based on its metadata, but is actually compressed with
// LZ4_RAW. We should be able to read it anyway with no exceptions.
String path = TestParquetTools.class.getResource("/sample_lz4_compressed.parquet").getFile();
readTable(path, EMPTY.withLayout(ParquetInstructions.ParquetFileLayout.SINGLE_FILE)).select();
} catch (RuntimeException e) {
TestCase.fail("Failed to read parquet file sample_lz4_compressed.parquet");
}
// The following file is tagged as LZ4 compressed based on its metadata, but is actually compressed with
// LZ4_RAW. We should be able to read it anyway with no exceptions.
String path = TestParquetTools.class.getResource("/sample_lz4_compressed.parquet").getFile();
readTable(path, EMPTY.withLayout(ParquetInstructions.ParquetFileLayout.SINGLE_FILE)).select();
final File randomDest = new File(rootFile, "random.parquet");
writeTable(fromDisk, randomDest.getPath(), ParquetTools.LZ4_RAW);

Expand Down Expand Up @@ -1656,7 +1652,7 @@ public void testVersionChecks() {

/**
* Reference data is generated using the following code:
*
*
* <pre>
* num_rows = 100000
* dh_table = empty_table(num_rows).update(formulas=[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ public SeekableByteChannel getReadChannel(@NotNull final SeekableChannelContext
}

@Override
public InputStream getInputStream(final SeekableByteChannel channel) {
public InputStream getInputStream(final SeekableByteChannel channel, final int sizeHint) {
// S3SeekableByteChannel is internally buffered, no need to re-buffer
return Channels.newInputStreamNoClose(channel);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
*/
final class TrackedSeekableChannelsProvider implements SeekableChannelsProvider {

private static final int MAX_READ_BUFFER_SIZE = 1 << 16; // 64 KiB

private final TrackedFileHandleFactory fileHandleFactory;

TrackedSeekableChannelsProvider(@NotNull final TrackedFileHandleFactory fileHandleFactory) {
Expand All @@ -59,9 +61,10 @@ public SeekableByteChannel getReadChannel(@Nullable final SeekableChannelContext
}

@Override
public InputStream getInputStream(SeekableByteChannel channel) {
// TrackedSeekableByteChannel is not buffered, need to buffer
return new BufferedInputStream(Channels.newInputStreamNoClose(channel));
public InputStream getInputStream(SeekableByteChannel channel, int sizeHint) {
// The following stream will read from the channel in chunks of bufferSize bytes
final int bufferSize = Math.min(sizeHint, MAX_READ_BUFFER_SIZE);
return new BufferedInputStream(Channels.newInputStreamNoClose(channel), bufferSize);
}

@Override
Expand Down
Loading