From e07d74d027eea20dc989be343b43016d16f1116e Mon Sep 17 00:00:00 2001 From: Joseph Perez Date: Wed, 11 Dec 2024 18:03:59 +0100 Subject: [PATCH] refactor: extract the caching logic into dedicated types --- zenoh/src/api/builders/publisher.rs | 95 ++++++++++------------ zenoh/src/api/publisher.rs | 120 +++++++++++++++++++++------- zenoh/src/api/session.rs | 60 +++++--------- 3 files changed, 157 insertions(+), 118 deletions(-) diff --git a/zenoh/src/api/builders/publisher.rs b/zenoh/src/api/builders/publisher.rs index 1264cf39d6..8d6931260e 100644 --- a/zenoh/src/api/builders/publisher.rs +++ b/zenoh/src/api/builders/publisher.rs @@ -11,10 +11,7 @@ // Contributors: // ZettaScale Zenoh Team, // -use std::{ - future::{IntoFuture, Ready}, - sync::atomic::AtomicU64, -}; +use std::future::{IntoFuture, Ready}; use itertools::Itertools; use zenoh_config::qos::PublisherQoSConfig; @@ -34,8 +31,8 @@ use crate::{ bytes::{OptionZBytes, ZBytes}, encoding::Encoding, key_expr::KeyExpr, - publisher::{Priority, Publisher}, - sample::{Locality, SampleKind}, + publisher::{Priority, Publisher, PublisherCache, PublisherCacheValue}, + sample::Locality, }, Session, }; @@ -212,12 +209,10 @@ impl Wait for PublicationBuilder, PublicationBuilderPut #[inline] fn wait(mut self) -> ::To { self.publisher = self.publisher.apply_qos_overwrites(); - self.publisher.session.0.resolve_put( - None, + self.publisher.session.0.resolve_push( + &mut PublisherCacheValue::default(), &self.publisher.key_expr?, - self.kind.payload, - SampleKind::Put, - self.kind.encoding, + Some(self.kind), self.publisher.congestion_control, self.publisher.priority, self.publisher.is_express, @@ -236,12 +231,10 @@ impl Wait for PublicationBuilder, PublicationBuilderDel #[inline] fn wait(mut self) -> ::To { self.publisher = self.publisher.apply_qos_overwrites(); - self.publisher.session.0.resolve_put( - None, + self.publisher.session.0.resolve_push( + &mut PublisherCacheValue::default(), &self.publisher.key_expr?, - ZBytes::new(), - SampleKind::Delete, - Encoding::ZENOH_BYTES, + None, self.publisher.congestion_control, self.publisher.priority, self.publisher.is_express, @@ -446,7 +439,7 @@ impl Wait for PublisherBuilder<'_, '_> { .declare_publisher_inner(key_expr.clone(), self.destination)?; Ok(Publisher { session: self.session.downgrade(), - cache: AtomicU64::new(0), + cache: PublisherCache::default(), id, key_expr, encoding: self.encoding, @@ -474,45 +467,45 @@ impl IntoFuture for PublisherBuilder<'_, '_> { impl Wait for PublicationBuilder<&Publisher<'_>, PublicationBuilderPut> { fn wait(self) -> ::To { - self.publisher.session.resolve_put( - Some(&self.publisher.cache), - &self.publisher.key_expr, - self.kind.payload, - SampleKind::Put, - self.kind.encoding, - self.publisher.congestion_control, - self.publisher.priority, - self.publisher.is_express, - self.publisher.destination, - #[cfg(feature = "unstable")] - self.publisher.reliability, - self.timestamp, - #[cfg(feature = "unstable")] - self.source_info, - self.attachment, - ) + self.publisher.cache.with_cache(|cached| { + self.publisher.session.resolve_push( + cached, + &self.publisher.key_expr, + Some(self.kind), + self.publisher.congestion_control, + self.publisher.priority, + self.publisher.is_express, + self.publisher.destination, + #[cfg(feature = "unstable")] + self.publisher.reliability, + self.timestamp, + #[cfg(feature = "unstable")] + self.source_info, + self.attachment, + ) + }) } } impl Wait for PublicationBuilder<&Publisher<'_>, PublicationBuilderDelete> { fn wait(self) -> ::To { - self.publisher.session.resolve_put( - Some(&self.publisher.cache), - &self.publisher.key_expr, - ZBytes::new(), - SampleKind::Delete, - Encoding::ZENOH_BYTES, - self.publisher.congestion_control, - self.publisher.priority, - self.publisher.is_express, - self.publisher.destination, - #[cfg(feature = "unstable")] - self.publisher.reliability, - self.timestamp, - #[cfg(feature = "unstable")] - self.source_info, - self.attachment, - ) + self.publisher.cache.with_cache(|cached| { + self.publisher.session.resolve_push( + cached, + &self.publisher.key_expr, + None, + self.publisher.congestion_control, + self.publisher.priority, + self.publisher.is_express, + self.publisher.destination, + #[cfg(feature = "unstable")] + self.publisher.reliability, + self.timestamp, + #[cfg(feature = "unstable")] + self.source_info, + self.attachment, + ) + }) } } diff --git a/zenoh/src/api/publisher.rs b/zenoh/src/api/publisher.rs index 5d9bb4f13f..4aefe96952 100644 --- a/zenoh/src/api/publisher.rs +++ b/zenoh/src/api/publisher.rs @@ -17,7 +17,7 @@ use std::{ fmt, future::{IntoFuture, Ready}, pin::Pin, - sync::atomic::AtomicU64, + sync::atomic::{AtomicU64, Ordering}, task::{Context, Poll}, }; @@ -41,17 +41,20 @@ use { zenoh_protocol::core::Reliability, }; -use crate::api::{ - builders::publisher::{ - PublicationBuilder, PublicationBuilderDelete, PublicationBuilderPut, - PublisherDeleteBuilder, PublisherPutBuilder, +use crate::{ + api::{ + builders::publisher::{ + PublicationBuilder, PublicationBuilderDelete, PublicationBuilderPut, + PublisherDeleteBuilder, PublisherPutBuilder, + }, + bytes::ZBytes, + encoding::Encoding, + key_expr::KeyExpr, + sample::{Locality, Sample, SampleFields}, + session::{UndeclarableSealed, WeakSession}, + Id, }, - bytes::ZBytes, - encoding::Encoding, - key_expr::KeyExpr, - sample::{Locality, Sample, SampleFields}, - session::{UndeclarableSealed, WeakSession}, - Id, + sample::SampleKind, }; pub(crate) struct PublisherState { @@ -70,6 +73,74 @@ impl fmt::Debug for PublisherState { } } +#[derive(Default)] +pub(crate) struct PublisherCache(AtomicU64); + +impl PublisherCache { + pub(crate) fn with_cache(&self, f: impl FnOnce(&mut PublisherCacheValue) -> R) -> R { + let cached = self.0.load(Ordering::Relaxed); + let mut to_cache = PublisherCacheValue(cached); + let res = f(&mut to_cache); + if to_cache.0 != cached { + let _ = self.0.compare_exchange_weak( + cached, + to_cache.0, + Ordering::Relaxed, + Ordering::Relaxed, + ); + } + res + } +} + +impl fmt::Debug for PublisherCache { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("PublisherCache") + .field(&PublisherCacheValue(self.0.load(Ordering::Relaxed))) + .finish() + } +} +#[derive(Default, PartialEq, Eq)] +pub(crate) struct PublisherCacheValue(u64); + +impl PublisherCacheValue { + const VERSION_SHIFT: usize = 2; + const NO_REMOTE: u64 = 0b01; + const NO_LOCAL: u64 = 0b10; + + pub(crate) fn match_subscription_version(&mut self, version: u64) { + if self.0 >> Self::VERSION_SHIFT != version { + self.0 = version << Self::VERSION_SHIFT; + } + } + + pub(crate) fn has_remote_sub(&self) -> bool { + self.0 & Self::NO_REMOTE == 0 + } + + pub(crate) fn set_no_remote_sub(&mut self) { + self.0 |= Self::NO_REMOTE; + } + + pub(crate) fn has_local_sub(&self) -> bool { + self.0 & Self::NO_LOCAL == 0 + } + + pub(crate) fn set_no_local_sub(&mut self) { + self.0 |= Self::NO_LOCAL; + } +} + +impl fmt::Debug for PublisherCacheValue { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("PublisherCacheValue") + .field("subscription_version", &(self.0 >> Self::VERSION_SHIFT)) + .field("has_remote_sub", &self.has_remote_sub()) + .field("has_local_sub", &self.has_local_sub()) + .finish() + } +} + /// A publisher that allows to send data through a stream. /// /// Publishers are automatically undeclared when dropped. @@ -102,7 +173,7 @@ impl fmt::Debug for PublisherState { #[derive(Debug)] pub struct Publisher<'a> { pub(crate) session: WeakSession, - pub(crate) cache: AtomicU64, + pub(crate) cache: PublisherCache, pub(crate) id: Id, pub(crate) key_expr: KeyExpr<'a>, pub(crate) encoding: Encoding, @@ -387,23 +458,14 @@ impl Sink for Publisher<'_> { attachment, .. } = item.into(); - self.session.resolve_put( - Some(&self.cache), - &self.key_expr, - payload, - kind, - encoding, - self.congestion_control, - self.priority, - self.is_express, - self.destination, - #[cfg(feature = "unstable")] - self.reliability, - None, - #[cfg(feature = "unstable")] - SourceInfo::empty(), - attachment, - ) + match kind { + SampleKind::Put => self + .put(payload) + .encoding(encoding) + .attachment(attachment) + .wait(), + SampleKind::Delete => self.delete().attachment(attachment).wait(), + } } #[inline] diff --git a/zenoh/src/api/session.rs b/zenoh/src/api/session.rs index bdc0be2495..6c87382976 100644 --- a/zenoh/src/api/session.rs +++ b/zenoh/src/api/session.rs @@ -19,7 +19,7 @@ use std::{ fmt, ops::Deref, sync::{ - atomic::{AtomicU16, AtomicU64, Ordering}, + atomic::{AtomicU16, Ordering}, Arc, Mutex, RwLock, }, time::{Duration, SystemTime, UNIX_EPOCH}, @@ -99,7 +99,7 @@ use crate::{ handlers::{Callback, DefaultHandler}, info::SessionInfo, key_expr::{KeyExpr, KeyExprInner}, - publisher::{Priority, PublisherState}, + publisher::{Priority, PublisherCacheValue, PublisherState}, query::{ConsolidationMode, QueryConsolidation, QueryState, QueryTarget, Reply}, queryable::{Query, QueryInner, QueryableState}, sample::{DataInfo, DataInfoIntoSample, Locality, QoS, Sample, SampleKind}, @@ -2137,13 +2137,11 @@ impl SessionInner { #[allow(clippy::too_many_arguments)] // TODO fixme #[inline(always)] - pub(crate) fn resolve_put( + pub(crate) fn resolve_push( &self, - cache: Option<&AtomicU64>, + cache: &mut PublisherCacheValue, key_expr: &KeyExpr, - payload: ZBytes, - kind: SampleKind, - encoding: Encoding, + mut put: Option, congestion_control: CongestionControl, priority: Priority, is_express: bool, @@ -2153,9 +2151,6 @@ impl SessionInner { #[cfg(feature = "unstable")] source_info: SourceInfo, attachment: Option, ) -> ZResult<()> { - const NO_REMOTE_FLAG: u64 = 0b01; - const NO_LOCAL_FLAG: u64 = 0b10; - const VERSION_SHIFT: u64 = 2; trace!("write({:?}, [...])", key_expr); let state = zread!(self.state); let primitives = state @@ -2163,26 +2158,14 @@ impl SessionInner { .as_ref() .cloned() .ok_or(SessionClosedError)?; - let version = state.subscription_version; + cache.match_subscription_version(state.subscription_version); drop(state); - let mut cached = 0; - let mut update_cache = None; - if let Some(cache) = cache { - let c = cache.load(Ordering::Relaxed); - if (c >> VERSION_SHIFT) == version { - cached = c; - } else { - cached = version << VERSION_SHIFT; - } - update_cache = Some(move |cached| { - if cached != c { - let _ = cache.compare_exchange(c, cached, Ordering::Relaxed, Ordering::Relaxed); - } - }); - } let timestamp = timestamp.or_else(|| self.runtime.new_timestamp()); let wire_expr = key_expr.to_wire(self); - if (cached & NO_REMOTE_FLAG) == 0 && destination != Locality::SessionLocal { + let push_remote = cache.has_remote_sub() && destination != Locality::SessionLocal; + let push_local = cache.has_local_sub() && destination != Locality::Remote; + if push_remote { + let put = if push_local { put.clone() } else { put.take() }; let remote = primitives.route_data( Push { wire_expr: wire_expr.to_owned(), @@ -2193,10 +2176,10 @@ impl SessionInner { ), ext_tstamp: None, ext_nodeid: push::ext::NodeIdType::DEFAULT, - payload: match kind { - SampleKind::Put => PushBody::Put(Put { + payload: match put.clone() { + Some(put) => PushBody::Put(Put { timestamp, - encoding: encoding.clone().into(), + encoding: put.encoding.into(), #[cfg(feature = "unstable")] ext_sinfo: source_info.into(), #[cfg(not(feature = "unstable"))] @@ -2205,9 +2188,9 @@ impl SessionInner { ext_shm: None, ext_attachment: attachment.clone().map(|a| a.into()), ext_unknown: vec![], - payload: payload.clone().into(), + payload: put.payload.into(), }), - SampleKind::Delete => PushBody::Del(Del { + None => PushBody::Del(Del { timestamp, #[cfg(feature = "unstable")] ext_sinfo: source_info.into(), @@ -2224,10 +2207,14 @@ impl SessionInner { Reliability::DEFAULT, ); if !remote { - cached |= NO_REMOTE_FLAG + cache.set_no_remote_sub(); } } - if (cached & NO_LOCAL_FLAG) == 0 && destination != Locality::Remote { + if push_local { + let (kind, payload, encoding) = match put { + Some(put) => (SampleKind::Put, put.payload, put.encoding), + None => (SampleKind::Delete, ZBytes::default(), Encoding::default()), + }; let data_info = DataInfo { kind, encoding: Some(encoding), @@ -2252,12 +2239,9 @@ impl SessionInner { attachment, ); if !local { - cached |= NO_LOCAL_FLAG; + cache.set_no_local_sub(); } } - if let Some(update) = update_cache { - update(cached); - } Ok(()) }