Skip to content

Commit

Permalink
MOD: Add methods to AsyncDynReader
Browse files Browse the repository at this point in the history
  • Loading branch information
threecgreen committed Apr 25, 2024
1 parent 304bea5 commit b223661
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 22 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
- Added links to example usage in documentation
- Added new predicate methods `InstrumentClass::is_option`, `is_future`, and `is_spread`
to make it easier to work with multiple instrument class variants
- Implemented `DecodeRecord` for `DbnRecordDecoder`
- Added `new_inferred`, `with_buffer`, `inferred_with_buffer`, `from_file`, `get_mut`,
and `get_ref` methods to `AsyncDynReader` for parity with the sync `DynReader`
- Improved documentation enumerating errors returned by functions

### Breaking changes
- Removed `write_dbn_file` function deprecated in version 0.14.0 from Python interface.
Expand Down
128 changes: 107 additions & 21 deletions rust/dbn/src/decode.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//! Decoding DBN and Zstd-compressed DBN files and streams. Decoders implement the
//! [`DecodeDbn`] trait.
//! Decoding DBN and Zstd-compressed DBN files and streams. Sync decoders implement
//the ! [`DecodeDbn`] trait.
pub mod dbn;
// Having any tests in a deprecated module emits many warnings that can't be silenced, see
// https://github.com/rust-lang/rust/issues/47238
Expand Down Expand Up @@ -33,7 +33,6 @@ use crate::{
enums::{Compression, VersionUpgradePolicy},
record::HasRType,
record_ref::RecordRef,
// record_ref::RecordRef,
Metadata,
};

Expand Down Expand Up @@ -328,6 +327,8 @@ where
R: io::Read,
{
/// Creates a new [`DynReader`] from a reader, with the specified `compression`.
/// If `reader` also implements [`BufRead`](io::BufRead), it's better to use
/// [`with_buffer()`](Self::with_buffer).
///
/// # Errors
/// This function will return an error if it fails to create the zstd decoder.
Expand Down Expand Up @@ -407,7 +408,7 @@ impl<'a> DynReader<'a, BufReader<File>> {
///
/// # Errors
/// This function will return an error if the file doesn't exist, it is unable to
/// determine the encoding of the file or it fails to parse the metadata.
/// determine the encoding of the file, or it fails to create the zstd decoder.
pub fn from_file(path: impl AsRef<Path>) -> crate::Result<Self> {
let file = File::open(path.as_ref()).map_err(|e| {
crate::Error::io(
Expand Down Expand Up @@ -447,11 +448,13 @@ where
}
}

mod private {
#[doc(hidden)]
pub mod private {
/// An implementation detail for the interaction between [`StreamingIterator`] and
/// implementors of [`DecodeDbn`].
/// implementors of [`DecodeRecord`].
#[doc(hidden)]
pub trait BufferSlice {
/// Returns an immutable slice of the decoder's buffer.
fn buffer_slice(&self) -> &[u8];
}
}
Expand All @@ -461,31 +464,35 @@ pub(crate) trait FromLittleEndianSlice {
}

impl FromLittleEndianSlice for u64 {
/// NOTE: assumes the length of `slice` is at least 8 bytes
/// # Panics
/// Panics if the length of `slice` is less than 8 bytes.
fn from_le_slice(slice: &[u8]) -> Self {
let (bytes, _) = slice.split_at(mem::size_of::<Self>());
Self::from_le_bytes(bytes.try_into().unwrap())
}
}

impl FromLittleEndianSlice for i32 {
/// NOTE: assumes the length of `slice` is at least 4 bytes
/// # Panics
/// Panics if the length of `slice` is less than 4 bytes.
fn from_le_slice(slice: &[u8]) -> Self {
let (bytes, _) = slice.split_at(mem::size_of::<Self>());
Self::from_le_bytes(bytes.try_into().unwrap())
}
}

impl FromLittleEndianSlice for u32 {
/// NOTE: assumes the length of `slice` is at least 4 bytes
/// # Panics
/// Panics if the length of `slice` is less than 4 bytes.
fn from_le_slice(slice: &[u8]) -> Self {
let (bytes, _) = slice.split_at(mem::size_of::<Self>());
Self::from_le_bytes(bytes.try_into().unwrap())
}
}

impl FromLittleEndianSlice for u16 {
/// NOTE: assumes the length of `slice` is at least 2 bytes
/// # Panics
/// Panics if the length of `slice` is less than 2 bytes.
fn from_le_slice(slice: &[u8]) -> Self {
let (bytes, _) = slice.split_at(mem::size_of::<Self>());
Self::from_le_bytes(bytes.try_into().unwrap())
Expand Down Expand Up @@ -551,42 +558,121 @@ pub use self::{

#[cfg(feature = "async")]
mod r#async {
use std::pin::Pin;
use std::{path::Path, pin::Pin};

use async_compression::tokio::bufread::ZstdDecoder;
use tokio::io::{self, BufReader};
use tokio::{
fs::File,
io::{self, BufReader},
};

use crate::enums::Compression;

/// A type for runtime polymorphism on compressed and uncompressed input.
/// The async version of [`DynReader`](super::DynReader).
pub struct DynReader<R>(DynReaderImpl<R>)
where
R: io::AsyncReadExt + Unpin;
R: io::AsyncBufReadExt + Unpin;

enum DynReaderImpl<R>
where
R: io::AsyncReadExt + Unpin,
R: io::AsyncBufReadExt + Unpin,
{
Uncompressed(R),
ZStd(ZstdDecoder<BufReader<R>>),
ZStd(ZstdDecoder<R>),
}

impl<R> DynReader<R>
impl<R> DynReader<BufReader<R>>
where
R: io::AsyncReadExt + Unpin,
{
/// Creates a new instance of [`DynReader`] with the specified `compression`.
/// Creates a new instance of [`DynReader`] with the specified `compression`. If
/// `reader` also implements [`AsyncBufRead`](tokio::io::AsyncBufRead), it's
/// better to use [`with_buffer()`](Self::with_buffer).
pub fn new(reader: R, compression: Compression) -> Self {
Self(match compression {
Compression::None => DynReaderImpl::Uncompressed(reader),
Compression::ZStd => DynReaderImpl::ZStd(ZstdDecoder::new(BufReader::new(reader))),
Self::with_buffer(BufReader::new(reader), compression)
}

/// Creates a new [`DynReader`] from a reader, inferring the compression.
/// If `reader` also implements [`AsyncBufRead`](tokio::io::AsyncBufRead), it is
/// better to use [`inferred_with_buffer()`](Self::inferred_with_buffer).
///
/// # Errors
/// This function will return an error if it is unable to read from `reader`.
pub async fn new_inferred(reader: R) -> crate::Result<Self> {
Self::inferred_with_buffer(BufReader::new(reader)).await
}
}

impl<R> DynReader<R>
where
R: io::AsyncBufReadExt + Unpin,
{
/// Creates a new [`DynReader`] from a buffered reader with the specified
/// `compression`.
pub fn with_buffer(reader: R, compression: Compression) -> Self {
match compression {
Compression::None => Self(DynReaderImpl::Uncompressed(reader)),
Compression::ZStd => Self(DynReaderImpl::ZStd(ZstdDecoder::new(reader))),
}
}

/// Creates a new [`DynReader`] from a buffered reader, inferring the compression.
///
/// # Errors
/// This function will return an error if it fails to read from `reader`.
pub async fn inferred_with_buffer(mut reader: R) -> crate::Result<Self> {
let first_bytes = reader
.fill_buf()
.await
.map_err(|e| crate::Error::io(e, "creating buffer to infer encoding"))?;
Ok(if super::zstd::starts_with_prefix(first_bytes) {
Self(DynReaderImpl::ZStd(ZstdDecoder::new(reader)))
} else {
Self(DynReaderImpl::Uncompressed(reader))
})
}

/// Returns a mutable reference to the inner reader.
pub fn get_mut(&mut self) -> &mut R {
match &mut self.0 {
DynReaderImpl::Uncompressed(reader) => reader,
DynReaderImpl::ZStd(reader) => reader.get_mut(),
}
}

/// Returns a reference to the inner reader.
pub fn get_ref(&self) -> &R {
match &self.0 {
DynReaderImpl::Uncompressed(reader) => reader,
DynReaderImpl::ZStd(reader) => reader.get_ref(),
}
}
}

impl DynReader<BufReader<File>> {
/// Creates a new [`DynReader`] from the file at `path`.
///
/// # Errors
/// This function will return an error if the file doesn't exist, it is unable
/// to read from it.
pub async fn from_file(path: impl AsRef<Path>) -> crate::Result<Self> {
let file = File::open(path.as_ref()).await.map_err(|e| {
crate::Error::io(
e,
format!(
"opening file to decode at path '{}'",
path.as_ref().display()
),
)
})?;
DynReader::new_inferred(file).await
}
}

impl<R> io::AsyncRead for DynReader<R>
where
R: io::AsyncRead + io::AsyncReadExt + Unpin,
R: io::AsyncRead + io::AsyncReadExt + io::AsyncBufReadExt + Unpin,
{
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
Expand Down
9 changes: 9 additions & 0 deletions rust/dbn/src/decode/dbn/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,15 @@ where
}
}

impl<R> DecodeRecord for RecordDecoder<R>
where
R: io::Read,
{
fn decode_record<T: HasRType>(&mut self) -> crate::Result<Option<&T>> {
self.decode()
}
}

impl<R> DecodeRecordRef for RecordDecoder<R>
where
R: io::Read,
Expand Down
3 changes: 2 additions & 1 deletion rust/dbn/src/decode/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ where
D: DecodeRecord,
T: HasRType,
{
pub(crate) fn new(decoder: D) -> Self {
/// Creates a new streaming decoder using the given `decoder`.
pub fn new(decoder: D) -> Self {
Self {
decoder,
i: Some(0),
Expand Down

0 comments on commit b223661

Please sign in to comment.