diff --git a/python/mpmc_queue_blocking_mixed.py b/python/mpmc_queue_blocking_mixed.py new file mode 100644 index 0000000..e84a3e3 --- /dev/null +++ b/python/mpmc_queue_blocking_mixed.py @@ -0,0 +1,58 @@ +import time +import os +from rocksq import remove_mpmc_queue, StartPosition +from rocksq.blocking import MpmcQueue + +NUM = 1 +OPS = 1000 +RELEASE_GIL = True +PATH = '/tmp/mpmc-queue' +TTL = 60 +LABEL_ONE = 'label1' +LABEL_TWO = 'label2' +LABEL_THREE = 'label3' + +# if directory exists, remove it +if os.path.exists(PATH): + remove_mpmc_queue(PATH) + +q = MpmcQueue(PATH, TTL) + +start = time.time() +for i in range(OPS): + data = [bytes(str(i), 'utf-8')] + q.add(data, no_gil=RELEASE_GIL) + v = q.next(label=LABEL_ONE, start_position=StartPosition.Oldest, max_elements=NUM, no_gil=RELEASE_GIL) + assert len(v) == NUM + assert v == data + +end = time.time() + +print("Time taken: %f" % (end - start)) + +v = q.next(label=LABEL_ONE, start_position=StartPosition.Oldest, max_elements=NUM, no_gil=RELEASE_GIL) +assert len(v) == 0 + +v = q.next(label=LABEL_TWO, start_position=StartPosition.Newest, max_elements=NUM, no_gil=RELEASE_GIL) +assert len(v) == NUM +assert v == [bytes(str(OPS-1), 'utf-8')] + +labels = q.labels +assert len(labels) == 2 +assert LABEL_ONE in labels +assert LABEL_TWO in labels + +r = q.remove_label(LABEL_THREE) +assert not r + +r = q.remove_label(LABEL_ONE) +assert r + +labels = q.labels +assert len(labels) == 1 +assert LABEL_ONE not in labels +assert LABEL_TWO in labels + +v = q.next(label=LABEL_ONE, start_position=StartPosition.Oldest, max_elements=NUM, no_gil=RELEASE_GIL) +assert len(v) == NUM +assert v == [bytes(str(0), 'utf-8')] diff --git a/python/mpmc_queue_blocking_w_r.py b/python/mpmc_queue_blocking_w_r.py new file mode 100644 index 0000000..8b2f6a4 --- /dev/null +++ b/python/mpmc_queue_blocking_w_r.py @@ -0,0 +1,31 @@ +import time +import os +from rocksq import remove_mpmc_queue, StartPosition +from rocksq.blocking import MpmcQueue + +NUM = 1 +OPS = 1000 +RELEASE_GIL = True +PATH = '/tmp/mpmc-queue' +TTL = 60 +LABEL = 'label' + +# if directory exists, remove it +if os.path.exists(PATH): + remove_mpmc_queue(PATH) + +q = MpmcQueue(PATH, TTL) + +start = time.time() +for i in range(OPS): + data = [bytes(str(i), 'utf-8')] + q.add(data, no_gil=RELEASE_GIL) + +for i in range(OPS): + v = q.next(label=LABEL, start_position=StartPosition.Oldest, max_elements=NUM, no_gil=RELEASE_GIL) + assert len(v) == NUM + assert v == [bytes(str(i), 'utf-8')] + +end = time.time() + +print("Time taken: %f" % (end - start)) diff --git a/python/mpmc_queue_nonblocking_mixed.py b/python/mpmc_queue_nonblocking_mixed.py new file mode 100644 index 0000000..0f85497 --- /dev/null +++ b/python/mpmc_queue_nonblocking_mixed.py @@ -0,0 +1,58 @@ +import time +import os +from rocksq import remove_mpmc_queue, StartPosition +from rocksq.nonblocking import MpmcQueue + +NUM = 1 +OPS = 1000 +RELEASE_GIL = True +PATH = '/tmp/mpmc-queue' +TTL = 60 +LABEL_ONE = 'label1' +LABEL_TWO = 'label2' +LABEL_THREE = 'label3' + +# if directory exists, remove it +if os.path.exists(PATH): + remove_mpmc_queue(PATH) + +q = MpmcQueue(PATH, TTL) + +start = time.time() +for i in range(OPS): + data = [bytes(str(i), 'utf-8')] + q.add(data, no_gil=RELEASE_GIL).get() + v = q.next(label=LABEL_ONE, start_position=StartPosition.Oldest, max_elements=NUM, no_gil=RELEASE_GIL).get().data + assert len(v) == NUM + assert v == data + +end = time.time() + +print("Time taken: %f" % (end - start)) + +v = q.next(label=LABEL_ONE, start_position=StartPosition.Oldest, max_elements=NUM, no_gil=RELEASE_GIL).get().data +assert len(v) == 0 + +v = q.next(label=LABEL_TWO, start_position=StartPosition.Newest, max_elements=NUM, no_gil=RELEASE_GIL).get().data +assert len(v) == NUM +assert v == [bytes(str(OPS-1), 'utf-8')] + +labels = q.labels.get().labels +assert len(labels) == 2 +assert LABEL_ONE in labels +assert LABEL_TWO in labels + +r = q.remove_label(LABEL_THREE).get().removed_label +assert not r + +r = q.remove_label(LABEL_ONE).get().removed_label +assert r + +labels = q.labels.get().labels +assert len(labels) == 1 +assert LABEL_ONE not in labels +assert LABEL_TWO in labels + +v = q.next(label=LABEL_ONE, start_position=StartPosition.Oldest, max_elements=NUM, no_gil=RELEASE_GIL).get().data +assert len(v) == NUM +assert v == [bytes(str(0), 'utf-8')] diff --git a/queue_py/python/rocksq/blocking/blocking.pyi b/queue_py/python/rocksq/blocking/blocking.pyi index 1a7272d..bb8feae 100644 --- a/queue_py/python/rocksq/blocking/blocking.pyi +++ b/queue_py/python/rocksq/blocking/blocking.pyi @@ -1,3 +1,5 @@ +from rocksq import StartPosition + class PersistentQueueWithCapacity: def __init__(self, path: str, max_elements: int = 1_000_000_000): ... @@ -15,4 +17,28 @@ class PersistentQueueWithCapacity: def payload_size(self) -> int: ... @property - def len(self) -> int: ... \ No newline at end of file + def len(self) -> int: ... + +class MpmcQueue: + def __init__(self, path: str, ttl: int): ... + + def add(self, items: list[bytes], no_gil: bool = True): ... + + def next(self, label: str, start_position: StartPosition, max_elements: int = 1, no_gil: bool = True) -> list[bytes]: ... + + @property + def is_empty(self) -> bool: ... + + @property + def disk_size(self) -> int: ... + + @property + def payload_size(self) -> int: ... + + @property + def len(self) -> int: ... + + @property + def labels(self): list[str]: ... + + def remove_label(self, label: str) -> bool: ... diff --git a/queue_py/python/rocksq/nonblocking/nonblocking.pyi b/queue_py/python/rocksq/nonblocking/nonblocking.pyi index 31e72d0..e4ba247 100644 --- a/queue_py/python/rocksq/nonblocking/nonblocking.pyi +++ b/queue_py/python/rocksq/nonblocking/nonblocking.pyi @@ -1,5 +1,5 @@ from typing import Optional - +from rocksq import StartPosition class ResponseVariant: @property @@ -39,3 +39,48 @@ class PersistentQueueWithCapacity: @property def len(self) -> Response: ... + +class MpmcResponseVariant: + @property + def data(self) -> Optional[list[bytes]]: ... + + @property + def labels(self) -> Optional[list[str]]: ... + + @property + def removed_label(self) -> Optional[bool]: ... + + @property + def len(self) -> Optional[int]: ... + + @property + def size(self) -> Optional[int]: ... + +class MpmcResponse: + @property + def is_ready(self) -> bool: ... + + def try_get(self) -> Optional[MpmcResponseVariant]: ... + + def get(self) -> MpmcResponseVariant: ... + +class MpmcQueue: + def __init__(self, path: str, ttl: int, max_inflight_ops: int = 1_000): ... + + def add(self, items: list[bytes], no_gil: bool = True) -> MpmcResponse: ... + + @property + def inflight_ops(self) -> int: ... + + def next(self, label: str, start_position: StartPosition, max_elements = 1, no_gil: bool = True) -> MpmcResponse: ... + + @property + def disk_size(self) -> MpmcResponse: ... + + @property + def len(self) -> MpmcResponse: ... + + @property + def labels(self) -> MpmcResponse: ... + + def remove_label(self, label: str) -> MpmcResponse: ... diff --git a/queue_py/python/rocksq/rocksq.pyi b/queue_py/python/rocksq/rocksq.pyi index f20501b..ee2c9e5 100644 --- a/queue_py/python/rocksq/rocksq.pyi +++ b/queue_py/python/rocksq/rocksq.pyi @@ -1,4 +1,11 @@ +from enum import Enum + def version() -> str: ... def remove_queue(queue_name: str): ... +def remove_mpmc_queue(queue_name: str): ... + +class StartPosition(Enum): + Oldest=0 + Newest=1 diff --git a/queue_py/src/blocking.rs b/queue_py/src/blocking.rs index b51a75a..a839365 100644 --- a/queue_py/src/blocking.rs +++ b/queue_py/src/blocking.rs @@ -1,7 +1,10 @@ +use crate::StartPosition; use pyo3::exceptions::PyRuntimeError; use pyo3::prelude::*; use pyo3::types::PyBytes; +use queue_rs::mpmc; use rocksdb::Options; +use std::time::Duration; /// A persistent queue with a fixed capacity. This is a blocking implementation. /// @@ -176,3 +179,232 @@ impl PersistentQueueWithCapacity { self.0.len() } } + +/// A persistent queue with a ttl that supports multiple consumers marked with labels. This is a +/// blocking implementation. +/// +/// Parameters +/// ---------- +/// path : str +/// The path to the queue. +/// ttl : int +/// The amount of seconds after which the element in the queue will be removed. Ttl is non-strict +/// meaning that it is guaranteed that the element inserted will remain in the queue for at least +/// ttl amount of time and the queue will make efforts to remove the element as soon as possible +/// after ttl seconds of its insertion. +/// +/// Raises +/// ------ +/// PyRuntimeError +/// If the queue could not be created. +/// +#[pyclass] +pub struct MpmcQueue(queue_rs::blocking::MpmcQueue); + +#[pymethods] +impl MpmcQueue { + #[new] + #[pyo3(signature=(path, ttl))] + fn new(path: &str, ttl: u32) -> PyResult { + let queue = queue_rs::blocking::MpmcQueue::new(path, Duration::from_secs(ttl as u64)) + .map_err(|e| PyRuntimeError::new_err(format!("Failed to create mpmc queue: {}", e)))?; + Ok(Self(queue)) + } + + /// Adds items to the queue. + /// + /// **GIL**: the method can optionally be called without the GIL. + /// + /// Parameters + /// ---------- + /// items : list of bytes + /// The items to add to the queue. + /// no_gil : bool + /// If True, the method will be called without the GIL. Default is ``True``. + /// + /// Raises + /// ------ + /// PyRuntimeError + /// If the method fails. + /// + /// Returns + /// ------- + /// None + /// + #[pyo3(signature = (items, no_gil = true))] + fn add(&self, items: Vec<&PyBytes>, no_gil: bool) -> PyResult<()> { + let data = items.iter().map(|e| e.as_bytes()).collect::>(); + Python::with_gil(|py| { + let f = || { + self.0 + .add(&data) + .map_err(|e| PyRuntimeError::new_err(format!("Failed to add items: {}", e))) + }; + + if no_gil { + py.allow_threads(f) + } else { + f() + } + }) + } + + /// Retrieves items from the queue. + /// + /// **GIL**: the method can optionally be called without the GIL. + /// + /// Parameters + /// ---------- + /// label: str + /// The consumer label that determines the start position in the queue to retrieve elements. + /// If the label does not exist the start position is determined by ``option` parameter. If + /// the label exists the start position is the next element after the last call of this + /// method. If some elements are expired between the last and this call next non-expired + /// elements will be retrieved. + /// start_position: StartPosition + /// The option that determines the start position in the queue to retrieve elements if the + /// consumer label does not exist. + /// max_elements : int + /// The maximum number of elements to retrieve. Default is ``1``. + /// no_gil : bool + /// If True, the method will be called without the GIL. Default is ``True``. + /// + /// Raises + /// ------ + /// PyRuntimeError + /// If the method fails. + /// + /// Returns + /// ------- + /// items : list of bytes + /// The items retrieved from the queue. + /// + #[pyo3(signature = (label, start_position, max_elements = 1, no_gil = true))] + fn next( + &self, + label: &str, + start_position: StartPosition, + max_elements: usize, + no_gil: bool, + ) -> PyResult> { + Python::with_gil(|py| { + let start_position = match start_position { + StartPosition::Oldest => mpmc::StartPosition::Oldest, + StartPosition::Newest => mpmc::StartPosition::Newest, + }; + if no_gil { + py.allow_threads(|| self.0.next(max_elements, label, start_position)) + } else { + self.0.next(max_elements, label, start_position) + } + .map(|results| { + results + .into_iter() + .map(|r| { + PyBytes::new_with(py, r.len(), |b: &mut [u8]| { + b.copy_from_slice(&r); + Ok(()) + }) + .map(PyObject::from) + }) + .collect::>>() + }) + .map_err(|_| PyRuntimeError::new_err("Failed to retrieve items")) + })? + } + + /// Checks if the queue is empty. + /// + /// Returns + /// ------- + /// bool + /// ``True`` if the queue is empty, ``False`` otherwise. + /// + #[getter] + fn is_empty(&self) -> bool { + self.0.is_empty() + } + + /// Returns the disk size of the queue in bytes. + /// + /// Returns + /// ------- + /// size : int + /// + /// Raises + /// ------ + /// PyRuntimeError + /// If the method fails. + /// + #[getter] + fn disk_size(&self) -> PyResult { + Python::with_gil(|py| { + py.allow_threads(|| { + self.0.disk_size().map_err(|e| { + PyRuntimeError::new_err(format!("Failed to get queue size: {}", e)) + }) + }) + }) + } + + /// Returns the number of elements in the queue. + /// + /// Returns + /// ------- + /// int + /// The number of elements in the queue. + /// + #[getter] + fn len(&self) -> usize { + self.0.len() + } + + /// Returns the consumer labels. + /// + /// Returns + /// ------- + /// labels: list of str + /// The consumer labels. + /// + #[getter] + fn labels(&self) -> Vec { + self.0.get_labels() + } + + /// Removes the consumer label from the queue. + /// + /// **GIL**: the method can optionally be called without the GIL. + /// + /// Parameters + /// ---------- + /// label : str + /// The consumer label to remove. + /// no_gil : bool + /// If True, the method will be called without the GIL. Default is ``True``. + /// + /// Raises + /// ------ + /// PyRuntimeError + /// If the method fails. + /// + /// Returns + /// ------- + /// None + /// + #[pyo3(signature = (label, no_gil = true))] + fn remove_label(&self, label: &str, no_gil: bool) -> PyResult { + Python::with_gil(|py| { + let f = || { + self.0 + .remove_label(label) + .map_err(|e| PyRuntimeError::new_err(format!("Failed to remove label: {}", e))) + }; + + if no_gil { + py.allow_threads(f) + } else { + f() + } + }) + } +} diff --git a/queue_py/src/lib.rs b/queue_py/src/lib.rs index 5c25170..566de17 100644 --- a/queue_py/src/lib.rs +++ b/queue_py/src/lib.rs @@ -18,7 +18,7 @@ pub fn version() -> String { queue_rs::version().to_string() } -/// Removes the queue at the given path. The queue must be closed. +/// Removes ``PersistentQueueWithCapacity`` at the given path. The queue must be closed. /// /// Parameters /// ---------- @@ -36,9 +36,35 @@ fn remove_queue(path: &str) -> PyResult<()> { .map_err(|e| PyRuntimeError::new_err(format!("Failed to remove persistent queue: {}", e))) } +/// Removes ``MpmcQueue`` at the given path. The queue must be closed. +/// +/// Parameters +/// ---------- +/// path : str +/// The path to the queue to remove. +/// +/// Raises +/// ------ +/// PyRuntimeError +/// If the queue could not be removed. +/// +#[pyfunction] +fn remove_mpmc_queue(path: &str) -> PyResult<()> { + queue_rs::mpmc::MpmcQueue::remove_db(path) + .map_err(|e| PyRuntimeError::new_err(format!("Failed to remove mpmc queue: {}", e))) +} + +#[pyclass] +#[derive(PartialEq, Copy, Clone)] +enum StartPosition { + Oldest, + Newest, +} + #[pymodule] fn rocksq_blocking(_: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; + m.add_class::()?; Ok(()) } @@ -47,6 +73,11 @@ fn rocksq_nonblocking(_: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + Ok(()) } @@ -54,10 +85,13 @@ fn rocksq_nonblocking(_: Python, m: &PyModule) -> PyResult<()> { fn rocksq(py: Python, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(version, m)?)?; m.add_function(wrap_pyfunction!(remove_queue, m)?)?; + m.add_function(wrap_pyfunction!(remove_mpmc_queue, m)?)?; m.add_wrapped(wrap_pymodule!(rocksq_blocking))?; m.add_wrapped(wrap_pymodule!(rocksq_nonblocking))?; + m.add_class::()?; + let sys = PyModule::import(py, "sys")?; let sys_modules: &PyDict = sys.getattr("modules")?.downcast()?; diff --git a/queue_py/src/nonblocking.rs b/queue_py/src/nonblocking.rs index c81f753..712fe1d 100644 --- a/queue_py/src/nonblocking.rs +++ b/queue_py/src/nonblocking.rs @@ -1,10 +1,14 @@ +use crate::StartPosition; use pyo3::exceptions::PyRuntimeError; use pyo3::prelude::*; use pyo3::types::PyBytes; +use queue_rs::mpmc; use rocksdb::Options; +use std::time::Duration; -/// A response variant containing the actual data for push, pop, size and length operations. -/// The object is created only by the library, there is no public constructor. +/// A response variant containing the actual data for push, pop, size and length operations of +/// ``PersistentQueueWithCapacity``. The object is created only by the library, there is no +/// public constructor. /// #[pyclass] pub struct ResponseVariant(queue_rs::nonblocking::ResponseVariant); @@ -344,3 +348,446 @@ impl PersistentQueueWithCapacity { .map_err(|e| PyRuntimeError::new_err(format!("Failed to get length: {}", e))) } } + +/// A response variant containing the actual data for add, next, size and length operations of +/// ``MpmcQueue``. The object is created only by the library, there is no public constructor. +/// +#[pyclass] +pub struct MpmcResponseVariant(queue_rs::nonblocking::MpmcResponseVariant); + +#[pymethods] +impl MpmcResponseVariant { + /// Returns the data for the ``next()`` operation. + /// + /// Raises + /// ------ + /// PyRuntimeError + /// If the method fails. + /// + /// Returns + /// ------- + /// list of bytes + /// The data for the ``next()`` operation if the operation was successful, + /// ``None`` + /// if the future doesn't represent the ``next()`` operation. + /// + #[getter] + fn data(&self) -> PyResult>> { + Python::with_gil(|py| match &self.0 { + queue_rs::nonblocking::MpmcResponseVariant::Next(data) => Ok(Some( + data.as_ref() + .map(|results| { + results + .iter() + .map(|r| { + PyBytes::new_with(py, r.len(), |b: &mut [u8]| { + b.copy_from_slice(r); + Ok(()) + }) + .map(PyObject::from) + }) + .collect::>>() + }) + .map_err(|e| { + PyRuntimeError::new_err(format!("Failed to get response: {}", e)) + })??, + )), + _ => Ok(None), + }) + } + + /// Returns the data for the ``get_labels()`` operation. + /// + /// Raises + /// ------ + /// PyRuntimeError + /// If the method fails. + /// + /// Returns + /// ------- + /// list of str + /// The data for the ``get_labels()`` operation if the operation was successful, + /// ``None`` + /// if the future doesn't represent the ``get_labels()`` operation. + /// + #[getter] + fn labels(&self) -> Option> { + match &self.0 { + queue_rs::nonblocking::MpmcResponseVariant::GetLabels(data) => Some(data.to_vec()), + _ => None, + } + } + + /// Returns the result of ``remove_label()`` operation. + /// + /// Raises + /// ------ + /// PyRuntimeError + /// If the method fails. + /// + /// Returns + /// ------- + /// bool + /// ``True`` if the consumer label existed, ``False`` otherwise. + /// ``None`` + /// if the future doesn't represent the ``size()`` operation. + /// + #[getter] + fn removed_label(&self) -> PyResult> { + match &self.0 { + queue_rs::nonblocking::MpmcResponseVariant::RemoveLabel(data) => Ok(data + .as_ref() + .map(|r| Some(*r)) + .map_err(|e| PyRuntimeError::new_err(format!("Failed to get response: {}", e)))?), + _ => Ok(None), + } + } + + /// Returns the length of the queue. + /// + /// Raises + /// ------ + /// PyRuntimeError + /// If the method fails. + /// + /// Returns + /// ------- + /// ``int`` + /// The length of the queue if the operation was successful, + /// ``None`` + /// if the future doesn't represent the ``length()`` operation. + /// + #[getter] + fn len(&self) -> Option { + match &self.0 { + queue_rs::nonblocking::MpmcResponseVariant::Length(data) => Some(*data), + _ => None, + } + } + + /// Returns the size of the queue. + /// + /// Raises + /// ------ + /// PyRuntimeError + /// If the method fails. + /// + /// Returns + /// ------- + /// ``int`` + /// The size of the queue if the operation was successful, + /// ``None`` + /// if the future doesn't represent the ``size()`` operation. + /// + #[getter] + fn size(&self) -> PyResult> { + match &self.0 { + queue_rs::nonblocking::MpmcResponseVariant::Size(data) => Ok(data + .as_ref() + .map(|r| Some(*r)) + .map_err(|e| PyRuntimeError::new_err(format!("Failed to get response: {}", e)))?), + _ => Ok(None), + } + } +} + +#[pyclass] +pub struct MpmcResponse(queue_rs::nonblocking::MpmcResponse); + +#[pymethods] +impl MpmcResponse { + /// Checks if the response is ready. + /// + /// Returns + /// ------- + /// ``bool`` + /// ``True`` if the response is ready, ``False`` otherwise. + /// + #[getter] + fn is_ready(&self) -> bool { + self.0.is_ready() + } + + /// Returns the response if it is ready. + /// + /// Raises + /// ------ + /// PyRuntimeError + /// If the method fails. + /// + /// Returns + /// ------- + /// :py:class:`MpmcResponseVariant` + /// The response if it is ready, + /// ``None`` + /// otherwise. + /// + fn try_get(&self) -> PyResult> { + self.0 + .try_get() + .map(|rvo| rvo.map(MpmcResponseVariant)) + .map_err(|e| PyRuntimeError::new_err(format!("Failed to get response: {}", e))) + } + + /// Returns the response in a blocking way. + /// + /// **GIL**: the method releases the GIL + /// + /// Raises + /// ------ + /// PyRuntimeError + /// If the method fails. + /// + /// Returns + /// ------- + /// :py:class:`MpmcResponseVariant` + /// The response when it is ready. + /// + fn get(&self) -> PyResult { + Python::with_gil(|py| { + py.allow_threads(|| { + self.0 + .get() + .map(MpmcResponseVariant) + .map_err(|e| PyRuntimeError::new_err(format!("Failed to get response: {}", e))) + }) + }) + } +} + +/// A persistent queue with a ttl that supports multiple consumers marked with labels. This is a +/// non-blocking implementation. All methods return the future-like object :py:class:`MpmcResponse` +/// which must be used to get the actual response. +/// +/// Parameters +/// ---------- +/// path : str +/// The path to the queue. +/// ttl : int +/// The amount of seconds after which the element in the queue will be removed. Ttl is non-strict +/// meaning that it is guaranteed that the element inserted will remain in the queue for at least +/// ttl amount of time and the queue will make efforts to remove the element as soon as possible +/// after ttl seconds of its insertion. +/// max_inflight_ops : int +/// The maximum number of inflight operations. If the number of inflight operations reached its limit, +/// further ops are blocked until the capacity is available. Default to ``1_000``. +/// +/// Raises +/// ------ +/// PyRuntimeError +/// If the queue could not be created. +/// +#[pyclass] +pub struct MpmcQueue(queue_rs::nonblocking::MpmcQueue); + +#[pymethods] +impl MpmcQueue { + #[new] + #[pyo3(signature=(path, ttl, max_inflight_ops = 1_000))] + fn new(path: &str, ttl: u32, max_inflight_ops: usize) -> PyResult { + let q = queue_rs::nonblocking::MpmcQueue::new( + path, + Duration::from_secs(ttl as u64), + max_inflight_ops, + ) + .map_err(|e| PyRuntimeError::new_err(format!("Failed to create mpmc queue: {}", e)))?; + Ok(Self(q)) + } + + /// Adds items to the queue. + /// + /// **GIL**: the method can optionally be called without the GIL. + /// + /// Parameters + /// ---------- + /// items : list of bytes + /// The items to add to the queue. + /// no_gil : bool + /// If True, the method will be called without the GIL. Default is ``True``. + /// + /// Raises + /// ------ + /// PyRuntimeError + /// If the method fails. + /// + /// Returns + /// ------- + /// :py:class:`MpmcResponse` + /// The future-like object which must be used to get the actual response. For the add operation, + /// the response object is only useful to call for `is_ready()`. + /// + #[pyo3(signature = (items, no_gil = true))] + fn add(&self, items: Vec<&PyBytes>, no_gil: bool) -> PyResult { + let data = items.iter().map(|e| e.as_bytes()).collect::>(); + Python::with_gil(|py| { + let f = || { + self.0 + .add(&data) + .map_err(|e| PyRuntimeError::new_err(format!("Failed to add items: {}", e))) + }; + + if no_gil { + py.allow_threads(f) + } else { + f() + } + }) + .map(MpmcResponse) + } + + #[getter] + pub fn inflight_ops(&self) -> PyResult { + self.0 + .inflight_ops() + .map_err(|e| PyRuntimeError::new_err(format!("Failed to get inflight ops: {}", e))) + } + + /// Retrieves items from the queue. + /// + /// **GIL**: the method can optionally be called without the GIL. + /// + /// Parameters + /// ---------- + /// label: str + /// The consumer label that determines the start position in the queue to retrieve elements. + /// If the label does not exist the start position is determined by ``option` parameter. If + /// the label exists the start position is the next element after the last call of this + /// method. If some elements are expired between the last and this call next non-expired + /// elements will be retrieved. + /// start_position: StartPosition + /// The option that determines the start position in the queue to retrieve elements if the + /// consumer label does not exist. + /// max_elements : int + /// The maximum number of elements to retrieve. Default is ``1``. + /// no_gil : bool + /// If True, the method will be called without the GIL. Default is ``True``. + /// + /// Raises + /// ------ + /// PyRuntimeError + /// If the method fails. + /// + /// Returns + /// ------- + /// :py:class:`MpmcResponse` + /// The future-like object which must be used to get the actual response. For the add operation, + /// the response object is useful to call for ``is_ready()``, ``try_get()`` and ``get()``. + /// + #[pyo3(signature = (label, start_position, max_elements = 1, no_gil = true))] + fn next( + &self, + label: &str, + start_position: StartPosition, + max_elements: usize, + no_gil: bool, + ) -> PyResult { + Python::with_gil(|py| { + let start_position = match start_position { + StartPosition::Oldest => mpmc::StartPosition::Oldest, + StartPosition::Newest => mpmc::StartPosition::Newest, + }; + if no_gil { + py.allow_threads(|| self.0.next(max_elements, label, start_position)) + } else { + self.0.next(max_elements, label, start_position) + } + .map(MpmcResponse) + .map_err(|e| PyRuntimeError::new_err(format!("Failed to pop items: {}", e))) + }) + } + + /// Returns the disk size of the queue. + /// + /// Raises + /// ------ + /// PyRuntimeError + /// + /// Returns + /// ------- + /// :py:class:`MpmcResponse` + /// The future-like object which must be used to get the actual response. For the size operation, + /// the response object is useful to call for ``is_ready()``, ``try_get()`` and ``get()``. + /// + #[getter] + pub fn disk_size(&self) -> PyResult { + self.0 + .disk_size() + .map(MpmcResponse) + .map_err(|e| PyRuntimeError::new_err(format!("Failed to get size: {}", e))) + } + + /// Returns the length of the queue. + /// + /// Raises + /// ------ + /// PyRuntimeError + /// If the method fails. + /// + /// Returns + /// ------- + /// :py:class:`MpmcResponse` + /// The future-like object which must be used to get the actual response. For the length operation, + /// the response object is useful to call for ``is_ready()``, ``try_get()`` and ``get()``. + /// + #[getter] + pub fn len(&self) -> PyResult { + self.0 + .len() + .map(MpmcResponse) + .map_err(|e| PyRuntimeError::new_err(format!("Failed to get length: {}", e))) + } + + /// Returns the consumer labels. + /// + /// Returns + /// ------- + /// labels: list of str + /// The consumer labels. + /// + #[getter] + fn labels(&self) -> PyResult { + self.0 + .get_labels() + .map(MpmcResponse) + .map_err(|e| PyRuntimeError::new_err(format!("Failed to get labels: {}", e))) + } + + /// Remove the consumer label from the queue. + /// + /// **GIL**: the method can optionally be called without the GIL. + /// + /// Parameters + /// ---------- + /// label : str + /// The consumer label to remove. + /// no_gil : bool + /// If True, the method will be called without the GIL. Default is ``True``. + /// + /// Raises + /// ------ + /// PyRuntimeError + /// If the method fails. + /// + /// Returns + /// ------- + /// bool + /// ``True`` if the consumer label existed, ``False`` otherwise. + /// + #[pyo3(signature = (label, no_gil = true))] + fn remove_label(&self, label: &str, no_gil: bool) -> PyResult { + Python::with_gil(|py| { + let f = || { + self.0 + .remove_label(label) + .map_err(|e| PyRuntimeError::new_err(format!("Failed to remove label: {}", e))) + }; + + if no_gil { + py.allow_threads(f) + } else { + f() + } + }) + .map(MpmcResponse) + } +} diff --git a/queue_rs/Cargo.toml b/queue_rs/Cargo.toml index 4f63580..09f090f 100644 --- a/queue_rs/Cargo.toml +++ b/queue_rs/Cargo.toml @@ -21,7 +21,12 @@ crate-type = ["cdylib", "lib"] anyhow = "1" parking_lot = "0.12" crossbeam-channel = "0.5.8" +bincode = "2.0.0-rc.3" +chrono = "0.4.38" [dependencies.rocksdb] version = "0.22" default-features = false + +[dev-dependencies] +tempfile = "3.12.0" diff --git a/queue_rs/src/blocking.rs b/queue_rs/src/blocking.rs index 01ded4b..5c5d1c8 100644 --- a/queue_rs/src/blocking.rs +++ b/queue_rs/src/blocking.rs @@ -1,7 +1,10 @@ +use crate::mpmc; +use crate::mpmc::StartPosition; use anyhow::Result; use parking_lot::Mutex; use rocksdb::Options; use std::sync::Arc; +use std::time::Duration; #[derive(Clone)] pub struct PersistentQueueWithCapacity(Arc>); @@ -41,3 +44,50 @@ impl PersistentQueueWithCapacity { crate::PersistentQueueWithCapacity::remove_db(path) } } + +#[derive(Clone)] +pub struct MpmcQueue(Arc>); + +impl MpmcQueue { + pub fn new(path: &str, ttl: Duration) -> Result { + let inner = mpmc::MpmcQueue::new(path, ttl)?; + Ok(Self(Arc::new(Mutex::new(inner)))) + } + + pub fn remove_db(path: &str) -> Result<()> { + mpmc::MpmcQueue::remove_db(path) + } + + pub fn disk_size(&self) -> Result { + self.0.lock().disk_size() + } + + pub fn len(&self) -> usize { + self.0.lock().len() + } + + pub fn is_empty(&self) -> bool { + self.0.lock().is_empty() + } + + pub fn add(&self, values: &[&[u8]]) -> Result<()> { + self.0.lock().add(values) + } + + pub fn next( + &self, + max_elts: usize, + label: &str, + start_position: StartPosition, + ) -> Result>> { + self.0.lock().next(max_elts, label, start_position) + } + + pub fn get_labels(&self) -> Vec { + self.0.lock().get_labels() + } + + pub fn remove_label(&self, label: &str) -> Result { + self.0.lock().remove_label(label) + } +} diff --git a/queue_rs/src/lib.rs b/queue_rs/src/lib.rs index ca848e0..be67694 100644 --- a/queue_rs/src/lib.rs +++ b/queue_rs/src/lib.rs @@ -1,7 +1,10 @@ pub mod blocking; mod fs; +pub mod mpmc; pub mod nonblocking; +mod utilities; +use crate::utilities::{index_to_key, next_index, u64_from_byte_vec}; use anyhow::{anyhow, Result}; use rocksdb::{Options, DB}; use std::cmp::Ordering; @@ -27,7 +30,7 @@ const READ_INDEX_CELL: u64 = u64::MAX - 1; const SPACE_STAT_CELL: u64 = u64::MAX - 2; #[cfg(test)] -const MAX_ALLOWED_INDEX: u64 = 4; +const MAX_ALLOWED_INDEX: u64 = 6; #[cfg(not(test))] const MAX_ALLOWED_INDEX: u64 = u64::MAX - 100; @@ -48,37 +51,25 @@ impl PersistentQueueWithCapacity { let db = DB::open(&db_opts, path)?; - let write_index_opt = db.get(Self::index_to_key(WRITE_INDEX_CELL))?; + let write_index_opt = db.get(index_to_key(WRITE_INDEX_CELL))?; let write_index = match write_index_opt { - Some(v) => { - let mut buf = [0u8; U64_BYTE_LEN]; - buf.copy_from_slice(&v); - u64::from_le_bytes(buf) - } + Some(v) => u64_from_byte_vec(&v), None => 0u64, }; - let read_index_opt = db.get(Self::index_to_key(READ_INDEX_CELL))?; + let read_index_opt = db.get(index_to_key(READ_INDEX_CELL))?; let read_index = match read_index_opt { - Some(v) => { - let mut buf = [0u8; U64_BYTE_LEN]; - buf.copy_from_slice(&v); - u64::from_le_bytes(buf) - } + Some(v) => u64_from_byte_vec(&v), None => 0u64, }; - let space_stat_opt = db.get(Self::index_to_key(SPACE_STAT_CELL))?; + let space_stat_opt = db.get(index_to_key(SPACE_STAT_CELL))?; let space_stat = match space_stat_opt { - Some(v) => { - let mut buf = [0u8; U64_BYTE_LEN]; - buf.copy_from_slice(&v); - u64::from_le_bytes(buf) - } + Some(v) => u64_from_byte_vec(&v), None => 0u64, }; - let empty = db.get(Self::index_to_key(read_index))?.is_none(); + let empty = db.get(index_to_key(read_index))?.is_none(); Ok(Self { db, @@ -91,10 +82,6 @@ impl PersistentQueueWithCapacity { }) } - fn index_to_key(index: u64) -> [u8; U64_BYTE_LEN] { - index.to_le_bytes() - } - pub fn remove_db(path: &str) -> Result<()> { Ok(DB::destroy(&Options::default(), path)?) } @@ -135,24 +122,15 @@ impl PersistentQueueWithCapacity { let mut write_index = self.write_index; for value in values { - batch.put(Self::index_to_key(write_index), value); - write_index += 1; - if write_index == MAX_ALLOWED_INDEX { - write_index = 0; - } + batch.put(index_to_key(write_index), value); + write_index = next_index(write_index); } - batch.put( - Self::index_to_key(WRITE_INDEX_CELL), - write_index.to_le_bytes(), - ); + batch.put(index_to_key(WRITE_INDEX_CELL), write_index.to_le_bytes()); let space_stat = self.space_stat + values.iter().map(|v| v.len() as u64).sum::(); - batch.put( - Self::index_to_key(SPACE_STAT_CELL), - space_stat.to_le_bytes(), - ); + batch.put(index_to_key(SPACE_STAT_CELL), space_stat.to_le_bytes()); self.db.write(batch)?; @@ -168,15 +146,12 @@ impl PersistentQueueWithCapacity { let mut batch = rocksdb::WriteBatch::default(); let mut read_index = self.read_index; loop { - let key = Self::index_to_key(read_index); + let key = index_to_key(read_index); let value = self.db.get(key)?; if let Some(v) = value { batch.delete(key); res.push(v); - read_index += 1; - if read_index == MAX_ALLOWED_INDEX { - read_index = 0; - } + read_index = next_index(read_index); max_elts -= 1; } else { break; @@ -191,14 +166,8 @@ impl PersistentQueueWithCapacity { if !res.is_empty() { let empty = read_index == self.write_index; let space_stat = self.space_stat - res.iter().map(|v| v.len() as u64).sum::(); - batch.put( - Self::index_to_key(SPACE_STAT_CELL), - space_stat.to_le_bytes(), - ); - batch.put( - Self::index_to_key(READ_INDEX_CELL), - read_index.to_le_bytes(), - ); + batch.put(index_to_key(SPACE_STAT_CELL), space_stat.to_le_bytes()); + batch.put(index_to_key(READ_INDEX_CELL), read_index.to_le_bytes()); self.db.write(batch)?; self.read_index = read_index; @@ -226,20 +195,26 @@ mod tests { .unwrap(); db.push(&[&[1, 2, 3]]).unwrap(); db.push(&[&[4, 5, 6]]).unwrap(); - assert_eq!(db.len(), 2); - assert_eq!(db.payload_size(), 6); + db.push(&[&[7, 8, 9]]).unwrap(); + db.push(&[&[10, 11, 12]]).unwrap(); + assert!(!db.is_empty()); + assert_eq!(db.len(), 4); + assert_eq!(db.payload_size(), 12); assert!(matches!(db.pop(1), Ok(v ) if v == vec![vec![1, 2, 3]])); assert!(matches!(db.pop(1), Ok(v) if v == vec![vec![4, 5, 6]])); + assert!(matches!(db.pop(1), Ok(v) if v == vec![vec![7, 8, 9]])); + assert!(matches!(db.pop(1), Ok(v) if v == vec![vec![10, 11, 12]])); assert!(db.is_empty()); + assert_eq!(db.len(), 0); db.push(&[&[1, 2, 3]]).unwrap(); db.push(&[&[4, 5, 6]]).unwrap(); db.push(&[&[7, 8, 9]]).unwrap(); assert_eq!(db.len(), 3); - assert_eq!(db.read_index, 2); + assert_eq!(db.read_index, 4); assert_eq!(db.write_index, 1); assert!(matches!(db.pop(1), Ok(v) if v == vec![vec![1, 2, 3]])); assert_eq!(db.len(), 2); - assert_eq!(db.read_index, 3); + assert_eq!(db.read_index, 5); assert_eq!(db.write_index, 1); assert!(matches!(db.pop(1), Ok(v) if v == vec![vec![4, 5, 6]])); assert_eq!(db.len(), 1); @@ -248,6 +223,8 @@ mod tests { let data = db.pop(1).unwrap(); assert!(db.is_empty()); assert_eq!(db.len(), 0); + assert_eq!(db.read_index, 1); + assert_eq!(db.write_index, 1); assert_eq!(data, vec![vec![7, 8, 9]]); } PersistentQueueWithCapacity::remove_db(&path).unwrap(); diff --git a/queue_rs/src/mpmc.rs b/queue_rs/src/mpmc.rs new file mode 100644 index 0000000..abdcb67 --- /dev/null +++ b/queue_rs/src/mpmc.rs @@ -0,0 +1,1309 @@ +use std::cmp::Ordering; +use std::collections::HashMap; +use std::time::Duration; + +use anyhow::{Error, Result}; +use bincode::config::Configuration; +use bincode::{Decode, Encode}; +use rocksdb::{ColumnFamilyDescriptor, Direction, IteratorMode, Options, SliceTransform, DB}; + +use crate::utilities::{ + current_timestamp, index_to_key, key_to_index, next_index, previous_index, u64_from_byte_vec, +}; +use crate::{fs, MAX_ALLOWED_INDEX}; + +const DATA_CF: &str = "data"; +const SYSTEM_CF: &str = "system"; +const READER_CF: &str = "reader"; +const START_INDEX_KEY: u64 = u64::MAX; +const WRITE_INDEX_KEY: u64 = u64::MAX - 1; +const WRITE_TIMESTAMP_KEY: u64 = u64::MAX - 2; + +#[derive(Clone, Copy)] +pub enum StartPosition { + Oldest, + Newest, +} + +#[derive(Encode, Decode, PartialEq, Debug, Clone)] +struct Reader { + index: u64, + end_timestamp: Option, + expired: bool, +} + +impl Reader { + fn new(index: u64, end_timestamp: Option, expired: bool) -> Self { + Self { + index, + end_timestamp, + expired, + } + } +} + +pub struct MpmcQueue { + db: DB, + path: String, + empty: bool, + start_index: u64, + write_index: u64, + write_timestamp: u64, + read_indices: HashMap, + configuration: Configuration, +} + +impl MpmcQueue { + pub fn new(path: &str, ttl: Duration) -> Result { + let configuration = bincode::config::standard(); + + let mut cf_opts = Options::default(); + cf_opts.create_if_missing(true); + cf_opts.set_prefix_extractor(SliceTransform::create_fixed_prefix(crate::U64_BYTE_LEN)); + let data_cf = ColumnFamilyDescriptor::new(DATA_CF, cf_opts); + + let mut cf_opts = Options::default(); + cf_opts.set_prefix_extractor(SliceTransform::create_fixed_prefix(crate::U64_BYTE_LEN)); + let system_cf = ColumnFamilyDescriptor::new(SYSTEM_CF, cf_opts); + + let reader_cf = ColumnFamilyDescriptor::new(READER_CF, Options::default()); + + let mut db_opts = Options::default(); + db_opts.create_missing_column_families(true); + db_opts.create_if_missing(true); + + let db = DB::open_cf_descriptors_with_ttl( + &db_opts, + path, + vec![system_cf, data_cf, reader_cf], + ttl, + )?; + + let system_cf = db.cf_handle(SYSTEM_CF).unwrap(); + let start_index_opt = db.get_cf(&system_cf, index_to_key(START_INDEX_KEY))?; + let start_index = match start_index_opt { + Some(v) => u64_from_byte_vec(&v), + None => 0u64, + }; + let write_index_opt = db.get_cf(&system_cf, index_to_key(WRITE_INDEX_KEY))?; + let write_index = match write_index_opt { + Some(v) => u64_from_byte_vec(&v), + None => 0u64, + }; + let write_timestamp_opt = db.get_cf(&system_cf, index_to_key(WRITE_TIMESTAMP_KEY))?; + let write_timestamp = match write_timestamp_opt { + Some(v) => u64_from_byte_vec(&v), + None => current_timestamp(), + }; + + let data_cf = db.cf_handle(DATA_CF).unwrap(); + let mut empty = true; + let iterator = db.iterator_cf(data_cf, IteratorMode::Start); + // the error with iterator.next() - error[E0505]: cannot move out of `db` because it is borrowed + #[allow(clippy::never_loop)] + for item in iterator { + item?; + empty = false; + break; + } + + let mut read_indices = HashMap::new(); + let reader_cf = db.cf_handle(READER_CF).unwrap(); + let iterator = db.iterator_cf(reader_cf, IteratorMode::Start); + for item in iterator { + let (key, value) = item?; + + let key = String::from_utf8(Vec::from(key)).map_err(Error::from)?; + let value = bincode::decode_from_slice(&value, configuration)?.0; + + read_indices.insert(key, value); + } + + Ok(Self { + db, + path: path.to_string(), + empty, + start_index, + write_index, + write_timestamp, + read_indices, + configuration, + }) + } + + pub fn remove_db(path: &str) -> Result<()> { + Ok(DB::destroy(&Options::default(), path)?) + } + + pub fn disk_size(&self) -> Result { + Ok(fs::dir_size(&self.path)?) + } + + pub fn len(&self) -> usize { + if self.empty { + 0 + } else { + (match self.write_index.cmp(&self.start_index) { + Ordering::Less => MAX_ALLOWED_INDEX - self.start_index + self.write_index, + Ordering::Equal => MAX_ALLOWED_INDEX, + Ordering::Greater => self.write_index - self.start_index, + }) as usize + } + } + + pub fn is_empty(&self) -> bool { + self.empty + } + + pub fn add(&mut self, values: &[&[u8]]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + self.actualize_indices()?; + if self.len() + values.len() > MAX_ALLOWED_INDEX as usize { + return Err(anyhow::anyhow!("Queue is full")); + } + + let data_cf = self.db.cf_handle(DATA_CF).unwrap(); + let system_cf = self.db.cf_handle(SYSTEM_CF).unwrap(); + let mut batch = rocksdb::WriteBatch::default(); + let mut write_index = self.write_index; + + for value in values { + batch.put_cf(data_cf, index_to_key(write_index), value); + write_index = next_index(write_index); + } + + batch.put_cf( + system_cf, + index_to_key(WRITE_INDEX_KEY), + write_index.to_le_bytes(), + ); + let write_timestamp = current_timestamp(); + batch.put_cf( + system_cf, + index_to_key(WRITE_TIMESTAMP_KEY), + write_timestamp.to_le_bytes(), + ); + + self.db.write(batch)?; + self.write_index = write_index; + self.write_timestamp = write_timestamp; + self.empty = false; + + Ok(()) + } + + pub fn next( + &mut self, + max_elts: usize, + label: &str, + start_position: StartPosition, + ) -> Result>> { + let mut res = Vec::with_capacity(max_elts); + + self.actualize_indices()?; + let label = label.to_string(); + let data_cf = self.db.cf_handle(DATA_CF).unwrap(); + let reader_cf = self.db.cf_handle(READER_CF).unwrap(); + let mut reader = match self.read_indices.get(&label) { + Some(e) => e.clone(), + None => { + let index = match start_position { + StartPosition::Oldest => self.start_index, + StartPosition::Newest => { + if self.empty { + self.write_index + } else { + previous_index(self.write_index) + } + } + }; + Reader::new(index, None, false) + } + }; + + let mut end = match reader.end_timestamp { + None => reader.index == self.write_index && self.empty, + Some(timestamp) => timestamp == self.write_timestamp, + }; + + while !end && res.len() < max_elts { + let value = self.db.get_cf(data_cf, index_to_key(reader.index))?; + if let Some(v) = value { + res.push(v); + } else { + res.clear(); + reader.expired = true; + } + reader.index = next_index(reader.index); + end = reader.index == self.write_index; + } + reader.end_timestamp = if end { + Some(self.write_timestamp) + } else { + None + }; + // for future to let the user know + let _expired = reader.expired; + reader.expired = false; + if !self.read_indices.get(&label).is_some_and(|e| *e == reader) { + self.db.put_cf( + reader_cf, + label.as_bytes(), + bincode::encode_to_vec(reader.clone(), self.configuration)?, + )?; + + self.read_indices.insert(label, reader); + } + + Ok(res) + } + + pub fn get_labels(&self) -> Vec { + self.read_indices + .iter() + .map(|e| e.0.clone()) + .collect::>() + } + + pub fn remove_label(&mut self, label: &str) -> Result { + let label = label.to_string(); + if self.read_indices.contains_key(&label) { + let reader_cf = self.db.cf_handle(READER_CF).unwrap(); + self.db.delete_cf(reader_cf, label.as_bytes())?; + + self.read_indices.remove(&label); + + return Ok(true); + } + Ok(false) + } + + fn actualize_indices(&mut self) -> Result<()> { + if self.empty { + return Ok(()); + } + + let data_cf = self.db.cf_handle(DATA_CF).unwrap(); + let system_cf = self.db.cf_handle(SYSTEM_CF).unwrap(); + let reader_cf = self.db.cf_handle(READER_CF).unwrap(); + + let mut iter = self.db.iterator_cf( + data_cf, + IteratorMode::From(&index_to_key(self.start_index), Direction::Forward), + ); + let first_entry = if let Some(e) = iter.next() { + Some(e) + } else { + // MAX_ALLOWED_INDEX handling + let mut iter = self.db.iterator_cf(data_cf, IteratorMode::Start); + iter.next() + }; + + let (start_index, empty, f) = match first_entry { + Some(Err(e)) => return Err(anyhow::Error::from(e)), + Some(Ok(e)) => { + let start_index = key_to_index(e.0); + if self.start_index == start_index { + // no elements have been expired + return Ok(()); + } else { + // some elements have been expired + let f: fn(u64, u64, u64, &mut Reader) = + |start, write_index, _write_timestamp, reader| { + let index = if reader.index == write_index { + reader.index + } else if reader.index > write_index { + if start > write_index { + u64::max(start, reader.index) + } else { + start + } + } else if start > write_index { + reader.index + } else { + u64::max(start, reader.index) + }; + reader.expired = index != reader.index; + reader.index = index; + }; + (start_index, false, f) + } + } + None => { + // all elements have been expired + let f: fn(u64, u64, u64, &mut Reader) = + |_start, write_index, write_timestamp, reader| { + reader.expired = reader.end_timestamp != Some(write_timestamp); + reader.index = write_index; + reader.end_timestamp = Some(write_timestamp); + }; + (self.write_index, true, f) + } + }; + + let mut batch = rocksdb::WriteBatch::default(); + + for (label, reader) in self.read_indices.iter() { + let mut reader = reader.clone(); + f( + start_index, + self.write_index, + self.write_timestamp, + &mut reader, + ); + batch.put_cf( + reader_cf, + label.as_bytes(), + bincode::encode_to_vec(reader, self.configuration)?, + ); + } + batch.put_cf( + system_cf, + index_to_key(START_INDEX_KEY), + start_index.to_le_bytes(), + ); + + self.db.write(batch)?; + + self.start_index = start_index; + self.empty = empty; + self.read_indices + .iter_mut() + .for_each(|e| f(start_index, self.write_index, self.write_timestamp, e.1)); + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use crate::mpmc::{MpmcQueue, Reader, StartPosition, DATA_CF}; + use crate::utilities::{current_timestamp, index_to_key}; + use crate::MAX_ALLOWED_INDEX; + use std::collections::HashMap; + use std::fs; + use std::ops::{Add, Div, Mul}; + use std::thread::sleep; + use std::time::Duration; + + #[test] + pub fn test_new_empty() { + let path = std::env::temp_dir().join("empty"); + let path = path.to_str().unwrap(); + let _ = fs::remove_dir_all(path); + + let ttl = Duration::from_secs(60); + let now = current_timestamp(); + + let queue = MpmcQueue::new(path, ttl).unwrap(); + + assert_eq!(queue.path, path); + assert_eq!(queue.start_index, 0); + assert_eq!(queue.write_index, 0); + assert!(queue.write_timestamp > now); + assert_eq!(queue.read_indices.is_empty(), true); + assert_eq!(queue.empty, true); + assert_eq!(queue.len(), 0); + + let _ = fs::remove_dir_all(path); + } + + #[test] + pub fn test_new_non_empty() { + let path = std::env::temp_dir().join("non-empty"); + let path = path.to_str().unwrap(); + let _ = fs::remove_dir_all(path); + + let ttl = Duration::from_secs(1); + let values = vec!["a".as_bytes(), "b".as_bytes()]; + let label = "label"; + let write_timestamp = { + let mut queue = MpmcQueue::new(path, ttl).unwrap(); + + queue.add(&values).unwrap(); + queue.next(1, label, StartPosition::Oldest).unwrap(); + + wait_and_expire(&mut queue, ttl.mul(2)); + + queue.add(&values).unwrap(); + queue.next(1, label, StartPosition::Oldest).unwrap(); + + assert_eq!(queue.start_index, 2); + assert_eq!(queue.write_index, 4); + assert_eq!( + queue.read_indices, + HashMap::from([(label.to_string(), Reader::new(3, None, false))]) + ); + assert_eq!(queue.empty, false); + queue.write_timestamp + }; + + { + let queue = MpmcQueue::new(path, ttl).unwrap(); + + assert_eq!(queue.start_index, 2); + assert_eq!(queue.write_index, 4); + assert_eq!(queue.write_timestamp, write_timestamp); + assert_eq!( + queue.read_indices, + HashMap::from([(label.to_string(), Reader::new(3, None, false))]) + ); + assert_eq!(queue.empty, false); + } + + let _ = fs::remove_dir_all(path); + } + + #[test] + pub fn test_add() { + test(Duration::from_secs(10), |mut queue| { + queue.add(&["a".as_bytes()]).unwrap(); + + assert_eq!(queue.start_index, 0); + assert_eq!(queue.write_index, 1); + assert_eq!(queue.read_indices.is_empty(), true); + assert_eq!(queue.empty, false); + assert_eq!(queue.len(), 1); + }); + } + + #[test] + pub fn test_add_too_big_batch() { + test(Duration::from_secs(10), |mut queue| { + let values = vec!["a".as_bytes(); (MAX_ALLOWED_INDEX + 1) as usize]; + let result = queue.add(&values); + + assert_eq!(result.is_err(), true); + assert_eq!(queue.start_index, 0); + assert_eq!(queue.write_index, 0); + assert_eq!(queue.read_indices.is_empty(), true); + assert_eq!(queue.empty, true); + assert_eq!(queue.len(), 0); + }); + } + + #[test] + pub fn test_add_full_queue() { + test(Duration::from_secs(10), |mut queue| { + let values = vec!["a".as_bytes(); MAX_ALLOWED_INDEX as usize]; + queue.add(&values).unwrap(); + + assert_eq!(queue.start_index, 0); + assert_eq!(queue.write_index, 0); + assert_eq!(queue.read_indices.is_empty(), true); + assert_eq!(queue.empty, false); + assert_eq!(queue.len(), MAX_ALLOWED_INDEX as usize); + + let result = queue.add(&["b".as_bytes()]); + + assert_eq!(result.is_err(), true); + + assert_eq!(queue.start_index, 0); + assert_eq!(queue.write_index, 0); + assert_eq!(queue.read_indices.is_empty(), true); + assert_eq!(queue.empty, false); + assert_eq!(queue.len(), MAX_ALLOWED_INDEX as usize); + }); + } + + #[test] + pub fn test_next_new_label_oldest_empty_queue() { + let ttl = Duration::from_secs(60); + let label = "label"; + + test(ttl, |mut queue| { + let result = queue.next(100, label, StartPosition::Oldest).unwrap(); + + assert_eq!(result.is_empty(), true); + assert_eq!( + queue.read_indices, + HashMap::from([( + label.to_string(), + Reader::new(0, Some(queue.write_timestamp), false) + )]) + ); + }); + } + + #[test] + pub fn test_next_new_label_newest_empty_queue() { + let ttl = Duration::from_secs(60); + let label = "label"; + + test(ttl, |mut queue| { + let result = queue.next(100, label, StartPosition::Newest).unwrap(); + + assert_eq!(result.is_empty(), true); + assert_eq!( + queue.read_indices, + HashMap::from([( + label.to_string(), + Reader::new(0, Some(queue.write_timestamp), false) + )]) + ); + }); + } + + #[test] + pub fn test_next_new_label_oldest_non_empty_queue() { + let ttl = Duration::from_secs(60); + let label = "label"; + let value_one = "a".as_bytes(); + let value_two = "b".as_bytes(); + let value_three = "c".as_bytes(); + let start_position = StartPosition::Oldest; + + test(ttl, |mut queue| { + queue.add(&[value_one, value_two, value_three]).unwrap(); + + let result = queue.next(2, label, start_position).unwrap(); + + assert_eq!(result, vec![value_one.to_vec(), value_two.to_vec()]); + assert_eq!( + queue.read_indices, + HashMap::from([(label.to_string(), Reader::new(2, None, false))]) + ); + + let result = queue.next(2, label, start_position).unwrap(); + + assert_eq!(result, vec![value_three.to_vec()]); + assert_eq!( + queue.read_indices, + HashMap::from([( + label.to_string(), + Reader::new(3, Some(queue.write_timestamp), false) + )]) + ); + }); + } + + #[test] + pub fn test_next_new_label_newest_non_empty_queue() { + let ttl = Duration::from_secs(60); + let label = "label"; + let value_one = "a".as_bytes(); + let value_two = "b".as_bytes(); + let value_three = "c".as_bytes(); + let start_position = StartPosition::Newest; + + test(ttl, |mut queue| { + queue.add(&[value_one, value_two, value_three]).unwrap(); + + let result = queue.next(2, label, start_position).unwrap(); + + assert_eq!(result, vec![value_three.to_vec()]); + assert_eq!( + queue.read_indices, + HashMap::from([( + label.to_string(), + Reader::new(3, Some(queue.write_timestamp), false) + )]) + ); + + let result = queue.next(2, label, start_position).unwrap(); + + assert_eq!(result.is_empty(), true); + assert_eq!( + queue.read_indices, + HashMap::from([( + label.to_string(), + Reader::new(3, Some(queue.write_timestamp), false) + )]) + ); + }); + } + + #[test] + pub fn test_next_new_label_newest_full_queue() { + let ttl = Duration::from_secs(1); + let label = "label"; + let last_value = "last".as_bytes(); + + test(ttl, |mut queue| { + queue + .add(&["value".as_bytes(); (MAX_ALLOWED_INDEX - 1) as usize]) + .unwrap(); + queue.add(&[last_value]).unwrap(); + + let result = queue.next(1, label, StartPosition::Newest).unwrap(); + + assert_eq!(result, vec![last_value.to_vec()]); + assert_eq!( + queue.read_indices, + HashMap::from([( + label.to_string(), + Reader::new(0, Some(queue.write_timestamp), false) + )]) + ); + + let result = queue.next(1, label, StartPosition::Newest).unwrap(); + + assert_eq!(result.is_empty(), true); + assert_eq!( + queue.read_indices, + HashMap::from([( + label.to_string(), + Reader::new(0, Some(queue.write_timestamp), false) + )]) + ); + }); + } + + #[test] + pub fn test_next_new_label_oldest_all_expired() { + let ttl = Duration::from_secs(1); + let label_one = "label1"; + let label_two = "label2"; + + test(ttl, |mut queue| { + queue.add(&["a".as_bytes(), "b".as_bytes()]).unwrap(); + + wait_and_expire(&mut queue, ttl.mul(2)); + + let result = queue.next(2, label_one, StartPosition::Oldest).unwrap(); + + assert_eq!(result.is_empty(), true); + assert_eq!(queue.start_index, 2); + assert_eq!(queue.write_index, 2); + assert_eq!( + queue.read_indices, + HashMap::from([( + label_one.to_string(), + Reader::new(2, Some(queue.write_timestamp), false) + )]) + ); + assert_eq!(queue.empty, true); + assert_eq!(queue.len(), 0); + + let write_timestamp = queue.write_timestamp; + + queue + .add(&["a".as_bytes(); MAX_ALLOWED_INDEX as usize]) + .unwrap(); + + assert_eq!(queue.start_index, 2); + assert_eq!(queue.write_index, 2); + assert_eq!( + queue.read_indices, + HashMap::from([( + label_one.to_string(), + Reader::new(2, Some(write_timestamp), false) + )]) + ); + assert_eq!(queue.empty, false); + assert_eq!(queue.len(), MAX_ALLOWED_INDEX as usize); + assert!(queue.write_timestamp > write_timestamp); + + // expire all + wait_and_expire(&mut queue, ttl.mul(2)); + + queue.next(1, label_two, StartPosition::Oldest).unwrap(); + + assert_eq!(queue.start_index, 2); + assert_eq!(queue.write_index, 2); + assert_eq!( + queue.read_indices, + HashMap::from([ + ( + label_one.to_string(), + Reader::new(2, Some(queue.write_timestamp), true) + ), + ( + label_two.to_string(), + Reader::new(2, Some(queue.write_timestamp), false) + ) + ]) + ); + assert_eq!(queue.empty, true); + assert_eq!(queue.len(), 0); + }); + } + + #[test] + pub fn test_next_new_label_newest_all_expired() { + let ttl = Duration::from_secs(1); + let label_one = "label1"; + let label_two = "label2"; + + test(ttl, |mut queue| { + queue.add(&["a".as_bytes(), "b".as_bytes()]).unwrap(); + + wait_and_expire(&mut queue, ttl.mul(2)); + + let result = queue.next(2, label_one, StartPosition::Newest).unwrap(); + + assert_eq!(result.is_empty(), true); + assert_eq!(queue.start_index, 2); + assert_eq!(queue.write_index, 2); + assert_eq!( + queue.read_indices, + HashMap::from([( + label_one.to_string(), + Reader::new(2, Some(queue.write_timestamp), false) + )]) + ); + assert_eq!(queue.empty, true); + assert_eq!(queue.len(), 0); + + let write_timestamp = queue.write_timestamp; + + queue + .add(&["a".as_bytes(); MAX_ALLOWED_INDEX as usize]) + .unwrap(); + + assert_eq!(queue.start_index, 2); + assert_eq!(queue.write_index, 2); + assert_eq!( + queue.read_indices, + HashMap::from([( + label_one.to_string(), + Reader::new(2, Some(write_timestamp), false) + )]) + ); + assert_eq!(queue.empty, false); + assert_eq!(queue.len(), MAX_ALLOWED_INDEX as usize); + assert!(queue.write_timestamp > write_timestamp); + + // expire all + wait_and_expire(&mut queue, ttl.mul(2)); + + queue.next(1, label_two, StartPosition::Newest).unwrap(); + + assert_eq!(queue.start_index, 2); + assert_eq!(queue.write_index, 2); + assert_eq!( + queue.read_indices, + HashMap::from([ + ( + label_one.to_string(), + Reader::new(2, Some(queue.write_timestamp), true) + ), + ( + label_two.to_string(), + Reader::new(2, Some(queue.write_timestamp), false) + ) + ]) + ); + assert_eq!(queue.empty, true); + assert_eq!(queue.len(), 0); + }); + } + + #[test] + pub fn test_next_new_label_oldest_some_expired() { + let ttl = Duration::from_secs(1); + let label = "label"; + let value = "c".as_bytes(); + + test(ttl, |mut queue| { + queue.add(&["a".as_bytes()]).unwrap(); + + sleep(ttl.mul(2)); + + queue.add(&[value, "b".as_bytes()]).unwrap(); + + let data_cf = queue.db.cf_handle(DATA_CF).unwrap(); + queue + .db + .compact_range_cf(data_cf, None::<&[u8]>, None::<&[u8]>); + + let result = queue.next(1, label, StartPosition::Oldest).unwrap(); + + assert_eq!(result, vec![value]); + assert_eq!(queue.start_index, 1); + assert_eq!(queue.write_index, 3); + assert_eq!( + queue.read_indices, + HashMap::from([(label.to_string(), Reader::new(2, None, false))]) + ); + assert_eq!(queue.empty, false); + assert_eq!(queue.len(), 2); + }); + } + + #[test] + pub fn test_next_new_label_newest_some_expired() { + let ttl = Duration::from_secs(1); + let label = "label"; + let value = "c".as_bytes(); + + test(ttl, |mut queue| { + queue.add(&["a".as_bytes()]).unwrap(); + + sleep(ttl.mul(2)); + + queue.add(&["b".as_bytes(), value]).unwrap(); + + let data_cf = queue.db.cf_handle(DATA_CF).unwrap(); + queue + .db + .compact_range_cf(data_cf, None::<&[u8]>, None::<&[u8]>); + + let result = queue.next(1, label, StartPosition::Newest).unwrap(); + + assert_eq!(result, vec![value]); + assert_eq!(queue.start_index, 1); + assert_eq!(queue.write_index, 3); + assert_eq!( + queue.read_indices, + HashMap::from([( + label.to_string(), + Reader::new(3, Some(queue.write_timestamp), false) + )]) + ); + assert_eq!(queue.empty, false); + assert_eq!(queue.len(), 2); + }); + } + + #[test] + pub fn test_next_after_ttl_start_index_less_write_index() { + let quarter_ttl = Duration::from_secs(1); + let ttl = quarter_ttl.mul(4); + let label_one = "label1"; + let label_two = "label2"; + let label_three = "label3"; + let label_four = "label4"; + let label_five = "label5"; + let value_one = "a".as_bytes(); + let value_two = "b".as_bytes(); + let value_three = "c".as_bytes(); + let value_four = "d".as_bytes(); + let start_position = StartPosition::Oldest; + + test(ttl, |mut queue| { + queue.add(&[value_one, value_two]).unwrap(); + + wait_and_expire(&mut queue, quarter_ttl.mul(3)); + + queue.add(&[value_three, value_four]).unwrap(); + + // read < write, start > read + queue.next(1, label_one, start_position).unwrap(); + // read < write, start = read + queue.next(2, label_two, start_position).unwrap(); + // read < write, start < read + queue.next(3, label_three, start_position).unwrap(); + // read = write + queue.next(4, label_four, start_position).unwrap(); + + assert_eq!(queue.start_index, 0); + assert_eq!(queue.write_index, 4); + assert_eq!( + queue.read_indices, + HashMap::from([ + (label_one.to_string(), Reader::new(1, None, false)), + (label_two.to_string(), Reader::new(2, None, false)), + (label_three.to_string(), Reader::new(3, None, false)), + ( + label_four.to_string(), + Reader::new(4, Some(queue.write_timestamp), false) + ) + ]) + ); + assert_eq!(queue.empty, false); + assert_eq!(queue.len(), 4); + + wait_and_expire(&mut queue, quarter_ttl.mul(2)); + + queue.next(1, label_five, start_position).unwrap(); + + assert_eq!(queue.start_index, 2); + assert_eq!(queue.write_index, 4); + assert_eq!( + queue.read_indices, + HashMap::from([ + (label_one.to_string(), Reader::new(2, None, true)), + (label_two.to_string(), Reader::new(2, None, false)), + (label_three.to_string(), Reader::new(3, None, false)), + ( + label_four.to_string(), + Reader::new(4, Some(queue.write_timestamp), false) + ), + (label_five.to_string(), Reader::new(3, None, false)) + ]) + ); + assert_eq!(queue.empty, false); + assert_eq!(queue.len(), 2); + + wait_and_expire(&mut queue, quarter_ttl.mul(3)); + + queue.next(1, label_five, start_position).unwrap(); + + assert_eq!(queue.start_index, 4); + assert_eq!(queue.write_index, 4); + assert_eq!( + queue.read_indices, + HashMap::from([ + ( + label_one.to_string(), + Reader::new(4, Some(queue.write_timestamp), true) + ), + ( + label_two.to_string(), + Reader::new(4, Some(queue.write_timestamp), true) + ), + ( + label_three.to_string(), + Reader::new(4, Some(queue.write_timestamp), true) + ), + ( + label_four.to_string(), + Reader::new(4, Some(queue.write_timestamp), false) + ), + ( + label_five.to_string(), + Reader::new(4, Some(queue.write_timestamp), false) + ) + ]) + ); + assert_eq!(queue.empty, true); + assert_eq!(queue.len(), 0); + }); + } + + #[test] + pub fn test_next_after_ttl_start_index_greater_write_index() { + let quarter_ttl = Duration::from_secs(1); + let ttl = quarter_ttl.mul(4); + let label_one = "label1"; + let label_two = "label2"; + let label_three = "label3"; + let label_four = "label4"; + let label_five = "label5"; + let label_six = "label6"; + let value_one = "a".as_bytes(); + let value_two = "b".as_bytes(); + let value_three = "c".as_bytes(); + let value_four = "d".as_bytes(); + let value_five = "e".as_bytes(); + let value_six = "f".as_bytes(); + let start_position = StartPosition::Oldest; + + test(ttl, |mut queue| { + queue.add(&[value_one, value_two]).unwrap(); + + // expire all to emulate write index < start index + wait_and_expire(&mut queue, quarter_ttl.mul(5)); + + queue.next(1, label_one, start_position).unwrap(); + + assert_eq!(queue.start_index, 2); + assert_eq!(queue.write_index, 2); + assert_eq!( + queue.read_indices, + HashMap::from([( + label_one.to_string(), + Reader::new(2, Some(queue.write_timestamp), false) + )]) + ); + assert_eq!(queue.empty, true); + assert_eq!(queue.len(), 0); + + queue.add(&[value_one, value_two]).unwrap(); + + // value one and value two will expire than elements added later + sleep(quarter_ttl.mul(3)); + + queue + .add(&[value_three, value_four, value_five, value_six]) + .unwrap(); + // read > write, start > read + queue.next(1, label_one, start_position).unwrap(); + // read > write, start = read + queue.next(2, label_two, start_position).unwrap(); + // read > write, start < read + queue.next(3, label_three, start_position).unwrap(); + // read < write, start > read + queue.next(4, label_four, start_position).unwrap(); + queue.next(5, label_five, start_position).unwrap(); + // read == write + queue.next(6, label_six, start_position).unwrap(); + + assert_eq!(queue.start_index, 2); + assert_eq!(queue.write_index, 2); + assert_eq!( + queue.read_indices, + HashMap::from([ + (label_one.to_string(), Reader::new(3, None, false)), + (label_two.to_string(), Reader::new(4, None, false)), + (label_three.to_string(), Reader::new(5, None, false)), + (label_four.to_string(), Reader::new(0, None, false)), + (label_five.to_string(), Reader::new(1, None, false)), + ( + label_six.to_string(), + Reader::new(2, Some(queue.write_timestamp), false) + ) + ]) + ); + assert_eq!(queue.empty, false); + assert_eq!(queue.len(), 6); + + // expire value one and value two + wait_and_expire(&mut queue, quarter_ttl.mul(2)); + + queue.next(1, label_five, start_position).unwrap(); + + assert_eq!(queue.start_index, 4); + assert_eq!(queue.write_index, 2); + assert_eq!( + queue.read_indices, + HashMap::from([ + (label_one.to_string(), Reader::new(4, None, true)), + (label_two.to_string(), Reader::new(4, None, false)), + (label_three.to_string(), Reader::new(5, None, false)), + (label_four.to_string(), Reader::new(0, None, false)), + ( + label_five.to_string(), + Reader::new(2, Some(queue.write_timestamp), false) + ), + ( + label_six.to_string(), + Reader::new(2, Some(queue.write_timestamp), false) + ) + ]) + ); + assert_eq!(queue.empty, false); + assert_eq!(queue.len(), 4); + + // expire all elements + wait_and_expire(&mut queue, quarter_ttl.mul(3)); + + queue.next(1, label_five, start_position).unwrap(); + + assert_eq!(queue.start_index, 2); + assert_eq!(queue.write_index, 2); + assert_eq!( + queue.read_indices, + HashMap::from([ + ( + label_one.to_string(), + Reader::new(2, Some(queue.write_timestamp), true) + ), + ( + label_two.to_string(), + Reader::new(2, Some(queue.write_timestamp), true) + ), + ( + label_three.to_string(), + Reader::new(2, Some(queue.write_timestamp), true) + ), + ( + label_four.to_string(), + Reader::new(2, Some(queue.write_timestamp), true) + ), + ( + label_five.to_string(), + Reader::new(2, Some(queue.write_timestamp), false) + ), + ( + label_six.to_string(), + Reader::new(2, Some(queue.write_timestamp), false) + ) + ]) + ); + assert_eq!(queue.empty, true); + assert_eq!(queue.len(), 0); + }); + } + + #[test] + pub fn test_next_after_ttl_start_index_greater_write_index_reverse() { + let ttl = Duration::from_secs(4); + let label_one = "label1"; + let label_two = "label2"; + + test(ttl, |mut queue| { + queue.add(&["a".as_bytes(); 2]).unwrap(); + + // expire all to emulate write index < start index + wait_and_expire(&mut queue, ttl.add(Duration::from_secs(1))); + + queue + .add(&[ + "a".as_bytes(), + "b".as_bytes(), + "c".as_bytes(), + "d".as_bytes(), + ]) + .unwrap(); + queue.next(3, label_one, StartPosition::Oldest).unwrap(); + + assert_eq!(queue.start_index, 2); + assert_eq!(queue.write_index, 0); + assert_eq!( + queue.read_indices, + HashMap::from([(label_one.to_string(), Reader::new(5, None, false))]) + ); + assert_eq!(queue.empty, false); + assert_eq!(queue.len(), 4); + + sleep(ttl.mul_f32(0.75)); + + queue.add(&["e".as_bytes(), "f".as_bytes()]).unwrap(); + + assert_eq!(queue.start_index, 2); + assert_eq!(queue.write_index, 2); + assert_eq!( + queue.read_indices, + HashMap::from([(label_one.to_string(), Reader::new(5, None, false))]) + ); + assert_eq!(queue.empty, false); + assert_eq!(queue.len(), 6); + + wait_and_expire(&mut queue, ttl.div(2)); + + queue.next(1, label_two, StartPosition::Oldest).unwrap(); + + assert_eq!(queue.start_index, 0); + assert_eq!(queue.write_index, 2); + assert_eq!( + queue.read_indices, + HashMap::from([ + (label_one.to_string(), Reader::new(0, None, true)), + (label_two.to_string(), Reader::new(1, None, false)) + ]) + ); + assert_eq!(queue.empty, false); + assert_eq!(queue.len(), 2); + }); + } + + #[test] + pub fn test_next_during_expiration_some_elements() { + let ttl = Duration::from_secs(4); + let label = "label"; + let value_three = "c".as_bytes(); + let value_four = "d".as_bytes(); + + test(ttl, |mut queue| { + queue + .add(&["a".as_bytes(), "b".as_bytes(), value_three, value_four]) + .unwrap(); + + // emulate that the second value expires right after reading the first value + let data_cf = queue.db.cf_handle(DATA_CF).unwrap(); + queue.db.delete_cf(data_cf, index_to_key(1)).unwrap(); + + let result = queue.next(2, label, StartPosition::Oldest).unwrap(); + + assert_eq!(result, vec![value_three, value_four]); + assert_eq!(queue.start_index, 0); + assert_eq!(queue.write_index, 4); + assert_eq!( + queue.read_indices, + HashMap::from([( + label.to_string(), + Reader::new(4, Some(queue.write_timestamp), false) + )]) + ); + assert_eq!(queue.empty, false); + assert_eq!(queue.len(), 4); + }); + } + + #[test] + pub fn test_next_during_expiration_all_elements() { + let ttl = Duration::from_secs(4); + let label = "label"; + let value_three = "c".as_bytes(); + let value_four = "d".as_bytes(); + + test(ttl, |mut queue| { + queue + .add(&["a".as_bytes(), "b".as_bytes(), value_three, value_four]) + .unwrap(); + + // emulate that the all values expire right after reading the first value + let data_cf = queue.db.cf_handle(DATA_CF).unwrap(); + queue.db.delete_cf(data_cf, index_to_key(1)).unwrap(); + queue.db.delete_cf(data_cf, index_to_key(2)).unwrap(); + queue.db.delete_cf(data_cf, index_to_key(3)).unwrap(); + + let result = queue.next(4, label, StartPosition::Oldest).unwrap(); + + assert_eq!(result.is_empty(), true); + assert_eq!(queue.start_index, 0); + assert_eq!(queue.write_index, 4); + assert_eq!( + queue.read_indices, + HashMap::from([( + label.to_string(), + Reader::new(4, Some(queue.write_timestamp), false) + )]) + ); + assert_eq!(queue.empty, false); + assert_eq!(queue.len(), 4); + }); + } + + #[test] + pub fn test_get_labels() { + let label_one = "label1"; + let label_two = "label2"; + test(Duration::from_secs(10), |mut queue| { + let result = queue.get_labels(); + + assert_eq!(result.is_empty(), true); + + queue.next(1, label_one, StartPosition::Oldest).unwrap(); + let result = queue.get_labels(); + + assert_eq!(result, vec![label_one.to_string()]); + + queue.next(1, label_two, StartPosition::Oldest).unwrap(); + let result = queue.get_labels(); + + assert_eq!(result.len(), 2); + assert_eq!(result.contains(&label_one.to_string()), true); + assert_eq!(result.contains(&label_two.to_string()), true); + }); + } + + #[test] + pub fn test_remove_label() { + let label = "label"; + test(Duration::from_secs(10), |mut queue| { + let result = queue.remove_label(label).unwrap(); + + assert_eq!(result, false); + assert_eq!(queue.read_indices.is_empty(), true); + + queue.next(1, label, StartPosition::Oldest).unwrap(); + assert_eq!( + queue.read_indices, + HashMap::from([( + label.to_string(), + Reader::new(0, Some(queue.write_timestamp), false) + )]) + ); + + let result = queue.remove_label(label).unwrap(); + + assert_eq!(result, true); + assert_eq!(queue.read_indices.is_empty(), true); + }); + } + + fn test(ttl: Duration, mut f: F) + where + F: FnMut(MpmcQueue), + { + let directory = tempfile::TempDir::new().unwrap(); + let path = directory.path().to_str().unwrap(); + + let queue = MpmcQueue::new(path, ttl).unwrap(); + + f(queue); + } + + fn wait_and_expire(queue: &mut MpmcQueue, duration: Duration) { + sleep(duration); + let data_cf = queue.db.cf_handle(DATA_CF).unwrap(); + queue + .db + .compact_range_cf(data_cf, None::<&[u8]>, None::<&[u8]>); + } +} diff --git a/queue_rs/src/nonblocking.rs b/queue_rs/src/nonblocking.rs index 09fcdf7..23ed4e1 100644 --- a/queue_rs/src/nonblocking.rs +++ b/queue_rs/src/nonblocking.rs @@ -1,7 +1,11 @@ +use crate::mpmc; +use crate::mpmc::StartPosition; use anyhow::Result; use crossbeam_channel::{Receiver, Sender}; use std::thread; +use std::time::Duration; +#[derive(Clone)] pub enum Operation { Push(Vec>), Pop(usize), @@ -19,14 +23,37 @@ pub enum ResponseVariant { Stop, } -pub struct Response(Receiver); +#[derive(Clone)] +pub enum MpmcOperation { + Add(Vec>), + Next(usize, String, StartPosition), + Length, + DiskSize, + GetLabels, + RemoveLabel(String), + Stop, +} + +pub enum MpmcResponseVariant { + Add(Result<()>), + Next(Result>>), + Length(usize), + Size(Result), + GetLabels(Vec), + RemoveLabel(Result), + Stop, +} + +pub struct TypedResponse(Receiver); +pub type Response = TypedResponse; +pub type MpmcResponse = TypedResponse; -impl Response { +impl TypedResponse { pub fn is_ready(&self) -> bool { !self.0.is_empty() } - pub fn try_get(&self) -> Result> { + pub fn try_get(&self) -> Result> { let res = self.0.try_recv(); if let Err(crossbeam_channel::TryRecvError::Empty) = &res { return Ok(None); @@ -34,83 +61,56 @@ impl Response { Ok(Some(res?)) } - pub fn get(&self) -> Result { + pub fn get(&self) -> Result { Ok(self.0.recv()?) } } type WorkingThread = Option>>; -type QueueSender = Sender<(Operation, Sender)>; -type QueueType = (WorkingThread, QueueSender); - -pub struct PersistentQueueWithCapacity(QueueType); - -fn start_op_loop( - path: &str, - max_elements: usize, - max_inflight_ops: usize, - db_options: rocksdb::Options, -) -> (WorkingThread, QueueSender) { - let mut queue = - crate::PersistentQueueWithCapacity::new(path, max_elements, db_options).unwrap(); - let (tx, rx) = - crossbeam_channel::bounded::<(Operation, Sender)>(max_inflight_ops); - let handle = thread::spawn(move || { - loop { - match rx.recv() { - Ok((Operation::Push(values), resp_tx)) => { - let value_slices = values.iter().map(|e| e.as_slice()).collect::>(); - let resp = queue.push(&value_slices); - resp_tx.send(ResponseVariant::Push(resp))?; - } - Ok((Operation::Pop(max_elements), resp_tx)) => { - let resp = queue.pop(max_elements); - resp_tx.send(ResponseVariant::Pop(resp))?; - } - Ok((Operation::Length, resp_tx)) => { - let resp = queue.len(); - resp_tx.send(ResponseVariant::Length(resp))?; - } - Ok((Operation::DiskSize, resp_tx)) => { - let resp = queue.disk_size(); - resp_tx.send(ResponseVariant::Size(resp))?; - } - Ok((Operation::PayloadSize, resp_tx)) => { - let resp = queue.payload_size(); - resp_tx.send(ResponseVariant::Size(Ok(resp as usize)))?; - } - Ok((Operation::Stop, resp_tx)) => { - resp_tx.send(ResponseVariant::Stop)?; - break; - } - Err(e) => return Err(anyhow::anyhow!("Error receiving operation: {}", e)), - } - } - Ok(()) - }); +type QueueSender = Sender<(O, Sender)>; +type QueueType = (WorkingThread, QueueSender); +pub struct NonBlockingQueueWrapper(QueueType, O) +where + O: Clone + Send + Sync + 'static, + R: Send + 'static; +pub type PersistentQueueWithCapacity = NonBlockingQueueWrapper; +pub type MpmcQueue = NonBlockingQueueWrapper; + +fn start_op_loop(mut f: F, max_inflight_ops: usize) -> (WorkingThread, QueueSender) +where + F: FnMut(Receiver<(O, Sender)>) -> Result<()> + Send + 'static, + O: Send + Sync + 'static, + R: Send + Sync + 'static, +{ + let (tx, rx) = crossbeam_channel::bounded::<(O, Sender)>(max_inflight_ops); + let handle = thread::spawn(move || f(rx)); (Some(handle), tx) } -impl PersistentQueueWithCapacity { - pub fn new( - path: &str, - max_elements: usize, - max_inflight_ops: usize, - db_options: rocksdb::Options, - ) -> Result { - let (handle, tx) = start_op_loop(path, max_elements, max_inflight_ops, db_options); - Ok(Self((handle, tx))) - } - +impl NonBlockingQueueWrapper +where + O: Clone + Send + Sync, + R: Send, +{ pub fn is_healthy(&self) -> bool { !self.0 .0.as_ref().map(|t| t.is_finished()).unwrap_or(true) } + pub fn inflight_ops(&self) -> Result { + if !self.is_healthy() { + return Err(anyhow::anyhow!( + "Queue is unhealthy: cannot use it anymore." + )); + } + + Ok(self.0 .1.len()) + } + fn shutdown(&mut self) -> Result<()> { if self.is_healthy() { let (tx, rx) = crossbeam_channel::bounded(1); - self.0 .1.send((Operation::Stop, tx))?; + self.0 .1.send((self.1.clone(), tx))?; rx.recv()?; let thread_opt = self.0 .0.take(); if let Some(thread) = thread_opt { @@ -119,27 +119,74 @@ impl PersistentQueueWithCapacity { } Ok(()) } +} - pub fn len(&self) -> Result { - if !self.is_healthy() { - return Err(anyhow::anyhow!( - "Queue is unhealthy: cannot use it anymore." - )); - } +impl Drop for NonBlockingQueueWrapper +where + O: Clone + Send + Sync, + R: Send, +{ + fn drop(&mut self) { + self.shutdown().unwrap(); + } +} - let (tx, rx) = crossbeam_channel::bounded(1); - self.0 .1.send((Operation::Length, tx))?; - Ok(Response(rx)) +impl PersistentQueueWithCapacity { + pub fn new( + path: &str, + max_elements: usize, + max_inflight_ops: usize, + db_options: rocksdb::Options, + ) -> Result { + let mut queue = + crate::PersistentQueueWithCapacity::new(path, max_elements, db_options).unwrap(); + let f = move |rx: Receiver<(Operation, Sender)>| { + loop { + match rx.recv() { + Ok((Operation::Push(values), resp_tx)) => { + let value_slices = values.iter().map(|e| e.as_slice()).collect::>(); + let resp = queue.push(&value_slices); + resp_tx.send(ResponseVariant::Push(resp))?; + } + Ok((Operation::Pop(max_elements), resp_tx)) => { + let resp = queue.pop(max_elements); + resp_tx.send(ResponseVariant::Pop(resp))?; + } + Ok((Operation::Length, resp_tx)) => { + let resp = queue.len(); + resp_tx.send(ResponseVariant::Length(resp))?; + } + Ok((Operation::DiskSize, resp_tx)) => { + let resp = queue.disk_size(); + resp_tx.send(ResponseVariant::Size(resp))?; + } + Ok((Operation::PayloadSize, resp_tx)) => { + let resp = queue.payload_size(); + resp_tx.send(ResponseVariant::Size(Ok(resp as usize)))?; + } + Ok((Operation::Stop, resp_tx)) => { + resp_tx.send(ResponseVariant::Stop)?; + break; + } + Err(e) => return Err(anyhow::anyhow!("Error receiving operation: {}", e)), + } + } + Ok(()) + }; + let (handle, tx) = start_op_loop(f, max_inflight_ops); + Ok(Self((handle, tx), Operation::Stop)) } - pub fn inflight_ops(&self) -> Result { + pub fn len(&self) -> Result { if !self.is_healthy() { return Err(anyhow::anyhow!( "Queue is unhealthy: cannot use it anymore." )); } - Ok(self.0 .1.len()) + let (tx, rx) = crossbeam_channel::bounded(1); + self.0 .1.send((Operation::Length, tx))?; + Ok(TypedResponse(rx)) } pub fn disk_size(&self) -> Result { @@ -151,7 +198,7 @@ impl PersistentQueueWithCapacity { let (tx, rx) = crossbeam_channel::bounded(1); self.0 .1.send((Operation::DiskSize, tx))?; - Ok(Response(rx)) + Ok(TypedResponse(rx)) } pub fn payload_size(&self) -> Result { @@ -163,7 +210,7 @@ impl PersistentQueueWithCapacity { let (tx, rx) = crossbeam_channel::bounded(1); self.0 .1.send((Operation::PayloadSize, tx))?; - Ok(Response(rx)) + Ok(TypedResponse(rx)) } pub fn push(&self, values: &[&[u8]]) -> Result { @@ -178,7 +225,7 @@ impl PersistentQueueWithCapacity { Operation::Push(values.iter().map(|e| e.to_vec()).collect()), tx, ))?; - Ok(Response(rx)) + Ok(TypedResponse(rx)) } pub fn pop(&self, max_elements: usize) -> Result { @@ -190,20 +237,148 @@ impl PersistentQueueWithCapacity { let (tx, rx) = crossbeam_channel::bounded(1); self.0 .1.send((Operation::Pop(max_elements), tx))?; - Ok(Response(rx)) + Ok(TypedResponse(rx)) } } -impl Drop for PersistentQueueWithCapacity { - fn drop(&mut self) { - self.shutdown().unwrap(); +impl MpmcQueue { + pub fn new(path: &str, ttl: Duration, max_inflight_ops: usize) -> Result { + let mut queue = mpmc::MpmcQueue::new(path, ttl)?; + let f = move |rx: Receiver<(MpmcOperation, Sender)>| { + loop { + match rx.recv() { + Ok((MpmcOperation::Add(values), resp_tx)) => { + let value_slices = values.iter().map(|e| e.as_slice()).collect::>(); + let resp = queue.add(&value_slices); + resp_tx.send(MpmcResponseVariant::Add(resp))?; + } + Ok((MpmcOperation::Next(max_elements, label, start_position), resp_tx)) => { + let resp = queue.next(max_elements, label.as_str(), start_position); + resp_tx.send(MpmcResponseVariant::Next(resp))?; + } + Ok((MpmcOperation::Length, resp_tx)) => { + let resp = queue.len(); + resp_tx.send(MpmcResponseVariant::Length(resp))?; + } + Ok((MpmcOperation::DiskSize, resp_tx)) => { + let resp = queue.disk_size(); + resp_tx.send(MpmcResponseVariant::Size(resp))?; + } + Ok((MpmcOperation::GetLabels, resp_tx)) => { + let resp = queue.get_labels(); + resp_tx.send(MpmcResponseVariant::GetLabels(resp))?; + } + Ok((MpmcOperation::RemoveLabel(label), resp_tx)) => { + let resp = queue.remove_label(label.as_str()); + resp_tx.send(MpmcResponseVariant::RemoveLabel(resp))?; + } + Ok((MpmcOperation::Stop, resp_tx)) => { + resp_tx.send(MpmcResponseVariant::Stop)?; + break; + } + Err(e) => return Err(anyhow::anyhow!("Error receiving operation: {}", e)), + } + } + Ok(()) + }; + let (handle, tx) = start_op_loop(f, max_inflight_ops); + Ok(Self((handle, tx), MpmcOperation::Stop)) + } + + pub fn disk_size(&self) -> Result { + if !self.is_healthy() { + return Err(anyhow::anyhow!( + "Queue is unhealthy: cannot use it anymore." + )); + } + + let (tx, rx) = crossbeam_channel::bounded(1); + self.0 .1.send((MpmcOperation::DiskSize, tx))?; + Ok(TypedResponse(rx)) + } + + pub fn len(&self) -> Result { + if !self.is_healthy() { + return Err(anyhow::anyhow!( + "Queue is unhealthy: cannot use it anymore." + )); + } + + let (tx, rx) = crossbeam_channel::bounded(1); + self.0 .1.send((MpmcOperation::Length, tx))?; + Ok(TypedResponse(rx)) + } + + pub fn add(&self, values: &[&[u8]]) -> Result { + if !self.is_healthy() { + return Err(anyhow::anyhow!( + "Queue is unhealthy: cannot use it anymore." + )); + } + + let (tx, rx) = crossbeam_channel::bounded(1); + self.0 .1.send(( + MpmcOperation::Add(values.iter().map(|e| e.to_vec()).collect()), + tx, + ))?; + Ok(TypedResponse(rx)) + } + + pub fn next( + &self, + max_elts: usize, + label: &str, + start_position: StartPosition, + ) -> Result { + if !self.is_healthy() { + return Err(anyhow::anyhow!( + "Queue is unhealthy: cannot use it anymore." + )); + } + + let (tx, rx) = crossbeam_channel::bounded(1); + self.0 .1.send(( + MpmcOperation::Next(max_elts, label.to_string(), start_position), + tx, + ))?; + Ok(TypedResponse(rx)) + } + + pub fn get_labels(&self) -> Result { + if !self.is_healthy() { + return Err(anyhow::anyhow!( + "Queue is unhealthy: cannot use it anymore." + )); + } + + let (tx, rx) = crossbeam_channel::bounded(1); + self.0 .1.send((MpmcOperation::GetLabels, tx))?; + Ok(TypedResponse(rx)) + } + + pub fn remove_label(&self, label: &str) -> Result { + if !self.is_healthy() { + return Err(anyhow::anyhow!( + "Queue is unhealthy: cannot use it anymore." + )); + } + + let (tx, rx) = crossbeam_channel::bounded(1); + self.0 + .1 + .send((MpmcOperation::RemoveLabel(label.to_string()), tx))?; + Ok(TypedResponse(rx)) } } #[cfg(test)] mod tests { + use crate::mpmc; + use crate::mpmc::StartPosition; + use std::time::Duration; + #[test] - fn fresh_healthy() { + fn persistent_queue_fresh_healthy() { let path = "/tmp/test_fresh_healthy".to_string(); _ = crate::PersistentQueueWithCapacity::remove_db(&path); let queue = @@ -216,7 +391,7 @@ mod tests { } #[test] - fn push_pop() { + fn persistent_queue_push_pop() { let path = "/tmp/test_push_pop".to_string(); _ = crate::PersistentQueueWithCapacity::remove_db(&path); let queue = @@ -245,7 +420,7 @@ mod tests { } #[test] - fn size() { + fn persistent_queue_size() { let path = "/tmp/test_size".to_string(); _ = crate::PersistentQueueWithCapacity::remove_db(&path); let queue = @@ -255,4 +430,54 @@ mod tests { let size = size_query.get().unwrap(); assert!(matches!(size, super::ResponseVariant::Size(Ok(r)) if r > 0)); } + + #[test] + fn mpmc_queue_fresh_healthy() { + let path = "/tmp/test_mpmc_fresh_healthy".to_string(); + _ = mpmc::MpmcQueue::remove_db(&path); + let queue = super::MpmcQueue::new(&path, Duration::from_secs(10), 1000).unwrap(); + assert!(queue.is_healthy()); + let resp = queue.len().unwrap().get().unwrap(); + assert!(matches!(resp, super::MpmcResponseVariant::Length(0))); + _ = mpmc::MpmcQueue::remove_db(&path); + } + + #[test] + fn mpmc_queue_add_next() { + let path = "/tmp/test_mpmc_add_next".to_string(); + _ = mpmc::MpmcQueue::remove_db(&path); + let queue = super::MpmcQueue::new(&path, Duration::from_secs(1), 1000).unwrap(); + assert!(queue.is_healthy()); + + let resp = queue.len().unwrap().get().unwrap(); + assert!(matches!(resp, super::MpmcResponseVariant::Length(0))); + + let resp = queue.add(&[&[1u8, 2u8, 3u8]]).unwrap().get().unwrap(); + assert!(matches!(resp, super::MpmcResponseVariant::Add(Ok(())))); + + let resp = queue.len().unwrap().get().unwrap(); + assert!(matches!(resp, super::MpmcResponseVariant::Length(1))); + let resp = queue + .next(1, "label", StartPosition::Oldest) + .unwrap() + .get() + .unwrap(); + assert!( + matches!(resp, super::MpmcResponseVariant::Next(Ok(v)) if v == vec![vec![1u8, 2u8, 3u8]]) + ); + let resp = queue.len().unwrap().get().unwrap(); + assert!(matches!(resp, super::MpmcResponseVariant::Length(1))); + _ = mpmc::MpmcQueue::remove_db(&path); + } + + #[test] + fn mpmc_queue_size() { + let path = "/tmp/test_mpmc_size".to_string(); + _ = mpmc::MpmcQueue::remove_db(&path); + let queue = super::MpmcQueue::new(&path, Duration::from_secs(1), 1000).unwrap(); + let size_query = queue.disk_size().unwrap(); + let size = size_query.get().unwrap(); + assert!(matches!(size, super::MpmcResponseVariant::Size(Ok(r)) if r > 0)); + _ = mpmc::MpmcQueue::remove_db(&path); + } } diff --git a/queue_rs/src/utilities.rs b/queue_rs/src/utilities.rs new file mode 100644 index 0000000..c8d2931 --- /dev/null +++ b/queue_rs/src/utilities.rs @@ -0,0 +1,38 @@ +use crate::{MAX_ALLOWED_INDEX, U64_BYTE_LEN}; +use chrono::Utc; + +pub fn u64_from_byte_vec(v: &[u8]) -> u64 { + let mut buf = [0u8; U64_BYTE_LEN]; + buf.copy_from_slice(v); + u64::from_le_bytes(buf) +} + +pub fn index_to_key(index: u64) -> [u8; U64_BYTE_LEN] { + index.to_le_bytes() +} + +pub fn key_to_index(v: Box<[u8]>) -> u64 { + let mut buf = [0u8; U64_BYTE_LEN]; + buf.copy_from_slice(&v); + u64::from_le_bytes(buf) +} + +pub fn next_index(index: u64) -> u64 { + let mut next = index + 1; + if next == MAX_ALLOWED_INDEX { + next = 0; + } + next +} + +pub fn previous_index(index: u64) -> u64 { + if index == 0 { + MAX_ALLOWED_INDEX - 1 + } else { + index - 1 + } +} + +pub fn current_timestamp() -> u64 { + Utc::now().timestamp_nanos_opt().unwrap() as u64 +}