Skip to content

Commit

Permalink
Ensure thread context set for streaming (elastic#115683)
Browse files Browse the repository at this point in the history
Currently the thread context is lost between streaming context switches.
This commit ensures that each time the thread context is properly set
before providing new data to the stream.
  • Loading branch information
Tim-Brooks authored Oct 29, 2024
1 parent 6182921 commit 23e1116
Show file tree
Hide file tree
Showing 11 changed files with 148 additions and 96 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,10 @@ public void channelRead(final ChannelHandlerContext ctx, final Object msg) {
netty4HttpRequest = new Netty4HttpRequest(readSequence++, fullHttpRequest);
currentRequestStream = null;
} else {
var contentStream = new Netty4HttpRequestBodyStream(ctx.channel());
var contentStream = new Netty4HttpRequestBodyStream(
ctx.channel(),
serverTransport.getThreadPool().getThreadContext()
);
currentRequestStream = contentStream;
netty4HttpRequest = new Netty4HttpRequest(readSequence++, request, contentStream);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import io.netty.handler.codec.http.HttpContent;
import io.netty.handler.codec.http.LastHttpContent;

import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.http.HttpBody;
import org.elasticsearch.transport.netty4.Netty4Utils;
Expand All @@ -34,14 +35,18 @@ public class Netty4HttpRequestBodyStream implements HttpBody.Stream {
private final Channel channel;
private final ChannelFutureListener closeListener = future -> doClose();
private final List<ChunkHandler> tracingHandlers = new ArrayList<>(4);
private final ThreadContext threadContext;
private ByteBuf buf;
private boolean hasLast = false;
private boolean requested = false;
private boolean closing = false;
private HttpBody.ChunkHandler handler;
private ThreadContext.StoredContext requestContext;

public Netty4HttpRequestBodyStream(Channel channel) {
public Netty4HttpRequestBodyStream(Channel channel, ThreadContext threadContext) {
this.channel = channel;
this.threadContext = threadContext;
this.requestContext = threadContext.newStoredContext();
Netty4Utils.addListener(channel.closeFuture(), closeListener);
channel.config().setAutoRead(false);
}
Expand All @@ -66,6 +71,7 @@ public void addTracingHandler(ChunkHandler chunkHandler) {
public void next() {
assert closing == false : "cannot request next chunk on closing stream";
assert handler != null : "handler must be set before requesting next chunk";
requestContext = threadContext.newStoredContext();
channel.eventLoop().submit(() -> {
requested = true;
if (buf == null) {
Expand Down Expand Up @@ -108,11 +114,6 @@ private void addChunk(ByteBuf chunk) {
}
}

// visible for test
Channel channel() {
return channel;
}

// visible for test
ByteBuf buf() {
return buf;
Expand All @@ -129,10 +130,12 @@ private void send() {
var bytesRef = Netty4Utils.toReleasableBytesReference(buf);
requested = false;
buf = null;
for (var tracer : tracingHandlers) {
tracer.onNext(bytesRef, hasLast);
try (var ignored = threadContext.restoreExistingContext(requestContext)) {
for (var tracer : tracingHandlers) {
tracer.onNext(bytesRef, hasLast);
}
handler.onNext(bytesRef, hasLast);
}
handler.onNext(bytesRef, hasLast);
if (hasLast) {
channel.config().setAutoRead(true);
channel.closeFuture().removeListener(closeListener);
Expand All @@ -150,11 +153,13 @@ public void close() {

private void doClose() {
closing = true;
for (var tracer : tracingHandlers) {
Releasables.closeExpectNoException(tracer);
}
if (handler != null) {
handler.close();
try (var ignored = threadContext.restoreExistingContext(requestContext)) {
for (var tracer : tracingHandlers) {
Releasables.closeExpectNoException(tracer);
}
if (handler != null) {
handler.close();
}
}
if (buf != null) {
buf.release();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,33 @@
import io.netty.handler.flow.FlowControlHandler;

import org.elasticsearch.common.bytes.ReleasableBytesReference;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.http.HttpBody;
import org.elasticsearch.test.ESTestCase;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;

import static org.hamcrest.Matchers.hasEntry;

public class Netty4HttpRequestBodyStreamTests extends ESTestCase {

EmbeddedChannel channel;
Netty4HttpRequestBodyStream stream;
private final ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
private EmbeddedChannel channel;
private Netty4HttpRequestBodyStream stream;
static HttpBody.ChunkHandler discardHandler = (chunk, isLast) -> chunk.close();

@Override
public void setUp() throws Exception {
super.setUp();
channel = new EmbeddedChannel();
stream = new Netty4HttpRequestBodyStream(channel);
threadContext.putHeader("header1", "value1");
stream = new Netty4HttpRequestBodyStream(channel, threadContext);
stream.setHandler(discardHandler); // set default handler, each test might override one
channel.pipeline().addLast(new SimpleChannelInboundHandler<HttpContent>(false) {
@Override
Expand Down Expand Up @@ -118,6 +127,60 @@ public void testReadFromChannel() {
assertTrue("should receive last content", gotLast.get());
}

public void testReadFromHasCorrectThreadContext() throws InterruptedException {
var gotLast = new AtomicBoolean(false);
AtomicReference<Map<String, String>> headers = new AtomicReference<>();
stream.setHandler(new HttpBody.ChunkHandler() {
@Override
public void onNext(ReleasableBytesReference chunk, boolean isLast) {
headers.set(threadContext.getHeaders());
gotLast.set(isLast);
chunk.close();
}

@Override
public void close() {
headers.set(threadContext.getHeaders());
}
});
channel.pipeline().addFirst(new FlowControlHandler()); // block all incoming messages, need explicit channel.read()
var chunkSize = 1024;

channel.writeInbound(randomContent(chunkSize));
channel.writeInbound(randomLastContent(chunkSize));

threadContext.putHeader("header2", "value2");
stream.next();

Thread thread = new Thread(() -> channel.runPendingTasks());
thread.start();
thread.join();

assertThat(headers.get(), hasEntry("header1", "value1"));
assertThat(headers.get(), hasEntry("header2", "value2"));

threadContext.putHeader("header3", "value3");
stream.next();

thread = new Thread(() -> channel.runPendingTasks());
thread.start();
thread.join();

assertThat(headers.get(), hasEntry("header1", "value1"));
assertThat(headers.get(), hasEntry("header2", "value2"));
assertThat(headers.get(), hasEntry("header3", "value3"));

assertTrue("should receive last content", gotLast.get());

headers.set(new HashMap<>());

stream.close();

assertThat(headers.get(), hasEntry("header1", "value1"));
assertThat(headers.get(), hasEntry("header2", "value2"));
assertThat(headers.get(), hasEntry("header3", "value3"));
}

HttpContent randomContent(int size, boolean isLast) {
var buf = Unpooled.wrappedBuffer(randomByteArrayOfLength(size));
if (isLast) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
Expand All @@ -43,12 +42,10 @@ public class IncrementalBulkService {
private final Client client;
private final AtomicBoolean enabledForTests = new AtomicBoolean(true);
private final IndexingPressure indexingPressure;
private final ThreadContext threadContext;

public IncrementalBulkService(Client client, IndexingPressure indexingPressure, ThreadContext threadContext) {
public IncrementalBulkService(Client client, IndexingPressure indexingPressure) {
this.client = client;
this.indexingPressure = indexingPressure;
this.threadContext = threadContext;
}

public Handler newBulkRequest() {
Expand All @@ -58,7 +55,7 @@ public Handler newBulkRequest() {

public Handler newBulkRequest(@Nullable String waitForActiveShards, @Nullable TimeValue timeout, @Nullable String refresh) {
ensureEnabled();
return new Handler(client, threadContext, indexingPressure, waitForActiveShards, timeout, refresh);
return new Handler(client, indexingPressure, waitForActiveShards, timeout, refresh);
}

private void ensureEnabled() {
Expand Down Expand Up @@ -94,7 +91,6 @@ public static class Handler implements Releasable {
public static final BulkRequest.IncrementalState EMPTY_STATE = new BulkRequest.IncrementalState(Collections.emptyMap(), true);

private final Client client;
private final ThreadContext threadContext;
private final IndexingPressure indexingPressure;
private final ActiveShardCount waitForActiveShards;
private final TimeValue timeout;
Expand All @@ -106,22 +102,18 @@ public static class Handler implements Releasable {
private boolean globalFailure = false;
private boolean incrementalRequestSubmitted = false;
private boolean bulkInProgress = false;
private ThreadContext.StoredContext requestContext;
private Exception bulkActionLevelFailure = null;
private long currentBulkSize = 0L;
private BulkRequest bulkRequest = null;

protected Handler(
Client client,
ThreadContext threadContext,
IndexingPressure indexingPressure,
@Nullable String waitForActiveShards,
@Nullable TimeValue timeout,
@Nullable String refresh
) {
this.client = client;
this.threadContext = threadContext;
this.requestContext = threadContext.newStoredContext();
this.indexingPressure = indexingPressure;
this.waitForActiveShards = waitForActiveShards != null ? ActiveShardCount.parseString(waitForActiveShards) : null;
this.timeout = timeout;
Expand All @@ -141,31 +133,28 @@ public void addItems(List<DocWriteRequest<?>> items, Releasable releasable, Runn
if (shouldBackOff()) {
final boolean isFirstRequest = incrementalRequestSubmitted == false;
incrementalRequestSubmitted = true;
try (var ignored = threadContext.restoreExistingContext(requestContext)) {
final ArrayList<Releasable> toRelease = new ArrayList<>(releasables);
releasables.clear();
bulkInProgress = true;
client.bulk(bulkRequest, ActionListener.runAfter(new ActionListener<>() {

@Override
public void onResponse(BulkResponse bulkResponse) {
handleBulkSuccess(bulkResponse);
createNewBulkRequest(
new BulkRequest.IncrementalState(bulkResponse.getIncrementalState().shardLevelFailures(), true)
);
}

@Override
public void onFailure(Exception e) {
handleBulkFailure(isFirstRequest, e);
}
}, () -> {
bulkInProgress = false;
requestContext = threadContext.newStoredContext();
toRelease.forEach(Releasable::close);
nextItems.run();
}));
}
final ArrayList<Releasable> toRelease = new ArrayList<>(releasables);
releasables.clear();
bulkInProgress = true;
client.bulk(bulkRequest, ActionListener.runAfter(new ActionListener<>() {

@Override
public void onResponse(BulkResponse bulkResponse) {
handleBulkSuccess(bulkResponse);
createNewBulkRequest(
new BulkRequest.IncrementalState(bulkResponse.getIncrementalState().shardLevelFailures(), true)
);
}

@Override
public void onFailure(Exception e) {
handleBulkFailure(isFirstRequest, e);
}
}, () -> {
bulkInProgress = false;
toRelease.forEach(Releasable::close);
nextItems.run();
}));
} else {
nextItems.run();
}
Expand All @@ -187,28 +176,26 @@ public void lastItems(List<DocWriteRequest<?>> items, Releasable releasable, Act
} else {
assert bulkRequest != null;
if (internalAddItems(items, releasable)) {
try (var ignored = threadContext.restoreExistingContext(requestContext)) {
final ArrayList<Releasable> toRelease = new ArrayList<>(releasables);
releasables.clear();
// We do not need to set this back to false as this will be the last request.
bulkInProgress = true;
client.bulk(bulkRequest, ActionListener.runBefore(new ActionListener<>() {

private final boolean isFirstRequest = incrementalRequestSubmitted == false;

@Override
public void onResponse(BulkResponse bulkResponse) {
handleBulkSuccess(bulkResponse);
listener.onResponse(combineResponses());
}
final ArrayList<Releasable> toRelease = new ArrayList<>(releasables);
releasables.clear();
// We do not need to set this back to false as this will be the last request.
bulkInProgress = true;
client.bulk(bulkRequest, ActionListener.runBefore(new ActionListener<>() {

private final boolean isFirstRequest = incrementalRequestSubmitted == false;

@Override
public void onResponse(BulkResponse bulkResponse) {
handleBulkSuccess(bulkResponse);
listener.onResponse(combineResponses());
}

@Override
public void onFailure(Exception e) {
handleBulkFailure(isFirstRequest, e);
errorResponse(listener);
}
}, () -> toRelease.forEach(Releasable::close)));
}
@Override
public void onFailure(Exception e) {
handleBulkFailure(isFirstRequest, e);
errorResponse(listener);
}
}, () -> toRelease.forEach(Releasable::close)));
} else {
errorResponse(listener);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -915,11 +915,7 @@ private void construct(
terminationHandler = getSinglePlugin(terminationHandlers, TerminationHandler.class).orElse(null);

final IndexingPressure indexingLimits = new IndexingPressure(settings);
final IncrementalBulkService incrementalBulkService = new IncrementalBulkService(
client,
indexingLimits,
threadPool.getThreadContext()
);
final IncrementalBulkService incrementalBulkService = new IncrementalBulkService(client, indexingLimits);

ActionModule actionModule = new ActionModule(
settings,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ public final void handleRequest(RestRequest request, RestChannel channel, NodeCl
if (request.isStreamedContent()) {
assert action instanceof RequestBodyChunkConsumer;
var chunkConsumer = (RequestBodyChunkConsumer) action;

request.contentStream().setHandler(new HttpBody.ChunkHandler() {
@Override
public void onNext(ReleasableBytesReference chunk, boolean isLast) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,7 @@ static class ChunkHandler implements BaseRestHandler.RequestBodyChunkConsumer {
this.defaultListExecutedPipelines = request.paramAsBoolean("list_executed_pipelines", false);
this.defaultRequireAlias = request.paramAsBoolean(DocWriteRequest.REQUIRE_ALIAS, false);
this.defaultRequireDataStream = request.paramAsBoolean(DocWriteRequest.REQUIRE_DATA_STREAM, false);
// TODO: Fix type deprecation logging
this.parser = new BulkRequestParser(false, request.getRestApiVersion());
this.parser = new BulkRequestParser(true, request.getRestApiVersion());
this.handlerSupplier = handlerSupplier;
}

Expand Down
Loading

0 comments on commit 23e1116

Please sign in to comment.