Skip to content

Commit

Permalink
opensearch should returns partial results after the timeout in coordi…
Browse files Browse the repository at this point in the history
…nate node when allow_partial_search_results is true
  • Loading branch information
kkewwei committed Nov 19, 2024
1 parent 9b3ee09 commit 3b9454f
Show file tree
Hide file tree
Showing 7 changed files with 157 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ public class SearchRequest extends ActionRequest implements IndicesRequest.Repla
// it's only been used in coordinator, so we don't need to serialize/deserialize it
private long startTimeMills;

private float queryPhaseTimeoutPercentage;
private float queryPhaseTimeoutPercentage = 0.8f;

public SearchRequest() {
this.localClusterAlias = null;
Expand Down Expand Up @@ -358,6 +358,10 @@ public ActionRequestValidationException validate() {
validationException = addValidationError("using [point in time] is not allowed in a scroll context", validationException);
}
}

if (queryPhaseTimeoutPercentage <= 0 || queryPhaseTimeoutPercentage >= 1) {
validationException = addValidationError("[queryPhaseTimeoutPercentage] must be in (0, 1)", validationException);
}
return validationException;
}

Expand Down Expand Up @@ -722,21 +726,27 @@ public String pipeline() {
return pipeline;
}


public void setQueryPhaseTimeoutPercentage(float queryPhaseTimeoutPercentage) {
if (source.timeout() == null) {
throw new IllegalArgumentException("timeout must be set before setting query phase timeout percentage");
throw new IllegalArgumentException("timeout must be set before setting queryPhaseTimeoutPercentage");
}
this.queryPhaseTimeoutPercentage = queryPhaseTimeoutPercentage;
}

public float getQueryPhasePercentage() {
return queryPhaseTimeoutPercentage;
}

@Override
public SearchTask createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
return new SearchTask(id, type, action, this::buildDescription, parentTaskId, headers, cancelAfterTimeInterval, startTimeMills, source.timeout() != null? source.timeout().millis() : -1, queryPhaseTimeoutPercentage);
return new SearchTask(
id,
type,
action,
this::buildDescription,
parentTaskId,
headers,
cancelAfterTimeInterval,
startTimeMills,
(source != null && source.timeout() != null) ? source.timeout().millis() : -1,
queryPhaseTimeoutPercentage
);
}

public final String buildDescription() {
Expand Down Expand Up @@ -788,7 +798,8 @@ public boolean equals(Object o) {
&& ccsMinimizeRoundtrips == that.ccsMinimizeRoundtrips
&& Objects.equals(cancelAfterTimeInterval, that.cancelAfterTimeInterval)
&& Objects.equals(pipeline, that.pipeline)
&& Objects.equals(phaseTook, that.phaseTook);
&& Objects.equals(phaseTook, that.phaseTook)
&& Objects.equals(queryPhaseTimeoutPercentage, that.queryPhaseTimeoutPercentage);
}

@Override
Expand All @@ -810,7 +821,8 @@ public int hashCode() {
absoluteStartMillis,
ccsMinimizeRoundtrips,
cancelAfterTimeInterval,
phaseTook
phaseTook,
queryPhaseTimeoutPercentage
);
}

Expand Down Expand Up @@ -855,6 +867,8 @@ public String toString() {
+ pipeline
+ ", phaseTook="
+ phaseTook
+ ", queryPhaseTimeoutPercentage="
+ queryPhaseTimeoutPercentage
+ "}";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ public SearchTask(
this.descriptionSupplier = descriptionSupplier;
this.startTimeMills = startTimeMills;
this.timeoutMills = timeoutMills;
assert queryPhaseTimeoutPercentage > 0 && queryPhaseTimeoutPercentage <= 1;
this.queryPhaseTimeoutPercentage = queryPhaseTimeoutPercentage;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -400,17 +400,17 @@ void sendExecuteMultiSearch(final MultiSearchRequest request, SearchTask task, f
);
}

public TransportRequestOptions getTransportRequestOptions(SearchTask task, Consumer<Exception> onFailure, boolean queryPhase) {
if (task.timeoutMills() > 0) {
static TransportRequestOptions getTransportRequestOptions(SearchTask task, Consumer<Exception> onFailure, boolean queryPhase) {
if (task != null && task.timeoutMills() > 0) {
long leftTimeMills;
if (queryPhase) {
//it's costly in query phase.
// it's costly in query phase.
leftTimeMills = task.queryPhaseTimeout() - (System.currentTimeMillis() - task.startTimeMills());
} else {
leftTimeMills = task.timeoutMills() - (System.currentTimeMillis() - task.startTimeMills());
}
if (leftTimeMills <= 0) {
onFailure.accept(new TaskCancelledException("failed to execute fetch phase, timeout exceeded"));
onFailure.accept(new TaskCancelledException("failed to execute fetch phase, timeout exceeded" + leftTimeMills + "ms"));
return null;
} else {
return TransportRequestOptions.builder().withTimeout(leftTimeMills).build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,9 @@ public static void parseSearchRequest(

searchRequest.setCancelAfterTimeInterval(request.paramAsTime("cancel_after_time_interval", null));

searchRequest.setQueryPhaseTimeoutPercentage(request.paramAsFloat("query_phase_timeout_percentage", SearchRequest.DEFAULT_QUERY_PHASE_TIMEOUT_PERCENTAGE));
searchRequest.setQueryPhaseTimeoutPercentage(
request.paramAsFloat("query_phase_timeout_percentage", SearchRequest.DEFAULT_QUERY_PHASE_TIMEOUT_PERCENTAGE)
);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,13 @@
import org.opensearch.action.OriginalIndices;
import org.opensearch.action.support.IndicesOptions;
import org.opensearch.cluster.ClusterState;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.routing.GroupShardsIterator;
import org.opensearch.common.UUIDs;
import org.opensearch.common.collect.Tuple;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.concurrent.AtomicArray;
import org.opensearch.common.util.concurrent.OpenSearchExecutors;
import org.opensearch.common.util.set.Sets;
Expand All @@ -55,6 +57,7 @@
import org.opensearch.index.shard.ShardNotFoundException;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.SearchShardTarget;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.internal.AliasFilter;
import org.opensearch.search.internal.InternalSearchResponse;
import org.opensearch.search.internal.ShardSearchContextId;
Expand All @@ -65,6 +68,7 @@
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.TestThreadPool;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.ReceiveTimeoutTransportException;
import org.opensearch.transport.Transport;
import org.junit.After;
import org.junit.Before;
Expand All @@ -89,6 +93,9 @@
import java.util.function.BiFunction;
import java.util.stream.IntStream;

import org.mockito.Mockito;

import static org.opensearch.action.search.SearchTransportService.QUERY_ACTION_NAME;
import static org.opensearch.tasks.TaskResourceTrackingService.TASK_RESOURCE_USAGE;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
Expand Down Expand Up @@ -138,6 +145,7 @@ private AbstractSearchAsyncAction<SearchPhaseResult> createAction(
false,
expected,
resourceUsage,
false,
new SearchShardIterator(null, null, Collections.emptyList(), null)
);
}
Expand All @@ -151,6 +159,7 @@ private AbstractSearchAsyncAction<SearchPhaseResult> createAction(
final boolean catchExceptionWhenExecutePhaseOnShard,
final AtomicLong expected,
final TaskResourceUsage resourceUsage,
final boolean blockTheFirstQueryPhase,
final SearchShardIterator... shards
) {

Expand Down Expand Up @@ -179,7 +188,7 @@ private AbstractSearchAsyncAction<SearchPhaseResult> createAction(
.setNodeId(randomAlphaOfLengthBetween(1, 5))
.build();
threadPool.getThreadContext().addResponseHeader(TASK_RESOURCE_USAGE, taskResourceInfo.toString());

AtomicBoolean firstShard = new AtomicBoolean(true);
return new AbstractSearchAsyncAction<SearchPhaseResult>(
"test",
logger,
Expand Down Expand Up @@ -207,7 +216,13 @@ private AbstractSearchAsyncAction<SearchPhaseResult> createAction(
) {
@Override
protected SearchPhase getNextPhase(final SearchPhaseResults<SearchPhaseResult> results, SearchPhaseContext context) {
return null;
return new SearchPhase("test") {
@Override
public void run() {
listener.onResponse(new SearchResponse(null, null, 0, 0, 0, 0, null, null));
assertingListener.onPhaseEnd(context, null);
}
};
}

@Override
Expand All @@ -218,6 +233,17 @@ protected void executePhaseOnShard(
) {
if (failExecutePhaseOnShard) {
listener.onFailure(new ShardNotFoundException(shardIt.shardId()));
} else if (blockTheFirstQueryPhase && firstShard.compareAndSet(true, false)) {
// Sleep and throw ReceiveTimeoutTransportException to simulate node blocked
try {
Thread.sleep(request.source().timeout().millis());
} catch (InterruptedException e) {}
;
DiscoveryNode node = Mockito.mock(DiscoveryNode.class);
Mockito.when(node.getName()).thenReturn("test_nodes");
listener.onFailure(
new ReceiveTimeoutTransportException(node, QUERY_ACTION_NAME, "request_id [171] timed out after [413ms]")
);
} else {
if (catchExceptionWhenExecutePhaseOnShard) {
try {
Expand All @@ -227,6 +253,7 @@ protected void executePhaseOnShard(
}
} else {
listener.onResponse(new QuerySearchResult());

}
}
}
Expand Down Expand Up @@ -587,6 +614,7 @@ public void onFailure(Exception e) {
false,
new AtomicLong(),
new TaskResourceUsage(randomLong(), randomLong()),
false,
shards
);
action.run();
Expand Down Expand Up @@ -635,6 +663,7 @@ public void onFailure(Exception e) {
false,
new AtomicLong(),
new TaskResourceUsage(randomLong(), randomLong()),
false,
shards
);
action.run();
Expand Down Expand Up @@ -688,6 +717,7 @@ public void onFailure(Exception e) {
catchExceptionWhenExecutePhaseOnShard,
new AtomicLong(),
new TaskResourceUsage(randomLong(), randomLong()),
false,
shards
);
action.run();
Expand Down Expand Up @@ -791,6 +821,41 @@ public void testOnPhaseListenersWithDfsType() throws InterruptedException {
assertEquals(0, testListener.getPhaseCurrent(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName()));
}

public void testExecutePhaseOnShardBlockAndRetrunPartialResult() {
// on shard is blocked in query phase
final Index index = new Index("test", UUID.randomUUID().toString());

final SearchShardIterator[] shards = IntStream.range(0, 2 + randomInt(4))
.mapToObj(i -> new SearchShardIterator(null, new ShardId(index, i), List.of("n1"), null, null, null))
.toArray(SearchShardIterator[]::new);

SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true);
searchRequest.source(new SearchSourceBuilder());
long timeoutMills = 500;
searchRequest.source().timeout(new TimeValue(timeoutMills, TimeUnit.MILLISECONDS));
searchRequest.setMaxConcurrentShardRequests(shards.length);
final AtomicBoolean successed = new AtomicBoolean(false);
long current = System.currentTimeMillis();

final ArraySearchPhaseResults<SearchPhaseResult> queryResult = new ArraySearchPhaseResults<>(shards.length);
AbstractSearchAsyncAction<SearchPhaseResult> action = createAction(searchRequest, queryResult, new ActionListener<>() {
@Override
public void onResponse(SearchResponse response) {
successed.set(true);
}

@Override
public void onFailure(Exception e) {
successed.set(false);
}
}, false, false, false, new AtomicLong(), new TaskResourceUsage(randomLong(), randomLong()), true, shards);
action.run();
long s = System.currentTimeMillis() - current;
assertTrue(s > timeoutMills);
assertTrue(successed.get());

}

private SearchDfsQueryThenFetchAsyncAction createSearchDfsQueryThenFetchAsyncAction(
List<SearchRequestOperationsListener> searchRequestOperationsListeners
) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,15 @@ public void testValidate() throws IOException {
assertEquals(1, validationErrors.validationErrors().size());
assertEquals("using [point in time] is not allowed in a scroll context", validationErrors.validationErrors().get(0));
}

{
// queryPhaseTimeoutPercentage must be in (0, 1)
SearchRequest searchRequest = createSearchRequest().source(new SearchSourceBuilder().timeout(TimeValue.timeValueMillis(10)));
searchRequest.setQueryPhaseTimeoutPercentage(-1);
ActionRequestValidationException validationErrors = searchRequest.validate();
assertNotNull(validationErrors);
assertEquals("[queryPhaseTimeoutPercentage] must be in (0, 1)", validationErrors.validationErrors().get(0));
}
}

public void testCopyConstructor() throws IOException {
Expand All @@ -261,6 +270,19 @@ public void testParseSearchRequestWithUnsupportedSearchType() throws IOException
assertEquals("Unsupported search type [query_and_fetch]", exception.getMessage());
}

public void testParseSearchRequestWithTimeoutAndQueryPhaseTimeoutPercentage() throws IOException {
RestRequest restRequest = new FakeRestRequest();
SearchRequest searchRequest = createSearchRequest().source(new SearchSourceBuilder());
IntConsumer setSize = mock(IntConsumer.class);
restRequest.params().put("query_phase_timeout_percentage", "30");

IllegalArgumentException exception = expectThrows(
IllegalArgumentException.class,
() -> RestSearchAction.parseSearchRequest(searchRequest, restRequest, null, namedWriteableRegistry, setSize)
);
assertEquals("timeout must be set before setting queryPhaseTimeoutPercentage", exception.getMessage());
}

public void testEqualsAndHashcode() throws IOException {
checkEqualsAndHashCode(createSearchRequest(), SearchRequest::new, this::mutate);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.action.search;

import org.opensearch.core.tasks.TaskCancelledException;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.transport.TransportRequestOptions;

public class SearchTransportServiceTests extends OpenSearchTestCase {
public void testGetTransportRequestOptions() {
SearchTask searchTask = new SearchTask(1, null, null, null, null, null, null, System.currentTimeMillis(), 1000, 0.8f);
TransportRequestOptions transportRequestOptions = SearchTransportService.getTransportRequestOptions(searchTask, e -> {}, true);
assertTrue(transportRequestOptions.timeout().millis() > 0);

TransportRequestOptions transportRequestOptions1 = SearchTransportService.getTransportRequestOptions(searchTask, e -> {}, false);
assertTrue(transportRequestOptions.timeout().millis() < transportRequestOptions1.timeout().millis());

SearchTask searchTask1 = new SearchTask(1, null, null, null, null, null, null, System.currentTimeMillis(), 1, 0.8f);

transportRequestOptions = SearchTransportService.getTransportRequestOptions(searchTask1, exception -> {
assertEquals(TaskCancelledException.class, exception.getClass());
assertTrue(exception.getMessage().contains("failed to execute fetch phase, timeout exceeded"));
}, true);
assertNull(transportRequestOptions);

searchTask = new SearchTask(1, null, null, null, null, null, null, System.currentTimeMillis(), 0, 0.8f);
transportRequestOptions = SearchTransportService.getTransportRequestOptions(searchTask, e -> {}, false);
assertEquals(TransportRequestOptions.EMPTY, transportRequestOptions);
}
}

0 comments on commit 3b9454f

Please sign in to comment.