From d382efa5a2f3c43c04e1dcc48bbc354f94e67e8e Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Tue, 19 Nov 2024 14:26:50 -0600 Subject: [PATCH 1/3] fix: add lifetime bound to `YieldFut` --- README.md | 4 +- src/lib.rs | 45 ++++++++++++---------- src/tests.rs | 104 +++++++++++++++++++++++++-------------------------- src/try.rs | 22 +++++------ 4 files changed, 91 insertions(+), 84 deletions(-) diff --git a/README.md b/README.md index fc2ae3e..d815bfd 100644 --- a/README.md +++ b/README.md @@ -8,9 +8,9 @@ use futures_util::{pin_mut, stream::StreamExt}; #[tokio::main] async fn main() { - let stream = async_stream(|r#yield| async move { + let stream = async_stream(|yielder| async move { for i in 0..3 { - r#yield(i).await; + yielder.r#yield(i).await; } }); pin_mut!(stream); diff --git a/src/lib.rs b/src/lib.rs index 6abfb60..55b7c45 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -27,21 +27,28 @@ thread_local! { #[thread_local] static STORE: Cell<*mut ()> = Cell::new(ptr::null_mut()); -pub(crate) fn r#yield(value: T) -> YieldFut { - YieldFut { value: Some(value) } +pub struct Yielder { + _p: PhantomData +} + +impl Yielder { + pub fn r#yield(&self, value: T) -> YieldFut<'_, T> { + YieldFut { value: Some(value), _p: PhantomData } + } } /// Future returned by an [`AsyncStream`]'s yield function. /// /// This future must be `.await`ed inside the generator in order for the item to be yielded by the stream. #[must_use = "stream will not yield this item unless the future returned by yield is awaited"] -pub struct YieldFut { - value: Option +pub struct YieldFut<'y, T> { + value: Option, + _p: PhantomData<&'y ()> } -impl Unpin for YieldFut {} +impl Unpin for YieldFut<'_, T> {} -impl Future for YieldFut { +impl Future for YieldFut<'_, T> { type Output = (); fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll { @@ -153,16 +160,16 @@ where /// Create an asynchronous [`Stream`] from an asynchronous generator function. /// -/// The provided function will be given a "yielder" function, which, when called, causes the stream to yield an item: +/// The provided function will be given a [`Yielder`], which, when called, causes the stream to yield an item: /// ``` /// use async_stream_lite::async_stream; /// use futures::{pin_mut, stream::StreamExt}; /// /// #[tokio::main] /// async fn main() { -/// let stream = async_stream(|r#yield| async move { +/// let stream = async_stream(|yielder| async move { /// for i in 0..3 { -/// r#yield(i).await; +/// yielder.r#yield(i).await; /// } /// }); /// pin_mut!(stream); @@ -181,9 +188,9 @@ where /// }; /// /// fn zero_to_three() -> impl Stream { -/// async_stream(|r#yield| async move { +/// async_stream(|yielder| async move { /// for i in 0..3 { -/// r#yield(i).await; +/// yielder.r#yield(i).await; /// } /// }) /// } @@ -207,9 +214,9 @@ where /// }; /// /// fn zero_to_three() -> BoxStream<'static, u32> { -/// Box::pin(async_stream(|r#yield| async move { +/// Box::pin(async_stream(|yielder| async move { /// for i in 0..3 { -/// r#yield(i).await; +/// yielder.r#yield(i).await; /// } /// })) /// } @@ -232,18 +239,18 @@ where /// }; /// /// fn zero_to_three() -> impl Stream { -/// async_stream(|r#yield| async move { +/// async_stream(|yielder| async move { /// for i in 0..3 { -/// r#yield(i).await; +/// yielder.r#yield(i).await; /// } /// }) /// } /// /// fn double>(input: S) -> impl Stream { -/// async_stream(|r#yield| async move { +/// async_stream(|yielder| async move { /// pin_mut!(input); /// while let Some(value) = input.next().await { -/// r#yield(value * 2).await; +/// yielder.r#yield(value * 2).await; /// } /// }) /// } @@ -261,10 +268,10 @@ where /// See also [`try_async_stream`], a variant of [`async_stream`] which supports try notation (`?`). pub fn async_stream(generator: F) -> AsyncStream where - F: FnOnce(fn(value: T) -> YieldFut) -> U, + F: FnOnce(Yielder) -> U, U: Future { - let generator = generator(r#yield::); + let generator = generator(Yielder { _p: PhantomData }); AsyncStream { _p: PhantomData, done: false, diff --git a/src/tests.rs b/src/tests.rs index 1e6c369..8180199 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -6,11 +6,11 @@ use futures::{ }; use tokio::sync::mpsc; -use super::{YieldFut, async_stream}; +use super::{Yielder, async_stream}; #[tokio::test] async fn noop_stream() { - let s = async_stream(|_yield: fn(()) -> YieldFut<()>| async move {}); + let s = async_stream(|_yielder: Yielder<()>| async move {}); pin_mut!(s); #[allow(clippy::never_loop)] @@ -25,7 +25,7 @@ async fn empty_stream() { { let r = &mut ran; - let s = async_stream(|_yield: fn(()) -> YieldFut<()>| async move { + let s = async_stream(|_yield: Yielder<()>| async move { *r = true; println!("hello world!"); }); @@ -42,8 +42,8 @@ async fn empty_stream() { #[tokio::test] async fn yield_single_value() { - let s = async_stream(|r#yield| async move { - r#yield("hello").await; + let s = async_stream(|yielder| async move { + yielder.r#yield("hello").await; }); let values: Vec<_> = s.collect().await; @@ -54,8 +54,8 @@ async fn yield_single_value() { #[tokio::test] async fn fused() { - let s = async_stream(|r#yield| async move { - r#yield("hello").await; + let s = async_stream(|yielder| async move { + yielder.r#yield("hello").await; }); pin_mut!(s); @@ -69,10 +69,10 @@ async fn fused() { #[tokio::test] async fn yield_multi_value() { - let stream = async_stream(|r#yield| async move { - r#yield("hello").await; - r#yield("world").await; - r#yield("foobar").await; + let stream = async_stream(|yielder| async move { + yielder.r#yield("hello").await; + yielder.r#yield("world").await; + yielder.r#yield("foobar").await; }); let values: Vec<_> = stream.collect().await; @@ -88,10 +88,10 @@ async fn unit_yield_in_select() { #[allow(clippy::unused_async)] async fn do_stuff_async() {} - let stream = async_stream(|r#yield| async move { + let stream = async_stream(|yielder| async move { tokio::select! { - () = do_stuff_async() => r#yield(()).await, - else => r#yield(()).await + () = do_stuff_async() => yielder.r#yield(()).await, + else => yielder.r#yield(()).await } }); @@ -105,11 +105,11 @@ async fn yield_with_select() { async fn do_stuff_async() {} async fn more_async_work() {} - let stream = async_stream(|r#yield| async move { + let stream = async_stream(|yielder| async move { tokio::select! { - () = do_stuff_async() => r#yield("hey").await, - () = more_async_work() => r#yield("hey").await, - else => r#yield("hey").await + () = do_stuff_async() => yielder.r#yield("hey").await, + () = more_async_work() => yielder.r#yield("hey").await, + else => yielder.r#yield("hey").await } }); @@ -120,10 +120,10 @@ async fn yield_with_select() { #[tokio::test] async fn return_stream() { fn build_stream() -> impl Stream { - async_stream(|r#yield| async move { - r#yield(1).await; - r#yield(2).await; - r#yield(3).await; + async_stream(|yielder| async move { + yielder.r#yield(1).await; + yielder.r#yield(2).await; + yielder.r#yield(3).await; }) } @@ -139,10 +139,10 @@ async fn return_stream() { #[tokio::test] async fn boxed_stream() { fn build_stream() -> BoxStream<'static, u32> { - Box::pin(async_stream(|r#yield| async move { - r#yield(1).await; - r#yield(2).await; - r#yield(3).await; + Box::pin(async_stream(|yielder| async move { + yielder.r#yield(1).await; + yielder.r#yield(2).await; + yielder.r#yield(3).await; })) } @@ -159,9 +159,9 @@ async fn boxed_stream() { async fn consume_channel() { let (tx, mut rx) = mpsc::channel(10); - let stream = async_stream(|r#yield| async move { + let stream = async_stream(|yielder| async move { while let Some(v) = rx.recv().await { - r#yield(v).await; + yielder.r#yield(v).await; } }); @@ -182,8 +182,8 @@ async fn borrow_self() { impl Data { fn stream(&self) -> impl Stream + '_ { - async_stream(|r#yield| async move { - r#yield(&self.0[..]).await; + async_stream(|yielder| async move { + yielder.r#yield(&self.0[..]).await; }) } } @@ -201,8 +201,8 @@ async fn borrow_self_boxed() { impl Data { fn stream(&self) -> BoxStream<'_, &str> { - Box::pin(async_stream(|r#yield| async move { - r#yield(&self.0[..]).await; + Box::pin(async_stream(|yielder| async move { + yielder.r#yield(&self.0[..]).await; })) } } @@ -216,16 +216,16 @@ async fn borrow_self_boxed() { #[tokio::test] async fn stream_in_stream() { - let s = async_stream(|r#yield| async move { - let s = async_stream(|r#yield| async move { + let s = async_stream(|yielder| async move { + let s = async_stream(|yielder| async move { for i in 0..3 { - r#yield(i).await; + yielder.r#yield(i).await; } }); pin_mut!(s); while let Some(v) = s.next().await { - r#yield(v).await; + yielder.r#yield(v).await; } }); @@ -235,37 +235,37 @@ async fn stream_in_stream() { #[tokio::test] async fn streamception() { - let s = async_stream(|r#yield| async move { - let s = async_stream(|r#yield| async move { - let s = async_stream(|r#yield| async move { - let s = async_stream(|r#yield| async move { - let s = async_stream(|r#yield| async move { + let s = async_stream(|yielder| async move { + let s = async_stream(|yielder| async move { + let s = async_stream(|yielder| async move { + let s = async_stream(|yielder| async move { + let s = async_stream(|yielder| async move { for i in 0..3 { - r#yield(i).await; + yielder.r#yield(i).await; } }); pin_mut!(s); while let Some(v) = s.next().await { - r#yield(v).await; + yielder.r#yield(v).await; } }); pin_mut!(s); while let Some(v) = s.next().await { - r#yield(v).await; + yielder.r#yield(v).await; } }); pin_mut!(s); while let Some(v) = s.next().await { - r#yield(v).await; + yielder.r#yield(v).await; } }); pin_mut!(s); while let Some(v) = s.next().await { - r#yield(v).await; + yielder.r#yield(v).await; } }); @@ -275,9 +275,9 @@ async fn streamception() { #[tokio::test] async fn yield_non_unpin_value() { - let s: Vec<_> = async_stream(|r#yield| async move { + let s: Vec<_> = async_stream(|yielder| async move { for i in 0..3 { - r#yield(async move { i }).await; + yielder.r#yield(async move { i }).await; } }) .buffered(1) @@ -290,10 +290,10 @@ async fn yield_non_unpin_value() { #[tokio::test(flavor = "multi_thread", worker_threads = 4)] async fn multithreaded() { fn build_stream() -> impl Stream { - async_stream(|r#yield| async move { - r#yield(1).await; - r#yield(2).await; - r#yield(3).await; + async_stream(|yielder| async move { + yielder.r#yield(1).await; + yielder.r#yield(2).await; + yielder.r#yield(3).await; }) } diff --git a/src/try.rs b/src/try.rs index fe54f4b..162c96a 100644 --- a/src/try.rs +++ b/src/try.rs @@ -7,7 +7,7 @@ use core::{ use futures_core::stream::{FusedStream, Stream}; -use crate::{YieldFut, enter, r#yield}; +use crate::{Yielder, enter}; pin_project_lite::pin_project! { /// A [`Stream`] created from a fallible, asynchronous generator-like function. @@ -78,12 +78,12 @@ where /// use tokio::net::{TcpListener, TcpStream}; /// /// fn bind_and_accept(addr: SocketAddr) -> impl Stream> { -/// try_async_stream(|r#yield| async move { +/// try_async_stream(|yielder| async move { /// let mut listener = TcpListener::bind(addr).await?; /// loop { /// let (stream, addr) = listener.accept().await?; /// println!("received on {addr:?}"); -/// r#yield(stream).await; +/// yielder.r#yield(stream).await; /// } /// }) /// } @@ -93,10 +93,10 @@ where /// error is encountered, the stream yields `Err(E)` and is subsequently terminated. pub fn try_async_stream(generator: F) -> TryAsyncStream where - F: FnOnce(fn(T) -> YieldFut) -> U, + F: FnOnce(Yielder) -> U, U: Future> { - let generator = generator(r#yield::); + let generator = generator(Yielder { _p: PhantomData }); TryAsyncStream { _p: PhantomData, done: false, @@ -112,11 +112,11 @@ mod tests { #[tokio::test] async fn single_err() { - let s = try_async_stream(|r#yield| async move { + let s = try_async_stream(|yielder| async move { if true { Err("hello")?; } else { - r#yield("world").await; + yielder.r#yield("world").await; } Ok(()) }); @@ -128,8 +128,8 @@ mod tests { #[tokio::test] async fn yield_then_err() { - let s = try_async_stream(|r#yield| async move { - r#yield("hello").await; + let s = try_async_stream(|yielder| async move { + yielder.r#yield("hello").await; Err("world")?; unreachable!(); }); @@ -152,13 +152,13 @@ mod tests { } fn test() -> impl Stream> { - try_async_stream(|r#yield| async move { + try_async_stream(|yielder| async move { if true { Err(ErrorA(1))?; } else { Err(ErrorB(2))?; } - r#yield("unreachable").await; + yielder.r#yield("unreachable").await; Ok(()) }) } From 6b98606a4cdaee6e00a53167b1969eca7627faa4 Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Tue, 3 Dec 2024 23:57:11 -0600 Subject: [PATCH 2/3] fix: remove thread local --- Cargo.toml | 1 - README.md | 2 +- src/lib.rs | 119 ++++++++++++++++++++++++--------------------------- src/tests.rs | 15 +++++++ src/try.rs | 24 ++++++----- 5 files changed, 86 insertions(+), 75 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index cb12f57..75c3122 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,4 +17,3 @@ tokio = { version = "1", features = [ "full", "macros" ] } futures = "0.3" [features] -unstable-thread-local = [] diff --git a/README.md b/README.md index d815bfd..45ccb48 100644 --- a/README.md +++ b/README.md @@ -21,4 +21,4 @@ async fn main() { ``` ## `#![no_std]` support -`async-stream-lite` supports `#![no_std]` on nightly Rust (due to the usage of [the unstable `#[thread_local]` attribute](https://doc.rust-lang.org/beta/unstable-book/language-features/thread-local.html)). To enable `#![no_std]` support, enable the `unstable-thread-local` feature. +`async-stream-lite` supports `#![no_std]`, but requires `alloc`. diff --git a/src/lib.rs b/src/lib.rs index 55b7c45..a409c70 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,15 +1,16 @@ #![allow(clippy::tabs_in_doc_comments)] -#![cfg_attr(feature = "unstable-thread-local", feature(thread_local))] -#![cfg_attr(all(not(test), feature = "unstable-thread-local"), no_std)] +#![cfg_attr(not(test), no_std)] +extern crate alloc; extern crate core; +use alloc::sync::{Arc, Weak}; use core::{ cell::Cell, future::Future, marker::PhantomData, pin::Pin, - ptr, + sync::atomic::{AtomicBool, Ordering}, task::{Context, Poll} }; @@ -19,21 +20,49 @@ use futures_core::stream::{FusedStream, Stream}; mod tests; mod r#try; -#[cfg(not(feature = "unstable-thread-local"))] -thread_local! { - static STORE: Cell<*mut ()> = const { Cell::new(ptr::null_mut()) }; +pub(crate) struct SharedStore { + entered: AtomicBool, + cell: Cell> } -#[cfg(feature = "unstable-thread-local")] -#[thread_local] -static STORE: Cell<*mut ()> = Cell::new(ptr::null_mut()); + +impl Default for SharedStore { + fn default() -> Self { + Self { + entered: AtomicBool::new(false), + cell: Cell::new(None) + } + } +} + +impl SharedStore { + pub fn has_value(&self) -> bool { + unsafe { &*self.cell.as_ptr() }.is_some() + } +} + +unsafe impl Sync for SharedStore {} pub struct Yielder { - _p: PhantomData + pub(crate) store: Weak> } impl Yielder { pub fn r#yield(&self, value: T) -> YieldFut<'_, T> { - YieldFut { value: Some(value), _p: PhantomData } + #[cold] + fn invalid_usage() -> ! { + panic!("attempted to use async_stream_lite yielder outside of stream context or across threads") + } + + let Some(store) = self.store.upgrade() else { + invalid_usage(); + }; + if !store.entered.load(Ordering::Relaxed) { + invalid_usage(); + } + + store.cell.replace(Some(value)); + + YieldFut { store, _p: PhantomData } } } @@ -42,7 +71,7 @@ impl Yielder { /// This future must be `.await`ed inside the generator in order for the item to be yielded by the stream. #[must_use = "stream will not yield this item unless the future returned by yield is awaited"] pub struct YieldFut<'y, T> { - value: Option, + store: Arc>, _p: PhantomData<&'y ()> } @@ -51,56 +80,27 @@ impl Unpin for YieldFut<'_, T> {} impl Future for YieldFut<'_, T> { type Output = (); - fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll { - if self.value.is_none() { + fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll { + if !self.store.has_value() { return Poll::Ready(()); } - fn op(cell: &Cell<*mut ()>, value: &mut Option) { - let ptr = cell.get().cast::>(); - let option_ref = unsafe { ptr.as_mut() }.expect("attempted to use async_stream yielder outside of stream context or across threads"); - if option_ref.is_none() { - *option_ref = value.take(); - } - } - - #[cfg(not(feature = "unstable-thread-local"))] - return STORE.with(|cell| { - op(cell, &mut self.value); - Poll::Pending - }); - #[cfg(feature = "unstable-thread-local")] - { - op(&STORE, &mut self.value); - Poll::Pending - } + Poll::Pending } } -struct Enter<'a, T> { - _p: PhantomData<&'a T>, - prev: *mut () +struct Enter<'s, T> { + store: &'s SharedStore } -fn enter(dst: &mut Option) -> Enter<'_, T> { - fn op(cell: &Cell<*mut ()>, dst: &mut Option) -> *mut () { - let prev = cell.get(); - cell.set((dst as *mut Option).cast::<()>()); - prev - } - #[cfg(not(feature = "unstable-thread-local"))] - let prev = STORE.with(|cell| op(cell, dst)); - #[cfg(feature = "unstable-thread-local")] - let prev = op(&STORE, dst); - Enter { _p: PhantomData, prev } +fn enter(store: &SharedStore) -> Enter<'_, T> { + store.entered.store(true, Ordering::Relaxed); + Enter { store } } impl Drop for Enter<'_, T> { fn drop(&mut self) { - #[cfg(not(feature = "unstable-thread-local"))] - STORE.with(|cell| cell.set(self.prev)); - #[cfg(feature = "unstable-thread-local")] - STORE.set(self.prev); + self.store.entered.store(false, Ordering::Relaxed); } } @@ -108,9 +108,8 @@ pin_project_lite::pin_project! { /// A [`Stream`] created from an asynchronous generator-like function. /// /// To create an [`AsyncStream`], use the [`async_stream`] function. - #[derive(Debug)] pub struct AsyncStream { - _p: PhantomData, + store: Arc>, done: bool, #[pin] generator: U @@ -138,16 +137,15 @@ where return Poll::Ready(None); } - let mut dst = None; let res = { - let _enter = enter(&mut dst); + let _enter = enter(&me.store); me.generator.poll(cx) }; *me.done = res.is_ready(); - if dst.is_some() { - return Poll::Ready(dst.take()); + if me.store.has_value() { + return Poll::Ready(me.store.cell.take()); } if *me.done { Poll::Ready(None) } else { Poll::Pending } @@ -271,12 +269,9 @@ where F: FnOnce(Yielder) -> U, U: Future { - let generator = generator(Yielder { _p: PhantomData }); - AsyncStream { - _p: PhantomData, - done: false, - generator - } + let store = Arc::new(SharedStore::default()); + let generator = generator(Yielder { store: Arc::downgrade(&store) }); + AsyncStream { store, done: false, generator } } pub use self::r#try::{TryAsyncStream, try_async_stream}; diff --git a/src/tests.rs b/src/tests.rs index 8180199..9087c62 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -310,3 +310,18 @@ async fn multithreaded() { } join_all(futures).await; } + +#[tokio::test] +#[should_panic = "attempted to use async_stream_lite yielder outside of stream context or across threads"] +async fn test_move_yielder() { + let mut slot = None; + let s = async_stream(|yielder: Yielder<()>| async { + slot.replace(yielder); + }); + pin_mut!(s); + + let _ = s.next().await; + drop(s); + + slot.take().unwrap().r#yield(()).await; +} diff --git a/src/try.rs b/src/try.rs index 162c96a..9c29040 100644 --- a/src/try.rs +++ b/src/try.rs @@ -1,3 +1,4 @@ +use alloc::sync::Arc; use core::{ future::Future, marker::PhantomData, @@ -7,18 +8,18 @@ use core::{ use futures_core::stream::{FusedStream, Stream}; -use crate::{Yielder, enter}; +use crate::{SharedStore, Yielder, enter}; pin_project_lite::pin_project! { /// A [`Stream`] created from a fallible, asynchronous generator-like function. /// /// To create a [`TryAsyncStream`], use the [`try_async_stream`] function. See also [`crate::AsyncStream`]. - #[derive(Debug)] pub struct TryAsyncStream { - _p: PhantomData<(T, E)>, + store: Arc>, done: bool, #[pin] - generator: U + generator: U, + _p: PhantomData } } @@ -43,9 +44,8 @@ where return Poll::Ready(None); } - let mut dst = None; let res = { - let _enter = enter(&mut dst); + let _enter = enter(&me.store); me.generator.poll(cx) }; @@ -53,8 +53,8 @@ where if let Poll::Ready(Err(e)) = res { return Poll::Ready(Some(Err(e))); - } else if dst.is_some() { - return Poll::Ready(dst.take().map(Ok)); + } else if me.store.has_value() { + return Poll::Ready(me.store.cell.take().map(Ok)); } if *me.done { Poll::Ready(None) } else { Poll::Pending } @@ -96,11 +96,13 @@ where F: FnOnce(Yielder) -> U, U: Future> { - let generator = generator(Yielder { _p: PhantomData }); + let store = Arc::new(SharedStore::default()); + let generator = generator(Yielder { store: Arc::downgrade(&store) }); TryAsyncStream { - _p: PhantomData, + store, done: false, - generator + generator, + _p: PhantomData } } From fd26250e56c3ab04cb89708511a924c67f26ac94 Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Tue, 3 Dec 2024 23:58:02 -0600 Subject: [PATCH 3/3] fix CI --- .github/workflows/test.yml | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f89632f..fd68a5d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -28,15 +28,3 @@ jobs: - uses: Swatinem/rust-cache@v1 - name: Run tests run: cargo test - test-no-std: - name: Run tests (unstable-thread-local) - runs-on: ubuntu-24.04 - steps: - - uses: actions/checkout@v4 - - name: Install nightly Rust toolchain - uses: dtolnay/rust-toolchain@nightly - with: - toolchain: nightly-2024-11-11 - - uses: Swatinem/rust-cache@v1 - - name: Run tests - run: cargo test --features unstable-thread-local