Skip to content

Commit

Permalink
Catch and handle disconnect exceptions in search (elastic#115836)
Browse files Browse the repository at this point in the history
Getting a connection can throw an exception for a disconnected node.
We failed to handle these in the adjusted spots, leading to a phase failure
(and possible memory leaks for outstanding operations) instead of correctly
recording a per-shard failure.
  • Loading branch information
original-brownbear authored Oct 29, 2024
1 parent 6742147 commit 78a531b
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 65 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/115836.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 115836
summary: Catch and handle disconnect exceptions in search
area: Search
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,20 @@ public void run() {

for (final DfsSearchResult dfsResult : searchResults) {
final SearchShardTarget shardTarget = dfsResult.getSearchShardTarget();
Transport.Connection connection = context.getConnection(shardTarget.getClusterAlias(), shardTarget.getNodeId());
ShardSearchRequest shardRequest = rewriteShardSearchRequest(dfsResult.getShardSearchRequest());
final int shardIndex = dfsResult.getShardIndex();
QuerySearchRequest querySearchRequest = new QuerySearchRequest(
context.getOriginalIndices(dfsResult.getShardIndex()),
context.getOriginalIndices(shardIndex),
dfsResult.getContextId(),
shardRequest,
rewriteShardSearchRequest(dfsResult.getShardSearchRequest()),
dfs
);
final int shardIndex = dfsResult.getShardIndex();
final Transport.Connection connection;
try {
connection = context.getConnection(shardTarget.getClusterAlias(), shardTarget.getNodeId());
} catch (Exception e) {
shardFailure(e, querySearchRequest, shardIndex, shardTarget, counter);
return;
}
searchTransportService.sendExecuteQuery(
connection,
querySearchRequest,
Expand All @@ -112,10 +117,7 @@ protected void innerOnResponse(QuerySearchResult response) {
@Override
public void onFailure(Exception exception) {
try {
context.getLogger()
.debug(() -> "[" + querySearchRequest.contextId() + "] Failed to execute query phase", exception);
progressListener.notifyQueryFailure(shardIndex, shardTarget, exception);
counter.onFailure(shardIndex, shardTarget, exception);
shardFailure(exception, querySearchRequest, shardIndex, shardTarget, counter);
} finally {
if (context.isPartOfPointInTime(querySearchRequest.contextId()) == false) {
// the query might not have been executed at all (for example because thread pool rejected
Expand All @@ -134,6 +136,18 @@ public void onFailure(Exception exception) {
}
}

private void shardFailure(
Exception exception,
QuerySearchRequest querySearchRequest,
int shardIndex,
SearchShardTarget shardTarget,
CountedCollector<SearchPhaseResult> counter
) {
context.getLogger().debug(() -> "[" + querySearchRequest.contextId() + "] Failed to execute query phase", exception);
progressListener.notifyQueryFailure(shardIndex, shardTarget, exception);
counter.onFailure(shardIndex, shardTarget, exception);
}

// package private for testing
ShardSearchRequest rewriteShardSearchRequest(ShardSearchRequest request) {
SearchSourceBuilder source = request.source();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.elasticsearch.search.internal.ShardSearchContextId;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.search.rank.RankDocShardInfo;
import org.elasticsearch.transport.Transport;

import java.util.ArrayList;
import java.util.HashMap;
Expand Down Expand Up @@ -214,9 +215,41 @@ private void executeFetch(
final ShardSearchContextId contextId = shardPhaseResult.queryResult() != null
? shardPhaseResult.queryResult().getContextId()
: shardPhaseResult.rankFeatureResult().getContextId();
var listener = new SearchActionListener<FetchSearchResult>(shardTarget, shardIndex) {
@Override
public void innerOnResponse(FetchSearchResult result) {
try {
progressListener.notifyFetchResult(shardIndex);
counter.onResult(result);
} catch (Exception e) {
context.onPhaseFailure(FetchSearchPhase.this, "", e);
}
}

@Override
public void onFailure(Exception e) {
try {
logger.debug(() -> "[" + contextId + "] Failed to execute fetch phase", e);
progressListener.notifyFetchFailure(shardIndex, shardTarget, e);
counter.onFailure(shardIndex, shardTarget, e);
} finally {
// the search context might not be cleared on the node where the fetch was executed for example
// because the action was rejected by the thread pool. in this case we need to send a dedicated
// request to clear the search context.
releaseIrrelevantSearchContext(shardPhaseResult, context);
}
}
};
final Transport.Connection connection;
try {
connection = context.getConnection(shardTarget.getClusterAlias(), shardTarget.getNodeId());
} catch (Exception e) {
listener.onFailure(e);
return;
}
context.getSearchTransport()
.sendExecuteFetch(
context.getConnection(shardTarget.getClusterAlias(), shardTarget.getNodeId()),
connection,
new ShardFetchSearchRequest(
context.getOriginalIndices(shardPhaseResult.getShardIndex()),
contextId,
Expand All @@ -228,31 +261,7 @@ private void executeFetch(
aggregatedDfs
),
context.getTask(),
new SearchActionListener<>(shardTarget, shardIndex) {
@Override
public void innerOnResponse(FetchSearchResult result) {
try {
progressListener.notifyFetchResult(shardIndex);
counter.onResult(result);
} catch (Exception e) {
context.onPhaseFailure(FetchSearchPhase.this, "", e);
}
}

@Override
public void onFailure(Exception e) {
try {
logger.debug(() -> "[" + contextId + "] Failed to execute fetch phase", e);
progressListener.notifyFetchFailure(shardIndex, shardTarget, e);
counter.onFailure(shardIndex, shardTarget, e);
} finally {
// the search context might not be cleared on the node where the fetch was executed for example
// because the action was rejected by the thread pool. in this case we need to send a dedicated
// request to clear the search context.
releaseIrrelevantSearchContext(shardPhaseResult, context);
}
}
}
listener
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.elasticsearch.search.rank.feature.RankFeatureDoc;
import org.elasticsearch.search.rank.feature.RankFeatureResult;
import org.elasticsearch.search.rank.feature.RankFeatureShardRequest;
import org.elasticsearch.transport.Transport;

import java.util.List;

Expand Down Expand Up @@ -131,38 +132,46 @@ private void executeRankFeatureShardPhase(
final SearchShardTarget shardTarget = queryResult.queryResult().getSearchShardTarget();
final ShardSearchContextId contextId = queryResult.queryResult().getContextId();
final int shardIndex = queryResult.getShardIndex();
var listener = new SearchActionListener<RankFeatureResult>(shardTarget, shardIndex) {
@Override
protected void innerOnResponse(RankFeatureResult response) {
try {
progressListener.notifyRankFeatureResult(shardIndex);
rankRequestCounter.onResult(response);
} catch (Exception e) {
context.onPhaseFailure(RankFeaturePhase.this, "", e);
}
}

@Override
public void onFailure(Exception e) {
try {
logger.debug(() -> "[" + contextId + "] Failed to execute rank phase", e);
progressListener.notifyRankFeatureFailure(shardIndex, shardTarget, e);
rankRequestCounter.onFailure(shardIndex, shardTarget, e);
} finally {
releaseIrrelevantSearchContext(queryResult, context);
}
}
};
final Transport.Connection connection;
try {
connection = context.getConnection(shardTarget.getClusterAlias(), shardTarget.getNodeId());
} catch (Exception e) {
listener.onFailure(e);
return;
}
context.getSearchTransport()
.sendExecuteRankFeature(
context.getConnection(shardTarget.getClusterAlias(), shardTarget.getNodeId()),
connection,
new RankFeatureShardRequest(
context.getOriginalIndices(queryResult.getShardIndex()),
queryResult.getContextId(),
queryResult.getShardSearchRequest(),
entry
),
context.getTask(),
new SearchActionListener<>(shardTarget, shardIndex) {
@Override
protected void innerOnResponse(RankFeatureResult response) {
try {
progressListener.notifyRankFeatureResult(shardIndex);
rankRequestCounter.onResult(response);
} catch (Exception e) {
context.onPhaseFailure(RankFeaturePhase.this, "", e);
}
}

@Override
public void onFailure(Exception e) {
try {
logger.debug(() -> "[" + contextId + "] Failed to execute rank phase", e);
progressListener.notifyRankFeatureFailure(shardIndex, shardTarget, e);
rankRequestCounter.onFailure(shardIndex, shardTarget, e);
} finally {
releaseIrrelevantSearchContext(queryResult, context);
}
}
}
listener
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,14 @@ protected void executePhaseOnShard(
final SearchShardTarget shard,
final SearchActionListener<DfsSearchResult> listener
) {
getSearchTransport().sendExecuteDfs(
getConnection(shard.getClusterAlias(), shard.getNodeId()),
buildShardSearchRequest(shardIt, listener.requestIndex),
getTask(),
listener
);
final Transport.Connection connection;
try {
connection = getConnection(shard.getClusterAlias(), shard.getNodeId());
} catch (Exception e) {
listener.onFailure(e);
return;
}
getSearchTransport().sendExecuteDfs(connection, buildShardSearchRequest(shardIt, listener.requestIndex), getTask(), listener);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,15 @@ protected void executePhaseOnShard(
final SearchShardTarget shard,
final SearchActionListener<SearchPhaseResult> listener
) {
final Transport.Connection connection;
try {
connection = getConnection(shard.getClusterAlias(), shard.getNodeId());
} catch (Exception e) {
listener.onFailure(e);
return;
}
ShardSearchRequest request = rewriteShardSearchRequest(super.buildShardSearchRequest(shardIt, listener.requestIndex));
getSearchTransport().sendExecuteQuery(getConnection(shard.getClusterAlias(), shard.getNodeId()), request, getTask(), listener);
getSearchTransport().sendExecuteQuery(connection, request, getTask(), listener);
}

@Override
Expand Down

0 comments on commit 78a531b

Please sign in to comment.