Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DO NOT MERGE] Hdfs-17531 rebase #7183

Closed
wants to merge 9 commits into from
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ public class Client implements AutoCloseable {
private static final ThreadLocal<Integer> retryCount = new ThreadLocal<Integer>();
private static final ThreadLocal<Object> EXTERNAL_CALL_HANDLER
= new ThreadLocal<>();
private static final ThreadLocal<AsyncGet<? extends Writable, IOException>>
ASYNC_RPC_RESPONSE = new ThreadLocal<>();
private static final ThreadLocal<CompletableFuture<Writable>> ASYNC_RPC_RESPONSE
= new ThreadLocal<>();
private static final ThreadLocal<Boolean> asynchronousMode =
new ThreadLocal<Boolean>() {
@Override
Expand All @@ -110,7 +110,46 @@ protected Boolean initialValue() {
@Unstable
public static <T extends Writable> AsyncGet<T, IOException>
getAsyncRpcResponse() {
return (AsyncGet<T, IOException>) ASYNC_RPC_RESPONSE.get();
CompletableFuture<Writable> responseFuture = ASYNC_RPC_RESPONSE.get();
return new AsyncGet<T, IOException>() {
@Override
public T get(long timeout, TimeUnit unit)
throws IOException, TimeoutException, InterruptedException {
try {
if (unit == null || timeout < 0) {
return (T) responseFuture.get();
}
return (T) responseFuture.get(timeout, unit);
} catch (ExecutionException e) {
Throwable cause = e.getCause();
if (cause instanceof IOException) {
throw (IOException) cause;
}
throw new IllegalStateException(e);
}
}

@Override
public boolean isDone() {
return responseFuture.isDone();
}
};
}

/**
* Retrieves the current response future from the thread-local storage.
*
* @return A {@link CompletableFuture} of type T that represents the
* asynchronous operation. If no response future is present in
* the thread-local storage, this method returns {@code null}.
* @param <T> The type of the value completed by the returned
* {@link CompletableFuture}. It must be a subclass of
* {@link Writable}.
* @see CompletableFuture
* @see Writable
*/
public static <T extends Writable> CompletableFuture<T> getResponseFuture() {
return (CompletableFuture<T>) ASYNC_RPC_RESPONSE.get();
}

/**
Expand Down Expand Up @@ -277,10 +316,8 @@ static class Call {
final int id; // call id
final int retry; // retry count
final Writable rpcRequest; // the serialized rpc request
Writable rpcResponse; // null if rpc has error
IOException error; // exception, null if success
private final CompletableFuture<Writable> rpcResponseFuture;
final RPC.RpcKind rpcKind; // Rpc EngineKind
boolean done; // true when call is done
private final Object externalHandler;
private AlignmentContext alignmentContext;

Expand All @@ -304,6 +341,7 @@ private Call(RPC.RpcKind rpcKind, Writable param) {
}

this.externalHandler = EXTERNAL_CALL_HANDLER.get();
this.rpcResponseFuture = new CompletableFuture<>();
}

@Override
Expand All @@ -314,9 +352,6 @@ public String toString() {
/** Indicate when the call is complete and the
* value or error are available. Notifies by default. */
protected synchronized void callComplete() {
this.done = true;
notify(); // notify caller

if (externalHandler != null) {
synchronized (externalHandler) {
externalHandler.notify();
Expand All @@ -339,7 +374,7 @@ public synchronized void setAlignmentContext(AlignmentContext ac) {
* @param error exception thrown by the call; either local or remote
*/
public synchronized void setException(IOException error) {
this.error = error;
rpcResponseFuture.completeExceptionally(error);
callComplete();
}

Expand All @@ -349,13 +384,9 @@ public synchronized void setException(IOException error) {
* @param rpcResponse return value of the rpc call.
*/
public synchronized void setRpcResponse(Writable rpcResponse) {
this.rpcResponse = rpcResponse;
rpcResponseFuture.complete(rpcResponse);
callComplete();
}

public synchronized Writable getRpcResponse() {
return rpcResponse;
}
}

/** Thread that reads responses and notifies callers. Each connection owns a
Expand Down Expand Up @@ -1495,39 +1526,19 @@ Writable call(RPC.RpcKind rpcKind, Writable rpcRequest,
}

if (isAsynchronousMode()) {
final AsyncGet<Writable, IOException> asyncGet
= new AsyncGet<Writable, IOException>() {
@Override
public Writable get(long timeout, TimeUnit unit)
throws IOException, TimeoutException{
boolean done = true;
try {
final Writable w = getRpcResponse(call, connection, timeout, unit);
if (w == null) {
done = false;
throw new TimeoutException(call + " timed out "
+ timeout + " " + unit);
}
return w;
} finally {
if (done) {
releaseAsyncCall();
CompletableFuture<Writable> result = call.rpcResponseFuture.handle(
(rpcResponse, e) -> {
releaseAsyncCall();
if (e != null) {
IOException ioe = (IOException) e;
throw new CompletionException(warpIOException(ioe, connection));
}
}
}

@Override
public boolean isDone() {
synchronized (call) {
return call.done;
}
}
};

ASYNC_RPC_RESPONSE.set(asyncGet);
return rpcResponse;
});
ASYNC_RPC_RESPONSE.set(result);
return null;
} else {
return getRpcResponse(call, connection, -1, null);
return getRpcResponse(call, connection);
}
}

Expand Down Expand Up @@ -1564,37 +1575,34 @@ int getAsyncCallCount() {
}

/** @return the rpc response or, in case of timeout, null. */
private Writable getRpcResponse(final Call call, final Connection connection,
final long timeout, final TimeUnit unit) throws IOException {
synchronized (call) {
while (!call.done) {
try {
AsyncGet.Util.wait(call, timeout, unit);
if (timeout >= 0 && !call.done) {
return null;
}
} catch (InterruptedException ie) {
Thread.currentThread().interrupt();
throw new InterruptedIOException("Call interrupted");
}
private Writable getRpcResponse(final Call call, final Connection connection)
throws IOException {
try {
return call.rpcResponseFuture.get();
} catch (InterruptedException ie) {
Thread.currentThread().interrupt();
throw new InterruptedIOException("Call interrupted");
} catch (ExecutionException e) {
Throwable cause = e.getCause();
if (cause instanceof IOException) {
throw warpIOException((IOException) cause, connection);
}
throw new IllegalStateException(e);
}
}

if (call.error != null) {
if (call.error instanceof RemoteException ||
call.error instanceof SaslException) {
call.error.fillInStackTrace();
throw call.error;
} else { // local exception
InetSocketAddress address = connection.getRemoteAddress();
throw NetUtils.wrapException(address.getHostName(),
address.getPort(),
NetUtils.getHostname(),
0,
call.error);
}
} else {
return call.getRpcResponse();
}
private IOException warpIOException(IOException ioe, Connection connection) {
if (ioe instanceof RemoteException ||
ioe instanceof SaslException) {
ioe.fillInStackTrace();
return ioe;
} else { // local exception
InetSocketAddress address = connection.getRemoteAddress();
return NetUtils.wrapException(address.getHostName(),
address.getPort(),
NetUtils.getHostname(),
0,
ioe);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcResponseHeaderProto;
import org.apache.hadoop.net.NetUtils;
import org.apache.hadoop.util.StringUtils;
import org.apache.hadoop.util.Time;
import org.apache.hadoop.util.concurrent.AsyncGetFuture;
import org.junit.Assert;
import org.junit.Before;
Expand All @@ -38,13 +39,16 @@
import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

public class TestAsyncIPC {

Expand Down Expand Up @@ -137,6 +141,60 @@ void assertReturnValues(long timeout, TimeUnit unit)
}
}

/**
* For testing the asynchronous calls of the RPC client
* implemented with CompletableFuture.
*/
static class AsyncCompletableFutureCaller extends Thread {
private final Client client;
private final InetSocketAddress server;
private final int count;
private final List<CompletableFuture<Writable>> completableFutures;
private final List<Long> expectedValues;

AsyncCompletableFutureCaller(Client client, InetSocketAddress server, int count) {
this.client = client;
this.server = server;
this.count = count;
this.completableFutures = new ArrayList<>(count);
this.expectedValues = new ArrayList<>(count);
setName("Async CompletableFuture Caller");
}

@Override
public void run() {
// Set the RPC client to use asynchronous mode.
Client.setAsynchronousMode(true);
long startTime = Time.monotonicNow();
try {
for (int i = 0; i < count; i++) {
final long param = TestIPC.RANDOM.nextLong();
TestIPC.call(client, param, server, conf);
expectedValues.add(param);
completableFutures.add(Client.getResponseFuture());
}
// Since the run method is asynchronous,
// it does not need to wait for a response after sending a request,
// so the time taken by the run method is less than count * 100
// (where 100 is the time taken by the server to process a request).
long cost = Time.monotonicNow() - startTime;
assertTrue(cost < count * 100L);
LOG.info("[{}] run cost {}ms", Thread.currentThread().getName(), cost);
} catch (Exception e) {
fail();
}
}

public void assertReturnValues()
throws InterruptedException, ExecutionException {
for (int i = 0; i < count; i++) {
LongWritable value = (LongWritable) completableFutures.get(i).get();
Assert.assertEquals("call" + i + " failed.",
expectedValues.get(i).longValue(), value.get());
}
}
}

static class AsyncLimitlCaller extends Thread {
private Client client;
private InetSocketAddress server;
Expand Down Expand Up @@ -538,4 +596,37 @@ public void run() {
assertEquals(startID + i, callIds.get(i).intValue());
}
}

@Test(timeout = 60000)
public void testAsyncCallWithCompletableFuture() throws IOException,
InterruptedException, ExecutionException {
// Override client to store the call id
final Client client = new Client(LongWritable.class, conf);

// Construct an RPC server, which includes a handler thread.
final TestServer server = new TestIPC.TestServer(1, false, conf);
server.callListener = () -> {
try {
// The server requires at least 100 milliseconds to process a request.
Thread.sleep(100);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
};

try {
InetSocketAddress addr = NetUtils.getConnectAddress(server);
server.start();
// Send 10 asynchronous requests.
final AsyncCompletableFutureCaller caller =
new AsyncCompletableFutureCaller(client, addr, 10);
caller.start();
caller.join();
// Check if the values returned by the asynchronous call meet the expected values.
caller.assertReturnValues();
} finally {
client.stop();
server.stop();
}
}
}
Loading
Loading