Skip to content

Commit

Permalink
[1/2] stream_all_messages: bindings and streaming messages from an un…
Browse files Browse the repository at this point in the history
…changing group list (#526)

Putting this up early, so that work on propagating this to SDK's can begin in parallel. This implements `stream_all_messages()`, but it doesn't pick up messages from any groups that get created after the method was called.

Some of the setup for picking up messages from *new* groups is in this PR, but that functionality won't be complete until the next PR (along with some cleanup). At that point, only the libxmtp binary will need to be updated - the binding interface won't be any different from the one in this PR.

Other changes:
1. Make a file called `subscriptions.rs` and move `process_streamed_welcome` and `stream_conversations` to it - `client.rs` was becoming a monster file.
2. Move the logic for streaming conversations with a callback from the `stream()` method in the bindings into a `stream_conversations_with_callback` method in xmtp_mls. The logic for running an infinite listener loop in a background thread, as well as closing the stream, will also be needed for the groups stream in `stream_all_messages`.
  • Loading branch information
richardhuaaa authored Feb 21, 2024
1 parent bfdda64 commit e460692
Show file tree
Hide file tree
Showing 6 changed files with 372 additions and 75 deletions.
76 changes: 69 additions & 7 deletions bindings_ffi/src/mls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,25 +255,42 @@ impl FfiConversations {
pub async fn stream(
&self,
callback: Box<dyn FfiConversationCallback>,
) -> Result<Arc<FfiStreamCloser>, GenericError> {
let client = self.inner_client.clone();
let stream_closer =
RustXmtpClient::stream_conversations_with_callback(client.clone(), move |convo| {
callback.on_conversation(Arc::new(FfiGroup {
inner_client: client.clone(),
group_id: convo.group_id,
created_at_ns: convo.created_at_ns,
}))
})?;

Ok(Arc::new(FfiStreamCloser {
close_fn: stream_closer.close_fn,
is_closed_atomic: stream_closer.is_closed_atomic,
}))
}

pub async fn stream_all_messages(
&self,
message_callback: Box<dyn FfiMessageCallback>,
) -> Result<Arc<FfiStreamCloser>, GenericError> {
let inner_client = Arc::clone(&self.inner_client);
let (close_sender, close_receiver) = oneshot::channel::<()>();
let is_closed = Arc::new(AtomicBool::new(false));
let is_closed_clone = is_closed.clone();

tokio::spawn(async move {
let client = inner_client.as_ref();
let mut stream = client.stream_conversations().await.unwrap();
let mut stream = RustXmtpClient::stream_all_messages(inner_client)
.await
.unwrap();
let mut close_receiver = close_receiver;
loop {
tokio::select! {
item = stream.next() => {
match item {
Some(convo) => callback.on_conversation(Arc::new(FfiGroup {
inner_client: inner_client.clone(),
group_id: convo.group_id,
created_at_ns: convo.created_at_ns,
})),
Some(message) => message_callback.on_message(message.into()),
None => break
}
}
Expand Down Expand Up @@ -934,6 +951,51 @@ mod tests {
assert!(stream.is_closed());
}

#[tokio::test(flavor = "multi_thread", worker_threads = 5)]
async fn test_stream_all_messages_unchanging_group_list() {
let alix = new_test_client().await;
let bo = new_test_client().await;
let caro = new_test_client().await;

let alix_group = alix
.conversations()
.create_group(vec![caro.account_address()], None)
.await
.unwrap();

tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;

let caro_group = caro
.conversations()
.create_group(vec![bo.account_address()], None)
.await
.unwrap();

tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;

let stream_callback = RustStreamCallback::new();
let stream = caro
.conversations()
.stream_all_messages(Box::new(stream_callback.clone()))
.await
.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;

alix_group.send("first".as_bytes().to_vec()).await.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
caro_group.send("second".as_bytes().to_vec()).await.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
alix_group.send("third".as_bytes().to_vec()).await.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
caro_group.send("fourth".as_bytes().to_vec()).await.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;

assert_eq!(stream_callback.message_count(), 4);
stream.end();
tokio::time::sleep(tokio::time::Duration::from_millis(5)).await;
assert!(stream.is_closed());
}

// Disabling this flakey test until it's reliable
#[tokio::test(flavor = "multi_thread", worker_threads = 5)]
async fn test_streaming() {
Expand Down
73 changes: 6 additions & 67 deletions xmtp_mls/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::{collections::HashSet, mem::Discriminant, pin::Pin};
use std::{collections::HashSet, mem::Discriminant};

use futures::{Stream, StreamExt};
use openmls::{
framing::{MlsMessageIn, MlsMessageInBody},
group::GroupEpoch,
Expand Down Expand Up @@ -69,6 +68,8 @@ pub enum ClientError {
KeyPackageVerification(#[from] KeyPackageVerificationError),
#[error("syncing errors: {0:?}")]
SyncingError(Vec<MessageProcessingError>),
#[error("Stream inconsistency error: {0}")]
StreamInconsistency(String),
#[error("generic:{0}")]
Generic(String),
}
Expand Down Expand Up @@ -468,55 +469,11 @@ where
})
.collect())
}

fn process_streamed_welcome(
&self,
welcome: WelcomeMessage,
) -> Result<MlsGroup<ApiClient>, ClientError> {
let welcome_v1 = extract_welcome_message(welcome)?;
let conn = self.store.conn()?;
let provider = self.mls_provider(&conn);

MlsGroup::create_from_encrypted_welcome(
self,
&provider,
welcome_v1.hpke_public_key.as_slice(),
welcome_v1.data,
)
.map_err(|e| ClientError::Generic(e.to_string()))
}

pub async fn stream_conversations(
&'a self,
) -> Result<Pin<Box<dyn Stream<Item = MlsGroup<ApiClient>> + Send + 'a>>, ClientError> {
let installation_key = self.installation_public_key();
let id_cursor = 0;

let subscription = self
.api_client
.subscribe_welcome_messages(installation_key, Some(id_cursor as u64))
.await?;

let stream = subscription
.map(|welcome_result| async {
let welcome = welcome_result?;
self.process_streamed_welcome(welcome)
})
.filter_map(|res| async {
match res.await {
Ok(group) => Some(group),
Err(err) => {
log::error!("Error processing stream entry: {:?}", err);
None
}
}
});

Ok(Box::pin(stream))
}
}

fn extract_welcome_message(welcome: WelcomeMessage) -> Result<WelcomeMessageV1, ClientError> {
pub(crate) fn extract_welcome_message(
welcome: WelcomeMessage,
) -> Result<WelcomeMessageV1, ClientError> {
match welcome.version {
Some(WelcomeMessageVersion::V1(welcome)) => Ok(welcome),
_ => Err(ClientError::Generic(
Expand Down Expand Up @@ -551,7 +508,6 @@ fn has_active_installation(updates: &Vec<IdentityUpdate>) -> bool {

#[cfg(test)]
mod tests {
use futures::StreamExt;
use xmtp_cryptography::utils::generate_local_wallet;

use crate::{
Expand Down Expand Up @@ -733,21 +689,4 @@ mod tests {
vec![1, 2, 3]
)
}

#[tokio::test]
async fn test_stream_welcomes() {
let alice = ClientBuilder::new_test_client(&generate_local_wallet()).await;
let bob = ClientBuilder::new_test_client(&generate_local_wallet()).await;

let alice_bob_group = alice.create_group(None).unwrap();

let mut bob_stream = bob.stream_conversations().await.unwrap();
alice_bob_group
.add_members(vec![bob.account_address()])
.await
.unwrap();

let bob_received_groups = bob_stream.next().await.unwrap();
assert_eq!(bob_received_groups.group_id, alice_bob_group.group_id);
}
}
17 changes: 17 additions & 0 deletions xmtp_mls/src/groups/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,16 @@ pub struct MlsGroup<'c, ApiClient> {
client: &'c Client<ApiClient>,
}

impl<'c, ApiClient> Clone for MlsGroup<'c, ApiClient> {
fn clone(&self) -> Self {
Self {
client: self.client,
group_id: self.group_id.clone(),
created_at_ns: self.created_at_ns,
}
}
}

impl<'c, ApiClient> MlsGroup<'c, ApiClient>
where
ApiClient: XmtpMlsClient,
Expand Down Expand Up @@ -356,6 +366,13 @@ fn extract_message_v1(message: GroupMessage) -> Result<GroupMessageV1, MessagePr
}
}

pub fn extract_group_id(message: &GroupMessage) -> Result<Vec<u8>, MessageProcessingError> {
match &message.version {
Some(GroupMessageVersion::V1(value)) => Ok(value.group_id.clone()),
_ => Err(MessageProcessingError::InvalidPayload),
}
}

fn validate_ed25519_keys(keys: &[Vec<u8>]) -> Result<(), GroupError> {
let mut invalid = keys
.iter()
Expand Down
2 changes: 1 addition & 1 deletion xmtp_mls/src/groups/subscriptions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ impl<'c, ApiClient> MlsGroup<'c, ApiClient>
where
ApiClient: XmtpMlsClient,
{
async fn process_stream_entry(
pub(crate) async fn process_stream_entry(
&self,
envelope: GroupMessage,
) -> Result<Option<StoredGroupMessage>, GroupError> {
Expand Down
1 change: 1 addition & 0 deletions xmtp_mls/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ pub mod identity;
pub mod owner;
pub mod retry;
pub mod storage;
pub mod subscriptions;
pub mod types;
pub mod utils;
pub mod verified_key_package;
Expand Down
Loading

0 comments on commit e460692

Please sign in to comment.