Skip to content

Commit

Permalink
Revert auto-merged PR "[2/n] stream_all_messages that detects message…
Browse files Browse the repository at this point in the history
…s in new groups (#533)"

This reverts commit ac22179.
  • Loading branch information
richardhuaaa committed Feb 26, 2024
1 parent ac22179 commit 3f9ef58
Show file tree
Hide file tree
Showing 5 changed files with 202 additions and 475 deletions.
176 changes: 77 additions & 99 deletions bindings_ffi/src/mls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@ pub use crate::inbox_owner::SigningError;
use crate::logger::init_logger;
use crate::logger::FfiLogger;
use crate::GenericError;
use futures::StreamExt;
use std::convert::TryInto;
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc, Mutex,
};
use tokio::sync::oneshot::Sender;
use tokio::sync::{oneshot, oneshot::Sender};
use xmtp_api_grpc::grpc_api_helper::Client as TonicApiClient;
use xmtp_mls::builder::IdentityStrategy;
use xmtp_mls::builder::LegacyIdentity;
Expand Down Expand Up @@ -256,17 +257,14 @@ impl FfiConversations {
callback: Box<dyn FfiConversationCallback>,
) -> Result<Arc<FfiStreamCloser>, GenericError> {
let client = self.inner_client.clone();
let stream_closer = RustXmtpClient::stream_conversations_with_callback(
client.clone(),
move |convo| {
let stream_closer =
RustXmtpClient::stream_conversations_with_callback(client.clone(), move |convo| {
callback.on_conversation(Arc::new(FfiGroup {
inner_client: client.clone(),
group_id: convo.group_id,
created_at_ns: convo.created_at_ns,
}))
},
|| {}, // on_close_callback
)?;
})?;

Ok(Arc::new(FfiStreamCloser {
close_fn: stream_closer.close_fn,
Expand All @@ -278,15 +276,36 @@ impl FfiConversations {
&self,
message_callback: Box<dyn FfiMessageCallback>,
) -> Result<Arc<FfiStreamCloser>, GenericError> {
let stream_closer = RustXmtpClient::stream_all_messages_with_callback(
self.inner_client.clone(),
move |message| message_callback.on_message(message.into()),
)
.await?;
let inner_client = Arc::clone(&self.inner_client);
let (close_sender, close_receiver) = oneshot::channel::<()>();
let is_closed = Arc::new(AtomicBool::new(false));
let is_closed_clone = is_closed.clone();

tokio::spawn(async move {
let mut stream = RustXmtpClient::stream_all_messages(inner_client)
.await
.unwrap();
let mut close_receiver = close_receiver;
loop {
tokio::select! {
item = stream.next() => {
match item {
Some(message) => message_callback.on_message(message.into()),
None => break
}
}
_ = &mut close_receiver => {
break;
}
}
}
is_closed_clone.store(true, Ordering::Relaxed);
log::info!("closing stream");
});

Ok(Arc::new(FfiStreamCloser {
close_fn: stream_closer.close_fn,
is_closed_atomic: stream_closer.is_closed_atomic,
close_fn: Arc::new(Mutex::new(Some(close_sender))),
is_closed_atomic: is_closed,
}))
}
}
Expand Down Expand Up @@ -406,17 +425,37 @@ impl FfiGroup {
message_callback: Box<dyn FfiMessageCallback>,
) -> Result<Arc<FfiStreamCloser>, GenericError> {
let inner_client = Arc::clone(&self.inner_client);
let stream_closer = MlsGroup::stream_with_callback(
inner_client,
self.group_id.clone(),
self.created_at_ns,
move |message| message_callback.on_message(message.into()),
)
.await?;
let group_id = self.group_id.clone();
let created_at_ns = self.created_at_ns;
let (close_sender, close_receiver) = oneshot::channel::<()>();
let is_closed = Arc::new(AtomicBool::new(false));
let is_closed_clone = is_closed.clone();

tokio::spawn(async move {
let client = inner_client.as_ref();
let group = MlsGroup::new(&client, group_id, created_at_ns);
let mut stream = group.stream().await.unwrap();
let mut close_receiver = close_receiver;
loop {
tokio::select! {
item = stream.next() => {
match item {
Some(message) => message_callback.on_message(message.into()),
None => break
}
}
_ = &mut close_receiver => {
break;
}
}
}
is_closed_clone.store(true, Ordering::Relaxed);
log::info!("closing stream");
});

Ok(Arc::new(FfiStreamCloser {
close_fn: stream_closer.close_fn,
is_closed_atomic: stream_closer.is_closed_atomic,
close_fn: Arc::new(Mutex::new(Some(close_sender))),
is_closed_atomic: is_closed,
}))
}

Expand Down Expand Up @@ -605,9 +644,7 @@ mod tests {
}

impl FfiMessageCallback for RustStreamCallback {
fn on_message(&self, message: FfiMessage) {
let message = String::from_utf8(message.content).unwrap_or("<not UTF8>".to_string());
log::info!("Received: {}", message);
fn on_message(&self, _: FfiMessage) {
*self.num_messages.lock().unwrap() += 1;
}
}
Expand Down Expand Up @@ -915,7 +952,7 @@ mod tests {
}

#[tokio::test(flavor = "multi_thread", worker_threads = 5)]
async fn test_stream_all_messages() {
async fn test_stream_all_messages_unchanging_group_list() {
let alix = new_test_client().await;
let bo = new_test_client().await;
let caro = new_test_client().await;
Expand All @@ -925,6 +962,15 @@ mod tests {
.create_group(vec![caro.account_address()], None)
.await
.unwrap();

tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;

let caro_group = caro
.conversations()
.create_group(vec![bo.account_address()], None)
.await
.unwrap();

tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;

let stream_callback = RustStreamCallback::new();
Expand All @@ -937,17 +983,11 @@ mod tests {

alix_group.send("first".as_bytes().to_vec()).await.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
let bo_group = bo
.conversations()
.create_group(vec![caro.account_address()], None)
.await
.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
bo_group.send("second".as_bytes().to_vec()).await.unwrap();
caro_group.send("second".as_bytes().to_vec()).await.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
alix_group.send("third".as_bytes().to_vec()).await.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
bo_group.send("fourth".as_bytes().to_vec()).await.unwrap();
caro_group.send("fourth".as_bytes().to_vec()).await.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;

assert_eq!(stream_callback.message_count(), 4);
Expand All @@ -956,8 +996,9 @@ mod tests {
assert!(stream.is_closed());
}

// Disabling this flakey test until it's reliable
#[tokio::test(flavor = "multi_thread", worker_threads = 5)]
async fn test_message_streaming() {
async fn test_streaming() {
let amal = new_test_client().await;
let bola = new_test_client().await;

Expand Down Expand Up @@ -985,67 +1026,4 @@ mod tests {

stream_closer.end();
}

#[tokio::test(flavor = "multi_thread", worker_threads = 5)]
async fn test_message_streaming_when_removed_then_added() {
let amal = new_test_client().await;
let bola = new_test_client().await;
log::info!(
"Created addresses {} and {}",
amal.account_address(),
bola.account_address()
);

let amal_group = amal
.conversations()
.create_group(vec![bola.account_address()], None)
.await
.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;

let stream_callback = RustStreamCallback::new();
let stream_closer = bola
.conversations()
.stream_all_messages(Box::new(stream_callback.clone()))
.await
.unwrap();

tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;

amal_group.send("hello1".as_bytes().to_vec()).await.unwrap();
amal_group.send("hello2".as_bytes().to_vec()).await.unwrap();

tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
assert_eq!(stream_callback.message_count(), 2);
assert!(!stream_closer.is_closed());

amal_group
.remove_members(vec![bola.account_address()])
.await
.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(60000)).await;
assert_eq!(stream_callback.message_count(), 3); // Member removal transcript message

amal_group.send("hello3".as_bytes().to_vec()).await.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
assert_eq!(stream_callback.message_count(), 3); // Don't receive messages while removed
assert!(!stream_closer.is_closed());

tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
amal_group
.add_members(vec![bola.account_address()])
.await
.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
assert_eq!(stream_callback.message_count(), 3); // Don't receive transcript messages while removed

amal_group.send("hello4".as_bytes().to_vec()).await.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
assert_eq!(stream_callback.message_count(), 4); // Receiving messages again
assert!(!stream_closer.is_closed());

stream_closer.end();
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
assert!(stream_closer.is_closed());
}
}
77 changes: 31 additions & 46 deletions xmtp_mls/src/groups/subscriptions.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;

use futures::Stream;
use futures::{Stream, StreamExt};

use xmtp_proto::{api_client::XmtpMlsClient, xmtp::mls::api::v1::GroupMessage};

use super::{extract_message_v1, GroupError, MlsGroup};
use crate::api_client_wrapper::GroupFilter;
use crate::storage::group_message::StoredGroupMessage;
use crate::subscriptions::{MessagesStreamInfo, StreamCloser};
use crate::Client;

impl<'c, ApiClient> MlsGroup<'c, ApiClient>
where
Expand All @@ -21,73 +18,61 @@ where
) -> Result<Option<StoredGroupMessage>, GroupError> {
let msgv1 = extract_message_v1(envelope)?;

log::info!("Running transaction {}", self.client.account_address());
let process_result = self.client.store.transaction(|provider| {
let mut openmls_group = self.load_mls_group(&provider)?;
// Attempt processing immediately, but fail if the message is not an Application Message
// Returning an error should roll back the DB tx
log::info!(
"Processing message in process_stream_entry {}",
self.client.account_address()
);
let res = self
.process_message(&mut openmls_group, &provider, &msgv1, false)
.map_err(GroupError::ReceiveError);
log::info!("Got process message result {:?}", res);
res
self.process_message(&mut openmls_group, &provider, &msgv1, false)
.map_err(GroupError::ReceiveError)
});
log::info!("Got process_result {:?}", process_result);

if let Some(GroupError::ReceiveError(_)) = process_result.err() {
log::info!("Re-syncing due to unreadable messaging stream payload");
self.sync().await?;
}

log::info!("Storing message");
// Load the message from the DB to handle cases where it may have been already processed in
// another thread
let new_message = self
.client
.store
.conn()?
.get_group_message_by_timestamp(&self.group_id, msgv1.created_ns as i64)?;
log::info!("Stored message");

Ok(new_message)
}

pub async fn stream(
&'c self,
) -> Result<Pin<Box<dyn Stream<Item = StoredGroupMessage> + 'c + Send>>, GroupError> {
Ok(self
let last_cursor = 0;

let subscription = self
.client
.stream_messages(HashMap::from([(
.api_client
.subscribe_group_messages(vec![GroupFilter::new(
self.group_id.clone(),
MessagesStreamInfo {
convo_created_at_ns: self.created_at_ns,
cursor: 0,
},
)]))
.await?)
}

pub async fn stream_with_callback(
client: Arc<Client<ApiClient>>,
group_id: Vec<u8>,
created_at_ns: i64,
mut callback: impl FnMut(StoredGroupMessage) + Send + 'static,
) -> Result<StreamCloser, GroupError> {
Ok(Client::<ApiClient>::stream_messages_with_callback(
client,
HashMap::from([(
group_id,
MessagesStreamInfo {
convo_created_at_ns: created_at_ns,
cursor: 0,
},
)]),
move |message| callback(message),
)?)
Some(last_cursor as u64),
)])
.await?;
let stream = subscription
.map(|res| async {
match res {
Ok(envelope) => self.process_stream_entry(envelope).await,
Err(err) => Err(GroupError::Api(err)),
}
})
.filter_map(move |res| async {
match res.await {
Ok(Some(message)) => Some(message),
Ok(None) => None,
Err(err) => {
log::error!("Error processing stream entry: {:?}", err);
None
}
}
});

Ok(Box::pin(stream))
}
}

Expand Down
Loading

0 comments on commit 3f9ef58

Please sign in to comment.