Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a simple WebSocketStream::send method to replace Sink trait u… #144

Merged
merged 1 commit into from
Dec 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ jobs:
strategy:
matrix:
rust:
- 1.63.0
- 1.64.0

steps:
- name: Checkout sources
Expand Down
2 changes: 2 additions & 0 deletions Cargo.lock.msrv

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 14 additions & 6 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@ version = "0.28.0"
edition = "2018"
readme = "README.md"
include = ["examples/**/*", "src/**/*", "LICENSE", "README.md", "CHANGELOG.md"]
rust-version = "1.63"
rust-version = "1.64"

[features]
default = ["handshake"]
default = ["handshake", "futures-03-sink"]
futures-03-sink = ["futures-util"]
handshake = ["tungstenite/handshake"]
async-std-runtime = ["async-std", "handshake"]
tokio-runtime = ["tokio", "handshake"]
Expand All @@ -37,10 +38,17 @@ features = ["async-std-runtime", "tokio-runtime", "gio-runtime", "async-tls", "a

[dependencies]
log = "0.4"
futures-util = { version = "0.3", default-features = false, features = ["sink", "std"] }
futures-core = { version = "0.3", default-features = false }
atomic-waker = { version = "1.1", default-features = false }
futures-io = { version = "0.3", default-features = false, features = ["std"] }
pin-project-lite = "0.2"

[dependencies.futures-util]
optional = true
version = "0.3"
default-features = false
features = ["sink"]

[dependencies.tungstenite]
version = "0.24"
default-features = false
Expand Down Expand Up @@ -141,7 +149,7 @@ required-features = ["async-std-runtime"]

[[example]]
name = "autobahn-server"
required-features = ["async-std-runtime"]
required-features = ["async-std-runtime", "futures-03-sink"]

[[example]]
name = "server"
Expand All @@ -153,7 +161,7 @@ required-features = ["async-std-runtime"]

[[example]]
name = "server-headers"
required-features = ["async-std-runtime", "handshake"]
required-features = ["async-std-runtime", "handshake", "futures-util"]

[[example]]
name = "interval-server"
Expand All @@ -173,4 +181,4 @@ required-features = ["tokio-runtime"]

[[example]]
name = "server-custom-accept"
required-features = ["tokio-runtime"]
required-features = ["tokio-runtime", "futures-util"]
1 change: 1 addition & 0 deletions examples/autobahn-client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ async fn run_test(case: u32) -> Result<()> {
while let Some(msg) = ws_stream.next().await {
let msg = msg?;
if msg.is_text() || msg.is_binary() {
// for Sink of futures 0.3, see autobahn-server example
ws_stream.send(msg).await?;
}
}
Expand Down
4 changes: 3 additions & 1 deletion examples/autobahn-server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ async fn handle_connection(peer: SocketAddr, stream: TcpStream) -> Result<()> {
while let Some(msg) = ws_stream.next().await {
let msg = msg?;
if msg.is_text() || msg.is_binary() {
ws_stream.send(msg).await?;
// here we explicitly using futures 0.3's Sink implementation for send message
// for WebSocketStream::send, see autobahn-client example
futures::SinkExt::send(&mut ws_stream, msg).await?;
}
}

Expand Down
2 changes: 1 addition & 1 deletion examples/server-headers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use async_tungstenite::{
use url::Url;
#[macro_use]
extern crate log;
use futures_util::{SinkExt, StreamExt};
use futures_util::StreamExt;

#[async_std::main]
async fn main() {
Expand Down
26 changes: 14 additions & 12 deletions src/compat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
use log::*;
use std::io::{Read, Write};
use std::pin::Pin;
use std::task::{Context, Poll};
use std::task::{Context, Poll, Wake, Waker};

use atomic_waker::AtomicWaker;
use futures_io::{AsyncRead, AsyncWrite};
use futures_util::task;
use std::sync::Arc;
use tungstenite::Error as WsError;

Expand Down Expand Up @@ -49,18 +49,20 @@ pub(crate) struct AllowStd<S> {
// read waker slot for this, but any would do.
//
// Don't ever use this from multiple tasks at the same time!
#[cfg(feature = "handshake")]
pub(crate) trait SetWaker {
fn set_waker(&self, waker: &task::Waker);
fn set_waker(&self, waker: &Waker);
}

#[cfg(feature = "handshake")]
impl<S> SetWaker for AllowStd<S> {
fn set_waker(&self, waker: &task::Waker) {
fn set_waker(&self, waker: &Waker) {
self.set_waker(ContextWaker::Read, waker);
}
}

impl<S> AllowStd<S> {
pub(crate) fn new(inner: S, waker: &task::Waker) -> Self {
pub(crate) fn new(inner: S, waker: &Waker) -> Self {
let res = Self {
inner,
write_waker_proxy: Default::default(),
Expand All @@ -83,7 +85,7 @@ impl<S> AllowStd<S> {
//
// Write: this is only supposde to be called by write operations, i.e. the Sink impl on the
// WebSocketStream.
pub(crate) fn set_waker(&self, kind: ContextWaker, waker: &task::Waker) {
pub(crate) fn set_waker(&self, kind: ContextWaker, waker: &Waker) {
match kind {
ContextWaker::Read => {
self.write_waker_proxy.read_waker.register(waker);
Expand All @@ -103,11 +105,11 @@ impl<S> AllowStd<S> {
// reads and writes, and the same for writes.
#[derive(Debug, Default)]
struct WakerProxy {
read_waker: task::AtomicWaker,
write_waker: task::AtomicWaker,
read_waker: AtomicWaker,
write_waker: AtomicWaker,
}

impl std::task::Wake for WakerProxy {
impl Wake for WakerProxy {
fn wake(self: Arc<Self>) {
self.wake_by_ref()
}
Expand All @@ -129,10 +131,10 @@ where
#[cfg(feature = "verbose-logging")]
trace!("{}:{} AllowStd.with_context", file!(), line!());
let waker = match kind {
ContextWaker::Read => task::Waker::from(self.read_waker_proxy.clone()),
ContextWaker::Write => task::Waker::from(self.write_waker_proxy.clone()),
ContextWaker::Read => Waker::from(self.read_waker_proxy.clone()),
ContextWaker::Write => Waker::from(self.write_waker_proxy.clone()),
};
let mut context = task::Context::from_waker(&waker);
let mut context = Context::from_waker(&waker);
f(&mut context, Pin::new(&mut self.inner))
}

Expand Down
98 changes: 89 additions & 9 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,16 @@ mod handshake;
))]
pub mod stream;

use std::io::{Read, Write};
use std::{
io::{Read, Write},
pin::Pin,
task::{ready, Context, Poll},
};

use compat::{cvt, AllowStd, ContextWaker};
use futures_core::stream::{FusedStream, Stream};
use futures_io::{AsyncRead, AsyncWrite};
use futures_util::{
sink::{Sink, SinkExt},
stream::{FusedStream, Stream},
};
use log::*;
use std::pin::Pin;
use std::task::{Context, Poll};

#[cfg(feature = "handshake")]
use tungstenite::{
Expand Down Expand Up @@ -227,6 +226,7 @@ where
#[derive(Debug)]
pub struct WebSocketStream<S> {
inner: WebSocket<AllowStd<S>>,
#[cfg(feature = "futures-03-sink")]
closing: bool,
ended: bool,
/// Tungstenite is probably ready to receive more data.
Expand Down Expand Up @@ -269,6 +269,7 @@ impl<S> WebSocketStream<S> {
pub(crate) fn new(ws: WebSocket<AllowStd<S>>) -> Self {
Self {
inner: ws,
#[cfg(feature = "futures-03-sink")]
closing: false,
ended: false,
ready: true,
Expand Down Expand Up @@ -337,7 +338,7 @@ where
return Poll::Ready(None);
}

match futures_util::ready!(self.with_context(Some((ContextWaker::Read, cx)), |s| {
match ready!(self.with_context(Some((ContextWaker::Read, cx)), |s| {
#[cfg(feature = "verbose-logging")]
trace!(
"{}:{} Stream.with_context poll_next -> read()",
Expand Down Expand Up @@ -368,7 +369,8 @@ where
}
}

impl<T> Sink<Message> for WebSocketStream<T>
#[cfg(feature = "futures-03-sink")]
impl<T> futures_util::Sink<Message> for WebSocketStream<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
Expand Down Expand Up @@ -446,6 +448,84 @@ where
}
}

impl<S> WebSocketStream<S> {
/// Simple send method to replace `futures_sink::Sink` (till v0.3).
pub async fn send(&mut self, msg: Message) -> Result<(), WsError>
where
S: AsyncRead + AsyncWrite + Unpin,
{
Send::new(self, msg).await
}
}

struct Send<'a, S> {
ws: &'a mut WebSocketStream<S>,
msg: Option<Message>,
}

impl<'a, S> Send<'a, S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn new(ws: &'a mut WebSocketStream<S>, msg: Message) -> Self {
Self { ws, msg: Some(msg) }
}
}

impl<S> std::future::Future for Send<'_, S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
type Output = Result<(), WsError>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.msg.is_some() {
if !self.ws.ready {
// Currently blocked so try to flush the blockage away
let polled = self
.ws
.with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
.map(|r| {
self.ws.ready = true;
r
});
ready!(polled)?
}

let msg = self.msg.take().expect("unreachable");
match self.ws.with_context(None, |s| s.write(msg)) {
Ok(_) => Ok(()),
Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
// the message was accepted and queued so not an error
//
// set to false here for cancellation safety of *this* Future
self.ws.ready = false;
Ok(())
}
Err(e) => {
debug!("websocket start_send error: {}", e);
Err(e)
}
}?;
}

let polled = self
.ws
.with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
.map(|r| {
self.ws.ready = true;
match r {
// WebSocket connection has just been closed. Flushing completed, not an error.
Err(WsError::ConnectionClosed) => Ok(()),
other => other,
}
});
ready!(polled)?;

Poll::Ready(Ok(()))
}
}

#[cfg(any(
feature = "async-tls",
feature = "async-std-runtime",
Expand Down
Loading