diff --git a/Cargo.lock b/Cargo.lock index 56827d3..7f87101 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3260,6 +3260,7 @@ dependencies = [ "actix-ws", "anyhow", "futures-util", + "rand", "serde", "thiserror 2.0.6", "tokio", @@ -3278,6 +3279,7 @@ dependencies = [ "serde", "serde-aux", "sqlx", + "tokio", "tracing", "tracing-subscriber", "tracked-cancellations", diff --git a/crates/ws-auth/Cargo.toml b/crates/ws-auth/Cargo.toml index a0e467b..8b18db6 100644 --- a/crates/ws-auth/Cargo.toml +++ b/crates/ws-auth/Cargo.toml @@ -8,6 +8,7 @@ actix-web.workspace = true actix-ws.workspace = true anyhow.workspace = true futures-util.workspace = true +rand.workspace = true serde.workspace = true thiserror.workspace = true tokio.workspace = true diff --git a/crates/ws-auth/src/errors.rs b/crates/ws-auth/src/errors.rs index 6c12cb9..6acba33 100644 --- a/crates/ws-auth/src/errors.rs +++ b/crates/ws-auth/src/errors.rs @@ -2,7 +2,7 @@ use crate::WsId; use wykies_shared::host_branch::HostId; #[derive(thiserror::Error, Debug)] -pub enum WebSocketError { +pub enum WebSocketAuthError { /// Client was not expected to be trying to connect #[error("Unexpected Client")] UnexpectedClient { @@ -25,13 +25,13 @@ pub mod conversions { use super::*; use actix_web::http::StatusCode; - impl actix_web::error::ResponseError for WebSocketError { + impl actix_web::error::ResponseError for WebSocketAuthError { fn status_code(&self) -> StatusCode { match self { - WebSocketError::UnexpectedClient { .. } => StatusCode::BAD_REQUEST, - WebSocketError::InvalidToken { .. } => StatusCode::BAD_REQUEST, - WebSocketError::FailedToStartSession(_) => StatusCode::INTERNAL_SERVER_ERROR, - WebSocketError::UnexpectedError(_) => StatusCode::INTERNAL_SERVER_ERROR, + WebSocketAuthError::UnexpectedClient { .. } => StatusCode::BAD_REQUEST, + WebSocketAuthError::InvalidToken { .. } => StatusCode::BAD_REQUEST, + WebSocketAuthError::FailedToStartSession(_) => StatusCode::INTERNAL_SERVER_ERROR, + WebSocketAuthError::UnexpectedError(_) => StatusCode::INTERNAL_SERVER_ERROR, } } } diff --git a/crates/ws-auth/src/id.rs b/crates/ws-auth/src/id.rs new file mode 100644 index 0000000..fa41cec --- /dev/null +++ b/crates/ws-auth/src/id.rs @@ -0,0 +1,48 @@ +use rand::Rng as _; +use std::{hash::Hash, sync::Arc}; +use wykies_shared::session::UserSessionInfo; + +/// Distinguishes different types of Websocket services supported +#[derive(Debug, PartialEq, Eq, Clone, Copy, serde::Serialize, serde::Deserialize)] +pub struct WsId(u8); + +impl WsId { + #[cfg(test)] + pub(crate) const TEST1: Self = Self::new(1); + + pub const fn new(value: u8) -> Self { + Self(value) + } +} + +#[derive(Debug, Clone)] +/// Websocket Connection ID +/// Includes user session info as this is not available in the websocket context +/// Only the id is used for hashing and equality checks +pub struct WsConnId { + id: usize, + pub user_info: Arc, +} + +impl WsConnId { + pub fn new(user_info: Arc) -> Self { + Self { + id: rand::thread_rng().gen::(), + user_info, + } + } +} + +impl PartialEq for WsConnId { + fn eq(&self, other: &Self) -> bool { + self.id == other.id + } +} + +impl Eq for WsConnId {} + +impl Hash for WsConnId { + fn hash(&self, state: &mut H) { + self.id.hash(state); + } +} diff --git a/crates/ws-auth/src/lib.rs b/crates/ws-auth/src/lib.rs index 286fa00..5093fc7 100644 --- a/crates/ws-auth/src/lib.rs +++ b/crates/ws-auth/src/lib.rs @@ -1,20 +1,9 @@ -//! Provides authenticated access to web socket handlers +//! Provides authentication for web socket handlers mod errors; +mod id; mod manager; -pub use errors::WebSocketError; +pub use errors::WebSocketAuthError; +pub use id::{WsConnId, WsId}; pub use manager::{validate_ws_connection, AuthTokenManager}; - -/// Distinguishes different types of Websocket services supported -#[derive(Debug, PartialEq, Eq, Clone, Copy, serde::Serialize, serde::Deserialize)] -pub struct WsId(u8); - -impl WsId { - #[cfg(test)] - const TEST1: Self = Self::new(1); - - pub const fn new(value: u8) -> Self { - Self(value) - } -} diff --git a/crates/wykies-server/Cargo.toml b/crates/wykies-server/Cargo.toml index 0a15a08..c8a6d0c 100644 --- a/crates/wykies-server/Cargo.toml +++ b/crates/wykies-server/Cargo.toml @@ -11,6 +11,7 @@ secrecy.workspace = true serde.workspace = true serde-aux.workspace = true sqlx.workspace = true +tokio.workspace = true tracing.workspace = true tracing-subscriber.workspace = true tracked-cancellations.workspace = true diff --git a/crates/wykies-server/src/lib.rs b/crates/wykies-server/src/lib.rs index 709a486..1fbe2e2 100644 --- a/crates/wykies-server/src/lib.rs +++ b/crates/wykies-server/src/lib.rs @@ -13,12 +13,14 @@ use wykies_shared::telemetry; mod configuration; mod macros; pub mod plugin; +pub mod ws; #[cfg_attr(feature = "mysql", path = "db_types_mysql.rs")] pub mod db_types; pub use configuration::{get_configuration, Configuration, DatabaseSettings, WebSocketSettings}; +// TODO 1: Move server init into module pub struct ServerInit { pub cancellation_token: TrackedCancellationToken, pub cancellation_tracker: CancellationTracker, diff --git a/crates/wykies-server/src/ws.rs b/crates/wykies-server/src/ws.rs new file mode 100644 index 0000000..a6541b0 --- /dev/null +++ b/crates/wykies-server/src/ws.rs @@ -0,0 +1,3 @@ +mod heartbeat; + +pub use heartbeat::HeartbeatConfig; diff --git a/crates/wykies-server/src/ws/heartbeat.rs b/crates/wykies-server/src/ws/heartbeat.rs new file mode 100644 index 0000000..2d394e6 --- /dev/null +++ b/crates/wykies-server/src/ws/heartbeat.rs @@ -0,0 +1,37 @@ +use std::{fmt::Display, time::Duration}; + +use crate::WebSocketSettings; +use tracing::instrument; +use wykies_time::Seconds; + +#[derive(Debug, Clone, Copy)] +pub struct HeartbeatConfig { + interval_time: Seconds, + client_timeout: Seconds, +} + +impl HeartbeatConfig { + #[instrument(ret)] + pub fn new(interval_time: Seconds, ws_config: &WebSocketSettings) -> Self { + let times_missed_allowance = ws_config.heartbeat_times_missed_allowance.into(); + let additional_buffer_time = ws_config.heartbeat_additional_buffer_time_secs; + let client_timeout = interval_time * times_missed_allowance + additional_buffer_time; + + Self { + interval_time, + client_timeout, + } + } + + pub fn interval(&self) -> tokio::time::Interval { + tokio::time::interval(self.interval_time.into()) + } + + pub fn client_timeout(&self) -> Duration { + self.client_timeout.into() + } + + pub fn client_timeout_display(&self) -> impl Display { + format!("{} sec", self.client_timeout) + } +}