Skip to content

Commit

Permalink
Fixes for different symbol size
Browse files Browse the repository at this point in the history
  • Loading branch information
automainint committed Apr 9, 2024
1 parent 533cd52 commit 0a956ba
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 49 deletions.
28 changes: 11 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,35 +10,29 @@ maturin build
```

### Example
Python program to use the library with Sip hash and 64-byte symbols:
Python program to use the library with SipHash and 1-byte symbols:
```py
import riblt_rust_py as riblt

symbol_size = 64
symbol_size = 1
key_0 = 123
key_1 = 456

zeros = [0] * 63

sym0 = bytes([1] + zeros)
sym1 = bytes([2] + zeros)
sym2 = bytes([3] + zeros)
sym3 = bytes([4] + zeros)

enc = riblt.new_encoder_sip(symbol_size, key_0, key_1)

enc.add_symbol(sym0)
enc.add_symbol(sym1)
enc.add_symbol(sym3)
enc.add_symbol(bytes([1]))
enc.add_symbol(bytes([2]))
enc.add_symbol(bytes([4]))

dec = riblt.new_decoder_sip(symbol_size, key_0, key_1)

dec.add_symbol(sym0)
dec.add_symbol(sym2)
dec.add_symbol(sym3)
dec.add_symbol(bytes([1]))
dec.add_symbol(bytes([3]))
dec.add_symbol(bytes([4]))

while True:
s = enc.produce_next_coded_symbol()
print("coded: " + str(s.data[0]) + ", " + str(s.hash) + ", " + str(s.count))
dec.add_coded_symbol(s)
dec.try_decode()
if dec.decoded():
Expand All @@ -47,8 +41,8 @@ while True:
local_symbols = dec.get_local_symbols()
remote_symbols = dec.get_remote_symbols()

print("local symbol: " + str(local_symbols[0].symbol[0]))
print("remote symbol: " + str(remote_symbols[0].symbol[0]))
print("local symbol: " + str(local_symbols[0].data[0]))
print("remote symbol: " + str(remote_symbols[0].data[0]))
```

To run the example:
Expand Down
26 changes: 10 additions & 16 deletions example.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,25 @@
import riblt_rust_py as riblt

def example():
symbol_size = 64
symbol_size = 1
key_0 = 123
key_1 = 456

zeros = [0] * 63

sym0 = bytes([1] + zeros)
sym1 = bytes([2] + zeros)
sym2 = bytes([3] + zeros)
sym3 = bytes([4] + zeros)

enc = riblt.new_encoder_sip(symbol_size, key_0, key_1)

enc.add_symbol(sym0)
enc.add_symbol(sym1)
enc.add_symbol(sym3)
enc.add_symbol(bytes([1]))
enc.add_symbol(bytes([2]))
enc.add_symbol(bytes([4]))

dec = riblt.new_decoder_sip(symbol_size, key_0, key_1)

dec.add_symbol(sym0)
dec.add_symbol(sym2)
dec.add_symbol(sym3)
dec.add_symbol(bytes([1]))
dec.add_symbol(bytes([3]))
dec.add_symbol(bytes([4]))

while True:
s = enc.produce_next_coded_symbol()
print("coded: " + str(s.data[0]) + ", " + str(s.hash) + ", " + str(s.count))
dec.add_coded_symbol(s)
dec.try_decode()
if dec.decoded():
Expand All @@ -34,7 +28,7 @@ def example():
local_symbols = dec.get_local_symbols()
remote_symbols = dec.get_remote_symbols()

print("local symbol: " + str(local_symbols[0].symbol[0]))
print("remote symbol: " + str(remote_symbols[0].symbol[0]))
print("local symbol: " + str(local_symbols[0].data[0]))
print("remote symbol: " + str(remote_symbols[0].data[0]))

example()
49 changes: 33 additions & 16 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,21 @@ struct PySymbol {
}

#[pyclass]
struct PyCodedSymbol {
sym: CodedSymbol<PySymbol>,
struct PyHashedSymbol {
#[pyo3(get, set)]
pub data : [u8; MAX_SIZE],
#[pyo3(get, set)]
pub hash : u64,
}

#[pyclass]
struct PyHashedSymbol {
struct PyCodedSymbol {
#[pyo3(get, set)]
pub symbol : [u8; 64],
pub data : [u8; MAX_SIZE],
#[pyo3(get, set)]
pub hash : u64,
pub hash : u64,
#[pyo3(get, set)]
pub count : i64,
}

#[pyclass]
Expand Down Expand Up @@ -98,11 +103,11 @@ impl PyEncoder {
}

fn add_symbol(&mut self, bytes: &[u8]) -> PyResult<()> {
if bytes.len() > MAX_SIZE {
if bytes.len() > MAX_SIZE || bytes.len() != self.symbol_size {
return Err(PyTypeError::new_err("invalid bytearray size"))
}
self.enc.add_symbol(&PySymbol {
bytes : core::array::from_fn(|i| bytes[i]),
bytes : core::array::from_fn(|i| if i < self.symbol_size { bytes[i] } else { 0 }),
size : self.symbol_size,
hash_type : self.hash_type,
hash_keys : self.hash_keys,
Expand All @@ -111,8 +116,11 @@ impl PyEncoder {
}

fn produce_next_coded_symbol(&mut self) -> PyResult<PyCodedSymbol> {
let sym = self.enc.produce_next_coded_symbol();
Ok(PyCodedSymbol {
sym: self.enc.produce_next_coded_symbol(),
data : sym.symbol.bytes,
hash : sym.hash,
count : sym.count,
})
}
}
Expand All @@ -125,11 +133,11 @@ impl PyDecoder {
}

fn add_symbol(&mut self, bytes: &[u8]) -> PyResult<()> {
if bytes.len() > MAX_SIZE {
if bytes.len() > MAX_SIZE || bytes.len() != self.symbol_size {
return Err(PyTypeError::new_err("invalid byte array size"))
}
self.dec.add_symbol(&PySymbol {
bytes : core::array::from_fn(|i| bytes[i]),
bytes : core::array::from_fn(|i| if i < self.symbol_size { bytes[i] } else { 0 }),
size : self.symbol_size,
hash_type : self.hash_type,
hash_keys : self.hash_keys,
Expand All @@ -138,7 +146,16 @@ impl PyDecoder {
}

fn add_coded_symbol(&mut self, sym: &PyCodedSymbol) -> PyResult<()> {
self.dec.add_coded_symbol(&sym.sym);
self.dec.add_coded_symbol(&CodedSymbol::<PySymbol> {
symbol : PySymbol {
bytes : sym.data,
size : self.symbol_size,
hash_type : self.hash_type,
hash_keys : self.hash_keys,
},
hash : sym.hash,
count : sym.count,
});
Ok(())
}

Expand All @@ -159,8 +176,8 @@ impl PyDecoder {
pyv.reserve_exact(v.len());
for i in 0..v.len() {
pyv.push(PyHashedSymbol {
symbol : v[i].symbol.bytes,
hash : v[i].hash,
data : v[i].symbol.bytes,
hash : v[i].hash,
});
}
Ok(pyv)
Expand All @@ -172,8 +189,8 @@ impl PyDecoder {
pyv.reserve_exact(v.len());
for i in 0..v.len() {
pyv.push(PyHashedSymbol {
symbol : v[i].symbol.bytes,
hash : v[i].hash,
data : v[i].symbol.bytes,
hash : v[i].hash,
});
}
Ok(pyv)
Expand Down Expand Up @@ -203,8 +220,8 @@ fn new_decoder_sip(size: usize, key_0: u64, key_1: u64) -> PyResult<PyDecoder> {
#[pymodule]
fn riblt_rust_py(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<PySymbol>()?;
m.add_class::<PyCodedSymbol>()?;
m.add_class::<PyHashedSymbol>()?;
m.add_class::<PyCodedSymbol>()?;
m.add_class::<PyEncoder>()?;
m.add_class::<PyDecoder>()?;
m.add_function(wrap_pyfunction!(new_encoder_sip, m)?)?;
Expand Down

0 comments on commit 0a956ba

Please sign in to comment.