From 90163ae5d85de0e3413acd75affb6416f2933e4d Mon Sep 17 00:00:00 2001 From: Amogh Jahagirdar Date: Tue, 24 Sep 2024 10:26:55 -0600 Subject: [PATCH] API, AWS: Retry S3InputStream reads (#10433) Co-authored-by: Jack Ye Co-authored-by: Xiaoxuan Li --- .../apache/iceberg/aws/s3/S3InputStream.java | 71 ++++-- .../aws/s3/TestFlakyS3InputStream.java | 206 ++++++++++++++++++ .../iceberg/aws/s3/TestS3InputStream.java | 24 +- build.gradle | 2 + gradle/libs.versions.toml | 2 + 5 files changed, 285 insertions(+), 20 deletions(-) create mode 100644 aws/src/test/java/org/apache/iceberg/aws/s3/TestFlakyS3InputStream.java diff --git a/aws/src/main/java/org/apache/iceberg/aws/s3/S3InputStream.java b/aws/src/main/java/org/apache/iceberg/aws/s3/S3InputStream.java index f442a0f04a1c..74e602a27378 100644 --- a/aws/src/main/java/org/apache/iceberg/aws/s3/S3InputStream.java +++ b/aws/src/main/java/org/apache/iceberg/aws/s3/S3InputStream.java @@ -18,9 +18,15 @@ */ package org.apache.iceberg.aws.s3; +import dev.failsafe.Failsafe; +import dev.failsafe.FailsafeException; +import dev.failsafe.RetryPolicy; import java.io.IOException; import java.io.InputStream; +import java.net.SocketException; +import java.net.SocketTimeoutException; import java.util.Arrays; +import javax.net.ssl.SSLException; import org.apache.iceberg.exceptions.NotFoundException; import org.apache.iceberg.io.FileIOMetricsContext; import org.apache.iceberg.io.IOUtil; @@ -31,6 +37,7 @@ import org.apache.iceberg.metrics.MetricsContext.Unit; import org.apache.iceberg.relocated.com.google.common.base.Joiner; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; import org.apache.iceberg.relocated.com.google.common.io.ByteStreams; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -57,6 +64,14 @@ class S3InputStream extends SeekableInputStream implements RangeReadable { private final Counter readOperations; private int skipSize = 1024 * 1024; + private RetryPolicy retryPolicy = + RetryPolicy.builder() + .handle( + ImmutableList.of( + SSLException.class, SocketTimeoutException.class, SocketException.class)) + .onFailure(failure -> openStream(true)) + .withMaxRetries(3) + .build(); S3InputStream(S3Client s3, S3URI location) { this(s3, location, new S3FileIOProperties(), MetricsContext.nullMetrics()); @@ -92,13 +107,21 @@ public void seek(long newPos) { public int read() throws IOException { Preconditions.checkState(!closed, "Cannot read: already closed"); positionStream(); + try { + int bytesRead = Failsafe.with(retryPolicy).get(() -> stream.read()); + pos += 1; + next += 1; + readBytes.increment(); + readOperations.increment(); + + return bytesRead; + } catch (FailsafeException ex) { + if (ex.getCause() instanceof IOException) { + throw (IOException) ex.getCause(); + } - pos += 1; - next += 1; - readBytes.increment(); - readOperations.increment(); - - return stream.read(); + throw ex; + } } @Override @@ -106,13 +129,21 @@ public int read(byte[] b, int off, int len) throws IOException { Preconditions.checkState(!closed, "Cannot read: already closed"); positionStream(); - int bytesRead = stream.read(b, off, len); - pos += bytesRead; - next += bytesRead; - readBytes.increment(bytesRead); - readOperations.increment(); + try { + int bytesRead = Failsafe.with(retryPolicy).get(() -> stream.read(b, off, len)); + pos += bytesRead; + next += bytesRead; + readBytes.increment(bytesRead); + readOperations.increment(); + + return bytesRead; + } catch (FailsafeException ex) { + if (ex.getCause() instanceof IOException) { + throw (IOException) ex.getCause(); + } - return bytesRead; + throw ex; + } } @Override @@ -146,7 +177,7 @@ private InputStream readRange(String range) { public void close() throws IOException { super.close(); closed = true; - closeStream(); + closeStream(false); } private void positionStream() throws IOException { @@ -178,6 +209,10 @@ private void positionStream() throws IOException { } private void openStream() throws IOException { + openStream(false); + } + + private void openStream(boolean closeQuietly) throws IOException { GetObjectRequest.Builder requestBuilder = GetObjectRequest.builder() .bucket(location.bucket()) @@ -186,7 +221,7 @@ private void openStream() throws IOException { S3RequestUtil.configureEncryption(s3FileIOProperties, requestBuilder); - closeStream(); + closeStream(closeQuietly); try { stream = s3.getObject(requestBuilder.build(), ResponseTransformer.toInputStream()); @@ -195,7 +230,7 @@ private void openStream() throws IOException { } } - private void closeStream() throws IOException { + private void closeStream(boolean closeQuietly) throws IOException { if (stream != null) { // if we aren't at the end of the stream, and the stream is abortable, then // call abort() so we don't read the remaining data with the Apache HTTP client @@ -203,6 +238,12 @@ private void closeStream() throws IOException { try { stream.close(); } catch (IOException e) { + if (closeQuietly) { + stream = null; + LOG.warn("An error occurred while closing the stream", e); + return; + } + // the Apache HTTP client will throw a ConnectionClosedException // when closing an aborted stream, which is expected if (!e.getClass().getSimpleName().equals("ConnectionClosedException")) { diff --git a/aws/src/test/java/org/apache/iceberg/aws/s3/TestFlakyS3InputStream.java b/aws/src/test/java/org/apache/iceberg/aws/s3/TestFlakyS3InputStream.java new file mode 100644 index 000000000000..08d14512cdc7 --- /dev/null +++ b/aws/src/test/java/org/apache/iceberg/aws/s3/TestFlakyS3InputStream.java @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.aws.s3; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; + +import java.io.IOException; +import java.io.InputStream; +import java.net.SocketTimeoutException; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Stream; +import javax.net.ssl.SSLException; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import software.amazon.awssdk.awscore.exception.AwsServiceException; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.core.sync.ResponseTransformer; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.CreateBucketRequest; +import software.amazon.awssdk.services.s3.model.CreateBucketResponse; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.HeadObjectRequest; +import software.amazon.awssdk.services.s3.model.HeadObjectResponse; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; + +public class TestFlakyS3InputStream extends TestS3InputStream { + + @ParameterizedTest + @MethodSource("retryableExceptions") + public void testReadWithFlakyStreamRetrySucceed(IOException exception) throws Exception { + testRead(flakyStreamClient(new AtomicInteger(3), exception)); + } + + @ParameterizedTest + @MethodSource("retryableExceptions") + public void testReadWithFlakyStreamExhaustedRetries(IOException exception) { + assertThatThrownBy(() -> testRead(flakyStreamClient(new AtomicInteger(5), exception))) + .isInstanceOf(exception.getClass()) + .hasMessage(exception.getMessage()); + } + + @ParameterizedTest + @MethodSource("nonRetryableExceptions") + public void testReadWithFlakyStreamNonRetryableException(IOException exception) { + assertThatThrownBy(() -> testRead(flakyStreamClient(new AtomicInteger(3), exception))) + .isInstanceOf(exception.getClass()) + .hasMessage(exception.getMessage()); + } + + @ParameterizedTest + @MethodSource("retryableExceptions") + public void testSeekWithFlakyStreamRetrySucceed(IOException exception) throws Exception { + testSeek(flakyStreamClient(new AtomicInteger(3), exception)); + } + + @ParameterizedTest + @MethodSource("retryableExceptions") + public void testSeekWithFlakyStreamExhaustedRetries(IOException exception) { + assertThatThrownBy(() -> testSeek(flakyStreamClient(new AtomicInteger(5), exception))) + .isInstanceOf(exception.getClass()) + .hasMessage(exception.getMessage()); + } + + @ParameterizedTest + @MethodSource("nonRetryableExceptions") + public void testSeekWithFlakyStreamNonRetryableException(IOException exception) { + assertThatThrownBy(() -> testSeek(flakyStreamClient(new AtomicInteger(3), exception))) + .isInstanceOf(exception.getClass()) + .hasMessage(exception.getMessage()); + } + + private static Stream retryableExceptions() { + return Stream.of( + Arguments.of( + new SocketTimeoutException("socket timeout exception"), + new SSLException("some ssl exception"))); + } + + private static Stream nonRetryableExceptions() { + return Stream.of(Arguments.of(new IOException("some generic non-retryable IO exception"))); + } + + private S3ClientWrapper flakyStreamClient(AtomicInteger counter, IOException failure) { + S3ClientWrapper flakyClient = spy(new S3ClientWrapper(s3Client())); + doAnswer(invocation -> new FlakyInputStream(invocation.callRealMethod(), counter, failure)) + .when(flakyClient) + .getObject(any(GetObjectRequest.class), any(ResponseTransformer.class)); + return flakyClient; + } + + /** Wrapper for S3 client, used to mock the final class DefaultS3Client */ + public static class S3ClientWrapper implements S3Client { + + private final S3Client delegate; + + public S3ClientWrapper(S3Client delegate) { + this.delegate = delegate; + } + + @Override + public String serviceName() { + return delegate.serviceName(); + } + + @Override + public void close() { + delegate.close(); + } + + @Override + public ReturnT getObject( + GetObjectRequest getObjectRequest, + ResponseTransformer responseTransformer) + throws AwsServiceException, SdkClientException { + return delegate.getObject(getObjectRequest, responseTransformer); + } + + @Override + public HeadObjectResponse headObject(HeadObjectRequest headObjectRequest) + throws AwsServiceException, SdkClientException { + return delegate.headObject(headObjectRequest); + } + + @Override + public PutObjectResponse putObject(PutObjectRequest putObjectRequest, RequestBody requestBody) + throws AwsServiceException, SdkClientException { + return delegate.putObject(putObjectRequest, requestBody); + } + + @Override + public CreateBucketResponse createBucket(CreateBucketRequest createBucketRequest) + throws AwsServiceException, SdkClientException { + return delegate.createBucket(createBucketRequest); + } + } + + static class FlakyInputStream extends InputStream { + private final ResponseInputStream delegate; + private final AtomicInteger counter; + private final int round; + private final IOException exception; + + FlakyInputStream(Object invocationResponse, AtomicInteger counter, IOException exception) { + this.delegate = (ResponseInputStream) invocationResponse; + this.counter = counter; + this.round = counter.get(); + this.exception = exception; + } + + private void checkCounter() throws IOException { + // for every round of n invocations, only the last call succeeds + if (counter.decrementAndGet() == 0) { + counter.set(round); + } else { + throw exception; + } + } + + @Override + public int read() throws IOException { + checkCounter(); + return delegate.read(); + } + + @Override + public int read(byte[] b) throws IOException { + checkCounter(); + return delegate.read(b); + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + checkCounter(); + return delegate.read(b, off, len); + } + + @Override + public void close() throws IOException { + delegate.close(); + } + } +} diff --git a/aws/src/test/java/org/apache/iceberg/aws/s3/TestS3InputStream.java b/aws/src/test/java/org/apache/iceberg/aws/s3/TestS3InputStream.java index ed71e259a26c..0e3f8b2136a6 100644 --- a/aws/src/test/java/org/apache/iceberg/aws/s3/TestS3InputStream.java +++ b/aws/src/test/java/org/apache/iceberg/aws/s3/TestS3InputStream.java @@ -54,16 +54,18 @@ public void before() { @Test public void testRead() throws Exception { + testRead(s3); + } + + protected void testRead(S3Client s3Client) throws Exception { S3URI uri = new S3URI("s3://bucket/path/to/read.dat"); int dataSize = 1024 * 1024 * 10; byte[] data = randomData(dataSize); writeS3Data(uri, data); - try (SeekableInputStream in = new S3InputStream(s3, uri)) { + try (SeekableInputStream in = new S3InputStream(s3Client, uri)) { int readSize = 1024; - byte[] actual = new byte[readSize]; - readAndCheck(in, in.getPos(), readSize, data, false); readAndCheck(in, in.getPos(), readSize, data, true); @@ -111,6 +113,10 @@ private void readAndCheck( @Test public void testRangeRead() throws Exception { + testRangeRead(s3); + } + + protected void testRangeRead(S3Client s3Client) throws Exception { S3URI uri = new S3URI("s3://bucket/path/to/range-read.dat"); int dataSize = 1024 * 1024 * 10; byte[] expected = randomData(dataSize); @@ -122,7 +128,7 @@ public void testRangeRead() throws Exception { writeS3Data(uri, expected); - try (RangeReadable in = new S3InputStream(s3, uri)) { + try (RangeReadable in = new S3InputStream(s3Client, uri)) { // first 1k position = 0; offset = 0; @@ -163,12 +169,16 @@ public void testClose() throws Exception { @Test public void testSeek() throws Exception { + testSeek(s3); + } + + protected void testSeek(S3Client s3Client) throws Exception { S3URI uri = new S3URI("s3://bucket/path/to/seek.dat"); byte[] expected = randomData(1024 * 1024); writeS3Data(uri, expected); - try (SeekableInputStream in = new S3InputStream(s3, uri)) { + try (SeekableInputStream in = new S3InputStream(s3Client, uri)) { in.seek(expected.length / 2); byte[] actual = new byte[expected.length / 2]; IOUtil.readFully(in, actual, 0, expected.length / 2); @@ -200,4 +210,8 @@ private void createBucket(String bucketName) { // don't do anything } } + + protected S3Client s3Client() { + return s3; + } } diff --git a/build.gradle b/build.gradle index 02758e2a793b..620641a21f92 100644 --- a/build.gradle +++ b/build.gradle @@ -347,6 +347,7 @@ project(':iceberg-core') { implementation libs.jackson.core implementation libs.jackson.databind implementation libs.caffeine + implementation libs.failsafe implementation libs.roaringbitmap compileOnly(libs.hadoop2.client) { exclude group: 'org.apache.avro', module: 'avro' @@ -462,6 +463,7 @@ project(':iceberg-aws') { annotationProcessor libs.immutables.value compileOnly libs.immutables.value implementation libs.caffeine + implementation libs.failsafe implementation platform(libs.jackson.bom) implementation libs.jackson.core implementation libs.jackson.databind diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 34471c2c4b75..fad4d49a1e62 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -39,6 +39,7 @@ delta-standalone = "3.2.0" delta-spark = "3.2.0" esotericsoftware-kryo = "4.0.3" errorprone-annotations = "2.31.0" +failsafe = "3.3.2" findbugs-jsr305 = "3.0.2" flink118 = { strictly = "1.18.1"} flink119 = { strictly = "1.19.0"} @@ -109,6 +110,7 @@ calcite-druid = { module = "org.apache.calcite:calcite-druid", version.ref = "ca datasketches = { module = "org.apache.datasketches:datasketches-java", version.ref = "datasketches" } delta-standalone = { module = "io.delta:delta-standalone_2.12", version.ref = "delta-standalone" } errorprone-annotations = { module = "com.google.errorprone:error_prone_annotations", version.ref = "errorprone-annotations" } +failsafe = { module = "dev.failsafe:failsafe", version.ref = "failsafe"} findbugs-jsr305 = { module = "com.google.code.findbugs:jsr305", version.ref = "findbugs-jsr305" } flink118-avro = { module = "org.apache.flink:flink-avro", version.ref = "flink118" } flink118-connector-base = { module = "org.apache.flink:flink-connector-base", version.ref = "flink118" }