Skip to content

Commit

Permalink
Replace AsyncRead impl of Object with Stream of Bytes
Browse files Browse the repository at this point in the history
  • Loading branch information
paolobarbolini committed May 5, 2024
1 parent ecdd941 commit bddbb24
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 112 deletions.
1 change: 1 addition & 0 deletions async-nats/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ criterion = { version = "0.5", features = ["async_tokio"]}
nats-server = { path = "../nats-server" }
rand = "0.8"
tokio = { version = "1.25.0", features = ["rt-multi-thread"] }
tokio-util = { version = "0.7", features = ["io"] }
futures = { version = "0.3.28", default-features = false, features = ["std", "async-await"] }
tracing-subscriber = "0.3"
async-nats = {path = ".", features = ["experimental"]}
Expand Down
190 changes: 97 additions & 93 deletions async-nats/src/jetstream/object_store/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,23 @@
// limitations under the License.

//! Object Store module
use std::collections::VecDeque;
use std::fmt::Display;
use std::{cmp, str::FromStr, task::Poll, time::Duration};
use std::io;
use std::pin::Pin;
use std::task::Context;
use std::{str::FromStr, task::Poll, time::Duration};

use crate::crypto::Sha256;
use crate::subject::Subject;
use crate::{HeaderMap, HeaderValue};
use base64::engine::general_purpose::{STANDARD, URL_SAFE};
use base64::engine::Engine;
use bytes::BytesMut;
use bytes::{Bytes, BytesMut};
use futures::future::BoxFuture;
use once_cell::sync::Lazy;
use tokio::io::AsyncReadExt;

use futures::{Stream, StreamExt};
use futures::{FutureExt, Stream, StreamExt};
use regex::Regex;
use serde::{Deserialize, Serialize};
use tracing::{debug, trace};
Expand Down Expand Up @@ -92,24 +94,42 @@ pub struct ObjectStore {
impl ObjectStore {
/// Gets an [Object] from the [ObjectStore].
///
/// [Object] implements [tokio::io::AsyncRead] that allows
/// to read the data from Object Store.
/// [Object] implements [Stream] that allows
/// to stream chunks from Object Store.
///
/// # Examples
///
/// ```no_run
/// # #[tokio::main]
/// # async fn main() -> Result<(), async_nats::Error> {
/// use tokio::io::AsyncReadExt;
/// # async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
/// use std::env;
///
/// use tokio::fs::File;
///
/// let client = async_nats::connect("demo.nats.io").await?;
/// let jetstream = async_nats::jetstream::new(client);
///
/// let bucket = jetstream.get_object_store("store").await?;
/// let mut object = bucket.get("FOO").await?;
///
/// // Object implements `tokio::io::AsyncRead`.
/// let mut bytes = vec![];
/// object.read_to_end(&mut bytes).await?;
/// // Use the `Stream` implementation
/// use futures::TryStreamExt as _;
/// use tokio::io::AsyncWriteExt as _;
///
/// let mut file = File::create(env::temp_dir().join("FOO.bin")).await?;
/// while let Some(chunk) = object.try_next().await? {
/// file.write_all(&chunk).await?;
/// }
/// file.sync_all().await?;
///
/// // Alternatively use `tokio_util` with the `io` feature
/// // to convert the `Stream` into `AsyncRead`
/// // (less efficient because of the added memcpy)
/// let mut reader = tokio_util::io::StreamReader::new(object);
///
/// let mut file = File::create(env::temp_dir().join("FOO.bin")).await?;
/// tokio::io::copy(&mut reader, &mut file).await?;
/// file.sync_all().await?;
/// # Ok(())
/// # }
/// ```
Expand Down Expand Up @@ -920,7 +940,6 @@ impl Stream for List {
/// Represents an object stored in a bucket.
pub struct Object {
pub info: ObjectInfo,
remaining_bytes: VecDeque<u8>,
has_pending_messages: bool,
digest: Option<Sha256>,
subscription: Option<crate::jetstream::consumer::push::Ordered>,
Expand All @@ -933,7 +952,6 @@ impl Object {
Object {
subscription: None,
info,
remaining_bytes: VecDeque::new(),
has_pending_messages: true,
digest: Some(Sha256::new()),
subscription_future: None,
Expand All @@ -947,24 +965,19 @@ impl Object {
}
}

impl tokio::io::AsyncRead for Object {
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
let (buf1, _buf2) = self.remaining_bytes.as_slices();
if !buf1.is_empty() {
let len = cmp::min(buf.remaining(), buf1.len());
buf.put_slice(&buf1[..len]);
self.remaining_bytes.drain(..len);
return Poll::Ready(Ok(()));
impl Stream for Object {
type Item = io::Result<Bytes>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if !self.has_pending_messages {
return Poll::Ready(None);
}

if self.has_pending_messages {
if self.subscription.is_none() {
let future = match self.subscription_future.as_mut() {
Some(future) => future,
let subscription = match &mut self.subscription {
Some(subscription) => subscription,
None => {
let subscription_future = match &mut self.subscription_future {
Some(subscription_future) => subscription_future,
None => {
let stream = self.stream.clone();
let bucket = self.info.bucket.clone();
Expand All @@ -983,77 +996,68 @@ impl tokio::io::AsyncRead for Object {
}))
}
};
match future.as_mut().poll(cx) {
Poll::Ready(subscription) => {
self.subscription = Some(subscription.unwrap());

match subscription_future.poll_unpin(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Ok(subscription)) => self.subscription.insert(subscription),
Poll::Ready(Err(err)) => {
return Poll::Ready(Some(Err(io::Error::new(
io::ErrorKind::Other,
format!("error from JetStream create subscription: {err}"),
))))
}
Poll::Pending => (),
}
}
if let Some(subscription) = self.subscription.as_mut() {
match subscription.poll_next_unpin(cx) {
Poll::Ready(message) => match message {
Some(message) => {
let message = message.map_err(|err| {
std::io::Error::new(
std::io::ErrorKind::Other,
format!("error from JetStream subscription: {err}"),
)
})?;
let len = cmp::min(buf.remaining(), message.payload.len());
buf.put_slice(&message.payload[..len]);
if let Some(context) = &mut self.digest {
context.update(&message.payload);
}
self.remaining_bytes.extend(&message.payload[len..]);
};

let info = message.info().map_err(|err| {
std::io::Error::new(
std::io::ErrorKind::Other,
format!("error from JetStream subscription: {err}"),
)
})?;
if info.pending == 0 {
let digest = self.digest.take().map(Sha256::finish);
if let Some(digest) = digest {
if self
.info
.digest
.as_ref()
.map(|digest_self| {
format!("SHA-256={}", URL_SAFE.encode(digest))
!= *digest_self
})
.unwrap_or(false)
{
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"wrong digest",
)));
}
} else {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"digest should be Some",
)));
}
self.has_pending_messages = false;
self.subscription = None;
}
Poll::Ready(Ok(()))
match subscription.poll_next_unpin(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Some(Ok(message))) => {
if let Some(digest) = &mut self.digest {
digest.update(&message.message.payload);
}

let info = message.info().map_err(|err| {
io::Error::new(
io::ErrorKind::Other,
format!("error from JetStream subscription: {err}"),
)
})?;

if info.pending == 0 {
self.has_pending_messages = false;
self.subscription = None;

if let Some(digest) = self.digest.take() {
let digest = digest.finish();

if self
.info
.digest
.as_ref()
.map(|digest_self| {
format!("SHA-256={}", URL_SAFE.encode(digest)) != *digest_self
})
.unwrap_or(false)
{
return Poll::Ready(Some(Err(io::Error::new(
io::ErrorKind::InvalidData,
"wrong digest",
))));
}
None => Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::Other,
"subscription ended before reading whole object",
))),
},
Poll::Pending => Poll::Pending,
}
}
} else {
Poll::Pending

Poll::Ready(Some(Ok(message.message.payload)))
}
} else {
Poll::Ready(Ok(()))
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(io::Error::new(
io::ErrorKind::Other,
format!("error from JetStream subscription: {err}"),
)))),
Poll::Ready(None) => Poll::Ready(Some(Err(io::Error::new(
io::ErrorKind::Other,
"subscription ended before reading whole object",
)))),
}
}
}
Expand Down
11 changes: 7 additions & 4 deletions async-nats/tests/compatibility.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

#[cfg(feature = "compatibility_tests")]
mod compatibility {
use futures::{pin_mut, stream::Peekable, StreamExt};
use futures::{pin_mut, stream::Peekable, StreamExt, TryStreamExt};
use ring::digest::{self, SHA256};

use core::panic;
Expand All @@ -27,7 +27,6 @@ mod compatibility {
service::{self, ServiceExt},
};
use serde::{Deserialize, Serialize};
use tokio::io::AsyncReadExt;

#[tokio::test]
async fn kv() {
Expand Down Expand Up @@ -226,7 +225,9 @@ mod compatibility {
let mut object = bucket.get(request.object).await.unwrap();
let mut contents = vec![];

object.read_to_end(&mut contents).await.unwrap();
while let Some(chunk) = object.try_next().await.unwrap() {
contents.extend_from_slice(&chunk);
}

let digest = digest::digest(&SHA256, &contents);

Expand Down Expand Up @@ -295,7 +296,9 @@ mod compatibility {
let mut object = bucket.get(request.object).await.unwrap();
let mut contents = vec![];

object.read_to_end(&mut contents).await.unwrap();
while let Some(chunk) = object.try_next().await.unwrap() {
contents.extend_from_slice(&chunk);
}

let digest = digest::digest(&SHA256, &contents);

Expand Down
24 changes: 9 additions & 15 deletions async-nats/tests/object_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@ mod object_store {
stream::DirectGetErrorKind,
};
use base64::Engine;
use futures::StreamExt;
use futures::{StreamExt, TryStreamExt};
use rand::RngCore;
use tokio::io::AsyncReadExt;

use ring::digest::{self, SHA256};

#[tokio::test]
Expand Down Expand Up @@ -51,16 +49,8 @@ mod object_store {
let mut object = bucket.get("FOO").await.unwrap();

let mut result = Vec::new();
loop {
let mut buffer = [0; 1024];
if let Ok(n) = object.read(&mut buffer).await {
if n == 0 {
println!("finished");
break;
}

result.extend_from_slice(&buffer[..n]);
}
while let Some(chunk) = object.try_next().await.unwrap() {
result.extend_from_slice(&chunk);
}
assert_eq!(
Some(format!(
Expand All @@ -79,7 +69,9 @@ mod object_store {
let mut contents = Vec::new();

tracing::info!("reading content");
object_link.read_to_end(&mut contents).await.unwrap();
while let Some(chunk) = object_link.try_next().await.unwrap() {
contents.extend_from_slice(&chunk);
}
assert_eq!(contents, result);

bucket
Expand Down Expand Up @@ -350,7 +342,9 @@ mod object_store {
assert_eq!(object.info.digest, Some(format!("SHA-256={digest}")));

let mut result = Vec::new();
object.read_to_end(&mut result).await.unwrap();
while let Some(chunk) = object.try_next().await.unwrap() {
result.extend_from_slice(&chunk);
}
assert_eq!(result, file);
}
}
Expand Down

0 comments on commit bddbb24

Please sign in to comment.