From 28513fca40c31e65f768399cf2058c8e74bd19da Mon Sep 17 00:00:00 2001 From: Amogh Jahagirdar Date: Tue, 20 Aug 2024 20:49:32 -0600 Subject: [PATCH] API, AWS: Add RetryableInputStream and use that in S3InputStream Co-authored-by: Jack Ye Co-authored-by: Xiaoxuan Li --- .../apache/iceberg/aws/s3/S3InputStream.java | 34 ++- .../aws/s3/TestFuzzyS3InputStream.java | 224 ++++++++++++++++++ .../iceberg/aws/s3/TestS3InputStream.java | 38 ++- build.gradle | 2 + .../iceberg/io/RetryableInputStream.java | 143 +++++++++++ gradle/libs.versions.toml | 2 + 6 files changed, 422 insertions(+), 21 deletions(-) create mode 100644 aws/src/test/java/org/apache/iceberg/aws/s3/TestFuzzyS3InputStream.java create mode 100644 core/src/main/java/org/apache/iceberg/io/RetryableInputStream.java diff --git a/aws/src/main/java/org/apache/iceberg/aws/s3/S3InputStream.java b/aws/src/main/java/org/apache/iceberg/aws/s3/S3InputStream.java index f1d6c30a27a5..83dd32d1e333 100644 --- a/aws/src/main/java/org/apache/iceberg/aws/s3/S3InputStream.java +++ b/aws/src/main/java/org/apache/iceberg/aws/s3/S3InputStream.java @@ -25,6 +25,7 @@ import org.apache.iceberg.io.FileIOMetricsContext; import org.apache.iceberg.io.IOUtil; import org.apache.iceberg.io.RangeReadable; +import org.apache.iceberg.io.RetryableInputStream; import org.apache.iceberg.io.SeekableInputStream; import org.apache.iceberg.metrics.Counter; import org.apache.iceberg.metrics.MetricsContext; @@ -92,13 +93,13 @@ public void seek(long newPos) { public int read() throws IOException { Preconditions.checkState(!closed, "Cannot read: already closed"); positionStream(); - + int bytesRead = stream.read(); pos += 1; next += 1; readBytes.increment(); readOperations.increment(); - return stream.read(); + return bytesRead; } @Override @@ -139,7 +140,11 @@ private InputStream readRange(String range) { S3RequestUtil.configureEncryption(s3FileIOProperties, requestBuilder); - return s3.getObject(requestBuilder.build(), ResponseTransformer.toInputStream()); + stream = + RetryableInputStream.builderFor( + () -> s3.getObject(requestBuilder.build(), ResponseTransformer.toInputStream())) + .build(); + return stream; } @Override @@ -178,18 +183,23 @@ private void positionStream() throws IOException { } private void openStream() throws IOException { - GetObjectRequest.Builder requestBuilder = - GetObjectRequest.builder() - .bucket(location.bucket()) - .key(location.key()) - .range(String.format("bytes=%s-", pos)); - - S3RequestUtil.configureEncryption(s3FileIOProperties, requestBuilder); - closeStream(); try { - stream = s3.getObject(requestBuilder.build(), ResponseTransformer.toInputStream()); + stream = + RetryableInputStream.builderFor( + rangeStart -> { + GetObjectRequest.Builder requestBuilder = + GetObjectRequest.builder() + .bucket(location.bucket()) + .key(location.key()) + .range(String.format("bytes=%s-", rangeStart)); + S3RequestUtil.configureEncryption(s3FileIOProperties, requestBuilder); + return s3.getObject( + requestBuilder.build(), ResponseTransformer.toInputStream()); + }, + () -> pos) + .build(); } catch (NoSuchKeyException e) { throw new NotFoundException(e, "Location does not exist: %s", location); } diff --git a/aws/src/test/java/org/apache/iceberg/aws/s3/TestFuzzyS3InputStream.java b/aws/src/test/java/org/apache/iceberg/aws/s3/TestFuzzyS3InputStream.java new file mode 100644 index 000000000000..fa9a2e90f7a7 --- /dev/null +++ b/aws/src/test/java/org/apache/iceberg/aws/s3/TestFuzzyS3InputStream.java @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.aws.s3; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; + +import java.io.IOException; +import java.io.InputStream; +import java.net.SocketTimeoutException; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Stream; +import javax.net.ssl.SSLException; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import software.amazon.awssdk.awscore.exception.AwsServiceException; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.core.sync.ResponseTransformer; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.CreateBucketRequest; +import software.amazon.awssdk.services.s3.model.CreateBucketResponse; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.HeadObjectRequest; +import software.amazon.awssdk.services.s3.model.HeadObjectResponse; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; + +public class TestFuzzyS3InputStream extends TestS3InputStream { + + private static final int DATA_SIZE = 100; + private static final int SEEK_SIZE = 4; + private static final int SEEK_NEW_POSITION = 25; + + @ParameterizedTest + @MethodSource("retryableExceptions") + public void testReadWithFuzzyStreamRetrySucceed(IOException exception) throws Exception { + testRead(fuzzyStreamClient(new AtomicInteger(3), exception), DATA_SIZE); + } + + @ParameterizedTest + @MethodSource("retryableExceptions") + public void testReadWithFuzzyStreamExhaustedRetries(IOException exception) { + assertThatThrownBy( + () -> testRead(fuzzyStreamClient(new AtomicInteger(5), exception), DATA_SIZE)) + .isInstanceOf(exception.getClass()) + .hasMessage(exception.getMessage()); + } + + @ParameterizedTest + @MethodSource("nonRetryableExceptions") + public void testReadWithFuzzyStreamNonRetryableException(IOException exception) { + assertThatThrownBy( + () -> testRead(fuzzyStreamClient(new AtomicInteger(3), exception), DATA_SIZE)) + .isInstanceOf(exception.getClass()) + .hasMessage(exception.getMessage()); + } + + @Override + protected void testRead(S3Client s3, int dataSize) throws Exception { + testRead(s3, DATA_SIZE, 4, SEEK_SIZE, SEEK_NEW_POSITION); + } + + private static Stream retryableExceptions() { + return Stream.of( + Arguments.of( + new SocketTimeoutException("socket timeout exception"), + new SSLException("some ssl exception"))); + } + + private static Stream nonRetryableExceptions() { + return Stream.of(Arguments.of(new IOException("some generic non-retryable IO exception"))); + } + + private S3ClientWrapper fuzzyStreamClient(AtomicInteger counter, IOException failure) { + S3ClientWrapper fuzzyClient = spy(new S3ClientWrapper(s3Client())); + doAnswer( + invocation -> + new FuzzyResponseInputStream(invocation.callRealMethod(), counter, failure)) + .when(fuzzyClient) + .getObject(any(GetObjectRequest.class), any(ResponseTransformer.class)); + return fuzzyClient; + } + + /** Wrapper for S3 client, used to mock the final class DefaultS3Client */ + public static class S3ClientWrapper implements S3Client { + + private final S3Client delegate; + + public S3ClientWrapper(S3Client delegate) { + this.delegate = delegate; + } + + @Override + public String serviceName() { + return delegate.serviceName(); + } + + @Override + public void close() { + delegate.close(); + } + + @Override + public ReturnT getObject( + GetObjectRequest getObjectRequest, + ResponseTransformer responseTransformer) + throws AwsServiceException, SdkClientException { + return delegate.getObject(getObjectRequest, responseTransformer); + } + + @Override + public HeadObjectResponse headObject(HeadObjectRequest headObjectRequest) + throws AwsServiceException, SdkClientException { + return delegate.headObject(headObjectRequest); + } + + @Override + public PutObjectResponse putObject(PutObjectRequest putObjectRequest, RequestBody requestBody) + throws AwsServiceException, SdkClientException { + return delegate.putObject(putObjectRequest, requestBody); + } + + @Override + public CreateBucketResponse createBucket(CreateBucketRequest createBucketRequest) + throws AwsServiceException, SdkClientException { + return delegate.createBucket(createBucketRequest); + } + } + + static class FuzzyResponseInputStream extends InputStream { + + private final ResponseInputStream delegate; + private final AtomicInteger counter; + private final int round; + private final IOException exception; + + FuzzyResponseInputStream( + Object invocationResponse, AtomicInteger counter, IOException exception) { + this.delegate = (ResponseInputStream) invocationResponse; + this.counter = counter; + this.round = counter.get(); + this.exception = exception; + } + + private void checkCounter() throws IOException { + // for every round of n invocations, only the last call succeeds + if (counter.decrementAndGet() == 0) { + counter.set(round); + } else { + throw exception; + } + } + + @Override + public int read() throws IOException { + checkCounter(); + return delegate.read(); + } + + @Override + public int read(byte[] b) throws IOException { + checkCounter(); + return delegate.read(b); + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + checkCounter(); + return delegate.read(b, off, len); + } + + @Override + public long skip(long n) throws IOException { + return delegate.skip(n); + } + + @Override + public int available() throws IOException { + return delegate.available(); + } + + @Override + public void close() throws IOException { + delegate.close(); + } + + @Override + public synchronized void mark(int readlimit) { + delegate.mark(readlimit); + } + + @Override + public synchronized void reset() throws IOException { + delegate.reset(); + } + + @Override + public boolean markSupported() { + return delegate.markSupported(); + } + } +} diff --git a/aws/src/test/java/org/apache/iceberg/aws/s3/TestS3InputStream.java b/aws/src/test/java/org/apache/iceberg/aws/s3/TestS3InputStream.java index ed71e259a26c..90c060eb5bf5 100644 --- a/aws/src/test/java/org/apache/iceberg/aws/s3/TestS3InputStream.java +++ b/aws/src/test/java/org/apache/iceberg/aws/s3/TestS3InputStream.java @@ -28,6 +28,7 @@ import org.apache.iceberg.io.IOUtil; import org.apache.iceberg.io.RangeReadable; import org.apache.iceberg.io.SeekableInputStream; +import org.apache.iceberg.metrics.MetricsContext; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -54,21 +55,26 @@ public void before() { @Test public void testRead() throws Exception { + testRead(s3, 10 * 1024 * 1024); + } + + protected void testRead(S3Client s3Client, int dataSize) throws Exception { + testRead(s3Client, 10 * 1024 * 1024, 1024, 1024, 2 * 1024 * 1024); + } + + protected void testRead( + S3Client s3Client, int dataSize, int readSize, int seekSize, int seekNewStreamPosition) + throws Exception { S3URI uri = new S3URI("s3://bucket/path/to/read.dat"); - int dataSize = 1024 * 1024 * 10; byte[] data = randomData(dataSize); writeS3Data(uri, data); - try (SeekableInputStream in = new S3InputStream(s3, uri)) { - int readSize = 1024; - byte[] actual = new byte[readSize]; - + try (SeekableInputStream in = new S3InputStream(s3Client, uri)) { readAndCheck(in, in.getPos(), readSize, data, false); readAndCheck(in, in.getPos(), readSize, data, true); // Seek forward in current stream - int seekSize = 1024; readAndCheck(in, in.getPos() + seekSize, readSize, data, false); readAndCheck(in, in.getPos() + seekSize, readSize, data, true); @@ -77,7 +83,6 @@ public void testRead() throws Exception { readAndCheck(in, in.getPos(), readSize, data, false); // Seek with new stream - long seekNewStreamPosition = 2 * 1024 * 1024; readAndCheck(in, in.getPos() + seekNewStreamPosition, readSize, data, true); readAndCheck(in, in.getPos() + seekNewStreamPosition, readSize, data, false); @@ -111,6 +116,11 @@ private void readAndCheck( @Test public void testRangeRead() throws Exception { + testRangeRead(s3, new S3FileIOProperties()); + } + + protected void testRangeRead(S3Client s3Client, S3FileIOProperties awsProperties) + throws Exception { S3URI uri = new S3URI("s3://bucket/path/to/range-read.dat"); int dataSize = 1024 * 1024 * 10; byte[] expected = randomData(dataSize); @@ -122,7 +132,8 @@ public void testRangeRead() throws Exception { writeS3Data(uri, expected); - try (RangeReadable in = new S3InputStream(s3, uri)) { + try (RangeReadable in = + new S3InputStream(s3Client, uri, awsProperties, MetricsContext.nullMetrics())) { // first 1k position = 0; offset = 0; @@ -163,12 +174,17 @@ public void testClose() throws Exception { @Test public void testSeek() throws Exception { + testSeek(s3, new S3FileIOProperties()); + } + + protected void testSeek(S3Client s3Client, S3FileIOProperties awsProperties) throws Exception { S3URI uri = new S3URI("s3://bucket/path/to/seek.dat"); byte[] expected = randomData(1024 * 1024); writeS3Data(uri, expected); - try (SeekableInputStream in = new S3InputStream(s3, uri)) { + try (SeekableInputStream in = + new S3InputStream(s3Client, uri, awsProperties, MetricsContext.nullMetrics())) { in.seek(expected.length / 2); byte[] actual = new byte[expected.length / 2]; IOUtil.readFully(in, actual, 0, expected.length / 2); @@ -200,4 +216,8 @@ private void createBucket(String bucketName) { // don't do anything } } + + protected S3Client s3Client() { + return s3; + } } diff --git a/build.gradle b/build.gradle index 7a11943cf8be..08d589589c95 100644 --- a/build.gradle +++ b/build.gradle @@ -347,6 +347,7 @@ project(':iceberg-core') { implementation libs.jackson.core implementation libs.jackson.databind implementation libs.caffeine + implementation libs.failsafe implementation libs.roaringbitmap compileOnly(libs.hadoop2.client) { exclude group: 'org.apache.avro', module: 'avro' @@ -462,6 +463,7 @@ project(':iceberg-aws') { annotationProcessor libs.immutables.value compileOnly libs.immutables.value implementation libs.caffeine + implementation libs.failsafe implementation platform(libs.jackson.bom) implementation libs.jackson.core implementation libs.jackson.databind diff --git a/core/src/main/java/org/apache/iceberg/io/RetryableInputStream.java b/core/src/main/java/org/apache/iceberg/io/RetryableInputStream.java new file mode 100644 index 000000000000..893014b953a1 --- /dev/null +++ b/core/src/main/java/org/apache/iceberg/io/RetryableInputStream.java @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.io; + +import dev.failsafe.Failsafe; +import dev.failsafe.FailsafeException; +import dev.failsafe.RetryPolicy; +import dev.failsafe.RetryPolicyBuilder; +import java.io.IOException; +import java.io.InputStream; +import java.net.SocketException; +import java.net.SocketTimeoutException; +import java.time.Duration; +import java.util.List; +import java.util.function.Function; +import java.util.function.Supplier; +import javax.net.ssl.SSLException; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; + +/** + * RetryableInputStream wraps over an underlying InputStream and retries failures encountered when + * reading through the stream. On retries, the underlying streams will be reinitialized. + */ +public class RetryableInputStream extends InputStream { + + private final InputStream underlyingStream; + private final RetryPolicy retryPolicy; + + private RetryableInputStream(InputStream underlyingStream, RetryPolicy retryPolicy) { + this.underlyingStream = underlyingStream; + this.retryPolicy = retryPolicy; + } + + @Override + public int read() throws IOException { + try { + return Failsafe.with(retryPolicy).get(() -> underlyingStream.read()); + } catch (FailsafeException ex) { + if (ex.getCause() instanceof IOException) { + throw (IOException) ex.getCause(); + } + + throw ex; + } + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + try { + return Failsafe.with(retryPolicy).get(() -> underlyingStream.read(b, off, len)); + } catch (FailsafeException ex) { + if (ex.getCause() instanceof IOException) { + throw (IOException) ex.getCause(); + } + + throw ex; + } + } + + @Override + public void close() throws IOException { + underlyingStream.close(); + } + + public static RetryableInputStream.Builder builderFor(Supplier newStreamSupplier) { + return new Builder(unusedPosition -> newStreamSupplier.get(), null); + } + + public static RetryableInputStream.Builder builderFor( + Function newStreamAtPositionSupplier, Supplier positionSupplier) { + Preconditions.checkArgument( + newStreamAtPositionSupplier != null, "New stream supplier cannot be null"); + Preconditions.checkArgument( + positionSupplier != null, "Stream position supplier cannot be null"); + return new Builder(newStreamAtPositionSupplier, positionSupplier); + } + + public static class Builder { + private final Supplier positionSupplier; + private InputStream underlyingStream; + private final Function newStreamProvider; + private List> retryableExceptions = + ImmutableList.of(SSLException.class, SocketTimeoutException.class, SocketException.class); + private int numRetries = 3; + private long delayMs = 400; + + private Builder( + Function newStreamProvider, Supplier positionSupplier) { + this.newStreamProvider = newStreamProvider; + this.positionSupplier = positionSupplier == null ? () -> null : positionSupplier; + initializeUnderlyingStream(); + } + + public Builder retryOn(Class... exceptions) { + this.retryableExceptions = Lists.newArrayList(exceptions); + return this; + } + + public Builder withRetries(int numberRetries) { + this.numRetries = numberRetries; + return this; + } + + public Builder withRetryDelay(long delayMillis) { + this.delayMs = delayMillis; + return this; + } + + public RetryableInputStream build() { + RetryPolicyBuilder retryPolicyBuilder = RetryPolicy.builder(); + retryableExceptions.forEach(retryPolicyBuilder::handle); + retryPolicyBuilder.onRetry((event) -> initializeUnderlyingStream()); + return new RetryableInputStream( + underlyingStream, + retryPolicyBuilder + .withMaxRetries(numRetries) + .withDelay(Duration.ofMillis(delayMs)) + .build()); + } + + private void initializeUnderlyingStream() { + this.underlyingStream = newStreamProvider.apply(positionSupplier.get()); + } + } +} diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 77e610e885f6..35f3f8775a66 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -38,6 +38,7 @@ delta-standalone = "3.2.0" delta-spark = "3.2.0" esotericsoftware-kryo = "4.0.3" errorprone-annotations = "2.29.2" +failsafe = "3.3.2" findbugs-jsr305 = "3.0.2" flink118 = { strictly = "1.18.1"} flink119 = { strictly = "1.19.0"} @@ -107,6 +108,7 @@ calcite-druid = { module = "org.apache.calcite:calcite-druid", version.ref = "ca datasketches = { module = "org.apache.datasketches:datasketches-java", version.ref = "datasketches" } delta-standalone = { module = "io.delta:delta-standalone_2.12", version.ref = "delta-standalone" } errorprone-annotations = { module = "com.google.errorprone:error_prone_annotations", version.ref = "errorprone-annotations" } +failsafe = { module = "dev.failsafe:failsafe", version.ref = "failsafe"} findbugs-jsr305 = { module = "com.google.code.findbugs:jsr305", version.ref = "findbugs-jsr305" } flink118-avro = { module = "org.apache.flink:flink-avro", version.ref = "flink118" } flink118-connector-base = { module = "org.apache.flink:flink-connector-base", version.ref = "flink118" }