diff --git a/bindings_ffi/src/mls.rs b/bindings_ffi/src/mls.rs index 7d2b2cafd..a66ad905f 100644 --- a/bindings_ffi/src/mls.rs +++ b/bindings_ffi/src/mls.rs @@ -2,12 +2,13 @@ pub use crate::inbox_owner::SigningError; use crate::logger::init_logger; use crate::logger::FfiLogger; use crate::GenericError; +use futures::StreamExt; use std::convert::TryInto; use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, Mutex, }; -use tokio::sync::oneshot::Sender; +use tokio::sync::{oneshot, oneshot::Sender}; use xmtp_api_grpc::grpc_api_helper::Client as TonicApiClient; use xmtp_mls::builder::IdentityStrategy; use xmtp_mls::builder::LegacyIdentity; @@ -256,17 +257,14 @@ impl FfiConversations { callback: Box, ) -> Result, GenericError> { let client = self.inner_client.clone(); - let stream_closer = RustXmtpClient::stream_conversations_with_callback( - client.clone(), - move |convo| { + let stream_closer = + RustXmtpClient::stream_conversations_with_callback(client.clone(), move |convo| { callback.on_conversation(Arc::new(FfiGroup { inner_client: client.clone(), group_id: convo.group_id, created_at_ns: convo.created_at_ns, })) - }, - || {}, // on_close_callback - )?; + })?; Ok(Arc::new(FfiStreamCloser { close_fn: stream_closer.close_fn, @@ -278,15 +276,36 @@ impl FfiConversations { &self, message_callback: Box, ) -> Result, GenericError> { - let stream_closer = RustXmtpClient::stream_all_messages_with_callback( - self.inner_client.clone(), - move |message| message_callback.on_message(message.into()), - ) - .await?; + let inner_client = Arc::clone(&self.inner_client); + let (close_sender, close_receiver) = oneshot::channel::<()>(); + let is_closed = Arc::new(AtomicBool::new(false)); + let is_closed_clone = is_closed.clone(); + + tokio::spawn(async move { + let mut stream = RustXmtpClient::stream_all_messages(inner_client) + .await + .unwrap(); + let mut close_receiver = close_receiver; + loop { + tokio::select! { + item = stream.next() => { + match item { + Some(message) => message_callback.on_message(message.into()), + None => break + } + } + _ = &mut close_receiver => { + break; + } + } + } + is_closed_clone.store(true, Ordering::Relaxed); + log::info!("closing stream"); + }); Ok(Arc::new(FfiStreamCloser { - close_fn: stream_closer.close_fn, - is_closed_atomic: stream_closer.is_closed_atomic, + close_fn: Arc::new(Mutex::new(Some(close_sender))), + is_closed_atomic: is_closed, })) } } @@ -406,17 +425,37 @@ impl FfiGroup { message_callback: Box, ) -> Result, GenericError> { let inner_client = Arc::clone(&self.inner_client); - let stream_closer = MlsGroup::stream_with_callback( - inner_client, - self.group_id.clone(), - self.created_at_ns, - move |message| message_callback.on_message(message.into()), - ) - .await?; + let group_id = self.group_id.clone(); + let created_at_ns = self.created_at_ns; + let (close_sender, close_receiver) = oneshot::channel::<()>(); + let is_closed = Arc::new(AtomicBool::new(false)); + let is_closed_clone = is_closed.clone(); + + tokio::spawn(async move { + let client = inner_client.as_ref(); + let group = MlsGroup::new(&client, group_id, created_at_ns); + let mut stream = group.stream().await.unwrap(); + let mut close_receiver = close_receiver; + loop { + tokio::select! { + item = stream.next() => { + match item { + Some(message) => message_callback.on_message(message.into()), + None => break + } + } + _ = &mut close_receiver => { + break; + } + } + } + is_closed_clone.store(true, Ordering::Relaxed); + log::info!("closing stream"); + }); Ok(Arc::new(FfiStreamCloser { - close_fn: stream_closer.close_fn, - is_closed_atomic: stream_closer.is_closed_atomic, + close_fn: Arc::new(Mutex::new(Some(close_sender))), + is_closed_atomic: is_closed, })) } @@ -605,9 +644,7 @@ mod tests { } impl FfiMessageCallback for RustStreamCallback { - fn on_message(&self, message: FfiMessage) { - let message = String::from_utf8(message.content).unwrap_or("".to_string()); - log::info!("Received: {}", message); + fn on_message(&self, _: FfiMessage) { *self.num_messages.lock().unwrap() += 1; } } @@ -915,7 +952,7 @@ mod tests { } #[tokio::test(flavor = "multi_thread", worker_threads = 5)] - async fn test_stream_all_messages() { + async fn test_stream_all_messages_unchanging_group_list() { let alix = new_test_client().await; let bo = new_test_client().await; let caro = new_test_client().await; @@ -925,6 +962,15 @@ mod tests { .create_group(vec![caro.account_address()], None) .await .unwrap(); + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + let caro_group = caro + .conversations() + .create_group(vec![bo.account_address()], None) + .await + .unwrap(); + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; let stream_callback = RustStreamCallback::new(); @@ -937,17 +983,11 @@ mod tests { alix_group.send("first".as_bytes().to_vec()).await.unwrap(); tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - let bo_group = bo - .conversations() - .create_group(vec![caro.account_address()], None) - .await - .unwrap(); - tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; - bo_group.send("second".as_bytes().to_vec()).await.unwrap(); + caro_group.send("second".as_bytes().to_vec()).await.unwrap(); tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; alix_group.send("third".as_bytes().to_vec()).await.unwrap(); tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - bo_group.send("fourth".as_bytes().to_vec()).await.unwrap(); + caro_group.send("fourth".as_bytes().to_vec()).await.unwrap(); tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; assert_eq!(stream_callback.message_count(), 4); @@ -956,8 +996,9 @@ mod tests { assert!(stream.is_closed()); } + // Disabling this flakey test until it's reliable #[tokio::test(flavor = "multi_thread", worker_threads = 5)] - async fn test_message_streaming() { + async fn test_streaming() { let amal = new_test_client().await; let bola = new_test_client().await; @@ -985,67 +1026,4 @@ mod tests { stream_closer.end(); } - - #[tokio::test(flavor = "multi_thread", worker_threads = 5)] - async fn test_message_streaming_when_removed_then_added() { - let amal = new_test_client().await; - let bola = new_test_client().await; - log::info!( - "Created addresses {} and {}", - amal.account_address(), - bola.account_address() - ); - - let amal_group = amal - .conversations() - .create_group(vec![bola.account_address()], None) - .await - .unwrap(); - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - - let stream_callback = RustStreamCallback::new(); - let stream_closer = bola - .conversations() - .stream_all_messages(Box::new(stream_callback.clone())) - .await - .unwrap(); - - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - - amal_group.send("hello1".as_bytes().to_vec()).await.unwrap(); - amal_group.send("hello2".as_bytes().to_vec()).await.unwrap(); - - tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; - assert_eq!(stream_callback.message_count(), 2); - assert!(!stream_closer.is_closed()); - - amal_group - .remove_members(vec![bola.account_address()]) - .await - .unwrap(); - tokio::time::sleep(std::time::Duration::from_millis(60000)).await; - assert_eq!(stream_callback.message_count(), 3); // Member removal transcript message - - amal_group.send("hello3".as_bytes().to_vec()).await.unwrap(); - tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; - assert_eq!(stream_callback.message_count(), 3); // Don't receive messages while removed - assert!(!stream_closer.is_closed()); - - tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; - amal_group - .add_members(vec![bola.account_address()]) - .await - .unwrap(); - tokio::time::sleep(std::time::Duration::from_millis(500)).await; - assert_eq!(stream_callback.message_count(), 3); // Don't receive transcript messages while removed - - amal_group.send("hello4".as_bytes().to_vec()).await.unwrap(); - tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; - assert_eq!(stream_callback.message_count(), 4); // Receiving messages again - assert!(!stream_closer.is_closed()); - - stream_closer.end(); - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - assert!(stream_closer.is_closed()); - } } diff --git a/xmtp_mls/src/groups/subscriptions.rs b/xmtp_mls/src/groups/subscriptions.rs index 0c5c0351e..fb09d440d 100644 --- a/xmtp_mls/src/groups/subscriptions.rs +++ b/xmtp_mls/src/groups/subscriptions.rs @@ -1,15 +1,12 @@ -use std::collections::HashMap; use std::pin::Pin; -use std::sync::Arc; -use futures::Stream; +use futures::{Stream, StreamExt}; use xmtp_proto::{api_client::XmtpMlsClient, xmtp::mls::api::v1::GroupMessage}; use super::{extract_message_v1, GroupError, MlsGroup}; +use crate::api_client_wrapper::GroupFilter; use crate::storage::group_message::StoredGroupMessage; -use crate::subscriptions::{MessagesStreamInfo, StreamCloser}; -use crate::Client; impl<'c, ApiClient> MlsGroup<'c, ApiClient> where @@ -21,29 +18,18 @@ where ) -> Result, GroupError> { let msgv1 = extract_message_v1(envelope)?; - log::info!("Running transaction {}", self.client.account_address()); let process_result = self.client.store.transaction(|provider| { let mut openmls_group = self.load_mls_group(&provider)?; // Attempt processing immediately, but fail if the message is not an Application Message // Returning an error should roll back the DB tx - log::info!( - "Processing message in process_stream_entry {}", - self.client.account_address() - ); - let res = self - .process_message(&mut openmls_group, &provider, &msgv1, false) - .map_err(GroupError::ReceiveError); - log::info!("Got process message result {:?}", res); - res + self.process_message(&mut openmls_group, &provider, &msgv1, false) + .map_err(GroupError::ReceiveError) }); - log::info!("Got process_result {:?}", process_result); if let Some(GroupError::ReceiveError(_)) = process_result.err() { - log::info!("Re-syncing due to unreadable messaging stream payload"); self.sync().await?; } - log::info!("Storing message"); // Load the message from the DB to handle cases where it may have been already processed in // another thread let new_message = self @@ -51,7 +37,6 @@ where .store .conn()? .get_group_message_by_timestamp(&self.group_id, msgv1.created_ns as i64)?; - log::info!("Stored message"); Ok(new_message) } @@ -59,35 +44,35 @@ where pub async fn stream( &'c self, ) -> Result + 'c + Send>>, GroupError> { - Ok(self + let last_cursor = 0; + + let subscription = self .client - .stream_messages(HashMap::from([( + .api_client + .subscribe_group_messages(vec![GroupFilter::new( self.group_id.clone(), - MessagesStreamInfo { - convo_created_at_ns: self.created_at_ns, - cursor: 0, - }, - )])) - .await?) - } - - pub async fn stream_with_callback( - client: Arc>, - group_id: Vec, - created_at_ns: i64, - mut callback: impl FnMut(StoredGroupMessage) + Send + 'static, - ) -> Result { - Ok(Client::::stream_messages_with_callback( - client, - HashMap::from([( - group_id, - MessagesStreamInfo { - convo_created_at_ns: created_at_ns, - cursor: 0, - }, - )]), - move |message| callback(message), - )?) + Some(last_cursor as u64), + )]) + .await?; + let stream = subscription + .map(|res| async { + match res { + Ok(envelope) => self.process_stream_entry(envelope).await, + Err(err) => Err(GroupError::Api(err)), + } + }) + .filter_map(move |res| async { + match res.await { + Ok(Some(message)) => Some(message), + Ok(None) => None, + Err(err) => { + log::error!("Error processing stream entry: {:?}", err); + None + } + } + }); + + Ok(Box::pin(stream)) } } diff --git a/xmtp_mls/src/groups/sync.rs b/xmtp_mls/src/groups/sync.rs index 560326a9e..395ead7bd 100644 --- a/xmtp_mls/src/groups/sync.rs +++ b/xmtp_mls/src/groups/sync.rs @@ -74,7 +74,6 @@ where conn: &'a DbConnection<'a>, ) -> Result<(), GroupError> { let mut errors: Vec = vec![]; - log::info!("Sync1"); // Even if publish fails, continue to receiving if let Err(publish_error) = self.publish_intents(conn).await { @@ -82,8 +81,6 @@ where errors.push(publish_error); } - log::info!("Sync2"); - // Even if receiving fails, continue to post_commit if let Err(receive_error) = self.receive(conn).await { log::error!("receive error {:?}", receive_error); @@ -92,22 +89,16 @@ where // added to the group } - log::info!("Sync3"); - if let Err(post_commit_err) = self.post_commit(conn).await { log::error!("post commit error {:?}", post_commit_err); errors.push(post_commit_err); } - log::info!("Sync4"); - // Return a combination of publish and post_commit errors if !errors.is_empty() { return Err(GroupError::Sync(errors)); } - log::info!("Sync5"); - Ok(()) } @@ -166,7 +157,6 @@ where envelope_timestamp_ns: u64, allow_epoch_increment: bool, ) -> Result<(), MessageProcessingError> { - log::info!("Process own message {}", self.client.account_address()); if intent.state == IntentState::Committed { return Ok(()); } @@ -253,23 +243,12 @@ where "[{}] processing private message", self.client.account_address() ); - log::info!("process_external_message {}", self.client.account_address()); let decrypted_message = openmls_group.process_message(provider, message)?; - log::info!( - "process_external_message a {}", - self.client.account_address() - ); let (sender_account_address, sender_installation_id) = validate_message_sender(openmls_group, &decrypted_message, envelope_timestamp_ns)?; - log::info!( - "process_external_message 2 {}", - self.client.account_address() - ); - match decrypted_message.into_content() { ProcessedMessageContent::ApplicationMessage(application_message) => { - log::info!("Application message {}", self.client.account_address()); let message_bytes = application_message.into_bytes(); let message_id = get_message_id(&message_bytes, &self.group_id, envelope_timestamp_ns); @@ -285,21 +264,12 @@ where .store(provider.conn())?; } ProcessedMessageContent::ProposalMessage(_proposal_ptr) => { - log::info!( - "process_external_message 3 {}", - self.client.account_address() - ); // intentionally left blank. } ProcessedMessageContent::ExternalJoinProposalMessage(_external_proposal_ptr) => { - log::info!( - "process_external_message 4 {}", - self.client.account_address() - ); // intentionally left blank. } ProcessedMessageContent::StagedCommitMessage(staged_commit) => { - log::info!("Staged commit {}", self.client.account_address()); if !allow_epoch_increment { return Err(MessageProcessingError::EpochIncrementNotAllowed); } @@ -309,18 +279,14 @@ where ); let sc = *staged_commit; - log::info!("Staged 1 {}", self.client.account_address()); // Validate the commit let validated_commit = ValidatedCommit::from_staged_commit(&sc, openmls_group)?; - log::info!("Staged 2 {}", self.client.account_address()); openmls_group.merge_staged_commit(provider, sc)?; - log::info!("Staged 3 {}", self.client.account_address()); self.save_transcript_message( provider.conn(), validated_commit, envelope_timestamp_ns, )?; - log::info!("Staged 4 {}", self.client.account_address()); } }; @@ -334,7 +300,6 @@ where envelope: &GroupMessageV1, allow_epoch_increment: bool, ) -> Result<(), MessageProcessingError> { - log::info!("process message 1 {}", self.client.account_address()); let mls_message_in = MlsMessageIn::tls_deserialize_exact(&envelope.data)?; let message = match mls_message_in.extract() { @@ -344,13 +309,9 @@ where )), }?; - log::info!("process message 2 {}", self.client.account_address()); - let intent = provider .conn() .find_group_intent_by_payload_hash(sha256(envelope.data.as_slice())); - - log::info!("process message 3 {}", self.client.account_address()); match intent { // Intent with the payload hash matches Ok(Some(intent)) => self.process_own_message( @@ -361,30 +322,15 @@ where envelope.created_ns, allow_epoch_increment, ), - Err(err) => { - log::info!("Storage error {}", self.client.account_address()); - Err(MessageProcessingError::Storage(err)) - } + Err(err) => Err(MessageProcessingError::Storage(err)), // No matching intent found - Ok(None) => { - log::info!( - "Calling process_external_message {}", - self.client.account_address() - ); - let res = self.process_external_message( - openmls_group, - provider, - message, - envelope.created_ns, - allow_epoch_increment, - ); - log::info!( - "process_external_message return {}", - self.client.account_address() - ); - log::info!("process_external_message result {:?}", res); - res - } + Ok(None) => self.process_external_message( + openmls_group, + provider, + message, + envelope.created_ns, + allow_epoch_increment, + ), } } @@ -456,22 +402,13 @@ where timestamp_ns: u64, ) -> Result, MessageProcessingError> { let mut transcript_message = None; - log::info!("Transcript {}", self.client.account_address()); if let Some(validated_commit) = maybe_validated_commit { - log::info!("Transcript 2 {}", self.client.account_address()); // If there are no members added or removed, don't write a transcript message if validated_commit.members_added.is_empty() && validated_commit.members_removed.is_empty() { - log::info!("Transcript 3 {}", self.client.account_address()); return Ok(None); } - log::info!( - "Storing a transcript message with {} members added and {} members removed for address {}", - validated_commit.members_added.len(), - validated_commit.members_removed.len(), - self.client.account_address() - ); let sender_installation_id = validated_commit.actor_installation_id(); let sender_account_address = validated_commit.actor_account_address(); let payload: GroupMembershipChanges = validated_commit.into(); diff --git a/xmtp_mls/src/subscriptions.rs b/xmtp_mls/src/subscriptions.rs index cf08e716b..d87b37cfc 100644 --- a/xmtp_mls/src/subscriptions.rs +++ b/xmtp_mls/src/subscriptions.rs @@ -15,11 +15,10 @@ use crate::{ api_client_wrapper::GroupFilter, client::{extract_welcome_message, ClientError}, groups::{extract_group_id, GroupError, MlsGroup}, - storage::group_message::StoredGroupMessage, + storage::{group::StoredGroup, group_message::StoredGroupMessage}, Client, }; -// TODO simplify FfiStreamCloser + StreamCloser duplication pub struct StreamCloser { pub close_fn: Arc>>>, pub is_closed_atomic: Arc, @@ -41,16 +40,23 @@ impl StreamCloser { self.is_closed_atomic.load(Ordering::Relaxed) } } +struct MessagesStreamInfo { + group: StoredGroup, + cursor: u64, +} -#[derive(Clone)] -pub(crate) struct MessagesStreamInfo { - pub convo_created_at_ns: i64, - pub cursor: u64, +impl Clone for MessagesStreamInfo { + fn clone(&self) -> Self { + Self { + group: self.group.clone(), + cursor: self.cursor, + } + } } impl<'a, ApiClient> Client where - ApiClient: XmtpMlsClient, + ApiClient: XmtpMlsClient + 'static, { fn process_streamed_welcome( &self, @@ -82,7 +88,6 @@ where let stream = subscription .map(|welcome_result| async { - log::info!("Received conversation streaming payload"); let welcome = welcome_result?; self.process_streamed_welcome(welcome) }) @@ -98,74 +103,15 @@ where Ok(Box::pin(stream)) } - - pub(crate) async fn stream_messages( - &'a self, - group_id_to_info: HashMap, MessagesStreamInfo>, - ) -> Result + Send + 'a>>, ClientError> { - let filters: Vec = group_id_to_info - .iter() - .map(|(group_id, info)| GroupFilter::new(group_id.clone(), Some(info.cursor))) - .collect(); - let messages_subscription = self.api_client.subscribe_group_messages(filters).await?; - let stream = messages_subscription - .map(move |res| { - let group_id_to_info = group_id_to_info.clone(); - async move { - match res { - Ok(envelope) => { - log::info!("Received message streaming payload"); - let group_id = extract_group_id(&envelope)?; - let stream_info = group_id_to_info.get(&group_id).ok_or( - ClientError::StreamInconsistency( - "Received message for a non-subscribed group".to_string(), - ), - )?; - log::info!("Processing stream entry {}", self.account_address()); - // TODO update cursor - let result = - MlsGroup::new(self, group_id, stream_info.convo_created_at_ns) - .process_stream_entry(envelope) - .await; - log::info!("Finished processing stream entry"); - result - } - Err(err) => { - log::info!("Got API error"); - Err(GroupError::Api(err)) - } - } - } - }) - .filter_map(move |res| async { - match res.await { - Ok(Some(message)) => { - log::info!("Message processed successfully"); - Some(message) - } - Ok(None) => { - log::info!("Skipped message streaming payload"); - None - } - Err(err) => { - log::info!("Error processing stream entry: {:?}", err); - None - } - } - }); - - Ok(Box::pin(stream)) - } } impl Client where - ApiClient: XmtpMlsClient, + ApiClient: XmtpMlsClient + 'static, { pub fn stream_conversations_with_callback( client: Arc>, - mut convo_callback: impl FnMut(MlsGroup) + Send + 'static, - mut on_close_callback: impl FnMut() + Send + 'static, + callback: impl Fn(MlsGroup) + Send + 'static, ) -> Result { let (close_sender, close_receiver) = oneshot::channel::<()>(); let is_closed = Arc::new(AtomicBool::new(false)); @@ -178,12 +124,11 @@ where tokio::select! { item = stream.next() => { match item { - Some(convo) => { convo_callback(convo) }, + Some(convo) => { callback(convo) }, None => break } } _ = &mut close_receiver => { - on_close_callback(); break; } } @@ -198,124 +143,77 @@ where }) } - pub(crate) fn stream_messages_with_callback( + pub async fn stream_all_messages( client: Arc>, - group_id_to_info: HashMap, MessagesStreamInfo>, - mut callback: impl FnMut(StoredGroupMessage) + Send + 'static, - ) -> Result { - let (close_sender, close_receiver) = oneshot::channel::<()>(); - let is_closed = Arc::new(AtomicBool::new(false)); - - let is_closed_clone = is_closed.clone(); - tokio::spawn(async move { - let mut stream = Self::stream_messages(client.as_ref(), group_id_to_info) - .await - .unwrap(); - let mut close_receiver = close_receiver; - loop { - tokio::select! { - item = stream.next() => { - match item { - Some(message) => callback(message), - None => break - } - } - _ = &mut close_receiver => { - break; - } - } - } - is_closed_clone.store(true, Ordering::Relaxed); - log::info!("closing stream"); - }); - - Ok(StreamCloser { - close_fn: Arc::new(Mutex::new(Some(close_sender))), - is_closed_atomic: is_closed, - }) - } - - pub async fn stream_all_messages_with_callback( - client: Arc>, - callback: impl FnMut(StoredGroupMessage) + Send + Sync + 'static, - ) -> Result { - let callback = Arc::new(Mutex::new(callback)); - + ) -> Result + Send>>, ClientError> { client.sync_welcomes().await?; // TODO pipe cursor from welcomes sync into groups_stream - let mut group_id_to_info: HashMap, MessagesStreamInfo> = client + let group_id_to_info: HashMap, MessagesStreamInfo> = client .store .conn()? .find_groups(None, None, None, None)? .into_iter() - .map(|group| { - ( - group.id.clone(), - MessagesStreamInfo { - convo_created_at_ns: group.created_at_ns, - cursor: 0, - }, - ) - }) + .map(|group| (group.id.clone(), MessagesStreamInfo { group, cursor: 0 })) .collect(); + // let groups_stream = Self::stream_conversations_with_callback(client.clone(), |convo| {}); + // TODO update messages_stream based on groups_stream - let callback_clone = callback.clone(); - let messages_stream_closer_mutex = - Arc::new(Mutex::new(Self::stream_messages_with_callback( - client.clone(), - group_id_to_info.clone(), - move |message| callback_clone.lock().unwrap()(message), // TODO fix unwrap - )?)); - let messages_stream_closer_mutex_clone = messages_stream_closer_mutex.clone(); - let groups_stream_closer = Self::stream_conversations_with_callback( - client.clone(), - move |convo| { - // TODO make sure key comparison works correctly - if group_id_to_info.contains_key(&convo.group_id) { - return; + let filters: Vec = group_id_to_info + .iter() + .map(|(group_id, info)| GroupFilter::new(group_id.clone(), Some(info.cursor))) + .collect(); + let messages_subscription = client.api_client.subscribe_group_messages(filters).await?; + let group_id_to_info_clone = group_id_to_info.clone(); + let client_clone = client.clone(); + let stream = messages_subscription + .map(move |res| { + let group_id_to_info_clone = group_id_to_info_clone.clone(); + let client_clone = client_clone.clone(); + async move { + match res { + Ok(envelope) => { + let group_id = extract_group_id(&envelope)?; + let stored_group = &group_id_to_info_clone + .get(&group_id) + .ok_or(ClientError::StreamInconsistency( + "Received message for a non-subscribed group".to_string(), + ))? + .group; + MlsGroup::new( + client_clone.as_ref(), + group_id, + stored_group.created_at_ns, + ) + .process_stream_entry(envelope) + .await + } + Err(err) => Err(GroupError::Api(err)), + } } - // Close existing message stream - // TODO remove unwrap - let mut messages_stream_closer = messages_stream_closer_mutex.lock().unwrap(); - messages_stream_closer.end(); - - // Set up new stream. For existing groups, stream new messages only by unsetting the cursor - for info in group_id_to_info.values_mut() { - info.cursor = 0; + }) + .filter_map(move |res| async { + match res.await { + Ok(Some(message)) => Some(message), + Ok(None) => None, + Err(err) => { + log::error!("Error processing stream entry: {:?}", err); + None + } } - group_id_to_info.insert( - convo.group_id, - MessagesStreamInfo { - convo_created_at_ns: convo.created_at_ns, - cursor: 1, // For the new group, stream all messages since the group was created - }, - ); - - // Open new message stream - let callback_clone = callback.clone(); - *messages_stream_closer = Self::stream_messages_with_callback( - client.clone(), - group_id_to_info.clone(), - move |message| callback_clone.lock().unwrap()(message), // TODO fix unwrap - ) - .unwrap(); // TODO fix unwrap - }, - move || { - messages_stream_closer_mutex_clone.lock().unwrap().end(); - }, - )?; + }); - Ok(groups_stream_closer) + Ok(Box::pin(stream)) } } #[cfg(test)] mod tests { - use crate::{builder::ClientBuilder, storage::group_message::StoredGroupMessage, Client}; + use std::sync::Arc; + use futures::StreamExt; - use std::sync::{Arc, Mutex}; - use xmtp_api_grpc::grpc_api_helper::Client as GrpcClient; use xmtp_cryptography::utils::generate_local_wallet; + use crate::{builder::ClientBuilder, Client}; + #[tokio::test] async fn test_stream_welcomes() { let alice = ClientBuilder::new_test_client(&generate_local_wallet()).await; @@ -350,18 +248,8 @@ mod tests { .add_members_by_installation_id(vec![caro.installation_public_key()]) .await .unwrap(); - tokio::time::sleep(std::time::Duration::from_millis(100)).await; - let messages: Arc>> = Arc::new(Mutex::new(Vec::new())); - let messages_clone = messages.clone(); - let stream = Client::::stream_all_messages_with_callback( - Arc::new(caro), - move |message| { - (*messages_clone.lock().unwrap()).push(message); - }, - ) - .await - .unwrap(); + let mut stream = Client::stream_all_messages(Arc::new(caro)).await.unwrap(); tokio::time::sleep(std::time::Duration::from_millis(50)).await; alix_group.send_message("first".as_bytes()).await.unwrap(); @@ -370,82 +258,21 @@ mod tests { bo_group.send_message("fourth".as_bytes()).await.unwrap(); tokio::time::sleep(std::time::Duration::from_millis(200)).await; - let messages = messages.lock().unwrap(); - assert_eq!(messages[0].decrypted_message_bytes, "first".as_bytes()); - assert_eq!(messages[1].decrypted_message_bytes, "second".as_bytes()); - assert_eq!(messages[2].decrypted_message_bytes, "third".as_bytes()); - assert_eq!(messages[3].decrypted_message_bytes, "fourth".as_bytes()); - - stream.end(); - } - - #[tokio::test(flavor = "multi_thread", worker_threads = 10)] - async fn test_stream_all_messages_changing_group_list() { - let alix = ClientBuilder::new_test_client(&generate_local_wallet()).await; - let bo = ClientBuilder::new_test_client(&generate_local_wallet()).await; - let caro = Arc::new(ClientBuilder::new_test_client(&generate_local_wallet()).await); - - let alix_group = alix.create_group(None).unwrap(); - alix_group - .add_members_by_installation_id(vec![caro.installation_public_key()]) - .await - .unwrap(); - - tokio::time::sleep(std::time::Duration::from_millis(100)).await; - - let messages: Arc>> = Arc::new(Mutex::new(Vec::new())); - let messages_clone = messages.clone(); - let stream = - Client::::stream_all_messages_with_callback(caro.clone(), move |message| { - let text = String::from_utf8(message.decrypted_message_bytes.clone()) - .unwrap_or("".to_string()); - println!("Received: {}", text); - (*messages_clone.lock().unwrap()).push(message); - }) - .await - .unwrap(); - - tokio::time::sleep(std::time::Duration::from_millis(50)).await; - - alix_group.send_message("first".as_bytes()).await.unwrap(); - tokio::time::sleep(std::time::Duration::from_millis(100)).await; - - let bo_group = bo.create_group(None).unwrap(); - bo_group - .add_members_by_installation_id(vec![caro.installation_public_key()]) - .await - .unwrap(); - tokio::time::sleep(std::time::Duration::from_millis(300)).await; - - bo_group.send_message("second".as_bytes()).await.unwrap(); - tokio::time::sleep(std::time::Duration::from_millis(100)).await; - - alix_group.send_message("third".as_bytes()).await.unwrap(); - tokio::time::sleep(std::time::Duration::from_millis(100)).await; - - let alix_group_2 = alix.create_group(None).unwrap(); - alix_group_2 - .add_members_by_installation_id(vec![caro.installation_public_key()]) - .await - .unwrap(); - tokio::time::sleep(std::time::Duration::from_millis(300)).await; - - alix_group.send_message("fourth".as_bytes()).await.unwrap(); - tokio::time::sleep(std::time::Duration::from_millis(100)).await; - alix_group_2.send_message("fifth".as_bytes()).await.unwrap(); - - tokio::time::sleep(std::time::Duration::from_millis(100)).await; - - let messages = messages.lock().unwrap(); - assert_eq!(messages.len(), 5); - - stream.end(); - tokio::time::sleep(std::time::Duration::from_millis(50)).await; - assert!(stream.is_closed()); - - alix_group.send_message("first".as_bytes()).await.unwrap(); - tokio::time::sleep(std::time::Duration::from_millis(100)).await; - - assert_eq!(messages.len(), 5); + assert_eq!( + stream.next().await.unwrap().decrypted_message_bytes, + "first".as_bytes() + ); + assert_eq!( + stream.next().await.unwrap().decrypted_message_bytes, + "second".as_bytes() + ); + assert_eq!( + stream.next().await.unwrap().decrypted_message_bytes, + "third".as_bytes() + ); + assert_eq!( + stream.next().await.unwrap().decrypted_message_bytes, + "fourth".as_bytes() + ); } } diff --git a/xmtp_proto/src/api_client.rs b/xmtp_proto/src/api_client.rs index ca164fda0..f49fa3500 100644 --- a/xmtp_proto/src/api_client.rs +++ b/xmtp_proto/src/api_client.rs @@ -132,7 +132,7 @@ pub type WelcomeMessageStream = Pin