Skip to content

Commit

Permalink
ADD: Create custom Python exception for Rust code
Browse files Browse the repository at this point in the history
  • Loading branch information
threecgreen committed May 7, 2024
1 parent 468a73b commit 3b43027
Show file tree
Hide file tree
Showing 10 changed files with 186 additions and 193 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
- Added `new_inferred`, `with_buffer`, `inferred_with_buffer`, `from_file`, `get_mut`,
and `get_ref` methods to `AsyncDynReader` for parity with the sync `DynReader`
- Improved documentation enumerating errors returned by functions
- Added new `DBNError` Python exception that's now the primary exception raised by
`databento_dbn`

### Breaking changes
- Changed type of `flags` in `MboMsg`, `TradeMsg`, `Mbp1Msg`, `Mbp10Msg`, and `CbboMsg`
Expand Down
19 changes: 12 additions & 7 deletions python/python/databento_dbn/_lib.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ _DBNRecord = Union[
StatusMsg,
]

class DBNError(Exception):
"""
An exception from databento_dbn Rust code.
"""

class Side(Enum):
"""
A side of the market. The side of the market for resting orders, or the side
Expand Down Expand Up @@ -950,7 +955,7 @@ class Metadata(SupportsBytes):
Raises
------
ValueError
DBNError
When a Metadata instance cannot be parsed from `data`.
"""
Expand All @@ -965,7 +970,7 @@ class Metadata(SupportsBytes):
Raises
------
ValueError
DBNError
When the Metadata object cannot be encoded.
"""
Expand Down Expand Up @@ -4651,7 +4656,7 @@ class DBNDecoder:
Raises
------
ValueError
DBNError
When the decoding fails.
See Also
Expand All @@ -4669,7 +4674,7 @@ class DBNDecoder:
Raises
------
ValueError
DBNError
When the write to the internal buffer fails.
See Also
Expand Down Expand Up @@ -4753,7 +4758,7 @@ class Transcoder:
Raises
------
ValueError
DBNError
When the write to the internal buffer or the output fails.
"""

Expand All @@ -4765,7 +4770,7 @@ class Transcoder:
Raises
------
ValueError
DBNError
When the write to the output fails.
"""

Expand Down Expand Up @@ -4793,7 +4798,7 @@ def update_encoded_metadata(
Raises
------
ValueError
DBNError
When the file update fails.
"""
31 changes: 13 additions & 18 deletions python/src/dbn_decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ use pyo3::prelude::*;

use dbn::{
decode::dbn::{MetadataDecoder, RecordDecoder},
python::to_val_err,
rtype_ts_out_dispatch, HasRType, Record, VersionUpgradePolicy,
python::to_py_err,
rtype_ts_out_dispatch, HasRType, VersionUpgradePolicy,
};

#[pyclass(module = "databento_dbn", name = "DBNDecoder")]
Expand Down Expand Up @@ -36,7 +36,7 @@ impl DbnDecoder {
}

fn write(&mut self, bytes: &[u8]) -> PyResult<()> {
self.buffer.write_all(bytes).map_err(to_val_err)
self.buffer.write_all(bytes).map_err(to_py_err)
}

fn buffer(&self) -> &[u8] {
Expand All @@ -63,7 +63,7 @@ impl DbnDecoder {
{
return Ok(Vec::new());
}
return Err(to_val_err(err));
return Err(PyErr::from(err));
}
}
}
Expand All @@ -73,10 +73,9 @@ impl DbnDecoder {
self.input_version,
self.upgrade_policy,
self.ts_out,
)
.map_err(to_val_err)?;
)?;
Python::with_gil(|py| -> PyResult<()> {
while let Some(rec) = decoder.decode_ref().map_err(to_val_err)? {
while let Some(rec) = decoder.decode_ref()? {
// Bug in clippy generates an error here. trivial_copy feature isn't enabled,
// but clippy thinks these records are `Copy`
fn push_rec<R: Clone + HasRType + IntoPy<Py<PyAny>>>(
Expand All @@ -89,14 +88,7 @@ impl DbnDecoder {

// Safety: It's safe to cast to `WithTsOut` because we're passing in the `ts_out`
// from the metadata header.
if unsafe { rtype_ts_out_dispatch!(rec, self.ts_out, push_rec, py, &mut recs) }
.is_err()
{
return Err(to_val_err(format!(
"Invalid rtype {} found in record",
rec.header().rtype,
)));
}
unsafe { rtype_ts_out_dispatch!(rec, self.ts_out, push_rec, py, &mut recs) }?;
// keep track of position after last _successful_ decoding to
// ensure buffer is left in correct state in the case where one
// or more successful decodings is followed by a partial one, i.e.
Expand Down Expand Up @@ -264,7 +256,7 @@ for record in records[1:]:
setup();
Python::with_gil(|py| {
py.run_bound(
r#"from _lib import DBNDecoder, Metadata, Schema, SType
r#"from _lib import DBNDecoder, DBNError, Metadata, Schema, SType
metadata = Metadata(
dataset="GLBX.MDP3",
Expand All @@ -286,8 +278,11 @@ try:
records = decoder.decode()
# If this code is called, the test will fail
assert False
except Exception as ex:
assert "Invalid rtype" in str(ex)
except DBNError as ex:
assert "couldn't convert" in str(ex)
assert "RType" in str(ex)
except Exception:
assert False
"#,
None,
None,
Expand Down
16 changes: 7 additions & 9 deletions python/src/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::{
num::NonZeroU64,
};

use dbn::{encode::dbn::MetadataEncoder, python::to_val_err};
use dbn::encode::dbn::MetadataEncoder;
use pyo3::{exceptions::PyTypeError, intern, prelude::*, types::PyBytes};

/// Updates existing fields that have already been written to the given file.
Expand All @@ -19,14 +19,12 @@ pub fn update_encoded_metadata(
let mut buf = [0; 4];
file.read_exact(&mut buf)?;
let version = buf[3];
MetadataEncoder::new(file)
.update_encoded(
version,
start,
end.and_then(NonZeroU64::new),
limit.and_then(NonZeroU64::new),
)
.map_err(to_val_err)
Ok(MetadataEncoder::new(file).update_encoded(
version,
start,
end.and_then(NonZeroU64::new),
limit.and_then(NonZeroU64::new),
)?)
}

/// A Python object that implements the Python file interface.
Expand Down
7 changes: 4 additions & 3 deletions python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use pyo3::{prelude::*, wrap_pyfunction, PyClass};
use dbn::{
compat::{ErrorMsgV1, InstrumentDefMsgV1, SymbolMappingMsgV1, SystemMsgV1},
flags,
python::EnumIterator,
python::{DBNError, EnumIterator},
Action, BidAskPair, CbboMsg, Compression, ConsolidatedBidAskPair, Encoding, ErrorMsg,
ImbalanceMsg, InstrumentClass, InstrumentDefMsg, MatchAlgorithm, MboMsg, Mbp10Msg, Mbp1Msg,
Metadata, OhlcvMsg, RType, RecordHeader, SType, Schema, SecurityUpdateAction, Side, StatMsg,
Expand All @@ -29,10 +29,11 @@ fn databento_dbn(_py: Python<'_>, m: &Bound<PyModule>) -> PyResult<()> {
}
// all functions exposed to Python need to be added here
m.add_wrapped(wrap_pyfunction!(encode::update_encoded_metadata))?;
m.add("DBNError", m.py().get_type_bound::<DBNError>())?;
checked_add_class::<EnumIterator>(m)?;
checked_add_class::<Metadata>(m)?;
checked_add_class::<dbn_decoder::DbnDecoder>(m)?;
checked_add_class::<transcoder::Transcoder>(m)?;
checked_add_class::<Metadata>(m)?;
checked_add_class::<EnumIterator>(m)?;
// Records
checked_add_class::<RecordHeader>(m)?;
checked_add_class::<MboMsg>(m)?;
Expand Down
53 changes: 19 additions & 34 deletions python/src/transcoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use dbn::{
CsvEncoder, DbnMetadataEncoder, DbnRecordEncoder, DynWriter, EncodeRecordRef,
EncodeRecordTextExt, JsonEncoder,
},
python::{py_to_time_date, to_val_err},
python::{py_to_time_date, to_py_err},
Compression, Encoding, PitSymbolMap, RType, Record, RecordRef, Schema, SymbolIndex,
TsSymbolMap, VersionUpgradePolicy,
};
Expand Down Expand Up @@ -51,9 +51,7 @@ impl Transcoder {
}
let start_date = py_to_time_date(start_date)?;
let end_date = py_to_time_date(end_date)?;
symbol_map
.insert(iid, start_date, end_date, Arc::new(symbol))
.map_err(to_val_err)?;
symbol_map.insert(iid, start_date, end_date, Arc::new(symbol))?;
}
}
Some(symbol_map)
Expand Down Expand Up @@ -141,7 +139,7 @@ struct Inner<const E: u8> {

impl<const E: u8> Transcode for Inner<E> {
fn write(&mut self, bytes: &[u8]) -> PyResult<()> {
self.buffer.write_all(bytes).map_err(to_val_err)?;
self.buffer.write_all(bytes).map_err(to_py_err)?;
self.encode()
}

Expand Down Expand Up @@ -177,7 +175,7 @@ impl<const OUTPUT_ENC: u8> Inner<OUTPUT_ENC> {
}
Ok(Self {
buffer: io::Cursor::default(),
output: DynWriter::new(BufWriter::new(file), compression).map_err(to_val_err)?,
output: DynWriter::new(BufWriter::new(file), compression)?,
use_pretty_px: pretty_px.unwrap_or(true),
use_pretty_ts: pretty_ts.unwrap_or(true),
map_symbols: map_symbols.unwrap_or(true),
Expand Down Expand Up @@ -215,14 +213,12 @@ impl<const OUTPUT_ENC: u8> Inner<OUTPUT_ENC> {
self.input_version,
self.upgrade_policy,
self.ts_out,
)
.map_err(to_val_err)?;
)?;
let mut encoder = DbnRecordEncoder::new(&mut self.output);
loop {
match decoder.decode_record_ref() {
Ok(Some(rec)) => {
unsafe { encoder.encode_record_ref_ts_out(rec, self.ts_out) }
.map_err(to_val_err)?;
unsafe { encoder.encode_record_ref_ts_out(rec, self.ts_out) }?;
// keep track of position after last _successful_ decoding to
// ensure buffer is left in correct state in the case where one
// or more successful decodings is followed by a partial one, i.e.
Expand All @@ -234,7 +230,7 @@ impl<const OUTPUT_ENC: u8> Inner<OUTPUT_ENC> {
}
Err(err) => {
self.buffer.set_position(orig_position);
return Err(to_val_err(err));
return Err(PyErr::from(err));
}
}
}
Expand All @@ -248,15 +244,13 @@ impl<const OUTPUT_ENC: u8> Inner<OUTPUT_ENC> {
self.input_version,
self.upgrade_policy,
self.ts_out,
)
.map_err(to_val_err)?;
)?;

let mut encoder = CsvEncoder::builder(&mut self.output)
.use_pretty_px(self.use_pretty_px)
.use_pretty_ts(self.use_pretty_ts)
.write_header(false)
.build()
.map_err(to_val_err)?;
.build()?;
loop {
match decoder.decode_record_ref() {
Ok(Some(rec)) => {
Expand All @@ -275,8 +269,7 @@ impl<const OUTPUT_ENC: u8> Inner<OUTPUT_ENC> {
unsafe { encoder.encode_ref_ts_out_with_sym(rec, self.ts_out, symbol) }
} else {
unsafe { encoder.encode_record_ref_ts_out(rec, self.ts_out) }
}
.map_err(to_val_err)?;
}?;
}
// keep track of position after last _successful_ decoding to
// ensure buffer is left in correct state in the case where one
Expand All @@ -289,7 +282,7 @@ impl<const OUTPUT_ENC: u8> Inner<OUTPUT_ENC> {
}
Err(err) => {
self.buffer.set_position(orig_position);
return Err(to_val_err(err));
return Err(PyErr::from(err));
}
}
}
Expand All @@ -303,8 +296,7 @@ impl<const OUTPUT_ENC: u8> Inner<OUTPUT_ENC> {
self.input_version,
self.upgrade_policy,
self.ts_out,
)
.map_err(to_val_err)?;
)?;

let mut encoder = JsonEncoder::builder(&mut self.output)
.use_pretty_px(self.use_pretty_px)
Expand All @@ -319,8 +311,7 @@ impl<const OUTPUT_ENC: u8> Inner<OUTPUT_ENC> {
unsafe { encoder.encode_ref_ts_out_with_sym(rec, self.ts_out, symbol) }
} else {
unsafe { encoder.encode_record_ref_ts_out(rec, self.ts_out) }
}
.map_err(to_val_err)?;
}?;
// keep track of position after last _successful_ decoding to
// ensure buffer is left in correct state in the case where one
// or more successful decodings is followed by a partial one, i.e.
Expand All @@ -332,7 +323,7 @@ impl<const OUTPUT_ENC: u8> Inner<OUTPUT_ENC> {
}
Err(err) => {
self.buffer.set_position(orig_position);
return Err(to_val_err(err));
return Err(PyErr::from(err));
}
}
}
Expand All @@ -353,20 +344,16 @@ impl<const OUTPUT_ENC: u8> Inner<OUTPUT_ENC> {
metadata.upgrade(self.upgrade_policy);
// Setup live symbol mapping
if OUTPUT_ENC == Encoding::Dbn as u8 {
DbnMetadataEncoder::new(&mut self.output)
.encode(&metadata)
.map_err(to_val_err)?;
DbnMetadataEncoder::new(&mut self.output).encode(&metadata)?;
// CSV or JSON
} else if self.map_symbols {
if metadata.schema.is_some() {
// historical
// only read from metadata mappings if symbol_map is unpopulated,
// i.e. no `symbol_map` was passed in
if self.symbol_map.is_empty() {
self.symbol_map = metadata
.symbol_map()
.map(SymbolMap::Historical)
.map_err(to_val_err)?;
self.symbol_map =
metadata.symbol_map().map(SymbolMap::Historical)?;
}
} else {
// live
Expand All @@ -381,7 +368,7 @@ impl<const OUTPUT_ENC: u8> Inner<OUTPUT_ENC> {
{
return Ok(false);
}
return Err(to_val_err(err));
return Err(PyErr::from(err));
}
}
// decoding metadata and the header are both done once at the beginning
Expand All @@ -393,9 +380,7 @@ impl<const OUTPUT_ENC: u8> Inner<OUTPUT_ENC> {
};
let mut encoder =
CsvEncoder::new(&mut self.output, self.use_pretty_px, self.use_pretty_ts);
encoder
.encode_header_for_schema(schema, self.ts_out, self.map_symbols)
.map_err(to_val_err)?;
encoder.encode_header_for_schema(schema, self.ts_out, self.map_symbols)?;
}
}
Ok(true)
Expand Down
Loading

0 comments on commit 3b43027

Please sign in to comment.