Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

release/0.3.2 #5

Merged
merged 18 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/ic-websocket-cdk/service_messages.did
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,20 @@ type ClientKeepAliveMessageContent = record {
last_incoming_sequence_num : nat64;
};

type CloseMessageReason = variant {
WrongSequenceNumber;
InvalidServiceMessage;
KeepAliveTimeout;
ClosedByApplication;
};

type CanisterCloseMessageContent = record {
reason : CloseMessageReason;
};

type WebsocketServiceMessageContent = variant {
OpenMessage : CanisterOpenMessageContent;
AckMessage : CanisterAckMessageContent;
KeepAliveMessage : ClientKeepAliveMessageContent;
CloseMessage : CanisterCloseMessageContent;
};
47 changes: 37 additions & 10 deletions src/ic-websocket-cdk/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@ mod utils;

use state::*;
use timers::*;
#[allow(deprecated)]
pub use types::CanisterWsSendResult;
use types::*;
pub use types::{
CanisterWsCloseArguments, CanisterWsCloseResult, CanisterWsGetMessagesArguments,
CanisterWsGetMessagesResult, CanisterWsMessageArguments, CanisterWsMessageResult,
CanisterWsOpenArguments, CanisterWsOpenResult, CanisterWsSendResult, ClientPrincipal,
CanisterCloseResult, CanisterSendResult, CanisterWsCloseArguments, CanisterWsCloseResult,
CanisterWsGetMessagesArguments, CanisterWsGetMessagesResult, CanisterWsMessageArguments,
CanisterWsMessageResult, CanisterWsOpenArguments, CanisterWsOpenResult, ClientPrincipal,
OnCloseCallbackArgs, OnMessageCallbackArgs, OnOpenCallbackArgs, WsHandlers, WsInitParams,
};

Expand All @@ -28,8 +30,12 @@ const LABEL_WEBSOCKET: &[u8] = b"websocket";
const DEFAULT_MAX_NUMBER_OF_RETURNED_MESSAGES: usize = 50;
/// The default interval at which to send acknowledgements to the client.
const DEFAULT_SEND_ACK_INTERVAL_MS: u64 = 300_000; // 5 minutes
/// The maximum latency allowed between the client and the canister.
const MAX_ALLOWED_CONNECTION_LATENCY_MS: u64 = 30_000; // 30 seconds
ilbertt marked this conversation as resolved.
Show resolved Hide resolved
/// The default timeout to wait for the client to send a keep alive after receiving an acknowledgement.
const DEFAULT_CLIENT_KEEP_ALIVE_TIMEOUT_MS: u64 = 60_000; // 1 minute
const CLIENT_KEEP_ALIVE_TIMEOUT_MS: u64 = 2 * MAX_ALLOWED_CONNECTION_LATENCY_MS;
/// Same as [CLIENT_KEEP_ALIVE_TIMEOUT_MS], but in nanoseconds.
const CLIENT_KEEP_ALIVE_TIMEOUT_NS: u64 = CLIENT_KEEP_ALIVE_TIMEOUT_MS * 1_000_000;

/// The initial nonce for outgoing messages.
const INITIAL_OUTGOING_MESSAGE_NONCE: u64 = 0;
Expand Down Expand Up @@ -88,7 +94,7 @@ pub fn ws_open(args: CanisterWsOpenArguments) -> CanisterWsOpenResult {
// Do nothing
},
Ok(old_client_key) => {
remove_client(&old_client_key);
remove_client(&old_client_key, None);
},
};

Expand All @@ -112,6 +118,9 @@ pub fn ws_open(args: CanisterWsOpenArguments) -> CanisterWsOpenResult {
}

/// Handles the WS connection close event received from the WS Gateway.
///
/// If you want to close the connection with the client in your logic,
/// use the [close] function instead.
pub fn ws_close(args: CanisterWsCloseArguments) -> CanisterWsCloseResult {
let gateway_principal = caller();

Expand All @@ -124,7 +133,7 @@ pub fn ws_close(args: CanisterWsCloseArguments) -> CanisterWsCloseResult {
// check if the client is registered to the gateway that is closing the connection
check_client_registered_to_gateway(&args.client_key, &gateway_principal)?;

remove_client(&args.client_key);
remove_client(&args.client_key, None);

Ok(())
}
Expand Down Expand Up @@ -188,7 +197,7 @@ pub fn ws_message<T: CandidType + for<'a> Deserialize<'a>>(
.eq(&expected_sequence_num)
.then_some(())
.ok_or_else(|| {
remove_client(&client_key);
remove_client(&client_key, Some(CloseMessageReason::WrongSequenceNumber));

WsError::IncomingSequenceNumberWrong {
expected_sequence_num,
Expand Down Expand Up @@ -231,7 +240,7 @@ pub fn ws_get_messages(args: CanisterWsGetMessagesArguments) -> CanisterWsGetMes
/// This example is the serialize equivalent of the [OnMessageCallbackArgs's example](struct.OnMessageCallbackArgs.html#example) deserialize one.
/// ```rust
/// use candid::{encode_one, CandidType, Principal};
/// use ic_websocket_cdk::ws_send;
/// use ic_websocket_cdk::send;
/// use serde::Deserialize;
///
/// #[derive(CandidType, Deserialize)]
Expand All @@ -247,13 +256,31 @@ pub fn ws_get_messages(args: CanisterWsGetMessagesArguments) -> CanisterWsGetMes
/// };
///
/// let msg_bytes = encode_one(&my_message).unwrap();
/// ws_send(my_client_principal, msg_bytes);
/// send(my_client_principal, msg_bytes);
/// ```
pub fn ws_send(client_principal: ClientPrincipal, msg_bytes: Vec<u8>) -> CanisterWsSendResult {
pub fn send(client_principal: ClientPrincipal, msg_bytes: Vec<u8>) -> CanisterSendResult {
let client_key = get_client_key_from_principal(&client_principal)?;
_ws_send(&client_key, msg_bytes, false)
}

#[deprecated(since = "0.3.2", note = "use `ic_websocket_cdk::send` instead")]
#[allow(deprecated)]
/// Deprecated: use [send] instead.
pub fn ws_send(client_principal: ClientPrincipal, msg_bytes: Vec<u8>) -> CanisterWsSendResult {
send(client_principal, msg_bytes)
}

/// Closes the connection with the client.
///
/// This function **must not** be called in the `on_close` callback.
pub fn close(client_principal: ClientPrincipal) -> CanisterCloseResult {
let client_key = get_client_key_from_principal(&client_principal)?;

remove_client(&client_key, Some(CloseMessageReason::ClosedByApplication));

Ok(())
}

/// Resets the internal state of the IC WebSocket CDK.
///
/// **Note:** You should only call this function in tests.
Expand Down
95 changes: 70 additions & 25 deletions src/ic-websocket-cdk/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::{
};

use candid::{encode_one, Principal};
#[allow(unused_imports)]
use ic_cdk::api::{data_certificate, set_certified_data};
use ic_certified_map::{labeled, labeled_hash, AsHashTree, Hash as ICHash, RbTree};
use serde::Serialize;
Expand All @@ -28,7 +29,7 @@ thread_local! {
/// Maps the client's key to the expected sequence number of the next incoming message (from that client).
/* flexible */ pub(crate) static INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP: RefCell<HashMap<ClientKey, u64>> = RefCell::new(HashMap::new());
/// Keeps track of the Merkle tree used for certified queries
/* flexible */ static CERT_TREE: RefCell<RbTree<String, ICHash>> = RefCell::new(RbTree::new());
/* flexible */ pub(crate) static CERT_TREE: RefCell<RbTree<String, ICHash>> = RefCell::new(RbTree::new());
/// Keeps track of the principals of the WS Gateways that poll the canister
/* flexible */ pub(crate) static REGISTERED_GATEWAYS: RefCell<HashMap<GatewayPrincipal, RegisteredGateway>> = RefCell::new(HashMap::new());
/// The parameters passed in the CDK initialization
Expand All @@ -44,7 +45,7 @@ pub(crate) fn reset_internal_state() {

// for each client, call the on_close handler before clearing the map
for client_key in client_keys_to_remove {
remove_client(&client_key);
remove_client(&client_key, None);
}

// make sure all the maps are cleared
Expand Down Expand Up @@ -79,18 +80,32 @@ pub(crate) fn increment_gateway_clients_count(gateway_principal: GatewayPrincipa
});
}

/// Decrements the clients connected count for the given gateway.
/// If there are no more clients connected, the gateway is removed from the list of registered gateways.
pub(crate) fn decrement_gateway_clients_count(gateway_principal: &GatewayPrincipal) {
REGISTERED_GATEWAYS.with(|map| {
/// Decrements the clients connected count for the given gateway, if it exists.
///
/// If `remove_if_empty` is true, the gateway is removed from the list of registered gateways
ilbertt marked this conversation as resolved.
Show resolved Hide resolved
/// if it has no clients connected.
pub(crate) fn decrement_gateway_clients_count(
gateway_principal: &GatewayPrincipal,
remove_if_empty: bool,
) {
let messages_keys_to_delete = REGISTERED_GATEWAYS.with(|map| {
let mut map = map.borrow_mut();
let g = map.get_mut(gateway_principal).unwrap(); // gateway must be registered at this point
let clients_count = g.decrement_clients_count();
if let Some(g) = map.get_mut(gateway_principal) {
let clients_count = g.decrement_clients_count();

if remove_if_empty && clients_count == 0 {
let g = map.remove(gateway_principal).unwrap();

if clients_count == 0 {
map.remove(gateway_principal);
return Some(g.messages_queue.iter().map(|m| m.key.clone()).collect());
}
}

None
});

if let Some(messages_keys_to_delete) = messages_keys_to_delete {
delete_keys_from_cert_tree(messages_keys_to_delete);
}
}

pub(crate) fn get_registered_gateway(
Expand Down Expand Up @@ -266,7 +281,23 @@ pub(crate) fn add_client(client_key: ClientKey, new_client: RegisteredClient) {
/// Removes a client from the internal state
/// and call the on_close callback,
/// if the client was registered in the state.
pub(crate) fn remove_client(client_key: &ClientKey) {
///
/// If a `close_reason` is provided, it also sends a close message to the client,
/// so that the client can close the WS connection with the gateway.
///
/// If a `close_reason` is **not** provided, it also removes the gateway from the state
/// if it has no clients connected anymore.
pub(crate) fn remove_client(client_key: &ClientKey, close_reason: Option<CloseMessageReason>) {
if let Some(close_reason) = close_reason.clone() {
// ignore the error
let _ = send_service_message_to_client(
client_key,
&WebsocketServiceMessageContent::CloseMessage(CanisterCloseMessageContent {
reason: close_reason,
}),
);
}

CLIENTS_WAITING_FOR_KEEP_ALIVE.with(|set| {
set.borrow_mut().remove(client_key);
});
Expand All @@ -283,7 +314,10 @@ pub(crate) fn remove_client(client_key: &ClientKey) {
if let Some(registered_client) =
REGISTERED_CLIENTS.with(|map| map.borrow_mut().remove(client_key))
{
decrement_gateway_clients_count(&registered_client.gateway_principal);
decrement_gateway_clients_count(
&registered_client.gateway_principal,
close_reason.is_none(),
ilbertt marked this conversation as resolved.
Show resolved Hide resolved
);

let handlers = get_handlers_from_params();
handlers.call_on_close(OnCloseCallbackArgs {
Expand Down Expand Up @@ -382,13 +416,16 @@ pub(crate) fn get_cert_messages_empty() -> CanisterWsGetMessagesResult {
Ok(CanisterOutputCertifiedMessages::empty())
}

fn put_cert_for_message(key: String, value: &Vec<u8>) {
pub(crate) fn put_cert_for_message(key: String, value: &Vec<u8>) {
#[allow(unused_variables)]
let root_hash = CERT_TREE.with(|tree| {
let mut tree = tree.borrow_mut();
tree.insert(key.clone(), Sha256::digest(value).into());
labeled_hash(LABEL_WEBSOCKET, &tree.root_hash())
});

#[cfg(not(test))]
// executing this in unit tests fails because it's an IC-specific API
set_certified_data(&root_hash);
}

Expand All @@ -400,7 +437,7 @@ pub(crate) fn push_message_in_gateway_queue(
) -> Result<(), String> {
REGISTERED_GATEWAYS.with(|map| {
// messages in the queue are inserted with contiguous and increasing nonces
// (from beginning to end of the queue) as ws_send is called sequentially, the nonce
// (from beginning to end of the queue) as `send` is called sequentially, the nonce
// is incremented by one in each call, and the message is pushed at the end of the queue
map.borrow_mut()
.get_mut(gateway_principal)
Expand All @@ -419,25 +456,32 @@ pub(crate) fn delete_old_messages_for_gateway(
) -> Result<(), String> {
let ack_interval_ms = get_params().send_ack_interval_ms;

// allow unused variables because sometimes the compiler complains about unused variables
// since it is only used in production code
#[allow(unused_variables)]
let deleted_messages_keys = REGISTERED_GATEWAYS.with(|map| {
map.borrow_mut()
.get_mut(gateway_principal)
.ok_or_else(|| WsError::GatewayNotRegistered { gateway_principal }.to_string())
.and_then(|g| Ok(g.delete_old_messages(MESSAGES_TO_DELETE_COUNT, ack_interval_ms)))
})?;

#[cfg(not(test))]
// executing this in tests fails because the tree is an IC-specific implementation
CERT_TREE.with(|tree| {
for key in deleted_messages_keys {
tree.borrow_mut().delete(key.as_ref());
delete_keys_from_cert_tree(deleted_messages_keys);

Ok(())
}

pub(crate) fn delete_keys_from_cert_tree(keys: Vec<String>) {
#[allow(unused_variables)]
let root_hash = CERT_TREE.with(|tree| {
let mut tree = tree.borrow_mut();
for key in keys {
tree.delete(key.as_ref());
}
labeled_hash(LABEL_WEBSOCKET, &tree.root_hash())
});

Ok(())
// certify data with the new root hash
#[cfg(not(test))]
// executing this in unit tests fails because it's an IC-specific API
set_certified_data(&root_hash);
}

fn get_cert_for_range(first: &String, last: &String) -> (Vec<u8>, Vec<u8>) {
Expand Down Expand Up @@ -487,7 +531,8 @@ pub(crate) fn handle_received_service_message(
let decoded = WebsocketServiceMessageContent::from_candid_bytes(content)?;
match decoded {
WebsocketServiceMessageContent::OpenMessage(_)
| WebsocketServiceMessageContent::AckMessage(_) => {
| WebsocketServiceMessageContent::AckMessage(_)
| WebsocketServiceMessageContent::CloseMessage(_) => {
WsError::InvalidServiceMessage.to_string_result()
},
WebsocketServiceMessageContent::KeepAliveMessage(keep_alive_message) => {
Expand All @@ -510,7 +555,7 @@ pub(crate) fn _ws_send(
client_key: &ClientKey,
msg_bytes: Vec<u8>,
is_service_message: bool,
) -> CanisterWsSendResult {
) -> CanisterSendResult {
// get the registered client if it exists
let registered_client = get_registered_client(client_key)?;

Expand Down
Loading
Loading