diff --git a/bindings_ffi/src/lib.rs b/bindings_ffi/src/lib.rs index 670ec176e..7cd5a5b65 100755 --- a/bindings_ffi/src/lib.rs +++ b/bindings_ffi/src/lib.rs @@ -63,6 +63,8 @@ pub enum GenericError { pub enum FfiSubscribeError { #[error("Subscribe Error {0}")] Subscribe(#[from] xmtp_mls::subscriptions::SubscribeError), + #[error("Storage error: {0}")] + Storage(#[from] xmtp_mls::storage::StorageError), } impl From for GenericError { diff --git a/bindings_ffi/src/mls.rs b/bindings_ffi/src/mls.rs index 6c58ee573..b1d5ca395 100644 --- a/bindings_ffi/src/mls.rs +++ b/bindings_ffi/src/mls.rs @@ -881,8 +881,8 @@ impl FfiConversations { pub async fn sync(&self) -> Result<(), GenericError> { let inner = self.inner_client.as_ref(); - let conn = inner.store().conn()?; - inner.sync_welcomes(&conn).await?; + let provider = inner.mls_provider()?; + inner.sync_welcomes(&provider).await?; Ok(()) } @@ -895,9 +895,9 @@ impl FfiConversations { pub async fn sync_all_conversations(&self) -> Result { let inner = self.inner_client.as_ref(); - let conn = inner.store().conn()?; + let provider = inner.mls_provider()?; - let num_groups_synced: usize = inner.sync_all_welcomes_and_groups(&conn).await?; + let num_groups_synced: usize = inner.sync_all_welcomes_and_groups(&provider).await?; // Convert usize to u32 for compatibility with Uniffi let num_groups_synced: u32 = num_groups_synced @@ -1262,9 +1262,10 @@ impl FfiConversation { &self, envelope_bytes: Vec, ) -> Result { + let provider = self.inner.mls_provider()?; let message = self .inner - .process_streamed_group_message(envelope_bytes) + .process_streamed_group_message(&provider, envelope_bytes) .await?; let ffi_message = message.into(); @@ -1848,6 +1849,8 @@ mod tests { conversations: Mutex>>, consent_updates: Mutex>, notify: Notify, + inbox_id: Option, + installation_id: Option, } impl RustStreamCallback { @@ -1867,12 +1870,22 @@ mod tests { .await?; Ok(()) } + + pub fn from_client(client: &FfiXmtpClient) -> Self { + RustStreamCallback { + inbox_id: Some(client.inner_client.inbox_id().to_string()), + installation_id: Some(hex::encode(client.inner_client.installation_public_key())), + ..Default::default() + } + } } impl FfiMessageCallback for RustStreamCallback { fn on_message(&self, message: FfiMessage) { let mut messages = self.messages.lock().unwrap(); log::info!( + inbox_id = self.inbox_id, + installation_id = self.installation_id, "ON MESSAGE Received\n-------- \n{}\n----------", String::from_utf8_lossy(&message.content) ); @@ -1888,7 +1901,11 @@ mod tests { impl FfiConversationCallback for RustStreamCallback { fn on_conversation(&self, group: Arc) { - log::debug!("received conversation"); + log::debug!( + inbox_id = self.inbox_id, + installation_id = self.installation_id, + "received conversation" + ); let _ = self.num_messages.fetch_add(1, Ordering::SeqCst); let mut convos = self.conversations.lock().unwrap(); convos.push(group); @@ -1902,7 +1919,11 @@ mod tests { impl FfiConsentCallback for RustStreamCallback { fn on_consent_update(&self, mut consent: Vec) { - log::debug!("received consent update"); + log::debug!( + inbox_id = self.inbox_id, + installation_id = self.installation_id, + "received consent update" + ); let mut consent_updates = self.consent_updates.lock().unwrap(); consent_updates.append(&mut consent); self.notify.notify_one(); @@ -2751,7 +2772,7 @@ mod tests { let caro = new_test_client().await; // Alix begins a stream for all messages - let message_callbacks = Arc::new(RustStreamCallback::default()); + let message_callbacks = Arc::new(RustStreamCallback::from_client(&alix)); let stream_messages = alix .conversations() .stream_all_messages(message_callbacks.clone()) @@ -2798,12 +2819,12 @@ mod tests { let bo2 = new_test_client_with_wallet(bo_wallet).await; // Bo begins a stream for all messages - let bo_message_callbacks = Arc::new(RustStreamCallback::default()); - let bo_stream_messages = bo2 + let bo2_message_callbacks = Arc::new(RustStreamCallback::from_client(&bo2)); + let bo2_stream_messages = bo2 .conversations() - .stream_all_messages(bo_message_callbacks.clone()) + .stream_all_messages(bo2_message_callbacks.clone()) .await; - bo_stream_messages.wait_for_ready().await; + bo2_stream_messages.wait_for_ready().await; alix_group.update_installations().await.unwrap(); @@ -3167,7 +3188,7 @@ mod tests { let bo = new_test_client().await; let caro = new_test_client().await; - let caro_conn = caro.inner_client.store().conn().unwrap(); + let caro_provider = caro.inner_client.mls_provider().unwrap(); let alix_group = alix .conversations() @@ -3197,7 +3218,11 @@ mod tests { ) .await .unwrap(); - let _ = caro.inner_client.sync_welcomes(&caro_conn).await.unwrap(); + let _ = caro + .inner_client + .sync_welcomes(&caro_provider) + .await + .unwrap(); bo_group.send("second".as_bytes().to_vec()).await.unwrap(); stream_callback.wait_for_delivery(None).await.unwrap(); @@ -3216,7 +3241,7 @@ mod tests { let amal = new_test_client().await; let bola = new_test_client().await; - let bola_conn = bola.inner_client.store().conn().unwrap(); + let bola_provider = bola.inner_client.mls_provider().unwrap(); let amal_group: Arc = amal .conversations() @@ -3227,7 +3252,10 @@ mod tests { .await .unwrap(); - bola.inner_client.sync_welcomes(&bola_conn).await.unwrap(); + bola.inner_client + .sync_welcomes(&bola_provider) + .await + .unwrap(); let bola_group = bola.conversation(amal_group.id()).unwrap(); let stream_callback = Arc::new(RustStreamCallback::default()); diff --git a/bindings_node/src/conversation.rs b/bindings_node/src/conversation.rs index aeeffa6b1..897662463 100644 --- a/bindings_node/src/conversation.rs +++ b/bindings_node/src/conversation.rs @@ -197,8 +197,9 @@ impl Conversation { self.created_at_ns, ); let envelope_bytes: Vec = envelope_bytes.deref().to_vec(); + let provider = group.mls_provider().map_err(ErrorWrapper::from)?; let message = group - .process_streamed_group_message(envelope_bytes) + .process_streamed_group_message(&provider, envelope_bytes) .await .map_err(ErrorWrapper::from)?; diff --git a/bindings_node/src/conversations.rs b/bindings_node/src/conversations.rs index fb35a4d83..ac1111a0a 100644 --- a/bindings_node/src/conversations.rs +++ b/bindings_node/src/conversations.rs @@ -235,14 +235,13 @@ impl Conversations { #[napi] pub async fn sync(&self) -> Result<()> { - let conn = self + let provider = self .inner_client - .store() - .conn() + .mls_provider() .map_err(ErrorWrapper::from)?; self .inner_client - .sync_welcomes(&conn) + .sync_welcomes(&provider) .await .map_err(ErrorWrapper::from)?; Ok(()) @@ -250,15 +249,14 @@ impl Conversations { #[napi] pub async fn sync_all_conversations(&self) -> Result { - let conn = self + let provider = self .inner_client - .store() - .conn() + .mls_provider() .map_err(ErrorWrapper::from)?; let num_groups_synced = self .inner_client - .sync_all_welcomes_and_groups(&conn) + .sync_all_welcomes_and_groups(&provider) .await .map_err(ErrorWrapper::from)?; diff --git a/bindings_wasm/src/conversations.rs b/bindings_wasm/src/conversations.rs index c33cd901d..eb11621a7 100644 --- a/bindings_wasm/src/conversations.rs +++ b/bindings_wasm/src/conversations.rs @@ -269,14 +269,13 @@ impl Conversations { #[wasm_bindgen] pub async fn sync(&self) -> Result<(), JsError> { - let conn = self + let provider = self .inner_client - .store() - .conn() + .mls_provider() .map_err(|e| JsError::new(format!("{}", e).as_str()))?; self .inner_client - .sync_welcomes(&conn) + .sync_welcomes(&provider) .await .map_err(|e| JsError::new(format!("{}", e).as_str()))?; @@ -285,15 +284,14 @@ impl Conversations { #[wasm_bindgen(js_name = syncAllConversations)] pub async fn sync_all_conversations(&self) -> Result { - let conn = self + let provider = self .inner_client - .store() - .conn() + .mls_provider() .map_err(|e| JsError::new(format!("{}", e).as_str()))?; let num_groups_synced = self .inner_client - .sync_all_welcomes_and_groups(&conn) + .sync_all_welcomes_and_groups(&provider) .await .map_err(|e| JsError::new(format!("{}", e).as_str()))?; diff --git a/examples/cli/cli-client.rs b/examples/cli/cli-client.rs index b371edef1..a9513bec5 100755 --- a/examples/cli/cli-client.rs +++ b/examples/cli/cli-client.rs @@ -290,9 +290,9 @@ async fn main() -> color_eyre::eyre::Result<()> { } Commands::ListGroups {} => { info!("List Groups"); - let conn = client.store().conn()?; + let provider = client.mls_provider()?; client - .sync_welcomes(&conn) + .sync_welcomes(&provider) .await .expect("failed to sync welcomes"); @@ -440,9 +440,8 @@ async fn main() -> color_eyre::eyre::Result<()> { ); } Commands::RequestHistorySync {} => { - let conn = client.store().conn().unwrap(); let provider = client.mls_provider().unwrap(); - client.sync_welcomes(&conn).await.unwrap(); + client.sync_welcomes(&provider).await.unwrap(); client.start_sync_worker(); client .send_sync_request(&provider, DeviceSyncKind::MessageHistory) @@ -451,9 +450,9 @@ async fn main() -> color_eyre::eyre::Result<()> { info!("Sent history sync request in sync group.") } Commands::ListHistorySyncMessages {} => { - let conn = client.store().conn()?; - client.sync_welcomes(&conn).await?; - let group = client.get_sync_group(&conn)?; + let provider = client.mls_provider()?; + client.sync_welcomes(&provider).await?; + let group = client.get_sync_group(provider.conn_ref())?; let group_id_str = hex::encode(group.group_id.clone()); group.sync().await?; let messages = group @@ -574,8 +573,8 @@ where } async fn get_group(client: &Client, group_id: Vec) -> Result { - let conn = client.store().conn().unwrap(); - client.sync_welcomes(&conn).await?; + let provider = client.mls_provider().unwrap(); + client.sync_welcomes(&provider).await?; let group = client.group(group_id)?; group .sync() diff --git a/xmtp_debug/src/app/generate/messages.rs b/xmtp_debug/src/app/generate/messages.rs index aeb118ff7..fc6006d85 100644 --- a/xmtp_debug/src/app/generate/messages.rs +++ b/xmtp_debug/src/app/generate/messages.rs @@ -8,7 +8,6 @@ use crate::{ use color_eyre::eyre::{self, eyre, Result}; use rand::{rngs::SmallRng, seq::SliceRandom, Rng, SeedableRng}; use std::sync::Arc; -use xmtp_mls::XmtpOpenMlsProvider; mod content_type; @@ -118,10 +117,9 @@ impl GenerateMessages { hex::encode(inbox_id) ))?; let client = app::client_from_identity(&identity, &network).await?; - let conn = client.store().conn()?; - client.sync_welcomes(&conn).await?; + let provider = client.mls_provider()?; + client.sync_welcomes(&provider).await?; let group = client.group(group.id.into())?; - let provider: XmtpOpenMlsProvider = conn.into(); group.maybe_update_installations(&provider, None).await?; group.sync_with_conn(&provider).await?; let words = rng.gen_range(0..*max_message_size); diff --git a/xmtp_debug/src/app/send.rs b/xmtp_debug/src/app/send.rs index 59601d4fe..35a95cbf7 100644 --- a/xmtp_debug/src/app/send.rs +++ b/xmtp_debug/src/app/send.rs @@ -49,8 +49,8 @@ impl Send { .ok_or(eyre!("No Identity with inbox_id [{}]", hex::encode(member)))?; let client = crate::app::client_from_identity(&identity, network).await?; - let conn = client.store().conn()?; - client.sync_welcomes(&conn).await?; + let provider = client.mls_provider()?; + client.sync_welcomes(&provider).await?; let xmtp_group = client.group(group.id.to_vec())?; xmtp_group.send_message(data.as_bytes()).await?; Ok(()) diff --git a/xmtp_mls/src/api/mls.rs b/xmtp_mls/src/api/mls.rs index 4000467e6..33a42a577 100644 --- a/xmtp_mls/src/api/mls.rs +++ b/xmtp_mls/src/api/mls.rs @@ -115,9 +115,9 @@ where } #[tracing::instrument(level = "trace", skip_all)] - pub async fn query_welcome_messages( + pub async fn query_welcome_messages + Copy>( &self, - installation_id: &[u8], + installation_id: Id, id_cursor: Option, ) -> Result, ApiError> { tracing::debug!( @@ -135,7 +135,7 @@ where (async { self.api_client .query_welcome_messages(QueryWelcomeMessagesRequest { - installation_key: installation_id.to_vec(), + installation_key: installation_id.as_ref().to_vec(), paging_info: Some(PagingInfo { id_cursor: id_cursor.unwrap_or(0), limit: page_size, @@ -297,7 +297,7 @@ where pub async fn subscribe_welcome_messages( &self, - installation_key: Vec, + installation_key: &[u8], id_cursor: Option, ) -> Result> + '_, ApiError> where @@ -307,7 +307,7 @@ where self.api_client .subscribe_welcome_messages(SubscribeWelcomeMessagesRequest { filters: vec![WelcomeFilterProto { - installation_key, + installation_key: installation_key.to_vec(), id_cursor: id_cursor.unwrap_or(0), }], }) diff --git a/xmtp_mls/src/builder.rs b/xmtp_mls/src/builder.rs index 930c81506..c03462161 100644 --- a/xmtp_mls/src/builder.rs +++ b/xmtp_mls/src/builder.rs @@ -187,15 +187,16 @@ where .take() .ok_or(ClientBuilderError::MissingParameter { parameter: "store" })?; - debug!( - inbox_id = identity_strategy.inbox_id(), - "Initializing identity" - ); - let identity = identity_strategy .initialize_identity(&api_client_wrapper, &store, &scw_verifier) .await?; + debug!( + inbox_id = identity.inbox_id(), + installation_id = hex::encode(identity.installation_keys.public_bytes()), + "Initialized identity" + ); + // get sequence_id from identity updates and loaded into the DB load_identity_updates( &api_client_wrapper, diff --git a/xmtp_mls/src/client.rs b/xmtp_mls/src/client.rs index eb6a86ab1..ef7af740a 100644 --- a/xmtp_mls/src/client.rs +++ b/xmtp_mls/src/client.rs @@ -12,7 +12,6 @@ use openmls::{ messages::Welcome, prelude::tls_codec::{Deserialize, Error as TlsCodecError}, }; -use openmls_traits::OpenMlsProvider; use thiserror::Error; use tokio::sync::broadcast; @@ -50,6 +49,7 @@ use crate::{ EncryptedMessageStore, StorageError, }, subscriptions::{LocalEventError, LocalEvents}, + types::InstallationId, verified_key_package_v2::{KeyPackageVerificationError, VerifiedKeyPackageV2}, xmtp_openmls_provider::XmtpOpenMlsProvider, Fetch, Store, XmtpApi, @@ -172,8 +172,8 @@ pub struct XmtpMlsLocalContext { impl XmtpMlsLocalContext { /// The installation public key is the primary identifier for an installation - pub fn installation_public_key(&self) -> &[u8; 32] { - self.identity.installation_keys.public_bytes() + pub fn installation_public_key(&self) -> InstallationId { + (*self.identity.installation_keys.public_bytes()).into() } /// Get the account address of the blockchain account associated with this client @@ -191,7 +191,7 @@ impl XmtpMlsLocalContext { } /// Pulls a new database connection and creates a new provider - pub fn mls_provider(&self) -> Result { + pub fn mls_provider(&self) -> Result { Ok(self.store.conn()?.into()) } @@ -278,7 +278,7 @@ where V: SmartContractSignatureVerifier, { /// Retrieves the client's installation public key, sometimes also called `installation_id` - pub fn installation_public_key(&self) -> &[u8; 32] { + pub fn installation_public_key(&self) -> InstallationId { self.context.installation_public_key() } /// Retrieves the client's inbox ID @@ -291,7 +291,7 @@ where } /// Pulls a connection and creates a new MLS Provider - pub fn mls_provider(&self) -> Result { + pub fn mls_provider(&self) -> Result { self.context.mls_provider() } @@ -689,9 +689,12 @@ where /// Upload a new key package to the network replacing an existing key package /// This is expected to be run any time the client receives new Welcome messages - pub async fn rotate_key_package(&self) -> Result<(), ClientError> { + pub async fn rotate_key_package( + &self, + provider: &XmtpOpenMlsProvider, + ) -> Result<(), ClientError> { self.store() - .transaction_async(|provider| async move { + .transaction_async(provider, |provider| async move { self.identity() .rotate_key_package(&provider, &self.api_client) .await @@ -729,7 +732,7 @@ where let welcomes = self .api_client - .query_welcome_messages(installation_id, Some(id_cursor as u64)) + .query_welcome_messages(installation_id.as_ref(), Some(id_cursor as u64)) .await?; Ok(welcomes) @@ -743,10 +746,10 @@ where ) -> Result, ClientError> { let key_package_results = self.api_client.fetch_key_packages(installation_ids).await?; - let mls_provider = self.mls_provider()?; + let crypto_provider = XmtpOpenMlsProvider::new_crypto(); Ok(key_package_results .values() - .map(|bytes| VerifiedKeyPackageV2::from_bytes(mls_provider.crypto(), bytes.as_slice())) + .map(|bytes| VerifiedKeyPackageV2::from_bytes(&crypto_provider, bytes.as_slice())) .collect::>()?) } @@ -755,9 +758,9 @@ where #[tracing::instrument(level = "debug", skip_all)] pub async fn sync_welcomes( &self, - conn: &DbConnection, + provider: &XmtpOpenMlsProvider, ) -> Result>, ClientError> { - let envelopes = self.query_welcome_messages(conn).await?; + let envelopes = self.query_welcome_messages(provider.conn_ref()).await?; let num_envelopes = envelopes.len(); let id = self.installation_public_key(); @@ -775,7 +778,8 @@ where (async { let welcome_v1 = &welcome_v1; self.intents.process_for_id( - id, + provider, + id.as_ref(), EntityKind::Welcome, welcome_v1.id, |provider| async move { @@ -816,7 +820,7 @@ where // If any welcomes were found, rotate your key package if num_envelopes > 0 { - self.rotate_key_package().await?; + self.rotate_key_package(provider).await?; } Ok(groups) @@ -876,9 +880,9 @@ where /// Returns the total number of active groups synced. pub async fn sync_all_welcomes_and_groups( &self, - conn: &DbConnection, + provider: &XmtpOpenMlsProvider, ) -> Result { - self.sync_welcomes(conn).await?; + self.sync_welcomes(provider).await?; let groups = self.find_groups(GroupQueryArgs::default().include_sync_groups())?; let active_groups_count = self.sync_all_groups(groups).await?; @@ -1056,7 +1060,10 @@ pub(crate) mod tests { let init1 = kp1[0].inner.hpke_init_key(); // Rotate and fetch again. - client.rotate_key_package().await.unwrap(); + client + .rotate_key_package(&client.mls_provider().unwrap()) + .await + .unwrap(); let kp2 = client .get_key_packages_for_installation_ids(vec![client.installation_public_key().to_vec()]) @@ -1117,7 +1124,7 @@ pub(crate) mod tests { .unwrap(); let bob_received_groups = bob - .sync_welcomes(&bob.store().conn().unwrap()) + .sync_welcomes(&bob.mls_provider().unwrap()) .await .unwrap(); assert_eq!(bob_received_groups.len(), 1); @@ -1127,7 +1134,7 @@ pub(crate) mod tests { ); let duplicate_received_groups = bob - .sync_welcomes(&bob.store().conn().unwrap()) + .sync_welcomes(&bob.mls_provider().unwrap()) .await .unwrap(); assert_eq!(duplicate_received_groups.len(), 0); @@ -1157,7 +1164,7 @@ pub(crate) mod tests { .await .unwrap(); - let bob_received_groups = bo.sync_welcomes(&bo.store().conn().unwrap()).await.unwrap(); + let bob_received_groups = bo.sync_welcomes(&bo.mls_provider().unwrap()).await.unwrap(); assert_eq!(bob_received_groups.len(), 2); let bo_groups = bo.find_groups(GroupQueryArgs::default()).unwrap(); @@ -1233,7 +1240,7 @@ pub(crate) mod tests { assert_eq!(amal_group.members().await.unwrap().len(), 1); tracing::info!("Syncing bolas welcomes"); // See if Bola can see that they were added to the group - bola.sync_welcomes(&bola.store().conn().unwrap()) + bola.sync_welcomes(&bola.mls_provider().unwrap()) .await .unwrap(); let bola_groups = bola.find_groups(Default::default()).unwrap(); @@ -1254,7 +1261,7 @@ pub(crate) mod tests { .add_members_by_inbox_id(&[bola.inbox_id()]) .await .unwrap(); - bola.sync_welcomes(&bola.store().conn().unwrap()) + bola.sync_welcomes(&bola.mls_provider().unwrap()) .await .unwrap(); @@ -1304,12 +1311,13 @@ pub(crate) mod tests { async fn get_key_package_init_key< ApiClient: XmtpApi, Verifier: SmartContractSignatureVerifier, + Id: AsRef<[u8]>, >( client: &Client, - installation_id: &[u8], + installation_id: Id, ) -> Vec { let kps = client - .get_key_packages_for_installation_ids(vec![installation_id.to_vec()]) + .get_key_packages_for_installation_ids(vec![installation_id.as_ref().to_vec()]) .await .unwrap(); let kp = kps.first().unwrap(); @@ -1346,18 +1354,18 @@ pub(crate) mod tests { .await .unwrap(); - bo.sync_welcomes(&bo.store().conn().unwrap()).await.unwrap(); + bo.sync_welcomes(&bo.mls_provider().unwrap()).await.unwrap(); let bo_new_key = get_key_package_init_key(&bo, bo.installation_public_key()).await; // Bo's key should have changed assert_ne!(bo_original_init_key, bo_new_key); - bo.sync_welcomes(&bo.store().conn().unwrap()).await.unwrap(); + bo.sync_welcomes(&bo.mls_provider().unwrap()).await.unwrap(); let bo_new_key_2 = get_key_package_init_key(&bo, bo.installation_public_key()).await; // Bo's key should not have changed syncing the second time. assert_eq!(bo_new_key, bo_new_key_2); - alix.sync_welcomes(&alix.store().conn().unwrap()) + alix.sync_welcomes(&alix.mls_provider().unwrap()) .await .unwrap(); let alix_key_2 = get_key_package_init_key(&alix, alix.installation_public_key()).await; @@ -1371,7 +1379,7 @@ pub(crate) mod tests { ) .await .unwrap(); - bo.sync_welcomes(&bo.store().conn().unwrap()).await.unwrap(); + bo.sync_welcomes(&bo.mls_provider().unwrap()).await.unwrap(); // Bo should have two groups now let bo_groups = bo.find_groups(GroupQueryArgs::default()).unwrap(); diff --git a/xmtp_mls/src/groups/device_sync.rs b/xmtp_mls/src/groups/device_sync.rs index 2a8d65983..4288e18af 100644 --- a/xmtp_mls/src/groups/device_sync.rs +++ b/xmtp_mls/src/groups/device_sync.rs @@ -385,7 +385,7 @@ where let content = DeviceSyncContent::Request(request.clone()); let content_bytes = serde_json::to_vec(&content)?; - let _message_id = sync_group.prepare_message(&content_bytes, provider.conn_ref(), { + let _message_id = sync_group.prepare_message(&content_bytes, provider, { let request = request.clone(); move |_time_ns| PlaintextEnvelope { content: Some(Content::V2(V2 { @@ -429,7 +429,6 @@ where provider: &XmtpOpenMlsProvider, contents: DeviceSyncReplyProto, ) -> Result<(), DeviceSyncError> { - let conn = provider.conn_ref(); // find the sync group let sync_group = self.get_sync_group(provider.conn_ref())?; @@ -441,7 +440,7 @@ where .await?; // add original sender to all groups on this device on the node - self.ensure_member_of_all_groups(conn, &msg.sender_inbox_id) + self.ensure_member_of_all_groups(provider.conn_ref(), &msg.sender_inbox_id) .await?; // the reply message @@ -455,7 +454,7 @@ where (content_bytes, contents) }; - sync_group.prepare_message(&content_bytes, conn, |_time_ns| PlaintextEnvelope { + sync_group.prepare_message(&content_bytes, provider, |_time_ns| PlaintextEnvelope { content: Some(Content::V2(V2 { idempotency_key: new_request_id(), message_type: Some(MessageType::DeviceSyncReply(contents)), @@ -549,7 +548,7 @@ where self.insert_encrypted_syncables(provider, enc_payload, &enc_key.try_into()?) .await?; - self.sync_welcomes(provider.conn_ref()).await?; + self.sync_welcomes(provider).await?; let groups = conn.find_groups(GroupQueryArgs::default().conversation_type(ConversationType::Group))?; diff --git a/xmtp_mls/src/groups/device_sync/consent_sync.rs b/xmtp_mls/src/groups/device_sync/consent_sync.rs index 0075c1230..da9f446a2 100644 --- a/xmtp_mls/src/groups/device_sync/consent_sync.rs +++ b/xmtp_mls/src/groups/device_sync/consent_sync.rs @@ -24,7 +24,6 @@ where "Streaming consent update. {:?}", record ); - let conn = provider.conn_ref(); let consent_update_proto = ConsentUpdateProto { entity: record.entity.clone(), @@ -42,7 +41,7 @@ where let sync_group = self.ensure_sync_group(provider).await?; let content_bytes = serde_json::to_vec(&consent_update_proto)?; - sync_group.prepare_message(&content_bytes, conn, |_time_ns| PlaintextEnvelope { + sync_group.prepare_message(&content_bytes, provider, |_time_ns| PlaintextEnvelope { content: Some(Content::V2(V2 { idempotency_key: new_request_id(), message_type: Some(MessageType::ConsentUpdate(consent_update_proto)), @@ -126,7 +125,7 @@ pub(crate) mod tests { let old_group_id = amal_a.get_sync_group(amal_a_conn).unwrap().group_id; tracing::info!("Old Group Id: {}", hex::encode(&old_group_id)); // Check for new welcomes to new groups in the first installation (should be welcomed to a new sync group from amal_b). - amal_a.sync_welcomes(amal_a_conn).await.unwrap(); + amal_a.sync_welcomes(&amal_a_provider).await.unwrap(); let new_group_id = amal_a.get_sync_group(amal_a_conn).unwrap().group_id; tracing::info!("New Group Id: {}", hex::encode(&new_group_id)); // group id should have changed to the new sync group created by the second installation diff --git a/xmtp_mls/src/groups/device_sync/message_sync.rs b/xmtp_mls/src/groups/device_sync/message_sync.rs index 056f0c3e7..66e63e261 100644 --- a/xmtp_mls/src/groups/device_sync/message_sync.rs +++ b/xmtp_mls/src/groups/device_sync/message_sync.rs @@ -106,7 +106,7 @@ pub(crate) mod tests { let old_group_id = amal_a.get_sync_group(amal_a_conn).unwrap().group_id; // Check for new welcomes to new groups in the first installation (should be welcomed to a new sync group from amal_b). amal_a - .sync_welcomes(amal_a_conn) + .sync_welcomes(&amal_a_provider) .await .expect("sync_welcomes"); let new_group_id = amal_a.get_sync_group(amal_a_conn).unwrap().group_id; @@ -205,7 +205,7 @@ pub(crate) mod tests { // Check for new welcomes to new groups in the first installation (should be welcomed to a new sync group from amal_b). amal_a - .sync_welcomes(amal_a_conn) + .sync_welcomes(&amal_a_provider) .await .expect("sync_welcomes"); let new_group_id = amal_a.get_sync_group(amal_a_conn).unwrap().group_id; @@ -262,7 +262,7 @@ pub(crate) mod tests { async fn test_externals_cant_join_sync_group() { let wallet = generate_local_wallet(); let amal = ClientBuilder::new_test_client_with_history(&wallet, HISTORY_SYNC_URL).await; - amal.sync_welcomes(&amal.store().conn().unwrap()) + amal.sync_welcomes(&amal.mls_provider().unwrap()) .await .expect("sync welcomes"); @@ -271,7 +271,7 @@ pub(crate) mod tests { ClientBuilder::new_test_client_with_history(&bo_wallet, HISTORY_SYNC_URL).await; bo_client - .sync_welcomes(&bo_client.store().conn().unwrap()) + .sync_welcomes(&bo_client.mls_provider().unwrap()) .await .expect("sync welcomes"); diff --git a/xmtp_mls/src/groups/intents.rs b/xmtp_mls/src/groups/intents.rs index 80f5eb167..4802afa94 100644 --- a/xmtp_mls/src/groups/intents.rs +++ b/xmtp_mls/src/groups/intents.rs @@ -33,6 +33,7 @@ use crate::{ }, types::Address, verified_key_package_v2::{KeyPackageVerificationError, VerifiedKeyPackageV2}, + XmtpOpenMlsProvider, }; use super::{ @@ -58,16 +59,18 @@ pub enum IntentError { impl MlsGroup { pub fn queue_intent( &self, + provider: &XmtpOpenMlsProvider, intent_kind: IntentKind, intent_data: Vec, ) -> Result { - self.context().store().transaction(|provider| { + self.context().store().transaction(provider, |provider| { let conn = provider.conn_ref(); self.queue_intent_with_conn(conn, intent_kind, intent_data) }) } - pub fn queue_intent_with_conn( + /// NOTE: Dangerous to use without a transaction + fn queue_intent_with_conn( &self, conn: &DbConnection, intent_kind: IntentKind, @@ -800,7 +803,7 @@ pub(crate) mod tests { // Client B sends a message to Client A let groups_b = client_b - .sync_welcomes(&client_b.store().conn().unwrap()) + .sync_welcomes(&client_b.mls_provider().unwrap()) .await .unwrap(); assert_eq!(groups_b.len(), 1); diff --git a/xmtp_mls/src/groups/mls_sync.rs b/xmtp_mls/src/groups/mls_sync.rs index a47ec3f9a..7cdb692b1 100644 --- a/xmtp_mls/src/groups/mls_sync.rs +++ b/xmtp_mls/src/groups/mls_sync.rs @@ -170,6 +170,7 @@ where let mls_provider = XmtpOpenMlsProvider::from(conn); tracing::info!( inbox_id = self.client.inbox_id(), + installation_id = %self.client.installation_id(), group_id = hex::encode(&self.group_id), current_epoch = self.load_mls_group(&mls_provider)?.epoch().as_u64(), "[{}] syncing group", @@ -177,6 +178,7 @@ where ); tracing::info!( inbox_id = self.client.inbox_id(), + installation_id = %self.client.installation_id(), group_id = hex::encode(&self.group_id), current_epoch = self.load_mls_group(&mls_provider)?.epoch().as_u64(), "current epoch for [{}] in sync() is Epoch: [{}]", @@ -363,7 +365,7 @@ where let group_epoch = openmls_group.epoch(); debug!( inbox_id = self.client.inbox_id(), - installation_id = hex::encode(self.client.installation_id()), + installation_id = %self.client.installation_id(), group_id = hex::encode(&self.group_id), current_epoch = openmls_group.epoch().as_u64(), msg_id, @@ -392,6 +394,7 @@ where if published_in_epoch_u64 != group_epoch_u64 { tracing::warn!( inbox_id = self.client.inbox_id(), + installation_id = %self.client.installation_id(), group_id = hex::encode(&self.group_id), current_epoch = openmls_group.epoch().as_u64(), msg_id, @@ -491,6 +494,7 @@ where tracing::info!( inbox_id = self.client.inbox_id(), + installation_id = %self.client.installation_id(), sender_inbox_id = sender_inbox_id, sender_installation_id = hex::encode(&sender_installation_id), group_id = hex::encode(&self.group_id), @@ -512,6 +516,8 @@ where tracing::info!( inbox_id = self.client.inbox_id(), sender_inbox_id = sender_inbox_id, + sender_installation_id = hex::encode(&sender_installation_id), + installation_id = %self.client.installation_id(), group_id = hex::encode(&self.group_id), current_epoch = openmls_group.epoch().as_u64(), msg_epoch, @@ -636,6 +642,7 @@ where tracing::info!( inbox_id = self.client.inbox_id(), sender_inbox_id = sender_inbox_id, + installation_id = %self.client.installation_id(), sender_installation_id = hex::encode(&sender_installation_id), group_id = hex::encode(&self.group_id), current_epoch = openmls_group.epoch().as_u64(), @@ -659,6 +666,7 @@ where tracing::info!( inbox_id = self.client.inbox_id(), sender_inbox_id = sender_inbox_id, + installation_id = %self.client.installation_id(), sender_installation_id = hex::encode(&sender_installation_id), group_id = hex::encode(&self.group_id), current_epoch = openmls_group.epoch().as_u64(), @@ -705,6 +713,7 @@ where .find_group_intent_by_payload_hash(sha256(envelope.data.as_slice())); tracing::info!( inbox_id = self.client.inbox_id(), + installation_id = %self.client.installation_id(), group_id = hex::encode(&self.group_id), current_epoch = openmls_group.epoch().as_u64(), msg_id = envelope.id, @@ -718,6 +727,7 @@ where let intent_id = intent.id; tracing::info!( inbox_id = self.client.inbox_id(), + installation_id = %self.client.installation_id(), group_id = hex::encode(&self.group_id), current_epoch = openmls_group.epoch().as_u64(), msg_id = envelope.id, @@ -752,6 +762,7 @@ where Ok(None) => { tracing::info!( inbox_id = self.client.inbox_id(), + installation_id = %self.client.installation_id(), group_id = hex::encode(&self.group_id), current_epoch = openmls_group.epoch().as_u64(), msg_id = envelope.id, @@ -769,9 +780,9 @@ where #[tracing::instrument(level = "trace", skip_all)] async fn consume_message( &self, + provider: &XmtpOpenMlsProvider, envelope: &GroupMessage, openmls_group: &mut OpenMlsGroup, - conn: &DbConnection, ) -> Result<(), GroupMessageProcessingError> { let msgv1 = match &envelope.version { Some(GroupMessageVersion::V1(value)) => value, @@ -784,12 +795,15 @@ where _ => EntityKind::Group, }; - let last_cursor = conn.get_last_cursor_for_id(&self.group_id, message_entity_kind)?; + let last_cursor = provider + .conn_ref() + .get_last_cursor_for_id(&self.group_id, message_entity_kind)?; tracing::info!("### last cursor --> [{:?}]", last_cursor); let should_skip_message = last_cursor > msgv1.id as i64; if should_skip_message { tracing::info!( inbox_id = "self.inbox_id()", + installation_id = %self.client.installation_id(), group_id = hex::encode(&self.group_id), "Message already processed: skipped msgId:[{}] entity kind:[{:?}] last cursor in db: [{}]", msgv1.id, @@ -801,6 +815,7 @@ where self.client .intents() .process_for_id( + provider, &msgv1.group_id, EntityKind::Group, msgv1.id, @@ -828,7 +843,7 @@ where let result = retry_async!( Retry::default(), (async { - self.consume_message(&message, &mut openmls_group, provider.conn_ref()) + self.consume_message(provider, &message, &mut openmls_group) .await }) ); @@ -852,7 +867,13 @@ where if receive_errors.is_empty() { Ok(()) } else { - tracing::error!("Message processing errors: {:?}", receive_errors); + tracing::error!( + group_id = hex::encode(&self.group_id), + inbox_id = self.client.inbox_id(), + installation_id = hex::encode(self.client.installation_id()), + "Message processing errors: {:?}", + receive_errors + ); Err(GroupError::ReceiveErrors(receive_errors)) } } @@ -944,6 +965,7 @@ where intent.id, intent.kind = %intent.kind, inbox_id = self.client.inbox_id(), + installation_id = %self.client.installation_id(), group_id = hex::encode(&self.group_id), "intent {} has reached max publish attempts", intent.id); // TODO: Eventually clean up errored attempts @@ -974,7 +996,7 @@ where )?; tracing::debug!( inbox_id = self.client.inbox_id(), - installation_id = hex::encode(self.client.installation_id()), + installation_id = %self.client.installation_id(), intent.id, intent.kind = %intent.kind, group_id = hex::encode(&self.group_id), @@ -992,7 +1014,7 @@ where intent.id, intent.kind = %intent.kind, inbox_id = self.client.inbox_id(), - installation_id = hex::encode(self.client.installation_id()), + installation_id = %self.client.installation_id(), group_id = hex::encode(&self.group_id), "[{}] published intent [{}] of type [{}]", self.client.inbox_id(), @@ -1007,7 +1029,7 @@ where Ok(None) => { tracing::info!( inbox_id = self.client.inbox_id(), - installation_id = hex::encode(self.client.installation_id()), + installation_id = %self.client.installation_id(), "Skipping intent because no publish data returned" ); let deleter: &dyn Delete = provider.conn_ref(); @@ -1148,7 +1170,7 @@ where if let Some(post_commit_data) = intent.post_commit_data { tracing::debug!( inbox_id = self.client.inbox_id(), - installation_id = hex::encode(self.client.installation_id()), + installation_id = %self.client.installation_id(), intent.id, intent.kind = %intent.kind, "taking post commit action" ); @@ -1216,13 +1238,13 @@ where debug!( inbox_id = self.client.inbox_id(), - installation_id = hex::encode(self.client.installation_id()), + installation_id = %self.client.installation_id(), "Adding missing installations {:?}", intent_data ); - let intent = self.queue_intent_with_conn( - provider.conn_ref(), + let intent = self.queue_intent( + provider, IntentKind::UpdateGroupMembership, intent_data.into(), )?; diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index 12440397c..9c5e59510 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -90,7 +90,7 @@ use crate::{ group::{ConversationType, GroupMembershipState, StoredGroup}, group_intent::IntentKind, group_message::{DeliveryStatus, GroupMessageKind, MsgQueryArgs, StoredGroupMessage}, - sql_key_store, + sql_key_store, StorageError, }, subscriptions::{LocalEventError, LocalEvents}, utils::{id::calculate_message_id, time::now_ns}, @@ -262,6 +262,7 @@ pub struct MlsGroup { pub group_id: Vec, pub created_at_ns: i64, pub client: Arc, + // provider: XmtpOpenMlsProvider, mutex: Arc>, } @@ -318,8 +319,8 @@ impl MlsGroup { /// Instantiate a new [`XmtpOpenMlsProvider`] pulling a connection from the database. /// prefer to use an already-instantiated mls provider if possible. - pub fn mls_provider(&self) -> Result { - Ok(self.context().mls_provider()?) + pub fn mls_provider(&self) -> Result { + self.context().mls_provider() } // Load the stored OpenMLS group from the OpenMLS provider's keystore @@ -607,9 +608,8 @@ impl MlsGroup { self.maybe_update_installations(provider, update_interval_ns) .await?; - let message_id = self.prepare_message(message, provider.conn_ref(), |now| { - Self::into_envelope(message, now) - }); + let message_id = + self.prepare_message(message, provider, |now| Self::into_envelope(message, now)); self.sync_until_last_intent_resolved(provider).await?; @@ -647,9 +647,9 @@ impl MlsGroup { /// Send a message, optimistically returning the ID of the message before the result of a message publish. pub fn send_message_optimistic(&self, message: &[u8]) -> Result, GroupError> { - let conn = self.context().store().conn()?; + let provider = self.mls_provider()?; let message_id = - self.prepare_message(message, &conn, |now| Self::into_envelope(message, now))?; + self.prepare_message(message, &provider, |now| Self::into_envelope(message, now))?; Ok(message_id) } @@ -663,7 +663,7 @@ impl MlsGroup { fn prepare_message( &self, message: &[u8], - conn: &DbConnection, + provider: &XmtpOpenMlsProvider, envelope: F, ) -> Result, GroupError> where @@ -677,7 +677,7 @@ impl MlsGroup { .map_err(GroupError::EncodeError)?; let intent_data: Vec = SendMessageIntentData::new(encoded_envelope).into(); - self.queue_intent_with_conn(conn, IntentKind::SendMessage, intent_data)?; + self.queue_intent(provider, IntentKind::SendMessage, intent_data)?; // store this unpublished message locally before sending let message_id = calculate_message_id(&self.group_id, message, &now.to_string()); @@ -691,7 +691,7 @@ impl MlsGroup { sender_inbox_id: self.context().inbox_id().to_string(), delivery_status: DeliveryStatus::Unpublished, }; - group_message.store(conn)?; + group_message.store(provider.conn_ref())?; Ok(message_id) } @@ -767,8 +767,8 @@ impl MlsGroup { return Ok(()); } - let intent = self.queue_intent_with_conn( - provider.conn_ref(), + let intent = self.queue_intent( + &provider, IntentKind::UpdateGroupMembership, intent_data.into(), )?; @@ -816,8 +816,8 @@ impl MlsGroup { .get_membership_update_intent(&provider, &[], inbox_ids) .await?; - let intent = self.queue_intent_with_conn( - provider.conn_ref(), + let intent = self.queue_intent( + &provider, IntentKind::UpdateGroupMembership, intent_data.into(), )?; @@ -834,7 +834,7 @@ impl MlsGroup { } let intent_data: Vec = UpdateMetadataIntentData::new_update_group_name(group_name).into(); - let intent = self.queue_intent(IntentKind::MetadataUpdate, intent_data)?; + let intent = self.queue_intent(&provider, IntentKind::MetadataUpdate, intent_data)?; self.sync_until_intent_resolved(&provider, intent.id).await } @@ -863,7 +863,7 @@ impl MlsGroup { ) .into(); - let intent = self.queue_intent(IntentKind::UpdatePermission, intent_data)?; + let intent = self.queue_intent(&provider, IntentKind::UpdatePermission, intent_data)?; self.sync_until_intent_resolved(&provider, intent.id).await } @@ -893,7 +893,7 @@ impl MlsGroup { } let intent_data: Vec = UpdateMetadataIntentData::new_update_group_description(group_description).into(); - let intent = self.queue_intent(IntentKind::MetadataUpdate, intent_data)?; + let intent = self.queue_intent(&provider, IntentKind::MetadataUpdate, intent_data)?; self.sync_until_intent_resolved(&provider, intent.id).await } @@ -923,7 +923,7 @@ impl MlsGroup { let intent_data: Vec = UpdateMetadataIntentData::new_update_group_image_url_square(group_image_url_square) .into(); - let intent = self.queue_intent(IntentKind::MetadataUpdate, intent_data)?; + let intent = self.queue_intent(&provider, IntentKind::MetadataUpdate, intent_data)?; self.sync_until_intent_resolved(&provider, intent.id).await } @@ -955,7 +955,7 @@ impl MlsGroup { } let intent_data: Vec = UpdateMetadataIntentData::new_update_group_pinned_frame_url(pinned_frame_url).into(); - let intent = self.queue_intent(IntentKind::MetadataUpdate, intent_data)?; + let intent = self.queue_intent(&provider, IntentKind::MetadataUpdate, intent_data)?; self.sync_until_intent_resolved(&provider, intent.id).await } @@ -1038,7 +1038,7 @@ impl MlsGroup { }; let intent_data: Vec = UpdateAdminListIntentData::new(intent_action_type, inbox_id).into(); - let intent = self.queue_intent(IntentKind::UpdateAdminList, intent_data)?; + let intent = self.queue_intent(&provider, IntentKind::UpdateAdminList, intent_data)?; self.sync_until_intent_resolved(&provider, intent.id).await } @@ -1101,9 +1101,9 @@ impl MlsGroup { /// Update this installation's leaf key in the group by creating a key update commit pub async fn key_update(&self) -> Result<(), GroupError> { - let intent = self.queue_intent(IntentKind::KeyUpdate, vec![])?; - self.sync_until_intent_resolved(&self.client.mls_provider()?, intent.id) - .await + let provider = self.client.mls_provider()?; + let intent = self.queue_intent(&provider, IntentKind::KeyUpdate, vec![])?; + self.sync_until_intent_resolved(&provider, intent.id).await } /// Checks if the the current user is active in the group. @@ -1624,7 +1624,7 @@ pub(crate) mod tests { async fn receive_group_invite(client: &FullXmtpClient) -> MlsGroup { client - .sync_welcomes(&client.store().conn().unwrap()) + .sync_welcomes(&client.mls_provider().unwrap()) .await .unwrap(); let mut groups = client.find_groups(GroupQueryArgs::default()).unwrap(); @@ -1772,7 +1772,7 @@ pub(crate) mod tests { // Get bola's version of the same group let bola_groups = bola - .sync_welcomes(&bola.store().conn().unwrap()) + .sync_welcomes(&bola.mls_provider().unwrap()) .await .unwrap(); let bola_group = bola_groups.first().unwrap(); @@ -1831,7 +1831,7 @@ pub(crate) mod tests { // Get bola's version of the same group let bola_groups = bola - .sync_welcomes(&bola.store().conn().unwrap()) + .sync_welcomes(&bola.mls_provider().unwrap()) .await .unwrap(); let bola_group = bola_groups.first().unwrap(); @@ -1941,7 +1941,7 @@ pub(crate) mod tests { force_add_member(&alix, &bo, &alix_group, &mut mls_group, &provider).await; // Bo should not be able to actually read this group - bo.sync_welcomes(&bo.store().conn().unwrap()).await.unwrap(); + bo.sync_welcomes(&bo.mls_provider().unwrap()).await.unwrap(); let groups = bo.find_groups(GroupQueryArgs::default()).unwrap(); assert_eq!(groups.len(), 0); assert_logged!("failed to create group from welcome", 1); @@ -2069,7 +2069,7 @@ pub(crate) mod tests { group.send_message(b"hello").await.expect("send message"); bola_client - .sync_welcomes(&bola_client.store().conn().unwrap()) + .sync_welcomes(&bola_client.mls_provider().unwrap()) .await .unwrap(); let bola_groups = bola_client.find_groups(GroupQueryArgs::default()).unwrap(); @@ -2426,7 +2426,7 @@ pub(crate) mod tests { .add_members_by_inbox_id(&[bola.inbox_id()]) .await .unwrap(); - bola.sync_welcomes(&bola.store().conn().unwrap()) + bola.sync_welcomes(&bola.mls_provider().unwrap()) .await .unwrap(); @@ -2597,7 +2597,7 @@ pub(crate) mod tests { .add_members(&[bola_wallet.get_address()]) .await .unwrap(); - bola.sync_welcomes(&bola.store().conn().unwrap()) + bola.sync_welcomes(&bola.mls_provider().unwrap()) .await .unwrap(); let bola_groups = bola.find_groups(GroupQueryArgs::default()).unwrap(); @@ -2679,7 +2679,7 @@ pub(crate) mod tests { .add_members(&[bola_wallet.get_address()]) .await .unwrap(); - bola.sync_welcomes(&bola.store().conn().unwrap()) + bola.sync_welcomes(&bola.mls_provider().unwrap()) .await .unwrap(); let bola_groups = bola.find_groups(GroupQueryArgs::default()).unwrap(); @@ -2697,7 +2697,7 @@ pub(crate) mod tests { assert!(super_admin_list.contains(&amal.inbox_id().to_string())); // Verify that bola can not add caro because they are not an admin - bola.sync_welcomes(&bola.store().conn().unwrap()) + bola.sync_welcomes(&bola.mls_provider().unwrap()) .await .unwrap(); let bola_groups = bola.find_groups(GroupQueryArgs::default()).unwrap(); @@ -2766,7 +2766,7 @@ pub(crate) mod tests { .contains(&bola.inbox_id().to_string())); // Verify that bola can not add charlie because they are not an admin - bola.sync_welcomes(&bola.store().conn().unwrap()) + bola.sync_welcomes(&bola.mls_provider().unwrap()) .await .unwrap(); let bola_groups = bola.find_groups(GroupQueryArgs::default()).unwrap(); @@ -2797,7 +2797,7 @@ pub(crate) mod tests { .add_members_by_inbox_id(&[bola.inbox_id()]) .await .unwrap(); - bola.sync_welcomes(&bola.store().conn().unwrap()) + bola.sync_welcomes(&bola.mls_provider().unwrap()) .await .unwrap(); let bola_groups = bola.find_groups(GroupQueryArgs::default()).unwrap(); @@ -2815,7 +2815,7 @@ pub(crate) mod tests { assert!(super_admin_list.contains(&amal.inbox_id().to_string())); // Verify that bola can not add caro as an admin because they are not a super admin - bola.sync_welcomes(&bola.store().conn().unwrap()) + bola.sync_welcomes(&bola.mls_provider().unwrap()) .await .unwrap(); let bola_groups = bola.find_groups(GroupQueryArgs::default()).unwrap(); @@ -3009,7 +3009,7 @@ pub(crate) mod tests { // Bola syncs groups - this will decrypt the Welcome, identify who added Bola // and then store that value on the group and insert into the database let bola_groups = bola - .sync_welcomes(&bola.store().conn().unwrap()) + .sync_welcomes(&bola.mls_provider().unwrap()) .await .unwrap(); @@ -3078,7 +3078,7 @@ pub(crate) mod tests { .unwrap(); // Step 3: Verify that Bola can update the group name, and amal sees the update - bola.sync_welcomes(&bola.store().conn().unwrap()) + bola.sync_welcomes(&bola.mls_provider().unwrap()) .await .unwrap(); let bola_groups = bola.find_groups(GroupQueryArgs::default()).unwrap(); @@ -3146,7 +3146,7 @@ pub(crate) mod tests { // Step 3: Bola attemps to add Caro, but fails because group is admin only let caro = ClientBuilder::new_test_client(&generate_local_wallet()).await; - bola.sync_welcomes(&bola.store().conn().unwrap()) + bola.sync_welcomes(&bola.mls_provider().unwrap()) .await .unwrap(); let bola_groups = bola.find_groups(GroupQueryArgs::default()).unwrap(); @@ -3309,7 +3309,7 @@ pub(crate) mod tests { assert_eq!(members.len(), 2); // Bola can message amal - let _ = bola.sync_welcomes(&bola.store().conn().unwrap()).await; + let _ = bola.sync_welcomes(&bola.mls_provider().unwrap()).await; let bola_groups = bola.find_groups(GroupQueryArgs::default()).unwrap(); let bola_dm: &MlsGroup<_> = bola_groups.first().unwrap(); @@ -3425,7 +3425,7 @@ pub(crate) mod tests { let alix_message = vec![1]; alix_group.send_message(&alix_message).await.unwrap(); bo_client - .sync_welcomes(&bo_client.store().conn().unwrap()) + .sync_welcomes(&bo_client.mls_provider().unwrap()) .await .unwrap(); let bo_groups = bo_client.find_groups(GroupQueryArgs::default()).unwrap(); @@ -3549,7 +3549,11 @@ pub(crate) mod tests { } group - .queue_intent(IntentKind::UpdateGroupMembership, intent_data.into()) + .queue_intent( + provider, + IntentKind::UpdateGroupMembership, + intent_data.into(), + ) .unwrap(); } @@ -3708,7 +3712,7 @@ pub(crate) mod tests { .await .unwrap(); - bola.sync_welcomes(&bola.store().conn().unwrap()) + bola.sync_welcomes(&bola.mls_provider().unwrap()) .await .unwrap(); let bola_groups = bola.find_groups(GroupQueryArgs::default()).unwrap(); @@ -3729,7 +3733,7 @@ pub(crate) mod tests { .await .unwrap(); - caro.sync_welcomes(&caro.store().conn().unwrap()) + caro.sync_welcomes(&caro.mls_provider().unwrap()) .await .unwrap(); let caro_groups = caro.find_groups(GroupQueryArgs::default()).unwrap(); @@ -3762,7 +3766,7 @@ pub(crate) mod tests { .await .unwrap(); - bo.sync_welcomes(&bo.store().conn().unwrap()).await.unwrap(); + bo.sync_welcomes(&bo.mls_provider().unwrap()).await.unwrap(); let bo_groups = bo.find_groups(GroupQueryArgs::default()).unwrap(); let bo_group = bo_groups.first().unwrap(); diff --git a/xmtp_mls/src/groups/scoped_client.rs b/xmtp_mls/src/groups/scoped_client.rs index 808fee0e0..df358c2cf 100644 --- a/xmtp_mls/src/groups/scoped_client.rs +++ b/xmtp_mls/src/groups/scoped_client.rs @@ -4,8 +4,9 @@ use crate::{ client::{ClientError, XmtpMlsLocalContext}, identity_updates::{InstallationDiff, InstallationDiffError}, intents::Intents, - storage::{DbConnection, EncryptedMessageStore}, + storage::{DbConnection, EncryptedMessageStore, StorageError}, subscriptions::LocalEvents, + types::InstallationId, verified_key_package_v2::VerifiedKeyPackageV2, xmtp_openmls_provider::XmtpOpenMlsProvider, Client, @@ -37,11 +38,11 @@ pub trait LocalScopedGroupClient: Send + Sync + Sized { self.context_ref().inbox_id() } - fn installation_id(&self) -> &[u8] { + fn installation_id(&self) -> InstallationId { self.context_ref().installation_public_key() } - fn mls_provider(&self) -> Result { + fn mls_provider(&self) -> Result { self.context_ref().mls_provider() } @@ -109,7 +110,7 @@ pub trait ScopedGroupClient: Sized { self.context_ref().installation_public_key() } - fn mls_provider(&self) -> Result { + fn mls_provider(&self) -> Result { self.context_ref().mls_provider() } @@ -276,7 +277,7 @@ where (**self).context_ref() } - fn mls_provider(&self) -> Result { + fn mls_provider(&self) -> Result { (**self).mls_provider() } @@ -370,7 +371,7 @@ where (**self).context_ref() } - fn mls_provider(&self) -> Result { + fn mls_provider(&self) -> Result { (**self).mls_provider() } diff --git a/xmtp_mls/src/groups/subscriptions.rs b/xmtp_mls/src/groups/subscriptions.rs index 2caf9bec9..2b5740de2 100644 --- a/xmtp_mls/src/groups/subscriptions.rs +++ b/xmtp_mls/src/groups/subscriptions.rs @@ -14,6 +14,7 @@ use crate::storage::refresh_state::EntityKind; use crate::storage::StorageError; use crate::subscriptions::MessagesStreamInfo; use crate::subscriptions::SubscribeError; +use crate::XmtpOpenMlsProvider; use crate::{retry::Retry, retry_async}; use prost::Message; use xmtp_proto::xmtp::mls::api::v1::GroupMessage; @@ -22,6 +23,7 @@ impl MlsGroup { /// Internal stream processing function pub(crate) async fn process_stream_entry( &self, + provider: &XmtpOpenMlsProvider, envelope: GroupMessage, ) -> Result { let msgv1 = extract_message_v1(envelope)?; @@ -45,7 +47,7 @@ impl MlsGroup { let msgv1 = &msgv1; self.context() .store() - .transaction_async(|provider| async move { + .transaction_async(provider, |provider| async move { let mut openmls_group = self.load_mls_group(&provider)?; // Attempt processing immediately, but fail if the message is not an Application Message @@ -78,7 +80,7 @@ impl MlsGroup { ); // Swallow errors here, since another process may have successfully saved the message // to the DB - if let Err(err) = self.sync_with_conn(&self.client.mls_provider()?).await { + if let Err(err) = self.sync_with_conn(provider).await { tracing::warn!( inbox_id = self.client.inbox_id(), group_id = hex::encode(&self.group_id), @@ -108,10 +110,8 @@ impl MlsGroup { // Load the message from the DB to handle cases where it may have been already processed in // another thread - let new_message = self - .context() - .store() - .conn()? + let new_message = provider + .conn_ref() .get_group_message_by_timestamp(&self.group_id, created_ns as i64)? .ok_or(SubscribeError::GroupMessageNotFound)?; @@ -133,20 +133,22 @@ impl MlsGroup { /// Converts some `SubscribeError` variants to an Option, if they are inconsequential. pub async fn process_streamed_group_message( &self, + provider: &XmtpOpenMlsProvider, envelope_bytes: Vec, ) -> Result { let envelope = GroupMessage::decode(envelope_bytes.as_slice())?; - self.process_stream_entry(envelope).await + self.process_stream_entry(provider, envelope).await } - pub async fn stream( - &self, + pub async fn stream<'a>( + &'a self, + provider: Option<&'a XmtpOpenMlsProvider>, ) -> Result< - impl Stream> + use<'_, ScopedClient>, + impl Stream> + use<'a, ScopedClient>, ClientError, > where - ::ApiClient: XmtpMlsStreams + 'static, + ::ApiClient: XmtpMlsStreams + 'a, { let group_list = HashMap::from([( self.group_id.clone(), @@ -155,7 +157,7 @@ impl MlsGroup { cursor: 0, }, )]); - stream_messages(&*self.client, Arc::new(group_list)).await + stream_messages(provider, &*self.client, Arc::new(group_list)).await } pub fn stream_with_callback( @@ -180,14 +182,16 @@ impl MlsGroup { } /// Stream messages from groups in `group_id_to_info` +// TODO: Note when to use a None provider #[tracing::instrument(level = "debug", skip_all)] -pub(crate) async fn stream_messages( - client: &ScopedClient, +pub(crate) async fn stream_messages<'a, ScopedClient>( + provider: Option<&'a XmtpOpenMlsProvider>, + client: &'a ScopedClient, group_id_to_info: Arc, MessagesStreamInfo>>, -) -> Result> + '_, ClientError> +) -> Result> + 'a, ClientError> where ScopedClient: ScopedGroupClient, - ::ApiClient: XmtpApi + XmtpMlsStreams + 'static, + ::ApiClient: XmtpApi + XmtpMlsStreams + 'a, { let filters: Vec = group_id_to_info .iter() @@ -214,7 +218,13 @@ where "Received message for a non-subscribed group".to_string(), ))?; let mls_group = MlsGroup::new(client, group_id, stream_info.convo_created_at_ns); - mls_group.process_stream_entry(envelope).await + + if let Some(p) = provider { + mls_group.process_stream_entry(p, envelope).await + } else { + let provider = mls_group.mls_provider()?; + mls_group.process_stream_entry(&provider, envelope).await + } } }) .inspect(|e| { @@ -243,7 +253,7 @@ where let (tx, rx) = oneshot::channel(); crate::spawn(Some(rx), async move { - let stream = stream_messages(&client, Arc::new(group_id_to_info)).await?; + let stream = stream_messages(None, &client, Arc::new(group_id_to_info)).await?; futures::pin_mut!(stream); let _ = tx.send(()); while let Some(message) = stream.next().await { @@ -296,8 +306,9 @@ pub(crate) mod tests { let message = messages.first().unwrap(); let mut message_bytes: Vec = Vec::new(); message.encode(&mut message_bytes).unwrap(); + let provider = amal.mls_provider().unwrap(); let message_again = amal_group - .process_streamed_group_message(message_bytes) + .process_streamed_group_message(&provider, message_bytes) .await; if let Ok(message) = message_again { @@ -327,7 +338,7 @@ pub(crate) mod tests { // Get bola's version of the same group let bola_groups = bola - .sync_welcomes(&bola.store().conn().unwrap()) + .sync_welcomes(&bola.mls_provider().unwrap()) .await .unwrap(); let bola_group = Arc::new(bola_groups.first().unwrap().clone()); @@ -338,7 +349,7 @@ pub(crate) mod tests { let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); let mut stream = UnboundedReceiverStream::new(rx); crate::spawn(None, async move { - let stream = bola_group_ptr.stream().await.unwrap(); + let stream = bola_group_ptr.stream(None).await.unwrap(); futures::pin_mut!(stream); while let Some(item) = stream.next().await { let _ = tx.send(item); @@ -380,7 +391,7 @@ pub(crate) mod tests { let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx); let group_ptr = group.clone(); crate::spawn(None, async move { - let stream = group_ptr.stream().await.unwrap(); + let stream = group_ptr.stream(None).await.unwrap(); futures::pin_mut!(stream); while let Some(item) = stream.next().await { let _ = tx.send(item); @@ -427,7 +438,7 @@ pub(crate) mod tests { let (start_tx, start_rx) = tokio::sync::oneshot::channel(); let mut stream = UnboundedReceiverStream::new(rx); crate::spawn(None, async move { - let stream = amal_group_ptr.stream().await.unwrap(); + let stream = amal_group_ptr.stream(None).await.unwrap(); let _ = start_tx.send(()); futures::pin_mut!(stream); while let Some(item) = stream.next().await { diff --git a/xmtp_mls/src/intents.rs b/xmtp_mls/src/intents.rs index 7f54daa6d..f1000c56d 100644 --- a/xmtp_mls/src/intents.rs +++ b/xmtp_mls/src/intents.rs @@ -52,6 +52,7 @@ impl Intents { /// apply the update after the provided `ProcessingFn` has completed successfully. pub(crate) async fn process_for_id( &self, + provider: &XmtpOpenMlsProvider, entity_id: &[u8], entity_kind: EntityKind, cursor: u64, @@ -66,7 +67,7 @@ impl Intents { + std::fmt::Display, { self.store() - .transaction_async(|provider| async move { + .transaction_async(provider, |provider| async move { let is_updated = provider .conn_ref() diff --git a/xmtp_mls/src/storage/encrypted_store/db_connection.rs b/xmtp_mls/src/storage/encrypted_store/db_connection.rs index e2ef9a970..d0ef3942a 100644 --- a/xmtp_mls/src/storage/encrypted_store/db_connection.rs +++ b/xmtp_mls/src/storage/encrypted_store/db_connection.rs @@ -51,6 +51,8 @@ where self.inner.lock() } + /// Internal-only API to get the underlying `diesel::Connection` reference + /// without a scope pub(super) fn inner_ref(&self) -> Arc> { self.inner.clone() } diff --git a/xmtp_mls/src/storage/encrypted_store/mod.rs b/xmtp_mls/src/storage/encrypted_store/mod.rs index 58ff1eff8..cf6117c50 100644 --- a/xmtp_mls/src/storage/encrypted_store/mod.rs +++ b/xmtp_mls/src/storage/encrypted_store/mod.rs @@ -198,6 +198,7 @@ pub mod private { } /// Start a new database transaction with the OpenMLS Provider from XMTP + /// with the provided connection /// # Arguments /// `fun`: Scoped closure providing a MLSProvider to carry out the transaction /// @@ -210,22 +211,25 @@ pub mod private { /// provider.conn().db_operation()?; /// }) /// ``` - pub fn transaction(&self, fun: F) -> Result + pub fn transaction( + &self, + provider: &XmtpOpenMlsProviderPrivate<::Connection>, + fun: F, + ) -> Result where F: FnOnce(&XmtpOpenMlsProviderPrivate<::Connection>) -> Result, E: From + From, { tracing::debug!("Transaction beginning"); - let connection = self.db.conn()?; { + let connection = provider.conn_ref(); let mut connection = connection.inner_mut_ref(); ::TransactionManager::begin_transaction(&mut *connection)?; } - let provider = XmtpOpenMlsProviderPrivate::new(connection); let conn = provider.conn_ref(); - match fun(&provider) { + match fun(provider) { Ok(value) => { conn.raw_query(|conn| { ::TransactionManager::commit_transaction(&mut *conn) @@ -260,20 +264,27 @@ pub mod private { /// provider.conn().db_operation()?; /// }).await /// ``` - pub async fn transaction_async(&self, fun: F) -> Result + pub async fn transaction_async( + &self, + provider: &XmtpOpenMlsProviderPrivate<::Connection>, + fun: F, + ) -> Result where F: FnOnce(XmtpOpenMlsProviderPrivate<::Connection>) -> Fut, Fut: futures::Future>, E: From + From, { tracing::debug!("Transaction async beginning"); - let db_connection = self.db.conn()?; { - let mut connection = db_connection.inner_mut_ref(); + let connection = provider.conn_ref(); + let mut connection = connection.inner_mut_ref(); ::TransactionManager::begin_transaction(&mut *connection)?; } - let local_connection = db_connection.inner_ref(); - let provider = XmtpOpenMlsProviderPrivate::new(db_connection); + + // essentially manually cloning Provider + let connection = DbConnectionPrivate::from_arc_mutex(provider.conn_ref().inner_ref()); + let local_connection = connection.inner_ref(); + let provider = XmtpOpenMlsProviderPrivate::new(connection); // the other connection is dropped in the closure // ensuring we have only one strong reference @@ -467,7 +478,7 @@ pub(crate) mod tests { identity::StoredIdentity, }, utils::test::{rand_vec, tmp_path}, - Fetch, Store, StreamHandle as _, + Fetch, Store, StreamHandle as _, XmtpOpenMlsProvider, }; /// Test harness that loads an Ephemeral store. @@ -651,11 +662,11 @@ pub(crate) mod tests { .unwrap(); let barrier = Arc::new(Barrier::new(2)); - + let provider = XmtpOpenMlsProvider::new(store.conn().unwrap()); let store_pointer = store.clone(); let barrier_pointer = barrier.clone(); let handle = std::thread::spawn(move || { - store_pointer.transaction(|provider| { + store_pointer.transaction(&provider, |provider| { let conn1 = provider.conn_ref(); StoredIdentity::new("correct".to_string(), rand_vec(), rand_vec()) .store(conn1) @@ -669,20 +680,22 @@ pub(crate) mod tests { }); let store_pointer = store.clone(); + let provider = XmtpOpenMlsProvider::new(store.conn().unwrap()); let handle2 = std::thread::spawn(move || { barrier.wait(); - let result = store_pointer.transaction(|provider| -> Result<(), anyhow::Error> { - let connection = provider.conn_ref(); - let group = StoredGroup::new( - b"should not exist".to_vec(), - 0, - GroupMembershipState::Allowed, - "goodbye".to_string(), - None, - ); - group.store(connection)?; - Ok(()) - }); + let result = + store_pointer.transaction(&provider, |provider| -> Result<(), anyhow::Error> { + let connection = provider.conn_ref(); + let group = StoredGroup::new( + b"should not exist".to_vec(), + 0, + GroupMembershipState::Allowed, + "goodbye".to_string(), + None, + ); + group.store(connection)?; + Ok(()) + }); barrier.wait(); result }); @@ -718,10 +731,10 @@ pub(crate) mod tests { .unwrap(); let store_pointer = store.clone(); - + let provider = XmtpOpenMlsProvider::new(store_pointer.conn().unwrap()); let handle = crate::spawn(None, async move { store_pointer - .transaction_async(|provider| async move { + .transaction_async(&provider, |provider| async move { let conn1 = provider.conn_ref(); StoredIdentity::new("crab".to_string(), rand_vec(), rand_vec()) .store(conn1) diff --git a/xmtp_mls/src/subscriptions.rs b/xmtp_mls/src/subscriptions.rs index 67606dae8..c8987a910 100644 --- a/xmtp_mls/src/subscriptions.rs +++ b/xmtp_mls/src/subscriptions.rs @@ -12,7 +12,10 @@ use xmtp_proto::{api_client::XmtpMlsStreams, xmtp::mls::api::v1::WelcomeMessage} use crate::{ client::{extract_welcome_message, ClientError}, - groups::{mls_sync::GroupMessageProcessingError, subscriptions, GroupError, MlsGroup}, + groups::{ + group_metadata::GroupMetadata, mls_sync::GroupMessageProcessingError, + scoped_client::ScopedGroupClient as _, subscriptions, GroupError, MlsGroup, + }, retry::{Retry, RetryableError}, retry_async, retryable, storage::{ @@ -21,7 +24,7 @@ use crate::{ group_message::StoredGroupMessage, StorageError, }, - Client, XmtpApi, + Client, XmtpApi, XmtpOpenMlsProvider, }; use thiserror::Error; @@ -204,6 +207,7 @@ where { async fn process_streamed_welcome( &self, + provider: &XmtpOpenMlsProvider, welcome: WelcomeMessage, ) -> Result, ClientError> { let welcome_v1 = extract_welcome_message(welcome)?; @@ -217,7 +221,7 @@ where let welcome_v1 = &welcome_v1; self.context .store() - .transaction_async(|provider| async move { + .transaction_async(provider, |provider| async move { MlsGroup::create_from_encrypted_welcome( Arc::new(self.clone()), &provider, @@ -232,7 +236,7 @@ where ); if let Some(err) = creation_result.as_ref().err() { - let conn = self.context.store().conn()?; + let conn = provider.conn_ref(); let result = conn.find_group_by_welcome_id(welcome_v1.id as i64); match result { Ok(Some(group)) => { @@ -257,71 +261,88 @@ where &self, envelope_bytes: Vec, ) -> Result, ClientError> { + let provider = self.mls_provider()?; let envelope = WelcomeMessage::decode(envelope_bytes.as_slice()) .map_err(|e| ClientError::Generic(e.to_string()))?; - let welcome = self.process_streamed_welcome(envelope).await?; + let welcome = self.process_streamed_welcome(&provider, envelope).await?; Ok(welcome) } #[tracing::instrument(level = "debug", skip_all)] - pub async fn stream_conversations( - &self, + pub async fn stream_conversations<'a>( + &'a self, + provider: Option<&'a XmtpOpenMlsProvider>, conversation_type: Option, - ) -> Result, SubscribeError>> + '_, ClientError> + ) -> Result, SubscribeError>> + 'a, ClientError> where ApiClient: XmtpMlsStreams, { - let event_queue = tokio_stream::wrappers::BroadcastStream::new( - self.local_events.subscribe(), - ) - .filter_map(|event| async { - crate::optify!(event, "Missed messages due to event queue lag") - .and_then(LocalEvents::group_filter) - .map(Result::Ok) - }); - - // Helper function for filtering Dm groups - let filter_group = move |group: Result, ClientError>| { - let conversation_type = &conversation_type; - // take care of any possible errors - let result = || -> Result<_, _> { - let group = group?; - let provider = group.client.context().mls_provider()?; - let metadata = group.metadata(&provider)?; - Ok((metadata, group)) - }; - let filtered = result().map(|(metadata, group)| { - conversation_type - .map_or(true, |ct| ct == metadata.conversation_type) - .then_some(group) - }); - futures::future::ready(filtered.transpose()) - }; - let installation_key = self.installation_public_key(); let id_cursor = 0; tracing::info!(inbox_id = self.inbox_id(), "Setting up conversation stream"); let subscription = self .api_client - .subscribe_welcome_messages(installation_key.into(), Some(id_cursor)) - .await?; + .subscribe_welcome_messages(installation_key.as_ref(), Some(id_cursor)) + .await? + .map(WelcomeOrGroup::::Welcome); + + let event_queue = + tokio_stream::wrappers::BroadcastStream::new(self.local_events.subscribe()) + .filter_map(|event| async { + crate::optify!(event, "Missed messages due to event queue lag") + .and_then(LocalEvents::group_filter) + .map(Result::Ok) + }) + .map(WelcomeOrGroup::::Group); + + let stream = futures::stream::select(event_queue, subscription); + let stream = stream.filter_map(move |either| async move { + tracing::info!( + inbox_id = self.inbox_id(), + installation_id = %self.installation_id(), + "Received conversation streaming payload" + ); + + let filtered = if let Some(p) = provider { + self.process_streamed_convo(p, either).await + } else { + let provider = self.mls_provider().ok()?; + self.process_streamed_convo(&provider, either).await + }; + let filtered = filtered.map(|(metadata, group)| { + conversation_type + .map_or(true, |ct| ct == metadata.conversation_type) + .then_some(group) + }); + filtered.transpose() + }); - let stream = subscription - .map(|welcome| async { - tracing::info!( - inbox_id = self.inbox_id(), - "Received conversation streaming payload" - ); - self.process_streamed_welcome(welcome?).await - }) - .filter_map(|v| async { Some(v.await) }); + Ok(stream) + } - Ok(futures::stream::select(stream, event_queue).filter_map(filter_group)) + async fn process_streamed_convo<'a, 'b>( + &'a self, + provider: &'b XmtpOpenMlsProvider, + welcome_or_group: WelcomeOrGroup, + ) -> Result<(GroupMetadata, MlsGroup>), SubscribeError> { + let group = match welcome_or_group { + WelcomeOrGroup::Welcome(welcome) => { + self.process_streamed_welcome(provider, welcome?).await? + } + WelcomeOrGroup::Group(group) => group?, + }; + let metadata = group.metadata(provider)?; + Ok((metadata, group)) } } +enum WelcomeOrGroup { + Group(Result>, SubscribeError>), + Welcome(Result), +} + impl Client where ApiClient: XmtpApi + XmtpMlsStreams + Send + Sync + 'static, @@ -335,7 +356,7 @@ where let (tx, rx) = oneshot::channel(); crate::spawn(Some(rx), async move { - let stream = client.stream_conversations(conversation_type).await?; + let stream = client.stream_conversations(None, conversation_type).await?; futures::pin_mut!(stream); let _ = tx.send(()); while let Some(convo) = stream.next().await { @@ -359,12 +380,11 @@ where "stream all messages" ); - let conn = self.store().conn()?; - self.sync_welcomes(&conn).await?; + let provider = self.mls_provider()?; + self.sync_welcomes(&provider).await?; - let mut group_id_to_info = self - .store() - .conn()? + let mut group_id_to_info = provider + .conn_ref() .find_groups(GroupQueryArgs::default().maybe_conversation_type(conversation_type))? .into_iter() .map(Into::into) @@ -372,13 +392,14 @@ where let stream = async_stream::stream! { let messages_stream = subscriptions::stream_messages( + Some(&provider), self, Arc::new(group_id_to_info.clone()) ) .await?; futures::pin_mut!(messages_stream); - let convo_stream = self.stream_conversations(conversation_type).await?; + let convo_stream = self.stream_conversations(Some(&provider), conversation_type).await?; futures::pin_mut!(convo_stream); @@ -419,6 +440,7 @@ where }, ); let new_messages_stream = match subscriptions::stream_messages( + Some(&provider), self, Arc::new(group_id_to_info.clone()) ).await { @@ -531,7 +553,7 @@ pub(crate) mod tests { let mut stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx); let bob_ptr = bob.clone(); crate::spawn(None, async move { - let bob_stream = bob_ptr.stream_conversations(None).await.unwrap(); + let bob_stream = bob_ptr.stream_conversations(None, None).await.unwrap(); futures::pin_mut!(bob_stream); while let Some(item) = bob_stream.next().await { let _ = tx.send(item); @@ -567,7 +589,7 @@ pub(crate) mod tests { .await .unwrap(); let bob_group = bob - .sync_welcomes(&bob.store().conn().unwrap()) + .sync_welcomes(&bob.mls_provider().unwrap()) .await .unwrap(); let bob_group = bob_group.first().unwrap(); @@ -576,7 +598,7 @@ pub(crate) mod tests { let notify_ptr = notify.clone(); let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); crate::spawn(None, async move { - let stream = alice_group.stream().await.unwrap(); + let stream = alice_group.stream(None).await.unwrap(); futures::pin_mut!(stream); while let Some(item) = stream.next().await { let _ = tx.send(item); @@ -876,7 +898,7 @@ pub(crate) mod tests { } // Verify syncing welcomes while streaming causes no issues - alix.sync_welcomes(&alix.store().conn().unwrap()) + alix.sync_welcomes(&alix.mls_provider().unwrap()) .await .unwrap(); let find_groups_results = alix.find_groups(GroupQueryArgs::default()).unwrap(); diff --git a/xmtp_mls/src/types.rs b/xmtp_mls/src/types.rs index 0a73a89b1..5a97a0ca3 100644 --- a/xmtp_mls/src/types.rs +++ b/xmtp_mls/src/types.rs @@ -1,2 +1,86 @@ pub type Address = String; -pub type InstallationId = String; + +use std::fmt; + +#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] +pub struct InstallationId([u8; 32]); + +impl fmt::Display for InstallationId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", hex::encode(self.0)) + } +} + +impl std::ops::Deref for InstallationId { + type Target = [u8; 32]; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl AsRef<[u8]> for InstallationId { + fn as_ref(&self) -> &[u8] { + &self.0 + } +} + +impl From for Vec { + fn from(value: InstallationId) -> Self { + value.0.to_vec() + } +} + +impl From<[u8; 32]> for InstallationId { + fn from(value: [u8; 32]) -> Self { + InstallationId(value) + } +} + +impl PartialEq> for InstallationId { + fn eq(&self, other: &Vec) -> bool { + self.0.eq(&other[..]) + } +} + +impl PartialEq for Vec { + fn eq(&self, other: &InstallationId) -> bool { + other.0.eq(&self[..]) + } +} + +impl PartialEq<&Vec> for InstallationId { + fn eq(&self, other: &&Vec) -> bool { + self.0.eq(&other[..]) + } +} + +impl PartialEq for &Vec { + fn eq(&self, other: &InstallationId) -> bool { + other.0.eq(&self[..]) + } +} + +impl PartialEq<[u8]> for InstallationId { + fn eq(&self, other: &[u8]) -> bool { + self.0.eq(other) + } +} + +impl PartialEq for [u8] { + fn eq(&self, other: &InstallationId) -> bool { + other.0.eq(self) + } +} + +impl PartialEq<[u8; 32]> for InstallationId { + fn eq(&self, other: &[u8; 32]) -> bool { + self.0.eq(other) + } +} + +impl PartialEq for [u8; 32] { + fn eq(&self, other: &InstallationId) -> bool { + other.0.eq(&self[..]) + } +} diff --git a/xmtp_mls/src/xmtp_openmls_provider.rs b/xmtp_mls/src/xmtp_openmls_provider.rs index 77dd51494..a5597ee31 100644 --- a/xmtp_mls/src/xmtp_openmls_provider.rs +++ b/xmtp_mls/src/xmtp_openmls_provider.rs @@ -19,7 +19,11 @@ impl XmtpOpenMlsProviderPrivate { } } - pub(crate) fn conn_ref(&self) -> &DbConnectionPrivate { + pub fn new_crypto() -> RustCrypto { + RustCrypto::default() + } + + pub fn conn_ref(&self) -> &DbConnectionPrivate { self.key_store.conn_ref() } }