From 0a956ba785a2cb0b9437d191183ef7e3c73885e2 Mon Sep 17 00:00:00 2001 From: Mitya Selivanov Date: Tue, 9 Apr 2024 11:59:47 +0200 Subject: [PATCH] Fixes for different symbol size --- README.md | 28 +++++++++++----------------- example.py | 26 ++++++++++---------------- src/lib.rs | 49 +++++++++++++++++++++++++++++++++---------------- 3 files changed, 54 insertions(+), 49 deletions(-) diff --git a/README.md b/README.md index 0d2c241..9afc02d 100644 --- a/README.md +++ b/README.md @@ -10,35 +10,29 @@ maturin build ``` ### Example -Python program to use the library with Sip hash and 64-byte symbols: +Python program to use the library with SipHash and 1-byte symbols: ```py import riblt_rust_py as riblt -symbol_size = 64 +symbol_size = 1 key_0 = 123 key_1 = 456 -zeros = [0] * 63 - -sym0 = bytes([1] + zeros) -sym1 = bytes([2] + zeros) -sym2 = bytes([3] + zeros) -sym3 = bytes([4] + zeros) - enc = riblt.new_encoder_sip(symbol_size, key_0, key_1) -enc.add_symbol(sym0) -enc.add_symbol(sym1) -enc.add_symbol(sym3) +enc.add_symbol(bytes([1])) +enc.add_symbol(bytes([2])) +enc.add_symbol(bytes([4])) dec = riblt.new_decoder_sip(symbol_size, key_0, key_1) -dec.add_symbol(sym0) -dec.add_symbol(sym2) -dec.add_symbol(sym3) +dec.add_symbol(bytes([1])) +dec.add_symbol(bytes([3])) +dec.add_symbol(bytes([4])) while True: s = enc.produce_next_coded_symbol() + print("coded: " + str(s.data[0]) + ", " + str(s.hash) + ", " + str(s.count)) dec.add_coded_symbol(s) dec.try_decode() if dec.decoded(): @@ -47,8 +41,8 @@ while True: local_symbols = dec.get_local_symbols() remote_symbols = dec.get_remote_symbols() -print("local symbol: " + str(local_symbols[0].symbol[0])) -print("remote symbol: " + str(remote_symbols[0].symbol[0])) +print("local symbol: " + str(local_symbols[0].data[0])) +print("remote symbol: " + str(remote_symbols[0].data[0])) ``` To run the example: diff --git a/example.py b/example.py index 7abedde..bc5b0cb 100644 --- a/example.py +++ b/example.py @@ -1,31 +1,25 @@ import riblt_rust_py as riblt def example(): - symbol_size = 64 + symbol_size = 1 key_0 = 123 key_1 = 456 - zeros = [0] * 63 - - sym0 = bytes([1] + zeros) - sym1 = bytes([2] + zeros) - sym2 = bytes([3] + zeros) - sym3 = bytes([4] + zeros) - enc = riblt.new_encoder_sip(symbol_size, key_0, key_1) - enc.add_symbol(sym0) - enc.add_symbol(sym1) - enc.add_symbol(sym3) + enc.add_symbol(bytes([1])) + enc.add_symbol(bytes([2])) + enc.add_symbol(bytes([4])) dec = riblt.new_decoder_sip(symbol_size, key_0, key_1) - dec.add_symbol(sym0) - dec.add_symbol(sym2) - dec.add_symbol(sym3) + dec.add_symbol(bytes([1])) + dec.add_symbol(bytes([3])) + dec.add_symbol(bytes([4])) while True: s = enc.produce_next_coded_symbol() + print("coded: " + str(s.data[0]) + ", " + str(s.hash) + ", " + str(s.count)) dec.add_coded_symbol(s) dec.try_decode() if dec.decoded(): @@ -34,7 +28,7 @@ def example(): local_symbols = dec.get_local_symbols() remote_symbols = dec.get_remote_symbols() - print("local symbol: " + str(local_symbols[0].symbol[0])) - print("remote symbol: " + str(remote_symbols[0].symbol[0])) + print("local symbol: " + str(local_symbols[0].data[0])) + print("remote symbol: " + str(remote_symbols[0].data[0])) example() diff --git a/src/lib.rs b/src/lib.rs index 1e84006..b00d446 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,16 +23,21 @@ struct PySymbol { } #[pyclass] -struct PyCodedSymbol { - sym: CodedSymbol, +struct PyHashedSymbol { + #[pyo3(get, set)] + pub data : [u8; MAX_SIZE], + #[pyo3(get, set)] + pub hash : u64, } #[pyclass] -struct PyHashedSymbol { +struct PyCodedSymbol { #[pyo3(get, set)] - pub symbol : [u8; 64], + pub data : [u8; MAX_SIZE], #[pyo3(get, set)] - pub hash : u64, + pub hash : u64, + #[pyo3(get, set)] + pub count : i64, } #[pyclass] @@ -98,11 +103,11 @@ impl PyEncoder { } fn add_symbol(&mut self, bytes: &[u8]) -> PyResult<()> { - if bytes.len() > MAX_SIZE { + if bytes.len() > MAX_SIZE || bytes.len() != self.symbol_size { return Err(PyTypeError::new_err("invalid bytearray size")) } self.enc.add_symbol(&PySymbol { - bytes : core::array::from_fn(|i| bytes[i]), + bytes : core::array::from_fn(|i| if i < self.symbol_size { bytes[i] } else { 0 }), size : self.symbol_size, hash_type : self.hash_type, hash_keys : self.hash_keys, @@ -111,8 +116,11 @@ impl PyEncoder { } fn produce_next_coded_symbol(&mut self) -> PyResult { + let sym = self.enc.produce_next_coded_symbol(); Ok(PyCodedSymbol { - sym: self.enc.produce_next_coded_symbol(), + data : sym.symbol.bytes, + hash : sym.hash, + count : sym.count, }) } } @@ -125,11 +133,11 @@ impl PyDecoder { } fn add_symbol(&mut self, bytes: &[u8]) -> PyResult<()> { - if bytes.len() > MAX_SIZE { + if bytes.len() > MAX_SIZE || bytes.len() != self.symbol_size { return Err(PyTypeError::new_err("invalid byte array size")) } self.dec.add_symbol(&PySymbol { - bytes : core::array::from_fn(|i| bytes[i]), + bytes : core::array::from_fn(|i| if i < self.symbol_size { bytes[i] } else { 0 }), size : self.symbol_size, hash_type : self.hash_type, hash_keys : self.hash_keys, @@ -138,7 +146,16 @@ impl PyDecoder { } fn add_coded_symbol(&mut self, sym: &PyCodedSymbol) -> PyResult<()> { - self.dec.add_coded_symbol(&sym.sym); + self.dec.add_coded_symbol(&CodedSymbol:: { + symbol : PySymbol { + bytes : sym.data, + size : self.symbol_size, + hash_type : self.hash_type, + hash_keys : self.hash_keys, + }, + hash : sym.hash, + count : sym.count, + }); Ok(()) } @@ -159,8 +176,8 @@ impl PyDecoder { pyv.reserve_exact(v.len()); for i in 0..v.len() { pyv.push(PyHashedSymbol { - symbol : v[i].symbol.bytes, - hash : v[i].hash, + data : v[i].symbol.bytes, + hash : v[i].hash, }); } Ok(pyv) @@ -172,8 +189,8 @@ impl PyDecoder { pyv.reserve_exact(v.len()); for i in 0..v.len() { pyv.push(PyHashedSymbol { - symbol : v[i].symbol.bytes, - hash : v[i].hash, + data : v[i].symbol.bytes, + hash : v[i].hash, }); } Ok(pyv) @@ -203,8 +220,8 @@ fn new_decoder_sip(size: usize, key_0: u64, key_1: u64) -> PyResult { #[pymodule] fn riblt_rust_py(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; - m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_function(wrap_pyfunction!(new_encoder_sip, m)?)?;