From c37e805e9e8bae487665686c4d43051d03d1494a Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Wed, 29 Nov 2023 22:39:02 +0100 Subject: [PATCH 1/8] chore: disable tests on draft PRs --- .github/workflows/tests.yml | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 32f7d23..1e9203f 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -5,11 +5,16 @@ on: branches: - main pull_request: + types: + - opened + - synchronize + - reopened + - ready_for_review jobs: test: + if: github.event.pull_request.draft == false runs-on: ubuntu-latest - steps: - uses: actions/checkout@v4 with: From 5881a5fb3f3d752dabacc67b552c130d27a98e6f Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Sun, 3 Dec 2023 18:11:47 +0100 Subject: [PATCH 2/8] chore: update readme --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 8d9e207..d163388 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ [![mops](https://oknww-riaaa-aaaam-qaf6a-cai.raw.ic0.app/badge/mops/ic-websocket-cdk)](https://mops.one/ic-websocket-cdk) -This repository contains the Motoko implementation of IC WebSocket CDK. For more information about IC WebSockets, see [IC WebSocket Gateway](https://github.com/omnia-network/ic-websocket-gateway). +This repository contains the Motoko implementation of IC WebSocket CDK and is basically a mirror of the [Rust CDK](https://github.com/omnia-network/ic-websocket-cdk-rs). For more information about IC WebSockets, see [IC WebSocket Gateway](https://github.com/omnia-network/ic-websocket-gateway). ## Installation @@ -14,7 +14,7 @@ mops add ic-websocket-cdk ## Usage -Refer to the [ic-websockets-pingpong-mo](https://github.com/iamenochchirima/ic-websockets-pingpong-mo) repository for an example of how to use this library. +Refer to the [ic-websockets-pingpong-mo](https://github.com/iamenochchirima/ic-websockets-pingpong-mo) and/or [ic-websockets-chat-mo](https://github.com/iamenochchirima/ic-websockets-chat-mo) repositories for examples of how to use this library. ### Candid interface In order for the frontend clients and the Gateway to work properly, the canister must expose some specific methods in its Candid interface, between the custom methods that you've implemented for your logic. A valid Candid interface for the canister is the following: @@ -23,7 +23,7 @@ In order for the frontend clients and the Gateway to work properly, the canister import "./ws_types.did"; // define here your message type -type MyMessageType = { +type MyMessageType = record { some_field : text; }; @@ -53,7 +53,7 @@ Clone the repo with submodules: git clone --recurse-submodules https://github.com/omnia-network/ic-websocket-cdk-mo.git ``` -Integration tests are imported from the [IC WebSocket Rust CDK](https://github.com/omnia-network/ic-websocket-cdk-rs.git), linked to this repo from the `ic-websocket-cdk-rs` submodule in the [tests](./tests/) folder. +Integration tests are imported from the [IC WebSocket Rust CDK](https://github.com/omnia-network/ic-websocket-cdk-rs.git), linked to this repo through the [`ic-websocket-cdk-rs`](./tests/ic-websocket-cdk-rs) submodule. There's a script that runs the integration tests, taking care of installing dependencies and setting up the local environment. To run the script, execute the following command: From 8b58f0040fe6a5420800a62c3486137b75759930 Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Sun, 3 Dec 2023 23:12:35 +0100 Subject: [PATCH 3/8] refactor: align to Rust CDK https://github.com/omnia-network/ic-websocket-cdk-rs/commit/fe627fbbc9ef7e872b7ff291da9c5d20c2459599 --- src/Constants.mo | 21 + src/Errors.mo | 88 ++ src/Logger.mo | 8 - src/State.mo | 592 ++++++++ src/Timers.mo | 137 ++ src/Types.mo | 557 +++++++ src/Utils.mo | 18 + src/lib.mo | 1292 ++--------------- tests/test_canister/src/test_canister/main.mo | 22 +- 9 files changed, 1537 insertions(+), 1198 deletions(-) create mode 100644 src/Constants.mo create mode 100644 src/Errors.mo delete mode 100644 src/Logger.mo create mode 100644 src/State.mo create mode 100644 src/Timers.mo create mode 100644 src/Types.mo create mode 100644 src/Utils.mo diff --git a/src/Constants.mo b/src/Constants.mo new file mode 100644 index 0000000..e242e72 --- /dev/null +++ b/src/Constants.mo @@ -0,0 +1,21 @@ +module { + /// The label used when constructing the certification tree. + public let LABEL_WEBSOCKET : Blob = "websocket"; + /// The default maximum number of messages returned by [ws_get_messages] at each poll. + public let DEFAULT_MAX_NUMBER_OF_RETURNED_MESSAGES : Nat = 50; + /// The default interval at which to send acknowledgements to the client. + public let DEFAULT_SEND_ACK_INTERVAL_MS : Nat64 = 300_000; // 5 minutes + /// The default timeout to wait for the client to send a keep alive after receiving an acknowledgement. + public let DEFAULT_CLIENT_KEEP_ALIVE_TIMEOUT_MS : Nat64 = 60_000; // 1 minute + + /// The initial nonce for outgoing messages. + public let INITIAL_OUTGOING_MESSAGE_NONCE : Nat64 = 0; + /// The initial sequence number to expect from messages coming from clients. + /// The first message coming from the client will have sequence number `1` because on the client the sequence number is incremented before sending the message. + public let INITIAL_CLIENT_SEQUENCE_NUM : Nat64 = 1; + /// The initial sequence number for outgoing messages. + public let INITIAL_CANISTER_SEQUENCE_NUM : Nat64 = 0; + + /// The number of messages to delete from the outgoing messages queue every time a new message is added. + public let MESSAGES_TO_DELETE_COUNT : Nat = 5; +}; diff --git a/src/Errors.mo b/src/Errors.mo new file mode 100644 index 0000000..a0904aa --- /dev/null +++ b/src/Errors.mo @@ -0,0 +1,88 @@ +import Principal "mo:base/Principal"; +import Nat64 "mo:base/Nat64"; + +import Types "Types"; + +module { + public type WsError = { + #AnonymousPrincipalNotAllowed; + #ClientKeyAlreadyConnected : { + client_key : Types.ClientKey; + }; + #ClientKeyMessageMismatch : { + client_key : Types.ClientKey; + }; + #ClientKeyNotConnected : { + client_key : Types.ClientKey; + }; + #ClientNotRegisteredToGateway : { + client_key : Types.ClientKey; + gateway_principal : Types.GatewayPrincipal; + }; + #ClientPrincipalNotConnected : { + client_principal : Types.ClientPrincipal; + }; + #DecodeServiceMessageContent : { + err : Text; + }; + #ExpectedIncomingMessageToClientNumNotInitialized : { + client_key : Types.ClientKey; + }; + #GatewayNotRegistered : { + gateway_principal : Types.GatewayPrincipal; + }; + #InvalidServiceMessage; + #IncomingSequenceNumberWrong : { + expected_sequence_num : Nat64; + actual_sequence_num : Nat64; + }; + #OutgoingMessageToClientNumNotInitialized : { + client_key : Types.ClientKey; + }; + }; + + public func to_string(err : WsError) : Text { + switch (err) { + case (#AnonymousPrincipalNotAllowed) { + "Anonymous principal is not allowed"; + }; + case (#ClientKeyAlreadyConnected({ client_key })) { + "Client with key " # Types.clientKeyToText(client_key) # " already has an open connection"; + }; + case (#ClientKeyMessageMismatch({ client_key })) { + "Client with principal " # Principal.toText(client_key.client_principal) # " has a different key than the one used in the message"; + }; + case (#ClientKeyNotConnected({ client_key })) { + "Client with key " # Types.clientKeyToText(client_key) # " doesn't have an open connection"; + }; + case (#ClientNotRegisteredToGateway({ client_key; gateway_principal })) { + "Client with key " # Types.clientKeyToText(client_key) # " was not registered to gateway " # Principal.toText(gateway_principal); + }; + case (#ClientPrincipalNotConnected({ client_principal })) { + "Client with principal " # Principal.toText(client_principal) # " doesn't have an open connection"; + }; + case (#DecodeServiceMessageContent({ err })) { + "Error decoding service message content: " # err; + }; + case (#ExpectedIncomingMessageToClientNumNotInitialized({ client_key })) { + "Expected incoming message to client num not initialized for client key " # Types.clientKeyToText(client_key); + }; + case (#GatewayNotRegistered({ gateway_principal })) { + "Gateway with principal " # Principal.toText(gateway_principal) # " is not registered"; + }; + case (#InvalidServiceMessage) { + "Invalid service message"; + }; + case (#IncomingSequenceNumberWrong({ expected_sequence_num; actual_sequence_num })) { + "Expected incoming sequence number " # Nat64.toText(expected_sequence_num) # " but got " # Nat64.toText(actual_sequence_num); + }; + case (#OutgoingMessageToClientNumNotInitialized({ client_key })) { + "Outgoing message to client num not initialized for client key " # Types.clientKeyToText(client_key); + }; + }; + }; + + public func to_string_result(err : WsError) : Types.Result<(), Text> { + #Err(to_string(err)); + }; +}; diff --git a/src/Logger.mo b/src/Logger.mo deleted file mode 100644 index 22b2d1f..0000000 --- a/src/Logger.mo +++ /dev/null @@ -1,8 +0,0 @@ -import Text "mo:base/Text"; -import Debug "mo:base/Debug"; - -module { - public func custom_print(s : Text) { - Debug.print(Text.concat("[IC-WEBSOCKET-CDK]: ", s)); - }; -}; diff --git a/src/State.mo b/src/State.mo new file mode 100644 index 0000000..55daff6 --- /dev/null +++ b/src/State.mo @@ -0,0 +1,592 @@ +import HashMap "mo:base/HashMap"; +import TrieSet "mo:base/TrieSet"; +import Timer "mo:base/Timer"; +import List "mo:base/List"; +import Iter "mo:base/Iter"; +import Principal "mo:base/Principal"; +import Prelude "mo:base/Prelude"; +import Option "mo:base/Option"; +import Nat64 "mo:base/Nat64"; +import Text "mo:base/Text"; +import Blob "mo:base/Blob"; +import CertifiedData "mo:base/CertifiedData"; +import CertTree "mo:ic-certification/CertTree"; +import Sha256 "mo:sha2/Sha256"; + +import Constants "Constants"; +import Errors "Errors"; +import Types "Types"; +import Utils "Utils"; + +module { + type CanisterOutputMessage = Types.CanisterOutputMessage; + type CanisterWsGetMessagesResult = Types.CanisterWsGetMessagesResult; + type CanisterWsSendResult = Types.CanisterWsSendResult; + type ClientKey = Types.ClientKey; + type ClientPrincipal = Types.ClientPrincipal; + type GatewayPrincipal = Types.GatewayPrincipal; + type RegisteredClient = Types.RegisteredClient; + type RegisteredGateway = Types.RegisteredGateway; + type Result = Types.Result; + type WsInitParams = Types.WsInitParams; + type WsHandlers = Types.WsHandlers; + + /// IC WebSocket class that holds the internal state of the IC WebSocket. + /// + /// Arguments: + /// + /// - `init_params`: `WsInitParams`. + /// + /// **Note**: you should only pass an instance of this class to the IcWebSocket class constructor, without using the methods or accessing the fields directly. + /// + /// # Traps + /// If the parameters are invalid. See [`WsInitParams.check_validity`] for more details. + public class IcWebSocketState(init_params : WsInitParams) = self { + //// STATE //// + /// Maps the client's key to the client metadata. + public var REGISTERED_CLIENTS = HashMap.HashMap(0, Types.areClientKeysEqual, Types.hashClientKey); + /// Maps the client's principal to the current client key. + var CURRENT_CLIENT_KEY_MAP = HashMap.HashMap(0, Principal.equal, Principal.hash); + /// Keeps track of all the clients for which we're waiting for a keep alive message. + public var CLIENTS_WAITING_FOR_KEEP_ALIVE : TrieSet.Set = TrieSet.empty(); + /// Maps the client's public key to the sequence number to use for the next outgoing message (to that client). + var OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP = HashMap.HashMap(0, Types.areClientKeysEqual, Types.hashClientKey); + /// Maps the client's public key to the expected sequence number of the next incoming message (from that client). + var INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP = HashMap.HashMap(0, Types.areClientKeysEqual, Types.hashClientKey); + /// Keeps track of the Merkle tree used for certified queries. + var CERT_TREE_STORE : CertTree.Store = CertTree.newStore(); + var CERT_TREE = CertTree.Ops(CERT_TREE_STORE); + /// Keeps track of the principals of the WS Gateways that poll the canister. + var REGISTERED_GATEWAYS = HashMap.HashMap(0, Principal.equal, Principal.hash); + /// The acknowledgement active timer. + public var ACK_TIMER : ?Timer.TimerId = null; + /// The keep alive active timer. + public var KEEP_ALIVE_TIMER : ?Timer.TimerId = null; + + //// FUNCTIONS //// + /// Resets all state to the initial state. + public func reset_internal_state(handlers : WsHandlers) : async () { + // for each client, call the on_close handler before clearing the map + for (client_key in REGISTERED_CLIENTS.keys()) { + await remove_client(client_key, handlers); + }; + + // make sure all the maps are cleared + CURRENT_CLIENT_KEY_MAP := HashMap.HashMap(0, Principal.equal, Principal.hash); + CLIENTS_WAITING_FOR_KEEP_ALIVE := TrieSet.empty(); + OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP := HashMap.HashMap(0, Types.areClientKeysEqual, Types.hashClientKey); + INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP := HashMap.HashMap(0, Types.areClientKeysEqual, Types.hashClientKey); + CERT_TREE_STORE := CertTree.newStore(); + CERT_TREE := CertTree.Ops(CERT_TREE_STORE); + REGISTERED_GATEWAYS := HashMap.HashMap(0, Principal.equal, Principal.hash); + }; + + /// Increments the clients connected count for the given gateway. + /// If the gateway is not registered, a new entry is created with a clients connected count of 1. + func increment_gateway_clients_count(gateway_principal : GatewayPrincipal) { + switch (REGISTERED_GATEWAYS.get(gateway_principal)) { + case (?registered_gateway) { + registered_gateway.increment_clients_count(); + }; + case (null) { + let new_gw = Types.RegisteredGateway(); + new_gw.increment_clients_count(); + REGISTERED_GATEWAYS.put(gateway_principal, new_gw); + }; + }; + }; + + /// Decrements the clients connected count for the given gateway. + /// If there are no more clients connected, the gateway is removed from the list of registered gateways. + func decrement_gateway_clients_count(gateway_principal : GatewayPrincipal) { + switch (REGISTERED_GATEWAYS.get(gateway_principal)) { + case (?registered_gateway) { + let clients_count = registered_gateway.decrement_clients_count(); + if (clients_count == 0) { + REGISTERED_GATEWAYS.delete(gateway_principal); + }; + }; + case (null) { + Prelude.unreachable(); // gateway must be registered at this point + }; + }; + }; + + func get_registered_gateway(gateway_principal : GatewayPrincipal) : Result { + switch (REGISTERED_GATEWAYS.get(gateway_principal)) { + case (?registered_gateway) { #Ok(registered_gateway) }; + case (null) { + #Err(Errors.to_string(#GatewayNotRegistered({ gateway_principal }))); + }; + }; + }; + + public func check_is_gateway_registered(gateway_principal : GatewayPrincipal) : Result<(), Text> { + switch (get_registered_gateway(gateway_principal)) { + case (#Ok(_)) { #Ok }; + case (#Err(err)) { #Err(err) }; + }; + }; + + public func is_registered_gateway(gateway_principal : GatewayPrincipal) : Bool { + switch (get_registered_gateway(gateway_principal)) { + case (#Ok(_)) { true }; + case (#Err(err)) { false }; + }; + }; + + public func get_outgoing_message_nonce(gateway_principal : GatewayPrincipal) : Result { + switch (get_registered_gateway(gateway_principal)) { + case (#Ok(registered_gateway)) { + #Ok(registered_gateway.outgoing_message_nonce); + }; + case (#Err(err)) { #Err(err) }; + }; + }; + + public func increment_outgoing_message_nonce(gateway_principal : GatewayPrincipal) : Result<(), Text> { + switch (get_registered_gateway(gateway_principal)) { + case (#Ok(registered_gateway)) { + registered_gateway.increment_nonce(); + #Ok; + }; + case (#Err(err)) { #Err(err) }; + }; + }; + + func insert_client(client_key : ClientKey, new_client : RegisteredClient) { + CURRENT_CLIENT_KEY_MAP.put(client_key.client_principal, client_key); + REGISTERED_CLIENTS.put(client_key, new_client); + }; + + func get_registered_client(client_key : ClientKey) : Result { + switch (REGISTERED_CLIENTS.get(client_key)) { + case (?registered_client) { #Ok(registered_client) }; + case (null) { + #Err(Errors.to_string(#ClientKeyNotConnected({ client_key }))); + }; + }; + }; + + public func get_client_key_from_principal(client_principal : ClientPrincipal) : Result { + switch (CURRENT_CLIENT_KEY_MAP.get(client_principal)) { + case (?client_key) #Ok(client_key); + case (null) #Err(Errors.to_string(#ClientPrincipalNotConnected({ client_principal }))); + }; + }; + + public func check_registered_client_exists(client_key : ClientKey) : Result<(), Text> { + switch (get_registered_client(client_key)) { + case (#Ok(_)) { #Ok }; + case (#Err(err)) { #Err(err) }; + }; + }; + + public func check_client_registered_to_gateway(client_key : ClientKey, gateway_principal : GatewayPrincipal) : Result<(), Text> { + switch (get_registered_client(client_key)) { + case (#Ok(registered_client)) { + if (Principal.equal(registered_client.gateway_principal, gateway_principal)) { + #Ok; + } else { + #Err(Errors.to_string(#ClientNotRegisteredToGateway({ client_key; gateway_principal }))); + }; + }; + case (#Err(err)) { #Err(err) }; + }; + }; + + public func add_client_to_wait_for_keep_alive(client_key : ClientKey) { + CLIENTS_WAITING_FOR_KEEP_ALIVE := TrieSet.put(CLIENTS_WAITING_FOR_KEEP_ALIVE, client_key, Types.hashClientKey(client_key), Types.areClientKeysEqual); + }; + + func init_outgoing_message_to_client_num(client_key : ClientKey) { + OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.put(client_key, Constants.INITIAL_CANISTER_SEQUENCE_NUM); + }; + + public func get_outgoing_message_to_client_num(client_key : ClientKey) : Result { + switch (OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.get(client_key)) { + case (?num) #Ok(num); + case (null) #Err(Errors.to_string(#OutgoingMessageToClientNumNotInitialized({ client_key }))); + }; + }; + + public func increment_outgoing_message_to_client_num(client_key : ClientKey) : Result<(), Text> { + switch (get_outgoing_message_to_client_num(client_key)) { + case (#Ok(num)) { + OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.put(client_key, num + 1); + #Ok; + }; + case (#Err(error)) #Err(error); + }; + }; + + func init_expected_incoming_message_from_client_num(client_key : ClientKey) { + INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.put(client_key, Constants.INITIAL_CLIENT_SEQUENCE_NUM); + }; + + public func get_expected_incoming_message_from_client_num(client_key : ClientKey) : Result { + switch (INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.get(client_key)) { + case (?num) #Ok(num); + case (null) #Err(Errors.to_string(#ExpectedIncomingMessageToClientNumNotInitialized({ client_key }))); + }; + }; + + public func increment_expected_incoming_message_from_client_num(client_key : ClientKey) : Result<(), Text> { + switch (get_expected_incoming_message_from_client_num(client_key)) { + case (#Ok(num)) { + INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.put(client_key, num + 1); + #Ok; + }; + case (#Err(error)) #Err(error); + }; + }; + + public func add_client(client_key : ClientKey, new_client : RegisteredClient) { + // insert the client in the map + insert_client(client_key, new_client); + // initialize incoming client's message sequence number to 1 + init_expected_incoming_message_from_client_num(client_key); + // initialize outgoing message sequence number to 0 + init_outgoing_message_to_client_num(client_key); + + increment_gateway_clients_count(new_client.gateway_principal); + }; + + public func remove_client(client_key : ClientKey, handlers : WsHandlers) : async () { + CLIENTS_WAITING_FOR_KEEP_ALIVE := TrieSet.delete(CLIENTS_WAITING_FOR_KEEP_ALIVE, client_key, Types.hashClientKey(client_key), Types.areClientKeysEqual); + CURRENT_CLIENT_KEY_MAP.delete(client_key.client_principal); + OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.delete(client_key); + INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.delete(client_key); + + let registered_client = REGISTERED_CLIENTS.remove(client_key); + switch (registered_client) { + case (?registered_client) { + decrement_gateway_clients_count(registered_client.gateway_principal); + }; + case (null) { + Prelude.unreachable(); + }; + }; + + await handlers.call_on_close({ + client_principal = client_key.client_principal; + }); + }; + + public func format_message_for_gateway_key(gateway_principal : Principal, nonce : Nat64) : Text { + let nonce_to_text = do { + // prints the nonce with 20 padding zeros + var nonce_str = Nat64.toText(nonce); + let padding : Nat = 20 - Text.size(nonce_str); + if (padding > 0) { + for (i in Iter.range(0, padding - 1)) { + nonce_str := "0" # nonce_str; + }; + }; + + nonce_str; + }; + Principal.toText(gateway_principal) # "_" # nonce_to_text; + }; + + func get_gateway_messages_queue(gateway_principal : Principal) : List.List { + switch (get_registered_gateway(gateway_principal)) { + case (#Ok(registered_gateway)) { + registered_gateway.messages_queue; + }; + case (#Err(error)) { + // the value exists because we just checked that the gateway is registered + Prelude.unreachable(); + }; + }; + }; + + func get_messages_for_gateway_range(gateway_principal : Principal, nonce : Nat64, max_number_of_returned_messages : Nat) : Types.MessagesForGatewayRange { + let messages_queue = get_gateway_messages_queue(gateway_principal); + + let queue_len = List.size(messages_queue); + + // smallest key used to determine the first message from the queue which has to be returned to the WS Gateway + let smallest_key = format_message_for_gateway_key(gateway_principal, nonce); + // partition the queue at the message which has the key with the nonce specified as argument to get_cert_messages + let start_index = do { + let partitions = List.partition( + messages_queue, + func(el : CanisterOutputMessage) : Bool { + Text.less(el.key, smallest_key); + }, + ); + List.size(partitions.0); + }; + let (end_index, is_end_of_queue) = if ((queue_len - start_index) : Nat > max_number_of_returned_messages) { + (start_index + max_number_of_returned_messages, false); + } else { (queue_len, true) }; + + { + start_index; + end_index; + is_end_of_queue; + }; + }; + + func get_messages_for_gateway(gateway_principal : Principal, start_index : Nat, end_index : Nat) : List.List { + let messages_queue = get_gateway_messages_queue(gateway_principal); + + var messages : List.List = List.nil(); + for (i in Iter.range(start_index, end_index - 1)) { + let message = List.get(messages_queue, i); + switch (message) { + case (?message) { + messages := List.push(message, messages); + }; + case (null) { + Prelude.unreachable(); // the value exists because this function is called only after partitioning the queue + }; + }; + }; + + List.reverse(messages); + }; + + /// Gets the messages in [MESSAGES_FOR_GATEWAYS] starting from the one with the specified nonce + public func get_cert_messages(gateway_principal : Principal, nonce : Nat64, max_number_of_returned_messages : Nat) : CanisterWsGetMessagesResult { + let { start_index; end_index; is_end_of_queue } = get_messages_for_gateway_range(gateway_principal, nonce, max_number_of_returned_messages); + let messages = get_messages_for_gateway(gateway_principal, start_index, end_index); + + if (List.isNil(messages)) { + return get_cert_messages_empty(); + }; + + let keys = List.map( + messages, + func(message : CanisterOutputMessage) : CertTree.Path { + [Text.encodeUtf8(message.key)]; + }, + ); + let (cert, tree) = get_cert_for_range(List.toIter(keys)); + + #Ok({ + messages = List.toArray(messages); + cert = cert; + tree = tree; + is_end_of_queue = is_end_of_queue; + }); + }; + + public func get_cert_messages_empty() : CanisterWsGetMessagesResult { + #Ok({ + messages = []; + cert = Blob.fromArray([]); + tree = Blob.fromArray([]); + is_end_of_queue = true; + }); + }; + + func labeledHash(l : Blob, content : CertTree.Hash) : CertTree.Hash { + let d = Sha256.Digest(#sha256); + d.writeBlob("\13ic-hashtree-labeled"); + d.writeBlob(l); + d.writeBlob(content); + d.sum(); + }; + + public func put_cert_for_message(key : Text, value : Blob) { + let root_hash = do { + CERT_TREE.put([Text.encodeUtf8(key)], Sha256.fromBlob(#sha256, value)); + labeledHash(Constants.LABEL_WEBSOCKET, CERT_TREE.treeHash()); + }; + + CertifiedData.set(root_hash); + }; + + /// Adds the message to the gateway queue. + func push_message_in_gateway_queue(gateway_principal : Principal, message : CanisterOutputMessage, message_timestamp : Nat64) : Result<(), Text> { + switch (get_registered_gateway(gateway_principal)) { + case (#Ok(registered_gateway)) { + // messages in the queue are inserted with contiguous and increasing nonces + // (from beginning to end of the queue) as ws_send is called sequentially, the nonce + // is incremented by one in each call, and the message is pushed at the end of the queue + registered_gateway.add_message_to_queue(message, message_timestamp); + #Ok; + }; + case (#Err(err)) { #Err(err) }; + }; + }; + + /// Deletes the an amount of [MESSAGES_TO_DELETE] messages from the queue + /// that are older than the ack interval. + func delete_old_messages_for_gateway(gateway_principal : GatewayPrincipal) : Result<(), Text> { + let ack_interval_ms = init_params.send_ack_interval_ms; + let deleted_messages_keys = switch (get_registered_gateway(gateway_principal)) { + case (#Ok(registered_gateway)) { + registered_gateway.delete_old_messages(Constants.MESSAGES_TO_DELETE_COUNT, ack_interval_ms); + }; + case (#Err(err)) { return #Err(err) }; + }; + + for (key in Iter.fromList(deleted_messages_keys)) { + CERT_TREE.delete([Text.encodeUtf8(key)]); + }; + + #Ok; + }; + + func get_cert_for_range(keys : Iter.Iter) : (Blob, Blob) { + let witness = CERT_TREE.reveals(keys); + let tree : CertTree.Witness = #labeled(Constants.LABEL_WEBSOCKET, witness); + + switch (CertifiedData.getCertificate()) { + case (?cert) { + let tree_blob = CERT_TREE.encodeWitness(tree); + (cert, tree_blob); + }; + case (null) Prelude.unreachable(); + }; + }; + + func handle_keep_alive_client_message(client_key : ClientKey, _keep_alive_message : Types.ClientKeepAliveMessageContent) { + // update the last keep alive timestamp for the client + switch (REGISTERED_CLIENTS.get(client_key)) { + case (?client_metadata) { + client_metadata.update_last_keep_alive_timestamp(); + }; + case (null) { + // Do nothing. + }; + }; + }; + + public func handle_received_service_message(client_key : ClientKey, content : Blob) : async Result<(), Text> { + let decoded = switch (Types.decode_websocket_service_message_content(content)) { + case (#Err(err)) { + return #Err(err); + }; + case (#Ok(message_content)) { + message_content; + }; + }; + + switch (decoded) { + case (#KeepAliveMessage(keep_alive_message)) { + handle_keep_alive_client_message(client_key, keep_alive_message); + #Ok; + }; + case (_) { + return #Err(Errors.to_string(#InvalidServiceMessage)); + }; + }; + }; + + public func send_service_message_to_client(client_key : ClientKey, message : Types.WebsocketServiceMessageContent) : Result<(), Text> { + let message_bytes = Types.encode_websocket_service_message_content(message); + _ws_send(client_key, message_bytes, true); + }; + + /// Internal function used to put the messages in the outgoing messages queue and certify them. + public func _ws_send(client_key : ClientKey, msg_bytes : Blob, is_service_message : Bool) : CanisterWsSendResult { + // get the registered client if it exists + let registered_client = switch (get_registered_client(client_key)) { + case (#Err(err)) { + return #Err(err); + }; + case (#Ok(registered_client)) { + registered_client; + }; + }; + + // the nonce in key is used by the WS Gateway to determine the message to start in the polling iteration + // the key is also passed to the client in order to validate the body of the certified message + let outgoing_message_nonce = switch (get_outgoing_message_nonce(registered_client.gateway_principal)) { + case (#Err(err)) { + return #Err(err); + }; + case (#Ok(nonce)) { + nonce; + }; + }; + let message_key = format_message_for_gateway_key(registered_client.gateway_principal, outgoing_message_nonce); + + // increment the nonce for the next message + switch (increment_outgoing_message_nonce(registered_client.gateway_principal)) { + case (#Err(err)) { + return #Err(err); + }; + case (_) { + // do nothing + }; + }; + + // increment the sequence number for the next message to the client + switch (increment_outgoing_message_to_client_num(client_key)) { + case (#Err(err)) { + return #Err(err); + }; + case (_) { + // do nothing + }; + }; + + let message_timestamp = Utils.get_current_time(); + + let websocket_message : Types.WebsocketMessage = { + client_key; + sequence_num = switch (get_outgoing_message_to_client_num(client_key)) { + case (#Err(err)) { + return #Err(err); + }; + case (#Ok(sequence_num)) { + sequence_num; + }; + }; + timestamp = message_timestamp; + is_service_message; + content = msg_bytes; + }; + + // CBOR serialize message of type WebsocketMessage + let message_content = switch (Types.encode_websocket_message(websocket_message)) { + case (#Err(err)) { + return #Err(err); + }; + case (#Ok(content)) { + content; + }; + }; + + // delete old messages from the gateway queue + switch (delete_old_messages_for_gateway(registered_client.gateway_principal)) { + case (#Err(err)) { + return #Err(err); + }; + case (_) { + // do nothing + }; + }; + + // certify data + put_cert_for_message(message_key, message_content); + + push_message_in_gateway_queue( + registered_client.gateway_principal, + { + client_key; + content = message_content; + key = message_key; + }, + message_timestamp, + ); + }; + + public func _ws_send_to_client_principal(client_principal : ClientPrincipal, msg_bytes : Blob) : CanisterWsSendResult { + let client_key = switch (get_client_key_from_principal(client_principal)) { + case (#Err(err)) { + return #Err(err); + }; + case (#Ok(client_key)) { + client_key; + }; + }; + _ws_send(client_key, msg_bytes, false); + }; + }; +}; diff --git a/src/Timers.mo b/src/Timers.mo new file mode 100644 index 0000000..49cbef2 --- /dev/null +++ b/src/Timers.mo @@ -0,0 +1,137 @@ +import Timer "mo:base/Timer"; +import Nat64 "mo:base/Nat64"; +import Array "mo:base/Array"; +import TrieSet "mo:base/TrieSet"; + +import Types "Types"; +import State "State"; +import Utils "Utils"; + +module { + func put_ack_timer_id(ws_state : State.IcWebSocketState, timer_id : Timer.TimerId) { + ws_state.ACK_TIMER := ?timer_id; + }; + + func cancel_ack_timer(ws_state : State.IcWebSocketState) { + switch (ws_state.ACK_TIMER) { + case (?t_id) { + Timer.cancelTimer(t_id); + ws_state.ACK_TIMER := null; + }; + case (null) { + // Do nothing + }; + }; + }; + + func put_keep_alive_timer_id(ws_state : State.IcWebSocketState, timer_id : Timer.TimerId) { + ws_state.KEEP_ALIVE_TIMER := ?timer_id; + }; + + func cancel_keep_alive_timer(ws_state : State.IcWebSocketState) { + switch (ws_state.KEEP_ALIVE_TIMER) { + case (?t_id) { + Timer.cancelTimer(t_id); + ws_state.KEEP_ALIVE_TIMER := null; + }; + case (null) { + // Do nothing + }; + }; + }; + + public func cancel_timers(ws_state : State.IcWebSocketState) { + cancel_ack_timer(ws_state); + cancel_keep_alive_timer(ws_state); + }; + + /// Start an interval to send an acknowledgement messages to the clients. + /// + /// The interval callback is [send_ack_to_clients_timer_callback]. After the callback is executed, + /// a timer is scheduled to check if the registered clients have sent a keep alive message. + public func schedule_send_ack_to_clients(ws_state : State.IcWebSocketState, ack_interval_ms : Nat64, keep_alive_timeout_ms : Nat64, handlers : Types.WsHandlers) { + let timer_id = Timer.recurringTimer( + #nanoseconds(Nat64.toNat(ack_interval_ms * 1_000_000)), + func() : async () { + send_ack_to_clients_timer_callback(ws_state, ack_interval_ms); + + schedule_check_keep_alive(ws_state, keep_alive_timeout_ms, handlers); + }, + ); + + put_ack_timer_id(ws_state, timer_id); + }; + + /// Schedules a timer to check if the clients (only those to which an ack message was sent) have sent a keep alive message + /// after receiving an acknowledgement message. + /// + /// The timer callback is [check_keep_alive_timer_callback]. + func schedule_check_keep_alive(ws_state : State.IcWebSocketState, keep_alive_timeout_ms : Nat64, handlers : Types.WsHandlers) { + let timer_id = Timer.setTimer( + #nanoseconds(Nat64.toNat(keep_alive_timeout_ms * 1_000_000)), + func() : async () { + await check_keep_alive_timer_callback(ws_state, keep_alive_timeout_ms, handlers); + }, + ); + + put_keep_alive_timer_id(ws_state, timer_id); + }; + + /// Sends an acknowledgement message to the client. + /// The message contains the current incoming message sequence number for that client, + /// so that the client knows that all the messages it sent have been received by the canister. + func send_ack_to_clients_timer_callback(ws_state : State.IcWebSocketState, ack_interval_ms : Nat64) { + for (client_key in ws_state.REGISTERED_CLIENTS.keys()) { + // ignore the error, which shouldn't happen since the client is registered and the sequence number is initialized + switch (ws_state.get_expected_incoming_message_from_client_num(client_key)) { + case (#Ok(expected_incoming_sequence_num)) { + let ack_message : Types.CanisterAckMessageContent = { + // the expected sequence number is 1 more because it's incremented when a message is received + last_incoming_sequence_num = expected_incoming_sequence_num - 1; + }; + let message : Types.WebsocketServiceMessageContent = #AckMessage(ack_message); + switch (ws_state.send_service_message_to_client(client_key, message)) { + case (#Err(err)) { + // TODO: decide what to do when sending the message fails + + Utils.custom_print("[ack-to-clients-timer-cb]: Error sending ack message to client" # Types.clientKeyToText(client_key) # ": " # err); + }; + case (#Ok(_)) { + ws_state.add_client_to_wait_for_keep_alive(client_key); + }; + }; + }; + case (#Err(err)) { + // TODO: decide what to do when getting the expected incoming sequence number fails (shouldn't happen) + Utils.custom_print("[ack-to-clients-timer-cb]: Error getting expected incoming sequence number for client" # Types.clientKeyToText(client_key) # ": " # err); + }; + }; + }; + + Utils.custom_print("[ack-to-clients-timer-cb]: Sent ack messages to all clients"); + }; + + /// Checks if the clients for which we are waiting for keep alive have sent a keep alive message. + /// If a client has not sent a keep alive message, it is removed from the connected clients. + func check_keep_alive_timer_callback(ws_state : State.IcWebSocketState, keep_alive_timeout_ms : Nat64, handlers : Types.WsHandlers) : async () { + for (client_key in Array.vals(TrieSet.toArray(ws_state.CLIENTS_WAITING_FOR_KEEP_ALIVE))) { + // get the last keep alive timestamp for the client and check if it has exceeded the timeout + switch (ws_state.REGISTERED_CLIENTS.get(client_key)) { + case (?client_metadata) { + let last_keep_alive = client_metadata.get_last_keep_alive_timestamp(); + + if (Utils.get_current_time() - last_keep_alive > (keep_alive_timeout_ms * 1_000_000)) { + await ws_state.remove_client(client_key, handlers); + + Utils.custom_print("[check-keep-alive-timer-cb]: Client " # Types.clientKeyToText(client_key) # " has not sent a keep alive message in the last " # debug_show (keep_alive_timeout_ms) # " ms and has been removed"); + }; + }; + case (null) { + // Do nothing + }; + }; + }; + + Utils.custom_print("[check-keep-alive-timer-cb]: Checked keep alive messages for all clients"); + }; +}; diff --git a/src/Types.mo b/src/Types.mo new file mode 100644 index 0000000..2f95581 --- /dev/null +++ b/src/Types.mo @@ -0,0 +1,557 @@ +import Hash "mo:base/Hash"; +import Principal "mo:base/Principal"; +import Text "mo:base/Text"; +import Nat64 "mo:base/Nat64"; +import Bool "mo:base/Bool"; +import List "mo:base/List"; +import Blob "mo:base/Blob"; +import Array "mo:base/Array"; +import Prelude "mo:base/Prelude"; +import Iter "mo:base/Iter"; +import Error "mo:base/Error"; +import CborDecoder "mo:cbor/Decoder"; +import CborEncoder "mo:cbor/Encoder"; +import CborValue "mo:cbor/Value"; + +import Constants "Constants"; +import Utils "Utils"; + +module { + /// Just to be compatible with the Rust version. + public type Result = { #Ok : Ok; #Err : Err }; + + public type ClientPrincipal = Principal; + + public type ClientKey = { + client_principal : ClientPrincipal; + client_nonce : Nat64; + }; + // functions needed for ClientKey + public func areClientKeysEqual(k1 : ClientKey, k2 : ClientKey) : Bool { + Principal.equal(k1.client_principal, k2.client_principal) and Nat64.equal(k1.client_nonce, k2.client_nonce); + }; + public func clientKeyToText(k : ClientKey) : Text { + Principal.toText(k.client_principal) # "_" # Nat64.toText(k.client_nonce); + }; + public func hashClientKey(k : ClientKey) : Hash.Hash { + Text.hash(clientKeyToText(k)); + }; + + /// The result of [ws_open]. + public type CanisterWsOpenResult = Result<(), Text>; + /// The result of [ws_close]. + public type CanisterWsCloseResult = Result<(), Text>; + // The result of [ws_message]. + public type CanisterWsMessageResult = Result<(), Text>; + /// The result of [ws_get_messages]. + public type CanisterWsGetMessagesResult = Result; + /// The result of [ws_send]. + public type CanisterWsSendResult = Result<(), Text>; + + /// The arguments for [ws_open]. + public type CanisterWsOpenArguments = { + client_nonce : Nat64; + gateway_principal : GatewayPrincipal; + }; + + /// The arguments for [ws_close]. + public type CanisterWsCloseArguments = { + client_key : ClientKey; + }; + + /// The arguments for [ws_message]. + public type CanisterWsMessageArguments = { + msg : WebsocketMessage; + }; + + /// The arguments for [ws_get_messages]. + public type CanisterWsGetMessagesArguments = { + nonce : Nat64; + }; + + /// Messages exchanged through the WebSocket. + public type WebsocketMessage = { + client_key : ClientKey; // The client that the gateway will forward the message to or that sent the message. + sequence_num : Nat64; // Both ways, messages should arrive with sequence numbers 0, 1, 2... + timestamp : Nat64; // Timestamp of when the message was made for the recipient to inspect. + is_service_message : Bool; // Whether the message is a service message sent by the CDK to the client or vice versa. + content : Blob; // Application message encoded in binary. + }; + /// Encodes the `WebsocketMessage` into a CBOR blob. + public func encode_websocket_message(websocket_message : WebsocketMessage) : Result { + let principal_blob = Blob.toArray(Principal.toBlob(websocket_message.client_key.client_principal)); + let cbor_value : CborValue.Value = #majorType5([ + (#majorType3("client_key"), #majorType5([(#majorType3("client_principal"), #majorType2(principal_blob)), (#majorType3("client_nonce"), #majorType0(websocket_message.client_key.client_nonce))])), + (#majorType3("sequence_num"), #majorType0(websocket_message.sequence_num)), + (#majorType3("timestamp"), #majorType0(websocket_message.timestamp)), + (#majorType3("is_service_message"), #majorType7(#bool(websocket_message.is_service_message))), + (#majorType3("content"), #majorType2(Blob.toArray(websocket_message.content))), + ]); + + switch (CborEncoder.encode(cbor_value)) { + case (#err(#invalidValue(err))) { + return #Err(err); + }; + case (#ok(data)) { + #Ok(Blob.fromArray(data)); + }; + }; + }; + + /// Decodes the CBOR blob into a `WebsocketMessage`. + func decode_websocket_message(bytes : Blob) : Result { + switch (CborDecoder.decode(bytes)) { + case (#err(err)) { + #Err("deserialization failed"); + }; + case (#ok(c)) { + switch (c) { + case (#majorType6({ tag; value })) { + switch (value) { + case (#majorType5(raw_content)) { + #Ok({ + client_key = do { + let client_key_key_value = Array.find(raw_content, func((key, _) : (CborValue.Value, CborValue.Value)) : Bool = key == #majorType3("client_key")); + switch (client_key_key_value) { + case (?(_, #majorType5(raw_client_key))) { + let client_principal_value = Array.find(raw_client_key, func((key, _) : (CborValue.Value, CborValue.Value)) : Bool = key == #majorType3("client_principal")); + let client_principal = switch (client_principal_value) { + case (?(_, #majorType2(client_principal_blob))) { + Principal.fromBlob( + Blob.fromArray(client_principal_blob) + ); + }; + case (_) { + return #Err("missing field `client_key.client_principal`"); + }; + }; + let client_nonce_value = Array.find(raw_client_key, func((key, _) : (CborValue.Value, CborValue.Value)) : Bool = key == #majorType3("client_nonce")); + let client_nonce = switch (client_nonce_value) { + case (?(_, #majorType0(client_nonce))) { + client_nonce; + }; + case (_) { + return #Err("missing field `client_key.client_nonce`"); + }; + }; + + { + client_principal; + client_nonce; + }; + }; + case (_) { + return #Err("missing field `client_key`"); + }; + }; + }; + sequence_num = do { + let sequence_num_key_value = Array.find(raw_content, func((key, _) : (CborValue.Value, CborValue.Value)) : Bool = key == #majorType3("sequence_num")); + switch (sequence_num_key_value) { + case (?(_, #majorType0(sequence_num))) { + sequence_num; + }; + case (_) { + return #Err("missing field `sequence_num`"); + }; + }; + }; + timestamp = do { + let timestamp_key_value = Array.find(raw_content, func((key, _) : (CborValue.Value, CborValue.Value)) : Bool = key == #majorType3("timestamp")); + switch (timestamp_key_value) { + case (?(_, #majorType0(timestamp))) { + timestamp; + }; + case (_) { + return #Err("missing field `timestamp`"); + }; + }; + }; + is_service_message = do { + let is_service_message_key_value = Array.find(raw_content, func((key, _) : (CborValue.Value, CborValue.Value)) : Bool = key == #majorType3("is_service_message")); + switch (is_service_message_key_value) { + case (?(_, #majorType7(#bool(is_service_message)))) { + is_service_message; + }; + case (_) { + return #Err("missing field `is_service_message`"); + }; + }; + }; + content = do { + let content_key_value = Array.find(raw_content, func((key, _) : (CborValue.Value, CborValue.Value)) : Bool = key == #majorType3("message")); + switch (content_key_value) { + case (?(_, #majorType2(content_blob))) { + Blob.fromArray(content_blob); + }; + case (_) { + return #Err("missing field `content`"); + }; + }; + }; + }); + }; + case (_) { + #Err("invalid CBOR message content"); + }; + }; + }; + case (_) { + #Err("invalid CBOR message content"); + }; + }; + }; + }; + }; + + // Element of the list of messages returned to the WS Gateway after polling. + public type CanisterOutputMessage = { + client_key : ClientKey; // The client that the gateway will forward the message to. + key : Text; // Key for certificate verification. + content : Blob; // The message to be relayed, that contains the application message. + }; + + /// List of messages returned to the WS Gateway after polling. + public type CanisterOutputCertifiedMessages = { + messages : [CanisterOutputMessage]; // List of messages. + cert : Blob; // cert+tree constitute the certificate for all returned messages. + tree : Blob; // cert+tree constitute the certificate for all returned messages. + is_end_of_queue : Bool; // Whether the end of the queue has been reached. + }; + + public type MessagesForGatewayRange = { + start_index : Nat; + end_index : Nat; + is_end_of_queue : Bool; + }; + + type MessageToDelete = { + timestamp : Nat64; + }; + + public type GatewayPrincipal = Principal; + + /// Contains data about the registered WS Gateway. + public class RegisteredGateway() { + /// The queue of the messages that the gateway can poll. + public var messages_queue : List.List = List.nil(); + /// The queue of messages' keys to delete. + public var messages_to_delete : List.List = List.nil(); + /// Keeps track of the nonce which: + /// - the WS Gateway uses to specify the first index of the certified messages to be returned when polling + /// - the client uses as part of the path in the Merkle tree in order to verify the certificate of the messages relayed by the WS Gateway + public var outgoing_message_nonce : Nat64 = Constants.INITIAL_OUTGOING_MESSAGE_NONCE; + /// The number of clients connected to this gateway. + public var connected_clients_count : Nat64 = 0; + + /// Increments the outgoing message nonce by 1. + public func increment_nonce() { + outgoing_message_nonce += 1; + }; + + /// Increments the connected clients count by 1. + public func increment_clients_count() { + connected_clients_count += 1; + }; + + /// Decrements the connected clients count by 1, returning the new value. + public func decrement_clients_count() : Nat64 { + connected_clients_count -= 1; + connected_clients_count; + }; + + /// Adds the message to the queue and its metadata to the `messages_to_delete` queue. + public func add_message_to_queue(message : CanisterOutputMessage, message_timestamp : Nat64) { + messages_queue := List.append( + messages_queue, + List.fromArray([message]), + ); + messages_to_delete := List.append( + messages_to_delete, + List.fromArray([{ + timestamp = message_timestamp; + }]), + ); + }; + + /// Deletes the oldest `n` messages that are older than `message_max_age_ms` from the queue. + /// + /// Returns the deleted messages keys. + public func delete_old_messages(n : Nat, message_max_age_ms : Nat64) : List.List { + let time = Utils.get_current_time(); + var deleted_keys : List.List = List.nil(); + + label f for (_ in Iter.range(0, n - 1)) { + let message_to_delete = do { + let (m, l) = List.pop(messages_to_delete); + messages_to_delete := l; + m; + }; + switch (message_to_delete) { + case (?message_to_delete) { + if ((time - message_to_delete.timestamp) > (message_max_age_ms * 1_000_000)) { + let deleted_message = do { + let (m, l) = List.pop(messages_queue); + messages_queue := l; + m; + }; + switch (deleted_message) { + case (?deleted_message) { + deleted_keys := List.append( + deleted_keys, + List.fromArray([deleted_message.key]), + ); + }; + case (null) { + // there is no case in which the messages_to_delete queue is populated + // while the messages_queue is empty + Prelude.unreachable(); + }; + }; + } else { + // In this case, no messages can be deleted because + // they're all not older than `message_max_age_ms`. + break f; + }; + }; + case (null) { + // There are no messages in the queue. Shouldn't happen. + break f; + }; + }; + }; + + deleted_keys; + }; + }; + + /// The metadata about a registered client. + public class RegisteredClient(gw_principal : GatewayPrincipal) { + public var last_keep_alive_timestamp : Nat64 = Utils.get_current_time(); + public let gateway_principal : GatewayPrincipal = gw_principal; + + /// Gets the last keep alive timestamp. + public func get_last_keep_alive_timestamp() : Nat64 { + last_keep_alive_timestamp; + }; + + /// Set the last keep alive timestamp to the current time. + public func update_last_keep_alive_timestamp() { + last_keep_alive_timestamp := Utils.get_current_time(); + }; + }; + + public type CanisterOpenMessageContent = { + client_key : ClientKey; + }; + + public type CanisterAckMessageContent = { + last_incoming_sequence_num : Nat64; + }; + + public type ClientKeepAliveMessageContent = { + last_incoming_sequence_num : Nat64; + }; + + public type WebsocketServiceMessageContent = { + #OpenMessage : CanisterOpenMessageContent; + #AckMessage : CanisterAckMessageContent; + #KeepAliveMessage : ClientKeepAliveMessageContent; + }; + public func encode_websocket_service_message_content(content : WebsocketServiceMessageContent) : Blob { + to_candid (content); + }; + public func decode_websocket_service_message_content(bytes : Blob) : Result { + let decoded : ?WebsocketServiceMessageContent = from_candid (bytes); // traps if the bytes are not a valid candid message + return switch (decoded) { + case (?value) { #Ok(value) }; + case (null) { #Err("Error decoding service message content: unknown") }; + }; + }; + + /// Arguments passed to the `on_open` handler. + public type OnOpenCallbackArgs = { + client_principal : ClientPrincipal; + }; + /// Handler initialized by the canister and triggered by the CDK once the IC WebSocket connection + /// is established. + public type OnOpenCallback = (OnOpenCallbackArgs) -> async (); + + /// Arguments passed to the `on_message` handler. + /// The `message` argument is the message received from the client, serialized in Candid. + /// Use [`from_candid`] to deserialize the message. + /// + /// # Example + /// This example is the deserialize equivalent of the [`ws_send`]'s serialize one. + /// ```motoko + /// import IcWebSocketCdk "mo:ic-websocket-cdk"; + /// + /// actor MyCanister { + /// // ... + /// + /// type MyMessage = { + /// some_field: Text; + /// }; + /// + /// // initialize the CDK + /// + /// func on_message(args : IcWebSocketCdk.OnMessageCallbackArgs) : async () { + /// let received_message: ?MyMessage = from_candid(args.message); + /// switch (received_message) { + /// case (?received_message) { + /// Debug.print("Received message: some_field: " # received_message.some_field); + /// }; + /// case (invalid_arg) { + /// return #Err("invalid argument: " # debug_show (invalid_arg)); + /// }; + /// }; + /// }; + /// + /// // ... + /// } + /// ``` + public type OnMessageCallbackArgs = { + /// The principal of the client sending the message to the canister. + client_principal : ClientPrincipal; + /// The message received from the client, serialized in Candid. See [OnMessageCallbackArgs] for an example on how to deserialize the message. + message : Blob; + }; + /// Handler initialized by the canister and triggered by the CDK once a message is received by + /// the CDK. + public type OnMessageCallback = (OnMessageCallbackArgs) -> async (); + + /// Arguments passed to the `on_close` handler. + public type OnCloseCallbackArgs = { + client_principal : ClientPrincipal; + }; + /// Handler initialized by the canister and triggered by the CDK once the WS Gateway closes the + /// IC WebSocket connection. + public type OnCloseCallback = (OnCloseCallbackArgs) -> async (); + + /// Handlers initialized by the canister and triggered by the CDK. + public class WsHandlers( + init_on_open : ?OnOpenCallback, + init_on_message : ?OnMessageCallback, + init_on_close : ?OnCloseCallback, + ) { + var on_open : ?OnOpenCallback = init_on_open; + var on_message : ?OnMessageCallback = init_on_message; + var on_close : ?OnCloseCallback = init_on_close; + + public func call_on_open(args : OnOpenCallbackArgs) : async () { + switch (on_open) { + case (?callback) { + try { + await callback(args); + } catch (err) { + Utils.custom_print("Error calling on_open handler: " # Error.message(err)); + }; + }; + case (null) { + // Do nothing. + }; + }; + }; + + public func call_on_message(args : OnMessageCallbackArgs) : async () { + switch (on_message) { + case (?callback) { + try { + await callback(args); + } catch (err) { + Utils.custom_print("Error calling on_message handler: " # Error.message(err)); + }; + }; + case (null) { + // Do nothing. + }; + }; + }; + + public func call_on_close(args : OnCloseCallbackArgs) : async () { + switch (on_close) { + case (?callback) { + try { + await callback(args); + } catch (err) { + Utils.custom_print("Error calling on_close handler: " # Error.message(err)); + }; + }; + case (null) { + // Do nothing. + }; + }; + }; + }; + + /// Parameters for the IC WebSocket CDK initialization. + /// + /// Arguments: + /// + /// - `init_max_number_of_returned_messages`: Maximum number of returned messages. Defaults to `10` if null. + /// - `init_send_ack_interval_ms`: Send ack interval in milliseconds. Defaults to `60_000` (60 seconds) if null. + /// - `init_keep_alive_timeout_ms`: Keep alive timeout in milliseconds. Defaults to `10_000` (10 seconds) if null. + public class WsInitParams( + init_max_number_of_returned_messages : ?Nat, + init_send_ack_interval_ms : ?Nat64, + init_keep_alive_timeout_ms : ?Nat64, + ) = self { + /// The maximum number of messages to be returned in a polling iteration. + /// Defaults to `50`. + public var max_number_of_returned_messages : Nat = switch (init_max_number_of_returned_messages) { + case (?value) { value }; + case (null) { Constants.DEFAULT_MAX_NUMBER_OF_RETURNED_MESSAGES }; + }; + /// The interval at which to send an acknowledgement message to the client, + /// so that the client knows that all the messages it sent have been received by the canister (in milliseconds). + /// + /// Must be greater than `keep_alive_timeout_ms`. + /// + /// Defaults to `300_000` (5 minutes). + public var send_ack_interval_ms : Nat64 = switch (init_send_ack_interval_ms) { + case (?value) { value }; + case (null) { Constants.DEFAULT_SEND_ACK_INTERVAL_MS }; + }; + /// The delay to wait for the client to send a keep alive after receiving an acknowledgement (in milliseconds). + /// + /// Must be lower than `send_ack_interval_ms`. + /// + /// Defaults to `60_000` (1 minute). + public var keep_alive_timeout_ms : Nat64 = switch (init_keep_alive_timeout_ms) { + case (?value) { value }; + case (null) { Constants.DEFAULT_CLIENT_KEEP_ALIVE_TIMEOUT_MS }; + }; + + /// Checks the validity of the timer parameters. + /// `send_ack_interval_ms` must be greater than `keep_alive_timeout_ms`. + /// + /// # Traps + /// If `send_ack_interval_ms` < `keep_alive_timeout_ms`. + public func check_validity() { + if (keep_alive_timeout_ms > send_ack_interval_ms) { + Utils.custom_trap("send_ack_interval_ms must be greater than keep_alive_timeout_ms"); + }; + }; + + public func with_max_number_of_returned_messages( + n : Nat + ) : WsInitParams { + max_number_of_returned_messages := n; + self; + }; + + public func with_send_ack_interval_ms( + ms : Nat64 + ) : WsInitParams { + send_ack_interval_ms := ms; + self; + }; + + public func with_keep_alive_timeout_ms( + ms : Nat64 + ) : WsInitParams { + keep_alive_timeout_ms := ms; + self; + }; + }; +}; diff --git a/src/Utils.mo b/src/Utils.mo new file mode 100644 index 0000000..3e4571b --- /dev/null +++ b/src/Utils.mo @@ -0,0 +1,18 @@ +import Text "mo:base/Text"; +import Debug "mo:base/Debug"; +import Nat64 "mo:base/Nat64"; +import Time "mo:base/Time"; + +module { + public func custom_print(s : Text) { + Debug.print(Text.concat("[IC-WEBSOCKET-CDK]: ", s)); + }; + + public func custom_trap(s : Text) { + Debug.trap(s); + }; + + public func get_current_time() : Nat64 { + Nat64.fromIntWrap(Time.now()); + }; +}; diff --git a/src/lib.mo b/src/lib.mo index c9e1750..70e302a 100644 --- a/src/lib.mo +++ b/src/lib.mo @@ -26,1172 +26,87 @@ import CborEncoder "mo:cbor/Encoder"; import CertTree "mo:ic-certification/CertTree"; import Sha256 "mo:sha2/Sha256"; -import Logger "Logger"; +import State "State"; +import Types "Types"; +import Utils "Utils"; +import Timers "Timers"; +import Errors "Errors"; module { - //// CONSTANTS //// - /// The label used when constructing the certification tree. - let LABEL_WEBSOCKET : Blob = "websocket"; - /// The default maximum number of messages returned by [ws_get_messages] at each poll. - let DEFAULT_MAX_NUMBER_OF_RETURNED_MESSAGES : Nat = 10; - /// The default interval at which to send acknowledgements to the client. - let DEFAULT_SEND_ACK_INTERVAL_MS : Nat64 = 60_000; // 60 seconds - /// The default timeout to wait for the client to send a keep alive after receiving an acknowledgement. - let DEFAULT_CLIENT_KEEP_ALIVE_TIMEOUT_MS : Nat64 = 10_000; // 10 seconds - - /// The initial nonce for outgoing messages. - let INITIAL_OUTGOING_MESSAGE_NONCE : Nat64 = 0; - /// The initial sequence number to expect from messages coming from clients. - /// The first message coming from the client will have sequence number `1` because on the client the sequence number is incremented before sending the message. - let INITIAL_CLIENT_SEQUENCE_NUM : Nat64 = 1; - /// The initial sequence number for outgoing messages. - let INITIAL_CANISTER_SEQUENCE_NUM : Nat64 = 0; - //// TYPES //// - /// Just to be compatible with the Rust version. - type Result = { #Ok : Ok; #Err : Err }; - - public type ClientPrincipal = Principal; - - public type ClientKey = { - client_principal : ClientPrincipal; - client_nonce : Nat64; - }; - // functions needed for ClientKey - func areClientKeysEqual(k1 : ClientKey, k2 : ClientKey) : Bool { - Principal.equal(k1.client_principal, k2.client_principal) and Nat64.equal(k1.client_nonce, k2.client_nonce); - }; - func clientKeyToText(k : ClientKey) : Text { - Principal.toText(k.client_principal) # "_" # Nat64.toText(k.client_nonce); - }; - func hashClientKey(k : ClientKey) : Hash.Hash { - Text.hash(clientKeyToText(k)); - }; - - /// The result of [ws_open]. - public type CanisterWsOpenResult = Result<(), Text>; - /// The result of [ws_close]. - public type CanisterWsCloseResult = Result<(), Text>; - // The result of [ws_message]. - public type CanisterWsMessageResult = Result<(), Text>; - /// The result of [ws_get_messages]. - public type CanisterWsGetMessagesResult = Result; - /// The result of [ws_send]. - public type CanisterWsSendResult = Result<(), Text>; - - /// The arguments for [ws_open]. - public type CanisterWsOpenArguments = { - client_nonce : Nat64; - gateway_principal : GatewayPrincipal; - }; - - /// The arguments for [ws_close]. - public type CanisterWsCloseArguments = { - client_key : ClientKey; - }; - - /// The arguments for [ws_message]. - public type CanisterWsMessageArguments = { - msg : WebsocketMessage; - }; - - /// The arguments for [ws_get_messages]. - public type CanisterWsGetMessagesArguments = { - nonce : Nat64; - }; - - /// Messages exchanged through the WebSocket. - type WebsocketMessage = { - client_key : ClientKey; // The client that the gateway will forward the message to or that sent the message. - sequence_num : Nat64; // Both ways, messages should arrive with sequence numbers 0, 1, 2... - timestamp : Nat64; // Timestamp of when the message was made for the recipient to inspect. - is_service_message : Bool; // Whether the message is a service message sent by the CDK to the client or vice versa. - content : Blob; // Application message encoded in binary. - }; - /// Encodes the `WebsocketMessage` into a CBOR blob. - func encode_websocket_message(websocket_message : WebsocketMessage) : Result { - let principal_blob = Blob.toArray(Principal.toBlob(websocket_message.client_key.client_principal)); - let cbor_value : CborValue.Value = #majorType5([ - (#majorType3("client_key"), #majorType5([(#majorType3("client_principal"), #majorType2(principal_blob)), (#majorType3("client_nonce"), #majorType0(websocket_message.client_key.client_nonce))])), - (#majorType3("sequence_num"), #majorType0(websocket_message.sequence_num)), - (#majorType3("timestamp"), #majorType0(websocket_message.timestamp)), - (#majorType3("is_service_message"), #majorType7(#bool(websocket_message.is_service_message))), - (#majorType3("content"), #majorType2(Blob.toArray(websocket_message.content))), - ]); - - switch (CborEncoder.encode(cbor_value)) { - case (#err(#invalidValue(err))) { - return #Err(err); - }; - case (#ok(data)) { - #Ok(Blob.fromArray(data)); - }; - }; - }; - - /// Decodes the CBOR blob into a `WebsocketMessage`. - func decode_websocket_message(bytes : Blob) : Result { - switch (CborDecoder.decode(bytes)) { - case (#err(err)) { - #Err("deserialization failed"); - }; - case (#ok(c)) { - switch (c) { - case (#majorType6({ tag; value })) { - switch (value) { - case (#majorType5(raw_content)) { - #Ok({ - client_key = do { - let client_key_key_value = Array.find(raw_content, func((key, _) : (CborValue.Value, CborValue.Value)) : Bool = key == #majorType3("client_key")); - switch (client_key_key_value) { - case (?(_, #majorType5(raw_client_key))) { - let client_principal_value = Array.find(raw_client_key, func((key, _) : (CborValue.Value, CborValue.Value)) : Bool = key == #majorType3("client_principal")); - let client_principal = switch (client_principal_value) { - case (?(_, #majorType2(client_principal_blob))) { - Principal.fromBlob( - Blob.fromArray(client_principal_blob) - ); - }; - case (_) { - return #Err("missing field `client_key.client_principal`"); - }; - }; - let client_nonce_value = Array.find(raw_client_key, func((key, _) : (CborValue.Value, CborValue.Value)) : Bool = key == #majorType3("client_nonce")); - let client_nonce = switch (client_nonce_value) { - case (?(_, #majorType0(client_nonce))) { - client_nonce; - }; - case (_) { - return #Err("missing field `client_key.client_nonce`"); - }; - }; - - { - client_principal; - client_nonce; - }; - }; - case (_) { - return #Err("missing field `client_key`"); - }; - }; - }; - sequence_num = do { - let sequence_num_key_value = Array.find(raw_content, func((key, _) : (CborValue.Value, CborValue.Value)) : Bool = key == #majorType3("sequence_num")); - switch (sequence_num_key_value) { - case (?(_, #majorType0(sequence_num))) { - sequence_num; - }; - case (_) { - return #Err("missing field `sequence_num`"); - }; - }; - }; - timestamp = do { - let timestamp_key_value = Array.find(raw_content, func((key, _) : (CborValue.Value, CborValue.Value)) : Bool = key == #majorType3("timestamp")); - switch (timestamp_key_value) { - case (?(_, #majorType0(timestamp))) { - timestamp; - }; - case (_) { - return #Err("missing field `timestamp`"); - }; - }; - }; - is_service_message = do { - let is_service_message_key_value = Array.find(raw_content, func((key, _) : (CborValue.Value, CborValue.Value)) : Bool = key == #majorType3("is_service_message")); - switch (is_service_message_key_value) { - case (?(_, #majorType7(#bool(is_service_message)))) { - is_service_message; - }; - case (_) { - return #Err("missing field `is_service_message`"); - }; - }; - }; - content = do { - let content_key_value = Array.find(raw_content, func((key, _) : (CborValue.Value, CborValue.Value)) : Bool = key == #majorType3("message")); - switch (content_key_value) { - case (?(_, #majorType2(content_blob))) { - Blob.fromArray(content_blob); - }; - case (_) { - return #Err("missing field `content`"); - }; - }; - }; - }); - }; - case (_) { - #Err("invalid CBOR message content"); - }; - }; - }; - case (_) { - #Err("invalid CBOR message content"); - }; - }; - }; - }; - }; - - /// Element of the list of messages returned to the WS Gateway after polling. - public type CanisterOutputMessage = { - client_key : ClientKey; // The client that the gateway will forward the message to. - key : Text; // Key for certificate verification. - content : Blob; // The message to be relayed, that contains the application message. - }; - - /// List of messages returned to the WS Gateway after polling. - public type CanisterOutputCertifiedMessages = { - messages : [CanisterOutputMessage]; // List of messages. - cert : Blob; // cert+tree constitute the certificate for all returned messages. - tree : Blob; // cert+tree constitute the certificate for all returned messages. - }; - - type GatewayPrincipal = Principal; - - /// Contains data about the registered WS Gateway. - class RegisteredGateway(gw_principal : Principal) { - /// The principal of the gateway. - public var gateway_principal : Principal = gw_principal; - /// The queue of the messages that the gateway can poll. - public var messages_queue : List.List = List.nil(); - /// Keeps track of the nonce which: - /// - the WS Gateway uses to specify the first index of the certified messages to be returned when polling - /// - the client uses as part of the path in the Merkle tree in order to verify the certificate of the messages relayed by the WS Gateway - public var outgoing_message_nonce : Nat64 = INITIAL_OUTGOING_MESSAGE_NONCE; - - /// Resets the messages and nonce to the initial values. - public func reset() { - messages_queue := List.nil(); - outgoing_message_nonce := INITIAL_OUTGOING_MESSAGE_NONCE; - }; - - /// Increments the outgoing message nonce by 1. - public func increment_nonce() { - outgoing_message_nonce += 1; - }; - }; - - /// The metadata about a registered client. - class RegisteredClient(gw_principal : GatewayPrincipal) { - public var last_keep_alive_timestamp : Nat64 = get_current_time(); - public let gateway_principal : GatewayPrincipal = gw_principal; - - /// Gets the last keep alive timestamp. - public func get_last_keep_alive_timestamp() : Nat64 { - last_keep_alive_timestamp; - }; - - /// Set the last keep alive timestamp to the current time. - public func update_last_keep_alive_timestamp() { - last_keep_alive_timestamp := get_current_time(); - }; - }; - - type CanisterOpenMessageContent = { - client_key : ClientKey; - }; - - type CanisterAckMessageContent = { - last_incoming_sequence_num : Nat64; - }; - - type ClientKeepAliveMessageContent = { - last_incoming_sequence_num : Nat64; - }; - - type WebsocketServiceMessageContent = { - #OpenMessage : CanisterOpenMessageContent; - #AckMessage : CanisterAckMessageContent; - #KeepAliveMessage : ClientKeepAliveMessageContent; - }; - func encode_websocket_service_message_content(content : WebsocketServiceMessageContent) : Blob { - to_candid (content); - }; - func decode_websocket_service_message_content(bytes : Blob) : Result { - let decoded : ?WebsocketServiceMessageContent = from_candid (bytes); // traps if the bytes are not a valid candid message - return switch (decoded) { - case (?value) { #Ok(value) }; - case (null) { #Err("Error decoding service message content: unknown") }; - }; - }; - - /// Arguments passed to the `on_open` handler. - public type OnOpenCallbackArgs = { - client_principal : ClientPrincipal; - }; - /// Handler initialized by the canister and triggered by the CDK once the IC WebSocket connection - /// is established. - public type OnOpenCallback = (OnOpenCallbackArgs) -> async (); - - /// Arguments passed to the `on_message` handler. - /// The `message` argument is the message received from the client, serialized in Candid. - /// Use [`from_candid`] to deserialize the message. - /// - /// # Example - /// This example is the deserialize equivalent of the [`ws_send`]'s serialize one. - /// ```motoko - /// import IcWebSocketCdk "mo:ic-websocket-cdk"; - /// - /// actor MyCanister { - /// // ... - /// - /// type MyMessage = { - /// some_field: Text; - /// }; - /// - /// // initialize the CDK - /// - /// func on_message(args : IcWebSocketCdk.OnMessageCallbackArgs) : async () { - /// let received_message: ?MyMessage = from_candid(args.message); - /// switch (received_message) { - /// case (?received_message) { - /// Debug.print("Received message: some_field: " # received_message.some_field); - /// }; - /// case (invalid_arg) { - /// return #Err("invalid argument: " # debug_show (invalid_arg)); - /// }; - /// }; - /// }; - /// - /// // ... - /// } - /// ``` - public type OnMessageCallbackArgs = { - /// The principal of the client sending the message to the canister. - client_principal : ClientPrincipal; - /// The message received from the client, serialized in Candid. See [OnMessageCallbackArgs] for an example on how to deserialize the message. - message : Blob; - }; - /// Handler initialized by the canister and triggered by the CDK once a message is received by - /// the CDK. - public type OnMessageCallback = (OnMessageCallbackArgs) -> async (); - - /// Arguments passed to the `on_close` handler. - public type OnCloseCallbackArgs = { - client_principal : ClientPrincipal; - }; - /// Handler initialized by the canister and triggered by the CDK once the WS Gateway closes the - /// IC WebSocket connection. - public type OnCloseCallback = (OnCloseCallbackArgs) -> async (); - - //// FUNCTIONS //// - func get_current_time() : Nat64 { - Nat64.fromIntWrap(Time.now()); - }; - - /// Handlers initialized by the canister and triggered by the CDK. - public class WsHandlers( - init_on_open : ?OnOpenCallback, - init_on_message : ?OnMessageCallback, - init_on_close : ?OnCloseCallback, - ) { - var on_open : ?OnOpenCallback = init_on_open; - var on_message : ?OnMessageCallback = init_on_message; - var on_close : ?OnCloseCallback = init_on_close; - - public func call_on_open(args : OnOpenCallbackArgs) : async () { - switch (on_open) { - case (?callback) { - try { - await callback(args); - } catch (err) { - Logger.custom_print("Error calling on_open handler: " # Error.message(err)); - }; - }; - case (null) { - // Do nothing. - }; - }; - }; - - public func call_on_message(args : OnMessageCallbackArgs) : async () { - switch (on_message) { - case (?callback) { - try { - await callback(args); - } catch (err) { - Logger.custom_print("Error calling on_message handler: " # Error.message(err)); - }; - }; - case (null) { - // Do nothing. - }; - }; - }; - - public func call_on_close(args : OnCloseCallbackArgs) : async () { - switch (on_close) { - case (?callback) { - try { - await callback(args); - } catch (err) { - Logger.custom_print("Error calling on_close handler: " # Error.message(err)); - }; - }; - case (null) { - // Do nothing. - }; - }; - }; - }; - - /// IC WebSocket class that holds the internal state of the IC WebSocket. - /// - /// Arguments: - /// - /// - `gateway_principals`: An array of the principals of the WS Gateways that are allowed to poll the canister. - /// - /// **Note**: you should only pass an instance of this class to the IcWebSocket class constructor, without using the methods or accessing the fields directly. - public class IcWebSocketState(gateway_principals : [Text]) = self { - //// STATE //// - /// Maps the client's key to the client metadata - var REGISTERED_CLIENTS = HashMap.HashMap(0, areClientKeysEqual, hashClientKey); - /// Maps the client's principal to the current client key - var CURRENT_CLIENT_KEY_MAP = HashMap.HashMap(0, Principal.equal, Principal.hash); - /// Keeps track of all the clients for which we're waiting for a keep alive message. - var CLIENTS_WAITING_FOR_KEEP_ALIVE : TrieSet.Set = TrieSet.empty(); - /// Maps the client's public key to the sequence number to use for the next outgoing message (to that client). - var OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP = HashMap.HashMap(0, areClientKeysEqual, hashClientKey); - /// Maps the client's public key to the expected sequence number of the next incoming message (from that client). - var INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP = HashMap.HashMap(0, areClientKeysEqual, hashClientKey); - /// Keeps track of the Merkle tree used for certified queries - var CERT_TREE_STORE : CertTree.Store = CertTree.newStore(); - var CERT_TREE = CertTree.Ops(CERT_TREE_STORE); - /// Keeps track of the principal of the WS Gateway which polls the canister - var REGISTERED_GATEWAYS = do { - let map = HashMap.HashMap(0, Principal.equal, Principal.hash); - - for (gateway_principal_text in Iter.fromArray(gateway_principals)) { - let gateway_principal = Principal.fromText(gateway_principal_text); - map.put(gateway_principal, RegisteredGateway(gateway_principal)); - }; - - map; - }; - /// The acknowledgement active timer. - var ACK_TIMER : ?Timer.TimerId = null; - /// The keep alive active timer. - var KEEP_ALIVE_TIMER : ?Timer.TimerId = null; - - //// FUNCTIONS //// - /// Resets all state to the initial state. - public func reset_internal_state(handlers : WsHandlers) : async () { - // for each client, call the on_close handler before clearing the map - for (client_key in REGISTERED_CLIENTS.keys()) { - await remove_client(client_key, handlers); - }; - - // make sure all the maps are cleared - CURRENT_CLIENT_KEY_MAP := HashMap.HashMap(0, Principal.equal, Principal.hash); - CLIENTS_WAITING_FOR_KEEP_ALIVE := TrieSet.empty(); - OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP := HashMap.HashMap(0, areClientKeysEqual, hashClientKey); - INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP := HashMap.HashMap(0, areClientKeysEqual, hashClientKey); - CERT_TREE_STORE := CertTree.newStore(); - CERT_TREE := CertTree.Ops(CERT_TREE_STORE); - for (g in REGISTERED_GATEWAYS.vals()) { - g.reset(); - }; - }; - - public func get_outgoing_message_nonce(gateway_principal : GatewayPrincipal) : Result { - switch (get_registered_gateway(gateway_principal)) { - case (#Ok(registered_gateway)) { - #Ok(registered_gateway.outgoing_message_nonce); - }; - case (#Err(err)) { #Err(err) }; - }; - }; - - public func increment_outgoing_message_nonce(gateway_principal : GatewayPrincipal) { - switch (REGISTERED_GATEWAYS.get(gateway_principal)) { - case (?registered_gateway) { - registered_gateway.increment_nonce(); - }; - case (null) { - Prelude.unreachable(); // we should always have a registered gateway at this point - }; - }; - }; - - func insert_client(client_key : ClientKey, new_client : RegisteredClient) { - CURRENT_CLIENT_KEY_MAP.put(client_key.client_principal, client_key); - REGISTERED_CLIENTS.put(client_key, new_client); - }; - - public func is_client_registered(client_key : ClientKey) : Bool { - Option.isSome(REGISTERED_CLIENTS.get(client_key)); - }; - - public func get_client_key_from_principal(client_principal : ClientPrincipal) : Result { - switch (CURRENT_CLIENT_KEY_MAP.get(client_principal)) { - case (?client_key) #Ok(client_key); - case (null) #Err("client with principal " # Principal.toText(client_principal) # " doesn't have an open connection"); - }; - }; - - public func check_registered_client(client_key : ClientKey) : Result<(), Text> { - if (not is_client_registered(client_key)) { - return #Err("client with key " # clientKeyToText(client_key) # " doesn't have an open connection"); - }; - - #Ok; - }; - - public func get_gateway_principal_from_registered_client(client_key : ClientKey) : GatewayPrincipal { - switch (REGISTERED_CLIENTS.get(client_key)) { - case (?registered_client) { registered_client.gateway_principal }; - case (null) { - Prelude.unreachable(); // the value exists because we checked that the client is registered - }; - }; - }; - - func add_client_to_wait_for_keep_alive(client_key : ClientKey) { - CLIENTS_WAITING_FOR_KEEP_ALIVE := TrieSet.put(CLIENTS_WAITING_FOR_KEEP_ALIVE, client_key, hashClientKey(client_key), areClientKeysEqual); - }; - - public func get_registered_gateway(gateway_principal : GatewayPrincipal) : Result { - switch (REGISTERED_GATEWAYS.get(gateway_principal)) { - case (?registered_gateway) { #Ok(registered_gateway) }; - case (null) { - #Err("no gateway registered with principal " # Principal.toText(gateway_principal)); - }; - }; - }; - - func init_outgoing_message_to_client_num(client_key : ClientKey) { - OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.put(client_key, INITIAL_CANISTER_SEQUENCE_NUM); - }; - - public func get_outgoing_message_to_client_num(client_key : ClientKey) : Result { - switch (OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.get(client_key)) { - case (?num) #Ok(num); - case (null) #Err("outgoing message to client num not initialized for client"); - }; - }; - - public func increment_outgoing_message_to_client_num(client_key : ClientKey) : Result<(), Text> { - let num = get_outgoing_message_to_client_num(client_key); - switch (num) { - case (#Ok(num)) { - OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.put(client_key, num + 1); - #Ok; - }; - case (#Err(error)) #Err(error); - }; - }; - - func init_expected_incoming_message_from_client_num(client_key : ClientKey) { - INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.put(client_key, INITIAL_CLIENT_SEQUENCE_NUM); - }; - - public func get_expected_incoming_message_from_client_num(client_key : ClientKey) : Result { - switch (INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.get(client_key)) { - case (?num) #Ok(num); - case (null) #Err("expected incoming message num not initialized for client"); - }; - }; - - public func increment_expected_incoming_message_from_client_num(client_key : ClientKey) : Result<(), Text> { - let num = get_expected_incoming_message_from_client_num(client_key); - switch (num) { - case (#Ok(num)) { - INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.put(client_key, num + 1); - #Ok; - }; - case (#Err(error)) #Err(error); - }; - }; - - public func add_client(client_key : ClientKey, new_client : RegisteredClient) { - // insert the client in the map - insert_client(client_key, new_client); - // initialize incoming client's message sequence number to 1 - init_expected_incoming_message_from_client_num(client_key); - // initialize outgoing message sequence number to 0 - init_outgoing_message_to_client_num(client_key); - }; - - public func remove_client(client_key : ClientKey, handlers : WsHandlers) : async () { - CLIENTS_WAITING_FOR_KEEP_ALIVE := TrieSet.delete(CLIENTS_WAITING_FOR_KEEP_ALIVE, client_key, hashClientKey(client_key), areClientKeysEqual); - CURRENT_CLIENT_KEY_MAP.delete(client_key.client_principal); - REGISTERED_CLIENTS.delete(client_key); - OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.delete(client_key); - INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.delete(client_key); - - await handlers.call_on_close({ - client_principal = client_key.client_principal; - }); - }; - - public func format_message_for_gateway_key(gateway_principal : Principal, nonce : Nat64) : Text { - let nonce_to_text = do { - // prints the nonce with 20 padding zeros - var nonce_str = Nat64.toText(nonce); - let padding : Nat = 20 - Text.size(nonce_str); - if (padding > 0) { - for (i in Iter.range(0, padding - 1)) { - nonce_str := "0" # nonce_str; - }; - }; - - nonce_str; - }; - Principal.toText(gateway_principal) # "_" # nonce_to_text; - }; - - func get_gateway_messages_queue(gateway_principal : Principal) : List.List { - switch (REGISTERED_GATEWAYS.get(gateway_principal)) { - case (?registered_gateway) { - registered_gateway.messages_queue; - }; - case (null) { - Prelude.unreachable(); // the value exists because we just checked that the gateway is registered - }; - }; - }; - - func get_messages_for_gateway_range(gateway_principal : Principal, nonce : Nat64, max_number_of_returned_messages : Nat) : (Nat, Nat) { - let messages_queue = get_gateway_messages_queue(gateway_principal); - - let queue_len = List.size(messages_queue); - - if (nonce == 0 and queue_len > 0) { - // this is the case in which the poller on the gateway restarted - // the range to return is end:last index and start: max(end - max_number_of_returned_messages, 0) - let start_index = if (queue_len > max_number_of_returned_messages) { - (queue_len - max_number_of_returned_messages) : Nat; - } else { - 0; - }; - - return (start_index, queue_len); - }; - - // smallest key used to determine the first message from the queue which has to be returned to the WS Gateway - let smallest_key = format_message_for_gateway_key(gateway_principal, nonce); - // partition the queue at the message which has the key with the nonce specified as argument to get_cert_messages - let start_index = do { - let partitions = List.partition( - messages_queue, - func(el : CanisterOutputMessage) : Bool { - Text.less(el.key, smallest_key); - }, - ); - List.size(partitions.0); - }; - var end_index = queue_len; - if (((end_index - start_index) : Nat) > max_number_of_returned_messages) { - end_index := start_index + max_number_of_returned_messages; - }; - - (start_index, end_index); - }; - - func get_messages_for_gateway(gateway_principal : Principal, start_index : Nat, end_index : Nat) : List.List { - let messages_queue = get_gateway_messages_queue(gateway_principal); - - var messages : List.List = List.nil(); - for (i in Iter.range(start_index, end_index - 1)) { - let message = List.get(messages_queue, i); - switch (message) { - case (?message) { - messages := List.push(message, messages); - }; - case (null) { - Prelude.unreachable(); // the value exists because this function is called only after partitioning the queue - }; - }; - }; - - List.reverse(messages); - }; - - public func get_cert_messages(gateway_principal : Principal, nonce : Nat64, max_number_of_returned_messages : Nat) : CanisterWsGetMessagesResult { - let (start_index, end_index) = get_messages_for_gateway_range(gateway_principal, nonce, max_number_of_returned_messages); - let messages = get_messages_for_gateway(gateway_principal, start_index, end_index); - - if (List.isNil(messages)) { - return #Ok({ - messages = []; - cert = Blob.fromArray([]); - tree = Blob.fromArray([]); - }); - }; - - let keys = List.map( - messages, - func(message : CanisterOutputMessage) : CertTree.Path { - [Text.encodeUtf8(message.key)]; - }, - ); - let (cert, tree) = get_cert_for_range(List.toIter(keys)); - - #Ok({ - messages = List.toArray(messages); - cert = cert; - tree = tree; - }); - }; - - public func is_registered_gateway(principal : Principal) : Bool { - switch (REGISTERED_GATEWAYS.get(principal)) { - case (?_) { true }; - case (null) { false }; - }; - }; - - /// Checks if the caller of the method is the same as the one that was registered during the initialization of the CDK - public func check_is_registered_gateway(input_principal : Principal) : Result<(), Text> { - if (not is_registered_gateway(input_principal)) { - return #Err("principal is not one of the authorized gateways that have been registered during CDK initialization"); - }; - - #Ok; - }; - - func labeledHash(l : Blob, content : CertTree.Hash) : CertTree.Hash { - let d = Sha256.Digest(#sha256); - d.writeBlob("\13ic-hashtree-labeled"); - d.writeBlob(l); - d.writeBlob(content); - d.sum(); - }; - - public func put_cert_for_message(key : Text, value : Blob) { - let root_hash = do { - CERT_TREE.put([Text.encodeUtf8(key)], Sha256.fromBlob(#sha256, value)); - labeledHash(LABEL_WEBSOCKET, CERT_TREE.treeHash()); - }; - - CertifiedData.set(root_hash); - }; - - func get_cert_for_range(keys : Iter.Iter) : (Blob, Blob) { - let witness = CERT_TREE.reveals(keys); - let tree : CertTree.Witness = #labeled(LABEL_WEBSOCKET, witness); - - switch (CertifiedData.getCertificate()) { - case (?cert) { - let tree_blob = CERT_TREE.encodeWitness(tree); - (cert, tree_blob); - }; - case (null) Prelude.unreachable(); - }; - }; - - func put_ack_timet_id(timer_id : Timer.TimerId) { - ACK_TIMER := ?timer_id; - }; - - func reset_ack_timer() { - switch (ACK_TIMER) { - case (?value) { - Timer.cancelTimer(value); - ACK_TIMER := null; - }; - case (null) { - // Do nothing - }; - }; - }; - - func put_keep_alive_timer_id(timer_id : Timer.TimerId) { - KEEP_ALIVE_TIMER := ?timer_id; - }; - - func reset_keep_alive_timer() { - switch (KEEP_ALIVE_TIMER) { - case (?value) { - Timer.cancelTimer(value); - KEEP_ALIVE_TIMER := null; - }; - case (null) { - // Do nothing - }; - }; - }; - - public func reset_timers() { - reset_ack_timer(); - reset_keep_alive_timer(); - }; - - /// Start an interval to send an acknowledgement messages to the clients. - /// - /// The interval callback is [send_ack_to_clients_timer_callback]. After the callback is executed, - /// a timer is scheduled to check if the registered clients have sent a keep alive message. - public func schedule_send_ack_to_clients(send_ack_interval_ms : Nat64, keep_alive_timeout_ms : Nat64, handlers : WsHandlers) { - let timer_id = Timer.recurringTimer( - #nanoseconds(Nat64.toNat(send_ack_interval_ms) * 1_000_000), - func() : async () { - send_ack_to_clients_timer_callback(); - - schedule_check_keep_alive(keep_alive_timeout_ms, handlers); - }, - ); - - put_ack_timet_id(timer_id); - }; - - /// Schedules a timer to check if the clients (only those to which an ack message was sent) have sent a keep alive message - /// after receiving an acknowledgement message. - /// - /// The timer callback is [check_keep_alive_timer_callback]. - func schedule_check_keep_alive(keep_alive_timeout_ms : Nat64, handlers : WsHandlers) { - let timer_id = Timer.setTimer( - #nanoseconds(Nat64.toNat(keep_alive_timeout_ms) * 1_000_000), - func() : async () { - await check_keep_alive_timer_callback(keep_alive_timeout_ms, handlers); - }, - ); - - put_keep_alive_timer_id(timer_id); - }; - - /// Sends an acknowledgement message to the client. - /// The message contains the current incoming message sequence number for that client, - /// so that the client knows that all the messages it sent have been received by the canister. - func send_ack_to_clients_timer_callback() { - for (client_key in REGISTERED_CLIENTS.keys()) { - switch (get_expected_incoming_message_from_client_num(client_key)) { - case (#Ok(expected_incoming_sequence_num)) { - let ack_message : CanisterAckMessageContent = { - // the expected sequence number is 1 more because it's incremented when a message is received - last_incoming_sequence_num = expected_incoming_sequence_num - 1; - }; - let message : WebsocketServiceMessageContent = #AckMessage(ack_message); - switch (send_service_message_to_client(self, client_key, message)) { - case (#Err(err)) { - // TODO: decide what to do when sending the message fails - - Logger.custom_print("[ack-to-clients-timer-cb]: Error sending ack message to client" # clientKeyToText(client_key) # ": " # err); - }; - case (#Ok(_)) { - add_client_to_wait_for_keep_alive(client_key); - }; - }; - }; - case (#Err(err)) { - // TODO: decide what to do when getting the expected incoming sequence number fails (shouldn't happen) - Logger.custom_print("[ack-to-clients-timer-cb]: Error getting expected incoming sequence number for client" # clientKeyToText(client_key) # ": " # err); - }; - }; - }; - - Logger.custom_print("[ack-to-clients-timer-cb]: Sent ack messages to all clients"); - }; - - /// Checks if the clients for which we are waiting for keep alive have sent a keep alive message. - /// If a client has not sent a keep alive message, it is removed from the connected clients. - func check_keep_alive_timer_callback(keep_alive_timeout_ms : Nat64, handlers : WsHandlers) : async () { - for (client_key in Array.vals(TrieSet.toArray(CLIENTS_WAITING_FOR_KEEP_ALIVE))) { - let client_metadata = REGISTERED_CLIENTS.get(client_key); - switch (client_metadata) { - case (?client_metadata) { - let last_keep_alive = client_metadata.get_last_keep_alive_timestamp(); - - if (get_current_time() - last_keep_alive > keep_alive_timeout_ms * 1_000_000) { - await remove_client(client_key, handlers); - - Logger.custom_print("[check-keep-alive-timer-cb]: Client " # clientKeyToText(client_key) # " has not sent a keep alive message in the last " # debug_show (keep_alive_timeout_ms) # " ms and has been removed"); - }; - }; - case (null) { - // Do nothing - }; - }; - }; - - Logger.custom_print("[check-keep-alive-timer-cb]: Checked keep alive messages for all clients"); - }; - - public func update_last_keep_alive_timestamp_for_client(client_key : ClientKey) { - let client = REGISTERED_CLIENTS.get(client_key); - switch (client) { - case (?client_metadata) { - client_metadata.update_last_keep_alive_timestamp(); - REGISTERED_CLIENTS.put(client_key, client_metadata); - }; - case (null) { - // Do nothing. - }; - }; - }; - }; - - /// Internal function used to put the messages in the outgoing messages queue and certify them. - func _ws_send(ws_state : IcWebSocketState, client_principal : ClientPrincipal, msg_bytes : Blob, is_service_message : Bool) : CanisterWsSendResult { - // better to get the client key here to not replicate the same logic across functions - let client_key = switch (ws_state.get_client_key_from_principal(client_principal)) { - case (#Err(err)) { - return #Err(err); - }; - case (#Ok(client_key)) { - client_key; - }; - }; - - // check if the client is registered - switch (ws_state.check_registered_client(client_key)) { - case (#Err(err)) { - return #Err(err); - }; - case (_) { - // do nothing - }; - }; - - // get the principal of the gateway that is polling the canister - let gateway_principal = ws_state.get_gateway_principal_from_registered_client(client_key); - switch (ws_state.check_is_registered_gateway(gateway_principal)) { - case (#Err(err)) { - return #Err(err); - }; - case (_) { - // do nothing - }; - }; - - // the nonce in key is used by the WS Gateway to determine the message to start in the polling iteration - // the key is also passed to the client in order to validate the body of the certified message - let outgoing_message_nonce = switch (ws_state.get_outgoing_message_nonce(gateway_principal)) { - case (#Err(err)) { - return #Err(err); - }; - case (#Ok(nonce)) { - nonce; - }; - }; - let message_key = ws_state.format_message_for_gateway_key(gateway_principal, outgoing_message_nonce); - - // increment the nonce for the next message - ws_state.increment_outgoing_message_nonce(gateway_principal); - - // increment the sequence number for the next message to the client - switch (ws_state.increment_outgoing_message_to_client_num(client_key)) { - case (#Err(err)) { - return #Err(err); - }; - case (_) { - // do nothing - }; - }; - - let sequence_num = switch (ws_state.get_outgoing_message_to_client_num(client_key)) { - case (#Err(err)) { - return #Err(err); - }; - case (#Ok(sequence_num)) { - sequence_num; - }; - }; - - let websocket_message : WebsocketMessage = { - client_key; - sequence_num; - timestamp = get_current_time(); - is_service_message; - content = msg_bytes; - }; - - // CBOR serialize message of type WebsocketMessage - let message_content = switch (encode_websocket_message(websocket_message)) { - case (#Err(err)) { - return #Err(err); - }; - case (#Ok(content)) { - content; - }; - }; - - // certify data - ws_state.put_cert_for_message(message_key, message_content); - - switch (ws_state.get_registered_gateway(gateway_principal)) { - case (#Ok(registered_gateway)) { - // messages in the queue are inserted with contiguous and increasing nonces - // (from beginning to end of the queue) as ws_send is called sequentially, the nonce - // is incremented by one in each call, and the message is pushed at the end of the queue - registered_gateway.messages_queue := List.append( - registered_gateway.messages_queue, - List.fromArray([{ - client_key; - content = message_content; - key = message_key; - }]), - ); - }; - case (_) { - Prelude.unreachable(); // the value exists because we just checked that the gateway is registered - }; - }; - #Ok; - }; - - /// Sends a message to the client. The message must already be serialized **using Candid**. - /// Use [`to_candid`] to serialize the message. - /// - /// Under the hood, the message is certified and added to the queue of messages - /// that the WS Gateway will poll in the next iteration. - /// - /// # Example - /// This example is the serialize equivalent of the [`OnMessageCallbackArgs`]'s deserialize one. - /// ```motoko - /// import IcWebSocketCdk "mo:ic-websocket-cdk"; - /// - /// actor MyCanister { - /// // ... - /// - /// type MyMessage = { - /// some_field: Text; - /// }; - /// - /// // initialize the CDK - /// - /// // at some point in your code - /// let msg : MyMessage = { - /// some_field: "Hello, World!"; - /// }; - /// - /// IcWebSocketCdk.ws_send(ws_state, client_principal, to_candid(msg)); - /// } - /// ``` - public func ws_send(ws_state : IcWebSocketState, client_principal : ClientPrincipal, msg_bytes : Blob) : async CanisterWsSendResult { - _ws_send(ws_state, client_principal, msg_bytes, false); - }; - - func send_service_message_to_client(ws_state : IcWebSocketState, client_key : ClientKey, message : WebsocketServiceMessageContent) : Result<(), Text> { - let message_bytes = encode_websocket_service_message_content(message); - _ws_send(ws_state, client_key.client_principal, message_bytes, true); - }; - - /// Parameters for the IC WebSocket CDK initialization. - /// - /// Arguments: - /// - /// - `init_handlers`: Handlers initialized by the canister and triggered by the CDK. - /// - `init_max_number_of_returned_messages`: Maximum number of returned messages. Defaults to `10` if null. - /// - `init_send_ack_interval_ms`: Send ack interval in milliseconds. Defaults to `60_000` (60 seconds) if null. - /// - `init_keep_alive_timeout_ms`: Keep alive timeout in milliseconds. Defaults to `10_000` (10 seconds) if null. - public class WsInitParams( - init_handlers : WsHandlers, - init_max_number_of_returned_messages : ?Nat, - init_send_ack_interval_ms : ?Nat64, - init_keep_alive_timeout_ms : ?Nat64, - ) { - /// The callback handlers for the WebSocket. - public var handlers : WsHandlers = init_handlers; - /// The maximum number of messages to be returned in a polling iteration. - /// Defaults to `10`. - public var max_number_of_returned_messages : Nat = switch (init_max_number_of_returned_messages) { - case (?value) { value }; - case (null) { DEFAULT_MAX_NUMBER_OF_RETURNED_MESSAGES }; - }; - /// The interval at which to send an acknowledgement message to the client, - /// so that the client knows that all the messages it sent have been received by the canister (in milliseconds). - /// - /// Must be greater than `keep_alive_timeout_ms`. - /// - /// Defaults to `60_000` (60 seconds). - public var send_ack_interval_ms : Nat64 = switch (init_send_ack_interval_ms) { - case (?value) { value }; - case (null) { DEFAULT_SEND_ACK_INTERVAL_MS }; - }; - /// The delay to wait for the client to send a keep alive after receiving an acknowledgement (in milliseconds). - /// - /// Must be lower than `send_ack_interval_ms`. - /// - /// Defaults to `10_000` (10 seconds). - public var keep_alive_timeout_ms : Nat64 = switch (init_keep_alive_timeout_ms) { - case (?value) { value }; - case (null) { DEFAULT_CLIENT_KEEP_ALIVE_TIMEOUT_MS }; - }; - - public func get_handlers() : WsHandlers { - return handlers; - }; - - /// Checks the validity of the timer parameters. - /// `send_ack_interval_ms` must be greater than `keep_alive_timeout_ms`. - /// - /// # Traps - /// If `send_ack_interval_ms` < `keep_alive_timeout_ms`. - public func check_validity() { - if (keep_alive_timeout_ms > send_ack_interval_ms) { - Debug.trap("send_ack_interval_ms must be greater than keep_alive_timeout_ms"); - }; - }; - }; + // re-export types + public type CanisterWsCloseArguments = Types.CanisterWsCloseArguments; + public type CanisterWsCloseResult = Types.CanisterWsCloseResult; + public type CanisterWsGetMessagesArguments = Types.CanisterWsGetMessagesArguments; + public type CanisterWsGetMessagesResult = Types.CanisterWsGetMessagesResult; + public type CanisterWsMessageArguments = Types.CanisterWsMessageArguments; + public type CanisterWsMessageResult = Types.CanisterWsMessageResult; + public type CanisterWsOpenArguments = Types.CanisterWsOpenArguments; + public type CanisterWsOpenResult = Types.CanisterWsOpenResult; + public type CanisterWsSendResult = Types.CanisterWsSendResult; + public type ClientPrincipal = Types.ClientPrincipal; + public type OnCloseCallbackArgs = Types.OnCloseCallbackArgs; + public type OnMessageCallbackArgs = Types.OnMessageCallbackArgs; + public type OnOpenCallbackArgs = Types.OnOpenCallbackArgs; + + // these classes cannot be re-exported + type WsHandlers = Types.WsHandlers; + type WsInitParams = Types.WsInitParams; + type IcWebSocketState = State.IcWebSocketState; /// The IC WebSocket instance. /// + /// **Note**: Restarts the acknowledgement timers under the hood. + /// /// # Traps /// If the parameters are invalid. See [`WsInitParams::check_validity`] for more details. - public class IcWebSocket(init_ws_state : IcWebSocketState, params : WsInitParams) { + public class IcWebSocket(init_ws_state : IcWebSocketState, params : WsInitParams, handlers : WsHandlers) { /// The state of the IC WebSocket. private var WS_STATE : IcWebSocketState = init_ws_state; - /// The callback handlers for the WebSocket. - private var HANDLERS : WsHandlers = params.get_handlers(); // the equivalent of the [init] function for the Rust CDK do { // check if the parameters are valid params.check_validity(); - // reset initial timers - WS_STATE.reset_timers(); + // cancel initial timers + Timers.cancel_timers(WS_STATE); // schedule a timer that will send an acknowledgement message to clients - WS_STATE.schedule_send_ack_to_clients(params.send_ack_interval_ms, params.keep_alive_timeout_ms, HANDLERS); - }; - - /// Resets the internal state of the IC WebSocket CDK. - /// - /// **Note:** You should only call this function in tests. - public func wipe() : async () { - await WS_STATE.reset_internal_state(HANDLERS); - - Logger.custom_print("Internal state has been wiped!"); + Timers.schedule_send_ack_to_clients(WS_STATE, params.send_ack_interval_ms, params.keep_alive_timeout_ms, handlers); }; - /// Handles the WS connection open event received from the WS Gateway - /// - /// WS Gateway relays the first message sent by the client together with its signature - /// to prove that the first message is actually coming from the same client that registered its public key - /// beforehand by calling the [ws_register] method. + /// Handles the WS connection open event sent by the client and relayed by the Gateway. public func ws_open(caller : Principal, args : CanisterWsOpenArguments) : async CanisterWsOpenResult { // anonymous clients cannot open a connection if (Principal.isAnonymous(caller)) { - return #Err("anonymous principal cannot open a connection"); + return #Err(Errors.to_string(#AnonymousPrincipalNotAllowed)); }; - // avoid gateway opening a connection for its own principal - if (WS_STATE.is_registered_gateway(caller)) { - return #Err("caller is the registered gateway which can't open a connection for itself"); - }; - - let client_key : ClientKey = { + let client_key : Types.ClientKey = { client_principal = caller; client_nonce = args.client_nonce; }; // check if client is not registered yet - if (WS_STATE.is_client_registered(client_key)) { - return #Err("client with key " # clientKeyToText(client_key) # " already has an open connection"); + // by swapping the result of the check_registered_client_exists function + switch (WS_STATE.check_registered_client_exists(client_key)) { + case (#Err(err)) { + // do nothing + }; + case (#Ok(_)) { + return #Err(Errors.to_string(#ClientKeyAlreadyConnected({ client_key }))); + }; }; // initialize client maps - let new_client = RegisteredClient(args.gateway_principal); + let new_client = Types.RegisteredClient(args.gateway_principal); WS_STATE.add_client(client_key, new_client); - let open_message : CanisterOpenMessageContent = { + let open_message : Types.CanisterOpenMessageContent = { client_key; }; - let message : WebsocketServiceMessageContent = #OpenMessage(open_message); - switch (send_service_message_to_client(WS_STATE, client_key, message)) { + let message : Types.WebsocketServiceMessageContent = #OpenMessage(open_message); + switch (WS_STATE.send_service_message_to_client(client_key, message)) { case (#Err(err)) { return #Err(err); }; @@ -1200,7 +115,7 @@ module { }; }; - await HANDLERS.call_on_open({ + await handlers.call_on_open({ client_principal = client_key.client_principal; }); @@ -1209,7 +124,20 @@ module { /// Handles the WS connection close event received from the WS Gateway. public func ws_close(caller : Principal, args : CanisterWsCloseArguments) : async CanisterWsCloseResult { - switch (WS_STATE.check_is_registered_gateway(caller)) { + let gateway_principal = caller; + + // check if the gateway is registered + switch (WS_STATE.check_is_gateway_registered(gateway_principal)) { + case (#Err(err)) { + return #Err(err); + }; + case (_) { + // do nothing + }; + }; + + // check if client registered itself by calling ws_open + switch (WS_STATE.check_registered_client_exists(args.client_key)) { case (#Err(err)) { return #Err(err); }; @@ -1218,7 +146,8 @@ module { }; }; - switch (WS_STATE.check_registered_client(args.client_key)) { + // check if the client is registered to the gateway that is closing the connection + switch (WS_STATE.check_client_registered_to_gateway(args.client_key, gateway_principal)) { case (#Err(err)) { return #Err(err); }; @@ -1227,7 +156,7 @@ module { }; }; - await WS_STATE.remove_client(args.client_key, HANDLERS); + await WS_STATE.remove_client(args.client_key, handlers); #Ok; }; @@ -1258,6 +187,7 @@ module { /// } /// ``` public func ws_message(caller : Principal, args : CanisterWsMessageArguments, _msg_type : ?Any) : async CanisterWsMessageResult { + let client_principal = caller; // check if client registered its principal by calling ws_open let registered_client_key = switch (WS_STATE.get_client_key_from_principal(caller)) { case (#Err(err)) { @@ -1277,8 +207,8 @@ module { } = args.msg; // check if the client key is correct - if (not areClientKeysEqual(registered_client_key, client_key)) { - return #Err("client with principal " #Principal.toText(caller) # " has a different key than the one used in the message"); + if (not Types.areClientKeysEqual(registered_client_key, client_key)) { + return #Err(Errors.to_string(#ClientKeyMessageMismatch({ client_key }))); }; let expected_sequence_num = switch (WS_STATE.get_expected_incoming_message_from_client_num(client_key)) { @@ -1292,14 +222,8 @@ module { // check if the incoming message has the expected sequence number if (sequence_num != expected_sequence_num) { - await WS_STATE.remove_client(client_key, HANDLERS); - return #Err( - "incoming client's message does not have the expected sequence number. Expected: " # - Nat64.toText(expected_sequence_num) - # ", actual: " # - Nat64.toText(sequence_num) - # ". Client removed." - ); + await WS_STATE.remove_client(client_key, handlers); + return #Err(Errors.to_string(#IncomingSequenceNumberWrong({ expected_sequence_num; actual_sequence_num = sequence_num }))); }; // increase the expected sequence number by 1 switch (WS_STATE.increment_expected_incoming_message_from_client_num(client_key)) { @@ -1312,60 +236,70 @@ module { }; if (is_service_message) { - return await handle_received_service_message(client_key, content); + return await WS_STATE.handle_received_service_message(client_key, content); }; - await HANDLERS.call_on_message({ + await handlers.call_on_message({ client_principal = client_key.client_principal; message = content; }); - #Ok; }; /// Returns messages to the WS Gateway in response of a polling iteration. public func ws_get_messages(caller : Principal, args : CanisterWsGetMessagesArguments) : CanisterWsGetMessagesResult { - // check if the caller of this method is the WS Gateway that has been set during the initialization of the SDK - switch (WS_STATE.check_is_registered_gateway(caller)) { - case (#Err(err)) { - #Err(err); - }; - case (_) { - WS_STATE.get_cert_messages(caller, args.nonce, params.max_number_of_returned_messages); - }; + let gateway_principal = caller; + if (not WS_STATE.is_registered_gateway(gateway_principal)) { + return WS_STATE.get_cert_messages_empty(); }; + + WS_STATE.get_cert_messages(caller, args.nonce, params.max_number_of_returned_messages); }; /// Sends a message to the client. See [IcWebSocketCdk.ws_send] function for reference. public func send(client_principal : ClientPrincipal, msg_bytes : Blob) : async CanisterWsSendResult { - await ws_send(WS_STATE, client_principal, msg_bytes); + WS_STATE._ws_send_to_client_principal(client_principal, msg_bytes); }; - func handle_received_service_message(client_key : ClientKey, content : Blob) : async Result<(), Text> { - let message_content = switch (decode_websocket_service_message_content(content)) { - case (#Err(err)) { - return #Err(err); - }; - case (#Ok(message_content)) { - message_content; - }; - }; + /// Resets the internal state of the IC WebSocket CDK. + /// + /// **Note:** You should only call this function in tests. + public func wipe() : async () { + await WS_STATE.reset_internal_state(handlers); - switch (message_content) { - case (#KeepAliveMessage(keep_alive_message)) { - handle_keep_alive_client_message(client_key, keep_alive_message); - #Ok; - }; - case (_) { - return #Err("Invalid received service message"); - }; - }; + Utils.custom_print("Internal state has been wiped!"); }; + }; - func handle_keep_alive_client_message(client_key : ClientKey, _keep_alive_message : ClientKeepAliveMessageContent) { - // TODO: delete messages from the queue that have been acknowledged by the client - - WS_STATE.update_last_keep_alive_timestamp_for_client(client_key); - }; + /// Sends a message to the client. The message must already be serialized **using Candid**. + /// Use [`to_candid`] to serialize the message. + /// + /// Under the hood, the message is certified and added to the queue of messages + /// that the WS Gateway will poll in the next iteration. + /// + /// # Example + /// This example is the serialize equivalent of the [`OnMessageCallbackArgs`]'s deserialize one. + /// ```motoko + /// import IcWebSocketCdk "mo:ic-websocket-cdk"; + /// + /// actor MyCanister { + /// // ... + /// + /// type MyMessage = { + /// some_field: Text; + /// }; + /// + /// // initialize the CDK + /// + /// // at some point in your code + /// let msg : MyMessage = { + /// some_field: "Hello, World!"; + /// }; + /// + /// IcWebSocketCdk.ws_send(ws_state, client_principal, to_candid(msg)); + /// } + /// ``` + public func ws_send(ws_state : IcWebSocketState, client_principal : ClientPrincipal, msg_bytes : Blob) : async CanisterWsSendResult { + ws_state._ws_send_to_client_principal(client_principal, msg_bytes); }; }; diff --git a/tests/test_canister/src/test_canister/main.mo b/tests/test_canister/src/test_canister/main.mo index ca078ca..669d4d7 100644 --- a/tests/test_canister/src/test_canister/main.mo +++ b/tests/test_canister/src/test_canister/main.mo @@ -2,9 +2,10 @@ import Array "mo:base/Array"; import Debug "mo:base/Debug"; import Nat64 "mo:base/Nat64"; import IcWebSocketCdk "mo:ic-websocket-cdk"; +import IcWebSocketCdkTypes "mo:ic-websocket-cdk/Types"; +import IcWebSocketCdkState "mo:ic-websocket-cdk/State"; actor class TestCanister( - gateway_principals : [Text], init_max_number_of_returned_messages : Nat64, init_send_ack_interval_ms : Nat64, init_keep_alive_timeout_ms : Nat64, @@ -14,7 +15,13 @@ actor class TestCanister( text : Text; }; - var ws_state = IcWebSocketCdk.IcWebSocketState(gateway_principals); + let params = IcWebSocketCdkTypes.WsInitParams( + ?Nat64.toNat(init_max_number_of_returned_messages), + ?init_send_ack_interval_ms, + ?init_keep_alive_timeout_ms, + ); + + var ws_state = IcWebSocketCdkState.IcWebSocketState(params); func on_open(args : IcWebSocketCdk.OnOpenCallbackArgs) : async () { Debug.print("Opened websocket: " # debug_show (args.client_principal)); @@ -28,20 +35,13 @@ actor class TestCanister( Debug.print("Client " # debug_show (args.client_principal) # " disconnected"); }; - let handlers = IcWebSocketCdk.WsHandlers( + let handlers = IcWebSocketCdkTypes.WsHandlers( ?on_open, ?on_message, ?on_close, ); - let params = IcWebSocketCdk.WsInitParams( - handlers, - ?Nat64.toNat(init_max_number_of_returned_messages), - ?init_send_ack_interval_ms, - ?init_keep_alive_timeout_ms, - ); - - var ws = IcWebSocketCdk.IcWebSocket(ws_state, params); + var ws = IcWebSocketCdk.IcWebSocket(ws_state, params, handlers); // method called by the WS Gateway after receiving FirstMessage from the client public shared ({ caller }) func ws_open(args : IcWebSocketCdk.CanisterWsOpenArguments) : async IcWebSocketCdk.CanisterWsOpenResult { From 82f4f6cfb6f8811bf2aac399a1b3866d80629fa0 Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Mon, 4 Dec 2023 10:33:15 +0100 Subject: [PATCH 4/8] fix: remove client if it was connected on ws_open --- src/State.mo | 18 ++++++++++-------- src/lib.mo | 11 +++++++++++ 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/src/State.mo b/src/State.mo index 55daff6..a815c66 100644 --- a/src/State.mo +++ b/src/State.mo @@ -252,25 +252,27 @@ module { increment_gateway_clients_count(new_client.gateway_principal); }; + /// Removes a client from the internal state + /// and call the on_close callback, + /// if the client was registered in the state. public func remove_client(client_key : ClientKey, handlers : WsHandlers) : async () { CLIENTS_WAITING_FOR_KEEP_ALIVE := TrieSet.delete(CLIENTS_WAITING_FOR_KEEP_ALIVE, client_key, Types.hashClientKey(client_key), Types.areClientKeysEqual); CURRENT_CLIENT_KEY_MAP.delete(client_key.client_principal); OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.delete(client_key); INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.delete(client_key); - let registered_client = REGISTERED_CLIENTS.remove(client_key); - switch (registered_client) { + switch (REGISTERED_CLIENTS.remove(client_key)) { case (?registered_client) { decrement_gateway_clients_count(registered_client.gateway_principal); + + await handlers.call_on_close({ + client_principal = client_key.client_principal; + }); }; case (null) { - Prelude.unreachable(); + // Do nothing }; }; - - await handlers.call_on_close({ - client_principal = client_key.client_principal; - }); }; public func format_message_for_gateway_key(gateway_principal : Principal, nonce : Nat64) : Text { @@ -413,7 +415,7 @@ module { }; }; - /// Deletes the an amount of [MESSAGES_TO_DELETE] messages from the queue + /// Deletes the an amount of [MESSAGES_TO_DELETE_COUNT] messages from the queue /// that are older than the ack interval. func delete_old_messages_for_gateway(gateway_principal : GatewayPrincipal) : Result<(), Text> { let ack_interval_ms = init_params.send_ack_interval_ms; diff --git a/src/lib.mo b/src/lib.mo index 70e302a..7fd4a81 100644 --- a/src/lib.mo +++ b/src/lib.mo @@ -98,6 +98,17 @@ module { }; }; + // check if there's a client already registered with the same principal + // and remove it if there is + switch (WS_STATE.get_client_key_from_principal(client_key.client_principal)) { + case (#Err(err)) { + // Do nothing + }; + case (#Ok(old_client_key)) { + await WS_STATE.remove_client(old_client_key, handlers); + }; + }; + // initialize client maps let new_client = Types.RegisteredClient(args.gateway_principal); WS_STATE.add_client(client_key, new_client); From e56239c0b1dcb35f435c322f1a6361c1cb3a1903 Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Mon, 4 Dec 2023 10:37:35 +0100 Subject: [PATCH 5/8] chore: disable backtrace in tests --- scripts/test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/test.sh b/scripts/test.sh index 4af5665..9ffe164 100755 --- a/scripts/test.sh +++ b/scripts/test.sh @@ -11,4 +11,4 @@ export TEST_CANISTER_WASM_PATH="$(pwd)/bin/test_canister.wasm" cd tests/ic-websocket-cdk-rs -RUST_BACKTRACE=1 cargo test --package ic-websocket-cdk --lib -- tests::integration_tests --test-threads 1 +cargo test --package ic-websocket-cdk --lib -- tests::integration_tests --test-threads 1 From b7c428e629fcb6aabe07eb3d316876abc246f691 Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Mon, 4 Dec 2023 11:16:17 +0100 Subject: [PATCH 6/8] fix: delete old messages for gateway integration tests are passing! --- src/Types.mo | 9 +++------ tests/ic-websocket-cdk-rs | 2 +- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/Types.mo b/src/Types.mo index 2f95581..cb1d576 100644 --- a/src/Types.mo +++ b/src/Types.mo @@ -282,12 +282,7 @@ module { var deleted_keys : List.List = List.nil(); label f for (_ in Iter.range(0, n - 1)) { - let message_to_delete = do { - let (m, l) = List.pop(messages_to_delete); - messages_to_delete := l; - m; - }; - switch (message_to_delete) { + switch (List.get(messages_to_delete, 0)) { case (?message_to_delete) { if ((time - message_to_delete.timestamp) > (message_max_age_ms * 1_000_000)) { let deleted_message = do { @@ -308,6 +303,8 @@ module { Prelude.unreachable(); }; }; + let (_, l) = List.pop(messages_to_delete); + messages_to_delete := l; } else { // In this case, no messages can be deleted because // they're all not older than `message_max_age_ms`. diff --git a/tests/ic-websocket-cdk-rs b/tests/ic-websocket-cdk-rs index e0104cf..4ff0311 160000 --- a/tests/ic-websocket-cdk-rs +++ b/tests/ic-websocket-cdk-rs @@ -1 +1 @@ -Subproject commit e0104cf432c8732ac50da70e147134c866218942 +Subproject commit 4ff03111d5efcc3f02fe61981d40d9787d6254df From 2b21b5ce65006cefc27b8a745e6c4d6dc71906a8 Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Mon, 4 Dec 2023 12:26:47 +0100 Subject: [PATCH 7/8] fix: update dids --- README.md | 2 +- did/service.example.did | 2 +- did/ws_types.did | 3 +++ 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index d163388..0e5e854 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ In order for the frontend clients and the Gateway to work properly, the canister ``` import "./ws_types.did"; -// define here your message type +// define your message type here type MyMessageType = record { some_field : text; }; diff --git a/did/service.example.did b/did/service.example.did index 478789c..456a033 100644 --- a/did/service.example.did +++ b/did/service.example.did @@ -1,7 +1,7 @@ import "./ws_types.did"; // define your message type here -type MyMessageType = { +type MyMessageType = record { some_field : text; }; diff --git a/did/ws_types.did b/did/ws_types.did index ac375d7..2fa715a 100644 --- a/did/ws_types.did +++ b/did/ws_types.did @@ -1,4 +1,5 @@ type ClientPrincipal = principal; +type GatewayPrincipal = principal; type ClientKey = record { client_principal : ClientPrincipal; client_nonce : nat64; @@ -22,10 +23,12 @@ type CanisterOutputCertifiedMessages = record { messages : vec CanisterOutputMessage; cert : blob; tree : blob; + is_end_of_queue : bool; }; type CanisterWsOpenArguments = record { client_nonce : nat64; + gateway_principal : GatewayPrincipal; }; type CanisterWsOpenResult = variant { From f1a48381c1149281e9d435f4f8f1a8b45652b4b0 Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Mon, 4 Dec 2023 12:52:21 +0100 Subject: [PATCH 8/8] chore: bump to version v0.3.1 --- mops.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mops.toml b/mops.toml index 64c7f8e..544313a 100644 --- a/mops.toml +++ b/mops.toml @@ -1,6 +1,6 @@ [package] name = "ic-websocket-cdk" -version = "0.3.0" +version = "0.3.1" description = "IC WebSocket Motoko CDK" repository = "https://github.com/omnia-network/ic-websocket-cdk-mo" keywords = ["ic", "websocket", "motoko", "cdk"]