Skip to content

Commit

Permalink
Moved HEAD to constructor, removed one copy and moved to CRT client
Browse files Browse the repository at this point in the history
  • Loading branch information
malhotrashivam committed Dec 21, 2023
1 parent 093567f commit d3e9da8
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -229,17 +229,6 @@ public boolean hasNext() {
return remainingValues > 0;
}

// TODO Move to a separate file
final class PositionedBufferedInputStream extends BufferedInputStream {
PositionedBufferedInputStream(final ReadableByteChannel readChannel, final int size) {
super(Channels.newInputStream(readChannel), size);
}

long position() throws IOException {
return this.pos;
}
}

@Override
public ColumnPageReader next() {
if (!hasNext()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,6 @@ public ColumnChunkReaderImpl getColumnChunk(@NotNull final List<String> path) {
if (columnChunk.isSetOffset_index_offset()) {
try (final SeekableByteChannel readChannel = channelsProvider.getReadChannel(rootPath)) {
readChannel.position(columnChunk.getOffset_index_offset());
// TODO Think if we need to reduce the buffer size.
// We read BUFFER_SIZE (=65536) number of bytes from the channel, which leads to a big read request to
// aws, even if the offset index is much smaller. Same thing happens for non aws parquet files too but
// reads are less expensive there.
offsetIndex = ParquetMetadataConverter.fromParquetOffsetIndex(Util.readOffsetIndex(
new BufferedInputStream(Channels.newInputStream(readChannel), BUFFER_SIZE)));
} catch (IOException e) {
Expand Down
2 changes: 1 addition & 1 deletion extensions/parquet/table/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ dependencies {
implementation depCommonsLang3
implementation platform('software.amazon.awssdk:bom:2.21.43')
implementation 'software.amazon.awssdk:s3'
implementation 'software.amazon.awssdk:apache-client:2.21.43'
implementation 'software.amazon.awssdk:aws-crt-client'

implementation("com.github.ben-manes.caffeine:caffeine:3.1.8")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@
import io.deephaven.configuration.Configuration;
import io.deephaven.parquet.base.util.SeekableChannelsProvider;
import org.jetbrains.annotations.NotNull;
import software.amazon.awssdk.http.async.SdkAsyncHttpClient;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.model.HeadObjectResponse;
import software.amazon.awssdk.http.crt.AwsCrtAsyncHttpClient;

import java.io.IOException;
import java.io.UnsupportedEncodingException;
Expand All @@ -16,16 +19,24 @@
import java.nio.ByteBuffer;
import java.nio.channels.SeekableByteChannel;
import java.nio.file.Path;
import java.time.Duration;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

public final class S3BackedSeekableChannelProvider implements SeekableChannelsProvider {

private final S3AsyncClient s3AsyncClient;
private final URI uri;
private final String s3uri, bucket, key;
private final long size;

private static final int MAX_CACHE_SIZE =
Configuration.getInstance().getIntegerWithDefault("s3.spi.read.max-cache-size", 10);
private static final int MAX_AWS_CONCURRENT_REQUESTS =
Configuration.getInstance().getIntegerWithDefault("s3.spi.read.max-concurrency", 20);

private final Cache<Integer, CompletableFuture<ByteBuffer>> readAheadBuffersCache;

public S3BackedSeekableChannelProvider(final String awsRegionName, final String uriStr) throws IOException {
Expand All @@ -35,26 +46,55 @@ public S3BackedSeekableChannelProvider(final String awsRegionName, final String
if (uriStr == null || uriStr.isEmpty()) {
throw new IllegalArgumentException("uri cannot be null or empty");
}
if (MAX_AWS_CONCURRENT_REQUESTS < 1) {
throw new IllegalArgumentException("maxConcurrency must be >= 1");
}

try {
uri = new URI(uriStr);
} catch (final URISyntaxException e) {
throw new UncheckedDeephavenException("Failed to parse URI " + uriStr, e);
}
this.s3uri = uriStr;
this.bucket = uri.getHost();
this.key = uri.getPath().substring(1);

final SdkAsyncHttpClient asyncHttpClient = AwsCrtAsyncHttpClient.builder()
.maxConcurrency(MAX_AWS_CONCURRENT_REQUESTS)
.connectionTimeout(Duration.ofSeconds(5))
.build();
s3AsyncClient = S3AsyncClient.builder()
.region(Region.of(awsRegionName))
.httpClient(asyncHttpClient)
.build();

if (MAX_CACHE_SIZE < 1)
throw new IllegalArgumentException("maxCacheSize must be >= 1");
this.readAheadBuffersCache = Caffeine.newBuilder().maximumSize(MAX_CACHE_SIZE).recordStats().build();

this.s3uri = uriStr;
this.bucket = uri.getHost();
this.key = uri.getPath().substring(1);
// Send HEAD request to S3 to get the size of the file
{
final long timeOut = 1L;
final TimeUnit unit = TimeUnit.MINUTES;

final HeadObjectResponse headObjectResponse;
try {
headObjectResponse = s3AsyncClient.headObject(builder -> builder
.bucket(bucket)
.key(key)).get(timeOut, unit);
} catch (final InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException(e);
} catch (final ExecutionException | TimeoutException e) {
throw new IOException(e);
}
this.size = headObjectResponse.contentLength();
}
}

@Override
public SeekableByteChannel getReadChannel(@NotNull final Path path) throws IOException {
return new S3SeekableByteChannel(s3uri, bucket, key, s3AsyncClient, 0, null, null, readAheadBuffersCache);
return new S3SeekableByteChannel(s3uri, bucket, key, s3AsyncClient, 0, size, null, null, readAheadBuffersCache);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import software.amazon.awssdk.core.BytesWrapper;
import software.amazon.awssdk.core.async.AsyncResponseTransformer;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.model.HeadObjectResponse;

import java.io.IOException;
import java.nio.ByteBuffer;
Expand All @@ -26,32 +25,33 @@
public final class S3SeekableByteChannel implements SeekableByteChannel {
private long position;
private final S3AsyncClient s3Client;
private boolean closed;
private volatile boolean closed;
private final String s3uri, bucket, key;
private final int maxFragmentSize;
private final int maxNumberFragments;
private final int numFragmentsInObject;
private volatile long size = -1L;
private final long size;
private final Long timeout;
private final TimeUnit timeUnit;
private boolean open;
private Cache<Integer, CompletableFuture<ByteBuffer>> readAheadBuffersCache;
private final Cache<Integer, CompletableFuture<ByteBuffer>> readAheadBuffersCache;

private final Logger logger = LoggerFactory.getLogger(this.getClass());

public static final int MAX_FRAGMENT_SIZE =
private static final int MAX_FRAGMENT_SIZE =
Configuration.getInstance().getIntegerWithDefault("s3.spi.read.max-fragment-size", 512 * 1024); // 512 KB
private static final int MAX_FRAGMENT_NUMBER =
Configuration.getInstance().getIntegerWithDefault("s3.spi.read.max-fragment-number", 2);


S3SeekableByteChannel(String s3uri, String bucket, String key, S3AsyncClient s3Client, long startAt,
S3SeekableByteChannel(String s3uri, String bucket, String key, S3AsyncClient s3Client, long startAt, long size,
Long timeout, TimeUnit timeUnit, final Cache<Integer, CompletableFuture<ByteBuffer>> readAheadBuffersCache) throws IOException {
Objects.requireNonNull(s3Client);
if (MAX_FRAGMENT_SIZE < 1)
throw new IllegalArgumentException("maxFragmentSize must be >= 1");
if (MAX_FRAGMENT_NUMBER < 1)
throw new IllegalArgumentException("maxNumberFragments must be >= 1");
if (size < 1)
throw new IllegalArgumentException("size must be >= 1");

this.position = startAt;
this.bucket = bucket;
Expand All @@ -63,12 +63,10 @@ public final class S3SeekableByteChannel implements SeekableByteChannel {
this.maxFragmentSize = MAX_FRAGMENT_SIZE;
this.readAheadBuffersCache = readAheadBuffersCache;
this.maxNumberFragments = MAX_FRAGMENT_NUMBER;
this.open = true;
this.timeout = timeout != null ? timeout : 5L;
this.timeUnit = timeUnit != null ? timeUnit : TimeUnit.MINUTES;

size(); // Will populate the size
this.numFragmentsInObject = (int) Math.ceil((float) size / (float) maxFragmentSize);
this.size = size;
this.numFragmentsInObject = (int) Math.ceil((double) size / maxFragmentSize);
}

/**
Expand All @@ -83,7 +81,7 @@ public final class S3SeekableByteChannel implements SeekableByteChannel {
* @return the number of bytes read or -1 if no more bytes can be read.
*/
@Override
public int read(ByteBuffer dst) throws IOException {
public int read(final ByteBuffer dst) throws IOException {
validateOpen();

Objects.requireNonNull(dst);
Expand All @@ -96,21 +94,18 @@ public int read(ByteBuffer dst) throws IOException {
}

//figure out the index of the fragment the bytes would start in
final Integer fragmentIndex = fragmentIndexForByteNumber(channelPosition);
final int fragmentOffset = (int) (channelPosition - (fragmentIndex.longValue() * maxFragmentSize));
final int fragmentIndex = fragmentIndexForByteNumber(channelPosition);
final int fragmentOffset = (int) (channelPosition - ((long) fragmentIndex * maxFragmentSize));
try {
final ByteBuffer fragment = Objects.requireNonNull(readAheadBuffersCache.get(fragmentIndex, this::computeFragmentFuture))
.get(timeout, timeUnit)
.asReadOnlyBuffer();

fragment.position(fragmentOffset);

//put the bytes from fragment from the offset upto the min of fragment remaining or dst remaining
fragment.position(fragmentOffset);
final int limit = Math.min(fragment.remaining(), dst.remaining());

final byte[] copiedBytes = new byte[limit];
fragment.get(copiedBytes, 0, limit);
dst.put(copiedBytes);
fragment.limit(fragment.position() + limit);
dst.put(fragment);

if (fragment.position() >= fragment.limit() / 2) {

Expand All @@ -128,8 +123,8 @@ public int read(ByteBuffer dst) throws IOException {
}
}

position(channelPosition + copiedBytes.length);
return copiedBytes.length;
position(channelPosition + limit);
return limit;

} catch (final InterruptedException e) {
Thread.currentThread().interrupt();
Expand All @@ -150,11 +145,11 @@ public int read(ByteBuffer dst) throws IOException {
* @param byteNumber the number of the byte in the object accessed by this channel
* @return the index of the fragment in which {@code byteNumber} will be found.
*/
private Integer fragmentIndexForByteNumber(final long byteNumber) {
return Integer.valueOf(Math.toIntExact(Math.floorDiv(byteNumber, (long) maxFragmentSize)));
private int fragmentIndexForByteNumber(final long byteNumber) {
return Math.toIntExact(Math.floorDiv(byteNumber, (long) maxFragmentSize));
}

private CompletableFuture<ByteBuffer> computeFragmentFuture(int fragmentIndex) {
private CompletableFuture<ByteBuffer> computeFragmentFuture(final int fragmentIndex) {
final long readFrom = (long) fragmentIndex * maxFragmentSize;
final long readTo = Math.min(readFrom + maxFragmentSize, size) - 1;
final String range = "bytes=" + readFrom + "-" + readTo;
Expand Down Expand Up @@ -251,27 +246,6 @@ public SeekableByteChannel position(long newPosition) throws IOException {
@Override
public long size() throws IOException {
validateOpen();

if (size < 0) {
long timeOut = 1L;
TimeUnit unit = TimeUnit.MINUTES;

synchronized (this) {
final HeadObjectResponse headObjectResponse;
try {
headObjectResponse = s3Client.headObject(builder -> builder
.bucket(bucket)
.key(key)).get(timeOut, unit);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException(e);
} catch (ExecutionException | TimeoutException e) {
throw new IOException(e);
}

this.size = headObjectResponse.contentLength();
}
}
return this.size;
}

Expand Down Expand Up @@ -314,7 +288,7 @@ public boolean isOpen() {
*
* @return the size of the cache after any async evictions or reloads have happened.
*/
protected int numberOfCachedFragments() {
int numberOfCachedFragments() {
readAheadBuffersCache.cleanUp();
return (int) readAheadBuffersCache.estimatedSize();
}
Expand All @@ -325,7 +299,7 @@ protected int numberOfCachedFragments() {
*
* @return the statistics of the internal cache.
*/
protected CacheStats cacheStatistics() {
CacheStats cacheStatistics() {
return readAheadBuffersCache.stats();
}

Expand Down

0 comments on commit d3e9da8

Please sign in to comment.