diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000..09571f2 --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,66 @@ +name: Release ic-websocket-cdk-mo + +# only run when the tests complete +on: + workflow_run: + workflows: [ic-websocket-cdk-mo tests] + types: + - completed + branches: + - main + +jobs: + publish: + runs-on: ubuntu-latest + # only run if the tests were successful + if: ${{ github.event.workflow_run.conclusion == 'success' }} + outputs: + version: ${{ steps.npm-publish.outputs.version }} + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-node@v3 + with: + node-version: 18 + + - uses: aviate-labs/setup-dfx@v0.2.6 + with: + dfx-version: 0.14.1 + env: + DFX_IDENTITY_PEM: ${{ secrets.DFX_IDENTITY_PEM }} + + - name: install mops + run: npm i ic-mops -g + + - run: | + dfx identity use action + mops import-identity --no-encrypt -- "$(dfx identity export action)" + mops publish --no-docs + echo "version=$(cat mops.toml | grep "version =" | cut -d\" -f2)" >> "$GITHUB_OUTPUT" + env: + CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }} + + tag: + needs: publish + runs-on: ubuntu-latest + outputs: + version: ${{ steps.tag_version.outputs.new_tag }} + steps: + - name: Checkout + uses: actions/checkout@v3 + - name: Bump version and push tag + id: tag_version + uses: mathieudutour/github-tag-action@v6.1 + with: + github_token: ${{ secrets.GITHUB_TOKEN }} + custom_tag: ${{ needs.publish.outputs.version }} + + release: + needs: tag + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v3 + - name: Release + uses: softprops/action-gh-release@v1 + with: + tag_name: ${{ needs.tag.outputs.version }} diff --git a/mops.toml b/mops.toml index 7012451..2520650 100644 --- a/mops.toml +++ b/mops.toml @@ -1,6 +1,6 @@ [package] name = "ic-websocket-cdk" -version = "0.1.0" +version = "0.1.1" description = "IC WebSocket Motoko CDK" repository = "https://github.com/omnia-network/ic-websocket-cdk-mo" keywords = ["ic", "websocket", "motoko", "cdk"] diff --git a/src/lib.mo b/src/lib.mo index 51d9de5..a54fab6 100644 --- a/src/lib.mo +++ b/src/lib.mo @@ -3,6 +3,7 @@ import Array "mo:base/Array"; import Blob "mo:base/Blob"; import CertifiedData "mo:base/CertifiedData"; +import Debug "mo:base/Debug"; import Deque "mo:base/Deque"; import HashMap "mo:base/HashMap"; import Hash "mo:base/Hash"; @@ -17,6 +18,7 @@ import Time "mo:base/Time"; import Timer "mo:base/Timer"; import Bool "mo:base/Bool"; import Error "mo:base/Error"; +import TrieSet "mo:base/TrieSet"; import CborValue "mo:cbor/Value"; import CborDecoder "mo:cbor/Decoder"; import CborEncoder "mo:cbor/Encoder"; @@ -42,6 +44,14 @@ module { /// The default timeout to wait for the client to send a keep alive after receiving an acknowledgement. let DEFAULT_CLIENT_KEEP_ALIVE_TIMEOUT_MS : Nat64 = 10_000; // 10 seconds + /// The initial nonce for outgoing messages. + let INITIAL_OUTGOING_MESSAGE_NONCE : Nat64 = 0; + /// The initial sequence number to expect from messages coming from clients. + /// The first message coming from the client will have sequence number `1` because on the client the sequence number is incremented before sending the message. + let INITIAL_CLIENT_SEQUENCE_NUM : Nat64 = 1; + /// The initial sequence number for outgoing messages. + let INITIAL_CANISTER_SEQUENCE_NUM : Nat64 = 0; + //// TYPES //// type CandidType = Type.Type; type CandidValue = Value.Value; @@ -588,6 +598,8 @@ module { public var REGISTERED_CLIENTS = HashMap.HashMap(0, areClientKeysEqual, hashClientKey); /// Maps the client's principal to the current client key public var CURRENT_CLIENT_KEY_MAP = HashMap.HashMap(0, Principal.equal, Principal.hash); + /// Keeps track of all the clients for which we're waiting for a keep alive message. + public var CLIENTS_WAITING_FOR_KEEP_ALIVE : TrieSet.Set = TrieSet.empty(); /// Maps the client's public key to the sequence number to use for the next outgoing message (to that client). public var OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP = HashMap.HashMap(0, areClientKeysEqual, hashClientKey); /// Maps the client's public key to the expected sequence number of the next incoming message (from that client). @@ -602,7 +614,7 @@ module { /// Keeps track of the nonce which: /// - the WS Gateway uses to specify the first index of the certified messages to be returned when polling /// - the client uses as part of the path in the Merkle tree in order to verify the certificate of the messages relayed by the WS Gateway - public var OUTGOING_MESSAGE_NONCE : Nat64 = 0; + public var OUTGOING_MESSAGE_NONCE : Nat64 = INITIAL_OUTGOING_MESSAGE_NONCE; /// The acknowledgement active timer. public var ACK_TIMER : ?Timer.TimerId = null; // /// The keep alive active timer. @@ -616,13 +628,15 @@ module { await remove_client(client_key, handlers); }; + // make sure all the maps are cleared CURRENT_CLIENT_KEY_MAP := HashMap.HashMap(0, Principal.equal, Principal.hash); + CLIENTS_WAITING_FOR_KEEP_ALIVE := TrieSet.empty(); OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP := HashMap.HashMap(0, areClientKeysEqual, hashClientKey); INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP := HashMap.HashMap(0, areClientKeysEqual, hashClientKey); CERT_TREE_STORE := CertTree.newStore(); CERT_TREE := CertTree.Ops(CERT_TREE_STORE); MESSAGES_FOR_GATEWAY := List.nil(); - OUTGOING_MESSAGE_NONCE := 0; + OUTGOING_MESSAGE_NONCE := INITIAL_OUTGOING_MESSAGE_NONCE; }; public func get_outgoing_message_nonce() : Nat64 { @@ -657,12 +671,16 @@ module { #Ok; }; + public func add_client_to_wait_for_keep_alive(client_key : ClientKey) { + CLIENTS_WAITING_FOR_KEEP_ALIVE := TrieSet.put(CLIENTS_WAITING_FOR_KEEP_ALIVE, client_key, hashClientKey(client_key), areClientKeysEqual); + }; + public func get_registered_gateway_principal() : Principal { REGISTERED_GATEWAY.gateway_principal; }; func init_outgoing_message_to_client_num(client_key : ClientKey) { - OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.put(client_key, 0); + OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.put(client_key, INITIAL_CANISTER_SEQUENCE_NUM); }; public func get_outgoing_message_to_client_num(client_key : ClientKey) : Result { @@ -684,7 +702,7 @@ module { }; func init_expected_incoming_message_from_client_num(client_key : ClientKey) { - INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.put(client_key, 1); + INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.put(client_key, INITIAL_CLIENT_SEQUENCE_NUM); }; public func get_expected_incoming_message_from_client_num(client_key : ClientKey) : Result { @@ -715,6 +733,7 @@ module { }; public func remove_client(client_key : ClientKey, handlers : WsHandlers) : async () { + CLIENTS_WAITING_FOR_KEEP_ALIVE := TrieSet.delete(CLIENTS_WAITING_FOR_KEEP_ALIVE, client_key, hashClientKey(client_key), areClientKeysEqual); CURRENT_CLIENT_KEY_MAP.delete(client_key.client_principal); REGISTERED_CLIENTS.delete(client_key); OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.delete(client_key); @@ -902,37 +921,36 @@ module { reset_keep_alive_timer(); }; - /// Schedules a timer to send an acknowledgement message to the client. + /// Start an interval to send an acknowledgement messages to the clients. /// - /// The timer callback is [send_ack_to_clients_timer_callback]. After the callback is executed, + /// The interval callback is [send_ack_to_clients_timer_callback]. After the callback is executed, /// a timer is scheduled to check if the registered clients have sent a keep alive message. public func schedule_send_ack_to_clients(send_ack_interval_ms : Nat64, keep_alive_timeout_ms : Nat64, handlers : WsHandlers) { - let timer_id = Timer.setTimer( + let timer_id = Timer.recurringTimer( #nanoseconds(Nat64.toNat(send_ack_interval_ms) * 1_000_000), func() : async () { send_ack_to_clients_timer_callback(); - schedule_check_keep_alive(send_ack_interval_ms, keep_alive_timeout_ms, handlers); + schedule_check_keep_alive(keep_alive_timeout_ms, handlers); }, ); put_ack_timet_id(timer_id); }; - /// Schedules a timer to check if the registered clients have sent a keep alive message + /// Schedules a timer to check if the clients (only those to which an ack message was sent) have sent a keep alive message /// after receiving an acknowledgement message. /// - /// The timer callback is [check_keep_alive_timer_callback]. After the callback is executed, - /// a timer is scheduled again to send an acknowledgement message to the registered clients. - func schedule_check_keep_alive(send_ack_interval_ms : Nat64, keep_alive_timeout_ms : Nat64, handlers : WsHandlers) { + /// The timer callback is [check_keep_alive_timer_callback]. + func schedule_check_keep_alive(keep_alive_timeout_ms : Nat64, handlers : WsHandlers) { let timer_id = Timer.setTimer( #nanoseconds(Nat64.toNat(keep_alive_timeout_ms) * 1_000_000), func() : async () { await check_keep_alive_timer_callback(keep_alive_timeout_ms, handlers); - - schedule_send_ack_to_clients(send_ack_interval_ms, keep_alive_timeout_ms, handlers); }, ); + + put_keep_alive_timer_id(timer_id); }; /// Sends an acknowledgement message to the client. @@ -943,6 +961,7 @@ module { switch (get_expected_incoming_message_from_client_num(client_key)) { case (#Ok(expected_incoming_sequence_num)) { let ack_message : CanisterAckMessageContent = { + // the expected sequence number is 1 more because it's incremented when a message is received last_incoming_sequence_num = expected_incoming_sequence_num - 1; }; let message : WebsocketServiceMessageContent = #AckMessage(ack_message); @@ -953,7 +972,7 @@ module { Logger.custom_print("[ack-to-clients-timer-cb]: Error sending ack message to client" # clientKeyToText(client_key) # ": " # err); }; case (#Ok(_)) { - // Do nothing + add_client_to_wait_for_keep_alive(client_key); }; }; }; @@ -967,16 +986,24 @@ module { Logger.custom_print("[ack-to-clients-timer-cb]: Sent ack messages to all clients"); }; - /// Checks if the registered clients have sent a keep alive message. - /// If a client has not sent a keep alive message, it is removed from the registered clients. + /// Checks if the clients for which we are waiting for keep alive have sent a keep alive message. + /// If a client has not sent a keep alive message, it is removed from the connected clients. func check_keep_alive_timer_callback(keep_alive_timeout_ms : Nat64, handlers : WsHandlers) : async () { - for ((client_key, client_metadata) in REGISTERED_CLIENTS.entries()) { - let last_keep_alive = client_metadata.last_keep_alive_timestamp; + for (client_key in Array.vals(TrieSet.toArray(CLIENTS_WAITING_FOR_KEEP_ALIVE))) { + let client_metadata = REGISTERED_CLIENTS.get(client_key); + switch (client_metadata) { + case (?client_metadata) { + let last_keep_alive = client_metadata.last_keep_alive_timestamp; - if (get_current_time() - last_keep_alive > keep_alive_timeout_ms * 1_000_000) { - await remove_client(client_key, handlers); + if (get_current_time() - last_keep_alive > keep_alive_timeout_ms * 1_000_000) { + await remove_client(client_key, handlers); - Logger.custom_print("[check-keep-alive-timer-cb]: Client " # clientKeyToText(client_key) # " has not sent a keep alive message in the last " # debug_show (keep_alive_timeout_ms) # " ms and has been removed"); + Logger.custom_print("[check-keep-alive-timer-cb]: Client " # clientKeyToText(client_key) # " has not sent a keep alive message in the last " # debug_show (keep_alive_timeout_ms) # " ms and has been removed"); + }; + }; + case (null) { + // Do nothing + }; }; }; @@ -1114,12 +1141,18 @@ module { }; /// The interval at which to send an acknowledgement message to the client, /// so that the client knows that all the messages it sent have been received by the canister (in milliseconds). + /// + /// Must be greater than `keep_alive_timeout_ms`. + /// /// Defaults to `60_000` (60 seconds). public var send_ack_interval_ms : Nat64 = switch (init_send_ack_interval_ms) { case (?value) { value }; case (null) { DEFAULT_SEND_ACK_INTERVAL_MS }; }; /// The delay to wait for the client to send a keep alive after receiving an acknowledgement (in milliseconds). + /// + /// Must be lower than `send_ack_interval_ms`. + /// /// Defaults to `10_000` (10 seconds). public var keep_alive_timeout_ms : Nat64 = switch (init_keep_alive_timeout_ms) { case (?value) { value }; @@ -1129,8 +1162,23 @@ module { public func get_handlers() : WsHandlers { return handlers; }; + + /// Checks the validity of the timer parameters. + /// `send_ack_interval_ms` must be greater than `keep_alive_timeout_ms`. + /// + /// # Traps + /// If `send_ack_interval_ms` < `keep_alive_timeout_ms`. + public func check_validity() { + if (keep_alive_timeout_ms > send_ack_interval_ms) { + Debug.trap("send_ack_interval_ms must be greater than keep_alive_timeout_ms"); + }; + }; }; + /// The IC WebSocket instance. + /// + /// # Traps + /// If the parameters are invalid. See [`WsInitParams::check_validity`] for more details. public class IcWebSocket(init_ws_state : IcWebSocketState, params : WsInitParams) { /// The state of the IC WebSocket. private var WS_STATE : IcWebSocketState = init_ws_state; @@ -1138,6 +1186,9 @@ module { private var HANDLERS : WsHandlers = params.get_handlers(); do { + // check if the parameters are valid + params.check_validity(); + // reset initial timers WS_STATE.reset_timers(); diff --git a/tests/integration/canister.test.ts b/tests/integration/canister.test.ts index c63de28..68a531d 100644 --- a/tests/integration/canister.test.ts +++ b/tests/integration/canister.test.ts @@ -74,9 +74,9 @@ const DEFAULT_TEST_SEND_ACK_INTERVAL_MS = 300_000; * The interval between keep alive checks in the canister. * Set to a high value to make sure the canister doesn't reset the client while testing other functions. * - * Value: `300_000` (5 minutes) + * Value: `120_000` (2 minutes) */ -const DEFAULT_TEST_KEEP_ALIVE_TIMEOUT_MS = 300_000; +const DEFAULT_TEST_KEEP_ALIVE_TIMEOUT_MS = 120_000; let client1Key: ClientKey; let client2Key: ClientKey; @@ -692,11 +692,12 @@ describe("Messages acknowledgement", () => { }); it("client should receive ack messages", async () => { - const sendAckIntervalMs = 5_000; // 5 seconds + const sendAckIntervalMs = 10_000; // 10 seconds + const keepAliveDelayMs = 5_000; // 5 seconds await initializeCdk({ maxNumberOfReturnedMessages: DEFAULT_TEST_MAX_NUMBER_OF_RETURNED_MESSAGES, sendAckIntervalMs, - keepAliveDelayMs: DEFAULT_TEST_KEEP_ALIVE_TIMEOUT_MS, // keep alive timeout still high to avoid removing connected client + keepAliveDelayMs, }); await wsOpen({ @@ -722,7 +723,7 @@ describe("Messages acknowledgement", () => { message: createWebsocketMessage(client1Key, 1), }, true); - // sleep for 5 seconds, which is more than the sendAckIntervalMs due to the previous calls + // sleep for 10 seconds, which is more than the sendAckIntervalMs due to the previous calls // so we are sure that the CDK has sent an ack await sleep(sendAckIntervalMs); @@ -767,7 +768,7 @@ describe("Messages acknowledgement", () => { }); it("client is removed if keep alive timeout is reached", async () => { - const sendAckIntervalMs = 2_000; // 2 seconds + const sendAckIntervalMs = 10_000; // 10 seconds const keepAliveDelayMs = 5_000; // 5 seconds await initializeCdk({ maxNumberOfReturnedMessages: DEFAULT_TEST_MAX_NUMBER_OF_RETURNED_MESSAGES, @@ -809,7 +810,7 @@ describe("Messages acknowledgement", () => { }); it("client is not removed if it sends a keep alive before timeout", async () => { - const sendAckIntervalMs = 3_000; // 3 seconds + const sendAckIntervalMs = 15_000; // 15 seconds const keepAliveDelayMs = 5_000; // 5 seconds await initializeCdk({ maxNumberOfReturnedMessages: DEFAULT_TEST_MAX_NUMBER_OF_RETURNED_MESSAGES, @@ -829,11 +830,24 @@ describe("Messages acknowledgement", () => { nonce: BigInt(1), // skip the service open message }); let messagesResult = (res as { Ok: CanisterOutputCertifiedMessages }).Ok; + // the queue contains only the first ack message expect(messagesResult.messages.length).toEqual(1); + let ackMessage = messagesResult.messages[0]; + expect(isClientKeyEq(ackMessage.client_key, client1Key)).toEqual(true); + let websocketMessage = getWebsocketMessageFromCanisterMessage(ackMessage); + expect(websocketMessage.is_service_message).toEqual(true); + expect(websocketMessage.sequence_num).toEqual(2); + let serviceMessageContent = decodeWebsocketServiceMessageContent(websocketMessage.content as Uint8Array); + expect(serviceMessageContent).toMatchObject({ + AckMessage: { + last_incoming_sequence_num: BigInt(0), + }, + }); + // send the keep alive message const keepAliveMessage: WebsocketServiceMessageContent = { KeepAliveMessage: { - last_incoming_sequence_num: BigInt(1), // not relevant + last_incoming_sequence_num: BigInt(1), // ignored in the CDK }, }; await wsMessage({ @@ -841,9 +855,11 @@ describe("Messages acknowledgement", () => { message: createWebsocketMessage(client1Key, 1, encodeWebsocketServiceMessageContent(keepAliveMessage), true), }, true); + // wait for the canister to check if the client has sent the keep alive await sleep(keepAliveDelayMs); - // send a message to the canister to see the the sequence number increase in the ack message + // send a message to the canister to see the sequence number increasing in the ack message + // and be sure that the client can still send messages await wsMessage({ actor: client1, message: createWebsocketMessage(client1Key, 2), @@ -853,30 +869,81 @@ describe("Messages acknowledgement", () => { await sleep(sendAckIntervalMs); res = await gateway1.ws_get_messages({ - nonce: BigInt(1), // skip the service open message + nonce: BigInt(2), // skip the service open message and the first service ack message }); messagesResult = (res as { Ok: CanisterOutputCertifiedMessages }).Ok; - expect(messagesResult.messages.length).toEqual(2); + // the fetched queue only contains the second ack message + expect(messagesResult.messages.length).toEqual(1); + ackMessage = messagesResult.messages[0]; + expect(isClientKeyEq(ackMessage.client_key, client1Key)).toEqual(true); + websocketMessage = getWebsocketMessageFromCanisterMessage(ackMessage); + expect(websocketMessage.is_service_message).toEqual(true); + expect(websocketMessage.sequence_num).toEqual(3); + serviceMessageContent = decodeWebsocketServiceMessageContent(websocketMessage.content as Uint8Array); + expect(serviceMessageContent).toMatchObject({ + AckMessage: { + last_incoming_sequence_num: BigInt(2), // as expected, the canister acks both the messages that we've sent + }, + }); + }); - let expectedClientSequenceNum = 0; - let expectedCanisterSequenceNum = 2; // first message is skipped and sequence number starts from 1 - for (const message of messagesResult.messages) { - expect(isClientKeyEq(message.client_key, client1Key)).toEqual(true); + it("client is not removed if it connects while canister is waiting for keep alive", async () => { + const sendAckIntervalMs = 15_000; // 15 seconds + const keepAliveDelayMs = 10_000; // 10 seconds, to make sure the canister is waiting for keep alive when the client connects + await initializeCdk({ + maxNumberOfReturnedMessages: DEFAULT_TEST_MAX_NUMBER_OF_RETURNED_MESSAGES, + sendAckIntervalMs, + keepAliveDelayMs, + }); - const websocketMessage = getWebsocketMessageFromCanisterMessage(message); - expect(websocketMessage.is_service_message).toEqual(true); - expect(websocketMessage.sequence_num).toEqual(expectedCanisterSequenceNum); + // make sure the canister is waiting for keep alive + await sleep(sendAckIntervalMs); - const serviceMessageContent = decodeWebsocketServiceMessageContent(websocketMessage.content as Uint8Array); - expect(serviceMessageContent).toMatchObject({ - AckMessage: { - last_incoming_sequence_num: BigInt(expectedClientSequenceNum), - }, - }); + await wsOpen({ + clientNonce: client1Key.client_nonce, + canisterId, + clientActor: client1, + }, true); - expectedClientSequenceNum++; - expectedCanisterSequenceNum++; - } - }); + let res = await gateway1.ws_get_messages({ + nonce: BigInt(1), // skip the service open message + }); + let messagesResult = (res as { Ok: CanisterOutputCertifiedMessages }).Ok; + expect(messagesResult.messages.length).toEqual(0); // client doesn't expect any other messages at this point + + // send a message to the canister to see the sequence number increasing in the ack message + // ad verify that the client is still connected to the canister + const wsMessageRes = await wsMessage({ + actor: client1, + message: createWebsocketMessage(client1Key, 1), + }); + expect(wsMessageRes).toMatchObject({ + Ok: null, + }); + + // wait to for the keep alive timeout to expire + await sleep(keepAliveDelayMs); + // wait for the canister to send the next ack + await sleep(sendAckIntervalMs - keepAliveDelayMs); + + res = await gateway1.ws_get_messages({ + nonce: BigInt(1), // skip the service open message + }); + + messagesResult = (res as { Ok: CanisterOutputCertifiedMessages }).Ok; + expect(messagesResult.messages.length).toEqual(1); + + const ackMessage = messagesResult.messages[0]; + expect(isClientKeyEq(ackMessage.client_key, client1Key)).toEqual(true); + const websocketMessage = getWebsocketMessageFromCanisterMessage(ackMessage); + expect(websocketMessage.is_service_message).toEqual(true); + expect(websocketMessage.sequence_num).toEqual(2); // first message is skipped and sequence number starts from 1 + const serviceMessageContent = decodeWebsocketServiceMessageContent(websocketMessage.content as Uint8Array); + expect(serviceMessageContent).toMatchObject({ + AckMessage: { + last_incoming_sequence_num: BigInt(1), + }, + }); + }) });