diff --git a/src/ic-websocket-cdk/src/lib.rs b/src/ic-websocket-cdk/src/lib.rs index a205591..d0e4acc 100644 --- a/src/ic-websocket-cdk/src/lib.rs +++ b/src/ic-websocket-cdk/src/lib.rs @@ -81,6 +81,17 @@ pub fn ws_open(args: CanisterWsOpenArguments) -> CanisterWsOpenResult { .to_string_result() })?; + // check if there's a client already registered with the same principal + // and remove it if there is + match get_client_key_from_principal(&client_key.client_principal) { + Err(_) => { + // Do nothing + }, + Ok(old_client_key) => { + remove_client(&old_client_key); + }, + }; + // initialize client maps let new_client = RegisteredClient::new(args.gateway_principal); add_client(client_key.clone(), new_client); diff --git a/src/ic-websocket-cdk/src/state.rs b/src/ic-websocket-cdk/src/state.rs index 4d914b9..f521f57 100644 --- a/src/ic-websocket-cdk/src/state.rs +++ b/src/ic-websocket-cdk/src/state.rs @@ -263,6 +263,9 @@ pub(crate) fn add_client(client_key: ClientKey, new_client: RegisteredClient) { increment_gateway_clients_count(new_client.gateway_principal); } +/// Removes a client from the internal state +/// and call the on_close callback, +/// if the client was registered in the state. pub(crate) fn remove_client(client_key: &ClientKey) { CLIENTS_WAITING_FOR_KEEP_ALIVE.with(|set| { set.borrow_mut().remove(client_key); @@ -277,14 +280,16 @@ pub(crate) fn remove_client(client_key: &ClientKey) { map.borrow_mut().remove(client_key); }); - let registered_client = - REGISTERED_CLIENTS.with(|map| map.borrow_mut().remove(client_key).unwrap()); - decrement_gateway_clients_count(®istered_client.gateway_principal); + if let Some(registered_client) = + REGISTERED_CLIENTS.with(|map| map.borrow_mut().remove(client_key)) + { + decrement_gateway_clients_count(®istered_client.gateway_principal); - let handlers = get_handlers_from_params(); - handlers.call_on_close(OnCloseCallbackArgs { - client_principal: client_key.client_principal, - }); + let handlers = get_handlers_from_params(); + handlers.call_on_close(OnCloseCallbackArgs { + client_principal: client_key.client_principal, + }); + }; } pub(crate) fn format_message_for_gateway_key( diff --git a/src/ic-websocket-cdk/src/tests/unit_tests/mod.rs b/src/ic-websocket-cdk/src/tests/unit_tests/mod.rs index 2d1f2c0..04fdac4 100644 --- a/src/ic-websocket-cdk/src/tests/unit_tests/mod.rs +++ b/src/ic-websocket-cdk/src/tests/unit_tests/mod.rs @@ -574,6 +574,14 @@ proptest! { REGISTERED_GATEWAYS.with(|map| map.borrow_mut().clear()); } + #[test] + fn test_remove_client_nonexistent(test_client_key in any::().prop_map(|_| common::get_random_client_key())) { + let res = panic::catch_unwind(|| { + remove_client(&test_client_key); + }); + prop_assert!(res.is_ok()); + } + #[test] fn test_remove_client(test_client_key in any::().prop_map(|_| common::get_random_client_key())) { // Set up