Skip to content

Commit

Permalink
MOD: Reduce unsafe in record version conversions
Browse files Browse the repository at this point in the history
  • Loading branch information
threecgreen committed Dec 31, 2024
1 parent 6f712bc commit ec950c3
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 75 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
- Relaxed `DecodeRecord` trait constraint on `StreamIterDecoder`'s inner decoder
- Added `DbnMetadata` implementation for `StreamInnerDecoder` if the inner decoder
implements `DbnMetadata`
- Eliminate `unsafe` in `From` implementations for record structs from different versions

## 0.25.0 - 2024-12-17

Expand Down
6 changes: 3 additions & 3 deletions rust/dbn/src/compat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,15 @@ unsafe fn upgrade_record<'a, T, U>(
) -> RecordRef<'a>
where
T: HasRType,
U: HasRType + for<'b> From<&'b T>,
U: AsRef<[u8]> + HasRType + for<'b> From<&'b T>,
{
if ts_out {
let rec = transmute_record_bytes::<WithTsOut<T>>(input).unwrap();
let upgraded = WithTsOut::new(U::from(&rec.rec), rec.ts_out);
std::ptr::copy_nonoverlapping(&upgraded, compat_buffer.as_mut_ptr().cast(), 1);
compat_buffer[..upgraded.as_ref().len()].copy_from_slice(upgraded.as_ref());
} else {
let upgraded = U::from(transmute_record_bytes::<T>(input).unwrap());
std::ptr::copy_nonoverlapping(&upgraded, compat_buffer.as_mut_ptr().cast(), 1);
compat_buffer[..upgraded.as_ref().len()].copy_from_slice(upgraded.as_ref());
}
RecordRef::new(compat_buffer)
}
Expand Down
8 changes: 5 additions & 3 deletions rust/dbn/src/record/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ pub unsafe fn transmute_record_mut<T: HasRType>(header: &mut RecordHeader) -> Op
/// This function returns an error if `s` contains more than N - 1 characters. The last
/// character is reserved for the null byte.
pub fn str_to_c_chars<const N: usize>(s: &str) -> Result<[c_char; N]> {
let s = s.as_bytes();
if s.len() > (N - 1) {
return Err(Error::encode(format!(
"string cannot be longer than {}; received str of length {}",
Expand All @@ -109,9 +110,10 @@ pub fn str_to_c_chars<const N: usize>(s: &str) -> Result<[c_char; N]> {
)));
}
let mut res = [0; N];
for (i, byte) in s.as_bytes().iter().enumerate() {
res[i] = *byte as c_char;
}
res[..s.len()].copy_from_slice(
// Safety: checked length of string and okay to interpret `u8` as `c_char`.
unsafe { std::mem::transmute::<&[u8], &[c_char]>(s) },
);
Ok(res)
}

Expand Down
35 changes: 6 additions & 29 deletions rust/dbn/src/v2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,7 @@ impl From<&v1::InstrumentDefMsg> for InstrumentDefMsg {
tick_rule: old.tick_rule,
..Default::default()
};
// Safety: SYMBOL_CSTR_LEN_V1 is less than SYMBOL_CSTR_LEN
unsafe {
std::ptr::copy_nonoverlapping(
old.raw_symbol.as_ptr(),
res.raw_symbol.as_mut_ptr(),
v1::SYMBOL_CSTR_LEN,
);
}
res.raw_symbol[..v1::SYMBOL_CSTR_LEN].copy_from_slice(old.raw_symbol.as_slice());
res
}
}
Expand All @@ -106,10 +99,7 @@ impl From<&v1::ErrorMsg> for ErrorMsg {
),
..Default::default()
};
// Safety: new `err` is longer than older
unsafe {
std::ptr::copy_nonoverlapping(old.err.as_ptr(), new.err.as_mut_ptr(), new.err.len());
}
new.err[..old.err.len()].copy_from_slice(old.err.as_slice());
new
}
}
Expand All @@ -127,19 +117,9 @@ impl From<&v1::SymbolMappingMsg> for SymbolMappingMsg {
end_ts: old.end_ts,
..Default::default()
};
// Safety: SYMBOL_CSTR_LEN_V1 is less than SYMBOL_CSTR_LEN
unsafe {
std::ptr::copy_nonoverlapping(
old.stype_in_symbol.as_ptr(),
res.stype_in_symbol.as_mut_ptr(),
v1::SYMBOL_CSTR_LEN,
);
std::ptr::copy_nonoverlapping(
old.stype_out_symbol.as_ptr(),
res.stype_out_symbol.as_mut_ptr(),
v1::SYMBOL_CSTR_LEN,
);
}
res.stype_in_symbol[..v1::SYMBOL_CSTR_LEN].copy_from_slice(old.stype_in_symbol.as_slice());
res.stype_out_symbol[..v1::SYMBOL_CSTR_LEN]
.copy_from_slice(old.stype_out_symbol.as_slice());
res
}
}
Expand All @@ -155,10 +135,7 @@ impl From<&v1::SystemMsg> for SystemMsg {
),
..Default::default()
};
// Safety: new `msg` is longer than older
unsafe {
std::ptr::copy_nonoverlapping(old.msg.as_ptr(), new.msg.as_mut_ptr(), new.msg.len());
}
new.msg[..old.msg.len()].copy_from_slice(old.msg.as_slice());
new
}
}
Expand Down
45 changes: 5 additions & 40 deletions rust/dbn/src/v3/methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ use crate::{
pretty::px_to_f64,
record::{c_chars_to_str, ts_to_dt},
rtype, v1, v2, Error, InstrumentClass, MatchAlgorithm, RecordHeader, Result,
SecurityUpdateAction, Side, UserDefinedInstrument, UNDEF_PRICE,
SecurityUpdateAction, UserDefinedInstrument,
};

use super::{InstrumentDefMsg, SYMBOL_CSTR_LEN};
use super::InstrumentDefMsg;

impl From<&v1::InstrumentDefMsg> for InstrumentDefMsg {
fn from(old: &v1::InstrumentDefMsg) -> Self {
Expand Down Expand Up @@ -74,31 +74,9 @@ impl From<&v1::InstrumentDefMsg> for InstrumentDefMsg {
contract_multiplier_unit: old.contract_multiplier_unit,
flow_schedule_type: old.flow_schedule_type,
tick_rule: old.tick_rule,
// Set later
raw_symbol: [0; SYMBOL_CSTR_LEN],
leg_count: 0,
leg_index: 0,
leg_price: UNDEF_PRICE,
leg_delta: UNDEF_PRICE,
leg_instrument_id: 0,
leg_ratio_price_numerator: 0,
leg_ratio_price_denominator: 0,
leg_ratio_qty_numerator: 0,
leg_ratio_qty_denominator: 0,
leg_underlying_id: 0,
leg_raw_symbol: [0; SYMBOL_CSTR_LEN],
leg_instrument_class: 0,
leg_side: Side::None as c_char,
_reserved: Default::default(),
..Default::default()
};
// Safety: SYMBOL_CSTR_LEN_V1 is less than SYMBOL_CSTR_LEN
unsafe {
std::ptr::copy_nonoverlapping(
old.raw_symbol.as_ptr(),
res.raw_symbol.as_mut_ptr(),
v1::SYMBOL_CSTR_LEN,
);
}
res.raw_symbol[..v1::SYMBOL_CSTR_LEN].copy_from_slice(old.raw_symbol.as_slice());
res
}
}
Expand Down Expand Up @@ -169,20 +147,7 @@ impl From<&v2::InstrumentDefMsg> for InstrumentDefMsg {
flow_schedule_type: old.flow_schedule_type,
tick_rule: old.tick_rule,
raw_symbol: old.raw_symbol,
leg_count: 0,
leg_index: 0,
leg_price: UNDEF_PRICE,
leg_delta: UNDEF_PRICE,
leg_instrument_id: 0,
leg_ratio_price_numerator: 0,
leg_ratio_price_denominator: 0,
leg_ratio_qty_numerator: 0,
leg_ratio_qty_denominator: 0,
leg_underlying_id: 0,
leg_raw_symbol: [0; SYMBOL_CSTR_LEN],
leg_instrument_class: 0,
leg_side: Side::None as c_char,
_reserved: Default::default(),
..Default::default()
}
}
}
Expand Down

0 comments on commit ec950c3

Please sign in to comment.