From 523947b60e99fd3f8c5ae939fcad358b9bf8c893 Mon Sep 17 00:00:00 2001 From: Will Nelson Date: Sat, 5 Feb 2022 10:23:28 -0800 Subject: [PATCH] Refactor for RPC call timeouts --- Cargo.lock | 1 + brokers/Cargo.toml | 4 +- brokers/src/error.rs | 2 +- brokers/src/redis.rs | 252 +++++++++++++++++++---------------- brokers/src/redis/message.rs | 94 +++++++++++++ brokers/src/redis/pubsub.rs | 95 +++++++++++++ brokers/src/redis/rpc.rs | 38 ++++++ 7 files changed, 371 insertions(+), 115 deletions(-) create mode 100644 brokers/src/redis/message.rs create mode 100644 brokers/src/redis/pubsub.rs create mode 100644 brokers/src/redis/rpc.rs diff --git a/Cargo.lock b/Cargo.lock index 2e14019..b1d6405 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1733,6 +1733,7 @@ dependencies = [ "lapin", "log", "nanoid", + "pin-project 1.0.8", "redis", "redis-subscribe", "rmp-serde", diff --git a/brokers/Cargo.toml b/brokers/Cargo.toml index 51259a6..3d86ae2 100644 --- a/brokers/Cargo.toml +++ b/brokers/Cargo.toml @@ -17,6 +17,7 @@ env_logger = "0.7" futures = "0.3" log = "0.4" nanoid = "0.3" +pin-project = "1.0" redis = { version = "0.21", optional = true, default-features = false, features = ["streams"] } redis-subscribe = { git = "https://github.com/appellation/redis-subscribe", branch = "feat/impl-error", optional = true } rmp-serde = "0.15" @@ -27,13 +28,14 @@ tokio-stream = { version = "0.1", features = ["sync"] } [dependencies.tokio] version = "1.0" features = ["rt-multi-thread"] +optional = true [dev-dependencies.tokio] version = "1.0" features = ["rt-multi-thread", "macros"] [features] -amqp-broker = ["lapin"] +amqp-broker = ["lapin", "tokio"] redis-broker = ["deadpool-redis", "redis", "redis-subscribe"] [[example]] diff --git a/brokers/src/error.rs b/brokers/src/error.rs index ea796ed..bab1c59 100644 --- a/brokers/src/error.rs +++ b/brokers/src/error.rs @@ -2,11 +2,11 @@ use deadpool_redis::{redis::RedisError, PoolError}; #[cfg(feature = "amqp-broker")] use lapin::Error as LapinError; -use tokio_stream::wrappers::errors::BroadcastStreamRecvError; use std::{io::Error as IoError, result::Result as StdResult}; use thiserror::Error; #[cfg(feature = "amqp-broker")] use tokio::sync::oneshot::error::RecvError; +use tokio_stream::wrappers::errors::BroadcastStreamRecvError; pub type Result = StdResult; diff --git a/brokers/src/redis.rs b/brokers/src/redis.rs index ffd17e7..f9d3f6c 100644 --- a/brokers/src/redis.rs +++ b/brokers/src/redis.rs @@ -1,91 +1,41 @@ -use std::{borrow::Cow, sync::Arc}; +use std::{ + borrow::Cow, + fmt::{self, Debug}, + time::{SystemTime, UNIX_EPOCH}, +}; pub use deadpool_redis; use deadpool_redis::{ redis::{ - streams::{StreamId, StreamRangeReply, StreamReadOptions, StreamReadReply}, + streams::{StreamRangeReply, StreamReadOptions, StreamReadReply}, AsyncCommands, FromRedisValue, RedisError, Value, }, Connection, Pool, }; use futures::{ stream::{iter, select_all}, - stream_select, StreamExt, TryStream, TryStreamExt, + stream_select, TryStream, TryStreamExt, }; use nanoid::nanoid; -use redis_subscribe::RedisSub; use serde::{de::DeserializeOwned, Serialize}; -use tokio::{spawn, sync::broadcast}; -use tokio_stream::wrappers::BroadcastStream; use crate::{ error::{Error, Result}, util::stream::repeat_fn, }; +use self::{message::Message, pubsub::BroadcastSub, rpc::Rpc}; + +pub mod message; +pub mod pubsub; +pub mod rpc; + const DEFAULT_MAX_CHUNK: usize = 10; const DEFAULT_BLOCK_INTERVAL: usize = 5000; const STREAM_DATA_KEY: &'static str = "data"; +const STREAM_TIMEOUT_KEY: &'static str = "timeout_at"; -/// A message received from the broker. -#[derive(Clone)] -pub struct Message<'a, V> { - /// The group this message belongs to. - pub group: &'a str, - /// The event this message signals. - pub event: Cow<'a, str>, - /// The ID of this message (generated by Redis). - pub id: String, - /// The data of this message. Always present unless there is a bug with a client implementation. - pub data: Option, - pool: &'a Pool, -} - -impl<'a, V> Message<'a, V> -where - V: DeserializeOwned, -{ - fn new(id: StreamId, group: &'a str, event: Cow<'a, str>, pool: &'a Pool) -> Self { - let data = id - .get(STREAM_DATA_KEY) - .and_then(|data: Vec| rmp_serde::from_read_ref(&data).ok()); - - Message { - group, - event, - id: id.id, - pool, - data, - } - } -} - -impl<'a, V> Message<'a, V> { - /// Acknowledge receipt of the message. This should always be called, since un-acked messages - /// will be reclaimed by other clients. - pub async fn ack(&self) -> Result<()> { - self.pool - .get() - .await? - .xack(&*self.event, self.group, &[&self.id]) - .await?; - - Ok(()) - } - - /// Reply to this message. - pub async fn reply(&self, data: &impl Serialize) -> Result<()> { - let key = format!("{}:{}", self.event, self.id); - let serialized = rmp_serde::to_vec(data)?; - self.pool.get().await?.publish(key, serialized).await?; - - Ok(()) - } -} - -pub struct Rpc; - -// #[derive(Debug)] +/// RedisBroker is internally reference counted and can be safely cloned. pub struct RedisBroker<'a> { /// The consumer name of this broker. Should be unique to the container/machine consuming /// messages. @@ -99,32 +49,52 @@ pub struct RedisBroker<'a> { /// time period will be reclaimed by other clients. pub max_operation_time: usize, pool: Pool, - pubsub: Arc, - pubsub_msgs: broadcast::Sender>, + pubsub: BroadcastSub, read_opts: StreamReadOptions, } +impl<'a> Clone for RedisBroker<'a> { + fn clone(&self) -> Self { + Self { + name: self.name.clone(), + group: self.group.clone(), + max_chunk: self.max_chunk, + max_operation_time: self.max_operation_time, + pool: self.pool.clone(), + pubsub: self.pubsub.clone(), + read_opts: Self::make_read_opts(&*self.group, &*self.name), + } + } +} + +impl<'a> Debug for RedisBroker<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("RedisBroker") + .field("name", &self.name) + .field("group", &self.group) + .field("max_chunk", &self.max_chunk) + .field("max_operation_time", &self.max_operation_time) + .field("pubsub", &self.pubsub) + .field("read_opts", &self.read_opts) + .finish_non_exhaustive() + } +} + impl<'a> RedisBroker<'a> { + fn make_read_opts(group: &str, name: &str) -> StreamReadOptions { + StreamReadOptions::default() + .group(group, name) + .count(DEFAULT_MAX_CHUNK) + .block(DEFAULT_BLOCK_INTERVAL) + } + /// Creates a new broker with sensible defaults. pub fn new(group: impl Into>, pool: Pool, address: &str) -> RedisBroker<'a> { let group = group.into(); let name = nanoid!(); - let read_opts = StreamReadOptions::default() - .group(&*group, &name) - .count(DEFAULT_MAX_CHUNK) - .block(DEFAULT_BLOCK_INTERVAL); + let read_opts = RedisBroker::make_read_opts(&*group, &name); - let pubsub = Arc::new(RedisSub::new(&address)); - - let (tx, _) = broadcast::channel(1); - let task_pubsub = Arc::clone(&pubsub); - let task_tx = tx.clone(); - spawn(async move { - let mut stream = task_pubsub.listen().await.unwrap(); - while let Some(msg) = stream.next().await { - task_tx.send(Arc::new(msg)).unwrap(); - } - }); + let pubsub = BroadcastSub::new(address); Self { name: Cow::Owned(name), @@ -133,47 +103,56 @@ impl<'a> RedisBroker<'a> { max_operation_time: DEFAULT_BLOCK_INTERVAL, pool, pubsub, - pubsub_msgs: tx, read_opts, } } /// Publishes an event to the broker. Returned value is the ID of the message. pub async fn publish(&self, event: &str, data: &impl Serialize) -> Result { - let serialized = rmp_serde::to_vec(data)?; - Ok(self - .get_conn() - .await? - .xadd(event, "*", &[(STREAM_DATA_KEY, serialized)]) - .await?) + self.publish_timeout(event, data, None).await } - pub async fn call(&self, event: &str, data: &impl Serialize) -> Result> - where - V: DeserializeOwned, - { - let id = self.publish(event, data).await?; + pub async fn call( + &self, + event: &str, + data: &impl Serialize, + timeout: Option, + ) -> Result> { + let id = self.publish_timeout(event, data, timeout).await?; let name = format!("{}:{}", event, id); - self.pubsub.subscribe(name.clone()).await?; + Ok(Rpc { + name, + broker: &self, + }) + } - let data = BroadcastStream::new(self.pubsub_msgs.subscribe()) - .err_into::() - .try_filter_map(|msg| async move { - match &*msg { - redis_subscribe::Message::Message { message, .. } => { - Ok(rmp_serde::from_read(message.as_bytes())?) - } - _ => Ok(None), - } - }) - .boxed() - .next() - .await - .transpose()?; + async fn publish_timeout( + &self, + event: &str, + data: &impl Serialize, + maybe_timeout: Option, + ) -> Result { + let serialized_data = rmp_serde::to_vec(data)?; + let mut conn = self.get_conn().await?; + + let args = match maybe_timeout { + Some(timeout) => vec![ + (STREAM_DATA_KEY, serialized_data), + ( + STREAM_TIMEOUT_KEY, + timeout + .duration_since(UNIX_EPOCH) + .unwrap() + .as_nanos() + .to_string() + .into_bytes(), + ), + ], + None => vec![(STREAM_DATA_KEY, serialized_data)], + }; - self.pubsub.unsubscribe(name).await?; - Ok(data) + Ok(conn.xadd(event, "*", &args).await?) } pub async fn subscribe(&self, events: &[&str]) -> Result<()> { @@ -230,7 +209,7 @@ impl<'a> RedisBroker<'a> { id, &group, Cow::Borrowed(event), - pool, + self, )) }); @@ -257,7 +236,7 @@ impl<'a> RedisBroker<'a> { move |event| { let key = Cow::from(event.key); event.ids.into_iter().map(move |id| { - Ok(Message::::new(id, group, key.clone(), pool)) + Ok(Message::::new(id, group, key.clone(), self)) }) }, ); @@ -278,8 +257,12 @@ impl<'a> RedisBroker<'a> { #[cfg(test)] mod test { + use std::time::{Duration, SystemTime}; + use deadpool_redis::{Manager, Pool}; use futures::TryStreamExt; + use redis::cmd; + use tokio::{spawn, try_join}; use super::RedisBroker; @@ -308,4 +291,47 @@ mod test { assert_eq!(msg.data.expect("data"), vec![1, 2, 3]); } + + #[tokio::test] + async fn rpc_timeout() { + let group = "foo"; + let manager = Manager::new("redis://localhost:6379").expect("create manager"); + let pool = Pool::new(manager, 32); + + let _: () = cmd("FLUSHDB") + .query_async(&mut pool.get().await.expect("redis connection")) + .await + .expect("flush db"); + + let broker1 = RedisBroker::new(group, pool, "localhost:6379"); + let broker2 = broker1.clone(); + + let events = ["def"]; + broker1.subscribe(&events).await.expect("subscribed"); + + let timeout = Some(SystemTime::now() + Duration::from_millis(500)); + + let call_fut = spawn(async move { + broker2 + .call("def", &[1u8, 2, 3], timeout) + .await + .expect("published"); + }); + + let consume_fut = spawn(async move { + let mut consumer = broker1.consume::>(&events); + let msg = consumer + .try_next() + .await + .expect("message") + .expect("message"); + + msg.ack().await.expect("ack"); + + assert_eq!(msg.data.as_ref().expect("data"), &[1, 2, 3]); + assert_eq!(msg.timeout_at, timeout); + }); + + try_join!(consume_fut, call_fut).expect("cancelation futures"); + } } diff --git a/brokers/src/redis/message.rs b/brokers/src/redis/message.rs new file mode 100644 index 0000000..69f23c5 --- /dev/null +++ b/brokers/src/redis/message.rs @@ -0,0 +1,94 @@ +use std::{ + borrow::Cow, + time::{Duration, SystemTime, UNIX_EPOCH}, +}; + +use redis::{streams::StreamId, AsyncCommands}; +use serde::{de::DeserializeOwned, Serialize}; + +use crate::error::Result; + +use super::{RedisBroker, STREAM_DATA_KEY, STREAM_TIMEOUT_KEY}; + +/// A message received from the broker. +#[derive(Debug, Clone)] +pub struct Message<'broker, V> { + /// The group this message belongs to. + pub group: &'broker str, + /// The event this message signals. + pub event: Cow<'broker, str>, + /// The ID of this message (generated by Redis). + pub id: String, + /// The data of this message. Always present unless there is a bug with a client implementation. + pub data: Option, + /// When this message times out. Clients should cancel work if it is still in progress after + /// this instant. + pub timeout_at: Option, + broker: &'broker RedisBroker<'broker>, +} + +impl<'broker, V> PartialEq for Message<'broker, V> { + fn eq(&self, other: &Self) -> bool { + self.id == other.id + } +} + +impl<'broker, V> Eq for Message<'broker, V> {} + +impl<'broker, V> Message<'broker, V> +where + V: DeserializeOwned, +{ + pub(crate) fn new( + id: StreamId, + group: &'broker str, + event: Cow<'broker, str>, + broker: &'broker RedisBroker, + ) -> Self { + let data = id + .get(STREAM_DATA_KEY) + .and_then(|data: Vec| rmp_serde::from_read_ref(&data).ok()); + + let timeout_at = id + .get(STREAM_TIMEOUT_KEY) + .map(|timeout| UNIX_EPOCH + Duration::from_nanos(timeout)); + + Message { + group, + event, + id: id.id, + data, + timeout_at, + broker, + } + } +} + +impl<'broker, V> Message<'broker, V> { + /// Acknowledge receipt of the message. This should always be called, since un-acked messages + /// will be reclaimed by other clients. + pub async fn ack(&self) -> Result<()> { + self.broker + .pool + .get() + .await? + .xack(&*self.event, self.group, &[&self.id]) + .await?; + + Ok(()) + } + + /// Reply to this message. + pub async fn reply(&self, data: &impl Serialize) -> Result<()> { + let key = format!("{}:{}", self.event, self.id); + let serialized = rmp_serde::to_vec(data)?; + self.broker + .pool + .get() + .await? + .publish(key, serialized) + .await?; + + Ok(()) + } +} diff --git a/brokers/src/redis/pubsub.rs b/brokers/src/redis/pubsub.rs new file mode 100644 index 0000000..95e3560 --- /dev/null +++ b/brokers/src/redis/pubsub.rs @@ -0,0 +1,95 @@ +use std::{ + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; + +use futures::{future::ready, Stream, StreamExt, TryStream, TryStreamExt}; +use pin_project::{pin_project, pinned_drop}; +use redis_subscribe::{Message, RedisSub}; +use tokio::{spawn, sync::broadcast}; +use tokio_stream::wrappers::{errors::BroadcastStreamRecvError, BroadcastStream}; + +use crate::error::{Error, Result}; + +#[derive(Debug, Clone)] +pub struct BroadcastSub { + pubsub: Arc, + pubsub_msgs: broadcast::Sender>, +} + +impl BroadcastSub { + pub fn new(addr: &str) -> Self { + let pubsub = Arc::new(RedisSub::new(addr)); + + let (tx, _) = broadcast::channel(1); + let task_pubsub = Arc::clone(&pubsub); + let task_tx = tx.clone(); + spawn(async move { + let mut stream = task_pubsub.listen().await.unwrap(); + while let Some(msg) = stream.next().await { + let _ = task_tx.send(Arc::new(msg)); + } + }); + + Self { + pubsub, + pubsub_msgs: tx, + } + } + + pub async fn subscribe( + &self, + channel: String, + ) -> Result> { + self.pubsub.subscribe(channel.clone()).await?; + + let stream = SubStream { + pubsub: Arc::clone(&self.pubsub), + channel: channel.clone(), + msgs: BroadcastStream::new(self.pubsub_msgs.subscribe()), + }; + + Ok(stream.err_into::().try_filter_map(move |msg| { + ready(match &*msg { + Message::Message { + channel: new_ch, + message, + } if &channel == new_ch => Ok(Some(message.clone())), + _ => Ok(None), + }) + })) + } +} + +#[derive(Debug)] +#[pin_project(PinnedDrop)] +struct SubStream { + pubsub: Arc, + channel: String, + #[pin] + msgs: BroadcastStream, +} + +impl Stream for SubStream +where + T: 'static + Clone + Send, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + this.msgs.poll_next(cx) + } +} + +#[pinned_drop] +impl PinnedDrop for SubStream { + fn drop(self: Pin<&mut Self>) { + let client = Arc::clone(&self.pubsub); + let channel = self.channel.clone(); + spawn(async move { + client.unsubscribe(channel).await.unwrap(); + }); + } +} diff --git a/brokers/src/redis/rpc.rs b/brokers/src/redis/rpc.rs new file mode 100644 index 0000000..8e18e28 --- /dev/null +++ b/brokers/src/redis/rpc.rs @@ -0,0 +1,38 @@ +use futures::TryStreamExt; +use serde::de::DeserializeOwned; + +use crate::error::Result; + +use super::RedisBroker; + +/// A Remote Procedure Call. Poll the future returned by `response` to get the response value. +#[derive(Debug, Clone)] +pub struct Rpc<'broker> { + pub(crate) name: String, + pub(crate) broker: &'broker RedisBroker<'broker>, +} + +impl<'broker> PartialEq for Rpc<'broker> { + fn eq(&self, other: &Self) -> bool { + self.name == other.name + } +} + +impl<'broker> Eq for Rpc<'broker> {} + +impl<'broker> Rpc<'broker> { + pub async fn response(&self) -> Result> + where + V: DeserializeOwned, + { + Ok(self + .broker + .pubsub + .subscribe(self.name.clone()) + .await? + .try_next() + .await? + .map(|msg| rmp_serde::from_read(msg.as_bytes())) + .transpose()?) + } +}