Skip to content
This repository has been archived by the owner on Nov 19, 2024. It is now read-only.

Fix unsoundness outlined in #1 #2

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ use futures_util::{pin_mut, stream::StreamExt};

#[tokio::main]
async fn main() {
let stream = async_stream(|r#yield| async move {
let stream = async_stream(|yielder| async move {
for i in 0..3 {
r#yield(i).await;
yielder.r#yield(i).await;
}
});
pin_mut!(stream);
Expand Down
45 changes: 26 additions & 19 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,28 @@ thread_local! {
#[thread_local]
static STORE: Cell<*mut ()> = Cell::new(ptr::null_mut());

pub(crate) fn r#yield<T>(value: T) -> YieldFut<T> {
YieldFut { value: Some(value) }
pub struct Yielder<T> {
_p: PhantomData<T>
}

impl<T> Yielder<T> {
pub fn r#yield(&self, value: T) -> YieldFut<'_, T> {
YieldFut { value: Some(value), _p: PhantomData }
}
}

/// Future returned by an [`AsyncStream`]'s yield function.
///
/// This future must be `.await`ed inside the generator in order for the item to be yielded by the stream.
#[must_use = "stream will not yield this item unless the future returned by yield is awaited"]
pub struct YieldFut<T> {
value: Option<T>
pub struct YieldFut<'y, T> {
value: Option<T>,
_p: PhantomData<&'y ()>
}

impl<T> Unpin for YieldFut<T> {}
impl<T> Unpin for YieldFut<'_, T> {}

impl<T> Future for YieldFut<T> {
impl<T> Future for YieldFut<'_, T> {
type Output = ();

fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
Expand Down Expand Up @@ -153,16 +160,16 @@ where

/// Create an asynchronous [`Stream`] from an asynchronous generator function.
///
/// The provided function will be given a "yielder" function, which, when called, causes the stream to yield an item:
/// The provided function will be given a [`Yielder`], which, when called, causes the stream to yield an item:
/// ```
/// use async_stream_lite::async_stream;
/// use futures::{pin_mut, stream::StreamExt};
///
/// #[tokio::main]
/// async fn main() {
/// let stream = async_stream(|r#yield| async move {
/// let stream = async_stream(|yielder| async move {
/// for i in 0..3 {
/// r#yield(i).await;
/// yielder.r#yield(i).await;
/// }
/// });
/// pin_mut!(stream);
Expand All @@ -181,9 +188,9 @@ where
/// };
///
/// fn zero_to_three() -> impl Stream<Item = u32> {
/// async_stream(|r#yield| async move {
/// async_stream(|yielder| async move {
/// for i in 0..3 {
/// r#yield(i).await;
/// yielder.r#yield(i).await;
/// }
/// })
/// }
Expand All @@ -207,9 +214,9 @@ where
/// };
///
/// fn zero_to_three() -> BoxStream<'static, u32> {
/// Box::pin(async_stream(|r#yield| async move {
/// Box::pin(async_stream(|yielder| async move {
/// for i in 0..3 {
/// r#yield(i).await;
/// yielder.r#yield(i).await;
/// }
/// }))
/// }
Expand All @@ -232,18 +239,18 @@ where
/// };
///
/// fn zero_to_three() -> impl Stream<Item = u32> {
/// async_stream(|r#yield| async move {
/// async_stream(|yielder| async move {
/// for i in 0..3 {
/// r#yield(i).await;
/// yielder.r#yield(i).await;
/// }
/// })
/// }
///
/// fn double<S: Stream<Item = u32>>(input: S) -> impl Stream<Item = u32> {
/// async_stream(|r#yield| async move {
/// async_stream(|yielder| async move {
/// pin_mut!(input);
/// while let Some(value) = input.next().await {
/// r#yield(value * 2).await;
/// yielder.r#yield(value * 2).await;
/// }
/// })
/// }
Expand All @@ -261,10 +268,10 @@ where
/// See also [`try_async_stream`], a variant of [`async_stream`] which supports try notation (`?`).
pub fn async_stream<T, F, U>(generator: F) -> AsyncStream<T, U>
where
F: FnOnce(fn(value: T) -> YieldFut<T>) -> U,
F: FnOnce(Yielder<T>) -> U,
U: Future<Output = ()>
{
let generator = generator(r#yield::<T>);
let generator = generator(Yielder { _p: PhantomData });
AsyncStream {
_p: PhantomData,
done: false,
Expand Down
104 changes: 52 additions & 52 deletions src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ use futures::{
};
use tokio::sync::mpsc;

use super::{YieldFut, async_stream};
use super::{Yielder, async_stream};

#[tokio::test]
async fn noop_stream() {
let s = async_stream(|_yield: fn(()) -> YieldFut<()>| async move {});
let s = async_stream(|_yielder: Yielder<()>| async move {});
pin_mut!(s);

#[allow(clippy::never_loop)]
Expand All @@ -25,7 +25,7 @@ async fn empty_stream() {

{
let r = &mut ran;
let s = async_stream(|_yield: fn(()) -> YieldFut<()>| async move {
let s = async_stream(|_yield: Yielder<()>| async move {
*r = true;
println!("hello world!");
});
Expand All @@ -42,8 +42,8 @@ async fn empty_stream() {

#[tokio::test]
async fn yield_single_value() {
let s = async_stream(|r#yield| async move {
r#yield("hello").await;
let s = async_stream(|yielder| async move {
yielder.r#yield("hello").await;
});

let values: Vec<_> = s.collect().await;
Expand All @@ -54,8 +54,8 @@ async fn yield_single_value() {

#[tokio::test]
async fn fused() {
let s = async_stream(|r#yield| async move {
r#yield("hello").await;
let s = async_stream(|yielder| async move {
yielder.r#yield("hello").await;
});
pin_mut!(s);

Expand All @@ -69,10 +69,10 @@ async fn fused() {

#[tokio::test]
async fn yield_multi_value() {
let stream = async_stream(|r#yield| async move {
r#yield("hello").await;
r#yield("world").await;
r#yield("foobar").await;
let stream = async_stream(|yielder| async move {
yielder.r#yield("hello").await;
yielder.r#yield("world").await;
yielder.r#yield("foobar").await;
});

let values: Vec<_> = stream.collect().await;
Expand All @@ -88,10 +88,10 @@ async fn unit_yield_in_select() {
#[allow(clippy::unused_async)]
async fn do_stuff_async() {}

let stream = async_stream(|r#yield| async move {
let stream = async_stream(|yielder| async move {
tokio::select! {
() = do_stuff_async() => r#yield(()).await,
else => r#yield(()).await
() = do_stuff_async() => yielder.r#yield(()).await,
else => yielder.r#yield(()).await
}
});

Expand All @@ -105,11 +105,11 @@ async fn yield_with_select() {
async fn do_stuff_async() {}
async fn more_async_work() {}

let stream = async_stream(|r#yield| async move {
let stream = async_stream(|yielder| async move {
tokio::select! {
() = do_stuff_async() => r#yield("hey").await,
() = more_async_work() => r#yield("hey").await,
else => r#yield("hey").await
() = do_stuff_async() => yielder.r#yield("hey").await,
() = more_async_work() => yielder.r#yield("hey").await,
else => yielder.r#yield("hey").await
}
});

Expand All @@ -120,10 +120,10 @@ async fn yield_with_select() {
#[tokio::test]
async fn return_stream() {
fn build_stream() -> impl Stream<Item = u32> {
async_stream(|r#yield| async move {
r#yield(1).await;
r#yield(2).await;
r#yield(3).await;
async_stream(|yielder| async move {
yielder.r#yield(1).await;
yielder.r#yield(2).await;
yielder.r#yield(3).await;
})
}

Expand All @@ -139,10 +139,10 @@ async fn return_stream() {
#[tokio::test]
async fn boxed_stream() {
fn build_stream() -> BoxStream<'static, u32> {
Box::pin(async_stream(|r#yield| async move {
r#yield(1).await;
r#yield(2).await;
r#yield(3).await;
Box::pin(async_stream(|yielder| async move {
yielder.r#yield(1).await;
yielder.r#yield(2).await;
yielder.r#yield(3).await;
}))
}

Expand All @@ -159,9 +159,9 @@ async fn boxed_stream() {
async fn consume_channel() {
let (tx, mut rx) = mpsc::channel(10);

let stream = async_stream(|r#yield| async move {
let stream = async_stream(|yielder| async move {
while let Some(v) = rx.recv().await {
r#yield(v).await;
yielder.r#yield(v).await;
}
});

Expand All @@ -182,8 +182,8 @@ async fn borrow_self() {

impl Data {
fn stream(&self) -> impl Stream<Item = &str> + '_ {
async_stream(|r#yield| async move {
r#yield(&self.0[..]).await;
async_stream(|yielder| async move {
yielder.r#yield(&self.0[..]).await;
})
}
}
Expand All @@ -201,8 +201,8 @@ async fn borrow_self_boxed() {

impl Data {
fn stream(&self) -> BoxStream<'_, &str> {
Box::pin(async_stream(|r#yield| async move {
r#yield(&self.0[..]).await;
Box::pin(async_stream(|yielder| async move {
yielder.r#yield(&self.0[..]).await;
}))
}
}
Expand All @@ -216,16 +216,16 @@ async fn borrow_self_boxed() {

#[tokio::test]
async fn stream_in_stream() {
let s = async_stream(|r#yield| async move {
let s = async_stream(|r#yield| async move {
let s = async_stream(|yielder| async move {
let s = async_stream(|yielder| async move {
for i in 0..3 {
r#yield(i).await;
yielder.r#yield(i).await;
}
});
pin_mut!(s);

while let Some(v) = s.next().await {
r#yield(v).await;
yielder.r#yield(v).await;
}
});

Expand All @@ -235,37 +235,37 @@ async fn stream_in_stream() {

#[tokio::test]
async fn streamception() {
let s = async_stream(|r#yield| async move {
let s = async_stream(|r#yield| async move {
let s = async_stream(|r#yield| async move {
let s = async_stream(|r#yield| async move {
let s = async_stream(|r#yield| async move {
let s = async_stream(|yielder| async move {
let s = async_stream(|yielder| async move {
let s = async_stream(|yielder| async move {
let s = async_stream(|yielder| async move {
let s = async_stream(|yielder| async move {
for i in 0..3 {
r#yield(i).await;
yielder.r#yield(i).await;
}
});
pin_mut!(s);

while let Some(v) = s.next().await {
r#yield(v).await;
yielder.r#yield(v).await;
}
});
pin_mut!(s);

while let Some(v) = s.next().await {
r#yield(v).await;
yielder.r#yield(v).await;
}
});
pin_mut!(s);

while let Some(v) = s.next().await {
r#yield(v).await;
yielder.r#yield(v).await;
}
});
pin_mut!(s);

while let Some(v) = s.next().await {
r#yield(v).await;
yielder.r#yield(v).await;
}
});

Expand All @@ -275,9 +275,9 @@ async fn streamception() {

#[tokio::test]
async fn yield_non_unpin_value() {
let s: Vec<_> = async_stream(|r#yield| async move {
let s: Vec<_> = async_stream(|yielder| async move {
for i in 0..3 {
r#yield(async move { i }).await;
yielder.r#yield(async move { i }).await;
}
})
.buffered(1)
Expand All @@ -290,10 +290,10 @@ async fn yield_non_unpin_value() {
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn multithreaded() {
fn build_stream() -> impl Stream<Item = u32> {
async_stream(|r#yield| async move {
r#yield(1).await;
r#yield(2).await;
r#yield(3).await;
async_stream(|yielder| async move {
yielder.r#yield(1).await;
yielder.r#yield(2).await;
yielder.r#yield(3).await;
})
}

Expand Down
Loading
Loading