diff --git a/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4IncrementalRequestHandlingIT.java b/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4IncrementalRequestHandlingIT.java index 4de7ca97ed51b..b3139fd336a70 100644 --- a/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4IncrementalRequestHandlingIT.java +++ b/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4IncrementalRequestHandlingIT.java @@ -10,6 +10,7 @@ import io.netty.bootstrap.Bootstrap; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufInputStream; import io.netty.buffer.Unpooled; import io.netty.channel.Channel; import io.netty.channel.ChannelHandlerContext; @@ -24,11 +25,16 @@ import io.netty.handler.codec.http.DefaultLastHttpContent; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpChunkedInput; import io.netty.handler.codec.http.HttpClientCodec; import io.netty.handler.codec.http.HttpContent; import io.netty.handler.codec.http.HttpObjectAggregator; import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpUtil; import io.netty.handler.codec.http.LastHttpContent; +import io.netty.handler.stream.ChunkedStream; +import io.netty.handler.stream.ChunkedWriteHandler; import org.elasticsearch.ESNetty4IntegTestCase; import org.elasticsearch.action.support.SubscribableListener; @@ -41,9 +47,13 @@ import org.elasticsearch.common.settings.IndexScopedSettings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.SettingsFilter; +import org.elasticsearch.common.unit.ByteSizeUnit; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.CollectionUtils; import org.elasticsearch.features.NodeFeature; +import org.elasticsearch.http.HttpHandlingSettings; import org.elasticsearch.http.HttpServerTransport; +import org.elasticsearch.http.HttpTransportSettings; import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.rest.BaseRestHandler; @@ -61,9 +71,7 @@ import java.util.List; import java.util.concurrent.BlockingDeque; import java.util.concurrent.BlockingQueue; -import java.util.concurrent.Callable; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.Future; import java.util.concurrent.LinkedBlockingDeque; import java.util.concurrent.TimeUnit; import java.util.function.Predicate; @@ -78,6 +86,13 @@ @ESIntegTestCase.ClusterScope(numDataNodes = 1) public class Netty4IncrementalRequestHandlingIT extends ESNetty4IntegTestCase { + @Override + protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { + Settings.Builder builder = Settings.builder().put(super.nodeSettings(nodeOrdinal, otherSettings)); + builder.put(HttpTransportSettings.SETTING_HTTP_MAX_CONTENT_LENGTH.getKey(), new ByteSizeValue(50, ByteSizeUnit.MB)); + return builder.build(); + } + // ensure empty http content has single 0 size chunk public void testEmptyContent() throws Exception { try (var ctx = setupClientCtx()) { @@ -111,7 +126,7 @@ public void testReceiveAllChunks() throws Exception { var opaqueId = opaqueId(reqNo); // this dataset will be compared with one on server side - var dataSize = randomIntBetween(1024, 10 * 1024 * 1024); + var dataSize = randomIntBetween(1024, maxContentLength()); var sendData = Unpooled.wrappedBuffer(randomByteArrayOfLength(dataSize)); sendData.retain(); ctx.clientChannel.writeAndFlush(fullHttpRequest(opaqueId, sendData)); @@ -212,12 +227,98 @@ public void testClientBackpressure() throws Exception { bufSize >= minBufSize && bufSize <= maxBufSize ); }); - handler.consumeBytes(MBytes(10)); + handler.readBytes(MBytes(10)); } assertTrue(handler.stream.hasLast()); } } + // ensures that server reply 100-continue on acceptable request size + public void test100Continue() throws Exception { + try (var ctx = setupClientCtx()) { + for (int reqNo = 0; reqNo < randomIntBetween(2, 10); reqNo++) { + var id = opaqueId(reqNo); + var acceptableContentLength = randomIntBetween(0, maxContentLength()); + + // send request header and await 100-continue + var req = httpRequest(id, acceptableContentLength); + HttpUtil.set100ContinueExpected(req, true); + ctx.clientChannel.writeAndFlush(req); + var resp = (FullHttpResponse) safePoll(ctx.clientRespQueue); + assertEquals(HttpResponseStatus.CONTINUE, resp.status()); + resp.release(); + + // send content + var content = randomContent(acceptableContentLength, true); + ctx.clientChannel.writeAndFlush(content); + + // consume content and reply 200 + var handler = ctx.awaitRestChannelAccepted(id); + var consumed = handler.readAllBytes(); + assertEquals(acceptableContentLength, consumed); + handler.sendResponse(new RestResponse(RestStatus.OK, "")); + + resp = (FullHttpResponse) safePoll(ctx.clientRespQueue); + assertEquals(HttpResponseStatus.OK, resp.status()); + resp.release(); + } + } + } + + // ensures that server reply 413-too-large on oversized request with expect-100-continue + public void test413TooLargeOnExpect100Continue() throws Exception { + try (var ctx = setupClientCtx()) { + for (int reqNo = 0; reqNo < randomIntBetween(2, 10); reqNo++) { + var id = opaqueId(reqNo); + var oversized = maxContentLength() + 1; + + // send request header and await 413 too large + var req = httpRequest(id, oversized); + HttpUtil.set100ContinueExpected(req, true); + ctx.clientChannel.writeAndFlush(req); + var resp = (FullHttpResponse) safePoll(ctx.clientRespQueue); + assertEquals(HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, resp.status()); + resp.release(); + + // terminate request + ctx.clientChannel.writeAndFlush(LastHttpContent.EMPTY_LAST_CONTENT); + } + } + } + + // ensures that oversized chunked encoded request has no limits at http layer + // rest handler is responsible for oversized requests + public void testOversizedChunkedEncodingNoLimits() throws Exception { + try (var ctx = setupClientCtx()) { + for (var reqNo = 0; reqNo < randomIntBetween(2, 10); reqNo++) { + var id = opaqueId(reqNo); + var contentSize = maxContentLength() + 1; + var content = randomByteArrayOfLength(contentSize); + var is = new ByteBufInputStream(Unpooled.wrappedBuffer(content)); + var chunkedIs = new ChunkedStream(is); + var httpChunkedIs = new HttpChunkedInput(chunkedIs, LastHttpContent.EMPTY_LAST_CONTENT); + var req = httpRequest(id, 0); + HttpUtil.setTransferEncodingChunked(req, true); + + ctx.clientChannel.pipeline().addLast(new ChunkedWriteHandler()); + ctx.clientChannel.writeAndFlush(req); + ctx.clientChannel.writeAndFlush(httpChunkedIs); + var handler = ctx.awaitRestChannelAccepted(id); + var consumed = handler.readAllBytes(); + assertEquals(contentSize, consumed); + handler.sendResponse(new RestResponse(RestStatus.OK, "")); + + var resp = (FullHttpResponse) safePoll(ctx.clientRespQueue); + assertEquals(HttpResponseStatus.OK, resp.status()); + resp.release(); + } + } + } + + private int maxContentLength() { + return HttpHandlingSettings.fromSettings(internalCluster().getInstance(Settings.class)).maxContentLength(); + } + private String opaqueId(int reqNo) { return getTestName() + "-" + reqNo; } @@ -368,24 +469,25 @@ void sendResponse(RestResponse response) { channel.sendResponse(response); } - void consumeBytes(int bytes) { - if (recvLast) { - return; - } - while (bytes > 0) { - stream.next(); - var recvChunk = safePoll(recvChunks); - bytes -= recvChunk.chunk.length(); - recvChunk.chunk.close(); - if (recvChunk.isLast) { - recvLast = true; - break; + int readBytes(int bytes) { + var consumed = 0; + if (recvLast == false) { + while (consumed < bytes) { + stream.next(); + var recvChunk = safePoll(recvChunks); + consumed += recvChunk.chunk.length(); + recvChunk.chunk.close(); + if (recvChunk.isLast) { + recvLast = true; + break; + } } } + return consumed; } - Future onChannelThread(Callable task) { - return this.stream.channel().eventLoop().submit(task); + int readAllBytes() { + return readBytes(Integer.MAX_VALUE); } record Chunk(ReleasableBytesReference chunk, boolean isLast) {} diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpAggregator.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpAggregator.java index 16f1c2bbd2e37..031e803737ee8 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpAggregator.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpAggregator.java @@ -8,19 +8,32 @@ package org.elasticsearch.http.netty4; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpContent; +import io.netty.handler.codec.http.HttpObject; import io.netty.handler.codec.http.HttpObjectAggregator; import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpUtil; import org.elasticsearch.http.HttpPreRequest; import org.elasticsearch.http.netty4.internal.HttpHeadersAuthenticatorUtils; import java.util.function.Predicate; +/** + * A wrapper around {@link HttpObjectAggregator}. Provides optional content aggregation based on + * predicate. {@link HttpObjectAggregator} also handles Expect: 100-continue and oversized content. + * Unfortunately, Netty does not provide handlers for oversized messages beyond HttpObjectAggregator. + */ public class Netty4HttpAggregator extends HttpObjectAggregator { private static final Predicate IGNORE_TEST = (req) -> req.uri().startsWith("/_test/request-stream") == false; private final Predicate decider; - private boolean shouldAggregate; + private boolean aggregating = true; + private boolean ignoreContentAfterContinueResponse = false; public Netty4HttpAggregator(int maxContentLength) { this(maxContentLength, IGNORE_TEST); @@ -32,15 +45,43 @@ public Netty4HttpAggregator(int maxContentLength, Predicate deci } @Override - public boolean acceptInboundMessage(Object msg) throws Exception { + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + assert msg instanceof HttpObject; if (msg instanceof HttpRequest request) { var preReq = HttpHeadersAuthenticatorUtils.asHttpPreRequest(request); - shouldAggregate = decider.test(preReq); + aggregating = decider.test(preReq); + } + if (aggregating || msg instanceof FullHttpRequest) { + super.channelRead(ctx, msg); + } else { + handle(ctx, (HttpObject) msg); } - if (shouldAggregate) { - return super.acceptInboundMessage(msg); + } + + private void handle(ChannelHandlerContext ctx, HttpObject msg) { + if (msg instanceof HttpRequest request) { + var continueResponse = newContinueResponse(request, maxContentLength(), ctx.pipeline()); + if (continueResponse != null) { + // there are 3 responses expected: 100, 413, 417 + // on 100 we pass request further and reply to client to continue + // on 413/417 we ignore following content + ctx.writeAndFlush(continueResponse); + var resp = (FullHttpResponse) continueResponse; + if (resp.status() != HttpResponseStatus.CONTINUE) { + ignoreContentAfterContinueResponse = true; + return; + } + HttpUtil.set100ContinueExpected(request, false); + } + ignoreContentAfterContinueResponse = false; + ctx.fireChannelRead(msg); } else { - return false; + var httpContent = (HttpContent) msg; + if (ignoreContentAfterContinueResponse) { + httpContent.release(); + } else { + ctx.fireChannelRead(msg); + } } } }