Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

release/0.3.2 #5

Merged
merged 18 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/ic-websocket-cdk/service_messages.did
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,20 @@ type ClientKeepAliveMessageContent = record {
last_incoming_sequence_num : nat64;
};

type CloseMessageReason = variant {
WrongSequenceNumber;
InvalidServiceMessage;
KeepAliveTimeout;
ClosedByApplication;
};

type CanisterCloseMessageContent = record {
reason : CloseMessageReason;
};

type WebsocketServiceMessageContent = variant {
OpenMessage : CanisterOpenMessageContent;
AckMessage : CanisterAckMessageContent;
KeepAliveMessage : ClientKeepAliveMessageContent;
CloseMessage : CanisterCloseMessageContent;
};
6 changes: 3 additions & 3 deletions src/ic-websocket-cdk/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ pub fn ws_open(args: CanisterWsOpenArguments) -> CanisterWsOpenResult {
// Do nothing
},
Ok(old_client_key) => {
remove_client(&old_client_key);
remove_client(&old_client_key, None);
},
};

Expand Down Expand Up @@ -124,7 +124,7 @@ pub fn ws_close(args: CanisterWsCloseArguments) -> CanisterWsCloseResult {
// check if the client is registered to the gateway that is closing the connection
check_client_registered_to_gateway(&args.client_key, &gateway_principal)?;

remove_client(&args.client_key);
remove_client(&args.client_key, None);

Ok(())
}
Expand Down Expand Up @@ -188,7 +188,7 @@ pub fn ws_message<T: CandidType + for<'a> Deserialize<'a>>(
.eq(&expected_sequence_num)
.then_some(())
.ok_or_else(|| {
remove_client(&client_key);
remove_client(&client_key, Some(CloseMessageReason::WrongSequenceNumber));

WsError::IncomingSequenceNumberWrong {
expected_sequence_num,
Expand Down
57 changes: 41 additions & 16 deletions src/ic-websocket-cdk/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::{
};

use candid::{encode_one, Principal};
#[allow(unused_imports)]
use ic_cdk::api::{data_certificate, set_certified_data};
use ic_certified_map::{labeled, labeled_hash, AsHashTree, Hash as ICHash, RbTree};
use serde::Serialize;
Expand Down Expand Up @@ -44,7 +45,7 @@ pub(crate) fn reset_internal_state() {

// for each client, call the on_close handler before clearing the map
for client_key in client_keys_to_remove {
remove_client(&client_key);
remove_client(&client_key, None);
}

// make sure all the maps are cleared
Expand Down Expand Up @@ -79,16 +80,22 @@ pub(crate) fn increment_gateway_clients_count(gateway_principal: GatewayPrincipa
});
}

/// Decrements the clients connected count for the given gateway.
/// If there are no more clients connected, the gateway is removed from the list of registered gateways.
pub(crate) fn decrement_gateway_clients_count(gateway_principal: &GatewayPrincipal) {
/// Decrements the clients connected count for the given gateway, if it exists.
///
/// If `remove_if_empty` is true, the gateway is removed from the list of registered gateways
ilbertt marked this conversation as resolved.
Show resolved Hide resolved
/// if it has no clients connected.
pub(crate) fn decrement_gateway_clients_count(
gateway_principal: &GatewayPrincipal,
remove_if_empty: bool,
) {
REGISTERED_GATEWAYS.with(|map| {
let mut map = map.borrow_mut();
let g = map.get_mut(gateway_principal).unwrap(); // gateway must be registered at this point
let clients_count = g.decrement_clients_count();
if let Some(g) = map.get_mut(gateway_principal) {
let clients_count = g.decrement_clients_count();

if clients_count == 0 {
map.remove(gateway_principal);
if remove_if_empty && clients_count == 0 {
map.remove(gateway_principal);
}
}
});
}
Expand Down Expand Up @@ -266,7 +273,23 @@ pub(crate) fn add_client(client_key: ClientKey, new_client: RegisteredClient) {
/// Removes a client from the internal state
/// and call the on_close callback,
/// if the client was registered in the state.
pub(crate) fn remove_client(client_key: &ClientKey) {
///
/// If a `close_reason` is provided, it also sends a close message to the client,
/// so that the client can close the WS connection with the gateway.
///
/// If a `close_reason` is **not** provided, it also removes the gateway from the state
/// if it has no clients connected anymore.
pub(crate) fn remove_client(client_key: &ClientKey, close_reason: Option<CloseMessageReason>) {
if let Some(close_reason) = close_reason.clone() {
// ignore the error
let _ = send_service_message_to_client(
client_key,
&WebsocketServiceMessageContent::CloseMessage(CanisterCloseMessageContent {
reason: close_reason,
}),
);
}

CLIENTS_WAITING_FOR_KEEP_ALIVE.with(|set| {
set.borrow_mut().remove(client_key);
});
Expand All @@ -283,7 +306,10 @@ pub(crate) fn remove_client(client_key: &ClientKey) {
if let Some(registered_client) =
REGISTERED_CLIENTS.with(|map| map.borrow_mut().remove(client_key))
{
decrement_gateway_clients_count(&registered_client.gateway_principal);
decrement_gateway_clients_count(
&registered_client.gateway_principal,
close_reason.is_none(),
ilbertt marked this conversation as resolved.
Show resolved Hide resolved
);

let handlers = get_handlers_from_params();
handlers.call_on_close(OnCloseCallbackArgs {
Expand Down Expand Up @@ -383,12 +409,15 @@ pub(crate) fn get_cert_messages_empty() -> CanisterWsGetMessagesResult {
}

fn put_cert_for_message(key: String, value: &Vec<u8>) {
#[allow(unused_variables)]
let root_hash = CERT_TREE.with(|tree| {
let mut tree = tree.borrow_mut();
tree.insert(key.clone(), Sha256::digest(value).into());
labeled_hash(LABEL_WEBSOCKET, &tree.root_hash())
});

#[cfg(not(test))]
// executing this in tests fails because the tree is an IC-specific implementation
set_certified_data(&root_hash);
}

Expand Down Expand Up @@ -419,18 +448,13 @@ pub(crate) fn delete_old_messages_for_gateway(
) -> Result<(), String> {
let ack_interval_ms = get_params().send_ack_interval_ms;

// allow unused variables because sometimes the compiler complains about unused variables
// since it is only used in production code
#[allow(unused_variables)]
let deleted_messages_keys = REGISTERED_GATEWAYS.with(|map| {
map.borrow_mut()
.get_mut(gateway_principal)
.ok_or_else(|| WsError::GatewayNotRegistered { gateway_principal }.to_string())
.and_then(|g| Ok(g.delete_old_messages(MESSAGES_TO_DELETE_COUNT, ack_interval_ms)))
})?;

#[cfg(not(test))]
// executing this in tests fails because the tree is an IC-specific implementation
CERT_TREE.with(|tree| {
for key in deleted_messages_keys {
tree.borrow_mut().delete(key.as_ref());
Expand Down Expand Up @@ -487,7 +511,8 @@ pub(crate) fn handle_received_service_message(
let decoded = WebsocketServiceMessageContent::from_candid_bytes(content)?;
match decoded {
WebsocketServiceMessageContent::OpenMessage(_)
| WebsocketServiceMessageContent::AckMessage(_) => {
| WebsocketServiceMessageContent::AckMessage(_)
| WebsocketServiceMessageContent::CloseMessage(_) => {
WsError::InvalidServiceMessage.to_string_result()
},
WebsocketServiceMessageContent::KeepAliveMessage(keep_alive_message) => {
Expand Down
80 changes: 62 additions & 18 deletions src/ic-websocket-cdk/src/tests/integration_tests/b_ws_message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,19 @@ use std::ops::Deref;
use candid::encode_one;

use crate::{
errors::WsError, tests::common::generate_random_principal, CanisterAckMessageContent,
CanisterWsMessageArguments, CanisterWsMessageResult, ClientKeepAliveMessageContent, ClientKey,
WebsocketServiceMessageContent,
errors::WsError,
tests::{
common::generate_random_principal,
integration_tests::utils::{
actor::{ws_close::call_ws_close, ws_get_messages::call_ws_get_messages_with_panic},
clients::GATEWAY_1,
messages::check_canister_message_has_close_reason,
},
},
types::CloseMessageReason,
CanisterAckMessageContent, CanisterWsCloseArguments, CanisterWsCloseResult,
CanisterWsGetMessagesArguments, CanisterWsMessageArguments, CanisterWsMessageResult,
ClientKeepAliveMessageContent, ClientKey, WebsocketServiceMessageContent,
};

use super::utils::{
Expand Down Expand Up @@ -44,6 +54,10 @@ fn test_1_fails_if_client_is_not_registered() {
#[test]
fn test_2_fails_if_client_sends_a_message_with_a_different_client_key() {
let client_1_key = CLIENT_1_KEY.deref();
// first, reset the canister
get_test_env().reset_canister_with_default_params();
// second, open a connection for client 1
call_ws_open_for_client_key_with_panic(client_1_key);

let wrong_client_key = ClientKey {
client_principal: generate_random_principal(),
Expand Down Expand Up @@ -93,6 +107,11 @@ fn test_2_fails_if_client_sends_a_message_with_a_different_client_key() {
#[test]
fn test_3_should_send_a_message_from_a_registered_client() {
let client_1_key = CLIENT_1_KEY.deref();
// first, reset the canister
get_test_env().reset_canister_with_default_params();
// second, open a connection for client 1
call_ws_open_for_client_key_with_panic(client_1_key);

let res = call_ws_message(
&client_1_key.client_principal,
CanisterWsMessageArguments {
Expand All @@ -105,8 +124,13 @@ fn test_3_should_send_a_message_from_a_registered_client() {
#[test]
fn test_4_fails_if_client_sends_a_message_with_a_wrong_sequence_number() {
let client_1_key = CLIENT_1_KEY.deref();
let wrong_sequence_number = 1; // the message with sequence number 1 has already been sent in the previous test
let expected_sequence_number = 2; // the next valid sequence number
// first, reset the canister
get_test_env().reset_canister_with_default_params();
// second, open a connection for client 1
call_ws_open_for_client_key_with_panic(client_1_key);

let wrong_sequence_number = 2; // the message with sequence number 1 has already been sent in the previous test
let expected_sequence_number = 1; // the next valid sequence number
let res = call_ws_message(
&client_1_key.client_principal,
CanisterWsMessageArguments {
Expand All @@ -124,28 +148,41 @@ fn test_4_fails_if_client_sends_a_message_with_a_wrong_sequence_number() {
)
);

// check if the client has been removed
let res = call_ws_message(
&client_1_key.client_principal,
CanisterWsMessageArguments {
msg: create_websocket_message(client_1_key, 1, None, false), // the sequence number doesn't matter here because the method fails before checking it
// check if the gateway put the close message in the queue
let msgs = call_ws_get_messages_with_panic(
GATEWAY_1.deref(),
CanisterWsGetMessagesArguments { nonce: 1 }, // skip the first open message
);
check_canister_message_has_close_reason(
&msgs.messages[0],
CloseMessageReason::WrongSequenceNumber,
);

// the gateway should still be between the registered gateways
// so calling the ws_close endpoint should return the ClientKeyNotConnected error
let res = call_ws_close(
GATEWAY_1.deref(),
CanisterWsCloseArguments {
client_key: client_1_key.clone(),
},
);
assert_eq!(
res,
CanisterWsMessageResult::Err(
WsError::ClientPrincipalNotConnected {
client_principal: &client_1_key.client_principal
CanisterWsCloseResult::Err(
WsError::ClientKeyNotConnected {
client_key: &client_1_key
}
.to_string()
)
)
);
}

#[test]
fn test_5_fails_if_client_sends_a_wrong_service_message() {
let client_1_key = CLIENT_1_KEY.deref();
// first, open the connection again for client 1
// first, reset the canister
get_test_env().reset_canister_with_default_params();
// second, open a connection for client 1
call_ws_open_for_client_key_with_panic(client_1_key);

// fail with wrong content encoding
Expand All @@ -160,8 +197,10 @@ fn test_5_fails_if_client_sends_a_wrong_service_message() {
),
},
);
let err = res.err().unwrap();
assert!(err.starts_with("Error decoding service message content:"));
assert!(res
.err()
.unwrap()
.starts_with("Error decoding service message content:"));

// fail with wrong service message variant
let wrong_service_message =
Expand Down Expand Up @@ -190,6 +229,11 @@ fn test_5_fails_if_client_sends_a_wrong_service_message() {
#[test]
fn test_6_should_send_a_service_message_from_a_registered_client() {
let client_1_key = CLIENT_1_KEY.deref();
// first, reset the canister
get_test_env().reset_canister_with_default_params();
// second, open a connection for client 1
call_ws_open_for_client_key_with_panic(client_1_key);

let client_service_message =
WebsocketServiceMessageContent::KeepAliveMessage(ClientKeepAliveMessageContent {
last_incoming_sequence_num: 0,
Expand All @@ -199,7 +243,7 @@ fn test_6_should_send_a_service_message_from_a_registered_client() {
CanisterWsMessageArguments {
msg: create_websocket_message(
client_1_key,
3,
1,
Some(encode_websocket_service_message_content(
&client_service_message,
)),
Expand Down
Loading