diff --git a/bindings_ffi/src/mls.rs b/bindings_ffi/src/mls.rs index 339c9c23e..e8b1068e9 100644 --- a/bindings_ffi/src/mls.rs +++ b/bindings_ffi/src/mls.rs @@ -1317,13 +1317,13 @@ impl FfiConversation { Ok(()) } - pub async fn find_messages( + pub fn find_messages( &self, opts: FfiListMessagesOptions, ) -> Result, GenericError> { let delivery_status = opts.delivery_status.map(|status| status.into()); let direction = opts.direction.map(|dir| dir.into()); - let kind = match self.conversation_type().await? { + let kind = match self.conversation_type()? { FfiConversationType::Group => None, FfiConversationType::Dm => Some(GroupMessageKind::Application), FfiConversationType::Sync => None, @@ -1445,7 +1445,7 @@ impl FfiConversation { pub fn group_image_url_square(&self) -> Result { let provider = self.inner.mls_provider()?; - Ok(self.inner.group_image_url_square(&provider)?) + Ok(self.inner.group_image_url_square(provider)?) } pub async fn update_group_description( @@ -1461,7 +1461,7 @@ impl FfiConversation { pub fn group_description(&self) -> Result { let provider = self.inner.mls_provider()?; - Ok(self.inner.group_description(&provider)?) + Ok(self.inner.group_description(provider)?) } pub async fn update_group_pinned_frame_url( @@ -1593,9 +1593,9 @@ impl FfiConversation { self.inner.added_by_inbox_id().map_err(Into::into) } - pub async fn group_metadata(&self) -> Result, GenericError> { + pub fn group_metadata(&self) -> Result, GenericError> { let provider = self.inner.mls_provider()?; - let metadata = self.inner.metadata(&provider).await?; + let metadata = self.inner.metadata(provider)?; Ok(Arc::new(FfiConversationMetadata { inner: Arc::new(metadata), })) @@ -1605,9 +1605,9 @@ impl FfiConversation { self.inner.dm_inbox_id().map_err(Into::into) } - pub async fn conversation_type(&self) -> Result { + pub fn conversation_type(&self) -> Result { let provider = self.inner.mls_provider()?; - let conversation_type = self.inner.conversation_type(&provider).await?; + let conversation_type = self.inner.conversation_type(&provider)?; Ok(conversation_type.into()) } } @@ -2104,9 +2104,6 @@ mod tests { .await .unwrap(); - let conn = client.inner_client.context().store().conn().unwrap(); - conn.register_triggers(); - register_client(&ffi_inbox_owner, &client).await; client } @@ -2598,8 +2595,6 @@ mod tests { async fn test_can_stream_group_messages_for_updates() { let alix = new_test_client().await; let bo = new_test_client().await; - let alix_provider = alix.inner_client.mls_provider().unwrap(); - let bo_provider = bo.inner_client.mls_provider().unwrap(); // Stream all group messages let message_callbacks = Arc::new(RustStreamCallback::default()); @@ -2632,21 +2627,14 @@ mod tests { .unwrap(); let bo_group = &bo_groups[0]; bo_group.sync().await.unwrap(); - - // alix published + processed group creation and name update - assert_eq!(alix_provider.conn_ref().intents_published(), 2); - assert_eq!(alix_provider.conn_ref().intents_deleted(), 2); - bo_group .update_group_name("Old Name2".to_string()) .await .unwrap(); message_callbacks.wait_for_delivery(None).await.unwrap(); - assert_eq!(bo_provider.conn_ref().intents_published(), 1); alix_group.send(b"Hello there".to_vec()).await.unwrap(); message_callbacks.wait_for_delivery(None).await.unwrap(); - assert_eq!(alix_provider.conn_ref().intents_published(), 3); let dm = bo .conversations() @@ -2654,7 +2642,6 @@ mod tests { .await .unwrap(); dm.send(b"Hello again".to_vec()).await.unwrap(); - assert_eq!(bo_provider.conn_ref().intents_published(), 3); message_callbacks.wait_for_delivery(None).await.unwrap(); // Uncomment the following lines to add more group name updates @@ -2663,8 +2650,6 @@ mod tests { .await .unwrap(); message_callbacks.wait_for_delivery(None).await.unwrap(); - message_callbacks.wait_for_delivery(None).await.unwrap(); - assert_eq!(bo_provider.conn_ref().intents_published(), 4); assert_eq!(message_callbacks.message_count(), 6); @@ -2708,11 +2693,9 @@ mod tests { let bo_messages1 = bo_group1 .find_messages(FfiListMessagesOptions::default()) - .await .unwrap(); let bo_messages5 = bo_group5 .find_messages(FfiListMessagesOptions::default()) - .await .unwrap(); assert_eq!(bo_messages1.len(), 0); assert_eq!(bo_messages5.len(), 0); @@ -2724,11 +2707,9 @@ mod tests { let bo_messages1 = bo_group1 .find_messages(FfiListMessagesOptions::default()) - .await .unwrap(); let bo_messages5 = bo_group5 .find_messages(FfiListMessagesOptions::default()) - .await .unwrap(); assert_eq!(bo_messages1.len(), 1); assert_eq!(bo_messages5.len(), 1); @@ -2847,13 +2828,11 @@ mod tests { alix_group.sync().await.unwrap(); let alix_messages = alix_group .find_messages(FfiListMessagesOptions::default()) - .await .unwrap(); bo_group.sync().await.unwrap(); let bo_messages = bo_group .find_messages(FfiListMessagesOptions::default()) - .await .unwrap(); assert_eq!(bo_messages.len(), 9); assert_eq!(alix_messages.len(), 10); @@ -3037,19 +3016,15 @@ mod tests { // Get the message count for all the clients let caro_messages = caro_group .find_messages(FfiListMessagesOptions::default()) - .await .unwrap(); let alix_messages = alix_group .find_messages(FfiListMessagesOptions::default()) - .await .unwrap(); let bo_messages = bo_group .find_messages(FfiListMessagesOptions::default()) - .await .unwrap(); let bo2_messages = bo2_group .find_messages(FfiListMessagesOptions::default()) - .await .unwrap(); assert_eq!(caro_messages.len(), 5); @@ -3105,11 +3080,9 @@ mod tests { let alix_messages = alix_group .find_messages(FfiListMessagesOptions::default()) - .await .unwrap(); let bo_messages = bo_group .find_messages(FfiListMessagesOptions::default()) - .await .unwrap(); let alix_can_see_bo_message = alix_messages @@ -3216,7 +3189,6 @@ mod tests { let bo_messages = bo_group .find_messages(FfiListMessagesOptions::default()) - .await .unwrap(); assert_eq!(bo_messages.len(), 0); @@ -3232,12 +3204,8 @@ mod tests { let bo_messages = bo_group .find_messages(FfiListMessagesOptions::default()) - .await .unwrap(); - assert_eq!( - bo_messages.first().unwrap().kind, - FfiConversationMessageKind::MembershipChange - ); + assert!(bo_messages.first().unwrap().kind == FfiConversationMessageKind::MembershipChange); assert_eq!(bo_messages.len(), 1); let bo_members = bo_group.list_members().await.unwrap(); @@ -3295,7 +3263,6 @@ mod tests { let bo_messages1 = bo_group .find_messages(FfiListMessagesOptions::default()) - .await .unwrap(); assert_eq!(bo_messages1.len(), first_msg_check); @@ -3308,7 +3275,6 @@ mod tests { let alix_messages = alix_group .find_messages(FfiListMessagesOptions::default()) - .await .unwrap(); assert_eq!(alix_messages.len(), second_msg_check); @@ -3318,7 +3284,6 @@ mod tests { let bo_messages2 = bo_group .find_messages(FfiListMessagesOptions::default()) - .await .unwrap(); assert_eq!(bo_messages2.len(), second_msg_check); assert_eq!(message_callbacks.message_count(), second_msg_check as u32); @@ -4564,19 +4529,15 @@ mod tests { // Get messages for both participants in both conversations let alix_dm_messages = alix_dm .find_messages(FfiListMessagesOptions::default()) - .await .unwrap(); let bo_dm_messages = bo_dm .find_messages(FfiListMessagesOptions::default()) - .await .unwrap(); let alix_group_messages = alix_group .find_messages(FfiListMessagesOptions::default()) - .await .unwrap(); let bo_group_messages = bo_group .find_messages(FfiListMessagesOptions::default()) - .await .unwrap(); // Verify DM messages @@ -4697,7 +4658,6 @@ mod tests { .await .unwrap()[0] .find_messages(FfiListMessagesOptions::default()) - .await .unwrap(); let bo_dm_messages = client_b .conversations() @@ -4705,7 +4665,6 @@ mod tests { .await .unwrap()[0] .find_messages(FfiListMessagesOptions::default()) - .await .unwrap(); assert_eq!(alix_dm_messages[0].content, "Hello in DM".as_bytes()); assert_eq!(bo_dm_messages[0].content, "Hello in DM".as_bytes()); diff --git a/bindings_node/src/conversation.rs b/bindings_node/src/conversation.rs index b520af793..3cd1328be 100644 --- a/bindings_node/src/conversation.rs +++ b/bindings_node/src/conversation.rs @@ -161,7 +161,7 @@ impl Conversation { } #[napi] - pub async fn find_messages(&self, opts: Option) -> Result> { + pub fn find_messages(&self, opts: Option) -> Result> { let opts = opts.unwrap_or_default(); let group = MlsGroup::new( self.inner_client.clone(), @@ -171,7 +171,6 @@ impl Conversation { let provider = group.mls_provider().map_err(ErrorWrapper::from)?; let conversation_type = group .conversation_type(&provider) - .await .map_err(ErrorWrapper::from)?; let kind = match conversation_type { ConversationType::Group => None, @@ -251,7 +250,7 @@ impl Conversation { ); let admin_list = group - .admin_list(&group.mls_provider().map_err(ErrorWrapper::from)?) + .admin_list(group.mls_provider().map_err(ErrorWrapper::from)?) .map_err(ErrorWrapper::from)?; Ok(admin_list) @@ -266,7 +265,7 @@ impl Conversation { ); let super_admin_list = group - .super_admin_list(&group.mls_provider().map_err(ErrorWrapper::from)?) + .super_admin_list(group.mls_provider().map_err(ErrorWrapper::from)?) .map_err(ErrorWrapper::from)?; Ok(super_admin_list) @@ -452,7 +451,7 @@ impl Conversation { ); let group_name = group - .group_name(&group.mls_provider().map_err(ErrorWrapper::from)?) + .group_name(group.mls_provider().map_err(ErrorWrapper::from)?) .map_err(ErrorWrapper::from)?; Ok(group_name) @@ -483,7 +482,7 @@ impl Conversation { ); let group_image_url_square = group - .group_image_url_square(&group.mls_provider().map_err(ErrorWrapper::from)?) + .group_image_url_square(group.mls_provider().map_err(ErrorWrapper::from)?) .map_err(ErrorWrapper::from)?; Ok(group_image_url_square) @@ -514,7 +513,7 @@ impl Conversation { ); let group_description = group - .group_description(&group.mls_provider().map_err(ErrorWrapper::from)?) + .group_description(group.mls_provider().map_err(ErrorWrapper::from)?) .map_err(ErrorWrapper::from)?; Ok(group_description) @@ -545,7 +544,7 @@ impl Conversation { ); let group_pinned_frame_url = group - .group_pinned_frame_url(&group.mls_provider().map_err(ErrorWrapper::from)?) + .group_pinned_frame_url(group.mls_provider().map_err(ErrorWrapper::from)?) .map_err(ErrorWrapper::from)?; Ok(group_pinned_frame_url) @@ -588,7 +587,7 @@ impl Conversation { Ok( group - .is_active(&group.mls_provider().map_err(ErrorWrapper::from)?) + .is_active(group.mls_provider().map_err(ErrorWrapper::from)?) .map_err(ErrorWrapper::from)?, ) } @@ -605,7 +604,7 @@ impl Conversation { } #[napi] - pub async fn group_metadata(&self) -> Result { + pub fn group_metadata(&self) -> Result { let group = MlsGroup::new( self.inner_client.clone(), self.group_id.clone(), @@ -613,8 +612,7 @@ impl Conversation { ); let metadata = group - .metadata(&group.mls_provider().map_err(ErrorWrapper::from)?) - .await + .metadata(group.mls_provider().map_err(ErrorWrapper::from)?) .map_err(ErrorWrapper::from)?; Ok(GroupMetadata { inner: metadata }) diff --git a/bindings_node/test/Conversations.test.ts b/bindings_node/test/Conversations.test.ts index c6123bc6d..ee40b431d 100644 --- a/bindings_node/test/Conversations.test.ts +++ b/bindings_node/test/Conversations.test.ts @@ -54,16 +54,14 @@ describe('Conversations', () => { updateGroupPinnedFrameUrlPolicy: 0, }) expect(group.addedByInboxId()).toBe(client1.inboxId()) - expect((await group.findMessages()).length).toBe(1) + expect(group.findMessages().length).toBe(1) const members = await group.listMembers() expect(members.length).toBe(2) const memberInboxIds = members.map((member) => member.inboxId) expect(memberInboxIds).toContain(client1.inboxId()) expect(memberInboxIds).toContain(client2.inboxId()) - expect((await group.groupMetadata()).conversationType()).toBe('group') - expect((await group.groupMetadata()).creatorInboxId()).toBe( - client1.inboxId() - ) + expect(group.groupMetadata().conversationType()).toBe('group') + expect(group.groupMetadata().creatorInboxId()).toBe(client1.inboxId()) expect(group.consentState()).toBe(ConsentState.Allowed) @@ -200,16 +198,14 @@ describe('Conversations', () => { updateGroupPinnedFrameUrlPolicy: 0, }) expect(group.addedByInboxId()).toBe(client1.inboxId()) - expect((await group.findMessages()).length).toBe(0) + expect(group.findMessages().length).toBe(0) const members = await group.listMembers() expect(members.length).toBe(2) const memberInboxIds = members.map((member) => member.inboxId) expect(memberInboxIds).toContain(client1.inboxId()) expect(memberInboxIds).toContain(client2.inboxId()) - expect((await group.groupMetadata()).conversationType()).toBe('dm') - expect((await group.groupMetadata()).creatorInboxId()).toBe( - client1.inboxId() - ) + expect(group.groupMetadata().conversationType()).toBe('dm') + expect(group.groupMetadata().creatorInboxId()).toBe(client1.inboxId()) expect(group.consentState()).toBe(ConsentState.Allowed) diff --git a/bindings_wasm/src/conversation.rs b/bindings_wasm/src/conversation.rs index 42ca2f06e..af09d83b2 100644 --- a/bindings_wasm/src/conversation.rs +++ b/bindings_wasm/src/conversation.rs @@ -183,10 +183,7 @@ impl Conversation { } #[wasm_bindgen(js_name = findMessages)] - pub async fn find_messages( - &self, - opts: Option, - ) -> Result, JsError> { + pub fn find_messages(&self, opts: Option) -> Result, JsError> { let opts = opts.unwrap_or_default(); let group = self.to_mls_group(); let provider = group @@ -194,7 +191,6 @@ impl Conversation { .map_err(|e| JsError::new(&format!("{e}")))?; let conversation_type = group .conversation_type(&provider) - .await .map_err(|e| JsError::new(&format!("{e}")))?; let kind = match conversation_type { ConversationType::Group => None, @@ -245,7 +241,7 @@ impl Conversation { let group = self.to_mls_group(); let admin_list = group .admin_list( - &group + group .mls_provider() .map_err(|e| JsError::new(&format!("{e}")))?, ) @@ -259,7 +255,7 @@ impl Conversation { let group = self.to_mls_group(); let super_admin_list = group .super_admin_list( - &group + group .mls_provider() .map_err(|e| JsError::new(&format!("{e}")))?, ) @@ -405,7 +401,7 @@ impl Conversation { let group_name = group .group_name( - &group + group .mls_provider() .map_err(|e| JsError::new(&format!("{e}")))?, ) @@ -435,7 +431,7 @@ impl Conversation { let group_image_url_square = group .group_image_url_square( - &group + group .mls_provider() .map_err(|e| JsError::new(&format!("{e}")))?, ) @@ -462,7 +458,7 @@ impl Conversation { let group_description = group .group_description( - &group + group .mls_provider() .map_err(|e| JsError::new(&format!("{e}")))?, ) @@ -492,7 +488,7 @@ impl Conversation { let group_pinned_frame_url = group .group_pinned_frame_url( - &group + group .mls_provider() .map_err(|e| JsError::new(&format!("{e}")))?, ) @@ -512,7 +508,7 @@ impl Conversation { group .is_active( - &group + group .mls_provider() .map_err(|e| JsError::new(&format!("{e}")))?, ) @@ -529,15 +525,14 @@ impl Conversation { } #[wasm_bindgen(js_name = groupMetadata)] - pub async fn group_metadata(&self) -> Result { + pub fn group_metadata(&self) -> Result { let group = self.to_mls_group(); let metadata = group .metadata( - &group + group .mls_provider() .map_err(|e| JsError::new(&format!("{e}")))?, ) - .await .map_err(|e| JsError::new(&format!("{e}")))?; Ok(GroupMetadata { inner: metadata }) diff --git a/common/src/test.rs b/common/src/test.rs index 4cfb2442d..c11c692cb 100644 --- a/common/src/test.rs +++ b/common/src/test.rs @@ -38,7 +38,7 @@ pub fn logger() { .from_env_lossy() }; - let _ = tracing_subscriber::registry() + tracing_subscriber::registry() // structured JSON logger only if STRUCTURED=true .with(is_structured.then(|| { tracing_subscriber::fmt::layer() @@ -61,7 +61,7 @@ pub fn logger() { }) .with_filter(filter()) })) - .try_init(); + .init(); }); } diff --git a/examples/cli/serializable.rs b/examples/cli/serializable.rs index 545081638..c6ee793ce 100644 --- a/examples/cli/serializable.rs +++ b/examples/cli/serializable.rs @@ -31,12 +31,11 @@ impl SerializableGroup { let metadata = group .metadata( - &group + group .mls_provider() .expect("MLS Provider could not be created"), ) - .await - .unwrap(); + .expect("could not load metadata"); let permissions = group.permissions().expect("could not load permissions"); Self { diff --git a/xmtp_mls/src/client.rs b/xmtp_mls/src/client.rs index 02327f11b..6db9747da 100644 --- a/xmtp_mls/src/client.rs +++ b/xmtp_mls/src/client.rs @@ -873,6 +873,7 @@ where .map(|group| { let active_group_count = Arc::clone(&active_group_count); async move { + let mls_group = group.load_mls_group(provider)?; tracing::info!( inbox_id = self.inbox_id(), "[{}] syncing group", @@ -880,15 +881,12 @@ where ); tracing::info!( inbox_id = self.inbox_id(), - "[{}] syncing group", - self.inbox_id() + group_epoch = mls_group.epoch().as_u64(), + "current epoch for [{}] in sync_all_groups() is Epoch: [{}]", + self.inbox_id(), + mls_group.epoch() ); - let is_active = group - .load_mls_group_with_lock_async(provider, |mls_group| async move { - Ok::(mls_group.is_active()) - }) - .await?; - if is_active { + if mls_group.is_active() { group.maybe_update_installations(provider, None).await?; group.sync_with_conn(provider).await?; diff --git a/xmtp_mls/src/groups/intents.rs b/xmtp_mls/src/groups/intents.rs index 5c9863496..d756f204a 100644 --- a/xmtp_mls/src/groups/intents.rs +++ b/xmtp_mls/src/groups/intents.rs @@ -865,10 +865,9 @@ pub(crate) mod tests { }; let provider = group.client.mls_provider().unwrap(); - let decrypted_message = group - .load_mls_group_with_lock(&provider, |mut mls_group| { - Ok(mls_group.process_message(&provider, mls_message).unwrap()) - }) + let mut openmls_group = group.load_mls_group(&provider).unwrap(); + let decrypted_message = openmls_group + .process_message(&provider, mls_message) .unwrap(); let staged_commit = match decrypted_message.into_content() { diff --git a/xmtp_mls/src/groups/members.rs b/xmtp_mls/src/groups/members.rs index cfdf56e28..5ca53a40a 100644 --- a/xmtp_mls/src/groups/members.rs +++ b/xmtp_mls/src/groups/members.rs @@ -40,9 +40,9 @@ where &self, provider: &XmtpOpenMlsProvider, ) -> Result, GroupError> { - let group_membership = self.load_mls_group_with_lock(provider, |mls_group| { - Ok(extract_group_membership(mls_group.extensions())?) - })?; + let openmls_group = self.load_mls_group(provider)?; + // TODO: Replace with try_into from extensions + let group_membership = extract_group_membership(openmls_group.extensions())?; let requests = group_membership .members .into_iter() diff --git a/xmtp_mls/src/groups/mls_sync.rs b/xmtp_mls/src/groups/mls_sync.rs index 87109e81a..8b243002c 100644 --- a/xmtp_mls/src/groups/mls_sync.rs +++ b/xmtp_mls/src/groups/mls_sync.rs @@ -2,11 +2,11 @@ use super::{ build_extensions_for_admin_lists_update, build_extensions_for_metadata_update, build_extensions_for_permissions_update, build_group_membership_extension, intents::{ - Installation, IntentError, PostCommitAction, SendMessageIntentData, SendWelcomesAction, + Installation, PostCommitAction, SendMessageIntentData, SendWelcomesAction, UpdateAdminListIntentData, UpdateGroupMembershipIntentData, UpdatePermissionIntentData, }, validated_commit::{extract_group_membership, CommitValidationError}, - GroupError, HmacKey, MlsGroup, ScopedGroupClient, + GroupError, HmacKey, IntentError, MlsGroup, ScopedGroupClient, }; use crate::{ configuration::{ @@ -183,6 +183,7 @@ where inbox_id = self.client.inbox_id(), installation_id = %self.client.installation_id(), group_id = hex::encode(&self.group_id), + current_epoch = self.load_mls_group(&mls_provider)?.epoch().as_u64(), "[{}] syncing group", self.client.inbox_id() ); @@ -190,8 +191,10 @@ where inbox_id = self.client.inbox_id(), installation_id = %self.client.installation_id(), group_id = hex::encode(&self.group_id), - "current epoch for [{}] in sync()", + current_epoch = self.load_mls_group(&mls_provider)?.epoch().as_u64(), + "current epoch for [{}] in sync() is Epoch: [{}]", self.client.inbox_id(), + self.load_mls_group(&mls_provider)?.epoch() ); self.maybe_update_installations(&mls_provider, None).await?; @@ -355,265 +358,265 @@ where async fn process_own_message( &self, intent: StoredGroupIntent, + openmls_group: &mut OpenMlsGroup, provider: &XmtpOpenMlsProvider, message: ProtocolMessage, envelope: &GroupMessageV1, ) -> Result { - self.load_mls_group_with_lock_async(provider, |mut mls_group| async move { - let GroupMessageV1 { - created_ns: envelope_timestamp_ns, - id: ref msg_id, - .. - } = *envelope; - - if intent.state == IntentState::Committed { - return Ok(IntentState::Committed); - } - let message_epoch = message.epoch(); - let group_epoch = mls_group.epoch(); - debug!( - inbox_id = self.client.inbox_id(), - installation_id = %self.client.installation_id(), - group_id = hex::encode(&self.group_id), - msg_id, - intent.id, - intent.kind = %intent.kind, - "[{}]-[{}] processing own message for intent {} / {:?}, message_epoch: {}", - self.context().inbox_id(), - hex::encode(self.group_id.clone()), - intent.id, - intent.kind, - message_epoch - ); + let GroupMessageV1 { + created_ns: envelope_timestamp_ns, + id: ref msg_id, + .. + } = *envelope; + + if intent.state == IntentState::Committed { + return Ok(IntentState::Committed); + } + let message_epoch = message.epoch(); + let group_epoch = openmls_group.epoch(); + debug!( + inbox_id = self.client.inbox_id(), + installation_id = %self.client.installation_id(), + group_id = hex::encode(&self.group_id), + current_epoch = openmls_group.epoch().as_u64(), + msg_id, + intent.id, + intent.kind = %intent.kind, + "[{}]-[{}] processing own message for intent {} / {:?}, group epoch: {}, message_epoch: {}", + self.context().inbox_id(), + hex::encode(self.group_id.clone()), + intent.id, + intent.kind, + group_epoch, + message_epoch + ); - let conn = provider.conn_ref(); - match intent.kind { - IntentKind::KeyUpdate - | IntentKind::UpdateGroupMembership - | IntentKind::UpdateAdminList - | IntentKind::MetadataUpdate - | IntentKind::UpdatePermission => { - if let Some(published_in_epoch) = intent.published_in_epoch { - let published_in_epoch_u64 = published_in_epoch as u64; - let group_epoch_u64 = group_epoch.as_u64(); - - if published_in_epoch_u64 != group_epoch_u64 { - tracing::warn!( - inbox_id = self.client.inbox_id(), - installation_id = %self.client.installation_id(), + let conn = provider.conn_ref(); + match intent.kind { + IntentKind::KeyUpdate + | IntentKind::UpdateGroupMembership + | IntentKind::UpdateAdminList + | IntentKind::MetadataUpdate + | IntentKind::UpdatePermission => { + if let Some(published_in_epoch) = intent.published_in_epoch { + let published_in_epoch_u64 = published_in_epoch as u64; + let group_epoch_u64 = group_epoch.as_u64(); + + if published_in_epoch_u64 != group_epoch_u64 { + tracing::warn!( + inbox_id = self.client.inbox_id(), + installation_id = %self.client.installation_id(), group_id = hex::encode(&self.group_id), - msg_id, - intent.id, - intent.kind = %intent.kind, - "Intent was published in epoch {} but group is currently", - published_in_epoch_u64, - ); - return Ok(IntentState::ToPublish); - } + current_epoch = openmls_group.epoch().as_u64(), + msg_id, + intent.id, + intent.kind = %intent.kind, + "Intent was published in epoch {} but group is currently in epoch {}", + published_in_epoch_u64, + group_epoch_u64 + ); + return Ok(IntentState::ToPublish); } + } - let pending_commit = if let Some(staged_commit) = intent.staged_commit { - decode_staged_commit(staged_commit)? - } else { - return Err(GroupMessageProcessingError::IntentMissingStagedCommit); - }; + let pending_commit = if let Some(staged_commit) = intent.staged_commit { + decode_staged_commit(staged_commit)? + } else { + return Err(GroupMessageProcessingError::IntentMissingStagedCommit); + }; - tracing::info!( - "[{}] Validating commit for intent {}. Message timestamp: {}", - self.context().inbox_id(), - intent.id, - envelope_timestamp_ns - ); + tracing::info!( + "[{}] Validating commit for intent {}. Message timestamp: {}", + self.context().inbox_id(), + intent.id, + envelope_timestamp_ns + ); - let maybe_validated_commit = ValidatedCommit::from_staged_commit( - self.client.as_ref(), - conn, - &pending_commit, - &mls_group, - ) - .await; + let maybe_validated_commit = ValidatedCommit::from_staged_commit( + self.client.as_ref(), + conn, + &pending_commit, + openmls_group, + ) + .await; - if let Err(err) = maybe_validated_commit { - tracing::error!( - "Error validating commit for own message. Intent ID [{}]: {:?}", - intent.id, - err - ); - // Return before merging commit since it does not pass validation - // Return OK so that the group intent update is still written to the DB - return Ok(IntentState::Error); - } + if let Err(err) = maybe_validated_commit { + tracing::error!( + "Error validating commit for own message. Intent ID [{}]: {:?}", + intent.id, + err + ); + // Return before merging commit since it does not pass validation + // Return OK so that the group intent update is still written to the DB + return Ok(IntentState::Error); + } - let validated_commit = maybe_validated_commit.expect("Checked for error"); + let validated_commit = maybe_validated_commit.expect("Checked for error"); - tracing::info!( - "[{}] merging pending commit for intent {}", - self.context().inbox_id(), - intent.id - ); - if let Err(err) = mls_group.merge_staged_commit(&provider, pending_commit) { - tracing::error!("error merging commit: {}", err); - return Ok(IntentState::ToPublish); - } else { - // If no error committing the change, write a transcript message - self.save_transcript_message( - conn, - validated_commit, - envelope_timestamp_ns, - )?; - } + tracing::info!( + "[{}] merging pending commit for intent {}", + self.context().inbox_id(), + intent.id + ); + if let Err(err) = openmls_group.merge_staged_commit(&provider, pending_commit) { + tracing::error!("error merging commit: {}", err); + return Ok(IntentState::ToPublish); + } else { + // If no error committing the change, write a transcript message + self.save_transcript_message(conn, validated_commit, envelope_timestamp_ns)?; } - IntentKind::SendMessage => { - if !Self::is_valid_epoch( - self.context().inbox_id(), - intent.id, - group_epoch, - message_epoch, - MAX_PAST_EPOCHS, - ) { - return Ok(IntentState::ToPublish); - } - if let Some(id) = intent.message_id()? { - conn.set_delivery_status_to_published(&id, envelope_timestamp_ns)?; - } + } + IntentKind::SendMessage => { + if !Self::is_valid_epoch( + self.context().inbox_id(), + intent.id, + group_epoch, + message_epoch, + MAX_PAST_EPOCHS, + ) { + return Ok(IntentState::ToPublish); } - }; + if let Some(id) = intent.message_id()? { + conn.set_delivery_status_to_published(&id, envelope_timestamp_ns)?; + } + } + }; - Ok(IntentState::Committed) - }) - .await + Ok(IntentState::Committed) } #[tracing::instrument(level = "trace", skip_all)] async fn process_external_message( &self, + openmls_group: &mut OpenMlsGroup, provider: &XmtpOpenMlsProvider, message: PrivateMessageIn, envelope: &GroupMessageV1, ) -> Result<(), GroupMessageProcessingError> { - self.load_mls_group_with_lock_async(provider, |mut mls_group| async move { - let GroupMessageV1 { - created_ns: envelope_timestamp_ns, - id: ref msg_id, - .. - } = *envelope; + let GroupMessageV1 { + created_ns: envelope_timestamp_ns, + id: ref msg_id, + .. + } = *envelope; - let decrypted_message = mls_group.process_message(provider, message)?; - let (sender_inbox_id, sender_installation_id) = - extract_message_sender(&mut mls_group, &decrypted_message, envelope_timestamp_ns)?; + let decrypted_message = openmls_group.process_message(provider, message)?; + let (sender_inbox_id, sender_installation_id) = + extract_message_sender(openmls_group, &decrypted_message, envelope_timestamp_ns)?; - tracing::info!( - inbox_id = self.client.inbox_id(), - installation_id = %self.client.installation_id(),sender_inbox_id = sender_inbox_id, - sender_installation_id = hex::encode(&sender_installation_id), - group_id = hex::encode(&self.group_id), - current_epoch = mls_group.epoch().as_u64(), - msg_epoch = decrypted_message.epoch().as_u64(), - msg_group_id = hex::encode(decrypted_message.group_id().as_slice()), - msg_id, - "[{}] extracted sender inbox id: {}", - self.client.inbox_id(), - sender_inbox_id - ); + tracing::info!( + inbox_id = self.client.inbox_id(), + installation_id = %self.client.installation_id(), + sender_inbox_id = sender_inbox_id, + sender_installation_id = hex::encode(&sender_installation_id), + group_id = hex::encode(&self.group_id), + current_epoch = openmls_group.epoch().as_u64(), + msg_epoch = decrypted_message.epoch().as_u64(), + msg_group_id = hex::encode(decrypted_message.group_id().as_slice()), + msg_id, + "[{}] extracted sender inbox id: {}", + self.client.inbox_id(), + sender_inbox_id + ); - let (msg_epoch, msg_group_id) = ( - decrypted_message.epoch().as_u64(), - hex::encode(decrypted_message.group_id().as_slice()), - ); - match decrypted_message.into_content() { - ProcessedMessageContent::ApplicationMessage(application_message) => { - tracing::info!( - inbox_id = self.client.inbox_id(), - sender_inbox_id = sender_inbox_id, - sender_installation_id = hex::encode(&sender_installation_id), - installation_id = %self.client.installation_id(),group_id = hex::encode(&self.group_id), - current_epoch = mls_group.epoch().as_u64(), - msg_epoch, - msg_group_id, - msg_id, - "[{}] decoding application message", - self.context().inbox_id() - ); - let message_bytes = application_message.into_bytes(); - - let mut bytes = Bytes::from(message_bytes.clone()); - let envelope = PlaintextEnvelope::decode(&mut bytes)?; - - match envelope.content { - Some(Content::V1(V1 { - idempotency_key, - content, - })) => { - let message_id = - calculate_message_id(&self.group_id, &content, &idempotency_key); - StoredGroupMessage { - id: message_id, - group_id: self.group_id.clone(), - decrypted_message_bytes: content, - sent_at_ns: envelope_timestamp_ns as i64, - kind: GroupMessageKind::Application, - sender_installation_id, - sender_inbox_id, - delivery_status: DeliveryStatus::Published, - } - .store_or_ignore(provider.conn_ref())? + let (msg_epoch, msg_group_id) = ( + decrypted_message.epoch().as_u64(), + hex::encode(decrypted_message.group_id().as_slice()), + ); + match decrypted_message.into_content() { + ProcessedMessageContent::ApplicationMessage(application_message) => { + tracing::info!( + inbox_id = self.client.inbox_id(), + sender_inbox_id = sender_inbox_id, + sender_installation_id = hex::encode(&sender_installation_id), + installation_id = %self.client.installation_id(), + group_id = hex::encode(&self.group_id), + current_epoch = openmls_group.epoch().as_u64(), + msg_epoch, + msg_group_id, + msg_id, + "[{}] decoding application message", + self.context().inbox_id() + ); + let message_bytes = application_message.into_bytes(); + + let mut bytes = Bytes::from(message_bytes.clone()); + let envelope = PlaintextEnvelope::decode(&mut bytes)?; + + match envelope.content { + Some(Content::V1(V1 { + idempotency_key, + content, + })) => { + let message_id = + calculate_message_id(&self.group_id, &content, &idempotency_key); + StoredGroupMessage { + id: message_id, + group_id: self.group_id.clone(), + decrypted_message_bytes: content, + sent_at_ns: envelope_timestamp_ns as i64, + kind: GroupMessageKind::Application, + sender_installation_id, + sender_inbox_id, + delivery_status: DeliveryStatus::Published, } - Some(Content::V2(V2 { - idempotency_key, - message_type, - })) => { - match message_type { - Some(MessageType::DeviceSyncRequest(history_request)) => { - let content: DeviceSyncContent = - DeviceSyncContent::Request(history_request); - let content_bytes = serde_json::to_vec(&content)?; - let message_id = calculate_message_id( - &self.group_id, - &content_bytes, - &idempotency_key, - ); - - // store the request message - StoredGroupMessage { - id: message_id.clone(), - group_id: self.group_id.clone(), - decrypted_message_bytes: content_bytes, - sent_at_ns: envelope_timestamp_ns as i64, - kind: GroupMessageKind::Application, - sender_installation_id, - sender_inbox_id: sender_inbox_id.clone(), - delivery_status: DeliveryStatus::Published, - } - .store_or_ignore(provider.conn_ref())?; - - tracing::info!("Received a history request."); - let _ = self.client.local_events().send(LocalEvents::SyncMessage( - SyncMessage::Request { message_id }, - )); + .store_or_ignore(provider.conn_ref())? + } + Some(Content::V2(V2 { + idempotency_key, + message_type, + })) => { + match message_type { + Some(MessageType::DeviceSyncRequest(history_request)) => { + let content: DeviceSyncContent = + DeviceSyncContent::Request(history_request); + let content_bytes = serde_json::to_vec(&content)?; + let message_id = calculate_message_id( + &self.group_id, + &content_bytes, + &idempotency_key, + ); + + // store the request message + StoredGroupMessage { + id: message_id.clone(), + group_id: self.group_id.clone(), + decrypted_message_bytes: content_bytes, + sent_at_ns: envelope_timestamp_ns as i64, + kind: GroupMessageKind::Application, + sender_installation_id, + sender_inbox_id: sender_inbox_id.clone(), + delivery_status: DeliveryStatus::Published, } + .store_or_ignore(provider.conn_ref())?; - Some(MessageType::DeviceSyncReply(history_reply)) => { - let content: DeviceSyncContent = - DeviceSyncContent::Reply(history_reply); - let content_bytes = serde_json::to_vec(&content)?; - let message_id = calculate_message_id( - &self.group_id, - &content_bytes, - &idempotency_key, - ); - - // store the reply message - StoredGroupMessage { - id: message_id.clone(), - group_id: self.group_id.clone(), - decrypted_message_bytes: content_bytes, - sent_at_ns: envelope_timestamp_ns as i64, - kind: GroupMessageKind::Application, - sender_installation_id, - sender_inbox_id, - delivery_status: DeliveryStatus::Published, - } - .store_or_ignore(provider.conn_ref())?; + tracing::info!("Received a history request."); + let _ = self.client.local_events().send(LocalEvents::SyncMessage( + SyncMessage::Request { message_id }, + )); + } + + Some(MessageType::DeviceSyncReply(history_reply)) => { + let content: DeviceSyncContent = + DeviceSyncContent::Reply(history_reply); + let content_bytes = serde_json::to_vec(&content)?; + let message_id = calculate_message_id( + &self.group_id, + &content_bytes, + &idempotency_key, + ); + + // store the reply message + StoredGroupMessage { + id: message_id.clone(), + group_id: self.group_id.clone(), + decrypted_message_bytes: content_bytes, + sent_at_ns: envelope_timestamp_ns as i64, + kind: GroupMessageKind::Application, + sender_installation_id, + sender_inbox_id, + delivery_status: DeliveryStatus::Published, + } + .store_or_ignore(provider.conn_ref())?; tracing::info!("Received a history reply."); let _ = self.client.local_events().send(LocalEvents::SyncMessage( @@ -638,68 +641,70 @@ where return Err(GroupMessageProcessingError::InvalidPayload); } } - } - None => return Err(GroupMessageProcessingError::InvalidPayload), } + None => return Err(GroupMessageProcessingError::InvalidPayload), } - ProcessedMessageContent::ProposalMessage(_proposal_ptr) => { - // intentionally left blank. - } - ProcessedMessageContent::ExternalJoinProposalMessage(_external_proposal_ptr) => { - // intentionally left blank. - } - ProcessedMessageContent::StagedCommitMessage(staged_commit) => { - tracing::info!( - inbox_id = self.client.inbox_id(), - sender_inbox_id = sender_inbox_id, - installation_id = %self.client.installation_id(),sender_installation_id = hex::encode(&sender_installation_id), - group_id = hex::encode(&self.group_id), - current_epoch = mls_group.epoch().as_u64(), - msg_epoch, - msg_group_id, - msg_id, - "[{}] received staged commit. Merging and clearing any pending commits", - self.context().inbox_id() - ); + } + ProcessedMessageContent::ProposalMessage(_proposal_ptr) => { + // intentionally left blank. + } + ProcessedMessageContent::ExternalJoinProposalMessage(_external_proposal_ptr) => { + // intentionally left blank. + } + ProcessedMessageContent::StagedCommitMessage(staged_commit) => { + tracing::info!( + inbox_id = self.client.inbox_id(), + sender_inbox_id = sender_inbox_id, + installation_id = %self.client.installation_id(), + sender_installation_id = hex::encode(&sender_installation_id), + group_id = hex::encode(&self.group_id), + current_epoch = openmls_group.epoch().as_u64(), + msg_epoch, + msg_group_id, + msg_id, + "[{}] received staged commit. Merging and clearing any pending commits", + self.context().inbox_id() + ); - let sc = *staged_commit; + let sc = *staged_commit; - // Validate the commit - let validated_commit = ValidatedCommit::from_staged_commit( - self.client.as_ref(), - provider.conn_ref(), - &sc, - &mls_group, - ) - .await?; - tracing::info!( - inbox_id = self.client.inbox_id(), - sender_inbox_id = sender_inbox_id, - installation_id = %self.client.installation_id(),sender_installation_id = hex::encode(&sender_installation_id), - group_id = hex::encode(&self.group_id), - current_epoch = mls_group.epoch().as_u64(), - msg_epoch, - msg_group_id, - msg_id, - "[{}] staged commit is valid, will attempt to merge", - self.context().inbox_id() - ); - mls_group.merge_staged_commit(provider, sc)?; - self.save_transcript_message( - provider.conn_ref(), - validated_commit, - envelope_timestamp_ns, - )?; - } - }; + // Validate the commit + let validated_commit = ValidatedCommit::from_staged_commit( + self.client.as_ref(), + provider.conn_ref(), + &sc, + openmls_group, + ) + .await?; + tracing::info!( + inbox_id = self.client.inbox_id(), + sender_inbox_id = sender_inbox_id, + installation_id = %self.client.installation_id(), + sender_installation_id = hex::encode(&sender_installation_id), + group_id = hex::encode(&self.group_id), + current_epoch = openmls_group.epoch().as_u64(), + msg_epoch, + msg_group_id, + msg_id, + "[{}] staged commit is valid, will attempt to merge", + self.context().inbox_id() + ); + openmls_group.merge_staged_commit(provider, sc)?; + self.save_transcript_message( + provider.conn_ref(), + validated_commit, + envelope_timestamp_ns, + )?; + } + }; - Ok(()) - }).await + Ok(()) } #[tracing::instrument(level = "trace", skip_all)] pub(super) async fn process_message( &self, + openmls_group: &mut OpenMlsGroup, provider: &XmtpOpenMlsProvider, envelope: &GroupMessageV1, allow_epoch_increment: bool, @@ -723,6 +728,7 @@ where inbox_id = self.client.inbox_id(), installation_id = %self.client.installation_id(), group_id = hex::encode(&self.group_id), + current_epoch = openmls_group.epoch().as_u64(), msg_id = envelope.id, "Processing envelope with hash {:?}", hex::encode(sha256(envelope.data.as_slice())) @@ -736,6 +742,7 @@ where inbox_id = self.client.inbox_id(), installation_id = %self.client.installation_id(), group_id = hex::encode(&self.group_id), + current_epoch = openmls_group.epoch().as_u64(), msg_id = envelope.id, intent_id, intent.kind = %intent.kind, @@ -745,7 +752,7 @@ where intent_id ); match self - .process_own_message(intent, provider, message.into(), envelope) + .process_own_message(intent, openmls_group, provider, message.into(), envelope) .await? { IntentState::ToPublish => { @@ -770,12 +777,13 @@ where inbox_id = self.client.inbox_id(), installation_id = %self.client.installation_id(), group_id = hex::encode(&self.group_id), + current_epoch = openmls_group.epoch().as_u64(), msg_id = envelope.id, "client [{}] is about to process external envelope [{}]", self.client.inbox_id(), envelope.id ); - self.process_external_message(provider, message, envelope) + self.process_external_message(openmls_group, provider, message, envelope) .await } Err(err) => Err(GroupMessageProcessingError::Storage(err)), @@ -787,6 +795,7 @@ where &self, provider: &XmtpOpenMlsProvider, envelope: &GroupMessage, + openmls_group: &mut OpenMlsGroup, ) -> Result<(), GroupMessageProcessingError> { let msgv1 = match &envelope.version { Some(GroupMessageVersion::V1(value)) => value, @@ -802,6 +811,7 @@ where let last_cursor = provider .conn_ref() .get_last_cursor_for_id(&self.group_id, message_entity_kind)?; + tracing::info!("### last cursor --> [{:?}]", last_cursor); let should_skip_message = last_cursor > msgv1.id as i64; if should_skip_message { tracing::info!( @@ -827,7 +837,7 @@ where if !is_updated { return Err(ProcessIntentError::AlreadyProcessed(*cursor).into()); } - self.process_message(provider, msgv1, true).await?; + self.process_message(openmls_group, provider, msgv1, true).await?; Ok::<_, GroupMessageProcessingError>(()) }).await .inspect(|_| { @@ -855,11 +865,16 @@ where messages: Vec, provider: &XmtpOpenMlsProvider, ) -> Result<(), GroupError> { + let mut openmls_group = self.load_mls_group(provider)?; + let mut receive_errors: Vec = vec![]; for message in messages.into_iter() { let result = retry_async!( Retry::default(), - (async { self.consume_message(provider, &message).await }) + (async { + self.consume_message(provider, &message, &mut openmls_group) + .await + }) ); if let Err(e) = result { let is_retryable = e.is_retryable(); @@ -954,104 +969,103 @@ where &self, provider: &XmtpOpenMlsProvider, ) -> Result<(), GroupError> { - self.load_mls_group_with_lock_async(provider, |mut mls_group| async move { - let intents = provider.conn_ref().find_group_intents( - self.group_id.clone(), - Some(vec![IntentState::ToPublish]), - None, - )?; - - for intent in intents { - let result = retry_async!( - Retry::default(), - (async { - self.get_publish_intent_data(provider, &mut mls_group, &intent) - .await - }) - ); + let mut openmls_group = self.load_mls_group(provider)?; - match result { - Err(err) => { - tracing::error!(error = %err, "error getting publish intent data {:?}", err); - if (intent.publish_attempts + 1) as usize >= MAX_INTENT_PUBLISH_ATTEMPTS { - tracing::error!( - intent.id, - intent.kind = %intent.kind, - inbox_id = self.client.inbox_id(), - installation_id = %self.client.installation_id(),group_id = hex::encode(&self.group_id), - "intent {} has reached max publish attempts", intent.id); - // TODO: Eventually clean up errored attempts - provider - .conn_ref() - .set_group_intent_error_and_fail_msg(&intent)?; - } else { - provider - .conn_ref() - .increment_intent_publish_attempt_count(intent.id)?; - } - - return Err(err); - } - Ok(Some(PublishIntentData { - payload_to_publish, - post_commit_action, - staged_commit, - })) => { - let payload_slice = payload_to_publish.as_slice(); - let has_staged_commit = staged_commit.is_some(); - provider.conn_ref().set_group_intent_published( - intent.id, - sha256(payload_slice), - post_commit_action, - staged_commit, - mls_group.epoch().as_u64() as i64, - )?; - tracing::debug!( - inbox_id = self.client.inbox_id(), - installation_id = %self.client.installation_id(), - intent.id, - intent.kind = %intent.kind, - group_id = hex::encode(&self.group_id), - "client [{}] set stored intent [{}] to state `published`", - self.client.inbox_id(), - intent.id - ); + let intents = provider.conn_ref().find_group_intents( + self.group_id.clone(), + Some(vec![IntentState::ToPublish]), + None, + )?; - let messages = self.prepare_group_messages(vec![payload_slice])?;self.client - .api() - .send_group_messages(messages) - .await?; + for intent in intents { + let result = retry_async!( + Retry::default(), + (async { + self.get_publish_intent_data(provider, &mut openmls_group, &intent) + .await + }) + ); - tracing::info!( + match result { + Err(err) => { + tracing::error!(error = %err, "error getting publish intent data {:?}", err); + if (intent.publish_attempts + 1) as usize >= MAX_INTENT_PUBLISH_ATTEMPTS { + tracing::error!( intent.id, intent.kind = %intent.kind, inbox_id = self.client.inbox_id(), installation_id = %self.client.installation_id(), group_id = hex::encode(&self.group_id), - "[{}] published intent [{}] of type [{}]", - self.client.inbox_id(), - intent.id, - intent.kind - ); - if has_staged_commit { - tracing::info!("Commit sent. Stopping further publishes for this round"); - return Ok(()); - } + "intent {} has reached max publish attempts", intent.id); + // TODO: Eventually clean up errored attempts + provider + .conn_ref() + .set_group_intent_error_and_fail_msg(&intent)?; + } else { + provider + .conn_ref() + .increment_intent_publish_attempt_count(intent.id)?; } - Ok(None) => { - tracing::info!( - inbox_id = self.client.inbox_id(), - installation_id = %self.client.installation_id(), - "Skipping intent because no publish data returned" - ); - let deleter: &dyn Delete = provider.conn_ref(); - deleter.delete(intent.id)?; + + return Err(err); + } + Ok(Some(PublishIntentData { + payload_to_publish, + post_commit_action, + staged_commit, + })) => { + let payload_slice = payload_to_publish.as_slice(); + let has_staged_commit = staged_commit.is_some(); + provider.conn_ref().set_group_intent_published( + intent.id, + sha256(payload_slice), + post_commit_action, + staged_commit, + openmls_group.epoch().as_u64() as i64, + )?; + tracing::debug!( + inbox_id = self.client.inbox_id(), + installation_id = %self.client.installation_id(), + intent.id, + intent.kind = %intent.kind, + group_id = hex::encode(&self.group_id), + "client [{}] set stored intent [{}] to state `published`", + self.client.inbox_id(), + intent.id + ); + + let messages = self.prepare_group_messages(vec![payload_slice])?; + self.client.api().send_group_messages(messages).await?; + + tracing::info!( + intent.id, + intent.kind = %intent.kind, + inbox_id = self.client.inbox_id(), + installation_id = %self.client.installation_id(), + group_id = hex::encode(&self.group_id), + "[{}] published intent [{}] of type [{}]", + self.client.inbox_id(), + intent.id, + intent.kind + ); + if has_staged_commit { + tracing::info!("Commit sent. Stopping further publishes for this round"); + return Ok(()); } } + Ok(None) => { + tracing::info!( + inbox_id = self.client.inbox_id(), + installation_id = %self.client.installation_id(), + "Skipping intent because no publish data returned" + ); + let deleter: &dyn Delete = provider.conn_ref(); + deleter.delete(intent.id)?; + } } + } - Ok(()) - }).await + Ok(()) } // Takes a StoredGroupIntent and returns the payload and post commit data as a tuple @@ -1208,7 +1222,10 @@ where update_interval_ns: Option, ) -> Result<(), GroupError> { // determine how long of an interval in time to use before updating list - let interval_ns = update_interval_ns.unwrap_or(SYNC_UPDATE_INSTALLATIONS_INTERVAL_NS); + let interval_ns = match update_interval_ns { + Some(val) => val, + None => SYNC_UPDATE_INSTALLATIONS_INTERVAL_NS, + }; let now_ns = xmtp_common::time::now_ns(); let last_ns = provider @@ -1275,59 +1292,58 @@ where inbox_ids_to_add: &[InboxIdRef<'_>], inbox_ids_to_remove: &[InboxIdRef<'_>], ) -> Result { - self.load_mls_group_with_lock_async(provider, |mls_group| async move { - let existing_group_membership = extract_group_membership(mls_group.extensions())?; - // TODO:nm prevent querying for updates on members who are being removed - let mut inbox_ids = existing_group_membership.inbox_ids(); - inbox_ids.extend_from_slice(inbox_ids_to_add); - let conn = provider.conn_ref(); - // Load any missing updates from the network - load_identity_updates(self.client.api(), conn, &inbox_ids).await?; - - let latest_sequence_id_map = conn.get_latest_sequence_id(&inbox_ids as &[&str])?; - - // Get a list of all inbox IDs that have increased sequence_id for the group - let changed_inbox_ids = - inbox_ids - .iter() - .try_fold(HashMap::new(), |mut updates, inbox_id| { - match ( - latest_sequence_id_map.get(inbox_id as &str), - existing_group_membership.get(inbox_id), - ) { - // This is an update. We have a new sequence ID and an existing one - (Some(latest_sequence_id), Some(current_sequence_id)) => { - let latest_sequence_id_u64 = *latest_sequence_id as u64; - if latest_sequence_id_u64.gt(current_sequence_id) { - updates.insert(inbox_id.to_string(), latest_sequence_id_u64); - } - } - // This is for new additions to the group - (Some(latest_sequence_id), _) => { - // This is the case for net new members to the group - updates.insert(inbox_id.to_string(), *latest_sequence_id as u64); - } - (_, _) => { - tracing::warn!( - "Could not find existing sequence ID for inbox {}", - inbox_id - ); - return Err(GroupError::MissingSequenceId); + let mls_group = self.load_mls_group(provider)?; + let existing_group_membership = extract_group_membership(mls_group.extensions())?; + + // TODO:nm prevent querying for updates on members who are being removed + let mut inbox_ids = existing_group_membership.inbox_ids(); + inbox_ids.extend_from_slice(inbox_ids_to_add); + let conn = provider.conn_ref(); + // Load any missing updates from the network + load_identity_updates(self.client.api(), conn, &inbox_ids).await?; + + let latest_sequence_id_map = conn.get_latest_sequence_id(&inbox_ids as &[&str])?; + + // Get a list of all inbox IDs that have increased sequence_id for the group + let changed_inbox_ids = + inbox_ids + .iter() + .try_fold(HashMap::new(), |mut updates, inbox_id| { + match ( + latest_sequence_id_map.get(inbox_id as &str), + existing_group_membership.get(inbox_id), + ) { + // This is an update. We have a new sequence ID and an existing one + (Some(latest_sequence_id), Some(current_sequence_id)) => { + let latest_sequence_id_u64 = *latest_sequence_id as u64; + if latest_sequence_id_u64.gt(current_sequence_id) { + updates.insert(inbox_id.to_string(), latest_sequence_id_u64); } } + // This is for new additions to the group + (Some(latest_sequence_id), _) => { + // This is the case for net new members to the group + updates.insert(inbox_id.to_string(), *latest_sequence_id as u64); + } + (_, _) => { + tracing::warn!( + "Could not find existing sequence ID for inbox {}", + inbox_id + ); + return Err(GroupError::MissingSequenceId); + } + } + + Ok(updates) + })?; - Ok(updates) - })?; - - Ok(UpdateGroupMembershipIntentData::new( - changed_inbox_ids, - inbox_ids_to_remove - .iter() - .map(|s| s.to_string()) - .collect::>(), - )) - }) - .await + Ok(UpdateGroupMembershipIntentData::new( + changed_inbox_ids, + inbox_ids_to_remove + .iter() + .map(|s| s.to_string()) + .collect::>(), + )) } /** diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index 35662bc7d..9213c38ef 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -38,6 +38,7 @@ use tokio::sync::Mutex; use self::device_sync::DeviceSyncError; pub use self::group_permissions::PreconfiguredPolicies; +pub use self::intents::{AddressesOrInstallationIds, IntentError}; use self::scoped_client::ScopedGroupClient; use self::{ group_membership::GroupMembership, @@ -55,11 +56,12 @@ use self::{ use self::{ group_metadata::{GroupMetadata, GroupMetadataError}, group_permissions::PolicySet, - intents::IntentError, validated_commit::CommitValidationError, }; -use crate::storage::StorageError; +use std::{collections::HashSet, sync::Arc}; use xmtp_common::time::now_ns; +use xmtp_cryptography::signature::{sanitize_evm_addresses, AddressValidationError}; +use xmtp_id::{InboxId, InboxIdRef}; use xmtp_proto::xmtp::mls::{ api::v1::{ group_message::{Version as GroupMessageVersion, V1 as GroupMessageV1}, @@ -89,18 +91,13 @@ use crate::{ group::{ConversationType, GroupMembershipState, StoredGroup}, group_intent::IntentKind, group_message::{DeliveryStatus, GroupMessageKind, MsgQueryArgs, StoredGroupMessage}, - sql_key_store, + sql_key_store, StorageError, }, subscriptions::{LocalEventError, LocalEvents}, utils::id::calculate_message_id, xmtp_openmls_provider::XmtpOpenMlsProvider, - Store, MLS_COMMIT_LOCK, + Store, }; -use std::future::Future; -use std::{collections::HashSet, sync::Arc}; -use xmtp_cryptography::signature::{sanitize_evm_addresses, AddressValidationError}; -use xmtp_id::{InboxId, InboxIdRef}; - use xmtp_common::retry::RetryableError; #[derive(Debug, Error)] @@ -206,10 +203,6 @@ pub enum GroupError { IntentNotCommitted, #[error(transparent)] ProcessIntent(#[from] ProcessIntentError), - #[error("Failed to load lock")] - LockUnavailable, - #[error("Failed to acquire semaphore lock")] - LockFailedToAcquire, } impl RetryableError for GroupError { @@ -236,8 +229,6 @@ impl RetryableError for GroupError { Self::MessageHistory(err) => err.is_retryable(), Self::ProcessIntent(err) => err.is_retryable(), Self::LocalEvent(err) => err.is_retryable(), - Self::LockUnavailable => true, - Self::LockFailedToAcquire => true, Self::SyncFailedToWait => true, Self::GroupNotFound | Self::GroupMetadata(_) @@ -341,55 +332,16 @@ impl MlsGroup { // Load the stored OpenMLS group from the OpenMLS provider's keystore #[tracing::instrument(level = "trace", skip_all)] - pub(crate) fn load_mls_group_with_lock( + pub(crate) fn load_mls_group( &self, provider: impl OpenMlsProvider, - operation: F, - ) -> Result - where - F: FnOnce(OpenMlsGroup) -> Result, - { - // Get the group ID for locking - let group_id = self.group_id.clone(); - - // Acquire the lock synchronously using blocking_lock - let _lock = MLS_COMMIT_LOCK.get_lock_sync(group_id.clone()); - // Load the MLS group + ) -> Result { let mls_group = OpenMlsGroup::load(provider.storage(), &GroupId::from_slice(&self.group_id)) .map_err(|_| GroupError::GroupNotFound)? .ok_or(GroupError::GroupNotFound)?; - // Perform the operation with the MLS group - operation(mls_group) - } - - // Load the stored OpenMLS group from the OpenMLS provider's keystore - #[tracing::instrument(level = "trace", skip_all)] - pub(crate) async fn load_mls_group_with_lock_async( - &self, - provider: &XmtpOpenMlsProvider, - operation: F, - ) -> Result - where - F: FnOnce(OpenMlsGroup) -> Fut, - Fut: Future>, - E: From + From, - { - // Get the group ID for locking - let group_id = self.group_id.clone(); - - // Acquire the lock asynchronously - let _lock = MLS_COMMIT_LOCK.get_lock_async(group_id.clone()).await; - - // Load the MLS group - let mls_group = - OpenMlsGroup::load(provider.storage(), &GroupId::from_slice(&self.group_id)) - .map_err(crate::StorageError::from)? - .ok_or(crate::StorageError::NotFound("Group Not Found".into()))?; - - // Perform the operation with the MLS group - operation(mls_group).await.map_err(Into::into) + Ok(mls_group) } // Create a new group and save it to the DB @@ -896,7 +848,7 @@ impl MlsGroup { /// to perform these updates. pub async fn update_group_name(&self, group_name: String) -> Result<(), GroupError> { let provider = self.client.mls_provider()?; - if self.metadata(&provider).await?.conversation_type == ConversationType::Dm { + if self.metadata(&provider)?.conversation_type == ConversationType::Dm { return Err(GroupError::DmGroupMetadataForbidden); } let intent_data: Vec = @@ -914,7 +866,7 @@ impl MlsGroup { metadata_field: Option, ) -> Result<(), GroupError> { let provider = self.client.mls_provider()?; - if self.metadata(&provider).await?.conversation_type == ConversationType::Dm { + if self.metadata(&provider)?.conversation_type == ConversationType::Dm { return Err(GroupError::DmGroupMetadataForbidden); } if permission_update_type == PermissionUpdateType::UpdateMetadata @@ -936,7 +888,7 @@ impl MlsGroup { } /// Retrieves the group name from the group's mutable metadata extension. - pub fn group_name(&self, provider: &XmtpOpenMlsProvider) -> Result { + pub fn group_name(&self, provider: impl OpenMlsProvider) -> Result { let mutable_metadata = self.mutable_metadata(provider)?; match mutable_metadata .attributes @@ -955,7 +907,7 @@ impl MlsGroup { group_description: String, ) -> Result<(), GroupError> { let provider = self.client.mls_provider()?; - if self.metadata(&provider).await?.conversation_type == ConversationType::Dm { + if self.metadata(&provider)?.conversation_type == ConversationType::Dm { return Err(GroupError::DmGroupMetadataForbidden); } let intent_data: Vec = @@ -965,7 +917,7 @@ impl MlsGroup { self.sync_until_intent_resolved(&provider, intent.id).await } - pub fn group_description(&self, provider: &XmtpOpenMlsProvider) -> Result { + pub fn group_description(&self, provider: impl OpenMlsProvider) -> Result { let mutable_metadata = self.mutable_metadata(provider)?; match mutable_metadata .attributes @@ -984,7 +936,7 @@ impl MlsGroup { group_image_url_square: String, ) -> Result<(), GroupError> { let provider = self.client.mls_provider()?; - if self.metadata(&provider).await?.conversation_type == ConversationType::Dm { + if self.metadata(&provider)?.conversation_type == ConversationType::Dm { return Err(GroupError::DmGroupMetadataForbidden); } let intent_data: Vec = @@ -998,7 +950,7 @@ impl MlsGroup { /// Retrieves the image URL (square) of the group from the group's mutable metadata extension. pub fn group_image_url_square( &self, - provider: &XmtpOpenMlsProvider, + provider: impl OpenMlsProvider, ) -> Result { let mutable_metadata = self.mutable_metadata(provider)?; match mutable_metadata @@ -1017,7 +969,7 @@ impl MlsGroup { pinned_frame_url: String, ) -> Result<(), GroupError> { let provider = self.client.mls_provider()?; - if self.metadata(&provider).await?.conversation_type == ConversationType::Dm { + if self.metadata(&provider)?.conversation_type == ConversationType::Dm { return Err(GroupError::DmGroupMetadataForbidden); } let intent_data: Vec = @@ -1029,7 +981,7 @@ impl MlsGroup { pub fn group_pinned_frame_url( &self, - provider: &XmtpOpenMlsProvider, + provider: impl OpenMlsProvider, ) -> Result { let mutable_metadata = self.mutable_metadata(provider)?; match mutable_metadata @@ -1044,7 +996,7 @@ impl MlsGroup { } /// Retrieves the admin list of the group from the group's mutable metadata extension. - pub fn admin_list(&self, provider: &XmtpOpenMlsProvider) -> Result, GroupError> { + pub fn admin_list(&self, provider: impl OpenMlsProvider) -> Result, GroupError> { let mutable_metadata = self.mutable_metadata(provider)?; Ok(mutable_metadata.admin_list) } @@ -1052,7 +1004,7 @@ impl MlsGroup { /// Retrieves the super admin list of the group from the group's mutable metadata extension. pub fn super_admin_list( &self, - provider: &XmtpOpenMlsProvider, + provider: impl OpenMlsProvider, ) -> Result, GroupError> { let mutable_metadata = self.mutable_metadata(provider)?; Ok(mutable_metadata.super_admin_list) @@ -1062,7 +1014,7 @@ impl MlsGroup { pub fn is_admin( &self, inbox_id: String, - provider: &XmtpOpenMlsProvider, + provider: impl OpenMlsProvider, ) -> Result { let mutable_metadata = self.mutable_metadata(provider)?; Ok(mutable_metadata.admin_list.contains(&inbox_id)) @@ -1072,18 +1024,18 @@ impl MlsGroup { pub fn is_super_admin( &self, inbox_id: String, - provider: &XmtpOpenMlsProvider, + provider: impl OpenMlsProvider, ) -> Result { let mutable_metadata = self.mutable_metadata(provider)?; Ok(mutable_metadata.super_admin_list.contains(&inbox_id)) } /// Retrieves the conversation type of the group from the group's metadata extension. - pub async fn conversation_type( + pub fn conversation_type( &self, - provider: &XmtpOpenMlsProvider, + provider: impl OpenMlsProvider, ) -> Result { - let metadata = self.metadata(provider).await?; + let metadata = self.metadata(provider)?; Ok(metadata.conversation_type) } @@ -1094,7 +1046,7 @@ impl MlsGroup { inbox_id: String, ) -> Result<(), GroupError> { let provider = self.client.mls_provider()?; - if self.metadata(&provider).await?.conversation_type == ConversationType::Dm { + if self.metadata(&provider)?.conversation_type == ConversationType::Dm { return Err(GroupError::DmGroupMetadataForbidden); } let intent_action_type = match action_type { @@ -1175,40 +1127,35 @@ impl MlsGroup { self.sync_until_intent_resolved(&provider, intent.id).await } - /// Checks if the current user is active in the group. + /// Checks if the the current user is active in the group. /// /// If the current user has been kicked out of the group, `is_active` will return `false` - pub fn is_active(&self, provider: &XmtpOpenMlsProvider) -> Result { - self.load_mls_group_with_lock(provider, |mls_group| Ok(mls_group.is_active())) + pub fn is_active(&self, provider: impl OpenMlsProvider) -> Result { + let mls_group = self.load_mls_group(provider)?; + Ok(mls_group.is_active()) } /// Get the `GroupMetadata` of the group. - pub async fn metadata( - &self, - provider: &XmtpOpenMlsProvider, - ) -> Result { - self.load_mls_group_with_lock_async(provider, |mls_group| { - futures::future::ready(extract_group_metadata(&mls_group).map_err(Into::into)) - }) - .await + pub fn metadata(&self, provider: impl OpenMlsProvider) -> Result { + let mls_group = self.load_mls_group(provider)?; + Ok(extract_group_metadata(&mls_group)?) } /// Get the `GroupMutableMetadata` of the group. pub fn mutable_metadata( &self, - provider: &XmtpOpenMlsProvider, + provider: impl OpenMlsProvider, ) -> Result { - self.load_mls_group_with_lock(provider, |mls_group| { - Ok(GroupMutableMetadata::try_from(&mls_group)?) - }) + let mls_group = &self.load_mls_group(provider)?; + + Ok(mls_group.try_into()?) } pub fn permissions(&self) -> Result { let provider = self.mls_provider()?; + let mls_group = self.load_mls_group(&provider)?; - self.load_mls_group_with_lock(&provider, |mls_group| { - Ok(extract_group_permissions(&mls_group)?) - }) + Ok(extract_group_permissions(&mls_group)?) } /// Used for testing that dm group validation works as expected. @@ -1660,6 +1607,7 @@ pub(crate) mod tests { use diesel::connection::SimpleConnection; use futures::future::join_all; + use openmls::prelude::Member; use prost::Message; use std::sync::Arc; use wasm_bindgen_test::wasm_bindgen_test; @@ -1856,7 +1804,7 @@ pub(crate) mod tests { // Verify bola can see the group name let bola_group_name = bola_group - .group_name(&bola_group.mls_provider().unwrap()) + .group_name(bola_group.mls_provider().unwrap()) .unwrap(); assert_eq!(bola_group_name, ""); @@ -1928,19 +1876,15 @@ pub(crate) mod tests { // Check Amal's MLS group state. let amal_db = XmtpOpenMlsProvider::from(amal.context.store().conn().unwrap()); - let amal_members_len = amal_group - .load_mls_group_with_lock(&amal_db, |mls_group| Ok(mls_group.members().count())) - .unwrap(); - - assert_eq!(amal_members_len, 3); + let amal_mls_group = amal_group.load_mls_group(&amal_db).unwrap(); + let amal_members: Vec = amal_mls_group.members().collect(); + assert_eq!(amal_members.len(), 3); // Check Bola's MLS group state. let bola_db = XmtpOpenMlsProvider::from(bola.context.store().conn().unwrap()); - let bola_members_len = bola_group - .load_mls_group_with_lock(&bola_db, |mls_group| Ok(mls_group.members().count())) - .unwrap(); - - assert_eq!(bola_members_len, 3); + let bola_mls_group = bola_group.load_mls_group(&bola_db).unwrap(); + let bola_members: Vec = bola_mls_group.members().collect(); + assert_eq!(bola_members.len(), 3); let amal_uncommitted_intents = amal_db .conn_ref() @@ -2000,26 +1944,19 @@ pub(crate) mod tests { .unwrap(); let provider = alix.mls_provider().unwrap(); // Doctor the group membership - let mut mls_group = alix_group - .load_mls_group_with_lock(&provider, |mut mls_group| { - let mut existing_extensions = mls_group.extensions().clone(); - let mut group_membership = GroupMembership::new(); - group_membership.add("deadbeef".to_string(), 1); - existing_extensions - .add_or_replace(build_group_membership_extension(&group_membership)); - - mls_group - .update_group_context_extensions( - &provider, - existing_extensions.clone(), - &alix.identity().installation_keys, - ) - .unwrap(); - mls_group.merge_pending_commit(&provider).unwrap(); - - Ok(mls_group) // Return the updated group if necessary - }) + let mut mls_group = alix_group.load_mls_group(&provider).unwrap(); + let mut existing_extensions = mls_group.extensions().clone(); + let mut group_membership = GroupMembership::new(); + group_membership.add("deadbeef".to_string(), 1); + existing_extensions.add_or_replace(build_group_membership_extension(&group_membership)); + mls_group + .update_group_context_extensions( + &provider, + existing_extensions.clone(), + &alix.identity().installation_keys, + ) .unwrap(); + mls_group.merge_pending_commit(&provider).unwrap(); // Now add bo to the group force_add_member(&alix, &bo, &alix_group, &mut mls_group, &provider).await; @@ -2141,13 +2078,9 @@ pub(crate) mod tests { assert_eq!(messages.len(), 2); let provider: XmtpOpenMlsProvider = client.context.store().conn().unwrap().into(); - let pending_commit_is_none = group - .load_mls_group_with_lock(&provider, |mls_group| { - Ok(mls_group.pending_commit().is_none()) - }) - .unwrap(); - - assert!(pending_commit_is_none); + let mls_group = group.load_mls_group(&provider).unwrap(); + let pending_commit = mls_group.pending_commit(); + assert!(pending_commit.is_none()); group.send_message(b"hello").await.expect("send message"); @@ -2229,7 +2162,7 @@ pub(crate) mod tests { let bola_group = receive_group_invite(&bola).await; bola_group.sync().await.unwrap(); assert!(!bola_group - .is_active(&bola_group.mls_provider().unwrap()) + .is_active(bola_group.mls_provider().unwrap()) .unwrap()) } @@ -2322,12 +2255,8 @@ pub(crate) mod tests { assert!(new_installations_were_added.is_ok()); group.sync().await.unwrap(); - let num_members = group - .load_mls_group_with_lock(&provider, |mls_group| { - Ok(mls_group.members().collect::>().len()) - }) - .unwrap(); - + let mls_group = group.load_mls_group(&provider).unwrap(); + let num_members = mls_group.members().collect::>().len(); assert_eq!(num_members, 3); } @@ -2424,7 +2353,7 @@ pub(crate) mod tests { .unwrap(); let binding = amal_group - .mutable_metadata(&amal_group.mls_provider().unwrap()) + .mutable_metadata(amal_group.mls_provider().unwrap()) .expect("msg"); let amal_group_name: &String = binding .attributes @@ -2487,7 +2416,7 @@ pub(crate) mod tests { amal_group.sync().await.unwrap(); let group_mutable_metadata = amal_group - .mutable_metadata(&amal_group.mls_provider().unwrap()) + .mutable_metadata(amal_group.mls_provider().unwrap()) .unwrap(); assert!(group_mutable_metadata.attributes.len().eq(&4)); assert!(group_mutable_metadata @@ -2510,7 +2439,7 @@ pub(crate) mod tests { let bola_group = bola_groups.first().unwrap(); bola_group.sync().await.unwrap(); let group_mutable_metadata = bola_group - .mutable_metadata(&bola_group.mls_provider().unwrap()) + .mutable_metadata(bola_group.mls_provider().unwrap()) .unwrap(); assert!(group_mutable_metadata .attributes @@ -2529,7 +2458,7 @@ pub(crate) mod tests { // Verify amal group sees update amal_group.sync().await.unwrap(); let binding = amal_group - .mutable_metadata(&amal_group.mls_provider().unwrap()) + .mutable_metadata(amal_group.mls_provider().unwrap()) .expect("msg"); let amal_group_name: &String = binding .attributes @@ -2540,7 +2469,7 @@ pub(crate) mod tests { // Verify bola group sees update bola_group.sync().await.unwrap(); let binding = bola_group - .mutable_metadata(&bola_group.mls_provider().unwrap()) + .mutable_metadata(bola_group.mls_provider().unwrap()) .expect("msg"); let bola_group_name: &String = binding .attributes @@ -2557,7 +2486,7 @@ pub(crate) mod tests { // Verify bola group does not see an update bola_group.sync().await.unwrap(); let binding = bola_group - .mutable_metadata(&bola_group.mls_provider().unwrap()) + .mutable_metadata(bola_group.mls_provider().unwrap()) .expect("msg"); let bola_group_name: &String = binding .attributes @@ -2578,7 +2507,7 @@ pub(crate) mod tests { amal_group.sync().await.unwrap(); let group_mutable_metadata = amal_group - .mutable_metadata(&amal_group.mls_provider().unwrap()) + .mutable_metadata(amal_group.mls_provider().unwrap()) .unwrap(); assert!(group_mutable_metadata .attributes @@ -2595,7 +2524,7 @@ pub(crate) mod tests { // Verify amal group sees update amal_group.sync().await.unwrap(); let binding = amal_group - .mutable_metadata(&amal_group.mls_provider().unwrap()) + .mutable_metadata(amal_group.mls_provider().unwrap()) .expect("msg"); let amal_group_image_url: &String = binding .attributes @@ -2616,7 +2545,7 @@ pub(crate) mod tests { amal_group.sync().await.unwrap(); let group_mutable_metadata = amal_group - .mutable_metadata(&amal_group.mls_provider().unwrap()) + .mutable_metadata(amal_group.mls_provider().unwrap()) .unwrap(); assert!(group_mutable_metadata .attributes @@ -2633,7 +2562,7 @@ pub(crate) mod tests { // Verify amal group sees update amal_group.sync().await.unwrap(); let binding = amal_group - .mutable_metadata(&amal_group.mls_provider().unwrap()) + .mutable_metadata(amal_group.mls_provider().unwrap()) .expect("msg"); let amal_group_pinned_frame_url: &String = binding .attributes @@ -2656,7 +2585,7 @@ pub(crate) mod tests { amal_group.sync().await.unwrap(); let group_mutable_metadata = amal_group - .mutable_metadata(&amal_group.mls_provider().unwrap()) + .mutable_metadata(amal_group.mls_provider().unwrap()) .unwrap(); assert!(group_mutable_metadata .attributes @@ -2677,7 +2606,7 @@ pub(crate) mod tests { let bola_group = bola_groups.first().unwrap(); bola_group.sync().await.unwrap(); let group_mutable_metadata = bola_group - .mutable_metadata(&bola_group.mls_provider().unwrap()) + .mutable_metadata(bola_group.mls_provider().unwrap()) .unwrap(); assert!(group_mutable_metadata .attributes @@ -2694,7 +2623,7 @@ pub(crate) mod tests { // Verify amal group sees update amal_group.sync().await.unwrap(); let binding = amal_group - .mutable_metadata(&amal_group.mls_provider().unwrap()) + .mutable_metadata(amal_group.mls_provider().unwrap()) .unwrap(); let amal_group_name: &String = binding .attributes @@ -2705,7 +2634,7 @@ pub(crate) mod tests { // Verify bola group sees update bola_group.sync().await.unwrap(); let binding = bola_group - .mutable_metadata(&bola_group.mls_provider().unwrap()) + .mutable_metadata(bola_group.mls_provider().unwrap()) .expect("msg"); let bola_group_name: &String = binding .attributes @@ -2722,7 +2651,7 @@ pub(crate) mod tests { // Verify amal group sees an update amal_group.sync().await.unwrap(); let binding = amal_group - .mutable_metadata(&amal_group.mls_provider().unwrap()) + .mutable_metadata(amal_group.mls_provider().unwrap()) .expect("msg"); let amal_group_name: &String = binding .attributes @@ -2789,13 +2718,13 @@ pub(crate) mod tests { bola_group.sync().await.unwrap(); assert_eq!( bola_group - .admin_list(&bola_group.mls_provider().unwrap()) + .admin_list(bola_group.mls_provider().unwrap()) .unwrap() .len(), 1 ); assert!(bola_group - .admin_list(&bola_group.mls_provider().unwrap()) + .admin_list(bola_group.mls_provider().unwrap()) .unwrap() .contains(&bola.inbox_id().to_string())); @@ -2826,13 +2755,13 @@ pub(crate) mod tests { bola_group.sync().await.unwrap(); assert_eq!( bola_group - .admin_list(&bola_group.mls_provider().unwrap()) + .admin_list(bola_group.mls_provider().unwrap()) .unwrap() .len(), 0 ); assert!(!bola_group - .admin_list(&bola_group.mls_provider().unwrap()) + .admin_list(bola_group.mls_provider().unwrap()) .unwrap() .contains(&bola.inbox_id().to_string())); @@ -3110,14 +3039,13 @@ pub(crate) mod tests { amal_group.sync().await.unwrap(); let mutable_metadata = amal_group - .mutable_metadata(&amal_group.mls_provider().unwrap()) + .mutable_metadata(amal_group.mls_provider().unwrap()) .unwrap(); assert_eq!(mutable_metadata.super_admin_list.len(), 1); assert_eq!(mutable_metadata.super_admin_list[0], amal.inbox_id()); let protected_metadata: GroupMetadata = amal_group - .metadata(&amal_group.mls_provider().unwrap()) - .await + .metadata(amal_group.mls_provider().unwrap()) .unwrap(); assert_eq!( protected_metadata.conversation_type, @@ -3157,7 +3085,7 @@ pub(crate) mod tests { .unwrap(); amal_group.sync().await.unwrap(); let name = amal_group - .group_name(&amal_group.mls_provider().unwrap()) + .group_name(amal_group.mls_provider().unwrap()) .unwrap(); assert_eq!(name, "Name Update 1"); @@ -3178,7 +3106,7 @@ pub(crate) mod tests { amal_group.sync().await.unwrap(); bola_group.sync().await.unwrap(); let binding = amal_group - .mutable_metadata(&amal_group.mls_provider().unwrap()) + .mutable_metadata(amal_group.mls_provider().unwrap()) .expect("msg"); let amal_group_name: &String = binding .attributes @@ -3186,7 +3114,7 @@ pub(crate) mod tests { .unwrap(); assert_eq!(amal_group_name, "Name Update 2"); let binding = bola_group - .mutable_metadata(&bola_group.mls_provider().unwrap()) + .mutable_metadata(bola_group.mls_provider().unwrap()) .expect("msg"); let bola_group_name: &String = binding .attributes @@ -3397,16 +3325,16 @@ pub(crate) mod tests { amal_dm.sync().await.unwrap(); bola_dm.sync().await.unwrap(); let is_amal_admin = amal_dm - .is_admin(amal.inbox_id().to_string(), &amal.mls_provider().unwrap()) + .is_admin(amal.inbox_id().to_string(), amal.mls_provider().unwrap()) .unwrap(); let is_bola_admin = amal_dm - .is_admin(bola.inbox_id().to_string(), &bola.mls_provider().unwrap()) + .is_admin(bola.inbox_id().to_string(), bola.mls_provider().unwrap()) .unwrap(); let is_amal_super_admin = amal_dm - .is_super_admin(amal.inbox_id().to_string(), &amal.mls_provider().unwrap()) + .is_super_admin(amal.inbox_id().to_string(), amal.mls_provider().unwrap()) .unwrap(); let is_bola_super_admin = amal_dm - .is_super_admin(bola.inbox_id().to_string(), &bola.mls_provider().unwrap()) + .is_super_admin(bola.inbox_id().to_string(), bola.mls_provider().unwrap()) .unwrap(); assert!(!is_amal_admin); assert!(!is_bola_admin); @@ -3728,8 +3656,9 @@ pub(crate) mod tests { panic!("wrong message format") }; let provider = client.mls_provider().unwrap(); + let mut openmls_group = group.load_mls_group(&provider).unwrap(); let process_result = group - .process_message(&provider, &first_message, false) + .process_message(&mut openmls_group, &provider, &first_message, false) .await; assert_err!( @@ -3873,11 +3802,14 @@ pub(crate) mod tests { None, ) .unwrap(); - assert!(valid_dm_group - .load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group| { - validate_dm_group(&client, &mls_group, added_by_inbox) - }) - .is_ok()); + assert!(validate_dm_group( + &client, + &valid_dm_group + .load_mls_group(client.mls_provider().unwrap()) + .unwrap(), + added_by_inbox + ) + .is_ok()); // Test case 2: Invalid conversation type let invalid_protected_metadata = @@ -3892,11 +3824,10 @@ pub(crate) mod tests { ) .unwrap(); assert!(matches!( - invalid_type_group.load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group| - validate_dm_group(&client, &mls_group, added_by_inbox) - ), + validate_dm_group(&client, &invalid_type_group.load_mls_group(client.mls_provider().unwrap()).unwrap(), added_by_inbox), Err(GroupError::Generic(msg)) if msg.contains("Invalid conversation type") )); + // Test case 3: Missing DmMembers // This case is not easily testable with the current structure, as DmMembers are set in the protected metadata @@ -3914,9 +3845,7 @@ pub(crate) mod tests { ) .unwrap(); assert!(matches!( - mismatched_dm_members_group.load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group| - validate_dm_group(&client, &mls_group, added_by_inbox) - ), + validate_dm_group(&client, &mismatched_dm_members_group.load_mls_group(client.mls_provider().unwrap()).unwrap(), added_by_inbox), Err(GroupError::Generic(msg)) if msg.contains("DM members do not match expected inboxes") )); @@ -3936,9 +3865,7 @@ pub(crate) mod tests { ) .unwrap(); assert!(matches!( - non_empty_admin_list_group.load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group| - validate_dm_group(&client, &mls_group, added_by_inbox) - ), + validate_dm_group(&client, &non_empty_admin_list_group.load_mls_group(client.mls_provider().unwrap()).unwrap(), added_by_inbox), Err(GroupError::Generic(msg)) if msg.contains("DM group must have empty admin and super admin lists") )); @@ -3957,9 +3884,11 @@ pub(crate) mod tests { ) .unwrap(); assert!(matches!( - invalid_permissions_group.load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group| - validate_dm_group(&client, &mls_group, added_by_inbox) - ), + validate_dm_group( + &client, + &invalid_permissions_group.load_mls_group(client.mls_provider().unwrap()).unwrap(), + added_by_inbox + ), Err(GroupError::Generic(msg)) if msg.contains("Invalid permissions for DM group") )); } diff --git a/xmtp_mls/src/groups/subscriptions.rs b/xmtp_mls/src/groups/subscriptions.rs index 616a0f6a3..eb354fe3c 100644 --- a/xmtp_mls/src/groups/subscriptions.rs +++ b/xmtp_mls/src/groups/subscriptions.rs @@ -48,14 +48,21 @@ impl MlsGroup { self.context() .store() .transaction_async(provider, |provider| async move { + let mut openmls_group = self.load_mls_group(provider)?; + + // Attempt processing immediately, but fail if the message is not an Application Message + // Returning an error should roll back the DB tx tracing::info!( inbox_id = self.client.inbox_id(), group_id = hex::encode(&self.group_id), + current_epoch = openmls_group.epoch().as_u64(), msg_id = msgv1.id, - "current epoch for [{}] in process_stream_entry()", + "current epoch for [{}] in process_stream_entry() is Epoch: [{}]", client_id, + openmls_group.epoch() ); - self.process_message(provider, msgv1, false) + + self.process_message(&mut openmls_group, provider, msgv1, false) .await // NOTE: We want to make sure we retry an error in process_message .map_err(SubscribeError::ReceiveGroup) diff --git a/xmtp_mls/src/lib.rs b/xmtp_mls/src/lib.rs index 164e79af6..d287fd676 100644 --- a/xmtp_mls/src/lib.rs +++ b/xmtp_mls/src/lib.rs @@ -20,100 +20,12 @@ pub mod verified_key_package_v2; mod xmtp_openmls_provider; pub use client::{Client, Network}; -use std::collections::HashMap; -use std::sync::{Arc, LazyLock, Mutex}; use storage::{DuplicateItem, StorageError}; -use tokio::sync::{OwnedSemaphorePermit, Semaphore}; pub use xmtp_openmls_provider::XmtpOpenMlsProvider; pub use xmtp_id::InboxOwner; pub use xmtp_proto::api_client::trait_impls::*; -/// A manager for group-specific semaphores -#[derive(Debug)] -pub struct GroupCommitLock { - // Storage for group-specific semaphores - locks: Mutex, Arc>>, -} - -impl Default for GroupCommitLock { - fn default() -> Self { - Self::new() - } -} -impl GroupCommitLock { - /// Create a new `GroupCommitLock` - pub fn new() -> Self { - Self { - locks: Mutex::new(HashMap::new()), - } - } - - /// Get or create a semaphore for a specific group and acquire it, returning a guard - pub async fn get_lock_async(&self, group_id: Vec) -> Result { - let semaphore = { - match self.locks.lock() { - Ok(mut locks) => locks - .entry(group_id) - .or_insert_with(|| Arc::new(Semaphore::new(1))) - .clone(), - Err(err) => { - eprintln!("Failed to lock the mutex: {}", err); - return Err(GroupError::LockUnavailable); - } - } - }; - - let semaphore_clone = semaphore.clone(); - let permit = match semaphore.acquire_owned().await { - Ok(permit) => permit, - Err(err) => { - eprintln!("Failed to acquire semaphore permit: {}", err); - return Err(GroupError::LockFailedToAcquire); - } - }; - Ok(SemaphoreGuard { - _permit: permit, - _semaphore: semaphore_clone, - }) - } - - /// Get or create a semaphore for a specific group and acquire it synchronously - pub fn get_lock_sync(&self, group_id: Vec) -> Result { - let semaphore = { - match self.locks.lock() { - Ok(mut locks) => locks - .entry(group_id) - .or_insert_with(|| Arc::new(Semaphore::new(1))) - .clone(), - Err(err) => { - eprintln!("Failed to lock the mutex: {}", err); - return Err(GroupError::LockUnavailable); - } - } - }; - - // Synchronously acquire the permit - let permit = semaphore - .clone() - .try_acquire_owned() - .map_err(|_| GroupError::LockUnavailable)?; - Ok(SemaphoreGuard { - _permit: permit, - _semaphore: semaphore, // semaphore is now valid because we cloned it earlier - }) - } -} - -/// A guard that releases the semaphore when dropped -pub struct SemaphoreGuard { - _permit: OwnedSemaphorePermit, - _semaphore: Arc, -} - -// Static instance of `GroupCommitLock` -pub static MLS_COMMIT_LOCK: LazyLock = LazyLock::new(GroupCommitLock::new); - /// Inserts a model to the underlying data store, erroring if it already exists pub trait Store { fn store(&self, into: &StorageConnection) -> Result<(), StorageError>; @@ -152,7 +64,6 @@ pub trait Delete { fn delete(&self, key: Self::Key) -> Result; } -use crate::groups::GroupError; pub use stream_handles::{ spawn, AbortHandle, GenericStreamHandle, StreamHandle, StreamHandleError, }; diff --git a/xmtp_mls/src/storage/encrypted_store/group_intent.rs b/xmtp_mls/src/storage/encrypted_store/group_intent.rs index 70edb0956..bfb3c1399 100644 --- a/xmtp_mls/src/storage/encrypted_store/group_intent.rs +++ b/xmtp_mls/src/storage/encrypted_store/group_intent.rs @@ -15,7 +15,7 @@ use super::{ Sqlite, }; use crate::{ - groups::intents::{IntentError, SendMessageIntentData}, + groups::{intents::SendMessageIntentData, IntentError}, impl_fetch, impl_store, storage::StorageError, utils::id::calculate_message_id, diff --git a/xmtp_mls/src/storage/errors.rs b/xmtp_mls/src/storage/errors.rs index 3cc3df2a8..de25850ab 100644 --- a/xmtp_mls/src/storage/errors.rs +++ b/xmtp_mls/src/storage/errors.rs @@ -3,7 +3,7 @@ use std::sync::PoisonError; use diesel::result::DatabaseErrorKind; use thiserror::Error; -use super::sql_key_store::{self, SqlKeyStoreError}; +use super::sql_key_store; use crate::groups::intents::IntentError; use xmtp_common::{retryable, RetryableError}; @@ -27,7 +27,6 @@ pub enum StorageError { Serialization(String), #[error("deserialization error")] Deserialization(String), - // TODO:insipx Make NotFound into an enum of possible items that may not be found #[error("{0} not found")] NotFound(String), #[error("lock")] @@ -46,8 +45,6 @@ pub enum StorageError { FromHex(#[from] hex::FromHexError), #[error(transparent)] Duplicate(DuplicateItem), - #[error(transparent)] - OpenMlsStorage(#[from] SqlKeyStoreError), } #[derive(Error, Debug)] diff --git a/xmtp_mls/src/subscriptions.rs b/xmtp_mls/src/subscriptions.rs index 97f538504..f9675d500 100644 --- a/xmtp_mls/src/subscriptions.rs +++ b/xmtp_mls/src/subscriptions.rs @@ -385,7 +385,7 @@ where } WelcomeOrGroup::Group(group) => group?, }; - let metadata = group.metadata(&provider).await?; + let metadata = group.metadata(provider)?; Ok((metadata, group)) } } @@ -658,11 +658,11 @@ pub(crate) mod tests { .add_members_by_inbox_id(&[bob.inbox_id()]) .await .unwrap(); - let bob_groups = bob + let bob_group = bob .sync_welcomes(&bob.mls_provider().unwrap()) .await .unwrap(); - let bob_group = bob_groups.first().unwrap(); + let bob_group = bob_group.first().unwrap(); let notify = Delivery::new(None); let notify_ptr = notify.clone(); @@ -968,7 +968,6 @@ pub(crate) mod tests { } #[wasm_bindgen_test(unsupported = tokio::test(flavor = "multi_thread"))] - #[cfg_attr(target_family = "wasm", ignore)] async fn test_dm_streaming() { let alix = Arc::new(ClientBuilder::new_test_client(&generate_local_wallet()).await); let bo = Arc::new(ClientBuilder::new_test_client(&generate_local_wallet()).await);