Skip to content

Commit

Permalink
Add more integration tests for flight producer
Browse files Browse the repository at this point in the history
Signed-off-by: Rishabh Maurya <[email protected]>
  • Loading branch information
rishabhmaurya committed Jan 31, 2025
1 parent 2a6590f commit 60a3586
Show file tree
Hide file tree
Showing 24 changed files with 101 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ interface FlushSignal {
/**
* Blocks until the current batch has been consumed or timeout occurs.
*
* @param timeout Maximum milliseconds to wait
* @param timeout Maximum time to wait
*/
void awaitConsumption(int timeout);
void awaitConsumption(TimeValue timeout);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ public void testFlightStreamReader() throws Exception {
// reader should be accessible from any node in the cluster due to the use ProxyStreamProducer
try (StreamReader reader = streamManagerCurrentNode.getStreamReader(ticket)) {
int totalBatches = 0;
assertNotNull(reader.getRoot().getVector("docID"));
while (reader.next()) {
IntVector docIDVector = (IntVector) reader.getRoot().getVector("docID");
assertEquals(10, docIDVector.getValueCount());
Expand Down Expand Up @@ -131,7 +132,6 @@ public void testEarlyCancel() throws Exception {
readerThread.start();
assertTrue("Reader thread did not complete in time", readerComplete.await(1, TimeUnit.SECONDS));

// Check for any exceptions in reader thread
if (readerException.get() != null) {
throw readerException.get();
}
Expand All @@ -147,7 +147,7 @@ public void testEarlyCancel() throws Exception {
reader.close();
}

// Wait for onCancel to complete
// Wait for close to complete
// Due to https://github.com/grpc/grpc-java/issues/5882, there is a logic in FlightStream.java
// where it exhausts the stream on the server side before it is actually cancelled.
assertTrue(
Expand All @@ -158,6 +158,38 @@ public void testEarlyCancel() throws Exception {
}
}

public void testFlightStreamServerError() throws Exception {
DiscoveryNode previousNode = null;
for (DiscoveryNode node : getClusterState().nodes()) {
if (previousNode == null) {
previousNode = node;
continue;
}
StreamManager streamManagerServer = getStreamManager(node.getName());
TestStreamProducer streamProducer = getStreamProducer();
streamProducer.setProduceError(true);
StreamTicket ticket = streamManagerServer.registerStream(streamProducer, null);
StreamManager streamManagerClient = getStreamManager(previousNode.getName());
try (StreamReader reader = streamManagerClient.getStreamReader(ticket)) {
int totalBatches = 0;
assertNotNull(reader.getRoot().getVector("docID"));
try {
while (reader.next()) {
IntVector docIDVector = (IntVector) reader.getRoot().getVector("docID");
assertEquals(10, docIDVector.getValueCount());
totalBatches++;
}
fail("Expected FlightRuntimeException");
} catch (FlightRuntimeException e) {
assertEquals("INTERNAL", e.status().code().name());
assertEquals("There was an error servicing your request.", e.getMessage());
}
assertEquals(1, totalBatches);
}
previousNode = node;
}
}

public void testFlightGetInfo() throws Exception {
StreamTicket ticket = null;
for (DiscoveryNode node : getClusterState().nodes()) {
Expand Down Expand Up @@ -193,6 +225,14 @@ private TestStreamProducer getStreamProducer() {
private static class TestStreamProducer implements StreamProducer {
volatile boolean isClosed = false;
private final CountDownLatch closeLatch = new CountDownLatch(1);
TimeValue deadline = TimeValue.timeValueSeconds(5);
private volatile boolean produceError = false;

public void setProduceError(boolean produceError) {
this.produceError = produceError;
}

TestStreamProducer() {}

VectorSchemaRoot root;

Expand All @@ -214,9 +254,12 @@ public void run(VectorSchemaRoot root, FlushSignal flushSignal) {
for (int i = 0; i < 100; i++) {
docIDVector.setSafe(i % 10, i);
if ((i + 1) % 10 == 0) {
flushSignal.awaitConsumption(1000);
flushSignal.awaitConsumption(TimeValue.timeValueMillis(1000));
docIDVector.clear();
root.setRowCount(10);
if (produceError) {
throw new RuntimeException("Server error while producing batch");
}
}
}
}
Expand All @@ -236,7 +279,7 @@ public boolean isCancelled() {

@Override
public TimeValue getJobDeadline() {
return TimeValue.timeValueSeconds(5);
return deadline;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.arrow.flight.api;
package org.opensearch.arrow.flight.api.flightinfo;

import org.opensearch.client.node.NodeClient;
import org.opensearch.rest.BaseRestHandler;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* compatible open source license.
*/

package org.opensearch.arrow.flight.api;
package org.opensearch.arrow.flight.api.flightinfo;

import org.opensearch.action.support.nodes.BaseNodeResponse;
import org.opensearch.cluster.node.DiscoveryNode;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* compatible open source license.
*/

package org.opensearch.arrow.flight.api;
package org.opensearch.arrow.flight.api.flightinfo;

import org.opensearch.action.ActionType;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* compatible open source license.
*/

package org.opensearch.arrow.flight.api;
package org.opensearch.arrow.flight.api.flightinfo;

import org.opensearch.action.support.nodes.BaseNodesRequest;
import org.opensearch.core.common.io.stream.StreamInput;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* compatible open source license.
*/

package org.opensearch.arrow.flight.api;
package org.opensearch.arrow.flight.api.flightinfo;

import org.opensearch.action.FailedNodeException;
import org.opensearch.action.support.nodes.BaseNodesResponse;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* compatible open source license.
*/

package org.opensearch.arrow.flight.api;
package org.opensearch.arrow.flight.api.flightinfo;

import org.opensearch.action.FailedNodeException;
import org.opensearch.action.support.ActionFilters;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
/**
* Action to retrieve flight info from nodes
*/
package org.opensearch.arrow.flight.api;
package org.opensearch.arrow.flight.api.flightinfo;
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.Version;
import org.opensearch.arrow.flight.api.NodeFlightInfo;
import org.opensearch.arrow.flight.api.NodesFlightInfoAction;
import org.opensearch.arrow.flight.api.NodesFlightInfoRequest;
import org.opensearch.arrow.flight.api.NodesFlightInfoResponse;
import org.opensearch.arrow.flight.api.flightinfo.NodeFlightInfo;
import org.opensearch.arrow.flight.api.flightinfo.NodesFlightInfoAction;
import org.opensearch.arrow.flight.api.flightinfo.NodesFlightInfoRequest;
import org.opensearch.arrow.flight.api.flightinfo.NodesFlightInfoResponse;
import org.opensearch.arrow.flight.bootstrap.tls.SslContextProvider;
import org.opensearch.client.Client;
import org.opensearch.cluster.ClusterChangedEvent;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
package org.opensearch.arrow.flight.bootstrap;

import org.opensearch.arrow.flight.BaseFlightStreamPlugin;
import org.opensearch.arrow.flight.api.FlightServerInfoAction;
import org.opensearch.arrow.flight.api.NodesFlightInfoAction;
import org.opensearch.arrow.flight.api.TransportNodesFlightInfoAction;
import org.opensearch.arrow.flight.api.flightinfo.FlightServerInfoAction;
import org.opensearch.arrow.flight.api.flightinfo.NodesFlightInfoAction;
import org.opensearch.arrow.flight.api.flightinfo.TransportNodesFlightInfoAction;
import org.opensearch.arrow.spi.StreamManager;
import org.opensearch.client.Client;
import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ public void getStream(CallContext context, Ticket ticket, ServerStreamListener l
BackpressureStrategy backpressureStrategy = new BaseBackpressureStrategy(null, batchedJob::onCancel);
backpressureStrategy.register(listener);
StreamProducer.FlushSignal flushSignal = (timeout) -> {
BackpressureStrategy.WaitResult result = backpressureStrategy.waitForListener(timeout);
BackpressureStrategy.WaitResult result = backpressureStrategy.waitForListener(timeout.millis());
if (result.equals(BackpressureStrategy.WaitResult.READY)) {
listener.putNext();
} else if (result.equals(BackpressureStrategy.WaitResult.TIMEOUT)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ public class FlightStreamManager implements StreamManager {
private final Supplier<BufferAllocator> allocatorSupplier;
private final Cache<String, StreamProducerHolder> streamProducers;
// TODO read from setting
private static final TimeValue DEFAULT_CACHE_EXPIRE = TimeValue.timeValueMinutes(10); // Maximum cache time
private static final TimeValue DEFAULT_CACHE_EXPIRE = TimeValue.timeValueMinutes(10);
private static final int MAX_WEIGHT = 1000;

/**
* Holds a StreamProducer along with its metadata and resources
Expand Down Expand Up @@ -99,6 +100,7 @@ public FlightStreamManager(Supplier<BufferAllocator> allocatorSupplier) {
};
this.streamProducers = CacheBuilder.<String, StreamProducerHolder>builder()
.setExpireAfterWrite(DEFAULT_CACHE_EXPIRE)
.setMaximumWeight(MAX_WEIGHT)
.removalListener(onProducerRemoval)
.build();
}
Expand Down Expand Up @@ -196,7 +198,6 @@ public Optional<StreamProducerHolder> removeStreamProducer(StreamTicket ticket)
*/
@Override
public void close() throws Exception {
// TODO: logic to cancel all threads and clear the streamManager queue
streamProducers.values().forEach(holder -> {
try {
holder.producer().close();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ static class ProxyBatchedJob implements BatchedJob {
@Override
public void run(VectorSchemaRoot root, FlushSignal flushSignal) throws Exception {
while (!isCancelled.get() && remoteStream.next()) {
flushSignal.awaitConsumption(1000);
flushSignal.awaitConsumption(TimeValue.timeValueMillis(1000));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

package org.opensearch.arrow.flight;

import org.opensearch.arrow.flight.api.FlightServerInfoAction;
import org.opensearch.arrow.flight.api.NodesFlightInfoAction;
import org.opensearch.arrow.flight.api.flightinfo.FlightServerInfoAction;
import org.opensearch.arrow.flight.api.flightinfo.NodesFlightInfoAction;
import org.opensearch.arrow.flight.bootstrap.FlightService;
import org.opensearch.arrow.spi.StreamManager;
import org.opensearch.cluster.ClusterState;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* compatible open source license.
*/

package org.opensearch.arrow.flight.api;
package org.opensearch.arrow.flight.api.flightinfo;

import org.opensearch.cluster.ClusterName;
import org.opensearch.cluster.node.DiscoveryNode;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* compatible open source license.
*/

package org.opensearch.arrow.flight.api;
package org.opensearch.arrow.flight.api.flightinfo;

import org.opensearch.Version;
import org.opensearch.cluster.node.DiscoveryNode;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* compatible open source license.
*/

package org.opensearch.arrow.flight.api;
package org.opensearch.arrow.flight.api.flightinfo;

import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* compatible open source license.
*/

package org.opensearch.arrow.flight.api;
package org.opensearch.arrow.flight.api.flightinfo;

import org.opensearch.Version;
import org.opensearch.action.FailedNodeException;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package org.opensearch.arrow.flight.api;/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.arrow.flight.api.flightinfo;

import org.opensearch.Version;
import org.opensearch.action.FailedNodeException;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.opensearch.Version;
import org.opensearch.arrow.flight.api.NodeFlightInfo;
import org.opensearch.arrow.flight.api.NodesFlightInfoAction;
import org.opensearch.arrow.flight.api.NodesFlightInfoRequest;
import org.opensearch.arrow.flight.api.NodesFlightInfoResponse;
import org.opensearch.arrow.flight.api.flightinfo.NodeFlightInfo;
import org.opensearch.arrow.flight.api.flightinfo.NodesFlightInfoAction;
import org.opensearch.arrow.flight.api.flightinfo.NodesFlightInfoRequest;
import org.opensearch.arrow.flight.api.flightinfo.NodesFlightInfoResponse;
import org.opensearch.arrow.flight.bootstrap.tls.SslContextProvider;
import org.opensearch.client.Client;
import org.opensearch.cluster.ClusterChangedEvent;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.apache.arrow.vector.ipc.message.IpcOption;
import org.opensearch.arrow.flight.bootstrap.FlightClientManager;
import org.opensearch.arrow.spi.StreamProducer;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.FeatureFlags;
import org.opensearch.test.FeatureFlagSetter;
import org.opensearch.test.OpenSearchTestCase;
Expand Down Expand Up @@ -172,7 +173,7 @@ public void testGetStream_SuccessfulFlow() throws Exception {
});
listener.setReady(false);
clientThread.start();
flushSignal.awaitConsumption(100);
flushSignal.awaitConsumption(TimeValue.timeValueMillis(100));
assertTrue(listener.getDataConsumed());
flushCount.incrementAndGet();
listener.resetConsumptionLatch();
Expand Down Expand Up @@ -215,7 +216,7 @@ public void testGetStream_WithSlowClient() throws Exception {
});
listener.setReady(false);
clientThread.start();
flushSignal.awaitConsumption(300); // waiting for consumption for more than client thread sleep
flushSignal.awaitConsumption(TimeValue.timeValueMillis(300)); // waiting for consumption for more than client thread sleep
assertTrue(listener.getDataConsumed());
flushCount.incrementAndGet();
listener.resetConsumptionLatch();
Expand Down Expand Up @@ -258,7 +259,7 @@ public void testGetStream_WithSlowClientTimeout() throws Exception {
});
listener.setReady(false);
clientThread.start();
flushSignal.awaitConsumption(100); // waiting for consumption for less than client thread sleep
flushSignal.awaitConsumption(TimeValue.timeValueMillis(100)); // waiting for consumption for less than client thread sleep
assertTrue(listener.getDataConsumed());
flushCount.incrementAndGet();
listener.resetConsumptionLatch();
Expand Down Expand Up @@ -302,7 +303,7 @@ public void testGetStream_WithClientCancel() throws Exception {
});
listener.setReady(false);
clientThread.start();
flushSignal.awaitConsumption(100); // waiting for consumption for less than client thread sleep
flushSignal.awaitConsumption(TimeValue.timeValueMillis(100)); // waiting for consumption for less than client thread sleep
assertTrue(listener.getDataConsumed());
flushCount.incrementAndGet();
listener.resetConsumptionLatch();
Expand Down Expand Up @@ -340,7 +341,7 @@ public void testGetStream_WithUnresponsiveClient() throws Exception {
});
listener.setReady(false);
clientThread.start();
flushSignal.awaitConsumption(100); // waiting for consumption for less than client thread sleep
flushSignal.awaitConsumption(TimeValue.timeValueMillis(100)); // waiting for consumption for less than client thread sleep
assertTrue(listener.getDataConsumed());
flushCount.incrementAndGet();
listener.resetConsumptionLatch();
Expand Down Expand Up @@ -380,7 +381,7 @@ public void testGetStream_WithServerBackpressure() throws Exception {
listener.setReady(false);
clientThread.start();
Thread.sleep(100); // simulating writer backpressure
flushSignal.awaitConsumption(100);
flushSignal.awaitConsumption(TimeValue.timeValueMillis(100));
assertTrue(listener.getDataConsumed());
flushCount.incrementAndGet();
listener.resetConsumptionLatch();
Expand Down Expand Up @@ -421,7 +422,7 @@ public void testGetStream_WithServerError() throws Exception {
if (i == 4) {
throw new RuntimeException("Server error");
}
flushSignal.awaitConsumption(100);
flushSignal.awaitConsumption(TimeValue.timeValueMillis(100));
assertTrue(listener.getDataConsumed());
flushCount.incrementAndGet();
listener.resetConsumptionLatch();
Expand Down
Loading

0 comments on commit 60a3586

Please sign in to comment.