Skip to content

Commit

Permalink
pw_rpc: Temporarily ignore call IDs by default in Java client
Browse files Browse the repository at this point in the history
- The Java pw_rpc client gained support for call IDs in
  15d4ae5. For backwards compatibility,
  temporarily restore the prior behavior of ignoring call IDs by
  default. Support using call IDs by creating a client with the new
  Client.createMultiCall() function.
- Rename PendingRpc.DEFAULT_CALL_ID to Endpoint.FIRST_CALL_ID to clarify
  its meaning.
- Move TestClient.java's server packet creation code to Packets.java.

Bug: b/389777782
Change-Id: I2e15bdb1cd8a9b3e2463f3a0bb212ba6c5e893d6
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/260273
Commit-Queue: Auto-Submit <[email protected]>
Lint: Lint 🤖 <[email protected]>
Reviewed-by: Alexei Frolov <[email protected]>
Pigweed-Auto-Submit: Wyatt Hepler <[email protected]>
Docs-Not-Needed: Wyatt Hepler <[email protected]>
  • Loading branch information
255 authored and CQ Bot Account committed Jan 14, 2025
1 parent 587b9d7 commit c2de658
Show file tree
Hide file tree
Showing 13 changed files with 231 additions and 91 deletions.
1 change: 1 addition & 0 deletions pw_rpc/java/main/dev/pigweed/pw_rpc/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ RPC_CLIENT_SOURCES = [
"ChannelOutputException.java",
"Client.java",
"Endpoint.java",
"CallIdMode.java",
"FutureCall.java",
"Ids.java",
"InvalidRpcChannelException.java",
Expand Down
22 changes: 22 additions & 0 deletions pw_rpc/java/main/dev/pigweed/pw_rpc/CallIdMode.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright 2025 The Pigweed Authors
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not
// use this file except in compliance with the License. You may obtain a copy of
// the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations under
// the License.

package dev.pigweed.pw_rpc;

/**
* Enum for whether or not call IDs are enabled for the client.
*
* TODO: b/389777782 - Remove this when call IDs are always enabled.
*/
enum CallIdMode { DISABLED, ENABLED }
81 changes: 62 additions & 19 deletions pw_rpc/java/main/dev/pigweed/pw_rpc/Client.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,24 @@
public class Client {
private static final Logger logger = Logger.forClass(Client.class);

private static final Function<RpcKey, StreamObserver<MessageLite>> LOG_DEFAULT_OBSERVER_FACTORY =
(rpc) -> new StreamObserver<MessageLite>() {
@Override
public void onNext(MessageLite value) {
logger.atFine().log("%s received response: %s", rpc, value);
}

@Override
public void onCompleted(Status status) {
logger.atInfo().log("%s completed with status %s", rpc, status);
}

@Override
public void onError(Status status) {
logger.atWarning().log("%s terminated with error %s", rpc, status);
}
};

private final Map<Integer, Service> services;
private final Endpoint endpoint;

Expand All @@ -48,17 +66,25 @@ public class Client {
* @param channels supported channels, which are used to send requests to the server
* @param services which RPC services this client supports; used to handle encoding and decoding
*/
private Client(List<Channel> channels,
private Client(CallIdMode callIdMode,
List<Channel> channels,
List<Service> services,
Function<RpcKey, StreamObserver<MessageLite>> defaultObserverFactory) {
this.services = services.stream().collect(Collectors.toMap(Service::id, s -> s));
this.endpoint = new Endpoint(channels);
this.endpoint = new Endpoint(callIdMode, channels);

this.defaultObserverFactory = defaultObserverFactory;
}

/**
* Creates a new pw_rpc client.
* Creates a new pw_rpc client that only supports one ongoing call per method.
*
* This creates a Client that does not use call IDs. Clients created with this function only
* support one ongoing call per channel / service / method.
*
* This function is deprecated and should not be used. New code should use createMultiCall.
* Migrate to createMultiCall, or use createLegacySingleCall to temporarily maintain the prior
* single call behavior.
*
* @param channels the set of channels for the client to send requests over
* @param services the services to support on this client
Expand All @@ -68,29 +94,41 @@ private Client(List<Channel> channels,
public static Client create(List<Channel> channels,
List<Service> services,
Function<RpcKey, StreamObserver<MessageLite>> defaultObserverFactory) {
return new Client(channels, services, defaultObserverFactory);
return new Client(CallIdMode.DISABLED, channels, services, defaultObserverFactory);
}

/**
* Creates a new pw_rpc client that logs responses when no observer is provided to calls.
* Creates a new single-call pw_rpc client that logs responses when no observer is provided to
* calls.
*
* This function is deprecated and should not be used. New code should use createMultiCall.
* Migrate to createMultiCall, or use createLegacySingleCall to temporarily maintain the prior
* single call behavior.
*
*/
public static Client create(List<Channel> channels, List<Service> services) {
return create(channels, services, (rpc) -> new StreamObserver<MessageLite>() {
@Override
public void onNext(MessageLite value) {
logger.atFine().log("%s received response: %s", rpc, value);
}
return create(channels, services, LOG_DEFAULT_OBSERVER_FACTORY);
}

@Override
public void onCompleted(Status status) {
logger.atInfo().log("%s completed with status %s", rpc, status);
}
/**
* Creates a new pw_rpc client.
*
* @param channels the set of channels for the client to send requests over
* @param services the services to support on this client
* @param defaultObserverFactory function that creates a default observer for each RPC
* @return the new pw.rpc.Client
*/
public static Client createMultiCall(List<Channel> channels,
List<Service> services,
Function<RpcKey, StreamObserver<MessageLite>> defaultObserverFactory) {
return new Client(CallIdMode.ENABLED, channels, services, defaultObserverFactory);
}

@Override
public void onError(Status status) {
logger.atWarning().log("%s terminated with error %s", rpc, status);
}
});
/**
* Creates a new pw_rpc client that logs responses when no observer is provided to calls.
*/
public static Client createMultiCall(List<Channel> channels, List<Service> services) {
return new Client(CallIdMode.ENABLED, channels, services, LOG_DEFAULT_OBSERVER_FACTORY);
}

/**
Expand Down Expand Up @@ -235,4 +273,9 @@ public boolean processPacket(ByteBuffer data) {
}
return endpoint.processClientPacket(method, packet);
}

/** Expose the Packets object for internal use by TestClient. */
Packets getPackets() {
return endpoint.getPackets();
}
}
33 changes: 22 additions & 11 deletions pw_rpc/java/main/dev/pigweed/pw_rpc/Endpoint.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,23 @@ class Endpoint {
// Call IDs are varint encoded. Limit the varint size to 2 bytes (14 usable bits).
private static final int MAX_CALL_ID = 1 << 14;

static final int FIRST_CALL_ID = 1;

private final Packets packets;
private final Map<Integer, Channel> channels;
private final Map<PendingRpc, AbstractCall<?, ?>> pending = new HashMap<>();
private final BlockingQueue<Runnable> callUpdates = new LinkedBlockingQueue<>();
private final int maxCallId;

@GuardedBy("this") private int nextCallId = 1;
@GuardedBy("this") private int nextCallId = FIRST_CALL_ID;

public Endpoint(List<Channel> channels) {
this(channels, MAX_CALL_ID);
Endpoint(CallIdMode callIdMode, List<Channel> channels) {
this(callIdMode, channels, MAX_CALL_ID);
}

/** Create endpoint with {@code maxCallId} possible call_ids for testing purposes */
Endpoint(List<Channel> channels, int maxCallId) {
Endpoint(CallIdMode callIdMode, List<Channel> channels, int maxCallId) {
this.packets = new Packets(callIdMode);
this.channels = channels.stream().collect(Collectors.toMap(Channel::id, c -> c));
this.maxCallId = maxCallId;
}
Expand All @@ -83,7 +87,7 @@ public Endpoint(List<Channel> channels) {

try {
// If sending the packet fails, the RPC is never considered pending.
call.rpc().channel().send(Packets.request(call.rpc(), request));
call.rpc().channel().send(packets.request(call.rpc(), request));
} catch (ChannelOutputException e) {
call.handleExceptionOnInitialPacket(e);
}
Expand Down Expand Up @@ -112,7 +116,9 @@ public Endpoint(List<Channel> channels) {
throw InvalidRpcChannelException.unknown(channelId);
}

return createCall.apply(this, PendingRpc.create(channel, method, getNewCallId()));
// Use 0 for call ID when IDs are disabled, which is equivalent to an unset ID in the packet.
int callId = packets.callIdsEnabled() ? getNewCallId() : 0;
return createCall.apply(this, PendingRpc.create(channel, method, callId));
}

private void registerCall(AbstractCall<?, ?> call) {
Expand Down Expand Up @@ -146,7 +152,7 @@ public boolean cancel(AbstractCall<?, ?> call) throws ChannelOutputException {
}

enqueueCallUpdate(() -> call.handleError(Status.CANCELLED));
call.sendPacket(Packets.cancel(call.rpc()));
call.sendPacket(packets.cancel(call.rpc()));
}
} finally {
logger.atFiner().log("Cancelling %s", call);
Expand All @@ -170,12 +176,12 @@ public boolean abandon(AbstractCall<?, ?> call) {

public synchronized boolean clientStream(AbstractCall<?, ?> call, MessageLite payload)
throws ChannelOutputException {
return sendPacket(call, Packets.clientStream(call.rpc(), payload));
return sendPacket(call, packets.clientStream(call.rpc(), payload));
}

public synchronized boolean clientStreamEnd(AbstractCall<?, ?> call)
throws ChannelOutputException {
return sendPacket(call, Packets.clientStreamEnd(call.rpc()));
return sendPacket(call, packets.clientStreamEnd(call.rpc()));
}

private boolean sendPacket(AbstractCall<?, ?> call, byte[] packet) throws ChannelOutputException {
Expand Down Expand Up @@ -299,9 +305,9 @@ private boolean updateCall(RpcPacket packet, PendingRpc rpc) {
return true;
}

private static void sendError(Channel channel, RpcPacket packet, Status status) {
private void sendError(Channel channel, RpcPacket packet, Status status) {
try {
channel.send(Packets.error(packet, status));
channel.send(packets.error(packet, status));
} catch (ChannelOutputException e) {
logger.atWarning().withCause(e).log("Failed to send error packet");
}
Expand Down Expand Up @@ -329,4 +335,9 @@ private synchronized int getNewCallId() {
}
return callId;
}

/** Expose the Packets object for internal use by TestClient. */
Packets getPackets() {
return packets;
}
}
64 changes: 48 additions & 16 deletions pw_rpc/java/main/dev/pigweed/pw_rpc/Packets.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,19 @@

/** Encodes pw_rpc packets of various types. */
/* package */ class Packets {
private Packets() {}
private final CallIdMode callIdMode;

public static byte[] request(PendingRpc rpc, MessageLite payload) {
RpcPacket.Builder builder = RpcPacket.newBuilder()
public Packets(CallIdMode callIdMode) {
this.callIdMode = callIdMode;
}

public boolean callIdsEnabled() {
return callIdMode == CallIdMode.ENABLED;
}

public byte[] request(PendingRpc rpc, MessageLite payload) {
RpcPacket.Builder builder = newBuilder(rpc.callId())
.setType(PacketType.REQUEST)
.setCallId(rpc.callId())
.setChannelId(rpc.channel().id())
.setServiceId(rpc.service().id())
.setMethodId(rpc.method().id());
Expand All @@ -35,10 +42,9 @@ public static byte[] request(PendingRpc rpc, MessageLite payload) {
return builder.build().toByteArray();
}

public static byte[] cancel(PendingRpc rpc) {
return RpcPacket.newBuilder()
public byte[] cancel(PendingRpc rpc) {
return newBuilder(rpc.callId())
.setType(PacketType.CLIENT_ERROR)
.setCallId(rpc.callId())
.setChannelId(rpc.channel().id())
.setServiceId(rpc.service().id())
.setMethodId(rpc.method().id())
Expand All @@ -47,10 +53,9 @@ public static byte[] cancel(PendingRpc rpc) {
.toByteArray();
}

public static byte[] error(RpcPacket packet, Status status) {
return RpcPacket.newBuilder()
public byte[] error(RpcPacket packet, Status status) {
return newBuilder(packet.getCallId())
.setType(PacketType.CLIENT_ERROR)
.setCallId(packet.getCallId())
.setChannelId(packet.getChannelId())
.setServiceId(packet.getServiceId())
.setMethodId(packet.getMethodId())
Expand All @@ -59,10 +64,9 @@ public static byte[] error(RpcPacket packet, Status status) {
.toByteArray();
}

public static byte[] clientStream(PendingRpc rpc, MessageLite payload) {
return RpcPacket.newBuilder()
public byte[] clientStream(PendingRpc rpc, MessageLite payload) {
return newBuilder(rpc.callId())
.setType(PacketType.CLIENT_STREAM)
.setCallId(rpc.callId())
.setChannelId(rpc.channel().id())
.setServiceId(rpc.service().id())
.setMethodId(rpc.method().id())
Expand All @@ -71,14 +75,42 @@ public static byte[] clientStream(PendingRpc rpc, MessageLite payload) {
.toByteArray();
}

public static byte[] clientStreamEnd(PendingRpc rpc) {
return RpcPacket.newBuilder()
public byte[] clientStreamEnd(PendingRpc rpc) {
return newBuilder(rpc.callId())
.setType(PacketType.CLIENT_REQUEST_COMPLETION)
.setCallId(rpc.callId())
.setChannelId(rpc.channel().id())
.setServiceId(rpc.service().id())
.setMethodId(rpc.method().id())
.build()
.toByteArray();
}

public byte[] serverError(
int channelId, String service, String method, int callId, Status error) {
return newBuilder(callId)
.setType(PacketType.SERVER_ERROR)
.setChannelId(channelId)
.setServiceId(Ids.calculate(service))
.setMethodId(Ids.calculate(method))
.setStatus(error.code())
.build()
.toByteArray();
}

public RpcPacket.Builder startServerStream(
int channelId, String service, String method, int callId) {
return newBuilder(callId)
.setType(PacketType.SERVER_STREAM)
.setChannelId(channelId)
.setServiceId(Ids.calculate(service))
.setMethodId(Ids.calculate(method));
}

private RpcPacket.Builder newBuilder(int callId) {
RpcPacket.Builder builder = RpcPacket.newBuilder();
if (callIdMode == CallIdMode.ENABLED) {
builder.setCallId(callId);
}
return builder;
}
}
3 changes: 0 additions & 3 deletions pw_rpc/java/main/dev/pigweed/pw_rpc/PendingRpc.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@
/** Represents an active RPC invocation: channel + service + method + call id. */
@AutoValue
abstract class PendingRpc {
// The default call id should always be 1 since it is the first id that is chosen by the endpoint.
static final int DEFAULT_CALL_ID = 1;

static PendingRpc create(Channel channel, Method method, int callId) {
return new AutoValue_PendingRpc(channel, method, callId);
}
Expand Down
Loading

0 comments on commit c2de658

Please sign in to comment.