diff --git a/Cargo.lock b/Cargo.lock index cf5f4a245..9795538da 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3211,7 +3211,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34" dependencies = [ "cfg-if", - "windows-targets 0.48.5", + "windows-targets 0.52.6", ] [[package]] @@ -6899,7 +6899,7 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" dependencies = [ - "windows-sys 0.48.0", + "windows-sys 0.59.0", ] [[package]] @@ -7246,9 +7246,10 @@ dependencies = [ name = "xmtp_api_http" version = "0.1.0" dependencies = [ - "async-stream", "async-trait", + "bytes", "futures", + "pin-project-lite", "reqwest 0.12.9", "serde", "serde_json", @@ -7424,6 +7425,7 @@ dependencies = [ "openssl", "openssl-sys", "parking_lot 0.12.3", + "pin-project-lite", "prost", "rand", "reqwest 0.12.9", diff --git a/Cargo.toml b/Cargo.toml index b46b42352..d34d8d219 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,8 +37,7 @@ ctor = "0.2" ed25519 = "2.2.3" ed25519-dalek = { version = "2.1.1", features = ["zeroize"] } ethers = { version = "2.0", default-features = false } -futures = "0.3.30" -futures-core = "0.3.30" +futures = { version = "0.3.30", default-features = false } getrandom = { version = "0.2", default-features = false } hex = "0.4.3" hkdf = "0.12.3" @@ -61,16 +60,7 @@ thiserror = "2.0" tls_codec = "0.4.1" tokio = { version = "1.35.1", default-features = false } uuid = "1.10" -wasm-timer = "0.2" web-time = "1.1" -# Changing this version and rustls may potentially break the android build. Use Caution. -# Test with Android and Swift first. -# Its probably preferable to one day use https://github.com/rustls/rustls-platform-verifier -# Until then, always test agains iOS/Android after updating these dependencies & making a PR -# Related Issues: -# - https://github.com/seanmonstar/reqwest/issues/2159 -# - https://github.com/hyperium/tonic/pull/1974 -# - https://github.com/rustls/rustls-platform-verifier/issues/58 bincode = "1.3" console_error_panic_hook = "0.1" const_format = "0.2" @@ -87,6 +77,14 @@ openssl = { version = "0.10", features = ["vendored"] } openssl-sys = { version = "0.9", features = ["vendored"] } parking_lot = "0.12.3" sqlite-web = "0.0.1" +# Changing this version and rustls may potentially break the android build. Use Caution. +# Test with Android and Swift first. +# Its probably preferable to one day use https://github.com/rustls/rustls-platform-verifier +# Until then, always test agains iOS/Android after updating these dependencies & making a PR +# Related Issues: +# - https://github.com/seanmonstar/reqwest/issues/2159 +# - https://github.com/hyperium/tonic/pull/1974 +# - https://github.com/rustls/rustls-platform-verifier/issues/58 tonic = { version = "0.12", default-features = false } tracing = { version = "0.1", features = ["log"] } tracing-subscriber = { version = "0.3", default-features = false } @@ -101,7 +99,8 @@ criterion = { version = "0.5", features = [ "html_reports", "async_tokio", ]} - once_cell = "1.2" +once_cell = "1.2" +pin-project-lite = "0.2" # Internal Crate Dependencies xmtp_api_grpc = { path = "xmtp_api_grpc" } diff --git a/common/src/test.rs b/common/src/test.rs index c11c692cb..4e2474ae4 100644 --- a/common/src/test.rs +++ b/common/src/test.rs @@ -108,6 +108,10 @@ pub fn rand_u64() -> u64 { crypto_utils::rng().gen() } +pub fn rand_i64() -> i64 { + crypto_utils::rng().gen() +} + #[cfg(not(target_arch = "wasm32"))] pub fn tmp_path() -> String { let db_name = crate::rand_string::<24>(); diff --git a/xmtp_api_grpc/Cargo.toml b/xmtp_api_grpc/Cargo.toml index b69a0d0c4..67ea6fb16 100644 --- a/xmtp_api_grpc/Cargo.toml +++ b/xmtp_api_grpc/Cargo.toml @@ -8,7 +8,7 @@ version.workspace = true async-stream.workspace = true async-trait = "0.1" base64.workspace = true -futures.workspace = true +futures = { workspace = true, features = ["alloc"] } hex.workspace = true prost = { workspace = true, features = ["prost-derive"] } tokio = { workspace = true, features = ["macros", "time"] } diff --git a/xmtp_api_http/Cargo.toml b/xmtp_api_http/Cargo.toml index b26a414a9..09a6a9214 100644 --- a/xmtp_api_http/Cargo.toml +++ b/xmtp_api_http/Cargo.toml @@ -8,16 +8,17 @@ license.workspace = true crate-type = ["cdylib", "rlib"] [dependencies] -async-stream.workspace = true futures = { workspace = true } tracing.workspace = true reqwest = { version = "0.12.5", features = ["json", "stream"] } serde = { workspace = true } serde_json = { workspace = true } -thiserror = "2.0" +thiserror.workspace = true tokio = { workspace = true, features = ["sync", "rt", "macros"] } xmtp_proto = { path = "../xmtp_proto", features = ["proto_full"] } async-trait = "0.1" +bytes = "1.9" +pin-project-lite = "0.2.15" [dev-dependencies] xmtp_proto = { path = "../xmtp_proto", features = ["test-utils"] } diff --git a/xmtp_api_http/src/http_stream.rs b/xmtp_api_http/src/http_stream.rs new file mode 100644 index 000000000..8e969b0c4 --- /dev/null +++ b/xmtp_api_http/src/http_stream.rs @@ -0,0 +1,231 @@ +//! Streams that work with HTTP POST requests + +use crate::util::GrpcResponse; +use futures::{ + stream::{self, Stream, StreamExt}, + Future, +}; +use reqwest::Response; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use serde_json::Deserializer; +use std::{marker::PhantomData, pin::Pin, task::Poll}; +use xmtp_proto::{Error, ErrorKind}; + +#[derive(Deserialize, Serialize, Debug)] +pub(crate) struct SubscriptionItem { + pub result: T, +} + +#[cfg(target_arch = "wasm32")] +pub type BytesStream = stream::LocalBoxStream<'static, Result>; + +// #[cfg(not(target_arch = "wasm32"))] +// pub type BytesStream = Pin> + Send>>; + +#[cfg(not(target_arch = "wasm32"))] +pub type BytesStream = stream::BoxStream<'static, Result>; + +pin_project_lite::pin_project! { + #[project = PostStreamProject] + enum HttpPostStream { + NotStarted{#[pin] fut: F}, + // `Reqwest::bytes_stream` returns `impl Stream` rather than a type generic, + // so we can't use a type generic here + // this makes wasm a bit tricky. + Started { + #[pin] http: BytesStream, + remaining: Vec, + _marker: PhantomData, + }, + } +} + +impl Stream for HttpPostStream +where + F: Future>, + for<'de> R: Send + Deserialize<'de>, +{ + type Item = Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + use std::task::Poll::*; + match self.as_mut().project() { + PostStreamProject::NotStarted { fut } => match fut.poll(cx) { + Ready(response) => { + let s = response.unwrap().bytes_stream(); + self.set(Self::started(s)); + self.as_mut().poll_next(cx) + } + Pending => { + cx.waker().wake_by_ref(); + Pending + } + }, + PostStreamProject::Started { + ref mut http, + ref mut remaining, + .. + } => { + let mut pinned = std::pin::pin!(http); + let next = pinned.as_mut().poll_next(cx); + Self::on_bytes(next, remaining, cx) + } + } + } +} + +impl HttpPostStream +where + R: Send, +{ + #[cfg(not(target_arch = "wasm32"))] + fn started( + http: impl Stream> + Send + 'static, + ) -> Self { + Self::Started { + http: http.boxed(), + remaining: Vec::new(), + _marker: PhantomData, + } + } + + #[cfg(target_arch = "wasm32")] + fn started(http: impl Stream> + 'static) -> Self { + Self::Started { + http: http.boxed_local(), + remaining: Vec::new(), + _marker: PhantomData, + } + } +} + +impl HttpPostStream +where + F: Future>, + for<'de> R: Deserialize<'de> + DeserializeOwned + Send, +{ + fn new(request: F) -> Self { + Self::NotStarted { fut: request } + } + + fn on_bytes( + p: Poll>>, + remaining: &mut Vec, + cx: &mut std::task::Context<'_>, + ) -> Poll::Item>> { + use futures::task::Poll::*; + match p { + Ready(Some(bytes)) => { + let bytes = bytes.map_err(|e| { + Error::new(ErrorKind::SubscriptionUpdateError).with(e.to_string()) + })?; + let bytes = &[remaining.as_ref(), bytes.as_ref()].concat(); + let de = Deserializer::from_slice(bytes); + let mut stream = de.into_iter::>(); + 'messages: loop { + tracing::debug!("Waiting on next response ..."); + let response = stream.next(); + let res = match response { + Some(Ok(GrpcResponse::Ok(response))) => Ok(response), + Some(Ok(GrpcResponse::SubscriptionItem(item))) => Ok(item.result), + Some(Ok(GrpcResponse::Err(e))) => { + Err(Error::new(ErrorKind::MlsError).with(e.message)) + } + Some(Err(e)) => { + if e.is_eof() { + *remaining = (&**bytes)[stream.byte_offset()..].to_vec(); + return Pending; + } else { + Err(Error::new(ErrorKind::MlsError).with(e.to_string())) + } + } + Some(Ok(GrpcResponse::Empty {})) => continue 'messages, + None => return Ready(None), + }; + return Ready(Some(res)); + } + } + Ready(None) => Ready(None), + Pending => { + cx.waker().wake_by_ref(); + Pending + } + } + } +} + +#[cfg(not(target_arch = "wasm32"))] +impl HttpPostStream +where + F: Future> + Unpin, + for<'de> R: Deserialize<'de> + DeserializeOwned + Send, +{ + /// Establish the initial HTTP Stream connection + fn establish(&mut self) -> () { + // we need to poll the future once to progress the future state & + // establish the initial POST request. + // It should always be pending + let noop_waker = futures::task::noop_waker(); + let mut cx = std::task::Context::from_waker(&noop_waker); + // let mut this = Pin::new(self); + let mut this = Pin::new(self); + let _ = this.poll_next_unpin(&mut cx); + } +} + +#[cfg(target_arch = "wasm32")] +impl HttpPostStream +where + F: Future>, + for<'de> R: Deserialize<'de> + DeserializeOwned + Send, +{ + fn establish(&mut self) -> () { + // we need to poll the future once to progress the future state & + // establish the initial POST request. + // It should always be pending + let noop_waker = futures::task::noop_waker(); + let mut cx = std::task::Context::from_waker(&noop_waker); + let mut this = unsafe { Pin::new_unchecked(self) }; + let _ = this.poll_next_unpin(&mut cx); + } +} + +#[cfg(target_arch = "wasm32")] +pub fn create_grpc_stream( + request: T, + endpoint: String, + http_client: reqwest::Client, +) -> stream::LocalBoxStream<'static, Result> { + create_grpc_stream_inner(request, endpoint, http_client).boxed_local() +} + +#[cfg(not(target_arch = "wasm32"))] +pub fn create_grpc_stream( + request: T, + endpoint: String, + http_client: reqwest::Client, +) -> stream::BoxStream<'static, Result> +where + T: Serialize + 'static, + R: DeserializeOwned + Send + 'static, +{ + create_grpc_stream_inner(request, endpoint, http_client).boxed() +} + +fn create_grpc_stream_inner( + request: T, + endpoint: String, + http_client: reqwest::Client, +) -> impl Stream> +where + T: Serialize + 'static, + R: DeserializeOwned + Send + 'static, +{ + let request = http_client.post(endpoint).json(&request).send(); + let mut http = HttpPostStream::new(request); + http.establish(); + http +} diff --git a/xmtp_api_http/src/lib.rs b/xmtp_api_http/src/lib.rs index 80489fb3c..8a3f972c4 100755 --- a/xmtp_api_http/src/lib.rs +++ b/xmtp_api_http/src/lib.rs @@ -1,11 +1,13 @@ #![warn(clippy::unwrap_used)] pub mod constants; +mod http_stream; mod util; use futures::stream; +use http_stream::create_grpc_stream; use reqwest::header; -use util::{create_grpc_stream, handle_error}; +use util::handle_error; use xmtp_proto::api_client::{ClientWithMetadata, XmtpIdentityClient}; use xmtp_proto::xmtp::identity::api::v1::{ GetIdentityUpdatesRequest as GetIdentityUpdatesV2Request, diff --git a/xmtp_api_http/src/util.rs b/xmtp_api_http/src/util.rs index 8a839fc56..34c878c4a 100644 --- a/xmtp_api_http/src/util.rs +++ b/xmtp_api_http/src/util.rs @@ -1,9 +1,5 @@ -use futures::{ - stream::{self, StreamExt}, - Stream, -}; +use crate::http_stream::SubscriptionItem; use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use serde_json::Deserializer; use std::io::Read; use xmtp_proto::{Error, ErrorKind}; @@ -23,11 +19,6 @@ pub(crate) struct ErrorResponse { details: Vec, } -#[derive(Deserialize, Serialize, Debug)] -pub(crate) struct SubscriptionItem { - pub result: T, -} - /// handle JSON response from gRPC, returning either /// the expected deserialized response object or a gRPC [`Error`] pub fn handle_error(reader: R) -> Result @@ -43,78 +34,6 @@ where } } -#[cfg(target_arch = "wasm32")] -pub fn create_grpc_stream< - T: Serialize + Send + 'static, - R: DeserializeOwned + Send + std::fmt::Debug + 'static, ->( - request: T, - endpoint: String, - http_client: reqwest::Client, -) -> stream::LocalBoxStream<'static, Result> { - create_grpc_stream_inner(request, endpoint, http_client).boxed_local() -} - -#[cfg(not(target_arch = "wasm32"))] -pub fn create_grpc_stream< - T: Serialize + Send + 'static, - R: DeserializeOwned + Send + std::fmt::Debug + 'static, ->( - request: T, - endpoint: String, - http_client: reqwest::Client, -) -> stream::BoxStream<'static, Result> { - create_grpc_stream_inner(request, endpoint, http_client).boxed() -} - -pub fn create_grpc_stream_inner< - T: Serialize + Send + 'static, - R: DeserializeOwned + Send + std::fmt::Debug + 'static, ->( - request: T, - endpoint: String, - http_client: reqwest::Client, -) -> impl Stream> { - async_stream::stream! { - let request = http_client - .post(endpoint) - .json(&request) - .send() - .await - .map_err(|e| Error::new(ErrorKind::MlsError).with(e))?; - - let mut remaining = vec![]; - for await bytes in request.bytes_stream() { - let bytes = bytes - .map_err(|e| Error::new(ErrorKind::SubscriptionUpdateError).with(e.to_string()))?; - let bytes = &[remaining.as_ref(), bytes.as_ref()].concat(); - let de = Deserializer::from_slice(bytes); - let mut stream = de.into_iter::>(); - 'messages: loop { - let response = stream.next(); - let res = match response { - Some(Ok(GrpcResponse::Ok(response))) => Ok(response), - Some(Ok(GrpcResponse::SubscriptionItem(item))) => Ok(item.result), - Some(Ok(GrpcResponse::Err(e))) => { - Err(Error::new(ErrorKind::MlsError).with(e.message)) - } - Some(Err(e)) => { - if e.is_eof() { - remaining = (&**bytes)[stream.byte_offset()..].to_vec(); - break 'messages; - } else { - Err(Error::new(ErrorKind::MlsError).with(e.to_string())) - } - } - Some(Ok(GrpcResponse::Empty {})) => continue 'messages, - None => break 'messages, - }; - yield res; - } - } - } -} - #[cfg(feature = "test-utils")] #[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)] #[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))] diff --git a/xmtp_debug/src/args.rs b/xmtp_debug/src/args.rs index 3c2cf67c5..a0241216b 100644 --- a/xmtp_debug/src/args.rs +++ b/xmtp_debug/src/args.rs @@ -154,6 +154,17 @@ pub enum EntityKind { Identity, } +impl std::fmt::Display for EntityKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + use EntityKind::*; + match self { + Group => write!(f, "group"), + Message => write!(f, "message"), + Identity => write!(f, "identity"), + } + } +} + /// specify the log output #[derive(Args, Debug)] pub struct LogOptions { diff --git a/xmtp_mls/Cargo.toml b/xmtp_mls/Cargo.toml index 0e160e485..ffd6f042e 100644 --- a/xmtp_mls/Cargo.toml +++ b/xmtp_mls/Cargo.toml @@ -49,7 +49,7 @@ async-stream.workspace = true async-trait.workspace = true bincode.workspace = true diesel_migrations.workspace = true -futures.workspace = true +futures = { workspace = true, features = ["alloc"] } hex.workspace = true hkdf.workspace = true openmls_rust_crypto = { workspace = true } @@ -70,6 +70,7 @@ tracing.workspace = true trait-variant.workspace = true xmtp_common.workspace = true zeroize.workspace = true +pin-project-lite.workspace = true # XMTP/Local xmtp_content_types = { path = "../xmtp_content_types" } diff --git a/xmtp_mls/src/api/mls.rs b/xmtp_mls/src/api/mls.rs index 3994cd8fa..86b206121 100644 --- a/xmtp_mls/src/api/mls.rs +++ b/xmtp_mls/src/api/mls.rs @@ -274,10 +274,10 @@ where Ok(()) } - pub async fn subscribe_group_messages( - &self, + pub(crate) async fn subscribe_group_messages<'a>( + &'a self, filters: Vec, - ) -> Result> + '_, ApiError> + ) -> Result<::GroupMessageStream<'a>, ApiError> where ApiClient: XmtpMlsStreams, { @@ -289,11 +289,11 @@ where .await } - pub async fn subscribe_welcome_messages( - &self, + pub(crate) async fn subscribe_welcome_messages<'a>( + &'a self, installation_key: &[u8], id_cursor: Option, - ) -> Result> + '_, ApiError> + ) -> Result<::WelcomeMessageStream<'a>, ApiError> where ApiClient: XmtpMlsStreams, { diff --git a/xmtp_mls/src/client.rs b/xmtp_mls/src/client.rs index 6db9747da..33f38699f 100644 --- a/xmtp_mls/src/client.rs +++ b/xmtp_mls/src/client.rs @@ -50,7 +50,7 @@ use crate::{ group_message::StoredGroupMessage, refresh_state::EntityKind, wallet_addresses::WalletEntry, - EncryptedMessageStore, StorageError, + EncryptedMessageStore, NotFound, StorageError, }, subscriptions::{LocalEventError, LocalEvents}, types::InstallationId, @@ -108,6 +108,12 @@ pub enum ClientError { Generic(String), } +impl From for ClientError { + fn from(value: NotFound) -> Self { + ClientError::Storage(StorageError::NotFound(value)) + } +} + impl From for ClientError { fn from(err: GroupError) -> ClientError { ClientError::Group(Box::new(err)) @@ -309,11 +315,7 @@ where address: String, ) -> Result, ClientError> { let results = self.find_inbox_ids_from_addresses(conn, &[address]).await?; - if let Some(first_result) = results.into_iter().next() { - Ok(first_result) - } else { - Ok(None) - } + Ok(results.into_iter().next().flatten()) } /// Calls the server to look up the `inbox_id`s` associated with a list of addresses. @@ -556,10 +558,9 @@ where { Some(id) => id, None => { - return Err(ClientError::Storage(StorageError::NotFound(format!( - "inbox id for address {} not found", - account_address - )))) + return Err(ClientError::Storage(StorageError::NotFound( + NotFound::InboxIdForAddress(account_address), + ))); } }; @@ -610,13 +611,10 @@ where group_id: Vec, ) -> Result, ClientError> { let stored_group: Option = conn.fetch(&group_id)?; - match stored_group { - Some(group) => Ok(MlsGroup::new(self.clone(), group.id, group.created_at_ns)), - None => Err(ClientError::Storage(StorageError::NotFound(format!( - "group {}", - hex::encode(group_id) - )))), - } + stored_group + .map(|g| MlsGroup::new(self.clone(), g.id, g.created_at_ns)) + .ok_or(NotFound::GroupById(group_id)) + .map_err(Into::into) } /// Look up a group by its ID @@ -638,17 +636,10 @@ where target_inbox_id: String, ) -> Result, ClientError> { let conn = self.store().conn()?; - match conn.find_dm_group(&target_inbox_id)? { - Some(dm_group) => Ok(MlsGroup::new( - self.clone(), - dm_group.id, - dm_group.created_at_ns, - )), - None => Err(ClientError::Storage(StorageError::NotFound(format!( - "dm_target_inbox_id {}", - hex::encode(target_inbox_id) - )))), - } + let group = conn + .find_dm_group(&target_inbox_id)? + .ok_or(NotFound::DmByInbox(target_inbox_id))?; + Ok(MlsGroup::new(self.clone(), group.id, group.created_at_ns)) } /// Look up a message by its ID @@ -656,13 +647,7 @@ where pub fn message(&self, message_id: Vec) -> Result { let conn = &mut self.store().conn()?; let message = conn.get_group_message(&message_id)?; - match message { - Some(message) => Ok(message), - None => Err(ClientError::Storage(StorageError::NotFound(format!( - "message {}", - hex::encode(message_id) - )))), - } + Ok(message.ok_or(NotFound::MessageById(message_id))?) } /// Query for groups with optional filters diff --git a/xmtp_mls/src/groups/device_sync.rs b/xmtp_mls/src/groups/device_sync.rs index 070caabe6..43660e643 100644 --- a/xmtp_mls/src/groups/device_sync.rs +++ b/xmtp_mls/src/groups/device_sync.rs @@ -6,11 +6,9 @@ use crate::{ configuration::NS_IN_HOUR, storage::{ consent_record::StoredConsentRecord, - group::StoredGroup, - group::{ConversationType, GroupQueryArgs}, - group_message::MsgQueryArgs, - group_message::{GroupMessageKind, StoredGroupMessage}, - DbConnection, StorageError, + group::{ConversationType, GroupQueryArgs, StoredGroup}, + group_message::{GroupMessageKind, MsgQueryArgs, StoredGroupMessage}, + DbConnection, NotFound, StorageError, }, subscriptions::{LocalEvents, StreamMessages, SubscribeError, SyncMessage}, xmtp_openmls_provider::XmtpOpenMlsProvider, @@ -115,6 +113,12 @@ impl RetryableError for DeviceSyncError { } } +impl From for DeviceSyncError { + fn from(value: NotFound) -> Self { + DeviceSyncError::Storage(StorageError::NotFound(value)) + } +} + impl Client where ApiClient: XmtpApi + Send + Sync + 'static, @@ -211,9 +215,9 @@ where retry, (async { conn.get_group_message(&message_id)? - .ok_or(DeviceSyncError::Storage(StorageError::NotFound(format!( - "Message id {message_id:?} not found." - )))) + .ok_or(DeviceSyncError::from(NotFound::MessageById( + message_id.clone(), + ))) }) )?; @@ -240,9 +244,9 @@ where retry, (async { conn.get_group_message(&message_id)? - .ok_or(DeviceSyncError::Storage(StorageError::NotFound(format!( - "Message id {message_id:?} not found." - )))) + .ok_or(DeviceSyncError::from(NotFound::MessageById( + message_id.clone(), + ))) }) )?; diff --git a/xmtp_mls/src/storage/encrypted_store/group.rs b/xmtp_mls/src/storage/encrypted_store/group.rs index 010eab713..93f3e8e86 100644 --- a/xmtp_mls/src/storage/encrypted_store/group.rs +++ b/xmtp_mls/src/storage/encrypted_store/group.rs @@ -5,7 +5,7 @@ use super::{ schema::groups::{self, dsl}, Sqlite, }; -use crate::{impl_fetch, impl_store, DuplicateItem, StorageError}; +use crate::{impl_fetch, impl_store, storage::NotFound, DuplicateItem, StorageError}; use diesel::{ backend::Backend, deserialize::{self, FromSql, FromSqlRow}, @@ -379,9 +379,8 @@ impl DbConnection { Ok::, StorageError>(ts) })?; - last_ts.ok_or(StorageError::NotFound(format!( - "installation time for group {}", - hex::encode(group_id) + last_ts.ok_or(StorageError::NotFound(NotFound::InstallationTimeForGroup( + group_id, ))) } @@ -407,10 +406,7 @@ impl DbConnection { Ok::<_, StorageError>(ts) })?; - last_ts.ok_or(StorageError::NotFound(format!( - "installation time for group {}", - hex::encode(group_id) - ))) + last_ts.ok_or(NotFound::InstallationTimeForGroup(group_id).into()) } /// Updates the 'last time checked' we checked for new installations. @@ -458,6 +454,22 @@ impl DbConnection { Ok(stored_group) } + + /// Get all the welcome ids turned into groups + pub(crate) fn group_welcome_ids(&self) -> Result, StorageError> { + self.raw_query(|conn| { + Ok::<_, StorageError>( + dsl::groups + .filter(dsl::welcome_id.is_not_null()) + .select(dsl::welcome_id) + .load::>(conn)? + .into_iter() + .map(|id| id.expect("SQL explicity filters for none")) + .collect(), + ) + }) + .map_err(Into::into) + } } #[repr(i32)] @@ -570,6 +582,25 @@ pub(crate) mod tests { ) } + /// Generate a test group with welcome + pub fn generate_group_with_welcome( + state: Option, + welcome_id: Option, + ) -> StoredGroup { + let id = rand_vec::<24>(); + let created_at_ns = now_ns(); + let membership_state = state.unwrap_or(GroupMembershipState::Allowed); + StoredGroup::new_from_welcome( + id, + created_at_ns, + membership_state, + "placeholder_address".to_string(), + welcome_id.unwrap_or(xmtp_common::rand_i64()), + ConversationType::Group, + None, + ) + } + /// Generate a test consent fn generate_consent_record( entity_type: ConsentType, @@ -856,4 +887,22 @@ pub(crate) mod tests { }) .await } + + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr(not(target_arch = "wasm32"), tokio::test)] + async fn test_get_group_welcome_ids() { + with_connection(|conn| { + let mls_groups = vec![ + generate_group_with_welcome(None, Some(30)), + generate_group(None), + generate_group(None), + generate_group_with_welcome(None, Some(10)), + ]; + for g in mls_groups.iter() { + g.store(conn).unwrap(); + } + assert_eq!(vec![30, 10], conn.group_welcome_ids().unwrap()); + }) + .await + } } diff --git a/xmtp_mls/src/storage/encrypted_store/group_intent.rs b/xmtp_mls/src/storage/encrypted_store/group_intent.rs index 40db6fd6c..cb781ab83 100644 --- a/xmtp_mls/src/storage/encrypted_store/group_intent.rs +++ b/xmtp_mls/src/storage/encrypted_store/group_intent.rs @@ -17,7 +17,7 @@ use super::{ use crate::{ groups::{intents::SendMessageIntentData, IntentError}, impl_fetch, impl_store, - storage::StorageError, + storage::{NotFound, StorageError}, utils::id::calculate_message_id, Delete, }; @@ -197,7 +197,7 @@ impl DbConnection { staged_commit: Option>, published_in_epoch: i64, ) -> Result<(), StorageError> { - let res = self.raw_query(|conn| { + let rows_changed = self.raw_query(|conn| { diesel::update(dsl::group_intents) .filter(dsl::id.eq(intent_id)) // State machine requires that the only valid state transition to Published is from @@ -213,18 +213,15 @@ impl DbConnection { .execute(conn) })?; - match res { - // If nothing matched the query, return an error. Either ID or state was wrong - 0 => Err(StorageError::NotFound(format!( - "ToPublish intent {intent_id} for publish" - ))), - _ => Ok(()), + if rows_changed == 0 { + return Err(NotFound::IntentForToPublish(intent_id).into()); } + Ok(()) } // Set the intent with the given ID to `Committed` pub fn set_group_intent_committed(&self, intent_id: ID) -> Result<(), StorageError> { - let res = self.raw_query(|conn| { + let rows_changed = self.raw_query(|conn| { diesel::update(dsl::group_intents) .filter(dsl::id.eq(intent_id)) // State machine requires that the only valid state transition to Committed is from @@ -234,19 +231,18 @@ impl DbConnection { .execute(conn) })?; - match res { - // If nothing matched the query, return an error. Either ID or state was wrong - 0 => Err(StorageError::NotFound(format!( - "Published intent {intent_id} for commit" - ))), - _ => Ok(()), + // If nothing matched the query, return an error. Either ID or state was wrong + if rows_changed == 0 { + return Err(NotFound::IntentForCommitted(intent_id).into()); } + + Ok(()) } // Set the intent with the given ID to `ToPublish`. Wipe any values for `payload_hash` and // `post_commit_data` pub fn set_group_intent_to_publish(&self, intent_id: ID) -> Result<(), StorageError> { - let res = self.raw_query(|conn| { + let rows_changed = self.raw_query(|conn| { diesel::update(dsl::group_intents) .filter(dsl::id.eq(intent_id)) // State machine requires that the only valid state transition to ToPublish is from @@ -263,32 +259,27 @@ impl DbConnection { .execute(conn) })?; - match res { - // If nothing matched the query, return an error. Either ID or state was wrong - 0 => Err(StorageError::NotFound(format!( - "Published intent {intent_id} for ToPublish" - ))), - _ => Ok(()), + if rows_changed == 0 { + return Err(NotFound::IntentForPublish(intent_id).into()); } + Ok(()) } /// Set the intent with the given ID to `Error` #[tracing::instrument(level = "trace", skip(self))] pub fn set_group_intent_error(&self, intent_id: ID) -> Result<(), StorageError> { - let res = self.raw_query(|conn| { + let rows_changed = self.raw_query(|conn| { diesel::update(dsl::group_intents) .filter(dsl::id.eq(intent_id)) .set(dsl::state.eq(IntentState::Error)) .execute(conn) })?; - match res { - // If nothing matched the query, return an error. Either ID or state was wrong - 0 => Err(StorageError::NotFound(format!( - "state for intent {intent_id}" - ))), - _ => Ok(()), + if rows_changed == 0 { + return Err(NotFound::IntentById(intent_id).into()); } + + Ok(()) } // Simple lookup of intents by payload hash, meant to be used when processing messages off the diff --git a/xmtp_mls/src/storage/encrypted_store/refresh_state.rs b/xmtp_mls/src/storage/encrypted_store/refresh_state.rs index f5cf0ba33..b1cfefcb0 100644 --- a/xmtp_mls/src/storage/encrypted_store/refresh_state.rs +++ b/xmtp_mls/src/storage/encrypted_store/refresh_state.rs @@ -8,7 +8,11 @@ use diesel::{ }; use super::{db_connection::DbConnection, schema::refresh_state, Sqlite}; -use crate::{impl_store, impl_store_or_ignore, storage::StorageError, StoreOrIgnore}; +use crate::{ + impl_store, impl_store_or_ignore, + storage::{NotFound, StorageError}, + StoreOrIgnore, +}; #[repr(i32)] #[derive(Debug, Clone, Copy, PartialEq, Eq, AsExpression, Hash, FromSqlRow)] @@ -18,6 +22,16 @@ pub enum EntityKind { Group = 2, } +impl std::fmt::Display for EntityKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + use EntityKind::*; + match self { + Welcome => write!(f, "welcome"), + Group => write!(f, "group"), + } + } +} + impl ToSql for EntityKind where i32: ToSql, @@ -96,24 +110,18 @@ impl DbConnection { entity_kind: EntityKind, cursor: i64, ) -> Result { - let state: Option = self.get_refresh_state(&entity_id, entity_kind)?; - match state { - Some(state) => { - use super::schema::refresh_state::dsl; - let num_updated = self.raw_query(|conn| { - diesel::update(&state) - .filter(dsl::cursor.lt(cursor)) - .set(dsl::cursor.eq(cursor)) - .execute(conn) - })?; - Ok(num_updated == 1) - } - None => Err(StorageError::NotFound(format!( - "state for entity ID {} with kind {:?}", - hex::encode(entity_id.as_ref()), - entity_kind - ))), - } + use super::schema::refresh_state::dsl; + let state: RefreshState = self.get_refresh_state(&entity_id, entity_kind)?.ok_or( + NotFound::RefreshStateByIdAndKind(entity_id.as_ref().to_vec(), entity_kind), + )?; + + let num_updated = self.raw_query(|conn| { + diesel::update(&state) + .filter(dsl::cursor.lt(cursor)) + .set(dsl::cursor.eq(cursor)) + .execute(conn) + })?; + Ok(num_updated == 1) } } diff --git a/xmtp_mls/src/storage/encrypted_store/sqlcipher_connection.rs b/xmtp_mls/src/storage/encrypted_store/sqlcipher_connection.rs index fe9350b26..6723f0df1 100644 --- a/xmtp_mls/src/storage/encrypted_store/sqlcipher_connection.rs +++ b/xmtp_mls/src/storage/encrypted_store/sqlcipher_connection.rs @@ -12,7 +12,7 @@ use std::{ path::{Path, PathBuf}, }; -use crate::storage::StorageError; +use crate::storage::{NotFound, StorageError}; use super::{EncryptionKey, StorageOption}; @@ -165,9 +165,9 @@ impl EncryptedConnection { ) -> Result<(), StorageError> { let mut row_iter = conn.load(sql_query("PRAGMA cipher_salt"))?; // cipher salt should always exist. if it doesn't SQLCipher is misconfigured. - let row = row_iter.next().ok_or(StorageError::NotFound( - "Cipher salt doesn't exist in database".into(), - ))??; + let row = row_iter + .next() + .ok_or(NotFound::CipherSalt(path.to_string()))??; let salt = >::build_from_row(&row)?; tracing::debug!( salt, diff --git a/xmtp_mls/src/storage/errors.rs b/xmtp_mls/src/storage/errors.rs index de25850ab..6a5a00905 100644 --- a/xmtp_mls/src/storage/errors.rs +++ b/xmtp_mls/src/storage/errors.rs @@ -3,7 +3,7 @@ use std::sync::PoisonError; use diesel::result::DatabaseErrorKind; use thiserror::Error; -use super::sql_key_store; +use super::{refresh_state::EntityKind, sql_key_store}; use crate::groups::intents::IntentError; use xmtp_common::{retryable, RetryableError}; @@ -27,9 +27,9 @@ pub enum StorageError { Serialization(String), #[error("deserialization error")] Deserialization(String), - #[error("{0} not found")] - NotFound(String), - #[error("lock")] + #[error(transparent)] + NotFound(#[from] NotFound), + #[error("lock {0}")] Lock(String), #[error("Pool needs to reconnect before use")] PoolNeedsConnection, @@ -47,6 +47,35 @@ pub enum StorageError { Duplicate(DuplicateItem), } +#[derive(Error, Debug)] +// Monolithic enum for all things lost +pub enum NotFound { + #[error("group with welcome id {0} not found")] + GroupByWelcome(i64), + #[error("group with id {id} not found", id = hex::encode(_0))] + GroupById(Vec), + #[error("installation time for group {id}", id = hex::encode(_0))] + InstallationTimeForGroup(Vec), + #[error("inbox id for address {0} not found")] + InboxIdForAddress(String), + #[error("message id {id} not found", id = hex::encode(_0))] + MessageById(Vec), + #[error("dm by dm_target_inbox_id {0} not found")] + DmByInbox(String), + #[error("intent with id {0} for state Publish from ToPublish not found")] + IntentForToPublish(i32), + #[error("intent with id {0} for state ToPublish from Published not found")] + IntentForPublish(i32), + #[error("intent with id {0} for state Committed from Published not found")] + IntentForCommitted(i32), + #[error("Intent with id {0} not found")] + IntentById(i32), + #[error("refresh state with id {id} and kind {1} not found", id = hex::encode(_0))] + RefreshStateByIdAndKind(Vec, EntityKind), + #[error("Cipher salt for db at [`{0}`] not found")] + CipherSalt(String), +} + #[derive(Error, Debug)] pub enum DuplicateItem { #[error("the welcome id {0:?} already exists")] @@ -102,6 +131,14 @@ impl RetryableError for StorageError { } } +impl RetryableError for NotFound { + fn is_retryable(&self) -> bool { + match self { + _ => true, + } + } +} + // OpenMLS KeyStore errors impl RetryableError for openmls::group::AddMembersError { fn is_retryable(&self) -> bool { diff --git a/xmtp_mls/src/subscriptions.rs b/xmtp_mls/src/subscriptions/mod.rs similarity index 88% rename from xmtp_mls/src/subscriptions.rs rename to xmtp_mls/src/subscriptions/mod.rs index f9675d500..531e41e0c 100644 --- a/xmtp_mls/src/subscriptions.rs +++ b/xmtp_mls/src/subscriptions/mod.rs @@ -10,6 +10,9 @@ use tracing::instrument; use xmtp_id::scw_verifier::SmartContractSignatureVerifier; use xmtp_proto::{api_client::XmtpMlsStreams, xmtp::mls::api::v1::WelcomeMessage}; +// mod stream_all; +// mod stream_conversations; + use crate::{ client::{extract_welcome_message, ClientError}, groups::{ @@ -454,9 +457,9 @@ where futures::pin_mut!(messages_stream); let convo_stream = self.stream_conversations(conversation_type).await?; - futures::pin_mut!(convo_stream); + tracing::info!("\n\n Waiting on messages \n\n"); let mut extra_messages = Vec::new(); loop { @@ -609,6 +612,45 @@ pub(crate) mod tests { use xmtp_cryptography::utils::generate_local_wallet; use xmtp_id::InboxOwner; + /// A macro for asserting that a stream yields a specific decrypted message. + /// + /// # Example + /// ```rust + /// assert_msg!(stream, b"first"); + /// ``` + #[macro_export] + macro_rules! assert_msg { + ($stream:expr, $expected:expr) => { + assert_eq!( + $stream + .next() + .await + .unwrap() + .unwrap() + .decrypted_message_bytes, + $expected.as_bytes() + ); + }; + } + + /// A macro for asserting that a stream yields a specific decrypted message. + /// + /// # Example + /// ```rust + /// assert_msg!(stream, b"first"); + /// ``` + #[macro_export] + macro_rules! assert_msg_exists { + ($stream:expr) => { + assert!(!$stream + .next() + .await + .unwrap() + .unwrap() + .decrypted_message_bytes + .is_empty()); + }; + } #[wasm_bindgen_test(unsupported = tokio::test(flavor = "multi_thread", worker_threads = 10))] async fn test_stream_welcomes() { let alice = Arc::new(ClientBuilder::new_test_client(&generate_local_wallet()).await); @@ -617,23 +659,8 @@ pub(crate) mod tests { .create_group(None, GroupMetadataOptions::default()) .unwrap(); - // FIXME:insipx we run into an issue where the reqwest::post().send() request - // blocks the executor and we cannot progress the runtime if we dont `tokio::spawn` this. - // A solution might be to use `hyper` instead, and implement a custom connection pool with - // `deadpool`. This is a bit more work but shouldn't be too complicated since - // we're only using `post` requests. It would be nice for all streams to work - // w/o spawning a separate task. - let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); - let mut stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx); - let bob_ptr = bob.clone(); - crate::spawn(None, async move { - let bob_stream = bob_ptr.stream_conversations(None).await.unwrap(); - futures::pin_mut!(bob_stream); - while let Some(item) = bob_stream.next().await { - let _ = tx.send(item); - } - }); - + let stream = bob.stream_conversations(None).await.unwrap(); + futures::pin_mut!(stream); let group_id = alice_bob_group.group_id.clone(); alice_bob_group .add_members_by_inbox_id(&[bob.inbox_id()]) @@ -644,7 +671,7 @@ pub(crate) mod tests { assert_eq!(bob_received_groups.group_id, group_id); } - #[wasm_bindgen_test(unsupported = tokio::test(flavor = "multi_thread", worker_threads = 10))] + #[wasm_bindgen_test(unsupported = tokio::test(flavor = "current_thread"))] async fn test_stream_messages() { xmtp_common::logger(); let alice = Arc::new(ClientBuilder::new_test_client(&generate_local_wallet()).await); @@ -653,6 +680,7 @@ pub(crate) mod tests { let alice_group = alice .create_group(None, GroupMetadataOptions::default()) .unwrap(); + tracing::info!("Group Id = [{}]", hex::encode(&alice_group.group_id)); alice_group .add_members_by_inbox_id(&[bob.inbox_id()]) @@ -664,33 +692,16 @@ pub(crate) mod tests { .unwrap(); let bob_group = bob_group.first().unwrap(); - let notify = Delivery::new(None); - let notify_ptr = notify.clone(); - let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); - crate::spawn(None, async move { - let stream = alice_group.stream().await.unwrap(); - futures::pin_mut!(stream); - while let Some(item) = stream.next().await { - let _ = tx.send(item); - notify_ptr.notify_one(); - } - }); - - let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx); - // let stream = alice_group.stream().await.unwrap(); + let stream = alice_group.stream().await.unwrap(); futures::pin_mut!(stream); bob_group.send_message(b"hello").await.unwrap(); - tracing::debug!("Bob Sent Message!, waiting for delivery"); - // notify.wait_for_delivery().await.unwrap(); + let message = stream.next().await.unwrap().unwrap(); assert_eq!(message.decrypted_message_bytes, b"hello"); bob_group.send_message(b"hello2").await.unwrap(); - // notify.wait_for_delivery().await.unwrap(); let message = stream.next().await.unwrap().unwrap(); assert_eq!(message.decrypted_message_bytes, b"hello2"); - - // assert_eq!(bob_received_groups.group_id, alice_bob_group.group_id); } #[wasm_bindgen_test(unsupported = tokio::test(flavor = "multi_thread", worker_threads = 10))] @@ -714,40 +725,20 @@ pub(crate) mod tests { .add_members_by_inbox_id(&[caro.inbox_id()]) .await .unwrap(); - xmtp_common::time::sleep(core::time::Duration::from_millis(100)).await; - let messages: Arc>> = Arc::new(Mutex::new(Vec::new())); - let messages_clone = messages.clone(); + let stream = caro.stream_all_messages(None).await.unwrap(); + futures::pin_mut!(stream); + bo_group.send_message(b"first").await.unwrap(); + assert_msg!(stream, "first"); - let notify = Delivery::new(None); - let notify_pointer = notify.clone(); - let mut handle = Client::::stream_all_messages_with_callback( - Arc::new(caro), - None, - move |message| { - (*messages_clone.lock()).push(message.unwrap()); - notify_pointer.notify_one(); - }, - ); - handle.wait_for_ready().await; + bo_group.send_message(b"second").await.unwrap(); + assert_msg!(stream, "second"); - alix_group.send_message("first".as_bytes()).await.unwrap(); - notify - .wait_for_delivery() - .await - .expect("didn't get `first`"); - bo_group.send_message("second".as_bytes()).await.unwrap(); - notify.wait_for_delivery().await.unwrap(); - alix_group.send_message("third".as_bytes()).await.unwrap(); - notify.wait_for_delivery().await.unwrap(); - bo_group.send_message("fourth".as_bytes()).await.unwrap(); - notify.wait_for_delivery().await.unwrap(); + alix_group.send_message(b"third").await.unwrap(); + assert_msg!(stream, "third"); - let messages = messages.lock(); - assert_eq!(messages[0].decrypted_message_bytes, b"first"); - assert_eq!(messages[1].decrypted_message_bytes, b"second"); - assert_eq!(messages[2].decrypted_message_bytes, b"third"); - assert_eq!(messages[3].decrypted_message_bytes, b"fourth"); + bo_group.send_message(b"fourth").await.unwrap(); + assert_msg!(stream, "fourth"); } #[wasm_bindgen_test(unsupported = tokio::test(flavor = "multi_thread", worker_threads = 10))] @@ -765,39 +756,21 @@ pub(crate) mod tests { .await .unwrap(); - let messages: Arc>> = Arc::new(Mutex::new(Vec::new())); - let messages_clone = messages.clone(); - let delivery = Delivery::new(None); - let delivery_pointer = delivery.clone(); - let mut handle = Client::::stream_all_messages_with_callback( - caro.clone(), - None, - move |message| { - delivery_pointer.notify_one(); - (*messages_clone.lock()).push(message.unwrap()); - }, - ); - handle.wait_for_ready().await; + let stream = caro.stream_all_messages(None).await.unwrap(); + futures::pin_mut!(stream); + tracing::info!("\n\nSENDING FIRST MESSAGE\n\n"); alix_group.send_message(b"first").await.unwrap(); - delivery - .wait_for_delivery() - .await - .expect("timed out waiting for `first`"); + assert_msg!(stream, "first"); let bo_group = bo.create_dm(caro_wallet.get_address()).await.unwrap(); + assert_msg_exists!(stream); bo_group.send_message(b"second").await.unwrap(); - delivery - .wait_for_delivery() - .await - .expect("timed out waiting for `second`"); + assert_msg!(stream, "second"); alix_group.send_message(b"third").await.unwrap(); - delivery - .wait_for_delivery() - .await - .expect("timed out waiting for `third`"); + assert_msg!(stream, "third"); let alix_group_2 = alix .create_group(None, GroupMetadataOptions::default()) @@ -808,36 +781,10 @@ pub(crate) mod tests { .unwrap(); alix_group.send_message(b"fourth").await.unwrap(); - delivery - .wait_for_delivery() - .await - .expect("timed out waiting for `fourth`"); + assert_msg!(stream, "fourth"); alix_group_2.send_message(b"fifth").await.unwrap(); - delivery - .wait_for_delivery() - .await - .expect("timed out waiting for `fifth`"); - - { - let messages = messages.lock(); - assert_eq!(messages.len(), 5); - } - - let a = handle.abort_handle(); - a.end(); - let _ = handle.join().await; - assert!(a.is_finished()); - - alix_group - .send_message("should not show up".as_bytes()) - .await - .unwrap(); - xmtp_common::time::sleep(core::time::Duration::from_millis(100)).await; - - let messages = messages.lock(); - - assert_eq!(messages.len(), 5); + assert_msg!(stream, "fifth"); } #[ignore] diff --git a/xmtp_mls/src/subscriptions/stream_all.rs b/xmtp_mls/src/subscriptions/stream_all.rs new file mode 100644 index 000000000..a6b0c3913 --- /dev/null +++ b/xmtp_mls/src/subscriptions/stream_all.rs @@ -0,0 +1,87 @@ +use std::{collections::HashMap, sync::Arc}; + +use crate::{ + client::ClientError, + groups::scoped_client::ScopedGroupClient, + groups::subscriptions, + storage::{ + group::{ConversationType, GroupQueryArgs}, + group_message::StoredGroupMessage, + }, + Client, +}; +use futures::{ + stream::{self, Stream, StreamExt}, + Future, +}; +use xmtp_id::scw_verifier::SmartContractSignatureVerifier; +use xmtp_proto::api_client::{trait_impls::XmtpApi, XmtpMlsStreams}; + +use super::{MessagesStreamInfo, SubscribeError}; +pub struct StreamAllMessages<'a, C, Welcomes, Messages> { + /// The monolithic XMTP Client + client: &'a C, + /// Type of conversation to stream + conversation_type: Option, + /// Conversations that are being actively streamed + active_conversations: HashMap, MessagesStreamInfo>, + /// Welcomes Stream + welcomes: Welcomes, + /// Messages Stream + messages: Messages, + /// Extra messages from message stream, when the stream switches because + /// of a new group received. + extra_messages: Vec, +} + +impl<'a, A, V, Welcomes, Messages> StreamAllMessages<'a, Client, Welcomes, Messages> +where + A: XmtpApi + XmtpMlsStreams + Send + Sync + 'static, + V: SmartContractSignatureVerifier + Send + Sync + 'static, +{ + pub async fn new( + client: &'a Client, + conversation_type: Option, + ) -> Result { + let mut active_conversations = async { + let provider = client.mls_provider()?; + client.sync_welcomes(&provider).await?; + + let active_conversations = provider + .conn_ref() + .find_groups(GroupQueryArgs::default().maybe_conversation_type(conversation_type))? + .into_iter() + .map(Into::into) + .collect::, MessagesStreamInfo>>(); + Ok::<_, ClientError>(active_conversations) + } + .await?; + + let messages = + subscriptions::stream_messages(client, Arc::new(active_conversations.clone())).await?; + let welcomes = client.stream_conversations(conversation_type).await?; + + Self { + client, + conversation_type, + messages, + welcomes, + active_conversations, + extra_messages: Vec::new(), + } + } +} + +impl<'a, C, Welcomes, Messages> Stream for StreamAllMessages<'a, C, Welcomes, Messages> +where + C: ScopedGroupClient, +{ + type Item = Result; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + todo!() + } +} diff --git a/xmtp_mls/src/subscriptions/stream_conversations.rs b/xmtp_mls/src/subscriptions/stream_conversations.rs new file mode 100644 index 000000000..3f3f3abb5 --- /dev/null +++ b/xmtp_mls/src/subscriptions/stream_conversations.rs @@ -0,0 +1,241 @@ +use std::{collections::HashSet, marker::PhantomData, sync::Arc, task::Poll}; + +use futures::{prelude::stream::Select, Stream}; +use pin_project_lite::pin_project; +use tokio_stream::wrappers::BroadcastStream; +use xmtp_common::{retry_async, Retry}; +use xmtp_id::scw_verifier::SmartContractSignatureVerifier; +use xmtp_proto::{ + api_client::{trait_impls::XmtpApi, XmtpMlsStreams}, + xmtp::mls::api::v1::WelcomeMessage, +}; + +use crate::{ + groups::{scoped_client::ScopedGroupClient, MlsGroup}, + storage::{group::ConversationType, DbConnection}, + Client, XmtpOpenMlsProvider, +}; + +use super::{LocalEvents, SubscribeError}; + +enum WelcomeOrGroup { + Group(Result, SubscribeError>), + Welcome(Result), +} + +pin_project! { + /// Broadcast stream filtered + mapped to WelcomeOrGroup + struct BroadcastGroupStream { + #[pin] inner: BroadcastStream>, + } +} + +impl BroadcastGroupStream { + fn new(inner: BroadcastStream>) -> Self { + Self { inner } + } +} + +impl Stream for BroadcastGroupStream +where + C: Clone + Send + Sync + 'static, // required by tokio::BroadcastStream +{ + type Item = WelcomeOrGroup; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + use std::task::Poll::*; + let this = self.project(); + + match this.inner.poll_next(cx) { + Ready(Some(event)) => { + let ev = xmtp_common::optify!(event, "Missed messages due to event queue lag") + .and_then(LocalEvents::group_filter); + if let Some(g) = ev { + Ready(Some(WelcomeOrGroup::::Group(Ok(g)))) + } else { + // skip this item since it was either missed due to lag, or not a group + Pending + } + } + Pending => Pending, + Ready(None) => Ready(None), + } + } +} + +pin_project! { + /// Subscription Stream mapped to WelcomeOrGroup + struct SubscriptionStream { + #[pin] inner: S, + _marker: PhantomData, + } +} + +impl SubscriptionStream { + fn new(inner: S) -> Self { + Self { + inner, + _marker: PhantomData, + } + } +} + +impl Stream for SubscriptionStream +where + S: Stream>, +{ + type Item = WelcomeOrGroup; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + use std::task::Poll::*; + let this = self.project(); + + match this.inner.poll_next(cx) { + Ready(Some(welcome)) => Ready(Some(WelcomeOrGroup::Welcome(welcome))), + Pending => Pending, + Ready(None) => Ready(None), + } + } +} + +pin_project! { + pub struct StreamConversations<'a, C, Subscription> { + client: &'a C, + #[pin] inner: Subscription, + conversation_type: Option, + known_welcome_ids: HashSet + } +} + +type MultiplexedSelect = Select, SubscriptionStream>; + +impl<'a, A, V> + StreamConversations< + 'a, + Client, + MultiplexedSelect, ::WelcomeMessageStream<'a>>, + > +where + A: XmtpApi + XmtpMlsStreams + Send + Sync + 'static, + V: SmartContractSignatureVerifier + Send + Sync + 'static, +{ + pub async fn new( + client: &'a Client, + conversation_type: Option, + conn: &DbConnection, + ) -> Result { + let installation_key = client.installation_public_key(); + let id_cursor = 0; + tracing::info!( + inbox_id = client.inbox_id(), + "Setting up conversation stream" + ); + + let events = + BroadcastGroupStream::new(BroadcastStream::new(client.local_events.subscribe())); + + let subscription = client + .api_client + .subscribe_welcome_messages(installation_key.as_ref(), Some(id_cursor)) + .await?; + let subscription = SubscriptionStream::new(subscription); + let known_welcome_ids = HashSet::from_iter(conn.group_welcome_ids()?.into_iter()); + + let stream = futures::stream::select(events, subscription); + + Ok(Self { + client, + inner: stream, + known_welcome_ids, + conversation_type, + }) + } +} + +impl<'a, C, Subscription> Stream for StreamConversations<'a, C, Subscription> +where + C: ScopedGroupClient + Clone, + Subscription: Stream, SubscribeError>>, +{ + type Item = Result, SubscribeError>; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + use std::task::Poll::*; + let this = self.project(); + + match this.inner.poll_next(cx) { + Ready(Some(msg)) => { + todo!() + } + // stream ended + Ready(None) => Ready(None), + Pending => { + cx.waker().wake_by_ref(); + Pending + } + } + } +} + +impl<'a, C, Subscription> StreamConversations<'a, C, Subscription> +where + C: ScopedGroupClient + Clone, +{ + async fn process_streamed_welcome( + &mut self, + client: C, + provider: &XmtpOpenMlsProvider, + welcome: WelcomeMessage, + ) -> Result, SubscribeError> { + let welcome_v1 = crate::client::extract_welcome_message(welcome)?; + if self.known_welcome_ids.contains(&(welcome_v1.id as i64)) { + let conn = provider.conn_ref(); + self.known_welcome_ids.insert(welcome_v1.id as i64); + let group = conn.find_group_by_welcome_id(welcome_v1.id as i64)?; + tracing::info!( + inbox_id = client.inbox_id(), + group_id = hex::encode(&group.id), + welcome_id = ?group.welcome_id, + "Loading existing group for welcome_id: {:?}", + group.welcome_id + ); + return Ok(MlsGroup::new(client.clone(), group.id, group.created_at_ns)); + } + + let creation_result = retry_async!( + Retry::default(), + (async { + tracing::info!( + installation_id = &welcome_v1.id, + "Trying to process streamed welcome" + ); + let welcome_v1 = &welcome_v1; + client + .context + .store() + .transaction_async(provider, |provider| async move { + MlsGroup::create_from_encrypted_welcome( + Arc::new(client.clone()), + provider, + welcome_v1.hpke_public_key.as_slice(), + &welcome_v1.data, + welcome_v1.id as i64, + ) + .await + }) + .await + }) + ); + + Ok(creation_result?) + } +}