Skip to content

Commit

Permalink
refactor: timers small refactor, comments
Browse files Browse the repository at this point in the history
  • Loading branch information
ilbertt committed Dec 3, 2023
1 parent ba8bef1 commit fe627fb
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 17 deletions.
8 changes: 3 additions & 5 deletions src/ic-websocket-cdk/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ pub fn init(params: WsInitParams) {
// set the handlers specified by the canister that the CDK uses to manage the IC WebSocket connection
set_params(params.clone());

// reset initial timers
reset_timers();
// cancel possibly running timers
cancel_timers();

// schedule a timer that will send an acknowledgement message to clients
schedule_send_ack_to_clients();
Expand Down Expand Up @@ -107,7 +107,7 @@ pub fn ws_close(args: CanisterWsCloseArguments) -> CanisterWsCloseResult {
// check if the gateway is registered
check_is_gateway_registered(&gateway_principal)?;

// check if client registered its principal by calling ws_open
// check if client registered itself by calling ws_open
check_registered_client_exists(&args.client_key)?;

// check if the client is registered to the gateway that is closing the connection
Expand Down Expand Up @@ -149,7 +149,6 @@ pub fn ws_message<T: CandidType + for<'a> Deserialize<'a>>(
_message_type: Option<T>,
) -> CanisterWsMessageResult {
let client_principal = caller();
// check if client registered its principal by calling ws_open
let registered_client_key = get_client_key_from_principal(&client_principal)?;

let WebsocketMessage {
Expand Down Expand Up @@ -203,7 +202,6 @@ pub fn ws_message<T: CandidType + for<'a> Deserialize<'a>>(

/// Returns messages to the WS Gateway in response of a polling iteration.
pub fn ws_get_messages(args: CanisterWsGetMessagesArguments) -> CanisterWsGetMessagesResult {
// check if the caller of this method is the WS Gateway that has been set during the initialization of the SDK
let gateway_principal = caller();
if !is_registered_gateway(&gateway_principal) {
return get_cert_messages_empty();
Expand Down
5 changes: 0 additions & 5 deletions src/ic-websocket-cdk/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use std::{

use candid::{encode_one, Principal};
use ic_cdk::api::{data_certificate, set_certified_data};
use ic_cdk_timers::TimerId;
use ic_certified_map::{labeled, labeled_hash, AsHashTree, Hash as ICHash, RbTree};
use serde::Serialize;
use serde_cbor::Serializer;
Expand Down Expand Up @@ -34,10 +33,6 @@ thread_local! {
/* flexible */ pub(crate) static REGISTERED_GATEWAYS: RefCell<HashMap<GatewayPrincipal, RegisteredGateway>> = RefCell::new(HashMap::new());
/// The parameters passed in the CDK initialization
/* flexible */ pub(crate) static PARAMS: RefCell<WsInitParams> = RefCell::new(WsInitParams::default());
/// The acknowledgement active timer.
/* flexible */ pub(crate) static ACK_TIMER: Rc<RefCell<Option<TimerId>>> = Rc::new(RefCell::new(None));
/// The keep alive active timer.
/* flexible */ pub(crate) static KEEP_ALIVE_TIMER: Rc<RefCell<Option<TimerId>>> = Rc::new(RefCell::new(None));
}

/// Resets all RefCells to their initial state.
Expand Down
22 changes: 15 additions & 7 deletions src/ic-websocket-cdk/src/timers.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use ic_cdk_timers::{clear_timer, TimerId};
use ic_cdk_timers::{set_timer, set_timer_interval};
use std::cell::RefCell;
use std::rc::Rc;
use std::time::Duration;

Expand All @@ -8,12 +9,19 @@ use crate::state::*;
use crate::types::*;
use crate::utils::*;

thread_local! {
/// The acknowledgement active timer.
/* flexible */ pub(crate) static ACK_TIMER: RefCell<Option<TimerId>> = RefCell::new(None);
/// The keep alive active timer.
/* flexible */ pub(crate) static KEEP_ALIVE_TIMER: RefCell<Option<TimerId>> = RefCell::new(None);
}

fn put_ack_timer_id(timer_id: TimerId) {
ACK_TIMER.with(|timer| timer.borrow_mut().replace(timer_id));
}

fn reset_ack_timer() {
if let Some(t_id) = ACK_TIMER.with(Rc::clone).borrow_mut().take() {
fn cancel_ack_timer() {
if let Some(t_id) = ACK_TIMER.with(|timer| timer.borrow_mut().take()) {
clear_timer(t_id);
}
}
Expand All @@ -22,15 +30,15 @@ fn put_keep_alive_timer_id(timer_id: TimerId) {
KEEP_ALIVE_TIMER.with(|timer| timer.borrow_mut().replace(timer_id));
}

fn reset_keep_alive_timer() {
if let Some(t_id) = KEEP_ALIVE_TIMER.with(Rc::clone).borrow_mut().take() {
fn cancel_keep_alive_timer() {
if let Some(t_id) = KEEP_ALIVE_TIMER.with(|timer| timer.borrow_mut().take()) {
clear_timer(t_id);
}
}

pub(crate) fn reset_timers() {
reset_ack_timer();
reset_keep_alive_timer();
pub(crate) fn cancel_timers() {
cancel_ack_timer();
cancel_keep_alive_timer();
}

/// Start an interval to send an acknowledgement messages to the clients.
Expand Down

0 comments on commit fe627fb

Please sign in to comment.