Skip to content

Commit

Permalink
Add a simple WebSocketStream::send method to replace Sink trait u…
Browse files Browse the repository at this point in the history
…sage

And also bump MSRV to 1.64.

Fixes #142
  • Loading branch information
stackinspector authored and sdroege committed Dec 7, 2024
1 parent ce58323 commit bb0b695
Show file tree
Hide file tree
Showing 8 changed files with 122 additions and 27 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ jobs:
strategy:
matrix:
rust:
- 1.63.0
- 1.64.0

steps:
- name: Checkout sources
Expand Down
2 changes: 2 additions & 0 deletions Cargo.lock.msrv

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 11 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@ version = "0.28.0"
edition = "2018"
readme = "README.md"
include = ["examples/**/*", "src/**/*", "LICENSE", "README.md", "CHANGELOG.md"]
rust-version = "1.63"
rust-version = "1.64"

[features]
default = ["handshake"]
default = ["handshake", "futures-03-sink"]
futures-03-sink = ["futures-util"]
handshake = ["tungstenite/handshake"]
async-std-runtime = ["async-std", "handshake"]
tokio-runtime = ["tokio", "handshake"]
Expand All @@ -37,10 +38,17 @@ features = ["async-std-runtime", "tokio-runtime", "gio-runtime", "async-tls", "a

[dependencies]
log = "0.4"
futures-util = { version = "0.3", default-features = false, features = ["sink", "std"] }
futures-core = { version = "0.3", default-features = false }
atomic-waker = { version = "1.1", default-features = false }
futures-io = { version = "0.3", default-features = false, features = ["std"] }
pin-project-lite = "0.2"

[dependencies.futures-util]
optional = true
version = "0.3"
default-features = false
features = ["sink"]

[dependencies.tungstenite]
version = "0.24"
default-features = false
Expand Down
1 change: 1 addition & 0 deletions examples/autobahn-client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ async fn run_test(case: u32) -> Result<()> {
while let Some(msg) = ws_stream.next().await {
let msg = msg?;
if msg.is_text() || msg.is_binary() {
// for Sink of futures 0.3, see autobahn-server example
ws_stream.send(msg).await?;
}
}
Expand Down
4 changes: 3 additions & 1 deletion examples/autobahn-server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ async fn handle_connection(peer: SocketAddr, stream: TcpStream) -> Result<()> {
while let Some(msg) = ws_stream.next().await {
let msg = msg?;
if msg.is_text() || msg.is_binary() {
ws_stream.send(msg).await?;
// here we explicitly using futures 0.3's Sink implementation for send message
// for WebSocketStream::send, see autobahn-client example
futures::SinkExt::send(&mut ws_stream, msg).await?;
}
}

Expand Down
2 changes: 1 addition & 1 deletion examples/server-headers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use async_tungstenite::{
use url::Url;
#[macro_use]
extern crate log;
use futures_util::{SinkExt, StreamExt};
use futures_util::StreamExt;

#[async_std::main]
async fn main() {
Expand Down
26 changes: 14 additions & 12 deletions src/compat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
use log::*;
use std::io::{Read, Write};
use std::pin::Pin;
use std::task::{Context, Poll};
use std::task::{Context, Poll, Wake, Waker};

use atomic_waker::AtomicWaker;
use futures_io::{AsyncRead, AsyncWrite};
use futures_util::task;
use std::sync::Arc;
use tungstenite::Error as WsError;

Expand Down Expand Up @@ -49,18 +49,20 @@ pub(crate) struct AllowStd<S> {
// read waker slot for this, but any would do.
//
// Don't ever use this from multiple tasks at the same time!
#[cfg(feature = "handshake")]
pub(crate) trait SetWaker {
fn set_waker(&self, waker: &task::Waker);
fn set_waker(&self, waker: &Waker);
}

#[cfg(feature = "handshake")]
impl<S> SetWaker for AllowStd<S> {
fn set_waker(&self, waker: &task::Waker) {
fn set_waker(&self, waker: &Waker) {
self.set_waker(ContextWaker::Read, waker);
}
}

impl<S> AllowStd<S> {
pub(crate) fn new(inner: S, waker: &task::Waker) -> Self {
pub(crate) fn new(inner: S, waker: &Waker) -> Self {
let res = Self {
inner,
write_waker_proxy: Default::default(),
Expand All @@ -83,7 +85,7 @@ impl<S> AllowStd<S> {
//
// Write: this is only supposde to be called by write operations, i.e. the Sink impl on the
// WebSocketStream.
pub(crate) fn set_waker(&self, kind: ContextWaker, waker: &task::Waker) {
pub(crate) fn set_waker(&self, kind: ContextWaker, waker: &Waker) {
match kind {
ContextWaker::Read => {
self.write_waker_proxy.read_waker.register(waker);
Expand All @@ -103,11 +105,11 @@ impl<S> AllowStd<S> {
// reads and writes, and the same for writes.
#[derive(Debug, Default)]
struct WakerProxy {
read_waker: task::AtomicWaker,
write_waker: task::AtomicWaker,
read_waker: AtomicWaker,
write_waker: AtomicWaker,
}

impl std::task::Wake for WakerProxy {
impl Wake for WakerProxy {
fn wake(self: Arc<Self>) {
self.wake_by_ref()
}
Expand All @@ -129,10 +131,10 @@ where
#[cfg(feature = "verbose-logging")]
trace!("{}:{} AllowStd.with_context", file!(), line!());
let waker = match kind {
ContextWaker::Read => task::Waker::from(self.read_waker_proxy.clone()),
ContextWaker::Write => task::Waker::from(self.write_waker_proxy.clone()),
ContextWaker::Read => Waker::from(self.read_waker_proxy.clone()),
ContextWaker::Write => Waker::from(self.write_waker_proxy.clone()),
};
let mut context = task::Context::from_waker(&waker);
let mut context = Context::from_waker(&waker);
f(&mut context, Pin::new(&mut self.inner))
}

Expand Down
98 changes: 89 additions & 9 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,16 @@ mod handshake;
))]
pub mod stream;

use std::io::{Read, Write};
use std::{
io::{Read, Write},
pin::Pin,
task::{ready, Context, Poll},
};

use compat::{cvt, AllowStd, ContextWaker};
use futures_core::stream::{FusedStream, Stream};
use futures_io::{AsyncRead, AsyncWrite};
use futures_util::{
sink::{Sink, SinkExt},
stream::{FusedStream, Stream},
};
use log::*;
use std::pin::Pin;
use std::task::{Context, Poll};

#[cfg(feature = "handshake")]
use tungstenite::{
Expand Down Expand Up @@ -227,6 +226,7 @@ where
#[derive(Debug)]
pub struct WebSocketStream<S> {
inner: WebSocket<AllowStd<S>>,
#[cfg(feature = "futures-03-sink")]
closing: bool,
ended: bool,
/// Tungstenite is probably ready to receive more data.
Expand Down Expand Up @@ -269,6 +269,7 @@ impl<S> WebSocketStream<S> {
pub(crate) fn new(ws: WebSocket<AllowStd<S>>) -> Self {
Self {
inner: ws,
#[cfg(feature = "futures-03-sink")]
closing: false,
ended: false,
ready: true,
Expand Down Expand Up @@ -337,7 +338,7 @@ where
return Poll::Ready(None);
}

match futures_util::ready!(self.with_context(Some((ContextWaker::Read, cx)), |s| {
match ready!(self.with_context(Some((ContextWaker::Read, cx)), |s| {
#[cfg(feature = "verbose-logging")]
trace!(
"{}:{} Stream.with_context poll_next -> read()",
Expand Down Expand Up @@ -368,7 +369,8 @@ where
}
}

impl<T> Sink<Message> for WebSocketStream<T>
#[cfg(feature = "futures-03-sink")]
impl<T> futures_util::Sink<Message> for WebSocketStream<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
Expand Down Expand Up @@ -446,6 +448,84 @@ where
}
}

impl<S> WebSocketStream<S> {
/// Simple send method to replace `futures_sink::Sink` (till v0.3).
pub async fn send(&mut self, msg: Message) -> Result<(), WsError>
where
S: AsyncRead + AsyncWrite + Unpin,
{
Send::new(self, msg).await
}
}

struct Send<'a, S> {
ws: &'a mut WebSocketStream<S>,
msg: Option<Message>,
}

impl<'a, S> Send<'a, S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn new(ws: &'a mut WebSocketStream<S>, msg: Message) -> Self {
Self { ws, msg: Some(msg) }
}
}

impl<S> std::future::Future for Send<'_, S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
type Output = Result<(), WsError>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.msg.is_some() {
if !self.ws.ready {
// Currently blocked so try to flush the blockage away
let polled = self
.ws
.with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
.map(|r| {
self.ws.ready = true;
r
});
ready!(polled)?
}

let msg = self.msg.take().expect("unreachable");
match self.ws.with_context(None, |s| s.write(msg)) {
Ok(_) => Ok(()),
Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
// the message was accepted and queued so not an error
//
// set to false here for cancellation safety of *this* Future
self.ws.ready = false;
Ok(())
}
Err(e) => {
debug!("websocket start_send error: {}", e);
Err(e)
}
}?;
}

let polled = self
.ws
.with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
.map(|r| {
self.ws.ready = true;
match r {
// WebSocket connection has just been closed. Flushing completed, not an error.
Err(WsError::ConnectionClosed) => Ok(()),
other => other,
}
});
ready!(polled)?;

Poll::Ready(Ok(()))
}
}

#[cfg(any(
feature = "async-tls",
feature = "async-std-runtime",
Expand Down

0 comments on commit bb0b695

Please sign in to comment.