diff --git a/Cargo.lock b/Cargo.lock index baab2d4..1b527ef 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -125,9 +125,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.4.0" +version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4682ae6287fcf752ecaabbfcc7b6f9b72aa33933dc23a554d853aea8eea8635" +checksum = "327762f6e5a765692301e5bb513e0d9fef63be86bbc14528052b1cd3e6f03e07" [[package]] name = "block-buffer" @@ -324,9 +324,12 @@ dependencies = [ [[package]] name = "deranged" -version = "0.3.8" +version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2696e8a945f658fd14dc3b87242e6b80cd0f36ff04ea560fa39082368847946" +checksum = "0f32d04922c60427da6f9fef14d042d9edddef64cb9d4ce0d64d0685fbeb1fd3" +dependencies = [ + "powerfmt", +] [[package]] name = "digest" @@ -857,9 +860,9 @@ dependencies = [ [[package]] name = "ic0" -version = "0.18.12" +version = "0.18.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16efdbe5d9b0ea368da50aedbf7640a054139569236f1a5249deb5fd9af5a5d5" +checksum = "576c539151d4769fb4d1a0c25c4108dd18facd04c5695b02cf2d226ab4e43aa5" [[package]] name = "idna" @@ -1149,6 +1152,12 @@ dependencies = [ "spki", ] +[[package]] +name = "powerfmt" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" + [[package]] name = "ppv-lite86" version = "0.2.17" @@ -1193,7 +1202,7 @@ checksum = "7c003ac8c77cb07bb74f5f198bce836a689bcd5a42574612bf14d17bfd08c20e" dependencies = [ "bit-set", "bit-vec", - "bitflags 2.4.0", + "bitflags 2.4.1", "lazy_static", "num-traits", "rand", @@ -1352,9 +1361,9 @@ dependencies = [ [[package]] name = "ring" -version = "0.17.2" +version = "0.17.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "911b295d2d302948838c8ac142da1ee09fa7863163b44e6715bc9357905878b8" +checksum = "fce3045ffa7c981a6ee93f640b538952e155f1ae3a1a02b84547fc7a56b7059a" dependencies = [ "cc", "getrandom", @@ -1372,11 +1381,11 @@ checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" [[package]] name = "rustix" -version = "0.38.17" +version = "0.38.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f25469e9ae0f3d0047ca8b93fc56843f38e6774f0914a107ff8b41be8be8e0b7" +checksum = "745ecfa778e66b2b63c88a61cb36e0eea109e803b0b86bf9879fbc77c70e86ed" dependencies = [ - "bitflags 2.4.0", + "bitflags 2.4.1", "errno", "libc", "linux-raw-sys", @@ -1476,9 +1485,9 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.188" +version = "1.0.189" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf9e0fcba69a370eed61bcf2b728575f726b50b55cba78064753d708ddc7549e" +checksum = "8e422a44e74ad4001bdc8eede9a4570ab52f71190e9c076d14369f38b9200537" dependencies = [ "serde_derive", ] @@ -1504,9 +1513,9 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.188" +version = "1.0.189" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4eca7ac642d82aa35b60049a6eccb4be6be75e599bd2e9adb5f875a737654af2" +checksum = "1e48d1f918009ce3145511378cf68d613e3b3d9137d67272562080d68a2b32d5" dependencies = [ "proc-macro2", "quote", @@ -1782,12 +1791,13 @@ dependencies = [ [[package]] name = "time" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "426f806f4089c493dcac0d24c29c01e2c38baf8e30f1b716ee37e83d200b18fe" +checksum = "c4a34ab300f2dee6e562c10a046fc05e358b29f9bf92277f30c3c8d82275f6f5" dependencies = [ "deranged", "itoa", + "powerfmt", "serde", "time-core", "time-macros", @@ -1888,20 +1898,19 @@ checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" [[package]] name = "tracing" -version = "0.1.37" +version = "0.1.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8" +checksum = "ee2ef2af84856a50c1d430afce2fdded0a4ec7eda868db86409b4543df0797f9" dependencies = [ - "cfg-if", "pin-project-lite", "tracing-core", ] [[package]] name = "tracing-core" -version = "0.1.31" +version = "0.1.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0955b8137a1df6f1a2e9a37d8a6656291ff0297c1a97c24e0d8425fe2312f79a" +checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" dependencies = [ "once_cell", ] @@ -2105,7 +2114,7 @@ version = "0.22.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed63aea5ce73d0ff405984102c42de94fc55a6b75765d621c65262469b3c9b53" dependencies = [ - "ring 0.17.2", + "ring 0.17.4", "untrusted 0.9.0", ] @@ -2214,9 +2223,9 @@ checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "winnow" -version = "0.5.16" +version = "0.5.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "037711d82167854aff2018dfd193aa0fef5370f456732f0d5a0c59b0f1b4b907" +checksum = "a3b801d0e0a6726477cc207f60162da452f3a95adb368399bef20a946e06f65c" dependencies = [ "memchr", ] diff --git a/README.md b/README.md index f285ec3..0c884e6 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,6 @@ In order for the frontend clients and the Gateway to work properly, the canister import "./ws_types.did"; service : { - "ws_register" : (CanisterWsRegisterArguments) -> (CanisterWsRegisterResult); "ws_open" : (CanisterWsOpenArguments) -> (CanisterWsOpenResult); "ws_close" : (CanisterWsCloseArguments) -> (CanisterWsCloseResult); "ws_message" : (CanisterWsMessageArguments) -> (CanisterWsMessageResult); diff --git a/src/ic-websocket-cdk/src/lib.rs b/src/ic-websocket-cdk/src/lib.rs index ffcfad7..4f20384 100644 --- a/src/ic-websocket-cdk/src/lib.rs +++ b/src/ic-websocket-cdk/src/lib.rs @@ -2,24 +2,26 @@ use candid::{decode_one, encode_one, CandidType, Principal}; #[cfg(not(test))] use ic_cdk::api::time; use ic_cdk::api::{caller, data_certificate, set_certified_data}; +use ic_cdk_timers::{clear_timer, set_timer, TimerId}; use ic_certified_map::{labeled, labeled_hash, AsHashTree, Hash as ICHash, RbTree}; use serde::{Deserialize, Serialize}; use serde_cbor::Serializer; use sha2::{Digest, Sha256}; use std::fmt; use std::panic; +use std::time::Duration; use std::{cell::RefCell, collections::HashMap, collections::VecDeque, convert::AsRef}; mod logger; /// The label used when constructing the certification tree. const LABEL_WEBSOCKET: &[u8] = b"websocket"; -/// The maximum number of messages returned by [ws_get_messages] at each poll. -const MAX_NUMBER_OF_RETURNED_MESSAGES: usize = 10; -/// The default delay between two consecutive acknowledgements sent to the client. -const DEFAULT_SEND_ACK_DELAY_MS: u64 = 60_000; // 60 seconds -/// The default delay to wait for the client to send a keep alive after receiving an acknowledgement. -const DEFAULT_CLIENT_KEEP_ALIVE_DELAY_MS: u64 = 10_000; // 10 seconds +/// The default maximum number of messages returned by [ws_get_messages] at each poll. +const DEFAULT_MAX_NUMBER_OF_RETURNED_MESSAGES: usize = 10; +/// The default interval at which to send acknowledgements to the client. +const DEFAULT_SEND_ACK_INTERVAL_MS: u64 = 60_000; // 60 seconds +/// The default timeout to wait for the client to send a keep alive after receiving an acknowledgement. +const DEFAULT_CLIENT_KEEP_ALIVE_TIMEOUT_MS: u64 = 10_000; // 10 seconds pub type ClientPrincipal = Principal; #[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq, Debug, Hash)] @@ -158,6 +160,16 @@ impl RegisteredClient { last_keep_alive_timestamp: get_current_time(), } } + + /// Gets the last keep alive timestamp. + fn get_last_keep_alive_timestamp(&self) -> u64 { + self.last_keep_alive_timestamp + } + + /// Set the last keep alive timestamp to the current time. + fn update_last_keep_alive_timestamp(&mut self) { + self.last_keep_alive_timestamp = get_current_time(); + } } thread_local! { @@ -179,12 +191,12 @@ thread_local! { /// - the WS Gateway uses to specify the first index of the certified messages to be returned when polling /// - the client uses as part of the path in the Merkle tree in order to verify the certificate of the messages relayed by the WS Gateway /* flexible */ static OUTGOING_MESSAGE_NONCE: RefCell = RefCell::new(0u64); - /// The callback handlers for the WebSocket. - /* flexible */ static HANDLERS: RefCell = RefCell::new(WsHandlers { - on_open: None, - on_message: None, - on_close: None, - }); + /// The parameters passed in the CDK initialization + /* flexible */ static PARAMS: RefCell = RefCell::new(WsInitParams::default()); + /// The acknowledgement active timer. + /* flexible */ static ACK_TIMER: RefCell> = RefCell::new(None); + /// The keep alive active timer. + /* flexible */ static KEEP_ALIVE_TIMER: RefCell> = RefCell::new(None); } /// Resets all RefCells to their initial state. @@ -360,7 +372,7 @@ fn remove_client(client_key: &ClientKey) { map.borrow_mut().remove(client_key); }); - let handlers = HANDLERS.with(|state| state.borrow().clone()); + let handlers = get_handlers_from_params(); handlers.call_on_close(OnCloseCallbackArgs { client_principal: client_key.client_principal, }); @@ -371,14 +383,16 @@ fn get_message_for_gateway_key(gateway_principal: Principal, nonce: u64) -> Stri } fn get_messages_for_gateway_range(gateway_principal: Principal, nonce: u64) -> (usize, usize) { + let max_number_of_returned_messages = get_params().max_number_of_returned_messages; + MESSAGES_FOR_GATEWAY.with(|m| { let queue_len = m.borrow().len(); if nonce == 0 && queue_len > 0 { // this is the case in which the poller on the gateway restarted - // the range to return is end:last index and start: max(end - MAX_NUMBER_OF_RETURNED_MESSAGES, 0) - let start_index = if queue_len > MAX_NUMBER_OF_RETURNED_MESSAGES { - queue_len - MAX_NUMBER_OF_RETURNED_MESSAGES + // the range to return is end:last index and start: max(end - max_number_of_returned_messages, 0) + let start_index = if queue_len > max_number_of_returned_messages { + queue_len - max_number_of_returned_messages } else { 0 }; @@ -392,8 +406,8 @@ fn get_messages_for_gateway_range(gateway_principal: Principal, nonce: u64) -> ( let start_index = m.borrow().partition_point(|x| x.key < smallest_key); // message at index corresponding to end index is excluded let mut end_index = queue_len; - if end_index - start_index > MAX_NUMBER_OF_RETURNED_MESSAGES { - end_index = start_index + MAX_NUMBER_OF_RETURNED_MESSAGES; + if end_index - start_index > max_number_of_returned_messages { + end_index = start_index + max_number_of_returned_messages; } (start_index, end_index) }) @@ -474,23 +488,64 @@ fn get_cert_for_range(first: &String, last: &String) -> (Vec, Vec) { }) } -#[derive(CandidType, Deserialize)] +fn put_ack_timer_id(timer_id: TimerId) { + ACK_TIMER.with(|timer| timer.borrow_mut().replace(timer_id)); +} + +fn reset_ack_timer() { + let timer_id = ACK_TIMER.with(|timer| timer.borrow_mut().take()); + + if let Some(t_id) = timer_id { + clear_timer(t_id); + } +} + +fn put_keep_alive_timer_id(timer_id: TimerId) { + KEEP_ALIVE_TIMER.with(|timer| timer.borrow_mut().replace(timer_id)); +} + +fn reset_keep_alive_timer() { + let timer_id = KEEP_ALIVE_TIMER.with(|timer| timer.borrow_mut().take()); + + if let Some(t_id) = timer_id { + clear_timer(t_id); + } +} + +fn reset_timers() { + reset_ack_timer(); + reset_keep_alive_timer(); +} + +fn set_params(params: WsInitParams) { + PARAMS.with(|state| *state.borrow_mut() = params); +} + +fn get_params() -> WsInitParams { + PARAMS.with(|state| state.borrow().clone()) +} + +fn get_handlers_from_params() -> WsHandlers { + get_params().get_handlers() +} + +#[derive(CandidType, Debug, Deserialize)] struct CanisterOpenMessageContent { client_key: ClientKey, } -#[derive(CandidType, Deserialize)] +#[derive(CandidType, Debug, Deserialize)] struct CanisterAckMessageContent { last_incoming_sequence_num: u64, } -#[derive(CandidType, Deserialize)] +#[derive(CandidType, Debug, Deserialize)] struct ClientKeepAliveMessageContent { last_incoming_sequence_num: u64, } /// A service message sent by the CDK to the client or vice versa. -#[derive(CandidType, Deserialize)] +#[derive(CandidType, Debug, Deserialize)] enum WebsocketServiceMessageContent { /// Message sent by the **canister** when a client opens a connection. OpenMessage(CanisterOpenMessageContent), @@ -501,7 +556,7 @@ enum WebsocketServiceMessageContent { } impl WebsocketServiceMessageContent { - fn from_candid_bytes(bytes: Vec) -> Result { + fn from_candid_bytes(bytes: &[u8]) -> Result { decode_one(&bytes).map_err(|e| { let mut err = String::from("Error decoding service message content: "); err.push_str(&e.to_string()); @@ -518,6 +573,123 @@ fn send_service_message_to_client( _ws_send(client_key, message_bytes, true) } +/// Schedules a timer to send an acknowledgement message to the client. +/// +/// The timer callback is [send_ack_to_clients_timer_callback]. After the callback is executed, +/// a timer is scheduled to check if the registered clients have sent a keep alive message. +fn schedule_send_ack_to_clients() { + let ack_interval_ms = get_params().send_ack_interval_ms; + let timer_id = set_timer(Duration::from_millis(ack_interval_ms), move || { + send_ack_to_clients_timer_callback(); + + schedule_check_keep_alive(); + }); + + put_ack_timer_id(timer_id); +} + +/// Schedules a timer to check if the registered clients have sent a keep alive message +/// after receiving an acknowledgement message. +/// +/// The timer callback is [check_keep_alive_timer_callback]. After the callback is executed, +/// a timer is scheduled again to send an acknowledgement message to the registered clients. +fn schedule_check_keep_alive() { + let keep_alive_timeout_ms = get_params().keep_alive_timeout_ms; + let timer_id = set_timer(Duration::from_millis(keep_alive_timeout_ms), move || { + check_keep_alive_timer_callback(keep_alive_timeout_ms); + + schedule_send_ack_to_clients(); + }); + + put_keep_alive_timer_id(timer_id); +} + +/// Sends an acknowledgement message to the client. +/// The message contains the current incoming message sequence number for that client, +/// so that the client knows that all the messages it sent have been received by the canister. +fn send_ack_to_clients_timer_callback() { + REGISTERED_CLIENTS.with(|state| { + let map = state.borrow(); + for client_key in map.keys() { + // ignore the error, which shouldn't happen since the client is registered and the sequence number is initialized + match get_expected_incoming_message_from_client_num(client_key) { + Ok(expected_incoming_sequence_num) => { + let ack_message = CanisterAckMessageContent { + last_incoming_sequence_num: expected_incoming_sequence_num - 1, + }; + let message = WebsocketServiceMessageContent::AckMessage(ack_message); + if let Err(e) = send_service_message_to_client(client_key, message) { + // TODO: decide what to do when sending the message fails + + custom_print!( + "[ack-to-clients-timer-cb]: Error sending ack message to client {}: {:?}", + client_key, + e + ); + }; + }, + Err(e) => { + // TODO: decide what to do when getting the expected incoming sequence number fails (shouldn't happen) + custom_print!( + "[ack-to-clients-timer-cb]: Error getting expected incoming sequence number for client {}: {:?}", + client_key, + e, + ); + } + } + } + }); + + custom_print!("[ack-to-clients-timer-cb]: Sent ack messages to all clients"); +} + +/// Checks if the registered clients have sent a keep alive message. +/// If a client has not sent a keep alive message, it is removed from the registered clients. +fn check_keep_alive_timer_callback(keep_alive_timeout_ms: u64) { + let client_keys_to_remove: Vec = REGISTERED_CLIENTS.with(|state| { + let map = state.borrow(); + map.iter() + .filter_map(|(client_key, client_metadata)| { + let last_keep_alive = client_metadata.get_last_keep_alive_timestamp(); + if get_current_time() - last_keep_alive > (keep_alive_timeout_ms * 1_000_000) { + Some(client_key.to_owned()) + } else { + None + } + }) + .collect() + }); + + for client_key in client_keys_to_remove { + remove_client(&client_key); + + custom_print!( + "[check-keep-alive-timer-cb]: Client {} has not sent a keep alive message in the last {} ms and has been removed", + client_key, + keep_alive_timeout_ms + ); + } + + custom_print!("[check-keep-alive-timer-cb]: Checked keep alive messages for all clients"); +} + +fn handle_keep_alive_client_message( + client_key: &ClientKey, + _keep_alive_message: ClientKeepAliveMessageContent, +) -> Result<(), String> { + // TODO: delete messages from the queue that have been acknowledged by the client + + // update the last keep alive timestamp for the client + REGISTERED_CLIENTS.with(|map| { + let mut map = map.borrow_mut(); + if let Some(client_metadata) = map.get_mut(client_key) { + client_metadata.update_last_keep_alive_timestamp(); + } + }); + + Ok(()) +} + /// Internal function used to put the messages in the outgoing messages queue and certify them. fn _ws_send( client_key: &ClientKey, @@ -567,16 +739,18 @@ fn _ws_send( Ok(()) } -fn handle_received_service_message(content: Vec) -> CanisterWsMessageResult { +fn handle_received_service_message( + client_key: &ClientKey, + content: &[u8], +) -> CanisterWsMessageResult { let decoded = WebsocketServiceMessageContent::from_candid_bytes(content)?; match decoded { WebsocketServiceMessageContent::OpenMessage(_) | WebsocketServiceMessageContent::AckMessage(_) => { Err(String::from("Invalid received service message")) }, - WebsocketServiceMessageContent::KeepAliveMessage(_) => { - custom_print!("Service message handling not implemented yet"); - Ok(()) + WebsocketServiceMessageContent::KeepAliveMessage(keep_alive_message) => { + handle_keep_alive_client_message(client_key, keep_alive_message) }, } } @@ -653,13 +827,6 @@ impl WsHandlers { } } -fn initialize_handlers(handlers: WsHandlers) { - HANDLERS.with(|h| { - let mut h = h.borrow_mut(); - *h = handlers; - }); -} - /// Parameters for the IC WebSocket CDK initialization. For default parameters and simpler initialization, use [`WsInitParams::new`]. #[derive(Clone)] pub struct WsInitParams { @@ -667,13 +834,16 @@ pub struct WsInitParams { pub handlers: WsHandlers, /// The principal of the WS Gateway that will be polling the canister. pub gateway_principal: String, + /// The maximum number of messages to be returned in a polling iteration. + /// Defaults to `10`. + pub max_number_of_returned_messages: usize, /// The interval at which to send an acknowledgement message to the client, /// so that the client knows that all the messages it sent have been received by the canister (in milliseconds). /// Defaults to `60_000` (60 seconds). pub send_ack_interval_ms: u64, /// The delay to wait for the client to send a keep alive after receiving an acknowledgement (in milliseconds). /// Defaults to `10_000` (10 seconds). - pub keep_alive_delay_ms: u64, + pub keep_alive_timeout_ms: u64, } impl WsInitParams { @@ -682,8 +852,23 @@ impl WsInitParams { Self { handlers, gateway_principal, - send_ack_interval_ms: DEFAULT_SEND_ACK_DELAY_MS, - keep_alive_delay_ms: DEFAULT_CLIENT_KEEP_ALIVE_DELAY_MS, + ..Default::default() + } + } + + fn get_handlers(&self) -> WsHandlers { + self.handlers.clone() + } +} + +impl Default for WsInitParams { + fn default() -> Self { + Self { + handlers: WsHandlers::default(), + gateway_principal: String::new(), + max_number_of_returned_messages: DEFAULT_MAX_NUMBER_OF_RETURNED_MESSAGES, + send_ack_interval_ms: DEFAULT_SEND_ACK_INTERVAL_MS, + keep_alive_timeout_ms: DEFAULT_CLIENT_KEEP_ALIVE_TIMEOUT_MS, } } } @@ -691,14 +876,19 @@ impl WsInitParams { /// Initialize the CDK by setting the callback handlers and the **principal** of the WS Gateway that /// will be polling the canister. /// -/// Under the hood, an interval (**60 seconds**) is started using [ic_cdk_timers::set_timer] -/// to check if the WS Gateway is still alive. +/// **Note**: Resets the timers under the hood. pub fn init(params: WsInitParams) { // set the handlers specified by the canister that the CDK uses to manage the IC WebSocket connection - initialize_handlers(params.handlers); + set_params(params.clone()); // set the principal of the (only) WS Gateway that will be polling the canister initialize_registered_gateway(¶ms.gateway_principal); + + // reset initial timers + reset_timers(); + + // schedule a timer that will send an acknowledgement message to clients + schedule_send_ack_to_clients(); } /// Handles the WS connection open event sent by the client and relayed by the Gateway. @@ -736,10 +926,7 @@ pub fn ws_open(args: CanisterWsOpenArguments) -> CanisterWsOpenResult { send_service_message_to_client(&client_key, message)?; // call the on_open handler initialized in init() - HANDLERS.with(|h| { - h.borrow() - .call_on_open(OnOpenCallbackArgs { client_principal }); - }); + get_handlers_from_params().call_on_open(OnOpenCallbackArgs { client_principal }); Ok(()) } @@ -794,17 +981,13 @@ pub fn ws_message(args: CanisterWsMessageArguments) -> CanisterWsMessageResult { increment_expected_incoming_message_from_client_num(&client_key)?; if is_service_message { - return handle_received_service_message(content); + return handle_received_service_message(&client_key, &content); } // call the on_message handler initialized in init() - HANDLERS.with(|h| { - // trigger the on_message handler initialized by canister - // create message to send to client - h.borrow().call_on_message(OnMessageCallbackArgs { - client_principal, - message: content, - }); + get_handlers_from_params().call_on_message(OnMessageCallbackArgs { + client_principal, + message: content, }); Ok(()) } @@ -924,31 +1107,32 @@ mod test { static CUSTOM_STATE : RefCell = RefCell::new(CustomState::new()); } - let mut handlers = WsHandlers { + let mut h = WsHandlers { on_open: None, on_message: None, on_close: None, }; - initialize_handlers(handlers); + set_params(WsInitParams { + handlers: h.clone(), + ..Default::default() + }); - HANDLERS.with(|h| { - let h = h.borrow(); + let handlers = get_handlers_from_params(); - assert!(h.on_open.is_none()); - assert!(h.on_message.is_none()); - assert!(h.on_close.is_none()); + assert!(handlers.on_open.is_none()); + assert!(handlers.on_message.is_none()); + assert!(handlers.on_close.is_none()); - h.call_on_open(OnOpenCallbackArgs { - client_principal: test_utils::generate_random_principal(), - }); - h.call_on_message(OnMessageCallbackArgs { - client_principal: test_utils::generate_random_principal(), - message: vec![], - }); - h.call_on_close(OnCloseCallbackArgs { - client_principal: test_utils::generate_random_principal(), - }); + handlers.call_on_open(OnOpenCallbackArgs { + client_principal: test_utils::generate_random_principal(), + }); + handlers.call_on_message(OnMessageCallbackArgs { + client_principal: test_utils::generate_random_principal(), + message: vec![], + }); + handlers.call_on_close(OnCloseCallbackArgs { + client_principal: test_utils::generate_random_principal(), }); // test that the handlers are not called if they are not initialized @@ -976,31 +1160,32 @@ mod test { }); }; - handlers = WsHandlers { + h = WsHandlers { on_open: Some(on_open), on_message: Some(on_message), on_close: Some(on_close), }; - initialize_handlers(handlers); + set_params(WsInitParams { + handlers: h.clone(), + ..Default::default() + }); - HANDLERS.with(|h| { - let h = h.borrow(); + let handlers = get_handlers_from_params(); - assert!(h.on_open.is_some()); - assert!(h.on_message.is_some()); - assert!(h.on_close.is_some()); + assert!(handlers.on_open.is_some()); + assert!(handlers.on_message.is_some()); + assert!(handlers.on_close.is_some()); - h.call_on_open(OnOpenCallbackArgs { - client_principal: test_utils::generate_random_principal(), - }); - h.call_on_message(OnMessageCallbackArgs { - client_principal: test_utils::generate_random_principal(), - message: vec![], - }); - h.call_on_close(OnCloseCallbackArgs { - client_principal: test_utils::generate_random_principal(), - }); + handlers.call_on_open(OnOpenCallbackArgs { + client_principal: test_utils::generate_random_principal(), + }); + handlers.call_on_message(OnMessageCallbackArgs { + client_principal: test_utils::generate_random_principal(), + message: vec![], + }); + handlers.call_on_close(OnCloseCallbackArgs { + client_principal: test_utils::generate_random_principal(), }); // test that the handlers are called if they are initialized @@ -1011,7 +1196,7 @@ mod test { #[test] fn test_ws_handlers_panic_is_handled() { - let handlers = WsHandlers { + let h = WsHandlers { on_open: Some(|_| { panic!("on_open_panic"); }), @@ -1023,9 +1208,12 @@ mod test { }), }; - initialize_handlers(handlers); + set_params(WsInitParams { + handlers: h.clone(), + ..Default::default() + }); - let handlers = HANDLERS.with(|h| h.borrow().clone()); + let handlers = get_handlers_from_params(); let res = panic::catch_unwind(|| { handlers.call_on_open(OnOpenCallbackArgs { @@ -1330,11 +1518,17 @@ mod test { } #[test] - fn test_get_messages_for_gateway_range_larger_than_max(gateway_principal in any::().prop_map(|_| test_utils::get_static_principal())) { + fn test_get_messages_for_gateway_range_larger_than_max(gateway_principal in any::().prop_map(|_| test_utils::get_static_principal()), max_number_of_returned_messages in any::().prop_map(|c| c % 1000)) { // Set up + PARAMS.with(|p| { + *p.borrow_mut() = WsInitParams { + max_number_of_returned_messages, + ..Default::default() + } + }); REGISTERED_GATEWAY.with(|p| *p.borrow_mut() = Some(RegisteredGateway::new(gateway_principal.clone()))); - let messages_count: u64 = (2 * MAX_NUMBER_OF_RETURNED_MESSAGES).try_into().unwrap(); + let messages_count: u64 = (2 * max_number_of_returned_messages).try_into().unwrap(); let test_client_key = test_utils::get_random_client_key(); test_utils::add_messages_for_gateway(test_client_key, gateway_principal, messages_count); @@ -1343,10 +1537,10 @@ mod test { // the case in which the start index is 0 is tested in test_get_messages_for_gateway_range_initial_nonce for i in 1..messages_count + 1 { let (start_index, end_index) = get_messages_for_gateway_range(gateway_principal, i); - let expected_end_index = if (i as usize) + MAX_NUMBER_OF_RETURNED_MESSAGES > messages_count as usize { + let expected_end_index = if (i as usize) + max_number_of_returned_messages > messages_count as usize { messages_count as usize } else { - (i as usize) + MAX_NUMBER_OF_RETURNED_MESSAGES + (i as usize) + max_number_of_returned_messages }; prop_assert_eq!(start_index, i as usize); prop_assert_eq!(end_index, expected_end_index); @@ -1357,8 +1551,14 @@ mod test { } #[test] - fn test_get_messages_for_gateway_initial_nonce(gateway_principal in any::().prop_map(|_| test_utils::get_static_principal()), messages_count in any::().prop_map(|c| c % 100)) { + fn test_get_messages_for_gateway_initial_nonce(gateway_principal in any::().prop_map(|_| test_utils::get_static_principal()), messages_count in any::().prop_map(|c| c % 100), max_number_of_returned_messages in any::().prop_map(|c| c % 1000)) { // Set up + PARAMS.with(|p| { + *p.borrow_mut() = WsInitParams { + max_number_of_returned_messages, + ..Default::default() + } + }); REGISTERED_GATEWAY.with(|p| *p.borrow_mut() = Some(RegisteredGateway::new(gateway_principal.clone()))); let test_client_key = test_utils::get_random_client_key(); @@ -1366,8 +1566,8 @@ mod test { // Test let (start_index, end_index) = get_messages_for_gateway_range(gateway_principal, 0); - let expected_start_index = if (messages_count as usize) > MAX_NUMBER_OF_RETURNED_MESSAGES { - (messages_count as usize) - MAX_NUMBER_OF_RETURNED_MESSAGES + let expected_start_index = if (messages_count as usize) > max_number_of_returned_messages { + (messages_count as usize) - max_number_of_returned_messages } else { 0 }; diff --git a/tests/integration/canister.test.ts b/tests/integration/canister.test.ts index f1f76a7..c63de28 100644 --- a/tests/integration/canister.test.ts +++ b/tests/integration/canister.test.ts @@ -11,7 +11,7 @@ import { gateway2, } from "./utils/actors"; import { - reinitialize, + initializeCdk, wsClose, wsGetMessages, wsMessage, @@ -33,8 +33,10 @@ import { generateClientKey, getRandomClientNonce } from "./utils/random"; import { CanisterOpenMessageContent, WebsocketServiceMessageContent, + decodeWebsocketServiceMessageContent, encodeWebsocketServiceMessageContent, - getServiceMessageFromCanisterMessage, + getServiceMessageContentFromCanisterMessage, + getWebsocketMessageFromCanisterMessage, isClientKeyEq, } from "./utils/idl"; import { @@ -46,21 +48,35 @@ import { getNextPollingNonceFromMessages, } from "./utils/messages"; import { formatClientKey } from "./utils/client"; +import { sleep } from "./utils/helpers"; /** - * The maximum number of messages returned by the **ws_get_messages** method. Set in the CDK. + * The maximum number of messages returned by the **ws_get_messages** method. * - * Value: `10` + * Value: `20` */ -const MAX_NUMBER_OF_RETURNED_MESSAGES = 10; +const DEFAULT_TEST_MAX_NUMBER_OF_RETURNED_MESSAGES = 20; /** + * Add more messages than the max to check the indexes and limits. * @{@link MAX_NUMBER_OF_RETURNED_MESSAGES} + 2 * * Value: `12` */ -const SEND_MESSAGES_COUNT = MAX_NUMBER_OF_RETURNED_MESSAGES + 2; // test with more messages to check the indexes and limits -// const DEFAULT_TEST_SEND_ACK_INTERVAL_MS = 300_000; // 5 minutes to make sure the canister doesn't reset the client -// const DEFAULT_TEST_KEEP_ALIVE_DELAY_MS = 300_000; // 5 minutes to make sure the canister doesn't reset the client +const SEND_MESSAGES_COUNT = DEFAULT_TEST_MAX_NUMBER_OF_RETURNED_MESSAGES + 2; +/** + * The interval between sending acks from the canister. + * Set to a high value to make sure the canister doesn't reset the client while testing other functions. + * + * Value: `300_000` (5 minutes) + */ +const DEFAULT_TEST_SEND_ACK_INTERVAL_MS = 300_000; +/** + * The interval between keep alive checks in the canister. + * Set to a high value to make sure the canister doesn't reset the client while testing other functions. + * + * Value: `300_000` (5 minutes) + */ +const DEFAULT_TEST_KEEP_ALIVE_TIMEOUT_MS = 300_000; let client1Key: ClientKey; let client2Key: ClientKey; @@ -128,13 +144,13 @@ describe("Canister - ws_open", () => { const serviceMessages = filterServiceMessagesFromCanisterMessages(msgs.messages); - expect(isClientKeyEq(serviceMessages[0].client_key, client1Key)).toBe(true); - const openMessage = getServiceMessageFromCanisterMessage(serviceMessages[0]); + expect(isClientKeyEq(serviceMessages[0].client_key, client1Key)).toEqual(true); + const openMessage = getServiceMessageContentFromCanisterMessage(serviceMessages[0]); expect(openMessage).toMatchObject({ OpenMessage: expect.any(Object), }); const openMessageContent = (openMessage as { OpenMessage: CanisterOpenMessageContent }).OpenMessage; - expect(isClientKeyEq(openMessageContent.client_key, client1Key)).toBe(true); + expect(isClientKeyEq(openMessageContent.client_key, client1Key)).toEqual(true); }); it("fails for a client with the same nonce", async () => { @@ -172,12 +188,12 @@ describe("Canister - ws_open", () => { const serviceMessages = filterServiceMessagesFromCanisterMessages(msgs.messages); const serviceMessagesForClient = serviceMessages.filter((msg) => isClientKeyEq(msg.client_key, clientKey)); - const openMessage = getServiceMessageFromCanisterMessage(serviceMessagesForClient[0]); + const openMessage = getServiceMessageContentFromCanisterMessage(serviceMessagesForClient[0]); expect(openMessage).toMatchObject({ OpenMessage: expect.any(Object), }); const openMessageContent = (openMessage as { OpenMessage: CanisterOpenMessageContent }).OpenMessage; - expect(isClientKeyEq(openMessageContent.client_key, clientKey)).toBe(true); + expect(isClientKeyEq(openMessageContent.client_key, clientKey)).toEqual(true); }); }); @@ -185,6 +201,12 @@ describe("Canister - ws_message", () => { beforeAll(async () => { await assignKeysToClients(); + await initializeCdk({ + maxNumberOfReturnedMessages: DEFAULT_TEST_MAX_NUMBER_OF_RETURNED_MESSAGES, + sendAckIntervalMs: DEFAULT_TEST_SEND_ACK_INTERVAL_MS, + keepAliveDelayMs: DEFAULT_TEST_KEEP_ALIVE_TIMEOUT_MS, + }); + await wsOpen({ clientNonce: client1Key.client_nonce, canisterId, @@ -273,7 +295,7 @@ describe("Canister - ws_message", () => { // wrong content encoding const res = await wsMessage({ - message: createWebsocketMessage(client1Key, 1, true, new Uint8Array([1, 2, 3])), + message: createWebsocketMessage(client1Key, 1, new Uint8Array([1, 2, 3]), true), actor: client1, }); @@ -288,7 +310,7 @@ describe("Canister - ws_message", () => { } }; const res2 = await wsMessage({ - message: createWebsocketMessage(client1Key, 2, true, encodeWebsocketServiceMessageContent(wrongServiceMessage)), + message: createWebsocketMessage(client1Key, 2, encodeWebsocketServiceMessageContent(wrongServiceMessage), true), actor: client1, }); @@ -304,7 +326,7 @@ describe("Canister - ws_message", () => { }, }; const res = await wsMessage({ - message: createWebsocketMessage(client1Key, 3, true, encodeWebsocketServiceMessageContent(clientServiceMessage)), + message: createWebsocketMessage(client1Key, 3, encodeWebsocketServiceMessageContent(clientServiceMessage), true), actor: client1, }); @@ -356,11 +378,11 @@ describe("Canister - ws_get_messages (receive)", () => { beforeAll(async () => { await assignKeysToClients(); - // reset the internal timers - // await reinitialize({ - // sendAckIntervalMs: DEFAULT_TEST_SEND_ACK_INTERVAL_MS, - // keepAliveDelayMs: DEFAULT_TEST_KEEP_ALIVE_DELAY_MS, - // }); + await initializeCdk({ + maxNumberOfReturnedMessages: DEFAULT_TEST_MAX_NUMBER_OF_RETURNED_MESSAGES, + sendAckIntervalMs: DEFAULT_TEST_SEND_ACK_INTERVAL_MS, + keepAliveDelayMs: DEFAULT_TEST_KEEP_ALIVE_TIMEOUT_MS, + }); await wsOpen({ clientNonce: client1Key.client_nonce, @@ -378,8 +400,6 @@ describe("Canister - ws_get_messages (receive)", () => { actor: client1, messages, }, true); - - await commonAgent.fetchRootKey(); }); afterAll(async () => { @@ -403,9 +423,9 @@ describe("Canister - ws_get_messages (receive)", () => { }); const messagesResult = (res as { Ok: CanisterOutputCertifiedMessages }).Ok; - expect(messagesResult.messages.length).toBe( - messagesCount - i > MAX_NUMBER_OF_RETURNED_MESSAGES - ? MAX_NUMBER_OF_RETURNED_MESSAGES + expect(messagesResult.messages.length).toEqual( + messagesCount - i > DEFAULT_TEST_MAX_NUMBER_OF_RETURNED_MESSAGES + ? DEFAULT_TEST_MAX_NUMBER_OF_RETURNED_MESSAGES : messagesCount - i ); } @@ -431,7 +451,7 @@ describe("Canister - ws_get_messages (receive)", () => { }); const firstBatchMessagesResult = (firstBatchRes as { Ok: CanisterOutputCertifiedMessages }).Ok; - expect(firstBatchMessagesResult.messages.length).toBe(MAX_NUMBER_OF_RETURNED_MESSAGES); + expect(firstBatchMessagesResult.messages.length).toEqual(DEFAULT_TEST_MAX_NUMBER_OF_RETURNED_MESSAGES); let expectedSequenceNumber = 2; // first is the service open message and the number is incremented before sending let i = 0; @@ -456,14 +476,14 @@ describe("Canister - ws_get_messages (receive)", () => { firstBatchMessagesResult.tree as Uint8Array, commonAgent ) - ).resolves.toBe(true); + ).resolves.toEqual(true); await expect( isMessageBodyValid( message.key, message.content as Uint8Array, firstBatchMessagesResult.tree as Uint8Array, ) - ).resolves.toBe(true); + ).resolves.toEqual(true); expectedSequenceNumber++; i++; @@ -477,7 +497,7 @@ describe("Canister - ws_get_messages (receive)", () => { }); const secondBatchMessagesResult = (secondBatchRes as { Ok: CanisterOutputCertifiedMessages }).Ok; - expect(secondBatchMessagesResult.messages.length).toBe(SEND_MESSAGES_COUNT - MAX_NUMBER_OF_RETURNED_MESSAGES); // remaining from SEND_MESSAGES_COUNT + expect(secondBatchMessagesResult.messages.length).toEqual(SEND_MESSAGES_COUNT - DEFAULT_TEST_MAX_NUMBER_OF_RETURNED_MESSAGES); // remaining from SEND_MESSAGES_COUNT for (const message of secondBatchMessagesResult.messages) { expect(isClientKeyEq(message.client_key, client1Key)).toEqual(true); @@ -500,14 +520,14 @@ describe("Canister - ws_get_messages (receive)", () => { secondBatchMessagesResult.tree as Uint8Array, commonAgent ) - ).resolves.toBe(true); + ).resolves.toEqual(true); await expect( isMessageBodyValid( message.key, message.content as Uint8Array, secondBatchMessagesResult.tree as Uint8Array, ) - ).resolves.toBe(true); + ).resolves.toEqual(true); expectedSequenceNumber++; i++; @@ -521,10 +541,10 @@ describe("Canister - ws_get_messages (receive)", () => { // we expect that the messages returned are the last MAX_NUMBER_OF_RETURNED_MESSAGES const messagesResult = (batchRes as { Ok: CanisterOutputCertifiedMessages }).Ok; - expect(messagesResult.messages.length).toBe(MAX_NUMBER_OF_RETURNED_MESSAGES); + expect(messagesResult.messages.length).toEqual(DEFAULT_TEST_MAX_NUMBER_OF_RETURNED_MESSAGES); - let expectedSequenceNumber = SEND_MESSAGES_COUNT - MAX_NUMBER_OF_RETURNED_MESSAGES + 1 + 1; // +1 for the service open message +1 because the seq num is incremented before sending - let i = SEND_MESSAGES_COUNT - MAX_NUMBER_OF_RETURNED_MESSAGES; + let expectedSequenceNumber = SEND_MESSAGES_COUNT - DEFAULT_TEST_MAX_NUMBER_OF_RETURNED_MESSAGES + 1 + 1; // +1 for the service open message +1 because the seq num is incremented before sending + let i = SEND_MESSAGES_COUNT - DEFAULT_TEST_MAX_NUMBER_OF_RETURNED_MESSAGES; for (const message of messagesResult.messages) { expect(isClientKeyEq(message.client_key, client1Key)).toEqual(true); const websocketMessage = decodeWebsocketMessage(new Uint8Array(message.content)); @@ -546,14 +566,14 @@ describe("Canister - ws_get_messages (receive)", () => { messagesResult.tree as Uint8Array, commonAgent ) - ).resolves.toBe(true); + ).resolves.toEqual(true); await expect( isMessageBodyValid( message.key, message.content as Uint8Array, messagesResult.tree as Uint8Array, ) - ).resolves.toBe(true); + ).resolves.toEqual(true); expectedSequenceNumber++; i++; @@ -565,6 +585,12 @@ describe("Canister - ws_close", () => { beforeAll(async () => { await assignKeysToClients(); + await initializeCdk({ + maxNumberOfReturnedMessages: DEFAULT_TEST_MAX_NUMBER_OF_RETURNED_MESSAGES, + sendAckIntervalMs: DEFAULT_TEST_SEND_ACK_INTERVAL_MS, + keepAliveDelayMs: DEFAULT_TEST_KEEP_ALIVE_TIMEOUT_MS, + }); + await wsOpen({ clientNonce: client1Key.client_nonce, canisterId, @@ -614,6 +640,12 @@ describe("Canister - ws_send", () => { beforeAll(async () => { await assignKeysToClients(); + await initializeCdk({ + maxNumberOfReturnedMessages: DEFAULT_TEST_MAX_NUMBER_OF_RETURNED_MESSAGES, + sendAckIntervalMs: DEFAULT_TEST_SEND_ACK_INTERVAL_MS, + keepAliveDelayMs: DEFAULT_TEST_KEEP_ALIVE_TIMEOUT_MS, + }); + await wsOpen({ clientNonce: client1Key.client_nonce, canisterId, @@ -649,3 +681,202 @@ describe("Canister - ws_send", () => { }); }); }); + +describe("Messages acknowledgement", () => { + beforeAll(async () => { + await assignKeysToClients(); + }); + + afterEach(async () => { + await wsWipe(); + }); + + it("client should receive ack messages", async () => { + const sendAckIntervalMs = 5_000; // 5 seconds + await initializeCdk({ + maxNumberOfReturnedMessages: DEFAULT_TEST_MAX_NUMBER_OF_RETURNED_MESSAGES, + sendAckIntervalMs, + keepAliveDelayMs: DEFAULT_TEST_KEEP_ALIVE_TIMEOUT_MS, // keep alive timeout still high to avoid removing connected client + }); + + await wsOpen({ + clientNonce: client1Key.client_nonce, + canisterId, + clientActor: client1, + }, true); + + const res = await gateway1.ws_get_messages({ + nonce: BigInt(1), // skip the service open message + }); + + expect(res).toMatchObject({ + Ok: { + messages: [], + cert: new Uint8Array(), + tree: new Uint8Array(), + }, + }); + + await wsMessage({ + actor: client1, + message: createWebsocketMessage(client1Key, 1), + }, true); + + // sleep for 5 seconds, which is more than the sendAckIntervalMs due to the previous calls + // so we are sure that the CDK has sent an ack + await sleep(sendAckIntervalMs); + + const res2 = await gateway1.ws_get_messages({ + nonce: BigInt(1), + }); + + expect(res2).toMatchObject({ + Ok: { + messages: expect.any(Array), + cert: expect.any(Uint8Array), + tree: expect.any(Uint8Array), + }, + }); + const messagesResult = (res2 as { Ok: CanisterOutputCertifiedMessages }).Ok; + expect(messagesResult.messages.length).toEqual(1); + + const ackMessage = messagesResult.messages[0]; + expect(isClientKeyEq(ackMessage.client_key, client1Key)).toEqual(true); + expect(getServiceMessageContentFromCanisterMessage(ackMessage)).toMatchObject({ + AckMessage: { + last_incoming_sequence_num: BigInt(1), + } + }); + + // check if the certification is correct + await expect( + isValidCertificate( + canisterId, + messagesResult.cert as Uint8Array, + messagesResult.tree as Uint8Array, + commonAgent + ) + ).resolves.toEqual(true); + await expect( + isMessageBodyValid( + ackMessage.key, + ackMessage.content as Uint8Array, + messagesResult.tree as Uint8Array, + ) + ).resolves.toEqual(true); + }); + + it("client is removed if keep alive timeout is reached", async () => { + const sendAckIntervalMs = 2_000; // 2 seconds + const keepAliveDelayMs = 5_000; // 5 seconds + await initializeCdk({ + maxNumberOfReturnedMessages: DEFAULT_TEST_MAX_NUMBER_OF_RETURNED_MESSAGES, + sendAckIntervalMs, + keepAliveDelayMs, + }); + + await wsOpen({ + clientNonce: client1Key.client_nonce, + canisterId, + clientActor: client1, + }, true); + + await sleep(sendAckIntervalMs); + + const res = await gateway1.ws_get_messages({ + nonce: BigInt(1), // skip the service open message + }); + + const messagesResult = (res as { Ok: CanisterOutputCertifiedMessages }).Ok; + expect(messagesResult.messages.length).toEqual(1); + // just check if the received message is a service message and belongs to the client + expect(isClientKeyEq(messagesResult.messages[0].client_key, client1Key)).toEqual(true); + const websocketMessage = getWebsocketMessageFromCanisterMessage(messagesResult.messages[0]); + expect(websocketMessage.is_service_message).toEqual(true); + + await sleep(keepAliveDelayMs); + + // to check if the client is not registered anymore, we try to send a message + const res2 = await wsSend({ + clientPrincipal: client1Key.client_principal, + actor: client1, + messages: [{ text: "test" }], + }); + + expect(res2).toMatchObject({ + Err: `client with principal ${client1Key.client_principal.toText()} doesn't have an open connection`, + }); + }); + + it("client is not removed if it sends a keep alive before timeout", async () => { + const sendAckIntervalMs = 3_000; // 3 seconds + const keepAliveDelayMs = 5_000; // 5 seconds + await initializeCdk({ + maxNumberOfReturnedMessages: DEFAULT_TEST_MAX_NUMBER_OF_RETURNED_MESSAGES, + sendAckIntervalMs, + keepAliveDelayMs, + }); + + await wsOpen({ + clientNonce: client1Key.client_nonce, + canisterId, + clientActor: client1, + }, true); + + await sleep(sendAckIntervalMs); + + let res = await gateway1.ws_get_messages({ + nonce: BigInt(1), // skip the service open message + }); + let messagesResult = (res as { Ok: CanisterOutputCertifiedMessages }).Ok; + expect(messagesResult.messages.length).toEqual(1); + + const keepAliveMessage: WebsocketServiceMessageContent = { + KeepAliveMessage: { + last_incoming_sequence_num: BigInt(1), // not relevant + }, + }; + await wsMessage({ + actor: client1, + message: createWebsocketMessage(client1Key, 1, encodeWebsocketServiceMessageContent(keepAliveMessage), true), + }, true); + + await sleep(keepAliveDelayMs); + + // send a message to the canister to see the the sequence number increase in the ack message + await wsMessage({ + actor: client1, + message: createWebsocketMessage(client1Key, 2), + }, true); + + // wait to receive the next acknowledgement + await sleep(sendAckIntervalMs); + + res = await gateway1.ws_get_messages({ + nonce: BigInt(1), // skip the service open message + }); + + messagesResult = (res as { Ok: CanisterOutputCertifiedMessages }).Ok; + expect(messagesResult.messages.length).toEqual(2); + + let expectedClientSequenceNum = 0; + let expectedCanisterSequenceNum = 2; // first message is skipped and sequence number starts from 1 + for (const message of messagesResult.messages) { + expect(isClientKeyEq(message.client_key, client1Key)).toEqual(true); + + const websocketMessage = getWebsocketMessageFromCanisterMessage(message); + expect(websocketMessage.is_service_message).toEqual(true); + expect(websocketMessage.sequence_num).toEqual(expectedCanisterSequenceNum); + + const serviceMessageContent = decodeWebsocketServiceMessageContent(websocketMessage.content as Uint8Array); + expect(serviceMessageContent).toMatchObject({ + AckMessage: { + last_incoming_sequence_num: BigInt(expectedClientSequenceNum), + }, + }); + + expectedClientSequenceNum++; + expectedCanisterSequenceNum++; + } + }); +}); diff --git a/tests/integration/utils/api.ts b/tests/integration/utils/api.ts index 731ed54..f999ec9 100644 --- a/tests/integration/utils/api.ts +++ b/tests/integration/utils/api.ts @@ -100,18 +100,19 @@ export const wsWipe = async () => { await anonymousClient.ws_wipe(); }; -type ReinitializeArgs = { +type InitializeCdkArgs = { + maxNumberOfReturnedMessages: number, sendAckIntervalMs: number, keepAliveDelayMs: number, }; /** - * Used to reinitialize the canister with the provided intervals. - * @param args {@link ReinitializeArgs} + * Used to initialize the CDK again with the provided parameters. + * @param args {@link InitializeCdkArgs} */ -export const reinitialize = async (args: ReinitializeArgs) => { +export const initializeCdk = async (args: InitializeCdkArgs) => { const gatewayPrincipal = (await gateway1Data.identity).getPrincipal().toText(); - await anonymousClient.reinitialize(gatewayPrincipal, BigInt(args.sendAckIntervalMs), BigInt(args.keepAliveDelayMs)); + await anonymousClient.initialize(gatewayPrincipal, BigInt(args.maxNumberOfReturnedMessages), BigInt(args.sendAckIntervalMs), BigInt(args.keepAliveDelayMs)); }; type WsSendArgs = { diff --git a/tests/integration/utils/helpers.ts b/tests/integration/utils/helpers.ts new file mode 100644 index 0000000..77366f9 --- /dev/null +++ b/tests/integration/utils/helpers.ts @@ -0,0 +1,5 @@ +/** + * Sleeps for the specified number of milliseconds. + * @param ms Number of milliseconds to sleep + */ +export const sleep = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms)); diff --git a/tests/integration/utils/idl.ts b/tests/integration/utils/idl.ts index 317486e..159572d 100644 --- a/tests/integration/utils/idl.ts +++ b/tests/integration/utils/idl.ts @@ -56,7 +56,7 @@ export const isClientKeyEq = (a: ClientKey, b: ClientKey): boolean => { return a.client_principal.compareTo(b.client_principal) === "eq" && a.client_nonce === b.client_nonce; } -export const getServiceMessageFromCanisterMessage = (msg: CanisterOutputMessage): WebsocketServiceMessageContent => { +export const getServiceMessageContentFromCanisterMessage = (msg: CanisterOutputMessage): WebsocketServiceMessageContent => { const content = getWebsocketMessageFromCanisterMessage(msg).content; return decodeWebsocketServiceMessageContent(content as Uint8Array); } diff --git a/tests/integration/utils/messages.ts b/tests/integration/utils/messages.ts index da2ac8e..55f7f0f 100644 --- a/tests/integration/utils/messages.ts +++ b/tests/integration/utils/messages.ts @@ -14,8 +14,8 @@ export const filterServiceMessagesFromCanisterMessages = (messages: CanisterOutp export const createWebsocketMessage = ( clientKey: ClientKey, sequenceNumber: number, + content?: ArrayBuffer | Uint8Array, isServiceMessage = false, - content?: ArrayBuffer | Uint8Array ): WebsocketMessage => { const websocketMessage: WebsocketMessage = { client_key: clientKey, @@ -62,6 +62,10 @@ export const isValidCertificate = async (canisterId: string, certificate: Uint8A const canisterPrincipal = Principal.fromText(canisterId); let cert: Certificate; + if (!agent["_rootKeyFetched"]) { + await agent.fetchRootKey(); + } + try { cert = await Certificate.create({ certificate, diff --git a/tests/package-lock.json b/tests/package-lock.json index 627ec17..e69ed7a 100644 --- a/tests/package-lock.json +++ b/tests/package-lock.json @@ -10,11 +10,10 @@ "devDependencies": { "@babel/preset-env": "^7.22.9", "@babel/preset-typescript": "^7.22.5", - "@dfinity/agent": "^0.18.1", - "@dfinity/candid": "^0.18.1", - "@dfinity/identity-secp256k1": "^0.18.1", - "@dfinity/principal": "^0.18.1", - "@noble/ed25519": "^2.0.0", + "@dfinity/agent": "^0.19.2", + "@dfinity/candid": "^0.19.2", + "@dfinity/identity-secp256k1": "^0.19.2", + "@dfinity/principal": "^0.19.2", "@types/hdkey": "^2.0.1", "@types/jest": "^29.5.3", "assert": "2.0.0", @@ -2179,43 +2178,50 @@ } }, "node_modules/@dfinity/agent": { - "version": "0.18.1", + "version": "0.19.2", + "resolved": "https://registry.npmjs.org/@dfinity/agent/-/agent-0.19.2.tgz", + "integrity": "sha512-KLRWEjeU9SyyaS7IBVJ9ZUcufxufr55e/kRIyClK157+0pkTG9a8xKjUIMx3QzKvLsqqzXL238nWwdoP6jAD8g==", "dev": true, - "license": "Apache-2.0", "dependencies": { + "@noble/hashes": "^1.3.1", "base64-arraybuffer": "^0.2.0", "borc": "^2.1.1", - "js-sha256": "0.9.0", "simple-cbor": "^0.4.1" }, "peerDependencies": { - "@dfinity/candid": "^0.18.1", - "@dfinity/principal": "^0.18.1" + "@dfinity/candid": "^0.19.2", + "@dfinity/principal": "^0.19.2" } }, "node_modules/@dfinity/candid": { - "version": "0.18.1", - "resolved": "https://registry.npmjs.org/@dfinity/candid/-/candid-0.18.1.tgz", - "integrity": "sha512-/PC3wDnrGcWhaF/veYKevcAAn5A5jK0mRkVKcz0YxK/a78Ai9wMJg0fUk7aoyZryCOz8JPVgR4nK1/zTmvEBHg==", - "dev": true + "version": "0.19.2", + "resolved": "https://registry.npmjs.org/@dfinity/candid/-/candid-0.19.2.tgz", + "integrity": "sha512-X2hCqNMhnnmwtnOc0WnymOZYx3qphjEMuSYbBr7tMIkV7Hwt9BmXXlLnQTxUytTPxf+3he0GcS3KzsSQ9CK8ew==", + "dev": true, + "peerDependencies": { + "@dfinity/principal": "^0.19.2" + } }, "node_modules/@dfinity/identity-secp256k1": { - "version": "0.18.1", + "version": "0.19.2", + "resolved": "https://registry.npmjs.org/@dfinity/identity-secp256k1/-/identity-secp256k1-0.19.2.tgz", + "integrity": "sha512-rzkmNE9n1XWicjt9R4kw3gpZimoalIoR5TLbXPaBILUvNNO7aGZ6Cz5bvrq8HI14yETkxAXNL5rjGwAtoFqAhQ==", "dev": true, - "license": "Apache-2.0", "dependencies": { - "@dfinity/agent": "^0.18.1", + "@dfinity/agent": "^0.19.2", + "@noble/hashes": "^1.3.1", "bip39": "^3.0.4", "bs58check": "^2.1.2", "secp256k1": "^4.0.3" } }, "node_modules/@dfinity/principal": { - "version": "0.18.1", + "version": "0.19.2", + "resolved": "https://registry.npmjs.org/@dfinity/principal/-/principal-0.19.2.tgz", + "integrity": "sha512-vsKN6BKya70bQUsjgKRDlR2lOpv/XpUkCMIiji6rjMtKHIuWEB5Eu3JqZsOuBmWo3A3TT/K/osT9VPm0k4qdYQ==", "dev": true, - "license": "Apache-2.0", "dependencies": { - "js-sha256": "^0.9.0" + "@noble/hashes": "^1.3.1" } }, "node_modules/@expo/bunyan": { @@ -3819,16 +3825,17 @@ "semver": "bin/semver.js" } }, - "node_modules/@noble/ed25519": { - "version": "2.0.0", + "node_modules/@noble/hashes": { + "version": "1.3.2", + "resolved": "https://registry.npmjs.org/@noble/hashes/-/hashes-1.3.2.tgz", + "integrity": "sha512-MVC8EAQp7MvEcm30KWENFjgR+Mkmf+D189XJTkFIlwohU5hcBbn1ZkKq7KVTi2Hme3PMGF390DaL52beVrIihQ==", "dev": true, - "funding": [ - { - "type": "individual", - "url": "https://paulmillr.com/funding/" - } - ], - "license": "MIT" + "engines": { + "node": ">= 16" + }, + "funding": { + "url": "https://paulmillr.com/funding/" + } }, "node_modules/@nodelib/fs.scandir": { "version": "2.1.5", @@ -9665,11 +9672,6 @@ "optional": true, "peer": true }, - "node_modules/js-sha256": { - "version": "0.9.0", - "dev": true, - "license": "MIT" - }, "node_modules/js-tokens": { "version": "4.0.0", "dev": true, @@ -15954,36 +15956,44 @@ } }, "@dfinity/agent": { - "version": "0.18.1", + "version": "0.19.2", + "resolved": "https://registry.npmjs.org/@dfinity/agent/-/agent-0.19.2.tgz", + "integrity": "sha512-KLRWEjeU9SyyaS7IBVJ9ZUcufxufr55e/kRIyClK157+0pkTG9a8xKjUIMx3QzKvLsqqzXL238nWwdoP6jAD8g==", "dev": true, "requires": { + "@noble/hashes": "^1.3.1", "base64-arraybuffer": "^0.2.0", "borc": "^2.1.1", - "js-sha256": "0.9.0", "simple-cbor": "^0.4.1" } }, "@dfinity/candid": { - "version": "0.18.1", - "resolved": "https://registry.npmjs.org/@dfinity/candid/-/candid-0.18.1.tgz", - "integrity": "sha512-/PC3wDnrGcWhaF/veYKevcAAn5A5jK0mRkVKcz0YxK/a78Ai9wMJg0fUk7aoyZryCOz8JPVgR4nK1/zTmvEBHg==", - "dev": true + "version": "0.19.2", + "resolved": "https://registry.npmjs.org/@dfinity/candid/-/candid-0.19.2.tgz", + "integrity": "sha512-X2hCqNMhnnmwtnOc0WnymOZYx3qphjEMuSYbBr7tMIkV7Hwt9BmXXlLnQTxUytTPxf+3he0GcS3KzsSQ9CK8ew==", + "dev": true, + "requires": {} }, "@dfinity/identity-secp256k1": { - "version": "0.18.1", + "version": "0.19.2", + "resolved": "https://registry.npmjs.org/@dfinity/identity-secp256k1/-/identity-secp256k1-0.19.2.tgz", + "integrity": "sha512-rzkmNE9n1XWicjt9R4kw3gpZimoalIoR5TLbXPaBILUvNNO7aGZ6Cz5bvrq8HI14yETkxAXNL5rjGwAtoFqAhQ==", "dev": true, "requires": { - "@dfinity/agent": "^0.18.1", + "@dfinity/agent": "^0.19.2", + "@noble/hashes": "^1.3.1", "bip39": "^3.0.4", "bs58check": "^2.1.2", "secp256k1": "^4.0.3" } }, "@dfinity/principal": { - "version": "0.18.1", + "version": "0.19.2", + "resolved": "https://registry.npmjs.org/@dfinity/principal/-/principal-0.19.2.tgz", + "integrity": "sha512-vsKN6BKya70bQUsjgKRDlR2lOpv/XpUkCMIiji6rjMtKHIuWEB5Eu3JqZsOuBmWo3A3TT/K/osT9VPm0k4qdYQ==", "dev": true, "requires": { - "js-sha256": "^0.9.0" + "@noble/hashes": "^1.3.1" } }, "@expo/bunyan": { @@ -17200,8 +17210,10 @@ "version": "6.3.3", "dev": true }, - "@noble/ed25519": { - "version": "2.0.0", + "@noble/hashes": { + "version": "1.3.2", + "resolved": "https://registry.npmjs.org/@noble/hashes/-/hashes-1.3.2.tgz", + "integrity": "sha512-MVC8EAQp7MvEcm30KWENFjgR+Mkmf+D189XJTkFIlwohU5hcBbn1ZkKq7KVTi2Hme3PMGF390DaL52beVrIihQ==", "dev": true }, "@nodelib/fs.scandir": { @@ -21321,10 +21333,6 @@ "optional": true, "peer": true }, - "js-sha256": { - "version": "0.9.0", - "dev": true - }, "js-tokens": { "version": "4.0.0", "dev": true diff --git a/tests/package.json b/tests/package.json index 7313deb..a87879d 100644 --- a/tests/package.json +++ b/tests/package.json @@ -10,17 +10,16 @@ ], "scripts": { "generate": "dfx generate test_canister", - "deploy:tests": "dfx deploy test_canister --no-wallet --argument '(\"i3gux-m3hwt-5mh2w-t7wwm-fwx5j-6z6ht-hxguo-t4rfw-qp24z-g5ivt-2qe\", 300_000 : nat64, 300_000 : nat64)'", + "deploy:tests": "dfx deploy test_canister --no-wallet --argument '(\"i3gux-m3hwt-5mh2w-t7wwm-fwx5j-6z6ht-hxguo-t4rfw-qp24z-g5ivt-2qe\", 10 : nat64, 300_000 : nat64, 300_000 : nat64)'", "test:integration": "jest integration" }, "devDependencies": { "@babel/preset-env": "^7.22.9", "@babel/preset-typescript": "^7.22.5", - "@dfinity/agent": "^0.18.1", - "@dfinity/candid": "^0.18.1", - "@dfinity/identity-secp256k1": "^0.18.1", - "@dfinity/principal": "^0.18.1", - "@noble/ed25519": "^2.0.0", + "@dfinity/agent": "^0.19.2", + "@dfinity/candid": "^0.19.2", + "@dfinity/identity-secp256k1": "^0.19.2", + "@dfinity/principal": "^0.19.2", "@types/hdkey": "^2.0.1", "@types/jest": "^29.5.3", "assert": "2.0.0", diff --git a/tests/src/lib.rs b/tests/src/lib.rs index 68fe993..c18c9c5 100644 --- a/tests/src/lib.rs +++ b/tests/src/lib.rs @@ -11,7 +11,12 @@ use ic_websocket_cdk::{ mod canister; #[init] -fn init(gateway_principal: String, send_ack_interval_ms: u64, keep_alive_delay_ms: u64) { +fn init( + gateway_principal: String, + max_number_of_returned_messages: usize, + send_ack_interval_ms: u64, + keep_alive_timeout_ms: u64, +) { let handlers = WsHandlers { on_open: Some(on_open), on_message: Some(on_message), @@ -21,16 +26,27 @@ fn init(gateway_principal: String, send_ack_interval_ms: u64, keep_alive_delay_m let params = WsInitParams { handlers, gateway_principal, + max_number_of_returned_messages, send_ack_interval_ms, - keep_alive_delay_ms, + keep_alive_timeout_ms, }; ic_websocket_cdk::init(params) } #[post_upgrade] -fn post_upgrade(gateway_principal: String, send_ack_interval_ms: u64, keep_alive_delay_ms: u64) { - init(gateway_principal, send_ack_interval_ms, keep_alive_delay_ms); +fn post_upgrade( + gateway_principal: String, + max_number_of_returned_messages: usize, + send_ack_interval_ms: u64, + keep_alive_timeout_ms: u64, +) { + init( + gateway_principal, + max_number_of_returned_messages, + send_ack_interval_ms, + keep_alive_timeout_ms, + ); } // method called by the WS Gateway after receiving FirstMessage from the client @@ -76,9 +92,18 @@ fn ws_send(client_principal: ClientPrincipal, messages: Vec>) -> Caniste Ok(()) } -// reinitialize the canister +// initialize the CK again #[update] -fn reinitialize(gateway_principal: String, send_ack_interval_ms: u64, keep_alive_delay_ms: u64) { - ic_websocket_cdk::wipe(); - init(gateway_principal, send_ack_interval_ms, keep_alive_delay_ms); +fn initialize( + gateway_principal: String, + max_number_of_returned_messages: usize, + send_ack_interval_ms: u64, + keep_alive_delay_ms: u64, +) { + init( + gateway_principal, + max_number_of_returned_messages, + send_ack_interval_ms, + keep_alive_delay_ms, + ); } diff --git a/tests/test_canister.did b/tests/test_canister.did index 785e50f..763924e 100644 --- a/tests/test_canister.did +++ b/tests/test_canister.did @@ -5,7 +5,7 @@ type CanisterWsSendResult = variant { Err : text; }; -service : (text, nat64, nat64) -> { +service : (text, nat64, nat64, nat64) -> { "ws_open" : (CanisterWsOpenArguments) -> (CanisterWsOpenResult); "ws_close" : (CanisterWsCloseArguments) -> (CanisterWsCloseResult); "ws_message" : (CanisterWsMessageArguments) -> (CanisterWsMessageResult); @@ -14,5 +14,5 @@ service : (text, nat64, nat64) -> { // methods used just for debugging/testing "ws_wipe" : () -> (); "ws_send" : (ClientPrincipal, vec blob) -> (CanisterWsSendResult); - "reinitialize" : (text, nat64, nat64) -> (); + "initialize" : (text, nat64, nat64, nat64) -> (); };