Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prepare for platform-dependant websockets backend #441

Merged
merged 5 commits into from
Nov 19, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions src/gateway/backend_tungstenite.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
use futures_util::{
stream::{SplitSink, SplitStream},
StreamExt,
};
use tokio::net::TcpStream;
use tokio_tungstenite::{
connect_async_tls_with_config, tungstenite, Connector, MaybeTlsStream, WebSocketStream,
};

use super::GatewayMessage;
use crate::errors::GatewayError;

#[derive(Debug, Clone)]
pub struct WebSocketBackend;

// These could be made into inherent associated types when that's stabilized
pub type WsSink = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tungstenite::Message>;
pub type WsStream = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;

impl WebSocketBackend {
pub async fn new(
websocket_url: &str,
) -> Result<(WsSink, WsStream), crate::errors::GatewayError> {
let mut roots = rustls::RootCertStore::empty();
for cert in rustls_native_certs::load_native_certs().expect("could not load platform certs")
{
roots.add(&rustls::Certificate(cert.0)).unwrap();
}
let (websocket_stream, _) = match connect_async_tls_with_config(
websocket_url,
None,
false,
Some(Connector::Rustls(
rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(roots)
.with_no_client_auth()
.into(),
)),
)
.await
{
Ok(websocket_stream) => websocket_stream,
Err(e) => {
return Err(GatewayError::CannotConnect {
error: e.to_string(),
})
}
};

Ok(websocket_stream.split())
}
Fixed Show fixed Hide fixed
}

impl From<GatewayMessage> for tungstenite::Message {
fn from(message: GatewayMessage) -> Self {
Self::Text(message.0)
}
}

impl From<tungstenite::Message> for GatewayMessage {
fn from(value: tungstenite::Message) -> Self {
Self(value.to_string())
}
}
86 changes: 22 additions & 64 deletions src/gateway/gateway.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use self::event::Events;
use super::handle::GatewayHandle;
use super::heartbeat::HeartbeatHandler;
use super::*;
use crate::types::{
self, AutoModerationRule, AutoModerationRuleUpdate, Channel, ChannelCreate, ChannelDelete,
Expand All @@ -10,15 +12,8 @@ use crate::types::{
pub struct Gateway {
events: Arc<Mutex<Events>>,
heartbeat_handler: HeartbeatHandler,
websocket_send: Arc<
Mutex<
SplitSink<
WebSocketStream<MaybeTlsStream<TcpStream>>,
tokio_tungstenite::tungstenite::Message,
>,
>,
>,
websocket_receive: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
websocket_send: Arc<Mutex<WsSink>>,
websocket_receive: WsStream,
kill_send: tokio::sync::broadcast::Sender<()>,
store: Arc<Mutex<HashMap<Snowflake, Arc<RwLock<ObservableObject>>>>>,
url: String,
Expand All @@ -27,34 +22,7 @@ pub struct Gateway {
impl Gateway {
#[allow(clippy::new_ret_no_self)]
pub async fn new(websocket_url: String) -> Result<GatewayHandle, GatewayError> {
let mut roots = rustls::RootCertStore::empty();
for cert in rustls_native_certs::load_native_certs().expect("could not load platform certs")
{
roots.add(&rustls::Certificate(cert.0)).unwrap();
}
let (websocket_stream, _) = match connect_async_tls_with_config(
&websocket_url,
None,
false,
Some(Connector::Rustls(
rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(roots)
.with_no_client_auth()
.into(),
)),
)
.await
{
Ok(websocket_stream) => websocket_stream,
Err(e) => {
return Err(GatewayError::CannotConnect {
error: e.to_string(),
})
}
};

let (websocket_send, mut websocket_receive) = websocket_stream.split();
let (websocket_send, mut websocket_receive) = WebSocketBackend::new(&websocket_url).await?;

let shared_websocket_send = Arc::new(Mutex::new(websocket_send));

Expand All @@ -63,9 +31,8 @@ impl Gateway {

// Wait for the first hello and then spawn both tasks so we avoid nested tasks
// This automatically spawns the heartbeat task, but from the main thread
let msg = websocket_receive.next().await.unwrap().unwrap();
let gateway_payload: types::GatewayReceivePayload =
serde_json::from_str(msg.to_text().unwrap()).unwrap();
let msg: GatewayMessage = websocket_receive.next().await.unwrap().unwrap().into();
let gateway_payload: types::GatewayReceivePayload = serde_json::from_str(&msg.0).unwrap();

if gateway_payload.op_code != GATEWAY_HELLO {
return Err(GatewayError::NonHelloOnInitiate {
Expand Down Expand Up @@ -120,8 +87,7 @@ impl Gateway {

// This if chain can be much better but if let is unstable on stable rust
if let Some(Ok(message)) = msg {
self.handle_message(GatewayMessage::from_tungstenite_message(message))
.await;
self.handle_message(message.into()).await;
continue;
}

Expand All @@ -134,7 +100,7 @@ impl Gateway {
/// Closes the websocket connection and stops all tasks
async fn close(&mut self) {
self.kill_send.send(()).unwrap();
self.websocket_send.lock().await.close().await.unwrap();
let _ = self.websocket_send.lock().await.close().await;
}

/// Deserializes and updates a dispatched event, when we already know its type;
Expand All @@ -156,31 +122,23 @@ impl Gateway {

/// This handles a message as a websocket event and updates its events along with the events' observers
pub async fn handle_message(&mut self, msg: GatewayMessage) {
if msg.is_empty() {
if msg.0.is_empty() {
return;
}

if !msg.is_error() && !msg.is_payload() {
warn!(
"Message unrecognised: {:?}, please open an issue on the chorus github",
msg.message.to_string()
);
return;
}

if msg.is_error() {
let error = msg.error().unwrap();

warn!("GW: Received error {:?}, connection will close..", error);

self.close().await;

self.events.lock().await.error.notify(error).await;

let Ok(gateway_payload) = msg.payload() else {
if let Some(error) = msg.error() {
warn!("GW: Received error {:?}, connection will close..", error);
self.close().await;
self.events.lock().await.error.notify(error).await;
} else {
warn!(
"Message unrecognised: {:?}, please open an issue on the chorus github",
msg.0
);
}
return;
}

let gateway_payload = msg.payload().unwrap();
};

// See https://discord.com/developers/docs/topics/opcodes-and-status-codes#gateway-gateway-opcodes
match gateway_payload.op_code {
Expand Down
14 changes: 3 additions & 11 deletions src/gateway/handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,7 @@ use crate::types::{self, Composite};
pub struct GatewayHandle {
pub url: String,
pub events: Arc<Mutex<Events>>,
pub websocket_send: Arc<
Mutex<
SplitSink<
WebSocketStream<MaybeTlsStream<TcpStream>>,
tokio_tungstenite::tungstenite::Message,
>,
>,
>,
pub websocket_send: Arc<Mutex<WsSink>>,
/// Tells gateway tasks to close
pub(super) kill_send: tokio::sync::broadcast::Sender<()>,
pub(crate) store: Arc<Mutex<HashMap<Snowflake, Arc<RwLock<ObservableObject>>>>>,
Expand All @@ -32,13 +25,12 @@ impl GatewayHandle {
};

let payload_json = serde_json::to_string(&gateway_payload).unwrap();

let message = tokio_tungstenite::tungstenite::Message::text(payload_json);
let message = GatewayMessage(payload_json);

self.websocket_send
.lock()
.await
.send(message)
.send(message.into())
.await
.unwrap();
}
Expand Down
35 changes: 7 additions & 28 deletions src/gateway/heartbeat.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::types;

use super::*;
use crate::types;

/// The amount of time we wait for a heartbeat ack before resending our heartbeat in ms
const HEARTBEAT_ACK_TIMEOUT: u64 = 2000;
Expand All @@ -20,27 +19,14 @@ pub(super) struct HeartbeatHandler {
impl HeartbeatHandler {
pub fn new(
heartbeat_interval: Duration,
websocket_tx: Arc<
Mutex<
SplitSink<
WebSocketStream<MaybeTlsStream<TcpStream>>,
tokio_tungstenite::tungstenite::Message,
>,
>,
>,
websocket_tx: Arc<Mutex<WsSink>>,
kill_rc: tokio::sync::broadcast::Receiver<()>,
) -> HeartbeatHandler {
) -> Self {
let (send, receive) = tokio::sync::mpsc::channel(32);
let kill_receive = kill_rc.resubscribe();

let handle: JoinHandle<()> = task::spawn(async move {
HeartbeatHandler::heartbeat_task(
websocket_tx,
heartbeat_interval,
receive,
kill_receive,
)
.await;
Self::heartbeat_task(websocket_tx, heartbeat_interval, receive, kill_receive).await;
});

Self {
Expand All @@ -55,14 +41,7 @@ impl HeartbeatHandler {
/// Can be killed by the kill broadcast;
/// If the websocket is closed, will die out next time it tries to send a heartbeat;
pub async fn heartbeat_task(
websocket_tx: Arc<
Mutex<
SplitSink<
WebSocketStream<MaybeTlsStream<TcpStream>>,
tokio_tungstenite::tungstenite::Message,
>,
>,
>,
websocket_tx: Arc<Mutex<WsSink>>,
heartbeat_interval: Duration,
mut receive: tokio::sync::mpsc::Receiver<HeartbeatThreadCommunication>,
mut kill_receive: tokio::sync::broadcast::Receiver<()>,
Expand Down Expand Up @@ -122,9 +101,9 @@ impl HeartbeatHandler {

let heartbeat_json = serde_json::to_string(&heartbeat).unwrap();

let msg = tokio_tungstenite::tungstenite::Message::text(heartbeat_json);
let msg = GatewayMessage(heartbeat_json);

let send_result = websocket_tx.lock().await.send(msg).await;
let send_result = websocket_tx.lock().await.send(msg.into()).await;
if send_result.is_err() {
// We couldn't send, the websocket is broken
warn!("GW: Couldnt send heartbeat, websocket seems broken");
Expand Down
36 changes: 3 additions & 33 deletions src/gateway/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,14 @@
/// Represents a messsage received from the gateway. This will be either a [types::GatewayReceivePayload], containing events, or a [GatewayError].
/// This struct is used internally when handling messages.
#[derive(Clone, Debug)]
pub struct GatewayMessage {
/// The message we received from the server
pub(super) message: tokio_tungstenite::tungstenite::Message,
}
pub struct GatewayMessage(pub String);

impl GatewayMessage {
/// Creates self from a tungstenite message
pub fn from_tungstenite_message(message: tokio_tungstenite::tungstenite::Message) -> Self {
Self { message }
}

/// Parses the message as an error;
/// Returns the error if succesfully parsed, None if the message isn't an error
pub fn error(&self) -> Option<GatewayError> {
let content = self.message.to_string();

// Some error strings have dots on the end, which we don't care about
let processed_content = content.to_lowercase().replace('.', "");
let processed_content = self.0.to_lowercase().replace('.', "");

match processed_content.as_str() {
"unknown error" | "4000" => Some(GatewayError::Unknown),
Expand All @@ -45,29 +35,9 @@
}
}

/// Returns whether or not the message is an error
pub fn is_error(&self) -> bool {
self.error().is_some()
}

/// Parses the message as a payload;
/// Returns a result of deserializing
pub fn payload(&self) -> Result<types::GatewayReceivePayload, serde_json::Error> {
return serde_json::from_str(self.message.to_text().unwrap());
}

/// Returns whether or not the message is a payload
pub fn is_payload(&self) -> bool {
// close messages are never payloads, payloads are only text messages
if self.message.is_close() | !self.message.is_text() {
return false;
}

return self.payload().is_ok();
}

/// Returns whether or not the message is empty
pub fn is_empty(&self) -> bool {
self.message.is_empty()
return serde_json::from_str(&self.0);
Fixed Show fixed Hide fixed
}
}
12 changes: 6 additions & 6 deletions src/gateway/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,13 @@ pub mod handle;
pub mod heartbeat;
pub mod message;

#[cfg(not(wasm))]
pub mod backend_tungstenite;
#[cfg(not(wasm))]
use backend_tungstenite::*;

pub use gateway::*;
pub use handle::*;
pub use handle::GatewayHandle;
use heartbeat::*;
pub use message::*;

Expand All @@ -19,20 +24,15 @@ use std::sync::{Arc, RwLock};
use std::time::Duration;
use tokio::time::sleep_until;

use futures_util::stream::SplitSink;
use futures_util::stream::SplitStream;
use futures_util::SinkExt;
use futures_util::StreamExt;
use log::{info, trace, warn};
use tokio::net::TcpStream;
use tokio::sync::mpsc::Sender;
use tokio::sync::Mutex;
use tokio::task;
use tokio::task::JoinHandle;
use tokio::time;
use tokio::time::Instant;
use tokio_tungstenite::MaybeTlsStream;
use tokio_tungstenite::{connect_async_tls_with_config, Connector, WebSocketStream};

// Gateway opcodes
/// Opcode received when the server dispatches a [crate::types::WebSocketEvent]
Expand Down
Loading