Skip to content

Commit

Permalink
Merge branch 'tests' into feat/authenticated-messages
Browse files Browse the repository at this point in the history
  • Loading branch information
ilbertt committed Oct 9, 2023
2 parents 18a8920 + 5d6c31f commit bd0d597
Show file tree
Hide file tree
Showing 14 changed files with 809 additions and 1,068 deletions.
275 changes: 161 additions & 114 deletions Cargo.lock

Large diffs are not rendered by default.

74 changes: 66 additions & 8 deletions src/ic-websocket-cdk/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use candid::{encode_one, CandidType, Principal};
use candid::{decode_one, encode_one, CandidType, Principal};
#[cfg(not(test))]
use ic_cdk::api::time;
use ic_cdk::api::{caller, data_certificate, set_certified_data};
Expand Down Expand Up @@ -374,7 +374,6 @@ fn get_messages_for_gateway_range(gateway_principal: Principal, nonce: u64) -> (
MESSAGES_FOR_GATEWAY.with(|m| {
let queue_len = m.borrow().len();

// TODO: test
if nonce == 0 && queue_len > 0 {
// this is the case in which the poller on the gateway restarted
// the range to return is end:last index and start: max(end - MAX_NUMBER_OF_RETURNED_MESSAGES, 0)
Expand Down Expand Up @@ -490,7 +489,7 @@ struct ClientKeepAliveMessageContent {
last_incoming_sequence_num: u64,
}

/// A service message sent by the CDK to the client.
/// A service message sent by the CDK to the client or vice versa.
#[derive(CandidType, Deserialize)]
enum WebsocketServiceMessageContent {
/// Message sent by the **canister** when a client opens a connection.
Expand All @@ -501,6 +500,16 @@ enum WebsocketServiceMessageContent {
KeepAliveMessage(ClientKeepAliveMessageContent),
}

impl WebsocketServiceMessageContent {
fn from_candid_bytes(bytes: Vec<u8>) -> Result<Self, String> {
decode_one(&bytes).map_err(|e| {
let mut err = String::from("Error decoding service message content: ");
err.push_str(&e.to_string());
err
})
}
}

fn send_service_message_to_client(
client_key: &ClientKey,
message: WebsocketServiceMessageContent,
Expand Down Expand Up @@ -558,6 +567,20 @@ fn _ws_send(
Ok(())
}

fn handle_received_service_message(content: Vec<u8>) -> CanisterWsMessageResult {
let decoded = WebsocketServiceMessageContent::from_candid_bytes(content)?;
match decoded {
WebsocketServiceMessageContent::OpenMessage(_)
| WebsocketServiceMessageContent::AckMessage(_) => {
Err(String::from("Invalid received service message"))
},
WebsocketServiceMessageContent::KeepAliveMessage(_) => {
custom_print!("Service message handling not implemented yet");
Ok(())
},
}
}

/// Arguments passed to the `on_open` handler.
pub struct OnOpenCallbackArgs {
pub client_principal: ClientPrincipal,
Expand Down Expand Up @@ -595,7 +618,6 @@ pub struct WsHandlers {
impl WsHandlers {
fn call_on_open(&self, args: OnOpenCallbackArgs) {
if let Some(on_open) = self.on_open {
// TODO: test the panic handling
let res = panic::catch_unwind(|| {
on_open(args);
});
Expand All @@ -608,7 +630,6 @@ impl WsHandlers {

fn call_on_message(&self, args: OnMessageCallbackArgs) {
if let Some(on_message) = self.on_message {
// TODO: test the panic handling
let res = panic::catch_unwind(|| {
on_message(args);
});
Expand All @@ -621,7 +642,6 @@ impl WsHandlers {

fn call_on_close(&self, args: OnCloseCallbackArgs) {
if let Some(on_close) = self.on_close {
// TODO: test the panic handling
let res = panic::catch_unwind(|| {
on_close(args);
});
Expand Down Expand Up @@ -773,9 +793,8 @@ pub fn ws_message(args: CanisterWsMessageArguments) -> CanisterWsMessageResult {
// increase the expected sequence number by 1
increment_expected_incoming_message_from_client_num(&client_key)?;

// TODO: test
if is_service_message {
custom_print!("Service message handling not implemented yet");
return handle_received_service_message(content);
}

// call the on_message handler initialized in init()
Expand Down Expand Up @@ -990,6 +1009,45 @@ mod test {
assert!(CUSTOM_STATE.with(|h| h.borrow().is_on_close_called));
}

#[test]
fn test_ws_handlers_panic_is_handled() {
let handlers = WsHandlers {
on_open: Some(|_| {
panic!("on_open_panic");
}),
on_message: Some(|_| {
panic!("on_close_panic");
}),
on_close: Some(|_| {
panic!("on_close_panic");
}),
};

initialize_handlers(handlers);

let handlers = HANDLERS.with(|h| h.borrow().clone());

let res = panic::catch_unwind(|| {
handlers.call_on_open(OnOpenCallbackArgs {
client_principal: test_utils::generate_random_principal(),
});
});
assert!(res.is_ok());
let res = panic::catch_unwind(|| {
handlers.call_on_message(OnMessageCallbackArgs {
client_principal: test_utils::generate_random_principal(),
message: vec![],
});
});
assert!(res.is_ok());
let res = panic::catch_unwind(|| {
handlers.call_on_close(OnCloseCallbackArgs {
client_principal: test_utils::generate_random_principal(),
});
});
assert!(res.is_ok());
}

#[test]
fn test_current_time() {
// test
Expand Down
Loading

0 comments on commit bd0d597

Please sign in to comment.