Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace AsyncRead impl of Object with Stream of Bytes #1205

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions async-nats/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ criterion = { version = "0.5", features = ["async_tokio"]}
nats-server = { path = "../nats-server" }
rand = "0.8"
tokio = { version = "1.25.0", features = ["rt-multi-thread"] }
tokio-util = { version = "0.7", features = ["io"] }
futures = { version = "0.3.28", default-features = false, features = ["std", "async-await"] }
tracing-subscriber = "0.3"
async-nats = {path = ".", features = ["experimental"]}
Expand Down
190 changes: 97 additions & 93 deletions async-nats/src/jetstream/object_store/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,23 @@
// limitations under the License.

//! Object Store module
use std::collections::VecDeque;
use std::fmt::Display;
use std::{cmp, str::FromStr, task::Poll, time::Duration};
use std::io;
use std::pin::Pin;
use std::task::Context;
use std::{str::FromStr, task::Poll, time::Duration};

use crate::crypto::Sha256;
use crate::subject::Subject;
use crate::{HeaderMap, HeaderValue};
use base64::engine::general_purpose::{STANDARD, URL_SAFE};
use base64::engine::Engine;
use bytes::BytesMut;
use bytes::{Bytes, BytesMut};
use futures::future::BoxFuture;
use once_cell::sync::Lazy;
use tokio::io::AsyncReadExt;

use futures::{Stream, StreamExt};
use futures::{FutureExt, Stream, StreamExt};
use regex::Regex;
use serde::{Deserialize, Serialize};
use tracing::{debug, trace};
Expand Down Expand Up @@ -92,24 +94,42 @@ pub struct ObjectStore {
impl ObjectStore {
/// Gets an [Object] from the [ObjectStore].
///
/// [Object] implements [tokio::io::AsyncRead] that allows
/// to read the data from Object Store.
/// [Object] implements [Stream] that allows
/// to stream chunks from Object Store.
///
/// # Examples
///
/// ```no_run
/// # #[tokio::main]
/// # async fn main() -> Result<(), async_nats::Error> {
/// use tokio::io::AsyncReadExt;
/// # async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
/// use std::env;
///
/// use tokio::fs::File;
///
/// let client = async_nats::connect("demo.nats.io").await?;
/// let jetstream = async_nats::jetstream::new(client);
///
/// let bucket = jetstream.get_object_store("store").await?;
/// let mut object = bucket.get("FOO").await?;
///
/// // Object implements `tokio::io::AsyncRead`.
/// let mut bytes = vec![];
/// object.read_to_end(&mut bytes).await?;
/// // Use the `Stream` implementation
/// use futures::TryStreamExt as _;
/// use tokio::io::AsyncWriteExt as _;
///
/// let mut file = File::create(env::temp_dir().join("FOO.bin")).await?;
/// while let Some(chunk) = object.try_next().await? {
/// file.write_all(&chunk).await?;
/// }
/// file.sync_all().await?;
///
/// // Alternatively use `tokio_util` with the `io` feature
/// // to convert the `Stream` into `AsyncRead`
/// // (less efficient because of the added memcpy)
/// let mut reader = tokio_util::io::StreamReader::new(object);
///
/// let mut file = File::create(env::temp_dir().join("FOO.bin")).await?;
/// tokio::io::copy(&mut reader, &mut file).await?;
/// file.sync_all().await?;
/// # Ok(())
/// # }
/// ```
Expand Down Expand Up @@ -920,7 +940,6 @@ impl Stream for List {
/// Represents an object stored in a bucket.
pub struct Object {
pub info: ObjectInfo,
remaining_bytes: VecDeque<u8>,
has_pending_messages: bool,
digest: Option<Sha256>,
subscription: Option<crate::jetstream::consumer::push::Ordered>,
Expand All @@ -933,7 +952,6 @@ impl Object {
Object {
subscription: None,
info,
remaining_bytes: VecDeque::new(),
has_pending_messages: true,
digest: Some(Sha256::new()),
subscription_future: None,
Expand All @@ -947,24 +965,19 @@ impl Object {
}
}

impl tokio::io::AsyncRead for Object {
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
let (buf1, _buf2) = self.remaining_bytes.as_slices();
if !buf1.is_empty() {
let len = cmp::min(buf.remaining(), buf1.len());
buf.put_slice(&buf1[..len]);
self.remaining_bytes.drain(..len);
return Poll::Ready(Ok(()));
impl Stream for Object {
type Item = io::Result<Bytes>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if !self.has_pending_messages {
return Poll::Ready(None);
}

if self.has_pending_messages {
if self.subscription.is_none() {
let future = match self.subscription_future.as_mut() {
Some(future) => future,
let subscription = match &mut self.subscription {
Some(subscription) => subscription,
None => {
let subscription_future = match &mut self.subscription_future {
Some(subscription_future) => subscription_future,
None => {
let stream = self.stream.clone();
let bucket = self.info.bucket.clone();
Expand All @@ -983,77 +996,68 @@ impl tokio::io::AsyncRead for Object {
}))
}
};
match future.as_mut().poll(cx) {
Poll::Ready(subscription) => {
self.subscription = Some(subscription.unwrap());

match subscription_future.poll_unpin(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Ok(subscription)) => self.subscription.insert(subscription),
Poll::Ready(Err(err)) => {
return Poll::Ready(Some(Err(io::Error::new(
io::ErrorKind::Other,
format!("error from JetStream create subscription: {err}"),
))))
}
Poll::Pending => (),
}
}
if let Some(subscription) = self.subscription.as_mut() {
match subscription.poll_next_unpin(cx) {
Poll::Ready(message) => match message {
Some(message) => {
let message = message.map_err(|err| {
std::io::Error::new(
std::io::ErrorKind::Other,
format!("error from JetStream subscription: {err}"),
)
})?;
let len = cmp::min(buf.remaining(), message.payload.len());
buf.put_slice(&message.payload[..len]);
if let Some(context) = &mut self.digest {
context.update(&message.payload);
}
self.remaining_bytes.extend(&message.payload[len..]);
};

let info = message.info().map_err(|err| {
std::io::Error::new(
std::io::ErrorKind::Other,
format!("error from JetStream subscription: {err}"),
)
})?;
if info.pending == 0 {
let digest = self.digest.take().map(Sha256::finish);
if let Some(digest) = digest {
if self
.info
.digest
.as_ref()
.map(|digest_self| {
format!("SHA-256={}", URL_SAFE.encode(digest))
!= *digest_self
})
.unwrap_or(false)
{
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"wrong digest",
)));
}
} else {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"digest should be Some",
)));
}
self.has_pending_messages = false;
self.subscription = None;
}
Poll::Ready(Ok(()))
match subscription.poll_next_unpin(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Some(Ok(message))) => {
if let Some(digest) = &mut self.digest {
digest.update(&message.message.payload);
}

let info = message.info().map_err(|err| {
io::Error::new(
io::ErrorKind::Other,
format!("error from JetStream subscription: {err}"),
)
})?;

if info.pending == 0 {
self.has_pending_messages = false;
self.subscription = None;

if let Some(digest) = self.digest.take() {
let digest = digest.finish();

if self
.info
.digest
.as_ref()
.map(|digest_self| {
format!("SHA-256={}", URL_SAFE.encode(digest)) != *digest_self
})
.unwrap_or(false)
{
return Poll::Ready(Some(Err(io::Error::new(
io::ErrorKind::InvalidData,
"wrong digest",
))));
}
None => Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::Other,
"subscription ended before reading whole object",
))),
},
Poll::Pending => Poll::Pending,
}
}
} else {
Poll::Pending

Poll::Ready(Some(Ok(message.message.payload)))
}
} else {
Poll::Ready(Ok(()))
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(io::Error::new(
io::ErrorKind::Other,
format!("error from JetStream subscription: {err}"),
)))),
Poll::Ready(None) => Poll::Ready(Some(Err(io::Error::new(
io::ErrorKind::Other,
"subscription ended before reading whole object",
)))),
}
}
}
Expand Down
11 changes: 7 additions & 4 deletions async-nats/tests/compatibility.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

#[cfg(feature = "compatibility_tests")]
mod compatibility {
use futures::{pin_mut, stream::Peekable, StreamExt};
use futures::{pin_mut, stream::Peekable, StreamExt, TryStreamExt};
use ring::digest::{self, SHA256};

use core::panic;
Expand All @@ -27,7 +27,6 @@ mod compatibility {
service::{self, ServiceExt},
};
use serde::{Deserialize, Serialize};
use tokio::io::AsyncReadExt;

#[tokio::test]
async fn kv() {
Expand Down Expand Up @@ -226,7 +225,9 @@ mod compatibility {
let mut object = bucket.get(request.object).await.unwrap();
let mut contents = vec![];

object.read_to_end(&mut contents).await.unwrap();
while let Some(chunk) = object.try_next().await.unwrap() {
contents.extend_from_slice(&chunk);
}

let digest = digest::digest(&SHA256, &contents);

Expand Down Expand Up @@ -295,7 +296,9 @@ mod compatibility {
let mut object = bucket.get(request.object).await.unwrap();
let mut contents = vec![];

object.read_to_end(&mut contents).await.unwrap();
while let Some(chunk) = object.try_next().await.unwrap() {
contents.extend_from_slice(&chunk);
}

let digest = digest::digest(&SHA256, &contents);

Expand Down
24 changes: 9 additions & 15 deletions async-nats/tests/object_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@ mod object_store {
stream::DirectGetErrorKind,
};
use base64::Engine;
use futures::StreamExt;
use futures::{StreamExt, TryStreamExt};
use rand::RngCore;
use tokio::io::AsyncReadExt;

use ring::digest::{self, SHA256};

#[tokio::test]
Expand Down Expand Up @@ -51,16 +49,8 @@ mod object_store {
let mut object = bucket.get("FOO").await.unwrap();

let mut result = Vec::new();
loop {
let mut buffer = [0; 1024];
if let Ok(n) = object.read(&mut buffer).await {
if n == 0 {
println!("finished");
break;
}

result.extend_from_slice(&buffer[..n]);
}
while let Some(chunk) = object.try_next().await.unwrap() {
result.extend_from_slice(&chunk);
}
assert_eq!(
Some(format!(
Expand All @@ -79,7 +69,9 @@ mod object_store {
let mut contents = Vec::new();

tracing::info!("reading content");
object_link.read_to_end(&mut contents).await.unwrap();
while let Some(chunk) = object_link.try_next().await.unwrap() {
contents.extend_from_slice(&chunk);
}
assert_eq!(contents, result);

bucket
Expand Down Expand Up @@ -350,7 +342,9 @@ mod object_store {
assert_eq!(object.info.digest, Some(format!("SHA-256={digest}")));

let mut result = Vec::new();
object.read_to_end(&mut result).await.unwrap();
while let Some(chunk) = object.try_next().await.unwrap() {
result.extend_from_slice(&chunk);
}
assert_eq!(result, file);
}
}
Expand Down
Loading