From 6797962f4e928a1cc141fd998c97aab8e119e384 Mon Sep 17 00:00:00 2001 From: Andrew Plaza Date: Fri, 6 Dec 2024 13:27:35 -0500 Subject: [PATCH] share provider with transactions & everywhere else --- bindings_ffi/src/lib.rs | 2 + bindings_ffi/src/mls.rs | 69 +++-- bindings_node/src/client.rs | 7 +- bindings_node/src/conversation.rs | 3 +- bindings_node/src/conversations.rs | 14 +- bindings_wasm/src/client.rs | 8 +- bindings_wasm/src/conversations.rs | 14 +- examples/cli/cli-client.rs | 17 +- xmtp_debug/src/app/generate/messages.rs | 6 +- xmtp_debug/src/app/send.rs | 4 +- xmtp_mls/src/api/mls.rs | 10 +- xmtp_mls/src/builder.rs | 26 +- xmtp_mls/src/client.rs | 247 ++++++++++-------- xmtp_mls/src/groups/device_sync.rs | 30 ++- .../src/groups/device_sync/consent_sync.rs | 6 +- .../src/groups/device_sync/message_sync.rs | 8 +- xmtp_mls/src/groups/intents.rs | 8 +- xmtp_mls/src/groups/mls_sync.rs | 88 +++++-- xmtp_mls/src/groups/mod.rs | 134 +++++----- xmtp_mls/src/groups/scoped_client.rs | 32 +-- xmtp_mls/src/groups/subscriptions.rs | 45 ++-- xmtp_mls/src/identity.rs | 7 +- xmtp_mls/src/intents.rs | 66 +---- .../storage/encrypted_store/db_connection.rs | 2 + xmtp_mls/src/storage/encrypted_store/mod.rs | 73 ++++-- .../storage/encrypted_store/refresh_state.rs | 8 +- xmtp_mls/src/subscriptions.rs | 148 ++++++----- xmtp_mls/src/types.rs | 86 +++++- xmtp_mls/src/xmtp_openmls_provider.rs | 6 +- 29 files changed, 668 insertions(+), 506 deletions(-) diff --git a/bindings_ffi/src/lib.rs b/bindings_ffi/src/lib.rs index 670ec176e..7cd5a5b65 100755 --- a/bindings_ffi/src/lib.rs +++ b/bindings_ffi/src/lib.rs @@ -63,6 +63,8 @@ pub enum GenericError { pub enum FfiSubscribeError { #[error("Subscribe Error {0}")] Subscribe(#[from] xmtp_mls::subscriptions::SubscribeError), + #[error("Storage error: {0}")] + Storage(#[from] xmtp_mls::storage::StorageError), } impl From for GenericError { diff --git a/bindings_ffi/src/mls.rs b/bindings_ffi/src/mls.rs index 4861deb56..f5014e982 100644 --- a/bindings_ffi/src/mls.rs +++ b/bindings_ffi/src/mls.rs @@ -295,8 +295,8 @@ impl FfiXmtpClient { pub async fn find_inbox_id(&self, address: String) -> Result, GenericError> { let inner = self.inner_client.as_ref(); - - let result = inner.find_inbox_id_from_address(address).await?; + let conn = self.inner_client.store().conn()?; + let result = inner.find_inbox_id_from_address(&conn, address).await?; Ok(result) } @@ -881,8 +881,8 @@ impl FfiConversations { pub async fn sync(&self) -> Result<(), GenericError> { let inner = self.inner_client.as_ref(); - let conn = inner.store().conn()?; - inner.sync_welcomes(&conn).await?; + let provider = inner.mls_provider()?; + inner.sync_welcomes(&provider).await?; Ok(()) } @@ -898,12 +898,11 @@ impl FfiConversations { consent_state: Option, ) -> Result { let inner = self.inner_client.as_ref(); - let conn = inner.store().conn()?; - + let provider = inner.mls_provider()?; let consent: Option = consent_state.map(|state| state.into()); - - let num_groups_synced: usize = inner.sync_all_welcomes_and_groups(&conn, consent).await?; - + let num_groups_synced: usize = inner + .sync_all_welcomes_and_groups(&provider, consent) + .await?; // Convert usize to u32 for compatibility with Uniffi let num_groups_synced: u32 = num_groups_synced .try_into() @@ -1267,9 +1266,10 @@ impl FfiConversation { &self, envelope_bytes: Vec, ) -> Result { + let provider = self.inner.mls_provider()?; let message = self .inner - .process_streamed_group_message(envelope_bytes) + .process_streamed_group_message(&provider, envelope_bytes) .await?; let ffi_message = message.into(); @@ -1853,6 +1853,8 @@ mod tests { conversations: Mutex>>, consent_updates: Mutex>, notify: Notify, + inbox_id: Option, + installation_id: Option, } impl RustStreamCallback { @@ -1872,12 +1874,22 @@ mod tests { .await?; Ok(()) } + + pub fn from_client(client: &FfiXmtpClient) -> Self { + RustStreamCallback { + inbox_id: Some(client.inner_client.inbox_id().to_string()), + installation_id: Some(hex::encode(client.inner_client.installation_public_key())), + ..Default::default() + } + } } impl FfiMessageCallback for RustStreamCallback { fn on_message(&self, message: FfiMessage) { let mut messages = self.messages.lock().unwrap(); log::info!( + inbox_id = self.inbox_id, + installation_id = self.installation_id, "ON MESSAGE Received\n-------- \n{}\n----------", String::from_utf8_lossy(&message.content) ); @@ -1893,7 +1905,11 @@ mod tests { impl FfiConversationCallback for RustStreamCallback { fn on_conversation(&self, group: Arc) { - log::debug!("received conversation"); + log::debug!( + inbox_id = self.inbox_id, + installation_id = self.installation_id, + "received conversation" + ); let _ = self.num_messages.fetch_add(1, Ordering::SeqCst); let mut convos = self.conversations.lock().unwrap(); convos.push(group); @@ -1907,7 +1923,11 @@ mod tests { impl FfiConsentCallback for RustStreamCallback { fn on_consent_update(&self, mut consent: Vec) { - log::debug!("received consent update"); + log::debug!( + inbox_id = self.inbox_id, + installation_id = self.installation_id, + "received consent update" + ); let mut consent_updates = self.consent_updates.lock().unwrap(); consent_updates.append(&mut consent); self.notify.notify_one(); @@ -2774,7 +2794,7 @@ mod tests { let caro = new_test_client().await; // Alix begins a stream for all messages - let message_callbacks = Arc::new(RustStreamCallback::default()); + let message_callbacks = Arc::new(RustStreamCallback::from_client(&alix)); let stream_messages = alix .conversations() .stream_all_messages(message_callbacks.clone()) @@ -2821,12 +2841,12 @@ mod tests { let bo2 = new_test_client_with_wallet(bo_wallet).await; // Bo begins a stream for all messages - let bo_message_callbacks = Arc::new(RustStreamCallback::default()); - let bo_stream_messages = bo2 + let bo2_message_callbacks = Arc::new(RustStreamCallback::from_client(&bo2)); + let bo2_stream_messages = bo2 .conversations() - .stream_all_messages(bo_message_callbacks.clone()) + .stream_all_messages(bo2_message_callbacks.clone()) .await; - bo_stream_messages.wait_for_ready().await; + bo2_stream_messages.wait_for_ready().await; alix_group.update_installations().await.unwrap(); @@ -3190,7 +3210,7 @@ mod tests { let bo = new_test_client().await; let caro = new_test_client().await; - let caro_conn = caro.inner_client.store().conn().unwrap(); + let caro_provider = caro.inner_client.mls_provider().unwrap(); let alix_group = alix .conversations() @@ -3220,7 +3240,11 @@ mod tests { ) .await .unwrap(); - let _ = caro.inner_client.sync_welcomes(&caro_conn).await.unwrap(); + let _ = caro + .inner_client + .sync_welcomes(&caro_provider) + .await + .unwrap(); bo_group.send("second".as_bytes().to_vec()).await.unwrap(); stream_callback.wait_for_delivery(None).await.unwrap(); @@ -3239,7 +3263,7 @@ mod tests { let amal = new_test_client().await; let bola = new_test_client().await; - let bola_conn = bola.inner_client.store().conn().unwrap(); + let bola_provider = bola.inner_client.mls_provider().unwrap(); let amal_group: Arc = amal .conversations() @@ -3250,7 +3274,10 @@ mod tests { .await .unwrap(); - bola.inner_client.sync_welcomes(&bola_conn).await.unwrap(); + bola.inner_client + .sync_welcomes(&bola_provider) + .await + .unwrap(); let bola_group = bola.conversation(amal_group.id()).unwrap(); let stream_callback = Arc::new(RustStreamCallback::default()); diff --git a/bindings_node/src/client.rs b/bindings_node/src/client.rs index 5190525f5..cbe595bd9 100644 --- a/bindings_node/src/client.rs +++ b/bindings_node/src/client.rs @@ -275,9 +275,14 @@ impl Client { #[napi] pub async fn find_inbox_id_by_address(&self, address: String) -> Result> { + let conn = self + .inner_client() + .store() + .conn() + .map_err(ErrorWrapper::from)?; let inbox_id = self .inner_client - .find_inbox_id_from_address(address) + .find_inbox_id_from_address(&conn, address) .await .map_err(ErrorWrapper::from)?; diff --git a/bindings_node/src/conversation.rs b/bindings_node/src/conversation.rs index aeeffa6b1..897662463 100644 --- a/bindings_node/src/conversation.rs +++ b/bindings_node/src/conversation.rs @@ -197,8 +197,9 @@ impl Conversation { self.created_at_ns, ); let envelope_bytes: Vec = envelope_bytes.deref().to_vec(); + let provider = group.mls_provider().map_err(ErrorWrapper::from)?; let message = group - .process_streamed_group_message(envelope_bytes) + .process_streamed_group_message(&provider, envelope_bytes) .await .map_err(ErrorWrapper::from)?; diff --git a/bindings_node/src/conversations.rs b/bindings_node/src/conversations.rs index 2ee01f5c6..f957a9807 100644 --- a/bindings_node/src/conversations.rs +++ b/bindings_node/src/conversations.rs @@ -235,14 +235,13 @@ impl Conversations { #[napi] pub async fn sync(&self) -> Result<()> { - let conn = self + let provider = self .inner_client - .store() - .conn() + .mls_provider() .map_err(ErrorWrapper::from)?; self .inner_client - .sync_welcomes(&conn) + .sync_welcomes(&provider) .await .map_err(ErrorWrapper::from)?; Ok(()) @@ -250,15 +249,14 @@ impl Conversations { #[napi] pub async fn sync_all_conversations(&self) -> Result { - let conn = self + let provider = self .inner_client - .store() - .conn() + .mls_provider() .map_err(ErrorWrapper::from)?; let num_groups_synced = self .inner_client - .sync_all_welcomes_and_groups(&conn, None) + .sync_all_welcomes_and_groups(&provider, None) .await .map_err(ErrorWrapper::from)?; diff --git a/bindings_wasm/src/client.rs b/bindings_wasm/src/client.rs index 1db20bf67..f00088619 100644 --- a/bindings_wasm/src/client.rs +++ b/bindings_wasm/src/client.rs @@ -12,7 +12,6 @@ use xmtp_api_http::XmtpHttpApiClient; use xmtp_cryptography::signature::ed25519_public_key_to_address; use xmtp_id::associations::builder::SignatureRequest; use xmtp_mls::builder::ClientBuilder; -use xmtp_mls::groups::scoped_client::ScopedGroupClient; use xmtp_mls::identity::IdentityStrategy; use xmtp_mls::storage::{EncryptedMessageStore, EncryptionKey, StorageOption}; use xmtp_mls::Client as MlsClient; @@ -273,9 +272,14 @@ impl Client { #[wasm_bindgen(js_name = findInboxIdByAddress)] pub async fn find_inbox_id_by_address(&self, address: String) -> Result, JsError> { + let conn = self + .inner_client + .store() + .conn() + .map_err(|e| JsError::new(format!("{}", e).as_str()))?; let inbox_id = self .inner_client - .find_inbox_id_from_address(address) + .find_inbox_id_from_address(&conn, address) .await .map_err(|e| JsError::new(format!("{}", e).as_str()))?; diff --git a/bindings_wasm/src/conversations.rs b/bindings_wasm/src/conversations.rs index 790868d29..50f0790a2 100644 --- a/bindings_wasm/src/conversations.rs +++ b/bindings_wasm/src/conversations.rs @@ -269,14 +269,13 @@ impl Conversations { #[wasm_bindgen] pub async fn sync(&self) -> Result<(), JsError> { - let conn = self + let provider = self .inner_client - .store() - .conn() + .mls_provider() .map_err(|e| JsError::new(format!("{}", e).as_str()))?; self .inner_client - .sync_welcomes(&conn) + .sync_welcomes(&provider) .await .map_err(|e| JsError::new(format!("{}", e).as_str()))?; @@ -285,15 +284,14 @@ impl Conversations { #[wasm_bindgen(js_name = syncAllConversations)] pub async fn sync_all_conversations(&self) -> Result { - let conn = self + let provider = self .inner_client - .store() - .conn() + .mls_provider() .map_err(|e| JsError::new(format!("{}", e).as_str()))?; let num_groups_synced = self .inner_client - .sync_all_welcomes_and_groups(&conn, None) + .sync_all_welcomes_and_groups(&provider, None) .await .map_err(|e| JsError::new(format!("{}", e).as_str()))?; diff --git a/examples/cli/cli-client.rs b/examples/cli/cli-client.rs index b371edef1..a9513bec5 100755 --- a/examples/cli/cli-client.rs +++ b/examples/cli/cli-client.rs @@ -290,9 +290,9 @@ async fn main() -> color_eyre::eyre::Result<()> { } Commands::ListGroups {} => { info!("List Groups"); - let conn = client.store().conn()?; + let provider = client.mls_provider()?; client - .sync_welcomes(&conn) + .sync_welcomes(&provider) .await .expect("failed to sync welcomes"); @@ -440,9 +440,8 @@ async fn main() -> color_eyre::eyre::Result<()> { ); } Commands::RequestHistorySync {} => { - let conn = client.store().conn().unwrap(); let provider = client.mls_provider().unwrap(); - client.sync_welcomes(&conn).await.unwrap(); + client.sync_welcomes(&provider).await.unwrap(); client.start_sync_worker(); client .send_sync_request(&provider, DeviceSyncKind::MessageHistory) @@ -451,9 +450,9 @@ async fn main() -> color_eyre::eyre::Result<()> { info!("Sent history sync request in sync group.") } Commands::ListHistorySyncMessages {} => { - let conn = client.store().conn()?; - client.sync_welcomes(&conn).await?; - let group = client.get_sync_group(&conn)?; + let provider = client.mls_provider()?; + client.sync_welcomes(&provider).await?; + let group = client.get_sync_group(provider.conn_ref())?; let group_id_str = hex::encode(group.group_id.clone()); group.sync().await?; let messages = group @@ -574,8 +573,8 @@ where } async fn get_group(client: &Client, group_id: Vec) -> Result { - let conn = client.store().conn().unwrap(); - client.sync_welcomes(&conn).await?; + let provider = client.mls_provider().unwrap(); + client.sync_welcomes(&provider).await?; let group = client.group(group_id)?; group .sync() diff --git a/xmtp_debug/src/app/generate/messages.rs b/xmtp_debug/src/app/generate/messages.rs index aeb118ff7..fc6006d85 100644 --- a/xmtp_debug/src/app/generate/messages.rs +++ b/xmtp_debug/src/app/generate/messages.rs @@ -8,7 +8,6 @@ use crate::{ use color_eyre::eyre::{self, eyre, Result}; use rand::{rngs::SmallRng, seq::SliceRandom, Rng, SeedableRng}; use std::sync::Arc; -use xmtp_mls::XmtpOpenMlsProvider; mod content_type; @@ -118,10 +117,9 @@ impl GenerateMessages { hex::encode(inbox_id) ))?; let client = app::client_from_identity(&identity, &network).await?; - let conn = client.store().conn()?; - client.sync_welcomes(&conn).await?; + let provider = client.mls_provider()?; + client.sync_welcomes(&provider).await?; let group = client.group(group.id.into())?; - let provider: XmtpOpenMlsProvider = conn.into(); group.maybe_update_installations(&provider, None).await?; group.sync_with_conn(&provider).await?; let words = rng.gen_range(0..*max_message_size); diff --git a/xmtp_debug/src/app/send.rs b/xmtp_debug/src/app/send.rs index 59601d4fe..35a95cbf7 100644 --- a/xmtp_debug/src/app/send.rs +++ b/xmtp_debug/src/app/send.rs @@ -49,8 +49,8 @@ impl Send { .ok_or(eyre!("No Identity with inbox_id [{}]", hex::encode(member)))?; let client = crate::app::client_from_identity(&identity, network).await?; - let conn = client.store().conn()?; - client.sync_welcomes(&conn).await?; + let provider = client.mls_provider()?; + client.sync_welcomes(&provider).await?; let xmtp_group = client.group(group.id.to_vec())?; xmtp_group.send_message(data.as_bytes()).await?; Ok(()) diff --git a/xmtp_mls/src/api/mls.rs b/xmtp_mls/src/api/mls.rs index 4000467e6..33a42a577 100644 --- a/xmtp_mls/src/api/mls.rs +++ b/xmtp_mls/src/api/mls.rs @@ -115,9 +115,9 @@ where } #[tracing::instrument(level = "trace", skip_all)] - pub async fn query_welcome_messages( + pub async fn query_welcome_messages + Copy>( &self, - installation_id: &[u8], + installation_id: Id, id_cursor: Option, ) -> Result, ApiError> { tracing::debug!( @@ -135,7 +135,7 @@ where (async { self.api_client .query_welcome_messages(QueryWelcomeMessagesRequest { - installation_key: installation_id.to_vec(), + installation_key: installation_id.as_ref().to_vec(), paging_info: Some(PagingInfo { id_cursor: id_cursor.unwrap_or(0), limit: page_size, @@ -297,7 +297,7 @@ where pub async fn subscribe_welcome_messages( &self, - installation_key: Vec, + installation_key: &[u8], id_cursor: Option, ) -> Result> + '_, ApiError> where @@ -307,7 +307,7 @@ where self.api_client .subscribe_welcome_messages(SubscribeWelcomeMessagesRequest { filters: vec![WelcomeFilterProto { - installation_key, + installation_key: installation_key.to_vec(), id_cursor: id_cursor.unwrap_or(0), }], }) diff --git a/xmtp_mls/src/builder.rs b/xmtp_mls/src/builder.rs index 930c81506..997a92ce8 100644 --- a/xmtp_mls/src/builder.rs +++ b/xmtp_mls/src/builder.rs @@ -13,7 +13,7 @@ use crate::{ identity_updates::load_identity_updates, retry::Retry, storage::EncryptedMessageStore, - StorageError, XmtpApi, + StorageError, XmtpApi, XmtpOpenMlsProvider, }; #[derive(Error, Debug)] @@ -186,20 +186,22 @@ where let store = store .take() .ok_or(ClientBuilderError::MissingParameter { parameter: "store" })?; + let conn = store.conn()?; + let provider = XmtpOpenMlsProvider::new(conn); + let identity = identity_strategy + .initialize_identity(&api_client_wrapper, &provider, &scw_verifier) + .await?; debug!( - inbox_id = identity_strategy.inbox_id(), - "Initializing identity" + inbox_id = identity.inbox_id(), + installation_id = hex::encode(identity.installation_keys.public_bytes()), + "Initialized identity" ); - let identity = identity_strategy - .initialize_identity(&api_client_wrapper, &store, &scw_verifier) - .await?; - // get sequence_id from identity updates and loaded into the DB load_identity_updates( &api_client_wrapper, - &store.conn()?, + provider.conn_ref(), vec![identity.inbox_id.as_str()].as_slice(), ) .await?; @@ -557,7 +559,7 @@ pub(crate) mod tests { let identity = IdentityStrategy::new("other_inbox_id".to_string(), address, nonce, None); assert!(matches!( identity - .initialize_identity(&wrapper, &store, &scw_verifier) + .initialize_identity(&wrapper, &store.mls_provider().unwrap(), &scw_verifier) .await .unwrap_err(), IdentityError::NewIdentity(msg) if msg == "Inbox ID mismatch" @@ -598,7 +600,7 @@ pub(crate) mod tests { let identity = IdentityStrategy::new(inbox_id.clone(), address, nonce, None); assert!(dbg!( identity - .initialize_identity(&wrapper, &store, &scw_verifier) + .initialize_identity(&wrapper, &store.mls_provider().unwrap(), &scw_verifier) .await ) .is_ok()); @@ -636,7 +638,7 @@ pub(crate) mod tests { let wrapper = ApiClientWrapper::new(mock_api.into(), Retry::default()); let identity = IdentityStrategy::new(inbox_id.clone(), address, nonce, None); assert!(identity - .initialize_identity(&wrapper, &store, &scw_verifier) + .initialize_identity(&wrapper, &store.mls_provider().unwrap(), &scw_verifier) .await .is_ok()); } @@ -676,7 +678,7 @@ pub(crate) mod tests { let inbox_id = "inbox_id".to_string(); let identity = IdentityStrategy::new(inbox_id.clone(), address.clone(), nonce, None); let err = identity - .initialize_identity(&wrapper, &store, &scw_verifier) + .initialize_identity(&wrapper, &store.mls_provider().unwrap(), &scw_verifier) .await .unwrap_err(); diff --git a/xmtp_mls/src/client.rs b/xmtp_mls/src/client.rs index 051c12500..81751d5a8 100644 --- a/xmtp_mls/src/client.rs +++ b/xmtp_mls/src/client.rs @@ -12,7 +12,6 @@ use openmls::{ messages::Welcome, prelude::tls_codec::{Deserialize, Error as TlsCodecError}, }; -use openmls_traits::OpenMlsProvider; use thiserror::Error; use tokio::sync::broadcast; @@ -36,7 +35,7 @@ use crate::{ groups::{group_permissions::PolicySet, GroupError, GroupMetadataOptions, MlsGroup}, identity::{parse_credential, Identity, IdentityError}, identity_updates::{load_identity_updates, IdentityUpdateError}, - intents::Intents, + intents::ProcessIntentError, mutex_registry::MutexRegistry, preferences::UserPreferenceUpdate, retry::Retry, @@ -51,6 +50,7 @@ use crate::{ EncryptedMessageStore, StorageError, }, subscriptions::{LocalEventError, LocalEvents}, + types::InstallationId, verified_key_package_v2::{KeyPackageVerificationError, VerifiedKeyPackageV2}, xmtp_openmls_provider::XmtpOpenMlsProvider, Fetch, Store, XmtpApi, @@ -138,7 +138,6 @@ impl From<&str> for ClientError { /// Clients manage access to the network, identity, and data store pub struct Client> { pub(crate) api_client: Arc>, - pub(crate) intents: Arc, pub(crate) context: Arc, pub(crate) history_sync_url: Option, pub(crate) local_events: broadcast::Sender>, @@ -155,7 +154,6 @@ impl Clone for Client { history_sync_url: self.history_sync_url.clone(), local_events: self.local_events.clone(), scw_verifier: self.scw_verifier.clone(), - intents: self.intents.clone(), } } } @@ -173,8 +171,8 @@ pub struct XmtpMlsLocalContext { impl XmtpMlsLocalContext { /// The installation public key is the primary identifier for an installation - pub fn installation_public_key(&self) -> &[u8; 32] { - self.identity.installation_keys.public_bytes() + pub fn installation_public_key(&self) -> InstallationId { + (*self.identity.installation_keys.public_bytes()).into() } /// Get the account address of the blockchain account associated with this client @@ -192,7 +190,7 @@ impl XmtpMlsLocalContext { } /// Pulls a new database connection and creates a new provider - pub fn mls_provider(&self) -> Result { + pub fn mls_provider(&self) -> Result { Ok(self.store.conn()?.into()) } @@ -234,9 +232,6 @@ where store, mutexes: MutexRegistry::new(), }); - let intents = Arc::new(Intents { - context: context.clone(), - }); let (tx, _) = broadcast::channel(32); Self { @@ -245,7 +240,6 @@ where history_sync_url, local_events: tx, scw_verifier: scw_verifier.into(), - intents, } } @@ -279,7 +273,7 @@ where V: SmartContractSignatureVerifier, { /// Retrieves the client's installation public key, sometimes also called `installation_id` - pub fn installation_public_key(&self) -> &[u8; 32] { + pub fn installation_public_key(&self) -> InstallationId { self.context.installation_public_key() } /// Retrieves the client's inbox ID @@ -287,12 +281,8 @@ where self.context.inbox_id() } - pub fn intents(&self) -> &Arc { - &self.intents - } - /// Pulls a connection and creates a new MLS Provider - pub fn mls_provider(&self) -> Result { + pub fn mls_provider(&self) -> Result { self.context.mls_provider() } @@ -303,9 +293,10 @@ where /// Calls the server to look up the `inbox_id` associated with a given address pub async fn find_inbox_id_from_address( &self, + conn: &DbConnection, address: String, ) -> Result, ClientError> { - let results = self.find_inbox_ids_from_addresses(&[address]).await?; + let results = self.find_inbox_ids_from_addresses(conn, &[address]).await?; if let Some(first_result) = results.into_iter().next() { Ok(first_result) } else { @@ -315,12 +306,12 @@ where /// Calls the server to look up the `inbox_id`s` associated with a list of addresses. /// If no `inbox_id` is found, returns None. - pub async fn find_inbox_ids_from_addresses( + pub(crate) async fn find_inbox_ids_from_addresses( &self, + conn: &DbConnection, addresses: &[String], ) -> Result>, ClientError> { let sanitized_addresses = sanitize_evm_addresses(addresses)?; - let conn = self.store().conn()?; let local_results: Vec = conn.fetch_wallets_list_with_key(&sanitized_addresses)?; @@ -354,7 +345,7 @@ where inbox_id: InboxId::from(inbox_id), wallet_address: address, }; - new_entry.store(&conn).ok(); + new_entry.store(conn).ok(); } let inbox_ids: Vec> = sanitized_addresses @@ -424,7 +415,7 @@ where } let inbox_ids = self - .find_inbox_ids_from_addresses(&addresses_to_lookup) + .find_inbox_ids_from_addresses(&conn, &addresses_to_lookup) .await?; for (i, inbox_id_opt) in inbox_ids.into_iter().enumerate() { @@ -464,7 +455,10 @@ where ) -> Result { let conn = self.store().conn()?; let record = if entity_type == ConsentType::Address { - if let Some(inbox_id) = self.find_inbox_id_from_address(entity.clone()).await? { + if let Some(inbox_id) = self + .find_inbox_id_from_address(&conn, entity.clone()) + .await? + { conn.get_consent_record(inbox_id, ConsentType::InboxId)? } else { conn.get_consent_record(entity, entity_type)? @@ -509,9 +503,10 @@ where opts: GroupMetadataOptions, ) -> Result, ClientError> { tracing::info!("creating group"); - + let provider = self.mls_provider()?; let group: MlsGroup> = MlsGroup::create_and_insert( Arc::new(self.clone()), + &provider, GroupMembershipState::Allowed, permissions_policy_set.unwrap_or_default(), opts, @@ -541,9 +536,10 @@ where /// Create a new Direct Message with the default settings pub async fn create_dm(&self, account_address: String) -> Result, ClientError> { tracing::info!("creating dm with address: {}", account_address); + let provider = self.mls_provider()?; let inbox_id = match self - .find_inbox_id_from_address(account_address.clone()) + .find_inbox_id_from_address(provider.conn_ref(), account_address.clone()) .await? { Some(id) => id, @@ -555,23 +551,26 @@ where } }; - self.create_dm_by_inbox_id(inbox_id).await + self.create_dm_by_inbox_id(&provider, inbox_id).await } /// Create a new Direct Message with the default settings - pub async fn create_dm_by_inbox_id( + pub(crate) async fn create_dm_by_inbox_id( &self, + provider: &XmtpOpenMlsProvider, dm_target_inbox_id: InboxId, ) -> Result, ClientError> { tracing::info!("creating dm with {}", dm_target_inbox_id); - let group: MlsGroup> = MlsGroup::create_dm_and_insert( + provider, Arc::new(self.clone()), GroupMembershipState::Allowed, dm_target_inbox_id.clone(), )?; - group.add_members_by_inbox_id(&[dm_target_inbox_id]).await?; + group + .add_members_by_inbox_id_with_provider(provider, &[dm_target_inbox_id]) + .await?; // notify any streams of the new group let _ = self.local_events.send(LocalEvents::NewGroup(group.clone())); @@ -694,12 +693,19 @@ where /// Upload a new key package to the network replacing an existing key package /// This is expected to be run any time the client receives new Welcome messages - pub async fn rotate_key_package(&self) -> Result<(), ClientError> { + pub async fn rotate_key_package( + &self, + provider: &XmtpOpenMlsProvider, + ) -> Result<(), ClientError> { self.store() - .transaction_async(|provider| async move { - self.identity() - .rotate_key_package(&provider, &self.api_client) - .await + .transaction_async(provider, move |provider| { + let provider = &provider; + async { + self.identity() + .rotate_key_package(provider, &self.api_client) + .await?; + Ok::<_, IdentityError>(()) + } }) .await?; @@ -734,7 +740,7 @@ where let welcomes = self .api_client - .query_welcome_messages(installation_id, Some(id_cursor as u64)) + .query_welcome_messages(installation_id.as_ref(), Some(id_cursor as u64)) .await?; Ok(welcomes) @@ -748,10 +754,10 @@ where ) -> Result, ClientError> { let key_package_results = self.api_client.fetch_key_packages(installation_ids).await?; - let mls_provider = self.mls_provider()?; + let crypto_provider = XmtpOpenMlsProvider::new_crypto(); Ok(key_package_results .values() - .map(|bytes| VerifiedKeyPackageV2::from_bytes(mls_provider.crypto(), bytes.as_slice())) + .map(|bytes| VerifiedKeyPackageV2::from_bytes(&crypto_provider, bytes.as_slice())) .collect::>()?) } @@ -760,11 +766,10 @@ where #[tracing::instrument(level = "debug", skip_all)] pub async fn sync_welcomes( &self, - conn: &DbConnection, + provider: &XmtpOpenMlsProvider, ) -> Result>, ClientError> { - let envelopes = self.query_welcome_messages(conn).await?; + let envelopes = self.query_welcome_messages(provider.conn_ref()).await?; let num_envelopes = envelopes.len(); - let id = self.installation_public_key(); let groups: Vec> = stream::iter(envelopes.into_iter()) .filter_map(|envelope: WelcomeMessage| async { @@ -777,73 +782,86 @@ where }; retry_async!( Retry::default(), - (async { - let welcome_v1 = &welcome_v1; - self.intents.process_for_id( - id, - EntityKind::Welcome, - welcome_v1.id, - |provider| async move { - let result = MlsGroup::create_from_encrypted_welcome( - Arc::new(self.clone()), - &provider, - welcome_v1.hpke_public_key.as_slice(), - &welcome_v1.data, - welcome_v1.id as i64, - ) - .await; - - match result { - Ok(mls_group) => Ok(Some(mls_group)), - Err(err) => { - use crate::StorageError::*; - use crate::DuplicateItem::*; - - if matches!(err, GroupError::Storage(Duplicate(WelcomeId(_)))) { - tracing::warn!("failed to create group from welcome due to duplicate welcome ID: {}", err); - } else { - tracing::error!("failed to create group from welcome: {}", err); - } - - Err(err) - } - } - }, - ) - .await - }) + (async { self.process_new_welcome(provider, &welcome_v1).await }) ) .ok() - .flatten() }) .collect() .await; // If any welcomes were found, rotate your key package if num_envelopes > 0 { - self.rotate_key_package().await?; + self.rotate_key_package(provider).await?; } Ok(groups) } + /// Internal API to process a unread welcome message and convert to a group. + /// In a database transaction, increments the cursor for a given installation and + /// applies the update after the welcome processed succesfully. + async fn process_new_welcome( + &self, + provider: &XmtpOpenMlsProvider, + welcome: &WelcomeMessageV1, + ) -> Result, GroupError> { + self.store() + .transaction_async(provider, |provider| async move { + let cursor = welcome.id; + let is_updated = provider.conn_ref().update_cursor( + self.installation_public_key(), + EntityKind::Welcome, + welcome.id as i64, + )?; + if !is_updated { + return Err(ProcessIntentError::AlreadyProcessed(cursor).into()); + } + let result = MlsGroup::create_from_encrypted_welcome( + Arc::new(self.clone()), + provider, + welcome.hpke_public_key.as_slice(), + &welcome.data, + welcome.id as i64, + ) + .await; + + match result { + Ok(mls_group) => Ok(mls_group), + Err(err) => { + use crate::DuplicateItem::*; + use crate::StorageError::*; + + if matches!(err, GroupError::Storage(Duplicate(WelcomeId(_)))) { + tracing::warn!( + "failed to create group from welcome due to duplicate welcome ID: {}", + err + ); + } else { + tracing::error!("failed to create group from welcome: {}", err); + } + + Err(err) + } + } + }) + .await + } + /// Sync all groups for the current installation and return the number of groups that were synced. /// Only active groups will be synced. - pub async fn sync_all_groups(&self, groups: Vec>) -> Result { - // Acquire a single connection to be reused - let provider: XmtpOpenMlsProvider = self.mls_provider()?; - + pub async fn sync_all_groups( + &self, + groups: Vec>, + provider: &XmtpOpenMlsProvider, + ) -> Result { let active_group_count = Arc::new(AtomicUsize::new(0)); let sync_futures = groups .into_iter() .map(|group| { - // create new provider ref that gets moved, leaving original - // provider alone. - let provider_ref = &provider; let active_group_count = Arc::clone(&active_group_count); async move { - let mls_group = group.load_mls_group(provider_ref)?; + let mls_group = group.load_mls_group(provider)?; tracing::info!( inbox_id = self.inbox_id(), "[{}] syncing group", @@ -857,9 +875,9 @@ where mls_group.epoch() ); if mls_group.is_active() { - group.maybe_update_installations(provider_ref, None).await?; + group.maybe_update_installations(provider, None).await?; - group.sync_with_conn(provider_ref).await?; + group.sync_with_conn(provider).await?; active_group_count.fetch_add(1, Ordering::SeqCst); } @@ -881,19 +899,22 @@ where /// Returns the total number of active groups synced. pub async fn sync_all_welcomes_and_groups( &self, - conn: &DbConnection, + provider: &XmtpOpenMlsProvider, consent_state: Option, ) -> Result { - self.sync_welcomes(conn).await?; - + self.sync_welcomes(provider).await?; let query_args = GroupQueryArgs { consent_state, include_sync_groups: true, ..GroupQueryArgs::default() }; - - let groups = self.find_groups(query_args)?; - let active_groups_count = self.sync_all_groups(groups).await?; + let groups = provider + .conn_ref() + .find_groups(query_args)? + .into_iter() + .map(|g| MlsGroup::new(self.clone(), g.id, g.created_at_ns)) + .collect(); + let active_groups_count = self.sync_all_groups(groups, provider).await?; Ok(active_groups_count) } @@ -1069,7 +1090,10 @@ pub(crate) mod tests { let init1 = kp1[0].inner.hpke_init_key(); // Rotate and fetch again. - client.rotate_key_package().await.unwrap(); + client + .rotate_key_package(&client.mls_provider().unwrap()) + .await + .unwrap(); let kp2 = client .get_key_packages_for_installation_ids(vec![client.installation_public_key().to_vec()]) @@ -1105,7 +1129,7 @@ pub(crate) mod tests { let client = ClientBuilder::new_test_client(&wallet).await; assert_eq!( client - .find_inbox_id_from_address(wallet.get_address()) + .find_inbox_id_from_address(&client.store().conn().unwrap(), wallet.get_address()) .await .unwrap(), Some(client.inbox_id().to_string()) @@ -1130,7 +1154,7 @@ pub(crate) mod tests { .unwrap(); let bob_received_groups = bob - .sync_welcomes(&bob.store().conn().unwrap()) + .sync_welcomes(&bob.mls_provider().unwrap()) .await .unwrap(); assert_eq!(bob_received_groups.len(), 1); @@ -1140,7 +1164,7 @@ pub(crate) mod tests { ); let duplicate_received_groups = bob - .sync_welcomes(&bob.store().conn().unwrap()) + .sync_welcomes(&bob.mls_provider().unwrap()) .await .unwrap(); assert_eq!(duplicate_received_groups.len(), 0); @@ -1170,7 +1194,7 @@ pub(crate) mod tests { .await .unwrap(); - let bob_received_groups = bo.sync_welcomes(&bo.store().conn().unwrap()).await.unwrap(); + let bob_received_groups = bo.sync_welcomes(&bo.mls_provider().unwrap()).await.unwrap(); assert_eq!(bob_received_groups.len(), 2); let bo_groups = bo.find_groups(GroupQueryArgs::default()).unwrap(); @@ -1189,7 +1213,9 @@ pub(crate) mod tests { .await .unwrap(); - bo.sync_all_groups(bo_groups).await.unwrap(); + bo.sync_all_groups(bo_groups, &bo.mls_provider().unwrap()) + .await + .unwrap(); let bo_messages1 = bo_group1.find_messages(&MsgQueryArgs::default()).unwrap(); assert_eq!(bo_messages1.len(), 1); @@ -1226,7 +1252,7 @@ pub(crate) mod tests { // Initial sync (None): Bob should fetch both groups let bob_received_groups = bo - .sync_all_welcomes_and_groups(&bo.store().conn().unwrap(), None) + .sync_all_welcomes_and_groups(&bo.mls_provider().unwrap(), None) .await .unwrap(); assert_eq!(bob_received_groups, 2); @@ -1261,7 +1287,7 @@ pub(crate) mod tests { // Sync with `Unknown`: Bob should not fetch new messages let bob_received_groups_unknown = bo - .sync_all_welcomes_and_groups(&bo.store().conn().unwrap(), Some(ConsentState::Allowed)) + .sync_all_welcomes_and_groups(&bo.mls_provider().unwrap(), Some(ConsentState::Allowed)) .await .unwrap(); assert_eq!(bob_received_groups_unknown, 0); @@ -1294,7 +1320,7 @@ pub(crate) mod tests { // Sync with `None`: Bob should fetch all messages let bob_received_groups_all = bo - .sync_all_welcomes_and_groups(&bo.store().conn().unwrap(), Some(ConsentState::Unknown)) + .sync_all_welcomes_and_groups(&bo.mls_provider().unwrap(), Some(ConsentState::Unknown)) .await .unwrap(); assert_eq!(bob_received_groups_all, 2); @@ -1355,7 +1381,7 @@ pub(crate) mod tests { assert_eq!(amal_group.members().await.unwrap().len(), 1); tracing::info!("Syncing bolas welcomes"); // See if Bola can see that they were added to the group - bola.sync_welcomes(&bola.store().conn().unwrap()) + bola.sync_welcomes(&bola.mls_provider().unwrap()) .await .unwrap(); let bola_groups = bola.find_groups(Default::default()).unwrap(); @@ -1376,7 +1402,7 @@ pub(crate) mod tests { .add_members_by_inbox_id(&[bola.inbox_id()]) .await .unwrap(); - bola.sync_welcomes(&bola.store().conn().unwrap()) + bola.sync_welcomes(&bola.mls_provider().unwrap()) .await .unwrap(); @@ -1426,12 +1452,13 @@ pub(crate) mod tests { async fn get_key_package_init_key< ApiClient: XmtpApi, Verifier: SmartContractSignatureVerifier, + Id: AsRef<[u8]>, >( client: &Client, - installation_id: &[u8], + installation_id: Id, ) -> Vec { let kps = client - .get_key_packages_for_installation_ids(vec![installation_id.to_vec()]) + .get_key_packages_for_installation_ids(vec![installation_id.as_ref().to_vec()]) .await .unwrap(); let kp = kps.first().unwrap(); @@ -1468,18 +1495,18 @@ pub(crate) mod tests { .await .unwrap(); - bo.sync_welcomes(&bo.store().conn().unwrap()).await.unwrap(); + bo.sync_welcomes(&bo.mls_provider().unwrap()).await.unwrap(); let bo_new_key = get_key_package_init_key(&bo, bo.installation_public_key()).await; // Bo's key should have changed assert_ne!(bo_original_init_key, bo_new_key); - bo.sync_welcomes(&bo.store().conn().unwrap()).await.unwrap(); + bo.sync_welcomes(&bo.mls_provider().unwrap()).await.unwrap(); let bo_new_key_2 = get_key_package_init_key(&bo, bo.installation_public_key()).await; // Bo's key should not have changed syncing the second time. assert_eq!(bo_new_key, bo_new_key_2); - alix.sync_welcomes(&alix.store().conn().unwrap()) + alix.sync_welcomes(&alix.mls_provider().unwrap()) .await .unwrap(); let alix_key_2 = get_key_package_init_key(&alix, alix.installation_public_key()).await; @@ -1493,7 +1520,7 @@ pub(crate) mod tests { ) .await .unwrap(); - bo.sync_welcomes(&bo.store().conn().unwrap()).await.unwrap(); + bo.sync_welcomes(&bo.mls_provider().unwrap()).await.unwrap(); // Bo should have two groups now let bo_groups = bo.find_groups(GroupQueryArgs::default()).unwrap(); diff --git a/xmtp_mls/src/groups/device_sync.rs b/xmtp_mls/src/groups/device_sync.rs index 89c86cb7a..c0af9ac84 100644 --- a/xmtp_mls/src/groups/device_sync.rs +++ b/xmtp_mls/src/groups/device_sync.rs @@ -401,7 +401,7 @@ where let content = DeviceSyncContent::Request(request.clone()); let content_bytes = serde_json::to_vec(&content)?; - let _message_id = sync_group.prepare_message(&content_bytes, provider.conn_ref(), { + let _message_id = sync_group.prepare_message(&content_bytes, provider, { let request = request.clone(); move |_time_ns| PlaintextEnvelope { content: Some(Content::V2(V2 { @@ -445,7 +445,6 @@ where provider: &XmtpOpenMlsProvider, contents: DeviceSyncReplyProto, ) -> Result<(), DeviceSyncError> { - let conn = provider.conn_ref(); // find the sync group let sync_group = self.get_sync_group(provider.conn_ref())?; @@ -457,7 +456,7 @@ where .await?; // add original sender to all groups on this device on the node - self.ensure_member_of_all_groups(conn, &msg.sender_inbox_id) + self.ensure_member_of_all_groups(provider, &msg.sender_inbox_id) .await?; // the reply message @@ -471,7 +470,7 @@ where (content_bytes, contents) }; - sync_group.prepare_message(&content_bytes, conn, |_time_ns| PlaintextEnvelope { + sync_group.prepare_message(&content_bytes, provider, |_time_ns| PlaintextEnvelope { content: Some(Content::V2(V2 { idempotency_key: new_request_id(), message_type: Some(MessageType::DeviceSyncReply(contents)), @@ -491,8 +490,10 @@ where let sync_group = self.get_sync_group(provider.conn_ref())?; sync_group.sync_with_conn(provider).await?; - let messages = sync_group - .find_messages(&MsgQueryArgs::default().kind(GroupMessageKind::Application))?; + let messages = provider.conn_ref().get_group_messages( + &sync_group.group_id, + &MsgQueryArgs::default().kind(GroupMessageKind::Application), + )?; for msg in messages.into_iter().rev() { let Ok(msg_content) = @@ -565,13 +566,14 @@ where self.insert_encrypted_syncables(provider, enc_payload, &enc_key.try_into()?) .await?; - self.sync_welcomes(provider.conn_ref()).await?; + self.sync_welcomes(provider).await?; let groups = conn.find_groups(GroupQueryArgs::default().conversation_type(ConversationType::Group))?; for crate::storage::group::StoredGroup { id, .. } in groups.into_iter() { - let group = self.group(id)?; - Box::pin(group.sync()).await?; + let group = self.group_with_conn(provider.conn_ref(), id)?; + group.maybe_update_installations(provider, None).await?; + Box::pin(group.sync_with_conn(provider)).await?; } Ok(()) @@ -579,14 +581,18 @@ where async fn ensure_member_of_all_groups( &self, - conn: &DbConnection, + provider: &XmtpOpenMlsProvider, inbox_id: &str, ) -> Result<(), GroupError> { + let conn = provider.conn_ref(); let groups = conn.find_groups(GroupQueryArgs::default().conversation_type(ConversationType::Group))?; for group in groups { - let group = self.group(group.id)?; - Box::pin(group.add_members_by_inbox_id(&[inbox_id.to_string()])).await?; + let group = self.group_with_conn(conn, group.id)?; + Box::pin( + group.add_members_by_inbox_id_with_provider(provider, &[inbox_id.to_string()]), + ) + .await?; } Ok(()) diff --git a/xmtp_mls/src/groups/device_sync/consent_sync.rs b/xmtp_mls/src/groups/device_sync/consent_sync.rs index a4053288b..4003322b8 100644 --- a/xmtp_mls/src/groups/device_sync/consent_sync.rs +++ b/xmtp_mls/src/groups/device_sync/consent_sync.rs @@ -19,15 +19,13 @@ where "Streaming consent update. {:?}", record ); - let conn = provider.conn_ref(); let sync_group = self.ensure_sync_group(provider).await?; - let update_proto: UserPreferenceUpdateProto = UserPreferenceUpdate::ConsentUpdate(record) .try_into() .map_err(|e| DeviceSyncError::Bincode(format!("{e:?}")))?; let content_bytes = serde_json::to_vec(&update_proto)?; - sync_group.prepare_message(&content_bytes, conn, |_time_ns| PlaintextEnvelope { + sync_group.prepare_message(&content_bytes, provider, |_time_ns| PlaintextEnvelope { content: Some(Content::V2(V2 { idempotency_key: new_request_id(), message_type: Some(MessageType::UserPreferenceUpdate(update_proto)), @@ -111,7 +109,7 @@ pub(crate) mod tests { let old_group_id = amal_a.get_sync_group(amal_a_conn).unwrap().group_id; tracing::info!("Old Group Id: {}", hex::encode(&old_group_id)); // Check for new welcomes to new groups in the first installation (should be welcomed to a new sync group from amal_b). - amal_a.sync_welcomes(amal_a_conn).await.unwrap(); + amal_a.sync_welcomes(&amal_a_provider).await.unwrap(); let new_group_id = amal_a.get_sync_group(amal_a_conn).unwrap().group_id; tracing::info!("New Group Id: {}", hex::encode(&new_group_id)); // group id should have changed to the new sync group created by the second installation diff --git a/xmtp_mls/src/groups/device_sync/message_sync.rs b/xmtp_mls/src/groups/device_sync/message_sync.rs index 056f0c3e7..66e63e261 100644 --- a/xmtp_mls/src/groups/device_sync/message_sync.rs +++ b/xmtp_mls/src/groups/device_sync/message_sync.rs @@ -106,7 +106,7 @@ pub(crate) mod tests { let old_group_id = amal_a.get_sync_group(amal_a_conn).unwrap().group_id; // Check for new welcomes to new groups in the first installation (should be welcomed to a new sync group from amal_b). amal_a - .sync_welcomes(amal_a_conn) + .sync_welcomes(&amal_a_provider) .await .expect("sync_welcomes"); let new_group_id = amal_a.get_sync_group(amal_a_conn).unwrap().group_id; @@ -205,7 +205,7 @@ pub(crate) mod tests { // Check for new welcomes to new groups in the first installation (should be welcomed to a new sync group from amal_b). amal_a - .sync_welcomes(amal_a_conn) + .sync_welcomes(&amal_a_provider) .await .expect("sync_welcomes"); let new_group_id = amal_a.get_sync_group(amal_a_conn).unwrap().group_id; @@ -262,7 +262,7 @@ pub(crate) mod tests { async fn test_externals_cant_join_sync_group() { let wallet = generate_local_wallet(); let amal = ClientBuilder::new_test_client_with_history(&wallet, HISTORY_SYNC_URL).await; - amal.sync_welcomes(&amal.store().conn().unwrap()) + amal.sync_welcomes(&amal.mls_provider().unwrap()) .await .expect("sync welcomes"); @@ -271,7 +271,7 @@ pub(crate) mod tests { ClientBuilder::new_test_client_with_history(&bo_wallet, HISTORY_SYNC_URL).await; bo_client - .sync_welcomes(&bo_client.store().conn().unwrap()) + .sync_welcomes(&bo_client.mls_provider().unwrap()) .await .expect("sync welcomes"); diff --git a/xmtp_mls/src/groups/intents.rs b/xmtp_mls/src/groups/intents.rs index 80f5eb167..b3faf11b4 100644 --- a/xmtp_mls/src/groups/intents.rs +++ b/xmtp_mls/src/groups/intents.rs @@ -33,6 +33,7 @@ use crate::{ }, types::Address, verified_key_package_v2::{KeyPackageVerificationError, VerifiedKeyPackageV2}, + XmtpOpenMlsProvider, }; use super::{ @@ -58,16 +59,17 @@ pub enum IntentError { impl MlsGroup { pub fn queue_intent( &self, + provider: &XmtpOpenMlsProvider, intent_kind: IntentKind, intent_data: Vec, ) -> Result { - self.context().store().transaction(|provider| { + self.context().store().transaction(provider, |provider| { let conn = provider.conn_ref(); self.queue_intent_with_conn(conn, intent_kind, intent_data) }) } - pub fn queue_intent_with_conn( + fn queue_intent_with_conn( &self, conn: &DbConnection, intent_kind: IntentKind, @@ -800,7 +802,7 @@ pub(crate) mod tests { // Client B sends a message to Client A let groups_b = client_b - .sync_welcomes(&client_b.store().conn().unwrap()) + .sync_welcomes(&client_b.mls_provider().unwrap()) .await .unwrap(); assert_eq!(groups_b.len(), 1); diff --git a/xmtp_mls/src/groups/mls_sync.rs b/xmtp_mls/src/groups/mls_sync.rs index 60a609cd4..91426988d 100644 --- a/xmtp_mls/src/groups/mls_sync.rs +++ b/xmtp_mls/src/groups/mls_sync.rs @@ -170,6 +170,7 @@ where let mls_provider = XmtpOpenMlsProvider::from(conn); tracing::info!( inbox_id = self.client.inbox_id(), + installation_id = %self.client.installation_id(), group_id = hex::encode(&self.group_id), current_epoch = self.load_mls_group(&mls_provider)?.epoch().as_u64(), "[{}] syncing group", @@ -177,6 +178,7 @@ where ); tracing::info!( inbox_id = self.client.inbox_id(), + installation_id = %self.client.installation_id(), group_id = hex::encode(&self.group_id), current_epoch = self.load_mls_group(&mls_provider)?.epoch().as_u64(), "current epoch for [{}] in sync() is Epoch: [{}]", @@ -363,7 +365,7 @@ where let group_epoch = openmls_group.epoch(); debug!( inbox_id = self.client.inbox_id(), - installation_id = hex::encode(self.client.installation_id()), + installation_id = %self.client.installation_id(), group_id = hex::encode(&self.group_id), current_epoch = openmls_group.epoch().as_u64(), msg_id, @@ -392,6 +394,7 @@ where if published_in_epoch_u64 != group_epoch_u64 { tracing::warn!( inbox_id = self.client.inbox_id(), + installation_id = %self.client.installation_id(), group_id = hex::encode(&self.group_id), current_epoch = openmls_group.epoch().as_u64(), msg_id, @@ -491,6 +494,7 @@ where tracing::info!( inbox_id = self.client.inbox_id(), + installation_id = %self.client.installation_id(), sender_inbox_id = sender_inbox_id, sender_installation_id = hex::encode(&sender_installation_id), group_id = hex::encode(&self.group_id), @@ -512,6 +516,8 @@ where tracing::info!( inbox_id = self.client.inbox_id(), sender_inbox_id = sender_inbox_id, + sender_installation_id = hex::encode(&sender_installation_id), + installation_id = %self.client.installation_id(), group_id = hex::encode(&self.group_id), current_epoch = openmls_group.epoch().as_u64(), msg_epoch, @@ -636,6 +642,7 @@ where tracing::info!( inbox_id = self.client.inbox_id(), sender_inbox_id = sender_inbox_id, + installation_id = %self.client.installation_id(), sender_installation_id = hex::encode(&sender_installation_id), group_id = hex::encode(&self.group_id), current_epoch = openmls_group.epoch().as_u64(), @@ -659,6 +666,7 @@ where tracing::info!( inbox_id = self.client.inbox_id(), sender_inbox_id = sender_inbox_id, + installation_id = %self.client.installation_id(), sender_installation_id = hex::encode(&sender_installation_id), group_id = hex::encode(&self.group_id), current_epoch = openmls_group.epoch().as_u64(), @@ -705,6 +713,7 @@ where .find_group_intent_by_payload_hash(sha256(envelope.data.as_slice())); tracing::info!( inbox_id = self.client.inbox_id(), + installation_id = %self.client.installation_id(), group_id = hex::encode(&self.group_id), current_epoch = openmls_group.epoch().as_u64(), msg_id = envelope.id, @@ -718,6 +727,7 @@ where let intent_id = intent.id; tracing::info!( inbox_id = self.client.inbox_id(), + installation_id = %self.client.installation_id(), group_id = hex::encode(&self.group_id), current_epoch = openmls_group.epoch().as_u64(), msg_id = envelope.id, @@ -752,6 +762,7 @@ where Ok(None) => { tracing::info!( inbox_id = self.client.inbox_id(), + installation_id = %self.client.installation_id(), group_id = hex::encode(&self.group_id), current_epoch = openmls_group.epoch().as_u64(), msg_id = envelope.id, @@ -769,9 +780,9 @@ where #[tracing::instrument(level = "trace", skip_all)] async fn consume_message( &self, + provider: &XmtpOpenMlsProvider, envelope: &GroupMessage, openmls_group: &mut OpenMlsGroup, - conn: &DbConnection, ) -> Result<(), GroupMessageProcessingError> { let msgv1 = match &envelope.version { Some(GroupMessageVersion::V1(value)) => value, @@ -784,12 +795,15 @@ where _ => EntityKind::Group, }; - let last_cursor = conn.get_last_cursor_for_id(&self.group_id, message_entity_kind)?; + let last_cursor = provider + .conn_ref() + .get_last_cursor_for_id(&self.group_id, message_entity_kind)?; tracing::info!("### last cursor --> [{:?}]", last_cursor); let should_skip_message = last_cursor > msgv1.id as i64; if should_skip_message { tracing::info!( inbox_id = "self.inbox_id()", + installation_id = %self.client.installation_id(), group_id = hex::encode(&self.group_id), "Message already processed: skipped msgId:[{}] entity kind:[{:?}] last cursor in db: [{}]", msgv1.id, @@ -798,19 +812,36 @@ where ); Err(GroupMessageProcessingError::AlreadyProcessed(msgv1.id)) } else { - self.client - .intents() - .process_for_id( - &msgv1.group_id, - EntityKind::Group, - msgv1.id, - |provider| async move { - self.process_message(openmls_group, &provider, msgv1, true) - .await?; - Ok::<(), GroupMessageProcessingError>(()) - }, - ) - .await?; + let cursor = &msgv1.id; + // Download all unread welcome messages and convert to groups. + // In a database transaction, increment the cursor for a given entity and + // apply the update after the provided `ProcessingFn` has completed successfully. + self.client.store().transaction_async(provider, |provider| async move { + let is_updated = + provider + .conn_ref() + .update_cursor(&msgv1.group_id, EntityKind::Group, *cursor as i64)?; + if !is_updated { + return Err(ProcessIntentError::AlreadyProcessed(*cursor).into()); + } + self.process_message(openmls_group, provider, msgv1, true).await?; + Ok::<_, GroupMessageProcessingError>(()) + }).await + .inspect(|_| { + tracing::info!( + "Transaction completed successfully: process for group [{}] envelope cursor[{}]", + hex::encode(&msgv1.group_id), + cursor + ); + }) + .inspect_err(|err| { + tracing::info!( + "Transaction failed: process for group [{}] envelope cursor [{}] error:[{}]", + hex::encode(&msgv1.group_id), + cursor, + err + ); + })?; Ok(()) } } @@ -828,7 +859,7 @@ where let result = retry_async!( Retry::default(), (async { - self.consume_message(&message, &mut openmls_group, provider.conn_ref()) + self.consume_message(provider, &message, &mut openmls_group) .await }) ); @@ -852,7 +883,13 @@ where if receive_errors.is_empty() { Ok(()) } else { - tracing::error!("Message processing errors: {:?}", receive_errors); + tracing::error!( + group_id = hex::encode(&self.group_id), + inbox_id = self.client.inbox_id(), + installation_id = hex::encode(self.client.installation_id()), + "Message processing errors: {:?}", + receive_errors + ); Err(GroupError::ReceiveErrors(receive_errors)) } } @@ -944,6 +981,7 @@ where intent.id, intent.kind = %intent.kind, inbox_id = self.client.inbox_id(), + installation_id = %self.client.installation_id(), group_id = hex::encode(&self.group_id), "intent {} has reached max publish attempts", intent.id); // TODO: Eventually clean up errored attempts @@ -974,7 +1012,7 @@ where )?; tracing::debug!( inbox_id = self.client.inbox_id(), - installation_id = hex::encode(self.client.installation_id()), + installation_id = %self.client.installation_id(), intent.id, intent.kind = %intent.kind, group_id = hex::encode(&self.group_id), @@ -992,7 +1030,7 @@ where intent.id, intent.kind = %intent.kind, inbox_id = self.client.inbox_id(), - installation_id = hex::encode(self.client.installation_id()), + installation_id = %self.client.installation_id(), group_id = hex::encode(&self.group_id), "[{}] published intent [{}] of type [{}]", self.client.inbox_id(), @@ -1007,7 +1045,7 @@ where Ok(None) => { tracing::info!( inbox_id = self.client.inbox_id(), - installation_id = hex::encode(self.client.installation_id()), + installation_id = %self.client.installation_id(), "Skipping intent because no publish data returned" ); let deleter: &dyn Delete = provider.conn_ref(); @@ -1148,7 +1186,7 @@ where if let Some(post_commit_data) = intent.post_commit_data { tracing::debug!( inbox_id = self.client.inbox_id(), - installation_id = hex::encode(self.client.installation_id()), + installation_id = %self.client.installation_id(), intent.id, intent.kind = %intent.kind, "taking post commit action" ); @@ -1216,13 +1254,13 @@ where debug!( inbox_id = self.client.inbox_id(), - installation_id = hex::encode(self.client.installation_id()), + installation_id = %self.client.installation_id(), "Adding missing installations {:?}", intent_data ); - let intent = self.queue_intent_with_conn( - provider.conn_ref(), + let intent = self.queue_intent( + provider, IntentKind::UpdateGroupMembership, intent_data.into(), )?; diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index ad9d92d34..f23bb8b6b 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -91,7 +91,7 @@ use crate::{ group::{ConversationType, GroupMembershipState, StoredGroup}, group_intent::IntentKind, group_message::{DeliveryStatus, GroupMessageKind, MsgQueryArgs, StoredGroupMessage}, - sql_key_store, + sql_key_store, StorageError, }, subscriptions::{LocalEventError, LocalEvents}, utils::{id::calculate_message_id, time::now_ns}, @@ -319,8 +319,8 @@ impl MlsGroup { /// Instantiate a new [`XmtpOpenMlsProvider`] pulling a connection from the database. /// prefer to use an already-instantiated mls provider if possible. - pub fn mls_provider(&self) -> Result { - Ok(self.context().mls_provider()?) + pub fn mls_provider(&self) -> Result { + self.context().mls_provider() } // Load the stored OpenMLS group from the OpenMLS provider's keystore @@ -338,15 +338,14 @@ impl MlsGroup { } // Create a new group and save it to the DB - pub fn create_and_insert( + pub(crate) fn create_and_insert( client: Arc, + provider: &XmtpOpenMlsProvider, membership_state: GroupMembershipState, permissions_policy_set: PolicySet, opts: GroupMetadataOptions, ) -> Result { let context = client.context(); - let conn = context.store().conn()?; - let provider = XmtpOpenMlsProvider::new(conn); let creator_inbox_id = context.inbox_id(); let protected_metadata = build_protected_metadata_extension(creator_inbox_id, ConversationType::Group)?; @@ -361,7 +360,7 @@ impl MlsGroup { )?; let mls_group = OpenMlsGroup::new( - &provider, + provider, &context.identity.installation_keys, &group_config, CredentialWithKey { @@ -388,14 +387,13 @@ impl MlsGroup { } // Create a new DM and save it to the DB - pub fn create_dm_and_insert( + pub(crate) fn create_dm_and_insert( + provider: &XmtpOpenMlsProvider, client: Arc, membership_state: GroupMembershipState, dm_target_inbox_id: InboxId, ) -> Result { let context = client.context(); - let conn = context.store().conn()?; - let provider = XmtpOpenMlsProvider::new(conn); let protected_metadata = build_dm_protected_metadata_extension(context.inbox_id(), dm_target_inbox_id.clone())?; let mutable_metadata = @@ -608,9 +606,8 @@ impl MlsGroup { self.maybe_update_installations(provider, update_interval_ns) .await?; - let message_id = self.prepare_message(message, provider.conn_ref(), |now| { - Self::into_envelope(message, now) - }); + let message_id = + self.prepare_message(message, provider, |now| Self::into_envelope(message, now)); self.sync_until_last_intent_resolved(provider).await?; @@ -648,9 +645,9 @@ impl MlsGroup { /// Send a message, optimistically returning the ID of the message before the result of a message publish. pub fn send_message_optimistic(&self, message: &[u8]) -> Result, GroupError> { - let conn = self.context().store().conn()?; + let provider = self.mls_provider()?; let message_id = - self.prepare_message(message, &conn, |now| Self::into_envelope(message, now))?; + self.prepare_message(message, &provider, |now| Self::into_envelope(message, now))?; Ok(message_id) } @@ -664,7 +661,7 @@ impl MlsGroup { fn prepare_message( &self, message: &[u8], - conn: &DbConnection, + provider: &XmtpOpenMlsProvider, envelope: F, ) -> Result, GroupError> where @@ -678,7 +675,7 @@ impl MlsGroup { .map_err(GroupError::EncodeError)?; let intent_data: Vec = SendMessageIntentData::new(encoded_envelope).into(); - self.queue_intent_with_conn(conn, IntentKind::SendMessage, intent_data)?; + self.queue_intent(provider, IntentKind::SendMessage, intent_data)?; // store this unpublished message locally before sending let message_id = calculate_message_id(&self.group_id, message, &now.to_string()); @@ -692,7 +689,7 @@ impl MlsGroup { sender_inbox_id: self.context().inbox_id().to_string(), delivery_status: DeliveryStatus::Unpublished, }; - group_message.store(conn)?; + group_message.store(provider.conn_ref())?; Ok(message_id) } @@ -730,8 +727,9 @@ impl MlsGroup { .api() .get_inbox_ids(account_addresses.clone()) .await?; + let provider = self.mls_provider()?; // get current number of users in group - let member_count = self.members().await?.len(); + let member_count = self.members_with_provider(&provider).await?.len(); if member_count + inbox_id_map.len() > MAX_GROUP_SIZE { return Err(GroupError::UserLimitExceeded); } @@ -745,8 +743,11 @@ impl MlsGroup { )); } - self.add_members_by_inbox_id(&inbox_id_map.into_values().collect::>()) - .await + self.add_members_by_inbox_id_with_provider( + &provider, + &inbox_id_map.into_values().collect::>(), + ) + .await } #[tracing::instrument(level = "trace", skip_all)] @@ -755,9 +756,19 @@ impl MlsGroup { inbox_ids: &[S], ) -> Result<(), GroupError> { let provider = self.client.mls_provider()?; + self.add_members_by_inbox_id_with_provider(&provider, inbox_ids) + .await + } + + #[tracing::instrument(level = "trace", skip_all)] + pub async fn add_members_by_inbox_id_with_provider>( + &self, + provider: &XmtpOpenMlsProvider, + inbox_ids: &[S], + ) -> Result<(), GroupError> { let ids = inbox_ids.iter().map(AsRef::as_ref).collect::>(); let intent_data = self - .get_membership_update_intent(&provider, ids.as_slice(), &[]) + .get_membership_update_intent(provider, ids.as_slice(), &[]) .await?; // TODO:nm this isn't the best test for whether the request is valid @@ -768,13 +779,13 @@ impl MlsGroup { return Ok(()); } - let intent = self.queue_intent_with_conn( - provider.conn_ref(), + let intent = self.queue_intent( + provider, IntentKind::UpdateGroupMembership, intent_data.into(), )?; - self.sync_until_intent_resolved(&provider, intent.id).await + self.sync_until_intent_resolved(provider, intent.id).await } /// Removes members from the group by their account addresses. @@ -817,8 +828,8 @@ impl MlsGroup { .get_membership_update_intent(&provider, &[], inbox_ids) .await?; - let intent = self.queue_intent_with_conn( - provider.conn_ref(), + let intent = self.queue_intent( + &provider, IntentKind::UpdateGroupMembership, intent_data.into(), )?; @@ -835,7 +846,7 @@ impl MlsGroup { } let intent_data: Vec = UpdateMetadataIntentData::new_update_group_name(group_name).into(); - let intent = self.queue_intent(IntentKind::MetadataUpdate, intent_data)?; + let intent = self.queue_intent(&provider, IntentKind::MetadataUpdate, intent_data)?; self.sync_until_intent_resolved(&provider, intent.id).await } @@ -864,7 +875,7 @@ impl MlsGroup { ) .into(); - let intent = self.queue_intent(IntentKind::UpdatePermission, intent_data)?; + let intent = self.queue_intent(&provider, IntentKind::UpdatePermission, intent_data)?; self.sync_until_intent_resolved(&provider, intent.id).await } @@ -894,7 +905,7 @@ impl MlsGroup { } let intent_data: Vec = UpdateMetadataIntentData::new_update_group_description(group_description).into(); - let intent = self.queue_intent(IntentKind::MetadataUpdate, intent_data)?; + let intent = self.queue_intent(&provider, IntentKind::MetadataUpdate, intent_data)?; self.sync_until_intent_resolved(&provider, intent.id).await } @@ -924,7 +935,7 @@ impl MlsGroup { let intent_data: Vec = UpdateMetadataIntentData::new_update_group_image_url_square(group_image_url_square) .into(); - let intent = self.queue_intent(IntentKind::MetadataUpdate, intent_data)?; + let intent = self.queue_intent(&provider, IntentKind::MetadataUpdate, intent_data)?; self.sync_until_intent_resolved(&provider, intent.id).await } @@ -956,7 +967,7 @@ impl MlsGroup { } let intent_data: Vec = UpdateMetadataIntentData::new_update_group_pinned_frame_url(pinned_frame_url).into(); - let intent = self.queue_intent(IntentKind::MetadataUpdate, intent_data)?; + let intent = self.queue_intent(&provider, IntentKind::MetadataUpdate, intent_data)?; self.sync_until_intent_resolved(&provider, intent.id).await } @@ -1039,7 +1050,7 @@ impl MlsGroup { }; let intent_data: Vec = UpdateAdminListIntentData::new(intent_action_type, inbox_id).into(); - let intent = self.queue_intent(IntentKind::UpdateAdminList, intent_data)?; + let intent = self.queue_intent(&provider, IntentKind::UpdateAdminList, intent_data)?; self.sync_until_intent_resolved(&provider, intent.id).await } @@ -1104,9 +1115,9 @@ impl MlsGroup { /// Update this installation's leaf key in the group by creating a key update commit pub async fn key_update(&self) -> Result<(), GroupError> { - let intent = self.queue_intent(IntentKind::KeyUpdate, vec![])?; - self.sync_until_intent_resolved(&self.client.mls_provider()?, intent.id) - .await + let provider = self.client.mls_provider()?; + let intent = self.queue_intent(&provider, IntentKind::KeyUpdate, vec![])?; + self.sync_until_intent_resolved(&provider, intent.id).await } /// Checks if the the current user is active in the group. @@ -1134,8 +1145,7 @@ impl MlsGroup { } pub fn permissions(&self) -> Result { - let conn = self.context().store().conn()?; - let provider = XmtpOpenMlsProvider::new(conn); + let provider = self.mls_provider()?; let mls_group = self.load_mls_group(&provider)?; Ok(extract_group_permissions(&mls_group)?) @@ -1627,7 +1637,7 @@ pub(crate) mod tests { async fn receive_group_invite(client: &FullXmtpClient) -> MlsGroup { client - .sync_welcomes(&client.store().conn().unwrap()) + .sync_welcomes(&client.mls_provider().unwrap()) .await .unwrap(); let mut groups = client.find_groups(GroupQueryArgs::default()).unwrap(); @@ -1775,7 +1785,7 @@ pub(crate) mod tests { // Get bola's version of the same group let bola_groups = bola - .sync_welcomes(&bola.store().conn().unwrap()) + .sync_welcomes(&bola.mls_provider().unwrap()) .await .unwrap(); let bola_group = bola_groups.first().unwrap(); @@ -1834,7 +1844,7 @@ pub(crate) mod tests { // Get bola's version of the same group let bola_groups = bola - .sync_welcomes(&bola.store().conn().unwrap()) + .sync_welcomes(&bola.mls_provider().unwrap()) .await .unwrap(); let bola_group = bola_groups.first().unwrap(); @@ -1944,7 +1954,7 @@ pub(crate) mod tests { force_add_member(&alix, &bo, &alix_group, &mut mls_group, &provider).await; // Bo should not be able to actually read this group - bo.sync_welcomes(&bo.store().conn().unwrap()).await.unwrap(); + bo.sync_welcomes(&bo.mls_provider().unwrap()).await.unwrap(); let groups = bo.find_groups(GroupQueryArgs::default()).unwrap(); assert_eq!(groups.len(), 0); assert_logged!("failed to create group from welcome", 1); @@ -2072,7 +2082,7 @@ pub(crate) mod tests { group.send_message(b"hello").await.expect("send message"); bola_client - .sync_welcomes(&bola_client.store().conn().unwrap()) + .sync_welcomes(&bola_client.mls_provider().unwrap()) .await .unwrap(); let bola_groups = bola_client.find_groups(GroupQueryArgs::default()).unwrap(); @@ -2429,7 +2439,7 @@ pub(crate) mod tests { .add_members_by_inbox_id(&[bola.inbox_id()]) .await .unwrap(); - bola.sync_welcomes(&bola.store().conn().unwrap()) + bola.sync_welcomes(&bola.mls_provider().unwrap()) .await .unwrap(); @@ -2600,7 +2610,7 @@ pub(crate) mod tests { .add_members(&[bola_wallet.get_address()]) .await .unwrap(); - bola.sync_welcomes(&bola.store().conn().unwrap()) + bola.sync_welcomes(&bola.mls_provider().unwrap()) .await .unwrap(); let bola_groups = bola.find_groups(GroupQueryArgs::default()).unwrap(); @@ -2682,7 +2692,7 @@ pub(crate) mod tests { .add_members(&[bola_wallet.get_address()]) .await .unwrap(); - bola.sync_welcomes(&bola.store().conn().unwrap()) + bola.sync_welcomes(&bola.mls_provider().unwrap()) .await .unwrap(); let bola_groups = bola.find_groups(GroupQueryArgs::default()).unwrap(); @@ -2700,7 +2710,7 @@ pub(crate) mod tests { assert!(super_admin_list.contains(&amal.inbox_id().to_string())); // Verify that bola can not add caro because they are not an admin - bola.sync_welcomes(&bola.store().conn().unwrap()) + bola.sync_welcomes(&bola.mls_provider().unwrap()) .await .unwrap(); let bola_groups = bola.find_groups(GroupQueryArgs::default()).unwrap(); @@ -2769,7 +2779,7 @@ pub(crate) mod tests { .contains(&bola.inbox_id().to_string())); // Verify that bola can not add charlie because they are not an admin - bola.sync_welcomes(&bola.store().conn().unwrap()) + bola.sync_welcomes(&bola.mls_provider().unwrap()) .await .unwrap(); let bola_groups = bola.find_groups(GroupQueryArgs::default()).unwrap(); @@ -2800,7 +2810,7 @@ pub(crate) mod tests { .add_members_by_inbox_id(&[bola.inbox_id()]) .await .unwrap(); - bola.sync_welcomes(&bola.store().conn().unwrap()) + bola.sync_welcomes(&bola.mls_provider().unwrap()) .await .unwrap(); let bola_groups = bola.find_groups(GroupQueryArgs::default()).unwrap(); @@ -2818,7 +2828,7 @@ pub(crate) mod tests { assert!(super_admin_list.contains(&amal.inbox_id().to_string())); // Verify that bola can not add caro as an admin because they are not a super admin - bola.sync_welcomes(&bola.store().conn().unwrap()) + bola.sync_welcomes(&bola.mls_provider().unwrap()) .await .unwrap(); let bola_groups = bola.find_groups(GroupQueryArgs::default()).unwrap(); @@ -3012,7 +3022,7 @@ pub(crate) mod tests { // Bola syncs groups - this will decrypt the Welcome, identify who added Bola // and then store that value on the group and insert into the database let bola_groups = bola - .sync_welcomes(&bola.store().conn().unwrap()) + .sync_welcomes(&bola.mls_provider().unwrap()) .await .unwrap(); @@ -3081,7 +3091,7 @@ pub(crate) mod tests { .unwrap(); // Step 3: Verify that Bola can update the group name, and amal sees the update - bola.sync_welcomes(&bola.store().conn().unwrap()) + bola.sync_welcomes(&bola.mls_provider().unwrap()) .await .unwrap(); let bola_groups = bola.find_groups(GroupQueryArgs::default()).unwrap(); @@ -3149,7 +3159,7 @@ pub(crate) mod tests { // Step 3: Bola attemps to add Caro, but fails because group is admin only let caro = ClientBuilder::new_test_client(&generate_local_wallet()).await; - bola.sync_welcomes(&bola.store().conn().unwrap()) + bola.sync_welcomes(&bola.mls_provider().unwrap()) .await .unwrap(); let bola_groups = bola.find_groups(GroupQueryArgs::default()).unwrap(); @@ -3294,7 +3304,7 @@ pub(crate) mod tests { // Amal creates a dm group targetting bola let amal_dm = amal - .create_dm_by_inbox_id(bola.inbox_id().to_string()) + .create_dm_by_inbox_id(&amal.mls_provider().unwrap(), bola.inbox_id().to_string()) .await .unwrap(); @@ -3312,7 +3322,7 @@ pub(crate) mod tests { assert_eq!(members.len(), 2); // Bola can message amal - let _ = bola.sync_welcomes(&bola.store().conn().unwrap()).await; + let _ = bola.sync_welcomes(&bola.mls_provider().unwrap()).await; let bola_groups = bola.find_groups(GroupQueryArgs::default()).unwrap(); let bola_dm: &MlsGroup<_> = bola_groups.first().unwrap(); @@ -3428,7 +3438,7 @@ pub(crate) mod tests { let alix_message = vec![1]; alix_group.send_message(&alix_message).await.unwrap(); bo_client - .sync_welcomes(&bo_client.store().conn().unwrap()) + .sync_welcomes(&bo_client.mls_provider().unwrap()) .await .unwrap(); let bo_groups = bo_client.find_groups(GroupQueryArgs::default()).unwrap(); @@ -3552,7 +3562,11 @@ pub(crate) mod tests { } group - .queue_intent(IntentKind::UpdateGroupMembership, intent_data.into()) + .queue_intent( + provider, + IntentKind::UpdateGroupMembership, + intent_data.into(), + ) .unwrap(); } @@ -3711,7 +3725,7 @@ pub(crate) mod tests { .await .unwrap(); - bola.sync_welcomes(&bola.store().conn().unwrap()) + bola.sync_welcomes(&bola.mls_provider().unwrap()) .await .unwrap(); let bola_groups = bola.find_groups(GroupQueryArgs::default()).unwrap(); @@ -3732,7 +3746,7 @@ pub(crate) mod tests { .await .unwrap(); - caro.sync_welcomes(&caro.store().conn().unwrap()) + caro.sync_welcomes(&caro.mls_provider().unwrap()) .await .unwrap(); let caro_groups = caro.find_groups(GroupQueryArgs::default()).unwrap(); @@ -3765,7 +3779,7 @@ pub(crate) mod tests { .await .unwrap(); - bo.sync_welcomes(&bo.store().conn().unwrap()).await.unwrap(); + bo.sync_welcomes(&bo.mls_provider().unwrap()).await.unwrap(); let bo_groups = bo.find_groups(GroupQueryArgs::default()).unwrap(); let bo_group = bo_groups.first().unwrap(); diff --git a/xmtp_mls/src/groups/scoped_client.rs b/xmtp_mls/src/groups/scoped_client.rs index 808fee0e0..557761f6e 100644 --- a/xmtp_mls/src/groups/scoped_client.rs +++ b/xmtp_mls/src/groups/scoped_client.rs @@ -3,9 +3,9 @@ use crate::{ api::ApiClientWrapper, client::{ClientError, XmtpMlsLocalContext}, identity_updates::{InstallationDiff, InstallationDiffError}, - intents::Intents, - storage::{DbConnection, EncryptedMessageStore}, + storage::{DbConnection, EncryptedMessageStore, StorageError}, subscriptions::LocalEvents, + types::InstallationId, verified_key_package_v2::VerifiedKeyPackageV2, xmtp_openmls_provider::XmtpOpenMlsProvider, Client, @@ -37,16 +37,14 @@ pub trait LocalScopedGroupClient: Send + Sync + Sized { self.context_ref().inbox_id() } - fn installation_id(&self) -> &[u8] { + fn installation_id(&self) -> InstallationId { self.context_ref().installation_public_key() } - fn mls_provider(&self) -> Result { + fn mls_provider(&self) -> Result { self.context_ref().mls_provider() } - fn intents(&self) -> &Arc; - fn context_ref(&self) -> &Arc; fn context(&self) -> Arc { @@ -105,16 +103,14 @@ pub trait ScopedGroupClient: Sized { self.context_ref().inbox_id() } - fn installation_id(&self) -> &[u8] { + fn installation_id(&self) -> InstallationId { self.context_ref().installation_public_key() } - fn mls_provider(&self) -> Result { + fn mls_provider(&self) -> Result { self.context_ref().mls_provider() } - fn intents(&self) -> &Arc; - fn context_ref(&self) -> &Arc; fn context(&self) -> Arc { @@ -173,10 +169,6 @@ where Client::::context(self) } - fn intents(&self) -> &Arc { - crate::Client::::intents(self) - } - fn history_sync_url(&self) -> &Option { &self.history_sync_url } @@ -264,10 +256,6 @@ where (**self).store() } - fn intents(&self) -> &Arc { - (**self).intents() - } - fn inbox_id(&self) -> InboxIdRef<'_> { (**self).inbox_id() } @@ -276,7 +264,7 @@ where (**self).context_ref() } - fn mls_provider(&self) -> Result { + fn mls_provider(&self) -> Result { (**self).mls_provider() } @@ -358,10 +346,6 @@ where (**self).history_sync_url() } - fn intents(&self) -> &Arc { - (**self).intents() - } - fn inbox_id(&self) -> InboxIdRef<'_> { (**self).inbox_id() } @@ -370,7 +354,7 @@ where (**self).context_ref() } - fn mls_provider(&self) -> Result { + fn mls_provider(&self) -> Result { (**self).mls_provider() } diff --git a/xmtp_mls/src/groups/subscriptions.rs b/xmtp_mls/src/groups/subscriptions.rs index 2caf9bec9..fbaaa20d6 100644 --- a/xmtp_mls/src/groups/subscriptions.rs +++ b/xmtp_mls/src/groups/subscriptions.rs @@ -14,6 +14,7 @@ use crate::storage::refresh_state::EntityKind; use crate::storage::StorageError; use crate::subscriptions::MessagesStreamInfo; use crate::subscriptions::SubscribeError; +use crate::XmtpOpenMlsProvider; use crate::{retry::Retry, retry_async}; use prost::Message; use xmtp_proto::xmtp::mls::api::v1::GroupMessage; @@ -22,6 +23,7 @@ impl MlsGroup { /// Internal stream processing function pub(crate) async fn process_stream_entry( &self, + provider: &XmtpOpenMlsProvider, envelope: GroupMessage, ) -> Result { let msgv1 = extract_message_v1(envelope)?; @@ -45,8 +47,8 @@ impl MlsGroup { let msgv1 = &msgv1; self.context() .store() - .transaction_async(|provider| async move { - let mut openmls_group = self.load_mls_group(&provider)?; + .transaction_async(provider, |provider| async move { + let mut openmls_group = self.load_mls_group(provider)?; // Attempt processing immediately, but fail if the message is not an Application Message // Returning an error should roll back the DB tx @@ -60,7 +62,7 @@ impl MlsGroup { openmls_group.epoch() ); - self.process_message(&mut openmls_group, &provider, msgv1, false) + self.process_message(&mut openmls_group, provider, msgv1, false) .await // NOTE: We want to make sure we retry an error in process_message .map_err(SubscribeError::ReceiveGroup) @@ -78,7 +80,7 @@ impl MlsGroup { ); // Swallow errors here, since another process may have successfully saved the message // to the DB - if let Err(err) = self.sync_with_conn(&self.client.mls_provider()?).await { + if let Err(err) = self.sync_with_conn(provider).await { tracing::warn!( inbox_id = self.client.inbox_id(), group_id = hex::encode(&self.group_id), @@ -108,10 +110,8 @@ impl MlsGroup { // Load the message from the DB to handle cases where it may have been already processed in // another thread - let new_message = self - .context() - .store() - .conn()? + let new_message = provider + .conn_ref() .get_group_message_by_timestamp(&self.group_id, created_ns as i64)? .ok_or(SubscribeError::GroupMessageNotFound)?; @@ -133,20 +133,21 @@ impl MlsGroup { /// Converts some `SubscribeError` variants to an Option, if they are inconsequential. pub async fn process_streamed_group_message( &self, + provider: &XmtpOpenMlsProvider, envelope_bytes: Vec, ) -> Result { let envelope = GroupMessage::decode(envelope_bytes.as_slice())?; - self.process_stream_entry(envelope).await + self.process_stream_entry(provider, envelope).await } - pub async fn stream( - &self, + pub async fn stream<'a>( + &'a self, ) -> Result< - impl Stream> + use<'_, ScopedClient>, + impl Stream> + use<'a, ScopedClient>, ClientError, > where - ::ApiClient: XmtpMlsStreams + 'static, + ::ApiClient: XmtpMlsStreams + 'a, { let group_list = HashMap::from([( self.group_id.clone(), @@ -180,14 +181,15 @@ impl MlsGroup { } /// Stream messages from groups in `group_id_to_info` +// TODO: Note when to use a None provider #[tracing::instrument(level = "debug", skip_all)] -pub(crate) async fn stream_messages( - client: &ScopedClient, +pub(crate) async fn stream_messages<'a, ScopedClient>( + client: &'a ScopedClient, group_id_to_info: Arc, MessagesStreamInfo>>, -) -> Result> + '_, ClientError> +) -> Result> + 'a, ClientError> where ScopedClient: ScopedGroupClient, - ::ApiClient: XmtpApi + XmtpMlsStreams + 'static, + ::ApiClient: XmtpApi + XmtpMlsStreams + 'a, { let filters: Vec = group_id_to_info .iter() @@ -200,6 +202,7 @@ where .then(move |res| { let group_id_to_info = group_id_to_info.clone(); async move { + let provider = client.mls_provider()?; let envelope = res.map_err(GroupError::from)?; let group_id = extract_group_id(&envelope)?; tracing::info!( @@ -214,7 +217,8 @@ where "Received message for a non-subscribed group".to_string(), ))?; let mls_group = MlsGroup::new(client, group_id, stream_info.convo_created_at_ns); - mls_group.process_stream_entry(envelope).await + + mls_group.process_stream_entry(&provider, envelope).await } }) .inspect(|e| { @@ -296,8 +300,9 @@ pub(crate) mod tests { let message = messages.first().unwrap(); let mut message_bytes: Vec = Vec::new(); message.encode(&mut message_bytes).unwrap(); + let provider = amal.mls_provider().unwrap(); let message_again = amal_group - .process_streamed_group_message(message_bytes) + .process_streamed_group_message(&provider, message_bytes) .await; if let Ok(message) = message_again { @@ -327,7 +332,7 @@ pub(crate) mod tests { // Get bola's version of the same group let bola_groups = bola - .sync_welcomes(&bola.store().conn().unwrap()) + .sync_welcomes(&bola.mls_provider().unwrap()) .await .unwrap(); let bola_group = Arc::new(bola_groups.first().unwrap().clone()); diff --git a/xmtp_mls/src/identity.rs b/xmtp_mls/src/identity.rs index 0b73a784d..d80ebf852 100644 --- a/xmtp_mls/src/identity.rs +++ b/xmtp_mls/src/identity.rs @@ -5,7 +5,6 @@ use crate::retry::RetryableError; use crate::storage::db_connection::DbConnection; use crate::storage::identity::StoredIdentity; use crate::storage::sql_key_store::{SqlKeyStore, SqlKeyStoreError, KEY_PACKAGE_REFERENCES}; -use crate::storage::EncryptedMessageStore; use crate::{ api::{ApiClientWrapper, WrappedApiError}, configuration::{CIPHERSUITE, GROUP_MEMBERSHIP_EXTENSION_ID, MUTABLE_METADATA_EXTENSION_ID}, @@ -111,14 +110,12 @@ impl IdentityStrategy { pub(crate) async fn initialize_identity( self, api_client: &ApiClientWrapper, - store: &EncryptedMessageStore, + provider: &XmtpOpenMlsProvider, scw_signature_verifier: impl SmartContractSignatureVerifier, ) -> Result { use IdentityStrategy::*; info!("Initializing identity"); - let conn = store.conn()?; - let provider = XmtpOpenMlsProvider::new(conn); let stored_identity: Option = provider .conn_ref() .fetch(&())? @@ -150,7 +147,7 @@ impl IdentityStrategy { nonce, legacy_signed_private_key, api_client, - &provider, + provider, scw_signature_verifier, ) .await diff --git a/xmtp_mls/src/intents.rs b/xmtp_mls/src/intents.rs index 7f54daa6d..6664152cf 100644 --- a/xmtp_mls/src/intents.rs +++ b/xmtp_mls/src/intents.rs @@ -8,13 +8,7 @@ //! Intents are written to local storage (SQLite), before being published to the delivery service via gRPC. An //! intent is fully resolved (success or failure) once it -use crate::{ - client::XmtpMlsLocalContext, - retry::RetryableError, - storage::{refresh_state::EntityKind, EncryptedMessageStore}, - xmtp_openmls_provider::XmtpOpenMlsProvider, -}; -use std::{future::Future, sync::Arc}; +use crate::retry::RetryableError; use thiserror::Error; #[derive(Debug, Error)] @@ -36,61 +30,3 @@ impl RetryableError for ProcessIntentError { } } } - -/// Intents holding the context of this Client -pub struct Intents { - pub(crate) context: Arc, -} - -impl Intents { - pub(crate) fn store(&self) -> &EncryptedMessageStore { - self.context.store() - } - - /// Download all unread welcome messages and convert to groups. - /// In a database transaction, increment the cursor for a given entity and - /// apply the update after the provided `ProcessingFn` has completed successfully. - pub(crate) async fn process_for_id( - &self, - entity_id: &[u8], - entity_kind: EntityKind, - cursor: u64, - process_envelope: ProcessingFn, - ) -> Result - where - Fut: Future>, - ProcessingFn: FnOnce(XmtpOpenMlsProvider) -> Fut, - ErrorType: From - + From - + From - + std::fmt::Display, - { - self.store() - .transaction_async(|provider| async move { - let is_updated = - provider - .conn_ref() - .update_cursor(entity_id, entity_kind, cursor as i64)?; - if !is_updated { - return Err(ProcessIntentError::AlreadyProcessed(cursor).into()); - } - process_envelope(provider).await - }) - .await - .inspect(|_| { - tracing::info!( - "Transaction completed successfully: process for entity [{:?}] envelope cursor[{}]", - entity_id, - cursor - ); - }) - .inspect_err(|err| { - tracing::info!( - "Transaction failed: process for entity [{:?}] envelope cursor[{}] error:[{}]", - entity_id, - cursor, - err - ); - }) - } -} diff --git a/xmtp_mls/src/storage/encrypted_store/db_connection.rs b/xmtp_mls/src/storage/encrypted_store/db_connection.rs index e2ef9a970..d0ef3942a 100644 --- a/xmtp_mls/src/storage/encrypted_store/db_connection.rs +++ b/xmtp_mls/src/storage/encrypted_store/db_connection.rs @@ -51,6 +51,8 @@ where self.inner.lock() } + /// Internal-only API to get the underlying `diesel::Connection` reference + /// without a scope pub(super) fn inner_ref(&self) -> Arc> { self.inner.clone() } diff --git a/xmtp_mls/src/storage/encrypted_store/mod.rs b/xmtp_mls/src/storage/encrypted_store/mod.rs index 58ff1eff8..9646a7611 100644 --- a/xmtp_mls/src/storage/encrypted_store/mod.rs +++ b/xmtp_mls/src/storage/encrypted_store/mod.rs @@ -190,6 +190,13 @@ pub mod private { Ok::<_, StorageError>(()) } + pub fn mls_provider( + &self, + ) -> Result::Connection>, StorageError> { + let conn = self.conn()?; + Ok(crate::xmtp_openmls_provider::XmtpOpenMlsProviderPrivate::new(conn)) + } + /// Pulls a new connection from the store pub fn conn( &self, @@ -198,6 +205,7 @@ pub mod private { } /// Start a new database transaction with the OpenMLS Provider from XMTP + /// with the provided connection /// # Arguments /// `fun`: Scoped closure providing a MLSProvider to carry out the transaction /// @@ -210,22 +218,25 @@ pub mod private { /// provider.conn().db_operation()?; /// }) /// ``` - pub fn transaction(&self, fun: F) -> Result + pub fn transaction( + &self, + provider: &XmtpOpenMlsProviderPrivate<::Connection>, + fun: F, + ) -> Result where F: FnOnce(&XmtpOpenMlsProviderPrivate<::Connection>) -> Result, E: From + From, { tracing::debug!("Transaction beginning"); - let connection = self.db.conn()?; { + let connection = provider.conn_ref(); let mut connection = connection.inner_mut_ref(); ::TransactionManager::begin_transaction(&mut *connection)?; } - let provider = XmtpOpenMlsProviderPrivate::new(connection); let conn = provider.conn_ref(); - match fun(&provider) { + match fun(provider) { Ok(value) => { conn.raw_query(|conn| { ::TransactionManager::commit_transaction(&mut *conn) @@ -260,27 +271,29 @@ pub mod private { /// provider.conn().db_operation()?; /// }).await /// ``` - pub async fn transaction_async(&self, fun: F) -> Result + pub async fn transaction_async<'a, T, F, E, Fut>( + &self, + provider: &'a XmtpOpenMlsProviderPrivate<::Connection>, + fun: F, + ) -> Result where - F: FnOnce(XmtpOpenMlsProviderPrivate<::Connection>) -> Fut, + F: FnOnce(&'a XmtpOpenMlsProviderPrivate<::Connection>) -> Fut, Fut: futures::Future>, E: From + From, { tracing::debug!("Transaction async beginning"); - let db_connection = self.db.conn()?; { - let mut connection = db_connection.inner_mut_ref(); + let connection = provider.conn_ref(); + let mut connection = connection.inner_mut_ref(); ::TransactionManager::begin_transaction(&mut *connection)?; } - let local_connection = db_connection.inner_ref(); - let provider = XmtpOpenMlsProviderPrivate::new(db_connection); - // the other connection is dropped in the closure // ensuring we have only one strong reference let result = fun(provider).await; + let local_connection = provider.conn_ref().inner_ref(); if Arc::strong_count(&local_connection) > 1 { tracing::warn!( - "More than 1 strong connection references still exist during transaction" + "More than 1 strong connection references still exist during async transaction" ); } @@ -467,7 +480,7 @@ pub(crate) mod tests { identity::StoredIdentity, }, utils::test::{rand_vec, tmp_path}, - Fetch, Store, StreamHandle as _, + Fetch, Store, StreamHandle as _, XmtpOpenMlsProvider, }; /// Test harness that loads an Ephemeral store. @@ -651,11 +664,11 @@ pub(crate) mod tests { .unwrap(); let barrier = Arc::new(Barrier::new(2)); - + let provider = XmtpOpenMlsProvider::new(store.conn().unwrap()); let store_pointer = store.clone(); let barrier_pointer = barrier.clone(); let handle = std::thread::spawn(move || { - store_pointer.transaction(|provider| { + store_pointer.transaction(&provider, |provider| { let conn1 = provider.conn_ref(); StoredIdentity::new("correct".to_string(), rand_vec(), rand_vec()) .store(conn1) @@ -669,20 +682,22 @@ pub(crate) mod tests { }); let store_pointer = store.clone(); + let provider = XmtpOpenMlsProvider::new(store.conn().unwrap()); let handle2 = std::thread::spawn(move || { barrier.wait(); - let result = store_pointer.transaction(|provider| -> Result<(), anyhow::Error> { - let connection = provider.conn_ref(); - let group = StoredGroup::new( - b"should not exist".to_vec(), - 0, - GroupMembershipState::Allowed, - "goodbye".to_string(), - None, - ); - group.store(connection)?; - Ok(()) - }); + let result = + store_pointer.transaction(&provider, |provider| -> Result<(), anyhow::Error> { + let connection = provider.conn_ref(); + let group = StoredGroup::new( + b"should not exist".to_vec(), + 0, + GroupMembershipState::Allowed, + "goodbye".to_string(), + None, + ); + group.store(connection)?; + Ok(()) + }); barrier.wait(); result }); @@ -718,10 +733,10 @@ pub(crate) mod tests { .unwrap(); let store_pointer = store.clone(); - + let provider = XmtpOpenMlsProvider::new(store_pointer.conn().unwrap()); let handle = crate::spawn(None, async move { store_pointer - .transaction_async(|provider| async move { + .transaction_async(&provider, |provider| async move { let conn1 = provider.conn_ref(); StoredIdentity::new("crab".to_string(), rand_vec(), rand_vec()) .store(conn1) diff --git a/xmtp_mls/src/storage/encrypted_store/refresh_state.rs b/xmtp_mls/src/storage/encrypted_store/refresh_state.rs index 34efa6220..f5cf0ba33 100644 --- a/xmtp_mls/src/storage/encrypted_store/refresh_state.rs +++ b/xmtp_mls/src/storage/encrypted_store/refresh_state.rs @@ -90,13 +90,13 @@ impl DbConnection { } } - pub fn update_cursor( + pub fn update_cursor>( &self, - entity_id: &[u8], + entity_id: Id, entity_kind: EntityKind, cursor: i64, ) -> Result { - let state: Option = self.get_refresh_state(entity_id, entity_kind)?; + let state: Option = self.get_refresh_state(&entity_id, entity_kind)?; match state { Some(state) => { use super::schema::refresh_state::dsl; @@ -110,7 +110,7 @@ impl DbConnection { } None => Err(StorageError::NotFound(format!( "state for entity ID {} with kind {:?}", - hex::encode(entity_id), + hex::encode(entity_id.as_ref()), entity_kind ))), } diff --git a/xmtp_mls/src/subscriptions.rs b/xmtp_mls/src/subscriptions.rs index 9a89fb014..b14a9291c 100644 --- a/xmtp_mls/src/subscriptions.rs +++ b/xmtp_mls/src/subscriptions.rs @@ -12,7 +12,10 @@ use xmtp_proto::{api_client::XmtpMlsStreams, xmtp::mls::api::v1::WelcomeMessage} use crate::{ client::{extract_welcome_message, ClientError}, - groups::{mls_sync::GroupMessageProcessingError, subscriptions, GroupError, MlsGroup}, + groups::{ + group_metadata::GroupMetadata, mls_sync::GroupMessageProcessingError, + scoped_client::ScopedGroupClient as _, subscriptions, GroupError, MlsGroup, + }, preferences::UserPreferenceUpdate, retry::{Retry, RetryableError}, retry_async, retryable, @@ -22,7 +25,7 @@ use crate::{ group_message::StoredGroupMessage, StorageError, }, - Client, XmtpApi, + Client, XmtpApi, XmtpOpenMlsProvider, }; use thiserror::Error; @@ -223,6 +226,7 @@ where { async fn process_streamed_welcome( &self, + provider: &XmtpOpenMlsProvider, welcome: WelcomeMessage, ) -> Result, ClientError> { let welcome_v1 = extract_welcome_message(welcome)?; @@ -236,10 +240,10 @@ where let welcome_v1 = &welcome_v1; self.context .store() - .transaction_async(|provider| async move { + .transaction_async(provider, |provider| async move { MlsGroup::create_from_encrypted_welcome( Arc::new(self.clone()), - &provider, + provider, welcome_v1.hpke_public_key.as_slice(), &welcome_v1.data, welcome_v1.id as i64, @@ -251,7 +255,7 @@ where ); if let Some(err) = creation_result.as_ref().err() { - let conn = self.context.store().conn()?; + let conn = provider.conn_ref(); let result = conn.find_group_by_welcome_id(welcome_v1.id as i64); match result { Ok(Some(group)) => { @@ -276,71 +280,81 @@ where &self, envelope_bytes: Vec, ) -> Result, ClientError> { + let provider = self.mls_provider()?; let envelope = WelcomeMessage::decode(envelope_bytes.as_slice()) .map_err(|e| ClientError::Generic(e.to_string()))?; - let welcome = self.process_streamed_welcome(envelope).await?; + let welcome = self.process_streamed_welcome(&provider, envelope).await?; Ok(welcome) } #[tracing::instrument(level = "debug", skip_all)] - pub async fn stream_conversations( - &self, + pub async fn stream_conversations<'a>( + &'a self, conversation_type: Option, - ) -> Result, SubscribeError>> + '_, ClientError> + ) -> Result, SubscribeError>> + 'a, ClientError> where ApiClient: XmtpMlsStreams, { - let event_queue = tokio_stream::wrappers::BroadcastStream::new( - self.local_events.subscribe(), - ) - .filter_map(|event| async { - crate::optify!(event, "Missed messages due to event queue lag") - .and_then(LocalEvents::group_filter) - .map(Result::Ok) - }); - - // Helper function for filtering Dm groups - let filter_group = move |group: Result, ClientError>| { - let conversation_type = &conversation_type; - // take care of any possible errors - let result = || -> Result<_, _> { - let group = group?; - let provider = group.client.context().mls_provider()?; - let metadata = group.metadata(&provider)?; - Ok((metadata, group)) - }; - let filtered = result().map(|(metadata, group)| { - conversation_type - .map_or(true, |ct| ct == metadata.conversation_type) - .then_some(group) - }); - futures::future::ready(filtered.transpose()) - }; - let installation_key = self.installation_public_key(); let id_cursor = 0; tracing::info!(inbox_id = self.inbox_id(), "Setting up conversation stream"); let subscription = self .api_client - .subscribe_welcome_messages(installation_key.into(), Some(id_cursor)) - .await?; + .subscribe_welcome_messages(installation_key.as_ref(), Some(id_cursor)) + .await? + .map(WelcomeOrGroup::::Welcome); + + let event_queue = + tokio_stream::wrappers::BroadcastStream::new(self.local_events.subscribe()) + .filter_map(|event| async { + crate::optify!(event, "Missed messages due to event queue lag") + .and_then(LocalEvents::group_filter) + .map(Result::Ok) + }) + .map(WelcomeOrGroup::::Group); + + let stream = futures::stream::select(event_queue, subscription); + let stream = stream.filter_map(move |group_or_welcome| async move { + tracing::info!( + inbox_id = self.inbox_id(), + installation_id = %self.installation_id(), + "Received conversation streaming payload" + ); + let filtered = self.process_streamed_convo(group_or_welcome).await; + let filtered = filtered.map(|(metadata, group)| { + conversation_type + .map_or(true, |ct| ct == metadata.conversation_type) + .then_some(group) + }); + filtered.transpose() + }); - let stream = subscription - .map(|welcome| async { - tracing::info!( - inbox_id = self.inbox_id(), - "Received conversation streaming payload" - ); - self.process_streamed_welcome(welcome?).await - }) - .filter_map(|v| async { Some(v.await) }); + Ok(stream) + } - Ok(futures::stream::select(stream, event_queue).filter_map(filter_group)) + async fn process_streamed_convo( + &self, + welcome_or_group: WelcomeOrGroup, + ) -> Result<(GroupMetadata, MlsGroup>), SubscribeError> { + let provider = self.mls_provider()?; + let group = match welcome_or_group { + WelcomeOrGroup::Welcome(welcome) => { + self.process_streamed_welcome(&provider, welcome?).await? + } + WelcomeOrGroup::Group(group) => group?, + }; + let metadata = group.metadata(provider)?; + Ok((metadata, group)) } } +enum WelcomeOrGroup { + Group(Result>, SubscribeError>), + Welcome(Result), +} + impl Client where ApiClient: XmtpApi + XmtpMlsStreams + Send + Sync + 'static, @@ -377,17 +391,19 @@ where conversation_type = ?conversation_type, "stream all messages" ); - - let conn = self.store().conn()?; - self.sync_welcomes(&conn).await?; - - let mut group_id_to_info = self - .store() - .conn()? - .find_groups(GroupQueryArgs::default().maybe_conversation_type(conversation_type))? - .into_iter() - .map(Into::into) - .collect::, MessagesStreamInfo>>(); + let mut group_id_to_info = async { + let provider = self.mls_provider()?; + self.sync_welcomes(&provider).await?; + + let group_id_to_info = provider + .conn_ref() + .find_groups(GroupQueryArgs::default().maybe_conversation_type(conversation_type))? + .into_iter() + .map(Into::into) + .collect::, MessagesStreamInfo>>(); + Ok::<_, ClientError>(group_id_to_info) + } + .await?; let stream = async_stream::stream! { let messages_stream = subscriptions::stream_messages( @@ -586,7 +602,7 @@ pub(crate) mod tests { .await .unwrap(); let bob_group = bob - .sync_welcomes(&bob.store().conn().unwrap()) + .sync_welcomes(&bob.mls_provider().unwrap()) .await .unwrap(); let bob_group = bob_group.first().unwrap(); @@ -895,7 +911,7 @@ pub(crate) mod tests { } // Verify syncing welcomes while streaming causes no issues - alix.sync_welcomes(&alix.store().conn().unwrap()) + alix.sync_welcomes(&alix.mls_provider().unwrap()) .await .unwrap(); let find_groups_results = alix.find_groups(GroupQueryArgs::default()).unwrap(); @@ -930,7 +946,7 @@ pub(crate) mod tests { }, ); - alix.create_dm_by_inbox_id(bo.inbox_id().to_string()) + alix.create_dm_by_inbox_id(&alix.mls_provider().unwrap(), bo.inbox_id().to_string()) .await .unwrap(); @@ -980,7 +996,7 @@ pub(crate) mod tests { let result = notify.wait_for_delivery().await; assert!(result.is_err(), "Stream unexpectedly received a Group"); - alix.create_dm_by_inbox_id(bo.inbox_id().to_string()) + alix.create_dm_by_inbox_id(&alix.mls_provider().unwrap(), bo.inbox_id().to_string()) .await .unwrap(); notify.wait_for_delivery().await.unwrap(); @@ -1003,7 +1019,7 @@ pub(crate) mod tests { notify_pointer.notify_one(); }); - alix.create_dm_by_inbox_id(bo.inbox_id().to_string()) + alix.create_dm_by_inbox_id(&alix.mls_provider().unwrap(), bo.inbox_id().to_string()) .await .unwrap(); notify.wait_for_delivery().await.unwrap(); @@ -1013,7 +1029,7 @@ pub(crate) mod tests { } let dm = bo - .create_dm_by_inbox_id(alix.inbox_id().to_string()) + .create_dm_by_inbox_id(&bo.mls_provider().unwrap(), alix.inbox_id().to_string()) .await .unwrap(); dm.add_members_by_inbox_id(&[alix.inbox_id()]) @@ -1057,7 +1073,7 @@ pub(crate) mod tests { .unwrap(); let alix_dm = alix - .create_dm_by_inbox_id(bo.inbox_id().to_string()) + .create_dm_by_inbox_id(&alix.mls_provider().unwrap(), bo.inbox_id().to_string()) .await .unwrap(); diff --git a/xmtp_mls/src/types.rs b/xmtp_mls/src/types.rs index 0a73a89b1..5a97a0ca3 100644 --- a/xmtp_mls/src/types.rs +++ b/xmtp_mls/src/types.rs @@ -1,2 +1,86 @@ pub type Address = String; -pub type InstallationId = String; + +use std::fmt; + +#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] +pub struct InstallationId([u8; 32]); + +impl fmt::Display for InstallationId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", hex::encode(self.0)) + } +} + +impl std::ops::Deref for InstallationId { + type Target = [u8; 32]; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl AsRef<[u8]> for InstallationId { + fn as_ref(&self) -> &[u8] { + &self.0 + } +} + +impl From for Vec { + fn from(value: InstallationId) -> Self { + value.0.to_vec() + } +} + +impl From<[u8; 32]> for InstallationId { + fn from(value: [u8; 32]) -> Self { + InstallationId(value) + } +} + +impl PartialEq> for InstallationId { + fn eq(&self, other: &Vec) -> bool { + self.0.eq(&other[..]) + } +} + +impl PartialEq for Vec { + fn eq(&self, other: &InstallationId) -> bool { + other.0.eq(&self[..]) + } +} + +impl PartialEq<&Vec> for InstallationId { + fn eq(&self, other: &&Vec) -> bool { + self.0.eq(&other[..]) + } +} + +impl PartialEq for &Vec { + fn eq(&self, other: &InstallationId) -> bool { + other.0.eq(&self[..]) + } +} + +impl PartialEq<[u8]> for InstallationId { + fn eq(&self, other: &[u8]) -> bool { + self.0.eq(other) + } +} + +impl PartialEq for [u8] { + fn eq(&self, other: &InstallationId) -> bool { + other.0.eq(self) + } +} + +impl PartialEq<[u8; 32]> for InstallationId { + fn eq(&self, other: &[u8; 32]) -> bool { + self.0.eq(other) + } +} + +impl PartialEq for [u8; 32] { + fn eq(&self, other: &InstallationId) -> bool { + other.0.eq(&self[..]) + } +} diff --git a/xmtp_mls/src/xmtp_openmls_provider.rs b/xmtp_mls/src/xmtp_openmls_provider.rs index 77dd51494..a5597ee31 100644 --- a/xmtp_mls/src/xmtp_openmls_provider.rs +++ b/xmtp_mls/src/xmtp_openmls_provider.rs @@ -19,7 +19,11 @@ impl XmtpOpenMlsProviderPrivate { } } - pub(crate) fn conn_ref(&self) -> &DbConnectionPrivate { + pub fn new_crypto() -> RustCrypto { + RustCrypto::default() + } + + pub fn conn_ref(&self) -> &DbConnectionPrivate { self.key_store.conn_ref() } }