From 701d50648b3795b1af0b49589dfa21a9cc2cdd1c Mon Sep 17 00:00:00 2001 From: Mitya Selivanov Date: Tue, 9 Apr 2024 17:15:04 +0200 Subject: [PATCH] Macro for different type instances --- README.md | 7 +- example.py | 7 +- src/lib.rs | 420 +++++++++++++++++++++++++++++------------------------ 3 files changed, 240 insertions(+), 194 deletions(-) diff --git a/README.md b/README.md index 9afc02d..2314f84 100644 --- a/README.md +++ b/README.md @@ -15,16 +15,15 @@ Python program to use the library with SipHash and 1-byte symbols: import riblt_rust_py as riblt symbol_size = 1 -key_0 = 123 -key_1 = 456 +key = bytes([16, 15, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 2, 1]) -enc = riblt.new_encoder_sip(symbol_size, key_0, key_1) +enc = riblt.new_encoder_sip(symbol_size, key) enc.add_symbol(bytes([1])) enc.add_symbol(bytes([2])) enc.add_symbol(bytes([4])) -dec = riblt.new_decoder_sip(symbol_size, key_0, key_1) +dec = riblt.new_decoder_sip(symbol_size, key) dec.add_symbol(bytes([1])) dec.add_symbol(bytes([3])) diff --git a/example.py b/example.py index bc5b0cb..02375c1 100644 --- a/example.py +++ b/example.py @@ -2,16 +2,15 @@ def example(): symbol_size = 1 - key_0 = 123 - key_1 = 456 + key = bytes([16, 15, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 2, 1]) - enc = riblt.new_encoder_sip(symbol_size, key_0, key_1) + enc = riblt.new_encoder_sip(symbol_size, key) enc.add_symbol(bytes([1])) enc.add_symbol(bytes([2])) enc.add_symbol(bytes([4])) - dec = riblt.new_decoder_sip(symbol_size, key_0, key_1) + dec = riblt.new_decoder_sip(symbol_size, key) dec.add_symbol(bytes([1])) dec.add_symbol(bytes([3])) diff --git a/src/lib.rs b/src/lib.rs index ec3d0d0..e5c5669 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,7 +5,8 @@ use riblt::*; #[allow(deprecated)] use std::hash::{SipHasher, Hasher}; -const MAX_SIZE : usize = 64; +const HASH_SIZE : usize = 8; +const KEY_SIZE : usize = 16; #[derive(Clone, Copy)] enum Hash { @@ -13,223 +14,270 @@ enum Hash { SIP } -#[pyclass] -#[derive(Clone, Copy)] -struct PySymbol { - bytes : [u8; MAX_SIZE], - size : usize, - hash_type : Hash, - hash_keys : (u64, u64), -} +macro_rules! instant { + ($max_size : expr, + $Sym : ident, + $Hashed : ident, + $Coded : ident, + $Enc : ident, + $Dec : ident) => { + #[pyclass] + #[derive(Clone, Copy)] + struct $Sym { + bytes : [u8; $max_size], + size : usize, + hash_type : Hash, + hash_key : [u8; KEY_SIZE], + } -#[pyclass] -struct PyHashedSymbol { - #[pyo3(get, set)] - pub data : [u8; MAX_SIZE], - #[pyo3(get, set)] - pub hash : u64, -} + #[pyclass] + struct $Hashed { + #[pyo3(get, set)] + pub data : [u8; $max_size], + #[pyo3(get, set)] + pub hash : [u8; HASH_SIZE], + } -#[pyclass] -struct PyCodedSymbol { - #[pyo3(get, set)] - pub data : [u8; MAX_SIZE], - #[pyo3(get, set)] - pub hash : u64, - #[pyo3(get, set)] - pub count : i64, -} + #[pyclass] + struct $Coded { + #[pyo3(get, set)] + pub data : [u8; $max_size], + #[pyo3(get, set)] + pub hash : [u8; HASH_SIZE], + #[pyo3(get, set)] + pub count : i64, + } -#[pyclass] -struct PyEncoder { - enc : Encoder, - symbol_size : usize, - hash_type : Hash, - hash_keys : (u64, u64), -} + #[pyclass] + struct $Enc { + enc : Encoder<$Sym>, + symbol_size : usize, + hash_type : Hash, + hash_key : [u8; KEY_SIZE], + } -#[pyclass] -struct PyDecoder { - dec : Decoder, - symbol_size : usize, - hash_type : Hash, - hash_keys : (u64, u64), -} + #[pyclass] + struct $Dec { + dec : Decoder<$Sym>, + symbol_size : usize, + hash_type : Hash, + hash_key : [u8; KEY_SIZE], + } -impl Symbol for PySymbol { - fn zero() -> PySymbol { - return PySymbol { - bytes : core::array::from_fn(|_| 0), - size : 0, - hash_type : Hash::NONE, - hash_keys : (0, 0), - }; - } + impl Symbol for $Sym { + fn zero() -> $Sym { + return $Sym { + bytes : core::array::from_fn(|_| 0), + size : 0, + hash_type : Hash::NONE, + hash_key : core::array::from_fn(|_| 0), + }; + } - fn xor(&self, other: &PySymbol) -> PySymbol { - let (s, t, k0, k1) = match self.hash_type { - Hash::NONE => (other.size, other.hash_type, other.hash_keys.0, other.hash_keys.1), - _ => ( self.size, self.hash_type, self.hash_keys.0, self.hash_keys.1), - }; - return PySymbol { - bytes : core::array::from_fn(|i| self.bytes[i] ^ other.bytes[i]), - size : s, - hash_type : t, - hash_keys : (k0, k1), - }; - } + fn xor(&self, other: &$Sym) -> $Sym { + let (s, t, k) = match self.hash_type { + Hash::NONE => (other.size, other.hash_type, other.hash_key), + _ => ( self.size, self.hash_type, self.hash_key), + }; + return $Sym { + bytes : core::array::from_fn(|i| self.bytes[i] ^ other.bytes[i]), + size : s, + hash_type : t, + hash_key : k, + }; + } + + #[allow(deprecated)] + fn hash(&self) -> u64 { + match self.hash_type { + Hash::SIP => { + let k = self.hash_key; + let mut hasher = SipHasher::new_with_keys( + u64::from_le_bytes([k[ 0], k[ 1], k[ 2], k[ 3], k[ 4], k[ 5], k[ 6], k[ 7]]), + u64::from_le_bytes([k[ 8], k[ 9], k[10], k[11], k[12], k[13], k[14], k[15]]) + ); + hasher.write(&self.bytes); + return hasher.finish(); + }, - #[allow(deprecated)] - fn hash(&self) -> u64 { - match self.hash_type { - Hash::SIP => { - let mut hasher = SipHasher::new_with_keys(self.hash_keys.0, self.hash_keys.1); - hasher.write(&self.bytes); - return hasher.finish(); - }, - - _ => { - return 0; - }, + _ => { + return 0; + }, + } + } } - } -} -#[pymethods] -impl PyEncoder { - fn reset(&mut self) -> PyResult<()> { - self.enc.reset(); - Ok(()) - } + #[pymethods] + impl $Enc { + fn reset(&mut self) -> PyResult<()> { + self.enc.reset(); + Ok(()) + } + + fn add_symbol(&mut self, bytes: &[u8]) -> PyResult<()> { + if bytes.len() > $max_size || bytes.len() != self.symbol_size { + return Err(PyTypeError::new_err("invalid byte array size")) + } + self.enc.add_symbol(&$Sym { + bytes : core::array::from_fn(|i| if i < self.symbol_size { bytes[i] } else { 0 }), + size : self.symbol_size, + hash_type : self.hash_type, + hash_key : self.hash_key, + }); + Ok(()) + } - fn add_symbol(&mut self, bytes: &[u8]) -> PyResult<()> { - if bytes.len() > MAX_SIZE || bytes.len() != self.symbol_size { - return Err(PyTypeError::new_err("invalid byte array size")) + fn produce_next_coded_symbol(&mut self) -> PyResult<$Coded> { + let sym = self.enc.produce_next_coded_symbol(); + Ok($Coded { + data : sym.symbol.bytes, + hash : sym.hash.to_le_bytes(), + count : sym.count, + }) + } } - self.enc.add_symbol(&PySymbol { - bytes : core::array::from_fn(|i| if i < self.symbol_size { bytes[i] } else { 0 }), - size : self.symbol_size, - hash_type : self.hash_type, - hash_keys : self.hash_keys, - }); - Ok(()) - } - fn produce_next_coded_symbol(&mut self) -> PyResult { - let sym = self.enc.produce_next_coded_symbol(); - Ok(PyCodedSymbol { - data : sym.symbol.bytes, - hash : sym.hash, - count : sym.count, - }) - } -} + #[pymethods] + impl $Dec { + fn reset(&mut self) -> PyResult<()> { + self.dec.reset(); + Ok(()) + } -#[pymethods] -impl PyDecoder { - fn reset(&mut self) -> PyResult<()> { - self.dec.reset(); - Ok(()) - } + fn add_symbol(&mut self, bytes: &[u8]) -> PyResult<()> { + if bytes.len() > $max_size || bytes.len() != self.symbol_size { + return Err(PyTypeError::new_err("invalid byte array size")) + } + self.dec.add_symbol(&$Sym { + bytes : core::array::from_fn(|i| if i < self.symbol_size { bytes[i] } else { 0 }), + size : self.symbol_size, + hash_type : self.hash_type, + hash_key : self.hash_key, + }); + Ok(()) + } - fn add_symbol(&mut self, bytes: &[u8]) -> PyResult<()> { - if bytes.len() > MAX_SIZE || bytes.len() != self.symbol_size { - return Err(PyTypeError::new_err("invalid byte array size")) - } - self.dec.add_symbol(&PySymbol { - bytes : core::array::from_fn(|i| if i < self.symbol_size { bytes[i] } else { 0 }), - size : self.symbol_size, - hash_type : self.hash_type, - hash_keys : self.hash_keys, - }); - Ok(()) - } + fn add_coded_symbol(&mut self, sym: &$Coded) -> PyResult<()> { + self.dec.add_coded_symbol(&CodedSymbol::<$Sym> { + symbol : $Sym { + bytes : sym.data, + size : self.symbol_size, + hash_type : self.hash_type, + hash_key : self.hash_key, + }, + hash : u64::from_le_bytes(sym.hash), + count : sym.count, + }); + Ok(()) + } - fn add_coded_symbol(&mut self, sym: &PyCodedSymbol) -> PyResult<()> { - self.dec.add_coded_symbol(&CodedSymbol:: { - symbol : PySymbol { - bytes : sym.data, - size : self.symbol_size, - hash_type : self.hash_type, - hash_keys : self.hash_keys, - }, - hash : sym.hash, - count : sym.count, - }); - Ok(()) - } + fn try_decode(&mut self) -> PyResult<()> { + if self.dec.try_decode().is_err() { + return Err(PyRuntimeError::new_err("decoding error")); + } + Ok(()) + } - fn try_decode(&mut self) -> PyResult<()> { - if self.dec.try_decode().is_err() { - return Err(PyRuntimeError::new_err("decoding error")); - } - Ok(()) - } + fn decoded(&self) -> PyResult { + Ok(self.dec.decoded()) + } - fn decoded(&self) -> PyResult { - Ok(self.dec.decoded()) - } + fn get_remote_symbols(&self) -> PyResult> { + let v = self.dec.get_remote_symbols(); + let mut pyv = Vec::<$Hashed>::new(); + pyv.reserve_exact(v.len()); + for i in 0..v.len() { + pyv.push($Hashed { + data : v[i].symbol.bytes, + hash : v[i].hash.to_le_bytes(), + }); + } + Ok(pyv) + } - fn get_remote_symbols(&self) -> PyResult> { - let v = self.dec.get_remote_symbols(); - let mut pyv = Vec::::new(); - pyv.reserve_exact(v.len()); - for i in 0..v.len() { - pyv.push(PyHashedSymbol { - data : v[i].symbol.bytes, - hash : v[i].hash, - }); + fn get_local_symbols(&self) -> PyResult> { + let v = self.dec.get_local_symbols(); + let mut pyv = Vec::<$Hashed>::new(); + pyv.reserve_exact(v.len()); + for i in 0..v.len() { + pyv.push($Hashed { + data : v[i].symbol.bytes, + hash : v[i].hash.to_le_bytes(), + }); + } + Ok(pyv) + } } - Ok(pyv) - } + }; +} - fn get_local_symbols(&self) -> PyResult> { - let v = self.dec.get_local_symbols(); - let mut pyv = Vec::::new(); - pyv.reserve_exact(v.len()); - for i in 0..v.len() { - pyv.push(PyHashedSymbol { - data : v[i].symbol.bytes, - hash : v[i].hash, - }); - } - Ok(pyv) +macro_rules! add_types { + ($module : ident, + $Sym : ident, + $Hashed : ident, + $Coded : ident, + $Enc : ident, + $Dec : ident) => { + $module.add_class::<$Sym>()?; + $module.add_class::<$Hashed>()?; + $module.add_class::<$Coded>()?; + $module.add_class::<$Enc>()?; + $module.add_class::<$Dec>()?; } } +const SIZE_0 : usize = 64; +const SIZE_MAX : usize = 4096; + +instant!(SIZE_0, Sym0, Hashed0, Coded0, Enc0, Dec0 ); +instant!(SIZE_MAX, SymMax, HashedMax, CodedMax, EncMax, DecMax); + #[pyfunction] -fn new_encoder_sip(size: usize, key_0: u64, key_1: u64) -> PyResult { - if size > MAX_SIZE { - return Err(PyValueError::new_err("size is too big")); +fn new_encoder_sip(py: Python, size: usize, key: [u8; 16]) -> PyResult { + if size <= SIZE_0 { + return Ok(Enc0 { + enc : Encoder::::new(), + symbol_size : size, + hash_type : Hash::SIP, + hash_key : key, + }.into_py(py)); + } + if size <= SIZE_MAX { + return Ok(EncMax { + enc : Encoder::::new(), + symbol_size : size, + hash_type : Hash::SIP, + hash_key : key, + }.into_py(py)); } - return Ok(PyEncoder { - enc : Encoder::::new(), - symbol_size : size, - hash_type : Hash::SIP, - hash_keys : (key_0, key_1), - }); + return Err(PyValueError::new_err("size is too big")); } #[pyfunction] -fn new_decoder_sip(size: usize, key_0: u64, key_1: u64) -> PyResult { - if size > MAX_SIZE { - return Err(PyValueError::new_err("size is too big")); +fn new_decoder_sip(py: Python, size: usize, key: [u8; 16]) -> PyResult { + if size <= SIZE_0 { + return Ok(Dec0 { + dec : Decoder::::new(), + symbol_size : size, + hash_type : Hash::SIP, + hash_key : key, + }.into_py(py)); } - return Ok(PyDecoder { - dec : Decoder::::new(), - symbol_size : size, - hash_type : Hash::SIP, - hash_keys : (key_0, key_1), - }); -} + if size <= SIZE_MAX { + return Ok(DecMax { + dec : Decoder::::new(), + symbol_size : size, + hash_type : Hash::SIP, + hash_key : key, + }.into_py(py)); + } + return Err(PyValueError::new_err("size is too big")); +} #[pymodule] fn riblt_rust_py(_py: Python, m: &PyModule) -> PyResult<()> { - m.add_class::()?; - m.add_class::()?; - m.add_class::()?; - m.add_class::()?; - m.add_class::()?; + add_types!(m, Sym0, Hashed0, Coded0, Enc0, Dec0 ); + add_types!(m, SymMax, HashedMax, CodedMax, EncMax, DecMax); m.add_function(wrap_pyfunction!(new_encoder_sip, m)?)?; m.add_function(wrap_pyfunction!(new_decoder_sip, m)?)?; Ok(())