From 5ac3759fb4e5a5166dfcb4fd326970edd165eff3 Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Mon, 28 Aug 2023 19:19:46 +0200 Subject: [PATCH 01/31] feat: authenticated messages, unit tests passing --- Cargo.lock | 47 +- src/ic-websocket-cdk/Cargo.toml | 1 - src/ic-websocket-cdk/service.example.did | 2 +- src/ic-websocket-cdk/src/lib.rs | 859 ++++++++--------------- src/ic-websocket-cdk/ws_types.did | 56 +- tests/src/canister.rs | 6 +- tests/src/lib.rs | 20 +- tests/test_canister.did | 4 +- 8 files changed, 338 insertions(+), 657 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index cca77ca..3ddbcb9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -69,9 +69,9 @@ checksum = "4c7f02d4ea65f2c1853089ffd8d2787bdbc63de2f0d29dedbcf8ccdfa0ccd4cf" [[package]] name = "base64" -version = "0.21.2" +version = "0.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d" +checksum = "414dcefbc63d77c526a76b3afcf6fbb9b5e2791c19c3aa2297733208750c6e53" [[package]] name = "base64ct" @@ -347,12 +347,6 @@ dependencies = [ "spki", ] -[[package]] -name = "ed25519-compact" -version = "2.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a3d382e8464107391c8706b4c14b087808ecb909f6c15c34114bc42e53a9e4c" - [[package]] name = "either" version = "1.9.0" @@ -842,7 +836,6 @@ version = "0.1.0" dependencies = [ "base64", "candid", - "ed25519-compact", "ic-agent", "ic-cdk", "ic-cdk-macros", @@ -974,9 +967,9 @@ checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" [[package]] name = "memchr" -version = "2.5.0" +version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" +checksum = "76fc44e2588d5b436dbc3c6cf62aef290f90dab6235744a93dfe1cc18f451e2c" [[package]] name = "mime" @@ -1130,9 +1123,9 @@ checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94" [[package]] name = "pin-project-lite" -version = "0.2.12" +version = "0.2.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "12cc1b0bf1727a77a54b6654e7b5f1af8604923edc8b81885f8ec92f9e3f0a05" +checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58" [[package]] name = "pin-utils" @@ -1358,9 +1351,9 @@ checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" [[package]] name = "rustix" -version = "0.38.8" +version = "0.38.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19ed4fa021d81c8392ce04db050a3da9a60299050b7ae1cf482d862b54a7218f" +checksum = "9bfe0f2582b4931a45d1fa608f8a8722e8b3c7ac54dd6d5f3b3212791fedef49" dependencies = [ "bitflags 2.4.0", "errno", @@ -1462,9 +1455,9 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.186" +version = "1.0.188" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f5db24220c009de9bd45e69fb2938f4b6d2df856aa9304ce377b3180f83b7c1" +checksum = "cf9e0fcba69a370eed61bcf2b728575f726b50b55cba78064753d708ddc7549e" dependencies = [ "serde_derive", ] @@ -1490,9 +1483,9 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.186" +version = "1.0.188" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ad697f7e0b65af4983a4ce8f56ed5b357e8d3c36651bf6a7e13639c17b8e670" +checksum = "4eca7ac642d82aa35b60049a6eccb4be6be75e599bd2e9adb5f875a737654af2" dependencies = [ "proc-macro2", "quote", @@ -1741,9 +1734,9 @@ dependencies = [ [[package]] name = "time" -version = "0.3.27" +version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bb39ee79a6d8de55f48f2293a830e040392f1c5f16e336bdd1788cd0aadce07" +checksum = "17f6bb557fd245c28e6411aa56b6403c689ad95061f50e4be16c274e70a17e48" dependencies = [ "deranged", "itoa", @@ -1760,9 +1753,9 @@ checksum = "7300fbefb4dadc1af235a9cef3737cea692a9d97e1b9cbcd4ebdae6f8868e6fb" [[package]] name = "time-macros" -version = "0.2.13" +version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "733d258752e9303d392b94b75230d07b0b9c489350c69b851fc6c065fde3e8f9" +checksum = "1a942f44339478ef67935ab2bbaec2fb0322496cf3cbe84b261e06ac3814c572" dependencies = [ "time-core", ] @@ -1930,9 +1923,9 @@ checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" [[package]] name = "url" -version = "2.4.0" +version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50bff7831e19200a85b17131d085c25d7811bc4e186efdaf54bbd132994a88cb" +checksum = "143b538f18257fac9cad154828a57c6bf5157e1aa604d4816b5995bf6de87ae5" dependencies = [ "form_urlencoded", "idna", @@ -2173,9 +2166,9 @@ checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "winnow" -version = "0.5.14" +version = "0.5.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d09770118a7eb1ccaf4a594a221334119a44a814fcb0d31c5b85e83e97227a97" +checksum = "7c2e3184b9c4e92ad5167ca73039d0c42476302ab603e2fec4487511f38ccefc" dependencies = [ "memchr", ] diff --git a/src/ic-websocket-cdk/Cargo.toml b/src/ic-websocket-cdk/Cargo.toml index b45e565..dc65cf2 100644 --- a/src/ic-websocket-cdk/Cargo.toml +++ b/src/ic-websocket-cdk/Cargo.toml @@ -16,7 +16,6 @@ ic-certified-map = "0.4.0" base64 = "0.21.2" sha2 = "0.10.7" serde_bytes = "0.11.12" -ed25519-compact = { version = "2.0.4", default-features = false } [dev-dependencies] ic-agent = "0.25.0" diff --git a/src/ic-websocket-cdk/service.example.did b/src/ic-websocket-cdk/service.example.did index c522572..39d0edb 100644 --- a/src/ic-websocket-cdk/service.example.did +++ b/src/ic-websocket-cdk/service.example.did @@ -1,9 +1,9 @@ import "./ws_types.did"; service : { - "ws_register" : (CanisterWsRegisterArguments) -> (CanisterWsRegisterResult); "ws_open" : (CanisterWsOpenArguments) -> (CanisterWsOpenResult); "ws_close" : (CanisterWsCloseArguments) -> (CanisterWsCloseResult); "ws_message" : (CanisterWsMessageArguments) -> (CanisterWsMessageResult); + "ws_status" : (CanisterWsStatusArguments) -> (CanisterWsStatusResult); "ws_get_messages" : (CanisterWsGetMessagesArguments) -> (CanisterWsGetMessagesResult) query; }; diff --git a/src/ic-websocket-cdk/src/lib.rs b/src/ic-websocket-cdk/src/lib.rs index 7164851..df8884c 100644 --- a/src/ic-websocket-cdk/src/lib.rs +++ b/src/ic-websocket-cdk/src/lib.rs @@ -1,20 +1,14 @@ use candid::{CandidType, Principal}; -use ed25519_compact::{PublicKey, Signature}; #[cfg(not(test))] use ic_cdk::api::time; use ic_cdk::api::{caller, data_certificate, set_certified_data}; use ic_cdk_timers::set_timer; use ic_certified_map::{labeled, labeled_hash, AsHashTree, Hash as ICHash, RbTree}; use serde::{Deserialize, Serialize}; -use serde_cbor::{from_slice, Serializer}; +use serde_cbor::Serializer; use sha2::{Digest, Sha256}; use std::time::Duration; -use std::{ - cell::RefCell, - collections::VecDeque, - collections::{HashMap, HashSet}, - convert::AsRef, -}; +use std::{cell::RefCell, collections::HashMap, collections::VecDeque, convert::AsRef}; mod logger; @@ -27,121 +21,66 @@ const CHECK_REGISTERED_GATEWAY_DELAY_NS: u64 = 60_000_000_000; // 60 seconds /// (**Used for integration tests**) The delay between two consecutive checks if the registered gateway is still alive. const CHECK_REGISTERED_GATEWAY_DELAY_NS_TEST: u64 = 15_000_000_000; // 15 seconds -pub type ClientPublicKey = Vec; +pub type ClientPrincipal = Principal; -/// The result of [ws_register]. -pub type CanisterWsRegisterResult = Result<(), String>; /// The result of [ws_open]. pub type CanisterWsOpenResult = Result; +/// The result of [ws_close]. +pub type CanisterWsCloseResult = Result<(), String>; /// The result of [ws_message]. pub type CanisterWsMessageResult = Result<(), String>; +/// The result of [ws_status]. +pub type CanisterWsStatusResult = Result<(), String>; /// The result of [ws_get_messages]. pub type CanisterWsGetMessagesResult = Result; /// The result of [ws_send]. pub type CanisterWsSendResult = Result<(), String>; -/// The result of [ws_close]. -pub type CanisterWsCloseResult = Result<(), String>; /// The Ok value of CanisterWsOpenResult returned by [ws_open]. #[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq, Debug)] pub struct CanisterWsOpenResultValue { - client_key: ClientPublicKey, - canister_id: Principal, + client_principal: ClientPrincipal, nonce: u64, } -/// The arguments for [ws_register]. -#[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq, Debug)] -pub struct CanisterWsRegisterArguments { - #[serde(with = "serde_bytes")] - client_key: ClientPublicKey, -} - /// The arguments for [ws_open]. #[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq, Debug)] pub struct CanisterWsOpenArguments { - #[serde(with = "serde_bytes")] - content: Vec, - #[serde(with = "serde_bytes")] - sig: Vec, + is_anonymous: bool, } /// The arguments for [ws_close]. #[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq, Debug)] pub struct CanisterWsCloseArguments { - #[serde(with = "serde_bytes")] - client_key: ClientPublicKey, + client_principal: ClientPrincipal, } /// The arguments for [ws_message]. #[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq, Debug)] pub struct CanisterWsMessageArguments { - msg: CanisterIncomingMessage, -} - -/// The arguments for [ws_get_messages]. -#[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq, Debug)] -pub struct CanisterWsGetMessagesArguments { - nonce: u64, + msg: WebsocketMessage, } -/// The first message received by the canister in [ws_open]. +/// The arguments for [ws_status]. #[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq, Debug)] -struct CanisterOpenMessageContent { - #[serde(with = "serde_bytes")] - client_key: ClientPublicKey, - canister_id: Principal, -} - -/// Message + signature from client, **relayed** by the WS Gateway. -#[derive(CandidType, Clone, Debug, Deserialize, Serialize, Eq, PartialEq)] -struct RelayedClientMessage { - #[serde(with = "serde_bytes")] - content: Vec, - #[serde(with = "serde_bytes")] - sig: Vec, -} - -/// Message coming **directly** from client, not relayed by the WS Gateway. -#[derive(CandidType, Clone, Debug, Deserialize, Serialize, Eq, PartialEq)] -struct DirectClientMessage { - message: Vec, - client_key: ClientPublicKey, -} - -/// Heartbeat message sent from the WS Gateway to the canister, so that the canister can -/// verify that the WS Gateway is still alive. -#[derive(CandidType, Clone, Debug, Deserialize, Serialize, Eq, PartialEq)] -struct GatewayStatusMessage { +pub struct CanisterWsStatusArguments { status_index: u64, } -/// The variants of the possible messages received by the canister in [ws_message]. -/// - **IcWebSocketEstablished**: message sent from WS Gateway to the canister to notify it about the -/// establishment of the IcWebSocketConnection -/// - **IcWebSocketGatewayStatus**: message sent from WS Gateway to the canister to notify it about the -/// status of the IcWebSocketConnection -/// - **RelayedByGateway**: message sent from the client to the WS Gateway (via WebSocket) and -/// relayed to the canister by the WS Gateway -/// - **DirectlyFromClient**: message sent from directly client so that it is not necessary to -/// verify the signature -#[derive(CandidType, Clone, Debug, Deserialize, Serialize, Eq, PartialEq)] -enum CanisterIncomingMessage { - DirectlyFromClient(DirectClientMessage), - RelayedByGateway(RelayedClientMessage), - IcWebSocketEstablished(ClientPublicKey), - IcWebSocketGatewayStatus(GatewayStatusMessage), +/// The arguments for [ws_get_messages]. +#[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq, Debug)] +pub struct CanisterWsGetMessagesArguments { + nonce: u64, } /// Messages exchanged through the WebSocket. #[derive(CandidType, Clone, Debug, Deserialize, Serialize, Eq, PartialEq)] struct WebsocketMessage { - #[serde(with = "serde_bytes")] - client_key: ClientPublicKey, // The client that the gateway will forward the message to or that sent the message. + client_principal: ClientPrincipal, // The client that the gateway will forward the message to or that sent the message. sequence_num: u64, // Both ways, messages should arrive with sequence numbers 0, 1, 2... timestamp: u64, // Timestamp of when the message was made for the recipient to inspect. #[serde(with = "serde_bytes")] - message: Vec, // Application message encoded in binary. + content: Vec, // Application message encoded in binary. } impl WebsocketMessage { @@ -158,11 +97,10 @@ impl WebsocketMessage { /// Element of the list of messages returned to the WS Gateway after polling. #[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq)] pub struct CanisterOutputMessage { - #[serde(with = "serde_bytes")] - client_key: Vec, // The client that the gateway will forward the message to or that sent the message. + client_principal: ClientPrincipal, // The client that the gateway will forward the message to or that sent the message. + key: String, // Key for certificate verification. #[serde(with = "serde_bytes")] content: Vec, // The message to be relayed, that contains the application message. - key: String, // Key for certificate verification. } /// List of messages returned to the WS Gateway after polling. @@ -249,57 +187,24 @@ fn get_check_registered_gateway_delay_ns() -> u64 { } } -/// The temporary clients that don't still have a connection open, based on the gateway status index at which they were registered. -/// The last two status indexes are kept in order to be able to handle the case when the gateway crashes and restarts. -struct TmpClients { - second_last_index_clients: HashSet, - last_index_clients: HashSet, +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct RegisteredClient { + is_anonymous: bool, } -impl TmpClients { - fn new() -> Self { - Self { - second_last_index_clients: HashSet::new(), - last_index_clients: HashSet::new(), - } - } - - fn shift(&mut self) { - self.second_last_index_clients = self.last_index_clients.clone(); - self.last_index_clients = HashSet::new(); - } - - fn clear(&mut self) { - self.second_last_index_clients.clear(); - self.last_index_clients.clear(); - } - - fn insert(&mut self, client_key: ClientPublicKey) { - self.last_index_clients.insert(client_key); - } - - fn contain_client(&self, client_key: &ClientPublicKey) -> bool { - self.last_index_clients.contains(client_key) - || self.second_last_index_clients.contains(client_key) - } - - fn remove(&mut self, client_key: &ClientPublicKey) { - let is_removed = self.last_index_clients.remove(client_key); - if !is_removed { - self.second_last_index_clients.remove(client_key); - } +impl RegisteredClient { + fn is_anonymous(&self) -> bool { + self.is_anonymous } } thread_local! { /// Maps the client's public key to the client's identity (anonymous if not authenticated). - /* flexible */ static CLIENT_CALLER_MAP: RefCell> = RefCell::new(HashMap::new()); - /// Maps the clients that still don't have a connection open, based on the gateway status index at which they were registered. - /* flexible */ static TMP_CLIENTS: RefCell = RefCell::new(TmpClients::new()); + /* flexible */ static REGISTERED_CLIENTS: RefCell> = RefCell::new(HashMap::new()); /// Maps the client's public key to the sequence number to use for the next outgoing message (to that client). - /* flexible */ static OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP: RefCell> = RefCell::new(HashMap::new()); + /* flexible */ static OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP: RefCell> = RefCell::new(HashMap::new()); /// Maps the client's public key to the expected sequence number of the next incoming message (from that client). - /* flexible */ static INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP: RefCell> = RefCell::new(HashMap::new()); + /* flexible */ static INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP: RefCell> = RefCell::new(HashMap::new()); /// Keeps track of the Merkle tree used for certified queries /* flexible */ static CERT_TREE: RefCell> = RefCell::new(RbTree::new()); /// Keeps track of the principal of the WS Gateway which polls the canister @@ -324,21 +229,15 @@ fn reset_internal_state() { // get the handlers to call the on_close handler for each client let handlers = HANDLERS.with(|state| state.borrow().clone()); - CLIENT_CALLER_MAP.with(|state| { + REGISTERED_CLIENTS.with(|state| { let mut map = state.borrow_mut(); // for each client, call the on_close handler before clearing the map - for (client_key, _) in map.clone().iter() { - // If a client registers while the gateway crashes and restarts, we have to keep the client in the map, - // so that the ws_open invoked by the gateway doesn't fail. - // To be sure that we retain the latest unregistered clients, - // we keep all the clients that have registered after the last two times the gateway updated the status index - if !is_client_in_tmp_clients(client_key) { - handlers.call_on_close(OnCloseCallbackArgs { - client_key: client_key.clone(), - }); - - map.remove(client_key); - } + for (client_principal, _) in map.clone().iter() { + handlers.call_on_close(OnCloseCallbackArgs { + client_principal: client_principal.clone(), + }); + + map.remove(client_principal); } }); @@ -370,15 +269,10 @@ pub fn wipe() { }); // remove all clients from the map - CLIENT_CALLER_MAP.with(|map| { + REGISTERED_CLIENTS.with(|map| { map.borrow_mut().clear(); }); - // clear the temporary clients - TMP_CLIENTS.with(|tmp_clients| { - tmp_clients.borrow_mut().clear(); - }); - custom_print!("Internal state has been wiped!"); } @@ -390,17 +284,30 @@ fn increment_outgoing_message_nonce() { OUTGOING_MESSAGE_NONCE.with(|n| n.replace_with(|&mut old| old + 1)); } -fn put_client_caller(client_key: ClientPublicKey, caller: Principal) { - CLIENT_CALLER_MAP.with(|map| { - map.borrow_mut().insert(client_key.clone(), caller); +fn insert_client(client_principal: ClientPrincipal, new_client: RegisteredClient) { + REGISTERED_CLIENTS.with(|map| { + map.borrow_mut() + .insert(client_principal.clone(), new_client); }); +} + +fn get_registered_client(client_principal: &ClientPrincipal) -> Option { + REGISTERED_CLIENTS.with(|map| map.borrow().get(client_principal).cloned()) +} - // add the client to the temporary clients - insert_in_tmp_clients(client_key); +fn is_client_registered(client_principal: &ClientPrincipal) -> bool { + REGISTERED_CLIENTS.with(|map| map.borrow().contains_key(client_principal)) } -fn get_client_caller(client_key: &ClientPublicKey) -> Option { - CLIENT_CALLER_MAP.with(|map| Some(map.borrow().get(client_key)?.to_owned())) +fn check_registered_client(client_principal: &ClientPrincipal) -> Result<(), String> { + if !REGISTERED_CLIENTS.with(|map| map.borrow().contains_key(client_principal)) { + return Err(String::from(format!( + "client with principal {:?} doesn't have an open connection", + client_principal + ))); + } + + Ok(()) } fn initialize_registered_gateway(gateway_principal: &str) { @@ -419,34 +326,6 @@ fn get_registered_gateway_principal() -> Principal { }) } -fn insert_in_tmp_clients(client_key: ClientPublicKey) { - TMP_CLIENTS.with(|tmp_clients| { - let mut tmp_clients = tmp_clients.borrow_mut(); - tmp_clients.insert(client_key); - }); -} - -fn shift_tmp_clients() { - TMP_CLIENTS.with(|tmp_clients| { - let mut tmp_clients = tmp_clients.borrow_mut(); - tmp_clients.shift(); - }); -} - -fn is_client_in_tmp_clients(client_key: &ClientPublicKey) -> bool { - TMP_CLIENTS.with(|tmp_clients| { - let tmp_clients = tmp_clients.borrow(); - tmp_clients.contain_client(client_key) - }) -} - -fn remove_client_from_tmp_clients(client_key: &ClientPublicKey) { - TMP_CLIENTS.with(|tmp_clients| { - let mut tmp_clients = tmp_clients.borrow_mut(); - tmp_clients.remove(client_key); - }); -} - /// Updates the registered gateway with the new status index. /// If the status index is not greater than the current one, the function returns an error. fn update_registered_gateway_status_index(status_index: u64) -> Result<(), String> { @@ -463,9 +342,6 @@ fn update_registered_gateway_status_index(status_index: u64) -> Result<(), Strin Ok(()) } else { - // update the temporary clients, shifting the last index clients to the second last index clients - shift_tmp_clients(); - v.update_status_index(status_index) } } else { @@ -474,52 +350,44 @@ fn update_registered_gateway_status_index(status_index: u64) -> Result<(), Strin }) } -fn check_registered_client_key(client_key: &ClientPublicKey) -> Result<(), String> { - if !CLIENT_CALLER_MAP.with(|map| map.borrow().contains_key(client_key)) { - return Err(String::from( - "client's public key has not been previously registered by client", - )); - } - - Ok(()) -} - -fn init_outgoing_message_to_client_num(client_key: ClientPublicKey) { +fn init_outgoing_message_to_client_num(client_principal: ClientPrincipal) { OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| { - map.borrow_mut().insert(client_key, 0); + map.borrow_mut().insert(client_principal, 0); }); } -fn get_outgoing_message_to_client_num(client_key: &ClientPublicKey) -> Result { +fn get_outgoing_message_to_client_num(client_principal: &ClientPrincipal) -> Result { OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| { let map = map.borrow(); - let num = *map.get(client_key).ok_or(String::from( + let num = *map.get(client_principal).ok_or(String::from( "outgoing message to client num not initialized for client", ))?; Ok(num) }) } -fn increment_outgoing_message_to_client_num(client_key: &ClientPublicKey) -> Result<(), String> { - let num = get_outgoing_message_to_client_num(client_key)?; +fn increment_outgoing_message_to_client_num( + client_principal: &ClientPrincipal, +) -> Result<(), String> { + let num = get_outgoing_message_to_client_num(client_principal)?; OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| { let mut map = map.borrow_mut(); - map.insert(client_key.clone(), num + 1); + map.insert(client_principal.clone(), num + 1); Ok(()) }) } -fn init_expected_incoming_message_from_client_num(client_key: ClientPublicKey) { +fn init_expected_incoming_message_from_client_num(client_principal: ClientPrincipal) { INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| { - map.borrow_mut().insert(client_key, 0); + map.borrow_mut().insert(client_principal, 0); }); } fn get_expected_incoming_message_from_client_num( - client_key: &ClientPublicKey, + client_principal: &ClientPrincipal, ) -> Result { INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| { - let num = *map.borrow().get(client_key).ok_or(String::from( + let num = *map.borrow().get(client_principal).ok_or(String::from( "expected incoming message num not initialized for client", ))?; Ok(num) @@ -527,35 +395,34 @@ fn get_expected_incoming_message_from_client_num( } fn increment_expected_incoming_message_from_client_num( - client_key: &ClientPublicKey, + client_principal: &ClientPrincipal, ) -> Result<(), String> { - let num = get_expected_incoming_message_from_client_num(client_key)?; + let num = get_expected_incoming_message_from_client_num(client_principal)?; INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| { let mut map = map.borrow_mut(); - map.insert(client_key.clone(), num + 1); + map.insert(client_principal.clone(), num + 1); Ok(()) }) } -fn add_client(client_key: ClientPublicKey) { +fn add_client(client_principal: ClientPrincipal, new_client: RegisteredClient) { + // insert the client in the map + insert_client(client_principal.clone(), new_client); // initialize incoming client's message sequence number to 0 - init_expected_incoming_message_from_client_num(client_key.clone()); + init_expected_incoming_message_from_client_num(client_principal.clone()); // initialize outgoing message sequence number to 0 - init_outgoing_message_to_client_num(client_key.clone()); - - // now that the client is registered, remove it from the temporary clients - remove_client_from_tmp_clients(&client_key); + init_outgoing_message_to_client_num(client_principal); } -fn remove_client(client_key: ClientPublicKey) { - CLIENT_CALLER_MAP.with(|map| { - map.borrow_mut().remove(&client_key); +fn remove_client(client_principal: ClientPrincipal) { + REGISTERED_CLIENTS.with(|map| { + map.borrow_mut().remove(&client_principal); }); OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| { - map.borrow_mut().remove(&client_key); + map.borrow_mut().remove(&client_principal); }); INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| { - map.borrow_mut().remove(&client_key); + map.borrow_mut().remove(&client_principal); }); } @@ -612,6 +479,11 @@ fn get_cert_messages(gateway_principal: Principal, nonce: u64) -> CanisterWsGetM }) } +fn is_registered_gateway(principal: Principal) -> bool { + let registered_gateway_principal = get_registered_gateway_principal(); + return registered_gateway_principal == principal; +} + /// Checks if the caller of the method is the same as the one that was registered during the initialization of the CDK fn check_is_registered_gateway(input_principal: Principal) -> Result<(), String> { let gateway_principal = get_registered_gateway_principal(); @@ -662,7 +534,7 @@ fn schedule_registered_gateway_check() { /// Checks if the registered gateway has sent a heartbeat recently. /// If not, this means that the gateway has been restarted and all clients registered have been disconnected. -/// In this case, all internal IC WebSocket CDK state is reset. +/// In this case, the internal IC WebSocket CDK state is reset. /// /// At the end, a new timer is scheduled to check again if the registered gateway has sent a heartbeat recently. fn check_registered_gateway_timer_callback() { @@ -692,27 +564,29 @@ fn check_registered_gateway_timer_callback() { /// Arguments passed to the `on_open` handler. pub struct OnOpenCallbackArgs { - pub client_key: ClientPublicKey, + pub client_principal: ClientPrincipal, } -/// Handler initialized by the canister and triggered by the CDK once the IC WebSocket connection -/// is established. +/// Handler initialized by the canister +/// and triggered by the CDK once the IC WebSocket connection is established. type OnOpenCallback = fn(OnOpenCallbackArgs); /// Arguments passed to the `on_message` handler. pub struct OnMessageCallbackArgs { - pub client_key: ClientPublicKey, + pub client_principal: ClientPrincipal, + pub is_anonymous: bool, pub message: Vec, } -/// Handler initialized by the canister and triggered by the CDK once a message is received by -/// the CDK. +/// Handler initialized by the canister +/// and triggered by the CDK once an IC WebSocket message is received. type OnMessageCallback = fn(OnMessageCallbackArgs); /// Arguments passed to the `on_close` handler. pub struct OnCloseCallbackArgs { - pub client_key: ClientPublicKey, + pub client_principal: ClientPrincipal, } -/// Handler initialized by the canister and triggered by the CDK once the WS Gateway closes the -/// IC WebSocket connection. +/// Handler initialized by the canister +/// and triggered by the CDK once the WS Gateway closes the IC WebSocket connection +/// for that client. type OnCloseCallback = fn(OnCloseCallbackArgs); /// Handlers initialized by the canister and triggered by the CDK. @@ -750,23 +624,11 @@ fn initialize_handlers(handlers: WsHandlers) { }); } -/// Checks the content signature -fn check_content_signature( - client_key: &ClientPublicKey, - content: &Vec, - sig: &Vec, -) -> Result<(), String> { - // check if client_key is a Ed25519 public key - let public_key = PublicKey::from_slice(client_key).map_err(|e| e.to_string())?; - // check if the signature relayed by the WS Gateway is a Ed25519 signature - let sig = Signature::from_slice(sig).map_err(|e| e.to_string())?; - // check if the signature on the first message verifies against the public key of the registered client - // if so, the first message came from the same client that registered its public key using ws_register - public_key.verify(content, &sig).map_err(|e| e.to_string()) -} - /// Initialize the CDK by setting the callback handlers and the **principal** of the WS Gateway that /// will be polling the canister. +/// +/// Under the hood, an interval (**60 seconds**) is started using [ic_cdk_timers::set_timer] +/// to check if the WS Gateway is still alive. pub fn init(handlers: WsHandlers, gateway_principal: &str) { // set the handlers specified by the canister that the CDK uses to manage the IC WebSocket connection initialize_handlers(handlers); @@ -778,44 +640,40 @@ pub fn init(handlers: WsHandlers, gateway_principal: &str) { schedule_registered_gateway_check(); } -/// Handles the register event received from the client. -/// -/// Registers the public key that the client SDK has generated to initialize an IcWebSocket connection. -pub fn ws_register(args: CanisterWsRegisterArguments) -> CanisterWsRegisterResult { - // TODO: check who is the caller, which can be a client or the anonymous principal - - // associate the identity of the client to its public key received as input - put_client_caller(args.client_key, caller()); - Ok(()) -} - -/// Handles the WS connection open event received from the WS Gateway -/// -/// WS Gateway relays the first message sent by the client together with its signature -/// to prove that the first message is actually coming from the same client that registered its public key -/// beforehand by calling the [ws_register] method. +/// Handles the WS connection open event sent by the client and relayed by the Gateway. pub fn ws_open(args: CanisterWsOpenArguments) -> CanisterWsOpenResult { - // the caller must be the gateway that was registered during CDK initialization - check_is_registered_gateway(caller())?; - - // decode the first message sent by the client - let CanisterOpenMessageContent { - client_key, - canister_id, - } = from_slice(&args.content).map_err(|e| e.to_string())?; + let client_principal = caller(); - // check if client registered its public key by calling ws_register - check_registered_client_key(&client_key)?; + // check if client is not registered yet + if is_client_registered(&client_principal) { + return Err(format!( + "client with principal {:?} already has an open connection", + client_principal, + )); + } - // check if the signature on the first message verifies against the public key of the registered client - check_content_signature(&client_key, &args.content, &args.sig)?; + // avoid gateway opening a connection for its own principal + if is_registered_gateway(client_principal) { + return Err(String::from( + "caller is the registered gateway, cannot open a connection", + )); + } // initialize client maps - add_client(client_key.clone()); + let new_client = RegisteredClient { + is_anonymous: args.is_anonymous, + }; + add_client(client_principal.clone(), new_client); + + // call the on_open handler initialized in init() + HANDLERS.with(|h| { + h.borrow().call_on_open(OnOpenCallbackArgs { + client_principal: client_principal.clone(), + }); + }); Ok(CanisterWsOpenResultValue { - client_key, - canister_id, + client_principal, // returns the current nonce so that in case the WS Gateway has to open a new poller for this canister // it knows which nonce to start polling from. This is needed in order to make sure that the WS Gateway // does not poll messages it has already relayed when a new it starts polling a canister @@ -829,14 +687,14 @@ pub fn ws_close(args: CanisterWsCloseArguments) -> CanisterWsCloseResult { // the caller must be the gateway that was registered during CDK initialization check_is_registered_gateway(caller())?; - // check if client registered its public key by calling ws_register - check_registered_client_key(&args.client_key)?; + // check if client registered its principal by calling ws_open + check_registered_client(&args.client_principal)?; - remove_client(args.client_key.clone()); + remove_client(args.client_principal.clone()); HANDLERS.with(|h| { h.borrow().call_on_close(OnCloseCallbackArgs { - client_key: args.client_key, + client_principal: args.client_principal, }); }); @@ -845,96 +703,57 @@ pub fn ws_close(args: CanisterWsCloseArguments) -> CanisterWsCloseResult { /// Handles the WS messages received either directly from the client or relayed by the WS Gateway. pub fn ws_message(args: CanisterWsMessageArguments) -> CanisterWsMessageResult { - match args.msg { - // message sent directly from client - CanisterIncomingMessage::DirectlyFromClient(received_message) => { - // check if the identity of the caller corresponds to the one registered for the given public key - let expected_caller = get_client_caller(&received_message.client_key).ok_or( - String::from("client is not registered, call ws_register first"), - )?; - if caller() != expected_caller { - return Err(String::from( - "caller is not the same that registered the public key", - )); - } - // call the on_message handler initialized in init() - HANDLERS.with(|h| { - // trigger the on_message handler initialized by canister - h.borrow().call_on_message(OnMessageCallbackArgs { - client_key: received_message.client_key, - message: received_message.message, - }); - }); - Ok(()) - }, - // WS Gateway relays a message from the client - CanisterIncomingMessage::RelayedByGateway(received_message) => { - // this message can come only from the registered gateway - check_is_registered_gateway(caller())?; - - // decode the message sent by the client - let WebsocketMessage { - client_key, - sequence_num, - timestamp: _timestamp, - message, - } = from_slice(&received_message.content).map_err(|e| e.to_string())?; - - // check if client registered its public key by calling ws_register - check_registered_client_key(&client_key)?; - - // check if the signature on the message verifies against the public key of the registered client - check_content_signature( - &client_key, - &received_message.content, - &received_message.sig, - )?; - - let expected_sequence_num = get_expected_incoming_message_from_client_num(&client_key)?; - - // check if the incoming message has the expected sequence number - if sequence_num == expected_sequence_num { - // increase the expected sequence number by 1 - increment_expected_incoming_message_from_client_num(&client_key)?; - // call the on_message handler initialized in init() - HANDLERS.with(|h| { - // trigger the on_message handler initialized by canister - // create message to send to client - h.borrow().call_on_message(OnMessageCallbackArgs { - client_key, - message, - }); - }); - return Ok(()); - } - Err(String::from( - "incoming client's message relayed from WS Gateway does not have the expected sequence number", + let client_principal = caller(); + // check if client registered its principal by calling ws_open + check_registered_client(&client_principal)?; + let registered_client = match get_registered_client(&client_principal) { + Some(v) => v, + None => { + return Err(String::from( + "client with principal {:?} doesn't have an open connection", )) }, - // WS Gateway notifies the canister of the established IC WebSocket connection - CanisterIncomingMessage::IcWebSocketEstablished(client_key) => { - // this message can come only from the registered gateway - check_is_registered_gateway(caller())?; - - // check if client registered its public key by calling ws_register - check_registered_client_key(&client_key)?; - - custom_print!("Can start notifying client with key: {:?}", client_key); - // call the on_open handler - HANDLERS.with(|h| { - // trigger the on_open handler initialized by canister - h.borrow().call_on_open(OnOpenCallbackArgs { client_key }); - }); - Ok(()) - }, - // WS Gateway notifies the canister that it is up and running - CanisterIncomingMessage::IcWebSocketGatewayStatus(gateway_status) => { - // this message can come only from the registered gateway - check_is_registered_gateway(caller())?; + }; - update_registered_gateway_status_index(gateway_status.status_index) - }, + let WebsocketMessage { + client_principal: _, + sequence_num, + timestamp: _, + content, + } = args.msg; + + let expected_sequence_num = get_expected_incoming_message_from_client_num(&client_principal)?; + + // check if the incoming message has the expected sequence number + if sequence_num != expected_sequence_num { + return Err(String::from( + "incoming client's message relayed from WS Gateway does not have the expected sequence number", + )); } + // increase the expected sequence number by 1 + increment_expected_incoming_message_from_client_num(&client_principal)?; + + // call the on_message handler initialized in init() + HANDLERS.with(|h| { + // trigger the on_message handler initialized by canister + // create message to send to client + h.borrow().call_on_message(OnMessageCallbackArgs { + client_principal, + is_anonymous: registered_client.is_anonymous(), + message: content, + }); + }); + Ok(()) +} + +/// Used by the WS Gateway to update its status on the canister. +/// This way, the canister can check if the WS Gateway is still alive. +pub fn ws_status(args: CanisterWsStatusArguments) -> CanisterWsStatusResult { + // check if the caller of this method is the WS Gateway that has been set during the initialization of the SDK + let gateway_principal = caller(); + check_is_registered_gateway(gateway_principal)?; + + update_registered_gateway_status_index(args.status_index) } /// Returns messages to the WS Gateway in response of a polling iteration. @@ -950,9 +769,9 @@ pub fn ws_get_messages(args: CanisterWsGetMessagesArguments) -> CanisterWsGetMes /// /// Under the hood, the message is serialized and certified, and then it is added to the queue of messages /// that the WS Gateway will poll in the next iteration. -pub fn ws_send(client_key: ClientPublicKey, msg_bytes: Vec) -> CanisterWsSendResult { +pub fn ws_send(client_principal: ClientPrincipal, msg_bytes: Vec) -> CanisterWsSendResult { // check if the client is registered - check_registered_client_key(&client_key)?; + check_registered_client(&client_principal)?; // get the principal of the gateway that is polling the canister let gateway_principal = get_registered_gateway_principal(); @@ -966,13 +785,13 @@ pub fn ws_send(client_key: ClientPublicKey, msg_bytes: Vec) -> CanisterWsSen increment_outgoing_message_nonce(); // increment the sequence number for the next message to the client - increment_outgoing_message_to_client_num(&client_key)?; + increment_outgoing_message_to_client_num(&client_principal)?; let websocket_message = WebsocketMessage { - client_key: client_key.clone(), - sequence_num: get_outgoing_message_to_client_num(&client_key)?, + client_principal: client_principal.clone(), + sequence_num: get_outgoing_message_to_client_num(&client_principal)?, timestamp: get_current_time(), - message: msg_bytes, + content: msg_bytes, }; // CBOR serialize message of type WebsocketMessage @@ -986,7 +805,7 @@ pub fn ws_send(client_key: ClientPublicKey, msg_bytes: Vec) -> CanisterWsSen // (from beginning to end of the queue) as ws_send is called sequentially, the nonce // is incremented by one in each call, and the message is pushed at the end of the queue m.borrow_mut().push_back(CanisterOutputMessage { - client_key, + client_principal, content, key, }); @@ -998,19 +817,18 @@ pub fn ws_send(client_key: ClientPublicKey, msg_bytes: Vec) -> CanisterWsSen mod test { use super::*; use proptest::prelude::*; - use ring::signature::KeyPair; mod test_utils { use candid::Principal; use ic_agent::{identity::BasicIdentity, Identity}; - use ring::signature::{Ed25519KeyPair, KeyPair}; + use ring::signature::Ed25519KeyPair; use crate::{ - get_message_for_gateway_key, CanisterOutputMessage, ClientPublicKey, + get_message_for_gateway_key, CanisterOutputMessage, ClientPrincipal, RegisteredClient, MESSAGES_FOR_GATEWAY, }; - pub fn generate_random_key_pair() -> Ed25519KeyPair { + fn generate_random_key_pair() -> Ed25519KeyPair { let rng = ring::rand::SystemRandom::new(); let key_pair = Ed25519KeyPair::generate_pkcs8(&rng).expect("Could not generate a key pair."); @@ -1025,10 +843,10 @@ mod test { candid::Principal::from_text(identity.sender().unwrap().to_text()).unwrap() } - pub fn generate_random_public_key() -> Vec { - let key_pair = generate_random_key_pair(); - - key_pair.public_key().as_ref().to_vec() + pub fn generate_random_registered_client() -> RegisteredClient { + RegisteredClient { + is_anonymous: false, + } } pub fn get_static_principal() -> Principal { @@ -1037,14 +855,14 @@ mod test { } pub fn add_messages_for_gateway( - client_key: ClientPublicKey, + client_principal: ClientPrincipal, gateway_principal: Principal, count: u64, ) { MESSAGES_FOR_GATEWAY.with(|m| { for i in 0..count { m.borrow_mut().push_back(CanisterOutputMessage { - client_key: client_key.clone(), + client_principal: client_principal.clone(), key: get_message_for_gateway_key(gateway_principal.clone(), i), content: vec![], }); @@ -1102,14 +920,15 @@ mod test { assert!(h.on_close.is_none()); h.call_on_open(OnOpenCallbackArgs { - client_key: test_utils::generate_random_public_key(), + client_principal: test_utils::generate_random_principal(), }); h.call_on_message(OnMessageCallbackArgs { - client_key: test_utils::generate_random_public_key(), + client_principal: test_utils::generate_random_principal(), + is_anonymous: false, // doesn't matter message: vec![], }); h.call_on_close(OnCloseCallbackArgs { - client_key: test_utils::generate_random_public_key(), + client_principal: test_utils::generate_random_principal(), }); }); @@ -1154,14 +973,15 @@ mod test { assert!(h.on_close.is_some()); h.call_on_open(OnOpenCallbackArgs { - client_key: test_utils::generate_random_public_key(), + client_principal: test_utils::generate_random_principal(), }); h.call_on_message(OnMessageCallbackArgs { - client_key: test_utils::generate_random_public_key(), + client_principal: test_utils::generate_random_principal(), + is_anonymous: false, // doesn't matter message: vec![], }); h.call_on_close(OnCloseCallbackArgs { - client_key: test_utils::generate_random_public_key(), + client_principal: test_utils::generate_random_principal(), }); }); @@ -1226,30 +1046,28 @@ mod test { } #[test] - fn test_get_client_caller(test_client_key in any::().prop_map(|_| test_utils::generate_random_public_key())) { + fn test_get_registered_client(test_client_principal in any::().prop_map(|_| test_utils::generate_random_principal())) { // Set up - let caller_principal = test_utils::generate_random_principal(); - CLIENT_CALLER_MAP.with(|map| { - map.borrow_mut().insert(test_client_key.clone(), caller_principal); + let registered_client = test_utils::generate_random_registered_client(); + REGISTERED_CLIENTS.with(|map| { + map.borrow_mut().insert(test_client_principal.clone(), registered_client.clone()); }); - let actual_client_caller = get_client_caller(&test_client_key); - prop_assert_eq!(actual_client_caller, Some(caller_principal)); - let actual_client_caller = get_client_caller(&test_utils::generate_random_public_key()); - prop_assert_eq!(actual_client_caller, None); + let actual_client = get_registered_client(&test_client_principal); + prop_assert_eq!(actual_client, Some(registered_client)); + let actual_client = get_registered_client(&test_utils::generate_random_principal()); + prop_assert_eq!(actual_client, None); } #[test] - fn test_put_client_caller(test_client_key in any::().prop_map(|_| test_utils::generate_random_public_key())) { + fn test_insert_client(test_client_principal in any::().prop_map(|_| test_utils::generate_random_principal())) { // Set up - let caller_principal = test_utils::generate_random_principal(); + let registered_client = test_utils::generate_random_registered_client(); - put_client_caller(test_client_key.clone(), caller_principal); + insert_client(test_client_principal.clone(), registered_client.clone()); - let actual_client = CLIENT_CALLER_MAP.with(|map| map.borrow().get(&test_client_key).unwrap().clone()); - prop_assert_eq!(actual_client, caller_principal); - let actual_result = TMP_CLIENTS.with(|map| map.borrow().last_index_clients.get(&test_client_key).unwrap().clone()); - prop_assert_eq!(actual_result, test_client_key); + let actual_client = REGISTERED_CLIENTS.with(|map| map.borrow().get(&test_client_principal).unwrap().clone()); + prop_assert_eq!(actual_client, registered_client); } #[test] @@ -1262,13 +1080,10 @@ mod test { } #[test] - fn test_update_registered_gateway_status_index(test_gateway_principal in any::().prop_map(|_| test_utils::generate_random_principal()), test_client_key in any::().prop_map(|_| test_utils::generate_random_public_key())) { + fn test_update_registered_gateway_status_index(test_gateway_principal in any::().prop_map(|_| test_utils::generate_random_principal()), test_client_principal in any::().prop_map(|_| test_utils::generate_random_principal())) { // Set up - CLIENT_CALLER_MAP.with(|map| { - map.borrow_mut().insert(test_client_key.clone(), test_utils::generate_random_principal()); - }); - TMP_CLIENTS.with(|map| { - map.borrow_mut().last_index_clients.insert(test_client_key.clone()); + REGISTERED_CLIENTS.with(|map| { + map.borrow_mut().insert(test_client_principal.clone(), test_utils::generate_random_registered_client()); }); REGISTERED_GATEWAY.with(|p| *p.borrow_mut() = Some(RegisteredGateway::new(test_gateway_principal.clone()))); @@ -1276,8 +1091,6 @@ mod test { let _ = update_registered_gateway_status_index(2); let actual_status_index = REGISTERED_GATEWAY.with(|p| p.borrow().as_ref().unwrap().last_status_index.clone()); prop_assert_eq!(actual_status_index, 2); - let actual_result = TMP_CLIENTS.with(|map| map.borrow().second_last_index_clients.get(&test_client_key).unwrap().clone()); - prop_assert_eq!(actual_result, test_client_key.clone()); // test with an invalid status index (behind the current one) let actual_result = update_registered_gateway_status_index(1); @@ -1295,158 +1108,149 @@ mod test { let _ = update_registered_gateway_status_index(10); let actual_status_index = REGISTERED_GATEWAY.with(|p| p.borrow().as_ref().unwrap().last_status_index.clone()); prop_assert_eq!(actual_status_index, 10); - let actual_result = TMP_CLIENTS.with(|map| map.borrow().second_last_index_clients.get(&test_client_key).is_none()); - prop_assert!(actual_result); // reset the registered gateway - let new_client_key = test_utils::generate_random_public_key(); - CLIENT_CALLER_MAP.with(|map| { - map.borrow_mut().insert(new_client_key.clone(), test_utils::generate_random_principal()); - }); - TMP_CLIENTS.with(|map| { - map.borrow_mut().last_index_clients.insert(new_client_key.clone()); + let new_client_principal = test_utils::generate_random_principal(); + REGISTERED_CLIENTS.with(|map| { + map.borrow_mut().insert(new_client_principal.clone(), test_utils::generate_random_registered_client()); }); let _ = update_registered_gateway_status_index(0); let actual_status_index = REGISTERED_GATEWAY.with(|p| p.borrow().as_ref().unwrap().last_status_index.clone()); prop_assert_eq!(actual_status_index, 0); - let actual_result = TMP_CLIENTS.with(|map| map.borrow().last_index_clients.get(&new_client_key).unwrap().clone()); - prop_assert_eq!(actual_result, new_client_key.clone()); - let actual_result = CLIENT_CALLER_MAP.with(|map| { + let actual_result = REGISTERED_CLIENTS.with(|map| { let map = map.borrow(); - map.get(&test_client_key).is_none() && map.get(&new_client_key).is_some() + map.get(&test_client_principal).is_none() && map.get(&new_client_principal).is_none() }); prop_assert!(actual_result); } #[test] - fn test_check_registered_client_key_empty(test_client_key in any::().prop_map(|_| test_utils::generate_random_public_key())) { - let actual_result = check_registered_client_key(&test_client_key); - prop_assert_eq!(actual_result.err(), Some(String::from("client's public key has not been previously registered by client"))); + fn test_check_registered_client_principal_empty(test_client_principal in any::().prop_map(|_| test_utils::generate_random_principal())) { + let actual_result = check_registered_client(&test_client_principal); + prop_assert_eq!(actual_result.err(), Some(format!("client with principal {:?} doesn't have an open connection", test_client_principal))); } #[test] - fn test_check_registered_client_key(test_client_key in any::().prop_map(|_| test_utils::generate_random_public_key())) { + fn test_check_registered_client_principal(test_client_principal in any::().prop_map(|_| test_utils::generate_random_principal())) { // Set up - CLIENT_CALLER_MAP.with(|map| { - map.borrow_mut().insert(test_client_key.clone(), test_utils::generate_random_principal()); + REGISTERED_CLIENTS.with(|map| { + map.borrow_mut().insert(test_client_principal.clone(), test_utils::generate_random_registered_client()); }); - let actual_result = check_registered_client_key(&test_client_key); + let actual_result = check_registered_client(&test_client_principal); prop_assert!(actual_result.is_ok()); - let actual_result = check_registered_client_key(&test_utils::generate_random_public_key()); - prop_assert_eq!(actual_result.err(), Some(String::from("client's public key has not been previously registered by client"))); + let non_existing_client_principal = test_utils::generate_random_principal(); + let actual_result = check_registered_client(&non_existing_client_principal); + prop_assert_eq!(actual_result.err(), Some(format!("client with principal {:?} doesn't have an open connection", non_existing_client_principal))); } #[test] - fn test_init_outgoing_message_to_client_num(test_client_key in any::().prop_map(|_| test_utils::generate_random_public_key())) { - init_outgoing_message_to_client_num(test_client_key.clone()); + fn test_init_outgoing_message_to_client_num(test_client_principal in any::().prop_map(|_| test_utils::generate_random_principal())) { + init_outgoing_message_to_client_num(test_client_principal.clone()); - let actual_result = OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| map.borrow().get(&test_client_key).unwrap().clone()); + let actual_result = OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| map.borrow().get(&test_client_principal).unwrap().clone()); prop_assert_eq!(actual_result, 0); } #[test] - fn test_increment_outgoing_message_to_client_num(test_client_key in any::().prop_map(|_| test_utils::generate_random_public_key()), test_num in any::()) { + fn test_increment_outgoing_message_to_client_num(test_client_principal in any::().prop_map(|_| test_utils::generate_random_principal()), test_num in any::()) { // Set up OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| { - map.borrow_mut().insert(test_client_key.clone(), test_num); + map.borrow_mut().insert(test_client_principal.clone(), test_num); }); - let increment_result = increment_outgoing_message_to_client_num(&test_client_key); + let increment_result = increment_outgoing_message_to_client_num(&test_client_principal); prop_assert!(increment_result.is_ok()); - let actual_result = OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| map.borrow().get(&test_client_key).unwrap().clone()); + let actual_result = OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| map.borrow().get(&test_client_principal).unwrap().clone()); prop_assert_eq!(actual_result, test_num + 1); } #[test] - fn test_get_outgoing_message_to_client_num(test_client_key in any::().prop_map(|_| test_utils::generate_random_public_key()), test_num in any::()) { + fn test_get_outgoing_message_to_client_num(test_client_principal in any::().prop_map(|_| test_utils::generate_random_principal()), test_num in any::()) { // Set up OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| { - map.borrow_mut().insert(test_client_key.clone(), test_num); + map.borrow_mut().insert(test_client_principal.clone(), test_num); }); - let actual_result = get_outgoing_message_to_client_num(&test_client_key); + let actual_result = get_outgoing_message_to_client_num(&test_client_principal); prop_assert!(actual_result.is_ok()); prop_assert_eq!(actual_result.unwrap(), test_num); } #[test] - fn test_init_expected_incoming_message_from_client_num(test_client_key in any::().prop_map(|_| test_utils::generate_random_public_key())) { - init_expected_incoming_message_from_client_num(test_client_key.clone()); + fn test_init_expected_incoming_message_from_client_num(test_client_principal in any::().prop_map(|_| test_utils::generate_random_principal())) { + init_expected_incoming_message_from_client_num(test_client_principal.clone()); - let actual_result = INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| map.borrow().get(&test_client_key).unwrap().clone()); + let actual_result = INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| map.borrow().get(&test_client_principal).unwrap().clone()); prop_assert_eq!(actual_result, 0); } #[test] - fn test_get_expected_incoming_message_from_client_num(test_client_key in any::().prop_map(|_| test_utils::generate_random_public_key()), test_num in any::()) { + fn test_get_expected_incoming_message_from_client_num(test_client_principal in any::().prop_map(|_| test_utils::generate_random_principal()), test_num in any::()) { // Set up INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| { - map.borrow_mut().insert(test_client_key.clone(), test_num); + map.borrow_mut().insert(test_client_principal.clone(), test_num); }); - let actual_result = get_expected_incoming_message_from_client_num(&test_client_key); + let actual_result = get_expected_incoming_message_from_client_num(&test_client_principal); prop_assert!(actual_result.is_ok()); prop_assert_eq!(actual_result.unwrap(), test_num); } #[test] - fn test_increment_expected_incoming_message_from_client_num(test_client_key in any::().prop_map(|_| test_utils::generate_random_public_key()), test_num in any::()) { + fn test_increment_expected_incoming_message_from_client_num(test_client_principal in any::().prop_map(|_| test_utils::generate_random_principal()), test_num in any::()) { // Set up INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| { - map.borrow_mut().insert(test_client_key.clone(), test_num); + map.borrow_mut().insert(test_client_principal.clone(), test_num); }); - let increment_result = increment_expected_incoming_message_from_client_num(&test_client_key); + let increment_result = increment_expected_incoming_message_from_client_num(&test_client_principal); prop_assert!(increment_result.is_ok()); - let actual_result = INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| map.borrow().get(&test_client_key).unwrap().clone()); + let actual_result = INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| map.borrow().get(&test_client_principal).unwrap().clone()); prop_assert_eq!(actual_result, test_num + 1); } #[test] - fn test_add_client(test_client_key in any::().prop_map(|_| test_utils::generate_random_public_key())) { - // Set up - TMP_CLIENTS.with(|map| { - map.borrow_mut().last_index_clients.insert(test_client_key.clone()); - }); + fn test_add_client(test_client_principal in any::().prop_map(|_| test_utils::generate_random_principal())) { + let registered_client = test_utils::generate_random_registered_client(); // Test - add_client(test_client_key.clone()); + add_client(test_client_principal.clone(), registered_client.clone()); - let actual_result = INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| map.borrow().get(&test_client_key).unwrap().clone()); - prop_assert_eq!(actual_result, 0); + let actual_result = REGISTERED_CLIENTS.with(|map| map.borrow().get(&test_client_principal).unwrap().clone()); + prop_assert_eq!(actual_result, registered_client); - let actual_result = OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| map.borrow().get(&test_client_key).unwrap().clone()); + let actual_result = INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| map.borrow().get(&test_client_principal).unwrap().clone()); prop_assert_eq!(actual_result, 0); - let actual_result = TMP_CLIENTS.with(|map| map.borrow().last_index_clients.get(&test_client_key).is_none()); - prop_assert!(actual_result); + let actual_result = OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| map.borrow().get(&test_client_principal).unwrap().clone()); + prop_assert_eq!(actual_result, 0); } #[test] - fn test_remove_client(test_client_key in any::().prop_map(|_| test_utils::generate_random_public_key())) { + fn test_remove_client(test_client_principal in any::().prop_map(|_| test_utils::generate_random_principal())) { // Set up - CLIENT_CALLER_MAP.with(|map| { - map.borrow_mut().insert(test_client_key.clone(), test_utils::generate_random_principal()); + REGISTERED_CLIENTS.with(|map| { + map.borrow_mut().insert(test_client_principal.clone(), test_utils::generate_random_registered_client()); }); INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| { - map.borrow_mut().insert(test_client_key.clone(), 0); + map.borrow_mut().insert(test_client_principal.clone(), 0); }); OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| { - map.borrow_mut().insert(test_client_key.clone(), 0); + map.borrow_mut().insert(test_client_principal.clone(), 0); }); - remove_client(test_client_key.clone()); + remove_client(test_client_principal.clone()); - let is_none = CLIENT_CALLER_MAP.with(|map| map.borrow().get(&test_client_key).is_none()); + let is_none = REGISTERED_CLIENTS.with(|map| map.borrow().get(&test_client_principal).is_none()); prop_assert!(is_none); - let is_none = INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| map.borrow().get(&test_client_key).is_none()); + let is_none = INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| map.borrow().get(&test_client_principal).is_none()); prop_assert!(is_none); - let is_none = OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| map.borrow().get(&test_client_key).is_none()); + let is_none = OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| map.borrow().get(&test_client_principal).is_none()); prop_assert!(is_none); } @@ -1477,8 +1281,8 @@ mod test { REGISTERED_GATEWAY.with(|p| *p.borrow_mut() = Some(RegisteredGateway::new(gateway_principal.clone()))); let messages_count = 4; - let test_client_key = test_utils::generate_random_public_key(); - test_utils::add_messages_for_gateway(test_client_key.clone(), gateway_principal, messages_count); + let test_client_principal = test_utils::generate_random_principal(); + test_utils::add_messages_for_gateway(test_client_principal.clone(), gateway_principal, messages_count); // Test // messages are just 4, so we don't exceed the max number of returned messages @@ -1499,8 +1303,8 @@ mod test { REGISTERED_GATEWAY.with(|p| *p.borrow_mut() = Some(RegisteredGateway::new(gateway_principal.clone()))); let messages_count: u64 = (2 * MAX_NUMBER_OF_RETURNED_MESSAGES).try_into().unwrap(); - let test_client_key = test_utils::generate_random_public_key(); - test_utils::add_messages_for_gateway(test_client_key.clone(), gateway_principal, messages_count); + let test_client_principal = test_utils::generate_random_principal(); + test_utils::add_messages_for_gateway(test_client_principal.clone(), gateway_principal, messages_count); // Test // messages are now MAX_NUMBER_OF_RETURNED_MESSAGES @@ -1524,8 +1328,8 @@ mod test { // Set up REGISTERED_GATEWAY.with(|p| *p.borrow_mut() = Some(RegisteredGateway::new(gateway_principal.clone()))); - let test_client_key = test_utils::generate_random_public_key(); - test_utils::add_messages_for_gateway(test_client_key.clone(), gateway_principal, messages_count); + let test_client_principal = test_utils::generate_random_principal(); + test_utils::add_messages_for_gateway(test_client_principal.clone(), gateway_principal, messages_count); // Test // add one to test the out of range index @@ -1557,117 +1361,18 @@ mod test { prop_assert_eq!(actual_result.err(), Some(String::from("caller is not the gateway that has been registered during CDK initialization"))); } - #[test] - fn test_check_content_signature(test_content in any::>(), test_key_pair in any::().prop_map(|_| test_utils::generate_random_key_pair())) { - let signature = test_key_pair.sign(&test_content); - - // wrong content - let actual_result = check_content_signature(&test_key_pair.public_key().as_ref().to_vec(), &vec![0], &signature.as_ref().to_vec()); - prop_assert_eq!(actual_result.err(), Some(String::from("Signature doesn't verify"))); - - // wrong public key - let other_key_pair = test_utils::generate_random_key_pair(); - let actual_result = check_content_signature(&other_key_pair.public_key().as_ref().to_vec(), &test_content, &signature.as_ref().to_vec()); - prop_assert_eq!(actual_result.err(), Some(String::from("Signature doesn't verify"))); - - // wrong signature - let other_signature = other_key_pair.sign(&test_content); - let actual_result = check_content_signature(&test_key_pair.public_key().as_ref().to_vec(), &test_content, &other_signature.as_ref().to_vec()); - prop_assert_eq!(actual_result.err(), Some(String::from("Signature doesn't verify"))); - - // correct signature - let actual_result = check_content_signature(&test_key_pair.public_key().as_ref().to_vec(), &test_content, &signature.as_ref().to_vec()); - prop_assert!(actual_result.is_ok()); - } - #[test] fn test_serialize_websocket_message(test_msg_bytes in any::>(), test_sequence_num in any::(), test_timestamp in any::()) { // TODO: add more tests, in which we check the serialized message let websocket_message = WebsocketMessage { - client_key: test_utils::generate_random_public_key(), + client_principal: test_utils::generate_random_principal(), sequence_num: test_sequence_num, timestamp: test_timestamp, - message: test_msg_bytes, + content: test_msg_bytes, }; let serialized_message = websocket_message.cbor_serialize(); assert!(serialized_message.is_ok()); // not so useful as a test } - - #[test] - fn test_insert_in_tmp_clients(test_client_key in any::().prop_map(|_| test_utils::generate_random_public_key())) { - insert_in_tmp_clients(test_client_key.clone()); - - let actual_result = TMP_CLIENTS.with(|map| map.borrow().last_index_clients.get(&test_client_key).unwrap().clone()); - prop_assert_eq!(actual_result, test_client_key); - } - - #[test] - fn test_shift_tmp_clients(test_client_key in any::().prop_map(|_| test_utils::generate_random_public_key())) { - // Set up - TMP_CLIENTS.with(|map| { - map.borrow_mut().last_index_clients.insert(test_client_key.clone()); - }); - - // Test - // shift and check if the client is still there - shift_tmp_clients(); - - let actual_result = TMP_CLIENTS.with(|map| map.borrow().last_index_clients.get(&test_client_key).is_none()); - prop_assert!(actual_result); - let actual_result = TMP_CLIENTS.with(|map| map.borrow().second_last_index_clients.get(&test_client_key).unwrap().clone()); - prop_assert_eq!(actual_result, test_client_key.clone()); - - // add a new client to the last index clients, and check if the old one is still there - let new_client = test_utils::generate_random_public_key(); - TMP_CLIENTS.with(|map| { - map.borrow_mut().last_index_clients.insert(new_client.clone()); - }); - - // shift again and check if the old one is removed - shift_tmp_clients(); - - let actual_result = TMP_CLIENTS.with(|map| map.borrow().second_last_index_clients.get(&new_client).unwrap().clone()); - prop_assert_eq!(actual_result, new_client); - let actual_result = TMP_CLIENTS.with(|map| map.borrow().second_last_index_clients.get(&test_client_key).is_none()); - prop_assert!(actual_result); - - // shift again and check if everything is empty - shift_tmp_clients(); - - let actual_result = TMP_CLIENTS.with(|map| map.borrow().second_last_index_clients.is_empty()); - prop_assert!(actual_result); - let actual_result = TMP_CLIENTS.with(|map| map.borrow().last_index_clients.is_empty()); - prop_assert!(actual_result); - } - - #[test] - fn test_is_client_in_tmp_clients(test_client_key in any::().prop_map(|_| test_utils::generate_random_public_key())) { - // Set up - TMP_CLIENTS.with(|map| { - map.borrow_mut().last_index_clients.insert(test_client_key.clone()); - }); - - // Test - let actual_result = is_client_in_tmp_clients(&test_client_key); - prop_assert!(actual_result); - - let actual_result = is_client_in_tmp_clients(&test_utils::generate_random_public_key()); - prop_assert!(!actual_result); - } - - #[test] - fn test_remove_client_from_tmp_clients(test_client_key in any::().prop_map(|_| test_utils::generate_random_public_key())) { - // Set up - TMP_CLIENTS.with(|map| { - map.borrow_mut().last_index_clients.insert(test_client_key.clone()); - }); - - // Test - remove_client_from_tmp_clients(&test_client_key); - - let actual_result = TMP_CLIENTS.with(|map| map.borrow().last_index_clients.get(&test_client_key).is_none()); - prop_assert!(actual_result); - } } } diff --git a/src/ic-websocket-cdk/ws_types.did b/src/ic-websocket-cdk/ws_types.did index c4a5601..3942d61 100644 --- a/src/ic-websocket-cdk/ws_types.did +++ b/src/ic-websocket-cdk/ws_types.did @@ -1,30 +1,16 @@ -type ClientPublicKey = blob; +type ClientPrincipal = principal; -type DirectClientMessage = record { - client_key : ClientPublicKey; - message : blob; -}; - -type RelayedClientMessage = record { +type WebsocketMessage = record { + client_principal : ClientPrincipal; + sequence_num : nat64; + timestamp : nat64; content : blob; - sig : blob; -}; - -type GatewayStatusMessage = record { - status_index : nat64; -}; - -type CanisterIncomingMessage = variant { - DirectlyFromClient : DirectClientMessage; - RelayedByGateway : RelayedClientMessage; - IcWebSocketEstablished : ClientPublicKey; - IcWebSocketGatewayStatus : GatewayStatusMessage; }; type CanisterOutputMessage = record { - client_key : ClientPublicKey; - content : blob; + client_principal : ClientPrincipal; key : text; + content : blob; }; type CanisterOutputCertifiedMessages = record { @@ -33,23 +19,12 @@ type CanisterOutputCertifiedMessages = record { tree : blob; }; -type CanisterWsRegisterArguments = record { - client_key : ClientPublicKey; -}; - -type CanisterWsRegisterResult = variant { - Ok : null; - Err : text; -}; - type CanisterWsOpenArguments = record { - content : blob; - sig : blob; + is_anonymous : bool; }; type CanisterWsOpenResultValue = record { - client_key : ClientPublicKey; - canister_id : principal; + client_principal : ClientPrincipal; nonce : nat64; }; @@ -59,7 +34,7 @@ type CanisterWsOpenResult = variant { }; type CanisterWsCloseArguments = record { - client_key : ClientPublicKey; + client_principal : ClientPrincipal; }; type CanisterWsCloseResult = variant { @@ -68,7 +43,7 @@ type CanisterWsCloseResult = variant { }; type CanisterWsMessageArguments = record { - msg : CanisterIncomingMessage; + msg : WebsocketMessage; }; type CanisterWsMessageResult = variant { @@ -76,6 +51,15 @@ type CanisterWsMessageResult = variant { Err : text; }; +type CanisterWsStatusArguments = record { + status_index : nat64; +}; + +type CanisterWsStatusResult = variant { + Ok : null; + Err : text; +}; + type CanisterWsGetMessagesArguments = record { nonce : nat64; }; diff --git a/tests/src/canister.rs b/tests/src/canister.rs index e9e88a0..eda5b74 100644 --- a/tests/src/canister.rs +++ b/tests/src/canister.rs @@ -3,13 +3,13 @@ use ic_cdk::print; use ic_websocket_cdk::{OnCloseCallbackArgs, OnMessageCallbackArgs, OnOpenCallbackArgs}; pub fn on_open(args: OnOpenCallbackArgs) { - print(format!("Opened websocket: {:?}", args.client_key)); + print(format!("Opened websocket: {:?}", args.client_principal)); } pub fn on_message(args: OnMessageCallbackArgs) { - print(format!("Received message: {:?}", args.client_key)); + print(format!("Received message: {:?}", args.client_principal)); } pub fn on_close(args: OnCloseCallbackArgs) { - print(format!("Client {:?} disconnected", args.client_key)); + print(format!("Client {:?} disconnected", args.client_principal)); } diff --git a/tests/src/lib.rs b/tests/src/lib.rs index 444adcb..f0989ee 100644 --- a/tests/src/lib.rs +++ b/tests/src/lib.rs @@ -4,8 +4,8 @@ use canister::{on_close, on_message, on_open}; use ic_websocket_cdk::{ CanisterWsCloseArguments, CanisterWsCloseResult, CanisterWsGetMessagesArguments, CanisterWsGetMessagesResult, CanisterWsMessageArguments, CanisterWsMessageResult, - CanisterWsOpenArguments, CanisterWsOpenResult, CanisterWsRegisterArguments, - CanisterWsRegisterResult, CanisterWsSendResult, ClientPublicKey, WsHandlers, + CanisterWsOpenArguments, CanisterWsOpenResult, CanisterWsSendResult, CanisterWsStatusArguments, + CanisterWsStatusResult, ClientPrincipal, WsHandlers, }; mod canister; @@ -26,12 +26,6 @@ fn post_upgrade(gateway_principal: String) { init(gateway_principal); } -// method called by the client SDK when instantiating a new IcWebSocket -#[update] -fn ws_register(args: CanisterWsRegisterArguments) -> CanisterWsRegisterResult { - ic_websocket_cdk::ws_register(args) -} - // method called by the WS Gateway after receiving FirstMessage from the client #[update] fn ws_open(args: CanisterWsOpenArguments) -> CanisterWsOpenResult { @@ -50,6 +44,12 @@ fn ws_message(args: CanisterWsMessageArguments) -> CanisterWsMessageResult { ic_websocket_cdk::ws_message(args) } +// method called by the WS Gateway to update its status in the canister +#[update] +fn ws_status(args: CanisterWsStatusArguments) -> CanisterWsStatusResult { + ic_websocket_cdk::ws_status(args) +} + // method called by the WS Gateway to get messages for all the clients it serves #[query] fn ws_get_messages(args: CanisterWsGetMessagesArguments) -> CanisterWsGetMessagesResult { @@ -65,6 +65,6 @@ fn ws_wipe() { // send a message to the client, usually called by the canister itself #[update] -fn ws_send(client_key: ClientPublicKey, msg_bytes: Vec) -> CanisterWsSendResult { - ic_websocket_cdk::ws_send(client_key, msg_bytes) +fn ws_send(client_principal: ClientPrincipal, msg_bytes: Vec) -> CanisterWsSendResult { + ic_websocket_cdk::ws_send(client_principal, msg_bytes) } diff --git a/tests/test_canister.did b/tests/test_canister.did index 62e3e61..c7610c4 100644 --- a/tests/test_canister.did +++ b/tests/test_canister.did @@ -6,13 +6,13 @@ type CanisterWsSendResult = variant { }; service : (text) -> { - "ws_register" : (CanisterWsRegisterArguments) -> (CanisterWsRegisterResult); "ws_open" : (CanisterWsOpenArguments) -> (CanisterWsOpenResult); "ws_close" : (CanisterWsCloseArguments) -> (CanisterWsCloseResult); "ws_message" : (CanisterWsMessageArguments) -> (CanisterWsMessageResult); + "ws_status" : (CanisterWsStatusArguments) -> (CanisterWsStatusResult); "ws_get_messages" : (CanisterWsGetMessagesArguments) -> (CanisterWsGetMessagesResult) query; // methods used just for debugging/testing "ws_wipe" : () -> (); - "ws_send" : (ClientPublicKey, blob) -> (CanisterWsSendResult); + "ws_send" : (ClientPrincipal, blob) -> (CanisterWsSendResult); }; From a8eaf07f6f9636eb6efae4825295b54e3ccb67b4 Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Thu, 31 Aug 2023 09:53:40 +0200 Subject: [PATCH 02/31] fix: verbose errors --- src/ic-websocket-cdk/src/lib.rs | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/ic-websocket-cdk/src/lib.rs b/src/ic-websocket-cdk/src/lib.rs index df8884c..b39092d 100644 --- a/src/ic-websocket-cdk/src/lib.rs +++ b/src/ic-websocket-cdk/src/lib.rs @@ -709,8 +709,9 @@ pub fn ws_message(args: CanisterWsMessageArguments) -> CanisterWsMessageResult { let registered_client = match get_registered_client(&client_principal) { Some(v) => v, None => { - return Err(String::from( + return Err(format!( "client with principal {:?} doesn't have an open connection", + client_principal )) }, }; @@ -726,9 +727,11 @@ pub fn ws_message(args: CanisterWsMessageArguments) -> CanisterWsMessageResult { // check if the incoming message has the expected sequence number if sequence_num != expected_sequence_num { - return Err(String::from( - "incoming client's message relayed from WS Gateway does not have the expected sequence number", - )); + return Err( + format!( + "incoming client's message does not have the expected sequence number. Expected: {expected_sequence_num}, actual: {sequence_num}", + ) + ); } // increase the expected sequence number by 1 increment_expected_incoming_message_from_client_num(&client_principal)?; From 26cf767d58d94e7879ddb44132075e59171cf3cd Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Thu, 31 Aug 2023 10:03:58 +0200 Subject: [PATCH 03/31] fix: increment outgoing sequence num after message --- src/ic-websocket-cdk/src/lib.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/ic-websocket-cdk/src/lib.rs b/src/ic-websocket-cdk/src/lib.rs index b39092d..a1dcd56 100644 --- a/src/ic-websocket-cdk/src/lib.rs +++ b/src/ic-websocket-cdk/src/lib.rs @@ -787,9 +787,6 @@ pub fn ws_send(client_principal: ClientPrincipal, msg_bytes: Vec) -> Caniste // increment the nonce for the next message increment_outgoing_message_nonce(); - // increment the sequence number for the next message to the client - increment_outgoing_message_to_client_num(&client_principal)?; - let websocket_message = WebsocketMessage { client_principal: client_principal.clone(), sequence_num: get_outgoing_message_to_client_num(&client_principal)?, @@ -797,6 +794,9 @@ pub fn ws_send(client_principal: ClientPrincipal, msg_bytes: Vec) -> Caniste content: msg_bytes, }; + // increment the sequence number for the next message to the client + increment_outgoing_message_to_client_num(&client_principal)?; + // CBOR serialize message of type WebsocketMessage let content = websocket_message.cbor_serialize()?; From 592a8bcc7bb95c8668bfa4290358a7031e5e3494 Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Thu, 31 Aug 2023 10:55:28 +0200 Subject: [PATCH 04/31] fix: sequence numbers logic --- src/ic-websocket-cdk/src/lib.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/ic-websocket-cdk/src/lib.rs b/src/ic-websocket-cdk/src/lib.rs index a1dcd56..e8a1991 100644 --- a/src/ic-websocket-cdk/src/lib.rs +++ b/src/ic-websocket-cdk/src/lib.rs @@ -379,7 +379,7 @@ fn increment_outgoing_message_to_client_num( fn init_expected_incoming_message_from_client_num(client_principal: ClientPrincipal) { INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| { - map.borrow_mut().insert(client_principal, 0); + map.borrow_mut().insert(client_principal, 1); }); } @@ -786,6 +786,8 @@ pub fn ws_send(client_principal: ClientPrincipal, msg_bytes: Vec) -> Caniste // increment the nonce for the next message increment_outgoing_message_nonce(); + // increment the sequence number for the next message to the client + increment_outgoing_message_to_client_num(&client_principal)?; let websocket_message = WebsocketMessage { client_principal: client_principal.clone(), @@ -794,9 +796,6 @@ pub fn ws_send(client_principal: ClientPrincipal, msg_bytes: Vec) -> Caniste content: msg_bytes, }; - // increment the sequence number for the next message to the client - increment_outgoing_message_to_client_num(&client_principal)?; - // CBOR serialize message of type WebsocketMessage let content = websocket_message.cbor_serialize()?; From c4b300e2685a7568b87de46eac1330edcdbcd665 Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Thu, 31 Aug 2023 22:03:38 +0200 Subject: [PATCH 05/31] fix: get nonce for gateway before on_open callback --- src/ic-websocket-cdk/src/lib.rs | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/ic-websocket-cdk/src/lib.rs b/src/ic-websocket-cdk/src/lib.rs index e8a1991..4a9c5fa 100644 --- a/src/ic-websocket-cdk/src/lib.rs +++ b/src/ic-websocket-cdk/src/lib.rs @@ -665,6 +665,16 @@ pub fn ws_open(args: CanisterWsOpenArguments) -> CanisterWsOpenResult { }; add_client(client_principal.clone(), new_client); + // returns the current nonce so that in case the WS Gateway has to open a new poller for this canister + // it knows which nonce to start polling from. This is needed in order to make sure that the WS Gateway + // does not poll messages it has already relayed when a new it starts polling a canister + // (which it might have already polled previously with another thread that was closed after the last client disconnected) + // + // it's important to get the message nonce BEFORE calling the on_open callback, + // otherwise if the developer calls the ws_send from the on_open callback, the nonce would be incremented + // and the WS Gateway would start polling from the next nonce, skipping the previous messages + let nonce = get_outgoing_message_nonce(); + // call the on_open handler initialized in init() HANDLERS.with(|h| { h.borrow().call_on_open(OnOpenCallbackArgs { @@ -674,11 +684,7 @@ pub fn ws_open(args: CanisterWsOpenArguments) -> CanisterWsOpenResult { Ok(CanisterWsOpenResultValue { client_principal, - // returns the current nonce so that in case the WS Gateway has to open a new poller for this canister - // it knows which nonce to start polling from. This is needed in order to make sure that the WS Gateway - // does not poll messages it has already relayed when a new it starts polling a canister - // (which it might have already polled previously with another thread that was closed after the last client disconnected) - nonce: get_outgoing_message_nonce(), + nonce, }) } From f1a4b232febfd25f593a17bf2f2934697fc58702 Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Tue, 5 Sep 2023 15:53:26 +0200 Subject: [PATCH 06/31] fix: first sequence number in tests --- src/ic-websocket-cdk/src/lib.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ic-websocket-cdk/src/lib.rs b/src/ic-websocket-cdk/src/lib.rs index 4a9c5fa..003917e 100644 --- a/src/ic-websocket-cdk/src/lib.rs +++ b/src/ic-websocket-cdk/src/lib.rs @@ -1191,7 +1191,7 @@ mod test { init_expected_incoming_message_from_client_num(test_client_principal.clone()); let actual_result = INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| map.borrow().get(&test_client_principal).unwrap().clone()); - prop_assert_eq!(actual_result, 0); + prop_assert_eq!(actual_result, 1); } #[test] @@ -1231,7 +1231,7 @@ mod test { prop_assert_eq!(actual_result, registered_client); let actual_result = INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| map.borrow().get(&test_client_principal).unwrap().clone()); - prop_assert_eq!(actual_result, 0); + prop_assert_eq!(actual_result, 1); let actual_result = OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| map.borrow().get(&test_client_principal).unwrap().clone()); prop_assert_eq!(actual_result, 0); From f12fc0a2c87245abbc8428b8bf66ff7f72432dd4 Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Tue, 5 Sep 2023 22:32:56 +0200 Subject: [PATCH 07/31] feat: init params, send ack to clients --- .github/workflows/tests.yml | 2 +- scripts/test_canister.sh | 2 +- src/ic-websocket-cdk/src/lib.rs | 150 ++++++++++++++++++++------------ tests/src/lib.rs | 11 ++- 4 files changed, 107 insertions(+), 58 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index d149afd..977bc1e 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -34,7 +34,7 @@ jobs: cd tests rustup target add wasm32-unknown-unknown dfx start --clean --background - IC_WS_CDK_INTEGRATION_TEST=1 npm run deploy:tests + npm run deploy:tests npm run generate - name: Run integration tests run: cd tests && npm run test:integration diff --git a/scripts/test_canister.sh b/scripts/test_canister.sh index cdf61f9..b550bb0 100755 --- a/scripts/test_canister.sh +++ b/scripts/test_canister.sh @@ -12,7 +12,7 @@ npm install # integration tests dfx start --clean --background -IC_WS_CDK_INTEGRATION_TEST=1 npm run deploy:tests +npm run deploy:tests npm run generate diff --git a/src/ic-websocket-cdk/src/lib.rs b/src/ic-websocket-cdk/src/lib.rs index 003917e..11d0624 100644 --- a/src/ic-websocket-cdk/src/lib.rs +++ b/src/ic-websocket-cdk/src/lib.rs @@ -1,4 +1,4 @@ -use candid::{CandidType, Principal}; +use candid::{CandidType, Encode, Principal}; #[cfg(not(test))] use ic_cdk::api::time; use ic_cdk::api::{caller, data_certificate, set_certified_data}; @@ -16,10 +16,10 @@ mod logger; const LABEL_WEBSOCKET: &[u8] = b"websocket"; /// The maximum number of messages returned by [ws_get_messages] at each poll. const MAX_NUMBER_OF_RETURNED_MESSAGES: usize = 10; -/// The delay between two consecutive checks if the registered gateway is still alive. -const CHECK_REGISTERED_GATEWAY_DELAY_NS: u64 = 60_000_000_000; // 60 seconds -/// (**Used for integration tests**) The delay between two consecutive checks if the registered gateway is still alive. -const CHECK_REGISTERED_GATEWAY_DELAY_NS_TEST: u64 = 15_000_000_000; // 15 seconds +/// The default delay between two consecutive checks if the registered gateway is still alive. +const DEFAULT_CHECK_REGISTERED_GATEWAY_DELAY_MS: u64 = 60_000; // 60 seconds +/// The default delay between two consecutive acknowledgements sent to the client. +const DEFAULT_SEND_ACK_DELAY_MS: u64 = 60_000; // 60 seconds pub type ClientPrincipal = Principal; @@ -169,24 +169,6 @@ fn get_current_time() -> u64 { } } -/// Returns true if the canister is running in an integration test. -/// -/// To run the canister in an integration test, -/// the `IC_WS_CDK_INTEGRATION_TEST` environment variable must be set to `1` at build time. -fn is_integration_test() -> bool { - let integration_test = option_env!("IC_WS_CDK_INTEGRATION_TEST"); - integration_test.is_some() && integration_test.unwrap() == "1" -} - -/// Returns the delay in nanoseconds between two consecutive checks if the registered gateway is still alive. -fn get_check_registered_gateway_delay_ns() -> u64 { - if is_integration_test() { - CHECK_REGISTERED_GATEWAY_DELAY_NS_TEST - } else { - CHECK_REGISTERED_GATEWAY_DELAY_NS - } -} - #[derive(Clone, Debug, Eq, PartialEq)] pub struct RegisteredClient { is_anonymous: bool, @@ -522,14 +504,11 @@ fn get_cert_for_range(first: &String, last: &String) -> (Vec, Vec) { /// Schedules a timer to check if the registered gateway has sent a heartbeat recently. /// -/// The timer delay is given by the [get_check_registered_gateway_delay_ms] function. -/// /// The timer callback is [check_registered_gateway_timer_callback]. -fn schedule_registered_gateway_check() { - set_timer( - Duration::from_nanos(get_check_registered_gateway_delay_ns()), - check_registered_gateway_timer_callback, - ); +fn schedule_registered_gateway_check(interval_ms: u64) { + set_timer(Duration::from_millis(interval_ms), move || { + check_registered_gateway_timer_callback(interval_ms) + }); } /// Checks if the registered gateway has sent a heartbeat recently. @@ -537,13 +516,14 @@ fn schedule_registered_gateway_check() { /// In this case, the internal IC WebSocket CDK state is reset. /// /// At the end, a new timer is scheduled to check again if the registered gateway has sent a heartbeat recently. -fn check_registered_gateway_timer_callback() { +fn check_registered_gateway_timer_callback(interval_ms: u64) { + let interval_ns = interval_ms * 1_000_000; REGISTERED_GATEWAY.with(|state| { let mut registered_gateway = state.borrow_mut(); if let Some(v) = registered_gateway.as_mut() { if let Some(last_heartbeat) = v.last_heartbeat { - if get_current_time() - last_heartbeat > get_check_registered_gateway_delay_ns() { - custom_print!("[timer-cb]: Registered gateway has not sent a heartbeat for more than {} seconds, resetting all internal state", get_check_registered_gateway_delay_ns() / 1_000_000_000); + if get_current_time() - last_heartbeat > interval_ns { + custom_print!("[timer-cb]: Registered gateway has not sent a heartbeat for more than {} seconds, resetting all internal state", interval_ns / 1_000_000_000); reset_internal_state(); @@ -559,7 +539,53 @@ fn check_registered_gateway_timer_callback() { } }); - schedule_registered_gateway_check(); + schedule_registered_gateway_check(interval_ms); +} + +#[derive(CandidType)] +struct AckMessageForClientContent { + last_incoming_message_num: u64, +} + +/// Schedules a timer to send an acknowledgement message to the client. +/// +/// The timer callback is [send_ack_timer_callback]. +fn schedule_send_ack_to_clients(interval_ms: u64) { + set_timer(Duration::from_millis(interval_ms), move || { + send_ack_to_clients_timer_callback(interval_ms) + }); +} + +/// Sends an acknowledgement message to the client. +/// The message contains the current incoming message sequence number for that client, +/// so that the client knows that all the messages it sent have been received by the canister. +/// +/// At the end, a new timer is scheduled to send another acknowledgement message to the client. +fn send_ack_to_clients_timer_callback(interval_ms: u64) { + REGISTERED_CLIENTS.with(|state| { + let map = state.borrow(); + for (client_principal, _) in map.iter() { + let last_incoming_message_num = + get_expected_incoming_message_from_client_num(client_principal).unwrap(); + let ack_message = AckMessageForClientContent { + last_incoming_message_num, + }; + let message_bytes = Encode!(&ack_message).unwrap(); + if let Err(e) = ws_send(*client_principal, message_bytes) { + custom_print!( + "[timer-cb]: Error sending ack message to client {:?}: {:?}", + client_principal, + e + ); + + break; + }; + } + + custom_print!("[timer-cb]: Sent ack messages to all clients"); + }); + + schedule_send_ack_to_clients(interval_ms); } /// Arguments passed to the `on_open` handler. @@ -590,7 +616,7 @@ pub struct OnCloseCallbackArgs { type OnCloseCallback = fn(OnCloseCallbackArgs); /// Handlers initialized by the canister and triggered by the CDK. -#[derive(Clone)] +#[derive(Clone, Default)] pub struct WsHandlers { pub on_open: Option, pub on_message: Option, @@ -624,20 +650,51 @@ fn initialize_handlers(handlers: WsHandlers) { }); } +/// Parameters for the IC WebSocket CDK initialization. For default parameters and simpler initialization, use [`WsInitParams::new`]. +#[derive(Clone)] +pub struct WsInitParams { + /// The callback handlers for the WebSocket. + pub handlers: WsHandlers, + /// The principal of the WS Gateway that will be polling the canister. + pub gateway_principal: String, + /// The interval at which to check if the registered gateway is still alive (in milliseconds). + /// Defaults to `60_000` (60 seconds). + pub check_registered_gateway_interval_ms: u64, + /// The interval at which to send an acknowledgement message to the client, + /// so that the client knows that all the messages it sent have been received by the canister (in milliseconds). + /// Defaults to `60_000` (60 seconds). + pub send_ack_interval_ms: u64, +} + +impl WsInitParams { + /// Creates a new instance of WsInitParams, with default interval values. + pub fn new(handlers: WsHandlers, gateway_principal: String) -> Self { + Self { + handlers, + gateway_principal, + check_registered_gateway_interval_ms: DEFAULT_CHECK_REGISTERED_GATEWAY_DELAY_MS, + send_ack_interval_ms: DEFAULT_SEND_ACK_DELAY_MS, + } + } +} + /// Initialize the CDK by setting the callback handlers and the **principal** of the WS Gateway that /// will be polling the canister. /// /// Under the hood, an interval (**60 seconds**) is started using [ic_cdk_timers::set_timer] /// to check if the WS Gateway is still alive. -pub fn init(handlers: WsHandlers, gateway_principal: &str) { +pub fn init(params: WsInitParams) { // set the handlers specified by the canister that the CDK uses to manage the IC WebSocket connection - initialize_handlers(handlers); + initialize_handlers(params.handlers); // set the principal of the (only) WS Gateway that will be polling the canister - initialize_registered_gateway(gateway_principal); + initialize_registered_gateway(¶ms.gateway_principal); // schedule a timer that will check if the registered gateway is still alive - schedule_registered_gateway_check(); + schedule_registered_gateway_check(params.check_registered_gateway_interval_ms); + + // schedule a timer that will send an acknowledgement message to clients + schedule_send_ack_to_clients(params.send_ack_interval_ms); } /// Handles the WS connection open event sent by the client and relayed by the Gateway. @@ -999,27 +1056,12 @@ mod test { assert!(CUSTOM_STATE.with(|h| h.borrow().is_on_close_called)); } - #[test] - fn test_is_integration_test() { - // test - assert_eq!(is_integration_test(), false); - } - #[test] fn test_current_time() { // test assert_eq!(get_current_time(), 0u64); } - #[test] - fn test_get_check_registered_gateway_delay() { - // test - assert_eq!( - get_check_registered_gateway_delay_ns(), - CHECK_REGISTERED_GATEWAY_DELAY_NS - ); - } - proptest! { #[test] fn test_initialize_registered_gateway(test_gateway_principal in any::().prop_map(|_| test_utils::generate_random_principal())) { diff --git a/tests/src/lib.rs b/tests/src/lib.rs index f0989ee..f1b79e6 100644 --- a/tests/src/lib.rs +++ b/tests/src/lib.rs @@ -5,7 +5,7 @@ use ic_websocket_cdk::{ CanisterWsCloseArguments, CanisterWsCloseResult, CanisterWsGetMessagesArguments, CanisterWsGetMessagesResult, CanisterWsMessageArguments, CanisterWsMessageResult, CanisterWsOpenArguments, CanisterWsOpenResult, CanisterWsSendResult, CanisterWsStatusArguments, - CanisterWsStatusResult, ClientPrincipal, WsHandlers, + CanisterWsStatusResult, ClientPrincipal, WsHandlers, WsInitParams, }; mod canister; @@ -18,7 +18,14 @@ fn init(gateway_principal: String) { on_close: Some(on_close), }; - ic_websocket_cdk::init(handlers, &gateway_principal) + let params = WsInitParams { + handlers, + gateway_principal, + check_registered_gateway_interval_ms: 15_000, + send_ack_interval_ms: 10_000, + }; + + ic_websocket_cdk::init(params) } #[post_upgrade] From 2aabd2b0e2efb8c4e39117fa18efb9fdbabaff84 Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Sun, 10 Sep 2023 19:55:29 +0200 Subject: [PATCH 08/31] wip: service messages, ws_open returns empty --- src/ic-websocket-cdk/src/lib.rs | 190 +++++++++++++++++------------- src/ic-websocket-cdk/ws_types.did | 9 +- 2 files changed, 109 insertions(+), 90 deletions(-) diff --git a/src/ic-websocket-cdk/src/lib.rs b/src/ic-websocket-cdk/src/lib.rs index 11d0624..6671cbf 100644 --- a/src/ic-websocket-cdk/src/lib.rs +++ b/src/ic-websocket-cdk/src/lib.rs @@ -24,7 +24,7 @@ const DEFAULT_SEND_ACK_DELAY_MS: u64 = 60_000; // 60 seconds pub type ClientPrincipal = Principal; /// The result of [ws_open]. -pub type CanisterWsOpenResult = Result; +pub type CanisterWsOpenResult = Result<(), String>; /// The result of [ws_close]. pub type CanisterWsCloseResult = Result<(), String>; /// The result of [ws_message]. @@ -36,13 +36,6 @@ pub type CanisterWsGetMessagesResult = Result; -/// The Ok value of CanisterWsOpenResult returned by [ws_open]. -#[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq, Debug)] -pub struct CanisterWsOpenResultValue { - client_principal: ClientPrincipal, - nonce: u64, -} - /// The arguments for [ws_open]. #[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq, Debug)] pub struct CanisterWsOpenArguments { @@ -79,6 +72,7 @@ struct WebsocketMessage { client_principal: ClientPrincipal, // The client that the gateway will forward the message to or that sent the message. sequence_num: u64, // Both ways, messages should arrive with sequence numbers 0, 1, 2... timestamp: u64, // Timestamp of when the message was made for the recipient to inspect. + is_service_message: bool, // Whether the message is a service message sent by the CDK to the client or vice versa. #[serde(with = "serde_bytes")] content: Vec, // Application message encoded in binary. } @@ -523,19 +517,19 @@ fn check_registered_gateway_timer_callback(interval_ms: u64) { if let Some(v) = registered_gateway.as_mut() { if let Some(last_heartbeat) = v.last_heartbeat { if get_current_time() - last_heartbeat > interval_ns { - custom_print!("[timer-cb]: Registered gateway has not sent a heartbeat for more than {} seconds, resetting all internal state", interval_ns / 1_000_000_000); + custom_print!("[registered-gateway-timer-cb]: Registered gateway has not sent a heartbeat for more than {} seconds, resetting all internal state", interval_ns / 1_000_000_000); reset_internal_state(); v.reset(); } else { - custom_print!("[timer-cb]: Registered gateway is still alive"); + custom_print!("[registered-gateway-timer-cb]: Registered gateway is still alive"); } } else { - custom_print!("[timer-cb]: Registered gateway has not sent a heartbeat yet"); + custom_print!("[registered-gateway-timer-cb]: Registered gateway has not sent a heartbeat yet"); } } else { - custom_print!("[timer-cb]: No registered gateway"); + custom_print!("[registered-gateway-timer-cb]: No registered gateway"); } }); @@ -543,13 +537,33 @@ fn check_registered_gateway_timer_callback(interval_ms: u64) { } #[derive(CandidType)] -struct AckMessageForClientContent { - last_incoming_message_num: u64, +struct CanisterOpenMessageContent { + client_principal: ClientPrincipal, +} + +#[derive(CandidType)] +struct CanisterAckMessageContent { + last_incoming_sequence_num: u64, +} + +/// A service message sent by the CDK to the client. +#[derive(CandidType)] +enum CanisterServiceMessage { + OpenMessage(CanisterOpenMessageContent), + AckMessage(CanisterAckMessageContent), +} + +fn send_service_message_to_client( + client_principal: ClientPrincipal, + message: CanisterServiceMessage, +) -> Result<(), String> { + let message_bytes = Encode!(&message).unwrap(); + _ws_send(client_principal, message_bytes, true) } /// Schedules a timer to send an acknowledgement message to the client. /// -/// The timer callback is [send_ack_timer_callback]. +/// The timer callback is [send_ack_to_clients_timer_callback]. fn schedule_send_ack_to_clients(interval_ms: u64) { set_timer(Duration::from_millis(interval_ms), move || { send_ack_to_clients_timer_callback(interval_ms) @@ -565,29 +579,81 @@ fn send_ack_to_clients_timer_callback(interval_ms: u64) { REGISTERED_CLIENTS.with(|state| { let map = state.borrow(); for (client_principal, _) in map.iter() { - let last_incoming_message_num = - get_expected_incoming_message_from_client_num(client_principal).unwrap(); - let ack_message = AckMessageForClientContent { - last_incoming_message_num, - }; - let message_bytes = Encode!(&ack_message).unwrap(); - if let Err(e) = ws_send(*client_principal, message_bytes) { - custom_print!( - "[timer-cb]: Error sending ack message to client {:?}: {:?}", - client_principal, - e - ); - - break; - }; + // ignore the error, which shouldn't happen since the client is registered and the sequence number is initialized + if let Ok(last_incoming_message_sequence_num) = + get_expected_incoming_message_from_client_num(client_principal) + { + let ack_message = CanisterAckMessageContent { + last_incoming_sequence_num: last_incoming_message_sequence_num, + }; + let message = CanisterServiceMessage::AckMessage(ack_message); + if let Err(e) = send_service_message_to_client(*client_principal, message) { + custom_print!( + "[ack-to-clients-timer-cb]: Error sending ack message to client {:?}: {:?}", + client_principal, + e + ); + + break; + }; + } } - custom_print!("[timer-cb]: Sent ack messages to all clients"); + custom_print!("[ack-to-clients-timer-cb]: Sent ack messages to all clients"); }); schedule_send_ack_to_clients(interval_ms); } +/// Internal function used to put the messages in the outgoing messages queue and certify them. +fn _ws_send( + client_principal: ClientPrincipal, + msg_bytes: Vec, + is_service_message: bool, +) -> CanisterWsSendResult { + // check if the client is registered + check_registered_client(&client_principal)?; + + // get the principal of the gateway that is polling the canister + let gateway_principal = get_registered_gateway_principal(); + + // the nonce in key is used by the WS Gateway to determine the message to start in the polling iteration + // the key is also passed to the client in order to validate the body of the certified message + let outgoing_message_nonce = get_outgoing_message_nonce(); + let key = get_message_for_gateway_key(gateway_principal, outgoing_message_nonce); + + // increment the nonce for the next message + increment_outgoing_message_nonce(); + // increment the sequence number for the next message to the client + increment_outgoing_message_to_client_num(&client_principal)?; + + let websocket_message = WebsocketMessage { + client_principal: client_principal.clone(), + sequence_num: get_outgoing_message_to_client_num(&client_principal)?, + timestamp: get_current_time(), + is_service_message, + content: msg_bytes, + }; + + // CBOR serialize message of type WebsocketMessage + let content = websocket_message.cbor_serialize()?; + + // certify data + put_cert_for_message(key.clone(), &content); + + MESSAGES_FOR_GATEWAY.with(|m| { + // messages in the queue are inserted with contiguous and increasing nonces + // (from beginning to end of the queue) as ws_send is called sequentially, the nonce + // is incremented by one in each call, and the message is pushed at the end of the queue + m.borrow_mut().push_back(CanisterOutputMessage { + client_principal, + content, + key, + }); + }); + Ok(()) +} + /// Arguments passed to the `on_open` handler. pub struct OnOpenCallbackArgs { pub client_principal: ClientPrincipal, @@ -722,15 +788,11 @@ pub fn ws_open(args: CanisterWsOpenArguments) -> CanisterWsOpenResult { }; add_client(client_principal.clone(), new_client); - // returns the current nonce so that in case the WS Gateway has to open a new poller for this canister - // it knows which nonce to start polling from. This is needed in order to make sure that the WS Gateway - // does not poll messages it has already relayed when a new it starts polling a canister - // (which it might have already polled previously with another thread that was closed after the last client disconnected) - // - // it's important to get the message nonce BEFORE calling the on_open callback, - // otherwise if the developer calls the ws_send from the on_open callback, the nonce would be incremented - // and the WS Gateway would start polling from the next nonce, skipping the previous messages - let nonce = get_outgoing_message_nonce(); + let open_message = CanisterOpenMessageContent { + client_principal: client_principal.clone(), + }; + let message = CanisterServiceMessage::OpenMessage(open_message); + send_service_message_to_client(client_principal.clone(), message)?; // call the on_open handler initialized in init() HANDLERS.with(|h| { @@ -739,10 +801,7 @@ pub fn ws_open(args: CanisterWsOpenArguments) -> CanisterWsOpenResult { }); }); - Ok(CanisterWsOpenResultValue { - client_principal, - nonce, - }) + Ok(()) } /// Handles the WS connection close event received from the WS Gateway. @@ -783,6 +842,7 @@ pub fn ws_message(args: CanisterWsMessageArguments) -> CanisterWsMessageResult { client_principal: _, sequence_num, timestamp: _, + is_service_message: _, // TODO: handle service messages content, } = args.msg; @@ -836,46 +896,7 @@ pub fn ws_get_messages(args: CanisterWsGetMessagesArguments) -> CanisterWsGetMes /// Under the hood, the message is serialized and certified, and then it is added to the queue of messages /// that the WS Gateway will poll in the next iteration. pub fn ws_send(client_principal: ClientPrincipal, msg_bytes: Vec) -> CanisterWsSendResult { - // check if the client is registered - check_registered_client(&client_principal)?; - - // get the principal of the gateway that is polling the canister - let gateway_principal = get_registered_gateway_principal(); - - // the nonce in key is used by the WS Gateway to determine the message to start in the polling iteration - // the key is also passed to the client in order to validate the body of the certified message - let outgoing_message_nonce = get_outgoing_message_nonce(); - let key = get_message_for_gateway_key(gateway_principal, outgoing_message_nonce); - - // increment the nonce for the next message - increment_outgoing_message_nonce(); - // increment the sequence number for the next message to the client - increment_outgoing_message_to_client_num(&client_principal)?; - - let websocket_message = WebsocketMessage { - client_principal: client_principal.clone(), - sequence_num: get_outgoing_message_to_client_num(&client_principal)?, - timestamp: get_current_time(), - content: msg_bytes, - }; - - // CBOR serialize message of type WebsocketMessage - let content = websocket_message.cbor_serialize()?; - - // certify data - put_cert_for_message(key.clone(), &content); - - MESSAGES_FOR_GATEWAY.with(|m| { - // messages in the queue are inserted with contiguous and increasing nonces - // (from beginning to end of the queue) as ws_send is called sequentially, the nonce - // is incremented by one in each call, and the message is pushed at the end of the queue - m.borrow_mut().push_back(CanisterOutputMessage { - client_principal, - content, - key, - }); - }); - Ok(()) + _ws_send(client_principal, msg_bytes, false) } #[cfg(test)] @@ -1418,6 +1439,7 @@ mod test { client_principal: test_utils::generate_random_principal(), sequence_num: test_sequence_num, timestamp: test_timestamp, + is_service_message: false, content: test_msg_bytes, }; diff --git a/src/ic-websocket-cdk/ws_types.did b/src/ic-websocket-cdk/ws_types.did index 3942d61..ff64e73 100644 --- a/src/ic-websocket-cdk/ws_types.did +++ b/src/ic-websocket-cdk/ws_types.did @@ -4,6 +4,7 @@ type WebsocketMessage = record { client_principal : ClientPrincipal; sequence_num : nat64; timestamp : nat64; + is_service_message : bool; content : blob; }; @@ -21,15 +22,11 @@ type CanisterOutputCertifiedMessages = record { type CanisterWsOpenArguments = record { is_anonymous : bool; -}; - -type CanisterWsOpenResultValue = record { - client_principal : ClientPrincipal; - nonce : nat64; + gateway_principal : principal; }; type CanisterWsOpenResult = variant { - Ok : CanisterWsOpenResultValue; + Ok : null; Err : text; }; From 96eb1b35c808a69ff5a60358405dab0e0e8bb7e8 Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Sun, 10 Sep 2023 22:07:04 +0200 Subject: [PATCH 09/31] chore: check anonymous principal comment --- src/ic-websocket-cdk/src/lib.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/ic-websocket-cdk/src/lib.rs b/src/ic-websocket-cdk/src/lib.rs index 6671cbf..3d1b6d0 100644 --- a/src/ic-websocket-cdk/src/lib.rs +++ b/src/ic-websocket-cdk/src/lib.rs @@ -767,6 +767,8 @@ pub fn init(params: WsInitParams) { pub fn ws_open(args: CanisterWsOpenArguments) -> CanisterWsOpenResult { let client_principal = caller(); + // TODO: check if the principal is not the anonymous principal + // check if client is not registered yet if is_client_registered(&client_principal) { return Err(format!( From d2976436eb49b4488c64b9a9a4badf14c05da64d Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Mon, 18 Sep 2023 09:54:30 +0200 Subject: [PATCH 10/31] feat: return last messages if nonce is 0 --- src/ic-websocket-cdk/src/lib.rs | 43 ++++++++++++++++++++++++++++++--- 1 file changed, 40 insertions(+), 3 deletions(-) diff --git a/src/ic-websocket-cdk/src/lib.rs b/src/ic-websocket-cdk/src/lib.rs index 3d1b6d0..fcb33e0 100644 --- a/src/ic-websocket-cdk/src/lib.rs +++ b/src/ic-websocket-cdk/src/lib.rs @@ -408,12 +408,26 @@ fn get_message_for_gateway_key(gateway_principal: Principal, nonce: u64) -> Stri fn get_messages_for_gateway_range(gateway_principal: Principal, nonce: u64) -> (usize, usize) { MESSAGES_FOR_GATEWAY.with(|m| { + let queue_len = m.borrow().len(); + + if nonce == 0 && queue_len > 0 { + // this is the case in which the poller on the gateway restarted + // the range to return is end:last index and start: max(end - MAX_NUMBER_OF_RETURNED_MESSAGES, 0) + let start_index = if queue_len > MAX_NUMBER_OF_RETURNED_MESSAGES { + queue_len - MAX_NUMBER_OF_RETURNED_MESSAGES + } else { + 0 + }; + + return (start_index, queue_len); + } + // smallest key used to determine the first message from the queue which has to be returned to the WS Gateway let smallest_key = get_message_for_gateway_key(gateway_principal, nonce); // partition the queue at the message which has the key with the nonce specified as argument to get_cert_messages let start_index = m.borrow().partition_point(|x| x.key < smallest_key); // message at index corresponding to end index is excluded - let mut end_index = m.borrow().len(); + let mut end_index = queue_len; if end_index - start_index > MAX_NUMBER_OF_RETURNED_MESSAGES { end_index = start_index + MAX_NUMBER_OF_RETURNED_MESSAGES; } @@ -1380,8 +1394,9 @@ mod test { test_utils::add_messages_for_gateway(test_client_principal.clone(), gateway_principal, messages_count); // Test - // messages are now MAX_NUMBER_OF_RETURNED_MESSAGES - for i in 0..messages_count + 1 { + // messages are now 2 * MAX_NUMBER_OF_RETURNED_MESSAGES + // the case in which the start index is 0 is tested in test_get_messages_for_gateway_range_initial_nonce + for i in 1..messages_count + 1 { let (start_index, end_index) = get_messages_for_gateway_range(gateway_principal, i); let expected_end_index = if (i as usize) + MAX_NUMBER_OF_RETURNED_MESSAGES > messages_count as usize { messages_count as usize @@ -1396,6 +1411,28 @@ mod test { test_utils::clean_messages_for_gateway(); } + #[test] + fn test_get_messages_for_gateway_initial_nonce(gateway_principal in any::().prop_map(|_| test_utils::get_static_principal()), messages_count in any::().prop_map(|c| c % 100)) { + // Set up + REGISTERED_GATEWAY.with(|p| *p.borrow_mut() = Some(RegisteredGateway::new(gateway_principal.clone()))); + + let test_client_principal = test_utils::generate_random_principal(); + test_utils::add_messages_for_gateway(test_client_principal.clone(), gateway_principal, messages_count); + + // Test + let (start_index, end_index) = get_messages_for_gateway_range(gateway_principal, 0); + let expected_start_index = if (messages_count as usize) > MAX_NUMBER_OF_RETURNED_MESSAGES { + (messages_count as usize) - MAX_NUMBER_OF_RETURNED_MESSAGES + } else { + 0 + }; + prop_assert_eq!(start_index, expected_start_index); + prop_assert_eq!(end_index, messages_count as usize); + + // Clean up + test_utils::clean_messages_for_gateway(); + } + #[test] fn test_get_messages_for_gateway(gateway_principal in any::().prop_map(|_| test_utils::get_static_principal()), messages_count in any::().prop_map(|c| c % 100)) { // Set up From 183a13d44bb3f14d1806575f4ea71849090b1447 Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Mon, 18 Sep 2023 10:02:50 +0200 Subject: [PATCH 11/31] feat: remove ws_status and registered gw checks --- src/ic-websocket-cdk/service.example.did | 1 - src/ic-websocket-cdk/src/lib.rs | 177 +---------------------- src/ic-websocket-cdk/ws_types.did | 9 -- tests/src/lib.rs | 10 +- 4 files changed, 3 insertions(+), 194 deletions(-) diff --git a/src/ic-websocket-cdk/service.example.did b/src/ic-websocket-cdk/service.example.did index 39d0edb..1d2f23d 100644 --- a/src/ic-websocket-cdk/service.example.did +++ b/src/ic-websocket-cdk/service.example.did @@ -4,6 +4,5 @@ service : { "ws_open" : (CanisterWsOpenArguments) -> (CanisterWsOpenResult); "ws_close" : (CanisterWsCloseArguments) -> (CanisterWsCloseResult); "ws_message" : (CanisterWsMessageArguments) -> (CanisterWsMessageResult); - "ws_status" : (CanisterWsStatusArguments) -> (CanisterWsStatusResult); "ws_get_messages" : (CanisterWsGetMessagesArguments) -> (CanisterWsGetMessagesResult) query; }; diff --git a/src/ic-websocket-cdk/src/lib.rs b/src/ic-websocket-cdk/src/lib.rs index fcb33e0..a78c7d9 100644 --- a/src/ic-websocket-cdk/src/lib.rs +++ b/src/ic-websocket-cdk/src/lib.rs @@ -16,8 +16,6 @@ mod logger; const LABEL_WEBSOCKET: &[u8] = b"websocket"; /// The maximum number of messages returned by [ws_get_messages] at each poll. const MAX_NUMBER_OF_RETURNED_MESSAGES: usize = 10; -/// The default delay between two consecutive checks if the registered gateway is still alive. -const DEFAULT_CHECK_REGISTERED_GATEWAY_DELAY_MS: u64 = 60_000; // 60 seconds /// The default delay between two consecutive acknowledgements sent to the client. const DEFAULT_SEND_ACK_DELAY_MS: u64 = 60_000; // 60 seconds @@ -29,8 +27,6 @@ pub type CanisterWsOpenResult = Result<(), String>; pub type CanisterWsCloseResult = Result<(), String>; /// The result of [ws_message]. pub type CanisterWsMessageResult = Result<(), String>; -/// The result of [ws_status]. -pub type CanisterWsStatusResult = Result<(), String>; /// The result of [ws_get_messages]. pub type CanisterWsGetMessagesResult = Result; /// The result of [ws_send]. @@ -54,12 +50,6 @@ pub struct CanisterWsMessageArguments { msg: WebsocketMessage, } -/// The arguments for [ws_status]. -#[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq, Debug)] -pub struct CanisterWsStatusArguments { - status_index: u64, -} - /// The arguments for [ws_get_messages]. #[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq, Debug)] pub struct CanisterWsGetMessagesArguments { @@ -112,43 +102,12 @@ pub struct CanisterOutputCertifiedMessages { struct RegisteredGateway { /// The principal of the gateway. gateway_principal: Principal, - /// The last time the gateway sent a heartbeat message. - last_heartbeat: Option, - /// The last status index received from the gateway. - last_status_index: u64, } impl RegisteredGateway { /// Creates a new instance of RegisteredGateway. fn new(gateway_principal: Principal) -> Self { - Self { - gateway_principal, - last_heartbeat: None, - last_status_index: 0, - } - } - - /// Updates the registered gateway's status index with the given one. - /// Sets the last heartbeat to the current time. - fn update_status_index(&mut self, status_index: u64) -> Result<(), String> { - if status_index <= self.last_status_index { - if status_index == 0 { - custom_print!("Gateway status index set to 0"); - } else { - return Err("Gateway status index is equal to or behind the current one".to_owned()); - } - } - self.last_status_index = status_index; - self.last_heartbeat = Some(get_current_time()); - Ok(()) - } - - /// Resets the registered gateway to the initial state. - fn reset(&mut self) { - self.last_heartbeat = None; - self.last_status_index = 0; - - custom_print!("Gateway has been reset"); + Self { gateway_principal } } } @@ -236,14 +195,6 @@ fn reset_internal_state() { pub fn wipe() { reset_internal_state(); - // if there is a registered gateway, reset its state - REGISTERED_GATEWAY.with(|state| { - let mut registered_gateway = state.borrow_mut(); - if let Some(v) = registered_gateway.as_mut() { - v.reset(); - } - }); - // remove all clients from the map REGISTERED_CLIENTS.with(|map| { map.borrow_mut().clear(); @@ -302,30 +253,6 @@ fn get_registered_gateway_principal() -> Principal { }) } -/// Updates the registered gateway with the new status index. -/// If the status index is not greater than the current one, the function returns an error. -fn update_registered_gateway_status_index(status_index: u64) -> Result<(), String> { - REGISTERED_GATEWAY.with(|state| { - let mut registered_gateway = state.borrow_mut(); - - if let Some(v) = registered_gateway.as_mut() { - // if the current status index is > 0 and the new status index is 0, it means that the gateway has been restarted - // in this case, we reset the internal state because all clients are not connected to the gateway anymore - if v.last_status_index > 0 && status_index == 0 { - reset_internal_state(); - - v.reset(); - - Ok(()) - } else { - v.update_status_index(status_index) - } - } else { - Err("no gateway registered".to_owned()) - } - }) -} - fn init_outgoing_message_to_client_num(client_principal: ClientPrincipal) { OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| { map.borrow_mut().insert(client_principal, 0); @@ -510,46 +437,6 @@ fn get_cert_for_range(first: &String, last: &String) -> (Vec, Vec) { }) } -/// Schedules a timer to check if the registered gateway has sent a heartbeat recently. -/// -/// The timer callback is [check_registered_gateway_timer_callback]. -fn schedule_registered_gateway_check(interval_ms: u64) { - set_timer(Duration::from_millis(interval_ms), move || { - check_registered_gateway_timer_callback(interval_ms) - }); -} - -/// Checks if the registered gateway has sent a heartbeat recently. -/// If not, this means that the gateway has been restarted and all clients registered have been disconnected. -/// In this case, the internal IC WebSocket CDK state is reset. -/// -/// At the end, a new timer is scheduled to check again if the registered gateway has sent a heartbeat recently. -fn check_registered_gateway_timer_callback(interval_ms: u64) { - let interval_ns = interval_ms * 1_000_000; - REGISTERED_GATEWAY.with(|state| { - let mut registered_gateway = state.borrow_mut(); - if let Some(v) = registered_gateway.as_mut() { - if let Some(last_heartbeat) = v.last_heartbeat { - if get_current_time() - last_heartbeat > interval_ns { - custom_print!("[registered-gateway-timer-cb]: Registered gateway has not sent a heartbeat for more than {} seconds, resetting all internal state", interval_ns / 1_000_000_000); - - reset_internal_state(); - - v.reset(); - } else { - custom_print!("[registered-gateway-timer-cb]: Registered gateway is still alive"); - } - } else { - custom_print!("[registered-gateway-timer-cb]: Registered gateway has not sent a heartbeat yet"); - } - } else { - custom_print!("[registered-gateway-timer-cb]: No registered gateway"); - } - }); - - schedule_registered_gateway_check(interval_ms); -} - #[derive(CandidType)] struct CanisterOpenMessageContent { client_principal: ClientPrincipal, @@ -737,9 +624,6 @@ pub struct WsInitParams { pub handlers: WsHandlers, /// The principal of the WS Gateway that will be polling the canister. pub gateway_principal: String, - /// The interval at which to check if the registered gateway is still alive (in milliseconds). - /// Defaults to `60_000` (60 seconds). - pub check_registered_gateway_interval_ms: u64, /// The interval at which to send an acknowledgement message to the client, /// so that the client knows that all the messages it sent have been received by the canister (in milliseconds). /// Defaults to `60_000` (60 seconds). @@ -752,7 +636,6 @@ impl WsInitParams { Self { handlers, gateway_principal, - check_registered_gateway_interval_ms: DEFAULT_CHECK_REGISTERED_GATEWAY_DELAY_MS, send_ack_interval_ms: DEFAULT_SEND_ACK_DELAY_MS, } } @@ -770,9 +653,6 @@ pub fn init(params: WsInitParams) { // set the principal of the (only) WS Gateway that will be polling the canister initialize_registered_gateway(¶ms.gateway_principal); - // schedule a timer that will check if the registered gateway is still alive - schedule_registered_gateway_check(params.check_registered_gateway_interval_ms); - // schedule a timer that will send an acknowledgement message to clients schedule_send_ack_to_clients(params.send_ack_interval_ms); } @@ -888,16 +768,6 @@ pub fn ws_message(args: CanisterWsMessageArguments) -> CanisterWsMessageResult { Ok(()) } -/// Used by the WS Gateway to update its status on the canister. -/// This way, the canister can check if the WS Gateway is still alive. -pub fn ws_status(args: CanisterWsStatusArguments) -> CanisterWsStatusResult { - // check if the caller of this method is the WS Gateway that has been set during the initialization of the SDK - let gateway_principal = caller(); - check_is_registered_gateway(gateway_principal)?; - - update_registered_gateway_status_index(args.status_index) -} - /// Returns messages to the WS Gateway in response of a polling iteration. pub fn ws_get_messages(args: CanisterWsGetMessagesArguments) -> CanisterWsGetMessagesResult { // check if the caller of this method is the WS Gateway that has been set during the initialization of the SDK @@ -1166,51 +1036,6 @@ mod test { prop_assert_eq!(actual_gateway_principal, test_gateway_principal); } - #[test] - fn test_update_registered_gateway_status_index(test_gateway_principal in any::().prop_map(|_| test_utils::generate_random_principal()), test_client_principal in any::().prop_map(|_| test_utils::generate_random_principal())) { - // Set up - REGISTERED_CLIENTS.with(|map| { - map.borrow_mut().insert(test_client_principal.clone(), test_utils::generate_random_registered_client()); - }); - REGISTERED_GATEWAY.with(|p| *p.borrow_mut() = Some(RegisteredGateway::new(test_gateway_principal.clone()))); - - // test with a valid status index - let _ = update_registered_gateway_status_index(2); - let actual_status_index = REGISTERED_GATEWAY.with(|p| p.borrow().as_ref().unwrap().last_status_index.clone()); - prop_assert_eq!(actual_status_index, 2); - - // test with an invalid status index (behind the current one) - let actual_result = update_registered_gateway_status_index(1); - let actual_status_index = REGISTERED_GATEWAY.with(|p| p.borrow().as_ref().unwrap().last_status_index.clone()); - prop_assert_eq!(actual_status_index, 2); - prop_assert_eq!(actual_result.err(), Some(String::from("Gateway status index is equal to or behind the current one"))); - - // test with an invalid status index (equal to the current one) - let actual_result = update_registered_gateway_status_index(2); - let actual_status_index = REGISTERED_GATEWAY.with(|p| p.borrow().as_ref().unwrap().last_status_index.clone()); - prop_assert_eq!(actual_status_index, 2); - prop_assert_eq!(actual_result.err(), Some(String::from("Gateway status index is equal to or behind the current one"))); - - // test with a valid status index (greater one) - let _ = update_registered_gateway_status_index(10); - let actual_status_index = REGISTERED_GATEWAY.with(|p| p.borrow().as_ref().unwrap().last_status_index.clone()); - prop_assert_eq!(actual_status_index, 10); - - // reset the registered gateway - let new_client_principal = test_utils::generate_random_principal(); - REGISTERED_CLIENTS.with(|map| { - map.borrow_mut().insert(new_client_principal.clone(), test_utils::generate_random_registered_client()); - }); - let _ = update_registered_gateway_status_index(0); - let actual_status_index = REGISTERED_GATEWAY.with(|p| p.borrow().as_ref().unwrap().last_status_index.clone()); - prop_assert_eq!(actual_status_index, 0); - let actual_result = REGISTERED_CLIENTS.with(|map| { - let map = map.borrow(); - map.get(&test_client_principal).is_none() && map.get(&new_client_principal).is_none() - }); - prop_assert!(actual_result); - } - #[test] fn test_check_registered_client_principal_empty(test_client_principal in any::().prop_map(|_| test_utils::generate_random_principal())) { let actual_result = check_registered_client(&test_client_principal); diff --git a/src/ic-websocket-cdk/ws_types.did b/src/ic-websocket-cdk/ws_types.did index ff64e73..be8ab22 100644 --- a/src/ic-websocket-cdk/ws_types.did +++ b/src/ic-websocket-cdk/ws_types.did @@ -48,15 +48,6 @@ type CanisterWsMessageResult = variant { Err : text; }; -type CanisterWsStatusArguments = record { - status_index : nat64; -}; - -type CanisterWsStatusResult = variant { - Ok : null; - Err : text; -}; - type CanisterWsGetMessagesArguments = record { nonce : nat64; }; diff --git a/tests/src/lib.rs b/tests/src/lib.rs index f1b79e6..0d5994f 100644 --- a/tests/src/lib.rs +++ b/tests/src/lib.rs @@ -4,8 +4,8 @@ use canister::{on_close, on_message, on_open}; use ic_websocket_cdk::{ CanisterWsCloseArguments, CanisterWsCloseResult, CanisterWsGetMessagesArguments, CanisterWsGetMessagesResult, CanisterWsMessageArguments, CanisterWsMessageResult, - CanisterWsOpenArguments, CanisterWsOpenResult, CanisterWsSendResult, CanisterWsStatusArguments, - CanisterWsStatusResult, ClientPrincipal, WsHandlers, WsInitParams, + CanisterWsOpenArguments, CanisterWsOpenResult, CanisterWsSendResult, ClientPrincipal, + WsHandlers, WsInitParams, }; mod canister; @@ -51,12 +51,6 @@ fn ws_message(args: CanisterWsMessageArguments) -> CanisterWsMessageResult { ic_websocket_cdk::ws_message(args) } -// method called by the WS Gateway to update its status in the canister -#[update] -fn ws_status(args: CanisterWsStatusArguments) -> CanisterWsStatusResult { - ic_websocket_cdk::ws_status(args) -} - // method called by the WS Gateway to get messages for all the clients it serves #[query] fn ws_get_messages(args: CanisterWsGetMessagesArguments) -> CanisterWsGetMessagesResult { From dd6c814d199bba956841824011d7b3fe098f2168 Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Mon, 18 Sep 2023 10:19:46 +0200 Subject: [PATCH 12/31] fix: anonymous principal cannot open a connection --- src/ic-websocket-cdk/src/lib.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/ic-websocket-cdk/src/lib.rs b/src/ic-websocket-cdk/src/lib.rs index a78c7d9..7698354 100644 --- a/src/ic-websocket-cdk/src/lib.rs +++ b/src/ic-websocket-cdk/src/lib.rs @@ -661,7 +661,10 @@ pub fn init(params: WsInitParams) { pub fn ws_open(args: CanisterWsOpenArguments) -> CanisterWsOpenResult { let client_principal = caller(); - // TODO: check if the principal is not the anonymous principal + // TODO: test + if client_principal == Principal::anonymous() { + return Err(String::from("anonymous principal cannot open a connection")); + } // check if client is not registered yet if is_client_registered(&client_principal) { From 7cdf92483db45e050aef4d592823f580034f7882 Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Mon, 18 Sep 2023 10:21:21 +0200 Subject: [PATCH 13/31] chore: todo tests comments --- src/ic-websocket-cdk/src/lib.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/ic-websocket-cdk/src/lib.rs b/src/ic-websocket-cdk/src/lib.rs index 7698354..452e36e 100644 --- a/src/ic-websocket-cdk/src/lib.rs +++ b/src/ic-websocket-cdk/src/lib.rs @@ -337,6 +337,7 @@ fn get_messages_for_gateway_range(gateway_principal: Principal, nonce: u64) -> ( MESSAGES_FOR_GATEWAY.with(|m| { let queue_len = m.borrow().len(); + // TODO: test if nonce == 0 && queue_len > 0 { // this is the case in which the poller on the gateway restarted // the range to return is end:last index and start: max(end - MAX_NUMBER_OF_RETURNED_MESSAGES, 0) @@ -654,6 +655,7 @@ pub fn init(params: WsInitParams) { initialize_registered_gateway(¶ms.gateway_principal); // schedule a timer that will send an acknowledgement message to clients + // TODO: test schedule_send_ack_to_clients(params.send_ack_interval_ms); } From c2f016248f7bd054595dd17d11064e200a8f84ba Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Mon, 18 Sep 2023 10:32:00 +0200 Subject: [PATCH 14/31] feat: remove is_anonymous flag --- src/ic-websocket-cdk/src/lib.rs | 53 ++++----------------------------- 1 file changed, 6 insertions(+), 47 deletions(-) diff --git a/src/ic-websocket-cdk/src/lib.rs b/src/ic-websocket-cdk/src/lib.rs index 452e36e..a441d4b 100644 --- a/src/ic-websocket-cdk/src/lib.rs +++ b/src/ic-websocket-cdk/src/lib.rs @@ -35,7 +35,7 @@ pub type CanisterWsSendResult = Result<(), String>; /// The arguments for [ws_open]. #[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq, Debug)] pub struct CanisterWsOpenArguments { - is_anonymous: bool, + // future versions may need more fields } /// The arguments for [ws_close]. @@ -122,15 +122,10 @@ fn get_current_time() -> u64 { } } +/// The data about a registered client. #[derive(Clone, Debug, Eq, PartialEq)] pub struct RegisteredClient { - is_anonymous: bool, -} - -impl RegisteredClient { - fn is_anonymous(&self) -> bool { - self.is_anonymous - } + // future versions may need more fields } thread_local! { @@ -218,10 +213,6 @@ fn insert_client(client_principal: ClientPrincipal, new_client: RegisteredClient }); } -fn get_registered_client(client_principal: &ClientPrincipal) -> Option { - REGISTERED_CLIENTS.with(|map| map.borrow().get(client_principal).cloned()) -} - fn is_client_registered(client_principal: &ClientPrincipal) -> bool { REGISTERED_CLIENTS.with(|map| map.borrow().contains_key(client_principal)) } @@ -567,7 +558,6 @@ type OnOpenCallback = fn(OnOpenCallbackArgs); /// Arguments passed to the `on_message` handler. pub struct OnMessageCallbackArgs { pub client_principal: ClientPrincipal, - pub is_anonymous: bool, pub message: Vec, } /// Handler initialized by the canister @@ -660,7 +650,7 @@ pub fn init(params: WsInitParams) { } /// Handles the WS connection open event sent by the client and relayed by the Gateway. -pub fn ws_open(args: CanisterWsOpenArguments) -> CanisterWsOpenResult { +pub fn ws_open(_args: CanisterWsOpenArguments) -> CanisterWsOpenResult { let client_principal = caller(); // TODO: test @@ -684,10 +674,7 @@ pub fn ws_open(args: CanisterWsOpenArguments) -> CanisterWsOpenResult { } // initialize client maps - let new_client = RegisteredClient { - is_anonymous: args.is_anonymous, - }; - add_client(client_principal.clone(), new_client); + add_client(client_principal.clone(), RegisteredClient {}); let open_message = CanisterOpenMessageContent { client_principal: client_principal.clone(), @@ -729,15 +716,6 @@ pub fn ws_message(args: CanisterWsMessageArguments) -> CanisterWsMessageResult { let client_principal = caller(); // check if client registered its principal by calling ws_open check_registered_client(&client_principal)?; - let registered_client = match get_registered_client(&client_principal) { - Some(v) => v, - None => { - return Err(format!( - "client with principal {:?} doesn't have an open connection", - client_principal - )) - }, - }; let WebsocketMessage { client_principal: _, @@ -766,7 +744,6 @@ pub fn ws_message(args: CanisterWsMessageArguments) -> CanisterWsMessageResult { // create message to send to client h.borrow().call_on_message(OnMessageCallbackArgs { client_principal, - is_anonymous: registered_client.is_anonymous(), message: content, }); }); @@ -821,9 +798,7 @@ mod test { } pub fn generate_random_registered_client() -> RegisteredClient { - RegisteredClient { - is_anonymous: false, - } + RegisteredClient {} } pub fn get_static_principal() -> Principal { @@ -901,7 +876,6 @@ mod test { }); h.call_on_message(OnMessageCallbackArgs { client_principal: test_utils::generate_random_principal(), - is_anonymous: false, // doesn't matter message: vec![], }); h.call_on_close(OnCloseCallbackArgs { @@ -954,7 +928,6 @@ mod test { }); h.call_on_message(OnMessageCallbackArgs { client_principal: test_utils::generate_random_principal(), - is_anonymous: false, // doesn't matter message: vec![], }); h.call_on_close(OnCloseCallbackArgs { @@ -1007,20 +980,6 @@ mod test { prop_assert_eq!(get_outgoing_message_nonce(), test_nonce + 1); } - #[test] - fn test_get_registered_client(test_client_principal in any::().prop_map(|_| test_utils::generate_random_principal())) { - // Set up - let registered_client = test_utils::generate_random_registered_client(); - REGISTERED_CLIENTS.with(|map| { - map.borrow_mut().insert(test_client_principal.clone(), registered_client.clone()); - }); - - let actual_client = get_registered_client(&test_client_principal); - prop_assert_eq!(actual_client, Some(registered_client)); - let actual_client = get_registered_client(&test_utils::generate_random_principal()); - prop_assert_eq!(actual_client, None); - } - #[test] fn test_insert_client(test_client_principal in any::().prop_map(|_| test_utils::generate_random_principal())) { // Set up From f610c5aeb4c64c7fce4ec92493b88858c3fc130b Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Mon, 18 Sep 2023 10:34:08 +0200 Subject: [PATCH 15/31] chore: comments --- src/ic-websocket-cdk/src/lib.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ic-websocket-cdk/src/lib.rs b/src/ic-websocket-cdk/src/lib.rs index a441d4b..5a9aaf4 100644 --- a/src/ic-websocket-cdk/src/lib.rs +++ b/src/ic-websocket-cdk/src/lib.rs @@ -122,14 +122,14 @@ fn get_current_time() -> u64 { } } -/// The data about a registered client. +/// The metadata about a registered client. #[derive(Clone, Debug, Eq, PartialEq)] pub struct RegisteredClient { // future versions may need more fields } thread_local! { - /// Maps the client's public key to the client's identity (anonymous if not authenticated). + /// Maps the client's principal to the client metadata /* flexible */ static REGISTERED_CLIENTS: RefCell> = RefCell::new(HashMap::new()); /// Maps the client's public key to the sequence number to use for the next outgoing message (to that client). /* flexible */ static OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP: RefCell> = RefCell::new(HashMap::new()); From cb8db47101701b4e6d9f63008ca03814d1025f31 Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Mon, 18 Sep 2023 11:00:31 +0200 Subject: [PATCH 16/31] fix: unneeded clients removal --- src/ic-websocket-cdk/src/lib.rs | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/ic-websocket-cdk/src/lib.rs b/src/ic-websocket-cdk/src/lib.rs index 5a9aaf4..7c98be3 100644 --- a/src/ic-websocket-cdk/src/lib.rs +++ b/src/ic-websocket-cdk/src/lib.rs @@ -190,11 +190,6 @@ fn reset_internal_state() { pub fn wipe() { reset_internal_state(); - // remove all clients from the map - REGISTERED_CLIENTS.with(|map| { - map.borrow_mut().clear(); - }); - custom_print!("Internal state has been wiped!"); } From 6dd085a65bf5361d58673d496c327d790777d8fb Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Mon, 18 Sep 2023 12:13:14 +0200 Subject: [PATCH 17/31] feat: timer to wait for client keep alive --- src/ic-websocket-cdk/src/lib.rs | 207 +++++++++++++++++++++++++------- 1 file changed, 164 insertions(+), 43 deletions(-) diff --git a/src/ic-websocket-cdk/src/lib.rs b/src/ic-websocket-cdk/src/lib.rs index 7c98be3..8dc939f 100644 --- a/src/ic-websocket-cdk/src/lib.rs +++ b/src/ic-websocket-cdk/src/lib.rs @@ -1,4 +1,4 @@ -use candid::{CandidType, Encode, Principal}; +use candid::{decode_one, encode_one, CandidType, Principal}; #[cfg(not(test))] use ic_cdk::api::time; use ic_cdk::api::{caller, data_certificate, set_certified_data}; @@ -18,6 +18,8 @@ const LABEL_WEBSOCKET: &[u8] = b"websocket"; const MAX_NUMBER_OF_RETURNED_MESSAGES: usize = 10; /// The default delay between two consecutive acknowledgements sent to the client. const DEFAULT_SEND_ACK_DELAY_MS: u64 = 60_000; // 60 seconds +/// The default delay to wait for the client to send a keep alive after receiving an acknowledgement. +const DEFAULT_CLIENT_KEEP_ALIVE_DELAY_MS: u64 = 10_000; // 10 seconds pub type ClientPrincipal = Principal; @@ -125,7 +127,26 @@ fn get_current_time() -> u64 { /// The metadata about a registered client. #[derive(Clone, Debug, Eq, PartialEq)] pub struct RegisteredClient { - // future versions may need more fields + last_keep_alive_timestamp: u64, +} + +impl RegisteredClient { + /// Creates a new instance of RegisteredClient. + fn new() -> Self { + Self { + last_keep_alive_timestamp: get_current_time(), + } + } + + /// Gets the last keep alive timestamp. + fn get_last_keep_alive_timestamp(&self) -> u64 { + self.last_keep_alive_timestamp + } + + /// Set the last keep alive timestamp to the current time. + fn update_last_keep_alive_timestamp(&mut self) { + self.last_keep_alive_timestamp = get_current_time(); + } } thread_local! { @@ -156,18 +177,11 @@ thread_local! { /// Resets all RefCells to their initial state. /// If there is a registered gateway, resets its state as well. fn reset_internal_state() { - // get the handlers to call the on_close handler for each client - let handlers = HANDLERS.with(|state| state.borrow().clone()); - REGISTERED_CLIENTS.with(|state| { - let mut map = state.borrow_mut(); + let map = state.borrow(); // for each client, call the on_close handler before clearing the map for (client_principal, _) in map.clone().iter() { - handlers.call_on_close(OnCloseCallbackArgs { - client_principal: client_principal.clone(), - }); - - map.remove(client_principal); + remove_client(client_principal); } }); @@ -303,15 +317,21 @@ fn add_client(client_principal: ClientPrincipal, new_client: RegisteredClient) { init_outgoing_message_to_client_num(client_principal); } -fn remove_client(client_principal: ClientPrincipal) { +fn remove_client(client_principal: &ClientPrincipal) { + let handlers = HANDLERS.with(|state| state.borrow().clone()); + + handlers.call_on_close(OnCloseCallbackArgs { + client_principal: client_principal.clone(), + }); + REGISTERED_CLIENTS.with(|map| { - map.borrow_mut().remove(&client_principal); + map.borrow_mut().remove(client_principal); }); OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| { - map.borrow_mut().remove(&client_principal); + map.borrow_mut().remove(client_principal); }); INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| { - map.borrow_mut().remove(&client_principal); + map.borrow_mut().remove(client_principal); }); } @@ -424,46 +444,69 @@ fn get_cert_for_range(first: &String, last: &String) -> (Vec, Vec) { }) } -#[derive(CandidType)] +#[derive(CandidType, Deserialize)] struct CanisterOpenMessageContent { client_principal: ClientPrincipal, } -#[derive(CandidType)] +#[derive(CandidType, Deserialize)] struct CanisterAckMessageContent { last_incoming_sequence_num: u64, } +#[derive(CandidType, Deserialize)] +struct ClientKeepAliveMessage { + last_received_sequence_num: u64, +} + /// A service message sent by the CDK to the client. -#[derive(CandidType)] -enum CanisterServiceMessage { +#[derive(CandidType, Deserialize)] +enum WebsocketServiceMessageContent { + /// Message sent by the **canister** when a client opens a connection. OpenMessage(CanisterOpenMessageContent), + /// Message sent _periodically_ by the **canister** to the client to acknowledge the messages received. AckMessage(CanisterAckMessageContent), + /// Message sent by the **client** in response to an acknowledgement message from the canister. + KeepAliveMessage(ClientKeepAliveMessage), } fn send_service_message_to_client( client_principal: ClientPrincipal, - message: CanisterServiceMessage, + message: WebsocketServiceMessageContent, ) -> Result<(), String> { - let message_bytes = Encode!(&message).unwrap(); + let message_bytes = encode_one(&message).unwrap(); _ws_send(client_principal, message_bytes, true) } /// Schedules a timer to send an acknowledgement message to the client. /// -/// The timer callback is [send_ack_to_clients_timer_callback]. -fn schedule_send_ack_to_clients(interval_ms: u64) { - set_timer(Duration::from_millis(interval_ms), move || { - send_ack_to_clients_timer_callback(interval_ms) +/// The timer callback is [send_ack_to_clients_timer_callback]. After the callback is executed, +/// a timer is scheduled to check if the registered clients have sent a keep alive message. +fn schedule_send_ack_to_clients(ack_interval_ms: u64, check_interval_ms: u64) { + set_timer(Duration::from_millis(ack_interval_ms), move || { + send_ack_to_clients_timer_callback(); + + schedule_check_keep_alive(ack_interval_ms, check_interval_ms); + }); +} + +/// Schedules a timer to check if the registered clients have sent a keep alive message +/// after receiving an acknowledgement message. +/// +/// The timer callback is [check_keep_alive_timer_callback]. After the callback is executed, +/// a timer is scheduled again to send an acknowledgement message to the registered clients. +fn schedule_check_keep_alive(ack_interval_ms: u64, check_interval_ms: u64) { + set_timer(Duration::from_millis(check_interval_ms), move || { + check_keep_alive_timer_callback(); + + schedule_send_ack_to_clients(ack_interval_ms, check_interval_ms); }); } /// Sends an acknowledgement message to the client. /// The message contains the current incoming message sequence number for that client, /// so that the client knows that all the messages it sent have been received by the canister. -/// -/// At the end, a new timer is scheduled to send another acknowledgement message to the client. -fn send_ack_to_clients_timer_callback(interval_ms: u64) { +fn send_ack_to_clients_timer_callback() { REGISTERED_CLIENTS.with(|state| { let map = state.borrow(); for (client_principal, _) in map.iter() { @@ -474,8 +517,9 @@ fn send_ack_to_clients_timer_callback(interval_ms: u64) { let ack_message = CanisterAckMessageContent { last_incoming_sequence_num: last_incoming_message_sequence_num, }; - let message = CanisterServiceMessage::AckMessage(ack_message); + let message = WebsocketServiceMessageContent::AckMessage(ack_message); if let Err(e) = send_service_message_to_client(*client_principal, message) { + // TODO: decide what to do when sending the message fails custom_print!( "[ack-to-clients-timer-cb]: Error sending ack message to client {:?}: {:?}", client_principal, @@ -486,11 +530,78 @@ fn send_ack_to_clients_timer_callback(interval_ms: u64) { }; } } + }); - custom_print!("[ack-to-clients-timer-cb]: Sent ack messages to all clients"); + custom_print!("[ack-to-clients-timer-cb]: Sent ack messages to all clients"); +} + +/// Checks if the registered clients have sent a keep alive message. +/// If a client has not sent a keep alive message, it is removed from the registered clients. +fn check_keep_alive_timer_callback() { + REGISTERED_CLIENTS.with(|state| { + let map = state.borrow(); + for (client_principal, client_metadata) in map.iter() { + let current_time = get_current_time(); + if current_time - client_metadata.get_last_keep_alive_timestamp() + > DEFAULT_CLIENT_KEEP_ALIVE_DELAY_MS + { + remove_client(client_principal); + + custom_print!( + "[check-keep-alive-timer-cb]: Client {:?} has not sent a keep alive message in the last {:?} ms and has been removed", + client_principal, + DEFAULT_CLIENT_KEEP_ALIVE_DELAY_MS + ); + } + } }); - schedule_send_ack_to_clients(interval_ms); + custom_print!("[check-keep-alive-timer-cb]: Checked keep alive messages for all clients"); +} + +fn handle_keep_alive_client_message( + client_principal: &ClientPrincipal, + content: &[u8], +) -> Result<(), String> { + match decode_one::(content) { + Ok(message_content) => { + match message_content { + WebsocketServiceMessageContent::KeepAliveMessage(keep_alive_message) => { + // first, we check if the client received the last message sent by the canister + let last_outgoing_message_sequence_num = + get_outgoing_message_to_client_num(client_principal)?; + + // if the client has not received the last message sent by the canister, we remove it + if last_outgoing_message_sequence_num + != keep_alive_message.last_received_sequence_num + { + custom_print!( + "client {:?} has not received the last message sent by the canister, removing it", + client_principal + ); + + remove_client(client_principal); + + return Ok(()); + } + + // update the last keep alive timestamp for the client + REGISTERED_CLIENTS.with(|map| { + let mut map = map.borrow_mut(); + let client_metadata = map.get_mut(client_principal).unwrap(); + client_metadata.update_last_keep_alive_timestamp(); + }); + + Ok(()) + }, + _ => Err(String::from("invalid keep alive message content")), + } + }, + Err(e) => Err(format!( + "Error decoding service message from client: {:?}", + e + )), + } } /// Internal function used to put the messages in the outgoing messages queue and certify them. @@ -614,6 +725,9 @@ pub struct WsInitParams { /// so that the client knows that all the messages it sent have been received by the canister (in milliseconds). /// Defaults to `60_000` (60 seconds). pub send_ack_interval_ms: u64, + /// The delay to wait for the client to send a keep alive after receiving an acknowledgement (in milliseconds). + /// Defaults to `10_000` (10 seconds). + pub keep_alive_delay_ms: u64, } impl WsInitParams { @@ -623,6 +737,7 @@ impl WsInitParams { handlers, gateway_principal, send_ack_interval_ms: DEFAULT_SEND_ACK_DELAY_MS, + keep_alive_delay_ms: DEFAULT_CLIENT_KEEP_ALIVE_DELAY_MS, } } } @@ -641,7 +756,7 @@ pub fn init(params: WsInitParams) { // schedule a timer that will send an acknowledgement message to clients // TODO: test - schedule_send_ack_to_clients(params.send_ack_interval_ms); + schedule_send_ack_to_clients(params.send_ack_interval_ms, params.keep_alive_delay_ms); } /// Handles the WS connection open event sent by the client and relayed by the Gateway. @@ -649,7 +764,7 @@ pub fn ws_open(_args: CanisterWsOpenArguments) -> CanisterWsOpenResult { let client_principal = caller(); // TODO: test - if client_principal == Principal::anonymous() { + if client_principal == ClientPrincipal::anonymous() { return Err(String::from("anonymous principal cannot open a connection")); } @@ -669,12 +784,13 @@ pub fn ws_open(_args: CanisterWsOpenArguments) -> CanisterWsOpenResult { } // initialize client maps - add_client(client_principal.clone(), RegisteredClient {}); + let new_client = RegisteredClient::new(); + add_client(client_principal.clone(), new_client); let open_message = CanisterOpenMessageContent { client_principal: client_principal.clone(), }; - let message = CanisterServiceMessage::OpenMessage(open_message); + let message = WebsocketServiceMessageContent::OpenMessage(open_message); send_service_message_to_client(client_principal.clone(), message)?; // call the on_open handler initialized in init() @@ -695,7 +811,7 @@ pub fn ws_close(args: CanisterWsCloseArguments) -> CanisterWsCloseResult { // check if client registered its principal by calling ws_open check_registered_client(&args.client_principal)?; - remove_client(args.client_principal.clone()); + remove_client(&args.client_principal); HANDLERS.with(|h| { h.borrow().call_on_close(OnCloseCallbackArgs { @@ -716,7 +832,7 @@ pub fn ws_message(args: CanisterWsMessageArguments) -> CanisterWsMessageResult { client_principal: _, sequence_num, timestamp: _, - is_service_message: _, // TODO: handle service messages + is_service_message, content, } = args.msg; @@ -724,15 +840,20 @@ pub fn ws_message(args: CanisterWsMessageArguments) -> CanisterWsMessageResult { // check if the incoming message has the expected sequence number if sequence_num != expected_sequence_num { - return Err( - format!( - "incoming client's message does not have the expected sequence number. Expected: {expected_sequence_num}, actual: {sequence_num}", - ) + custom_print!( + "incoming client's message does not have the expected sequence number. Expected: {expected_sequence_num}, actual: {sequence_num}. Removing client..." ); + remove_client(&client_principal); + return Ok(()); } // increase the expected sequence number by 1 increment_expected_incoming_message_from_client_num(&client_principal)?; + // TODO: test + if is_service_message { + return handle_keep_alive_client_message(&client_principal, &content); + } + // call the on_message handler initialized in init() HANDLERS.with(|h| { // trigger the on_message handler initialized by canister @@ -793,7 +914,7 @@ mod test { } pub fn generate_random_registered_client() -> RegisteredClient { - RegisteredClient {} + RegisteredClient::new() } pub fn get_static_principal() -> Principal { @@ -1113,7 +1234,7 @@ mod test { map.borrow_mut().insert(test_client_principal.clone(), 0); }); - remove_client(test_client_principal.clone()); + remove_client(&test_client_principal); let is_none = REGISTERED_CLIENTS.with(|map| map.borrow().get(&test_client_principal).is_none()); prop_assert!(is_none); From 135b8305e94d7123c910a92854a3bbc60c8a15b0 Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Mon, 18 Sep 2023 16:44:29 +0200 Subject: [PATCH 18/31] fix: handle callbacks panics --- Cargo.lock | 126 ++++++++++++++++---------------- src/ic-websocket-cdk/src/lib.rs | 39 +++++++--- 2 files changed, 93 insertions(+), 72 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3ddbcb9..2a37569 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -69,9 +69,9 @@ checksum = "4c7f02d4ea65f2c1853089ffd8d2787bdbc63de2f0d29dedbcf8ccdfa0ccd4cf" [[package]] name = "base64" -version = "0.21.3" +version = "0.21.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "414dcefbc63d77c526a76b3afcf6fbb9b5e2791c19c3aa2297733208750c6e53" +checksum = "9ba43ea6f343b788c8764558649e08df62f86c6ef251fdaeb1ffd010a9ae50a2" [[package]] name = "base64ct" @@ -163,9 +163,9 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.13.0" +version = "3.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a3e2c3daef883ecc1b5d58c15adae93470a91d425f3532ba1695849656af3fc1" +checksum = "7f30e7476521f6f8af1a1c4c0b8cc94f0bee37d91763d0ca2665f299b6cd8aec" [[package]] name = "byteorder" @@ -175,15 +175,15 @@ checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" [[package]] name = "bytes" -version = "1.4.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be" +checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" [[package]] name = "candid" -version = "0.9.3" +version = "0.9.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a31e5ab22cdcd093b93b02bdff4ba18ffee324b05e669b25cdd93fdb8402d207" +checksum = "88f6eec0ae850e006ef0fe306f362884d370624094ec55a6a26de18b251774be" dependencies = [ "anyhow", "binread", @@ -208,14 +208,14 @@ dependencies = [ [[package]] name = "candid_derive" -version = "0.6.2" +version = "0.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "810b3bd60244f282090652ffc7c30a9d23892e72dfe443e46ee55569044f7dd5" +checksum = "158403ea38fab5904ae47a5d67eb7047650a91681407f5ccbcbcabc4f4ffb489" dependencies = [ "lazy_static", "proc-macro2", "quote", - "syn 2.0.29", + "syn 2.0.37", ] [[package]] @@ -269,9 +269,9 @@ dependencies = [ [[package]] name = "crypto-bigint" -version = "0.5.2" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf4c2f4e1afd912bc40bfd6fed5d9dc1f288e0ba01bfcc835cc5bc3eb13efe15" +checksum = "740fe28e594155f10cfc383984cbefd529d7396050557148f79cb0f621204124" dependencies = [ "generic-array", "rand_core", @@ -390,9 +390,9 @@ checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" [[package]] name = "errno" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b30f669a7961ef1631673d2766cc92f52d64f7ef354d4fe0ddfd30ed52f0f4f" +checksum = "136526188508e25c6fef639d7927dfb3e0e3084488bf202267829cf7fc23dbdd" dependencies = [ "errno-dragonfly", "libc", @@ -506,7 +506,7 @@ checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72" dependencies = [ "proc-macro2", "quote", - "syn 2.0.29", + "syn 2.0.37", ] [[package]] @@ -714,7 +714,7 @@ dependencies = [ "futures-util", "http", "hyper", - "rustls 0.21.6", + "rustls 0.21.7", "tokio", "tokio-rustls", ] @@ -740,7 +740,7 @@ dependencies = [ "rand", "reqwest", "ring", - "rustls 0.20.8", + "rustls 0.20.9", "sec1", "serde", "serde_bytes", @@ -943,9 +943,9 @@ checksum = "884e2677b40cc8c339eaefcb701c32ef1fd2493d71118dc0ca4b6a736c93bd67" [[package]] name = "libc" -version = "0.2.147" +version = "0.2.148" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3" +checksum = "9cdc71e17332e86d2e1d38c1f99edcb6288ee11b815fb1a4b049eaa2114d369b" [[package]] name = "libm" @@ -955,9 +955,9 @@ checksum = "f7012b1bbb0719e1097c47611d3898568c546d597c2e74d66f6087edd5233ff4" [[package]] name = "linux-raw-sys" -version = "0.4.5" +version = "0.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57bcfdad1b858c2db7c38303a6d2ad4dfaf5eb53dfeb0910128b2c26d6158503" +checksum = "1a9bad9f94746442c783ca431b22403b519cd7fbeed0533fdd6328b2f2212128" [[package]] name = "log" @@ -967,9 +967,9 @@ checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" [[package]] name = "memchr" -version = "2.6.0" +version = "2.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76fc44e2588d5b436dbc3c6cf62aef290f90dab6235744a93dfe1cc18f451e2c" +checksum = "8f232d6ef707e1956a43342693d2a31e72989554d58299d7a88738cc95b0d35c" [[package]] name = "mime" @@ -1057,14 +1057,14 @@ dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "syn 2.0.29", + "syn 2.0.37", ] [[package]] name = "object" -version = "0.32.0" +version = "0.32.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77ac5bbd07aea88c60a577a1ce218075ffd59208b2d7ca97adf9bfc5aeb21ebe" +checksum = "9cf5f9dd3933bd50a9e1f149ec995f39ae2c496d31fd772c1fd45ebc27e902b0" dependencies = [ "memchr", ] @@ -1172,9 +1172,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.66" +version = "1.0.67" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18fb31db3f9bddb2ea821cde30a9f70117e3f119938b5ee630b7403aa6e2ead9" +checksum = "3d433d9f1a3e8c1263d9456598b16fec66f4acc9a74dacffd35c7bb09b3a1328" dependencies = [ "unicode-ident", ] @@ -1300,7 +1300,7 @@ dependencies = [ "once_cell", "percent-encoding", "pin-project-lite", - "rustls 0.21.6", + "rustls 0.21.7", "rustls-pemfile", "serde", "serde_json", @@ -1351,9 +1351,9 @@ checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" [[package]] name = "rustix" -version = "0.38.9" +version = "0.38.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9bfe0f2582b4931a45d1fa608f8a8722e8b3c7ac54dd6d5f3b3212791fedef49" +checksum = "d7db8590df6dfcd144d22afd1b83b36c21a18d7cbc1dc4bb5295a8712e9eb662" dependencies = [ "bitflags 2.4.0", "errno", @@ -1364,9 +1364,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.20.8" +version = "0.20.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fff78fc74d175294f4e83b28343315ffcfb114b156f0185e9741cb5570f50e2f" +checksum = "1b80e3dec595989ea8510028f30c408a4630db12c9cbb8de34203b89d6577e99" dependencies = [ "log", "ring", @@ -1376,9 +1376,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.21.6" +version = "0.21.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d1feddffcfcc0b33f5c6ce9a29e341e4cd59c3f78e7ee45f4a40c038b1d6cbb" +checksum = "cd8d6c9f025a446bc4d18ad9632e69aec8f287aa84499ee335599fabd20c3fd8" dependencies = [ "log", "ring", @@ -1397,9 +1397,9 @@ dependencies = [ [[package]] name = "rustls-webpki" -version = "0.101.4" +version = "0.101.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d93931baf2d282fff8d3a532bbfd7653f734643161b87e3e01e59a04439bf0d" +checksum = "45a27e3b59326c16e23d30aeb7a36a24cc0d29e71d68ff611cdfb4a01d013bed" dependencies = [ "ring", "untrusted", @@ -1489,14 +1489,14 @@ checksum = "4eca7ac642d82aa35b60049a6eccb4be6be75e599bd2e9adb5f875a737654af2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.29", + "syn 2.0.37", ] [[package]] name = "serde_json" -version = "1.0.105" +version = "1.0.107" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "693151e1ac27563d6dbcec9dee9fbd5da8539b20fa14ad3752b2e6d363ace360" +checksum = "6b420ce6e3d8bd882e9b243c6eed35dbc9a6110c9769e74b584e0d68d1f20c65" dependencies = [ "itoa", "ryu", @@ -1511,7 +1511,7 @@ checksum = "8725e1dfadb3a50f7e5ce0b1a540466f6ed3fe7a0fca2ac2b8b831d31316bd00" dependencies = [ "proc-macro2", "quote", - "syn 2.0.29", + "syn 2.0.37", ] [[package]] @@ -1613,9 +1613,9 @@ dependencies = [ [[package]] name = "socket2" -version = "0.5.3" +version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2538b18701741680e0322a2302176d3253a35388e2e62f172f64f4f16605f877" +checksum = "4031e820eb552adee9295814c0ced9e5cf38ddf1e8b7d566d6de8e2538ea989e" dependencies = [ "libc", "windows-sys", @@ -1669,9 +1669,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.29" +version = "2.0.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c324c494eba9d92503e6f1ef2e6df781e78f6a7705a0202d9801b198807d518a" +checksum = "7303ef2c05cd654186cb250d29049a24840ca25d2747c25c0381c8d9e2f582e8" dependencies = [ "proc-macro2", "quote", @@ -1714,22 +1714,22 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.47" +version = "1.0.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97a802ec30afc17eee47b2855fc72e0c4cd62be9b4efe6591edde0ec5bd68d8f" +checksum = "9d6d7a740b8a666a7e828dd00da9c0dc290dff53154ea77ac109281de90589b7" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.47" +version = "1.0.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6bb623b56e39ab7dcd4b1b98bb6c8f8d907ed255b18de254088016b27a8ee19b" +checksum = "49922ecae66cc8a249b77e68d1d0623c1b2c514f0060c27cdc68bd62a1219d35" dependencies = [ "proc-macro2", "quote", - "syn 2.0.29", + "syn 2.0.37", ] [[package]] @@ -1787,7 +1787,7 @@ dependencies = [ "mio", "num_cpus", "pin-project-lite", - "socket2 0.5.3", + "socket2 0.5.4", "windows-sys", ] @@ -1797,7 +1797,7 @@ version = "0.24.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081" dependencies = [ - "rustls 0.21.6", + "rustls 0.21.7", "tokio", ] @@ -1823,9 +1823,9 @@ checksum = "7cda73e2f1397b1262d6dfdcef8aafae14d1de7748d66822d3bfeeb6d03e5e4b" [[package]] name = "toml_edit" -version = "0.19.14" +version = "0.19.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8123f27e969974a3dfba720fdb560be359f57b44302d280ba72e76a74480e8a" +checksum = "1b5bb770da30e5cbfde35a2d7b9b8a2c4b8ef89548a7a6aeab5c9a576e3e7421" dependencies = [ "indexmap 2.0.0", "toml_datetime", @@ -1872,9 +1872,9 @@ checksum = "6af6ae20167a9ece4bcb41af5b80f8a1f1df981f6391189ce00fd257af04126a" [[package]] name = "typenum" -version = "1.16.0" +version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba" +checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" [[package]] name = "unarray" @@ -1890,9 +1890,9 @@ checksum = "92888ba5573ff080736b3648696b70cafad7d250551175acbaa4e0385b3e1460" [[package]] name = "unicode-ident" -version = "1.0.11" +version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "301abaae475aa91687eb82514b328ab47a211a533026cb25fc3e519b86adfc3c" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" [[package]] name = "unicode-normalization" @@ -1983,7 +1983,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.29", + "syn 2.0.37", "wasm-bindgen-shared", ] @@ -2017,7 +2017,7 @@ checksum = "54681b18a46765f095758388f2d0cf16eb8d4169b639ab575a8f5693af210c7b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.29", + "syn 2.0.37", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -2053,9 +2053,9 @@ dependencies = [ [[package]] name = "webpki" -version = "0.22.0" +version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f095d78192e208183081cc07bc5515ef55216397af48b873e5edcd72637fa1bd" +checksum = "f0e74f82d49d545ad128049b7e88f6576df2da6b02e9ce565c6f533be576957e" dependencies = [ "ring", "untrusted", diff --git a/src/ic-websocket-cdk/src/lib.rs b/src/ic-websocket-cdk/src/lib.rs index 8dc939f..3e80a6f 100644 --- a/src/ic-websocket-cdk/src/lib.rs +++ b/src/ic-websocket-cdk/src/lib.rs @@ -7,6 +7,7 @@ use ic_certified_map::{labeled, labeled_hash, AsHashTree, Hash as ICHash, RbTree use serde::{Deserialize, Serialize}; use serde_cbor::Serializer; use sha2::{Digest, Sha256}; +use std::panic; use std::time::Duration; use std::{cell::RefCell, collections::HashMap, collections::VecDeque, convert::AsRef}; @@ -318,12 +319,6 @@ fn add_client(client_principal: ClientPrincipal, new_client: RegisteredClient) { } fn remove_client(client_principal: &ClientPrincipal) { - let handlers = HANDLERS.with(|state| state.borrow().clone()); - - handlers.call_on_close(OnCloseCallbackArgs { - client_principal: client_principal.clone(), - }); - REGISTERED_CLIENTS.with(|map| { map.borrow_mut().remove(client_principal); }); @@ -333,6 +328,11 @@ fn remove_client(client_principal: &ClientPrincipal) { INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| { map.borrow_mut().remove(client_principal); }); + + let handlers = HANDLERS.with(|state| state.borrow().clone()); + handlers.call_on_close(OnCloseCallbackArgs { + client_principal: client_principal.clone(), + }); } fn get_message_for_gateway_key(gateway_principal: Principal, nonce: u64) -> String { @@ -690,19 +690,40 @@ pub struct WsHandlers { impl WsHandlers { fn call_on_open(&self, args: OnOpenCallbackArgs) { if let Some(on_open) = self.on_open { - on_open(args); + // TODO: test the panic handling + let res = panic::catch_unwind(|| { + on_open(args); + }); + + if let Err(e) = res { + custom_print!("Error calling on_open handler: {:?}", e); + } } } fn call_on_message(&self, args: OnMessageCallbackArgs) { if let Some(on_message) = self.on_message { - on_message(args); + // TODO: test the panic handling + let res = panic::catch_unwind(|| { + on_message(args); + }); + + if let Err(e) = res { + custom_print!("Error calling on_message handler: {:?}", e); + } } } fn call_on_close(&self, args: OnCloseCallbackArgs) { if let Some(on_close) = self.on_close { - on_close(args); + // TODO: test the panic handling + let res = panic::catch_unwind(|| { + on_close(args); + }); + + if let Err(e) = res { + custom_print!("Error calling on_close handler: {:?}", e); + } } } } From 13763f7bd03c4abcb1009285cd5ae440dee9f9e8 Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Mon, 18 Sep 2023 16:45:12 +0200 Subject: [PATCH 19/31] fix: test canister typos --- tests/src/lib.rs | 2 +- tests/test_canister.did | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/src/lib.rs b/tests/src/lib.rs index 0d5994f..9e09999 100644 --- a/tests/src/lib.rs +++ b/tests/src/lib.rs @@ -21,8 +21,8 @@ fn init(gateway_principal: String) { let params = WsInitParams { handlers, gateway_principal, - check_registered_gateway_interval_ms: 15_000, send_ack_interval_ms: 10_000, + keep_alive_delay_ms: 5_000, }; ic_websocket_cdk::init(params) diff --git a/tests/test_canister.did b/tests/test_canister.did index c7610c4..a98d79e 100644 --- a/tests/test_canister.did +++ b/tests/test_canister.did @@ -9,7 +9,6 @@ service : (text) -> { "ws_open" : (CanisterWsOpenArguments) -> (CanisterWsOpenResult); "ws_close" : (CanisterWsCloseArguments) -> (CanisterWsCloseResult); "ws_message" : (CanisterWsMessageArguments) -> (CanisterWsMessageResult); - "ws_status" : (CanisterWsStatusArguments) -> (CanisterWsStatusResult); "ws_get_messages" : (CanisterWsGetMessagesArguments) -> (CanisterWsGetMessagesResult) query; // methods used just for debugging/testing From 2035e2502a23456140a94985d3ecb49025468e1e Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Mon, 18 Sep 2023 16:53:29 +0200 Subject: [PATCH 20/31] fix: wsopenarguments candid and renaming --- src/ic-websocket-cdk/src/lib.rs | 4 ++-- src/ic-websocket-cdk/ws_types.did | 5 +---- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/ic-websocket-cdk/src/lib.rs b/src/ic-websocket-cdk/src/lib.rs index 3e80a6f..b4387d6 100644 --- a/src/ic-websocket-cdk/src/lib.rs +++ b/src/ic-websocket-cdk/src/lib.rs @@ -455,7 +455,7 @@ struct CanisterAckMessageContent { } #[derive(CandidType, Deserialize)] -struct ClientKeepAliveMessage { +struct ClientKeepAliveMessageContent { last_received_sequence_num: u64, } @@ -467,7 +467,7 @@ enum WebsocketServiceMessageContent { /// Message sent _periodically_ by the **canister** to the client to acknowledge the messages received. AckMessage(CanisterAckMessageContent), /// Message sent by the **client** in response to an acknowledgement message from the canister. - KeepAliveMessage(ClientKeepAliveMessage), + KeepAliveMessage(ClientKeepAliveMessageContent), } fn send_service_message_to_client( diff --git a/src/ic-websocket-cdk/ws_types.did b/src/ic-websocket-cdk/ws_types.did index be8ab22..659238d 100644 --- a/src/ic-websocket-cdk/ws_types.did +++ b/src/ic-websocket-cdk/ws_types.did @@ -20,10 +20,7 @@ type CanisterOutputCertifiedMessages = record { tree : blob; }; -type CanisterWsOpenArguments = record { - is_anonymous : bool; - gateway_principal : principal; -}; +type CanisterWsOpenArguments = record {}; type CanisterWsOpenResult = variant { Ok : null; From 70ed828a4b77e8d1bb01db6afcc9a88cabf773ac Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Tue, 19 Sep 2023 12:34:16 +0200 Subject: [PATCH 21/31] fix: empty ws open arguments --- Cargo.lock | 8 ++++---- src/ic-websocket-cdk/src/lib.rs | 9 +++------ src/ic-websocket-cdk/ws_types.did | 2 +- 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2a37569..61d1ee2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -851,9 +851,9 @@ dependencies = [ [[package]] name = "ic0" -version = "0.18.11" +version = "0.18.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "576c539151d4769fb4d1a0c25c4108dd18facd04c5695b02cf2d226ab4e43aa5" +checksum = "16efdbe5d9b0ea368da50aedbf7640a054139569236f1a5249deb5fd9af5a5d5" [[package]] name = "idna" @@ -1693,9 +1693,9 @@ dependencies = [ [[package]] name = "termcolor" -version = "1.2.0" +version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be55cf8942feac5c765c2c993422806843c9a9a45d4d5c407ad6dd2ea95eb9b6" +checksum = "6093bad37da69aab9d123a8091e4be0aa4a03e4d601ec641c327398315f62b64" dependencies = [ "winapi-util", ] diff --git a/src/ic-websocket-cdk/src/lib.rs b/src/ic-websocket-cdk/src/lib.rs index b4387d6..b5f0256 100644 --- a/src/ic-websocket-cdk/src/lib.rs +++ b/src/ic-websocket-cdk/src/lib.rs @@ -36,10 +36,7 @@ pub type CanisterWsGetMessagesResult = Result; /// The arguments for [ws_open]. -#[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq, Debug)] -pub struct CanisterWsOpenArguments { - // future versions may need more fields -} +pub type CanisterWsOpenArguments = (); /// The arguments for [ws_close]. #[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq, Debug)] @@ -456,7 +453,7 @@ struct CanisterAckMessageContent { #[derive(CandidType, Deserialize)] struct ClientKeepAliveMessageContent { - last_received_sequence_num: u64, + last_incoming_sequence_num: u64, } /// A service message sent by the CDK to the client. @@ -573,7 +570,7 @@ fn handle_keep_alive_client_message( // if the client has not received the last message sent by the canister, we remove it if last_outgoing_message_sequence_num - != keep_alive_message.last_received_sequence_num + != keep_alive_message.last_incoming_sequence_num { custom_print!( "client {:?} has not received the last message sent by the canister, removing it", diff --git a/src/ic-websocket-cdk/ws_types.did b/src/ic-websocket-cdk/ws_types.did index 659238d..84ac40d 100644 --- a/src/ic-websocket-cdk/ws_types.did +++ b/src/ic-websocket-cdk/ws_types.did @@ -20,7 +20,7 @@ type CanisterOutputCertifiedMessages = record { tree : blob; }; -type CanisterWsOpenArguments = record {}; +type CanisterWsOpenArguments = null; type CanisterWsOpenResult = variant { Ok : null; From ca118f23fc5ae52cc195e708eaa51e1dd8638c63 Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Wed, 20 Sep 2023 11:16:11 +0200 Subject: [PATCH 22/31] feat: ClientKey as index --- Cargo.lock | 1 + src/ic-websocket-cdk/Cargo.toml | 3 +- src/ic-websocket-cdk/src/lib.rs | 427 +++++++++++++++++++----------- src/ic-websocket-cdk/ws_types.did | 14 +- 4 files changed, 280 insertions(+), 165 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 61d1ee2..5791941 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -842,6 +842,7 @@ dependencies = [ "ic-cdk-timers", "ic-certified-map", "proptest", + "rand", "ring", "serde", "serde_bytes", diff --git a/src/ic-websocket-cdk/Cargo.toml b/src/ic-websocket-cdk/Cargo.toml index dc65cf2..15c9071 100644 --- a/src/ic-websocket-cdk/Cargo.toml +++ b/src/ic-websocket-cdk/Cargo.toml @@ -20,7 +20,8 @@ serde_bytes = "0.11.12" [dev-dependencies] ic-agent = "0.25.0" proptest = "1.2.0" +rand = "0.8.5" ring = "0.16.20" [package.metadata.docs.rs] -default-target = "wasm32-unknown-unknown" \ No newline at end of file +default-target = "wasm32-unknown-unknown" diff --git a/src/ic-websocket-cdk/src/lib.rs b/src/ic-websocket-cdk/src/lib.rs index b5f0256..7c9ba8c 100644 --- a/src/ic-websocket-cdk/src/lib.rs +++ b/src/ic-websocket-cdk/src/lib.rs @@ -7,6 +7,7 @@ use ic_certified_map::{labeled, labeled_hash, AsHashTree, Hash as ICHash, RbTree use serde::{Deserialize, Serialize}; use serde_cbor::Serializer; use sha2::{Digest, Sha256}; +use std::fmt; use std::panic; use std::time::Duration; use std::{cell::RefCell, collections::HashMap, collections::VecDeque, convert::AsRef}; @@ -23,6 +24,27 @@ const DEFAULT_SEND_ACK_DELAY_MS: u64 = 60_000; // 60 seconds const DEFAULT_CLIENT_KEEP_ALIVE_DELAY_MS: u64 = 10_000; // 10 seconds pub type ClientPrincipal = Principal; +#[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq, Debug, Hash)] +struct ClientKey { + client_principal: ClientPrincipal, + client_nonce: u64, +} + +impl ClientKey { + /// Creates a new instance of ClientKey. + fn new(client_principal: ClientPrincipal, client_nonce: u64) -> Self { + Self { + client_principal, + client_nonce, + } + } +} + +impl fmt::Display for ClientKey { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}_{}", self.client_principal, self.client_nonce) + } +} /// The result of [ws_open]. pub type CanisterWsOpenResult = Result<(), String>; @@ -36,12 +58,15 @@ pub type CanisterWsGetMessagesResult = Result; /// The arguments for [ws_open]. -pub type CanisterWsOpenArguments = (); +#[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq, Debug)] +pub struct CanisterWsOpenArguments { + client_nonce: u64, +} /// The arguments for [ws_close]. #[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq, Debug)] pub struct CanisterWsCloseArguments { - client_principal: ClientPrincipal, + client_key: ClientKey, } /// The arguments for [ws_message]. @@ -59,9 +84,9 @@ pub struct CanisterWsGetMessagesArguments { /// Messages exchanged through the WebSocket. #[derive(CandidType, Clone, Debug, Deserialize, Serialize, Eq, PartialEq)] struct WebsocketMessage { - client_principal: ClientPrincipal, // The client that the gateway will forward the message to or that sent the message. - sequence_num: u64, // Both ways, messages should arrive with sequence numbers 0, 1, 2... - timestamp: u64, // Timestamp of when the message was made for the recipient to inspect. + client_key: ClientKey, // The client that the gateway will forward the message to or that sent the message. + sequence_num: u64, // Both ways, messages should arrive with sequence numbers 0, 1, 2... + timestamp: u64, // Timestamp of when the message was made for the recipient to inspect. is_service_message: bool, // Whether the message is a service message sent by the CDK to the client or vice versa. #[serde(with = "serde_bytes")] content: Vec, // Application message encoded in binary. @@ -81,8 +106,8 @@ impl WebsocketMessage { /// Element of the list of messages returned to the WS Gateway after polling. #[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq)] pub struct CanisterOutputMessage { - client_principal: ClientPrincipal, // The client that the gateway will forward the message to or that sent the message. - key: String, // Key for certificate verification. + client_key: ClientKey, // The client that the gateway will forward the message to or that sent the message. + key: String, // Key for certificate verification. #[serde(with = "serde_bytes")] content: Vec, // The message to be relayed, that contains the application message. } @@ -124,7 +149,7 @@ fn get_current_time() -> u64 { /// The metadata about a registered client. #[derive(Clone, Debug, Eq, PartialEq)] -pub struct RegisteredClient { +struct RegisteredClient { last_keep_alive_timestamp: u64, } @@ -148,12 +173,14 @@ impl RegisteredClient { } thread_local! { - /// Maps the client's principal to the client metadata - /* flexible */ static REGISTERED_CLIENTS: RefCell> = RefCell::new(HashMap::new()); - /// Maps the client's public key to the sequence number to use for the next outgoing message (to that client). - /* flexible */ static OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP: RefCell> = RefCell::new(HashMap::new()); - /// Maps the client's public key to the expected sequence number of the next incoming message (from that client). - /* flexible */ static INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP: RefCell> = RefCell::new(HashMap::new()); + /// Maps the client's key to the client metadata + /* flexible */ static REGISTERED_CLIENTS: RefCell> = RefCell::new(HashMap::new()); + /// Maps the client's principal to the current client key + /* flexible */ static CURRENT_CLIENT_KEY_MAP: RefCell> = RefCell::new(HashMap::new()); + /// Maps the client's key to the sequence number to use for the next outgoing message (to that client). + /* flexible */ static OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP: RefCell> = RefCell::new(HashMap::new()); + /// Maps the client's key to the expected sequence number of the next incoming message (from that client). + /* flexible */ static INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP: RefCell> = RefCell::new(HashMap::new()); /// Keeps track of the Merkle tree used for certified queries /* flexible */ static CERT_TREE: RefCell> = RefCell::new(RbTree::new()); /// Keeps track of the principal of the WS Gateway which polls the canister @@ -178,11 +205,15 @@ fn reset_internal_state() { REGISTERED_CLIENTS.with(|state| { let map = state.borrow(); // for each client, call the on_close handler before clearing the map - for (client_principal, _) in map.clone().iter() { - remove_client(client_principal); + for (client_key, _) in map.clone().iter() { + remove_client(client_key); } }); + CURRENT_CLIENT_KEY_MAP.with(|map| { + map.borrow_mut().clear(); + }); + OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| { map.borrow_mut().clear(); }); @@ -213,22 +244,37 @@ fn increment_outgoing_message_nonce() { OUTGOING_MESSAGE_NONCE.with(|n| n.replace_with(|&mut old| old + 1)); } -fn insert_client(client_principal: ClientPrincipal, new_client: RegisteredClient) { - REGISTERED_CLIENTS.with(|map| { +fn insert_client(client_key: ClientKey, new_client: RegisteredClient) { + CURRENT_CLIENT_KEY_MAP.with(|map| { map.borrow_mut() - .insert(client_principal.clone(), new_client); + .insert(client_key.client_principal.clone(), client_key.clone()); + }); + REGISTERED_CLIENTS.with(|map| { + map.borrow_mut().insert(client_key, new_client); }); } -fn is_client_registered(client_principal: &ClientPrincipal) -> bool { - REGISTERED_CLIENTS.with(|map| map.borrow().contains_key(client_principal)) +fn is_client_registered(client_key: &ClientKey) -> bool { + REGISTERED_CLIENTS.with(|map| map.borrow().contains_key(client_key)) +} + +fn get_client_key_from_principal(client_principal: &ClientPrincipal) -> Result { + CURRENT_CLIENT_KEY_MAP.with(|map| { + map.borrow() + .get(client_principal) + .cloned() + .ok_or(String::from(format!( + "client with principal {} doesn't have an open connection", + client_principal + ))) + }) } -fn check_registered_client(client_principal: &ClientPrincipal) -> Result<(), String> { - if !REGISTERED_CLIENTS.with(|map| map.borrow().contains_key(client_principal)) { +fn check_registered_client(client_key: &ClientKey) -> Result<(), String> { + if !is_client_registered(client_key) { return Err(String::from(format!( - "client with principal {:?} doesn't have an open connection", - client_principal + "client with key {} doesn't have an open connection", + client_key ))); } @@ -251,44 +297,40 @@ fn get_registered_gateway_principal() -> Principal { }) } -fn init_outgoing_message_to_client_num(client_principal: ClientPrincipal) { +fn init_outgoing_message_to_client_num(client_key: ClientKey) { OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| { - map.borrow_mut().insert(client_principal, 0); + map.borrow_mut().insert(client_key, 0); }); } -fn get_outgoing_message_to_client_num(client_principal: &ClientPrincipal) -> Result { +fn get_outgoing_message_to_client_num(client_key: &ClientKey) -> Result { OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| { let map = map.borrow(); - let num = *map.get(client_principal).ok_or(String::from( + let num = *map.get(client_key).ok_or(String::from( "outgoing message to client num not initialized for client", ))?; Ok(num) }) } -fn increment_outgoing_message_to_client_num( - client_principal: &ClientPrincipal, -) -> Result<(), String> { - let num = get_outgoing_message_to_client_num(client_principal)?; +fn increment_outgoing_message_to_client_num(client_key: &ClientKey) -> Result<(), String> { + let num = get_outgoing_message_to_client_num(client_key)?; OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| { let mut map = map.borrow_mut(); - map.insert(client_principal.clone(), num + 1); + map.insert(client_key.clone(), num + 1); Ok(()) }) } -fn init_expected_incoming_message_from_client_num(client_principal: ClientPrincipal) { +fn init_expected_incoming_message_from_client_num(client_key: ClientKey) { INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| { - map.borrow_mut().insert(client_principal, 1); + map.borrow_mut().insert(client_key, 1); }); } -fn get_expected_incoming_message_from_client_num( - client_principal: &ClientPrincipal, -) -> Result { +fn get_expected_incoming_message_from_client_num(client_key: &ClientKey) -> Result { INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| { - let num = *map.borrow().get(client_principal).ok_or(String::from( + let num = *map.borrow().get(client_key).ok_or(String::from( "expected incoming message num not initialized for client", ))?; Ok(num) @@ -296,39 +338,42 @@ fn get_expected_incoming_message_from_client_num( } fn increment_expected_incoming_message_from_client_num( - client_principal: &ClientPrincipal, + client_key: &ClientKey, ) -> Result<(), String> { - let num = get_expected_incoming_message_from_client_num(client_principal)?; + let num = get_expected_incoming_message_from_client_num(client_key)?; INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| { let mut map = map.borrow_mut(); - map.insert(client_principal.clone(), num + 1); + map.insert(client_key.clone(), num + 1); Ok(()) }) } -fn add_client(client_principal: ClientPrincipal, new_client: RegisteredClient) { +fn add_client(client_key: ClientKey, new_client: RegisteredClient) { // insert the client in the map - insert_client(client_principal.clone(), new_client); + insert_client(client_key.clone(), new_client); // initialize incoming client's message sequence number to 0 - init_expected_incoming_message_from_client_num(client_principal.clone()); + init_expected_incoming_message_from_client_num(client_key.clone()); // initialize outgoing message sequence number to 0 - init_outgoing_message_to_client_num(client_principal); + init_outgoing_message_to_client_num(client_key); } -fn remove_client(client_principal: &ClientPrincipal) { +fn remove_client(client_key: &ClientKey) { + CURRENT_CLIENT_KEY_MAP.with(|map| { + map.borrow_mut().remove(&client_key.client_principal); + }); REGISTERED_CLIENTS.with(|map| { - map.borrow_mut().remove(client_principal); + map.borrow_mut().remove(client_key); }); OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| { - map.borrow_mut().remove(client_principal); + map.borrow_mut().remove(client_key); }); INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| { - map.borrow_mut().remove(client_principal); + map.borrow_mut().remove(client_key); }); let handlers = HANDLERS.with(|state| state.borrow().clone()); handlers.call_on_close(OnCloseCallbackArgs { - client_principal: client_principal.clone(), + client_principal: client_key.client_principal, }); } @@ -443,7 +488,7 @@ fn get_cert_for_range(first: &String, last: &String) -> (Vec, Vec) { #[derive(CandidType, Deserialize)] struct CanisterOpenMessageContent { - client_principal: ClientPrincipal, + client_key: ClientKey, } #[derive(CandidType, Deserialize)] @@ -468,11 +513,11 @@ enum WebsocketServiceMessageContent { } fn send_service_message_to_client( - client_principal: ClientPrincipal, + client_key: &ClientKey, message: WebsocketServiceMessageContent, ) -> Result<(), String> { let message_bytes = encode_one(&message).unwrap(); - _ws_send(client_principal, message_bytes, true) + _ws_send(client_key, message_bytes, true) } /// Schedules a timer to send an acknowledgement message to the client. @@ -506,20 +551,20 @@ fn schedule_check_keep_alive(ack_interval_ms: u64, check_interval_ms: u64) { fn send_ack_to_clients_timer_callback() { REGISTERED_CLIENTS.with(|state| { let map = state.borrow(); - for (client_principal, _) in map.iter() { + for (client_key, _) in map.iter() { // ignore the error, which shouldn't happen since the client is registered and the sequence number is initialized if let Ok(last_incoming_message_sequence_num) = - get_expected_incoming_message_from_client_num(client_principal) + get_expected_incoming_message_from_client_num(client_key) { let ack_message = CanisterAckMessageContent { last_incoming_sequence_num: last_incoming_message_sequence_num, }; let message = WebsocketServiceMessageContent::AckMessage(ack_message); - if let Err(e) = send_service_message_to_client(*client_principal, message) { + if let Err(e) = send_service_message_to_client(client_key, message) { // TODO: decide what to do when sending the message fails custom_print!( - "[ack-to-clients-timer-cb]: Error sending ack message to client {:?}: {:?}", - client_principal, + "[ack-to-clients-timer-cb]: Error sending ack message to client {}: {:?}", + client_key, e ); @@ -537,16 +582,16 @@ fn send_ack_to_clients_timer_callback() { fn check_keep_alive_timer_callback() { REGISTERED_CLIENTS.with(|state| { let map = state.borrow(); - for (client_principal, client_metadata) in map.iter() { + for (client_key, client_metadata) in map.iter() { let current_time = get_current_time(); if current_time - client_metadata.get_last_keep_alive_timestamp() > DEFAULT_CLIENT_KEEP_ALIVE_DELAY_MS { - remove_client(client_principal); + remove_client(client_key); custom_print!( - "[check-keep-alive-timer-cb]: Client {:?} has not sent a keep alive message in the last {:?} ms and has been removed", - client_principal, + "[check-keep-alive-timer-cb]: Client {} has not sent a keep alive message in the last {} ms and has been removed", + client_key, DEFAULT_CLIENT_KEEP_ALIVE_DELAY_MS ); } @@ -556,28 +601,25 @@ fn check_keep_alive_timer_callback() { custom_print!("[check-keep-alive-timer-cb]: Checked keep alive messages for all clients"); } -fn handle_keep_alive_client_message( - client_principal: &ClientPrincipal, - content: &[u8], -) -> Result<(), String> { +fn handle_keep_alive_client_message(client_key: &ClientKey, content: &[u8]) -> Result<(), String> { match decode_one::(content) { Ok(message_content) => { match message_content { WebsocketServiceMessageContent::KeepAliveMessage(keep_alive_message) => { // first, we check if the client received the last message sent by the canister let last_outgoing_message_sequence_num = - get_outgoing_message_to_client_num(client_principal)?; + get_outgoing_message_to_client_num(client_key)?; // if the client has not received the last message sent by the canister, we remove it if last_outgoing_message_sequence_num != keep_alive_message.last_incoming_sequence_num { custom_print!( - "client {:?} has not received the last message sent by the canister, removing it", - client_principal + "client {} has not received the last message sent by the canister, removing the client", + client_key ); - remove_client(client_principal); + remove_client(client_key); return Ok(()); } @@ -585,7 +627,7 @@ fn handle_keep_alive_client_message( // update the last keep alive timestamp for the client REGISTERED_CLIENTS.with(|map| { let mut map = map.borrow_mut(); - let client_metadata = map.get_mut(client_principal).unwrap(); + let client_metadata = map.get_mut(client_key).unwrap(); client_metadata.update_last_keep_alive_timestamp(); }); @@ -603,12 +645,12 @@ fn handle_keep_alive_client_message( /// Internal function used to put the messages in the outgoing messages queue and certify them. fn _ws_send( - client_principal: ClientPrincipal, + client_key: &ClientKey, msg_bytes: Vec, is_service_message: bool, ) -> CanisterWsSendResult { // check if the client is registered - check_registered_client(&client_principal)?; + check_registered_client(client_key)?; // get the principal of the gateway that is polling the canister let gateway_principal = get_registered_gateway_principal(); @@ -621,11 +663,11 @@ fn _ws_send( // increment the nonce for the next message increment_outgoing_message_nonce(); // increment the sequence number for the next message to the client - increment_outgoing_message_to_client_num(&client_principal)?; + increment_outgoing_message_to_client_num(client_key)?; let websocket_message = WebsocketMessage { - client_principal: client_principal.clone(), - sequence_num: get_outgoing_message_to_client_num(&client_principal)?, + client_key: client_key.clone(), + sequence_num: get_outgoing_message_to_client_num(client_key)?, timestamp: get_current_time(), is_service_message, content: msg_bytes, @@ -642,7 +684,7 @@ fn _ws_send( // (from beginning to end of the queue) as ws_send is called sequentially, the nonce // is incremented by one in each call, and the message is pushed at the end of the queue m.borrow_mut().push_back(CanisterOutputMessage { - client_principal, + client_key: client_key.clone(), content, key, }); @@ -778,19 +820,19 @@ pub fn init(params: WsInitParams) { } /// Handles the WS connection open event sent by the client and relayed by the Gateway. -pub fn ws_open(_args: CanisterWsOpenArguments) -> CanisterWsOpenResult { +pub fn ws_open(args: CanisterWsOpenArguments) -> CanisterWsOpenResult { let client_principal = caller(); - // TODO: test if client_principal == ClientPrincipal::anonymous() { return Err(String::from("anonymous principal cannot open a connection")); } + let client_key = ClientKey::new(client_principal, args.client_nonce); // check if client is not registered yet - if is_client_registered(&client_principal) { + if is_client_registered(&client_key) { return Err(format!( - "client with principal {:?} already has an open connection", - client_principal, + "client with key {} already has an open connection", + client_key, )); } @@ -803,19 +845,18 @@ pub fn ws_open(_args: CanisterWsOpenArguments) -> CanisterWsOpenResult { // initialize client maps let new_client = RegisteredClient::new(); - add_client(client_principal.clone(), new_client); + add_client(client_key.clone(), new_client); let open_message = CanisterOpenMessageContent { - client_principal: client_principal.clone(), + client_key: client_key.clone(), }; let message = WebsocketServiceMessageContent::OpenMessage(open_message); - send_service_message_to_client(client_principal.clone(), message)?; + send_service_message_to_client(&client_key, message)?; // call the on_open handler initialized in init() HANDLERS.with(|h| { - h.borrow().call_on_open(OnOpenCallbackArgs { - client_principal: client_principal.clone(), - }); + h.borrow() + .call_on_open(OnOpenCallbackArgs { client_principal }); }); Ok(()) @@ -827,13 +868,13 @@ pub fn ws_close(args: CanisterWsCloseArguments) -> CanisterWsCloseResult { check_is_registered_gateway(caller())?; // check if client registered its principal by calling ws_open - check_registered_client(&args.client_principal)?; + check_registered_client(&args.client_key)?; - remove_client(&args.client_principal); + remove_client(&args.client_key); HANDLERS.with(|h| { h.borrow().call_on_close(OnCloseCallbackArgs { - client_principal: args.client_principal, + client_principal: args.client_key.client_principal, }); }); @@ -844,32 +885,40 @@ pub fn ws_close(args: CanisterWsCloseArguments) -> CanisterWsCloseResult { pub fn ws_message(args: CanisterWsMessageArguments) -> CanisterWsMessageResult { let client_principal = caller(); // check if client registered its principal by calling ws_open - check_registered_client(&client_principal)?; + let registered_client_key = get_client_key_from_principal(&client_principal)?; let WebsocketMessage { - client_principal: _, + client_key, sequence_num, timestamp: _, is_service_message, content, } = args.msg; - let expected_sequence_num = get_expected_incoming_message_from_client_num(&client_principal)?; + // check if the client is registered with the same nonce as the one used in the message + if registered_client_key.client_nonce != client_key.client_nonce { + return Err(String::from(format!( + "client with principal {} has a different nonce than the one used in the message", + client_principal + ))); + } + + let expected_sequence_num = get_expected_incoming_message_from_client_num(&client_key)?; // check if the incoming message has the expected sequence number if sequence_num != expected_sequence_num { custom_print!( "incoming client's message does not have the expected sequence number. Expected: {expected_sequence_num}, actual: {sequence_num}. Removing client..." ); - remove_client(&client_principal); + remove_client(&client_key); return Ok(()); } // increase the expected sequence number by 1 - increment_expected_incoming_message_from_client_num(&client_principal)?; + increment_expected_incoming_message_from_client_num(&client_key)?; // TODO: test if is_service_message { - return handle_keep_alive_client_message(&client_principal, &content); + return handle_keep_alive_client_message(&client_key, &content); } // call the on_message handler initialized in init() @@ -898,7 +947,8 @@ pub fn ws_get_messages(args: CanisterWsGetMessagesArguments) -> CanisterWsGetMes /// Under the hood, the message is serialized and certified, and then it is added to the queue of messages /// that the WS Gateway will poll in the next iteration. pub fn ws_send(client_principal: ClientPrincipal, msg_bytes: Vec) -> CanisterWsSendResult { - _ws_send(client_principal, msg_bytes, false) + let client_key = get_client_key_from_principal(&client_principal)?; + _ws_send(&client_key, msg_bytes, false) } #[cfg(test)] @@ -911,8 +961,8 @@ mod test { use ic_agent::{identity::BasicIdentity, Identity}; use ring::signature::Ed25519KeyPair; - use crate::{ - get_message_for_gateway_key, CanisterOutputMessage, ClientPrincipal, RegisteredClient, + use super::{ + get_message_for_gateway_key, CanisterOutputMessage, ClientKey, RegisteredClient, MESSAGES_FOR_GATEWAY, }; @@ -931,7 +981,7 @@ mod test { candid::Principal::from_text(identity.sender().unwrap().to_text()).unwrap() } - pub fn generate_random_registered_client() -> RegisteredClient { + pub(super) fn generate_random_registered_client() -> RegisteredClient { RegisteredClient::new() } @@ -940,15 +990,23 @@ mod test { .unwrap() // a random static but valid principal } - pub fn add_messages_for_gateway( - client_principal: ClientPrincipal, + pub(super) fn get_random_client_key() -> ClientKey { + ClientKey::new( + generate_random_principal(), + // a random nonce + rand::random(), + ) + } + + pub(super) fn add_messages_for_gateway( + client_key: ClientKey, gateway_principal: Principal, count: u64, ) { MESSAGES_FOR_GATEWAY.with(|m| { for i in 0..count { m.borrow_mut().push_back(CanisterOutputMessage { - client_principal: client_principal.clone(), + client_key: client_key.clone(), key: get_message_for_gateway_key(gateway_principal.clone(), i), content: vec![], }); @@ -1115,13 +1173,16 @@ mod test { } #[test] - fn test_insert_client(test_client_principal in any::().prop_map(|_| test_utils::generate_random_principal())) { + fn test_insert_client(test_client_key in any::().prop_map(|_| test_utils::get_random_client_key())) { // Set up let registered_client = test_utils::generate_random_registered_client(); - insert_client(test_client_principal.clone(), registered_client.clone()); + insert_client(test_client_key.clone(), registered_client.clone()); + + let actual_client_key = CURRENT_CLIENT_KEY_MAP.with(|map| map.borrow().get(&test_client_key.client_principal).unwrap().clone()); + prop_assert_eq!(actual_client_key, test_client_key.clone()); - let actual_client = REGISTERED_CLIENTS.with(|map| map.borrow().get(&test_client_principal).unwrap().clone()); + let actual_client = REGISTERED_CLIENTS.with(|map| map.borrow().get(&test_client_key).unwrap().clone()); prop_assert_eq!(actual_client, registered_client); } @@ -1135,132 +1196,178 @@ mod test { } #[test] - fn test_check_registered_client_principal_empty(test_client_principal in any::().prop_map(|_| test_utils::generate_random_principal())) { - let actual_result = check_registered_client(&test_client_principal); - prop_assert_eq!(actual_result.err(), Some(format!("client with principal {:?} doesn't have an open connection", test_client_principal))); + fn test_is_client_registered_empty(test_client_key in any::().prop_map(|_| test_utils::get_random_client_key())) { + let actual_result = is_client_registered(&test_client_key); + prop_assert_eq!(actual_result, false); } #[test] - fn test_check_registered_client_principal(test_client_principal in any::().prop_map(|_| test_utils::generate_random_principal())) { + fn test_is_client_registered(test_client_key in any::().prop_map(|_| test_utils::get_random_client_key())) { // Set up REGISTERED_CLIENTS.with(|map| { - map.borrow_mut().insert(test_client_principal.clone(), test_utils::generate_random_registered_client()); + map.borrow_mut().insert(test_client_key.clone(), test_utils::generate_random_registered_client()); }); - let actual_result = check_registered_client(&test_client_principal); + let actual_result = is_client_registered(&test_client_key); + prop_assert_eq!(actual_result, true); + } + + #[test] + fn test_get_client_key_from_principal_empty(test_client_principal in any::().prop_map(|_| test_utils::generate_random_principal())) { + let actual_result = get_client_key_from_principal(&test_client_principal); + prop_assert_eq!(actual_result.err(), Some(String::from(format!( + "client with principal {} doesn't have an open connection", + test_client_principal + )))); + } + + #[test] + fn test_get_client_key_from_principal(test_client_key in any::().prop_map(|_| test_utils::get_random_client_key())) { + // Set up + CURRENT_CLIENT_KEY_MAP.with(|map| { + map.borrow_mut().insert(test_client_key.client_principal, test_client_key.clone()); + }); + + let actual_result = get_client_key_from_principal(&test_client_key.client_principal); + prop_assert_eq!(actual_result.unwrap(), test_client_key); + } + + #[test] + fn test_check_registered_client_empty(test_client_key in any::().prop_map(|_| test_utils::get_random_client_key())) { + let actual_result = check_registered_client(&test_client_key); + prop_assert_eq!(actual_result.err(), Some(format!("client with key {} doesn't have an open connection", test_client_key))); + } + + #[test] + fn test_check_registered_client(test_client_key in any::().prop_map(|_| test_utils::get_random_client_key())) { + // Set up + REGISTERED_CLIENTS.with(|map| { + map.borrow_mut().insert(test_client_key.clone(), test_utils::generate_random_registered_client()); + }); + + let actual_result = check_registered_client(&test_client_key); prop_assert!(actual_result.is_ok()); - let non_existing_client_principal = test_utils::generate_random_principal(); - let actual_result = check_registered_client(&non_existing_client_principal); - prop_assert_eq!(actual_result.err(), Some(format!("client with principal {:?} doesn't have an open connection", non_existing_client_principal))); + let non_existing_client_key = test_utils::get_random_client_key(); + let actual_result = check_registered_client(&non_existing_client_key); + prop_assert_eq!(actual_result.err(), Some(format!("client with key {} doesn't have an open connection", non_existing_client_key))); } #[test] - fn test_init_outgoing_message_to_client_num(test_client_principal in any::().prop_map(|_| test_utils::generate_random_principal())) { - init_outgoing_message_to_client_num(test_client_principal.clone()); + fn test_init_outgoing_message_to_client_num(test_client_key in any::().prop_map(|_| test_utils::get_random_client_key())) { + init_outgoing_message_to_client_num(test_client_key.clone()); - let actual_result = OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| map.borrow().get(&test_client_principal).unwrap().clone()); + let actual_result = OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| map.borrow().get(&test_client_key).unwrap().clone()); prop_assert_eq!(actual_result, 0); } #[test] - fn test_increment_outgoing_message_to_client_num(test_client_principal in any::().prop_map(|_| test_utils::generate_random_principal()), test_num in any::()) { + fn test_increment_outgoing_message_to_client_num(test_client_key in any::().prop_map(|_| test_utils::get_random_client_key()), test_num in any::()) { // Set up OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| { - map.borrow_mut().insert(test_client_principal.clone(), test_num); + map.borrow_mut().insert(test_client_key.clone(), test_num); }); - let increment_result = increment_outgoing_message_to_client_num(&test_client_principal); + let increment_result = increment_outgoing_message_to_client_num(&test_client_key); prop_assert!(increment_result.is_ok()); - let actual_result = OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| map.borrow().get(&test_client_principal).unwrap().clone()); + let actual_result = OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| map.borrow().get(&test_client_key).unwrap().clone()); prop_assert_eq!(actual_result, test_num + 1); } #[test] - fn test_get_outgoing_message_to_client_num(test_client_principal in any::().prop_map(|_| test_utils::generate_random_principal()), test_num in any::()) { + fn test_get_outgoing_message_to_client_num(test_client_key in any::().prop_map(|_| test_utils::get_random_client_key()), test_num in any::()) { // Set up OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| { - map.borrow_mut().insert(test_client_principal.clone(), test_num); + map.borrow_mut().insert(test_client_key.clone(), test_num); }); - let actual_result = get_outgoing_message_to_client_num(&test_client_principal); + let actual_result = get_outgoing_message_to_client_num(&test_client_key); prop_assert!(actual_result.is_ok()); prop_assert_eq!(actual_result.unwrap(), test_num); } #[test] - fn test_init_expected_incoming_message_from_client_num(test_client_principal in any::().prop_map(|_| test_utils::generate_random_principal())) { - init_expected_incoming_message_from_client_num(test_client_principal.clone()); + fn test_init_expected_incoming_message_from_client_num(test_client_key in any::().prop_map(|_| test_utils::get_random_client_key())) { + init_expected_incoming_message_from_client_num(test_client_key.clone()); - let actual_result = INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| map.borrow().get(&test_client_principal).unwrap().clone()); + let actual_result = INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| map.borrow().get(&test_client_key).unwrap().clone()); prop_assert_eq!(actual_result, 1); } #[test] - fn test_get_expected_incoming_message_from_client_num(test_client_principal in any::().prop_map(|_| test_utils::generate_random_principal()), test_num in any::()) { + fn test_get_expected_incoming_message_from_client_num(test_client_key in any::().prop_map(|_| test_utils::get_random_client_key()), test_num in any::()) { // Set up INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| { - map.borrow_mut().insert(test_client_principal.clone(), test_num); + map.borrow_mut().insert(test_client_key.clone(), test_num); }); - let actual_result = get_expected_incoming_message_from_client_num(&test_client_principal); + let actual_result = get_expected_incoming_message_from_client_num(&test_client_key); prop_assert!(actual_result.is_ok()); prop_assert_eq!(actual_result.unwrap(), test_num); } #[test] - fn test_increment_expected_incoming_message_from_client_num(test_client_principal in any::().prop_map(|_| test_utils::generate_random_principal()), test_num in any::()) { + fn test_increment_expected_incoming_message_from_client_num(test_client_key in any::().prop_map(|_| test_utils::get_random_client_key()), test_num in any::()) { // Set up INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| { - map.borrow_mut().insert(test_client_principal.clone(), test_num); + map.borrow_mut().insert(test_client_key.clone(), test_num); }); - let increment_result = increment_expected_incoming_message_from_client_num(&test_client_principal); + let increment_result = increment_expected_incoming_message_from_client_num(&test_client_key); prop_assert!(increment_result.is_ok()); - let actual_result = INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| map.borrow().get(&test_client_principal).unwrap().clone()); + let actual_result = INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| map.borrow().get(&test_client_key).unwrap().clone()); prop_assert_eq!(actual_result, test_num + 1); } #[test] - fn test_add_client(test_client_principal in any::().prop_map(|_| test_utils::generate_random_principal())) { + fn test_add_client(test_client_key in any::().prop_map(|_| test_utils::get_random_client_key())) { let registered_client = test_utils::generate_random_registered_client(); // Test - add_client(test_client_principal.clone(), registered_client.clone()); + add_client(test_client_key.clone(), registered_client.clone()); + + let actual_result = CURRENT_CLIENT_KEY_MAP.with(|map| map.borrow().get(&test_client_key.client_principal).unwrap().clone()); + prop_assert_eq!(actual_result, test_client_key.clone()); - let actual_result = REGISTERED_CLIENTS.with(|map| map.borrow().get(&test_client_principal).unwrap().clone()); + let actual_result = REGISTERED_CLIENTS.with(|map| map.borrow().get(&test_client_key).unwrap().clone()); prop_assert_eq!(actual_result, registered_client); - let actual_result = INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| map.borrow().get(&test_client_principal).unwrap().clone()); + let actual_result = INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| map.borrow().get(&test_client_key).unwrap().clone()); prop_assert_eq!(actual_result, 1); - let actual_result = OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| map.borrow().get(&test_client_principal).unwrap().clone()); + let actual_result = OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| map.borrow().get(&test_client_key).unwrap().clone()); prop_assert_eq!(actual_result, 0); } #[test] - fn test_remove_client(test_client_principal in any::().prop_map(|_| test_utils::generate_random_principal())) { + fn test_remove_client(test_client_key in any::().prop_map(|_| test_utils::get_random_client_key())) { // Set up + CURRENT_CLIENT_KEY_MAP.with(|map| { + map.borrow_mut().insert(test_client_key.client_principal.clone(), test_client_key.clone()); + }); REGISTERED_CLIENTS.with(|map| { - map.borrow_mut().insert(test_client_principal.clone(), test_utils::generate_random_registered_client()); + map.borrow_mut().insert(test_client_key.clone(), test_utils::generate_random_registered_client()); }); INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| { - map.borrow_mut().insert(test_client_principal.clone(), 0); + map.borrow_mut().insert(test_client_key.clone(), 0); }); OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| { - map.borrow_mut().insert(test_client_principal.clone(), 0); + map.borrow_mut().insert(test_client_key.clone(), 0); }); - remove_client(&test_client_principal); + remove_client(&test_client_key); + + let is_none = CURRENT_CLIENT_KEY_MAP.with(|map| map.borrow().get(&test_client_key.client_principal).is_none()); + prop_assert!(is_none); - let is_none = REGISTERED_CLIENTS.with(|map| map.borrow().get(&test_client_principal).is_none()); + let is_none = REGISTERED_CLIENTS.with(|map| map.borrow().get(&test_client_key).is_none()); prop_assert!(is_none); - let is_none = INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| map.borrow().get(&test_client_principal).is_none()); + let is_none = INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| map.borrow().get(&test_client_key).is_none()); prop_assert!(is_none); - let is_none = OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| map.borrow().get(&test_client_principal).is_none()); + let is_none = OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| map.borrow().get(&test_client_key).is_none()); prop_assert!(is_none); } @@ -1291,8 +1398,8 @@ mod test { REGISTERED_GATEWAY.with(|p| *p.borrow_mut() = Some(RegisteredGateway::new(gateway_principal.clone()))); let messages_count = 4; - let test_client_principal = test_utils::generate_random_principal(); - test_utils::add_messages_for_gateway(test_client_principal.clone(), gateway_principal, messages_count); + let test_client_key = test_utils::get_random_client_key(); + test_utils::add_messages_for_gateway(test_client_key, gateway_principal, messages_count); // Test // messages are just 4, so we don't exceed the max number of returned messages @@ -1313,8 +1420,8 @@ mod test { REGISTERED_GATEWAY.with(|p| *p.borrow_mut() = Some(RegisteredGateway::new(gateway_principal.clone()))); let messages_count: u64 = (2 * MAX_NUMBER_OF_RETURNED_MESSAGES).try_into().unwrap(); - let test_client_principal = test_utils::generate_random_principal(); - test_utils::add_messages_for_gateway(test_client_principal.clone(), gateway_principal, messages_count); + let test_client_key = test_utils::get_random_client_key(); + test_utils::add_messages_for_gateway(test_client_key, gateway_principal, messages_count); // Test // messages are now 2 * MAX_NUMBER_OF_RETURNED_MESSAGES @@ -1339,8 +1446,8 @@ mod test { // Set up REGISTERED_GATEWAY.with(|p| *p.borrow_mut() = Some(RegisteredGateway::new(gateway_principal.clone()))); - let test_client_principal = test_utils::generate_random_principal(); - test_utils::add_messages_for_gateway(test_client_principal.clone(), gateway_principal, messages_count); + let test_client_key = test_utils::get_random_client_key(); + test_utils::add_messages_for_gateway(test_client_key, gateway_principal, messages_count); // Test let (start_index, end_index) = get_messages_for_gateway_range(gateway_principal, 0); @@ -1361,8 +1468,8 @@ mod test { // Set up REGISTERED_GATEWAY.with(|p| *p.borrow_mut() = Some(RegisteredGateway::new(gateway_principal.clone()))); - let test_client_principal = test_utils::generate_random_principal(); - test_utils::add_messages_for_gateway(test_client_principal.clone(), gateway_principal, messages_count); + let test_client_key = test_utils::get_random_client_key(); + test_utils::add_messages_for_gateway(test_client_key, gateway_principal, messages_count); // Test // add one to test the out of range index @@ -1398,7 +1505,7 @@ mod test { fn test_serialize_websocket_message(test_msg_bytes in any::>(), test_sequence_num in any::(), test_timestamp in any::()) { // TODO: add more tests, in which we check the serialized message let websocket_message = WebsocketMessage { - client_principal: test_utils::generate_random_principal(), + client_key: test_utils::get_random_client_key(), sequence_num: test_sequence_num, timestamp: test_timestamp, is_service_message: false, diff --git a/src/ic-websocket-cdk/ws_types.did b/src/ic-websocket-cdk/ws_types.did index 84ac40d..ac375d7 100644 --- a/src/ic-websocket-cdk/ws_types.did +++ b/src/ic-websocket-cdk/ws_types.did @@ -1,7 +1,11 @@ type ClientPrincipal = principal; +type ClientKey = record { + client_principal : ClientPrincipal; + client_nonce : nat64; +}; type WebsocketMessage = record { - client_principal : ClientPrincipal; + client_key : ClientKey; sequence_num : nat64; timestamp : nat64; is_service_message : bool; @@ -9,7 +13,7 @@ type WebsocketMessage = record { }; type CanisterOutputMessage = record { - client_principal : ClientPrincipal; + client_key : ClientKey; key : text; content : blob; }; @@ -20,7 +24,9 @@ type CanisterOutputCertifiedMessages = record { tree : blob; }; -type CanisterWsOpenArguments = null; +type CanisterWsOpenArguments = record { + client_nonce : nat64; +}; type CanisterWsOpenResult = variant { Ok : null; @@ -28,7 +34,7 @@ type CanisterWsOpenResult = variant { }; type CanisterWsCloseArguments = record { - client_principal : ClientPrincipal; + client_key : ClientKey; }; type CanisterWsCloseResult = variant { From 78fec4b3415c54388161bd30eb244091be20d035 Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Wed, 20 Sep 2023 21:22:34 +0200 Subject: [PATCH 23/31] fix: avoid borrow errors --- src/ic-websocket-cdk/src/lib.rs | 49 +++++++++++++++++++-------------- 1 file changed, 29 insertions(+), 20 deletions(-) diff --git a/src/ic-websocket-cdk/src/lib.rs b/src/ic-websocket-cdk/src/lib.rs index 7c9ba8c..76306c4 100644 --- a/src/ic-websocket-cdk/src/lib.rs +++ b/src/ic-websocket-cdk/src/lib.rs @@ -202,14 +202,16 @@ thread_local! { /// Resets all RefCells to their initial state. /// If there is a registered gateway, resets its state as well. fn reset_internal_state() { - REGISTERED_CLIENTS.with(|state| { + let client_keys_to_remove = REGISTERED_CLIENTS.with(|state| { let map = state.borrow(); - // for each client, call the on_close handler before clearing the map - for (client_key, _) in map.clone().iter() { - remove_client(client_key); - } + map.keys().cloned().collect::>() }); + // for each client, call the on_close handler before clearing the map + for client_key in client_keys_to_remove { + remove_client(&client_key); + } + CURRENT_CLIENT_KEY_MAP.with(|map| { map.borrow_mut().clear(); }); @@ -580,24 +582,31 @@ fn send_ack_to_clients_timer_callback() { /// Checks if the registered clients have sent a keep alive message. /// If a client has not sent a keep alive message, it is removed from the registered clients. fn check_keep_alive_timer_callback() { - REGISTERED_CLIENTS.with(|state| { + let client_keys_to_remove: Vec = REGISTERED_CLIENTS.with(|state| { let map = state.borrow(); - for (client_key, client_metadata) in map.iter() { - let current_time = get_current_time(); - if current_time - client_metadata.get_last_keep_alive_timestamp() - > DEFAULT_CLIENT_KEEP_ALIVE_DELAY_MS - { - remove_client(client_key); - - custom_print!( - "[check-keep-alive-timer-cb]: Client {} has not sent a keep alive message in the last {} ms and has been removed", - client_key, - DEFAULT_CLIENT_KEEP_ALIVE_DELAY_MS - ); - } - } + map.iter() + .filter_map(|(client_key, client_metadata)| { + let current_time = get_current_time(); + let last_keep_alive = client_metadata.get_last_keep_alive_timestamp(); + if current_time - last_keep_alive > DEFAULT_CLIENT_KEEP_ALIVE_DELAY_MS { + Some(client_key.to_owned()) + } else { + None + } + }) + .collect() }); + for client_key in client_keys_to_remove { + remove_client(&client_key); + + custom_print!( + "[check-keep-alive-timer-cb]: Client {} has not sent a keep alive message in the last {} ms and has been removed", + client_key, + DEFAULT_CLIENT_KEEP_ALIVE_DELAY_MS + ); + } + custom_print!("[check-keep-alive-timer-cb]: Checked keep alive messages for all clients"); } From 8576f931b4bddf2e97d5b0864745634a085e18c9 Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Wed, 20 Sep 2023 21:46:19 +0200 Subject: [PATCH 24/31] fix: don't check keep alive sequence number --- src/ic-websocket-cdk/src/lib.rs | 20 ++------------------ 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/src/ic-websocket-cdk/src/lib.rs b/src/ic-websocket-cdk/src/lib.rs index 76306c4..3a818b5 100644 --- a/src/ic-websocket-cdk/src/lib.rs +++ b/src/ic-websocket-cdk/src/lib.rs @@ -614,24 +614,8 @@ fn handle_keep_alive_client_message(client_key: &ClientKey, content: &[u8]) -> R match decode_one::(content) { Ok(message_content) => { match message_content { - WebsocketServiceMessageContent::KeepAliveMessage(keep_alive_message) => { - // first, we check if the client received the last message sent by the canister - let last_outgoing_message_sequence_num = - get_outgoing_message_to_client_num(client_key)?; - - // if the client has not received the last message sent by the canister, we remove it - if last_outgoing_message_sequence_num - != keep_alive_message.last_incoming_sequence_num - { - custom_print!( - "client {} has not received the last message sent by the canister, removing the client", - client_key - ); - - remove_client(client_key); - - return Ok(()); - } + WebsocketServiceMessageContent::KeepAliveMessage(_keep_alive_message) => { + // TODO: delete messages from the queue that have been acknowledged by the client // update the last keep alive timestamp for the client REGISTERED_CLIENTS.with(|map| { From 5eb26ed71cf9dc57bbfd36fc3758c69fe5d907e5 Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Thu, 21 Sep 2023 11:49:06 +0200 Subject: [PATCH 25/31] refactor: collect and current time --- src/ic-websocket-cdk/src/lib.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/ic-websocket-cdk/src/lib.rs b/src/ic-websocket-cdk/src/lib.rs index 3a818b5..b9c6002 100644 --- a/src/ic-websocket-cdk/src/lib.rs +++ b/src/ic-websocket-cdk/src/lib.rs @@ -202,9 +202,9 @@ thread_local! { /// Resets all RefCells to their initial state. /// If there is a registered gateway, resets its state as well. fn reset_internal_state() { - let client_keys_to_remove = REGISTERED_CLIENTS.with(|state| { + let client_keys_to_remove: Vec = REGISTERED_CLIENTS.with(|state| { let map = state.borrow(); - map.keys().cloned().collect::>() + map.keys().cloned().collect() }); // for each client, call the on_close handler before clearing the map @@ -586,9 +586,8 @@ fn check_keep_alive_timer_callback() { let map = state.borrow(); map.iter() .filter_map(|(client_key, client_metadata)| { - let current_time = get_current_time(); let last_keep_alive = client_metadata.get_last_keep_alive_timestamp(); - if current_time - last_keep_alive > DEFAULT_CLIENT_KEEP_ALIVE_DELAY_MS { + if get_current_time() - last_keep_alive > DEFAULT_CLIENT_KEEP_ALIVE_DELAY_MS { Some(client_key.to_owned()) } else { None From 5be8d03c440c6c85825f7aefc10628b04b56c067 Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Mon, 25 Sep 2023 08:53:18 +0200 Subject: [PATCH 26/31] wip: integration tests --- Cargo.lock | 28 +- tests/integration/canister.test.ts | 1312 ++++++++++----------------- tests/integration/utils/actors.ts | 2 + tests/integration/utils/api.ts | 157 ++-- tests/integration/utils/crypto.ts | 22 - tests/integration/utils/identity.ts | 5 + tests/integration/utils/idl.ts | 67 ++ tests/integration/utils/messages.ts | 31 + tests/integration/utils/random.ts | 20 + tests/package.json | 2 +- tests/src/lib.rs | 17 +- tests/test_canister.did | 3 +- 12 files changed, 711 insertions(+), 955 deletions(-) delete mode 100644 tests/integration/utils/crypto.ts create mode 100644 tests/integration/utils/idl.ts create mode 100644 tests/integration/utils/messages.ts create mode 100644 tests/integration/utils/random.ts diff --git a/Cargo.lock b/Cargo.lock index 5791941..61394ca 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -181,9 +181,9 @@ checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" [[package]] name = "candid" -version = "0.9.6" +version = "0.9.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88f6eec0ae850e006ef0fe306f362884d370624094ec55a6a26de18b251774be" +checksum = "f391a0d11d997af68e1a06b5e2ab354079cecb82b6eefb26addb38adf66d351d" dependencies = [ "anyhow", "binread", @@ -628,9 +628,9 @@ checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a" [[package]] name = "hermit-abi" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "443144c8cdadd93ebf52ddb4056d257f5b52c04d3c804e657d19eb73fc33668b" +checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7" [[package]] name = "hex" @@ -1352,9 +1352,9 @@ checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" [[package]] name = "rustix" -version = "0.38.13" +version = "0.38.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7db8590df6dfcd144d22afd1b83b36c21a18d7cbc1dc4bb5295a8712e9eb662" +checksum = "747c788e9ce8e92b12cd485c49ddf90723550b654b32508f979b71a7b1ecda4f" dependencies = [ "bitflags 2.4.0", "errno", @@ -1398,9 +1398,9 @@ dependencies = [ [[package]] name = "rustls-webpki" -version = "0.101.5" +version = "0.101.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45a27e3b59326c16e23d30aeb7a36a24cc0d29e71d68ff611cdfb4a01d013bed" +checksum = "3c7d5dece342910d9ba34d259310cae3e0154b873b35408b787b59bce53d34fe" dependencies = [ "ring", "untrusted", @@ -1804,9 +1804,9 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.8" +version = "0.7.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "806fe8c2c87eccc8b3267cbae29ed3ab2d0bd37fca70ab622e46aaa9375ddb7d" +checksum = "1d68074620f57a0b21594d9735eb2e98ab38b17f80d3fcb189fca266771ca60d" dependencies = [ "bytes", "futures-core", @@ -1912,9 +1912,9 @@ checksum = "1dd624098567895118886609431a7c3b8f516e41d30e0643f03d94592a147e36" [[package]] name = "unicode-width" -version = "0.1.10" +version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0edd1e5b14653f783770bce4a4dabb4a5108a5370a5f5d8cfe8710c361f6c8b" +checksum = "e51733f11c9c4f72aa0c160008246859e340b00807569a0da0e7a1079b27ba85" [[package]] name = "untrusted" @@ -2086,9 +2086,9 @@ checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" [[package]] name = "winapi-util" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70ec6ce85bb158151cae5e5c87f95a8e97d2c0c4b001223f33a334e3ce5de178" +checksum = "f29e6f9198ba0d26b4c9f07dbe6f9ed633e1f3d5b8b414090084349e46a52596" dependencies = [ "winapi", ] diff --git a/tests/integration/canister.test.ts b/tests/integration/canister.test.ts index 5ab291c..ecf5eaa 100644 --- a/tests/integration/canister.test.ts +++ b/tests/integration/canister.test.ts @@ -2,22 +2,24 @@ import { IDL } from "@dfinity/candid"; import { Principal } from "@dfinity/principal"; import { Cbor } from "@dfinity/agent"; import { + anonymousClient, canisterId, client1, + client1Data, client2, + client2Data, commonAgent, gateway1, gateway2, } from "./utils/actors"; -import { getKeyPair, getMessageSignature } from "./utils/crypto"; import { - getWebsocketMessage, isMessageBodyValid, isValidCertificate, + reinitialize, wsClose, + wsGetMessages, wsMessage, wsOpen, - wsRegister, wsSend, wsWipe, } from "./utils/api"; @@ -27,583 +29,263 @@ import type { CanisterWsGetMessagesResult, CanisterWsMessageResult, CanisterWsOpenResult, - CanisterWsRegisterResult, CanisterWsSendResult, + ClientKey, + WebsocketMessage, } from "../src/declarations/test_canister/test_canister.did"; -import type { WebsocketMessage } from "./utils/api"; +import { generateClientKey, getRandomClientNonce } from "./utils/random"; +import { CanisterOpenMessageContent, WebsocketServiceMessageContent, encodeWebsocketServiceMessageContent, getServiceMessageFromCanisterMessage, isClientKeyEq } from "./utils/idl"; +import { createWebsocketMessage, decodeWebsocketMessage, filterServiceMessagesFromCanisterMessages } from "./utils/messages"; const MAX_NUMBER_OF_RETURNED_MESSAGES = 10; // set in the CDK const SEND_MESSAGES_COUNT = MAX_NUMBER_OF_RETURNED_MESSAGES + 2; // test with more messages to check the indexes and limits const MAX_GATEWAY_KEEP_ALIVE_TIME_MS = 15_000; // set in the CDK +const DEFAULT_TEST_SEND_ACK_INTERVAL_MS = 300_000; // 5 minutes to make sure the canister doesn't reset the client +const DEFAULT_TEST_KEEP_ALIVE_DELAY_MS = 300_000; // 5 minutes to make sure the canister doesn't reset the client -let client1KeyPair: { publicKey: Uint8Array; secretKey: Uint8Array | string; }; -let client2KeyPair: { publicKey: Uint8Array; secretKey: Uint8Array | string; }; +let client1Key: ClientKey; +let client2Key: ClientKey; -// the status index used by the gateway to send a keep-alive message -let gatewayStatusIndex = 0; - -const sendGatewayStatusMessage = async (index?: number) => { - const statusIndex = index !== undefined ? index : gatewayStatusIndex; - - await wsMessage({ - message: { - IcWebSocketGatewayStatus: { - status_index: BigInt(statusIndex), - } - }, - actor: gateway1, - }, true); - - gatewayStatusIndex += 1; -}; - -const assignKeyPairsToClients = async () => { - if (!client1KeyPair) { - client1KeyPair = await getKeyPair(); +const assignKeysToClients = async () => { + if (!client1Key) { + client1Key = generateClientKey((await client1Data.identity).getPrincipal()); } - if (!client2KeyPair) { - client2KeyPair = await getKeyPair(); + if (!client2Key) { + client2Key = generateClientKey((await client2Data.identity).getPrincipal()); } }; // testing again canister takes quite a while jest.setTimeout(60_000); -describe("Canister - ws_register", () => { - beforeAll(async () => { - await assignKeyPairsToClients(); - }); - - afterAll(async () => { - await wsWipe(gateway1); - }); - - it("should register a client", async () => { - const res = await wsRegister({ - clientActor: client1, - clientKey: client1KeyPair.publicKey, - }); - - expect(res).toMatchObject({ - Ok: null, - }); - }); -}); - describe("Canister - ws_open", () => { beforeAll(async () => { - await assignKeyPairsToClients(); - - await wsRegister({ - clientActor: client1, - clientKey: client1KeyPair.publicKey, - }, true); + await assignKeysToClients(); }); afterAll(async () => { - await wsWipe(gateway1); + await wsWipe(); }); - beforeEach(async () => { - await sendGatewayStatusMessage(); - }); - - it("fails for a gateway which is not registered", async () => { + it("fails for an anonymous client", async () => { const res = await wsOpen({ - clientPublicKey: client1KeyPair.publicKey, - clientSecretKey: client1KeyPair.secretKey, canisterId, - gatewayActor: gateway2, - }); + clientActor: anonymousClient, + clientNonce: getRandomClientNonce(), + }) expect(res).toMatchObject({ - Err: "caller is not the gateway that has been registered during CDK initialization", + Err: "anonymous principal cannot open a connection", }); }); - it("fails if a registered gateway relays a wrong first message", async () => { - // empty message - let content = Cbor.encode({}) - let res = await gateway1.ws_open({ - content: new Uint8Array(content), - sig: await getMessageSignature(content, client1KeyPair.secretKey), - }); - expect(res).toMatchObject({ - Err: "missing field `client_key`", - }); - - // with client_key - content = Cbor.encode({ - client_key: client1KeyPair.publicKey, - }); - res = await gateway1.ws_open({ - content: new Uint8Array(content), - sig: await getMessageSignature(content, client1KeyPair.secretKey), - }); - expect(res).toMatchObject({ - Err: "missing field `canister_id`", - }); - }); - - it("fails for a client which is not registered", async () => { + it("fails for the registered gateway", async () => { const res = await wsOpen({ - clientPublicKey: client2KeyPair.publicKey, - clientSecretKey: client2KeyPair.secretKey, canisterId, - gatewayActor: gateway1, + clientActor: gateway1, + clientNonce: getRandomClientNonce(), }); expect(res).toMatchObject({ - Err: "client's public key has not been previously registered by client", + Err: "caller is the registered gateway which can't open a connection for itself", }); }); - it("fails for an invalid signature", async () => { - // sign message with client2 secret key but send client1 public key + it("should open a connection", async () => { const res = await wsOpen({ - clientPublicKey: client1KeyPair.publicKey, - clientSecretKey: client2KeyPair.secretKey, canisterId, - gatewayActor: gateway1, + clientActor: client1, + clientNonce: client1Key.client_nonce, }); expect(res).toMatchObject({ - Err: "Signature doesn't verify", + Ok: null, }); - }); - - it("fails for a client which is not registered after the gateway has been reset", async () => { - await sendGatewayStatusMessage(0); - const res = await wsOpen({ - clientPublicKey: client2KeyPair.publicKey, - clientSecretKey: client2KeyPair.secretKey, - canisterId, + const msgs = await wsGetMessages({ + fromNonce: 0, gatewayActor: gateway1, }); - expect(res).toMatchObject({ - Err: "client's public key has not been previously registered by client", + const serviceMessages = filterServiceMessagesFromCanisterMessages(msgs.messages); + + expect(isClientKeyEq(serviceMessages[0].client_key, client1Key)).toBe(true); + const openMessage = getServiceMessageFromCanisterMessage(serviceMessages[0]); + expect(openMessage).toMatchObject({ + OpenMessage: expect.any(Object), }); + const openMessageContent = (openMessage as { OpenMessage: CanisterOpenMessageContent }).OpenMessage; + expect(isClientKeyEq(openMessageContent.client_key, client1Key)).toBe(true); }); - it("fails for a client which is registered, but after the gateway increased the status index by two and then been reset", async () => { - // reset the canister state from the previous test - await wsWipe(gateway1); - // register the client again - await wsRegister({ - clientActor: client1, - clientKey: client1KeyPair.publicKey, - }, true); - - // send two status messages to make the client key shift out of the tmp ones - await sendGatewayStatusMessage(); - await sendGatewayStatusMessage(); - - // reset the gateway on the canister - await sendGatewayStatusMessage(0); - + it("fails for a client with the same nonce", async () => { const res = await wsOpen({ - clientPublicKey: client1KeyPair.publicKey, - clientSecretKey: client1KeyPair.secretKey, canisterId, - gatewayActor: gateway1, + clientActor: client1, + clientNonce: client1Key.client_nonce, }); expect(res).toMatchObject({ - Err: "client's public key has not been previously registered by client", + Err: `client with key ${client1Key.client_principal.toText()}_${client1Key.client_nonce} already has an open connection`, }); }); - it("should open the websocket for a registered client after gateway has been reset", async () => { - // reset the canister state from the previous test - await wsWipe(gateway1); - // register the client again - await wsRegister({ - clientActor: client1, - clientKey: client1KeyPair.publicKey, - }, true); - - // reset the gateway on the canister - await sendGatewayStatusMessage(0); - + it("should open a connection for the same client with a different nonce", async () => { + const clientKey = { + ...client1Key, + client_nonce: getRandomClientNonce(), + } const res = await wsOpen({ - clientPublicKey: client1KeyPair.publicKey, - clientSecretKey: client1KeyPair.secretKey, canisterId, - gatewayActor: gateway1, + clientActor: client1, + clientNonce: clientKey.client_nonce, }); expect(res).toMatchObject({ - Ok: { - client_key: client1KeyPair.publicKey, - canister_id: Principal.fromText(canisterId), - nonce: BigInt(0), - }, + Ok: null, }); - }); - it("should open the websocket for a registered client", async () => { - // reset the canister state from the previous test - await wsWipe(gateway1); - // setup the canister state again - await sendGatewayStatusMessage(); - await wsRegister({ - clientActor: client1, - clientKey: client1KeyPair.publicKey, - }, true); - - // open the websocket - const res = await wsOpen({ - clientPublicKey: client1KeyPair.publicKey, - clientSecretKey: client1KeyPair.secretKey, - canisterId, + const msgs = await wsGetMessages({ + fromNonce: 0, gatewayActor: gateway1, }); - expect(res).toMatchObject({ - Ok: { - client_key: client1KeyPair.publicKey, - canister_id: Principal.fromText(canisterId), - nonce: BigInt(0), - }, + const serviceMessages = filterServiceMessagesFromCanisterMessages(msgs.messages); + const serviceMessagesForClient = serviceMessages.filter((msg) => isClientKeyEq(msg.client_key, clientKey)); + + const openMessage = getServiceMessageFromCanisterMessage(serviceMessagesForClient[0]); + expect(openMessage).toMatchObject({ + OpenMessage: expect.any(Object), }); + const openMessageContent = (openMessage as { OpenMessage: CanisterOpenMessageContent }).OpenMessage; + expect(isClientKeyEq(openMessageContent.client_key, clientKey)).toBe(true); }); }); describe("Canister - ws_message", () => { beforeAll(async () => { - await assignKeyPairsToClients(); - - await wsRegister({ - clientActor: client1, - clientKey: client1KeyPair.publicKey, - }, true); + await assignKeysToClients(); await wsOpen({ - clientPublicKey: client1KeyPair.publicKey, - clientSecretKey: client1KeyPair.secretKey, + clientNonce: client1Key.client_nonce, canisterId, - gatewayActor: gateway1, + clientActor: client1, }, true); }); afterAll(async () => { - await wsWipe(gateway1); - }); - - beforeEach(async () => { - await sendGatewayStatusMessage(); + await wsWipe(); }); - it("fails if a non registered gateway sends an IcWebSocketEstablished message", async () => { - const res = await wsMessage({ - message: { - IcWebSocketEstablished: client1KeyPair.publicKey, - }, - actor: gateway2, - }); - - expect(res).toMatchObject({ - Err: "caller is not the gateway that has been registered during CDK initialization", - }); - }); - - it("fails if a non registered gateway sends a RelayedByGateway message", async () => { - const content = getWebsocketMessage(client1KeyPair.publicKey, 0); - const res = await wsMessage({ - message: { - RelayedByGateway: { - content, - sig: await getMessageSignature(content, client1KeyPair.secretKey), - } - }, - actor: gateway2, - }); - - expect(res).toMatchObject({ - Err: "caller is not the gateway that has been registered during CDK initialization", - }); - }); - - it("fails if a non registered client sends a DirectlyFromClient message", async () => { - const message = IDL.encode([IDL.Record({ 'text': IDL.Text })], [{ text: "pong" }]); - const res = await wsMessage({ - message: { - DirectlyFromClient: { - client_key: client2KeyPair.publicKey, - message: new Uint8Array(message), - } - }, - actor: client2, - }); - - expect(res).toMatchObject({ - Err: "client is not registered, call ws_register first", - }); - }); - - it("fails if a non registered client sends a DirectlyFromClient message using a registered client key", async () => { - const message = IDL.encode([IDL.Record({ 'text': IDL.Text })], [{ text: "pong" }]); + it("fails if client is not registered", async () => { const res = await wsMessage({ - message: { - DirectlyFromClient: { - client_key: client1KeyPair.publicKey, - message: new Uint8Array(message), - } - }, + message: createWebsocketMessage(client2Key, 0), actor: client2, }); expect(res).toMatchObject({ - Err: "caller is not the same that registered the public key", + Err: `client with principal ${client2Key.client_principal.toText()} doesn't have an open connection`, }); }); - it("fails if a registered gateway sends an IcWebSocketEstablished message for a non registered client", async () => { + it("fails if client sends a message with a different client key", async () => { + // first, send a message with a different principal const res = await wsMessage({ - message: { - IcWebSocketEstablished: client2KeyPair.publicKey, - }, - actor: gateway1, - }); - - expect(res).toMatchObject({ - Err: "client's public key has not been previously registered by client", - }); - }); - - it("fails if a registered gateway sends a wrong RelayedByGateway message", async () => { - // empty message - let content = Cbor.encode({}); - let res = await wsMessage({ - message: { - RelayedByGateway: { - content: new Uint8Array(content), - sig: await getMessageSignature(content, client2KeyPair.secretKey), - }, - }, - actor: gateway1, - }); - expect(res).toMatchObject({ - Err: "missing field `client_key`", + message: createWebsocketMessage({ ...client1Key, client_principal: client2Key.client_principal }, 0), + actor: client1, }); - // with client_key - content = Cbor.encode({ - client_key: client1KeyPair.publicKey, - }); - res = await wsMessage({ - message: { - RelayedByGateway: { - content: new Uint8Array(content), - sig: await getMessageSignature(content, client2KeyPair.secretKey), - }, - }, - actor: gateway1, - }); expect(res).toMatchObject({ - Err: "missing field `sequence_num`", + Err: `client with principal ${client1Key.client_principal.toText()} has a different key than the one used in the message`, }); - // with client_key, sequence_num - content = Cbor.encode({ - client_key: client1KeyPair.publicKey, - sequence_num: 0, - }); - res = await wsMessage({ - message: { - RelayedByGateway: { - content: new Uint8Array(content), - sig: await getMessageSignature(content, client2KeyPair.secretKey), - }, - }, - actor: gateway1, - }); - expect(res).toMatchObject({ - Err: "missing field `timestamp`", + // then, send a message with a different nonce + const res2 = await wsMessage({ + message: createWebsocketMessage({ ...client1Key, client_nonce: getRandomClientNonce() }, 0), + actor: client1, }); - // with client_key, sequence_num, timestamp - content = Cbor.encode({ - client_key: client1KeyPair.publicKey, - sequence_num: 0, - timestamp: Date.now(), - }); - res = await wsMessage({ - message: { - RelayedByGateway: { - content: new Uint8Array(content), - sig: await getMessageSignature(content, client2KeyPair.secretKey), - }, - }, - actor: gateway1, - }); - expect(res).toMatchObject({ - Err: "missing field `message`", + expect(res2).toMatchObject({ + Err: `client with principal ${client1Key.client_principal.toText()} has a different key than the one used in the message`, }); }); - it("fails if a registered gateway sends a RelayedByGateway message with an invalid signature", async () => { - const content = getWebsocketMessage(client1KeyPair.publicKey, 0); + it("should send a message from a registered client", async () => { const res = await wsMessage({ - message: { - RelayedByGateway: { - content, - sig: await getMessageSignature(content, client2KeyPair.secretKey), - } - }, - actor: gateway1, + message: createWebsocketMessage(client1Key, 1), + actor: client1, }); expect(res).toMatchObject({ - Err: "Signature doesn't verify", + Ok: null, }); }); - it("fails if registered gateway sends a RelayedByGateway message with a wrong sequence number", async () => { - const appMessage = IDL.Record({ 'text': IDL.Text }).encodeValue({ text: "pong" }); - let content = getWebsocketMessage(client1KeyPair.publicKey, 1, appMessage); - let res = await wsMessage({ - message: { - RelayedByGateway: { - content, - sig: await getMessageSignature(content, client1KeyPair.secretKey), - } - }, - actor: gateway1, - }); - expect(res).toMatchObject({ - Err: "incoming client's message relayed from WS Gateway does not have the expected sequence number", + it("fails if client sends a message with a wrong sequence number", async () => { + const actualSequenceNumber = 1; + const expectedSequenceNumber = 2; // first valid message with sequence number 1 was sent in the previous test + const res = await wsMessage({ + message: createWebsocketMessage(client1Key, actualSequenceNumber), + actor: client1, }); - // send a correct message to increase the sequence number - content = getWebsocketMessage(client1KeyPair.publicKey, 0, appMessage); - res = await wsMessage({ - message: { - RelayedByGateway: { - content, - sig: await getMessageSignature(content, client1KeyPair.secretKey), - } - }, - actor: gateway1, - }); expect(res).toMatchObject({ - Ok: null, + Err: `incoming client's message does not have the expected sequence number. Expected: ${expectedSequenceNumber}, actual: ${actualSequenceNumber}. Client removed.`, }); - // send a message with the old sequence number - content = getWebsocketMessage(client1KeyPair.publicKey, 0, appMessage); - res = await wsMessage({ - message: { - RelayedByGateway: { - content, - sig: await getMessageSignature(content, client1KeyPair.secretKey), - } - }, - actor: gateway1, - }); - expect(res).toMatchObject({ - Err: "incoming client's message relayed from WS Gateway does not have the expected sequence number", + // check if client has been removed + const res2 = await wsMessage({ + message: createWebsocketMessage(client1Key, 0), // here the sequence number doesn't matter + actor: client1, }); - // send a message with a sequence number that is too high - content = getWebsocketMessage(client1KeyPair.publicKey, 2, appMessage); - res = await wsMessage({ - message: { - RelayedByGateway: { - content, - sig: await getMessageSignature(content, client1KeyPair.secretKey), - } - }, - actor: gateway1, - }); - expect(res).toMatchObject({ - Err: "incoming client's message relayed from WS Gateway does not have the expected sequence number", + expect(res2).toMatchObject({ + Err: `client with principal ${client1Key.client_principal.toText()} doesn't have an open connection`, }); }); - it("fails if a registered gateway sends a RelayedByGateway for a registered client that doesn't have open connection", async () => { - // register another client, but don't call ws_open for it - await wsRegister({ - clientActor: client2, - clientKey: client2KeyPair.publicKey, + it("fails if a client sends a wrong service message", async () => { + // open the connection again + await wsOpen({ + clientNonce: client1Key.client_nonce, + canisterId, + clientActor: client1, }, true); - const content = getWebsocketMessage(client2KeyPair.publicKey, 0); + // wring content encoding const res = await wsMessage({ - message: { - RelayedByGateway: { - content, - sig: await getMessageSignature(content, client2KeyPair.secretKey), - } - }, - actor: gateway1, + message: createWebsocketMessage(client1Key, 1, true, new Uint8Array([1, 2, 3])), + actor: client1, }); expect(res).toMatchObject({ - Err: "expected incoming message num not initialized for client", + Err: expect.stringContaining("Error decoding service message from client:"), }); - }); - it("fails if registered gateway sends a DirectlyFromClient message", async () => { - const message = IDL.encode([IDL.Record({ 'text': IDL.Text })], [{ text: "pong" }]); - const res = await wsMessage({ - message: { - DirectlyFromClient: { - client_key: client1KeyPair.publicKey, - message: new Uint8Array(message), - } - }, - actor: gateway1, + const wrongServiceMessage: WebsocketServiceMessageContent = { + // the client can only send KeepAliveMessage variant + AckMessage: { + last_incoming_sequence_num: BigInt(0), + } + }; + const res2 = await wsMessage({ + message: createWebsocketMessage(client1Key, 2, true, encodeWebsocketServiceMessageContent(wrongServiceMessage)), + actor: client1, }); - expect(res).toMatchObject({ - Err: "caller is not the same that registered the public key", + expect(res2).toMatchObject({ + Err: "invalid keep alive message content", }); }); - it("a registered gateway should send a message (IcWebSocketEstablished) for a registered client", async () => { - const res = await wsMessage({ - message: { - IcWebSocketEstablished: client1KeyPair.publicKey, + it("should send a service message from a registered client", async () => { + const clientServiceMessage: WebsocketServiceMessageContent = { + KeepAliveMessage: { + last_incoming_sequence_num: BigInt(0), }, - actor: gateway1, - }); - - expect(res).toMatchObject({ - Ok: null, - }); - }); - - it("a registered gateway should send a message (RelayedByGateway) for a registered client", async () => { - const appMessage = IDL.encode([IDL.Record({ 'text': IDL.Text })], [{ text: "pong" }]); - // the message with sequence number 0 has been sent in a previous test, so we send a message with sequence number 1 - const content = getWebsocketMessage(client1KeyPair.publicKey, 1, appMessage); + }; const res = await wsMessage({ - message: { - RelayedByGateway: { - content, - sig: await getMessageSignature(content, client1KeyPair.secretKey), - } - }, - actor: gateway1, - }); - - expect(res).toMatchObject({ - Ok: null, - }); - }); - - it("a registered client should send a message (DirectlyFromClient)", async () => { - const message = IDL.encode([IDL.Record({ 'text': IDL.Text })], [{ text: "pong" }]); - const res = await wsMessage({ - message: { - DirectlyFromClient: { - client_key: client1KeyPair.publicKey, - message: new Uint8Array(message), - } - }, + message: createWebsocketMessage(client1Key, 3, true, encodeWebsocketServiceMessageContent(clientServiceMessage)), actor: client1, }); @@ -614,32 +296,6 @@ describe("Canister - ws_message", () => { }); describe("Canister - ws_get_messages (failures,empty)", () => { - beforeAll(async () => { - await assignKeyPairsToClients(); - - await wsRegister({ - clientActor: client1, - clientKey: client1KeyPair.publicKey, - }, true); - - await wsOpen({ - clientPublicKey: client1KeyPair.publicKey, - clientSecretKey: client1KeyPair.secretKey, - canisterId, - gatewayActor: gateway1, - }, true); - - await commonAgent.fetchRootKey(); - }); - - afterAll(async () => { - await wsWipe(gateway1); - }); - - beforeEach(async () => { - await sendGatewayStatusMessage(); - }); - it("fails if a non registered gateway tries to get messages", async () => { const res = await gateway2.ws_get_messages({ nonce: BigInt(0), @@ -677,245 +333,245 @@ describe("Canister - ws_get_messages (failures,empty)", () => { }); }); -describe("Canister - ws_message (gateway status)", () => { +// describe("Canister - ws_message (gateway status)", () => { +// beforeAll(async () => { +// await assignKeysToClients(); + +// await wsRegister({ +// clientActor: client1, +// clientKey: client1Key.publicKey, +// }, true); + +// await wsOpen({ +// clientPublicKey: client1Key.publicKey, +// clientSecretKey: client1Key.secretKey, +// canisterId, +// clientActor: gateway1, +// }, true); + +// await wsSend({ +// clientKey: client1Key.publicKey, +// actor: client1, +// message: { text: "test" }, +// }, true); +// }); + +// afterAll(async () => { +// await wsWipe(gateway1); +// }); + +// it("fails if a non registered gateway sends an IcWebSocketGatewayStatus message", async () => { +// const res = await wsMessage({ +// message: { +// IcWebSocketGatewayStatus: { +// status_index: BigInt(1), +// }, +// }, +// actor: gateway2, +// }); + +// expect(res).toMatchObject({ +// Err: "caller is not the gateway that has been registered during CDK initialization", +// }); +// }); + +// it("registered gateway should update the status index", async () => { +// const res = await wsMessage({ +// message: { +// IcWebSocketGatewayStatus: { +// status_index: BigInt(2), // set it high to test behavior for indexes behind the current one +// }, +// }, +// actor: gateway1, +// }); + +// expect(res).toMatchObject({ +// Ok: null, +// }); +// }); + +// it("fails if a registered gateway sends an IcWebSocketGatewayStatus with a wrong status index (equal to current)", async () => { +// const res = await wsMessage({ +// message: { +// IcWebSocketGatewayStatus: { +// status_index: BigInt(2), +// }, +// }, +// actor: gateway1, +// }); + +// expect(res).toMatchObject({ +// Err: "Gateway status index is equal to or behind the current one", +// }); +// }); + +// it("fails if a registered gateway sends an IcWebSocketGatewayStatus with a wrong status index (behind the current)", async () => { +// const res = await wsMessage({ +// message: { +// IcWebSocketGatewayStatus: { +// status_index: BigInt(1), +// }, +// }, +// actor: gateway1, +// }); + +// expect(res).toMatchObject({ +// Err: "Gateway status index is equal to or behind the current one", +// }); +// }); + +// it("registered gateway should disconnect after maximum time", async () => { +// let res = await gateway1.ws_get_messages({ +// nonce: BigInt(0), +// }); + +// expect(res).toMatchObject({ +// Ok: { +// messages: expect.any(Array), +// cert: expect.any(Uint8Array), +// tree: expect.any(Uint8Array), +// }, +// }); +// expect((res as { Ok: CanisterOutputCertifiedMessages }).Ok.messages.length).toEqual(1); + +// // wait for the maximum time the gateway can send a status message, +// // so that the internal canister state is reset +// // double the time to make sure the canister state is reset +// await new Promise((resolve) => setTimeout(resolve, 2 * MAX_GATEWAY_KEEP_ALIVE_TIME_MS)); + +// // check if messages have been deleted +// res = await gateway1.ws_get_messages({ +// nonce: BigInt(0), +// }); +// expect(res).toMatchObject({ +// Ok: { +// messages: [], +// cert: expect.any(Uint8Array), +// tree: expect.any(Uint8Array), +// }, +// }); + +// // check if registered client has been deleted +// const sendRes = await wsSend({ +// clientKey: client1Key.publicKey, +// actor: client1, +// message: { text: "test" }, +// }); +// expect(sendRes).toMatchObject({ +// Err: "client's public key has not been previously registered by client", +// }); +// }); + +// it("registered gateway should reconnect by resetting the status index", async () => { +// let res = await wsMessage({ +// message: { +// IcWebSocketGatewayStatus: { +// status_index: BigInt(0), +// }, +// }, +// actor: gateway1, +// }); + +// expect(res).toMatchObject({ +// Ok: null, +// }); + +// res = await wsMessage({ +// message: { +// IcWebSocketGatewayStatus: { +// status_index: BigInt(1), +// }, +// }, +// actor: gateway1, +// }); + +// expect(res).toMatchObject({ +// Ok: null, +// }); +// }); + +// it("registered gateway should reconnect before maximum time", async () => { +// // reconnect the client +// await wsRegister({ +// clientActor: client1, +// clientKey: client1Key.publicKey, +// }, true); + +// await wsOpen({ +// clientPublicKey: client1Key.publicKey, +// clientSecretKey: client1Key.secretKey, +// canisterId, +// clientActor: gateway1, +// }, true); + +// // send a test message from the canister to check if the internal state is reset +// await wsSend({ +// clientKey: client1Key.publicKey, +// actor: client1, +// message: { text: "test" }, +// }, true); + +// // check if the canister has the message in the queue +// let messagesRes = await gateway1.ws_get_messages({ +// nonce: BigInt(0), +// }); +// expect(messagesRes).toMatchObject({ +// Ok: { +// messages: expect.any(Array), +// cert: expect.any(Uint8Array), +// tree: expect.any(Uint8Array), +// }, +// }); +// expect((messagesRes as { Ok: CanisterOutputCertifiedMessages }).Ok.messages.length).toEqual(1); + +// // simulate a reconnection +// const res = await wsMessage({ +// message: { +// IcWebSocketGatewayStatus: { +// status_index: BigInt(0), +// }, +// }, +// actor: gateway1, +// }); +// expect(res).toMatchObject({ +// Ok: null, +// }); + +// // check if the canister reset the internal state +// messagesRes = await gateway1.ws_get_messages({ +// nonce: BigInt(0), +// }); +// expect(messagesRes).toMatchObject({ +// Ok: { +// messages: [], +// cert: expect.any(Uint8Array), +// tree: expect.any(Uint8Array), +// }, +// }); +// }); +// }); + +describe.only("Canister - ws_get_messages (receive)", () => { beforeAll(async () => { - await assignKeyPairsToClients(); - - await wsRegister({ - clientActor: client1, - clientKey: client1KeyPair.publicKey, - }, true); + await assignKeysToClients(); - await wsOpen({ - clientPublicKey: client1KeyPair.publicKey, - clientSecretKey: client1KeyPair.secretKey, - canisterId, - gatewayActor: gateway1, - }, true); - - await wsSend({ - clientPublicKey: client1KeyPair.publicKey, - actor: client1, - message: { text: "test" }, - }, true); - }); - - afterAll(async () => { - await wsWipe(gateway1); - }); - - it("fails if a non registered gateway sends an IcWebSocketGatewayStatus message", async () => { - const res = await wsMessage({ - message: { - IcWebSocketGatewayStatus: { - status_index: BigInt(1), - }, - }, - actor: gateway2, - }); - - expect(res).toMatchObject({ - Err: "caller is not the gateway that has been registered during CDK initialization", - }); - }); - - it("registered gateway should update the status index", async () => { - const res = await wsMessage({ - message: { - IcWebSocketGatewayStatus: { - status_index: BigInt(2), // set it high to test behavior for indexes behind the current one - }, - }, - actor: gateway1, - }); - - expect(res).toMatchObject({ - Ok: null, - }); - }); - - it("fails if a registered gateway sends an IcWebSocketGatewayStatus with a wrong status index (equal to current)", async () => { - const res = await wsMessage({ - message: { - IcWebSocketGatewayStatus: { - status_index: BigInt(2), - }, - }, - actor: gateway1, - }); - - expect(res).toMatchObject({ - Err: "Gateway status index is equal to or behind the current one", - }); - }); - - it("fails if a registered gateway sends an IcWebSocketGatewayStatus with a wrong status index (behind the current)", async () => { - const res = await wsMessage({ - message: { - IcWebSocketGatewayStatus: { - status_index: BigInt(1), - }, - }, - actor: gateway1, - }); - - expect(res).toMatchObject({ - Err: "Gateway status index is equal to or behind the current one", - }); - }); - - it("registered gateway should disconnect after maximum time", async () => { - let res = await gateway1.ws_get_messages({ - nonce: BigInt(0), - }); - - expect(res).toMatchObject({ - Ok: { - messages: expect.any(Array), - cert: expect.any(Uint8Array), - tree: expect.any(Uint8Array), - }, - }); - expect((res as { Ok: CanisterOutputCertifiedMessages }).Ok.messages.length).toEqual(1); - - // wait for the maximum time the gateway can send a status message, - // so that the internal canister state is reset - // double the time to make sure the canister state is reset - await new Promise((resolve) => setTimeout(resolve, 2 * MAX_GATEWAY_KEEP_ALIVE_TIME_MS)); - - // check if messages have been deleted - res = await gateway1.ws_get_messages({ - nonce: BigInt(0), - }); - expect(res).toMatchObject({ - Ok: { - messages: [], - cert: expect.any(Uint8Array), - tree: expect.any(Uint8Array), - }, + // reset the internal timers + await reinitialize({ + sendAckIntervalMs: DEFAULT_TEST_SEND_ACK_INTERVAL_MS, + keepAliveDelayMs: DEFAULT_TEST_KEEP_ALIVE_DELAY_MS, }); - // check if registered client has been deleted - const sendRes = await wsSend({ - clientPublicKey: client1KeyPair.publicKey, - actor: client1, - message: { text: "test" }, - }); - expect(sendRes).toMatchObject({ - Err: "client's public key has not been previously registered by client", - }); - }); - - it("registered gateway should reconnect by resetting the status index", async () => { - let res = await wsMessage({ - message: { - IcWebSocketGatewayStatus: { - status_index: BigInt(0), - }, - }, - actor: gateway1, - }); - - expect(res).toMatchObject({ - Ok: null, - }); - - res = await wsMessage({ - message: { - IcWebSocketGatewayStatus: { - status_index: BigInt(1), - }, - }, - actor: gateway1, - }); - - expect(res).toMatchObject({ - Ok: null, - }); - }); - - it("registered gateway should reconnect before maximum time", async () => { - // reconnect the client - await wsRegister({ - clientActor: client1, - clientKey: client1KeyPair.publicKey, - }, true); - await wsOpen({ - clientPublicKey: client1KeyPair.publicKey, - clientSecretKey: client1KeyPair.secretKey, + clientNonce: client1Key.client_nonce, canisterId, - gatewayActor: gateway1, - }, true); - - // send a test message from the canister to check if the internal state is reset - await wsSend({ - clientPublicKey: client1KeyPair.publicKey, - actor: client1, - message: { text: "test" }, - }, true); - - // check if the canister has the message in the queue - let messagesRes = await gateway1.ws_get_messages({ - nonce: BigInt(0), - }); - expect(messagesRes).toMatchObject({ - Ok: { - messages: expect.any(Array), - cert: expect.any(Uint8Array), - tree: expect.any(Uint8Array), - }, - }); - expect((messagesRes as { Ok: CanisterOutputCertifiedMessages }).Ok.messages.length).toEqual(1); - - // simulate a reconnection - const res = await wsMessage({ - message: { - IcWebSocketGatewayStatus: { - status_index: BigInt(0), - }, - }, - actor: gateway1, - }); - expect(res).toMatchObject({ - Ok: null, - }); - - // check if the canister reset the internal state - messagesRes = await gateway1.ws_get_messages({ - nonce: BigInt(0), - }); - expect(messagesRes).toMatchObject({ - Ok: { - messages: [], - cert: expect.any(Uint8Array), - tree: expect.any(Uint8Array), - }, - }); - }); -}); - -describe("Canister - ws_get_messages (receive)", () => { - beforeAll(async () => { - await assignKeyPairsToClients(); - - await wsRegister({ clientActor: client1, - clientKey: client1KeyPair.publicKey, - }, true); - - await wsOpen({ - clientPublicKey: client1KeyPair.publicKey, - clientSecretKey: client1KeyPair.secretKey, - canisterId, - gatewayActor: gateway1, }, true); // prepare the messages for (let i = 0; i < SEND_MESSAGES_COUNT; i++) { const appMessage = { text: `test${i}` }; await wsSend({ - clientPublicKey: client1KeyPair.publicKey, + clientPrincipal: client1Key.client_principal, actor: client1, message: appMessage, }, true); @@ -925,15 +581,13 @@ describe("Canister - ws_get_messages (receive)", () => { }); afterAll(async () => { - await wsWipe(gateway1); - }); - - beforeEach(async () => { - await sendGatewayStatusMessage(); + await wsWipe(); }); it("registered gateway can receive correct amount of messages", async () => { - for (let i = 0; i < SEND_MESSAGES_COUNT; i++) { + // on open, the canister puts a service message in the queue + const messagesCount = SEND_MESSAGES_COUNT + 1; // +1 for the service message + for (let i = 0; i < messagesCount; i++) { const res = await gateway1.ws_get_messages({ nonce: BigInt(i), }); @@ -948,15 +602,15 @@ describe("Canister - ws_get_messages (receive)", () => { const messagesResult = (res as { Ok: CanisterOutputCertifiedMessages }).Ok; expect(messagesResult.messages.length).toBe( - SEND_MESSAGES_COUNT - i > MAX_NUMBER_OF_RETURNED_MESSAGES + messagesCount - i > MAX_NUMBER_OF_RETURNED_MESSAGES ? MAX_NUMBER_OF_RETURNED_MESSAGES - : SEND_MESSAGES_COUNT - i + : messagesCount - i ); } // try to get more messages than available const res = await gateway1.ws_get_messages({ - nonce: BigInt(SEND_MESSAGES_COUNT), + nonce: BigInt(messagesCount), }); expect(res).toMatchObject({ @@ -971,21 +625,25 @@ describe("Canister - ws_get_messages (receive)", () => { it("registered gateway can receive certified messages", async () => { // first batch of messages const firstBatchRes = await gateway1.ws_get_messages({ - nonce: BigInt(0), + nonce: BigInt(1), }); const firstBatchMessagesResult = (firstBatchRes as { Ok: CanisterOutputCertifiedMessages }).Ok; + console.log(firstBatchMessagesResult.messages.map((msg) => msg.key)); for (let i = 0; i < firstBatchMessagesResult.messages.length; i++) { const message = firstBatchMessagesResult.messages[i]; - expect(message.client_key).toEqual(client1KeyPair.publicKey); - const decodedContent = Cbor.decode(new Uint8Array(message.content)); - expect(decodedContent).toMatchObject({ - client_key: client1KeyPair.publicKey, - message: expect.any(Uint8Array), - sequence_num: i + 1, - timestamp: expect.any(Object), // weird timestamp deserialization + expect(isClientKeyEq(message.client_key, client1Key)).toEqual(true); + const websocketMessage = decodeWebsocketMessage(new Uint8Array(message.content)); + console.log(websocketMessage); + expect(websocketMessage).toMatchObject({ + client_key: expect.any(Object), + content: expect.any(Uint8Array), + sequence_num: BigInt(i + 1), + timestamp: expect.any(Object), // weird cbor bigint deserialization + is_service_message: false, }); - expect(IDL.decode([IDL.Record({ 'text': IDL.Text })], decodedContent.message as Uint8Array)).toEqual([{ text: `test${i}` }]); + expect(isClientKeyEq(websocketMessage.client_key, client1Key)).toEqual(true); + expect(IDL.decode([IDL.Record({ 'text': IDL.Text })], websocketMessage.content as Uint8Array)).toEqual([{ text: `test${i}` }]); // check the certification await expect( @@ -1013,15 +671,17 @@ describe("Canister - ws_get_messages (receive)", () => { const secondBatchMessagesResult = (secondBatchRes as { Ok: CanisterOutputCertifiedMessages }).Ok; for (let i = 0; i < secondBatchMessagesResult.messages.length; i++) { const message = secondBatchMessagesResult.messages[i]; - expect(message.client_key).toEqual(client1KeyPair.publicKey); - const decodedContent = Cbor.decode(new Uint8Array(message.content)); - expect(decodedContent).toMatchObject({ - client_key: client1KeyPair.publicKey, - message: expect.any(Uint8Array), - sequence_num: i + MAX_NUMBER_OF_RETURNED_MESSAGES + 1, - timestamp: expect.any(Object), // weird timestamp deserialization + expect(isClientKeyEq(message.client_key, client1Key)).toEqual(true); + const websocketMessage = decodeWebsocketMessage(new Uint8Array(message.content)); + expect(websocketMessage).toMatchObject({ + client_key: expect.any(Object), + content: expect.any(Uint8Array), + sequence_num: BigInt(i + MAX_NUMBER_OF_RETURNED_MESSAGES + 1), + timestamp: expect.any(Object), // weird cbor bigint deserialization + is_service_message: false, }); - expect(IDL.decode([IDL.Record({ 'text': IDL.Text })], decodedContent.message as Uint8Array)).toEqual([{ text: `test${i + MAX_NUMBER_OF_RETURNED_MESSAGES}` }]); + expect(isClientKeyEq(websocketMessage.client_key, client1Key)).toEqual(true); + expect(IDL.decode([IDL.Record({ 'text': IDL.Text })], websocketMessage.content as Uint8Array)).toEqual([{ text: `test${i + MAX_NUMBER_OF_RETURNED_MESSAGES}` }]); // check the certification await expect( @@ -1043,107 +703,103 @@ describe("Canister - ws_get_messages (receive)", () => { }); }); -describe("Canister - ws_close", () => { - beforeAll(async () => { - await assignKeyPairsToClients(); - - await wsRegister({ - clientActor: client1, - clientKey: client1KeyPair.publicKey, - }, true); - - await wsOpen({ - clientPublicKey: client1KeyPair.publicKey, - clientSecretKey: client1KeyPair.secretKey, - canisterId, - gatewayActor: gateway1, - }, true); - }); - - afterAll(async () => { - await wsWipe(gateway1); - }); - - beforeEach(async () => { - await sendGatewayStatusMessage(); - }); - - it("fails if gateway is not registered", async () => { - const res = await wsClose({ - clientPublicKey: client1KeyPair.publicKey, - gatewayActor: gateway2, - }); - - expect(res).toMatchObject({ - Err: "caller is not the gateway that has been registered during CDK initialization", - }); - }); - - it("fails if client is not registered", async () => { - const res = await wsClose({ - clientPublicKey: client2KeyPair.publicKey, - gatewayActor: gateway1, - }); - - expect(res).toMatchObject({ - Err: "client's public key has not been previously registered by client", - }); - }); - - it("should close the websocket for a registered client", async () => { - const res = await wsClose({ - clientPublicKey: client1KeyPair.publicKey, - gatewayActor: gateway1, - }); - - expect(res).toMatchObject({ - Ok: null, - }); - }); -}); - -describe("Canister - ws_send", () => { - beforeAll(async () => { - await assignKeyPairsToClients(); - - await wsRegister({ - clientActor: client1, - clientKey: client1KeyPair.publicKey, - }, true); - - await wsOpen({ - clientPublicKey: client1KeyPair.publicKey, - clientSecretKey: client1KeyPair.secretKey, - canisterId, - gatewayActor: gateway1, - }, true); - }); - - afterAll(async () => { - await wsWipe(gateway1); - }); - - it("fails if sending a message to a non registered client", async () => { - const res = await wsSend({ - clientPublicKey: client2KeyPair.publicKey, - actor: client1, - message: { text: "test" }, - }); - - expect(res).toMatchObject({ - Err: "client's public key has not been previously registered by client", - }); - }); - - it("should send a message to a registered client", async () => { - const res = await wsSend({ - clientPublicKey: client1KeyPair.publicKey, - actor: client1, - message: { text: "test" }, - }); - - expect(res).toMatchObject({ - Ok: null, - }); - }); -}); +// describe("Canister - ws_close", () => { +// beforeAll(async () => { +// await assignKeysToClients(); + +// await wsRegister({ +// clientActor: client1, +// clientKey: client1Key.publicKey, +// }, true); + +// await wsOpen({ +// clientPublicKey: client1Key.publicKey, +// clientSecretKey: client1Key.secretKey, +// canisterId, +// clientActor: gateway1, +// }, true); +// }); + +// afterAll(async () => { +// await wsWipe(gateway1); +// }); + +// it("fails if gateway is not registered", async () => { +// const res = await wsClose({ +// clientPublicKey: client1Key.publicKey, +// gatewayActor: gateway2, +// }); + +// expect(res).toMatchObject({ +// Err: "caller is not the gateway that has been registered during CDK initialization", +// }); +// }); + +// it("fails if client is not registered", async () => { +// const res = await wsClose({ +// clientPublicKey: client2Key.publicKey, +// gatewayActor: gateway1, +// }); + +// expect(res).toMatchObject({ +// Err: "client's public key has not been previously registered by client", +// }); +// }); + +// it("should close the websocket for a registered client", async () => { +// const res = await wsClose({ +// clientPublicKey: client1Key.publicKey, +// gatewayActor: gateway1, +// }); + +// expect(res).toMatchObject({ +// Ok: null, +// }); +// }); +// }); + +// describe("Canister - ws_send", () => { +// beforeAll(async () => { +// await assignKeysToClients(); + +// await wsRegister({ +// clientActor: client1, +// clientKey: client1Key.publicKey, +// }, true); + +// await wsOpen({ +// clientPublicKey: client1Key.publicKey, +// clientSecretKey: client1Key.secretKey, +// canisterId, +// clientActor: gateway1, +// }, true); +// }); + +// afterAll(async () => { +// await wsWipe(gateway1); +// }); + +// it("fails if sending a message to a non registered client", async () => { +// const res = await wsSend({ +// clientKey: client2Key.publicKey, +// actor: client1, +// message: { text: "test" }, +// }); + +// expect(res).toMatchObject({ +// Err: "client's public key has not been previously registered by client", +// }); +// }); + +// it("should send a message to a registered client", async () => { +// const res = await wsSend({ +// clientKey: client1Key.publicKey, +// actor: client1, +// message: { text: "test" }, +// }); + +// expect(res).toMatchObject({ +// Ok: null, +// }); +// }); +// }); diff --git a/tests/integration/utils/actors.ts b/tests/integration/utils/actors.ts index 46635a5..525be15 100644 --- a/tests/integration/utils/actors.ts +++ b/tests/integration/utils/actors.ts @@ -84,3 +84,5 @@ export const client2 = createActor(canisterId, { identity: client2Data.identity, }, }); + +export const anonymousClient = createActor(canisterId); diff --git a/tests/integration/utils/api.ts b/tests/integration/utils/api.ts index ef33d37..a380a87 100644 --- a/tests/integration/utils/api.ts +++ b/tests/integration/utils/api.ts @@ -4,124 +4,120 @@ import { ActorSubclass, Cbor, Certificate, HashTree, HttpAgent, compare, lookup_ import { Secp256k1KeyIdentity } from "@dfinity/identity-secp256k1"; import { Principal } from "@dfinity/principal"; import { IDL } from "@dfinity/candid"; -import { getMessageSignature } from "./crypto"; -import type { CanisterIncomingMessage, ClientPublicKey, _SERVICE } from "../../src/declarations/test_canister/test_canister.did"; +import { anonymousClient, gateway1Data } from "./actors"; +import type { CanisterOutputCertifiedMessages, ClientKey, ClientPrincipal, WebsocketMessage, _SERVICE } from "../../src/declarations/test_canister/test_canister.did"; -type WsRegisterArgs = { - clientActor: ActorSubclass<_SERVICE>, - clientKey: Uint8Array, +type GenericResult = { + Ok: T, +} | { + Err: string, }; -export const wsRegister = async (args: WsRegisterArgs, throwIfError = false) => { - const res = await args.clientActor.ws_register({ - client_key: args.clientKey, - }); - - if (throwIfError) { - if ('Err' in res) { - throw new Error(res.Err); - } +const resolveResult = (result: GenericResult, throwIfError: boolean) => { + if (throwIfError && 'Err' in result) { + throw new Error(result.Err); } - return res; + return result; }; type WsOpenArgs = { - clientPublicKey: Uint8Array, - clientSecretKey: Uint8Array | string, + clientNonce: bigint, canisterId: string, - gatewayActor: ActorSubclass<_SERVICE>, -}; - -export type CanisterOpenMessageContent = { - client_key: ClientPublicKey, - canister_id: Principal, + clientActor: ActorSubclass<_SERVICE>, }; +/** + * Sends an update call to the canister to the **ws_open** method, using the provided actor. + * @param args {@link WsOpenArgs} + * @param throwIfError whether to throw if the result is an error (defaults to `false`) + * @returns the result of the **ws_open** method + */ export const wsOpen = async (args: WsOpenArgs, throwIfError = false) => { - const firstMessage: CanisterOpenMessageContent = { - client_key: args.clientPublicKey, - canister_id: Principal.fromText(args.canisterId), - }; - const contentBuf = new Uint8Array(Cbor.encode(firstMessage)); - const sig = await getMessageSignature(contentBuf, args.clientSecretKey); - - const res = await args.gatewayActor.ws_open({ - content: contentBuf, - sig, + const res = await args.clientActor.ws_open({ + client_nonce: args.clientNonce, }); - if (throwIfError) { - if ('Err' in res) { - throw new Error(res.Err); - } - } - - return res; + return resolveResult(res, throwIfError); }; type WsMessageArgs = { - message: CanisterIncomingMessage, + message: WebsocketMessage, actor: ActorSubclass<_SERVICE>, }; +/** + * Sends an update call to the canister to the **ws_message** method, using the provided actor. + * @param args {@link WsMessageArgs} + * @param throwIfError whether to throw if the result is an error (defaults to `false`) + * @returns the result of the **ws_message** method + */ export const wsMessage = async (args: WsMessageArgs, throwIfError = false) => { const res = await args.actor.ws_message({ msg: args.message, }); - if (throwIfError) { - if ('Err' in res) { - throw new Error(res.Err); - } - } - - return res; + return resolveResult(res, throwIfError); }; -export type WebsocketMessage = { - client_key: ClientPublicKey, - sequence_num: number, - timestamp: number, - message: ArrayBuffer | Uint8Array, +type WsCloseArgs = { + clientKey: ClientKey, + gatewayActor: ActorSubclass<_SERVICE>, }; -export const getWebsocketMessage = (clientPublicKey: ClientPublicKey, sequenceNumber: number, content?: ArrayBuffer | Uint8Array): Uint8Array => { - const websocketMessage: WebsocketMessage = { - client_key: clientPublicKey, - sequence_num: sequenceNumber, - timestamp: Date.now(), - message: new Uint8Array(content || []), - }; +/** + * Sends an update call to the canister to the **ws_close** method, using the provided gateway actor. + * @param args {@link WsCloseArgs} + * @param throwIfError whether to throw if the result is an error (defaults to `false`) + * @returns the result of the **ws_close** method + */ +export const wsClose = async (args: WsCloseArgs, throwIfError = false) => { + const res = await args.gatewayActor.ws_close({ + client_key: args.clientKey, + }); - return new Uint8Array(Cbor.encode(websocketMessage)); + return resolveResult(res, throwIfError); }; -type WsCloseArgs = { - clientPublicKey: Uint8Array, +type WsGetMessagesArgs = { + fromNonce: number, gatewayActor: ActorSubclass<_SERVICE>, }; -export const wsClose = async (args: WsCloseArgs, throwIfError = false) => { - const res = await args.gatewayActor.ws_close({ - client_key: args.clientPublicKey, +/** + * Sends a query call to the canister to the **ws_get_messages** method, using the provided gateway actor. + * @param args {@link WsGetMessagesArgs} + */ +export const wsGetMessages = async (args: WsGetMessagesArgs): Promise => { + const res = await args.gatewayActor.ws_get_messages({ + nonce: BigInt(args.fromNonce), }); - if (throwIfError) { - if ('Err' in res) { - throw new Error(res.Err); - } - } + const messages = resolveResult(res, true); - return res; + return (messages as { Ok: CanisterOutputCertifiedMessages }).Ok; }; -export const wsWipe = async (gatewayActor: ActorSubclass<_SERVICE>) => { - await gatewayActor.ws_wipe(); +export const wsWipe = async () => { + await anonymousClient.ws_wipe(); +}; + +type ReinitializeArgs = { + sendAckIntervalMs: number, + keepAliveDelayMs: number, +}; + +/** + * Used to reinitialize the canister with the provided intervals. + * @param args {@link ReinitializeArgs} + */ +export const reinitialize = async (args: ReinitializeArgs) => { + const gatewayPrincipal = (await gateway1Data.identity).getPrincipal().toText(); + await anonymousClient.reinitialize(gatewayPrincipal, BigInt(args.sendAckIntervalMs), BigInt(args.keepAliveDelayMs)); }; type WsSendArgs = { - clientPublicKey: Uint8Array, + clientPrincipal: ClientPrincipal, actor: ActorSubclass<_SERVICE>, message: { text: string, @@ -130,15 +126,9 @@ type WsSendArgs = { export const wsSend = async (args: WsSendArgs, throwIfError = false) => { const msgBytes = IDL.encode([IDL.Record({ 'text': IDL.Text })], [args.message]); - const res = await args.actor.ws_send(args.clientPublicKey, new Uint8Array(msgBytes)); + const res = await args.actor.ws_send(args.clientPrincipal, new Uint8Array(msgBytes)); - if (throwIfError) { - if ('Err' in res) { - throw new Error(res.Err); - } - } - - return res; + return resolveResult(res, throwIfError); }; export const getCertifiedMessageKey = async (gatewayIdentity: Promise, nonce: number) => { @@ -146,7 +136,6 @@ export const getCertifiedMessageKey = async (gatewayIdentity: Promise { const canisterPrincipal = Principal.fromText(canisterId); let cert: Certificate; diff --git a/tests/integration/utils/crypto.ts b/tests/integration/utils/crypto.ts deleted file mode 100644 index a46d653..0000000 --- a/tests/integration/utils/crypto.ts +++ /dev/null @@ -1,22 +0,0 @@ -import * as ed from "@noble/ed25519"; - -export const getKeyPair = async (secretKey?: string | Uint8Array): Promise<{ publicKey: Uint8Array, secretKey: Uint8Array | string }> => { - if (!secretKey) { - secretKey = ed.utils.randomPrivateKey(); - } - - const publicKey = await ed.getPublicKeyAsync(secretKey); - - return { - publicKey, - secretKey, - }; -}; - -export const getMessageSignature = async (buf: ArrayBuffer | Uint8Array, secretKey: Uint8Array | string): Promise => { - // Sign the message so that the gateway can verify canister and client ids match - const toSign = new Uint8Array(buf); - const sig = await ed.signAsync(toSign, secretKey); - - return sig; -} diff --git a/tests/integration/utils/identity.ts b/tests/integration/utils/identity.ts index 21a1d8b..cfac2ee 100644 --- a/tests/integration/utils/identity.ts +++ b/tests/integration/utils/identity.ts @@ -9,3 +9,8 @@ export const identityFromSeed = async (phrase: string) => { return Secp256k1KeyIdentity.generate(addrnode.privateKey); }; + +export const generateRandomIdentity = async () => { + const mnemonic = bip39.generateMnemonic(); + return identityFromSeed(mnemonic); +}; diff --git a/tests/integration/utils/idl.ts b/tests/integration/utils/idl.ts new file mode 100644 index 0000000..317486e --- /dev/null +++ b/tests/integration/utils/idl.ts @@ -0,0 +1,67 @@ +import { IDL } from "@dfinity/candid"; +import { CanisterOutputMessage, ClientKey, WebsocketMessage } from "../../src/declarations/test_canister/test_canister.did"; +import { Cbor } from "@dfinity/agent"; + +export const ClientPrincipalIdl = IDL.Principal; +export const ClientKeyIdl = IDL.Record({ + 'client_principal': ClientPrincipalIdl, + 'client_nonce': IDL.Nat64, +}); + +export type CanisterOpenMessageContent = { + 'client_key': ClientKey, +}; +export type CanisterAckMessageContent = { + 'last_incoming_sequence_num': bigint, +}; +export type ClientKeepAliveMessageContent = { + 'last_incoming_sequence_num': bigint, +}; +export type WebsocketServiceMessageContent = { + OpenMessage: CanisterOpenMessageContent, +} | { + AckMessage: CanisterAckMessageContent, +} | { + KeepAliveMessage: ClientKeepAliveMessageContent, +}; + +export const CanisterOpenMessageContentIdl = IDL.Record({ + 'client_key': ClientKeyIdl, +}); +export const CanisterAckMessageContentIdl = IDL.Record({ + 'last_incoming_sequence_num': IDL.Nat64, +}); +export const ClientKeepAliveMessageContentIdl = IDL.Record({ + 'last_incoming_sequence_num': IDL.Nat64, +}); +export const WebsocketServiceMessageContentIdl = IDL.Variant({ + 'OpenMessage': CanisterOpenMessageContentIdl, + 'AckMessage': CanisterAckMessageContentIdl, + 'KeepAliveMessage': ClientKeepAliveMessageContentIdl, +}); + +export const decodeWebsocketServiceMessageContent = (bytes: Uint8Array): WebsocketServiceMessageContent => { + const decoded = IDL.decode([WebsocketServiceMessageContentIdl], bytes); + if (decoded.length !== 1) { + throw new Error("Invalid CanisterServiceMessage"); + } + return decoded[0] as unknown as WebsocketServiceMessageContent; +}; + +export const encodeWebsocketServiceMessageContent = (msg: WebsocketServiceMessageContent): Uint8Array => { + return new Uint8Array(IDL.encode([WebsocketServiceMessageContentIdl], [msg])); +}; + +export const isClientKeyEq = (a: ClientKey, b: ClientKey): boolean => { + return a.client_principal.compareTo(b.client_principal) === "eq" && a.client_nonce === b.client_nonce; +} + +export const getServiceMessageFromCanisterMessage = (msg: CanisterOutputMessage): WebsocketServiceMessageContent => { + const content = getWebsocketMessageFromCanisterMessage(msg).content; + return decodeWebsocketServiceMessageContent(content as Uint8Array); +} + +export const getWebsocketMessageFromCanisterMessage = (msg: CanisterOutputMessage): WebsocketMessage => { + const websocketMessage: WebsocketMessage = Cbor.decode(msg.content as Uint8Array); + return websocketMessage; +} diff --git a/tests/integration/utils/messages.ts b/tests/integration/utils/messages.ts new file mode 100644 index 0000000..993e36c --- /dev/null +++ b/tests/integration/utils/messages.ts @@ -0,0 +1,31 @@ +import { Cbor } from "@dfinity/agent"; +import { CanisterOutputMessage, ClientKey, WebsocketMessage } from "../../src/declarations/test_canister/test_canister.did"; +import { getWebsocketMessageFromCanisterMessage } from "./idl"; + +export const filterServiceMessagesFromCanisterMessages = (messages: CanisterOutputMessage[]): CanisterOutputMessage[] => { + return messages.filter((msg) => { + const websocketMessage = getWebsocketMessageFromCanisterMessage(msg); + return websocketMessage.is_service_message; + }); +}; + +export const createWebsocketMessage = ( + clientKey: ClientKey, + sequenceNumber: number, + isServiceMessage = false, + content?: ArrayBuffer | Uint8Array +): WebsocketMessage => { + const websocketMessage: WebsocketMessage = { + client_key: clientKey, + sequence_num: BigInt(sequenceNumber), + timestamp: BigInt(Date.now()) * BigInt(1_000_000), // in nanoseconds + content: new Uint8Array(content || []), + is_service_message: isServiceMessage, + }; + + return websocketMessage; +}; + +export const decodeWebsocketMessage = (bytes: Uint8Array): WebsocketMessage => { + return Cbor.decode(bytes); +}; diff --git a/tests/integration/utils/random.ts b/tests/integration/utils/random.ts new file mode 100644 index 0000000..d469381 --- /dev/null +++ b/tests/integration/utils/random.ts @@ -0,0 +1,20 @@ +import { ClientKey, ClientPrincipal } from "../../src/declarations/test_canister/test_canister.did"; +import { generateRandomIdentity } from "./identity"; + +export const getRandomClientNonce = (): bigint => { + const array = new BigUint64Array(1); + globalThis.crypto.getRandomValues(array); + return array[0]; +}; + +export const generateClientKey = (clientPrincipal: ClientPrincipal): ClientKey => { + return { + client_principal: clientPrincipal, + client_nonce: getRandomClientNonce(), + }; +}; + +export const getRandomPrincipal = async (): Promise => { + const identity = await generateRandomIdentity(); + return identity.getPrincipal(); +}; diff --git a/tests/package.json b/tests/package.json index 240d132..7313deb 100644 --- a/tests/package.json +++ b/tests/package.json @@ -10,7 +10,7 @@ ], "scripts": { "generate": "dfx generate test_canister", - "deploy:tests": "dfx deploy test_canister --no-wallet --argument '(\"i3gux-m3hwt-5mh2w-t7wwm-fwx5j-6z6ht-hxguo-t4rfw-qp24z-g5ivt-2qe\")'", + "deploy:tests": "dfx deploy test_canister --no-wallet --argument '(\"i3gux-m3hwt-5mh2w-t7wwm-fwx5j-6z6ht-hxguo-t4rfw-qp24z-g5ivt-2qe\", 300_000 : nat64, 300_000 : nat64)'", "test:integration": "jest integration" }, "devDependencies": { diff --git a/tests/src/lib.rs b/tests/src/lib.rs index 9e09999..26932a3 100644 --- a/tests/src/lib.rs +++ b/tests/src/lib.rs @@ -11,7 +11,7 @@ use ic_websocket_cdk::{ mod canister; #[init] -fn init(gateway_principal: String) { +fn init(gateway_principal: String, send_ack_interval_ms: u64, keep_alive_delay_ms: u64) { let handlers = WsHandlers { on_open: Some(on_open), on_message: Some(on_message), @@ -21,16 +21,16 @@ fn init(gateway_principal: String) { let params = WsInitParams { handlers, gateway_principal, - send_ack_interval_ms: 10_000, - keep_alive_delay_ms: 5_000, + send_ack_interval_ms, + keep_alive_delay_ms, }; ic_websocket_cdk::init(params) } #[post_upgrade] -fn post_upgrade(gateway_principal: String) { - init(gateway_principal); +fn post_upgrade(gateway_principal: String, send_ack_interval_ms: u64, keep_alive_delay_ms: u64) { + init(gateway_principal, send_ack_interval_ms, keep_alive_delay_ms); } // method called by the WS Gateway after receiving FirstMessage from the client @@ -69,3 +69,10 @@ fn ws_wipe() { fn ws_send(client_principal: ClientPrincipal, msg_bytes: Vec) -> CanisterWsSendResult { ic_websocket_cdk::ws_send(client_principal, msg_bytes) } + +// reinitialize the canister +#[update] +fn reinitialize(gateway_principal: String, send_ack_interval_ms: u64, keep_alive_delay_ms: u64) { + ic_websocket_cdk::wipe(); + init(gateway_principal, send_ack_interval_ms, keep_alive_delay_ms); +} diff --git a/tests/test_canister.did b/tests/test_canister.did index a98d79e..da066d3 100644 --- a/tests/test_canister.did +++ b/tests/test_canister.did @@ -5,7 +5,7 @@ type CanisterWsSendResult = variant { Err : text; }; -service : (text) -> { +service : (text, nat64, nat64) -> { "ws_open" : (CanisterWsOpenArguments) -> (CanisterWsOpenResult); "ws_close" : (CanisterWsCloseArguments) -> (CanisterWsCloseResult); "ws_message" : (CanisterWsMessageArguments) -> (CanisterWsMessageResult); @@ -14,4 +14,5 @@ service : (text) -> { // methods used just for debugging/testing "ws_wipe" : () -> (); "ws_send" : (ClientPrincipal, blob) -> (CanisterWsSendResult); + "reinitialize" : (text, nat64, nat64) -> (); }; From 825fabf505432d7467785b49fe59bffde0d8f345 Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Mon, 25 Sep 2023 08:56:58 +0200 Subject: [PATCH 27/31] fix: checks order, no multiple handlers calls --- src/ic-websocket-cdk/src/lib.rs | 42 ++++++++++++++------------------- 1 file changed, 18 insertions(+), 24 deletions(-) diff --git a/src/ic-websocket-cdk/src/lib.rs b/src/ic-websocket-cdk/src/lib.rs index b9c6002..159972d 100644 --- a/src/ic-websocket-cdk/src/lib.rs +++ b/src/ic-websocket-cdk/src/lib.rs @@ -200,7 +200,6 @@ thread_local! { } /// Resets all RefCells to their initial state. -/// If there is a registered gateway, resets its state as well. fn reset_internal_state() { let client_keys_to_remove: Vec = REGISTERED_CLIENTS.with(|state| { let map = state.borrow(); @@ -353,7 +352,7 @@ fn increment_expected_incoming_message_from_client_num( fn add_client(client_key: ClientKey, new_client: RegisteredClient) { // insert the client in the map insert_client(client_key.clone(), new_client); - // initialize incoming client's message sequence number to 0 + // initialize incoming client's message sequence number to 1 init_expected_incoming_message_from_client_num(client_key.clone()); // initialize outgoing message sequence number to 0 init_outgoing_message_to_client_num(client_key); @@ -423,7 +422,7 @@ fn get_messages_for_gateway(start_index: usize, end_index: usize) -> Vec CanisterWsGetMessagesResult { let (start_index, end_index) = get_messages_for_gateway_range(gateway_principal, nonce); let messages = get_messages_for_gateway(start_index, end_index); @@ -814,11 +813,18 @@ pub fn init(params: WsInitParams) { /// Handles the WS connection open event sent by the client and relayed by the Gateway. pub fn ws_open(args: CanisterWsOpenArguments) -> CanisterWsOpenResult { let client_principal = caller(); - // TODO: test + // anonymous clients cannot open a connection if client_principal == ClientPrincipal::anonymous() { return Err(String::from("anonymous principal cannot open a connection")); } + // avoid gateway opening a connection for its own principal + if is_registered_gateway(client_principal) { + return Err(String::from( + "caller is the registered gateway which can't open a connection for itself", + )); + } + let client_key = ClientKey::new(client_principal, args.client_nonce); // check if client is not registered yet if is_client_registered(&client_key) { @@ -828,13 +834,6 @@ pub fn ws_open(args: CanisterWsOpenArguments) -> CanisterWsOpenResult { )); } - // avoid gateway opening a connection for its own principal - if is_registered_gateway(client_principal) { - return Err(String::from( - "caller is the registered gateway, cannot open a connection", - )); - } - // initialize client maps let new_client = RegisteredClient::new(); add_client(client_key.clone(), new_client); @@ -864,12 +863,6 @@ pub fn ws_close(args: CanisterWsCloseArguments) -> CanisterWsCloseResult { remove_client(&args.client_key); - HANDLERS.with(|h| { - h.borrow().call_on_close(OnCloseCallbackArgs { - client_principal: args.client_key.client_principal, - }); - }); - Ok(()) } @@ -887,10 +880,10 @@ pub fn ws_message(args: CanisterWsMessageArguments) -> CanisterWsMessageResult { content, } = args.msg; - // check if the client is registered with the same nonce as the one used in the message - if registered_client_key.client_nonce != client_key.client_nonce { + // check if the client key is correct + if registered_client_key != client_key { return Err(String::from(format!( - "client with principal {} has a different nonce than the one used in the message", + "client with principal {} has a different key than the one used in the message", client_principal ))); } @@ -899,11 +892,12 @@ pub fn ws_message(args: CanisterWsMessageArguments) -> CanisterWsMessageResult { // check if the incoming message has the expected sequence number if sequence_num != expected_sequence_num { - custom_print!( - "incoming client's message does not have the expected sequence number. Expected: {expected_sequence_num}, actual: {sequence_num}. Removing client..." - ); remove_client(&client_key); - return Ok(()); + return Err(String::from( + format!( + "incoming client's message does not have the expected sequence number. Expected: {expected_sequence_num}, actual: {sequence_num}. Client removed.", + ), + )); } // increase the expected sequence number by 1 increment_expected_incoming_message_from_client_num(&client_key)?; From 18a892055d866d602038db3ba6e1f1a1102eeac7 Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Mon, 25 Sep 2023 15:42:10 +0200 Subject: [PATCH 28/31] fix: remove ack messages logic --- src/ic-websocket-cdk/src/lib.rs | 133 +------------------------------- 1 file changed, 2 insertions(+), 131 deletions(-) diff --git a/src/ic-websocket-cdk/src/lib.rs b/src/ic-websocket-cdk/src/lib.rs index 159972d..2816771 100644 --- a/src/ic-websocket-cdk/src/lib.rs +++ b/src/ic-websocket-cdk/src/lib.rs @@ -1,15 +1,13 @@ -use candid::{decode_one, encode_one, CandidType, Principal}; +use candid::{encode_one, CandidType, Principal}; #[cfg(not(test))] use ic_cdk::api::time; use ic_cdk::api::{caller, data_certificate, set_certified_data}; -use ic_cdk_timers::set_timer; use ic_certified_map::{labeled, labeled_hash, AsHashTree, Hash as ICHash, RbTree}; use serde::{Deserialize, Serialize}; use serde_cbor::Serializer; use sha2::{Digest, Sha256}; use std::fmt; use std::panic; -use std::time::Duration; use std::{cell::RefCell, collections::HashMap, collections::VecDeque, convert::AsRef}; mod logger; @@ -160,16 +158,6 @@ impl RegisteredClient { last_keep_alive_timestamp: get_current_time(), } } - - /// Gets the last keep alive timestamp. - fn get_last_keep_alive_timestamp(&self) -> u64 { - self.last_keep_alive_timestamp - } - - /// Set the last keep alive timestamp to the current time. - fn update_last_keep_alive_timestamp(&mut self) { - self.last_keep_alive_timestamp = get_current_time(); - } } thread_local! { @@ -521,119 +509,6 @@ fn send_service_message_to_client( _ws_send(client_key, message_bytes, true) } -/// Schedules a timer to send an acknowledgement message to the client. -/// -/// The timer callback is [send_ack_to_clients_timer_callback]. After the callback is executed, -/// a timer is scheduled to check if the registered clients have sent a keep alive message. -fn schedule_send_ack_to_clients(ack_interval_ms: u64, check_interval_ms: u64) { - set_timer(Duration::from_millis(ack_interval_ms), move || { - send_ack_to_clients_timer_callback(); - - schedule_check_keep_alive(ack_interval_ms, check_interval_ms); - }); -} - -/// Schedules a timer to check if the registered clients have sent a keep alive message -/// after receiving an acknowledgement message. -/// -/// The timer callback is [check_keep_alive_timer_callback]. After the callback is executed, -/// a timer is scheduled again to send an acknowledgement message to the registered clients. -fn schedule_check_keep_alive(ack_interval_ms: u64, check_interval_ms: u64) { - set_timer(Duration::from_millis(check_interval_ms), move || { - check_keep_alive_timer_callback(); - - schedule_send_ack_to_clients(ack_interval_ms, check_interval_ms); - }); -} - -/// Sends an acknowledgement message to the client. -/// The message contains the current incoming message sequence number for that client, -/// so that the client knows that all the messages it sent have been received by the canister. -fn send_ack_to_clients_timer_callback() { - REGISTERED_CLIENTS.with(|state| { - let map = state.borrow(); - for (client_key, _) in map.iter() { - // ignore the error, which shouldn't happen since the client is registered and the sequence number is initialized - if let Ok(last_incoming_message_sequence_num) = - get_expected_incoming_message_from_client_num(client_key) - { - let ack_message = CanisterAckMessageContent { - last_incoming_sequence_num: last_incoming_message_sequence_num, - }; - let message = WebsocketServiceMessageContent::AckMessage(ack_message); - if let Err(e) = send_service_message_to_client(client_key, message) { - // TODO: decide what to do when sending the message fails - custom_print!( - "[ack-to-clients-timer-cb]: Error sending ack message to client {}: {:?}", - client_key, - e - ); - - break; - }; - } - } - }); - - custom_print!("[ack-to-clients-timer-cb]: Sent ack messages to all clients"); -} - -/// Checks if the registered clients have sent a keep alive message. -/// If a client has not sent a keep alive message, it is removed from the registered clients. -fn check_keep_alive_timer_callback() { - let client_keys_to_remove: Vec = REGISTERED_CLIENTS.with(|state| { - let map = state.borrow(); - map.iter() - .filter_map(|(client_key, client_metadata)| { - let last_keep_alive = client_metadata.get_last_keep_alive_timestamp(); - if get_current_time() - last_keep_alive > DEFAULT_CLIENT_KEEP_ALIVE_DELAY_MS { - Some(client_key.to_owned()) - } else { - None - } - }) - .collect() - }); - - for client_key in client_keys_to_remove { - remove_client(&client_key); - - custom_print!( - "[check-keep-alive-timer-cb]: Client {} has not sent a keep alive message in the last {} ms and has been removed", - client_key, - DEFAULT_CLIENT_KEEP_ALIVE_DELAY_MS - ); - } - - custom_print!("[check-keep-alive-timer-cb]: Checked keep alive messages for all clients"); -} - -fn handle_keep_alive_client_message(client_key: &ClientKey, content: &[u8]) -> Result<(), String> { - match decode_one::(content) { - Ok(message_content) => { - match message_content { - WebsocketServiceMessageContent::KeepAliveMessage(_keep_alive_message) => { - // TODO: delete messages from the queue that have been acknowledged by the client - - // update the last keep alive timestamp for the client - REGISTERED_CLIENTS.with(|map| { - let mut map = map.borrow_mut(); - let client_metadata = map.get_mut(client_key).unwrap(); - client_metadata.update_last_keep_alive_timestamp(); - }); - - Ok(()) - }, - _ => Err(String::from("invalid keep alive message content")), - } - }, - Err(e) => Err(format!( - "Error decoding service message from client: {:?}", - e - )), - } -} - /// Internal function used to put the messages in the outgoing messages queue and certify them. fn _ws_send( client_key: &ClientKey, @@ -804,10 +679,6 @@ pub fn init(params: WsInitParams) { // set the principal of the (only) WS Gateway that will be polling the canister initialize_registered_gateway(¶ms.gateway_principal); - - // schedule a timer that will send an acknowledgement message to clients - // TODO: test - schedule_send_ack_to_clients(params.send_ack_interval_ms, params.keep_alive_delay_ms); } /// Handles the WS connection open event sent by the client and relayed by the Gateway. @@ -904,7 +775,7 @@ pub fn ws_message(args: CanisterWsMessageArguments) -> CanisterWsMessageResult { // TODO: test if is_service_message { - return handle_keep_alive_client_message(&client_key, &content); + custom_print!("Service message handling not implemented yet"); } // call the on_message handler initialized in init() From 3c2ec90f42b41a2ea34fa89173c6b25f41d5f22f Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Mon, 9 Oct 2023 13:45:13 +0200 Subject: [PATCH 29/31] chore: unit test panic handling --- Cargo.lock | 255 +++++++++++++++++++------------- src/ic-websocket-cdk/src/lib.rs | 43 +++++- 2 files changed, 190 insertions(+), 108 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 61394ca..baab2d4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -169,9 +169,9 @@ checksum = "7f30e7476521f6f8af1a1c4c0b8cc94f0bee37d91763d0ca2665f299b6cd8aec" [[package]] name = "byteorder" -version = "1.4.3" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" @@ -181,9 +181,9 @@ checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" [[package]] name = "candid" -version = "0.9.7" +version = "0.9.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f391a0d11d997af68e1a06b5e2ab354079cecb82b6eefb26addb38adf66d351d" +checksum = "aa0f00717c71b8e9ee4c090b4880ec2418c8506bb6828a2c72df72d3896e905d" dependencies = [ "anyhow", "binread", @@ -201,7 +201,7 @@ dependencies = [ "pretty", "serde", "serde_bytes", - "sha2 0.10.7", + "sha2 0.10.8", "stacker", "thiserror", ] @@ -215,7 +215,7 @@ dependencies = [ "lazy_static", "proc-macro2", "quote", - "syn 2.0.37", + "syn 2.0.38", ] [[package]] @@ -249,6 +249,22 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28c122c3980598d243d63d9a704629a2d748d101f278052ff068be5a4423ab6f" +[[package]] +name = "core-foundation" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "194a7a9e6de53fa55116934067c844d9d749312f75c6f6d0980e8c252f8c2146" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa" + [[package]] name = "cpufeatures" version = "0.2.9" @@ -355,9 +371,9 @@ checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" [[package]] name = "elliptic-curve" -version = "0.13.5" +version = "0.13.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "968405c8fdc9b3bf4df0a6638858cc0b52462836ab6b1c87377785dd09cf1c0b" +checksum = "d97ca172ae9dc9f9b779a6e3a65d308f2af74e5b8c921299075bdb4a0370e914" dependencies = [ "base16ct", "crypto-bigint", @@ -390,30 +406,19 @@ checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" [[package]] name = "errno" -version = "0.3.3" +version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "136526188508e25c6fef639d7927dfb3e0e3084488bf202267829cf7fc23dbdd" +checksum = "ac3e13f66a2f95e32a39eaa81f6b95d42878ca0e1db0c7543723dfe12557e860" dependencies = [ - "errno-dragonfly", "libc", "windows-sys", ] -[[package]] -name = "errno-dragonfly" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa68f1b12764fab894d2755d2518754e71b4fd80ecfb822714a1206c2aab39bf" -dependencies = [ - "cc", - "libc", -] - [[package]] name = "fastrand" -version = "2.0.0" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6999dc1837253364c2ebb0704ba97994bd874e8f195d665c50b7548f6ea92764" +checksum = "25cbce373ec4653f1a01a31e8a5e5ec0c622dc27ff9c4e6606eefef5cbbed4a5" [[package]] name = "ff" @@ -506,7 +511,7 @@ checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72" dependencies = [ "proc-macro2", "quote", - "syn 2.0.37", + "syn 2.0.38", ] [[package]] @@ -622,9 +627,9 @@ checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" [[package]] name = "hashbrown" -version = "0.14.0" +version = "0.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a" +checksum = "7dfda62a12f55daeae5015f81b0baea145391cb4520f86c248fc615d72640d12" [[package]] name = "hermit-abi" @@ -739,14 +744,14 @@ dependencies = [ "pkcs8", "rand", "reqwest", - "ring", + "ring 0.16.20", "rustls 0.20.9", "sec1", "serde", "serde_bytes", "serde_cbor", "serde_repr", - "sha2 0.10.7", + "sha2 0.10.8", "simple_asn1", "thiserror", "time", @@ -804,7 +809,7 @@ dependencies = [ "hex", "serde", "serde_bytes", - "sha2 0.10.7", + "sha2 0.10.8", ] [[package]] @@ -815,7 +820,7 @@ checksum = "197524aecec47db0b6c0c9f8821aad47272c2bd762c7a0ffe9715eaca0364061" dependencies = [ "serde", "serde_bytes", - "sha2 0.10.7", + "sha2 0.10.8", ] [[package]] @@ -843,11 +848,11 @@ dependencies = [ "ic-certified-map", "proptest", "rand", - "ring", + "ring 0.16.20", "serde", "serde_bytes", "serde_cbor", - "sha2 0.10.7", + "sha2 0.10.8", ] [[package]] @@ -878,12 +883,12 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.0.0" +version = "2.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d5477fe2230a79769d8dc68e0eabf5437907c0457a5614a9e8dddb67f65eb65d" +checksum = "8adf3ddd720272c6ea8bf59463c04e0f93d0bbf7c5439b691bca2987e0270897" dependencies = [ "equivalent", - "hashbrown 0.14.0", + "hashbrown 0.14.1", ] [[package]] @@ -926,7 +931,7 @@ dependencies = [ "ecdsa", "elliptic-curve", "once_cell", - "sha2 0.10.7", + "sha2 0.10.8", "signature", ] @@ -944,21 +949,21 @@ checksum = "884e2677b40cc8c339eaefcb701c32ef1fd2493d71118dc0ca4b6a736c93bd67" [[package]] name = "libc" -version = "0.2.148" +version = "0.2.149" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9cdc71e17332e86d2e1d38c1f99edcb6288ee11b815fb1a4b049eaa2114d369b" +checksum = "a08173bc88b7955d1b3145aa561539096c421ac8debde8cbc3612ec635fee29b" [[package]] name = "libm" -version = "0.2.7" +version = "0.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7012b1bbb0719e1097c47611d3898568c546d597c2e74d66f6087edd5233ff4" +checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" [[package]] name = "linux-raw-sys" -version = "0.4.7" +version = "0.4.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a9bad9f94746442c783ca431b22403b519cd7fbeed0533fdd6328b2f2212128" +checksum = "da2479e8c062e40bf0066ffa0bc823de0a9368974af99c9f6df941d2c231e03f" [[package]] name = "log" @@ -968,9 +973,9 @@ checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" [[package]] name = "memchr" -version = "2.6.3" +version = "2.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f232d6ef707e1956a43342693d2a31e72989554d58299d7a88738cc95b0d35c" +checksum = "f665ee40bc4a3c5590afb1e9677db74a508659dfd71e126420da8274909a0167" [[package]] name = "mime" @@ -1022,9 +1027,9 @@ dependencies = [ [[package]] name = "num-traits" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f30b0abd723be7e2ffca1272140fac1a2f084c77ec3e123c192b66af1ee9e6c2" +checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c" dependencies = [ "autocfg", "libm", @@ -1058,7 +1063,7 @@ dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "syn 2.0.37", + "syn 2.0.38", ] [[package]] @@ -1152,13 +1157,13 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "pretty" -version = "0.12.1" +version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "563c9d701c3a31dfffaaf9ce23507ba09cbe0b9125ba176d15e629b0235e9acc" +checksum = "b55c4d17d994b637e2f4daf6e5dc5d660d209d5642377d675d7a1c3ab69fa579" dependencies = [ "arrayvec", "typed-arena", - "unicode-segmentation", + "unicode-width", ] [[package]] @@ -1173,22 +1178,22 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.67" +version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d433d9f1a3e8c1263d9456598b16fec66f4acc9a74dacffd35c7bb09b3a1328" +checksum = "134c189feb4956b20f6f547d2cf727d4c0fe06722b20a0eec87ed445a97f92da" dependencies = [ "unicode-ident", ] [[package]] name = "proptest" -version = "1.2.0" +version = "1.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e35c06b98bf36aba164cc17cb25f7e232f5c4aeea73baa14b8a9f0d92dbfa65" +checksum = "7c003ac8c77cb07bb74f5f198bce836a689bcd5a42574612bf14d17bfd08c20e" dependencies = [ "bit-set", - "bitflags 1.3.2", - "byteorder", + "bit-vec", + "bitflags 2.4.0", "lazy_static", "num-traits", "rand", @@ -1274,15 +1279,15 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.6.29" +version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" +checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da" [[package]] name = "reqwest" -version = "0.11.20" +version = "0.11.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e9ad3fe7488d7e34558a2033d45a0c90b72d97b4f80705666fea71472e2e6a1" +checksum = "046cd98826c46c2ac8ddecae268eb5c2e58628688a5fc7a2643704a73faba95b" dependencies = [ "base64", "bytes", @@ -1306,6 +1311,7 @@ dependencies = [ "serde", "serde_json", "serde_urlencoded", + "system-configuration", "tokio", "tokio-rustls", "tokio-util", @@ -1338,12 +1344,26 @@ dependencies = [ "cc", "libc", "once_cell", - "spin", - "untrusted", + "spin 0.5.2", + "untrusted 0.7.1", "web-sys", "winapi", ] +[[package]] +name = "ring" +version = "0.17.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "911b295d2d302948838c8ac142da1ee09fa7863163b44e6715bc9357905878b8" +dependencies = [ + "cc", + "getrandom", + "libc", + "spin 0.9.8", + "untrusted 0.9.0", + "windows-sys", +] + [[package]] name = "rustc-demangle" version = "0.1.23" @@ -1352,9 +1372,9 @@ checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" [[package]] name = "rustix" -version = "0.38.14" +version = "0.38.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "747c788e9ce8e92b12cd485c49ddf90723550b654b32508f979b71a7b1ecda4f" +checksum = "f25469e9ae0f3d0047ca8b93fc56843f38e6774f0914a107ff8b41be8be8e0b7" dependencies = [ "bitflags 2.4.0", "errno", @@ -1370,7 +1390,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b80e3dec595989ea8510028f30c408a4630db12c9cbb8de34203b89d6577e99" dependencies = [ "log", - "ring", + "ring 0.16.20", "sct", "webpki", ] @@ -1382,7 +1402,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cd8d6c9f025a446bc4d18ad9632e69aec8f287aa84499ee335599fabd20c3fd8" dependencies = [ "log", - "ring", + "ring 0.16.20", "rustls-webpki", "sct", ] @@ -1402,8 +1422,8 @@ version = "0.101.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c7d5dece342910d9ba34d259310cae3e0154b873b35408b787b59bce53d34fe" dependencies = [ - "ring", - "untrusted", + "ring 0.16.20", + "untrusted 0.7.1", ] [[package]] @@ -1436,8 +1456,8 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d53dcdb7c9f8158937a7981b48accfd39a43af418591a5d008c7b22b5e1b7ca4" dependencies = [ - "ring", - "untrusted", + "ring 0.16.20", + "untrusted 0.7.1", ] [[package]] @@ -1490,7 +1510,7 @@ checksum = "4eca7ac642d82aa35b60049a6eccb4be6be75e599bd2e9adb5f875a737654af2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.37", + "syn 2.0.38", ] [[package]] @@ -1512,7 +1532,7 @@ checksum = "8725e1dfadb3a50f7e5ce0b1a540466f6ed3fe7a0fca2ac2b8b831d31316bd00" dependencies = [ "proc-macro2", "quote", - "syn 2.0.37", + "syn 2.0.38", ] [[package]] @@ -1553,9 +1573,9 @@ dependencies = [ [[package]] name = "sha2" -version = "0.10.7" +version = "0.10.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "479fb9d862239e610720565ca91403019f2f00410f1864c5aa7479b950a76ed8" +checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" dependencies = [ "cfg-if", "cpufeatures", @@ -1628,6 +1648,12 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" + [[package]] name = "spki" version = "0.7.2" @@ -1670,15 +1696,36 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.37" +version = "2.0.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7303ef2c05cd654186cb250d29049a24840ca25d2747c25c0381c8d9e2f582e8" +checksum = "e96b79aaa137db8f61e26363a0c9b47d8b4ec75da28b7d1d614c2303e232408b" dependencies = [ "proc-macro2", "quote", "unicode-ident", ] +[[package]] +name = "system-configuration" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "system-configuration-sys", +] + +[[package]] +name = "system-configuration-sys" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "tempfile" version = "3.8.0" @@ -1715,29 +1762,29 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.48" +version = "1.0.49" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d6d7a740b8a666a7e828dd00da9c0dc290dff53154ea77ac109281de90589b7" +checksum = "1177e8c6d7ede7afde3585fd2513e611227efd6481bd78d2e82ba1ce16557ed4" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.48" +version = "1.0.49" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49922ecae66cc8a249b77e68d1d0623c1b2c514f0060c27cdc68bd62a1219d35" +checksum = "10712f02019e9288794769fba95cd6847df9874d49d871d062172f9dd41bc4cc" dependencies = [ "proc-macro2", "quote", - "syn 2.0.37", + "syn 2.0.38", ] [[package]] name = "time" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17f6bb557fd245c28e6411aa56b6403c689ad95061f50e4be16c274e70a17e48" +checksum = "426f806f4089c493dcac0d24c29c01e2c38baf8e30f1b716ee37e83d200b18fe" dependencies = [ "deranged", "itoa", @@ -1748,15 +1795,15 @@ dependencies = [ [[package]] name = "time-core" -version = "0.1.1" +version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7300fbefb4dadc1af235a9cef3737cea692a9d97e1b9cbcd4ebdae6f8868e6fb" +checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" [[package]] name = "time-macros" -version = "0.2.14" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a942f44339478ef67935ab2bbaec2fb0322496cf3cbe84b261e06ac3814c572" +checksum = "4ad70d68dba9e1f8aceda7aa6711965dfec1cac869f311a51bd08b3a2ccbce20" dependencies = [ "time-core", ] @@ -1778,9 +1825,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.32.0" +version = "1.33.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17ed6077ed6cd6c74735e21f37eb16dc3935f96878b1fe961074089cc80893f9" +checksum = "4f38200e3ef7995e5ef13baec2f432a6da0aa9ac495b2c0e8f3b7eec2c92d653" dependencies = [ "backtrace", "bytes", @@ -1828,7 +1875,7 @@ version = "0.19.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b5bb770da30e5cbfde35a2d7b9b8a2c4b8ef89548a7a6aeab5c9a576e3e7421" dependencies = [ - "indexmap 2.0.0", + "indexmap 2.0.2", "toml_datetime", "winnow", ] @@ -1904,12 +1951,6 @@ dependencies = [ "tinyvec", ] -[[package]] -name = "unicode-segmentation" -version = "1.10.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1dd624098567895118886609431a7c3b8f516e41d30e0643f03d94592a147e36" - [[package]] name = "unicode-width" version = "0.1.11" @@ -1922,6 +1963,12 @@ version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + [[package]] name = "url" version = "2.4.1" @@ -1984,7 +2031,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.37", + "syn 2.0.38", "wasm-bindgen-shared", ] @@ -2018,7 +2065,7 @@ checksum = "54681b18a46765f095758388f2d0cf16eb8d4169b639ab575a8f5693af210c7b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.37", + "syn 2.0.38", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -2054,12 +2101,12 @@ dependencies = [ [[package]] name = "webpki" -version = "0.22.1" +version = "0.22.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0e74f82d49d545ad128049b7e88f6576df2da6b02e9ce565c6f533be576957e" +checksum = "ed63aea5ce73d0ff405984102c42de94fc55a6b75765d621c65262469b3c9b53" dependencies = [ - "ring", - "untrusted", + "ring 0.17.2", + "untrusted 0.9.0", ] [[package]] @@ -2167,9 +2214,9 @@ checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "winnow" -version = "0.5.15" +version = "0.5.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c2e3184b9c4e92ad5167ca73039d0c42476302ab603e2fec4487511f38ccefc" +checksum = "037711d82167854aff2018dfd193aa0fef5370f456732f0d5a0c59b0f1b4b907" dependencies = [ "memchr", ] diff --git a/src/ic-websocket-cdk/src/lib.rs b/src/ic-websocket-cdk/src/lib.rs index 2816771..648ff2f 100644 --- a/src/ic-websocket-cdk/src/lib.rs +++ b/src/ic-websocket-cdk/src/lib.rs @@ -374,7 +374,6 @@ fn get_messages_for_gateway_range(gateway_principal: Principal, nonce: u64) -> ( MESSAGES_FOR_GATEWAY.with(|m| { let queue_len = m.borrow().len(); - // TODO: test if nonce == 0 && queue_len > 0 { // this is the case in which the poller on the gateway restarted // the range to return is end:last index and start: max(end - MAX_NUMBER_OF_RETURNED_MESSAGES, 0) @@ -595,7 +594,6 @@ pub struct WsHandlers { impl WsHandlers { fn call_on_open(&self, args: OnOpenCallbackArgs) { if let Some(on_open) = self.on_open { - // TODO: test the panic handling let res = panic::catch_unwind(|| { on_open(args); }); @@ -608,7 +606,6 @@ impl WsHandlers { fn call_on_message(&self, args: OnMessageCallbackArgs) { if let Some(on_message) = self.on_message { - // TODO: test the panic handling let res = panic::catch_unwind(|| { on_message(args); }); @@ -621,7 +618,6 @@ impl WsHandlers { fn call_on_close(&self, args: OnCloseCallbackArgs) { if let Some(on_close) = self.on_close { - // TODO: test the panic handling let res = panic::catch_unwind(|| { on_close(args); }); @@ -990,6 +986,45 @@ mod test { assert!(CUSTOM_STATE.with(|h| h.borrow().is_on_close_called)); } + #[test] + fn test_ws_handlers_panic_is_handled() { + let handlers = WsHandlers { + on_open: Some(|_| { + panic!("on_open_panic"); + }), + on_message: Some(|_| { + panic!("on_close_panic"); + }), + on_close: Some(|_| { + panic!("on_close_panic"); + }), + }; + + initialize_handlers(handlers); + + let handlers = HANDLERS.with(|h| h.borrow().clone()); + + let res = panic::catch_unwind(|| { + handlers.call_on_open(OnOpenCallbackArgs { + client_principal: test_utils::generate_random_principal(), + }); + }); + assert!(res.is_ok()); + let res = panic::catch_unwind(|| { + handlers.call_on_message(OnMessageCallbackArgs { + client_principal: test_utils::generate_random_principal(), + message: vec![], + }); + }); + assert!(res.is_ok()); + let res = panic::catch_unwind(|| { + handlers.call_on_close(OnCloseCallbackArgs { + client_principal: test_utils::generate_random_principal(), + }); + }); + assert!(res.is_ok()); + } + #[test] fn test_current_time() { // test From 924e3544d4fde8d74f2a67d7f5f735650abffe39 Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Mon, 9 Oct 2023 14:04:22 +0200 Subject: [PATCH 30/31] chore: handle received service message (mock) --- src/ic-websocket-cdk/src/lib.rs | 31 +++++++++++++++++++++++++++---- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/src/ic-websocket-cdk/src/lib.rs b/src/ic-websocket-cdk/src/lib.rs index 648ff2f..ffcfad7 100644 --- a/src/ic-websocket-cdk/src/lib.rs +++ b/src/ic-websocket-cdk/src/lib.rs @@ -1,4 +1,4 @@ -use candid::{encode_one, CandidType, Principal}; +use candid::{decode_one, encode_one, CandidType, Principal}; #[cfg(not(test))] use ic_cdk::api::time; use ic_cdk::api::{caller, data_certificate, set_certified_data}; @@ -489,7 +489,7 @@ struct ClientKeepAliveMessageContent { last_incoming_sequence_num: u64, } -/// A service message sent by the CDK to the client. +/// A service message sent by the CDK to the client or vice versa. #[derive(CandidType, Deserialize)] enum WebsocketServiceMessageContent { /// Message sent by the **canister** when a client opens a connection. @@ -500,6 +500,16 @@ enum WebsocketServiceMessageContent { KeepAliveMessage(ClientKeepAliveMessageContent), } +impl WebsocketServiceMessageContent { + fn from_candid_bytes(bytes: Vec) -> Result { + decode_one(&bytes).map_err(|e| { + let mut err = String::from("Error decoding service message content: "); + err.push_str(&e.to_string()); + err + }) + } +} + fn send_service_message_to_client( client_key: &ClientKey, message: WebsocketServiceMessageContent, @@ -557,6 +567,20 @@ fn _ws_send( Ok(()) } +fn handle_received_service_message(content: Vec) -> CanisterWsMessageResult { + let decoded = WebsocketServiceMessageContent::from_candid_bytes(content)?; + match decoded { + WebsocketServiceMessageContent::OpenMessage(_) + | WebsocketServiceMessageContent::AckMessage(_) => { + Err(String::from("Invalid received service message")) + }, + WebsocketServiceMessageContent::KeepAliveMessage(_) => { + custom_print!("Service message handling not implemented yet"); + Ok(()) + }, + } +} + /// Arguments passed to the `on_open` handler. pub struct OnOpenCallbackArgs { pub client_principal: ClientPrincipal, @@ -769,9 +793,8 @@ pub fn ws_message(args: CanisterWsMessageArguments) -> CanisterWsMessageResult { // increase the expected sequence number by 1 increment_expected_incoming_message_from_client_num(&client_key)?; - // TODO: test if is_service_message { - custom_print!("Service message handling not implemented yet"); + return handle_received_service_message(content); } // call the on_message handler initialized in init() From 5d6c31f8bb035d4bcfb90a937ed4a9e1a76f75ca Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Mon, 9 Oct 2023 14:05:09 +0200 Subject: [PATCH 31/31] chore: integration tests --- tests/integration/canister.test.ts | 564 ++++++++++------------------ tests/integration/utils/api.ts | 65 +--- tests/integration/utils/client.ts | 5 + tests/integration/utils/messages.ts | 78 +++- tests/src/lib.rs | 10 +- tests/test_canister.did | 2 +- 6 files changed, 302 insertions(+), 422 deletions(-) create mode 100644 tests/integration/utils/client.ts diff --git a/tests/integration/canister.test.ts b/tests/integration/canister.test.ts index ecf5eaa..f1f76a7 100644 --- a/tests/integration/canister.test.ts +++ b/tests/integration/canister.test.ts @@ -1,6 +1,4 @@ import { IDL } from "@dfinity/candid"; -import { Principal } from "@dfinity/principal"; -import { Cbor } from "@dfinity/agent"; import { anonymousClient, canisterId, @@ -13,8 +11,6 @@ import { gateway2, } from "./utils/actors"; import { - isMessageBodyValid, - isValidCertificate, reinitialize, wsClose, wsGetMessages, @@ -34,14 +30,37 @@ import type { WebsocketMessage, } from "../src/declarations/test_canister/test_canister.did"; import { generateClientKey, getRandomClientNonce } from "./utils/random"; -import { CanisterOpenMessageContent, WebsocketServiceMessageContent, encodeWebsocketServiceMessageContent, getServiceMessageFromCanisterMessage, isClientKeyEq } from "./utils/idl"; -import { createWebsocketMessage, decodeWebsocketMessage, filterServiceMessagesFromCanisterMessages } from "./utils/messages"; - -const MAX_NUMBER_OF_RETURNED_MESSAGES = 10; // set in the CDK +import { + CanisterOpenMessageContent, + WebsocketServiceMessageContent, + encodeWebsocketServiceMessageContent, + getServiceMessageFromCanisterMessage, + isClientKeyEq, +} from "./utils/idl"; +import { + isMessageBodyValid, + isValidCertificate, + createWebsocketMessage, + decodeWebsocketMessage, + filterServiceMessagesFromCanisterMessages, + getNextPollingNonceFromMessages, +} from "./utils/messages"; +import { formatClientKey } from "./utils/client"; + +/** + * The maximum number of messages returned by the **ws_get_messages** method. Set in the CDK. + * + * Value: `10` + */ +const MAX_NUMBER_OF_RETURNED_MESSAGES = 10; +/** + * @{@link MAX_NUMBER_OF_RETURNED_MESSAGES} + 2 + * + * Value: `12` + */ const SEND_MESSAGES_COUNT = MAX_NUMBER_OF_RETURNED_MESSAGES + 2; // test with more messages to check the indexes and limits -const MAX_GATEWAY_KEEP_ALIVE_TIME_MS = 15_000; // set in the CDK -const DEFAULT_TEST_SEND_ACK_INTERVAL_MS = 300_000; // 5 minutes to make sure the canister doesn't reset the client -const DEFAULT_TEST_KEEP_ALIVE_DELAY_MS = 300_000; // 5 minutes to make sure the canister doesn't reset the client +// const DEFAULT_TEST_SEND_ACK_INTERVAL_MS = 300_000; // 5 minutes to make sure the canister doesn't reset the client +// const DEFAULT_TEST_KEEP_ALIVE_DELAY_MS = 300_000; // 5 minutes to make sure the canister doesn't reset the client let client1Key: ClientKey; let client2Key: ClientKey; @@ -252,14 +271,14 @@ describe("Canister - ws_message", () => { clientActor: client1, }, true); - // wring content encoding + // wrong content encoding const res = await wsMessage({ message: createWebsocketMessage(client1Key, 1, true, new Uint8Array([1, 2, 3])), actor: client1, }); expect(res).toMatchObject({ - Err: expect.stringContaining("Error decoding service message from client:"), + Err: expect.stringContaining("Error decoding service message content:"), }); const wrongServiceMessage: WebsocketServiceMessageContent = { @@ -274,7 +293,7 @@ describe("Canister - ws_message", () => { }); expect(res2).toMatchObject({ - Err: "invalid keep alive message content", + Err: "Invalid received service message", }); }); @@ -333,233 +352,15 @@ describe("Canister - ws_get_messages (failures,empty)", () => { }); }); -// describe("Canister - ws_message (gateway status)", () => { -// beforeAll(async () => { -// await assignKeysToClients(); - -// await wsRegister({ -// clientActor: client1, -// clientKey: client1Key.publicKey, -// }, true); - -// await wsOpen({ -// clientPublicKey: client1Key.publicKey, -// clientSecretKey: client1Key.secretKey, -// canisterId, -// clientActor: gateway1, -// }, true); - -// await wsSend({ -// clientKey: client1Key.publicKey, -// actor: client1, -// message: { text: "test" }, -// }, true); -// }); - -// afterAll(async () => { -// await wsWipe(gateway1); -// }); - -// it("fails if a non registered gateway sends an IcWebSocketGatewayStatus message", async () => { -// const res = await wsMessage({ -// message: { -// IcWebSocketGatewayStatus: { -// status_index: BigInt(1), -// }, -// }, -// actor: gateway2, -// }); - -// expect(res).toMatchObject({ -// Err: "caller is not the gateway that has been registered during CDK initialization", -// }); -// }); - -// it("registered gateway should update the status index", async () => { -// const res = await wsMessage({ -// message: { -// IcWebSocketGatewayStatus: { -// status_index: BigInt(2), // set it high to test behavior for indexes behind the current one -// }, -// }, -// actor: gateway1, -// }); - -// expect(res).toMatchObject({ -// Ok: null, -// }); -// }); - -// it("fails if a registered gateway sends an IcWebSocketGatewayStatus with a wrong status index (equal to current)", async () => { -// const res = await wsMessage({ -// message: { -// IcWebSocketGatewayStatus: { -// status_index: BigInt(2), -// }, -// }, -// actor: gateway1, -// }); - -// expect(res).toMatchObject({ -// Err: "Gateway status index is equal to or behind the current one", -// }); -// }); - -// it("fails if a registered gateway sends an IcWebSocketGatewayStatus with a wrong status index (behind the current)", async () => { -// const res = await wsMessage({ -// message: { -// IcWebSocketGatewayStatus: { -// status_index: BigInt(1), -// }, -// }, -// actor: gateway1, -// }); - -// expect(res).toMatchObject({ -// Err: "Gateway status index is equal to or behind the current one", -// }); -// }); - -// it("registered gateway should disconnect after maximum time", async () => { -// let res = await gateway1.ws_get_messages({ -// nonce: BigInt(0), -// }); - -// expect(res).toMatchObject({ -// Ok: { -// messages: expect.any(Array), -// cert: expect.any(Uint8Array), -// tree: expect.any(Uint8Array), -// }, -// }); -// expect((res as { Ok: CanisterOutputCertifiedMessages }).Ok.messages.length).toEqual(1); - -// // wait for the maximum time the gateway can send a status message, -// // so that the internal canister state is reset -// // double the time to make sure the canister state is reset -// await new Promise((resolve) => setTimeout(resolve, 2 * MAX_GATEWAY_KEEP_ALIVE_TIME_MS)); - -// // check if messages have been deleted -// res = await gateway1.ws_get_messages({ -// nonce: BigInt(0), -// }); -// expect(res).toMatchObject({ -// Ok: { -// messages: [], -// cert: expect.any(Uint8Array), -// tree: expect.any(Uint8Array), -// }, -// }); - -// // check if registered client has been deleted -// const sendRes = await wsSend({ -// clientKey: client1Key.publicKey, -// actor: client1, -// message: { text: "test" }, -// }); -// expect(sendRes).toMatchObject({ -// Err: "client's public key has not been previously registered by client", -// }); -// }); - -// it("registered gateway should reconnect by resetting the status index", async () => { -// let res = await wsMessage({ -// message: { -// IcWebSocketGatewayStatus: { -// status_index: BigInt(0), -// }, -// }, -// actor: gateway1, -// }); - -// expect(res).toMatchObject({ -// Ok: null, -// }); - -// res = await wsMessage({ -// message: { -// IcWebSocketGatewayStatus: { -// status_index: BigInt(1), -// }, -// }, -// actor: gateway1, -// }); - -// expect(res).toMatchObject({ -// Ok: null, -// }); -// }); - -// it("registered gateway should reconnect before maximum time", async () => { -// // reconnect the client -// await wsRegister({ -// clientActor: client1, -// clientKey: client1Key.publicKey, -// }, true); - -// await wsOpen({ -// clientPublicKey: client1Key.publicKey, -// clientSecretKey: client1Key.secretKey, -// canisterId, -// clientActor: gateway1, -// }, true); - -// // send a test message from the canister to check if the internal state is reset -// await wsSend({ -// clientKey: client1Key.publicKey, -// actor: client1, -// message: { text: "test" }, -// }, true); - -// // check if the canister has the message in the queue -// let messagesRes = await gateway1.ws_get_messages({ -// nonce: BigInt(0), -// }); -// expect(messagesRes).toMatchObject({ -// Ok: { -// messages: expect.any(Array), -// cert: expect.any(Uint8Array), -// tree: expect.any(Uint8Array), -// }, -// }); -// expect((messagesRes as { Ok: CanisterOutputCertifiedMessages }).Ok.messages.length).toEqual(1); - -// // simulate a reconnection -// const res = await wsMessage({ -// message: { -// IcWebSocketGatewayStatus: { -// status_index: BigInt(0), -// }, -// }, -// actor: gateway1, -// }); -// expect(res).toMatchObject({ -// Ok: null, -// }); - -// // check if the canister reset the internal state -// messagesRes = await gateway1.ws_get_messages({ -// nonce: BigInt(0), -// }); -// expect(messagesRes).toMatchObject({ -// Ok: { -// messages: [], -// cert: expect.any(Uint8Array), -// tree: expect.any(Uint8Array), -// }, -// }); -// }); -// }); - -describe.only("Canister - ws_get_messages (receive)", () => { +describe("Canister - ws_get_messages (receive)", () => { beforeAll(async () => { await assignKeysToClients(); // reset the internal timers - await reinitialize({ - sendAckIntervalMs: DEFAULT_TEST_SEND_ACK_INTERVAL_MS, - keepAliveDelayMs: DEFAULT_TEST_KEEP_ALIVE_DELAY_MS, - }); + // await reinitialize({ + // sendAckIntervalMs: DEFAULT_TEST_SEND_ACK_INTERVAL_MS, + // keepAliveDelayMs: DEFAULT_TEST_KEEP_ALIVE_DELAY_MS, + // }); await wsOpen({ clientNonce: client1Key.client_nonce, @@ -568,14 +369,15 @@ describe.only("Canister - ws_get_messages (receive)", () => { }, true); // prepare the messages - for (let i = 0; i < SEND_MESSAGES_COUNT; i++) { - const appMessage = { text: `test${i}` }; - await wsSend({ - clientPrincipal: client1Key.client_principal, - actor: client1, - message: appMessage, - }, true); - } + const messages = Array.from({ length: SEND_MESSAGES_COUNT }, (_, i) => { + return { text: `test${i}` }; + }); + + await wsSend({ + clientPrincipal: client1Key.client_principal, + actor: client1, + messages, + }, true); await commonAgent.fetchRootKey(); }); @@ -586,7 +388,7 @@ describe.only("Canister - ws_get_messages (receive)", () => { it("registered gateway can receive correct amount of messages", async () => { // on open, the canister puts a service message in the queue - const messagesCount = SEND_MESSAGES_COUNT + 1; // +1 for the service message + const messagesCount = SEND_MESSAGES_COUNT + 1; // +1 for the open service message for (let i = 0; i < messagesCount; i++) { const res = await gateway1.ws_get_messages({ nonce: BigInt(i), @@ -625,21 +427,22 @@ describe.only("Canister - ws_get_messages (receive)", () => { it("registered gateway can receive certified messages", async () => { // first batch of messages const firstBatchRes = await gateway1.ws_get_messages({ - nonce: BigInt(1), + nonce: BigInt(1), // skip the case in which the gateway restarts polling from the beginning (tested below) }); const firstBatchMessagesResult = (firstBatchRes as { Ok: CanisterOutputCertifiedMessages }).Ok; - console.log(firstBatchMessagesResult.messages.map((msg) => msg.key)); - for (let i = 0; i < firstBatchMessagesResult.messages.length; i++) { - const message = firstBatchMessagesResult.messages[i]; + expect(firstBatchMessagesResult.messages.length).toBe(MAX_NUMBER_OF_RETURNED_MESSAGES); + + let expectedSequenceNumber = 2; // first is the service open message and the number is incremented before sending + let i = 0; + for (const message of firstBatchMessagesResult.messages) { expect(isClientKeyEq(message.client_key, client1Key)).toEqual(true); const websocketMessage = decodeWebsocketMessage(new Uint8Array(message.content)); - console.log(websocketMessage); expect(websocketMessage).toMatchObject({ client_key: expect.any(Object), content: expect.any(Uint8Array), - sequence_num: BigInt(i + 1), - timestamp: expect.any(Object), // weird cbor bigint deserialization + sequence_num: BigInt(expectedSequenceNumber), + timestamp: expect.any(BigInt), is_service_message: false, }); expect(isClientKeyEq(websocketMessage.client_key, client1Key)).toEqual(true); @@ -661,27 +464,33 @@ describe.only("Canister - ws_get_messages (receive)", () => { firstBatchMessagesResult.tree as Uint8Array, ) ).resolves.toBe(true); + + expectedSequenceNumber++; + i++; } + const nextPollingNonce = getNextPollingNonceFromMessages(firstBatchMessagesResult.messages); + // second batch of messages, starting from the last nonce of the first batch const secondBatchRes = await gateway1.ws_get_messages({ - nonce: BigInt(MAX_NUMBER_OF_RETURNED_MESSAGES), + nonce: BigInt(nextPollingNonce), }); const secondBatchMessagesResult = (secondBatchRes as { Ok: CanisterOutputCertifiedMessages }).Ok; - for (let i = 0; i < secondBatchMessagesResult.messages.length; i++) { - const message = secondBatchMessagesResult.messages[i]; + expect(secondBatchMessagesResult.messages.length).toBe(SEND_MESSAGES_COUNT - MAX_NUMBER_OF_RETURNED_MESSAGES); // remaining from SEND_MESSAGES_COUNT + + for (const message of secondBatchMessagesResult.messages) { expect(isClientKeyEq(message.client_key, client1Key)).toEqual(true); const websocketMessage = decodeWebsocketMessage(new Uint8Array(message.content)); expect(websocketMessage).toMatchObject({ client_key: expect.any(Object), content: expect.any(Uint8Array), - sequence_num: BigInt(i + MAX_NUMBER_OF_RETURNED_MESSAGES + 1), - timestamp: expect.any(Object), // weird cbor bigint deserialization + sequence_num: BigInt(expectedSequenceNumber), + timestamp: expect.any(BigInt), is_service_message: false, }); expect(isClientKeyEq(websocketMessage.client_key, client1Key)).toEqual(true); - expect(IDL.decode([IDL.Record({ 'text': IDL.Text })], websocketMessage.content as Uint8Array)).toEqual([{ text: `test${i + MAX_NUMBER_OF_RETURNED_MESSAGES}` }]); + expect(IDL.decode([IDL.Record({ 'text': IDL.Text })], websocketMessage.content as Uint8Array)).toEqual([{ text: `test${i}` }]); // check the certification await expect( @@ -699,107 +508,144 @@ describe.only("Canister - ws_get_messages (receive)", () => { secondBatchMessagesResult.tree as Uint8Array, ) ).resolves.toBe(true); + + expectedSequenceNumber++; + i++; + } + }); + + it("registered gateway can poll messages after restart", async () => { + const batchRes = await gateway1.ws_get_messages({ + nonce: BigInt(0), // start polling from the beginning, as if the gateway restarted + }); + + // we expect that the messages returned are the last MAX_NUMBER_OF_RETURNED_MESSAGES + const messagesResult = (batchRes as { Ok: CanisterOutputCertifiedMessages }).Ok; + expect(messagesResult.messages.length).toBe(MAX_NUMBER_OF_RETURNED_MESSAGES); + + let expectedSequenceNumber = SEND_MESSAGES_COUNT - MAX_NUMBER_OF_RETURNED_MESSAGES + 1 + 1; // +1 for the service open message +1 because the seq num is incremented before sending + let i = SEND_MESSAGES_COUNT - MAX_NUMBER_OF_RETURNED_MESSAGES; + for (const message of messagesResult.messages) { + expect(isClientKeyEq(message.client_key, client1Key)).toEqual(true); + const websocketMessage = decodeWebsocketMessage(new Uint8Array(message.content)); + expect(websocketMessage).toMatchObject({ + client_key: expect.any(Object), + content: expect.any(Uint8Array), + sequence_num: BigInt(expectedSequenceNumber), + timestamp: expect.any(BigInt), + is_service_message: false, + }); + expect(isClientKeyEq(websocketMessage.client_key, client1Key)).toEqual(true); + expect(IDL.decode([IDL.Record({ 'text': IDL.Text })], websocketMessage.content as Uint8Array)).toEqual([{ text: `test${i}` }]); + + // check the certification + await expect( + isValidCertificate( + canisterId, + messagesResult.cert as Uint8Array, + messagesResult.tree as Uint8Array, + commonAgent + ) + ).resolves.toBe(true); + await expect( + isMessageBodyValid( + message.key, + message.content as Uint8Array, + messagesResult.tree as Uint8Array, + ) + ).resolves.toBe(true); + + expectedSequenceNumber++; + i++; } }); }); -// describe("Canister - ws_close", () => { -// beforeAll(async () => { -// await assignKeysToClients(); - -// await wsRegister({ -// clientActor: client1, -// clientKey: client1Key.publicKey, -// }, true); - -// await wsOpen({ -// clientPublicKey: client1Key.publicKey, -// clientSecretKey: client1Key.secretKey, -// canisterId, -// clientActor: gateway1, -// }, true); -// }); - -// afterAll(async () => { -// await wsWipe(gateway1); -// }); - -// it("fails if gateway is not registered", async () => { -// const res = await wsClose({ -// clientPublicKey: client1Key.publicKey, -// gatewayActor: gateway2, -// }); - -// expect(res).toMatchObject({ -// Err: "caller is not the gateway that has been registered during CDK initialization", -// }); -// }); - -// it("fails if client is not registered", async () => { -// const res = await wsClose({ -// clientPublicKey: client2Key.publicKey, -// gatewayActor: gateway1, -// }); - -// expect(res).toMatchObject({ -// Err: "client's public key has not been previously registered by client", -// }); -// }); - -// it("should close the websocket for a registered client", async () => { -// const res = await wsClose({ -// clientPublicKey: client1Key.publicKey, -// gatewayActor: gateway1, -// }); - -// expect(res).toMatchObject({ -// Ok: null, -// }); -// }); -// }); - -// describe("Canister - ws_send", () => { -// beforeAll(async () => { -// await assignKeysToClients(); - -// await wsRegister({ -// clientActor: client1, -// clientKey: client1Key.publicKey, -// }, true); - -// await wsOpen({ -// clientPublicKey: client1Key.publicKey, -// clientSecretKey: client1Key.secretKey, -// canisterId, -// clientActor: gateway1, -// }, true); -// }); - -// afterAll(async () => { -// await wsWipe(gateway1); -// }); - -// it("fails if sending a message to a non registered client", async () => { -// const res = await wsSend({ -// clientKey: client2Key.publicKey, -// actor: client1, -// message: { text: "test" }, -// }); - -// expect(res).toMatchObject({ -// Err: "client's public key has not been previously registered by client", -// }); -// }); - -// it("should send a message to a registered client", async () => { -// const res = await wsSend({ -// clientKey: client1Key.publicKey, -// actor: client1, -// message: { text: "test" }, -// }); - -// expect(res).toMatchObject({ -// Ok: null, -// }); -// }); -// }); +describe("Canister - ws_close", () => { + beforeAll(async () => { + await assignKeysToClients(); + + await wsOpen({ + clientNonce: client1Key.client_nonce, + canisterId, + clientActor: client1, + }, true); + }); + + afterAll(async () => { + await wsWipe(); + }); + + it("fails if gateway is not registered", async () => { + const res = await wsClose({ + clientKey: client1Key, + gatewayActor: gateway2, + }); + + expect(res).toMatchObject({ + Err: "caller is not the gateway that has been registered during CDK initialization", + }); + }); + + it("fails if client is not registered", async () => { + const res = await wsClose({ + clientKey: client2Key, + gatewayActor: gateway1, + }); + + expect(res).toMatchObject({ + Err: `client with key ${formatClientKey(client2Key)} doesn't have an open connection`, + }); + }); + + it("should close the websocket for a registered client", async () => { + const res = await wsClose({ + clientKey: client1Key, + gatewayActor: gateway1, + }); + + expect(res).toMatchObject({ + Ok: null, + }); + }); +}); + +describe("Canister - ws_send", () => { + beforeAll(async () => { + await assignKeysToClients(); + + await wsOpen({ + clientNonce: client1Key.client_nonce, + canisterId, + clientActor: client1, + }, true); + }); + + afterAll(async () => { + await wsWipe(); + }); + + it("fails if sending a message to a non registered client", async () => { + const res = await wsSend({ + clientPrincipal: client2Key.client_principal, + actor: client1, + messages: [{ text: "test" }], + }); + + expect(res).toMatchObject({ + Err: `client with principal ${client2Key.client_principal.toText()} doesn't have an open connection`, + }); + }); + + it("should send a message to a registered client", async () => { + const res = await wsSend({ + clientPrincipal: client1Key.client_principal, + actor: client1, + messages: [{ text: "test" }], + }); + + expect(res).toMatchObject({ + Ok: null, + }); + }); +}); diff --git a/tests/integration/utils/api.ts b/tests/integration/utils/api.ts index a380a87..731ed54 100644 --- a/tests/integration/utils/api.ts +++ b/tests/integration/utils/api.ts @@ -1,8 +1,6 @@ // helpers for functions that are called frequently in tests -import { ActorSubclass, Cbor, Certificate, HashTree, HttpAgent, compare, lookup_path, reconstruct } from "@dfinity/agent"; -import { Secp256k1KeyIdentity } from "@dfinity/identity-secp256k1"; -import { Principal } from "@dfinity/principal"; +import { ActorSubclass } from "@dfinity/agent"; import { IDL } from "@dfinity/candid"; import { anonymousClient, gateway1Data } from "./actors"; import type { CanisterOutputCertifiedMessages, ClientKey, ClientPrincipal, WebsocketMessage, _SERVICE } from "../../src/declarations/test_canister/test_canister.did"; @@ -119,65 +117,16 @@ export const reinitialize = async (args: ReinitializeArgs) => { type WsSendArgs = { clientPrincipal: ClientPrincipal, actor: ActorSubclass<_SERVICE>, - message: { + messages: Array<{ text: string, - }, + }>, }; export const wsSend = async (args: WsSendArgs, throwIfError = false) => { - const msgBytes = IDL.encode([IDL.Record({ 'text': IDL.Text })], [args.message]); - const res = await args.actor.ws_send(args.clientPrincipal, new Uint8Array(msgBytes)); + const serializedMessages = args.messages.map((msg) => { + return new Uint8Array(IDL.encode([IDL.Record({ 'text': IDL.Text })], [msg])); + }); + const res = await args.actor.ws_send(args.clientPrincipal, serializedMessages); return resolveResult(res, throwIfError); }; - -export const getCertifiedMessageKey = async (gatewayIdentity: Promise, nonce: number) => { - const gatewayPrincipal = (await gatewayIdentity).getPrincipal().toText(); - return `${gatewayPrincipal}_${String(nonce).padStart(20, '0')}`; -}; - -export const isValidCertificate = async (canisterId: string, certificate: Uint8Array, tree: Uint8Array, agent: HttpAgent) => { - const canisterPrincipal = Principal.fromText(canisterId); - let cert: Certificate; - - try { - cert = await Certificate.create({ - certificate, - canisterId: canisterPrincipal, - rootKey: agent.rootKey! - }); - } catch (error) { - console.error("Error creating certificate:", error); - return false; - } - - const hashTree = Cbor.decode(tree); - const reconstructed = await reconstruct(hashTree); - const witness = cert.lookup([ - "canister", - canisterPrincipal.toUint8Array(), - "certified_data" - ]); - - if (!witness) { - throw new Error( - "Could not find certified data for this canister in the certificate." - ); - } - - // First validate that the Tree is as good as the certification. - return compare(witness, reconstructed) === 0; -}; - -export const isMessageBodyValid = async (path: string, body: Uint8Array | ArrayBuffer, tree: Uint8Array) => { - const hashTree = Cbor.decode(tree); - const sha = await crypto.subtle.digest("SHA-256", body); - let treeSha = lookup_path(["websocket", path], hashTree); - - if (!treeSha) { - // Allow fallback to index path. - treeSha = lookup_path(["websocket"], hashTree); - } - - return !!treeSha && (compare(sha, treeSha) === 0); -}; diff --git a/tests/integration/utils/client.ts b/tests/integration/utils/client.ts new file mode 100644 index 0000000..2a66f27 --- /dev/null +++ b/tests/integration/utils/client.ts @@ -0,0 +1,5 @@ +import { ClientKey } from "../../src/declarations/test_canister/test_canister.did"; + +export const formatClientKey = (clientKey: ClientKey): string => { + return `${clientKey.client_principal.toText()}_${clientKey.client_nonce.toString()}`; +}; diff --git a/tests/integration/utils/messages.ts b/tests/integration/utils/messages.ts index 993e36c..da2ac8e 100644 --- a/tests/integration/utils/messages.ts +++ b/tests/integration/utils/messages.ts @@ -1,6 +1,8 @@ -import { Cbor } from "@dfinity/agent"; +import { Cbor, Certificate, HashTree, HttpAgent, compare, lookup_path, reconstruct } from "@dfinity/agent"; import { CanisterOutputMessage, ClientKey, WebsocketMessage } from "../../src/declarations/test_canister/test_canister.did"; import { getWebsocketMessageFromCanisterMessage } from "./idl"; +import { Principal } from "@dfinity/principal"; +import { Secp256k1KeyIdentity } from "@dfinity/identity-secp256k1"; export const filterServiceMessagesFromCanisterMessages = (messages: CanisterOutputMessage[]): CanisterOutputMessage[] => { return messages.filter((msg) => { @@ -27,5 +29,77 @@ export const createWebsocketMessage = ( }; export const decodeWebsocketMessage = (bytes: Uint8Array): WebsocketMessage => { - return Cbor.decode(bytes); + const decoded: any = Cbor.decode(bytes); + + // normalize the decoded message + return { + client_key: { + client_principal: Principal.fromUint8Array(decoded.client_key.client_principal), + client_nonce: BigInt(decoded.client_key.client_nonce), + }, + sequence_num: BigInt(decoded.sequence_num), // not clear why cbor deserializes bigint as number + timestamp: BigInt(decoded.timestamp), + content: decoded.content, + is_service_message: decoded.is_service_message, + } +}; + +export const getPollingNonceFromMessage = (message: CanisterOutputMessage): number => { + const nonceStr = message.key.split("_")[1]; + return parseInt(nonceStr); +}; + +export const getNextPollingNonceFromMessages = (messages: CanisterOutputMessage[]): number => { + return getPollingNonceFromMessage(messages[messages.length - 1]) + 1; +}; + +export const getCertifiedMessageKey = async (gatewayIdentity: Promise, nonce: number) => { + const gatewayPrincipal = (await gatewayIdentity).getPrincipal().toText(); + return `${gatewayPrincipal}_${String(nonce).padStart(20, '0')}`; +}; + +export const isValidCertificate = async (canisterId: string, certificate: Uint8Array, tree: Uint8Array, agent: HttpAgent) => { + const canisterPrincipal = Principal.fromText(canisterId); + let cert: Certificate; + + try { + cert = await Certificate.create({ + certificate, + canisterId: canisterPrincipal, + rootKey: agent.rootKey! + }); + } catch (error) { + console.error("Error creating certificate:", error); + return false; + } + + const hashTree = Cbor.decode(tree); + const reconstructed = await reconstruct(hashTree); + const witness = cert.lookup([ + "canister", + canisterPrincipal.toUint8Array(), + "certified_data" + ]); + + if (!witness) { + throw new Error( + "Could not find certified data for this canister in the certificate." + ); + } + + // First validate that the Tree is as good as the certification. + return compare(witness, reconstructed) === 0; +}; + +export const isMessageBodyValid = async (path: string, body: Uint8Array | ArrayBuffer, tree: Uint8Array) => { + const hashTree = Cbor.decode(tree); + const sha = await crypto.subtle.digest("SHA-256", body); + let treeSha = lookup_path(["websocket", path], hashTree); + + if (!treeSha) { + // Allow fallback to index path. + treeSha = lookup_path(["websocket"], hashTree); + } + + return !!treeSha && (compare(sha, treeSha) === 0); }; diff --git a/tests/src/lib.rs b/tests/src/lib.rs index 26932a3..68fe993 100644 --- a/tests/src/lib.rs +++ b/tests/src/lib.rs @@ -66,8 +66,14 @@ fn ws_wipe() { // send a message to the client, usually called by the canister itself #[update] -fn ws_send(client_principal: ClientPrincipal, msg_bytes: Vec) -> CanisterWsSendResult { - ic_websocket_cdk::ws_send(client_principal, msg_bytes) +fn ws_send(client_principal: ClientPrincipal, messages: Vec>) -> CanisterWsSendResult { + for msg_bytes in messages { + match ic_websocket_cdk::ws_send(client_principal, msg_bytes) { + Ok(_) => {}, + Err(e) => return Err(e), + } + } + Ok(()) } // reinitialize the canister diff --git a/tests/test_canister.did b/tests/test_canister.did index da066d3..785e50f 100644 --- a/tests/test_canister.did +++ b/tests/test_canister.did @@ -13,6 +13,6 @@ service : (text, nat64, nat64) -> { // methods used just for debugging/testing "ws_wipe" : () -> (); - "ws_send" : (ClientPrincipal, blob) -> (CanisterWsSendResult); + "ws_send" : (ClientPrincipal, vec blob) -> (CanisterWsSendResult); "reinitialize" : (text, nat64, nat64) -> (); };