Skip to content

Commit

Permalink
Merge pull request #1 from omnia-network/fix/keep-alive-for-new-clients
Browse files Browse the repository at this point in the history
Fix/keep alive for new clients
  • Loading branch information
ilbertt authored Oct 18, 2023
2 parents a00feb4 + 457b1f8 commit 0d560ec
Show file tree
Hide file tree
Showing 4 changed files with 235 additions and 51 deletions.
66 changes: 66 additions & 0 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
name: Release ic-websocket-cdk-mo

# only run when the tests complete
on:
workflow_run:
workflows: [ic-websocket-cdk-mo tests]
types:
- completed
branches:
- main

jobs:
publish:
runs-on: ubuntu-latest
# only run if the tests were successful
if: ${{ github.event.workflow_run.conclusion == 'success' }}
outputs:
version: ${{ steps.npm-publish.outputs.version }}
steps:
- uses: actions/checkout@v3
- uses: actions/setup-node@v3
with:
node-version: 18

- uses: aviate-labs/[email protected]
with:
dfx-version: 0.14.1
env:
DFX_IDENTITY_PEM: ${{ secrets.DFX_IDENTITY_PEM }}

- name: install mops
run: npm i ic-mops -g

- run: |
dfx identity use action
mops import-identity --no-encrypt -- "$(dfx identity export action)"
mops publish --no-docs
echo "version=$(cat mops.toml | grep "version =" | cut -d\" -f2)" >> "$GITHUB_OUTPUT"
env:
CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }}
tag:
needs: publish
runs-on: ubuntu-latest
outputs:
version: ${{ steps.tag_version.outputs.new_tag }}
steps:
- name: Checkout
uses: actions/checkout@v3
- name: Bump version and push tag
id: tag_version
uses: mathieudutour/[email protected]
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
custom_tag: ${{ needs.publish.outputs.version }}

release:
needs: tag
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v3
- name: Release
uses: softprops/action-gh-release@v1
with:
tag_name: ${{ needs.tag.outputs.version }}
2 changes: 1 addition & 1 deletion mops.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "ic-websocket-cdk"
version = "0.1.0"
version = "0.1.1"
description = "IC WebSocket Motoko CDK"
repository = "https://github.com/omnia-network/ic-websocket-cdk-mo"
keywords = ["ic", "websocket", "motoko", "cdk"]
Expand Down
95 changes: 73 additions & 22 deletions src/lib.mo
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import Array "mo:base/Array";
import Blob "mo:base/Blob";
import CertifiedData "mo:base/CertifiedData";
import Debug "mo:base/Debug";
import Deque "mo:base/Deque";
import HashMap "mo:base/HashMap";
import Hash "mo:base/Hash";
Expand All @@ -17,6 +18,7 @@ import Time "mo:base/Time";
import Timer "mo:base/Timer";
import Bool "mo:base/Bool";
import Error "mo:base/Error";
import TrieSet "mo:base/TrieSet";
import CborValue "mo:cbor/Value";
import CborDecoder "mo:cbor/Decoder";
import CborEncoder "mo:cbor/Encoder";
Expand All @@ -42,6 +44,14 @@ module {
/// The default timeout to wait for the client to send a keep alive after receiving an acknowledgement.
let DEFAULT_CLIENT_KEEP_ALIVE_TIMEOUT_MS : Nat64 = 10_000; // 10 seconds

/// The initial nonce for outgoing messages.
let INITIAL_OUTGOING_MESSAGE_NONCE : Nat64 = 0;
/// The initial sequence number to expect from messages coming from clients.
/// The first message coming from the client will have sequence number `1` because on the client the sequence number is incremented before sending the message.
let INITIAL_CLIENT_SEQUENCE_NUM : Nat64 = 1;
/// The initial sequence number for outgoing messages.
let INITIAL_CANISTER_SEQUENCE_NUM : Nat64 = 0;

//// TYPES ////
type CandidType = Type.Type;
type CandidValue = Value.Value;
Expand Down Expand Up @@ -588,6 +598,8 @@ module {
public var REGISTERED_CLIENTS = HashMap.HashMap<ClientKey, RegisteredClient>(0, areClientKeysEqual, hashClientKey);
/// Maps the client's principal to the current client key
public var CURRENT_CLIENT_KEY_MAP = HashMap.HashMap<ClientPrincipal, ClientKey>(0, Principal.equal, Principal.hash);
/// Keeps track of all the clients for which we're waiting for a keep alive message.
public var CLIENTS_WAITING_FOR_KEEP_ALIVE : TrieSet.Set<ClientKey> = TrieSet.empty();
/// Maps the client's public key to the sequence number to use for the next outgoing message (to that client).
public var OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP = HashMap.HashMap<ClientKey, Nat64>(0, areClientKeysEqual, hashClientKey);
/// Maps the client's public key to the expected sequence number of the next incoming message (from that client).
Expand All @@ -602,7 +614,7 @@ module {
/// Keeps track of the nonce which:
/// - the WS Gateway uses to specify the first index of the certified messages to be returned when polling
/// - the client uses as part of the path in the Merkle tree in order to verify the certificate of the messages relayed by the WS Gateway
public var OUTGOING_MESSAGE_NONCE : Nat64 = 0;
public var OUTGOING_MESSAGE_NONCE : Nat64 = INITIAL_OUTGOING_MESSAGE_NONCE;
/// The acknowledgement active timer.
public var ACK_TIMER : ?Timer.TimerId = null;
// /// The keep alive active timer.
Expand All @@ -616,13 +628,15 @@ module {
await remove_client(client_key, handlers);
};

// make sure all the maps are cleared
CURRENT_CLIENT_KEY_MAP := HashMap.HashMap<ClientPrincipal, ClientKey>(0, Principal.equal, Principal.hash);
CLIENTS_WAITING_FOR_KEEP_ALIVE := TrieSet.empty<ClientKey>();
OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP := HashMap.HashMap<ClientKey, Nat64>(0, areClientKeysEqual, hashClientKey);
INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP := HashMap.HashMap<ClientKey, Nat64>(0, areClientKeysEqual, hashClientKey);
CERT_TREE_STORE := CertTree.newStore();
CERT_TREE := CertTree.Ops(CERT_TREE_STORE);
MESSAGES_FOR_GATEWAY := List.nil<CanisterOutputMessage>();
OUTGOING_MESSAGE_NONCE := 0;
OUTGOING_MESSAGE_NONCE := INITIAL_OUTGOING_MESSAGE_NONCE;
};

public func get_outgoing_message_nonce() : Nat64 {
Expand Down Expand Up @@ -657,12 +671,16 @@ module {
#Ok;
};

public func add_client_to_wait_for_keep_alive(client_key : ClientKey) {
CLIENTS_WAITING_FOR_KEEP_ALIVE := TrieSet.put<ClientKey>(CLIENTS_WAITING_FOR_KEEP_ALIVE, client_key, hashClientKey(client_key), areClientKeysEqual);
};

public func get_registered_gateway_principal() : Principal {
REGISTERED_GATEWAY.gateway_principal;
};

func init_outgoing_message_to_client_num(client_key : ClientKey) {
OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.put(client_key, 0);
OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.put(client_key, INITIAL_CANISTER_SEQUENCE_NUM);
};

public func get_outgoing_message_to_client_num(client_key : ClientKey) : Result<Nat64, Text> {
Expand All @@ -684,7 +702,7 @@ module {
};

func init_expected_incoming_message_from_client_num(client_key : ClientKey) {
INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.put(client_key, 1);
INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.put(client_key, INITIAL_CLIENT_SEQUENCE_NUM);
};

public func get_expected_incoming_message_from_client_num(client_key : ClientKey) : Result<Nat64, Text> {
Expand Down Expand Up @@ -715,6 +733,7 @@ module {
};

public func remove_client(client_key : ClientKey, handlers : WsHandlers) : async () {
CLIENTS_WAITING_FOR_KEEP_ALIVE := TrieSet.delete(CLIENTS_WAITING_FOR_KEEP_ALIVE, client_key, hashClientKey(client_key), areClientKeysEqual);
CURRENT_CLIENT_KEY_MAP.delete(client_key.client_principal);
REGISTERED_CLIENTS.delete(client_key);
OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.delete(client_key);
Expand Down Expand Up @@ -902,37 +921,36 @@ module {
reset_keep_alive_timer();
};

/// Schedules a timer to send an acknowledgement message to the client.
/// Start an interval to send an acknowledgement messages to the clients.
///
/// The timer callback is [send_ack_to_clients_timer_callback]. After the callback is executed,
/// The interval callback is [send_ack_to_clients_timer_callback]. After the callback is executed,
/// a timer is scheduled to check if the registered clients have sent a keep alive message.
public func schedule_send_ack_to_clients(send_ack_interval_ms : Nat64, keep_alive_timeout_ms : Nat64, handlers : WsHandlers) {
let timer_id = Timer.setTimer(
let timer_id = Timer.recurringTimer(
#nanoseconds(Nat64.toNat(send_ack_interval_ms) * 1_000_000),
func() : async () {
send_ack_to_clients_timer_callback();

schedule_check_keep_alive(send_ack_interval_ms, keep_alive_timeout_ms, handlers);
schedule_check_keep_alive(keep_alive_timeout_ms, handlers);
},
);

put_ack_timet_id(timer_id);
};

/// Schedules a timer to check if the registered clients have sent a keep alive message
/// Schedules a timer to check if the clients (only those to which an ack message was sent) have sent a keep alive message
/// after receiving an acknowledgement message.
///
/// The timer callback is [check_keep_alive_timer_callback]. After the callback is executed,
/// a timer is scheduled again to send an acknowledgement message to the registered clients.
func schedule_check_keep_alive(send_ack_interval_ms : Nat64, keep_alive_timeout_ms : Nat64, handlers : WsHandlers) {
/// The timer callback is [check_keep_alive_timer_callback].
func schedule_check_keep_alive(keep_alive_timeout_ms : Nat64, handlers : WsHandlers) {
let timer_id = Timer.setTimer(
#nanoseconds(Nat64.toNat(keep_alive_timeout_ms) * 1_000_000),
func() : async () {
await check_keep_alive_timer_callback(keep_alive_timeout_ms, handlers);

schedule_send_ack_to_clients(send_ack_interval_ms, keep_alive_timeout_ms, handlers);
},
);

put_keep_alive_timer_id(timer_id);
};

/// Sends an acknowledgement message to the client.
Expand All @@ -943,6 +961,7 @@ module {
switch (get_expected_incoming_message_from_client_num(client_key)) {
case (#Ok(expected_incoming_sequence_num)) {
let ack_message : CanisterAckMessageContent = {
// the expected sequence number is 1 more because it's incremented when a message is received
last_incoming_sequence_num = expected_incoming_sequence_num - 1;
};
let message : WebsocketServiceMessageContent = #AckMessage(ack_message);
Expand All @@ -953,7 +972,7 @@ module {
Logger.custom_print("[ack-to-clients-timer-cb]: Error sending ack message to client" # clientKeyToText(client_key) # ": " # err);
};
case (#Ok(_)) {
// Do nothing
add_client_to_wait_for_keep_alive(client_key);
};
};
};
Expand All @@ -967,16 +986,24 @@ module {
Logger.custom_print("[ack-to-clients-timer-cb]: Sent ack messages to all clients");
};

/// Checks if the registered clients have sent a keep alive message.
/// If a client has not sent a keep alive message, it is removed from the registered clients.
/// Checks if the clients for which we are waiting for keep alive have sent a keep alive message.
/// If a client has not sent a keep alive message, it is removed from the connected clients.
func check_keep_alive_timer_callback(keep_alive_timeout_ms : Nat64, handlers : WsHandlers) : async () {
for ((client_key, client_metadata) in REGISTERED_CLIENTS.entries()) {
let last_keep_alive = client_metadata.last_keep_alive_timestamp;
for (client_key in Array.vals(TrieSet.toArray(CLIENTS_WAITING_FOR_KEEP_ALIVE))) {
let client_metadata = REGISTERED_CLIENTS.get(client_key);
switch (client_metadata) {
case (?client_metadata) {
let last_keep_alive = client_metadata.last_keep_alive_timestamp;

if (get_current_time() - last_keep_alive > keep_alive_timeout_ms * 1_000_000) {
await remove_client(client_key, handlers);
if (get_current_time() - last_keep_alive > keep_alive_timeout_ms * 1_000_000) {
await remove_client(client_key, handlers);

Logger.custom_print("[check-keep-alive-timer-cb]: Client " # clientKeyToText(client_key) # " has not sent a keep alive message in the last " # debug_show (keep_alive_timeout_ms) # " ms and has been removed");
Logger.custom_print("[check-keep-alive-timer-cb]: Client " # clientKeyToText(client_key) # " has not sent a keep alive message in the last " # debug_show (keep_alive_timeout_ms) # " ms and has been removed");
};
};
case (null) {
// Do nothing
};
};
};

Expand Down Expand Up @@ -1114,12 +1141,18 @@ module {
};
/// The interval at which to send an acknowledgement message to the client,
/// so that the client knows that all the messages it sent have been received by the canister (in milliseconds).
///
/// Must be greater than `keep_alive_timeout_ms`.
///
/// Defaults to `60_000` (60 seconds).
public var send_ack_interval_ms : Nat64 = switch (init_send_ack_interval_ms) {
case (?value) { value };
case (null) { DEFAULT_SEND_ACK_INTERVAL_MS };
};
/// The delay to wait for the client to send a keep alive after receiving an acknowledgement (in milliseconds).
///
/// Must be lower than `send_ack_interval_ms`.
///
/// Defaults to `10_000` (10 seconds).
public var keep_alive_timeout_ms : Nat64 = switch (init_keep_alive_timeout_ms) {
case (?value) { value };
Expand All @@ -1129,15 +1162,33 @@ module {
public func get_handlers() : WsHandlers {
return handlers;
};

/// Checks the validity of the timer parameters.
/// `send_ack_interval_ms` must be greater than `keep_alive_timeout_ms`.
///
/// # Traps
/// If `send_ack_interval_ms` < `keep_alive_timeout_ms`.
public func check_validity() {
if (keep_alive_timeout_ms > send_ack_interval_ms) {
Debug.trap("send_ack_interval_ms must be greater than keep_alive_timeout_ms");
};
};
};

/// The IC WebSocket instance.
///
/// # Traps
/// If the parameters are invalid. See [`WsInitParams::check_validity`] for more details.
public class IcWebSocket(init_ws_state : IcWebSocketState, params : WsInitParams) {
/// The state of the IC WebSocket.
private var WS_STATE : IcWebSocketState = init_ws_state;
/// The callback handlers for the WebSocket.
private var HANDLERS : WsHandlers = params.get_handlers();

do {
// check if the parameters are valid
params.check_validity();

// reset initial timers
WS_STATE.reset_timers();

Expand Down
Loading

0 comments on commit 0d560ec

Please sign in to comment.