From fb4e7dd4b1801429a8ac5c4b8ed9aea5efd19cab Mon Sep 17 00:00:00 2001 From: Andrew Plaza Date: Fri, 20 Dec 2024 12:43:32 -0500 Subject: [PATCH] feat(wasm): unblock streams in the browser --- Cargo.lock | 4 +- Cargo.toml | 23 +- common/src/test.rs | 4 + xmtp_api_grpc/Cargo.toml | 2 +- xmtp_api_http/Cargo.toml | 5 +- xmtp_api_http/src/http_stream.rs | 231 +++++++++++++++++ xmtp_api_http/src/lib.rs | 4 +- xmtp_api_http/src/util.rs | 83 +----- xmtp_mls/Cargo.toml | 3 +- xmtp_mls/src/api/mls.rs | 12 +- xmtp_mls/src/storage/encrypted_store/group.rs | 53 ++++ .../mod.rs} | 189 +++++--------- xmtp_mls/src/subscriptions/stream_all.rs | 87 +++++++ .../src/subscriptions/stream_conversations.rs | 241 ++++++++++++++++++ 14 files changed, 714 insertions(+), 227 deletions(-) create mode 100644 xmtp_api_http/src/http_stream.rs rename xmtp_mls/src/{subscriptions.rs => subscriptions/mod.rs} (88%) create mode 100644 xmtp_mls/src/subscriptions/stream_all.rs create mode 100644 xmtp_mls/src/subscriptions/stream_conversations.rs diff --git a/Cargo.lock b/Cargo.lock index fce55ab9e..7ac00d962 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7305,9 +7305,10 @@ dependencies = [ name = "xmtp_api_http" version = "0.1.0" dependencies = [ - "async-stream", "async-trait", + "bytes", "futures", + "pin-project-lite", "reqwest 0.12.9", "serde", "serde_json", @@ -7483,6 +7484,7 @@ dependencies = [ "openssl", "openssl-sys", "parking_lot 0.12.3", + "pin-project-lite", "prost", "rand", "reqwest 0.12.9", diff --git a/Cargo.toml b/Cargo.toml index 77828087e..88b92f316 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,8 +37,7 @@ ctor = "0.2" ed25519 = "2.2.3" ed25519-dalek = { version = "2.1.1", features = ["zeroize"] } ethers = { version = "2.0", default-features = false } -futures = "0.3.30" -futures-core = "0.3.30" +futures = { version = "0.3.30", default-features = false } getrandom = { version = "0.2", default-features = false } hex = "0.4.3" hkdf = "0.12.3" @@ -62,16 +61,7 @@ tls_codec = "0.4.1" tokio = { version = "1.35.1", default-features = false } uuid = "1.10" vergen-git2 = "1.0.2" -wasm-timer = "0.2" web-time = "1.1" -# Changing this version and rustls may potentially break the android build. Use Caution. -# Test with Android and Swift first. -# Its probably preferable to one day use https://github.com/rustls/rustls-platform-verifier -# Until then, always test agains iOS/Android after updating these dependencies & making a PR -# Related Issues: -# - https://github.com/seanmonstar/reqwest/issues/2159 -# - https://github.com/hyperium/tonic/pull/1974 -# - https://github.com/rustls/rustls-platform-verifier/issues/58 bincode = "1.3" console_error_panic_hook = "0.1" const_format = "0.2" @@ -88,6 +78,14 @@ openssl = { version = "0.10", features = ["vendored"] } openssl-sys = { version = "0.9", features = ["vendored"] } parking_lot = "0.12.3" sqlite-web = "0.0.1" +# Changing this version and rustls may potentially break the android build. Use Caution. +# Test with Android and Swift first. +# Its probably preferable to one day use https://github.com/rustls/rustls-platform-verifier +# Until then, always test agains iOS/Android after updating these dependencies & making a PR +# Related Issues: +# - https://github.com/seanmonstar/reqwest/issues/2159 +# - https://github.com/hyperium/tonic/pull/1974 +# - https://github.com/rustls/rustls-platform-verifier/issues/58 tonic = { version = "0.12", default-features = false } tracing = { version = "0.1", features = ["log"] } tracing-subscriber = { version = "0.3", default-features = false } @@ -102,7 +100,8 @@ criterion = { version = "0.5", features = [ "html_reports", "async_tokio", ]} - once_cell = "1.2" +once_cell = "1.2" +pin-project-lite = "0.2" # Internal Crate Dependencies xmtp_api_grpc = { path = "xmtp_api_grpc" } diff --git a/common/src/test.rs b/common/src/test.rs index 4cfb2442d..e8ae377a6 100644 --- a/common/src/test.rs +++ b/common/src/test.rs @@ -108,6 +108,10 @@ pub fn rand_u64() -> u64 { crypto_utils::rng().gen() } +pub fn rand_i64() -> i64 { + crypto_utils::rng().gen() +} + #[cfg(not(target_arch = "wasm32"))] pub fn tmp_path() -> String { let db_name = crate::rand_string::<24>(); diff --git a/xmtp_api_grpc/Cargo.toml b/xmtp_api_grpc/Cargo.toml index b69a0d0c4..67ea6fb16 100644 --- a/xmtp_api_grpc/Cargo.toml +++ b/xmtp_api_grpc/Cargo.toml @@ -8,7 +8,7 @@ version.workspace = true async-stream.workspace = true async-trait = "0.1" base64.workspace = true -futures.workspace = true +futures = { workspace = true, features = ["alloc"] } hex.workspace = true prost = { workspace = true, features = ["prost-derive"] } tokio = { workspace = true, features = ["macros", "time"] } diff --git a/xmtp_api_http/Cargo.toml b/xmtp_api_http/Cargo.toml index b26a414a9..09a6a9214 100644 --- a/xmtp_api_http/Cargo.toml +++ b/xmtp_api_http/Cargo.toml @@ -8,16 +8,17 @@ license.workspace = true crate-type = ["cdylib", "rlib"] [dependencies] -async-stream.workspace = true futures = { workspace = true } tracing.workspace = true reqwest = { version = "0.12.5", features = ["json", "stream"] } serde = { workspace = true } serde_json = { workspace = true } -thiserror = "2.0" +thiserror.workspace = true tokio = { workspace = true, features = ["sync", "rt", "macros"] } xmtp_proto = { path = "../xmtp_proto", features = ["proto_full"] } async-trait = "0.1" +bytes = "1.9" +pin-project-lite = "0.2.15" [dev-dependencies] xmtp_proto = { path = "../xmtp_proto", features = ["test-utils"] } diff --git a/xmtp_api_http/src/http_stream.rs b/xmtp_api_http/src/http_stream.rs new file mode 100644 index 000000000..8e969b0c4 --- /dev/null +++ b/xmtp_api_http/src/http_stream.rs @@ -0,0 +1,231 @@ +//! Streams that work with HTTP POST requests + +use crate::util::GrpcResponse; +use futures::{ + stream::{self, Stream, StreamExt}, + Future, +}; +use reqwest::Response; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use serde_json::Deserializer; +use std::{marker::PhantomData, pin::Pin, task::Poll}; +use xmtp_proto::{Error, ErrorKind}; + +#[derive(Deserialize, Serialize, Debug)] +pub(crate) struct SubscriptionItem { + pub result: T, +} + +#[cfg(target_arch = "wasm32")] +pub type BytesStream = stream::LocalBoxStream<'static, Result>; + +// #[cfg(not(target_arch = "wasm32"))] +// pub type BytesStream = Pin> + Send>>; + +#[cfg(not(target_arch = "wasm32"))] +pub type BytesStream = stream::BoxStream<'static, Result>; + +pin_project_lite::pin_project! { + #[project = PostStreamProject] + enum HttpPostStream { + NotStarted{#[pin] fut: F}, + // `Reqwest::bytes_stream` returns `impl Stream` rather than a type generic, + // so we can't use a type generic here + // this makes wasm a bit tricky. + Started { + #[pin] http: BytesStream, + remaining: Vec, + _marker: PhantomData, + }, + } +} + +impl Stream for HttpPostStream +where + F: Future>, + for<'de> R: Send + Deserialize<'de>, +{ + type Item = Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + use std::task::Poll::*; + match self.as_mut().project() { + PostStreamProject::NotStarted { fut } => match fut.poll(cx) { + Ready(response) => { + let s = response.unwrap().bytes_stream(); + self.set(Self::started(s)); + self.as_mut().poll_next(cx) + } + Pending => { + cx.waker().wake_by_ref(); + Pending + } + }, + PostStreamProject::Started { + ref mut http, + ref mut remaining, + .. + } => { + let mut pinned = std::pin::pin!(http); + let next = pinned.as_mut().poll_next(cx); + Self::on_bytes(next, remaining, cx) + } + } + } +} + +impl HttpPostStream +where + R: Send, +{ + #[cfg(not(target_arch = "wasm32"))] + fn started( + http: impl Stream> + Send + 'static, + ) -> Self { + Self::Started { + http: http.boxed(), + remaining: Vec::new(), + _marker: PhantomData, + } + } + + #[cfg(target_arch = "wasm32")] + fn started(http: impl Stream> + 'static) -> Self { + Self::Started { + http: http.boxed_local(), + remaining: Vec::new(), + _marker: PhantomData, + } + } +} + +impl HttpPostStream +where + F: Future>, + for<'de> R: Deserialize<'de> + DeserializeOwned + Send, +{ + fn new(request: F) -> Self { + Self::NotStarted { fut: request } + } + + fn on_bytes( + p: Poll>>, + remaining: &mut Vec, + cx: &mut std::task::Context<'_>, + ) -> Poll::Item>> { + use futures::task::Poll::*; + match p { + Ready(Some(bytes)) => { + let bytes = bytes.map_err(|e| { + Error::new(ErrorKind::SubscriptionUpdateError).with(e.to_string()) + })?; + let bytes = &[remaining.as_ref(), bytes.as_ref()].concat(); + let de = Deserializer::from_slice(bytes); + let mut stream = de.into_iter::>(); + 'messages: loop { + tracing::debug!("Waiting on next response ..."); + let response = stream.next(); + let res = match response { + Some(Ok(GrpcResponse::Ok(response))) => Ok(response), + Some(Ok(GrpcResponse::SubscriptionItem(item))) => Ok(item.result), + Some(Ok(GrpcResponse::Err(e))) => { + Err(Error::new(ErrorKind::MlsError).with(e.message)) + } + Some(Err(e)) => { + if e.is_eof() { + *remaining = (&**bytes)[stream.byte_offset()..].to_vec(); + return Pending; + } else { + Err(Error::new(ErrorKind::MlsError).with(e.to_string())) + } + } + Some(Ok(GrpcResponse::Empty {})) => continue 'messages, + None => return Ready(None), + }; + return Ready(Some(res)); + } + } + Ready(None) => Ready(None), + Pending => { + cx.waker().wake_by_ref(); + Pending + } + } + } +} + +#[cfg(not(target_arch = "wasm32"))] +impl HttpPostStream +where + F: Future> + Unpin, + for<'de> R: Deserialize<'de> + DeserializeOwned + Send, +{ + /// Establish the initial HTTP Stream connection + fn establish(&mut self) -> () { + // we need to poll the future once to progress the future state & + // establish the initial POST request. + // It should always be pending + let noop_waker = futures::task::noop_waker(); + let mut cx = std::task::Context::from_waker(&noop_waker); + // let mut this = Pin::new(self); + let mut this = Pin::new(self); + let _ = this.poll_next_unpin(&mut cx); + } +} + +#[cfg(target_arch = "wasm32")] +impl HttpPostStream +where + F: Future>, + for<'de> R: Deserialize<'de> + DeserializeOwned + Send, +{ + fn establish(&mut self) -> () { + // we need to poll the future once to progress the future state & + // establish the initial POST request. + // It should always be pending + let noop_waker = futures::task::noop_waker(); + let mut cx = std::task::Context::from_waker(&noop_waker); + let mut this = unsafe { Pin::new_unchecked(self) }; + let _ = this.poll_next_unpin(&mut cx); + } +} + +#[cfg(target_arch = "wasm32")] +pub fn create_grpc_stream( + request: T, + endpoint: String, + http_client: reqwest::Client, +) -> stream::LocalBoxStream<'static, Result> { + create_grpc_stream_inner(request, endpoint, http_client).boxed_local() +} + +#[cfg(not(target_arch = "wasm32"))] +pub fn create_grpc_stream( + request: T, + endpoint: String, + http_client: reqwest::Client, +) -> stream::BoxStream<'static, Result> +where + T: Serialize + 'static, + R: DeserializeOwned + Send + 'static, +{ + create_grpc_stream_inner(request, endpoint, http_client).boxed() +} + +fn create_grpc_stream_inner( + request: T, + endpoint: String, + http_client: reqwest::Client, +) -> impl Stream> +where + T: Serialize + 'static, + R: DeserializeOwned + Send + 'static, +{ + let request = http_client.post(endpoint).json(&request).send(); + let mut http = HttpPostStream::new(request); + http.establish(); + http +} diff --git a/xmtp_api_http/src/lib.rs b/xmtp_api_http/src/lib.rs index 80489fb3c..8a3f972c4 100755 --- a/xmtp_api_http/src/lib.rs +++ b/xmtp_api_http/src/lib.rs @@ -1,11 +1,13 @@ #![warn(clippy::unwrap_used)] pub mod constants; +mod http_stream; mod util; use futures::stream; +use http_stream::create_grpc_stream; use reqwest::header; -use util::{create_grpc_stream, handle_error}; +use util::handle_error; use xmtp_proto::api_client::{ClientWithMetadata, XmtpIdentityClient}; use xmtp_proto::xmtp::identity::api::v1::{ GetIdentityUpdatesRequest as GetIdentityUpdatesV2Request, diff --git a/xmtp_api_http/src/util.rs b/xmtp_api_http/src/util.rs index 8a839fc56..34c878c4a 100644 --- a/xmtp_api_http/src/util.rs +++ b/xmtp_api_http/src/util.rs @@ -1,9 +1,5 @@ -use futures::{ - stream::{self, StreamExt}, - Stream, -}; +use crate::http_stream::SubscriptionItem; use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use serde_json::Deserializer; use std::io::Read; use xmtp_proto::{Error, ErrorKind}; @@ -23,11 +19,6 @@ pub(crate) struct ErrorResponse { details: Vec, } -#[derive(Deserialize, Serialize, Debug)] -pub(crate) struct SubscriptionItem { - pub result: T, -} - /// handle JSON response from gRPC, returning either /// the expected deserialized response object or a gRPC [`Error`] pub fn handle_error(reader: R) -> Result @@ -43,78 +34,6 @@ where } } -#[cfg(target_arch = "wasm32")] -pub fn create_grpc_stream< - T: Serialize + Send + 'static, - R: DeserializeOwned + Send + std::fmt::Debug + 'static, ->( - request: T, - endpoint: String, - http_client: reqwest::Client, -) -> stream::LocalBoxStream<'static, Result> { - create_grpc_stream_inner(request, endpoint, http_client).boxed_local() -} - -#[cfg(not(target_arch = "wasm32"))] -pub fn create_grpc_stream< - T: Serialize + Send + 'static, - R: DeserializeOwned + Send + std::fmt::Debug + 'static, ->( - request: T, - endpoint: String, - http_client: reqwest::Client, -) -> stream::BoxStream<'static, Result> { - create_grpc_stream_inner(request, endpoint, http_client).boxed() -} - -pub fn create_grpc_stream_inner< - T: Serialize + Send + 'static, - R: DeserializeOwned + Send + std::fmt::Debug + 'static, ->( - request: T, - endpoint: String, - http_client: reqwest::Client, -) -> impl Stream> { - async_stream::stream! { - let request = http_client - .post(endpoint) - .json(&request) - .send() - .await - .map_err(|e| Error::new(ErrorKind::MlsError).with(e))?; - - let mut remaining = vec![]; - for await bytes in request.bytes_stream() { - let bytes = bytes - .map_err(|e| Error::new(ErrorKind::SubscriptionUpdateError).with(e.to_string()))?; - let bytes = &[remaining.as_ref(), bytes.as_ref()].concat(); - let de = Deserializer::from_slice(bytes); - let mut stream = de.into_iter::>(); - 'messages: loop { - let response = stream.next(); - let res = match response { - Some(Ok(GrpcResponse::Ok(response))) => Ok(response), - Some(Ok(GrpcResponse::SubscriptionItem(item))) => Ok(item.result), - Some(Ok(GrpcResponse::Err(e))) => { - Err(Error::new(ErrorKind::MlsError).with(e.message)) - } - Some(Err(e)) => { - if e.is_eof() { - remaining = (&**bytes)[stream.byte_offset()..].to_vec(); - break 'messages; - } else { - Err(Error::new(ErrorKind::MlsError).with(e.to_string())) - } - } - Some(Ok(GrpcResponse::Empty {})) => continue 'messages, - None => break 'messages, - }; - yield res; - } - } - } -} - #[cfg(feature = "test-utils")] #[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)] #[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))] diff --git a/xmtp_mls/Cargo.toml b/xmtp_mls/Cargo.toml index f933391fd..506e0b698 100644 --- a/xmtp_mls/Cargo.toml +++ b/xmtp_mls/Cargo.toml @@ -49,7 +49,7 @@ async-stream.workspace = true async-trait.workspace = true bincode.workspace = true diesel_migrations.workspace = true -futures.workspace = true +futures = { workspace = true, features = ["alloc"] } hex.workspace = true hkdf.workspace = true openmls_rust_crypto = { workspace = true } @@ -70,6 +70,7 @@ tracing.workspace = true trait-variant.workspace = true xmtp_common.workspace = true zeroize.workspace = true +pin-project-lite.workspace = true # XMTP/Local xmtp_content_types = { path = "../xmtp_content_types" } diff --git a/xmtp_mls/src/api/mls.rs b/xmtp_mls/src/api/mls.rs index 3994cd8fa..86b206121 100644 --- a/xmtp_mls/src/api/mls.rs +++ b/xmtp_mls/src/api/mls.rs @@ -274,10 +274,10 @@ where Ok(()) } - pub async fn subscribe_group_messages( - &self, + pub(crate) async fn subscribe_group_messages<'a>( + &'a self, filters: Vec, - ) -> Result> + '_, ApiError> + ) -> Result<::GroupMessageStream<'a>, ApiError> where ApiClient: XmtpMlsStreams, { @@ -289,11 +289,11 @@ where .await } - pub async fn subscribe_welcome_messages( - &self, + pub(crate) async fn subscribe_welcome_messages<'a>( + &'a self, installation_key: &[u8], id_cursor: Option, - ) -> Result> + '_, ApiError> + ) -> Result<::WelcomeMessageStream<'a>, ApiError> where ApiClient: XmtpMlsStreams, { diff --git a/xmtp_mls/src/storage/encrypted_store/group.rs b/xmtp_mls/src/storage/encrypted_store/group.rs index f8307ca3d..02b51e8ca 100644 --- a/xmtp_mls/src/storage/encrypted_store/group.rs +++ b/xmtp_mls/src/storage/encrypted_store/group.rs @@ -478,6 +478,22 @@ impl DbConnection { Ok(stored_group) } + + /// Get all the welcome ids turned into groups + pub(crate) fn group_welcome_ids(&self) -> Result, StorageError> { + self.raw_query(|conn| { + Ok::<_, StorageError>( + dsl::groups + .filter(dsl::welcome_id.is_not_null()) + .select(dsl::welcome_id) + .load::>(conn)? + .into_iter() + .map(|id| id.expect("SQL explicity filters for none")) + .collect(), + ) + }) + .map_err(Into::into) + } } #[repr(i32)] @@ -619,6 +635,25 @@ pub(crate) mod tests { ) } + /// Generate a test group with welcome + pub fn generate_group_with_welcome( + state: Option, + welcome_id: Option, + ) -> StoredGroup { + let id = rand_vec::<24>(); + let created_at_ns = now_ns(); + let membership_state = state.unwrap_or(GroupMembershipState::Allowed); + StoredGroup::new_from_welcome( + id, + created_at_ns, + membership_state, + "placeholder_address".to_string(), + welcome_id.unwrap_or(xmtp_common::rand_i64()), + ConversationType::Group, + None, + ) + } + /// Generate a test consent fn generate_consent_record( entity_type: ConsentType, @@ -952,4 +987,22 @@ pub(crate) mod tests { }) .await } + + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr(not(target_arch = "wasm32"), tokio::test)] + async fn test_get_group_welcome_ids() { + with_connection(|conn| { + let mls_groups = vec![ + generate_group_with_welcome(None, Some(30)), + generate_group(None), + generate_group(None), + generate_group_with_welcome(None, Some(10)), + ]; + for g in mls_groups.iter() { + g.store(conn).unwrap(); + } + assert_eq!(vec![30, 10], conn.group_welcome_ids().unwrap()); + }) + .await + } } diff --git a/xmtp_mls/src/subscriptions.rs b/xmtp_mls/src/subscriptions/mod.rs similarity index 88% rename from xmtp_mls/src/subscriptions.rs rename to xmtp_mls/src/subscriptions/mod.rs index 97f538504..e9b8dcad0 100644 --- a/xmtp_mls/src/subscriptions.rs +++ b/xmtp_mls/src/subscriptions/mod.rs @@ -10,6 +10,9 @@ use tracing::instrument; use xmtp_id::scw_verifier::SmartContractSignatureVerifier; use xmtp_proto::{api_client::XmtpMlsStreams, xmtp::mls::api::v1::WelcomeMessage}; +// mod stream_all; +// mod stream_conversations; + use crate::{ client::{extract_welcome_message, ClientError}, groups::{ @@ -454,9 +457,9 @@ where futures::pin_mut!(messages_stream); let convo_stream = self.stream_conversations(conversation_type).await?; - futures::pin_mut!(convo_stream); + tracing::info!("\n\n Waiting on messages \n\n"); let mut extra_messages = Vec::new(); loop { @@ -609,6 +612,45 @@ pub(crate) mod tests { use xmtp_cryptography::utils::generate_local_wallet; use xmtp_id::InboxOwner; + /// A macro for asserting that a stream yields a specific decrypted message. + /// + /// # Example + /// ```rust + /// assert_msg!(stream, b"first"); + /// ``` + #[macro_export] + macro_rules! assert_msg { + ($stream:expr, $expected:expr) => { + assert_eq!( + $stream + .next() + .await + .unwrap() + .unwrap() + .decrypted_message_bytes, + $expected.as_bytes() + ); + }; + } + + /// A macro for asserting that a stream yields a specific decrypted message. + /// + /// # Example + /// ```rust + /// assert_msg!(stream, b"first"); + /// ``` + #[macro_export] + macro_rules! assert_msg_exists { + ($stream:expr) => { + assert!(!$stream + .next() + .await + .unwrap() + .unwrap() + .decrypted_message_bytes + .is_empty()); + }; + } #[wasm_bindgen_test(unsupported = tokio::test(flavor = "multi_thread", worker_threads = 10))] async fn test_stream_welcomes() { let alice = Arc::new(ClientBuilder::new_test_client(&generate_local_wallet()).await); @@ -617,23 +659,8 @@ pub(crate) mod tests { .create_group(None, GroupMetadataOptions::default()) .unwrap(); - // FIXME:insipx we run into an issue where the reqwest::post().send() request - // blocks the executor and we cannot progress the runtime if we dont `tokio::spawn` this. - // A solution might be to use `hyper` instead, and implement a custom connection pool with - // `deadpool`. This is a bit more work but shouldn't be too complicated since - // we're only using `post` requests. It would be nice for all streams to work - // w/o spawning a separate task. - let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); - let mut stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx); - let bob_ptr = bob.clone(); - crate::spawn(None, async move { - let bob_stream = bob_ptr.stream_conversations(None).await.unwrap(); - futures::pin_mut!(bob_stream); - while let Some(item) = bob_stream.next().await { - let _ = tx.send(item); - } - }); - + let stream = bob.stream_conversations(None).await.unwrap(); + futures::pin_mut!(stream); let group_id = alice_bob_group.group_id.clone(); alice_bob_group .add_members_by_inbox_id(&[bob.inbox_id()]) @@ -644,7 +671,7 @@ pub(crate) mod tests { assert_eq!(bob_received_groups.group_id, group_id); } - #[wasm_bindgen_test(unsupported = tokio::test(flavor = "multi_thread", worker_threads = 10))] + #[wasm_bindgen_test(unsupported = tokio::test(flavor = "current_thread"))] async fn test_stream_messages() { xmtp_common::logger(); let alice = Arc::new(ClientBuilder::new_test_client(&generate_local_wallet()).await); @@ -653,6 +680,7 @@ pub(crate) mod tests { let alice_group = alice .create_group(None, GroupMetadataOptions::default()) .unwrap(); + tracing::info!("Group Id = [{}]", hex::encode(&alice_group.group_id)); alice_group .add_members_by_inbox_id(&[bob.inbox_id()]) @@ -664,33 +692,16 @@ pub(crate) mod tests { .unwrap(); let bob_group = bob_groups.first().unwrap(); - let notify = Delivery::new(None); - let notify_ptr = notify.clone(); - let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); - crate::spawn(None, async move { - let stream = alice_group.stream().await.unwrap(); - futures::pin_mut!(stream); - while let Some(item) = stream.next().await { - let _ = tx.send(item); - notify_ptr.notify_one(); - } - }); - - let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx); - // let stream = alice_group.stream().await.unwrap(); + let stream = alice_group.stream().await.unwrap(); futures::pin_mut!(stream); bob_group.send_message(b"hello").await.unwrap(); - tracing::debug!("Bob Sent Message!, waiting for delivery"); - // notify.wait_for_delivery().await.unwrap(); + let message = stream.next().await.unwrap().unwrap(); assert_eq!(message.decrypted_message_bytes, b"hello"); bob_group.send_message(b"hello2").await.unwrap(); - // notify.wait_for_delivery().await.unwrap(); let message = stream.next().await.unwrap().unwrap(); assert_eq!(message.decrypted_message_bytes, b"hello2"); - - // assert_eq!(bob_received_groups.group_id, alice_bob_group.group_id); } #[wasm_bindgen_test(unsupported = tokio::test(flavor = "multi_thread", worker_threads = 10))] @@ -714,40 +725,20 @@ pub(crate) mod tests { .add_members_by_inbox_id(&[caro.inbox_id()]) .await .unwrap(); - xmtp_common::time::sleep(core::time::Duration::from_millis(100)).await; - let messages: Arc>> = Arc::new(Mutex::new(Vec::new())); - let messages_clone = messages.clone(); + let stream = caro.stream_all_messages(None).await.unwrap(); + futures::pin_mut!(stream); + bo_group.send_message(b"first").await.unwrap(); + assert_msg!(stream, "first"); - let notify = Delivery::new(None); - let notify_pointer = notify.clone(); - let mut handle = Client::::stream_all_messages_with_callback( - Arc::new(caro), - None, - move |message| { - (*messages_clone.lock()).push(message.unwrap()); - notify_pointer.notify_one(); - }, - ); - handle.wait_for_ready().await; + bo_group.send_message(b"second").await.unwrap(); + assert_msg!(stream, "second"); - alix_group.send_message("first".as_bytes()).await.unwrap(); - notify - .wait_for_delivery() - .await - .expect("didn't get `first`"); - bo_group.send_message("second".as_bytes()).await.unwrap(); - notify.wait_for_delivery().await.unwrap(); - alix_group.send_message("third".as_bytes()).await.unwrap(); - notify.wait_for_delivery().await.unwrap(); - bo_group.send_message("fourth".as_bytes()).await.unwrap(); - notify.wait_for_delivery().await.unwrap(); + alix_group.send_message(b"third").await.unwrap(); + assert_msg!(stream, "third"); - let messages = messages.lock(); - assert_eq!(messages[0].decrypted_message_bytes, b"first"); - assert_eq!(messages[1].decrypted_message_bytes, b"second"); - assert_eq!(messages[2].decrypted_message_bytes, b"third"); - assert_eq!(messages[3].decrypted_message_bytes, b"fourth"); + bo_group.send_message(b"fourth").await.unwrap(); + assert_msg!(stream, "fourth"); } #[wasm_bindgen_test(unsupported = tokio::test(flavor = "multi_thread", worker_threads = 10))] @@ -765,39 +756,21 @@ pub(crate) mod tests { .await .unwrap(); - let messages: Arc>> = Arc::new(Mutex::new(Vec::new())); - let messages_clone = messages.clone(); - let delivery = Delivery::new(None); - let delivery_pointer = delivery.clone(); - let mut handle = Client::::stream_all_messages_with_callback( - caro.clone(), - None, - move |message| { - delivery_pointer.notify_one(); - (*messages_clone.lock()).push(message.unwrap()); - }, - ); - handle.wait_for_ready().await; + let stream = caro.stream_all_messages(None).await.unwrap(); + futures::pin_mut!(stream); + tracing::info!("\n\nSENDING FIRST MESSAGE\n\n"); alix_group.send_message(b"first").await.unwrap(); - delivery - .wait_for_delivery() - .await - .expect("timed out waiting for `first`"); + assert_msg!(stream, "first"); let bo_group = bo.create_dm(caro_wallet.get_address()).await.unwrap(); + assert_msg_exists!(stream); bo_group.send_message(b"second").await.unwrap(); - delivery - .wait_for_delivery() - .await - .expect("timed out waiting for `second`"); + assert_msg!(stream, "second"); alix_group.send_message(b"third").await.unwrap(); - delivery - .wait_for_delivery() - .await - .expect("timed out waiting for `third`"); + assert_msg!(stream, "third"); let alix_group_2 = alix .create_group(None, GroupMetadataOptions::default()) @@ -808,36 +781,10 @@ pub(crate) mod tests { .unwrap(); alix_group.send_message(b"fourth").await.unwrap(); - delivery - .wait_for_delivery() - .await - .expect("timed out waiting for `fourth`"); + assert_msg!(stream, "fourth"); alix_group_2.send_message(b"fifth").await.unwrap(); - delivery - .wait_for_delivery() - .await - .expect("timed out waiting for `fifth`"); - - { - let messages = messages.lock(); - assert_eq!(messages.len(), 5); - } - - let a = handle.abort_handle(); - a.end(); - let _ = handle.join().await; - assert!(a.is_finished()); - - alix_group - .send_message("should not show up".as_bytes()) - .await - .unwrap(); - xmtp_common::time::sleep(core::time::Duration::from_millis(100)).await; - - let messages = messages.lock(); - - assert_eq!(messages.len(), 5); + assert_msg!(stream, "fifth"); } #[ignore] diff --git a/xmtp_mls/src/subscriptions/stream_all.rs b/xmtp_mls/src/subscriptions/stream_all.rs new file mode 100644 index 000000000..a6b0c3913 --- /dev/null +++ b/xmtp_mls/src/subscriptions/stream_all.rs @@ -0,0 +1,87 @@ +use std::{collections::HashMap, sync::Arc}; + +use crate::{ + client::ClientError, + groups::scoped_client::ScopedGroupClient, + groups::subscriptions, + storage::{ + group::{ConversationType, GroupQueryArgs}, + group_message::StoredGroupMessage, + }, + Client, +}; +use futures::{ + stream::{self, Stream, StreamExt}, + Future, +}; +use xmtp_id::scw_verifier::SmartContractSignatureVerifier; +use xmtp_proto::api_client::{trait_impls::XmtpApi, XmtpMlsStreams}; + +use super::{MessagesStreamInfo, SubscribeError}; +pub struct StreamAllMessages<'a, C, Welcomes, Messages> { + /// The monolithic XMTP Client + client: &'a C, + /// Type of conversation to stream + conversation_type: Option, + /// Conversations that are being actively streamed + active_conversations: HashMap, MessagesStreamInfo>, + /// Welcomes Stream + welcomes: Welcomes, + /// Messages Stream + messages: Messages, + /// Extra messages from message stream, when the stream switches because + /// of a new group received. + extra_messages: Vec, +} + +impl<'a, A, V, Welcomes, Messages> StreamAllMessages<'a, Client, Welcomes, Messages> +where + A: XmtpApi + XmtpMlsStreams + Send + Sync + 'static, + V: SmartContractSignatureVerifier + Send + Sync + 'static, +{ + pub async fn new( + client: &'a Client, + conversation_type: Option, + ) -> Result { + let mut active_conversations = async { + let provider = client.mls_provider()?; + client.sync_welcomes(&provider).await?; + + let active_conversations = provider + .conn_ref() + .find_groups(GroupQueryArgs::default().maybe_conversation_type(conversation_type))? + .into_iter() + .map(Into::into) + .collect::, MessagesStreamInfo>>(); + Ok::<_, ClientError>(active_conversations) + } + .await?; + + let messages = + subscriptions::stream_messages(client, Arc::new(active_conversations.clone())).await?; + let welcomes = client.stream_conversations(conversation_type).await?; + + Self { + client, + conversation_type, + messages, + welcomes, + active_conversations, + extra_messages: Vec::new(), + } + } +} + +impl<'a, C, Welcomes, Messages> Stream for StreamAllMessages<'a, C, Welcomes, Messages> +where + C: ScopedGroupClient, +{ + type Item = Result; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + todo!() + } +} diff --git a/xmtp_mls/src/subscriptions/stream_conversations.rs b/xmtp_mls/src/subscriptions/stream_conversations.rs new file mode 100644 index 000000000..3f3f3abb5 --- /dev/null +++ b/xmtp_mls/src/subscriptions/stream_conversations.rs @@ -0,0 +1,241 @@ +use std::{collections::HashSet, marker::PhantomData, sync::Arc, task::Poll}; + +use futures::{prelude::stream::Select, Stream}; +use pin_project_lite::pin_project; +use tokio_stream::wrappers::BroadcastStream; +use xmtp_common::{retry_async, Retry}; +use xmtp_id::scw_verifier::SmartContractSignatureVerifier; +use xmtp_proto::{ + api_client::{trait_impls::XmtpApi, XmtpMlsStreams}, + xmtp::mls::api::v1::WelcomeMessage, +}; + +use crate::{ + groups::{scoped_client::ScopedGroupClient, MlsGroup}, + storage::{group::ConversationType, DbConnection}, + Client, XmtpOpenMlsProvider, +}; + +use super::{LocalEvents, SubscribeError}; + +enum WelcomeOrGroup { + Group(Result, SubscribeError>), + Welcome(Result), +} + +pin_project! { + /// Broadcast stream filtered + mapped to WelcomeOrGroup + struct BroadcastGroupStream { + #[pin] inner: BroadcastStream>, + } +} + +impl BroadcastGroupStream { + fn new(inner: BroadcastStream>) -> Self { + Self { inner } + } +} + +impl Stream for BroadcastGroupStream +where + C: Clone + Send + Sync + 'static, // required by tokio::BroadcastStream +{ + type Item = WelcomeOrGroup; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + use std::task::Poll::*; + let this = self.project(); + + match this.inner.poll_next(cx) { + Ready(Some(event)) => { + let ev = xmtp_common::optify!(event, "Missed messages due to event queue lag") + .and_then(LocalEvents::group_filter); + if let Some(g) = ev { + Ready(Some(WelcomeOrGroup::::Group(Ok(g)))) + } else { + // skip this item since it was either missed due to lag, or not a group + Pending + } + } + Pending => Pending, + Ready(None) => Ready(None), + } + } +} + +pin_project! { + /// Subscription Stream mapped to WelcomeOrGroup + struct SubscriptionStream { + #[pin] inner: S, + _marker: PhantomData, + } +} + +impl SubscriptionStream { + fn new(inner: S) -> Self { + Self { + inner, + _marker: PhantomData, + } + } +} + +impl Stream for SubscriptionStream +where + S: Stream>, +{ + type Item = WelcomeOrGroup; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + use std::task::Poll::*; + let this = self.project(); + + match this.inner.poll_next(cx) { + Ready(Some(welcome)) => Ready(Some(WelcomeOrGroup::Welcome(welcome))), + Pending => Pending, + Ready(None) => Ready(None), + } + } +} + +pin_project! { + pub struct StreamConversations<'a, C, Subscription> { + client: &'a C, + #[pin] inner: Subscription, + conversation_type: Option, + known_welcome_ids: HashSet + } +} + +type MultiplexedSelect = Select, SubscriptionStream>; + +impl<'a, A, V> + StreamConversations< + 'a, + Client, + MultiplexedSelect, ::WelcomeMessageStream<'a>>, + > +where + A: XmtpApi + XmtpMlsStreams + Send + Sync + 'static, + V: SmartContractSignatureVerifier + Send + Sync + 'static, +{ + pub async fn new( + client: &'a Client, + conversation_type: Option, + conn: &DbConnection, + ) -> Result { + let installation_key = client.installation_public_key(); + let id_cursor = 0; + tracing::info!( + inbox_id = client.inbox_id(), + "Setting up conversation stream" + ); + + let events = + BroadcastGroupStream::new(BroadcastStream::new(client.local_events.subscribe())); + + let subscription = client + .api_client + .subscribe_welcome_messages(installation_key.as_ref(), Some(id_cursor)) + .await?; + let subscription = SubscriptionStream::new(subscription); + let known_welcome_ids = HashSet::from_iter(conn.group_welcome_ids()?.into_iter()); + + let stream = futures::stream::select(events, subscription); + + Ok(Self { + client, + inner: stream, + known_welcome_ids, + conversation_type, + }) + } +} + +impl<'a, C, Subscription> Stream for StreamConversations<'a, C, Subscription> +where + C: ScopedGroupClient + Clone, + Subscription: Stream, SubscribeError>>, +{ + type Item = Result, SubscribeError>; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + use std::task::Poll::*; + let this = self.project(); + + match this.inner.poll_next(cx) { + Ready(Some(msg)) => { + todo!() + } + // stream ended + Ready(None) => Ready(None), + Pending => { + cx.waker().wake_by_ref(); + Pending + } + } + } +} + +impl<'a, C, Subscription> StreamConversations<'a, C, Subscription> +where + C: ScopedGroupClient + Clone, +{ + async fn process_streamed_welcome( + &mut self, + client: C, + provider: &XmtpOpenMlsProvider, + welcome: WelcomeMessage, + ) -> Result, SubscribeError> { + let welcome_v1 = crate::client::extract_welcome_message(welcome)?; + if self.known_welcome_ids.contains(&(welcome_v1.id as i64)) { + let conn = provider.conn_ref(); + self.known_welcome_ids.insert(welcome_v1.id as i64); + let group = conn.find_group_by_welcome_id(welcome_v1.id as i64)?; + tracing::info!( + inbox_id = client.inbox_id(), + group_id = hex::encode(&group.id), + welcome_id = ?group.welcome_id, + "Loading existing group for welcome_id: {:?}", + group.welcome_id + ); + return Ok(MlsGroup::new(client.clone(), group.id, group.created_at_ns)); + } + + let creation_result = retry_async!( + Retry::default(), + (async { + tracing::info!( + installation_id = &welcome_v1.id, + "Trying to process streamed welcome" + ); + let welcome_v1 = &welcome_v1; + client + .context + .store() + .transaction_async(provider, |provider| async move { + MlsGroup::create_from_encrypted_welcome( + Arc::new(client.clone()), + provider, + welcome_v1.hpke_public_key.as_slice(), + &welcome_v1.data, + welcome_v1.id as i64, + ) + .await + }) + .await + }) + ); + + Ok(creation_result?) + } +}