From 77d6c7cdc5417d774db3dac30c1755fe150fd26b Mon Sep 17 00:00:00 2001 From: massimoalbarello Date: Fri, 10 Nov 2023 17:42:41 +0100 Subject: [PATCH 01/27] wip: multiple gateways + unit tests --- src/ic-websocket-cdk/src/lib.rs | 129 +++++++++++------- .../src/tests/integration_tests/a_ws_open.rs | 5 + .../tests/integration_tests/utils/actor.rs | 8 +- src/ic-websocket-cdk/src/tests/unit_tests.rs | 74 ++++++---- src/test_canister/src/lib.rs | 12 +- 5 files changed, 151 insertions(+), 77 deletions(-) diff --git a/src/ic-websocket-cdk/src/lib.rs b/src/ic-websocket-cdk/src/lib.rs index f0f4cf9..9aff3da 100644 --- a/src/ic-websocket-cdk/src/lib.rs +++ b/src/ic-websocket-cdk/src/lib.rs @@ -77,6 +77,7 @@ pub type CanisterWsSendResult = Result<(), String>; #[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq, Debug)] pub struct CanisterWsOpenArguments { pub(crate) client_nonce: u64, + gateway_principal: Principal, } /// The arguments for [ws_close]. @@ -167,13 +168,15 @@ fn get_current_time() -> u64 { #[derive(Clone, Debug, Eq, PartialEq)] struct RegisteredClient { last_keep_alive_timestamp: u64, + gateway_principal: Principal, } impl RegisteredClient { /// Creates a new instance of RegisteredClient. - fn new() -> Self { + fn new(gateway_principal: Principal) -> Self { Self { last_keep_alive_timestamp: get_current_time(), + gateway_principal, } } @@ -202,9 +205,9 @@ thread_local! { /// Keeps track of the Merkle tree used for certified queries /* flexible */ static CERT_TREE: RefCell> = RefCell::new(RbTree::new()); /// Keeps track of the principal of the WS Gateway which polls the canister - /* flexible */ static REGISTERED_GATEWAY: RefCell> = RefCell::new(None); - /// Keeps track of the messages that have to be sent to the WS Gateway - /* flexible */ static MESSAGES_FOR_GATEWAY: RefCell> = RefCell::new(VecDeque::new()); + /* flexible */ static REGISTERED_GATEWAYS: RefCell>> = RefCell::new(None); + /// Keeps track of the messages that have to be sent to each authorized WS Gateway + /* flexible */ static MESSAGES_FOR_GATEWAYS: RefCell>> = RefCell::new(HashMap::new()); /// Keeps track of the nonce which: /// - the WS Gateway uses to specify the first index of the certified messages to be returned when polling /// - the client uses as part of the path in the Merkle tree in order to verify the certificate of the messages relayed by the WS Gateway @@ -245,7 +248,7 @@ fn reset_internal_state() { CERT_TREE.with(|t| { t.replace(RbTree::new()); }); - MESSAGES_FOR_GATEWAY.with(|m| *m.borrow_mut() = VecDeque::new()); + MESSAGES_FOR_GATEWAYS.with(|m| *m.borrow_mut() = HashMap::new()); OUTGOING_MESSAGE_NONCE.with(|next_id| next_id.replace(INITIAL_OUTGOING_MESSAGE_NONCE)); } @@ -303,25 +306,45 @@ fn check_registered_client(client_key: &ClientKey) -> Result<(), String> { Ok(()) } +fn get_gateway_principal_from_registered_client( + client_key: &ClientKey, +) -> Result { + check_registered_client(client_key)?; + let gateway_principal = REGISTERED_CLIENTS.with(|map| { + map.borrow() + .get(client_key) + .expect("must be registered") + .gateway_principal + }); + Ok(gateway_principal) +} + fn add_client_to_wait_for_keep_alive(client_key: &ClientKey) { CLIENTS_WAITING_FOR_KEEP_ALIVE.with(|clients| { clients.borrow_mut().insert(client_key.clone()); }); } -fn initialize_registered_gateway(gateway_principal: &str) { - REGISTERED_GATEWAY.with(|p| { - let gateway_principal = - Principal::from_text(gateway_principal).expect("invalid gateway principal"); - *p.borrow_mut() = Some(RegisteredGateway::new(gateway_principal)); +fn initialize_registered_gateways(gateways_principals: Vec) { + REGISTERED_GATEWAYS.with(|p| { + let registered_gateways = gateways_principals + .iter() + .map(|s| { + RegisteredGateway::new(Principal::from_text(s).expect("invalid gateway principal")) + }) + .collect(); + *p.borrow_mut() = Some(registered_gateways); }); } -fn get_registered_gateway_principal() -> Principal { - REGISTERED_GATEWAY.with(|g| { +fn get_registered_gateways_principals() -> Vec { + REGISTERED_GATEWAYS.with(|g| { g.borrow() + .as_ref() .expect("gateway should be initialized") - .gateway_principal + .iter() + .map(|registered_gateway| registered_gateway.gateway_principal) + .collect() }) } @@ -417,8 +440,8 @@ fn get_message_for_gateway_key(gateway_principal: Principal, nonce: u64) -> Stri fn get_messages_for_gateway_range(gateway_principal: Principal, nonce: u64) -> (usize, usize) { let max_number_of_returned_messages = get_params().max_number_of_returned_messages; - MESSAGES_FOR_GATEWAY.with(|m| { - let queue_len = m.borrow().len(); + MESSAGES_FOR_GATEWAYS.with(|h| { + let queue_len = h.borrow().get(&gateway_principal).expect("TODO").len(); if nonce == 0 && queue_len > 0 { // this is the case in which the poller on the gateway restarted @@ -435,7 +458,11 @@ fn get_messages_for_gateway_range(gateway_principal: Principal, nonce: u64) -> ( // smallest key used to determine the first message from the queue which has to be returned to the WS Gateway let smallest_key = get_message_for_gateway_key(gateway_principal, nonce); // partition the queue at the message which has the key with the nonce specified as argument to get_cert_messages - let start_index = m.borrow().partition_point(|x| x.key < smallest_key); + let start_index = h + .borrow() + .get(&gateway_principal) + .expect("TODO") + .partition_point(|x| x.key < smallest_key); // message at index corresponding to end index is excluded let mut end_index = queue_len; if end_index - start_index > max_number_of_returned_messages { @@ -445,20 +472,31 @@ fn get_messages_for_gateway_range(gateway_principal: Principal, nonce: u64) -> ( }) } -fn get_messages_for_gateway(start_index: usize, end_index: usize) -> Vec { - MESSAGES_FOR_GATEWAY.with(|m| { +fn get_messages_for_gateway( + gateway_principal: Principal, + start_index: usize, + end_index: usize, +) -> Vec { + MESSAGES_FOR_GATEWAYS.with(|h| { let mut messages: Vec = Vec::with_capacity(end_index - start_index); for index in start_index..end_index { - messages.push(m.borrow().get(index).unwrap().clone()); + messages.push( + h.borrow() + .get(&gateway_principal) + .expect("TODO") + .get(index) + .unwrap() + .clone(), + ); } messages }) } -/// Gets the messages in MESSAGES_FOR_GATEWAY starting from the one with the specified nonce +/// Gets the messages in MESSAGES_FOR_GATEWAYS starting from the one with the specified nonce fn get_cert_messages(gateway_principal: Principal, nonce: u64) -> CanisterWsGetMessagesResult { let (start_index, end_index) = get_messages_for_gateway_range(gateway_principal, nonce); - let messages = get_messages_for_gateway(start_index, end_index); + let messages = get_messages_for_gateway(gateway_principal, start_index, end_index); if messages.is_empty() { return Ok(CanisterOutputCertifiedMessages { @@ -480,17 +518,14 @@ fn get_cert_messages(gateway_principal: Principal, nonce: u64) -> CanisterWsGetM } fn is_registered_gateway(principal: Principal) -> bool { - let registered_gateway_principal = get_registered_gateway_principal(); - return registered_gateway_principal == principal; + get_registered_gateways_principals().contains(&principal) } -/// Checks if the caller of the method is the same as the one that was registered during the initialization of the CDK -fn check_is_registered_gateway(input_principal: Principal) -> Result<(), String> { - let gateway_principal = get_registered_gateway_principal(); - // check if the caller is the same as the one that was registered during the initialization of the CDK - if gateway_principal != input_principal { +/// Checks if the caller of the method is one of the authorized WS Gateways that have been registered during the initialization of the CDK +fn check_is_registered_gateway(principal: Principal) -> Result<(), String> { + if !is_registered_gateway(principal) { return Err(String::from( - "caller is not the gateway that has been registered during CDK initialization", + "caller is not one of the authorized gateways that have been registered during CDK initialization", )); } Ok(()) @@ -729,11 +764,8 @@ fn _ws_send( msg_bytes: Vec, is_service_message: bool, ) -> CanisterWsSendResult { - // check if the client is registered - check_registered_client(client_key)?; - - // get the principal of the gateway that is polling the canister - let gateway_principal = get_registered_gateway_principal(); + // get the principal of the gateway that the client is connected to + let gateway_principal = get_gateway_principal_from_registered_client(client_key)?; // the nonce in key is used by the WS Gateway to determine the message to start in the polling iteration // the key is also passed to the client in order to validate the body of the certified message @@ -759,15 +791,18 @@ fn _ws_send( // certify data put_cert_for_message(key.clone(), &content); - MESSAGES_FOR_GATEWAY.with(|m| { + MESSAGES_FOR_GATEWAYS.with(|h| { // messages in the queue are inserted with contiguous and increasing nonces // (from beginning to end of the queue) as ws_send is called sequentially, the nonce // is incremented by one in each call, and the message is pushed at the end of the queue - m.borrow_mut().push_back(CanisterOutputMessage { - client_key: client_key.clone(), - content, - key, - }); + h.borrow_mut() + .get_mut(&gateway_principal) + .expect("TODO") + .push_back(CanisterOutputMessage { + client_key: client_key.clone(), + content, + key, + }); }); Ok(()) } @@ -888,8 +923,8 @@ impl WsHandlers { pub struct WsInitParams { /// The callback handlers for the WebSocket. pub handlers: WsHandlers, - /// The principal of the WS Gateway that will be polling the canister. - pub gateway_principal: String, + /// The principals of the WS Gateways that are authorized to poll the canister. + pub gateway_principals: Vec, /// The maximum number of messages to be returned in a polling iteration. /// Defaults to `10`. pub max_number_of_returned_messages: usize, @@ -910,10 +945,10 @@ pub struct WsInitParams { impl WsInitParams { /// Creates a new instance of WsInitParams, with default interval values. - pub fn new(handlers: WsHandlers, gateway_principal: String) -> Self { + pub fn new(handlers: WsHandlers, gateway_principals: Vec) -> Self { Self { handlers, - gateway_principal, + gateway_principals, ..Default::default() } } @@ -938,7 +973,7 @@ impl Default for WsInitParams { fn default() -> Self { Self { handlers: WsHandlers::default(), - gateway_principal: String::new(), + gateway_principals: Vec::new(), max_number_of_returned_messages: DEFAULT_MAX_NUMBER_OF_RETURNED_MESSAGES, send_ack_interval_ms: DEFAULT_SEND_ACK_INTERVAL_MS, keep_alive_timeout_ms: DEFAULT_CLIENT_KEEP_ALIVE_TIMEOUT_MS, @@ -961,7 +996,7 @@ pub fn init(params: WsInitParams) { set_params(params.clone()); // set the principal of the (only) WS Gateway that will be polling the canister - initialize_registered_gateway(¶ms.gateway_principal); + initialize_registered_gateways(params.gateway_principals); // reset initial timers reset_timers(); @@ -995,7 +1030,7 @@ pub fn ws_open(args: CanisterWsOpenArguments) -> CanisterWsOpenResult { } // initialize client maps - let new_client = RegisteredClient::new(); + let new_client = RegisteredClient::new(args.gateway_principal); add_client(client_key.clone(), new_client); let open_message = CanisterOpenMessageContent { diff --git a/src/ic-websocket-cdk/src/tests/integration_tests/a_ws_open.rs b/src/ic-websocket-cdk/src/tests/integration_tests/a_ws_open.rs index 163749a..bc91aa1 100644 --- a/src/ic-websocket-cdk/src/tests/integration_tests/a_ws_open.rs +++ b/src/ic-websocket-cdk/src/tests/integration_tests/a_ws_open.rs @@ -16,6 +16,7 @@ use super::utils::{ fn test_1_fail_for_an_anonymous_client() { let args = CanisterWsOpenArguments { client_nonce: generate_random_client_nonce(), + gateway_principal: GATEWAY_1.deref().to_owned(), }; let res = call_ws_open(&Principal::anonymous(), args); assert_eq!( @@ -28,6 +29,7 @@ fn test_1_fail_for_an_anonymous_client() { fn test_2_fails_for_the_registered_gateway() { let args = CanisterWsOpenArguments { client_nonce: generate_random_client_nonce(), + gateway_principal: GATEWAY_1.deref().to_owned(), }; let res = call_ws_open(GATEWAY_1.deref(), args); assert_eq!( @@ -43,6 +45,7 @@ fn test_3_should_open_a_connection() { let client_1_key = CLIENT_1_KEY.deref(); let args = CanisterWsOpenArguments { client_nonce: client_1_key.client_nonce, + gateway_principal: GATEWAY_1.deref().to_owned(), }; let res = call_ws_open(CLIENT_1.deref(), args); assert_eq!(res, CanisterWsOpenResult::Ok(())); @@ -73,6 +76,7 @@ fn test_4_fails_for_a_client_with_the_same_nonce() { let client_1_key = CLIENT_1_KEY.deref(); let args = CanisterWsOpenArguments { client_nonce: client_1_key.client_nonce, + gateway_principal: GATEWAY_1.deref().to_owned(), }; let res = call_ws_open(CLIENT_1.deref(), args); assert_eq!( @@ -91,6 +95,7 @@ fn test_5_should_open_a_connection_for_the_same_client_with_a_different_nonce() }; let args = CanisterWsOpenArguments { client_nonce: client_key.client_nonce, + gateway_principal: GATEWAY_1.deref().to_owned(), }; let res = call_ws_open(&client_key.client_principal, args); assert_eq!(res, CanisterWsOpenResult::Ok(())); diff --git a/src/ic-websocket-cdk/src/tests/integration_tests/utils/actor.rs b/src/ic-websocket-cdk/src/tests/integration_tests/utils/actor.rs index 9a7e0b6..c033592 100644 --- a/src/ic-websocket-cdk/src/tests/integration_tests/utils/actor.rs +++ b/src/ic-websocket-cdk/src/tests/integration_tests/utils/actor.rs @@ -4,7 +4,12 @@ use pocket_ic::WasmResult; use super::test_env::TEST_ENV; pub mod ws_open { - use crate::{CanisterWsOpenArguments, CanisterWsOpenResult, ClientKey}; + use std::ops::Deref; + + use crate::{ + tests::integration_tests::utils::clients::GATEWAY_1, CanisterWsOpenArguments, + CanisterWsOpenResult, ClientKey, + }; use super::*; @@ -39,6 +44,7 @@ pub mod ws_open { pub(crate) fn call_ws_open_for_client_key_with_panic(client_key: &ClientKey) { let args = CanisterWsOpenArguments { client_nonce: client_key.client_nonce, + gateway_principal: GATEWAY_1.deref().to_owned(), }; call_ws_open_with_panic(&client_key.client_principal, args); } diff --git a/src/ic-websocket-cdk/src/tests/unit_tests.rs b/src/ic-websocket-cdk/src/tests/unit_tests.rs index 5bb02b2..16a23ae 100644 --- a/src/ic-websocket-cdk/src/tests/unit_tests.rs +++ b/src/ic-websocket-cdk/src/tests/unit_tests.rs @@ -8,7 +8,7 @@ mod test_utils { use super::{ get_message_for_gateway_key, CanisterOutputMessage, ClientKey, RegisteredClient, - MESSAGES_FOR_GATEWAY, + MESSAGES_FOR_GATEWAYS, }; fn generate_random_key_pair() -> Ed25519KeyPair { @@ -27,7 +27,7 @@ mod test_utils { } pub(super) fn generate_random_registered_client() -> RegisteredClient { - RegisteredClient::new() + RegisteredClient::new(Principal::anonymous()) } pub fn get_static_principal() -> Principal { @@ -48,19 +48,22 @@ mod test_utils { gateway_principal: Principal, count: u64, ) { - MESSAGES_FOR_GATEWAY.with(|m| { + MESSAGES_FOR_GATEWAYS.with(|m| { for i in 0..count { - m.borrow_mut().push_back(CanisterOutputMessage { - client_key: client_key.clone(), - key: get_message_for_gateway_key(gateway_principal.clone(), i), - content: vec![], - }); + m.borrow_mut() + .get_mut(&gateway_principal) + .expect("TODO") + .push_back(CanisterOutputMessage { + client_key: client_key.clone(), + key: get_message_for_gateway_key(gateway_principal.clone(), i), + content: vec![], + }); } }); } pub fn clean_messages_for_gateway() { - MESSAGES_FOR_GATEWAY.with(|m| m.borrow_mut().clear()); + MESSAGES_FOR_GATEWAYS.with(|m| m.borrow_mut().clear()); } } @@ -68,7 +71,7 @@ mod test_utils { #[test] #[should_panic = "gateway should be initialized"] fn test_get_gateway_principal_not_set() { - get_registered_gateway_principal(); + get_registered_gateways_principals(); } #[test] @@ -231,13 +234,13 @@ fn test_current_time() { proptest! { #[test] fn test_initialize_registered_gateway(test_gateway_principal in any::().prop_map(|_| test_utils::generate_random_principal())) { - initialize_registered_gateway(&test_gateway_principal.to_string()); + initialize_registered_gateways(vec![test_gateway_principal.to_string()]); - REGISTERED_GATEWAY.with(|p| { + REGISTERED_GATEWAYS.with(|p| { let p = p.borrow(); assert!(p.is_some()); assert_eq!( - p.unwrap(), + p.to_owned().unwrap()[0], RegisteredGateway::new(test_gateway_principal) ); }); @@ -278,9 +281,9 @@ proptest! { #[test] fn test_get_gateway_principal(test_gateway_principal in any::().prop_map(|_| test_utils::generate_random_principal())) { // Set up - REGISTERED_GATEWAY.with(|p| *p.borrow_mut() = Some(RegisteredGateway::new(test_gateway_principal.clone()))); + REGISTERED_GATEWAYS.with(|p| *p.borrow_mut() = Some(vec![RegisteredGateway::new(test_gateway_principal.clone())])); - let actual_gateway_principal = get_registered_gateway_principal(); + let actual_gateway_principal = get_registered_gateways_principals()[0]; prop_assert_eq!(actual_gateway_principal, test_gateway_principal); } @@ -478,7 +481,12 @@ proptest! { fn test_get_messages_for_gateway_range_empty(messages_count in any::().prop_map(|c| c % 1000)) { // Set up let gateway_principal = test_utils::generate_random_principal(); - REGISTERED_GATEWAY.with(|p| *p.borrow_mut() = Some(RegisteredGateway::new(gateway_principal.clone()))); + REGISTERED_GATEWAYS.with(|p| *p.borrow_mut() = Some(vec![RegisteredGateway::new(gateway_principal.clone())])); + MESSAGES_FOR_GATEWAYS.with(|m| *m.borrow_mut() = { + let mut m = HashMap::new(); + m.insert(gateway_principal.clone(), VecDeque::new()); + m + }); // Test // we ask for a random range of messages to check if it always returns the same range for empty messages @@ -492,7 +500,12 @@ proptest! { #[test] fn test_get_messages_for_gateway_range_smaller_than_max(gateway_principal in any::().prop_map(|_| test_utils::get_static_principal())) { // Set up - REGISTERED_GATEWAY.with(|p| *p.borrow_mut() = Some(RegisteredGateway::new(gateway_principal.clone()))); + REGISTERED_GATEWAYS.with(|p| *p.borrow_mut() = Some(vec![RegisteredGateway::new(gateway_principal.clone())])); + MESSAGES_FOR_GATEWAYS.with(|m| *m.borrow_mut() = { + let mut m = HashMap::new(); + m.insert(gateway_principal.clone(), VecDeque::new()); + m + }); let messages_count = 4; let test_client_key = test_utils::get_random_client_key(); @@ -520,7 +533,12 @@ proptest! { ..Default::default() } }); - REGISTERED_GATEWAY.with(|p| *p.borrow_mut() = Some(RegisteredGateway::new(gateway_principal.clone()))); + REGISTERED_GATEWAYS.with(|p| *p.borrow_mut() = Some(vec![RegisteredGateway::new(gateway_principal.clone())])); + MESSAGES_FOR_GATEWAYS.with(|m| *m.borrow_mut() = { + let mut m = HashMap::new(); + m.insert(gateway_principal.clone(), VecDeque::new()); + m + }); let messages_count: u64 = (2 * max_number_of_returned_messages).try_into().unwrap(); let test_client_key = test_utils::get_random_client_key(); @@ -553,7 +571,12 @@ proptest! { ..Default::default() } }); - REGISTERED_GATEWAY.with(|p| *p.borrow_mut() = Some(RegisteredGateway::new(gateway_principal.clone()))); + REGISTERED_GATEWAYS.with(|p| *p.borrow_mut() = Some(vec![RegisteredGateway::new(gateway_principal.clone())])); + MESSAGES_FOR_GATEWAYS.with(|m| *m.borrow_mut() = { + let mut m = HashMap::new(); + m.insert(gateway_principal.clone(), VecDeque::new()); + m + }); let test_client_key = test_utils::get_random_client_key(); test_utils::add_messages_for_gateway(test_client_key, gateway_principal, messages_count); @@ -575,7 +598,12 @@ proptest! { #[test] fn test_get_messages_for_gateway(gateway_principal in any::().prop_map(|_| test_utils::get_static_principal()), messages_count in any::().prop_map(|c| c % 100)) { // Set up - REGISTERED_GATEWAY.with(|p| *p.borrow_mut() = Some(RegisteredGateway::new(gateway_principal.clone()))); + REGISTERED_GATEWAYS.with(|p| *p.borrow_mut() = Some(vec![RegisteredGateway::new(gateway_principal.clone())])); + MESSAGES_FOR_GATEWAYS.with(|m| *m.borrow_mut() = { + let mut m = HashMap::new(); + m.insert(gateway_principal.clone(), VecDeque::new()); + m + }); let test_client_key = test_utils::get_random_client_key(); test_utils::add_messages_for_gateway(test_client_key, gateway_principal, messages_count); @@ -584,7 +612,7 @@ proptest! { // add one to test the out of range index for i in 0..messages_count + 1 { let (start_index, end_index) = get_messages_for_gateway_range(gateway_principal, i); - let messages = get_messages_for_gateway(start_index, end_index); + let messages = get_messages_for_gateway(gateway_principal, start_index, end_index); // check if the messages returned are the ones we expect for (j, message) in messages.iter().enumerate() { @@ -600,14 +628,14 @@ proptest! { #[test] fn test_check_is_registered_gateway(test_gateway_principal in any::().prop_map(|_| test_utils::generate_random_principal())) { // Set up - REGISTERED_GATEWAY.with(|p| *p.borrow_mut() = Some(RegisteredGateway::new(test_gateway_principal.clone()))); + REGISTERED_GATEWAYS.with(|p| *p.borrow_mut() = Some(vec![RegisteredGateway::new(test_gateway_principal.clone())])); let actual_result = check_is_registered_gateway(test_gateway_principal); prop_assert!(actual_result.is_ok()); let other_principal = test_utils::generate_random_principal(); let actual_result = check_is_registered_gateway(other_principal); - prop_assert_eq!(actual_result.err(), Some(String::from("caller is not the gateway that has been registered during CDK initialization"))); + prop_assert_eq!(actual_result.err(), Some(String::from("caller is not one of the authorized gateways that have been registered during CDK initialization"))); } #[test] diff --git a/src/test_canister/src/lib.rs b/src/test_canister/src/lib.rs index ff198fa..d105d5a 100644 --- a/src/test_canister/src/lib.rs +++ b/src/test_canister/src/lib.rs @@ -12,7 +12,7 @@ mod canister; #[init] fn init( - gateway_principal: String, + gateway_principals: Vec, max_number_of_returned_messages: usize, send_ack_interval_ms: u64, keep_alive_timeout_ms: u64, @@ -25,7 +25,7 @@ fn init( let params = WsInitParams { handlers, - gateway_principal, + gateway_principals, max_number_of_returned_messages, send_ack_interval_ms, keep_alive_timeout_ms, @@ -36,13 +36,13 @@ fn init( #[post_upgrade] fn post_upgrade( - gateway_principal: String, + gateway_principals: Vec, max_number_of_returned_messages: usize, send_ack_interval_ms: u64, keep_alive_timeout_ms: u64, ) { init( - gateway_principal, + gateway_principals, max_number_of_returned_messages, send_ack_interval_ms, keep_alive_timeout_ms, @@ -98,13 +98,13 @@ fn ws_send(client_principal: ClientPrincipal, messages: Vec>) -> Caniste // initialize the CDK again #[update] fn initialize( - gateway_principal: String, + gateway_principals: Vec, max_number_of_returned_messages: usize, send_ack_interval_ms: u64, keep_alive_delay_ms: u64, ) { init( - gateway_principal, + gateway_principals, max_number_of_returned_messages, send_ack_interval_ms, keep_alive_delay_ms, From d417100bfce11376ce3222d4481512c948e4c099 Mon Sep 17 00:00:00 2001 From: massimoalbarello Date: Fri, 10 Nov 2023 19:15:15 +0100 Subject: [PATCH 02/27] message nonce for each gateway --- src/ic-websocket-cdk/src/lib.rs | 50 ++++++++++++------- .../tests/integration_tests/utils/test_env.rs | 7 ++- src/ic-websocket-cdk/src/tests/unit_tests.rs | 12 +++-- src/ic-websocket-cdk/ws_types.did | 1 + 4 files changed, 44 insertions(+), 26 deletions(-) diff --git a/src/ic-websocket-cdk/src/lib.rs b/src/ic-websocket-cdk/src/lib.rs index 9aff3da..40c4dba 100644 --- a/src/ic-websocket-cdk/src/lib.rs +++ b/src/ic-websocket-cdk/src/lib.rs @@ -48,7 +48,7 @@ pub(crate) struct ClientKey { impl ClientKey { /// Creates a new instance of ClientKey. - pub(crate) fn new(client_principal: ClientPrincipal, client_nonce: u64) -> Self { + fn new(client_principal: ClientPrincipal, client_nonce: u64) -> Self { Self { client_principal, client_nonce, @@ -76,26 +76,26 @@ pub type CanisterWsSendResult = Result<(), String>; /// The arguments for [ws_open]. #[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq, Debug)] pub struct CanisterWsOpenArguments { - pub(crate) client_nonce: u64, + client_nonce: u64, gateway_principal: Principal, } /// The arguments for [ws_close]. #[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq, Debug)] pub struct CanisterWsCloseArguments { - pub(crate) client_key: ClientKey, + client_key: ClientKey, } /// The arguments for [ws_message]. #[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq, Debug)] pub struct CanisterWsMessageArguments { - pub(crate) msg: WebsocketMessage, + msg: WebsocketMessage, } /// The arguments for [ws_get_messages]. #[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq, Debug)] pub struct CanisterWsGetMessagesArguments { - pub(crate) nonce: u64, + nonce: u64, } /// Messages exchanged through the WebSocket. @@ -123,20 +123,20 @@ impl WebsocketMessage { /// Element of the list of messages returned to the WS Gateway after polling. #[derive(CandidType, Clone, Debug, Deserialize, Serialize, Eq, PartialEq)] pub struct CanisterOutputMessage { - pub(crate) client_key: ClientKey, // The client that the gateway will forward the message to or that sent the message. - pub(crate) key: String, // Key for certificate verification. + client_key: ClientKey, // The client that the gateway will forward the message to or that sent the message. + key: String, // Key for certificate verification. #[serde(with = "serde_bytes")] - pub(crate) content: Vec, // The message to be relayed, that contains the application message. + content: Vec, // The message to be relayed, that contains the application message. } /// List of messages returned to the WS Gateway after polling. #[derive(CandidType, Clone, Debug, Deserialize, Serialize, Eq, PartialEq)] pub struct CanisterOutputCertifiedMessages { - pub(crate) messages: Vec, // List of messages. + messages: Vec, // List of messages. #[serde(with = "serde_bytes")] - pub(crate) cert: Vec, // cert+tree constitute the certificate for all returned messages. + cert: Vec, // cert+tree constitute the certificate for all returned messages. #[serde(with = "serde_bytes")] - pub(crate) tree: Vec, // cert+tree constitute the certificate for all returned messages. + tree: Vec, // cert+tree constitute the certificate for all returned messages. } #[derive(Copy, Clone, Debug, Eq, PartialEq)] @@ -211,7 +211,7 @@ thread_local! { /// Keeps track of the nonce which: /// - the WS Gateway uses to specify the first index of the certified messages to be returned when polling /// - the client uses as part of the path in the Merkle tree in order to verify the certificate of the messages relayed by the WS Gateway - /* flexible */ static OUTGOING_MESSAGE_NONCE: RefCell = RefCell::new(INITIAL_OUTGOING_MESSAGE_NONCE); + /* flexible */ static OUTGOING_MESSAGE_NONCE: RefCell> = RefCell::new(HashMap::new()); /// The parameters passed in the CDK initialization /* flexible */ static PARAMS: RefCell = RefCell::new(WsInitParams::default()); /// The acknowledgement active timer. @@ -249,7 +249,7 @@ fn reset_internal_state() { t.replace(RbTree::new()); }); MESSAGES_FOR_GATEWAYS.with(|m| *m.borrow_mut() = HashMap::new()); - OUTGOING_MESSAGE_NONCE.with(|next_id| next_id.replace(INITIAL_OUTGOING_MESSAGE_NONCE)); + OUTGOING_MESSAGE_NONCE.with(|m| *m.borrow_mut() = HashMap::new()); } /// Resets the internal state of the IC WebSocket CDK. @@ -261,12 +261,16 @@ pub fn wipe() { custom_print!("Internal state has been wiped!"); } -fn get_outgoing_message_nonce() -> u64 { - OUTGOING_MESSAGE_NONCE.with(|n| n.borrow().clone()) +fn get_outgoing_message_nonce(gateway_principal: &Principal) -> u64 { + OUTGOING_MESSAGE_NONCE.with(|n| n.borrow().get(gateway_principal).expect("TODO").clone()) } -fn increment_outgoing_message_nonce() { - OUTGOING_MESSAGE_NONCE.with(|n| n.replace_with(|&mut old| old + 1)); +fn increment_outgoing_message_nonce(gateway_principal: &Principal) { + OUTGOING_MESSAGE_NONCE.with(|n| { + let previous_nonce = *n.borrow().get(gateway_principal).expect("TODO"); + n.borrow_mut() + .insert(*gateway_principal, previous_nonce + 1); + }); } fn insert_client(client_key: ClientKey, new_client: RegisteredClient) { @@ -335,6 +339,14 @@ fn initialize_registered_gateways(gateways_principals: Vec) { .collect(); *p.borrow_mut() = Some(registered_gateways); }); + OUTGOING_MESSAGE_NONCE.with(|n| { + for gateway_principal in gateways_principals { + n.borrow_mut().insert( + Principal::from_text(gateway_principal).expect("invalid gateway principal"), + INITIAL_OUTGOING_MESSAGE_NONCE, + ); + } + }); } fn get_registered_gateways_principals() -> Vec { @@ -769,11 +781,11 @@ fn _ws_send( // the nonce in key is used by the WS Gateway to determine the message to start in the polling iteration // the key is also passed to the client in order to validate the body of the certified message - let outgoing_message_nonce = get_outgoing_message_nonce(); + let outgoing_message_nonce = get_outgoing_message_nonce(&gateway_principal); let key = get_message_for_gateway_key(gateway_principal, outgoing_message_nonce); // increment the nonce for the next message - increment_outgoing_message_nonce(); + increment_outgoing_message_nonce(&gateway_principal); // increment the sequence number for the next message to the client increment_outgoing_message_to_client_num(client_key)?; diff --git a/src/ic-websocket-cdk/src/tests/integration_tests/utils/test_env.rs b/src/ic-websocket-cdk/src/tests/integration_tests/utils/test_env.rs index d8e1b74..2cdf7d8 100644 --- a/src/ic-websocket-cdk/src/tests/integration_tests/utils/test_env.rs +++ b/src/ic-websocket-cdk/src/tests/integration_tests/utils/test_env.rs @@ -25,8 +25,9 @@ pub struct TestEnv { root_ic_key: Vec, } +type AuthorizedGateways = Vec; /// (`gateway_principal`, `max_number_or_returned_messages`, `send_ack_interval_ms`, `send_ack_timeout_ms`) -type CanisterInitArgs = (String, u64, u64, u64); +type CanisterInitArgs = (AuthorizedGateways, u64, u64, u64); impl TestEnv { pub fn new() -> Self { @@ -39,8 +40,10 @@ impl TestEnv { pic.add_cycles(canister_id, 1_000_000_000_000_000); let wasm_bytes = load_canister_wasm_from_bin("test_canister.wasm"); + + let authorized_gateways = vec![GATEWAY_1.to_string()]; let arguments: CanisterInitArgs = ( - GATEWAY_1.to_string(), + authorized_gateways, DEFAULT_TEST_MAX_NUMBER_OF_RETURNED_MESSAGES, DEFAULT_TEST_SEND_ACK_INTERVAL_MS, DEFAULT_TEST_KEEP_ALIVE_TIMEOUT_MS, diff --git a/src/ic-websocket-cdk/src/tests/unit_tests.rs b/src/ic-websocket-cdk/src/tests/unit_tests.rs index 16a23ae..cb3ef53 100644 --- a/src/ic-websocket-cdk/src/tests/unit_tests.rs +++ b/src/ic-websocket-cdk/src/tests/unit_tests.rs @@ -249,19 +249,21 @@ proptest! { #[test] fn test_get_outgoing_message_nonce(test_nonce in any::()) { // Set up - OUTGOING_MESSAGE_NONCE.with(|n| *n.borrow_mut() = test_nonce); + let gateway_principal = test_utils::generate_random_principal(); + OUTGOING_MESSAGE_NONCE.with(|n| n.borrow_mut().insert(gateway_principal, test_nonce)); - let actual_nonce = get_outgoing_message_nonce(); + let actual_nonce = get_outgoing_message_nonce(&gateway_principal); prop_assert_eq!(actual_nonce, test_nonce); } #[test] fn test_increment_outgoing_message_nonce(test_nonce in any::()) { // Set up - OUTGOING_MESSAGE_NONCE.with(|n| *n.borrow_mut() = test_nonce); + let gateway_principal = test_utils::generate_random_principal(); + OUTGOING_MESSAGE_NONCE.with(|n| n.borrow_mut().insert(gateway_principal, test_nonce)); - increment_outgoing_message_nonce(); - prop_assert_eq!(get_outgoing_message_nonce(), test_nonce + 1); + increment_outgoing_message_nonce(&gateway_principal); + prop_assert_eq!(get_outgoing_message_nonce(&gateway_principal), test_nonce + 1); } #[test] diff --git a/src/ic-websocket-cdk/ws_types.did b/src/ic-websocket-cdk/ws_types.did index ac375d7..d801a9f 100644 --- a/src/ic-websocket-cdk/ws_types.did +++ b/src/ic-websocket-cdk/ws_types.did @@ -26,6 +26,7 @@ type CanisterOutputCertifiedMessages = record { type CanisterWsOpenArguments = record { client_nonce : nat64; + gateway_principal : principal; }; type CanisterWsOpenResult = variant { From c1cb3d102b93b33ff8a6bfa9757d121994527fce Mon Sep 17 00:00:00 2001 From: massimoalbarello Date: Fri, 10 Nov 2023 19:41:46 +0100 Subject: [PATCH 03/27] fix: create message queue for each authorized gateway during init --- src/ic-websocket-cdk/src/lib.rs | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/ic-websocket-cdk/src/lib.rs b/src/ic-websocket-cdk/src/lib.rs index 40c4dba..f93c884 100644 --- a/src/ic-websocket-cdk/src/lib.rs +++ b/src/ic-websocket-cdk/src/lib.rs @@ -340,13 +340,21 @@ fn initialize_registered_gateways(gateways_principals: Vec) { *p.borrow_mut() = Some(registered_gateways); }); OUTGOING_MESSAGE_NONCE.with(|n| { - for gateway_principal in gateways_principals { + for gateway_principal in &gateways_principals { n.borrow_mut().insert( Principal::from_text(gateway_principal).expect("invalid gateway principal"), INITIAL_OUTGOING_MESSAGE_NONCE, ); } }); + MESSAGES_FOR_GATEWAYS.with(|n| { + for gateway_principal in &gateways_principals { + n.borrow_mut().insert( + Principal::from_text(gateway_principal).expect("invalid gateway principal"), + VecDeque::new(), + ); + } + }); } fn get_registered_gateways_principals() -> Vec { From 620fbe05908f296195c828f62ca6719381aecb90 Mon Sep 17 00:00:00 2001 From: massimoalbarello Date: Fri, 10 Nov 2023 19:50:35 +0100 Subject: [PATCH 04/27] fix: integration tests --- .../src/tests/integration_tests/c_ws_get_messages.rs | 2 +- src/ic-websocket-cdk/src/tests/integration_tests/d_ws_close.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ic-websocket-cdk/src/tests/integration_tests/c_ws_get_messages.rs b/src/ic-websocket-cdk/src/tests/integration_tests/c_ws_get_messages.rs index fad4329..44b5163 100644 --- a/src/ic-websocket-cdk/src/tests/integration_tests/c_ws_get_messages.rs +++ b/src/ic-websocket-cdk/src/tests/integration_tests/c_ws_get_messages.rs @@ -28,7 +28,7 @@ fn test_1_fails_if_a_non_registered_gateway_tries_to_get_messages() { assert_eq!( res, CanisterWsGetMessagesResult::Err(String::from( - "caller is not the gateway that has been registered during CDK initialization", + "caller is not one of the authorized gateways that have been registered during CDK initialization", )), ); } diff --git a/src/ic-websocket-cdk/src/tests/integration_tests/d_ws_close.rs b/src/ic-websocket-cdk/src/tests/integration_tests/d_ws_close.rs index 96eb270..abf83a4 100644 --- a/src/ic-websocket-cdk/src/tests/integration_tests/d_ws_close.rs +++ b/src/ic-websocket-cdk/src/tests/integration_tests/d_ws_close.rs @@ -25,7 +25,7 @@ fn test_1_fails_if_gateway_is_not_registered() { assert_eq!( res, CanisterWsCloseResult::Err(String::from( - "caller is not the gateway that has been registered during CDK initialization", + "caller is not one of the authorized gateways that have been registered during CDK initialization", )), ); } From f76a4622720a3840f1811066504d99dfd6204a15 Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Tue, 14 Nov 2023 15:49:55 +0100 Subject: [PATCH 05/27] chore: update readme tests docs --- README.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 874620f..db0e1b2 100644 --- a/README.md +++ b/README.md @@ -45,15 +45,17 @@ The **ic-websocket-cdk** library implementation can be found in the [src/ic-webs ### Testing There are two types of tests available: -- **Unit tests**: tests for CDK functions, written in Rust. -- **Integration tests**: for these tests a local IC replica is set up and the CDK is deployed to a [test canister](./tests/src/lib.rs). Tests are written in Node.js and are available in the [tests](./tests/integration/) folder. +- **Unit tests**: tests for CDK functions, written in Rust and available in the [unit_tests.rs](./src/ic-websocket-cdk/src/tests/unit_tests.rs) file. +- **Integration tests**: for these tests the CDK is deployed to a [test canister](./src/test_canister/). These tests are written in Rust and use [PocketIC](https://github.com/dfinity/pocketic) under the hood. They are available in the [integration_tests](./src/ic-websocket-cdk/src/tests/integration_tests/) folder. -There's a script that runs all the tests together, taking care of setting up the replica and deploying the canister. To run the script, execute the following command: +There's a script that runs all the tests together, taking care of setting up the environment (Linux only!) and deploying the canister. To run the script, execute the following command: ```bash ./scripts/test_canister.sh ``` +> If you're on **macOS**, you have to manually download the PocketIC binary ([guide](https://github.com/dfinity/pocketic#download)) and place it in the [bin](./bin/) folder. + ## License MIT License. See [LICENSE](./LICENSE). From aebef670341991223f749ccf9a5b0a6e8590471a Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Tue, 14 Nov 2023 15:59:04 +0100 Subject: [PATCH 06/27] perf: unneeded pub visibility --- src/ic-websocket-cdk/src/lib.rs | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/ic-websocket-cdk/src/lib.rs b/src/ic-websocket-cdk/src/lib.rs index f93c884..dba3422 100644 --- a/src/ic-websocket-cdk/src/lib.rs +++ b/src/ic-websocket-cdk/src/lib.rs @@ -42,8 +42,8 @@ const INITIAL_CANISTER_SEQUENCE_NUM: u64 = 0; pub type ClientPrincipal = Principal; #[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq, Debug, Hash)] pub(crate) struct ClientKey { - pub client_principal: ClientPrincipal, - pub client_nonce: u64, + client_principal: ClientPrincipal, + client_nonce: u64, } impl ClientKey { @@ -101,12 +101,12 @@ pub struct CanisterWsGetMessagesArguments { /// Messages exchanged through the WebSocket. #[derive(CandidType, Clone, Debug, Deserialize, Serialize, Eq, PartialEq)] pub(crate) struct WebsocketMessage { - pub client_key: ClientKey, // The client that the gateway will forward the message to or that sent the message. - pub sequence_num: u64, // Both ways, messages should arrive with sequence numbers 0, 1, 2... - pub timestamp: u64, // Timestamp of when the message was made for the recipient to inspect. - pub is_service_message: bool, // Whether the message is a service message sent by the CDK to the client or vice versa. + client_key: ClientKey, // The client that the gateway will forward the message to or that sent the message. + sequence_num: u64, // Both ways, messages should arrive with sequence numbers 0, 1, 2... + timestamp: u64, // Timestamp of when the message was made for the recipient to inspect. + is_service_message: bool, // Whether the message is a service message sent by the CDK to the client or vice versa. #[serde(with = "serde_bytes")] - pub content: Vec, // Application message encoded in binary. + content: Vec, // Application message encoded in binary. } impl WebsocketMessage { @@ -614,17 +614,17 @@ fn get_handlers_from_params() -> WsHandlers { #[derive(CandidType, Debug, Deserialize, PartialEq, Eq)] pub(crate) struct CanisterOpenMessageContent { - pub client_key: ClientKey, + client_key: ClientKey, } #[derive(CandidType, Debug, Deserialize, PartialEq, Eq)] pub(crate) struct CanisterAckMessageContent { - pub last_incoming_sequence_num: u64, + last_incoming_sequence_num: u64, } #[derive(CandidType, Debug, Deserialize, PartialEq, Eq)] pub(crate) struct ClientKeepAliveMessageContent { - pub last_incoming_sequence_num: u64, + last_incoming_sequence_num: u64, } /// A service message sent by the CDK to the client or vice versa. From 1ead87ea4f674d4acc534c7fdf1022ffd2ebc64d Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Wed, 15 Nov 2023 09:06:41 +0100 Subject: [PATCH 07/27] perf: remove expects, variables names, references --- src/ic-websocket-cdk/src/lib.rs | 195 ++++++++++-------- .../integration_tests/c_ws_get_messages.rs | 2 +- .../src/tests/integration_tests/d_ws_close.rs | 2 +- src/ic-websocket-cdk/src/tests/unit_tests.rs | 51 +++-- 4 files changed, 145 insertions(+), 105 deletions(-) diff --git a/src/ic-websocket-cdk/src/lib.rs b/src/ic-websocket-cdk/src/lib.rs index dba3422..9ceedd7 100644 --- a/src/ic-websocket-cdk/src/lib.rs +++ b/src/ic-websocket-cdk/src/lib.rs @@ -151,6 +151,20 @@ impl RegisteredGateway { fn new(gateway_principal: Principal) -> Self { Self { gateway_principal } } + + /// Creates a new instance of RegisteredGateway from a text representation of its principal. + fn from_text_principal(gateway_principal_text: &str) -> Self { + Self::new( + Principal::from_text(gateway_principal_text).expect(&format!( + "invalid gateway principal {gateway_principal_text}" + )), + ) + } + + /// Gets the gateway principal. + fn get_principal(&self) -> &Principal { + &self.gateway_principal + } } fn get_current_time() -> u64 { @@ -204,7 +218,7 @@ thread_local! { /* flexible */ static INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP: RefCell> = RefCell::new(HashMap::new()); /// Keeps track of the Merkle tree used for certified queries /* flexible */ static CERT_TREE: RefCell> = RefCell::new(RbTree::new()); - /// Keeps track of the principal of the WS Gateway which polls the canister + /// Keeps track of the principals of the WS Gateways that poll the canister /* flexible */ static REGISTERED_GATEWAYS: RefCell>> = RefCell::new(None); /// Keeps track of the messages that have to be sent to each authorized WS Gateway /* flexible */ static MESSAGES_FOR_GATEWAYS: RefCell>> = RefCell::new(HashMap::new()); @@ -261,22 +275,29 @@ pub fn wipe() { custom_print!("Internal state has been wiped!"); } -fn get_outgoing_message_nonce(gateway_principal: &Principal) -> u64 { - OUTGOING_MESSAGE_NONCE.with(|n| n.borrow().get(gateway_principal).expect("TODO").clone()) +fn get_outgoing_message_nonce(gateway_principal: &Principal) -> Result { + OUTGOING_MESSAGE_NONCE.with(|n| { + n.borrow() + .get(gateway_principal) + .cloned() + .ok_or(String::from( + "gateway doesn't have an outgoing message nonce", + )) + }) } -fn increment_outgoing_message_nonce(gateway_principal: &Principal) { +fn increment_outgoing_message_nonce(gateway_principal: Principal) -> Result<(), String> { + let previous_nonce = get_outgoing_message_nonce(&gateway_principal)?; OUTGOING_MESSAGE_NONCE.with(|n| { - let previous_nonce = *n.borrow().get(gateway_principal).expect("TODO"); - n.borrow_mut() - .insert(*gateway_principal, previous_nonce + 1); + n.borrow_mut().insert(gateway_principal, previous_nonce + 1); }); + Ok(()) } fn insert_client(client_key: ClientKey, new_client: RegisteredClient) { CURRENT_CLIENT_KEY_MAP.with(|map| { map.borrow_mut() - .insert(client_key.client_principal.clone(), client_key.clone()); + .insert(client_key.client_principal, client_key.clone()); }); REGISTERED_CLIENTS.with(|map| { map.borrow_mut().insert(client_key, new_client); @@ -310,17 +331,13 @@ fn check_registered_client(client_key: &ClientKey) -> Result<(), String> { Ok(()) } -fn get_gateway_principal_from_registered_client( - client_key: &ClientKey, -) -> Result { - check_registered_client(client_key)?; - let gateway_principal = REGISTERED_CLIENTS.with(|map| { +fn get_gateway_principal_from_registered_client(client_key: &ClientKey) -> Principal { + REGISTERED_CLIENTS.with(|map| { map.borrow() .get(client_key) - .expect("must be registered") + .unwrap() // the value exists because we checked that the client is registered .gateway_principal - }); - Ok(gateway_principal) + }) } fn add_client_to_wait_for_keep_alive(client_key: &ClientKey) { @@ -330,40 +347,36 @@ fn add_client_to_wait_for_keep_alive(client_key: &ClientKey) { } fn initialize_registered_gateways(gateways_principals: Vec) { - REGISTERED_GATEWAYS.with(|p| { - let registered_gateways = gateways_principals - .iter() - .map(|s| { - RegisteredGateway::new(Principal::from_text(s).expect("invalid gateway principal")) - }) - .collect(); - *p.borrow_mut() = Some(registered_gateways); - }); - OUTGOING_MESSAGE_NONCE.with(|n| { - for gateway_principal in &gateways_principals { - n.borrow_mut().insert( - Principal::from_text(gateway_principal).expect("invalid gateway principal"), + let registered_gateways: Vec = gateways_principals + .iter() + .map(|s| RegisteredGateway::from_text_principal(s)) + .collect(); + + REGISTERED_GATEWAYS.with(|p| *p.borrow_mut() = Some(registered_gateways.clone())); + + for registered_gateway in registered_gateways.iter() { + OUTGOING_MESSAGE_NONCE.with(|n| { + let mut n = n.borrow_mut(); + n.insert( + *registered_gateway.get_principal(), INITIAL_OUTGOING_MESSAGE_NONCE, ); - } - }); - MESSAGES_FOR_GATEWAYS.with(|n| { - for gateway_principal in &gateways_principals { - n.borrow_mut().insert( - Principal::from_text(gateway_principal).expect("invalid gateway principal"), - VecDeque::new(), - ); - } - }); + }); + + MESSAGES_FOR_GATEWAYS.with(|n| { + let mut n = n.borrow_mut(); + n.insert(*registered_gateway.get_principal(), VecDeque::new()); + }); + } } fn get_registered_gateways_principals() -> Vec { REGISTERED_GATEWAYS.with(|g| { g.borrow() .as_ref() - .expect("gateway should be initialized") + .expect("gateways map should be initialized") .iter() - .map(|registered_gateway| registered_gateway.gateway_principal) + .map(|registered_gateway| *registered_gateway.get_principal()) .collect() }) } @@ -411,12 +424,12 @@ fn get_expected_incoming_message_from_client_num(client_key: &ClientKey) -> Resu } fn increment_expected_incoming_message_from_client_num( - client_key: &ClientKey, + client_key: ClientKey, ) -> Result<(), String> { - let num = get_expected_incoming_message_from_client_num(client_key)?; + let num = get_expected_incoming_message_from_client_num(&client_key)?; INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| { let mut map = map.borrow_mut(); - map.insert(client_key.clone(), num + 1); + map.insert(client_key, num + 1); Ok(()) }) } @@ -453,70 +466,75 @@ fn remove_client(client_key: &ClientKey) { }); } -fn get_message_for_gateway_key(gateway_principal: Principal, nonce: u64) -> String { +fn get_message_for_gateway_key(gateway_principal: &Principal, nonce: u64) -> String { gateway_principal.to_string() + "_" + &format!("{:0>20}", nonce.to_string()) } -fn get_messages_for_gateway_range(gateway_principal: Principal, nonce: u64) -> (usize, usize) { +fn get_messages_for_gateway_range( + gateway_principal: &Principal, + nonce: u64, +) -> Result<(usize, usize), String> { let max_number_of_returned_messages = get_params().max_number_of_returned_messages; + check_is_registered_gateway(gateway_principal)?; + MESSAGES_FOR_GATEWAYS.with(|h| { - let queue_len = h.borrow().get(&gateway_principal).expect("TODO").len(); + let h = h.borrow(); + let messages_queue = h.get(gateway_principal).unwrap(); // the value exists because we just checked that the gateway is registered + let queue_len = messages_queue.len(); if nonce == 0 && queue_len > 0 { // this is the case in which the poller on the gateway restarted - // the range to return is end:last index and start: max(end - max_number_of_returned_messages, 0) + // the range to return is end=(last index) and start=max(end - max_number_of_returned_messages, 0) let start_index = if queue_len > max_number_of_returned_messages { queue_len - max_number_of_returned_messages } else { 0 }; - return (start_index, queue_len); + return Ok((start_index, queue_len)); } // smallest key used to determine the first message from the queue which has to be returned to the WS Gateway let smallest_key = get_message_for_gateway_key(gateway_principal, nonce); // partition the queue at the message which has the key with the nonce specified as argument to get_cert_messages - let start_index = h - .borrow() - .get(&gateway_principal) - .expect("TODO") - .partition_point(|x| x.key < smallest_key); + let start_index = messages_queue.partition_point(|x| x.key < smallest_key); // message at index corresponding to end index is excluded let mut end_index = queue_len; if end_index - start_index > max_number_of_returned_messages { end_index = start_index + max_number_of_returned_messages; } - (start_index, end_index) + Ok((start_index, end_index)) }) } fn get_messages_for_gateway( - gateway_principal: Principal, + gateway_principal: &Principal, start_index: usize, end_index: usize, -) -> Vec { +) -> Result, String> { + check_is_registered_gateway(gateway_principal)?; + MESSAGES_FOR_GATEWAYS.with(|h| { let mut messages: Vec = Vec::with_capacity(end_index - start_index); for index in start_index..end_index { messages.push( h.borrow() .get(&gateway_principal) - .expect("TODO") + .unwrap() // the value exists because we just checked that the gateway is registered .get(index) - .unwrap() + .unwrap() // the value exists because this function is called only after partitioning the queue .clone(), ); } - messages + Ok(messages) }) } -/// Gets the messages in MESSAGES_FOR_GATEWAYS starting from the one with the specified nonce -fn get_cert_messages(gateway_principal: Principal, nonce: u64) -> CanisterWsGetMessagesResult { - let (start_index, end_index) = get_messages_for_gateway_range(gateway_principal, nonce); - let messages = get_messages_for_gateway(gateway_principal, start_index, end_index); +/// Gets the messages in [MESSAGES_FOR_GATEWAYS] starting from the one with the specified nonce +fn get_cert_messages(gateway_principal: &Principal, nonce: u64) -> CanisterWsGetMessagesResult { + let (start_index, end_index) = get_messages_for_gateway_range(gateway_principal, nonce)?; // TODO: test error case + let messages = get_messages_for_gateway(gateway_principal, start_index, end_index)?; // TODO: test error case if messages.is_empty() { return Ok(CanisterOutputCertifiedMessages { @@ -537,15 +555,15 @@ fn get_cert_messages(gateway_principal: Principal, nonce: u64) -> CanisterWsGetM }) } -fn is_registered_gateway(principal: Principal) -> bool { - get_registered_gateways_principals().contains(&principal) +fn is_registered_gateway(principal: &Principal) -> bool { + get_registered_gateways_principals().contains(principal) } -/// Checks if the caller of the method is one of the authorized WS Gateways that have been registered during the initialization of the CDK -fn check_is_registered_gateway(principal: Principal) -> Result<(), String> { +/// Checks if the principal of the authorized WS Gateways that have been registered during the initialization of the CDK +fn check_is_registered_gateway(principal: &Principal) -> Result<(), String> { if !is_registered_gateway(principal) { return Err(String::from( - "caller is not one of the authorized gateways that have been registered during CDK initialization", + "principal is not one of the authorized gateways that have been registered during CDK initialization", )); } Ok(()) @@ -650,7 +668,7 @@ impl WebsocketServiceMessageContent { fn send_service_message_to_client( client_key: &ClientKey, - message: WebsocketServiceMessageContent, + message: &WebsocketServiceMessageContent, ) -> Result<(), String> { let message_bytes = encode_one(&message).unwrap(); _ws_send(client_key, message_bytes, true) @@ -697,7 +715,7 @@ fn send_ack_to_clients_timer_callback() { last_incoming_sequence_num: expected_incoming_sequence_num - 1, }; let message = WebsocketServiceMessageContent::AckMessage(ack_message); - if let Err(e) = send_service_message_to_client(client_key, message) { + if let Err(e) = send_service_message_to_client(client_key, &message) { // TODO: decide what to do when sending the message fails custom_print!( @@ -784,16 +802,21 @@ fn _ws_send( msg_bytes: Vec, is_service_message: bool, ) -> CanisterWsSendResult { + // check if the client is registered + check_registered_client(client_key)?; + // get the principal of the gateway that the client is connected to - let gateway_principal = get_gateway_principal_from_registered_client(client_key)?; + let gateway_principal = get_gateway_principal_from_registered_client(client_key); + check_is_registered_gateway(&gateway_principal)?; // the nonce in key is used by the WS Gateway to determine the message to start in the polling iteration // the key is also passed to the client in order to validate the body of the certified message - let outgoing_message_nonce = get_outgoing_message_nonce(&gateway_principal); - let key = get_message_for_gateway_key(gateway_principal, outgoing_message_nonce); + let outgoing_message_nonce = get_outgoing_message_nonce(&gateway_principal)?; // TODO: test the error case + let message_key = get_message_for_gateway_key(&gateway_principal, outgoing_message_nonce); // increment the nonce for the next message - increment_outgoing_message_nonce(&gateway_principal); + increment_outgoing_message_nonce(gateway_principal)?; // TODO: test the error case + // increment the sequence number for the next message to the client increment_outgoing_message_to_client_num(client_key)?; @@ -806,10 +829,10 @@ fn _ws_send( }; // CBOR serialize message of type WebsocketMessage - let content = websocket_message.cbor_serialize()?; + let message_content = websocket_message.cbor_serialize()?; // certify data - put_cert_for_message(key.clone(), &content); + put_cert_for_message(message_key.clone(), &message_content); MESSAGES_FOR_GATEWAYS.with(|h| { // messages in the queue are inserted with contiguous and increasing nonces @@ -817,11 +840,11 @@ fn _ws_send( // is incremented by one in each call, and the message is pushed at the end of the queue h.borrow_mut() .get_mut(&gateway_principal) - .expect("TODO") + .unwrap() // the value exists because we just checked that the gateway is registered .push_back(CanisterOutputMessage { client_key: client_key.clone(), - content, - key, + content: message_content, + key: message_key, }); }); Ok(()) @@ -1034,7 +1057,7 @@ pub fn ws_open(args: CanisterWsOpenArguments) -> CanisterWsOpenResult { } // avoid gateway opening a connection for its own principal - if is_registered_gateway(client_principal) { + if is_registered_gateway(&client_principal) { return Err(String::from( "caller is the registered gateway which can't open a connection for itself", )); @@ -1057,7 +1080,7 @@ pub fn ws_open(args: CanisterWsOpenArguments) -> CanisterWsOpenResult { client_key: client_key.clone(), }; let message = WebsocketServiceMessageContent::OpenMessage(open_message); - send_service_message_to_client(&client_key, message)?; + send_service_message_to_client(&client_key, &message)?; // call the on_open handler initialized in init() get_handlers_from_params().call_on_open(OnOpenCallbackArgs { client_principal }); @@ -1068,7 +1091,7 @@ pub fn ws_open(args: CanisterWsOpenArguments) -> CanisterWsOpenResult { /// Handles the WS connection close event received from the WS Gateway. pub fn ws_close(args: CanisterWsCloseArguments) -> CanisterWsCloseResult { // the caller must be the gateway that was registered during CDK initialization - check_is_registered_gateway(caller())?; + check_is_registered_gateway(&caller())?; // check if client registered its principal by calling ws_open check_registered_client(&args.client_key)?; @@ -1140,7 +1163,7 @@ pub fn ws_message Deserialize<'a>>( )); } // increase the expected sequence number by 1 - increment_expected_incoming_message_from_client_num(&client_key)?; + increment_expected_incoming_message_from_client_num(client_key.clone())?; if is_service_message { return handle_received_service_message(&client_key, &content); @@ -1158,9 +1181,9 @@ pub fn ws_message Deserialize<'a>>( pub fn ws_get_messages(args: CanisterWsGetMessagesArguments) -> CanisterWsGetMessagesResult { // check if the caller of this method is the WS Gateway that has been set during the initialization of the SDK let gateway_principal = caller(); - check_is_registered_gateway(gateway_principal)?; + check_is_registered_gateway(&gateway_principal)?; - get_cert_messages(gateway_principal, args.nonce) + get_cert_messages(&gateway_principal, args.nonce) } /// Sends a message to the client. The message must already be serialized **using Candid**. diff --git a/src/ic-websocket-cdk/src/tests/integration_tests/c_ws_get_messages.rs b/src/ic-websocket-cdk/src/tests/integration_tests/c_ws_get_messages.rs index 44b5163..7b9bdd5 100644 --- a/src/ic-websocket-cdk/src/tests/integration_tests/c_ws_get_messages.rs +++ b/src/ic-websocket-cdk/src/tests/integration_tests/c_ws_get_messages.rs @@ -28,7 +28,7 @@ fn test_1_fails_if_a_non_registered_gateway_tries_to_get_messages() { assert_eq!( res, CanisterWsGetMessagesResult::Err(String::from( - "caller is not one of the authorized gateways that have been registered during CDK initialization", + "principal is not one of the authorized gateways that have been registered during CDK initialization", )), ); } diff --git a/src/ic-websocket-cdk/src/tests/integration_tests/d_ws_close.rs b/src/ic-websocket-cdk/src/tests/integration_tests/d_ws_close.rs index abf83a4..ee3ba83 100644 --- a/src/ic-websocket-cdk/src/tests/integration_tests/d_ws_close.rs +++ b/src/ic-websocket-cdk/src/tests/integration_tests/d_ws_close.rs @@ -25,7 +25,7 @@ fn test_1_fails_if_gateway_is_not_registered() { assert_eq!( res, CanisterWsCloseResult::Err(String::from( - "caller is not one of the authorized gateways that have been registered during CDK initialization", + "principal is not one of the authorized gateways that have been registered during CDK initialization", )), ); } diff --git a/src/ic-websocket-cdk/src/tests/unit_tests.rs b/src/ic-websocket-cdk/src/tests/unit_tests.rs index cb3ef53..00a76a0 100644 --- a/src/ic-websocket-cdk/src/tests/unit_tests.rs +++ b/src/ic-websocket-cdk/src/tests/unit_tests.rs @@ -55,7 +55,7 @@ mod test_utils { .expect("TODO") .push_back(CanisterOutputMessage { client_key: client_key.clone(), - key: get_message_for_gateway_key(gateway_principal.clone(), i), + key: get_message_for_gateway_key(&gateway_principal, i), content: vec![], }); } @@ -69,7 +69,7 @@ mod test_utils { // we don't need to proptest get_gateway_principal if principal is not set, as it just panics #[test] -#[should_panic = "gateway should be initialized"] +#[should_panic = "gateways map should be initialized"] fn test_get_gateway_principal_not_set() { get_registered_gateways_principals(); } @@ -252,7 +252,9 @@ proptest! { let gateway_principal = test_utils::generate_random_principal(); OUTGOING_MESSAGE_NONCE.with(|n| n.borrow_mut().insert(gateway_principal, test_nonce)); - let actual_nonce = get_outgoing_message_nonce(&gateway_principal); + let res = get_outgoing_message_nonce(&gateway_principal); + prop_assert!(res.is_ok()); + let actual_nonce = res.unwrap(); prop_assert_eq!(actual_nonce, test_nonce); } @@ -262,8 +264,11 @@ proptest! { let gateway_principal = test_utils::generate_random_principal(); OUTGOING_MESSAGE_NONCE.with(|n| n.borrow_mut().insert(gateway_principal, test_nonce)); - increment_outgoing_message_nonce(&gateway_principal); - prop_assert_eq!(get_outgoing_message_nonce(&gateway_principal), test_nonce + 1); + let res = increment_outgoing_message_nonce(gateway_principal); + prop_assert!(res.is_ok()); + let res = get_outgoing_message_nonce(&gateway_principal); + prop_assert!(res.is_ok()); + prop_assert_eq!(res.unwrap(), test_nonce + 1); } #[test] @@ -407,7 +412,7 @@ proptest! { map.borrow_mut().insert(test_client_key.clone(), test_num); }); - let increment_result = increment_expected_incoming_message_from_client_num(&test_client_key); + let increment_result = increment_expected_incoming_message_from_client_num(test_client_key.clone()); prop_assert!(increment_result.is_ok()); let actual_result = INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| map.borrow().get(&test_client_key).unwrap().clone()); @@ -475,7 +480,7 @@ proptest! { #[test] fn test_get_message_for_gateway_key(test_gateway_principal in any::().prop_map(|_| test_utils::generate_random_principal()), test_nonce in any::()) { - let actual_result = get_message_for_gateway_key(test_gateway_principal.clone(), test_nonce); + let actual_result = get_message_for_gateway_key(&test_gateway_principal, test_nonce); prop_assert_eq!(actual_result, test_gateway_principal.to_string() + "_" + &format!("{:0>20}", test_nonce.to_string())); } @@ -493,7 +498,9 @@ proptest! { // Test // we ask for a random range of messages to check if it always returns the same range for empty messages for i in 0..messages_count { - let (start_index, end_index) = get_messages_for_gateway_range(gateway_principal, i); + let res = get_messages_for_gateway_range(&gateway_principal, i); + prop_assert!(res.is_ok()); + let (start_index, end_index) = res.unwrap(); prop_assert_eq!(start_index, 0); prop_assert_eq!(end_index, 0); } @@ -517,7 +524,9 @@ proptest! { // messages are just 4, so we don't exceed the max number of returned messages // add one to test the out of range index for i in 0..messages_count + 1 { - let (start_index, end_index) = get_messages_for_gateway_range(gateway_principal, i); + let res = get_messages_for_gateway_range(&gateway_principal, i); + prop_assert!(res.is_ok()); + let (start_index, end_index) = res.unwrap(); prop_assert_eq!(start_index, i as usize); prop_assert_eq!(end_index, messages_count as usize); } @@ -550,7 +559,9 @@ proptest! { // messages are now 2 * MAX_NUMBER_OF_RETURNED_MESSAGES // the case in which the start index is 0 is tested in test_get_messages_for_gateway_range_initial_nonce for i in 1..messages_count + 1 { - let (start_index, end_index) = get_messages_for_gateway_range(gateway_principal, i); + let res = get_messages_for_gateway_range(&gateway_principal, i); + prop_assert!(res.is_ok()); + let (start_index, end_index) = res.unwrap(); let expected_end_index = if (i as usize) + max_number_of_returned_messages > messages_count as usize { messages_count as usize } else { @@ -584,7 +595,9 @@ proptest! { test_utils::add_messages_for_gateway(test_client_key, gateway_principal, messages_count); // Test - let (start_index, end_index) = get_messages_for_gateway_range(gateway_principal, 0); + let res = get_messages_for_gateway_range(&gateway_principal, 0); + prop_assert!(res.is_ok()); + let (start_index, end_index) = res.unwrap(); let expected_start_index = if (messages_count as usize) > max_number_of_returned_messages { (messages_count as usize) - max_number_of_returned_messages } else { @@ -613,12 +626,16 @@ proptest! { // Test // add one to test the out of range index for i in 0..messages_count + 1 { - let (start_index, end_index) = get_messages_for_gateway_range(gateway_principal, i); - let messages = get_messages_for_gateway(gateway_principal, start_index, end_index); + let res = get_messages_for_gateway_range(&gateway_principal, i); + prop_assert!(res.is_ok()); + let (start_index, end_index) = res.unwrap(); + let res = get_messages_for_gateway(&gateway_principal, start_index, end_index); + prop_assert!(res.is_ok()); + let messages = res.unwrap(); // check if the messages returned are the ones we expect for (j, message) in messages.iter().enumerate() { - let expected_key = get_message_for_gateway_key(gateway_principal.clone(), (start_index + j) as u64); + let expected_key = get_message_for_gateway_key(&gateway_principal, (start_index + j) as u64); prop_assert_eq!(&message.key, &expected_key); } } @@ -632,12 +649,12 @@ proptest! { // Set up REGISTERED_GATEWAYS.with(|p| *p.borrow_mut() = Some(vec![RegisteredGateway::new(test_gateway_principal.clone())])); - let actual_result = check_is_registered_gateway(test_gateway_principal); + let actual_result = check_is_registered_gateway(&test_gateway_principal); prop_assert!(actual_result.is_ok()); let other_principal = test_utils::generate_random_principal(); - let actual_result = check_is_registered_gateway(other_principal); - prop_assert_eq!(actual_result.err(), Some(String::from("caller is not one of the authorized gateways that have been registered during CDK initialization"))); + let actual_result = check_is_registered_gateway(&other_principal); + prop_assert_eq!(actual_result.err(), Some(String::from("principal is not one of the authorized gateways that have been registered during CDK initialization"))); } #[test] From 920840f4a4cac1bb0b6ea5ccf3926d64131681f0 Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Wed, 15 Nov 2023 09:14:04 +0100 Subject: [PATCH 08/27] perf: unneeded results, removed todos --- src/ic-websocket-cdk/src/lib.rs | 31 ++++++++----------- src/ic-websocket-cdk/src/tests/unit_tests.rs | 32 ++++++-------------- 2 files changed, 22 insertions(+), 41 deletions(-) diff --git a/src/ic-websocket-cdk/src/lib.rs b/src/ic-websocket-cdk/src/lib.rs index 9ceedd7..e6fc40f 100644 --- a/src/ic-websocket-cdk/src/lib.rs +++ b/src/ic-websocket-cdk/src/lib.rs @@ -466,18 +466,13 @@ fn remove_client(client_key: &ClientKey) { }); } -fn get_message_for_gateway_key(gateway_principal: &Principal, nonce: u64) -> String { +fn format_message_for_gateway_key(gateway_principal: &Principal, nonce: u64) -> String { gateway_principal.to_string() + "_" + &format!("{:0>20}", nonce.to_string()) } -fn get_messages_for_gateway_range( - gateway_principal: &Principal, - nonce: u64, -) -> Result<(usize, usize), String> { +fn get_messages_for_gateway_range(gateway_principal: &Principal, nonce: u64) -> (usize, usize) { let max_number_of_returned_messages = get_params().max_number_of_returned_messages; - check_is_registered_gateway(gateway_principal)?; - MESSAGES_FOR_GATEWAYS.with(|h| { let h = h.borrow(); let messages_queue = h.get(gateway_principal).unwrap(); // the value exists because we just checked that the gateway is registered @@ -492,11 +487,11 @@ fn get_messages_for_gateway_range( 0 }; - return Ok((start_index, queue_len)); + return (start_index, queue_len); } // smallest key used to determine the first message from the queue which has to be returned to the WS Gateway - let smallest_key = get_message_for_gateway_key(gateway_principal, nonce); + let smallest_key = format_message_for_gateway_key(gateway_principal, nonce); // partition the queue at the message which has the key with the nonce specified as argument to get_cert_messages let start_index = messages_queue.partition_point(|x| x.key < smallest_key); // message at index corresponding to end index is excluded @@ -504,7 +499,7 @@ fn get_messages_for_gateway_range( if end_index - start_index > max_number_of_returned_messages { end_index = start_index + max_number_of_returned_messages; } - Ok((start_index, end_index)) + (start_index, end_index) }) } @@ -512,9 +507,7 @@ fn get_messages_for_gateway( gateway_principal: &Principal, start_index: usize, end_index: usize, -) -> Result, String> { - check_is_registered_gateway(gateway_principal)?; - +) -> Vec { MESSAGES_FOR_GATEWAYS.with(|h| { let mut messages: Vec = Vec::with_capacity(end_index - start_index); for index in start_index..end_index { @@ -527,14 +520,14 @@ fn get_messages_for_gateway( .clone(), ); } - Ok(messages) + messages }) } /// Gets the messages in [MESSAGES_FOR_GATEWAYS] starting from the one with the specified nonce fn get_cert_messages(gateway_principal: &Principal, nonce: u64) -> CanisterWsGetMessagesResult { - let (start_index, end_index) = get_messages_for_gateway_range(gateway_principal, nonce)?; // TODO: test error case - let messages = get_messages_for_gateway(gateway_principal, start_index, end_index)?; // TODO: test error case + let (start_index, end_index) = get_messages_for_gateway_range(gateway_principal, nonce); + let messages = get_messages_for_gateway(gateway_principal, start_index, end_index); if messages.is_empty() { return Ok(CanisterOutputCertifiedMessages { @@ -811,11 +804,11 @@ fn _ws_send( // the nonce in key is used by the WS Gateway to determine the message to start in the polling iteration // the key is also passed to the client in order to validate the body of the certified message - let outgoing_message_nonce = get_outgoing_message_nonce(&gateway_principal)?; // TODO: test the error case - let message_key = get_message_for_gateway_key(&gateway_principal, outgoing_message_nonce); + let outgoing_message_nonce = get_outgoing_message_nonce(&gateway_principal)?; // we never hit the error case because we just checked that the gateway is registered + let message_key = format_message_for_gateway_key(&gateway_principal, outgoing_message_nonce); // increment the nonce for the next message - increment_outgoing_message_nonce(gateway_principal)?; // TODO: test the error case + increment_outgoing_message_nonce(gateway_principal)?; // we never hit the error case because we just checked that the gateway is registered // increment the sequence number for the next message to the client increment_outgoing_message_to_client_num(client_key)?; diff --git a/src/ic-websocket-cdk/src/tests/unit_tests.rs b/src/ic-websocket-cdk/src/tests/unit_tests.rs index 00a76a0..97e0728 100644 --- a/src/ic-websocket-cdk/src/tests/unit_tests.rs +++ b/src/ic-websocket-cdk/src/tests/unit_tests.rs @@ -7,7 +7,7 @@ mod test_utils { use ring::signature::Ed25519KeyPair; use super::{ - get_message_for_gateway_key, CanisterOutputMessage, ClientKey, RegisteredClient, + format_message_for_gateway_key, CanisterOutputMessage, ClientKey, RegisteredClient, MESSAGES_FOR_GATEWAYS, }; @@ -55,7 +55,7 @@ mod test_utils { .expect("TODO") .push_back(CanisterOutputMessage { client_key: client_key.clone(), - key: get_message_for_gateway_key(&gateway_principal, i), + key: format_message_for_gateway_key(&gateway_principal, i), content: vec![], }); } @@ -480,7 +480,7 @@ proptest! { #[test] fn test_get_message_for_gateway_key(test_gateway_principal in any::().prop_map(|_| test_utils::generate_random_principal()), test_nonce in any::()) { - let actual_result = get_message_for_gateway_key(&test_gateway_principal, test_nonce); + let actual_result = format_message_for_gateway_key(&test_gateway_principal, test_nonce); prop_assert_eq!(actual_result, test_gateway_principal.to_string() + "_" + &format!("{:0>20}", test_nonce.to_string())); } @@ -498,9 +498,7 @@ proptest! { // Test // we ask for a random range of messages to check if it always returns the same range for empty messages for i in 0..messages_count { - let res = get_messages_for_gateway_range(&gateway_principal, i); - prop_assert!(res.is_ok()); - let (start_index, end_index) = res.unwrap(); + let (start_index, end_index) = get_messages_for_gateway_range(&gateway_principal, i); prop_assert_eq!(start_index, 0); prop_assert_eq!(end_index, 0); } @@ -524,9 +522,7 @@ proptest! { // messages are just 4, so we don't exceed the max number of returned messages // add one to test the out of range index for i in 0..messages_count + 1 { - let res = get_messages_for_gateway_range(&gateway_principal, i); - prop_assert!(res.is_ok()); - let (start_index, end_index) = res.unwrap(); + let (start_index, end_index) = get_messages_for_gateway_range(&gateway_principal, i); prop_assert_eq!(start_index, i as usize); prop_assert_eq!(end_index, messages_count as usize); } @@ -559,9 +555,7 @@ proptest! { // messages are now 2 * MAX_NUMBER_OF_RETURNED_MESSAGES // the case in which the start index is 0 is tested in test_get_messages_for_gateway_range_initial_nonce for i in 1..messages_count + 1 { - let res = get_messages_for_gateway_range(&gateway_principal, i); - prop_assert!(res.is_ok()); - let (start_index, end_index) = res.unwrap(); + let (start_index, end_index) = get_messages_for_gateway_range(&gateway_principal, i); let expected_end_index = if (i as usize) + max_number_of_returned_messages > messages_count as usize { messages_count as usize } else { @@ -595,9 +589,7 @@ proptest! { test_utils::add_messages_for_gateway(test_client_key, gateway_principal, messages_count); // Test - let res = get_messages_for_gateway_range(&gateway_principal, 0); - prop_assert!(res.is_ok()); - let (start_index, end_index) = res.unwrap(); + let (start_index, end_index) = get_messages_for_gateway_range(&gateway_principal, 0); let expected_start_index = if (messages_count as usize) > max_number_of_returned_messages { (messages_count as usize) - max_number_of_returned_messages } else { @@ -626,16 +618,12 @@ proptest! { // Test // add one to test the out of range index for i in 0..messages_count + 1 { - let res = get_messages_for_gateway_range(&gateway_principal, i); - prop_assert!(res.is_ok()); - let (start_index, end_index) = res.unwrap(); - let res = get_messages_for_gateway(&gateway_principal, start_index, end_index); - prop_assert!(res.is_ok()); - let messages = res.unwrap(); + let (start_index, end_index) = get_messages_for_gateway_range(&gateway_principal, i); + let messages = get_messages_for_gateway(&gateway_principal, start_index, end_index); // check if the messages returned are the ones we expect for (j, message) in messages.iter().enumerate() { - let expected_key = get_message_for_gateway_key(&gateway_principal, (start_index + j) as u64); + let expected_key = format_message_for_gateway_key(&gateway_principal, (start_index + j) as u64); prop_assert_eq!(&message.key, &expected_key); } } From 672c992c1cefd3cf153144f1232ae235759808a5 Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Wed, 15 Nov 2023 11:16:32 +0100 Subject: [PATCH 09/27] feat: gateway data in one map, unit_tests folder --- src/ic-websocket-cdk/src/lib.rs | 165 +++++----- .../{unit_tests.rs => unit_tests/mod.rs} | 294 +++++++----------- .../src/tests/unit_tests/utils.rs | 75 +++++ 3 files changed, 261 insertions(+), 273 deletions(-) rename src/ic-websocket-cdk/src/tests/{unit_tests.rs => unit_tests/mod.rs} (62%) create mode 100644 src/ic-websocket-cdk/src/tests/unit_tests/utils.rs diff --git a/src/ic-websocket-cdk/src/lib.rs b/src/ic-websocket-cdk/src/lib.rs index e6fc40f..98a3e7b 100644 --- a/src/ic-websocket-cdk/src/lib.rs +++ b/src/ic-websocket-cdk/src/lib.rs @@ -77,7 +77,7 @@ pub type CanisterWsSendResult = Result<(), String>; #[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq, Debug)] pub struct CanisterWsOpenArguments { client_nonce: u64, - gateway_principal: Principal, + gateway_principal: GatewayPrincipal, } /// The arguments for [ws_close]. @@ -139,31 +139,37 @@ pub struct CanisterOutputCertifiedMessages { tree: Vec, // cert+tree constitute the certificate for all returned messages. } -#[derive(Copy, Clone, Debug, Eq, PartialEq)] +type GatewayPrincipal = Principal; + +#[derive(Clone, Debug, Default, Eq, PartialEq)] /// Contains data about the registered WS Gateway. struct RegisteredGateway { - /// The principal of the gateway. - gateway_principal: Principal, + /// The queue of the messages that the gateway can poll. + messages_queue: VecDeque, + /// Keeps track of the nonce which: + /// - the WS Gateway uses to specify the first index of the certified messages to be returned when polling + /// - the client uses as part of the path in the Merkle tree in order to verify the certificate of the messages relayed by the WS Gateway + outgoing_message_nonce: u64, } impl RegisteredGateway { /// Creates a new instance of RegisteredGateway. - fn new(gateway_principal: Principal) -> Self { - Self { gateway_principal } + fn new() -> Self { + Self { + messages_queue: VecDeque::new(), + outgoing_message_nonce: INITIAL_OUTGOING_MESSAGE_NONCE, + } } - /// Creates a new instance of RegisteredGateway from a text representation of its principal. - fn from_text_principal(gateway_principal_text: &str) -> Self { - Self::new( - Principal::from_text(gateway_principal_text).expect(&format!( - "invalid gateway principal {gateway_principal_text}" - )), - ) + /// Resets the messages and nonce to the initial values. + fn reset(&mut self) { + self.messages_queue.clear(); + self.outgoing_message_nonce = INITIAL_OUTGOING_MESSAGE_NONCE; } - /// Gets the gateway principal. - fn get_principal(&self) -> &Principal { - &self.gateway_principal + /// Increments the outgoing message nonce by 1. + fn increment_nonce(&mut self) { + self.outgoing_message_nonce += 1; } } @@ -182,12 +188,12 @@ fn get_current_time() -> u64 { #[derive(Clone, Debug, Eq, PartialEq)] struct RegisteredClient { last_keep_alive_timestamp: u64, - gateway_principal: Principal, + gateway_principal: GatewayPrincipal, } impl RegisteredClient { /// Creates a new instance of RegisteredClient. - fn new(gateway_principal: Principal) -> Self { + fn new(gateway_principal: GatewayPrincipal) -> Self { Self { last_keep_alive_timestamp: get_current_time(), gateway_principal, @@ -219,13 +225,7 @@ thread_local! { /// Keeps track of the Merkle tree used for certified queries /* flexible */ static CERT_TREE: RefCell> = RefCell::new(RbTree::new()); /// Keeps track of the principals of the WS Gateways that poll the canister - /* flexible */ static REGISTERED_GATEWAYS: RefCell>> = RefCell::new(None); - /// Keeps track of the messages that have to be sent to each authorized WS Gateway - /* flexible */ static MESSAGES_FOR_GATEWAYS: RefCell>> = RefCell::new(HashMap::new()); - /// Keeps track of the nonce which: - /// - the WS Gateway uses to specify the first index of the certified messages to be returned when polling - /// - the client uses as part of the path in the Merkle tree in order to verify the certificate of the messages relayed by the WS Gateway - /* flexible */ static OUTGOING_MESSAGE_NONCE: RefCell> = RefCell::new(HashMap::new()); + /* flexible */ static REGISTERED_GATEWAYS: RefCell> = RefCell::new(HashMap::new()); /// The parameters passed in the CDK initialization /* flexible */ static PARAMS: RefCell = RefCell::new(WsInitParams::default()); /// The acknowledgement active timer. @@ -262,8 +262,11 @@ fn reset_internal_state() { CERT_TREE.with(|t| { t.replace(RbTree::new()); }); - MESSAGES_FOR_GATEWAYS.with(|m| *m.borrow_mut() = HashMap::new()); - OUTGOING_MESSAGE_NONCE.with(|m| *m.borrow_mut() = HashMap::new()); + REGISTERED_GATEWAYS.with(|map| { + for g in map.borrow_mut().values_mut() { + g.reset(); + } + }); } /// Resets the internal state of the IC WebSocket CDK. @@ -275,23 +278,18 @@ pub fn wipe() { custom_print!("Internal state has been wiped!"); } -fn get_outgoing_message_nonce(gateway_principal: &Principal) -> Result { - OUTGOING_MESSAGE_NONCE.with(|n| { - n.borrow() - .get(gateway_principal) - .cloned() - .ok_or(String::from( - "gateway doesn't have an outgoing message nonce", - )) - }) +fn get_outgoing_message_nonce(gateway_principal: &GatewayPrincipal) -> Result { + let registered_gateway = get_registered_gateway(gateway_principal)?; + Ok(registered_gateway.outgoing_message_nonce) } -fn increment_outgoing_message_nonce(gateway_principal: Principal) -> Result<(), String> { - let previous_nonce = get_outgoing_message_nonce(&gateway_principal)?; - OUTGOING_MESSAGE_NONCE.with(|n| { - n.borrow_mut().insert(gateway_principal, previous_nonce + 1); +fn increment_outgoing_message_nonce(gateway_principal: &GatewayPrincipal) { + REGISTERED_GATEWAYS.with(|map| { + map.borrow_mut() + .get_mut(gateway_principal) + .unwrap() // we should always have a registered gateway at this point + .increment_nonce(); }); - Ok(()) } fn insert_client(client_key: ClientKey, new_client: RegisteredClient) { @@ -331,7 +329,7 @@ fn check_registered_client(client_key: &ClientKey) -> Result<(), String> { Ok(()) } -fn get_gateway_principal_from_registered_client(client_key: &ClientKey) -> Principal { +fn get_gateway_principal_from_registered_client(client_key: &ClientKey) -> GatewayPrincipal { REGISTERED_CLIENTS.with(|map| { map.borrow() .get(client_key) @@ -347,37 +345,27 @@ fn add_client_to_wait_for_keep_alive(client_key: &ClientKey) { } fn initialize_registered_gateways(gateways_principals: Vec) { - let registered_gateways: Vec = gateways_principals - .iter() - .map(|s| RegisteredGateway::from_text_principal(s)) - .collect(); - - REGISTERED_GATEWAYS.with(|p| *p.borrow_mut() = Some(registered_gateways.clone())); - - for registered_gateway in registered_gateways.iter() { - OUTGOING_MESSAGE_NONCE.with(|n| { - let mut n = n.borrow_mut(); - n.insert( - *registered_gateway.get_principal(), - INITIAL_OUTGOING_MESSAGE_NONCE, - ); - }); - - MESSAGES_FOR_GATEWAYS.with(|n| { - let mut n = n.borrow_mut(); - n.insert(*registered_gateway.get_principal(), VecDeque::new()); + for gateway_principal_text in gateways_principals.iter() { + let gateway_principal = Principal::from_text(gateway_principal_text).expect(&format!( + "invalid gateway principal {gateway_principal_text}" + )); + REGISTERED_GATEWAYS.with(|map| { + map.borrow_mut() + .insert(gateway_principal, RegisteredGateway::new()) }); } } -fn get_registered_gateways_principals() -> Vec { - REGISTERED_GATEWAYS.with(|g| { - g.borrow() - .as_ref() - .expect("gateways map should be initialized") - .iter() - .map(|registered_gateway| *registered_gateway.get_principal()) - .collect() +fn get_registered_gateway( + gateway_principal: &GatewayPrincipal, +) -> Result { + REGISTERED_GATEWAYS.with(|map| { + map.borrow() + .get(gateway_principal) + .cloned() + .ok_or(String::from(format!( + "no gateway registered with principal {gateway_principal}" + ))) }) } @@ -466,16 +454,20 @@ fn remove_client(client_key: &ClientKey) { }); } -fn format_message_for_gateway_key(gateway_principal: &Principal, nonce: u64) -> String { +fn format_message_for_gateway_key(gateway_principal: &GatewayPrincipal, nonce: u64) -> String { gateway_principal.to_string() + "_" + &format!("{:0>20}", nonce.to_string()) } -fn get_messages_for_gateway_range(gateway_principal: &Principal, nonce: u64) -> (usize, usize) { +fn get_messages_for_gateway_range( + gateway_principal: &GatewayPrincipal, + nonce: u64, +) -> (usize, usize) { let max_number_of_returned_messages = get_params().max_number_of_returned_messages; - MESSAGES_FOR_GATEWAYS.with(|h| { - let h = h.borrow(); - let messages_queue = h.get(gateway_principal).unwrap(); // the value exists because we just checked that the gateway is registered + REGISTERED_GATEWAYS.with(|map| { + let map = map.borrow(); + let messages_queue = &map.get(gateway_principal).unwrap().messages_queue; // the value exists because we just checked that the gateway is registered + let queue_len = messages_queue.len(); if nonce == 0 && queue_len > 0 { @@ -504,17 +496,18 @@ fn get_messages_for_gateway_range(gateway_principal: &Principal, nonce: u64) -> } fn get_messages_for_gateway( - gateway_principal: &Principal, + gateway_principal: &GatewayPrincipal, start_index: usize, end_index: usize, ) -> Vec { - MESSAGES_FOR_GATEWAYS.with(|h| { + REGISTERED_GATEWAYS.with(|map| { + let map = map.borrow(); + let messages_queue = &map.get(gateway_principal).unwrap().messages_queue; // the value exists because we just checked that the gateway is registered + let mut messages: Vec = Vec::with_capacity(end_index - start_index); for index in start_index..end_index { messages.push( - h.borrow() - .get(&gateway_principal) - .unwrap() // the value exists because we just checked that the gateway is registered + messages_queue .get(index) .unwrap() // the value exists because this function is called only after partitioning the queue .clone(), @@ -525,7 +518,10 @@ fn get_messages_for_gateway( } /// Gets the messages in [MESSAGES_FOR_GATEWAYS] starting from the one with the specified nonce -fn get_cert_messages(gateway_principal: &Principal, nonce: u64) -> CanisterWsGetMessagesResult { +fn get_cert_messages( + gateway_principal: &GatewayPrincipal, + nonce: u64, +) -> CanisterWsGetMessagesResult { let (start_index, end_index) = get_messages_for_gateway_range(gateway_principal, nonce); let messages = get_messages_for_gateway(gateway_principal, start_index, end_index); @@ -549,7 +545,7 @@ fn get_cert_messages(gateway_principal: &Principal, nonce: u64) -> CanisterWsGet } fn is_registered_gateway(principal: &Principal) -> bool { - get_registered_gateways_principals().contains(principal) + REGISTERED_GATEWAYS.with(|map| map.borrow().contains_key(principal)) } /// Checks if the principal of the authorized WS Gateways that have been registered during the initialization of the CDK @@ -808,7 +804,7 @@ fn _ws_send( let message_key = format_message_for_gateway_key(&gateway_principal, outgoing_message_nonce); // increment the nonce for the next message - increment_outgoing_message_nonce(gateway_principal)?; // we never hit the error case because we just checked that the gateway is registered + increment_outgoing_message_nonce(&gateway_principal); // increment the sequence number for the next message to the client increment_outgoing_message_to_client_num(client_key)?; @@ -827,13 +823,14 @@ fn _ws_send( // certify data put_cert_for_message(message_key.clone(), &message_content); - MESSAGES_FOR_GATEWAYS.with(|h| { + REGISTERED_GATEWAYS.with(|map| { // messages in the queue are inserted with contiguous and increasing nonces // (from beginning to end of the queue) as ws_send is called sequentially, the nonce // is incremented by one in each call, and the message is pushed at the end of the queue - h.borrow_mut() + map.borrow_mut() .get_mut(&gateway_principal) .unwrap() // the value exists because we just checked that the gateway is registered + .messages_queue .push_back(CanisterOutputMessage { client_key: client_key.clone(), content: message_content, diff --git a/src/ic-websocket-cdk/src/tests/unit_tests.rs b/src/ic-websocket-cdk/src/tests/unit_tests/mod.rs similarity index 62% rename from src/ic-websocket-cdk/src/tests/unit_tests.rs rename to src/ic-websocket-cdk/src/tests/unit_tests/mod.rs index 97e0728..5be4f2c 100644 --- a/src/ic-websocket-cdk/src/tests/unit_tests.rs +++ b/src/ic-websocket-cdk/src/tests/unit_tests/mod.rs @@ -1,78 +1,7 @@ -use super::super::*; +use crate::*; use proptest::prelude::*; -mod test_utils { - use candid::Principal; - use ic_agent::{identity::BasicIdentity, Identity}; - use ring::signature::Ed25519KeyPair; - - use super::{ - format_message_for_gateway_key, CanisterOutputMessage, ClientKey, RegisteredClient, - MESSAGES_FOR_GATEWAYS, - }; - - fn generate_random_key_pair() -> Ed25519KeyPair { - let rng = ring::rand::SystemRandom::new(); - let key_pair = - Ed25519KeyPair::generate_pkcs8(&rng).expect("Could not generate a key pair."); - Ed25519KeyPair::from_pkcs8(key_pair.as_ref()).expect("Could not read the key pair.") - } - - pub fn generate_random_principal() -> candid::Principal { - let key_pair = generate_random_key_pair(); - let identity = BasicIdentity::from_key_pair(key_pair); - - // workaround to keep the principal in the version of candid used by the canister - candid::Principal::from_text(identity.sender().unwrap().to_text()).unwrap() - } - - pub(super) fn generate_random_registered_client() -> RegisteredClient { - RegisteredClient::new(Principal::anonymous()) - } - - pub fn get_static_principal() -> Principal { - Principal::from_text("wnkwv-wdqb5-7wlzr-azfpw-5e5n5-dyxrf-uug7x-qxb55-mkmpa-5jqik-tqe") - .unwrap() // a random static but valid principal - } - - pub(super) fn get_random_client_key() -> ClientKey { - ClientKey::new( - generate_random_principal(), - // a random nonce - rand::random(), - ) - } - - pub(super) fn add_messages_for_gateway( - client_key: ClientKey, - gateway_principal: Principal, - count: u64, - ) { - MESSAGES_FOR_GATEWAYS.with(|m| { - for i in 0..count { - m.borrow_mut() - .get_mut(&gateway_principal) - .expect("TODO") - .push_back(CanisterOutputMessage { - client_key: client_key.clone(), - key: format_message_for_gateway_key(&gateway_principal, i), - content: vec![], - }); - } - }); - } - - pub fn clean_messages_for_gateway() { - MESSAGES_FOR_GATEWAYS.with(|m| m.borrow_mut().clear()); - } -} - -// we don't need to proptest get_gateway_principal if principal is not set, as it just panics -#[test] -#[should_panic = "gateways map should be initialized"] -fn test_get_gateway_principal_not_set() { - get_registered_gateways_principals(); -} +mod utils; #[test] fn test_ws_handlers_are_called() { @@ -114,14 +43,14 @@ fn test_ws_handlers_are_called() { assert!(handlers.on_close.is_none()); handlers.call_on_open(OnOpenCallbackArgs { - client_principal: test_utils::generate_random_principal(), + client_principal: utils::generate_random_principal(), }); handlers.call_on_message(OnMessageCallbackArgs { - client_principal: test_utils::generate_random_principal(), + client_principal: utils::generate_random_principal(), message: vec![], }); handlers.call_on_close(OnCloseCallbackArgs { - client_principal: test_utils::generate_random_principal(), + client_principal: utils::generate_random_principal(), }); // test that the handlers are not called if they are not initialized @@ -167,14 +96,14 @@ fn test_ws_handlers_are_called() { assert!(handlers.on_close.is_some()); handlers.call_on_open(OnOpenCallbackArgs { - client_principal: test_utils::generate_random_principal(), + client_principal: utils::generate_random_principal(), }); handlers.call_on_message(OnMessageCallbackArgs { - client_principal: test_utils::generate_random_principal(), + client_principal: utils::generate_random_principal(), message: vec![], }); handlers.call_on_close(OnCloseCallbackArgs { - client_principal: test_utils::generate_random_principal(), + client_principal: utils::generate_random_principal(), }); // test that the handlers are called if they are initialized @@ -206,20 +135,20 @@ fn test_ws_handlers_panic_is_handled() { let res = panic::catch_unwind(|| { handlers.call_on_open(OnOpenCallbackArgs { - client_principal: test_utils::generate_random_principal(), + client_principal: utils::generate_random_principal(), }); }); assert!(res.is_ok()); let res = panic::catch_unwind(|| { handlers.call_on_message(OnMessageCallbackArgs { - client_principal: test_utils::generate_random_principal(), + client_principal: utils::generate_random_principal(), message: vec![], }); }); assert!(res.is_ok()); let res = panic::catch_unwind(|| { handlers.call_on_close(OnCloseCallbackArgs { - client_principal: test_utils::generate_random_principal(), + client_principal: utils::generate_random_principal(), }); }); assert!(res.is_ok()); @@ -233,48 +162,56 @@ fn test_current_time() { proptest! { #[test] - fn test_initialize_registered_gateway(test_gateway_principal in any::().prop_map(|_| test_utils::generate_random_principal())) { + fn test_initialize_registered_gateways(test_gateway_principal in any::().prop_map(|_| utils::generate_random_principal())) { initialize_registered_gateways(vec![test_gateway_principal.to_string()]); - REGISTERED_GATEWAYS.with(|p| { - let p = p.borrow(); - assert!(p.is_some()); - assert_eq!( - p.to_owned().unwrap()[0], - RegisteredGateway::new(test_gateway_principal) - ); + let map = REGISTERED_GATEWAYS.with(|map| map.borrow().clone()); + prop_assert!(map.get(&test_gateway_principal).is_some()); + prop_assert_eq!( + map.get(&test_gateway_principal).unwrap(), + &RegisteredGateway::new() + ); + } + + #[test] + fn test_initialize_registered_gateways_wrong(test_gateway_principal in any::()) { + let res = panic::catch_unwind(|| { + initialize_registered_gateways(vec![test_gateway_principal]); }); + prop_assert!(res.is_err()); } #[test] fn test_get_outgoing_message_nonce(test_nonce in any::()) { // Set up - let gateway_principal = test_utils::generate_random_principal(); - OUTGOING_MESSAGE_NONCE.with(|n| n.borrow_mut().insert(gateway_principal, test_nonce)); + let gateway_principal = utils::generate_random_principal(); + REGISTERED_GATEWAYS.with(|n| n.borrow_mut().insert(gateway_principal, RegisteredGateway { outgoing_message_nonce: test_nonce, ..Default::default() })); let res = get_outgoing_message_nonce(&gateway_principal); - prop_assert!(res.is_ok()); - let actual_nonce = res.unwrap(); - prop_assert_eq!(actual_nonce, test_nonce); + prop_assert_eq!(res.ok(), Some(test_nonce)); + } + + #[test] + fn test_get_outgoing_message_nonce_nonexistent(test_gateway_principal in any::().prop_map(|_| utils::generate_random_principal())) { + let res = get_outgoing_message_nonce(&test_gateway_principal); + prop_assert_eq!(res.err(), Some(String::from(format!("no gateway registered with principal {test_gateway_principal}")))); } #[test] fn test_increment_outgoing_message_nonce(test_nonce in any::()) { // Set up - let gateway_principal = test_utils::generate_random_principal(); - OUTGOING_MESSAGE_NONCE.with(|n| n.borrow_mut().insert(gateway_principal, test_nonce)); + let gateway_principal = utils::generate_random_principal(); + REGISTERED_GATEWAYS.with(|n| n.borrow_mut().insert(gateway_principal, RegisteredGateway { outgoing_message_nonce: test_nonce, ..Default::default() })); - let res = increment_outgoing_message_nonce(gateway_principal); - prop_assert!(res.is_ok()); + increment_outgoing_message_nonce(&gateway_principal); let res = get_outgoing_message_nonce(&gateway_principal); - prop_assert!(res.is_ok()); - prop_assert_eq!(res.unwrap(), test_nonce + 1); + prop_assert_eq!(res.ok(), Some(test_nonce + 1)); } #[test] - fn test_insert_client(test_client_key in any::().prop_map(|_| test_utils::get_random_client_key())) { + fn test_insert_client(test_client_key in any::().prop_map(|_| utils::get_random_client_key())) { // Set up - let registered_client = test_utils::generate_random_registered_client(); + let registered_client = utils::generate_random_registered_client(); insert_client(test_client_key.clone(), registered_client.clone()); @@ -286,25 +223,31 @@ proptest! { } #[test] - fn test_get_gateway_principal(test_gateway_principal in any::().prop_map(|_| test_utils::generate_random_principal())) { + fn test_get_registered_gateway(test_gateway_principal in any::().prop_map(|_| utils::generate_random_principal())) { // Set up - REGISTERED_GATEWAYS.with(|p| *p.borrow_mut() = Some(vec![RegisteredGateway::new(test_gateway_principal.clone())])); + REGISTERED_GATEWAYS.with(|n| n.borrow_mut().insert(test_gateway_principal, RegisteredGateway::new())); - let actual_gateway_principal = get_registered_gateways_principals()[0]; - prop_assert_eq!(actual_gateway_principal, test_gateway_principal); + let res = get_registered_gateway(&test_gateway_principal); + prop_assert_eq!(res.ok(), Some(RegisteredGateway::new())); } #[test] - fn test_is_client_registered_empty(test_client_key in any::().prop_map(|_| test_utils::get_random_client_key())) { + fn test_get_registered_gateway_nonexistent(test_gateway_principal in any::().prop_map(|_| utils::generate_random_principal())) { + let res = get_registered_gateway(&test_gateway_principal); + prop_assert_eq!(res.err(), Some(String::from(format!("no gateway registered with principal {test_gateway_principal}")))); + } + + #[test] + fn test_is_client_registered_empty(test_client_key in any::().prop_map(|_| utils::get_random_client_key())) { let actual_result = is_client_registered(&test_client_key); prop_assert_eq!(actual_result, false); } #[test] - fn test_is_client_registered(test_client_key in any::().prop_map(|_| test_utils::get_random_client_key())) { + fn test_is_client_registered(test_client_key in any::().prop_map(|_| utils::get_random_client_key())) { // Set up REGISTERED_CLIENTS.with(|map| { - map.borrow_mut().insert(test_client_key.clone(), test_utils::generate_random_registered_client()); + map.borrow_mut().insert(test_client_key.clone(), utils::generate_random_registered_client()); }); let actual_result = is_client_registered(&test_client_key); @@ -312,7 +255,7 @@ proptest! { } #[test] - fn test_get_client_key_from_principal_empty(test_client_principal in any::().prop_map(|_| test_utils::generate_random_principal())) { + fn test_get_client_key_from_principal_empty(test_client_principal in any::().prop_map(|_| utils::generate_random_principal())) { let actual_result = get_client_key_from_principal(&test_client_principal); prop_assert_eq!(actual_result.err(), Some(String::from(format!( "client with principal {} doesn't have an open connection", @@ -321,7 +264,7 @@ proptest! { } #[test] - fn test_get_client_key_from_principal(test_client_key in any::().prop_map(|_| test_utils::get_random_client_key())) { + fn test_get_client_key_from_principal(test_client_key in any::().prop_map(|_| utils::get_random_client_key())) { // Set up CURRENT_CLIENT_KEY_MAP.with(|map| { map.borrow_mut().insert(test_client_key.client_principal, test_client_key.clone()); @@ -332,27 +275,27 @@ proptest! { } #[test] - fn test_check_registered_client_empty(test_client_key in any::().prop_map(|_| test_utils::get_random_client_key())) { + fn test_check_registered_client_empty(test_client_key in any::().prop_map(|_| utils::get_random_client_key())) { let actual_result = check_registered_client(&test_client_key); prop_assert_eq!(actual_result.err(), Some(format!("client with key {} doesn't have an open connection", test_client_key))); } #[test] - fn test_check_registered_client(test_client_key in any::().prop_map(|_| test_utils::get_random_client_key())) { + fn test_check_registered_client(test_client_key in any::().prop_map(|_| utils::get_random_client_key())) { // Set up REGISTERED_CLIENTS.with(|map| { - map.borrow_mut().insert(test_client_key.clone(), test_utils::generate_random_registered_client()); + map.borrow_mut().insert(test_client_key.clone(), utils::generate_random_registered_client()); }); let actual_result = check_registered_client(&test_client_key); prop_assert!(actual_result.is_ok()); - let non_existing_client_key = test_utils::get_random_client_key(); + let non_existing_client_key = utils::get_random_client_key(); let actual_result = check_registered_client(&non_existing_client_key); prop_assert_eq!(actual_result.err(), Some(format!("client with key {} doesn't have an open connection", non_existing_client_key))); } #[test] - fn test_init_outgoing_message_to_client_num(test_client_key in any::().prop_map(|_| test_utils::get_random_client_key())) { + fn test_init_outgoing_message_to_client_num(test_client_key in any::().prop_map(|_| utils::get_random_client_key())) { init_outgoing_message_to_client_num(test_client_key.clone()); let actual_result = OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| map.borrow().get(&test_client_key).unwrap().clone()); @@ -360,7 +303,7 @@ proptest! { } #[test] - fn test_increment_outgoing_message_to_client_num(test_client_key in any::().prop_map(|_| test_utils::get_random_client_key()), test_num in any::()) { + fn test_increment_outgoing_message_to_client_num(test_client_key in any::().prop_map(|_| utils::get_random_client_key()), test_num in any::()) { // Set up OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| { map.borrow_mut().insert(test_client_key.clone(), test_num); @@ -374,7 +317,7 @@ proptest! { } #[test] - fn test_get_outgoing_message_to_client_num(test_client_key in any::().prop_map(|_| test_utils::get_random_client_key()), test_num in any::()) { + fn test_get_outgoing_message_to_client_num(test_client_key in any::().prop_map(|_| utils::get_random_client_key()), test_num in any::()) { // Set up OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| { map.borrow_mut().insert(test_client_key.clone(), test_num); @@ -386,7 +329,7 @@ proptest! { } #[test] - fn test_init_expected_incoming_message_from_client_num(test_client_key in any::().prop_map(|_| test_utils::get_random_client_key())) { + fn test_init_expected_incoming_message_from_client_num(test_client_key in any::().prop_map(|_| utils::get_random_client_key())) { init_expected_incoming_message_from_client_num(test_client_key.clone()); let actual_result = INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| map.borrow().get(&test_client_key).unwrap().clone()); @@ -394,7 +337,7 @@ proptest! { } #[test] - fn test_get_expected_incoming_message_from_client_num(test_client_key in any::().prop_map(|_| test_utils::get_random_client_key()), test_num in any::()) { + fn test_get_expected_incoming_message_from_client_num(test_client_key in any::().prop_map(|_| utils::get_random_client_key()), test_num in any::()) { // Set up INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| { map.borrow_mut().insert(test_client_key.clone(), test_num); @@ -406,7 +349,7 @@ proptest! { } #[test] - fn test_increment_expected_incoming_message_from_client_num(test_client_key in any::().prop_map(|_| test_utils::get_random_client_key()), test_num in any::()) { + fn test_increment_expected_incoming_message_from_client_num(test_client_key in any::().prop_map(|_| utils::get_random_client_key()), test_num in any::()) { // Set up INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| { map.borrow_mut().insert(test_client_key.clone(), test_num); @@ -420,7 +363,7 @@ proptest! { } #[test] - fn test_add_client_to_wait_for_keep_alive(test_client_key in any::().prop_map(|_| test_utils::get_random_client_key())) { + fn test_add_client_to_wait_for_keep_alive(test_client_key in any::().prop_map(|_| utils::get_random_client_key())) { add_client_to_wait_for_keep_alive(&test_client_key); let actual_result = CLIENTS_WAITING_FOR_KEEP_ALIVE.with(|map| map.borrow().get(&test_client_key).is_some()); @@ -428,8 +371,8 @@ proptest! { } #[test] - fn test_add_client(test_client_key in any::().prop_map(|_| test_utils::get_random_client_key())) { - let registered_client = test_utils::generate_random_registered_client(); + fn test_add_client(test_client_key in any::().prop_map(|_| utils::get_random_client_key())) { + let registered_client = utils::generate_random_registered_client(); // Test add_client(test_client_key.clone(), registered_client.clone()); @@ -448,13 +391,13 @@ proptest! { } #[test] - fn test_remove_client(test_client_key in any::().prop_map(|_| test_utils::get_random_client_key())) { + fn test_remove_client(test_client_key in any::().prop_map(|_| utils::get_random_client_key())) { // Set up CURRENT_CLIENT_KEY_MAP.with(|map| { map.borrow_mut().insert(test_client_key.client_principal.clone(), test_client_key.clone()); }); REGISTERED_CLIENTS.with(|map| { - map.borrow_mut().insert(test_client_key.clone(), test_utils::generate_random_registered_client()); + map.borrow_mut().insert(test_client_key.clone(), utils::generate_random_registered_client()); }); INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| { map.borrow_mut().insert(test_client_key.clone(), INITIAL_CLIENT_SEQUENCE_NUM); @@ -479,7 +422,7 @@ proptest! { } #[test] - fn test_get_message_for_gateway_key(test_gateway_principal in any::().prop_map(|_| test_utils::generate_random_principal()), test_nonce in any::()) { + fn test_format_message_for_gateway_key(test_gateway_principal in any::().prop_map(|_| utils::generate_random_principal()), test_nonce in any::()) { let actual_result = format_message_for_gateway_key(&test_gateway_principal, test_nonce); prop_assert_eq!(actual_result, test_gateway_principal.to_string() + "_" + &format!("{:0>20}", test_nonce.to_string())); } @@ -487,13 +430,9 @@ proptest! { #[test] fn test_get_messages_for_gateway_range_empty(messages_count in any::().prop_map(|c| c % 1000)) { // Set up - let gateway_principal = test_utils::generate_random_principal(); - REGISTERED_GATEWAYS.with(|p| *p.borrow_mut() = Some(vec![RegisteredGateway::new(gateway_principal.clone())])); - MESSAGES_FOR_GATEWAYS.with(|m| *m.borrow_mut() = { - let mut m = HashMap::new(); - m.insert(gateway_principal.clone(), VecDeque::new()); - m - }); + utils::initialize_params(); + let gateway_principal = utils::generate_random_principal(); + REGISTERED_GATEWAYS.with(|n| n.borrow_mut().insert(gateway_principal, RegisteredGateway::new())); // Test // we ask for a random range of messages to check if it always returns the same range for empty messages @@ -505,18 +444,14 @@ proptest! { } #[test] - fn test_get_messages_for_gateway_range_smaller_than_max(gateway_principal in any::().prop_map(|_| test_utils::get_static_principal())) { + fn test_get_messages_for_gateway_range_smaller_than_max(gateway_principal in any::().prop_map(|_| utils::get_static_principal())) { // Set up - REGISTERED_GATEWAYS.with(|p| *p.borrow_mut() = Some(vec![RegisteredGateway::new(gateway_principal.clone())])); - MESSAGES_FOR_GATEWAYS.with(|m| *m.borrow_mut() = { - let mut m = HashMap::new(); - m.insert(gateway_principal.clone(), VecDeque::new()); - m - }); + utils::initialize_params(); + REGISTERED_GATEWAYS.with(|n| n.borrow_mut().insert(gateway_principal, RegisteredGateway::new())); let messages_count = 4; - let test_client_key = test_utils::get_random_client_key(); - test_utils::add_messages_for_gateway(test_client_key, gateway_principal, messages_count); + let test_client_key = utils::get_random_client_key(); + utils::add_messages_for_gateway(test_client_key, &gateway_principal, messages_count); // Test // messages are just 4, so we don't exceed the max number of returned messages @@ -528,28 +463,21 @@ proptest! { } // Clean up - test_utils::clean_messages_for_gateway(); + utils::clean_messages_for_gateway(&gateway_principal); } #[test] - fn test_get_messages_for_gateway_range_larger_than_max(gateway_principal in any::().prop_map(|_| test_utils::get_static_principal()), max_number_of_returned_messages in any::().prop_map(|c| c % 1000)) { + fn test_get_messages_for_gateway_range_larger_than_max(gateway_principal in any::().prop_map(|_| utils::get_static_principal()), max_number_of_returned_messages in any::().prop_map(|c| c % 1000)) { // Set up - PARAMS.with(|p| { - *p.borrow_mut() = WsInitParams { - max_number_of_returned_messages, - ..Default::default() - } - }); - REGISTERED_GATEWAYS.with(|p| *p.borrow_mut() = Some(vec![RegisteredGateway::new(gateway_principal.clone())])); - MESSAGES_FOR_GATEWAYS.with(|m| *m.borrow_mut() = { - let mut m = HashMap::new(); - m.insert(gateway_principal.clone(), VecDeque::new()); - m + set_params(WsInitParams { + max_number_of_returned_messages, + ..Default::default() }); + REGISTERED_GATEWAYS.with(|n| n.borrow_mut().insert(gateway_principal, RegisteredGateway::new())); let messages_count: u64 = (2 * max_number_of_returned_messages).try_into().unwrap(); - let test_client_key = test_utils::get_random_client_key(); - test_utils::add_messages_for_gateway(test_client_key, gateway_principal, messages_count); + let test_client_key = utils::get_random_client_key(); + utils::add_messages_for_gateway(test_client_key, &gateway_principal, messages_count); // Test // messages are now 2 * MAX_NUMBER_OF_RETURNED_MESSAGES @@ -566,27 +494,20 @@ proptest! { } // Clean up - test_utils::clean_messages_for_gateway(); + utils::clean_messages_for_gateway(&gateway_principal); } #[test] - fn test_get_messages_for_gateway_initial_nonce(gateway_principal in any::().prop_map(|_| test_utils::get_static_principal()), messages_count in any::().prop_map(|c| c % 100), max_number_of_returned_messages in any::().prop_map(|c| c % 1000)) { + fn test_get_messages_for_gateway_initial_nonce(gateway_principal in any::().prop_map(|_| utils::get_static_principal()), messages_count in any::().prop_map(|c| c % 100), max_number_of_returned_messages in any::().prop_map(|c| c % 1000)) { // Set up - PARAMS.with(|p| { - *p.borrow_mut() = WsInitParams { - max_number_of_returned_messages, - ..Default::default() - } - }); - REGISTERED_GATEWAYS.with(|p| *p.borrow_mut() = Some(vec![RegisteredGateway::new(gateway_principal.clone())])); - MESSAGES_FOR_GATEWAYS.with(|m| *m.borrow_mut() = { - let mut m = HashMap::new(); - m.insert(gateway_principal.clone(), VecDeque::new()); - m + set_params(WsInitParams { + max_number_of_returned_messages, + ..Default::default() }); + REGISTERED_GATEWAYS.with(|n| n.borrow_mut().insert(gateway_principal, RegisteredGateway::new())); - let test_client_key = test_utils::get_random_client_key(); - test_utils::add_messages_for_gateway(test_client_key, gateway_principal, messages_count); + let test_client_key = utils::get_random_client_key(); + utils::add_messages_for_gateway(test_client_key, &gateway_principal, messages_count); // Test let (start_index, end_index) = get_messages_for_gateway_range(&gateway_principal, 0); @@ -599,21 +520,16 @@ proptest! { prop_assert_eq!(end_index, messages_count as usize); // Clean up - test_utils::clean_messages_for_gateway(); + utils::clean_messages_for_gateway(&gateway_principal); } #[test] - fn test_get_messages_for_gateway(gateway_principal in any::().prop_map(|_| test_utils::get_static_principal()), messages_count in any::().prop_map(|c| c % 100)) { + fn test_get_messages_for_gateway(gateway_principal in any::().prop_map(|_| utils::get_static_principal()), messages_count in any::().prop_map(|c| c % 100)) { // Set up - REGISTERED_GATEWAYS.with(|p| *p.borrow_mut() = Some(vec![RegisteredGateway::new(gateway_principal.clone())])); - MESSAGES_FOR_GATEWAYS.with(|m| *m.borrow_mut() = { - let mut m = HashMap::new(); - m.insert(gateway_principal.clone(), VecDeque::new()); - m - }); + REGISTERED_GATEWAYS.with(|n| n.borrow_mut().insert(gateway_principal, RegisteredGateway::new())); - let test_client_key = test_utils::get_random_client_key(); - test_utils::add_messages_for_gateway(test_client_key, gateway_principal, messages_count); + let test_client_key = utils::get_random_client_key(); + utils::add_messages_for_gateway(test_client_key, &gateway_principal, messages_count); // Test // add one to test the out of range index @@ -629,18 +545,18 @@ proptest! { } // Clean up - test_utils::clean_messages_for_gateway(); + utils::clean_messages_for_gateway(&gateway_principal); } #[test] - fn test_check_is_registered_gateway(test_gateway_principal in any::().prop_map(|_| test_utils::generate_random_principal())) { + fn test_check_is_registered_gateway(test_gateway_principal in any::().prop_map(|_| utils::generate_random_principal())) { // Set up - REGISTERED_GATEWAYS.with(|p| *p.borrow_mut() = Some(vec![RegisteredGateway::new(test_gateway_principal.clone())])); + REGISTERED_GATEWAYS.with(|n| n.borrow_mut().insert(test_gateway_principal, RegisteredGateway::new())); let actual_result = check_is_registered_gateway(&test_gateway_principal); prop_assert!(actual_result.is_ok()); - let other_principal = test_utils::generate_random_principal(); + let other_principal = utils::generate_random_principal(); let actual_result = check_is_registered_gateway(&other_principal); prop_assert_eq!(actual_result.err(), Some(String::from("principal is not one of the authorized gateways that have been registered during CDK initialization"))); } @@ -649,7 +565,7 @@ proptest! { fn test_serialize_websocket_message(test_msg_bytes in any::>(), test_sequence_num in any::(), test_timestamp in any::()) { // TODO: add more tests, in which we check the serialized message let websocket_message = WebsocketMessage { - client_key: test_utils::get_random_client_key(), + client_key: utils::get_random_client_key(), sequence_num: test_sequence_num, timestamp: test_timestamp, is_service_message: false, diff --git a/src/ic-websocket-cdk/src/tests/unit_tests/utils.rs b/src/ic-websocket-cdk/src/tests/unit_tests/utils.rs new file mode 100644 index 0000000..0966412 --- /dev/null +++ b/src/ic-websocket-cdk/src/tests/unit_tests/utils.rs @@ -0,0 +1,75 @@ +use candid::Principal; +use ic_agent::{identity::BasicIdentity, Identity}; +use ring::signature::Ed25519KeyPair; + +use crate::{ + format_message_for_gateway_key, set_params, CanisterOutputMessage, ClientKey, GatewayPrincipal, + RegisteredClient, WsInitParams, REGISTERED_GATEWAYS, +}; + +fn generate_random_key_pair() -> Ed25519KeyPair { + let rng = ring::rand::SystemRandom::new(); + let key_pair = Ed25519KeyPair::generate_pkcs8(&rng).expect("Could not generate a key pair."); + Ed25519KeyPair::from_pkcs8(key_pair.as_ref()).expect("Could not read the key pair.") +} + +pub fn generate_random_principal() -> Principal { + let key_pair = generate_random_key_pair(); + let identity = BasicIdentity::from_key_pair(key_pair); + + // workaround to keep the principal in the version of candid used by the canister + Principal::from_text(identity.sender().unwrap().to_text()).unwrap() +} + +pub(super) fn generate_random_registered_client() -> RegisteredClient { + RegisteredClient::new(Principal::anonymous()) +} + +pub fn get_static_principal() -> Principal { + Principal::from_text("wnkwv-wdqb5-7wlzr-azfpw-5e5n5-dyxrf-uug7x-qxb55-mkmpa-5jqik-tqe").unwrap() + // a random static but valid principal +} + +pub(super) fn get_random_client_key() -> ClientKey { + ClientKey::new( + generate_random_principal(), + // a random nonce + rand::random(), + ) +} + +pub(super) fn add_messages_for_gateway( + client_key: ClientKey, + gateway_principal: &GatewayPrincipal, + count: u64, +) { + REGISTERED_GATEWAYS.with(|m| { + for i in 0..count { + m.borrow_mut() + .get_mut(gateway_principal) + .unwrap() + .messages_queue + .push_back(CanisterOutputMessage { + client_key: client_key.clone(), + key: format_message_for_gateway_key(&gateway_principal, i), + content: vec![], + }); + } + }); +} + +pub fn clean_messages_for_gateway(gateway_principal: &GatewayPrincipal) { + REGISTERED_GATEWAYS.with(|m| { + m.borrow_mut() + .get_mut(gateway_principal) + .unwrap() + .messages_queue + .clear() + }); +} + +pub fn initialize_params() { + set_params(WsInitParams { + ..Default::default() + }); +} From f55b0b4352a06517b9630988b8cbef344a12010c Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Wed, 15 Nov 2023 11:29:41 +0100 Subject: [PATCH 10/27] chore: update unit tests path in readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index db0e1b2..839cb84 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ The **ic-websocket-cdk** library implementation can be found in the [src/ic-webs ### Testing There are two types of tests available: -- **Unit tests**: tests for CDK functions, written in Rust and available in the [unit_tests.rs](./src/ic-websocket-cdk/src/tests/unit_tests.rs) file. +- **Unit tests**: tests for CDK functions, written in Rust and available in the [unit_tests](./src/ic-websocket-cdk/src/tests/unit_tests/) folder. - **Integration tests**: for these tests the CDK is deployed to a [test canister](./src/test_canister/). These tests are written in Rust and use [PocketIC](https://github.com/dfinity/pocketic) under the hood. They are available in the [integration_tests](./src/ic-websocket-cdk/src/tests/integration_tests/) folder. There's a script that runs all the tests together, taking care of setting up the environment (Linux only!) and deploying the canister. To run the script, execute the following command: From daf4e24fd034a715d6a8bcaf5028f3f3852b1e92 Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Wed, 15 Nov 2023 11:31:06 +0100 Subject: [PATCH 11/27] fix: rename gateway principal in did --- src/ic-websocket-cdk/ws_types.did | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/ic-websocket-cdk/ws_types.did b/src/ic-websocket-cdk/ws_types.did index d801a9f..6db065a 100644 --- a/src/ic-websocket-cdk/ws_types.did +++ b/src/ic-websocket-cdk/ws_types.did @@ -1,4 +1,5 @@ type ClientPrincipal = principal; +type GatewayPrincipal = principal; type ClientKey = record { client_principal : ClientPrincipal; client_nonce : nat64; @@ -26,7 +27,7 @@ type CanisterOutputCertifiedMessages = record { type CanisterWsOpenArguments = record { client_nonce : nat64; - gateway_principal : principal; + gateway_principal : GatewayPrincipal; }; type CanisterWsOpenResult = variant { From 8b450e679a104409524b81fcd50122bd7c541f34 Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Wed, 15 Nov 2023 11:55:19 +0100 Subject: [PATCH 12/27] feat: TEST_ENV mutex --- .../tests/integration_tests/b_ws_message.rs | 4 +- .../integration_tests/c_ws_get_messages.rs | 15 +++--- .../src/tests/integration_tests/d_ws_close.rs | 4 +- .../src/tests/integration_tests/e_ws_send.rs | 4 +- .../f_messages_acknowledgement.rs | 39 ++++++++------- .../tests/integration_tests/utils/actor.rs | 49 +++++++++++-------- .../tests/integration_tests/utils/test_env.rs | 36 +++++++++++--- 7 files changed, 91 insertions(+), 60 deletions(-) diff --git a/src/ic-websocket-cdk/src/tests/integration_tests/b_ws_message.rs b/src/ic-websocket-cdk/src/tests/integration_tests/b_ws_message.rs index 84db537..a4ba0be 100644 --- a/src/ic-websocket-cdk/src/tests/integration_tests/b_ws_message.rs +++ b/src/ic-websocket-cdk/src/tests/integration_tests/b_ws_message.rs @@ -9,13 +9,13 @@ use super::utils::{ actor::{ws_message::call_ws_message, ws_open::call_ws_open_for_client_key_with_panic}, clients::{generate_random_client_nonce, CLIENT_1_KEY, CLIENT_2, CLIENT_2_KEY}, messages::{create_websocket_message, encode_websocket_service_message_content}, - test_env::TEST_ENV, + test_env::get_test_env, }; #[test] fn test_1_fails_if_client_is_not_registered() { // first, reset the canister - TEST_ENV.reset_canister_with_default_params(); + get_test_env().reset_canister_with_default_params(); // second, open a connection for client 1 call_ws_open_for_client_key_with_panic(CLIENT_1_KEY.deref()); diff --git a/src/ic-websocket-cdk/src/tests/integration_tests/c_ws_get_messages.rs b/src/ic-websocket-cdk/src/tests/integration_tests/c_ws_get_messages.rs index 7b9bdd5..9e11dde 100644 --- a/src/ic-websocket-cdk/src/tests/integration_tests/c_ws_get_messages.rs +++ b/src/ic-websocket-cdk/src/tests/integration_tests/c_ws_get_messages.rs @@ -13,13 +13,13 @@ use super::utils::{ clients::{CLIENT_1_KEY, GATEWAY_1, GATEWAY_2}, constants::{DEFAULT_TEST_MAX_NUMBER_OF_RETURNED_MESSAGES, SEND_MESSAGES_COUNT}, messages::get_next_polling_nonce_from_messages, - test_env::TEST_ENV, + test_env::get_test_env, }; #[test] fn test_1_fails_if_a_non_registered_gateway_tries_to_get_messages() { // first, reset the canister - TEST_ENV.reset_canister_with_default_params(); + get_test_env().reset_canister_with_default_params(); let res = call_ws_get_messages( GATEWAY_2.deref(), @@ -213,14 +213,12 @@ fn test_5_registered_gateway_can_poll_messages_after_restart() { } mod helpers { - use std::ops::Deref; - use crate::{ tests::integration_tests::utils::{ actor::ws_send::AppMessage, certification::{is_message_body_valid, is_valid_certificate}, messages::decode_websocket_message, - test_env::TEST_ENV, + test_env::get_test_env, }, CanisterOutputMessage, ClientKey, }; @@ -261,7 +259,10 @@ mod helpers { let websocket_message = decode_websocket_message(&message.content); assert_eq!(websocket_message.client_key, *client_key); assert_eq!(websocket_message.sequence_num, *expected_sequence_number); - assert_eq!(websocket_message.timestamp, TEST_ENV.get_canister_time()); + assert_eq!( + websocket_message.timestamp, + get_test_env().get_canister_time() + ); assert_eq!(websocket_message.is_service_message, false); let decoded_content: AppMessage = decode_one(&websocket_message.content).unwrap(); assert_eq!( @@ -272,7 +273,7 @@ mod helpers { ); // check the certification - assert!(is_valid_certificate(TEST_ENV.deref(), cert, tree,)); + assert!(is_valid_certificate(&get_test_env(), cert, tree,)); assert!(is_message_body_valid(&message.key, &message.content, tree)); } } diff --git a/src/ic-websocket-cdk/src/tests/integration_tests/d_ws_close.rs b/src/ic-websocket-cdk/src/tests/integration_tests/d_ws_close.rs index ee3ba83..59b2441 100644 --- a/src/ic-websocket-cdk/src/tests/integration_tests/d_ws_close.rs +++ b/src/ic-websocket-cdk/src/tests/integration_tests/d_ws_close.rs @@ -5,13 +5,13 @@ use crate::{CanisterWsCloseArguments, CanisterWsCloseResult}; use super::utils::{ actor::{ws_close::call_ws_close, ws_open::call_ws_open_for_client_key_with_panic}, clients::{CLIENT_1_KEY, CLIENT_2_KEY, GATEWAY_1, GATEWAY_2}, - test_env::TEST_ENV, + test_env::get_test_env, }; #[test] fn test_1_fails_if_gateway_is_not_registered() { // first, reset the canister - TEST_ENV.reset_canister_with_default_params(); + get_test_env().reset_canister_with_default_params(); // second, open a connection for client 1 call_ws_open_for_client_key_with_panic(CLIENT_1_KEY.deref()); diff --git a/src/ic-websocket-cdk/src/tests/integration_tests/e_ws_send.rs b/src/ic-websocket-cdk/src/tests/integration_tests/e_ws_send.rs index f474451..de2a72b 100644 --- a/src/ic-websocket-cdk/src/tests/integration_tests/e_ws_send.rs +++ b/src/ic-websocket-cdk/src/tests/integration_tests/e_ws_send.rs @@ -8,13 +8,13 @@ use super::utils::{ ws_send::{call_ws_send, AppMessage}, }, clients::{CLIENT_1, CLIENT_1_KEY, CLIENT_2}, - test_env::TEST_ENV, + test_env::get_test_env, }; #[test] fn test_1_fails_if_sending_a_message_to_a_non_registered_client() { // first, reset the canister - TEST_ENV.reset_canister_with_default_params(); + get_test_env().reset_canister_with_default_params(); // second, open a connection for client 1 call_ws_open_for_client_key_with_panic(CLIENT_1_KEY.deref()); diff --git a/src/ic-websocket-cdk/src/tests/integration_tests/f_messages_acknowledgement.rs b/src/ic-websocket-cdk/src/tests/integration_tests/f_messages_acknowledgement.rs index f23d986..5e1e4c3 100644 --- a/src/ic-websocket-cdk/src/tests/integration_tests/f_messages_acknowledgement.rs +++ b/src/ic-websocket-cdk/src/tests/integration_tests/f_messages_acknowledgement.rs @@ -15,12 +15,12 @@ use super::utils::{ clients::{CLIENT_1_KEY, GATEWAY_1}, constants::{DEFAULT_TEST_KEEP_ALIVE_TIMEOUT_MS, DEFAULT_TEST_SEND_ACK_INTERVAL_MS}, messages::{create_websocket_message, encode_websocket_service_message_content}, - test_env::TEST_ENV, + test_env::get_test_env, }; #[test] fn test_1_client_should_receive_ack_messages() { - TEST_ENV.reset_canister_with_default_params(); + get_test_env().reset_canister_with_default_params(); // open a connection for client 1 let client_1_key = CLIENT_1_KEY.deref(); call_ws_open_for_client_key_with_panic(client_1_key); @@ -43,7 +43,7 @@ fn test_1_client_should_receive_ack_messages() { }, ); // advance the canister time to make sure the ack timer expires and an ack is sent - TEST_ENV.advance_canister_time_ms(DEFAULT_TEST_SEND_ACK_INTERVAL_MS); + get_test_env().advance_canister_time_ms(DEFAULT_TEST_SEND_ACK_INTERVAL_MS); let res = call_ws_get_messages( GATEWAY_1.deref(), @@ -56,10 +56,10 @@ fn test_1_client_should_receive_ack_messages() { fn test_2_client_is_removed_if_keep_alive_timeout_is_reached() { let client_1_key = CLIENT_1_KEY.deref(); // open a connection for client 1 - TEST_ENV.reset_canister_with_default_params(); + get_test_env().reset_canister_with_default_params(); call_ws_open_for_client_key_with_panic(client_1_key); // advance the canister time to make sure the ack timer expires and an ack is sent - TEST_ENV.advance_canister_time_ms(DEFAULT_TEST_SEND_ACK_INTERVAL_MS); + get_test_env().advance_canister_time_ms(DEFAULT_TEST_SEND_ACK_INTERVAL_MS); // get messages to check if the ack message has been set let res = call_ws_get_messages( GATEWAY_1.deref(), @@ -68,7 +68,7 @@ fn test_2_client_is_removed_if_keep_alive_timeout_is_reached() { helpers::check_ack_message_result(&res, client_1_key, 0, 2); // advance the canister time to make sure the keep alive timeout expires - TEST_ENV.advance_canister_time_ms(DEFAULT_TEST_KEEP_ALIVE_TIMEOUT_MS); + get_test_env().advance_canister_time_ms(DEFAULT_TEST_KEEP_ALIVE_TIMEOUT_MS); // to check if the client has been removed, we try to send the keep alive message late let res = call_ws_message( @@ -100,11 +100,11 @@ fn test_2_client_is_removed_if_keep_alive_timeout_is_reached() { #[test] fn test_3_client_is_not_removed_if_it_sends_a_keep_alive_before_timeout() { let client_1_key = CLIENT_1_KEY.deref(); - TEST_ENV.reset_canister_with_default_params(); + get_test_env().reset_canister_with_default_params(); // open a connection for client 1 call_ws_open_for_client_key_with_panic(client_1_key); // advance the canister time to make sure the ack timer expires and an ack is sent - TEST_ENV.advance_canister_time_ms(DEFAULT_TEST_SEND_ACK_INTERVAL_MS); + get_test_env().advance_canister_time_ms(DEFAULT_TEST_SEND_ACK_INTERVAL_MS); // get messages to check if the ack message has been set let res = call_ws_get_messages( GATEWAY_1.deref(), @@ -131,7 +131,7 @@ fn test_3_client_is_not_removed_if_it_sends_a_keep_alive_before_timeout() { }, ); // advance the canister time to make sure the keep alive timeout expires and the canister checks the keep alive - TEST_ENV.advance_canister_time_ms(DEFAULT_TEST_KEEP_ALIVE_TIMEOUT_MS); + get_test_env().advance_canister_time_ms(DEFAULT_TEST_KEEP_ALIVE_TIMEOUT_MS); // send a message to the canister to see the sequence number increasing in the ack message // and be sure that the client has not been removed call_ws_message_with_panic( @@ -141,7 +141,7 @@ fn test_3_client_is_not_removed_if_it_sends_a_keep_alive_before_timeout() { }, ); // wait to receive the next ack message - TEST_ENV.advance_canister_time_ms(DEFAULT_TEST_SEND_ACK_INTERVAL_MS); + get_test_env().advance_canister_time_ms(DEFAULT_TEST_SEND_ACK_INTERVAL_MS); let res = call_ws_get_messages( GATEWAY_1.deref(), CanisterWsGetMessagesArguments { nonce: 2 }, // skip the service open message and the fist ack message @@ -152,9 +152,9 @@ fn test_3_client_is_not_removed_if_it_sends_a_keep_alive_before_timeout() { #[test] fn test_4_client_is_not_removed_if_it_connects_while_canister_is_waiting_for_keep_alive() { let client_1_key = CLIENT_1_KEY.deref(); - TEST_ENV.reset_canister_with_default_params(); + get_test_env().reset_canister_with_default_params(); // advance the canister time to make sure the ack timer expires and the canister started the keep alive timer - TEST_ENV.advance_canister_time_ms(DEFAULT_TEST_SEND_ACK_INTERVAL_MS); + get_test_env().advance_canister_time_ms(DEFAULT_TEST_SEND_ACK_INTERVAL_MS); // open a connection for client 1 call_ws_open_for_client_key_with_panic(client_1_key); @@ -181,9 +181,9 @@ fn test_4_client_is_not_removed_if_it_connects_while_canister_is_waiting_for_kee assert_eq!(res, CanisterWsMessageResult::Ok(())); // wait for the keep alive timeout to expire - TEST_ENV.advance_canister_time_ms(DEFAULT_TEST_KEEP_ALIVE_TIMEOUT_MS); + get_test_env().advance_canister_time_ms(DEFAULT_TEST_KEEP_ALIVE_TIMEOUT_MS); // wait for the canister to send the next ack - TEST_ENV.advance_canister_time_ms( + get_test_env().advance_canister_time_ms( DEFAULT_TEST_SEND_ACK_INTERVAL_MS - DEFAULT_TEST_KEEP_ALIVE_TIMEOUT_MS, ); @@ -195,8 +195,6 @@ fn test_4_client_is_not_removed_if_it_connects_while_canister_is_waiting_for_kee } mod helpers { - use std::ops::Deref; - use crate::{ tests::integration_tests::utils::{ certification::{is_message_body_valid, is_valid_certificate}, @@ -204,7 +202,7 @@ mod helpers { decode_websocket_service_message_content, get_websocket_message_from_canister_message, }, - test_env::TEST_ENV, + test_env::get_test_env, }, CanisterAckMessageContent, CanisterOutputMessage, CanisterWsGetMessagesResult, ClientKey, WebsocketServiceMessageContent, @@ -227,7 +225,7 @@ mod helpers { expected_websocket_message_sequence_number, ); assert!(is_valid_certificate( - TEST_ENV.deref(), + &get_test_env(), &messages.cert, &messages.tree, )); @@ -254,7 +252,10 @@ mod helpers { websocket_message.sequence_num, expected_websocket_message_sequence_number ); - assert_eq!(websocket_message.timestamp, TEST_ENV.get_canister_time()); + assert_eq!( + websocket_message.timestamp, + get_test_env().get_canister_time() + ); assert_eq!( decode_websocket_service_message_content(&websocket_message.content), WebsocketServiceMessageContent::AckMessage(CanisterAckMessageContent { diff --git a/src/ic-websocket-cdk/src/tests/integration_tests/utils/actor.rs b/src/ic-websocket-cdk/src/tests/integration_tests/utils/actor.rs index c033592..c8b266c 100644 --- a/src/ic-websocket-cdk/src/tests/integration_tests/utils/actor.rs +++ b/src/ic-websocket-cdk/src/tests/integration_tests/utils/actor.rs @@ -1,7 +1,7 @@ use candid::{decode_one, encode_one, Principal}; use pocket_ic::WasmResult; -use super::test_env::TEST_ENV; +use super::test_env::get_test_env; pub mod ws_open { use std::ops::Deref; @@ -16,14 +16,10 @@ pub mod ws_open { /// # Panics /// if the call returns a [WasmResult::Reject]. pub fn call_ws_open(caller: &Principal, args: CanisterWsOpenArguments) -> CanisterWsOpenResult { - let res = TEST_ENV + let canister_id = get_test_env().canister_id; + let res = get_test_env() .pic - .update_call( - TEST_ENV.canister_id, - *caller, - "ws_open", - encode_one(args).unwrap(), - ) + .update_call(canister_id, *caller, "ws_open", encode_one(args).unwrap()) .expect("Failed to call canister"); match res { @@ -48,6 +44,18 @@ pub mod ws_open { }; call_ws_open_with_panic(&client_key.client_principal, args); } + + /// See [call_ws_open_with_panic]. + pub(crate) fn call_ws_open_for_client_key_and_gateway_with_panic( + client_key: &ClientKey, + gateway_principal: Principal, + ) { + let args = CanisterWsOpenArguments { + client_nonce: client_key.client_nonce, + gateway_principal, + }; + call_ws_open_with_panic(&client_key.client_principal, args); + } } pub mod ws_message { @@ -61,10 +69,11 @@ pub mod ws_message { caller: &Principal, args: CanisterWsMessageArguments, ) -> CanisterWsMessageResult { - let res = TEST_ENV + let canister_id = get_test_env().canister_id; + let res = get_test_env() .pic .update_call( - TEST_ENV.canister_id, + canister_id, *caller, "ws_message", encode_one(args).unwrap(), @@ -97,14 +106,10 @@ pub mod ws_close { caller: &Principal, args: CanisterWsCloseArguments, ) -> CanisterWsCloseResult { - let res = TEST_ENV + let canister_id = get_test_env().canister_id; + let res = get_test_env() .pic - .update_call( - TEST_ENV.canister_id, - *caller, - "ws_close", - encode_one(args).unwrap(), - ) + .update_call(canister_id, *caller, "ws_close", encode_one(args).unwrap()) .expect("Failed to call canister"); match res { @@ -123,10 +128,11 @@ pub mod ws_get_messages { caller: &Principal, args: CanisterWsGetMessagesArguments, ) -> CanisterWsGetMessagesResult { - let res = TEST_ENV + let canister_id = get_test_env().canister_id; + let res = get_test_env() .pic .query_call( - TEST_ENV.canister_id, + canister_id, *caller, "ws_get_messages", encode_one(args).unwrap(), @@ -161,10 +167,11 @@ pub mod ws_send { ) -> CanisterWsSendResult { let messages: Vec> = messages.iter().map(|m| encode_one(m).unwrap()).collect(); let args: WsSendArguments = (send_to_principal.clone(), messages); - let res = TEST_ENV + let canister_id = get_test_env().canister_id; + let res = get_test_env() .pic .update_call( - TEST_ENV.canister_id, + canister_id, Principal::anonymous(), "ws_send", encode_args(args).unwrap(), diff --git a/src/ic-websocket-cdk/src/tests/integration_tests/utils/test_env.rs b/src/ic-websocket-cdk/src/tests/integration_tests/utils/test_env.rs index 2cdf7d8..bf42312 100644 --- a/src/ic-websocket-cdk/src/tests/integration_tests/utils/test_env.rs +++ b/src/ic-websocket-cdk/src/tests/integration_tests/utils/test_env.rs @@ -1,4 +1,7 @@ -use std::time::{Duration, SystemTime}; +use std::{ + sync::{Mutex, MutexGuard}, + time::{Duration, SystemTime}, +}; use candid::Principal; use lazy_static::lazy_static; @@ -14,7 +17,11 @@ use super::{ }; lazy_static! { - pub static ref TEST_ENV: TestEnv = TestEnv::new(); + pub static ref TEST_ENV: Mutex = Mutex::new(TestEnv::new()); +} + +pub fn get_test_env<'a>() -> MutexGuard<'a, TestEnv> { + TEST_ENV.lock().unwrap() } pub struct TestEnv { @@ -67,13 +74,14 @@ impl TestEnv { } pub fn reset_canister( - &self, + &mut self, + authorized_gateways: AuthorizedGateways, max_number_or_returned_messages: u64, send_ack_interval_ms: u64, keep_alive_delay_ms: u64, ) { let arguments: CanisterInitArgs = ( - self.canister_init_args.0.clone(), + authorized_gateways, max_number_or_returned_messages, send_ack_interval_ms, keep_alive_delay_ms, @@ -81,20 +89,34 @@ impl TestEnv { let res = self.pic.reinstall_canister( self.canister_id, self.wasm_module.to_owned(), - candid::encode_args(arguments).unwrap(), + candid::encode_args(arguments.clone()).unwrap(), None, ); match res { - Ok(_) => {}, + Ok(_) => { + self.canister_init_args = arguments; + }, Err(err) => { panic!("Failed to reset canister: {:?}", err); }, } } - pub fn reset_canister_with_default_params(&self) { + /// Resets the canister using the default parameters. See [reset_canister]. + pub fn reset_canister_with_default_params(&mut self) { + self.reset_canister( + self.canister_init_args.0.clone(), + DEFAULT_TEST_MAX_NUMBER_OF_RETURNED_MESSAGES, + DEFAULT_TEST_SEND_ACK_INTERVAL_MS, + DEFAULT_TEST_KEEP_ALIVE_TIMEOUT_MS, + ); + } + + /// Resets the canister using the default parameters and the given gateways. See [reset_canister]. + pub fn reset_canister_with_gateways(&mut self, gateways: AuthorizedGateways) { self.reset_canister( + gateways, DEFAULT_TEST_MAX_NUMBER_OF_RETURNED_MESSAGES, DEFAULT_TEST_SEND_ACK_INTERVAL_MS, DEFAULT_TEST_KEEP_ALIVE_TIMEOUT_MS, From b071bea6ceb6958a22dcce5512b3d5c8f9ff9d18 Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Wed, 15 Nov 2023 12:24:35 +0100 Subject: [PATCH 13/27] feat: multiple gateways integration tests --- .../integration_tests/g_multiple_gateways.rs | 108 ++++++++++++++++++ .../src/tests/integration_tests/mod.rs | 1 + .../tests/integration_tests/utils/actor.rs | 10 +- 3 files changed, 117 insertions(+), 2 deletions(-) create mode 100644 src/ic-websocket-cdk/src/tests/integration_tests/g_multiple_gateways.rs diff --git a/src/ic-websocket-cdk/src/tests/integration_tests/g_multiple_gateways.rs b/src/ic-websocket-cdk/src/tests/integration_tests/g_multiple_gateways.rs new file mode 100644 index 0000000..8f0bdea --- /dev/null +++ b/src/ic-websocket-cdk/src/tests/integration_tests/g_multiple_gateways.rs @@ -0,0 +1,108 @@ +use std::ops::Deref; + +use crate::{ + tests::integration_tests::utils::messages::get_service_message_content_from_canister_message, + CanisterOutputCertifiedMessages, CanisterWsCloseArguments, CanisterWsGetMessagesArguments, + CanisterWsGetMessagesResult, WebsocketServiceMessageContent, +}; + +use super::utils::{ + actor::{ + ws_close::call_ws_close_with_panic, + ws_get_messages::call_ws_get_messages, + ws_open::call_ws_open_for_client_key_and_gateway_with_panic, + ws_send::{call_ws_send_with_panic, AppMessage}, + }, + clients::{CLIENT_1_KEY, GATEWAY_1, GATEWAY_2}, + test_env::get_test_env, +}; + +#[test] +fn test_1_client_can_switch_to_another_gateway() { + get_test_env().reset_canister_with_gateways(vec![GATEWAY_1.to_string(), GATEWAY_2.to_string()]); + // open a connection for client 1 + let client_1_key = CLIENT_1_KEY.deref(); + call_ws_open_for_client_key_and_gateway_with_panic(client_1_key, *GATEWAY_1); + // simulate canister sending messages to client + call_ws_send_with_panic( + &client_1_key.client_principal, + (0..10) + .map(|i| AppMessage { + text: format!("test{}", i), + }) + .collect(), + ); + + // test + // gateway 1 can poll the messages + let res_gateway_1 = call_ws_get_messages( + GATEWAY_1.deref(), + CanisterWsGetMessagesArguments { nonce: 0 }, + ); + match res_gateway_1 { + CanisterWsGetMessagesResult::Ok(CanisterOutputCertifiedMessages { messages, .. }) => { + assert_eq!(messages.len() as u64, 10 + 1); // +1 for the open service message + }, + _ => panic!("unexpected result"), + }; + // gateway 2 has no messages + let res_gateway_2 = call_ws_get_messages( + GATEWAY_2.deref(), + CanisterWsGetMessagesArguments { nonce: 0 }, + ); + match res_gateway_2 { + CanisterWsGetMessagesResult::Ok(CanisterOutputCertifiedMessages { messages, .. }) => { + assert_eq!(messages.len() as u64, 0); + }, + _ => panic!("unexpected result"), + }; + + // client disconnects, so gateway 1 closes the connection + call_ws_close_with_panic( + GATEWAY_1.deref(), + CanisterWsCloseArguments { + client_key: client_1_key.clone(), + }, + ); + // client reopens connection with gateway 2 + call_ws_open_for_client_key_and_gateway_with_panic(client_1_key, *GATEWAY_2); + // gateway 2 now has the open message + let res_gateway_2 = call_ws_get_messages( + GATEWAY_2.deref(), + CanisterWsGetMessagesArguments { nonce: 0 }, + ); + match res_gateway_2 { + CanisterWsGetMessagesResult::Ok(CanisterOutputCertifiedMessages { messages, .. }) => { + let first_message = &messages[0]; + assert_eq!(first_message.client_key, *client_1_key); + let open_message = get_service_message_content_from_canister_message(first_message); + match open_message { + WebsocketServiceMessageContent::OpenMessage(open_message) => { + assert_eq!(open_message.client_key, *client_1_key); + }, + _ => panic!("Expected OpenMessage"), + } + }, + _ => panic!("unexpected result"), + }; + + // simulate canister sending other messages to client + call_ws_send_with_panic( + &client_1_key.client_principal, + (0..10) + .map(|i| AppMessage { + text: format!("test{}", i + 10), + }) + .collect(), + ); + let res_gateway_2 = call_ws_get_messages( + GATEWAY_2.deref(), + CanisterWsGetMessagesArguments { nonce: 0 }, + ); + match res_gateway_2 { + CanisterWsGetMessagesResult::Ok(CanisterOutputCertifiedMessages { messages, .. }) => { + assert_eq!(messages.len() as u64, 10 + 1); // +1 for the open service message + }, + _ => panic!("unexpected result"), + }; +} diff --git a/src/ic-websocket-cdk/src/tests/integration_tests/mod.rs b/src/ic-websocket-cdk/src/tests/integration_tests/mod.rs index 85f18a5..8188d2a 100644 --- a/src/ic-websocket-cdk/src/tests/integration_tests/mod.rs +++ b/src/ic-websocket-cdk/src/tests/integration_tests/mod.rs @@ -13,3 +13,4 @@ mod c_ws_get_messages; mod d_ws_close; mod e_ws_send; mod f_messages_acknowledgement; +mod g_multiple_gateways; diff --git a/src/ic-websocket-cdk/src/tests/integration_tests/utils/actor.rs b/src/ic-websocket-cdk/src/tests/integration_tests/utils/actor.rs index c8b266c..a31a77e 100644 --- a/src/ic-websocket-cdk/src/tests/integration_tests/utils/actor.rs +++ b/src/ic-websocket-cdk/src/tests/integration_tests/utils/actor.rs @@ -117,6 +117,13 @@ pub mod ws_close { _ => panic!("Expected reply"), } } + + pub fn call_ws_close_with_panic(caller: &Principal, args: CanisterWsCloseArguments) { + match call_ws_close(caller, args) { + CanisterWsCloseResult::Ok(_) => {}, + CanisterWsCloseResult::Err(err) => panic!("failed ws_close: {:?}", err), + } + } } pub mod ws_get_messages { @@ -184,8 +191,7 @@ pub mod ws_send { } pub fn call_ws_send_with_panic(send_to_principal: &Principal, messages: Vec) { - let res = call_ws_send(send_to_principal, messages); - match res { + match call_ws_send(send_to_principal, messages) { CanisterWsSendResult::Ok(_) => {}, CanisterWsSendResult::Err(err) => panic!("failed ws_send: {:?}", err), } From b1becf1881f4a5beb8470d98c43c171e2089ab4c Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Wed, 15 Nov 2023 12:42:00 +0100 Subject: [PATCH 14/27] chore: test docs --- .github/workflows/tests.yml | 2 ++ scripts/test_canister.sh | 4 +++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 13e4307..0bb86a5 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -19,6 +19,8 @@ jobs: cache-on-failure: "true" - name: Run unit tests run: cargo test --package ic-websocket-cdk --lib -- tests::unit_tests + - name: Run doc tests + run: cargo test --package ic-websocket-cdk --doc - name: Prepare environment for integration tests run: | rustup target add wasm32-unknown-unknown diff --git a/scripts/test_canister.sh b/scripts/test_canister.sh index edcb452..c68e222 100755 --- a/scripts/test_canister.sh +++ b/scripts/test_canister.sh @@ -4,10 +4,12 @@ set -e # unit tests cargo test --package ic-websocket-cdk --lib -- tests::unit_tests +# doc tests +cargo test --package ic-websocket-cdk --doc # integration tests ./scripts/download-pocket-ic.sh ./scripts/build-test-canister.sh -POCKET_IC_BIN=$(pwd)/bin/pocket-ic RUST_BACKTRACE=1 cargo test --package ic-websocket-cdk --lib -- tests::integration_tests --test-threads 1 +POCKET_IC_BIN="$(pwd)/bin/pocket-ic" RUST_BACKTRACE=1 cargo test --package ic-websocket-cdk --lib -- tests::integration_tests --test-threads 1 From e56a6b2af00ec8e1a7dfecccad42e78b8e36fa38 Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Wed, 15 Nov 2023 14:29:27 +0100 Subject: [PATCH 15/27] perf: proptest strategies --- src/ic-websocket-cdk/src/tests/unit_tests/mod.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/ic-websocket-cdk/src/tests/unit_tests/mod.rs b/src/ic-websocket-cdk/src/tests/unit_tests/mod.rs index 5be4f2c..6348dd4 100644 --- a/src/ic-websocket-cdk/src/tests/unit_tests/mod.rs +++ b/src/ic-websocket-cdk/src/tests/unit_tests/mod.rs @@ -428,7 +428,7 @@ proptest! { } #[test] - fn test_get_messages_for_gateway_range_empty(messages_count in any::().prop_map(|c| c % 1000)) { + fn test_get_messages_for_gateway_range_empty(messages_count in 0..1000u64) { // Set up utils::initialize_params(); let gateway_principal = utils::generate_random_principal(); @@ -467,7 +467,7 @@ proptest! { } #[test] - fn test_get_messages_for_gateway_range_larger_than_max(gateway_principal in any::().prop_map(|_| utils::get_static_principal()), max_number_of_returned_messages in any::().prop_map(|c| c % 1000)) { + fn test_get_messages_for_gateway_range_larger_than_max(gateway_principal in any::().prop_map(|_| utils::get_static_principal()), max_number_of_returned_messages in 0..1000usize) { // Set up set_params(WsInitParams { max_number_of_returned_messages, @@ -498,7 +498,7 @@ proptest! { } #[test] - fn test_get_messages_for_gateway_initial_nonce(gateway_principal in any::().prop_map(|_| utils::get_static_principal()), messages_count in any::().prop_map(|c| c % 100), max_number_of_returned_messages in any::().prop_map(|c| c % 1000)) { + fn test_get_messages_for_gateway_initial_nonce(gateway_principal in any::().prop_map(|_| utils::get_static_principal()), messages_count in 0..100u64, max_number_of_returned_messages in 0..1000usize) { // Set up set_params(WsInitParams { max_number_of_returned_messages, @@ -524,7 +524,7 @@ proptest! { } #[test] - fn test_get_messages_for_gateway(gateway_principal in any::().prop_map(|_| utils::get_static_principal()), messages_count in any::().prop_map(|c| c % 100)) { + fn test_get_messages_for_gateway(gateway_principal in any::().prop_map(|_| utils::get_static_principal()), messages_count in 0..100u64) { // Set up REGISTERED_GATEWAYS.with(|n| n.borrow_mut().insert(gateway_principal, RegisteredGateway::new())); From 15c63ba0b2e87c3195818a8ecb3a3a0b94459656 Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Wed, 15 Nov 2023 14:30:49 +0100 Subject: [PATCH 16/27] perf: destructure CanisterOutputCertifiedMessages --- .../src/tests/integration_tests/a_ws_open.rs | 9 ++--- .../f_messages_acknowledgement.rs | 36 +++++++++---------- 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/src/ic-websocket-cdk/src/tests/integration_tests/a_ws_open.rs b/src/ic-websocket-cdk/src/tests/integration_tests/a_ws_open.rs index bc91aa1..b6f9019 100644 --- a/src/ic-websocket-cdk/src/tests/integration_tests/a_ws_open.rs +++ b/src/ic-websocket-cdk/src/tests/integration_tests/a_ws_open.rs @@ -1,8 +1,9 @@ use std::ops::Deref; use crate::{ - CanisterOutputMessage, CanisterWsGetMessagesArguments, CanisterWsGetMessagesResult, - CanisterWsOpenArguments, CanisterWsOpenResult, ClientKey, WebsocketServiceMessageContent, + CanisterOutputCertifiedMessages, CanisterOutputMessage, CanisterWsGetMessagesArguments, + CanisterWsGetMessagesResult, CanisterWsOpenArguments, CanisterWsOpenResult, ClientKey, + WebsocketServiceMessageContent, }; use candid::Principal; @@ -56,8 +57,8 @@ fn test_3_should_open_a_connection() { ); match msgs { - CanisterWsGetMessagesResult::Ok(messages) => { - let first_message = &messages.messages[0]; + CanisterWsGetMessagesResult::Ok(CanisterOutputCertifiedMessages { messages, .. }) => { + let first_message = &messages[0]; assert_eq!(first_message.client_key, *client_1_key); let open_message = get_service_message_content_from_canister_message(first_message); match open_message { diff --git a/src/ic-websocket-cdk/src/tests/integration_tests/f_messages_acknowledgement.rs b/src/ic-websocket-cdk/src/tests/integration_tests/f_messages_acknowledgement.rs index 5e1e4c3..58ad36b 100644 --- a/src/ic-websocket-cdk/src/tests/integration_tests/f_messages_acknowledgement.rs +++ b/src/ic-websocket-cdk/src/tests/integration_tests/f_messages_acknowledgement.rs @@ -1,9 +1,9 @@ use std::ops::Deref; use crate::{ - CanisterWsGetMessagesArguments, CanisterWsGetMessagesResult, CanisterWsMessageArguments, - CanisterWsMessageResult, CanisterWsSendResult, ClientKeepAliveMessageContent, - WebsocketServiceMessageContent, + CanisterOutputCertifiedMessages, CanisterWsGetMessagesArguments, CanisterWsGetMessagesResult, + CanisterWsMessageArguments, CanisterWsMessageResult, CanisterWsSendResult, + ClientKeepAliveMessageContent, WebsocketServiceMessageContent, }; use super::utils::{ @@ -30,8 +30,8 @@ fn test_1_client_should_receive_ack_messages() { CanisterWsGetMessagesArguments { nonce: 1 }, // skip the service open message ); match res { - CanisterWsGetMessagesResult::Ok(messages) => { - assert_eq!(messages.messages.len(), 0); + CanisterWsGetMessagesResult::Ok(CanisterOutputCertifiedMessages { messages, .. }) => { + assert_eq!(messages.len(), 0); }, _ => panic!("unexpected result"), } @@ -164,8 +164,8 @@ fn test_4_client_is_not_removed_if_it_connects_while_canister_is_waiting_for_kee CanisterWsGetMessagesArguments { nonce: 1 }, // skip the service open message ); match res { - CanisterWsGetMessagesResult::Ok(messages) => { - assert_eq!(messages.messages.len(), 0); + CanisterWsGetMessagesResult::Ok(CanisterOutputCertifiedMessages { messages, .. }) => { + assert_eq!(messages.len(), 0); }, _ => panic!("unexpected result"), } @@ -204,8 +204,8 @@ mod helpers { }, test_env::get_test_env, }, - CanisterAckMessageContent, CanisterOutputMessage, CanisterWsGetMessagesResult, ClientKey, - WebsocketServiceMessageContent, + CanisterAckMessageContent, CanisterOutputCertifiedMessages, CanisterOutputMessage, + CanisterWsGetMessagesResult, ClientKey, WebsocketServiceMessageContent, }; pub(crate) fn check_ack_message_result( @@ -215,24 +215,24 @@ mod helpers { expected_websocket_message_sequence_number: u64, ) { match res { - CanisterWsGetMessagesResult::Ok(messages) => { - assert_eq!(messages.messages.len(), 1); - let ack_message = messages.messages.first().unwrap(); + CanisterWsGetMessagesResult::Ok(CanisterOutputCertifiedMessages { + messages, + cert, + tree, + }) => { + assert_eq!(messages.len(), 1); + let ack_message = messages.first().unwrap(); check_ack_message_in_messages( ack_message, receiver_client_key, expected_ack_sequence_number, expected_websocket_message_sequence_number, ); - assert!(is_valid_certificate( - &get_test_env(), - &messages.cert, - &messages.tree, - )); + assert!(is_valid_certificate(&get_test_env(), &cert, &tree,)); assert!(is_message_body_valid( &ack_message.key, &ack_message.content, - &messages.tree + &tree )); }, _ => panic!("unexpected result"), From 17d6d2819d7fd4397a0ec7188d4bf25b85ecb59f Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Wed, 15 Nov 2023 14:50:48 +0100 Subject: [PATCH 17/27] perf: typos --- .../src/tests/integration_tests/utils/actor.rs | 4 ++-- src/ic-websocket-cdk/src/tests/unit_tests/mod.rs | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/ic-websocket-cdk/src/tests/integration_tests/utils/actor.rs b/src/ic-websocket-cdk/src/tests/integration_tests/utils/actor.rs index a31a77e..75e9699 100644 --- a/src/ic-websocket-cdk/src/tests/integration_tests/utils/actor.rs +++ b/src/ic-websocket-cdk/src/tests/integration_tests/utils/actor.rs @@ -144,7 +144,7 @@ pub mod ws_get_messages { "ws_get_messages", encode_one(args).unwrap(), ) - .expect("Failed to call counter canister"); + .expect("Failed to call canister"); match res { WasmResult::Reply(bytes) => decode_one(&bytes).unwrap(), @@ -183,7 +183,7 @@ pub mod ws_send { "ws_send", encode_args(args).unwrap(), ) - .expect("Failed to call counter canister"); + .expect("Failed to call canister"); match res { WasmResult::Reply(bytes) => decode_one(&bytes).unwrap(), _ => panic!("Expected reply"), diff --git a/src/ic-websocket-cdk/src/tests/unit_tests/mod.rs b/src/ic-websocket-cdk/src/tests/unit_tests/mod.rs index 6348dd4..cfced7a 100644 --- a/src/ic-websocket-cdk/src/tests/unit_tests/mod.rs +++ b/src/ic-websocket-cdk/src/tests/unit_tests/mod.rs @@ -573,6 +573,6 @@ proptest! { }; let serialized_message = websocket_message.cbor_serialize(); - assert!(serialized_message.is_ok()); // not so useful as a test + prop_assert!(serialized_message.is_ok()); // not so useful as a test } } From a6f21fc416d2909e3e934db8e79c56d9c292e16a Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Wed, 15 Nov 2023 14:53:12 +0100 Subject: [PATCH 18/27] perf: ClientPrincipal for ws_send --- .../src/tests/integration_tests/utils/actor.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/ic-websocket-cdk/src/tests/integration_tests/utils/actor.rs b/src/ic-websocket-cdk/src/tests/integration_tests/utils/actor.rs index 75e9699..da13b9f 100644 --- a/src/ic-websocket-cdk/src/tests/integration_tests/utils/actor.rs +++ b/src/ic-websocket-cdk/src/tests/integration_tests/utils/actor.rs @@ -154,7 +154,7 @@ pub mod ws_get_messages { } pub mod ws_send { - use crate::CanisterWsSendResult; + use crate::{CanisterWsSendResult, ClientPrincipal}; use candid::{encode_args, CandidType}; use serde::{Deserialize, Serialize}; @@ -165,11 +165,11 @@ pub mod ws_send { pub text: String, } - /// (`Principal`, `Vec>`) - type WsSendArguments = (Principal, Vec>); + /// (`ClientPrincipal`, `Vec>`) + type WsSendArguments = (ClientPrincipal, Vec>); pub fn call_ws_send( - send_to_principal: &Principal, + send_to_principal: &ClientPrincipal, messages: Vec, ) -> CanisterWsSendResult { let messages: Vec> = messages.iter().map(|m| encode_one(m).unwrap()).collect(); @@ -190,7 +190,7 @@ pub mod ws_send { } } - pub fn call_ws_send_with_panic(send_to_principal: &Principal, messages: Vec) { + pub fn call_ws_send_with_panic(send_to_principal: &ClientPrincipal, messages: Vec) { match call_ws_send(send_to_principal, messages) { CanisterWsSendResult::Ok(_) => {}, CanisterWsSendResult::Err(err) => panic!("failed ws_send: {:?}", err), From 6f2081013c6039d5b4ad612f7b565648b2a151c6 Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Wed, 15 Nov 2023 15:51:44 +0100 Subject: [PATCH 19/27] feat: prop tests in integration tests --- src/ic-websocket-cdk/src/tests/common.rs | 32 +++ .../src/tests/integration_tests/a_ws_open.rs | 75 +++---- .../tests/integration_tests/b_ws_message.rs | 9 +- .../integration_tests/c_ws_get_messages.rs | 143 +++++++------ .../src/tests/integration_tests/e_ws_send.rs | 8 +- .../integration_tests/g_multiple_gateways.rs | 188 +++++++++--------- .../tests/integration_tests/utils/clients.rs | 26 +-- src/ic-websocket-cdk/src/tests/mod.rs | 2 + .../src/tests/unit_tests/mod.rs | 91 ++++----- .../src/tests/unit_tests/utils.rs | 29 --- 10 files changed, 321 insertions(+), 282 deletions(-) create mode 100644 src/ic-websocket-cdk/src/tests/common.rs diff --git a/src/ic-websocket-cdk/src/tests/common.rs b/src/ic-websocket-cdk/src/tests/common.rs new file mode 100644 index 0000000..c73698e --- /dev/null +++ b/src/ic-websocket-cdk/src/tests/common.rs @@ -0,0 +1,32 @@ +use candid::Principal; +use ic_agent::{identity::BasicIdentity, Identity}; +use ring::signature::Ed25519KeyPair; + +use crate::ClientKey; + +fn generate_random_key_pair() -> Ed25519KeyPair { + let rng = ring::rand::SystemRandom::new(); + let key_pair = Ed25519KeyPair::generate_pkcs8(&rng).expect("Could not generate a key pair."); + Ed25519KeyPair::from_pkcs8(key_pair.as_ref()).expect("Could not read the key pair.") +} + +pub fn generate_random_principal() -> Principal { + let key_pair = generate_random_key_pair(); + let identity = BasicIdentity::from_key_pair(key_pair); + + // workaround to keep the principal in the version of candid used by the canister + Principal::from_text(identity.sender().unwrap().to_text()).unwrap() +} + +pub fn get_static_principal() -> Principal { + Principal::from_text("wnkwv-wdqb5-7wlzr-azfpw-5e5n5-dyxrf-uug7x-qxb55-mkmpa-5jqik-tqe").unwrap() + // a random static but valid principal +} + +pub(super) fn get_random_client_key() -> ClientKey { + ClientKey::new( + generate_random_principal(), + // a random nonce + rand::random(), + ) +} diff --git a/src/ic-websocket-cdk/src/tests/integration_tests/a_ws_open.rs b/src/ic-websocket-cdk/src/tests/integration_tests/a_ws_open.rs index b6f9019..fab9ed4 100644 --- a/src/ic-websocket-cdk/src/tests/integration_tests/a_ws_open.rs +++ b/src/ic-websocket-cdk/src/tests/integration_tests/a_ws_open.rs @@ -1,3 +1,4 @@ +use proptest::prelude::*; use std::ops::Deref; use crate::{ @@ -9,7 +10,7 @@ use candid::Principal; use super::utils::{ actor::{ws_get_messages::call_ws_get_messages, ws_open::call_ws_open}, - clients::{generate_random_client_nonce, CLIENT_1, CLIENT_1_KEY, GATEWAY_1}, + clients::{generate_random_client_nonce, CLIENT_1_KEY, GATEWAY_1}, messages::get_service_message_content_from_canister_message, }; @@ -48,7 +49,7 @@ fn test_3_should_open_a_connection() { client_nonce: client_1_key.client_nonce, gateway_principal: GATEWAY_1.deref().to_owned(), }; - let res = call_ws_open(CLIENT_1.deref(), args); + let res = call_ws_open(&client_1_key.client_principal, args); assert_eq!(res, CanisterWsOpenResult::Ok(())); let msgs = call_ws_get_messages( @@ -79,7 +80,7 @@ fn test_4_fails_for_a_client_with_the_same_nonce() { client_nonce: client_1_key.client_nonce, gateway_principal: GATEWAY_1.deref().to_owned(), }; - let res = call_ws_open(CLIENT_1.deref(), args); + let res = call_ws_open(&client_1_key.client_principal, args); assert_eq!( res, CanisterWsOpenResult::Err(String::from(format!( @@ -88,41 +89,43 @@ fn test_4_fails_for_a_client_with_the_same_nonce() { ); } -#[test] -fn test_5_should_open_a_connection_for_the_same_client_with_a_different_nonce() { - let client_key = ClientKey { - client_principal: CLIENT_1_KEY.deref().client_principal, - client_nonce: generate_random_client_nonce(), - }; - let args = CanisterWsOpenArguments { - client_nonce: client_key.client_nonce, - gateway_principal: GATEWAY_1.deref().to_owned(), - }; - let res = call_ws_open(&client_key.client_principal, args); - assert_eq!(res, CanisterWsOpenResult::Ok(())); +proptest! { + #[test] + fn test_5_should_open_a_connection_for_the_same_client_with_a_different_nonce(test_client_nonce in any::().prop_map(|_| generate_random_client_nonce())) { + let client_key = ClientKey { + client_principal: CLIENT_1_KEY.deref().client_principal, + client_nonce: test_client_nonce, + }; + let args = CanisterWsOpenArguments { + client_nonce: client_key.client_nonce, + gateway_principal: GATEWAY_1.deref().to_owned(), + }; + let res = call_ws_open(&client_key.client_principal, args); + assert_eq!(res, CanisterWsOpenResult::Ok(())); - let msgs = call_ws_get_messages( - GATEWAY_1.deref(), - CanisterWsGetMessagesArguments { nonce: 0 }, - ); + let msgs = call_ws_get_messages( + GATEWAY_1.deref(), + CanisterWsGetMessagesArguments { nonce: 0 }, + ); - match msgs { - CanisterWsGetMessagesResult::Ok(messages) => { - let service_message_for_client = messages - .messages - .iter() - .filter(|msg| msg.client_key == client_key) - .collect::>()[0]; + match msgs { + CanisterWsGetMessagesResult::Ok(messages) => { + let service_message_for_client = messages + .messages + .iter() + .filter(|msg| msg.client_key == client_key) + .collect::>()[0]; - let open_message = - get_service_message_content_from_canister_message(service_message_for_client); - match open_message { - WebsocketServiceMessageContent::OpenMessage(open_message) => { - assert_eq!(open_message.client_key, client_key); - }, - _ => panic!("Expected OpenMessage"), - } - }, - _ => panic!("Expected Ok result"), + let open_message = + get_service_message_content_from_canister_message(service_message_for_client); + match open_message { + WebsocketServiceMessageContent::OpenMessage(open_message) => { + assert_eq!(open_message.client_key, client_key); + }, + _ => panic!("Expected OpenMessage"), + } + }, + _ => panic!("Expected Ok result"), + } } } diff --git a/src/ic-websocket-cdk/src/tests/integration_tests/b_ws_message.rs b/src/ic-websocket-cdk/src/tests/integration_tests/b_ws_message.rs index a4ba0be..36d8887 100644 --- a/src/ic-websocket-cdk/src/tests/integration_tests/b_ws_message.rs +++ b/src/ic-websocket-cdk/src/tests/integration_tests/b_ws_message.rs @@ -1,13 +1,14 @@ use std::ops::Deref; use crate::{ - CanisterAckMessageContent, CanisterWsMessageArguments, CanisterWsMessageResult, - ClientKeepAliveMessageContent, ClientKey, WebsocketServiceMessageContent, + tests::common::generate_random_principal, CanisterAckMessageContent, + CanisterWsMessageArguments, CanisterWsMessageResult, ClientKeepAliveMessageContent, ClientKey, + WebsocketServiceMessageContent, }; use super::utils::{ actor::{ws_message::call_ws_message, ws_open::call_ws_open_for_client_key_with_panic}, - clients::{generate_random_client_nonce, CLIENT_1_KEY, CLIENT_2, CLIENT_2_KEY}, + clients::{generate_random_client_nonce, CLIENT_1_KEY, CLIENT_2_KEY}, messages::{create_websocket_message, encode_websocket_service_message_content}, test_env::get_test_env, }; @@ -46,7 +47,7 @@ fn test_2_fails_if_client_sends_a_message_with_a_different_client_key() { CanisterWsMessageArguments { msg: create_websocket_message( &ClientKey { - client_principal: *CLIENT_2.deref(), + client_principal: generate_random_principal(), ..client_1_key.clone() }, 0, diff --git a/src/ic-websocket-cdk/src/tests/integration_tests/c_ws_get_messages.rs b/src/ic-websocket-cdk/src/tests/integration_tests/c_ws_get_messages.rs index 9e11dde..731efa1 100644 --- a/src/ic-websocket-cdk/src/tests/integration_tests/c_ws_get_messages.rs +++ b/src/ic-websocket-cdk/src/tests/integration_tests/c_ws_get_messages.rs @@ -1,3 +1,4 @@ +use proptest::prelude::*; use std::ops::Deref; use crate::{ @@ -34,7 +35,7 @@ fn test_1_fails_if_a_non_registered_gateway_tries_to_get_messages() { } #[test] -fn test_2_registered_gateway_should_receive_empty_messages_if_no_messages_are_available() { +fn test_2_registered_gateway_should_receive_empty_messages_if_no_messages_initial_nonce() { let res = call_ws_get_messages( GATEWAY_1.deref(), CanisterWsGetMessagesArguments { nonce: 0 }, @@ -47,79 +48,102 @@ fn test_2_registered_gateway_should_receive_empty_messages_if_no_messages_are_av tree: vec![], }), ); - - // test also with a high nonce to make sure the indexes are calculated correctly in the canister - let res = call_ws_get_messages( - GATEWAY_1.deref(), - CanisterWsGetMessagesArguments { nonce: 100 }, - ); - assert_eq!( - res, - CanisterWsGetMessagesResult::Ok(CanisterOutputCertifiedMessages { - messages: vec![], - cert: vec![], - tree: vec![], - }), - ); } -#[test] -fn test_3_registered_gateway_can_receive_correct_amount_of_messages() { - // first, register client 1 - let client_1_key = CLIENT_1_KEY.deref(); - call_ws_open_for_client_key_with_panic(client_1_key); - // second, send a batch of messages to the client - call_ws_send_with_panic( - &client_1_key.client_principal, - (0..SEND_MESSAGES_COUNT) - .map(|i| AppMessage { - text: format!("test{}", i), +proptest! { + #![proptest_config(ProptestConfig::with_cases(50))] + + #[test] + fn test_3_registered_gateway_should_receive_empty_messages_if_no_messages_are_available(test_nonce in any::()) { + // test also with a high nonce to make sure the indexes are calculated correctly in the canister + let res = call_ws_get_messages( + GATEWAY_1.deref(), + CanisterWsGetMessagesArguments { nonce: test_nonce }, + ); + prop_assert_eq!( + res, + CanisterWsGetMessagesResult::Ok(CanisterOutputCertifiedMessages { + messages: vec![], + cert: vec![], + tree: vec![], }) - .collect(), - ); + ); + } + + #[test] + fn test_4_registered_gateway_can_receive_correct_amount_of_messages(test_send_messages_count in 1..100u64) { + // first, reset the canister + get_test_env().reset_canister_with_default_params(); + // second, register client 1 + let client_1_key = CLIENT_1_KEY.deref(); + call_ws_open_for_client_key_with_panic(client_1_key); + // third, send a batch of messages to the client + call_ws_send_with_panic( + &client_1_key.client_principal, + (0..test_send_messages_count) + .map(|i| AppMessage { + text: format!("test{}", i), + }) + .collect(), + ); - // now we can start testing - let messages_count = SEND_MESSAGES_COUNT + 1; // +1 for the open service message - for i in 0..messages_count { + // now we can start testing + let messages_count = test_send_messages_count + 1; // +1 for the open service message + for i in 0..messages_count { + let res = call_ws_get_messages( + GATEWAY_1.deref(), + CanisterWsGetMessagesArguments { nonce: i }, + ); + match res { + CanisterWsGetMessagesResult::Ok(CanisterOutputCertifiedMessages { + messages, .. + }) => { + prop_assert_eq!( + messages.len() as u64, + if (messages_count - i) > DEFAULT_TEST_MAX_NUMBER_OF_RETURNED_MESSAGES { + DEFAULT_TEST_MAX_NUMBER_OF_RETURNED_MESSAGES + } else { + messages_count - i + } + ); + }, + _ => panic!("unexpected result"), + }; + } + + // try to get more messages than available let res = call_ws_get_messages( GATEWAY_1.deref(), - CanisterWsGetMessagesArguments { nonce: i }, + CanisterWsGetMessagesArguments { + nonce: messages_count, + }, ); + prop_assert!(matches!(res, CanisterWsGetMessagesResult::Ok(_))); match res { - CanisterWsGetMessagesResult::Ok(CanisterOutputCertifiedMessages { - messages, .. - }) => { - assert_eq!( - messages.len() as u64, - if messages_count - i > DEFAULT_TEST_MAX_NUMBER_OF_RETURNED_MESSAGES { - DEFAULT_TEST_MAX_NUMBER_OF_RETURNED_MESSAGES - } else { - messages_count - i - }, - ); + CanisterWsGetMessagesResult::Ok(CanisterOutputCertifiedMessages { messages, .. }) => { + prop_assert_eq!(messages.len(), 0); }, _ => panic!("unexpected result"), }; } - - // try to get more messages than available - let res = call_ws_get_messages( - GATEWAY_1.deref(), - CanisterWsGetMessagesArguments { - nonce: messages_count, - }, - ); - match res { - CanisterWsGetMessagesResult::Ok(CanisterOutputCertifiedMessages { messages, .. }) => { - assert_eq!(messages.len(), 0); - }, - _ => panic!("unexpected result"), - }; } #[test] -fn test_4_registered_gateway_can_receive_certified_messages() { +fn test_5_registered_gateway_can_receive_certified_messages() { + // first, reset the canister + get_test_env().reset_canister_with_default_params(); + // second, register client 1 let client_1_key = CLIENT_1_KEY.deref(); + call_ws_open_for_client_key_with_panic(client_1_key); + // third, send a batch of messages to the client + call_ws_send_with_panic( + &client_1_key.client_principal, + (0..SEND_MESSAGES_COUNT) + .map(|i| AppMessage { + text: format!("test{}", i), + }) + .collect(), + ); // first batch of messages match call_ws_get_messages( @@ -148,7 +172,6 @@ fn test_4_registered_gateway_can_receive_certified_messages() { ); let next_polling_nonce = get_next_polling_nonce_from_messages(first_batch_messages); - println!("next polling nonce: {}", next_polling_nonce); // second batch of messages match call_ws_get_messages( GATEWAY_1.deref(), @@ -183,7 +206,7 @@ fn test_4_registered_gateway_can_receive_certified_messages() { } #[test] -fn test_5_registered_gateway_can_poll_messages_after_restart() { +fn test_6_registered_gateway_can_poll_messages_after_restart() { let res = call_ws_get_messages( GATEWAY_1.deref(), CanisterWsGetMessagesArguments { nonce: 0 }, // start polling from the beginning, as if the gateway restarted diff --git a/src/ic-websocket-cdk/src/tests/integration_tests/e_ws_send.rs b/src/ic-websocket-cdk/src/tests/integration_tests/e_ws_send.rs index de2a72b..cbb0d9e 100644 --- a/src/ic-websocket-cdk/src/tests/integration_tests/e_ws_send.rs +++ b/src/ic-websocket-cdk/src/tests/integration_tests/e_ws_send.rs @@ -1,13 +1,13 @@ use std::ops::Deref; -use crate::CanisterWsSendResult; +use crate::{tests::integration_tests::utils::clients::CLIENT_2_KEY, CanisterWsSendResult}; use super::utils::{ actor::{ ws_open::call_ws_open_for_client_key_with_panic, ws_send::{call_ws_send, AppMessage}, }, - clients::{CLIENT_1, CLIENT_1_KEY, CLIENT_2}, + clients::CLIENT_1_KEY, test_env::get_test_env, }; @@ -19,7 +19,7 @@ fn test_1_fails_if_sending_a_message_to_a_non_registered_client() { call_ws_open_for_client_key_with_panic(CLIENT_1_KEY.deref()); // finally, we can start testing - let client_2_principal = CLIENT_2.deref(); + let client_2_principal = &CLIENT_2_KEY.client_principal; let res = call_ws_send( client_2_principal, vec![AppMessage { @@ -37,7 +37,7 @@ fn test_1_fails_if_sending_a_message_to_a_non_registered_client() { #[test] fn test_2_should_send_a_message_to_a_registered_client() { let res = call_ws_send( - CLIENT_1.deref(), + &CLIENT_1_KEY.client_principal, vec![AppMessage { text: String::from("test"), }], diff --git a/src/ic-websocket-cdk/src/tests/integration_tests/g_multiple_gateways.rs b/src/ic-websocket-cdk/src/tests/integration_tests/g_multiple_gateways.rs index 8f0bdea..96b1e44 100644 --- a/src/ic-websocket-cdk/src/tests/integration_tests/g_multiple_gateways.rs +++ b/src/ic-websocket-cdk/src/tests/integration_tests/g_multiple_gateways.rs @@ -1,9 +1,12 @@ -use std::ops::Deref; +use proptest::prelude::*; use crate::{ - tests::integration_tests::utils::messages::get_service_message_content_from_canister_message, + tests::{ + common, + integration_tests::utils::messages::get_service_message_content_from_canister_message, + }, CanisterOutputCertifiedMessages, CanisterWsCloseArguments, CanisterWsGetMessagesArguments, - CanisterWsGetMessagesResult, WebsocketServiceMessageContent, + CanisterWsGetMessagesResult, GatewayPrincipal, WebsocketServiceMessageContent, }; use super::utils::{ @@ -13,96 +16,103 @@ use super::utils::{ ws_open::call_ws_open_for_client_key_and_gateway_with_panic, ws_send::{call_ws_send_with_panic, AppMessage}, }, - clients::{CLIENT_1_KEY, GATEWAY_1, GATEWAY_2}, test_env::get_test_env, }; -#[test] -fn test_1_client_can_switch_to_another_gateway() { - get_test_env().reset_canister_with_gateways(vec![GATEWAY_1.to_string(), GATEWAY_2.to_string()]); - // open a connection for client 1 - let client_1_key = CLIENT_1_KEY.deref(); - call_ws_open_for_client_key_and_gateway_with_panic(client_1_key, *GATEWAY_1); - // simulate canister sending messages to client - call_ws_send_with_panic( - &client_1_key.client_principal, - (0..10) - .map(|i| AppMessage { - text: format!("test{}", i), - }) - .collect(), - ); +proptest! { + #![proptest_config(ProptestConfig::with_cases(50))] + + #[test] + fn test_1_client_can_switch_to_another_gateway( + ref client_key in any::().prop_map(|_| common::get_random_client_key()), + gateways in any::>().prop_map(|_| (0..2).map(|_| common::generate_random_principal()).collect::>()), + ) { + let first_gateway = &gateways[0]; + let second_gateway = &gateways[1]; + get_test_env().reset_canister_with_gateways(gateways.iter().map(|g| g.to_string()).collect()); + // open a connection for client + call_ws_open_for_client_key_and_gateway_with_panic(&client_key, *first_gateway); + // simulate canister sending messages to client + call_ws_send_with_panic( + &client_key.client_principal, + (0..10) + .map(|i| AppMessage { + text: format!("test{}", i), + }) + .collect(), + ); - // test - // gateway 1 can poll the messages - let res_gateway_1 = call_ws_get_messages( - GATEWAY_1.deref(), - CanisterWsGetMessagesArguments { nonce: 0 }, - ); - match res_gateway_1 { - CanisterWsGetMessagesResult::Ok(CanisterOutputCertifiedMessages { messages, .. }) => { - assert_eq!(messages.len() as u64, 10 + 1); // +1 for the open service message - }, - _ => panic!("unexpected result"), - }; - // gateway 2 has no messages - let res_gateway_2 = call_ws_get_messages( - GATEWAY_2.deref(), - CanisterWsGetMessagesArguments { nonce: 0 }, - ); - match res_gateway_2 { - CanisterWsGetMessagesResult::Ok(CanisterOutputCertifiedMessages { messages, .. }) => { - assert_eq!(messages.len() as u64, 0); - }, - _ => panic!("unexpected result"), - }; + // test + // gateway 1 can poll the messages + let res_gateway_1 = call_ws_get_messages( + first_gateway, + CanisterWsGetMessagesArguments { nonce: 0 }, + ); + match res_gateway_1 { + CanisterWsGetMessagesResult::Ok(CanisterOutputCertifiedMessages { messages, .. }) => { + prop_assert_eq!(messages.len() as u64, 10 + 1); // +1 for the open service message + }, + _ => panic!("unexpected result"), + }; + // gateway 2 has no messages + let res_gateway_2 = call_ws_get_messages( + second_gateway, + CanisterWsGetMessagesArguments { nonce: 0 }, + ); + match res_gateway_2 { + CanisterWsGetMessagesResult::Ok(CanisterOutputCertifiedMessages { messages, .. }) => { + prop_assert_eq!(messages.len() as u64, 0); + }, + _ => panic!("unexpected result"), + }; - // client disconnects, so gateway 1 closes the connection - call_ws_close_with_panic( - GATEWAY_1.deref(), - CanisterWsCloseArguments { - client_key: client_1_key.clone(), - }, - ); - // client reopens connection with gateway 2 - call_ws_open_for_client_key_and_gateway_with_panic(client_1_key, *GATEWAY_2); - // gateway 2 now has the open message - let res_gateway_2 = call_ws_get_messages( - GATEWAY_2.deref(), - CanisterWsGetMessagesArguments { nonce: 0 }, - ); - match res_gateway_2 { - CanisterWsGetMessagesResult::Ok(CanisterOutputCertifiedMessages { messages, .. }) => { - let first_message = &messages[0]; - assert_eq!(first_message.client_key, *client_1_key); - let open_message = get_service_message_content_from_canister_message(first_message); - match open_message { - WebsocketServiceMessageContent::OpenMessage(open_message) => { - assert_eq!(open_message.client_key, *client_1_key); - }, - _ => panic!("Expected OpenMessage"), - } - }, - _ => panic!("unexpected result"), - }; + // client disconnects, so gateway 1 closes the connection + call_ws_close_with_panic( + first_gateway, + CanisterWsCloseArguments { + client_key: client_key.clone(), + }, + ); + // client reopens connection with gateway 2 + call_ws_open_for_client_key_and_gateway_with_panic(&client_key, *second_gateway); + // gateway 2 now has the open message + let res_gateway_2 = call_ws_get_messages( + second_gateway, + CanisterWsGetMessagesArguments { nonce: 0 }, + ); + match res_gateway_2 { + CanisterWsGetMessagesResult::Ok(CanisterOutputCertifiedMessages { messages, .. }) => { + let first_message = &messages[0]; + prop_assert_eq!(&first_message.client_key, client_key); + let open_message = get_service_message_content_from_canister_message(first_message); + match open_message { + WebsocketServiceMessageContent::OpenMessage(open_message) => { + prop_assert_eq!(open_message.client_key, client_key.clone()); + }, + _ => panic!("Expected OpenMessage"), + } + }, + _ => panic!("unexpected result"), + }; - // simulate canister sending other messages to client - call_ws_send_with_panic( - &client_1_key.client_principal, - (0..10) - .map(|i| AppMessage { - text: format!("test{}", i + 10), - }) - .collect(), - ); - let res_gateway_2 = call_ws_get_messages( - GATEWAY_2.deref(), - CanisterWsGetMessagesArguments { nonce: 0 }, - ); - match res_gateway_2 { - CanisterWsGetMessagesResult::Ok(CanisterOutputCertifiedMessages { messages, .. }) => { - assert_eq!(messages.len() as u64, 10 + 1); // +1 for the open service message - }, - _ => panic!("unexpected result"), - }; + // simulate canister sending other messages to client + call_ws_send_with_panic( + &client_key.client_principal, + (0..10) + .map(|i| AppMessage { + text: format!("test{}", i + 10), + }) + .collect(), + ); + let res_gateway_2 = call_ws_get_messages( + second_gateway, + CanisterWsGetMessagesArguments { nonce: 0 }, + ); + match res_gateway_2 { + CanisterWsGetMessagesResult::Ok(CanisterOutputCertifiedMessages { messages, .. }) => { + prop_assert_eq!(messages.len() as u64, 10 + 1); // +1 for the open service message + }, + _ => panic!("unexpected result"), + }; + } } diff --git a/src/ic-websocket-cdk/src/tests/integration_tests/utils/clients.rs b/src/ic-websocket-cdk/src/tests/integration_tests/utils/clients.rs index 75739f8..0520d64 100644 --- a/src/ic-websocket-cdk/src/tests/integration_tests/utils/clients.rs +++ b/src/ic-websocket-cdk/src/tests/integration_tests/utils/clients.rs @@ -3,28 +3,24 @@ use candid::Principal; use lazy_static::lazy_static; lazy_static! { - pub static ref CLIENT_1: Principal = - Principal::from_text("pmisz-prtlk-b6oe6-bj4fl-6l5fy-h7c2h-so6i7-jiz2h-bgto7-piqfr-7ae") - .unwrap(); - pub static ref CLIENT_2: Principal = - Principal::from_text("zuh6g-qnmvg-vky2t-tnob7-h4xoj-ykrcx-jqjpi-cdf3k-23i3i-ykozs-fae") - .unwrap(); + pub(crate) static ref CLIENT_1_KEY: ClientKey = + generate_client_key("pmisz-prtlk-b6oe6-bj4fl-6l5fy-h7c2h-so6i7-jiz2h-bgto7-piqfr-7ae"); + pub(crate) static ref CLIENT_2_KEY: ClientKey = + generate_client_key("zuh6g-qnmvg-vky2t-tnob7-h4xoj-ykrcx-jqjpi-cdf3k-23i3i-ykozs-fae"); /// The gateway registered in the local PocketIc env - pub static ref GATEWAY_1: Principal = + pub(crate) static ref GATEWAY_1: Principal = Principal::from_text("i3gux-m3hwt-5mh2w-t7wwm-fwx5j-6z6ht-hxguo-t4rfw-qp24z-g5ivt-2qe") .unwrap(); - pub static ref GATEWAY_2: Principal = + pub(crate) static ref GATEWAY_2: Principal = Principal::from_text("trj6m-u7l6v-zilnb-2hl6a-3jfz3-asri5-mkw3k-e2tpo-5emmk-6hqxb-uae") .unwrap(); } -lazy_static! { - pub(crate) static ref CLIENT_1_KEY: ClientKey = generate_client_key(*CLIENT_1.deref()); - pub(crate) static ref CLIENT_2_KEY: ClientKey = generate_client_key(*CLIENT_2.deref()); -} - -fn generate_client_key(client_principal: Principal) -> ClientKey { - ClientKey::new(client_principal, generate_random_client_nonce()) +fn generate_client_key(client_principal_text: &str) -> ClientKey { + ClientKey::new( + Principal::from_text(client_principal_text).unwrap(), + generate_random_client_nonce(), + ) } pub fn generate_random_client_nonce() -> u64 { diff --git a/src/ic-websocket-cdk/src/tests/mod.rs b/src/ic-websocket-cdk/src/tests/mod.rs index f33905a..bc07667 100644 --- a/src/ic-websocket-cdk/src/tests/mod.rs +++ b/src/ic-websocket-cdk/src/tests/mod.rs @@ -1,4 +1,6 @@ #![cfg(test)] +mod common; + mod integration_tests; mod unit_tests; diff --git a/src/ic-websocket-cdk/src/tests/unit_tests/mod.rs b/src/ic-websocket-cdk/src/tests/unit_tests/mod.rs index cfced7a..ec0dfd8 100644 --- a/src/ic-websocket-cdk/src/tests/unit_tests/mod.rs +++ b/src/ic-websocket-cdk/src/tests/unit_tests/mod.rs @@ -1,3 +1,4 @@ +use super::common; use crate::*; use proptest::prelude::*; @@ -43,14 +44,14 @@ fn test_ws_handlers_are_called() { assert!(handlers.on_close.is_none()); handlers.call_on_open(OnOpenCallbackArgs { - client_principal: utils::generate_random_principal(), + client_principal: common::generate_random_principal(), }); handlers.call_on_message(OnMessageCallbackArgs { - client_principal: utils::generate_random_principal(), + client_principal: common::generate_random_principal(), message: vec![], }); handlers.call_on_close(OnCloseCallbackArgs { - client_principal: utils::generate_random_principal(), + client_principal: common::generate_random_principal(), }); // test that the handlers are not called if they are not initialized @@ -96,14 +97,14 @@ fn test_ws_handlers_are_called() { assert!(handlers.on_close.is_some()); handlers.call_on_open(OnOpenCallbackArgs { - client_principal: utils::generate_random_principal(), + client_principal: common::generate_random_principal(), }); handlers.call_on_message(OnMessageCallbackArgs { - client_principal: utils::generate_random_principal(), + client_principal: common::generate_random_principal(), message: vec![], }); handlers.call_on_close(OnCloseCallbackArgs { - client_principal: utils::generate_random_principal(), + client_principal: common::generate_random_principal(), }); // test that the handlers are called if they are initialized @@ -135,20 +136,20 @@ fn test_ws_handlers_panic_is_handled() { let res = panic::catch_unwind(|| { handlers.call_on_open(OnOpenCallbackArgs { - client_principal: utils::generate_random_principal(), + client_principal: common::generate_random_principal(), }); }); assert!(res.is_ok()); let res = panic::catch_unwind(|| { handlers.call_on_message(OnMessageCallbackArgs { - client_principal: utils::generate_random_principal(), + client_principal: common::generate_random_principal(), message: vec![], }); }); assert!(res.is_ok()); let res = panic::catch_unwind(|| { handlers.call_on_close(OnCloseCallbackArgs { - client_principal: utils::generate_random_principal(), + client_principal: common::generate_random_principal(), }); }); assert!(res.is_ok()); @@ -162,7 +163,7 @@ fn test_current_time() { proptest! { #[test] - fn test_initialize_registered_gateways(test_gateway_principal in any::().prop_map(|_| utils::generate_random_principal())) { + fn test_initialize_registered_gateways(test_gateway_principal in any::().prop_map(|_| common::generate_random_principal())) { initialize_registered_gateways(vec![test_gateway_principal.to_string()]); let map = REGISTERED_GATEWAYS.with(|map| map.borrow().clone()); @@ -184,7 +185,7 @@ proptest! { #[test] fn test_get_outgoing_message_nonce(test_nonce in any::()) { // Set up - let gateway_principal = utils::generate_random_principal(); + let gateway_principal = common::generate_random_principal(); REGISTERED_GATEWAYS.with(|n| n.borrow_mut().insert(gateway_principal, RegisteredGateway { outgoing_message_nonce: test_nonce, ..Default::default() })); let res = get_outgoing_message_nonce(&gateway_principal); @@ -192,7 +193,7 @@ proptest! { } #[test] - fn test_get_outgoing_message_nonce_nonexistent(test_gateway_principal in any::().prop_map(|_| utils::generate_random_principal())) { + fn test_get_outgoing_message_nonce_nonexistent(test_gateway_principal in any::().prop_map(|_| common::generate_random_principal())) { let res = get_outgoing_message_nonce(&test_gateway_principal); prop_assert_eq!(res.err(), Some(String::from(format!("no gateway registered with principal {test_gateway_principal}")))); } @@ -200,7 +201,7 @@ proptest! { #[test] fn test_increment_outgoing_message_nonce(test_nonce in any::()) { // Set up - let gateway_principal = utils::generate_random_principal(); + let gateway_principal = common::generate_random_principal(); REGISTERED_GATEWAYS.with(|n| n.borrow_mut().insert(gateway_principal, RegisteredGateway { outgoing_message_nonce: test_nonce, ..Default::default() })); increment_outgoing_message_nonce(&gateway_principal); @@ -209,7 +210,7 @@ proptest! { } #[test] - fn test_insert_client(test_client_key in any::().prop_map(|_| utils::get_random_client_key())) { + fn test_insert_client(test_client_key in any::().prop_map(|_| common::get_random_client_key())) { // Set up let registered_client = utils::generate_random_registered_client(); @@ -223,7 +224,7 @@ proptest! { } #[test] - fn test_get_registered_gateway(test_gateway_principal in any::().prop_map(|_| utils::generate_random_principal())) { + fn test_get_registered_gateway(test_gateway_principal in any::().prop_map(|_| common::generate_random_principal())) { // Set up REGISTERED_GATEWAYS.with(|n| n.borrow_mut().insert(test_gateway_principal, RegisteredGateway::new())); @@ -232,19 +233,19 @@ proptest! { } #[test] - fn test_get_registered_gateway_nonexistent(test_gateway_principal in any::().prop_map(|_| utils::generate_random_principal())) { + fn test_get_registered_gateway_nonexistent(test_gateway_principal in any::().prop_map(|_| common::generate_random_principal())) { let res = get_registered_gateway(&test_gateway_principal); prop_assert_eq!(res.err(), Some(String::from(format!("no gateway registered with principal {test_gateway_principal}")))); } #[test] - fn test_is_client_registered_empty(test_client_key in any::().prop_map(|_| utils::get_random_client_key())) { + fn test_is_client_registered_empty(test_client_key in any::().prop_map(|_| common::get_random_client_key())) { let actual_result = is_client_registered(&test_client_key); prop_assert_eq!(actual_result, false); } #[test] - fn test_is_client_registered(test_client_key in any::().prop_map(|_| utils::get_random_client_key())) { + fn test_is_client_registered(test_client_key in any::().prop_map(|_| common::get_random_client_key())) { // Set up REGISTERED_CLIENTS.with(|map| { map.borrow_mut().insert(test_client_key.clone(), utils::generate_random_registered_client()); @@ -255,7 +256,7 @@ proptest! { } #[test] - fn test_get_client_key_from_principal_empty(test_client_principal in any::().prop_map(|_| utils::generate_random_principal())) { + fn test_get_client_key_from_principal_empty(test_client_principal in any::().prop_map(|_| common::generate_random_principal())) { let actual_result = get_client_key_from_principal(&test_client_principal); prop_assert_eq!(actual_result.err(), Some(String::from(format!( "client with principal {} doesn't have an open connection", @@ -264,7 +265,7 @@ proptest! { } #[test] - fn test_get_client_key_from_principal(test_client_key in any::().prop_map(|_| utils::get_random_client_key())) { + fn test_get_client_key_from_principal(test_client_key in any::().prop_map(|_| common::get_random_client_key())) { // Set up CURRENT_CLIENT_KEY_MAP.with(|map| { map.borrow_mut().insert(test_client_key.client_principal, test_client_key.clone()); @@ -275,13 +276,13 @@ proptest! { } #[test] - fn test_check_registered_client_empty(test_client_key in any::().prop_map(|_| utils::get_random_client_key())) { + fn test_check_registered_client_empty(test_client_key in any::().prop_map(|_| common::get_random_client_key())) { let actual_result = check_registered_client(&test_client_key); prop_assert_eq!(actual_result.err(), Some(format!("client with key {} doesn't have an open connection", test_client_key))); } #[test] - fn test_check_registered_client(test_client_key in any::().prop_map(|_| utils::get_random_client_key())) { + fn test_check_registered_client(test_client_key in any::().prop_map(|_| common::get_random_client_key())) { // Set up REGISTERED_CLIENTS.with(|map| { map.borrow_mut().insert(test_client_key.clone(), utils::generate_random_registered_client()); @@ -289,13 +290,13 @@ proptest! { let actual_result = check_registered_client(&test_client_key); prop_assert!(actual_result.is_ok()); - let non_existing_client_key = utils::get_random_client_key(); + let non_existing_client_key = common::get_random_client_key(); let actual_result = check_registered_client(&non_existing_client_key); prop_assert_eq!(actual_result.err(), Some(format!("client with key {} doesn't have an open connection", non_existing_client_key))); } #[test] - fn test_init_outgoing_message_to_client_num(test_client_key in any::().prop_map(|_| utils::get_random_client_key())) { + fn test_init_outgoing_message_to_client_num(test_client_key in any::().prop_map(|_| common::get_random_client_key())) { init_outgoing_message_to_client_num(test_client_key.clone()); let actual_result = OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| map.borrow().get(&test_client_key).unwrap().clone()); @@ -303,7 +304,7 @@ proptest! { } #[test] - fn test_increment_outgoing_message_to_client_num(test_client_key in any::().prop_map(|_| utils::get_random_client_key()), test_num in any::()) { + fn test_increment_outgoing_message_to_client_num(test_client_key in any::().prop_map(|_| common::get_random_client_key()), test_num in any::()) { // Set up OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| { map.borrow_mut().insert(test_client_key.clone(), test_num); @@ -317,7 +318,7 @@ proptest! { } #[test] - fn test_get_outgoing_message_to_client_num(test_client_key in any::().prop_map(|_| utils::get_random_client_key()), test_num in any::()) { + fn test_get_outgoing_message_to_client_num(test_client_key in any::().prop_map(|_| common::get_random_client_key()), test_num in any::()) { // Set up OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| { map.borrow_mut().insert(test_client_key.clone(), test_num); @@ -329,7 +330,7 @@ proptest! { } #[test] - fn test_init_expected_incoming_message_from_client_num(test_client_key in any::().prop_map(|_| utils::get_random_client_key())) { + fn test_init_expected_incoming_message_from_client_num(test_client_key in any::().prop_map(|_| common::get_random_client_key())) { init_expected_incoming_message_from_client_num(test_client_key.clone()); let actual_result = INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| map.borrow().get(&test_client_key).unwrap().clone()); @@ -337,7 +338,7 @@ proptest! { } #[test] - fn test_get_expected_incoming_message_from_client_num(test_client_key in any::().prop_map(|_| utils::get_random_client_key()), test_num in any::()) { + fn test_get_expected_incoming_message_from_client_num(test_client_key in any::().prop_map(|_| common::get_random_client_key()), test_num in any::()) { // Set up INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| { map.borrow_mut().insert(test_client_key.clone(), test_num); @@ -349,7 +350,7 @@ proptest! { } #[test] - fn test_increment_expected_incoming_message_from_client_num(test_client_key in any::().prop_map(|_| utils::get_random_client_key()), test_num in any::()) { + fn test_increment_expected_incoming_message_from_client_num(test_client_key in any::().prop_map(|_| common::get_random_client_key()), test_num in any::()) { // Set up INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| { map.borrow_mut().insert(test_client_key.clone(), test_num); @@ -363,7 +364,7 @@ proptest! { } #[test] - fn test_add_client_to_wait_for_keep_alive(test_client_key in any::().prop_map(|_| utils::get_random_client_key())) { + fn test_add_client_to_wait_for_keep_alive(test_client_key in any::().prop_map(|_| common::get_random_client_key())) { add_client_to_wait_for_keep_alive(&test_client_key); let actual_result = CLIENTS_WAITING_FOR_KEEP_ALIVE.with(|map| map.borrow().get(&test_client_key).is_some()); @@ -371,7 +372,7 @@ proptest! { } #[test] - fn test_add_client(test_client_key in any::().prop_map(|_| utils::get_random_client_key())) { + fn test_add_client(test_client_key in any::().prop_map(|_| common::get_random_client_key())) { let registered_client = utils::generate_random_registered_client(); // Test @@ -391,7 +392,7 @@ proptest! { } #[test] - fn test_remove_client(test_client_key in any::().prop_map(|_| utils::get_random_client_key())) { + fn test_remove_client(test_client_key in any::().prop_map(|_| common::get_random_client_key())) { // Set up CURRENT_CLIENT_KEY_MAP.with(|map| { map.borrow_mut().insert(test_client_key.client_principal.clone(), test_client_key.clone()); @@ -422,7 +423,7 @@ proptest! { } #[test] - fn test_format_message_for_gateway_key(test_gateway_principal in any::().prop_map(|_| utils::generate_random_principal()), test_nonce in any::()) { + fn test_format_message_for_gateway_key(test_gateway_principal in any::().prop_map(|_| common::generate_random_principal()), test_nonce in any::()) { let actual_result = format_message_for_gateway_key(&test_gateway_principal, test_nonce); prop_assert_eq!(actual_result, test_gateway_principal.to_string() + "_" + &format!("{:0>20}", test_nonce.to_string())); } @@ -431,7 +432,7 @@ proptest! { fn test_get_messages_for_gateway_range_empty(messages_count in 0..1000u64) { // Set up utils::initialize_params(); - let gateway_principal = utils::generate_random_principal(); + let gateway_principal = common::generate_random_principal(); REGISTERED_GATEWAYS.with(|n| n.borrow_mut().insert(gateway_principal, RegisteredGateway::new())); // Test @@ -444,13 +445,13 @@ proptest! { } #[test] - fn test_get_messages_for_gateway_range_smaller_than_max(gateway_principal in any::().prop_map(|_| utils::get_static_principal())) { + fn test_get_messages_for_gateway_range_smaller_than_max(gateway_principal in any::().prop_map(|_| common::get_static_principal())) { // Set up utils::initialize_params(); REGISTERED_GATEWAYS.with(|n| n.borrow_mut().insert(gateway_principal, RegisteredGateway::new())); let messages_count = 4; - let test_client_key = utils::get_random_client_key(); + let test_client_key = common::get_random_client_key(); utils::add_messages_for_gateway(test_client_key, &gateway_principal, messages_count); // Test @@ -467,7 +468,7 @@ proptest! { } #[test] - fn test_get_messages_for_gateway_range_larger_than_max(gateway_principal in any::().prop_map(|_| utils::get_static_principal()), max_number_of_returned_messages in 0..1000usize) { + fn test_get_messages_for_gateway_range_larger_than_max(gateway_principal in any::().prop_map(|_| common::get_static_principal()), max_number_of_returned_messages in 0..1000usize) { // Set up set_params(WsInitParams { max_number_of_returned_messages, @@ -476,7 +477,7 @@ proptest! { REGISTERED_GATEWAYS.with(|n| n.borrow_mut().insert(gateway_principal, RegisteredGateway::new())); let messages_count: u64 = (2 * max_number_of_returned_messages).try_into().unwrap(); - let test_client_key = utils::get_random_client_key(); + let test_client_key = common::get_random_client_key(); utils::add_messages_for_gateway(test_client_key, &gateway_principal, messages_count); // Test @@ -498,7 +499,7 @@ proptest! { } #[test] - fn test_get_messages_for_gateway_initial_nonce(gateway_principal in any::().prop_map(|_| utils::get_static_principal()), messages_count in 0..100u64, max_number_of_returned_messages in 0..1000usize) { + fn test_get_messages_for_gateway_initial_nonce(gateway_principal in any::().prop_map(|_| common::get_static_principal()), messages_count in 0..100u64, max_number_of_returned_messages in 0..1000usize) { // Set up set_params(WsInitParams { max_number_of_returned_messages, @@ -506,7 +507,7 @@ proptest! { }); REGISTERED_GATEWAYS.with(|n| n.borrow_mut().insert(gateway_principal, RegisteredGateway::new())); - let test_client_key = utils::get_random_client_key(); + let test_client_key = common::get_random_client_key(); utils::add_messages_for_gateway(test_client_key, &gateway_principal, messages_count); // Test @@ -524,11 +525,11 @@ proptest! { } #[test] - fn test_get_messages_for_gateway(gateway_principal in any::().prop_map(|_| utils::get_static_principal()), messages_count in 0..100u64) { + fn test_get_messages_for_gateway(gateway_principal in any::().prop_map(|_| common::get_static_principal()), messages_count in 0..100u64) { // Set up REGISTERED_GATEWAYS.with(|n| n.borrow_mut().insert(gateway_principal, RegisteredGateway::new())); - let test_client_key = utils::get_random_client_key(); + let test_client_key = common::get_random_client_key(); utils::add_messages_for_gateway(test_client_key, &gateway_principal, messages_count); // Test @@ -549,14 +550,14 @@ proptest! { } #[test] - fn test_check_is_registered_gateway(test_gateway_principal in any::().prop_map(|_| utils::generate_random_principal())) { + fn test_check_is_registered_gateway(test_gateway_principal in any::().prop_map(|_| common::generate_random_principal())) { // Set up REGISTERED_GATEWAYS.with(|n| n.borrow_mut().insert(test_gateway_principal, RegisteredGateway::new())); let actual_result = check_is_registered_gateway(&test_gateway_principal); prop_assert!(actual_result.is_ok()); - let other_principal = utils::generate_random_principal(); + let other_principal = common::generate_random_principal(); let actual_result = check_is_registered_gateway(&other_principal); prop_assert_eq!(actual_result.err(), Some(String::from("principal is not one of the authorized gateways that have been registered during CDK initialization"))); } @@ -565,7 +566,7 @@ proptest! { fn test_serialize_websocket_message(test_msg_bytes in any::>(), test_sequence_num in any::(), test_timestamp in any::()) { // TODO: add more tests, in which we check the serialized message let websocket_message = WebsocketMessage { - client_key: utils::get_random_client_key(), + client_key: common::get_random_client_key(), sequence_num: test_sequence_num, timestamp: test_timestamp, is_service_message: false, diff --git a/src/ic-websocket-cdk/src/tests/unit_tests/utils.rs b/src/ic-websocket-cdk/src/tests/unit_tests/utils.rs index 0966412..38e593d 100644 --- a/src/ic-websocket-cdk/src/tests/unit_tests/utils.rs +++ b/src/ic-websocket-cdk/src/tests/unit_tests/utils.rs @@ -1,43 +1,14 @@ use candid::Principal; -use ic_agent::{identity::BasicIdentity, Identity}; -use ring::signature::Ed25519KeyPair; use crate::{ format_message_for_gateway_key, set_params, CanisterOutputMessage, ClientKey, GatewayPrincipal, RegisteredClient, WsInitParams, REGISTERED_GATEWAYS, }; -fn generate_random_key_pair() -> Ed25519KeyPair { - let rng = ring::rand::SystemRandom::new(); - let key_pair = Ed25519KeyPair::generate_pkcs8(&rng).expect("Could not generate a key pair."); - Ed25519KeyPair::from_pkcs8(key_pair.as_ref()).expect("Could not read the key pair.") -} - -pub fn generate_random_principal() -> Principal { - let key_pair = generate_random_key_pair(); - let identity = BasicIdentity::from_key_pair(key_pair); - - // workaround to keep the principal in the version of candid used by the canister - Principal::from_text(identity.sender().unwrap().to_text()).unwrap() -} - pub(super) fn generate_random_registered_client() -> RegisteredClient { RegisteredClient::new(Principal::anonymous()) } -pub fn get_static_principal() -> Principal { - Principal::from_text("wnkwv-wdqb5-7wlzr-azfpw-5e5n5-dyxrf-uug7x-qxb55-mkmpa-5jqik-tqe").unwrap() - // a random static but valid principal -} - -pub(super) fn get_random_client_key() -> ClientKey { - ClientKey::new( - generate_random_principal(), - // a random nonce - rand::random(), - ) -} - pub(super) fn add_messages_for_gateway( client_key: ClientKey, gateway_principal: &GatewayPrincipal, From 3984056fcb8fc687f49708df41f7b1f0ab0f8e64 Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Wed, 15 Nov 2023 15:51:57 +0100 Subject: [PATCH 20/27] perf: functions visibility --- src/ic-websocket-cdk/src/tests/common.rs | 4 ++-- .../integration_tests/c_ws_get_messages.rs | 2 +- .../f_messages_acknowledgement.rs | 2 +- .../tests/integration_tests/utils/actor.rs | 6 +++-- .../tests/integration_tests/utils/clients.rs | 8 +++---- .../tests/integration_tests/utils/messages.rs | 22 ++++++++++++------- 6 files changed, 26 insertions(+), 18 deletions(-) diff --git a/src/ic-websocket-cdk/src/tests/common.rs b/src/ic-websocket-cdk/src/tests/common.rs index c73698e..1ed5ca6 100644 --- a/src/ic-websocket-cdk/src/tests/common.rs +++ b/src/ic-websocket-cdk/src/tests/common.rs @@ -10,7 +10,7 @@ fn generate_random_key_pair() -> Ed25519KeyPair { Ed25519KeyPair::from_pkcs8(key_pair.as_ref()).expect("Could not read the key pair.") } -pub fn generate_random_principal() -> Principal { +pub(super) fn generate_random_principal() -> Principal { let key_pair = generate_random_key_pair(); let identity = BasicIdentity::from_key_pair(key_pair); @@ -18,7 +18,7 @@ pub fn generate_random_principal() -> Principal { Principal::from_text(identity.sender().unwrap().to_text()).unwrap() } -pub fn get_static_principal() -> Principal { +pub(super) fn get_static_principal() -> Principal { Principal::from_text("wnkwv-wdqb5-7wlzr-azfpw-5e5n5-dyxrf-uug7x-qxb55-mkmpa-5jqik-tqe").unwrap() // a random static but valid principal } diff --git a/src/ic-websocket-cdk/src/tests/integration_tests/c_ws_get_messages.rs b/src/ic-websocket-cdk/src/tests/integration_tests/c_ws_get_messages.rs index 731efa1..512a9c8 100644 --- a/src/ic-websocket-cdk/src/tests/integration_tests/c_ws_get_messages.rs +++ b/src/ic-websocket-cdk/src/tests/integration_tests/c_ws_get_messages.rs @@ -247,7 +247,7 @@ mod helpers { }; use candid::decode_one; - pub(crate) fn verify_messages( + pub(super) fn verify_messages( messages: &Vec, client_key: &ClientKey, cert: &[u8], diff --git a/src/ic-websocket-cdk/src/tests/integration_tests/f_messages_acknowledgement.rs b/src/ic-websocket-cdk/src/tests/integration_tests/f_messages_acknowledgement.rs index 58ad36b..34b5d7c 100644 --- a/src/ic-websocket-cdk/src/tests/integration_tests/f_messages_acknowledgement.rs +++ b/src/ic-websocket-cdk/src/tests/integration_tests/f_messages_acknowledgement.rs @@ -208,7 +208,7 @@ mod helpers { CanisterWsGetMessagesResult, ClientKey, WebsocketServiceMessageContent, }; - pub(crate) fn check_ack_message_result( + pub(super) fn check_ack_message_result( res: &CanisterWsGetMessagesResult, receiver_client_key: &ClientKey, expected_ack_sequence_number: u64, diff --git a/src/ic-websocket-cdk/src/tests/integration_tests/utils/actor.rs b/src/ic-websocket-cdk/src/tests/integration_tests/utils/actor.rs index da13b9f..45499c3 100644 --- a/src/ic-websocket-cdk/src/tests/integration_tests/utils/actor.rs +++ b/src/ic-websocket-cdk/src/tests/integration_tests/utils/actor.rs @@ -37,7 +37,9 @@ pub mod ws_open { } /// See [call_ws_open_with_panic]. - pub(crate) fn call_ws_open_for_client_key_with_panic(client_key: &ClientKey) { + pub(in crate::tests::integration_tests) fn call_ws_open_for_client_key_with_panic( + client_key: &ClientKey, + ) { let args = CanisterWsOpenArguments { client_nonce: client_key.client_nonce, gateway_principal: GATEWAY_1.deref().to_owned(), @@ -46,7 +48,7 @@ pub mod ws_open { } /// See [call_ws_open_with_panic]. - pub(crate) fn call_ws_open_for_client_key_and_gateway_with_panic( + pub(in crate::tests::integration_tests) fn call_ws_open_for_client_key_and_gateway_with_panic( client_key: &ClientKey, gateway_principal: Principal, ) { diff --git a/src/ic-websocket-cdk/src/tests/integration_tests/utils/clients.rs b/src/ic-websocket-cdk/src/tests/integration_tests/utils/clients.rs index 0520d64..0352e61 100644 --- a/src/ic-websocket-cdk/src/tests/integration_tests/utils/clients.rs +++ b/src/ic-websocket-cdk/src/tests/integration_tests/utils/clients.rs @@ -3,15 +3,15 @@ use candid::Principal; use lazy_static::lazy_static; lazy_static! { - pub(crate) static ref CLIENT_1_KEY: ClientKey = + pub(in crate::tests::integration_tests) static ref CLIENT_1_KEY: ClientKey = generate_client_key("pmisz-prtlk-b6oe6-bj4fl-6l5fy-h7c2h-so6i7-jiz2h-bgto7-piqfr-7ae"); - pub(crate) static ref CLIENT_2_KEY: ClientKey = + pub(in crate::tests::integration_tests) static ref CLIENT_2_KEY: ClientKey = generate_client_key("zuh6g-qnmvg-vky2t-tnob7-h4xoj-ykrcx-jqjpi-cdf3k-23i3i-ykozs-fae"); /// The gateway registered in the local PocketIc env - pub(crate) static ref GATEWAY_1: Principal = + pub(in crate::tests::integration_tests) static ref GATEWAY_1: Principal = Principal::from_text("i3gux-m3hwt-5mh2w-t7wwm-fwx5j-6z6ht-hxguo-t4rfw-qp24z-g5ivt-2qe") .unwrap(); - pub(crate) static ref GATEWAY_2: Principal = + pub(in crate::tests::integration_tests) static ref GATEWAY_2: Principal = Principal::from_text("trj6m-u7l6v-zilnb-2hl6a-3jfz3-asri5-mkw3k-e2tpo-5emmk-6hqxb-uae") .unwrap(); } diff --git a/src/ic-websocket-cdk/src/tests/integration_tests/utils/messages.rs b/src/ic-websocket-cdk/src/tests/integration_tests/utils/messages.rs index 5d922c7..262a99b 100644 --- a/src/ic-websocket-cdk/src/tests/integration_tests/utils/messages.rs +++ b/src/ic-websocket-cdk/src/tests/integration_tests/utils/messages.rs @@ -3,32 +3,32 @@ use candid::{decode_one, encode_one}; use super::get_current_timestamp_ns; use crate::{CanisterOutputMessage, ClientKey, WebsocketMessage, WebsocketServiceMessageContent}; -pub(crate) fn get_websocket_message_from_canister_message( +pub(in crate::tests::integration_tests) fn get_websocket_message_from_canister_message( msg: &CanisterOutputMessage, ) -> WebsocketMessage { decode_websocket_message(&msg.content) } -pub(crate) fn encode_websocket_service_message_content( +pub(in crate::tests::integration_tests) fn encode_websocket_service_message_content( content: &WebsocketServiceMessageContent, ) -> Vec { encode_one(content).unwrap() } -pub(crate) fn decode_websocket_service_message_content( +pub(in crate::tests::integration_tests) fn decode_websocket_service_message_content( bytes: &[u8], ) -> WebsocketServiceMessageContent { decode_one(bytes).unwrap() } -pub(crate) fn get_service_message_content_from_canister_message( +pub(in crate::tests::integration_tests) fn get_service_message_content_from_canister_message( msg: &CanisterOutputMessage, ) -> WebsocketServiceMessageContent { let websocket_message = get_websocket_message_from_canister_message(msg); decode_websocket_service_message_content(&websocket_message.content) } -pub(crate) fn create_websocket_message( +pub(in crate::tests::integration_tests) fn create_websocket_message( client_key: &ClientKey, sequence_number: u64, content: Option>, @@ -45,14 +45,20 @@ pub(crate) fn create_websocket_message( } } -pub(crate) fn decode_websocket_message(bytes: &[u8]) -> WebsocketMessage { +pub(in crate::tests::integration_tests) fn decode_websocket_message( + bytes: &[u8], +) -> WebsocketMessage { serde_cbor::from_slice(bytes).unwrap() } -pub fn get_polling_nonce_from_message(message: &CanisterOutputMessage) -> u64 { +pub(in crate::tests::integration_tests) fn get_polling_nonce_from_message( + message: &CanisterOutputMessage, +) -> u64 { message.key.split("_").last().unwrap().parse().unwrap() } -pub fn get_next_polling_nonce_from_messages(messages: Vec) -> u64 { +pub(in crate::tests::integration_tests) fn get_next_polling_nonce_from_messages( + messages: Vec, +) -> u64 { get_polling_nonce_from_message(messages.last().unwrap()) + 1 } From 426d5d070654f5ad73269ce5254c329051b9db9b Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Wed, 15 Nov 2023 17:23:27 +0100 Subject: [PATCH 21/27] feat: load test canister wasm also from path just set TEST_CANISTER_WASM_PATH env variable --- .../src/tests/integration_tests/utils/mod.rs | 2 -- .../src/tests/integration_tests/utils/test_env.rs | 8 ++++++-- .../src/tests/integration_tests/utils/wasm.rs | 15 ++++++++++----- 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/src/ic-websocket-cdk/src/tests/integration_tests/utils/mod.rs b/src/ic-websocket-cdk/src/tests/integration_tests/utils/mod.rs index bb4486f..f648465 100644 --- a/src/ic-websocket-cdk/src/tests/integration_tests/utils/mod.rs +++ b/src/ic-websocket-cdk/src/tests/integration_tests/utils/mod.rs @@ -18,8 +18,6 @@ pub fn bin_folder_path() -> PathBuf { file_path.pop(); file_path.push("bin"); file_path - // println!("{:?}", std::env::current_dir()); - // PathBuf::from("./bin") } /// Returns the current timestamp in nanoseconds. diff --git a/src/ic-websocket-cdk/src/tests/integration_tests/utils/test_env.rs b/src/ic-websocket-cdk/src/tests/integration_tests/utils/test_env.rs index bf42312..2603fe8 100644 --- a/src/ic-websocket-cdk/src/tests/integration_tests/utils/test_env.rs +++ b/src/ic-websocket-cdk/src/tests/integration_tests/utils/test_env.rs @@ -1,4 +1,5 @@ use std::{ + path::PathBuf, sync::{Mutex, MutexGuard}, time::{Duration, SystemTime}, }; @@ -13,7 +14,7 @@ use super::{ DEFAULT_TEST_KEEP_ALIVE_TIMEOUT_MS, DEFAULT_TEST_MAX_NUMBER_OF_RETURNED_MESSAGES, DEFAULT_TEST_SEND_ACK_INTERVAL_MS, }, - wasm::load_canister_wasm_from_bin, + wasm::{load_canister_wasm_from_bin, load_canister_wasm_from_path}, }; lazy_static! { @@ -46,7 +47,10 @@ impl TestEnv { let canister_id = pic.create_canister(None); pic.add_cycles(canister_id, 1_000_000_000_000_000); - let wasm_bytes = load_canister_wasm_from_bin("test_canister.wasm"); + let wasm_bytes = match std::env::var("TEST_CANISTER_WASM_PATH") { + Ok(path) => load_canister_wasm_from_path(&PathBuf::from(path)), + Err(_) => load_canister_wasm_from_bin("test_canister.wasm"), + }; let authorized_gateways = vec![GATEWAY_1.to_string()]; let arguments: CanisterInitArgs = ( diff --git a/src/ic-websocket-cdk/src/tests/integration_tests/utils/wasm.rs b/src/ic-websocket-cdk/src/tests/integration_tests/utils/wasm.rs index b078155..61cf4fa 100644 --- a/src/ic-websocket-cdk/src/tests/integration_tests/utils/wasm.rs +++ b/src/ic-websocket-cdk/src/tests/integration_tests/utils/wasm.rs @@ -1,15 +1,20 @@ use std::fs::File; use std::io::Read; +use std::path::PathBuf; use super::bin_folder_path; +pub fn load_canister_wasm_from_path(path: &PathBuf) -> Vec { + let mut file = File::open(&path) + .unwrap_or_else(|_| panic!("Failed to open file: {}", path.to_str().unwrap())); + let mut bytes = Vec::new(); + file.read_to_end(&mut bytes).expect("Failed to read file"); + bytes +} + pub fn load_canister_wasm_from_bin(wasm_name: &str) -> Vec { let mut file_path = bin_folder_path(); file_path.push(wasm_name); - let mut file = File::open(&file_path) - .unwrap_or_else(|_| panic!("Failed to open file: {}", file_path.to_str().unwrap())); - let mut bytes = Vec::new(); - file.read_to_end(&mut bytes).expect("Failed to read file"); - bytes + load_canister_wasm_from_path(&file_path) } From 4cb928c9a2f05d0f0652ddc505c86ee75c69147f Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Wed, 15 Nov 2023 18:19:02 +0100 Subject: [PATCH 22/27] chore: update to checkout@v4 action --- .github/workflows/release.yml | 6 +++--- .github/workflows/tests.yml | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 128cb93..df0378b 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -18,7 +18,7 @@ jobs: outputs: version: ${{ steps.cargo-publish.outputs.version }} steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: actions-rs/toolchain@v1 with: toolchain: "stable" @@ -41,7 +41,7 @@ jobs: version: ${{ steps.tag_version.outputs.new_tag }} steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Bump version and push tag id: tag_version uses: mathieudutour/github-tag-action@v6.1 @@ -54,7 +54,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Release uses: softprops/action-gh-release@v1 with: diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 0bb86a5..9a01470 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -10,7 +10,7 @@ jobs: unit-and-integration-tests: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: actions-rs/toolchain@v1 with: toolchain: "stable" From a2d952e32c32661e0c004b919eb3c4a9665efb0b Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Wed, 15 Nov 2023 21:54:56 +0100 Subject: [PATCH 23/27] perf: unneeded methods --- src/ic-websocket-cdk/src/lib.rs | 7 +++---- src/test_canister/src/lib.rs | 22 ---------------------- src/test_canister/test_canister.did | 2 -- 3 files changed, 3 insertions(+), 28 deletions(-) diff --git a/src/ic-websocket-cdk/src/lib.rs b/src/ic-websocket-cdk/src/lib.rs index 98a3e7b..77fec0a 100644 --- a/src/ic-websocket-cdk/src/lib.rs +++ b/src/ic-websocket-cdk/src/lib.rs @@ -770,7 +770,7 @@ fn check_keep_alive_timer_callback(keep_alive_timeout_ms: u64) { fn handle_keep_alive_client_message( client_key: &ClientKey, _keep_alive_message: ClientKeepAliveMessageContent, -) -> Result<(), String> { +) { // TODO: delete messages from the queue that have been acknowledged by the client // update the last keep alive timestamp for the client @@ -781,8 +781,6 @@ fn handle_keep_alive_client_message( { client_metadata.update_last_keep_alive_timestamp(); } - - Ok(()) } /// Internal function used to put the messages in the outgoing messages queue and certify them. @@ -851,7 +849,8 @@ fn handle_received_service_message( Err(String::from("Invalid received service message")) }, WebsocketServiceMessageContent::KeepAliveMessage(keep_alive_message) => { - handle_keep_alive_client_message(client_key, keep_alive_message) + handle_keep_alive_client_message(client_key, keep_alive_message); + Ok(()) }, } } diff --git a/src/test_canister/src/lib.rs b/src/test_canister/src/lib.rs index d105d5a..4ede2aa 100644 --- a/src/test_canister/src/lib.rs +++ b/src/test_canister/src/lib.rs @@ -77,12 +77,6 @@ fn ws_get_messages(args: CanisterWsGetMessagesArguments) -> CanisterWsGetMessage } //// Debug/tests methods -// wipe all websocket data in the canister -#[update] -fn ws_wipe() { - ic_websocket_cdk::wipe(); -} - // send a message to the client, usually called by the canister itself #[update] fn ws_send(client_principal: ClientPrincipal, messages: Vec>) -> CanisterWsSendResult { @@ -94,19 +88,3 @@ fn ws_send(client_principal: ClientPrincipal, messages: Vec>) -> Caniste } Ok(()) } - -// initialize the CDK again -#[update] -fn initialize( - gateway_principals: Vec, - max_number_of_returned_messages: usize, - send_ack_interval_ms: u64, - keep_alive_delay_ms: u64, -) { - init( - gateway_principals, - max_number_of_returned_messages, - send_ack_interval_ms, - keep_alive_delay_ms, - ); -} diff --git a/src/test_canister/test_canister.did b/src/test_canister/test_canister.did index 42b7f6e..11f9c11 100644 --- a/src/test_canister/test_canister.did +++ b/src/test_canister/test_canister.did @@ -16,7 +16,5 @@ service : (text, nat64, nat64, nat64) -> { "ws_get_messages" : (CanisterWsGetMessagesArguments) -> (CanisterWsGetMessagesResult) query; // methods used just for debugging/testing - "ws_wipe" : () -> (); "ws_send" : (ClientPrincipal, vec blob) -> (CanisterWsSendResult); - "initialize" : (text, nat64, nat64, nat64) -> (); }; From 4e573e3a8245c05bd5c0cf30f468d0c9cdc9c508 Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Wed, 15 Nov 2023 21:55:11 +0100 Subject: [PATCH 24/27] fix: send a valid but wrong candid --- .../src/tests/integration_tests/b_ws_message.rs | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/ic-websocket-cdk/src/tests/integration_tests/b_ws_message.rs b/src/ic-websocket-cdk/src/tests/integration_tests/b_ws_message.rs index 36d8887..9a89a2e 100644 --- a/src/ic-websocket-cdk/src/tests/integration_tests/b_ws_message.rs +++ b/src/ic-websocket-cdk/src/tests/integration_tests/b_ws_message.rs @@ -1,5 +1,7 @@ use std::ops::Deref; +use candid::encode_one; + use crate::{ tests::common::generate_random_principal, CanisterAckMessageContent, CanisterWsMessageArguments, CanisterWsMessageResult, ClientKeepAliveMessageContent, ClientKey, @@ -141,7 +143,12 @@ fn test_5_fails_if_client_sends_a_wrong_service_message() { let res = call_ws_message( &client_1_key.client_principal, CanisterWsMessageArguments { - msg: create_websocket_message(client_1_key, 1, Some(vec![1, 2, 3]), true), + msg: create_websocket_message( + client_1_key, + 1, + Some(encode_one(vec![1, 2, 3]).unwrap()), + true, + ), }, ); match res { From 3f024529c53c52f540a582c2d07bf91d64d090b1 Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Thu, 16 Nov 2023 08:26:37 +0100 Subject: [PATCH 25/27] chore: enable tests on pull request --- .github/workflows/tests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9a01470..44fdebc 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -5,6 +5,7 @@ on: push: branches: - main + pull_request: jobs: unit-and-integration-tests: From 7e12d38feda916a393128e9792e319544729ca79 Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Fri, 17 Nov 2023 21:10:32 +0100 Subject: [PATCH 26/27] chore: rename test script --- .github/workflows/tests.yml | 1 + README.md | 2 +- scripts/{test_canister.sh => test.sh} | 0 3 files changed, 2 insertions(+), 1 deletion(-) rename scripts/{test_canister.sh => test.sh} (100%) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 44fdebc..6423064 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -22,6 +22,7 @@ jobs: run: cargo test --package ic-websocket-cdk --lib -- tests::unit_tests - name: Run doc tests run: cargo test --package ic-websocket-cdk --doc + # the following steps are replicated in the scripts/test.sh file - name: Prepare environment for integration tests run: | rustup target add wasm32-unknown-unknown diff --git a/README.md b/README.md index 839cb84..5527bc2 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,7 @@ There are two types of tests available: There's a script that runs all the tests together, taking care of setting up the environment (Linux only!) and deploying the canister. To run the script, execute the following command: ```bash -./scripts/test_canister.sh +./scripts/test.sh ``` > If you're on **macOS**, you have to manually download the PocketIC binary ([guide](https://github.com/dfinity/pocketic#download)) and place it in the [bin](./bin/) folder. diff --git a/scripts/test_canister.sh b/scripts/test.sh similarity index 100% rename from scripts/test_canister.sh rename to scripts/test.sh From 969d7677156b5670035e7a979024a6cdb22a2622 Mon Sep 17 00:00:00 2001 From: Luca8991 Date: Fri, 17 Nov 2023 21:10:46 +0100 Subject: [PATCH 27/27] style: typo --- src/ic-websocket-cdk/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ic-websocket-cdk/src/lib.rs b/src/ic-websocket-cdk/src/lib.rs index 77fec0a..ee5326c 100644 --- a/src/ic-websocket-cdk/src/lib.rs +++ b/src/ic-websocket-cdk/src/lib.rs @@ -1013,7 +1013,7 @@ impl Default for WsInitParams { } } -/// Initialize the CDK by setting the callback handlers and the **principal** of the WS Gateway that +/// Initialize the CDK by setting the callback handlers and the **principals** of the WS Gateways that /// will be polling the canister. /// /// **Note**: Resets the timers under the hood.