From f2cb9f7111a75ccd28d996359f109229582d2f39 Mon Sep 17 00:00:00 2001 From: kkewwei Date: Thu, 21 Nov 2024 20:35:11 +0800 Subject: [PATCH] Coordinator can return partial results after the timeout when allow_partial_search_results is true Signed-off-by: kkewwei Signed-off-by: kkewwei --- CHANGELOG.md | 1 + .../action/search/SearchRequest.java | 47 ++++++++++- .../opensearch/action/search/SearchTask.java | 26 ++++++- .../action/search/SearchTransportService.java | 78 ++++++++++++++++++- .../rest/action/search/RestSearchAction.java | 6 ++ .../AbstractSearchAsyncActionTests.java | 73 ++++++++++++++++- .../action/search/SearchRequestTests.java | 22 ++++++ .../search/SearchTransportServiceTests.java | 36 +++++++++ 8 files changed, 280 insertions(+), 9 deletions(-) create mode 100644 server/src/test/java/org/opensearch/action/search/SearchTransportServiceTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 9cfcd4e6dfbd1..9ac4577ededae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Support prefix list for remote repository attributes([#16271](https://github.com/opensearch-project/OpenSearch/pull/16271)) - Add new configuration setting `synonym_analyzer`, to the `synonym` and `synonym_graph` filters, enabling the specification of a custom analyzer for reading the synonym file ([#16488](https://github.com/opensearch-project/OpenSearch/pull/16488)). - Add stats for remote publication failure and move download failure stats to remote methods([#16682](https://github.com/opensearch-project/OpenSearch/pull/16682/)) +- Coordinator can return partial results after the timeout when allow_partial_search_results is true ([#16681](https://github.com/opensearch-project/OpenSearch/pull/16681)). ### Dependencies - Bump `com.google.cloud:google-cloud-core-http` from 2.23.0 to 2.47.0 ([#16504](https://github.com/opensearch-project/OpenSearch/pull/16504)) diff --git a/server/src/main/java/org/opensearch/action/search/SearchRequest.java b/server/src/main/java/org/opensearch/action/search/SearchRequest.java index 4d3bb868b779a..65c6a897a96b2 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchRequest.java +++ b/server/src/main/java/org/opensearch/action/search/SearchRequest.java @@ -84,6 +84,8 @@ public class SearchRequest extends ActionRequest implements IndicesRequest.Repla private static final long DEFAULT_ABSOLUTE_START_MILLIS = -1; + public static final float DEFAULT_QUERY_PHASE_TIMEOUT_PERCENTAGE = 0.8f; + private final String localClusterAlias; private final long absoluteStartMillis; private final boolean finalReduce; @@ -123,10 +125,16 @@ public class SearchRequest extends ActionRequest implements IndicesRequest.Repla private Boolean phaseTook = null; + // it's only been used in coordinator, so we don't need to serialize/deserialize it + private long startTimeMills; + + private float queryPhaseTimeoutPercentage = 0.8f; + public SearchRequest() { this.localClusterAlias = null; this.absoluteStartMillis = DEFAULT_ABSOLUTE_START_MILLIS; this.finalReduce = true; + this.startTimeMills = System.currentTimeMillis(); } /** @@ -228,6 +236,8 @@ private SearchRequest( this.finalReduce = finalReduce; this.cancelAfterTimeInterval = searchRequest.cancelAfterTimeInterval; this.phaseTook = searchRequest.phaseTook; + this.startTimeMills = searchRequest.startTimeMills; + this.queryPhaseTimeoutPercentage = searchRequest.queryPhaseTimeoutPercentage; } /** @@ -275,6 +285,7 @@ public SearchRequest(StreamInput in) throws IOException { if (in.getVersion().onOrAfter(Version.V_2_12_0)) { phaseTook = in.readOptionalBoolean(); } + startTimeMills = -1; } @Override @@ -347,6 +358,10 @@ public ActionRequestValidationException validate() { validationException = addValidationError("using [point in time] is not allowed in a scroll context", validationException); } } + + if (queryPhaseTimeoutPercentage <= 0 || queryPhaseTimeoutPercentage > 1) { + validationException = addValidationError("[query_phase_timeout_percentage] must be in (0, 1]", validationException); + } return validationException; } @@ -711,9 +726,31 @@ public String pipeline() { return pipeline; } + public void setQueryPhaseTimeoutPercentage(float queryPhaseTimeoutPercentage) { + if (source.timeout() == null) { + throw new IllegalArgumentException("timeout must be set before setting queryPhaseTimeoutPercentage"); + } + if (source.size() == 0) { + this.queryPhaseTimeoutPercentage = 1; + } else { + this.queryPhaseTimeoutPercentage = queryPhaseTimeoutPercentage; + } + } + @Override public SearchTask createTask(long id, String type, String action, TaskId parentTaskId, Map headers) { - return new SearchTask(id, type, action, this::buildDescription, parentTaskId, headers, cancelAfterTimeInterval); + return new SearchTask( + id, + type, + action, + this::buildDescription, + parentTaskId, + headers, + cancelAfterTimeInterval, + startTimeMills, + (source != null && source.timeout() != null) ? source.timeout().millis() : -1, + queryPhaseTimeoutPercentage + ); } public final String buildDescription() { @@ -765,7 +802,8 @@ public boolean equals(Object o) { && ccsMinimizeRoundtrips == that.ccsMinimizeRoundtrips && Objects.equals(cancelAfterTimeInterval, that.cancelAfterTimeInterval) && Objects.equals(pipeline, that.pipeline) - && Objects.equals(phaseTook, that.phaseTook); + && Objects.equals(phaseTook, that.phaseTook) + && Objects.equals(queryPhaseTimeoutPercentage, that.queryPhaseTimeoutPercentage); } @Override @@ -787,7 +825,8 @@ public int hashCode() { absoluteStartMillis, ccsMinimizeRoundtrips, cancelAfterTimeInterval, - phaseTook + phaseTook, + queryPhaseTimeoutPercentage ); } @@ -832,6 +871,8 @@ public String toString() { + pipeline + ", phaseTook=" + phaseTook + + ", queryPhaseTimeoutPercentage=" + + queryPhaseTimeoutPercentage + "}"; } } diff --git a/server/src/main/java/org/opensearch/action/search/SearchTask.java b/server/src/main/java/org/opensearch/action/search/SearchTask.java index 2a1a961e7607b..7460095b41d26 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchTask.java +++ b/server/src/main/java/org/opensearch/action/search/SearchTask.java @@ -53,6 +53,9 @@ public class SearchTask extends QueryGroupTask implements SearchBackpressureTask // generating description in a lazy way since source can be quite big private final Supplier descriptionSupplier; private SearchProgressListener progressListener = SearchProgressListener.NOOP; + private final long startTimeMills; + private final long timeoutMills; + private final float queryPhaseTimeoutPercentage; public SearchTask( long id, @@ -62,7 +65,7 @@ public SearchTask( TaskId parentTaskId, Map headers ) { - this(id, type, action, descriptionSupplier, parentTaskId, headers, NO_TIMEOUT); + this(id, type, action, descriptionSupplier, parentTaskId, headers, NO_TIMEOUT, -1, -1, 0.8f); } public SearchTask( @@ -72,10 +75,17 @@ public SearchTask( Supplier descriptionSupplier, TaskId parentTaskId, Map headers, - TimeValue cancelAfterTimeInterval + TimeValue cancelAfterTimeInterval, + long startTimeMills, + long timeoutMills, + float queryPhaseTimeoutPercentage ) { super(id, type, action, null, parentTaskId, headers, cancelAfterTimeInterval); this.descriptionSupplier = descriptionSupplier; + this.startTimeMills = startTimeMills; + this.timeoutMills = timeoutMills; + assert queryPhaseTimeoutPercentage > 0 && queryPhaseTimeoutPercentage <= 1; + this.queryPhaseTimeoutPercentage = queryPhaseTimeoutPercentage; } @Override @@ -106,4 +116,16 @@ public final SearchProgressListener getProgressListener() { public boolean shouldCancelChildrenOnCancellation() { return true; } + + public long startTimeMills() { + return startTimeMills; + } + + public long timeoutMills() { + return timeoutMills; + } + + public long queryPhaseTimeout() { + return (long) (timeoutMills * queryPhaseTimeoutPercentage); + } } diff --git a/server/src/main/java/org/opensearch/action/search/SearchTransportService.java b/server/src/main/java/org/opensearch/action/search/SearchTransportService.java index 64c738f633f2e..0ba8cdd5d3b94 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchTransportService.java +++ b/server/src/main/java/org/opensearch/action/search/SearchTransportService.java @@ -44,6 +44,7 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.tasks.TaskCancelledException; import org.opensearch.core.transport.TransportResponse; import org.opensearch.ratelimitting.admissioncontrol.enums.AdmissionControlActionType; import org.opensearch.search.SearchPhaseResult; @@ -76,6 +77,7 @@ import java.util.Map; import java.util.Objects; import java.util.function.BiFunction; +import java.util.function.Consumer; /** * An encapsulation of {@link org.opensearch.search.SearchService} operations exposed through @@ -167,12 +169,18 @@ public void createPitContext( SearchTask task, ActionListener actionListener ) { + + TransportRequestOptions options = getTransportRequestOptions(task, actionListener::onFailure, false); + if (options == null) { + return; + } + transportService.sendChildRequest( connection, CREATE_READER_CONTEXT_ACTION_NAME, request, task, - TransportRequestOptions.EMPTY, + options, new ActionListenerResponseHandler<>(actionListener, TransportCreatePitAction.CreateReaderContextResponse::new) ); } @@ -183,12 +191,18 @@ public void sendCanMatch( SearchTask task, final ActionListener listener ) { + + TransportRequestOptions options = getTransportRequestOptions(task, listener::onFailure, false); + if (options == null) { + return; + } + transportService.sendChildRequest( connection, QUERY_CAN_MATCH_NAME, request, task, - TransportRequestOptions.EMPTY, + options, new ActionListenerResponseHandler<>(listener, SearchService.CanMatchResponse::new) ); } @@ -223,11 +237,18 @@ public void sendExecuteDfs( SearchTask task, final SearchActionListener listener ) { + + TransportRequestOptions options = getTransportRequestOptions(task, listener::onFailure, true); + if (options == null) { + return; + } + transportService.sendChildRequest( connection, DFS_ACTION_NAME, request, task, + options, new ConnectionCountingHandler<>(listener, DfsSearchResult::new, clientConnections, connection.getNode().getId()) ); } @@ -243,12 +264,18 @@ public void sendExecuteQuery( final boolean fetchDocuments = request.numberOfShards() == 1; Writeable.Reader reader = fetchDocuments ? QueryFetchSearchResult::new : QuerySearchResult::new; + TransportRequestOptions options = getTransportRequestOptions(task, listener::onFailure, true); + if (options == null) { + return; + } + final ActionListener handler = responseWrapper.apply(connection, listener); transportService.sendChildRequest( connection, QUERY_ACTION_NAME, request, task, + options, new ConnectionCountingHandler<>(handler, reader, clientConnections, connection.getNode().getId()) ); } @@ -259,11 +286,18 @@ public void sendExecuteQuery( SearchTask task, final SearchActionListener listener ) { + + TransportRequestOptions options = getTransportRequestOptions(task, listener::onFailure, true); + if (options == null) { + return; + } + transportService.sendChildRequest( connection, QUERY_ID_ACTION_NAME, request, task, + options, new ConnectionCountingHandler<>(listener, QuerySearchResult::new, clientConnections, connection.getNode().getId()) ); } @@ -274,11 +308,18 @@ public void sendExecuteScrollQuery( SearchTask task, final SearchActionListener listener ) { + + TransportRequestOptions options = getTransportRequestOptions(task, listener::onFailure, false); + if (options == null) { + return; + } + transportService.sendChildRequest( connection, QUERY_SCROLL_ACTION_NAME, request, task, + options, new ConnectionCountingHandler<>(listener, ScrollQuerySearchResult::new, clientConnections, connection.getNode().getId()) ); } @@ -323,11 +364,17 @@ private void sendExecuteFetch( SearchTask task, final SearchActionListener listener ) { + TransportRequestOptions options = getTransportRequestOptions(task, listener::onFailure, false); + if (options == null) { + return; + } + transportService.sendChildRequest( connection, action, request, task, + options, new ConnectionCountingHandler<>(listener, FetchSearchResult::new, clientConnections, connection.getNode().getId()) ); } @@ -337,15 +384,42 @@ private void sendExecuteFetch( */ void sendExecuteMultiSearch(final MultiSearchRequest request, SearchTask task, final ActionListener listener) { final Transport.Connection connection = transportService.getConnection(transportService.getLocalNode()); + + TransportRequestOptions options = getTransportRequestOptions(task, listener::onFailure, false); + if (options == null) { + return; + } + transportService.sendChildRequest( connection, MultiSearchAction.NAME, request, task, + options, new ConnectionCountingHandler<>(listener, MultiSearchResponse::new, clientConnections, connection.getNode().getId()) ); } + static TransportRequestOptions getTransportRequestOptions(SearchTask task, Consumer onFailure, boolean queryPhase) { + if (task != null && task.timeoutMills() > 0) { + long leftTimeMills; + if (queryPhase) { + // it's costly in query phase. + leftTimeMills = task.queryPhaseTimeout() - (System.currentTimeMillis() - task.startTimeMills()); + } else { + leftTimeMills = task.timeoutMills() - (System.currentTimeMillis() - task.startTimeMills()); + } + if (leftTimeMills <= 0) { + onFailure.accept(new TaskCancelledException("failed to execute fetch phase, timeout exceeded" + leftTimeMills + "ms")); + return null; + } else { + return TransportRequestOptions.builder().withTimeout(leftTimeMills).build(); + } + } else { + return TransportRequestOptions.EMPTY; + } + } + public RemoteClusterService getRemoteClusterService() { return transportService.getRemoteClusterService(); } diff --git a/server/src/main/java/org/opensearch/rest/action/search/RestSearchAction.java b/server/src/main/java/org/opensearch/rest/action/search/RestSearchAction.java index 05465e32631fd..afd4160f93279 100644 --- a/server/src/main/java/org/opensearch/rest/action/search/RestSearchAction.java +++ b/server/src/main/java/org/opensearch/rest/action/search/RestSearchAction.java @@ -224,6 +224,12 @@ public static void parseSearchRequest( } searchRequest.setCancelAfterTimeInterval(request.paramAsTime("cancel_after_time_interval", null)); + + if (request.hasParam("query_phase_timeout_percentage")) { + searchRequest.setQueryPhaseTimeoutPercentage( + request.paramAsFloat("query_phase_timeout_percentage", SearchRequest.DEFAULT_QUERY_PHASE_TIMEOUT_PERCENTAGE) + ); + } } /** diff --git a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java index 27336e86e52b0..aef80a38b5b50 100644 --- a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java +++ b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java @@ -36,11 +36,13 @@ import org.opensearch.action.OriginalIndices; import org.opensearch.action.support.IndicesOptions; import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.routing.GroupShardsIterator; import org.opensearch.common.UUIDs; import org.opensearch.common.collect.Tuple; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.AtomicArray; import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.common.util.set.Sets; @@ -55,6 +57,7 @@ import org.opensearch.index.shard.ShardNotFoundException; import org.opensearch.search.SearchPhaseResult; import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.internal.AliasFilter; import org.opensearch.search.internal.InternalSearchResponse; import org.opensearch.search.internal.ShardSearchContextId; @@ -65,6 +68,7 @@ import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.ReceiveTimeoutTransportException; import org.opensearch.transport.Transport; import org.junit.After; import org.junit.Before; @@ -89,6 +93,9 @@ import java.util.function.BiFunction; import java.util.stream.IntStream; +import org.mockito.Mockito; + +import static org.opensearch.action.search.SearchTransportService.QUERY_ACTION_NAME; import static org.opensearch.tasks.TaskResourceTrackingService.TASK_RESOURCE_USAGE; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThanOrEqualTo; @@ -138,6 +145,7 @@ private AbstractSearchAsyncAction createAction( false, expected, resourceUsage, + false, new SearchShardIterator(null, null, Collections.emptyList(), null) ); } @@ -151,6 +159,7 @@ private AbstractSearchAsyncAction createAction( final boolean catchExceptionWhenExecutePhaseOnShard, final AtomicLong expected, final TaskResourceUsage resourceUsage, + final boolean blockTheFirstQueryPhase, final SearchShardIterator... shards ) { @@ -179,7 +188,7 @@ private AbstractSearchAsyncAction createAction( .setNodeId(randomAlphaOfLengthBetween(1, 5)) .build(); threadPool.getThreadContext().addResponseHeader(TASK_RESOURCE_USAGE, taskResourceInfo.toString()); - + AtomicBoolean firstShard = new AtomicBoolean(true); return new AbstractSearchAsyncAction( "test", logger, @@ -207,7 +216,17 @@ private AbstractSearchAsyncAction createAction( ) { @Override protected SearchPhase getNextPhase(final SearchPhaseResults results, SearchPhaseContext context) { - return null; + if (blockTheFirstQueryPhase) { + return new SearchPhase("test") { + @Override + public void run() { + listener.onResponse(new SearchResponse(null, null, 0, 0, 0, 0, null, null)); + assertingListener.onPhaseEnd(context, null); + } + }; + } else { + return null; + } } @Override @@ -218,6 +237,17 @@ protected void executePhaseOnShard( ) { if (failExecutePhaseOnShard) { listener.onFailure(new ShardNotFoundException(shardIt.shardId())); + } else if (blockTheFirstQueryPhase && firstShard.compareAndSet(true, false)) { + // Sleep and throw ReceiveTimeoutTransportException to simulate node blocked + try { + Thread.sleep(request.source().timeout().millis()); + } catch (InterruptedException e) {} + ; + DiscoveryNode node = Mockito.mock(DiscoveryNode.class); + Mockito.when(node.getName()).thenReturn("test_nodes"); + listener.onFailure( + new ReceiveTimeoutTransportException(node, QUERY_ACTION_NAME, "request_id [171] timed out after [413ms]") + ); } else { if (catchExceptionWhenExecutePhaseOnShard) { try { @@ -227,6 +257,7 @@ protected void executePhaseOnShard( } } else { listener.onResponse(new QuerySearchResult()); + } } } @@ -587,6 +618,7 @@ public void onFailure(Exception e) { false, new AtomicLong(), new TaskResourceUsage(randomLong(), randomLong()), + false, shards ); action.run(); @@ -635,6 +667,7 @@ public void onFailure(Exception e) { false, new AtomicLong(), new TaskResourceUsage(randomLong(), randomLong()), + false, shards ); action.run(); @@ -688,6 +721,7 @@ public void onFailure(Exception e) { catchExceptionWhenExecutePhaseOnShard, new AtomicLong(), new TaskResourceUsage(randomLong(), randomLong()), + false, shards ); action.run(); @@ -791,6 +825,41 @@ public void testOnPhaseListenersWithDfsType() throws InterruptedException { assertEquals(0, testListener.getPhaseCurrent(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName())); } + public void testExecutePhaseOnShardBlockAndRetrunPartialResult() { + // on shard is blocked in query phase + final Index index = new Index("test", UUID.randomUUID().toString()); + + final SearchShardIterator[] shards = IntStream.range(0, 2 + randomInt(4)) + .mapToObj(i -> new SearchShardIterator(null, new ShardId(index, i), List.of("n1"), null, null, null)) + .toArray(SearchShardIterator[]::new); + + SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true); + searchRequest.source(new SearchSourceBuilder()); + long timeoutMills = 500; + searchRequest.source().timeout(new TimeValue(timeoutMills, TimeUnit.MILLISECONDS)); + searchRequest.setMaxConcurrentShardRequests(shards.length); + final AtomicBoolean successed = new AtomicBoolean(false); + long current = System.currentTimeMillis(); + + final ArraySearchPhaseResults queryResult = new ArraySearchPhaseResults<>(shards.length); + AbstractSearchAsyncAction action = createAction(searchRequest, queryResult, new ActionListener<>() { + @Override + public void onResponse(SearchResponse response) { + successed.set(true); + } + + @Override + public void onFailure(Exception e) { + successed.set(false); + } + }, false, false, false, new AtomicLong(), new TaskResourceUsage(randomLong(), randomLong()), true, shards); + action.run(); + long s = System.currentTimeMillis() - current; + assertTrue(s > timeoutMills); + assertTrue(successed.get()); + + } + private SearchDfsQueryThenFetchAsyncAction createSearchDfsQueryThenFetchAsyncAction( List searchRequestOperationsListeners ) { diff --git a/server/src/test/java/org/opensearch/action/search/SearchRequestTests.java b/server/src/test/java/org/opensearch/action/search/SearchRequestTests.java index acda1445bacbb..5b51f0d2f8f40 100644 --- a/server/src/test/java/org/opensearch/action/search/SearchRequestTests.java +++ b/server/src/test/java/org/opensearch/action/search/SearchRequestTests.java @@ -238,6 +238,15 @@ public void testValidate() throws IOException { assertEquals(1, validationErrors.validationErrors().size()); assertEquals("using [point in time] is not allowed in a scroll context", validationErrors.validationErrors().get(0)); } + + { + // queryPhaseTimeoutPercentage must be in (0, 1) + SearchRequest searchRequest = createSearchRequest().source(new SearchSourceBuilder().timeout(TimeValue.timeValueMillis(10))); + searchRequest.setQueryPhaseTimeoutPercentage(-1); + ActionRequestValidationException validationErrors = searchRequest.validate(); + assertNotNull(validationErrors); + assertEquals("[query_phase_timeout_percentage] must be in (0, 1]", validationErrors.validationErrors().get(0)); + } } public void testCopyConstructor() throws IOException { @@ -261,6 +270,19 @@ public void testParseSearchRequestWithUnsupportedSearchType() throws IOException assertEquals("Unsupported search type [query_and_fetch]", exception.getMessage()); } + public void testParseSearchRequestWithTimeoutAndQueryPhaseTimeoutPercentage() throws IOException { + RestRequest restRequest = new FakeRestRequest(); + SearchRequest searchRequest = createSearchRequest().source(new SearchSourceBuilder()); + IntConsumer setSize = mock(IntConsumer.class); + restRequest.params().put("query_phase_timeout_percentage", "30"); + + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> RestSearchAction.parseSearchRequest(searchRequest, restRequest, null, namedWriteableRegistry, setSize) + ); + assertEquals("timeout must be set before setting queryPhaseTimeoutPercentage", exception.getMessage()); + } + public void testEqualsAndHashcode() throws IOException { checkEqualsAndHashCode(createSearchRequest(), SearchRequest::new, this::mutate); } diff --git a/server/src/test/java/org/opensearch/action/search/SearchTransportServiceTests.java b/server/src/test/java/org/opensearch/action/search/SearchTransportServiceTests.java new file mode 100644 index 0000000000000..e389ff4b50e5c --- /dev/null +++ b/server/src/test/java/org/opensearch/action/search/SearchTransportServiceTests.java @@ -0,0 +1,36 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.action.search; + +import org.opensearch.core.tasks.TaskCancelledException; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.transport.TransportRequestOptions; + +public class SearchTransportServiceTests extends OpenSearchTestCase { + public void testGetTransportRequestOptions() { + SearchTask searchTask = new SearchTask(1, null, null, null, null, null, null, System.currentTimeMillis(), 1000, 0.8f); + TransportRequestOptions transportRequestOptions = SearchTransportService.getTransportRequestOptions(searchTask, e -> {}, true); + assertTrue(transportRequestOptions.timeout().millis() > 0); + + TransportRequestOptions transportRequestOptions1 = SearchTransportService.getTransportRequestOptions(searchTask, e -> {}, false); + assertTrue(transportRequestOptions.timeout().millis() < transportRequestOptions1.timeout().millis()); + + SearchTask searchTask1 = new SearchTask(1, null, null, null, null, null, null, System.currentTimeMillis(), 1, 0.8f); + + transportRequestOptions = SearchTransportService.getTransportRequestOptions(searchTask1, exception -> { + assertEquals(TaskCancelledException.class, exception.getClass()); + assertTrue(exception.getMessage().contains("failed to execute fetch phase, timeout exceeded")); + }, true); + assertNull(transportRequestOptions); + + searchTask = new SearchTask(1, null, null, null, null, null, null, System.currentTimeMillis(), 0, 0.8f); + transportRequestOptions = SearchTransportService.getTransportRequestOptions(searchTask, e -> {}, false); + assertEquals(TransportRequestOptions.EMPTY, transportRequestOptions); + } +}