diff --git a/Cargo.lock b/Cargo.lock index b7d5b938a..8e784cf9d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6275,6 +6275,7 @@ dependencies = [ "thiserror", "tls_codec 0.4.1", "tokio", + "tokio-stream", "toml 0.8.13", "tracing", "tracing-log", diff --git a/bindings_ffi/Cargo.lock b/bindings_ffi/Cargo.lock index 8bce151aa..d2325e0f0 100644 --- a/bindings_ffi/Cargo.lock +++ b/bindings_ffi/Cargo.lock @@ -5823,6 +5823,7 @@ dependencies = [ "thiserror", "tls_codec 0.4.0", "tokio", + "tokio-stream", "toml 0.8.8", "xmtp_cryptography", "xmtp_id", diff --git a/bindings_node/Cargo.lock b/bindings_node/Cargo.lock index 0de3f366b..62081d3de 100644 --- a/bindings_node/Cargo.lock +++ b/bindings_node/Cargo.lock @@ -2616,7 +2616,7 @@ dependencies = [ [[package]] name = "openmls" version = "0.5.0" -source = "git+https://github.com/xmtp/openmls?rev=606bf92#606bf929e133422fe9737ba7089f6e63a4738300" +source = "git+https://github.com/xmtp/openmls?rev=99b2d5e7d0e034ac57644395e2194c5a102afb9a#99b2d5e7d0e034ac57644395e2194c5a102afb9a" dependencies = [ "backtrace", "itertools 0.10.5", @@ -2638,7 +2638,7 @@ dependencies = [ [[package]] name = "openmls_basic_credential" version = "0.2.0" -source = "git+https://github.com/xmtp/openmls?rev=606bf92#606bf929e133422fe9737ba7089f6e63a4738300" +source = "git+https://github.com/xmtp/openmls?rev=99b2d5e7d0e034ac57644395e2194c5a102afb9a#99b2d5e7d0e034ac57644395e2194c5a102afb9a" dependencies = [ "ed25519-dalek", "openmls_traits", @@ -2651,7 +2651,7 @@ dependencies = [ [[package]] name = "openmls_memory_storage" version = "0.2.0" -source = "git+https://github.com/xmtp/openmls?rev=606bf92#606bf929e133422fe9737ba7089f6e63a4738300" +source = "git+https://github.com/xmtp/openmls?rev=99b2d5e7d0e034ac57644395e2194c5a102afb9a#99b2d5e7d0e034ac57644395e2194c5a102afb9a" dependencies = [ "hex", "log", @@ -2664,7 +2664,7 @@ dependencies = [ [[package]] name = "openmls_rust_crypto" version = "0.2.0" -source = "git+https://github.com/xmtp/openmls?rev=606bf92#606bf929e133422fe9737ba7089f6e63a4738300" +source = "git+https://github.com/xmtp/openmls?rev=99b2d5e7d0e034ac57644395e2194c5a102afb9a#99b2d5e7d0e034ac57644395e2194c5a102afb9a" dependencies = [ "aes-gcm", "chacha20poly1305", @@ -2688,7 +2688,7 @@ dependencies = [ [[package]] name = "openmls_test" version = "0.1.0" -source = "git+https://github.com/xmtp/openmls?rev=606bf92#606bf929e133422fe9737ba7089f6e63a4738300" +source = "git+https://github.com/xmtp/openmls?rev=99b2d5e7d0e034ac57644395e2194c5a102afb9a#99b2d5e7d0e034ac57644395e2194c5a102afb9a" dependencies = [ "ansi_term", "openmls_rust_crypto", @@ -2703,7 +2703,7 @@ dependencies = [ [[package]] name = "openmls_traits" version = "0.2.0" -source = "git+https://github.com/xmtp/openmls?rev=606bf92#606bf929e133422fe9737ba7089f6e63a4738300" +source = "git+https://github.com/xmtp/openmls?rev=99b2d5e7d0e034ac57644395e2194c5a102afb9a#99b2d5e7d0e034ac57644395e2194c5a102afb9a" dependencies = [ "serde", "tls_codec 0.4.2-pre.1", diff --git a/xmtp_mls/Cargo.toml b/xmtp_mls/Cargo.toml index bbd2e5959..972109433 100644 --- a/xmtp_mls/Cargo.toml +++ b/xmtp_mls/Cargo.toml @@ -47,6 +47,7 @@ smart-default = "0.7.1" thiserror = { workspace = true } tls_codec = { workspace = true } tokio = { workspace = true, features = ["rt-multi-thread"] } +tokio-stream = "0.1" toml = "0.8.4" xmtp_cryptography = { workspace = true } xmtp_id = { path = "../xmtp_id" } diff --git a/xmtp_mls/src/subscriptions.rs b/xmtp_mls/src/subscriptions.rs index 7affc1855..11adeadb7 100644 --- a/xmtp_mls/src/subscriptions.rs +++ b/xmtp_mls/src/subscriptions.rs @@ -9,7 +9,14 @@ use std::{ use futures::{Stream, StreamExt}; use prost::Message; -use tokio::sync::oneshot::{self, Sender}; +use tokio::{ + sync::{ + mpsc::{self, UnboundedSender}, + oneshot::{self, Sender}, + }, + task::JoinHandle, +}; +use tokio_stream::wrappers::UnboundedReceiverStream; use xmtp_proto::xmtp::mls::api::v1::WelcomeMessage; use crate::{ @@ -242,6 +249,58 @@ where }) } + pub async fn stream_all_messages( + client: Arc>, + ) -> Result, ClientError> { + let mut handle; + + let (tx, rx) = mpsc::unbounded_channel(); + + client.sync_welcomes().await?; + + let current_groups = client.store().conn()?.find_groups(None, None, None, None)?; + + let mut group_id_to_info: HashMap, MessagesStreamInfo> = current_groups + .into_iter() + .map(|group| { + ( + group.id.clone(), + MessagesStreamInfo { + convo_created_at_ns: group.created_at_ns, + cursor: 0, + }, + ) + }) + .collect(); + handle = Self::relay_messages(client.clone(), tx.clone(), group_id_to_info.clone()); + + tokio::spawn(async move { + let client_pointer = client.clone(); + let mut convo_stream = Self::stream_conversations(&client_pointer).await?; + + while let Some(new_group) = convo_stream.next().await { + if group_id_to_info.contains_key(&new_group.group_id) { + continue; + } + + handle.abort(); + for info in group_id_to_info.values_mut() { + info.cursor = 0; + } + group_id_to_info.insert( + new_group.group_id, + MessagesStreamInfo { + convo_created_at_ns: new_group.created_at_ns, + cursor: 1, + }, + ); + handle = Self::relay_messages(client.clone(), tx.clone(), group_id_to_info.clone()); + } + }); + + Ok(UnboundedReceiverStream::new(rx)) + } + pub async fn stream_all_messages_with_callback( client: Arc>, callback: impl FnMut(StoredGroupMessage) + Send + Sync + 'static, @@ -313,6 +372,23 @@ where Ok(groups_stream_closer) } + + fn relay_messages( + client: Arc>, + tx: UnboundedSender, + group_id_to_info: HashMap, MessagesStreamInfo>, + ) -> JoinHandle> { + tokio::spawn(async move { + let mut stream = client.stream_messages(group_id_to_info).await?; + while let Some(message) = stream.next().await { + // an error can only mean the receiver has been dropped or closed + if tx.send(message).is_err() { + break; + } + } + Ok::<_, ClientError>(()) + }) + } } #[cfg(test)]