diff --git a/common-http/src/main/java/io/cdap/common/http/HttpContentConsumer.java b/common-http/src/main/java/io/cdap/common/http/HttpContentConsumer.java new file mode 100644 index 0000000..6c503c8 --- /dev/null +++ b/common-http/src/main/java/io/cdap/common/http/HttpContentConsumer.java @@ -0,0 +1,53 @@ +/* + * Copyright © 2021 Cask Data, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package io.cdap.common.http; + +import java.nio.ByteBuffer; + +/** + * Consumer for {@link HttpResponse} body. + */ +public abstract class HttpContentConsumer { + // Default 64K chunk size + private static final int DEFAULT_CHUNK_SIZE = 65536; + private int chunkSize; + + public HttpContentConsumer() { + this.chunkSize = DEFAULT_CHUNK_SIZE; + } + + public HttpContentConsumer(int chunkSize) { + this.chunkSize = chunkSize; + } + + /** + * This method is invoked when a new chunk of the response body is available to be consumed. + * + * @param chunk a {@link ByteBuffer} containing a chunk of the response body + * @return true to continue reading from the response stream, false to stop reading and close the connection. + */ + public abstract boolean onReceived(ByteBuffer chunk); + + /** + * This method is invoked when the end of the response body is reached. + */ + public abstract void onFinished(); + + int getChunkSize() { + return chunkSize; + } +} diff --git a/common-http/src/main/java/io/cdap/common/http/HttpRequest.java b/common-http/src/main/java/io/cdap/common/http/HttpRequest.java index 237c57b..96a87f6 100644 --- a/common-http/src/main/java/io/cdap/common/http/HttpRequest.java +++ b/common-http/src/main/java/io/cdap/common/http/HttpRequest.java @@ -45,15 +45,23 @@ public class HttpRequest { private final Multimap headers; private final ContentProvider body; private final Long bodyLength; + private HttpContentConsumer consumer; public HttpRequest(HttpMethod method, URL url, @Nullable Multimap headers, @Nullable ContentProvider body, @Nullable Long bodyLength) { + this(method, url, headers, body, bodyLength, null); + } + + public HttpRequest(HttpMethod method, URL url, @Nullable Multimap headers, + @Nullable ContentProvider body, + @Nullable Long bodyLength, @Nullable HttpContentConsumer consumer) { this.method = method; this.url = url; this.headers = headers; this.body = body; this.bodyLength = bodyLength; + this.consumer = consumer; } public static Builder get(URL url) { @@ -103,6 +111,15 @@ public Long getBodyLength() { return bodyLength; } + @Nullable + public HttpContentConsumer getConsumer() { + return consumer; + } + + public boolean hasContentConsumer() { + return consumer != null; + } + /** * Builder for {@link HttpRequest}. */ @@ -112,6 +129,7 @@ public static final class Builder { private final Multimap headers; private ContentProvider body; private Long bodyLength; + private HttpContentConsumer consumer; Builder(HttpMethod method, URL url) { this.method = method; @@ -203,8 +221,13 @@ public InputStream getInput() { return this; } + public Builder withContentConsumer(HttpContentConsumer consumer) { + this.consumer = consumer; + return this; + } + public HttpRequest build() { - return new HttpRequest(method, url, headers, body, bodyLength); + return new HttpRequest(method, url, headers, body, bodyLength, consumer); } } } diff --git a/common-http/src/main/java/io/cdap/common/http/HttpRequests.java b/common-http/src/main/java/io/cdap/common/http/HttpRequests.java index 66951cb..b47161f 100644 --- a/common-http/src/main/java/io/cdap/common/http/HttpRequests.java +++ b/common-http/src/main/java/io/cdap/common/http/HttpRequests.java @@ -22,7 +22,6 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.io.FileNotFoundException; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; @@ -62,6 +61,20 @@ private HttpRequests() { } * @return HTTP response */ public static HttpResponse execute(HttpRequest request, HttpRequestConfig requestConfig) throws IOException { + HttpURLConnection conn = getConnection(request, requestConfig); + conn.connect(); + + ContentProvider bodySrc = request.getBody(); + if (bodySrc != null) { + try (InputStream input = bodySrc.getInput(); OutputStream os = conn.getOutputStream()) { + ByteStreams.copy(input, os); + } + } + return request.hasContentConsumer() ? new HttpResponse(conn, request.getConsumer()) : new HttpResponse(conn); + } + + private static HttpURLConnection getConnection(HttpRequest request, HttpRequestConfig requestConfig) + throws IOException { String requestMethod = request.getMethod().name(); URL url = request.getURL(); @@ -77,8 +90,7 @@ public static HttpResponse execute(HttpRequest request, HttpRequestConfig reques } } - ContentProvider bodySrc = request.getBody(); - if (bodySrc != null) { + if (request.getBody() != null) { conn.setDoOutput(true); Long bodyLength = request.getBodyLength(); if (bodyLength != null) { @@ -100,34 +112,7 @@ public static HttpResponse execute(HttpRequest request, HttpRequestConfig reques LOG.error("Got exception while disabling SSL certificate check for {}", request.getURL()); } } - - conn.connect(); - - try { - if (bodySrc != null) { - try (InputStream input = bodySrc.getInput(); OutputStream os = conn.getOutputStream()) { - ByteStreams.copy(input, os); - } - } - - try { - if (isSuccessful(conn.getResponseCode())) { - try (InputStream inputStream = conn.getInputStream()) { - return new HttpResponse(conn.getResponseCode(), conn.getResponseMessage(), - ByteStreams.toByteArray(inputStream), conn.getHeaderFields()); - } - } - } catch (FileNotFoundException e) { - // Server returns 404. Hence handle as error flow below. Intentional having empty catch block. - } - - // Non 2xx response - InputStream es = conn.getErrorStream(); - byte[] content = (es == null) ? new byte[0] : ByteStreams.toByteArray(es); - return new HttpResponse(conn.getResponseCode(), conn.getResponseMessage(), content, conn.getHeaderFields()); - } finally { - conn.disconnect(); - } + return conn; } /** diff --git a/common-http/src/main/java/io/cdap/common/http/HttpResponse.java b/common-http/src/main/java/io/cdap/common/http/HttpResponse.java index 065fdf8..6438c45 100644 --- a/common-http/src/main/java/io/cdap/common/http/HttpResponse.java +++ b/common-http/src/main/java/io/cdap/common/http/HttpResponse.java @@ -16,21 +16,37 @@ package io.cdap.common.http; import com.google.common.base.Charsets; +import com.google.common.base.Throwables; import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.Multimap; +import com.google.common.io.ByteStreams; +import com.google.common.io.Closeables; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import java.io.IOException; +import java.io.InputStream; +import java.net.HttpURLConnection; +import java.nio.ByteBuffer; +import java.nio.channels.Channels; +import java.nio.channels.ReadableByteChannel; import java.nio.charset.Charset; import java.util.List; import java.util.Map; +import javax.annotation.Nullable; /** * Return type for http requests executed by {@link HttpResponse} */ public class HttpResponse { + private static final Logger LOG = LoggerFactory.getLogger(HttpResponse.class); private final int responseCode; private final String responseMessage; - private final byte[] responseBody; + private byte[] responseBody; private final Multimap headers; + private InputStream inputStream; + private HttpURLConnection conn; + private HttpContentConsumer consumer; HttpResponse(int responseCode, String responseMessage, byte[] responseBody, Map> headers) { @@ -45,6 +61,20 @@ public class HttpResponse { this.headers = headers; } + HttpResponse(HttpURLConnection conn) throws IOException { + this(conn, null); + this.responseBody = getResponseBodyFromStream(); + } + + HttpResponse(HttpURLConnection conn, @Nullable HttpContentConsumer consumer) throws IOException { + this.conn = conn; + this.responseCode = conn.getResponseCode(); + this.responseMessage = conn.getResponseMessage(); + this.headers = parseHeaders(conn.getHeaderFields()); + this.inputStream = isSuccessful(responseCode) ? conn.getInputStream() : conn.getErrorStream(); + this.consumer = consumer; + } + public int getResponseCode() { return responseCode; } @@ -69,6 +99,30 @@ public Multimap getHeaders() { return headers; } + public void consumeContent() throws IOException { + if (inputStream == null) { + conn.disconnect(); + consumer.onFinished(); + return; + } + + try (ReadableByteChannel channel = Channels.newChannel(inputStream)) { + ByteBuffer buffer = ByteBuffer.allocate(consumer.getChunkSize()); + while (channel.read(buffer) >= 0) { + // Flip the buffer for the consumer to read + buffer.flip(); + boolean continueReading = consumer.onReceived(buffer); + buffer.clear(); + if (!continueReading) { + break; + } + } + } finally { + conn.disconnect(); + consumer.onFinished(); + } + } + private static Multimap parseHeaders(Map> headers) { ImmutableListMultimap.Builder builder = new ImmutableListMultimap.Builder(); for (Map.Entry> entry : headers.entrySet()) { @@ -80,6 +134,25 @@ private static Multimap parseHeaders(Map> h return builder.build(); } + private byte[] getResponseBodyFromStream() { + try { + if (inputStream == null) { + return new byte[0]; + } + return ByteStreams.toByteArray(inputStream); + } catch (IOException e) { + throw Throwables.propagate(e); + } finally { + Closeables.closeQuietly(inputStream); + inputStream = null; + conn.disconnect(); + } + } + + private boolean isSuccessful(int responseCode) { + return 200 <= responseCode && responseCode < 300; + } + @Override public String toString() { return String.format("Response code: %s, message: '%s', body: '%s'", diff --git a/common-http/src/test/java/io/cdap/common/http/HttpRequestsStreamTest.java b/common-http/src/test/java/io/cdap/common/http/HttpRequestsStreamTest.java new file mode 100644 index 0000000..561dda3 --- /dev/null +++ b/common-http/src/test/java/io/cdap/common/http/HttpRequestsStreamTest.java @@ -0,0 +1,61 @@ +/* + * Copyright © 2021 Cask Data, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package io.cdap.common.http; + +import org.junit.After; +import org.junit.Before; + +import java.net.InetSocketAddress; +import java.net.URI; +import java.net.URISyntaxException; + +public class HttpRequestsStreamTest extends HttpRequestsTestBase { + + private static TestHttpService httpService; + + @Before + public void setUp() throws Exception { + httpService = new TestHttpService(false); + httpService.startAndWait(); + } + + @After + public void tearDown() { + httpService.stopAndWait(); + } + + @Override + protected URI getBaseURI() throws URISyntaxException { + InetSocketAddress bindAddress = httpService.getBindAddress(); + return new URI("http://" + bindAddress.getHostName() + ":" + bindAddress.getPort()); + } + + @Override + protected HttpRequestConfig getHttpRequestsConfig() { + return new HttpRequestConfig(0, 0, false); + } + + @Override + protected int getNumConnectionsOpened() { + return httpService.getNumConnectionsOpened(); + } + + @Override + protected boolean returnResponseStream() { + return true; + } +} diff --git a/common-http/src/test/java/io/cdap/common/http/HttpRequestsTest.java b/common-http/src/test/java/io/cdap/common/http/HttpRequestsTest.java index 7ed8404..7dc3e58 100644 --- a/common-http/src/test/java/io/cdap/common/http/HttpRequestsTest.java +++ b/common-http/src/test/java/io/cdap/common/http/HttpRequestsTest.java @@ -56,4 +56,9 @@ protected HttpRequestConfig getHttpRequestsConfig() { protected int getNumConnectionsOpened() { return httpService.getNumConnectionsOpened(); } + + @Override + protected boolean returnResponseStream() { + return false; + } } diff --git a/common-http/src/test/java/io/cdap/common/http/HttpRequestsTestBase.java b/common-http/src/test/java/io/cdap/common/http/HttpRequestsTestBase.java index f449464..ddca028 100644 --- a/common-http/src/test/java/io/cdap/common/http/HttpRequestsTestBase.java +++ b/common-http/src/test/java/io/cdap/common/http/HttpRequestsTestBase.java @@ -30,10 +30,15 @@ import org.junit.Assert; import org.junit.Test; +import java.io.IOException; import java.net.URI; import java.net.URISyntaxException; import java.net.URL; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; import javax.ws.rs.DELETE; import javax.ws.rs.GET; import javax.ws.rs.POST; @@ -42,6 +47,7 @@ import static com.google.inject.matcher.Matchers.any; import static com.google.inject.matcher.Matchers.only; +import static org.junit.Assert.fail; /** * Test base for {@link HttpRequests}. @@ -54,6 +60,8 @@ public abstract class HttpRequestsTestBase { protected abstract int getNumConnectionsOpened(); + protected abstract boolean returnResponseStream(); + @Test public void testHttpStatus() throws Exception { testGet("/fake/fake", only(404), only("Not Found"), @@ -110,14 +118,42 @@ public void testHttpStatus() throws Exception { testDelete("/api/testDelete", only(200), any(), any(), any()); } + private HttpRequest getRequest(HttpRequest.Builder builder, CompletableFuture future) { + if (returnResponseStream()) { + // Set a small chunk size to verify that we're doing chunked reads correctly. + int chunkSize = 5; + StringBuilder sb = new StringBuilder(); + return builder.withContentConsumer(new HttpContentConsumer(chunkSize) { + @Override + public boolean onReceived(ByteBuffer buffer) { + // create byte array with length = number of bytes written to the buffer + byte[] bytes = new byte[buffer.remaining()]; + // read the bytes that were written to the buffer + buffer.get(bytes); + sb.append(new String(bytes, StandardCharsets.UTF_8)); + return true; + } + + @Override + public void onFinished() { + future.complete(sb.toString()); + } + }).build(); + } else { + return builder.build(); + } + } + private void testPost(String path, Map headers, String body, Matcher expectedResponseCode, Matcher expectedMessage, Matcher expectedBody, Matcher expectedHeaders) throws Exception { URL url = getBaseURI().resolve(path).toURL(); - HttpRequest request = HttpRequest.post(url).addHeaders(headers).withBody(body).build(); - HttpResponse response = HttpRequests.execute(request, getHttpRequestsConfig()); - verifyResponse(response, expectedResponseCode, expectedMessage, expectedBody, expectedHeaders); + CompletableFuture future = new CompletableFuture<>(); + HttpRequest request = getRequest(HttpRequest.post(url).addHeaders(headers).withBody(body), future); + HttpRequestConfig requestConfig = getHttpRequestsConfig(); + HttpResponse response = HttpRequests.execute(request, requestConfig); + verifyResponse(response, expectedResponseCode, expectedMessage, expectedBody, expectedHeaders, future); } private void testPost(String path, Matcher expectedResponseCode, Matcher expectedMessage, @@ -132,32 +168,38 @@ private void testPut(String path, Map headers, String body, Matcher expectedBody, Matcher expectedHeaders) throws Exception { URL url = getBaseURI().resolve(path).toURL(); - HttpRequest request = HttpRequest.put(url).addHeaders(headers).withBody(body).build(); - HttpResponse response = HttpRequests.execute(request, getHttpRequestsConfig()); - verifyResponse(response, expectedResponseCode, expectedMessage, expectedBody, expectedHeaders); + CompletableFuture future = new CompletableFuture<>(); + HttpRequest request = getRequest(HttpRequest.put(url).addHeaders(headers).withBody(body), future); + HttpRequestConfig requestConfig = getHttpRequestsConfig(); + HttpResponse response = HttpRequests.execute(request, requestConfig); + verifyResponse(response, expectedResponseCode, expectedMessage, expectedBody, expectedHeaders, future); } private void testGet(String path, Matcher expectedResponseCode, Matcher expectedMessage, Matcher expectedBody, Matcher expectedHeaders) throws Exception { URL url = getBaseURI().resolve(path).toURL(); - HttpRequest request = HttpRequest.get(url).build(); - HttpResponse response = HttpRequests.execute(request, getHttpRequestsConfig()); - verifyResponse(response, expectedResponseCode, expectedMessage, expectedBody, expectedHeaders); + CompletableFuture future = new CompletableFuture<>(); + HttpRequest request = getRequest(HttpRequest.get(url), future); + HttpRequestConfig requestConfig = getHttpRequestsConfig(); + HttpResponse response = HttpRequests.execute(request, requestConfig); + verifyResponse(response, expectedResponseCode, expectedMessage, expectedBody, expectedHeaders, future); } private void testDelete(String path, Matcher expectedResponseCode, Matcher expectedMessage, Matcher expectedBody, Matcher expectedHeaders) throws Exception { URL url = getBaseURI().resolve(path).toURL(); - HttpRequest request = HttpRequest.delete(url).build(); - HttpResponse response = HttpRequests.execute(request, getHttpRequestsConfig()); - verifyResponse(response, expectedResponseCode, expectedMessage, expectedBody, expectedHeaders); + CompletableFuture future = new CompletableFuture<>(); + HttpRequest request = getRequest(HttpRequest.delete(url), future); + HttpRequestConfig requestConfig = getHttpRequestsConfig(); + HttpResponse response = HttpRequests.execute(request, requestConfig); + verifyResponse(response, expectedResponseCode, expectedMessage, expectedBody, expectedHeaders, future); } private void verifyResponse(HttpResponse response, Matcher expectedResponseCode, Matcher expectedMessage, Matcher expectedBody, - Matcher expectedHeaders) { + Matcher expectedHeaders, CompletableFuture future) { Assert.assertTrue("Response code - expected: " + expectedResponseCode.toString() + " actual: " + response.getResponseCode(), @@ -167,7 +209,17 @@ private void verifyResponse(HttpResponse response, Matcher expectedRespo + " actual: " + response.getResponseMessage(), expectedMessage.matches(response.getResponseMessage())); - String actualResponseBody = new String(response.getResponseBody()); + String actualResponseBody = ""; + if (!returnResponseStream()) { + actualResponseBody = response.getResponseBodyAsString(); + } else { + try { + response.consumeContent(); + actualResponseBody = future.get(); + } catch (IOException | InterruptedException | ExecutionException e) { + fail("Unexpected exception"); + } + } Assert.assertTrue("Response body - expected: " + expectedBody.toString() + " actual: " + actualResponseBody, expectedBody.matches(actualResponseBody)); diff --git a/common-http/src/test/java/io/cdap/common/http/HttpsRequestsTest.java b/common-http/src/test/java/io/cdap/common/http/HttpsRequestsTest.java index 3390ba6..d93ae43 100644 --- a/common-http/src/test/java/io/cdap/common/http/HttpsRequestsTest.java +++ b/common-http/src/test/java/io/cdap/common/http/HttpsRequestsTest.java @@ -56,4 +56,9 @@ protected HttpRequestConfig getHttpRequestsConfig() { protected int getNumConnectionsOpened() { return httpsService.getNumConnectionsOpened(); } + + @Override + protected boolean returnResponseStream() { + return false; + } }