diff --git a/bindings_ffi/src/mls.rs b/bindings_ffi/src/mls.rs index 4cc709f0d..a66ad905f 100644 --- a/bindings_ffi/src/mls.rs +++ b/bindings_ffi/src/mls.rs @@ -255,6 +255,26 @@ impl FfiConversations { pub async fn stream( &self, callback: Box, + ) -> Result, GenericError> { + let client = self.inner_client.clone(); + let stream_closer = + RustXmtpClient::stream_conversations_with_callback(client.clone(), move |convo| { + callback.on_conversation(Arc::new(FfiGroup { + inner_client: client.clone(), + group_id: convo.group_id, + created_at_ns: convo.created_at_ns, + })) + })?; + + Ok(Arc::new(FfiStreamCloser { + close_fn: stream_closer.close_fn, + is_closed_atomic: stream_closer.is_closed_atomic, + })) + } + + pub async fn stream_all_messages( + &self, + message_callback: Box, ) -> Result, GenericError> { let inner_client = Arc::clone(&self.inner_client); let (close_sender, close_receiver) = oneshot::channel::<()>(); @@ -262,18 +282,15 @@ impl FfiConversations { let is_closed_clone = is_closed.clone(); tokio::spawn(async move { - let client = inner_client.as_ref(); - let mut stream = client.stream_conversations().await.unwrap(); + let mut stream = RustXmtpClient::stream_all_messages(inner_client) + .await + .unwrap(); let mut close_receiver = close_receiver; loop { tokio::select! { item = stream.next() => { match item { - Some(convo) => callback.on_conversation(Arc::new(FfiGroup { - inner_client: inner_client.clone(), - group_id: convo.group_id, - created_at_ns: convo.created_at_ns, - })), + Some(message) => message_callback.on_message(message.into()), None => break } } @@ -934,6 +951,51 @@ mod tests { assert!(stream.is_closed()); } + #[tokio::test(flavor = "multi_thread", worker_threads = 5)] + async fn test_stream_all_messages_unchanging_group_list() { + let alix = new_test_client().await; + let bo = new_test_client().await; + let caro = new_test_client().await; + + let alix_group = alix + .conversations() + .create_group(vec![caro.account_address()], None) + .await + .unwrap(); + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + let caro_group = caro + .conversations() + .create_group(vec![bo.account_address()], None) + .await + .unwrap(); + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + let stream_callback = RustStreamCallback::new(); + let stream = caro + .conversations() + .stream_all_messages(Box::new(stream_callback.clone())) + .await + .unwrap(); + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + alix_group.send("first".as_bytes().to_vec()).await.unwrap(); + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + caro_group.send("second".as_bytes().to_vec()).await.unwrap(); + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + alix_group.send("third".as_bytes().to_vec()).await.unwrap(); + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + caro_group.send("fourth".as_bytes().to_vec()).await.unwrap(); + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + assert_eq!(stream_callback.message_count(), 4); + stream.end(); + tokio::time::sleep(tokio::time::Duration::from_millis(5)).await; + assert!(stream.is_closed()); + } + // Disabling this flakey test until it's reliable #[tokio::test(flavor = "multi_thread", worker_threads = 5)] async fn test_streaming() { diff --git a/xmtp_mls/src/client.rs b/xmtp_mls/src/client.rs index 253ff10ea..8c91dda0f 100644 --- a/xmtp_mls/src/client.rs +++ b/xmtp_mls/src/client.rs @@ -1,6 +1,5 @@ -use std::{collections::HashSet, mem::Discriminant, pin::Pin}; +use std::{collections::HashSet, mem::Discriminant}; -use futures::{Stream, StreamExt}; use openmls::{ framing::{MlsMessageIn, MlsMessageInBody}, group::GroupEpoch, @@ -69,6 +68,8 @@ pub enum ClientError { KeyPackageVerification(#[from] KeyPackageVerificationError), #[error("syncing errors: {0:?}")] SyncingError(Vec), + #[error("Stream inconsistency error: {0}")] + StreamInconsistency(String), #[error("generic:{0}")] Generic(String), } @@ -468,55 +469,11 @@ where }) .collect()) } - - fn process_streamed_welcome( - &self, - welcome: WelcomeMessage, - ) -> Result, ClientError> { - let welcome_v1 = extract_welcome_message(welcome)?; - let conn = self.store.conn()?; - let provider = self.mls_provider(&conn); - - MlsGroup::create_from_encrypted_welcome( - self, - &provider, - welcome_v1.hpke_public_key.as_slice(), - welcome_v1.data, - ) - .map_err(|e| ClientError::Generic(e.to_string())) - } - - pub async fn stream_conversations( - &'a self, - ) -> Result> + Send + 'a>>, ClientError> { - let installation_key = self.installation_public_key(); - let id_cursor = 0; - - let subscription = self - .api_client - .subscribe_welcome_messages(installation_key, Some(id_cursor as u64)) - .await?; - - let stream = subscription - .map(|welcome_result| async { - let welcome = welcome_result?; - self.process_streamed_welcome(welcome) - }) - .filter_map(|res| async { - match res.await { - Ok(group) => Some(group), - Err(err) => { - log::error!("Error processing stream entry: {:?}", err); - None - } - } - }); - - Ok(Box::pin(stream)) - } } -fn extract_welcome_message(welcome: WelcomeMessage) -> Result { +pub(crate) fn extract_welcome_message( + welcome: WelcomeMessage, +) -> Result { match welcome.version { Some(WelcomeMessageVersion::V1(welcome)) => Ok(welcome), _ => Err(ClientError::Generic( @@ -551,7 +508,6 @@ fn has_active_installation(updates: &Vec) -> bool { #[cfg(test)] mod tests { - use futures::StreamExt; use xmtp_cryptography::utils::generate_local_wallet; use crate::{ @@ -733,21 +689,4 @@ mod tests { vec![1, 2, 3] ) } - - #[tokio::test] - async fn test_stream_welcomes() { - let alice = ClientBuilder::new_test_client(&generate_local_wallet()).await; - let bob = ClientBuilder::new_test_client(&generate_local_wallet()).await; - - let alice_bob_group = alice.create_group(None).unwrap(); - - let mut bob_stream = bob.stream_conversations().await.unwrap(); - alice_bob_group - .add_members(vec![bob.account_address()]) - .await - .unwrap(); - - let bob_received_groups = bob_stream.next().await.unwrap(); - assert_eq!(bob_received_groups.group_id, alice_bob_group.group_id); - } } diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index 846d1b8bd..93d965881 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -131,6 +131,16 @@ pub struct MlsGroup<'c, ApiClient> { client: &'c Client, } +impl<'c, ApiClient> Clone for MlsGroup<'c, ApiClient> { + fn clone(&self) -> Self { + Self { + client: self.client, + group_id: self.group_id.clone(), + created_at_ns: self.created_at_ns, + } + } +} + impl<'c, ApiClient> MlsGroup<'c, ApiClient> where ApiClient: XmtpMlsClient, @@ -356,6 +366,13 @@ fn extract_message_v1(message: GroupMessage) -> Result Result, MessageProcessingError> { + match &message.version { + Some(GroupMessageVersion::V1(value)) => Ok(value.group_id.clone()), + _ => Err(MessageProcessingError::InvalidPayload), + } +} + fn validate_ed25519_keys(keys: &[Vec]) -> Result<(), GroupError> { let mut invalid = keys .iter() diff --git a/xmtp_mls/src/groups/subscriptions.rs b/xmtp_mls/src/groups/subscriptions.rs index 2f0f6355f..fb09d440d 100644 --- a/xmtp_mls/src/groups/subscriptions.rs +++ b/xmtp_mls/src/groups/subscriptions.rs @@ -12,7 +12,7 @@ impl<'c, ApiClient> MlsGroup<'c, ApiClient> where ApiClient: XmtpMlsClient, { - async fn process_stream_entry( + pub(crate) async fn process_stream_entry( &self, envelope: GroupMessage, ) -> Result, GroupError> { diff --git a/xmtp_mls/src/lib.rs b/xmtp_mls/src/lib.rs index 3461f610a..7f24b51d4 100644 --- a/xmtp_mls/src/lib.rs +++ b/xmtp_mls/src/lib.rs @@ -10,6 +10,7 @@ pub mod identity; pub mod owner; pub mod retry; pub mod storage; +pub mod subscriptions; pub mod types; pub mod utils; pub mod verified_key_package; diff --git a/xmtp_mls/src/subscriptions.rs b/xmtp_mls/src/subscriptions.rs new file mode 100644 index 000000000..d87b37cfc --- /dev/null +++ b/xmtp_mls/src/subscriptions.rs @@ -0,0 +1,278 @@ +use std::{ + collections::HashMap, + pin::Pin, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, Mutex, + }, +}; + +use futures::{Stream, StreamExt}; +use tokio::sync::oneshot::{self, Sender}; +use xmtp_proto::{api_client::XmtpMlsClient, xmtp::mls::api::v1::WelcomeMessage}; + +use crate::{ + api_client_wrapper::GroupFilter, + client::{extract_welcome_message, ClientError}, + groups::{extract_group_id, GroupError, MlsGroup}, + storage::{group::StoredGroup, group_message::StoredGroupMessage}, + Client, +}; + +pub struct StreamCloser { + pub close_fn: Arc>>>, + pub is_closed_atomic: Arc, +} + +impl StreamCloser { + pub fn end(&self) { + match self.close_fn.lock() { + Ok(mut close_fn_option) => { + let _ = close_fn_option.take().map(|close_fn| close_fn.send(())); + } + _ => { + log::warn!("close_fn already closed"); + } + } + } + + pub fn is_closed(&self) -> bool { + self.is_closed_atomic.load(Ordering::Relaxed) + } +} +struct MessagesStreamInfo { + group: StoredGroup, + cursor: u64, +} + +impl Clone for MessagesStreamInfo { + fn clone(&self) -> Self { + Self { + group: self.group.clone(), + cursor: self.cursor, + } + } +} + +impl<'a, ApiClient> Client +where + ApiClient: XmtpMlsClient + 'static, +{ + fn process_streamed_welcome( + &self, + welcome: WelcomeMessage, + ) -> Result, ClientError> { + let welcome_v1 = extract_welcome_message(welcome)?; + let conn = self.store.conn()?; + let provider = self.mls_provider(&conn); + + MlsGroup::create_from_encrypted_welcome( + self, + &provider, + welcome_v1.hpke_public_key.as_slice(), + welcome_v1.data, + ) + .map_err(|e| ClientError::Generic(e.to_string())) + } + + pub async fn stream_conversations( + &'a self, + ) -> Result> + Send + 'a>>, ClientError> { + let installation_key = self.installation_public_key(); + let id_cursor = 0; + + let subscription = self + .api_client + .subscribe_welcome_messages(installation_key, Some(id_cursor as u64)) + .await?; + + let stream = subscription + .map(|welcome_result| async { + let welcome = welcome_result?; + self.process_streamed_welcome(welcome) + }) + .filter_map(|res| async { + match res.await { + Ok(group) => Some(group), + Err(err) => { + log::error!("Error processing stream entry: {:?}", err); + None + } + } + }); + + Ok(Box::pin(stream)) + } +} + +impl Client +where + ApiClient: XmtpMlsClient + 'static, +{ + pub fn stream_conversations_with_callback( + client: Arc>, + callback: impl Fn(MlsGroup) + Send + 'static, + ) -> Result { + let (close_sender, close_receiver) = oneshot::channel::<()>(); + let is_closed = Arc::new(AtomicBool::new(false)); + let is_closed_clone = is_closed.clone(); + + tokio::spawn(async move { + let mut stream = client.stream_conversations().await.unwrap(); + let mut close_receiver = close_receiver; + loop { + tokio::select! { + item = stream.next() => { + match item { + Some(convo) => { callback(convo) }, + None => break + } + } + _ = &mut close_receiver => { + break; + } + } + } + is_closed_clone.store(true, Ordering::Relaxed); + log::info!("closing stream"); + }); + + Ok(StreamCloser { + close_fn: Arc::new(Mutex::new(Some(close_sender))), + is_closed_atomic: is_closed, + }) + } + + pub async fn stream_all_messages( + client: Arc>, + ) -> Result + Send>>, ClientError> { + client.sync_welcomes().await?; // TODO pipe cursor from welcomes sync into groups_stream + let group_id_to_info: HashMap, MessagesStreamInfo> = client + .store + .conn()? + .find_groups(None, None, None, None)? + .into_iter() + .map(|group| (group.id.clone(), MessagesStreamInfo { group, cursor: 0 })) + .collect(); + // let groups_stream = Self::stream_conversations_with_callback(client.clone(), |convo| {}); + // TODO update messages_stream based on groups_stream + + let filters: Vec = group_id_to_info + .iter() + .map(|(group_id, info)| GroupFilter::new(group_id.clone(), Some(info.cursor))) + .collect(); + let messages_subscription = client.api_client.subscribe_group_messages(filters).await?; + let group_id_to_info_clone = group_id_to_info.clone(); + let client_clone = client.clone(); + let stream = messages_subscription + .map(move |res| { + let group_id_to_info_clone = group_id_to_info_clone.clone(); + let client_clone = client_clone.clone(); + async move { + match res { + Ok(envelope) => { + let group_id = extract_group_id(&envelope)?; + let stored_group = &group_id_to_info_clone + .get(&group_id) + .ok_or(ClientError::StreamInconsistency( + "Received message for a non-subscribed group".to_string(), + ))? + .group; + MlsGroup::new( + client_clone.as_ref(), + group_id, + stored_group.created_at_ns, + ) + .process_stream_entry(envelope) + .await + } + Err(err) => Err(GroupError::Api(err)), + } + } + }) + .filter_map(move |res| async { + match res.await { + Ok(Some(message)) => Some(message), + Ok(None) => None, + Err(err) => { + log::error!("Error processing stream entry: {:?}", err); + None + } + } + }); + + Ok(Box::pin(stream)) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use futures::StreamExt; + use xmtp_cryptography::utils::generate_local_wallet; + + use crate::{builder::ClientBuilder, Client}; + + #[tokio::test] + async fn test_stream_welcomes() { + let alice = ClientBuilder::new_test_client(&generate_local_wallet()).await; + let bob = ClientBuilder::new_test_client(&generate_local_wallet()).await; + + let alice_bob_group = alice.create_group(None).unwrap(); + + let mut bob_stream = bob.stream_conversations().await.unwrap(); + alice_bob_group + .add_members(vec![bob.account_address()]) + .await + .unwrap(); + + let bob_received_groups = bob_stream.next().await.unwrap(); + assert_eq!(bob_received_groups.group_id, alice_bob_group.group_id); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 10)] + async fn test_stream_all_messages_unchanging_group_list() { + let alix = ClientBuilder::new_test_client(&generate_local_wallet()).await; + let bo = ClientBuilder::new_test_client(&generate_local_wallet()).await; + let caro = ClientBuilder::new_test_client(&generate_local_wallet()).await; + + let alix_group = alix.create_group(None).unwrap(); + alix_group + .add_members_by_installation_id(vec![caro.installation_public_key()]) + .await + .unwrap(); + + let bo_group = bo.create_group(None).unwrap(); + bo_group + .add_members_by_installation_id(vec![caro.installation_public_key()]) + .await + .unwrap(); + + let mut stream = Client::stream_all_messages(Arc::new(caro)).await.unwrap(); + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + + alix_group.send_message("first".as_bytes()).await.unwrap(); + bo_group.send_message("second".as_bytes()).await.unwrap(); + alix_group.send_message("third".as_bytes()).await.unwrap(); + bo_group.send_message("fourth".as_bytes()).await.unwrap(); + tokio::time::sleep(std::time::Duration::from_millis(200)).await; + + assert_eq!( + stream.next().await.unwrap().decrypted_message_bytes, + "first".as_bytes() + ); + assert_eq!( + stream.next().await.unwrap().decrypted_message_bytes, + "second".as_bytes() + ); + assert_eq!( + stream.next().await.unwrap().decrypted_message_bytes, + "third".as_bytes() + ); + assert_eq!( + stream.next().await.unwrap().decrypted_message_bytes, + "fourth".as_bytes() + ); + } +}