Skip to content

Commit

Permalink
Merge pull request #4 from omnia-network/release/0.3.2
Browse files Browse the repository at this point in the history
release/0.3.2
  • Loading branch information
ilbertt authored Dec 15, 2023
2 parents 736a4e1 + 658de2c commit 344bd5e
Show file tree
Hide file tree
Showing 13 changed files with 269 additions and 113 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:

- uses: aviate-labs/[email protected]
with:
dfx-version: 0.15.1
dfx-version: 0.15.2
env:
DFX_IDENTITY_PEM: ${{ secrets.DFX_IDENTITY_PEM }}

Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:

- uses: aviate-labs/[email protected]
with:
dfx-version: 0.15.1
dfx-version: 0.15.2

# rust toolchain is needed for integration tests
- uses: actions-rs/toolchain@v1
Expand All @@ -46,6 +46,7 @@ jobs:
- name: Run integration tests
run: |
export POCKET_IC_MUTE_SERVER=1
export POCKET_IC_BIN="$(pwd)/bin/pocket-ic"
export TEST_CANISTER_WASM_PATH="$(pwd)/bin/test_canister.wasm"
cd tests/ic-websocket-cdk-rs
Expand Down
2 changes: 1 addition & 1 deletion mops.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "ic-websocket-cdk"
version = "0.3.1"
version = "0.3.2"
description = "IC WebSocket Motoko CDK"
repository = "https://github.com/omnia-network/ic-websocket-cdk-mo"
keywords = ["ic", "websocket", "motoko", "cdk"]
Expand Down
18 changes: 0 additions & 18 deletions scripts/download-pocket-ic.sh

This file was deleted.

1 change: 1 addition & 0 deletions scripts/download-pocket-ic.sh
1 change: 1 addition & 0 deletions scripts/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ set -e

./scripts/build-test-canister.sh

export POCKET_IC_MUTE_SERVER=1
export POCKET_IC_BIN="$(pwd)/bin/pocket-ic"
export TEST_CANISTER_WASM_PATH="$(pwd)/bin/test_canister.wasm"

Expand Down
10 changes: 8 additions & 2 deletions src/Constants.mo
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,14 @@ module {
public let DEFAULT_MAX_NUMBER_OF_RETURNED_MESSAGES : Nat = 50;
/// The default interval at which to send acknowledgements to the client.
public let DEFAULT_SEND_ACK_INTERVAL_MS : Nat64 = 300_000; // 5 minutes
/// The default timeout to wait for the client to send a keep alive after receiving an acknowledgement.
public let DEFAULT_CLIENT_KEEP_ALIVE_TIMEOUT_MS : Nat64 = 60_000; // 1 minute
/// The maximum communication latency allowed between the client and the canister.
public let COMMUNICATION_LATENCY_BOUND_MS : Nat64 = 30_000; // 30 seconds
public class Computed() {
/// The default timeout to wait for the client to send a keep alive after receiving an acknowledgement.
public let CLIENT_KEEP_ALIVE_TIMEOUT_MS : Nat64 = 2 * COMMUNICATION_LATENCY_BOUND_MS;
/// Same as [CLIENT_KEEP_ALIVE_TIMEOUT_MS], but in nanoseconds.
public let CLIENT_KEEP_ALIVE_TIMEOUT_NS : Nat64 = CLIENT_KEEP_ALIVE_TIMEOUT_MS * 1_000_000;
};

/// The initial nonce for outgoing messages.
public let INITIAL_OUTGOING_MESSAGE_NONCE : Nat64 = 0;
Expand Down
129 changes: 112 additions & 17 deletions src/State.mo
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import Nat64 "mo:base/Nat64";
import Text "mo:base/Text";
import Blob "mo:base/Blob";
import CertifiedData "mo:base/CertifiedData";
import Buffer "mo:base/Buffer";
import CertTree "mo:ic-certification/CertTree";
import Sha256 "mo:sha2/Sha256";

Expand All @@ -21,7 +22,8 @@ import Utils "Utils";
module {
type CanisterOutputMessage = Types.CanisterOutputMessage;
type CanisterWsGetMessagesResult = Types.CanisterWsGetMessagesResult;
type CanisterWsSendResult = Types.CanisterWsSendResult;
type CanisterCloseResult = Types.CanisterCloseResult;
type CanisterSendResult = Types.CanisterSendResult;
type ClientKey = Types.ClientKey;
type ClientPrincipal = Types.ClientPrincipal;
type GatewayPrincipal = Types.GatewayPrincipal;
Expand Down Expand Up @@ -58,6 +60,8 @@ module {
var CERT_TREE = CertTree.Ops(CERT_TREE_STORE);
/// Keeps track of the principals of the WS Gateways that poll the canister.
var REGISTERED_GATEWAYS = HashMap.HashMap<GatewayPrincipal, RegisteredGateway>(0, Principal.equal, Principal.hash);
/// Keeps track of the gateways that must be removed from the list of registered gateways in the next ack interval
var GATEWAYS_TO_REMOVE = HashMap.HashMap<GatewayPrincipal, Types.TimestampNs>(0, Principal.equal, Principal.hash);
/// The acknowledgement active timer.
public var ACK_TIMER : ?Timer.TimerId = null;
/// The keep alive active timer.
Expand All @@ -68,7 +72,7 @@ module {
public func reset_internal_state(handlers : WsHandlers) : async () {
// for each client, call the on_close handler before clearing the map
for (client_key in REGISTERED_CLIENTS.keys()) {
await remove_client(client_key, handlers);
await remove_client(client_key, ?handlers, null);
};

// make sure all the maps are cleared
Expand All @@ -79,11 +83,14 @@ module {
CERT_TREE_STORE := CertTree.newStore();
CERT_TREE := CertTree.Ops(CERT_TREE_STORE);
REGISTERED_GATEWAYS := HashMap.HashMap<GatewayPrincipal, RegisteredGateway>(0, Principal.equal, Principal.hash);
GATEWAYS_TO_REMOVE := HashMap.HashMap<GatewayPrincipal, Types.TimestampNs>(0, Principal.equal, Principal.hash);
};

/// Increments the clients connected count for the given gateway.
/// If the gateway is not registered, a new entry is created with a clients connected count of 1.
func increment_gateway_clients_count(gateway_principal : GatewayPrincipal) {
ignore GATEWAYS_TO_REMOVE.remove(gateway_principal);

switch (REGISTERED_GATEWAYS.get(gateway_principal)) {
case (?registered_gateway) {
registered_gateway.increment_clients_count();
Expand All @@ -96,18 +103,61 @@ module {
};
};

/// Decrements the clients connected count for the given gateway.
/// If there are no more clients connected, the gateway is removed from the list of registered gateways.
/// Decrements the clients connected count for the given gateway, if it exists.
///
/// If the gateway has no more clients connected, it is added to the [GATEWAYS_TO_REMOVE] map,
/// in order to remove it in the next keep alive check.
func decrement_gateway_clients_count(gateway_principal : GatewayPrincipal) {
switch (REGISTERED_GATEWAYS.get(gateway_principal)) {
case (?registered_gateway) {
let clients_count = registered_gateway.decrement_clients_count();

if (clients_count == 0) {
REGISTERED_GATEWAYS.delete(gateway_principal);
GATEWAYS_TO_REMOVE.put(gateway_principal, Utils.get_current_time());
};
};
case (null) {
Prelude.unreachable(); // gateway must be registered at this point
// do nothing
};
};
};

/// Removes the gateways that were added to the [GATEWAYS_TO_REMOVE] map
/// more than the ack interval ms time ago from the list of registered gateways
public func remove_empty_expired_gateways() {
let ack_interval_ms = init_params.send_ack_interval_ms;
let time = Utils.get_current_time();

let gateway_principals_to_remove : Buffer.Buffer<GatewayPrincipal> = Buffer.Buffer(GATEWAYS_TO_REMOVE.size());
GATEWAYS_TO_REMOVE := HashMap.mapFilter(
GATEWAYS_TO_REMOVE,
Principal.equal,
Principal.hash,
func(gp : GatewayPrincipal, added_at : Types.TimestampNs) : ?Types.TimestampNs {
if (time - added_at > (ack_interval_ms * 1_000_000)) {
gateway_principals_to_remove.add(gp);
null;
} else {
?added_at;
};
},
);

for (gateway_principal in gateway_principals_to_remove.vals()) {
switch (
Option.map(
REGISTERED_GATEWAYS.remove(gateway_principal),
func(g : RegisteredGateway) : List.List<Text> {
List.map(g.messages_queue, func(m : CanisterOutputMessage) : Text { m.key });
},
)
) {
case (?messages_keys_to_delete) {
delete_keys_from_cert_tree(messages_keys_to_delete);
};
case (null) {
// do nothing
};
};
};
};
Expand Down Expand Up @@ -253,9 +303,22 @@ module {
};

/// Removes a client from the internal state
/// and call the on_close callback,
/// and call the on_close callback (if handlers are provided),
/// if the client was registered in the state.
public func remove_client(client_key : ClientKey, handlers : WsHandlers) : async () {
///
/// If a `close_reason` is provided, it also sends a close message to the client,
/// so that the client can close the WS connection with the gateway.
public func remove_client(client_key : ClientKey, handlers : ?WsHandlers, close_reason : ?Types.CloseMessageReason) : async () {
switch (close_reason) {
case (?close_reason) {
// ignore the error
ignore send_service_message_to_client(client_key, #CloseMessage({ reason = close_reason }));
};
case (null) {
// Do nothing
};
};

CLIENTS_WAITING_FOR_KEEP_ALIVE := TrieSet.delete(CLIENTS_WAITING_FOR_KEEP_ALIVE, client_key, Types.hashClientKey(client_key), Types.areClientKeysEqual);
CURRENT_CLIENT_KEY_MAP.delete(client_key.client_principal);
OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.delete(client_key);
Expand All @@ -265,9 +328,16 @@ module {
case (?registered_client) {
decrement_gateway_clients_count(registered_client.gateway_principal);

await handlers.call_on_close({
client_principal = client_key.client_principal;
});
switch (handlers) {
case (?handlers) {
await handlers.call_on_close({
client_principal = client_key.client_principal;
});
};
case (null) {
// Do nothing
};
};
};
case (null) {
// Do nothing
Expand Down Expand Up @@ -406,7 +476,7 @@ module {
switch (get_registered_gateway(gateway_principal)) {
case (#Ok(registered_gateway)) {
// messages in the queue are inserted with contiguous and increasing nonces
// (from beginning to end of the queue) as ws_send is called sequentially, the nonce
// (from beginning to end of the queue) as `send` is called sequentially, the nonce
// is incremented by one in each call, and the message is pushed at the end of the queue
registered_gateway.add_message_to_queue(message, message_timestamp);
#Ok;
Expand All @@ -426,13 +496,23 @@ module {
case (#Err(err)) { return #Err(err) };
};

for (key in Iter.fromList(deleted_messages_keys)) {
CERT_TREE.delete([Text.encodeUtf8(key)]);
};
delete_keys_from_cert_tree(deleted_messages_keys);

#Ok;
};

func delete_keys_from_cert_tree(keys : List.List<Text>) {
let root_hash = do {
for (key in Iter.fromList(keys)) {
CERT_TREE.delete([Text.encodeUtf8(key)]);
};
labeledHash(Constants.LABEL_WEBSOCKET, CERT_TREE.treeHash());
};

// certify data with the new root hash
CertifiedData.set(root_hash);
};

func get_cert_for_range(keys : Iter.Iter<CertTree.Path>) : (Blob, Blob) {
let witness = CERT_TREE.reveals(keys);
let tree : CertTree.Witness = #labeled(Constants.LABEL_WEBSOCKET, witness);
Expand Down Expand Up @@ -485,7 +565,7 @@ module {
};

/// Internal function used to put the messages in the outgoing messages queue and certify them.
public func _ws_send(client_key : ClientKey, msg_bytes : Blob, is_service_message : Bool) : CanisterWsSendResult {
public func _ws_send(client_key : ClientKey, msg_bytes : Blob, is_service_message : Bool) : CanisterSendResult {
// get the registered client if it exists
let registered_client = switch (get_registered_client(client_key)) {
case (#Err(err)) {
Expand Down Expand Up @@ -579,7 +659,7 @@ module {
);
};

public func _ws_send_to_client_principal(client_principal : ClientPrincipal, msg_bytes : Blob) : CanisterWsSendResult {
public func _ws_send_to_client_principal(client_principal : ClientPrincipal, msg_bytes : Blob) : CanisterSendResult {
let client_key = switch (get_client_key_from_principal(client_principal)) {
case (#Err(err)) {
return #Err(err);
Expand All @@ -590,5 +670,20 @@ module {
};
_ws_send(client_key, msg_bytes, false);
};

public func _close_for_client_principal(client_principal : ClientPrincipal, handlers : ?WsHandlers) : async CanisterCloseResult {
let client_key = switch (get_client_key_from_principal(client_principal)) {
case (#Err(err)) {
return #Err(err);
};
case (#Ok(client_key)) {
client_key;
};
};

await remove_client(client_key, handlers, ? #ClosedByApplication);

#Ok;
};
};
};
23 changes: 14 additions & 9 deletions src/Timers.mo
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import Nat64 "mo:base/Nat64";
import Array "mo:base/Array";
import TrieSet "mo:base/TrieSet";

import Constants "Constants";
import Types "Types";
import State "State";
import Utils "Utils";
Expand Down Expand Up @@ -49,13 +50,13 @@ module {
///
/// The interval callback is [send_ack_to_clients_timer_callback]. After the callback is executed,
/// a timer is scheduled to check if the registered clients have sent a keep alive message.
public func schedule_send_ack_to_clients(ws_state : State.IcWebSocketState, ack_interval_ms : Nat64, keep_alive_timeout_ms : Nat64, handlers : Types.WsHandlers) {
public func schedule_send_ack_to_clients(ws_state : State.IcWebSocketState, ack_interval_ms : Nat64, handlers : Types.WsHandlers) {
let timer_id = Timer.recurringTimer(
#nanoseconds(Nat64.toNat(ack_interval_ms * 1_000_000)),
func() : async () {
send_ack_to_clients_timer_callback(ws_state, ack_interval_ms);

schedule_check_keep_alive(ws_state, keep_alive_timeout_ms, handlers);
schedule_check_keep_alive(ws_state, handlers);
},
);

Expand All @@ -66,11 +67,11 @@ module {
/// after receiving an acknowledgement message.
///
/// The timer callback is [check_keep_alive_timer_callback].
func schedule_check_keep_alive(ws_state : State.IcWebSocketState, keep_alive_timeout_ms : Nat64, handlers : Types.WsHandlers) {
func schedule_check_keep_alive(ws_state : State.IcWebSocketState, handlers : Types.WsHandlers) {
let timer_id = Timer.setTimer(
#nanoseconds(Nat64.toNat(keep_alive_timeout_ms * 1_000_000)),
#nanoseconds(Nat64.toNat(Constants.Computed().CLIENT_KEEP_ALIVE_TIMEOUT_NS)),
func() : async () {
await check_keep_alive_timer_callback(ws_state, keep_alive_timeout_ms, handlers);
await check_keep_alive_timer_callback(ws_state, handlers);
},
);

Expand Down Expand Up @@ -113,17 +114,21 @@ module {

/// Checks if the clients for which we are waiting for keep alive have sent a keep alive message.
/// If a client has not sent a keep alive message, it is removed from the connected clients.
func check_keep_alive_timer_callback(ws_state : State.IcWebSocketState, keep_alive_timeout_ms : Nat64, handlers : Types.WsHandlers) : async () {
///
/// Before checking the clients, it removes all the empty expired gateways from the list of registered gateways.
func check_keep_alive_timer_callback(ws_state : State.IcWebSocketState, handlers : Types.WsHandlers) : async () {
ws_state.remove_empty_expired_gateways();

for (client_key in Array.vals(TrieSet.toArray(ws_state.CLIENTS_WAITING_FOR_KEEP_ALIVE))) {
// get the last keep alive timestamp for the client and check if it has exceeded the timeout
switch (ws_state.REGISTERED_CLIENTS.get(client_key)) {
case (?client_metadata) {
let last_keep_alive = client_metadata.get_last_keep_alive_timestamp();

if (Utils.get_current_time() - last_keep_alive > (keep_alive_timeout_ms * 1_000_000)) {
await ws_state.remove_client(client_key, handlers);
if (Utils.get_current_time() - last_keep_alive > Constants.Computed().CLIENT_KEEP_ALIVE_TIMEOUT_NS) {
await ws_state.remove_client(client_key, ?handlers, ? #KeepAliveTimeout);

Utils.custom_print("[check-keep-alive-timer-cb]: Client " # Types.clientKeyToText(client_key) # " has not sent a keep alive message in the last " # debug_show (keep_alive_timeout_ms) # " ms and has been removed");
Utils.custom_print("[check-keep-alive-timer-cb]: Client " # Types.clientKeyToText(client_key) # " has not sent a keep alive message in the last " # debug_show (Constants.Computed().CLIENT_KEEP_ALIVE_TIMEOUT_MS) # " ms and has been removed");
};
};
case (null) {
Expand Down
Loading

0 comments on commit 344bd5e

Please sign in to comment.