From 59a0b6e0e3207f35ad85dd09a59ec2822021c3d9 Mon Sep 17 00:00:00 2001 From: Mojtaba Chenani Date: Mon, 9 Dec 2024 17:43:58 +0100 Subject: [PATCH 01/28] wip --- bindings_node/src/conversation.rs | 2 +- xmtp_mls/src/client.rs | 20 +- xmtp_mls/src/groups/intents.rs | 14 +- xmtp_mls/src/groups/mls_sync.rs | 681 +++++++++++++-------------- xmtp_mls/src/groups/mod.rs | 70 ++- xmtp_mls/src/groups/subscriptions.rs | 29 +- xmtp_mls/src/lib.rs | 62 +++ 7 files changed, 478 insertions(+), 400 deletions(-) diff --git a/bindings_node/src/conversation.rs b/bindings_node/src/conversation.rs index aeeffa6b1..d7df6a4fa 100644 --- a/bindings_node/src/conversation.rs +++ b/bindings_node/src/conversation.rs @@ -1,5 +1,5 @@ use std::{ops::Deref, sync::Arc}; - +use futures::TryFutureExt; use napi::{ bindgen_prelude::{Result, Uint8Array}, threadsafe_function::{ErrorStrategy, ThreadsafeFunction, ThreadsafeFunctionCallMode}, diff --git a/xmtp_mls/src/client.rs b/xmtp_mls/src/client.rs index 051c12500..82a317dc6 100644 --- a/xmtp_mls/src/client.rs +++ b/xmtp_mls/src/client.rs @@ -843,25 +843,19 @@ where let provider_ref = &provider; let active_group_count = Arc::clone(&active_group_count); async move { - let mls_group = group.load_mls_group(provider_ref)?; tracing::info!( inbox_id = self.inbox_id(), - "[{}] syncing group", - self.inbox_id() + "current epoch for [{}] in sync_all_groups()", + self.inbox_id(), ); tracing::info!( inbox_id = self.inbox_id(), - group_epoch = mls_group.epoch().as_u64(), - "current epoch for [{}] in sync_all_groups() is Epoch: [{}]", - self.inbox_id(), - mls_group.epoch() + "[{}] syncing group", + self.inbox_id() ); - if mls_group.is_active() { - group.maybe_update_installations(provider_ref, None).await?; - - group.sync_with_conn(provider_ref).await?; - active_group_count.fetch_add(1, Ordering::SeqCst); - } + group.maybe_update_installations(provider_ref, None).await?; + group.sync_with_conn(provider_ref).await?; + active_group_count.fetch_add(1, Ordering::SeqCst); Ok::<(), GroupError>(()) } diff --git a/xmtp_mls/src/groups/intents.rs b/xmtp_mls/src/groups/intents.rs index 80f5eb167..bbb4e06c0 100644 --- a/xmtp_mls/src/groups/intents.rs +++ b/xmtp_mls/src/groups/intents.rs @@ -725,6 +725,7 @@ impl TryFrom> for PostCommitAction { pub(crate) mod tests { #[cfg(target_arch = "wasm32")] wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_dedicated_worker); + use openmls::prelude::{MlsMessageBodyIn, MlsMessageIn, ProcessedMessageContent}; use tls_codec::Deserialize; use xmtp_cryptography::utils::generate_local_wallet; @@ -863,10 +864,15 @@ pub(crate) mod tests { }; let provider = group.client.mls_provider().unwrap(); - let mut openmls_group = group.load_mls_group(&provider).unwrap(); - let decrypted_message = openmls_group - .process_message(&provider, mls_message) - .unwrap(); + let decrypted_message = match group + .load_mls_group_with_lock(&provider, |mut mls_group| { + mls_group + .process_message(&provider, mls_message) + .map_err(|e| GroupError::Generic(e.to_string())) + }) { + Ok(message) => message, + Err(err) => panic!("Error: {:?}", err), + }; let staged_commit = match decrypted_message.into_content() { ProcessedMessageContent::StagedCommitMessage(staged_commit) => *staged_commit, diff --git a/xmtp_mls/src/groups/mls_sync.rs b/xmtp_mls/src/groups/mls_sync.rs index 60a609cd4..dc6c2048d 100644 --- a/xmtp_mls/src/groups/mls_sync.rs +++ b/xmtp_mls/src/groups/mls_sync.rs @@ -171,17 +171,14 @@ where tracing::info!( inbox_id = self.client.inbox_id(), group_id = hex::encode(&self.group_id), - current_epoch = self.load_mls_group(&mls_provider)?.epoch().as_u64(), "[{}] syncing group", self.client.inbox_id() ); tracing::info!( inbox_id = self.client.inbox_id(), group_id = hex::encode(&self.group_id), - current_epoch = self.load_mls_group(&mls_provider)?.epoch().as_u64(), - "current epoch for [{}] in sync() is Epoch: [{}]", + "current epoch for [{}] in sync()", self.client.inbox_id(), - self.load_mls_group(&mls_provider)?.epoch() ); self.maybe_update_installations(&mls_provider, None).await?; @@ -191,6 +188,10 @@ where // TODO: Should probably be renamed to `sync_with_provider` #[tracing::instrument(skip_all)] pub async fn sync_with_conn(&self, provider: &XmtpOpenMlsProvider) -> Result<(), GroupError> { + // Check if we're still part of the group + if !self.is_active(provider)? { + return Ok(()); + } let _mutex = self.mutex.lock().await; let mut errors: Vec = vec![]; @@ -345,345 +346,343 @@ where async fn process_own_message( &self, intent: StoredGroupIntent, - openmls_group: &mut OpenMlsGroup, provider: &XmtpOpenMlsProvider, message: ProtocolMessage, envelope: &GroupMessageV1, ) -> Result { - let GroupMessageV1 { - created_ns: envelope_timestamp_ns, - id: ref msg_id, - .. - } = *envelope; - - if intent.state == IntentState::Committed { - return Ok(IntentState::Committed); - } - let message_epoch = message.epoch(); - let group_epoch = openmls_group.epoch(); - debug!( - inbox_id = self.client.inbox_id(), - installation_id = hex::encode(self.client.installation_id()), - group_id = hex::encode(&self.group_id), - current_epoch = openmls_group.epoch().as_u64(), - msg_id, - intent.id, - intent.kind = %intent.kind, - "[{}]-[{}] processing own message for intent {} / {:?}, group epoch: {}, message_epoch: {}", - self.context().inbox_id(), - hex::encode(self.group_id.clone()), - intent.id, - intent.kind, - group_epoch, - message_epoch - ); - - let conn = provider.conn_ref(); - match intent.kind { - IntentKind::KeyUpdate - | IntentKind::UpdateGroupMembership - | IntentKind::UpdateAdminList - | IntentKind::MetadataUpdate - | IntentKind::UpdatePermission => { - if let Some(published_in_epoch) = intent.published_in_epoch { - let published_in_epoch_u64 = published_in_epoch as u64; - let group_epoch_u64 = group_epoch.as_u64(); - - if published_in_epoch_u64 != group_epoch_u64 { - tracing::warn!( - inbox_id = self.client.inbox_id(), - group_id = hex::encode(&self.group_id), - current_epoch = openmls_group.epoch().as_u64(), - msg_id, - intent.id, - intent.kind = %intent.kind, - "Intent was published in epoch {} but group is currently in epoch {}", - published_in_epoch_u64, - group_epoch_u64 - ); - return Ok(IntentState::ToPublish); - } - } + self.load_mls_group_with_lock_async(provider, |mut mls_group| async move { + let GroupMessageV1 { + created_ns: envelope_timestamp_ns, + id: ref msg_id, + .. + } = *envelope; + + if intent.state == IntentState::Committed { + return Ok(IntentState::Committed); + } + let group_epoch = mls_group.epoch(); - let pending_commit = if let Some(staged_commit) = intent.staged_commit { - decode_staged_commit(staged_commit)? - } else { - return Err(GroupMessageProcessingError::IntentMissingStagedCommit); - }; + let message_epoch = message.epoch(); + debug!( + inbox_id = self.client.inbox_id(), + installation_id = hex::encode(self.client.installation_id()), + group_id = hex::encode(&self.group_id), + msg_id, + intent.id, + intent.kind = %intent.kind, + "[{}]-[{}] processing own message for intent {} / {:?}, message_epoch: {}", + self.context().inbox_id(), + hex::encode(self.group_id.clone()), + intent.id, + intent.kind, + message_epoch + ); - tracing::info!( - "[{}] Validating commit for intent {}. Message timestamp: {}", - self.context().inbox_id(), - intent.id, - envelope_timestamp_ns - ); + let conn = provider.conn_ref(); + match intent.kind { + IntentKind::KeyUpdate + | IntentKind::UpdateGroupMembership + | IntentKind::UpdateAdminList + | IntentKind::MetadataUpdate + | IntentKind::UpdatePermission => { + if let Some(published_in_epoch) = intent.published_in_epoch { + let published_in_epoch_u64 = published_in_epoch as u64; + let group_epoch_u64 = group_epoch.as_u64(); + + if published_in_epoch_u64 != group_epoch_u64 { + tracing::warn!( + inbox_id = self.client.inbox_id(), + group_id = hex::encode(&self.group_id), + msg_id, + intent.id, + intent.kind = %intent.kind, + "Intent was published in epoch {} but group is currently", + published_in_epoch_u64, + ); + return Ok(IntentState::ToPublish); + } + } - let maybe_validated_commit = ValidatedCommit::from_staged_commit( - self.client.as_ref(), - conn, - &pending_commit, - openmls_group, - ) - .await; + let pending_commit = if let Some(staged_commit) = intent.staged_commit { + decode_staged_commit(staged_commit)? + } else { + return Err(GroupMessageProcessingError::IntentMissingStagedCommit); + }; - if let Err(err) = maybe_validated_commit { - tracing::error!( - "Error validating commit for own message. Intent ID [{}]: {:?}", + tracing::info!( + "[{}] Validating commit for intent {}. Message timestamp: {}", + self.context().inbox_id(), intent.id, - err + envelope_timestamp_ns ); - // Return before merging commit since it does not pass validation - // Return OK so that the group intent update is still written to the DB - return Ok(IntentState::Error); - } - let validated_commit = maybe_validated_commit.expect("Checked for error"); + let maybe_validated_commit = ValidatedCommit::from_staged_commit( + self.client.as_ref(), + conn, + &pending_commit, + &mls_group, + ) + .await; - tracing::info!( - "[{}] merging pending commit for intent {}", - self.context().inbox_id(), - intent.id - ); - if let Err(err) = openmls_group.merge_staged_commit(&provider, pending_commit) { - tracing::error!("error merging commit: {}", err); - return Ok(IntentState::ToPublish); - } else { - // If no error committing the change, write a transcript message - self.save_transcript_message(conn, validated_commit, envelope_timestamp_ns)?; - } - } - IntentKind::SendMessage => { - if !Self::is_valid_epoch( - self.context().inbox_id(), - intent.id, - group_epoch, - message_epoch, - MAX_PAST_EPOCHS, - ) { - return Ok(IntentState::ToPublish); + if let Err(err) = maybe_validated_commit { + tracing::error!( + "Error validating commit for own message. Intent ID [{}]: {:?}", + intent.id, + err + ); + // Return before merging commit since it does not pass validation + // Return OK so that the group intent update is still written to the DB + return Ok(IntentState::Error); + } + + let validated_commit = maybe_validated_commit.expect("Checked for error"); + + tracing::info!( + "[{}] merging pending commit for intent {}", + self.context().inbox_id(), + intent.id + ); + if let Err(err) = mls_group.merge_staged_commit(&provider, pending_commit) { + tracing::error!("error merging commit: {}", err); + return Ok(IntentState::ToPublish); + } else { + // If no error committing the change, write a transcript message + self.save_transcript_message(conn, validated_commit, envelope_timestamp_ns)?; + } } - if let Some(id) = intent.message_id()? { - conn.set_delivery_status_to_published(&id, envelope_timestamp_ns)?; + IntentKind::SendMessage => { + if !Self::is_valid_epoch( + self.context().inbox_id(), + intent.id, + group_epoch, + message_epoch, + MAX_PAST_EPOCHS, + ) { + return Ok(IntentState::ToPublish); + } + if let Some(id) = intent.message_id()? { + conn.set_delivery_status_to_published(&id, envelope_timestamp_ns)?; + } } - } - }; + }; - Ok(IntentState::Committed) + Ok(IntentState::Committed) + }).await } #[tracing::instrument(level = "trace", skip_all)] async fn process_external_message( &self, - openmls_group: &mut OpenMlsGroup, provider: &XmtpOpenMlsProvider, message: PrivateMessageIn, envelope: &GroupMessageV1, ) -> Result<(), GroupMessageProcessingError> { - let GroupMessageV1 { - created_ns: envelope_timestamp_ns, - id: ref msg_id, - .. - } = *envelope; + self.load_mls_group_with_lock_async(provider, |mut mls_group| async move { + let GroupMessageV1 { + created_ns: envelope_timestamp_ns, + id: ref msg_id, + .. + } = *envelope; - let decrypted_message = openmls_group.process_message(provider, message)?; - let (sender_inbox_id, sender_installation_id) = - extract_message_sender(openmls_group, &decrypted_message, envelope_timestamp_ns)?; + let decrypted_message = mls_group.process_message(provider, message)?; + let (sender_inbox_id, sender_installation_id) = + extract_message_sender(&mut mls_group, &decrypted_message, envelope_timestamp_ns)?; - tracing::info!( - inbox_id = self.client.inbox_id(), - sender_inbox_id = sender_inbox_id, - sender_installation_id = hex::encode(&sender_installation_id), - group_id = hex::encode(&self.group_id), - current_epoch = openmls_group.epoch().as_u64(), - msg_epoch = decrypted_message.epoch().as_u64(), - msg_group_id = hex::encode(decrypted_message.group_id().as_slice()), - msg_id, - "[{}] extracted sender inbox id: {}", - self.client.inbox_id(), - sender_inbox_id - ); + tracing::info!( + inbox_id = self.client.inbox_id(), + sender_inbox_id = sender_inbox_id, + sender_installation_id = hex::encode(&sender_installation_id), + group_id = hex::encode(&self.group_id), + current_epoch = mls_group.epoch().as_u64(), + msg_epoch = decrypted_message.epoch().as_u64(), + msg_group_id = hex::encode(decrypted_message.group_id().as_slice()), + msg_id, + "[{}] extracted sender inbox id: {}", + self.client.inbox_id(), + sender_inbox_id + ); - let (msg_epoch, msg_group_id) = ( - decrypted_message.epoch().as_u64(), - hex::encode(decrypted_message.group_id().as_slice()), - ); - match decrypted_message.into_content() { - ProcessedMessageContent::ApplicationMessage(application_message) => { - tracing::info!( - inbox_id = self.client.inbox_id(), - sender_inbox_id = sender_inbox_id, - group_id = hex::encode(&self.group_id), - current_epoch = openmls_group.epoch().as_u64(), - msg_epoch, - msg_group_id, - msg_id, - "[{}] decoding application message", - self.context().inbox_id() - ); - let message_bytes = application_message.into_bytes(); - - let mut bytes = Bytes::from(message_bytes.clone()); - let envelope = PlaintextEnvelope::decode(&mut bytes)?; - - match envelope.content { - Some(Content::V1(V1 { - idempotency_key, - content, - })) => { - let message_id = - calculate_message_id(&self.group_id, &content, &idempotency_key); - StoredGroupMessage { - id: message_id, - group_id: self.group_id.clone(), - decrypted_message_bytes: content, - sent_at_ns: envelope_timestamp_ns as i64, - kind: GroupMessageKind::Application, - sender_installation_id, - sender_inbox_id, - delivery_status: DeliveryStatus::Published, + let (msg_epoch, msg_group_id) = ( + decrypted_message.epoch().as_u64(), + hex::encode(decrypted_message.group_id().as_slice()), + ); + match decrypted_message.into_content() { + ProcessedMessageContent::ApplicationMessage(application_message) => { + tracing::info!( + inbox_id = self.client.inbox_id(), + sender_inbox_id = sender_inbox_id, + group_id = hex::encode(&self.group_id), + current_epoch = mls_group.epoch().as_u64(), + msg_epoch, + msg_group_id, + msg_id, + "[{}] decoding application message", + self.context().inbox_id() + ); + let message_bytes = application_message.into_bytes(); + + let mut bytes = Bytes::from(message_bytes.clone()); + let envelope = PlaintextEnvelope::decode(&mut bytes)?; + + match envelope.content { + Some(Content::V1(V1 { + idempotency_key, + content, + })) => { + let message_id = + calculate_message_id(&self.group_id, &content, &idempotency_key); + StoredGroupMessage { + id: message_id, + group_id: self.group_id.clone(), + decrypted_message_bytes: content, + sent_at_ns: envelope_timestamp_ns as i64, + kind: GroupMessageKind::Application, + sender_installation_id, + sender_inbox_id, + delivery_status: DeliveryStatus::Published, + } + .store_or_ignore(provider.conn_ref())? } - .store_or_ignore(provider.conn_ref())? - } - Some(Content::V2(V2 { - idempotency_key, - message_type, - })) => { - match message_type { - Some(MessageType::DeviceSyncRequest(history_request)) => { - let content: DeviceSyncContent = - DeviceSyncContent::Request(history_request); - let content_bytes = serde_json::to_vec(&content)?; - let message_id = calculate_message_id( - &self.group_id, - &content_bytes, - &idempotency_key, - ); - - // store the request message - StoredGroupMessage { - id: message_id.clone(), - group_id: self.group_id.clone(), - decrypted_message_bytes: content_bytes, - sent_at_ns: envelope_timestamp_ns as i64, - kind: GroupMessageKind::Application, - sender_installation_id, - sender_inbox_id: sender_inbox_id.clone(), - delivery_status: DeliveryStatus::Published, + Some(Content::V2(V2 { + idempotency_key, + message_type, + })) => { + match message_type { + Some(MessageType::DeviceSyncRequest(history_request)) => { + let content: DeviceSyncContent = + DeviceSyncContent::Request(history_request); + let content_bytes = serde_json::to_vec(&content)?; + let message_id = calculate_message_id( + &self.group_id, + &content_bytes, + &idempotency_key, + ); + + // store the request message + StoredGroupMessage { + id: message_id.clone(), + group_id: self.group_id.clone(), + decrypted_message_bytes: content_bytes, + sent_at_ns: envelope_timestamp_ns as i64, + kind: GroupMessageKind::Application, + sender_installation_id, + sender_inbox_id: sender_inbox_id.clone(), + delivery_status: DeliveryStatus::Published, + } + .store_or_ignore(provider.conn_ref())?; + + tracing::info!("Received a history request."); + let _ = self.client.local_events().send(LocalEvents::SyncMessage( + SyncMessage::Request { message_id }, + )); } - .store_or_ignore(provider.conn_ref())?; - tracing::info!("Received a history request."); - let _ = self.client.local_events().send(LocalEvents::SyncMessage( - SyncMessage::Request { message_id }, - )); - } - - Some(MessageType::DeviceSyncReply(history_reply)) => { - let content: DeviceSyncContent = - DeviceSyncContent::Reply(history_reply); - let content_bytes = serde_json::to_vec(&content)?; - let message_id = calculate_message_id( - &self.group_id, - &content_bytes, - &idempotency_key, - ); - - // store the reply message - StoredGroupMessage { - id: message_id.clone(), - group_id: self.group_id.clone(), - decrypted_message_bytes: content_bytes, - sent_at_ns: envelope_timestamp_ns as i64, - kind: GroupMessageKind::Application, - sender_installation_id, - sender_inbox_id, - delivery_status: DeliveryStatus::Published, + Some(MessageType::DeviceSyncReply(history_reply)) => { + let content: DeviceSyncContent = + DeviceSyncContent::Reply(history_reply); + let content_bytes = serde_json::to_vec(&content)?; + let message_id = calculate_message_id( + &self.group_id, + &content_bytes, + &idempotency_key, + ); + + // store the reply message + StoredGroupMessage { + id: message_id.clone(), + group_id: self.group_id.clone(), + decrypted_message_bytes: content_bytes, + sent_at_ns: envelope_timestamp_ns as i64, + kind: GroupMessageKind::Application, + sender_installation_id, + sender_inbox_id, + delivery_status: DeliveryStatus::Published, + } + .store_or_ignore(provider.conn_ref())?; + + tracing::info!("Received a history reply."); + let _ = self.client.local_events().send(LocalEvents::SyncMessage( + SyncMessage::Reply { message_id }, + )); } - .store_or_ignore(provider.conn_ref())?; - - tracing::info!("Received a history reply."); - let _ = self.client.local_events().send(LocalEvents::SyncMessage( - SyncMessage::Reply { message_id }, - )); - } - Some(MessageType::UserPreferenceUpdate(update)) => { - // Ignore errors since this may come from a newer version of the lib - // that has new update types. - if let Ok(update) = update.try_into() { - let _ = self - .client - .local_events() - .send(LocalEvents::IncomingPreferenceUpdate(vec![update])); - } else { - tracing::warn!("Failed to deserialize preference update. Is this libxmtp version old?"); + Some(MessageType::UserPreferenceUpdate(update)) => { + // Ignore errors since this may come from a newer version of the lib + // that has new update types. + if let Ok(update) = update.try_into() { + let _ = self + .client + .local_events() + .send(LocalEvents::IncomingPreferenceUpdate(vec![update])); + } else { + tracing::warn!("Failed to deserialize preference update. Is this libxmtp version old?"); + } + } + _ => { + return Err(GroupMessageProcessingError::InvalidPayload); } - } - _ => { - return Err(GroupMessageProcessingError::InvalidPayload); } } + None => return Err(GroupMessageProcessingError::InvalidPayload), } - None => return Err(GroupMessageProcessingError::InvalidPayload), } - } - ProcessedMessageContent::ProposalMessage(_proposal_ptr) => { - // intentionally left blank. - } - ProcessedMessageContent::ExternalJoinProposalMessage(_external_proposal_ptr) => { - // intentionally left blank. - } - ProcessedMessageContent::StagedCommitMessage(staged_commit) => { - tracing::info!( - inbox_id = self.client.inbox_id(), - sender_inbox_id = sender_inbox_id, - sender_installation_id = hex::encode(&sender_installation_id), - group_id = hex::encode(&self.group_id), - current_epoch = openmls_group.epoch().as_u64(), - msg_epoch, - msg_group_id, - msg_id, - "[{}] received staged commit. Merging and clearing any pending commits", - self.context().inbox_id() - ); + ProcessedMessageContent::ProposalMessage(_proposal_ptr) => { + // intentionally left blank. + } + ProcessedMessageContent::ExternalJoinProposalMessage(_external_proposal_ptr) => { + // intentionally left blank. + } + ProcessedMessageContent::StagedCommitMessage(staged_commit) => { + tracing::info!( + inbox_id = self.client.inbox_id(), + sender_inbox_id = sender_inbox_id, + sender_installation_id = hex::encode(&sender_installation_id), + group_id = hex::encode(&self.group_id), + current_epoch = mls_group.epoch().as_u64(), + msg_epoch, + msg_group_id, + msg_id, + "[{}] received staged commit. Merging and clearing any pending commits", + self.context().inbox_id() + ); - let sc = *staged_commit; + let sc = *staged_commit; - // Validate the commit - let validated_commit = ValidatedCommit::from_staged_commit( - self.client.as_ref(), - provider.conn_ref(), - &sc, - openmls_group, - ) - .await?; - tracing::info!( - inbox_id = self.client.inbox_id(), - sender_inbox_id = sender_inbox_id, - sender_installation_id = hex::encode(&sender_installation_id), - group_id = hex::encode(&self.group_id), - current_epoch = openmls_group.epoch().as_u64(), - msg_epoch, - msg_group_id, - msg_id, - "[{}] staged commit is valid, will attempt to merge", - self.context().inbox_id() - ); - openmls_group.merge_staged_commit(provider, sc)?; - self.save_transcript_message( - provider.conn_ref(), - validated_commit, - envelope_timestamp_ns, - )?; - } - }; + // Validate the commit + let validated_commit = ValidatedCommit::from_staged_commit( + self.client.as_ref(), + provider.conn_ref(), + &sc, + &mls_group, + ) + .await?; + tracing::info!( + inbox_id = self.client.inbox_id(), + sender_inbox_id = sender_inbox_id, + sender_installation_id = hex::encode(&sender_installation_id), + group_id = hex::encode(&self.group_id), + current_epoch = mls_group.epoch().as_u64(), + msg_epoch, + msg_group_id, + msg_id, + "[{}] staged commit is valid, will attempt to merge", + self.context().inbox_id() + ); + mls_group.merge_staged_commit(provider, sc)?; + self.save_transcript_message( + provider.conn_ref(), + validated_commit, + envelope_timestamp_ns, + )?; + } + }; - Ok(()) + Ok(()) + }).await } #[tracing::instrument(level = "trace", skip_all)] pub(super) async fn process_message( &self, - openmls_group: &mut OpenMlsGroup, provider: &XmtpOpenMlsProvider, envelope: &GroupMessageV1, allow_epoch_increment: bool, @@ -703,63 +702,34 @@ where let intent = provider .conn_ref() .find_group_intent_by_payload_hash(sha256(envelope.data.as_slice())); - tracing::info!( - inbox_id = self.client.inbox_id(), - group_id = hex::encode(&self.group_id), - current_epoch = openmls_group.epoch().as_u64(), - msg_id = envelope.id, - "Processing envelope with hash {:?}", - hex::encode(sha256(envelope.data.as_slice())) - ); - match intent { // Intent with the payload hash matches Ok(Some(intent)) => { let intent_id = intent.id; - tracing::info!( - inbox_id = self.client.inbox_id(), - group_id = hex::encode(&self.group_id), - current_epoch = openmls_group.epoch().as_u64(), - msg_id = envelope.id, - intent_id, - intent.kind = %intent.kind, - "client [{}] is about to process own envelope [{}] for intent [{}]", - self.client.inbox_id(), - envelope.id, - intent_id - ); - match self - .process_own_message(intent, openmls_group, provider, message.into(), envelope) - .await? - { - IntentState::ToPublish => { - Ok(provider.conn_ref().set_group_intent_to_publish(intent_id)?) - } - IntentState::Committed => { - Ok(provider.conn_ref().set_group_intent_committed(intent_id)?) - } - IntentState::Published => { - tracing::error!("Unexpected behaviour: returned intent state published from process_own_message"); - Ok(()) - } - IntentState::Error => { - tracing::warn!("Intent [{}] moved to error status", intent_id); - Ok(provider.conn_ref().set_group_intent_error(intent_id)?) + match self + .process_own_message(intent, provider, message.into(), envelope) + .await? + { + IntentState::ToPublish => { + Ok(provider.conn_ref().set_group_intent_to_publish(intent_id)?) + } + IntentState::Committed => { + Ok(provider.conn_ref().set_group_intent_committed(intent_id)?) + } + IntentState::Published => { + tracing::error!("Unexpected behaviour: returned intent state published from process_own_message"); + Ok(()) + } + IntentState::Error => { + tracing::warn!("Intent [{}] moved to error status", intent_id); + Ok(provider.conn_ref().set_group_intent_error(intent_id)?) + } } - } + } // No matching intent found Ok(None) => { - tracing::info!( - inbox_id = self.client.inbox_id(), - group_id = hex::encode(&self.group_id), - current_epoch = openmls_group.epoch().as_u64(), - msg_id = envelope.id, - "client [{}] is about to process external envelope [{}]", - self.client.inbox_id(), - envelope.id - ); - self.process_external_message(openmls_group, provider, message, envelope) + self.process_external_message(provider, message, envelope) .await } Err(err) => Err(GroupMessageProcessingError::Storage(err)), @@ -770,7 +740,6 @@ where async fn consume_message( &self, envelope: &GroupMessage, - openmls_group: &mut OpenMlsGroup, conn: &DbConnection, ) -> Result<(), GroupMessageProcessingError> { let msgv1 = match &envelope.version { @@ -805,8 +774,7 @@ where EntityKind::Group, msgv1.id, |provider| async move { - self.process_message(openmls_group, &provider, msgv1, true) - .await?; + self.process_message(&provider, msgv1, true).await?; Ok::<(), GroupMessageProcessingError>(()) }, ) @@ -821,16 +789,11 @@ where messages: Vec, provider: &XmtpOpenMlsProvider, ) -> Result<(), GroupError> { - let mut openmls_group = self.load_mls_group(provider)?; - let mut receive_errors: Vec = vec![]; for message in messages.into_iter() { let result = retry_async!( Retry::default(), - (async { - self.consume_message(&message, &mut openmls_group, provider.conn_ref()) - .await - }) + (async { self.consume_message(&message, provider.conn_ref()).await }) ); if let Err(e) = result { let is_retryable = e.is_retryable(); @@ -1172,6 +1135,10 @@ where provider: &XmtpOpenMlsProvider, update_interval_ns: Option, ) -> Result<(), GroupError> { + // Check if we're still part of the group + if !self.is_active(provider)? { + return Ok(()); + } // determine how long of an interval in time to use before updating list let interval_ns = match update_interval_ns { Some(val) => val, diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index ad9d92d34..48d90f168 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -57,6 +57,7 @@ use self::{ group_permissions::PolicySet, validated_commit::CommitValidationError, }; +use std::future::Future; use std::{collections::HashSet, sync::Arc}; use xmtp_cryptography::signature::{sanitize_evm_addresses, AddressValidationError}; use xmtp_id::{InboxId, InboxIdRef}; @@ -79,6 +80,7 @@ use crate::{ MAX_PAST_EPOCHS, MUTABLE_METADATA_EXTENSION_ID, SEND_MESSAGE_UPDATE_INSTALLATIONS_INTERVAL_NS, }, + groups, hpke::{decrypt_welcome, HpkeError}, identity::{parse_credential, IdentityError}, identity_updates::{load_identity_updates, InstallationDiffError}, @@ -96,8 +98,10 @@ use crate::{ subscriptions::{LocalEventError, LocalEvents}, utils::{id::calculate_message_id, time::now_ns}, xmtp_openmls_provider::XmtpOpenMlsProvider, - Store, + GroupCommitLock, Store, MLS_COMMIT_LOCK, }; +use crate::hpke::HpkeError::StorageError; +use crate::storage::sql_key_store::SqlKeyStoreError; #[derive(Debug, Error)] pub enum GroupError { @@ -202,6 +206,8 @@ pub enum GroupError { IntentNotCommitted, #[error(transparent)] ProcessIntent(#[from] ProcessIntentError), + #[error("Failed to acquire lock for group operation")] + LockUnavailable, } impl RetryableError for GroupError { @@ -228,6 +234,7 @@ impl RetryableError for GroupError { Self::MessageHistory(err) => err.is_retryable(), Self::ProcessIntent(err) => err.is_retryable(), Self::LocalEvent(err) => err.is_retryable(), + Self::LockUnavailable => true, Self::SyncFailedToWait => true, Self::GroupNotFound | Self::GroupMetadata(_) @@ -337,6 +344,59 @@ impl MlsGroup { Ok(mls_group) } + #[tracing::instrument(level = "trace", skip_all)] + pub(crate) fn load_mls_group_with_lock( + &self, + provider: impl OpenMlsProvider, + operation: F, + ) -> Result + where + F: FnOnce(OpenMlsGroup) -> Result, + { + // Get the group ID for locking + let group_id = self.group_id.clone(); + + // Acquire the lock synchronously using blocking_lock + + let _lock = MLS_COMMIT_LOCK.get_lock_sync(group_id.clone())?; + // Load the MLS group + let mls_group = + OpenMlsGroup::load(provider.storage(), &GroupId::from_slice(&self.group_id)) + .map_err(|_| GroupError::GroupNotFound)? + .ok_or(GroupError::GroupNotFound)?; + + // Perform the operation with the MLS group + operation(mls_group) + } + + #[tracing::instrument(level = "trace", skip_all)] + pub(crate) async fn load_mls_group_with_lock_async( + &self, + provider: impl OpenMlsProvider, + operation: F, + ) -> Result + where + F: FnOnce(OpenMlsGroup) -> Fut + Send + 'static, + Fut: Future>, + E: From, + E: Into, + E: From + { + // Get the group ID for locking + let group_id = self.group_id.clone(); + + // Acquire the lock asynchronously + let _lock = MLS_COMMIT_LOCK.get_lock_async(group_id.clone()).await; + + // Load the MLS group + let mls_group = + OpenMlsGroup::load(provider.storage(), &GroupId::from_slice(&self.group_id))? + .ok_or(|e|GroupMessageProcessingError::Storage)?; + + // Perform the operation with the MLS group + operation(mls_group).await.map_err(Into::into) + } + // Create a new group and save it to the DB pub fn create_and_insert( client: Arc, @@ -1109,12 +1169,11 @@ impl MlsGroup { .await } - /// Checks if the the current user is active in the group. + /// Checks if the current user is active in the group. /// /// If the current user has been kicked out of the group, `is_active` will return `false` pub fn is_active(&self, provider: impl OpenMlsProvider) -> Result { - let mls_group = self.load_mls_group(provider)?; - Ok(mls_group.is_active()) + self.load_mls_group_with_lock(provider, |mls_group| Ok(mls_group.is_active())) } /// Get the `GroupMetadata` of the group. @@ -3677,9 +3736,8 @@ pub(crate) mod tests { panic!("wrong message format") }; let provider = client.mls_provider().unwrap(); - let mut openmls_group = group.load_mls_group(&provider).unwrap(); let process_result = group - .process_message(&mut openmls_group, &provider, &first_message, false) + .process_message(&provider, &first_message, false) .await; assert_err!( diff --git a/xmtp_mls/src/groups/subscriptions.rs b/xmtp_mls/src/groups/subscriptions.rs index 2caf9bec9..8e76d437c 100644 --- a/xmtp_mls/src/groups/subscriptions.rs +++ b/xmtp_mls/src/groups/subscriptions.rs @@ -41,31 +41,22 @@ impl MlsGroup { let process_result = retry_async!( Retry::default(), (async { - let client_id = &client_id; let msgv1 = &msgv1; self.context() .store() .transaction_async(|provider| async move { - let mut openmls_group = self.load_mls_group(&provider)?; - - // Attempt processing immediately, but fail if the message is not an Application Message - // Returning an error should roll back the DB tx - tracing::info!( - inbox_id = self.client.inbox_id(), - group_id = hex::encode(&self.group_id), - current_epoch = openmls_group.epoch().as_u64(), - msg_id = msgv1.id, - "current epoch for [{}] in process_stream_entry() is Epoch: [{}]", - client_id, - openmls_group.epoch() - ); - - self.process_message(&mut openmls_group, &provider, msgv1, false) - .await - // NOTE: We want to make sure we retry an error in process_message - .map_err(SubscribeError::ReceiveGroup) + let prov_ref = &provider; // Borrow provider instead of moving it + self.load_mls_group_with_lock_async(prov_ref, |mut mls_group| async move { + // Attempt processing immediately, but fail if the message is not an Application Message + // Returning an error should roll back the DB tx + self.process_message(&mut mls_group, &prov_ref, msgv1, false) + .await + // NOTE: We want to make sure we retry an error in process_message + .map_err(SubscribeError::ReceiveGroup) + }).await }) .await + }) ); diff --git a/xmtp_mls/src/lib.rs b/xmtp_mls/src/lib.rs index cf0a08702..84c9279a6 100644 --- a/xmtp_mls/src/lib.rs +++ b/xmtp_mls/src/lib.rs @@ -22,15 +22,76 @@ pub mod utils; pub mod verified_key_package_v2; mod xmtp_openmls_provider; +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; +use tokio::sync::{Semaphore, OwnedSemaphorePermit}; pub use client::{Client, Network}; use storage::{DuplicateItem, StorageError}; pub use xmtp_openmls_provider::XmtpOpenMlsProvider; +use std::sync::LazyLock; pub use xmtp_id::InboxOwner; pub use xmtp_proto::api_client::trait_impls::*; #[macro_use] extern crate tracing; +/// A manager for group-specific semaphores +#[derive(Debug)] +pub struct GroupCommitLock { + // Storage for group-specific semaphores + locks: Mutex, Arc>>, +} + +impl GroupCommitLock { + /// Create a new `GroupCommitLock` + pub fn new() -> Self { + Self { + locks: Mutex::new(HashMap::new()), + } + } + + /// Get or create a semaphore for a specific group and acquire it, returning a guard + pub async fn get_lock_async(&self, group_id: Vec) -> SemaphoreGuard { + let semaphore = { + let mut locks = self.locks.lock().unwrap(); + locks + .entry(group_id) + .or_insert_with(|| Arc::new(Semaphore::new(1))) + .clone() + }; + + let semaphore_clone = semaphore.clone(); + let permit = semaphore.acquire_owned().await.unwrap(); + SemaphoreGuard { _permit: permit, _semaphore: semaphore_clone } + } + + /// Get or create a semaphore for a specific group and acquire it synchronously + pub fn get_lock_sync(&self, group_id: Vec) -> Result { + let semaphore = { + let mut locks = self.locks.lock().unwrap(); + locks + .entry(group_id) + .or_insert_with(|| Arc::new(Semaphore::new(1))) + .clone() // Clone here to retain ownership for later use + }; + + // Synchronously acquire the permit + let permit = semaphore.clone().try_acquire_owned().map_err(|_| GroupError::LockUnavailable)?; + Ok(SemaphoreGuard { + _permit: permit, + _semaphore: semaphore, // semaphore is now valid because we cloned it earlier + }) + } +} + +/// A guard that releases the semaphore when dropped +pub struct SemaphoreGuard { + _permit: OwnedSemaphorePermit, + _semaphore: Arc, +} + +// Static instance of `GroupCommitLock` +pub static MLS_COMMIT_LOCK: LazyLock = LazyLock::new(GroupCommitLock::new); /// Global Marker trait for WebAssembly #[cfg(target_arch = "wasm32")] @@ -79,6 +140,7 @@ pub trait Delete { pub use stream_handles::{ spawn, AbortHandle, GenericStreamHandle, StreamHandle, StreamHandleError, }; +use crate::groups::GroupError; #[cfg(target_arch = "wasm32")] #[doc(hidden)] From 653f6c9b2cc17289b8dcb0cc228a030ec9174043 Mon Sep 17 00:00:00 2001 From: Mojtaba Chenani Date: Mon, 9 Dec 2024 18:53:28 +0100 Subject: [PATCH 02/28] wip --- xmtp_mls/src/groups/members.rs | 7 +- xmtp_mls/src/groups/mls_sync.rs | 269 +++++++++++++-------------- xmtp_mls/src/groups/mod.rs | 134 +++++++------ xmtp_mls/src/groups/subscriptions.rs | 21 ++- xmtp_mls/src/storage/errors.rs | 5 +- 5 files changed, 228 insertions(+), 208 deletions(-) diff --git a/xmtp_mls/src/groups/members.rs b/xmtp_mls/src/groups/members.rs index 5ca53a40a..730529fb3 100644 --- a/xmtp_mls/src/groups/members.rs +++ b/xmtp_mls/src/groups/members.rs @@ -40,9 +40,10 @@ where &self, provider: &XmtpOpenMlsProvider, ) -> Result, GroupError> { - let openmls_group = self.load_mls_group(provider)?; - // TODO: Replace with try_into from extensions - let group_membership = extract_group_membership(openmls_group.extensions())?; + let group_membership = self.load_mls_group_with_lock(provider, |mls_group| { + // Extract group membership from extensions + Ok(extract_group_membership(mls_group.extensions())?) + })?; let requests = group_membership .members .into_iter() diff --git a/xmtp_mls/src/groups/mls_sync.rs b/xmtp_mls/src/groups/mls_sync.rs index dc6c2048d..1a1f50b2d 100644 --- a/xmtp_mls/src/groups/mls_sync.rs +++ b/xmtp_mls/src/groups/mls_sync.rs @@ -754,7 +754,6 @@ where }; let last_cursor = conn.get_last_cursor_for_id(&self.group_id, message_entity_kind)?; - tracing::info!("### last cursor --> [{:?}]", last_cursor); let should_skip_message = last_cursor > msgv1.id as i64; if should_skip_message { tracing::info!( @@ -882,104 +881,104 @@ where &self, provider: &XmtpOpenMlsProvider, ) -> Result<(), GroupError> { - let mut openmls_group = self.load_mls_group(provider)?; + self.load_mls_group_with_lock_async(provider, |mut mls_group| async move { + let intents = provider.conn_ref().find_group_intents( + self.group_id.clone(), + Some(vec![IntentState::ToPublish]), + None, + )?; + + for intent in intents { + let result = retry_async!( + Retry::default(), + (async { + self.get_publish_intent_data(provider, &mut mls_group, &intent) + .await + }) + ); - let intents = provider.conn_ref().find_group_intents( - self.group_id.clone(), - Some(vec![IntentState::ToPublish]), - None, - )?; + match result { + Err(err) => { + tracing::error!(error = %err, "error getting publish intent data {:?}", err); + if (intent.publish_attempts + 1) as usize >= MAX_INTENT_PUBLISH_ATTEMPTS { + tracing::error!( + intent.id, + intent.kind = %intent.kind, + inbox_id = self.client.inbox_id(), + group_id = hex::encode(&self.group_id), + "intent {} has reached max publish attempts", intent.id); + // TODO: Eventually clean up errored attempts + provider + .conn_ref() + .set_group_intent_error_and_fail_msg(&intent)?; + } else { + provider + .conn_ref() + .increment_intent_publish_attempt_count(intent.id)?; + } - for intent in intents { - let result = retry_async!( - Retry::default(), - (async { - self.get_publish_intent_data(provider, &mut openmls_group, &intent) - .await - }) - ); + return Err(err); + } + Ok(Some(PublishIntentData { + payload_to_publish, + post_commit_action, + staged_commit, + })) => { + let payload_slice = payload_to_publish.as_slice(); + let has_staged_commit = staged_commit.is_some(); + provider.conn_ref().set_group_intent_published( + intent.id, + sha256(payload_slice), + post_commit_action, + staged_commit, + mls_group.epoch().as_u64() as i64, + )?; + tracing::debug!( + inbox_id = self.client.inbox_id(), + installation_id = hex::encode(self.client.installation_id()), + intent.id, + intent.kind = %intent.kind, + group_id = hex::encode(&self.group_id), + "client [{}] set stored intent [{}] to state `published`", + self.client.inbox_id(), + intent.id + ); - match result { - Err(err) => { - tracing::error!(error = %err, "error getting publish intent data {:?}", err); - if (intent.publish_attempts + 1) as usize >= MAX_INTENT_PUBLISH_ATTEMPTS { - tracing::error!( + self.client + .api() + .send_group_messages(vec![payload_slice]) + .await?; + + tracing::info!( intent.id, intent.kind = %intent.kind, inbox_id = self.client.inbox_id(), + installation_id = hex::encode(self.client.installation_id()), group_id = hex::encode(&self.group_id), - "intent {} has reached max publish attempts", intent.id); - // TODO: Eventually clean up errored attempts - provider - .conn_ref() - .set_group_intent_error_and_fail_msg(&intent)?; - } else { - provider - .conn_ref() - .increment_intent_publish_attempt_count(intent.id)?; + "[{}] published intent [{}] of type [{}]", + self.client.inbox_id(), + intent.id, + intent.kind + ); + if has_staged_commit { + tracing::info!("Commit sent. Stopping further publishes for this round"); + return Ok(()); + } } - - return Err(err); - } - Ok(Some(PublishIntentData { - payload_to_publish, - post_commit_action, - staged_commit, - })) => { - let payload_slice = payload_to_publish.as_slice(); - let has_staged_commit = staged_commit.is_some(); - provider.conn_ref().set_group_intent_published( - intent.id, - sha256(payload_slice), - post_commit_action, - staged_commit, - openmls_group.epoch().as_u64() as i64, - )?; - tracing::debug!( - inbox_id = self.client.inbox_id(), - installation_id = hex::encode(self.client.installation_id()), - intent.id, - intent.kind = %intent.kind, - group_id = hex::encode(&self.group_id), - "client [{}] set stored intent [{}] to state `published`", - self.client.inbox_id(), - intent.id - ); - - self.client - .api() - .send_group_messages(vec![payload_slice]) - .await?; - - tracing::info!( - intent.id, - intent.kind = %intent.kind, - inbox_id = self.client.inbox_id(), - installation_id = hex::encode(self.client.installation_id()), - group_id = hex::encode(&self.group_id), - "[{}] published intent [{}] of type [{}]", - self.client.inbox_id(), - intent.id, - intent.kind - ); - if has_staged_commit { - tracing::info!("Commit sent. Stopping further publishes for this round"); - return Ok(()); + Ok(None) => { + tracing::info!( + inbox_id = self.client.inbox_id(), + installation_id = hex::encode(self.client.installation_id()), + "Skipping intent because no publish data returned" + ); + let deleter: &dyn Delete = provider.conn_ref(); + deleter.delete(intent.id)?; } } - Ok(None) => { - tracing::info!( - inbox_id = self.client.inbox_id(), - installation_id = hex::encode(self.client.installation_id()), - "Skipping intent because no publish data returned" - ); - let deleter: &dyn Delete = provider.conn_ref(); - deleter.delete(intent.id)?; - } } - } - Ok(()) + Ok(()) + }).await } // Takes a StoredGroupIntent and returns the payload and post commit data as a tuple @@ -1210,58 +1209,58 @@ where inbox_ids_to_add: &[InboxIdRef<'_>], inbox_ids_to_remove: &[InboxIdRef<'_>], ) -> Result { - let mls_group = self.load_mls_group(provider)?; - let existing_group_membership = extract_group_membership(mls_group.extensions())?; - - // TODO:nm prevent querying for updates on members who are being removed - let mut inbox_ids = existing_group_membership.inbox_ids(); - inbox_ids.extend_from_slice(inbox_ids_to_add); - let conn = provider.conn_ref(); - // Load any missing updates from the network - load_identity_updates(self.client.api(), conn, &inbox_ids).await?; - - let latest_sequence_id_map = conn.get_latest_sequence_id(&inbox_ids as &[&str])?; - - // Get a list of all inbox IDs that have increased sequence_id for the group - let changed_inbox_ids = - inbox_ids - .iter() - .try_fold(HashMap::new(), |mut updates, inbox_id| { - match ( - latest_sequence_id_map.get(inbox_id as &str), - existing_group_membership.get(inbox_id), - ) { - // This is an update. We have a new sequence ID and an existing one - (Some(latest_sequence_id), Some(current_sequence_id)) => { - let latest_sequence_id_u64 = *latest_sequence_id as u64; - if latest_sequence_id_u64.gt(current_sequence_id) { - updates.insert(inbox_id.to_string(), latest_sequence_id_u64); + self.load_mls_group_with_lock_async(provider, |mut mls_group| async move { + let existing_group_membership = extract_group_membership(mls_group.extensions())?; + // TODO:nm prevent querying for updates on members who are being removed + let mut inbox_ids = existing_group_membership.inbox_ids(); + inbox_ids.extend_from_slice(inbox_ids_to_add); + let conn = provider.conn_ref(); + // Load any missing updates from the network + load_identity_updates(self.client.api(), conn, &inbox_ids).await?; + + let latest_sequence_id_map = conn.get_latest_sequence_id(&inbox_ids as &[&str])?; + + // Get a list of all inbox IDs that have increased sequence_id for the group + let changed_inbox_ids = + inbox_ids + .iter() + .try_fold(HashMap::new(), |mut updates, inbox_id| { + match ( + latest_sequence_id_map.get(inbox_id as &str), + existing_group_membership.get(inbox_id), + ) { + // This is an update. We have a new sequence ID and an existing one + (Some(latest_sequence_id), Some(current_sequence_id)) => { + let latest_sequence_id_u64 = *latest_sequence_id as u64; + if latest_sequence_id_u64.gt(current_sequence_id) { + updates.insert(inbox_id.to_string(), latest_sequence_id_u64); + } + } + // This is for new additions to the group + (Some(latest_sequence_id), _) => { + // This is the case for net new members to the group + updates.insert(inbox_id.to_string(), *latest_sequence_id as u64); + } + (_, _) => { + tracing::warn!( + "Could not find existing sequence ID for inbox {}", + inbox_id + ); + return Err(GroupError::MissingSequenceId); } } - // This is for new additions to the group - (Some(latest_sequence_id), _) => { - // This is the case for net new members to the group - updates.insert(inbox_id.to_string(), *latest_sequence_id as u64); - } - (_, _) => { - tracing::warn!( - "Could not find existing sequence ID for inbox {}", - inbox_id - ); - return Err(GroupError::MissingSequenceId); - } - } - Ok(updates) - })?; + Ok(updates) + })?; - Ok(UpdateGroupMembershipIntentData::new( - changed_inbox_ids, - inbox_ids_to_remove - .iter() - .map(|s| s.to_string()) - .collect::>(), - )) + Ok(UpdateGroupMembershipIntentData::new( + changed_inbox_ids, + inbox_ids_to_remove + .iter() + .map(|s| s.to_string()) + .collect::>(), + )) + }).await } /** diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index 48d90f168..285832b89 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -57,6 +57,7 @@ use self::{ group_permissions::PolicySet, validated_commit::CommitValidationError, }; +use futures::TryFutureExt; use std::future::Future; use std::{collections::HashSet, sync::Arc}; use xmtp_cryptography::signature::{sanitize_evm_addresses, AddressValidationError}; @@ -72,6 +73,8 @@ use xmtp_proto::xmtp::mls::{ }, }; +use crate::hpke::HpkeError::StorageError; +use crate::storage::sql_key_store::SqlKeyStoreError; use crate::{ api::WrappedApiError, client::{deserialize_welcome, ClientError, XmtpMlsLocalContext}, @@ -100,8 +103,6 @@ use crate::{ xmtp_openmls_provider::XmtpOpenMlsProvider, GroupCommitLock, Store, MLS_COMMIT_LOCK, }; -use crate::hpke::HpkeError::StorageError; -use crate::storage::sql_key_store::SqlKeyStoreError; #[derive(Debug, Error)] pub enum GroupError { @@ -370,17 +371,15 @@ impl MlsGroup { } #[tracing::instrument(level = "trace", skip_all)] - pub(crate) async fn load_mls_group_with_lock_async( + pub(crate) async fn load_mls_group_with_lock_async( &self, - provider: impl OpenMlsProvider, + provider: &XmtpOpenMlsProvider, operation: F, ) -> Result where - F: FnOnce(OpenMlsGroup) -> Fut + Send + 'static, - Fut: Future>, - E: From, - E: Into, - E: From + F: FnOnce(OpenMlsGroup) -> Fut + Send, + Fut: Future>, + E: From + From, { // Get the group ID for locking let group_id = self.group_id.clone(); @@ -390,8 +389,9 @@ impl MlsGroup { // Load the MLS group let mls_group = - OpenMlsGroup::load(provider.storage(), &GroupId::from_slice(&self.group_id))? - .ok_or(|e|GroupMessageProcessingError::Storage)?; + OpenMlsGroup::load(provider.storage(), &GroupId::from_slice(&self.group_id)) + .map_err(crate::StorageError::from)? + .ok_or(crate::StorageError::NotFound("Group Not Found".into()))?; // Perform the operation with the MLS group operation(mls_group).await.map_err(Into::into) @@ -1178,8 +1178,9 @@ impl MlsGroup { /// Get the `GroupMetadata` of the group. pub fn metadata(&self, provider: impl OpenMlsProvider) -> Result { - let mls_group = self.load_mls_group(provider)?; - Ok(extract_group_metadata(&mls_group)?) + self.load_mls_group_with_lock(provider, |mls_group| { + Ok(extract_group_metadata(&mls_group)?) + }) } /// Get the `GroupMutableMetadata` of the group. @@ -1187,17 +1188,18 @@ impl MlsGroup { &self, provider: impl OpenMlsProvider, ) -> Result { - let mls_group = &self.load_mls_group(provider)?; - - Ok(mls_group.try_into()?) + self.load_mls_group_with_lock(provider, |mls_group| { + Ok(GroupMutableMetadata::try_from(&mls_group)?) + }) } pub fn permissions(&self) -> Result { let conn = self.context().store().conn()?; let provider = XmtpOpenMlsProvider::new(conn); - let mls_group = self.load_mls_group(&provider)?; - Ok(extract_group_permissions(&mls_group)?) + self.load_mls_group_with_lock(&provider, |mls_group| { + Ok(extract_group_permissions(&mls_group)?) + }) } /// Used for testing that dm group validation works as expected. @@ -1918,15 +1920,19 @@ pub(crate) mod tests { // Check Amal's MLS group state. let amal_db = XmtpOpenMlsProvider::from(amal.context.store().conn().unwrap()); - let amal_mls_group = amal_group.load_mls_group(&amal_db).unwrap(); - let amal_members: Vec = amal_mls_group.members().collect(); - assert_eq!(amal_members.len(), 3); + let amal_members_len = amal_group.load_mls_group_with_lock(&amal_db, |amal_mls_group| { + Ok(amal_mls_group.members().count()) + }).unwrap(); + + assert_eq!(amal_members_len, 3); // Check Bola's MLS group state. let bola_db = XmtpOpenMlsProvider::from(bola.context.store().conn().unwrap()); - let bola_mls_group = bola_group.load_mls_group(&bola_db).unwrap(); - let bola_members: Vec = bola_mls_group.members().collect(); - assert_eq!(bola_members.len(), 3); + let bola_members_len = bola_group.load_mls_group_with_lock(&bola_db, |bola_mls_group| { + Ok(bola_mls_group.members().count()) + }).unwrap(); + + assert_eq!(bola_members_len, 3); let amal_uncommitted_intents = amal_db .conn_ref() @@ -1985,22 +1991,26 @@ pub(crate) mod tests { .unwrap(); let provider = alix.mls_provider().unwrap(); // Doctor the group membership - let mut mls_group = alix_group.load_mls_group(&provider).unwrap(); - let mut existing_extensions = mls_group.extensions().clone(); - let mut group_membership = GroupMembership::new(); - group_membership.add("deadbeef".to_string(), 1); - existing_extensions.add_or_replace(build_group_membership_extension(&group_membership)); - mls_group - .update_group_context_extensions( - &provider, - existing_extensions.clone(), - &alix.identity().installation_keys, - ) - .unwrap(); - mls_group.merge_pending_commit(&provider).unwrap(); + let mut mls_group = alix_group.load_mls_group_with_lock(&provider, |mut mls_group| { + let mut existing_extensions = mls_group.extensions().clone(); + let mut group_membership = GroupMembership::new(); + group_membership.add("deadbeef".to_string(), 1); + existing_extensions.add_or_replace(build_group_membership_extension(&group_membership)); + + mls_group + .update_group_context_extensions( + &provider, + existing_extensions.clone(), + &alix.identity().installation_keys, + ) + .unwrap(); + mls_group.merge_pending_commit(&provider).unwrap(); + + Ok(mls_group) // Return the updated group if necessary + }).unwrap(); - // Now add bo to the group force_add_member(&alix, &bo, &alix_group, &mut mls_group, &provider).await; + // Now add bo to the group // Bo should not be able to actually read this group bo.sync_welcomes(&bo.store().conn().unwrap()).await.unwrap(); @@ -2124,9 +2134,11 @@ pub(crate) mod tests { assert_eq!(messages.len(), 2); let provider: XmtpOpenMlsProvider = client.context.store().conn().unwrap().into(); - let mls_group = group.load_mls_group(&provider).unwrap(); - let pending_commit = mls_group.pending_commit(); - assert!(pending_commit.is_none()); + let pending_commit_is_none = group.load_mls_group_with_lock(&provider, |mls_group| { + Ok(mls_group.pending_commit().is_none()) + }).unwrap(); + + assert!(pending_commit_is_none); group.send_message(b"hello").await.expect("send message"); @@ -2305,8 +2317,10 @@ pub(crate) mod tests { assert!(new_installations_were_added.is_ok()); group.sync().await.unwrap(); - let mls_group = group.load_mls_group(&provider).unwrap(); - let num_members = mls_group.members().collect::>().len(); + let num_members = group.load_mls_group_with_lock(&provider, |mls_group| { + Ok(mls_group.members().collect::>().len()) + }).unwrap(); + assert_eq!(num_members, 3); } @@ -3884,14 +3898,11 @@ pub(crate) mod tests { None, ) .unwrap(); - assert!(validate_dm_group( - &client, - &valid_dm_group - .load_mls_group(client.mls_provider().unwrap()) - .unwrap(), - added_by_inbox - ) - .is_ok()); + assert!(valid_dm_group + .load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group| { + validate_dm_group(&client, &mls_group, added_by_inbox) + }) + .is_ok()); // Test case 2: Invalid conversation type let invalid_protected_metadata = @@ -3906,10 +3917,11 @@ pub(crate) mod tests { ) .unwrap(); assert!(matches!( - validate_dm_group(&client, &invalid_type_group.load_mls_group(client.mls_provider().unwrap()).unwrap(), added_by_inbox), + invalid_type_group.load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group| { + validate_dm_group(&client, &mls_group, added_by_inbox) + }), Err(GroupError::Generic(msg)) if msg.contains("Invalid conversation type") )); - // Test case 3: Missing DmMembers // This case is not easily testable with the current structure, as DmMembers are set in the protected metadata @@ -3927,7 +3939,9 @@ pub(crate) mod tests { ) .unwrap(); assert!(matches!( - validate_dm_group(&client, &mismatched_dm_members_group.load_mls_group(client.mls_provider().unwrap()).unwrap(), added_by_inbox), + mismatched_dm_members_group.load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group| { + validate_dm_group(&client, &mls_group, added_by_inbox) + }), Err(GroupError::Generic(msg)) if msg.contains("DM members do not match expected inboxes") )); @@ -3947,7 +3961,9 @@ pub(crate) mod tests { ) .unwrap(); assert!(matches!( - validate_dm_group(&client, &non_empty_admin_list_group.load_mls_group(client.mls_provider().unwrap()).unwrap(), added_by_inbox), + non_empty_admin_list_group.load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group| { + validate_dm_group(&client, &mls_group, added_by_inbox) + }), Err(GroupError::Generic(msg)) if msg.contains("DM group must have empty admin and super admin lists") )); @@ -3966,11 +3982,9 @@ pub(crate) mod tests { ) .unwrap(); assert!(matches!( - validate_dm_group( - &client, - &invalid_permissions_group.load_mls_group(client.mls_provider().unwrap()).unwrap(), - added_by_inbox - ), + invalid_permissions_group.load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group| { + validate_dm_group(&client, &mls_group, added_by_inbox) + }), Err(GroupError::Generic(msg)) if msg.contains("Invalid permissions for DM group") )); } diff --git a/xmtp_mls/src/groups/subscriptions.rs b/xmtp_mls/src/groups/subscriptions.rs index 8e76d437c..b99821487 100644 --- a/xmtp_mls/src/groups/subscriptions.rs +++ b/xmtp_mls/src/groups/subscriptions.rs @@ -46,17 +46,20 @@ impl MlsGroup { .store() .transaction_async(|provider| async move { let prov_ref = &provider; // Borrow provider instead of moving it - self.load_mls_group_with_lock_async(prov_ref, |mut mls_group| async move { - // Attempt processing immediately, but fail if the message is not an Application Message - // Returning an error should roll back the DB tx - self.process_message(&mut mls_group, &prov_ref, msgv1, false) - .await - // NOTE: We want to make sure we retry an error in process_message - .map_err(SubscribeError::ReceiveGroup) - }).await + self.load_mls_group_with_lock_async( + prov_ref, + |mut mls_group| async move { + // Attempt processing immediately, but fail if the message is not an Application Message + // Returning an error should roll back the DB tx + self.process_message(&prov_ref, msgv1, false) + .await + // NOTE: We want to make sure we retry an error in process_message + .map_err(SubscribeError::ReceiveGroup) + }, + ) + .await }) .await - }) ); diff --git a/xmtp_mls/src/storage/errors.rs b/xmtp_mls/src/storage/errors.rs index f109b0248..50c630df2 100644 --- a/xmtp_mls/src/storage/errors.rs +++ b/xmtp_mls/src/storage/errors.rs @@ -5,7 +5,7 @@ use thiserror::Error; use crate::{groups::intents::IntentError, retry::RetryableError, retryable}; -use super::sql_key_store; +use super::sql_key_store::{self, SqlKeyStoreError}; #[derive(Debug, Error)] pub enum StorageError { @@ -25,6 +25,7 @@ pub enum StorageError { Serialization(String), #[error("deserialization error")] Deserialization(String), + // TODO:insipx Make NotFound into an enum of possible items that may not be found #[error("{0} not found")] NotFound(String), #[error("lock")] @@ -43,6 +44,8 @@ pub enum StorageError { FromHex(#[from] hex::FromHexError), #[error(transparent)] Duplicate(DuplicateItem), + #[error(transparent)] + OpenMlsStorage(#[from] SqlKeyStoreError), } #[derive(Error, Debug)] From e5b2bf1bff451e7d5eb949fdbed75edb7402cb4d Mon Sep 17 00:00:00 2001 From: Mojtaba Chenani Date: Mon, 9 Dec 2024 22:17:34 +0100 Subject: [PATCH 03/28] wip --- bindings_node/src/conversation.rs | 2 +- xmtp_mls/src/client.rs | 4 +- xmtp_mls/src/groups/intents.rs | 6 +-- xmtp_mls/src/groups/members.rs | 1 - xmtp_mls/src/groups/mls_sync.rs | 29 ++++++++--- xmtp_mls/src/groups/mod.rs | 73 ++++++++++++---------------- xmtp_mls/src/groups/subscriptions.rs | 10 +++- 7 files changed, 67 insertions(+), 58 deletions(-) diff --git a/bindings_node/src/conversation.rs b/bindings_node/src/conversation.rs index d7df6a4fa..aeeffa6b1 100644 --- a/bindings_node/src/conversation.rs +++ b/bindings_node/src/conversation.rs @@ -1,5 +1,5 @@ use std::{ops::Deref, sync::Arc}; -use futures::TryFutureExt; + use napi::{ bindgen_prelude::{Result, Uint8Array}, threadsafe_function::{ErrorStrategy, ThreadsafeFunction, ThreadsafeFunctionCallMode}, diff --git a/xmtp_mls/src/client.rs b/xmtp_mls/src/client.rs index 82a317dc6..4410f8c63 100644 --- a/xmtp_mls/src/client.rs +++ b/xmtp_mls/src/client.rs @@ -845,8 +845,8 @@ where async move { tracing::info!( inbox_id = self.inbox_id(), - "current epoch for [{}] in sync_all_groups()", - self.inbox_id(), + "[{}] syncing group", + self.inbox_id() ); tracing::info!( inbox_id = self.inbox_id(), diff --git a/xmtp_mls/src/groups/intents.rs b/xmtp_mls/src/groups/intents.rs index bbb4e06c0..65fdd32f7 100644 --- a/xmtp_mls/src/groups/intents.rs +++ b/xmtp_mls/src/groups/intents.rs @@ -725,7 +725,6 @@ impl TryFrom> for PostCommitAction { pub(crate) mod tests { #[cfg(target_arch = "wasm32")] wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_dedicated_worker); - use openmls::prelude::{MlsMessageBodyIn, MlsMessageIn, ProcessedMessageContent}; use tls_codec::Deserialize; use xmtp_cryptography::utils::generate_local_wallet; @@ -866,9 +865,8 @@ pub(crate) mod tests { let provider = group.client.mls_provider().unwrap(); let decrypted_message = match group .load_mls_group_with_lock(&provider, |mut mls_group| { - mls_group - .process_message(&provider, mls_message) - .map_err(|e| GroupError::Generic(e.to_string())) + Ok(mls_group + .process_message(&provider, mls_message).unwrap()) }) { Ok(message) => message, Err(err) => panic!("Error: {:?}", err), diff --git a/xmtp_mls/src/groups/members.rs b/xmtp_mls/src/groups/members.rs index 730529fb3..cfdf56e28 100644 --- a/xmtp_mls/src/groups/members.rs +++ b/xmtp_mls/src/groups/members.rs @@ -41,7 +41,6 @@ where provider: &XmtpOpenMlsProvider, ) -> Result, GroupError> { let group_membership = self.load_mls_group_with_lock(provider, |mls_group| { - // Extract group membership from extensions Ok(extract_group_membership(mls_group.extensions())?) })?; let requests = group_membership diff --git a/xmtp_mls/src/groups/mls_sync.rs b/xmtp_mls/src/groups/mls_sync.rs index 1a1f50b2d..cd450f11f 100644 --- a/xmtp_mls/src/groups/mls_sync.rs +++ b/xmtp_mls/src/groups/mls_sync.rs @@ -360,9 +360,8 @@ where if intent.state == IntentState::Committed { return Ok(IntentState::Committed); } - let group_epoch = mls_group.epoch(); - let message_epoch = message.epoch(); + let group_epoch = mls_group.epoch(); debug!( inbox_id = self.client.inbox_id(), installation_id = hex::encode(self.client.installation_id()), @@ -702,6 +701,14 @@ where let intent = provider .conn_ref() .find_group_intent_by_payload_hash(sha256(envelope.data.as_slice())); + tracing::info!( + inbox_id = self.client.inbox_id(), + group_id = hex::encode(&self.group_id), + msg_id = envelope.id, + "Processing envelope with hash {:?}", + hex::encode(sha256(envelope.data.as_slice())) + ); + match intent { // Intent with the payload hash matches Ok(Some(intent)) => { @@ -729,6 +736,14 @@ where } // No matching intent found Ok(None) => { + tracing::info!( + inbox_id = self.client.inbox_id(), + group_id = hex::encode(&self.group_id), + msg_id = envelope.id, + "client [{}] is about to process external envelope [{}]", + self.client.inbox_id(), + envelope.id + ); self.process_external_message(provider, message, envelope) .await } @@ -792,7 +807,10 @@ where for message in messages.into_iter() { let result = retry_async!( Retry::default(), - (async { self.consume_message(&message, provider.conn_ref()).await }) + (async { + self.consume_message(&message, provider.conn_ref()) + .await + }) ); if let Err(e) = result { let is_retryable = e.is_retryable(); @@ -1139,10 +1157,7 @@ where return Ok(()); } // determine how long of an interval in time to use before updating list - let interval_ns = match update_interval_ns { - Some(val) => val, - None => SYNC_UPDATE_INSTALLATIONS_INTERVAL_NS, - }; + let interval_ns = update_interval_ns.unwrap_or_else(|| SYNC_UPDATE_INSTALLATIONS_INTERVAL_NS); let now_ns = crate::utils::time::now_ns(); let last_ns = provider diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index 285832b89..87739dde9 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -332,19 +332,6 @@ impl MlsGroup { } // Load the stored OpenMLS group from the OpenMLS provider's keystore - #[tracing::instrument(level = "trace", skip_all)] - pub(crate) fn load_mls_group( - &self, - provider: impl OpenMlsProvider, - ) -> Result { - let mls_group = - OpenMlsGroup::load(provider.storage(), &GroupId::from_slice(&self.group_id)) - .map_err(|_| GroupError::GroupNotFound)? - .ok_or(GroupError::GroupNotFound)?; - - Ok(mls_group) - } - #[tracing::instrument(level = "trace", skip_all)] pub(crate) fn load_mls_group_with_lock( &self, @@ -358,7 +345,6 @@ impl MlsGroup { let group_id = self.group_id.clone(); // Acquire the lock synchronously using blocking_lock - let _lock = MLS_COMMIT_LOCK.get_lock_sync(group_id.clone())?; // Load the MLS group let mls_group = @@ -370,6 +356,7 @@ impl MlsGroup { operation(mls_group) } + // Load the stored OpenMLS group from the OpenMLS provider's keystore #[tracing::instrument(level = "trace", skip_all)] pub(crate) async fn load_mls_group_with_lock_async( &self, @@ -1173,14 +1160,16 @@ impl MlsGroup { /// /// If the current user has been kicked out of the group, `is_active` will return `false` pub fn is_active(&self, provider: impl OpenMlsProvider) -> Result { - self.load_mls_group_with_lock(provider, |mls_group| Ok(mls_group.is_active())) + self.load_mls_group_with_lock(provider, |mls_group| + Ok(mls_group.is_active()) + ) } /// Get the `GroupMetadata` of the group. pub fn metadata(&self, provider: impl OpenMlsProvider) -> Result { - self.load_mls_group_with_lock(provider, |mls_group| { + self.load_mls_group_with_lock(provider, |mls_group| Ok(extract_group_metadata(&mls_group)?) - }) + ) } /// Get the `GroupMutableMetadata` of the group. @@ -1188,18 +1177,18 @@ impl MlsGroup { &self, provider: impl OpenMlsProvider, ) -> Result { - self.load_mls_group_with_lock(provider, |mls_group| { + self.load_mls_group_with_lock(provider, |mls_group| Ok(GroupMutableMetadata::try_from(&mls_group)?) - }) + ) } pub fn permissions(&self) -> Result { let conn = self.context().store().conn()?; let provider = XmtpOpenMlsProvider::new(conn); - self.load_mls_group_with_lock(&provider, |mls_group| { + self.load_mls_group_with_lock(&provider, |mls_group| Ok(extract_group_permissions(&mls_group)?) - }) + ) } /// Used for testing that dm group validation works as expected. @@ -1920,17 +1909,17 @@ pub(crate) mod tests { // Check Amal's MLS group state. let amal_db = XmtpOpenMlsProvider::from(amal.context.store().conn().unwrap()); - let amal_members_len = amal_group.load_mls_group_with_lock(&amal_db, |amal_mls_group| { - Ok(amal_mls_group.members().count()) - }).unwrap(); + let amal_members_len = amal_group.load_mls_group_with_lock(&amal_db, |mls_group| + Ok(mls_group.members().count()) + ).unwrap(); assert_eq!(amal_members_len, 3); // Check Bola's MLS group state. let bola_db = XmtpOpenMlsProvider::from(bola.context.store().conn().unwrap()); - let bola_members_len = bola_group.load_mls_group_with_lock(&bola_db, |bola_mls_group| { - Ok(bola_mls_group.members().count()) - }).unwrap(); + let bola_members_len = bola_group.load_mls_group_with_lock(&bola_db, |mls_group| + Ok(mls_group.members().count()) + ).unwrap(); assert_eq!(bola_members_len, 3); @@ -2009,8 +1998,8 @@ pub(crate) mod tests { Ok(mls_group) // Return the updated group if necessary }).unwrap(); - force_add_member(&alix, &bo, &alix_group, &mut mls_group, &provider).await; // Now add bo to the group + force_add_member(&alix, &bo, &alix_group, &mut mls_group, &provider).await; // Bo should not be able to actually read this group bo.sync_welcomes(&bo.store().conn().unwrap()).await.unwrap(); @@ -2134,9 +2123,9 @@ pub(crate) mod tests { assert_eq!(messages.len(), 2); let provider: XmtpOpenMlsProvider = client.context.store().conn().unwrap().into(); - let pending_commit_is_none = group.load_mls_group_with_lock(&provider, |mls_group| { + let pending_commit_is_none = group.load_mls_group_with_lock(&provider, |mls_group| Ok(mls_group.pending_commit().is_none()) - }).unwrap(); + ).unwrap(); assert!(pending_commit_is_none); @@ -2317,9 +2306,9 @@ pub(crate) mod tests { assert!(new_installations_were_added.is_ok()); group.sync().await.unwrap(); - let num_members = group.load_mls_group_with_lock(&provider, |mls_group| { + let num_members = group.load_mls_group_with_lock(&provider, |mls_group| Ok(mls_group.members().collect::>().len()) - }).unwrap(); + ).unwrap(); assert_eq!(num_members, 3); } @@ -3899,9 +3888,9 @@ pub(crate) mod tests { ) .unwrap(); assert!(valid_dm_group - .load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group| { + .load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group| validate_dm_group(&client, &mls_group, added_by_inbox) - }) + ) .is_ok()); // Test case 2: Invalid conversation type @@ -3917,9 +3906,9 @@ pub(crate) mod tests { ) .unwrap(); assert!(matches!( - invalid_type_group.load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group| { + invalid_type_group.load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group| validate_dm_group(&client, &mls_group, added_by_inbox) - }), + ), Err(GroupError::Generic(msg)) if msg.contains("Invalid conversation type") )); // Test case 3: Missing DmMembers @@ -3939,9 +3928,9 @@ pub(crate) mod tests { ) .unwrap(); assert!(matches!( - mismatched_dm_members_group.load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group| { + mismatched_dm_members_group.load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group| validate_dm_group(&client, &mls_group, added_by_inbox) - }), + ), Err(GroupError::Generic(msg)) if msg.contains("DM members do not match expected inboxes") )); @@ -3961,9 +3950,9 @@ pub(crate) mod tests { ) .unwrap(); assert!(matches!( - non_empty_admin_list_group.load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group| { + non_empty_admin_list_group.load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group| validate_dm_group(&client, &mls_group, added_by_inbox) - }), + ), Err(GroupError::Generic(msg)) if msg.contains("DM group must have empty admin and super admin lists") )); @@ -3982,9 +3971,9 @@ pub(crate) mod tests { ) .unwrap(); assert!(matches!( - invalid_permissions_group.load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group| { + invalid_permissions_group.load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group| validate_dm_group(&client, &mls_group, added_by_inbox) - }), + ), Err(GroupError::Generic(msg)) if msg.contains("Invalid permissions for DM group") )); } diff --git a/xmtp_mls/src/groups/subscriptions.rs b/xmtp_mls/src/groups/subscriptions.rs index b99821487..64d0a6f86 100644 --- a/xmtp_mls/src/groups/subscriptions.rs +++ b/xmtp_mls/src/groups/subscriptions.rs @@ -41,6 +41,7 @@ impl MlsGroup { let process_result = retry_async!( Retry::default(), (async { + let client_id = &client_id; let msgv1 = &msgv1; self.context() .store() @@ -48,9 +49,16 @@ impl MlsGroup { let prov_ref = &provider; // Borrow provider instead of moving it self.load_mls_group_with_lock_async( prov_ref, - |mut mls_group| async move { + |mls_group| async move { // Attempt processing immediately, but fail if the message is not an Application Message // Returning an error should roll back the DB tx + tracing::info!( + inbox_id = self.client.inbox_id(), + group_id = hex::encode(&self.group_id), + msg_id = msgv1.id, + "current epoch for [{}] in process_stream_entry()", + client_id, + ); self.process_message(&prov_ref, msgv1, false) .await // NOTE: We want to make sure we retry an error in process_message From ab27ee354865920bf9948bd218bc49877027d385 Mon Sep 17 00:00:00 2001 From: Mojtaba Chenani Date: Wed, 11 Dec 2024 15:47:22 +0100 Subject: [PATCH 04/28] fixed tests --- xmtp_mls/src/client.rs | 14 ++++- xmtp_mls/src/groups/mls_sync.rs | 7 --- xmtp_mls/src/groups/mod.rs | 85 +++++++++++++++------------- xmtp_mls/src/groups/subscriptions.rs | 15 +---- xmtp_mls/src/subscriptions.rs | 4 +- 5 files changed, 59 insertions(+), 66 deletions(-) diff --git a/xmtp_mls/src/client.rs b/xmtp_mls/src/client.rs index 4410f8c63..ca1c2bdec 100644 --- a/xmtp_mls/src/client.rs +++ b/xmtp_mls/src/client.rs @@ -853,9 +853,17 @@ where "[{}] syncing group", self.inbox_id() ); - group.maybe_update_installations(provider_ref, None).await?; - group.sync_with_conn(provider_ref).await?; - active_group_count.fetch_add(1, Ordering::SeqCst); + let is_active = group + .load_mls_group_with_lock_async(provider_ref, |mls_group| async move { + Ok::(mls_group.is_active()) + }) + .await?; + if is_active { + group.maybe_update_installations(provider_ref, None).await?; + + group.sync_with_conn(provider_ref).await?; + active_group_count.fetch_add(1, Ordering::SeqCst); + } Ok::<(), GroupError>(()) } diff --git a/xmtp_mls/src/groups/mls_sync.rs b/xmtp_mls/src/groups/mls_sync.rs index cd450f11f..b3e79e834 100644 --- a/xmtp_mls/src/groups/mls_sync.rs +++ b/xmtp_mls/src/groups/mls_sync.rs @@ -189,9 +189,6 @@ where #[tracing::instrument(skip_all)] pub async fn sync_with_conn(&self, provider: &XmtpOpenMlsProvider) -> Result<(), GroupError> { // Check if we're still part of the group - if !self.is_active(provider)? { - return Ok(()); - } let _mutex = self.mutex.lock().await; let mut errors: Vec = vec![]; @@ -1152,10 +1149,6 @@ where provider: &XmtpOpenMlsProvider, update_interval_ns: Option, ) -> Result<(), GroupError> { - // Check if we're still part of the group - if !self.is_active(provider)? { - return Ok(()); - } // determine how long of an interval in time to use before updating list let interval_ns = update_interval_ns.unwrap_or_else(|| SYNC_UPDATE_INSTALLATIONS_INTERVAL_NS); diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index 87739dde9..76cdca006 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -1160,16 +1160,14 @@ impl MlsGroup { /// /// If the current user has been kicked out of the group, `is_active` will return `false` pub fn is_active(&self, provider: impl OpenMlsProvider) -> Result { - self.load_mls_group_with_lock(provider, |mls_group| - Ok(mls_group.is_active()) - ) + self.load_mls_group_with_lock(provider, |mls_group| Ok(mls_group.is_active())) } /// Get the `GroupMetadata` of the group. pub fn metadata(&self, provider: impl OpenMlsProvider) -> Result { - self.load_mls_group_with_lock(provider, |mls_group| + self.load_mls_group_with_lock(provider, |mls_group| { Ok(extract_group_metadata(&mls_group)?) - ) + }) } /// Get the `GroupMutableMetadata` of the group. @@ -1177,18 +1175,18 @@ impl MlsGroup { &self, provider: impl OpenMlsProvider, ) -> Result { - self.load_mls_group_with_lock(provider, |mls_group| + self.load_mls_group_with_lock(provider, |mls_group| { Ok(GroupMutableMetadata::try_from(&mls_group)?) - ) + }) } pub fn permissions(&self) -> Result { let conn = self.context().store().conn()?; let provider = XmtpOpenMlsProvider::new(conn); - self.load_mls_group_with_lock(&provider, |mls_group| + self.load_mls_group_with_lock(&provider, |mls_group| { Ok(extract_group_permissions(&mls_group)?) - ) + }) } /// Used for testing that dm group validation works as expected. @@ -1909,17 +1907,17 @@ pub(crate) mod tests { // Check Amal's MLS group state. let amal_db = XmtpOpenMlsProvider::from(amal.context.store().conn().unwrap()); - let amal_members_len = amal_group.load_mls_group_with_lock(&amal_db, |mls_group| - Ok(mls_group.members().count()) - ).unwrap(); + let amal_members_len = amal_group + .load_mls_group_with_lock(&amal_db, |mls_group| Ok(mls_group.members().count())) + .unwrap(); assert_eq!(amal_members_len, 3); // Check Bola's MLS group state. let bola_db = XmtpOpenMlsProvider::from(bola.context.store().conn().unwrap()); - let bola_members_len = bola_group.load_mls_group_with_lock(&bola_db, |mls_group| - Ok(mls_group.members().count()) - ).unwrap(); + let bola_members_len = bola_group + .load_mls_group_with_lock(&bola_db, |mls_group| Ok(mls_group.members().count())) + .unwrap(); assert_eq!(bola_members_len, 3); @@ -1980,23 +1978,26 @@ pub(crate) mod tests { .unwrap(); let provider = alix.mls_provider().unwrap(); // Doctor the group membership - let mut mls_group = alix_group.load_mls_group_with_lock(&provider, |mut mls_group| { - let mut existing_extensions = mls_group.extensions().clone(); - let mut group_membership = GroupMembership::new(); - group_membership.add("deadbeef".to_string(), 1); - existing_extensions.add_or_replace(build_group_membership_extension(&group_membership)); - - mls_group - .update_group_context_extensions( - &provider, - existing_extensions.clone(), - &alix.identity().installation_keys, - ) - .unwrap(); - mls_group.merge_pending_commit(&provider).unwrap(); - - Ok(mls_group) // Return the updated group if necessary - }).unwrap(); + let mut mls_group = alix_group + .load_mls_group_with_lock(&provider, |mut mls_group| { + let mut existing_extensions = mls_group.extensions().clone(); + let mut group_membership = GroupMembership::new(); + group_membership.add("deadbeef".to_string(), 1); + existing_extensions + .add_or_replace(build_group_membership_extension(&group_membership)); + + mls_group + .update_group_context_extensions( + &provider, + existing_extensions.clone(), + &alix.identity().installation_keys, + ) + .unwrap(); + mls_group.merge_pending_commit(&provider).unwrap(); + + Ok(mls_group) // Return the updated group if necessary + }) + .unwrap(); // Now add bo to the group force_add_member(&alix, &bo, &alix_group, &mut mls_group, &provider).await; @@ -2123,9 +2124,11 @@ pub(crate) mod tests { assert_eq!(messages.len(), 2); let provider: XmtpOpenMlsProvider = client.context.store().conn().unwrap().into(); - let pending_commit_is_none = group.load_mls_group_with_lock(&provider, |mls_group| - Ok(mls_group.pending_commit().is_none()) - ).unwrap(); + let pending_commit_is_none = group + .load_mls_group_with_lock(&provider, |mls_group| { + Ok(mls_group.pending_commit().is_none()) + }) + .unwrap(); assert!(pending_commit_is_none); @@ -2306,9 +2309,11 @@ pub(crate) mod tests { assert!(new_installations_were_added.is_ok()); group.sync().await.unwrap(); - let num_members = group.load_mls_group_with_lock(&provider, |mls_group| - Ok(mls_group.members().collect::>().len()) - ).unwrap(); + let num_members = group + .load_mls_group_with_lock(&provider, |mls_group| { + Ok(mls_group.members().collect::>().len()) + }) + .unwrap(); assert_eq!(num_members, 3); } @@ -3888,9 +3893,9 @@ pub(crate) mod tests { ) .unwrap(); assert!(valid_dm_group - .load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group| + .load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group| { validate_dm_group(&client, &mls_group, added_by_inbox) - ) + }) .is_ok()); // Test case 2: Invalid conversation type diff --git a/xmtp_mls/src/groups/subscriptions.rs b/xmtp_mls/src/groups/subscriptions.rs index 64d0a6f86..9cb7d7653 100644 --- a/xmtp_mls/src/groups/subscriptions.rs +++ b/xmtp_mls/src/groups/subscriptions.rs @@ -47,18 +47,7 @@ impl MlsGroup { .store() .transaction_async(|provider| async move { let prov_ref = &provider; // Borrow provider instead of moving it - self.load_mls_group_with_lock_async( - prov_ref, - |mls_group| async move { - // Attempt processing immediately, but fail if the message is not an Application Message - // Returning an error should roll back the DB tx - tracing::info!( - inbox_id = self.client.inbox_id(), - group_id = hex::encode(&self.group_id), - msg_id = msgv1.id, - "current epoch for [{}] in process_stream_entry()", - client_id, - ); + self.process_message(&prov_ref, msgv1, false) .await // NOTE: We want to make sure we retry an error in process_message @@ -66,8 +55,6 @@ impl MlsGroup { }, ) .await - }) - .await }) ); diff --git a/xmtp_mls/src/subscriptions.rs b/xmtp_mls/src/subscriptions.rs index 9a89fb014..8a7d6aeac 100644 --- a/xmtp_mls/src/subscriptions.rs +++ b/xmtp_mls/src/subscriptions.rs @@ -585,11 +585,11 @@ pub(crate) mod tests { .add_members_by_inbox_id(&[bob.inbox_id()]) .await .unwrap(); - let bob_group = bob + let bob_groups = bob .sync_welcomes(&bob.store().conn().unwrap()) .await .unwrap(); - let bob_group = bob_group.first().unwrap(); + let bob_group = bob_groups.first().unwrap(); let notify = Delivery::new(None); let notify_ptr = notify.clone(); From a8a39f6f80f9a498a94615e18e7c74c90b11b50b Mon Sep 17 00:00:00 2001 From: Mojtaba Chenani Date: Mon, 9 Dec 2024 17:43:58 +0100 Subject: [PATCH 05/28] wip --- bindings_node/src/conversation.rs | 2 +- xmtp_mls/src/client.rs | 20 +- xmtp_mls/src/groups/intents.rs | 14 +- xmtp_mls/src/groups/mls_sync.rs | 656 +++++++++++++-------------- xmtp_mls/src/groups/mod.rs | 70 ++- xmtp_mls/src/groups/subscriptions.rs | 26 +- xmtp_mls/src/lib.rs | 62 +++ 7 files changed, 481 insertions(+), 369 deletions(-) diff --git a/bindings_node/src/conversation.rs b/bindings_node/src/conversation.rs index 897662463..ee52eb9d2 100644 --- a/bindings_node/src/conversation.rs +++ b/bindings_node/src/conversation.rs @@ -1,5 +1,5 @@ use std::{ops::Deref, sync::Arc}; - +use futures::TryFutureExt; use napi::{ bindgen_prelude::{Result, Uint8Array}, threadsafe_function::{ErrorStrategy, ThreadsafeFunction, ThreadsafeFunctionCallMode}, diff --git a/xmtp_mls/src/client.rs b/xmtp_mls/src/client.rs index 819356dc4..5f9fdc82c 100644 --- a/xmtp_mls/src/client.rs +++ b/xmtp_mls/src/client.rs @@ -863,25 +863,19 @@ where .map(|group| { let active_group_count = Arc::clone(&active_group_count); async move { - let mls_group = group.load_mls_group(provider)?; tracing::info!( inbox_id = self.inbox_id(), - "[{}] syncing group", - self.inbox_id() + "current epoch for [{}] in sync_all_groups()", + self.inbox_id(), ); tracing::info!( inbox_id = self.inbox_id(), - group_epoch = mls_group.epoch().as_u64(), - "current epoch for [{}] in sync_all_groups() is Epoch: [{}]", - self.inbox_id(), - mls_group.epoch() + "[{}] syncing group", + self.inbox_id() ); - if mls_group.is_active() { - group.maybe_update_installations(provider, None).await?; - - group.sync_with_conn(provider).await?; - active_group_count.fetch_add(1, Ordering::SeqCst); - } + group.maybe_update_installations(provider, None).await?; + group.sync_with_conn(provider).await?; + active_group_count.fetch_add(1, Ordering::SeqCst); Ok::<(), GroupError>(()) } diff --git a/xmtp_mls/src/groups/intents.rs b/xmtp_mls/src/groups/intents.rs index b3faf11b4..7e4276cf0 100644 --- a/xmtp_mls/src/groups/intents.rs +++ b/xmtp_mls/src/groups/intents.rs @@ -727,6 +727,7 @@ impl TryFrom> for PostCommitAction { pub(crate) mod tests { #[cfg(target_arch = "wasm32")] wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_dedicated_worker); + use openmls::prelude::{MlsMessageBodyIn, MlsMessageIn, ProcessedMessageContent}; use tls_codec::Deserialize; use xmtp_cryptography::utils::generate_local_wallet; @@ -865,10 +866,15 @@ pub(crate) mod tests { }; let provider = group.client.mls_provider().unwrap(); - let mut openmls_group = group.load_mls_group(&provider).unwrap(); - let decrypted_message = openmls_group - .process_message(&provider, mls_message) - .unwrap(); + let decrypted_message = match group + .load_mls_group_with_lock(&provider, |mut mls_group| { + mls_group + .process_message(&provider, mls_message) + .map_err(|e| GroupError::Generic(e.to_string())) + }) { + Ok(message) => message, + Err(err) => panic!("Error: {:?}", err), + }; let staged_commit = match decrypted_message.into_content() { ProcessedMessageContent::StagedCommitMessage(staged_commit) => *staged_commit, diff --git a/xmtp_mls/src/groups/mls_sync.rs b/xmtp_mls/src/groups/mls_sync.rs index d1131d40c..3be8e70b6 100644 --- a/xmtp_mls/src/groups/mls_sync.rs +++ b/xmtp_mls/src/groups/mls_sync.rs @@ -179,7 +179,6 @@ where inbox_id = self.client.inbox_id(), installation_id = %self.client.installation_id(), group_id = hex::encode(&self.group_id), - current_epoch = self.load_mls_group(&mls_provider)?.epoch().as_u64(), "[{}] syncing group", self.client.inbox_id() ); @@ -187,10 +186,8 @@ where inbox_id = self.client.inbox_id(), installation_id = %self.client.installation_id(), group_id = hex::encode(&self.group_id), - current_epoch = self.load_mls_group(&mls_provider)?.epoch().as_u64(), - "current epoch for [{}] in sync() is Epoch: [{}]", + "current epoch for [{}] in sync()", self.client.inbox_id(), - self.load_mls_group(&mls_provider)?.epoch() ); self.maybe_update_installations(&mls_provider, None).await?; @@ -200,6 +197,10 @@ where // TODO: Should probably be renamed to `sync_with_provider` #[tracing::instrument(skip_all)] pub async fn sync_with_conn(&self, provider: &XmtpOpenMlsProvider) -> Result<(), GroupError> { + // Check if we're still part of the group + if !self.is_active(provider)? { + return Ok(()); + } let _mutex = self.mutex.lock().await; let mut errors: Vec = vec![]; @@ -354,351 +355,345 @@ where async fn process_own_message( &self, intent: StoredGroupIntent, - openmls_group: &mut OpenMlsGroup, provider: &XmtpOpenMlsProvider, message: ProtocolMessage, envelope: &GroupMessageV1, ) -> Result { - let GroupMessageV1 { - created_ns: envelope_timestamp_ns, - id: ref msg_id, - .. - } = *envelope; - - if intent.state == IntentState::Committed { - return Ok(IntentState::Committed); - } - let message_epoch = message.epoch(); - let group_epoch = openmls_group.epoch(); - debug!( - inbox_id = self.client.inbox_id(), - installation_id = %self.client.installation_id(), - group_id = hex::encode(&self.group_id), - current_epoch = openmls_group.epoch().as_u64(), - msg_id, - intent.id, - intent.kind = %intent.kind, - "[{}]-[{}] processing own message for intent {} / {:?}, group epoch: {}, message_epoch: {}", - self.context().inbox_id(), - hex::encode(self.group_id.clone()), - intent.id, - intent.kind, - group_epoch, - message_epoch - ); + self.load_mls_group_with_lock_async(provider, |mut mls_group| async move { + let GroupMessageV1 { + created_ns: envelope_timestamp_ns, + id: ref msg_id, + .. + } = *envelope; + + if intent.state == IntentState::Committed { + return Ok(IntentState::Committed); + } + let group_epoch = mls_group.epoch(); - let conn = provider.conn_ref(); - match intent.kind { - IntentKind::KeyUpdate - | IntentKind::UpdateGroupMembership - | IntentKind::UpdateAdminList - | IntentKind::MetadataUpdate - | IntentKind::UpdatePermission => { - if let Some(published_in_epoch) = intent.published_in_epoch { - let published_in_epoch_u64 = published_in_epoch as u64; - let group_epoch_u64 = group_epoch.as_u64(); - - if published_in_epoch_u64 != group_epoch_u64 { - tracing::warn!( - inbox_id = self.client.inbox_id(), - installation_id = %self.client.installation_id(), + let message_epoch = message.epoch(); + debug!( + inbox_id = self.client.inbox_id(), + installation_id = %self.client.installation_id(), + group_id = hex::encode(&self.group_id), + msg_id, + intent.id, + intent.kind = %intent.kind, + "[{}]-[{}] processing own message for intent {} / {:?}, message_epoch: {}", + self.context().inbox_id(), + hex::encode(self.group_id.clone()), + intent.id, + intent.kind, + message_epoch + ); + + let conn = provider.conn_ref(); + match intent.kind { + IntentKind::KeyUpdate + | IntentKind::UpdateGroupMembership + | IntentKind::UpdateAdminList + | IntentKind::MetadataUpdate + | IntentKind::UpdatePermission => { + if let Some(published_in_epoch) = intent.published_in_epoch { + let published_in_epoch_u64 = published_in_epoch as u64; + let group_epoch_u64 = group_epoch.as_u64(); + + if published_in_epoch_u64 != group_epoch_u64 { + tracing::warn!( + inbox_id = self.client.inbox_id(), + installation_id = %self.client.installation_id(), group_id = hex::encode(&self.group_id), - current_epoch = openmls_group.epoch().as_u64(), - msg_id, - intent.id, - intent.kind = %intent.kind, - "Intent was published in epoch {} but group is currently in epoch {}", - published_in_epoch_u64, - group_epoch_u64 - ); - return Ok(IntentState::ToPublish); + msg_id, + intent.id, + intent.kind = %intent.kind, + "Intent was published in epoch {} but group is currently", + published_in_epoch_u64, + ); + return Ok(IntentState::ToPublish); + } } - } - - let pending_commit = if let Some(staged_commit) = intent.staged_commit { - decode_staged_commit(staged_commit)? - } else { - return Err(GroupMessageProcessingError::IntentMissingStagedCommit); - }; - tracing::info!( - "[{}] Validating commit for intent {}. Message timestamp: {}", - self.context().inbox_id(), - intent.id, - envelope_timestamp_ns - ); - - let maybe_validated_commit = ValidatedCommit::from_staged_commit( - self.client.as_ref(), - conn, - &pending_commit, - openmls_group, - ) - .await; + let pending_commit = if let Some(staged_commit) = intent.staged_commit { + decode_staged_commit(staged_commit)? + } else { + return Err(GroupMessageProcessingError::IntentMissingStagedCommit); + }; - if let Err(err) = maybe_validated_commit { - tracing::error!( - "Error validating commit for own message. Intent ID [{}]: {:?}", + tracing::info!( + "[{}] Validating commit for intent {}. Message timestamp: {}", + self.context().inbox_id(), intent.id, - err + envelope_timestamp_ns ); - // Return before merging commit since it does not pass validation - // Return OK so that the group intent update is still written to the DB - return Ok(IntentState::Error); - } - let validated_commit = maybe_validated_commit.expect("Checked for error"); + let maybe_validated_commit = ValidatedCommit::from_staged_commit( + self.client.as_ref(), + conn, + &pending_commit, + &mls_group, + ) + .await; - tracing::info!( - "[{}] merging pending commit for intent {}", - self.context().inbox_id(), - intent.id - ); - if let Err(err) = openmls_group.merge_staged_commit(&provider, pending_commit) { - tracing::error!("error merging commit: {}", err); - return Ok(IntentState::ToPublish); - } else { - // If no error committing the change, write a transcript message - self.save_transcript_message(conn, validated_commit, envelope_timestamp_ns)?; - } - } - IntentKind::SendMessage => { - if !Self::is_valid_epoch( - self.context().inbox_id(), - intent.id, - group_epoch, - message_epoch, - MAX_PAST_EPOCHS, - ) { - return Ok(IntentState::ToPublish); + if let Err(err) = maybe_validated_commit { + tracing::error!( + "Error validating commit for own message. Intent ID [{}]: {:?}", + intent.id, + err + ); + // Return before merging commit since it does not pass validation + // Return OK so that the group intent update is still written to the DB + return Ok(IntentState::Error); + } + + let validated_commit = maybe_validated_commit.expect("Checked for error"); + + tracing::info!( + "[{}] merging pending commit for intent {}", + self.context().inbox_id(), + intent.id + ); + if let Err(err) = mls_group.merge_staged_commit(&provider, pending_commit) { + tracing::error!("error merging commit: {}", err); + return Ok(IntentState::ToPublish); + } else { + // If no error committing the change, write a transcript message + self.save_transcript_message(conn, validated_commit, envelope_timestamp_ns)?; + } } - if let Some(id) = intent.message_id()? { - conn.set_delivery_status_to_published(&id, envelope_timestamp_ns)?; + IntentKind::SendMessage => { + if !Self::is_valid_epoch( + self.context().inbox_id(), + intent.id, + group_epoch, + message_epoch, + MAX_PAST_EPOCHS, + ) { + return Ok(IntentState::ToPublish); + } + if let Some(id) = intent.message_id()? { + conn.set_delivery_status_to_published(&id, envelope_timestamp_ns)?; + } } - } - }; + }; - Ok(IntentState::Committed) + Ok(IntentState::Committed) + }).await } #[tracing::instrument(level = "trace", skip_all)] async fn process_external_message( &self, - openmls_group: &mut OpenMlsGroup, provider: &XmtpOpenMlsProvider, message: PrivateMessageIn, envelope: &GroupMessageV1, ) -> Result<(), GroupMessageProcessingError> { - let GroupMessageV1 { - created_ns: envelope_timestamp_ns, - id: ref msg_id, - .. - } = *envelope; + self.load_mls_group_with_lock_async(provider, |mut mls_group| async move { + let GroupMessageV1 { + created_ns: envelope_timestamp_ns, + id: ref msg_id, + .. + } = *envelope; - let decrypted_message = openmls_group.process_message(provider, message)?; - let (sender_inbox_id, sender_installation_id) = - extract_message_sender(openmls_group, &decrypted_message, envelope_timestamp_ns)?; + let decrypted_message = mls_group.process_message(provider, message)?; + let (sender_inbox_id, sender_installation_id) = + extract_message_sender(&mut mls_group, &decrypted_message, envelope_timestamp_ns)?; - tracing::info!( - inbox_id = self.client.inbox_id(), - installation_id = %self.client.installation_id(), - sender_inbox_id = sender_inbox_id, - sender_installation_id = hex::encode(&sender_installation_id), - group_id = hex::encode(&self.group_id), - current_epoch = openmls_group.epoch().as_u64(), - msg_epoch = decrypted_message.epoch().as_u64(), - msg_group_id = hex::encode(decrypted_message.group_id().as_slice()), - msg_id, - "[{}] extracted sender inbox id: {}", - self.client.inbox_id(), - sender_inbox_id - ); + tracing::info!( + inbox_id = self.client.inbox_id(), + installation_id = %self.client.installation_id(),sender_inbox_id = sender_inbox_id, + sender_installation_id = hex::encode(&sender_installation_id), + group_id = hex::encode(&self.group_id), + current_epoch = mls_group.epoch().as_u64(), + msg_epoch = decrypted_message.epoch().as_u64(), + msg_group_id = hex::encode(decrypted_message.group_id().as_slice()), + msg_id, + "[{}] extracted sender inbox id: {}", + self.client.inbox_id(), + sender_inbox_id + ); - let (msg_epoch, msg_group_id) = ( - decrypted_message.epoch().as_u64(), - hex::encode(decrypted_message.group_id().as_slice()), - ); - match decrypted_message.into_content() { - ProcessedMessageContent::ApplicationMessage(application_message) => { - tracing::info!( - inbox_id = self.client.inbox_id(), - sender_inbox_id = sender_inbox_id, - sender_installation_id = hex::encode(&sender_installation_id), - installation_id = %self.client.installation_id(), - group_id = hex::encode(&self.group_id), - current_epoch = openmls_group.epoch().as_u64(), - msg_epoch, - msg_group_id, - msg_id, - "[{}] decoding application message", - self.context().inbox_id() - ); - let message_bytes = application_message.into_bytes(); - - let mut bytes = Bytes::from(message_bytes.clone()); - let envelope = PlaintextEnvelope::decode(&mut bytes)?; - - match envelope.content { - Some(Content::V1(V1 { - idempotency_key, - content, - })) => { - let message_id = - calculate_message_id(&self.group_id, &content, &idempotency_key); - StoredGroupMessage { - id: message_id, - group_id: self.group_id.clone(), - decrypted_message_bytes: content, - sent_at_ns: envelope_timestamp_ns as i64, - kind: GroupMessageKind::Application, - sender_installation_id, - sender_inbox_id, - delivery_status: DeliveryStatus::Published, + let (msg_epoch, msg_group_id) = ( + decrypted_message.epoch().as_u64(), + hex::encode(decrypted_message.group_id().as_slice()), + ); + match decrypted_message.into_content() { + ProcessedMessageContent::ApplicationMessage(application_message) => { + tracing::info!( + inbox_id = self.client.inbox_id(), + sender_inbox_id = sender_inbox_id, + sender_installation_id = hex::encode(&sender_installation_id), + installation_id = %self.client.installation_id(),group_id = hex::encode(&self.group_id), + current_epoch = mls_group.epoch().as_u64(), + msg_epoch, + msg_group_id, + msg_id, + "[{}] decoding application message", + self.context().inbox_id() + ); + let message_bytes = application_message.into_bytes(); + + let mut bytes = Bytes::from(message_bytes.clone()); + let envelope = PlaintextEnvelope::decode(&mut bytes)?; + + match envelope.content { + Some(Content::V1(V1 { + idempotency_key, + content, + })) => { + let message_id = + calculate_message_id(&self.group_id, &content, &idempotency_key); + StoredGroupMessage { + id: message_id, + group_id: self.group_id.clone(), + decrypted_message_bytes: content, + sent_at_ns: envelope_timestamp_ns as i64, + kind: GroupMessageKind::Application, + sender_installation_id, + sender_inbox_id, + delivery_status: DeliveryStatus::Published, + } + .store_or_ignore(provider.conn_ref())? } - .store_or_ignore(provider.conn_ref())? - } - Some(Content::V2(V2 { - idempotency_key, - message_type, - })) => { - match message_type { - Some(MessageType::DeviceSyncRequest(history_request)) => { - let content: DeviceSyncContent = - DeviceSyncContent::Request(history_request); - let content_bytes = serde_json::to_vec(&content)?; - let message_id = calculate_message_id( - &self.group_id, - &content_bytes, - &idempotency_key, - ); - - // store the request message - StoredGroupMessage { - id: message_id.clone(), - group_id: self.group_id.clone(), - decrypted_message_bytes: content_bytes, - sent_at_ns: envelope_timestamp_ns as i64, - kind: GroupMessageKind::Application, - sender_installation_id, - sender_inbox_id: sender_inbox_id.clone(), - delivery_status: DeliveryStatus::Published, + Some(Content::V2(V2 { + idempotency_key, + message_type, + })) => { + match message_type { + Some(MessageType::DeviceSyncRequest(history_request)) => { + let content: DeviceSyncContent = + DeviceSyncContent::Request(history_request); + let content_bytes = serde_json::to_vec(&content)?; + let message_id = calculate_message_id( + &self.group_id, + &content_bytes, + &idempotency_key, + ); + + // store the request message + StoredGroupMessage { + id: message_id.clone(), + group_id: self.group_id.clone(), + decrypted_message_bytes: content_bytes, + sent_at_ns: envelope_timestamp_ns as i64, + kind: GroupMessageKind::Application, + sender_installation_id, + sender_inbox_id: sender_inbox_id.clone(), + delivery_status: DeliveryStatus::Published, + } + .store_or_ignore(provider.conn_ref())?; + + tracing::info!("Received a history request."); + let _ = self.client.local_events().send(LocalEvents::SyncMessage( + SyncMessage::Request { message_id }, + )); } - .store_or_ignore(provider.conn_ref())?; - - tracing::info!("Received a history request."); - let _ = self.client.local_events().send(LocalEvents::SyncMessage( - SyncMessage::Request { message_id }, - )); - } - Some(MessageType::DeviceSyncReply(history_reply)) => { - let content: DeviceSyncContent = - DeviceSyncContent::Reply(history_reply); - let content_bytes = serde_json::to_vec(&content)?; - let message_id = calculate_message_id( - &self.group_id, - &content_bytes, - &idempotency_key, - ); - - // store the reply message - StoredGroupMessage { - id: message_id.clone(), - group_id: self.group_id.clone(), - decrypted_message_bytes: content_bytes, - sent_at_ns: envelope_timestamp_ns as i64, - kind: GroupMessageKind::Application, - sender_installation_id, - sender_inbox_id, - delivery_status: DeliveryStatus::Published, + Some(MessageType::DeviceSyncReply(history_reply)) => { + let content: DeviceSyncContent = + DeviceSyncContent::Reply(history_reply); + let content_bytes = serde_json::to_vec(&content)?; + let message_id = calculate_message_id( + &self.group_id, + &content_bytes, + &idempotency_key, + ); + + // store the reply message + StoredGroupMessage { + id: message_id.clone(), + group_id: self.group_id.clone(), + decrypted_message_bytes: content_bytes, + sent_at_ns: envelope_timestamp_ns as i64, + kind: GroupMessageKind::Application, + sender_installation_id, + sender_inbox_id, + delivery_status: DeliveryStatus::Published, + } + .store_or_ignore(provider.conn_ref())?; + + tracing::info!("Received a history reply."); + let _ = self.client.local_events().send(LocalEvents::SyncMessage( + SyncMessage::Reply { message_id }, + )); } - .store_or_ignore(provider.conn_ref())?; - - tracing::info!("Received a history reply."); - let _ = self.client.local_events().send(LocalEvents::SyncMessage( - SyncMessage::Reply { message_id }, - )); - } - Some(MessageType::UserPreferenceUpdate(update)) => { - // Ignore errors since this may come from a newer version of the lib - // that has new update types. - if let Ok(update) = update.try_into() { - let _ = self - .client - .local_events() - .send(LocalEvents::IncomingPreferenceUpdate(vec![update])); - } else { - tracing::warn!("Failed to deserialize preference update. Is this libxmtp version old?"); + Some(MessageType::UserPreferenceUpdate(update)) => { + // Ignore errors since this may come from a newer version of the lib + // that has new update types. + if let Ok(update) = update.try_into() { + let _ = self + .client + .local_events() + .send(LocalEvents::IncomingPreferenceUpdate(vec![update])); + } else { + tracing::warn!("Failed to deserialize preference update. Is this libxmtp version old?"); + } + } + _ => { + return Err(GroupMessageProcessingError::InvalidPayload); } - } - _ => { - return Err(GroupMessageProcessingError::InvalidPayload); } } + None => return Err(GroupMessageProcessingError::InvalidPayload), } - None => return Err(GroupMessageProcessingError::InvalidPayload), } - } - ProcessedMessageContent::ProposalMessage(_proposal_ptr) => { - // intentionally left blank. - } - ProcessedMessageContent::ExternalJoinProposalMessage(_external_proposal_ptr) => { - // intentionally left blank. - } - ProcessedMessageContent::StagedCommitMessage(staged_commit) => { - tracing::info!( - inbox_id = self.client.inbox_id(), - sender_inbox_id = sender_inbox_id, - installation_id = %self.client.installation_id(), - sender_installation_id = hex::encode(&sender_installation_id), - group_id = hex::encode(&self.group_id), - current_epoch = openmls_group.epoch().as_u64(), - msg_epoch, - msg_group_id, - msg_id, - "[{}] received staged commit. Merging and clearing any pending commits", - self.context().inbox_id() - ); + ProcessedMessageContent::ProposalMessage(_proposal_ptr) => { + // intentionally left blank. + } + ProcessedMessageContent::ExternalJoinProposalMessage(_external_proposal_ptr) => { + // intentionally left blank. + } + ProcessedMessageContent::StagedCommitMessage(staged_commit) => { + tracing::info!( + inbox_id = self.client.inbox_id(), + sender_inbox_id = sender_inbox_id, + installation_id = %self.client.installation_id(),sender_installation_id = hex::encode(&sender_installation_id), + group_id = hex::encode(&self.group_id), + current_epoch = mls_group.epoch().as_u64(), + msg_epoch, + msg_group_id, + msg_id, + "[{}] received staged commit. Merging and clearing any pending commits", + self.context().inbox_id() + ); - let sc = *staged_commit; + let sc = *staged_commit; - // Validate the commit - let validated_commit = ValidatedCommit::from_staged_commit( - self.client.as_ref(), - provider.conn_ref(), - &sc, - openmls_group, - ) - .await?; - tracing::info!( - inbox_id = self.client.inbox_id(), - sender_inbox_id = sender_inbox_id, - installation_id = %self.client.installation_id(), - sender_installation_id = hex::encode(&sender_installation_id), - group_id = hex::encode(&self.group_id), - current_epoch = openmls_group.epoch().as_u64(), - msg_epoch, - msg_group_id, - msg_id, - "[{}] staged commit is valid, will attempt to merge", - self.context().inbox_id() - ); - openmls_group.merge_staged_commit(provider, sc)?; - self.save_transcript_message( - provider.conn_ref(), - validated_commit, - envelope_timestamp_ns, - )?; - } - }; + // Validate the commit + let validated_commit = ValidatedCommit::from_staged_commit( + self.client.as_ref(), + provider.conn_ref(), + &sc, + &mls_group, + ) + .await?; + tracing::info!( + inbox_id = self.client.inbox_id(), + sender_inbox_id = sender_inbox_id, + installation_id = %self.client.installation_id(),sender_installation_id = hex::encode(&sender_installation_id), + group_id = hex::encode(&self.group_id), + current_epoch = mls_group.epoch().as_u64(), + msg_epoch, + msg_group_id, + msg_id, + "[{}] staged commit is valid, will attempt to merge", + self.context().inbox_id() + ); + mls_group.merge_staged_commit(provider, sc)?; + self.save_transcript_message( + provider.conn_ref(), + validated_commit, + envelope_timestamp_ns, + )?; + } + }; - Ok(()) + Ok(()) + }).await } #[tracing::instrument(level = "trace", skip_all)] pub(super) async fn process_message( &self, - openmls_group: &mut OpenMlsGroup, provider: &XmtpOpenMlsProvider, envelope: &GroupMessageV1, allow_epoch_increment: bool, @@ -722,7 +717,6 @@ where inbox_id = self.client.inbox_id(), installation_id = %self.client.installation_id(), group_id = hex::encode(&self.group_id), - current_epoch = openmls_group.epoch().as_u64(), msg_id = envelope.id, "Processing envelope with hash {:?}", hex::encode(sha256(envelope.data.as_slice())) @@ -736,7 +730,6 @@ where inbox_id = self.client.inbox_id(), installation_id = %self.client.installation_id(), group_id = hex::encode(&self.group_id), - current_epoch = openmls_group.epoch().as_u64(), msg_id = envelope.id, intent_id, intent.kind = %intent.kind, @@ -745,25 +738,26 @@ where envelope.id, intent_id ); - match self - .process_own_message(intent, openmls_group, provider, message.into(), envelope) - .await? - { - IntentState::ToPublish => { - Ok(provider.conn_ref().set_group_intent_to_publish(intent_id)?) - } - IntentState::Committed => { - Ok(provider.conn_ref().set_group_intent_committed(intent_id)?) - } - IntentState::Published => { - tracing::error!("Unexpected behaviour: returned intent state published from process_own_message"); - Ok(()) - } - IntentState::Error => { - tracing::warn!("Intent [{}] moved to error status", intent_id); - Ok(provider.conn_ref().set_group_intent_error(intent_id)?) + match self + .process_own_message(intent, provider, message.into(), envelope) + .await? + { + IntentState::ToPublish => { + Ok(provider.conn_ref().set_group_intent_to_publish(intent_id)?) + } + IntentState::Committed => { + Ok(provider.conn_ref().set_group_intent_committed(intent_id)?) + } + IntentState::Published => { + tracing::error!("Unexpected behaviour: returned intent state published from process_own_message"); + Ok(()) + } + IntentState::Error => { + tracing::warn!("Intent [{}] moved to error status", intent_id); + Ok(provider.conn_ref().set_group_intent_error(intent_id)?) + } } - } + } // No matching intent found Ok(None) => { @@ -771,13 +765,12 @@ where inbox_id = self.client.inbox_id(), installation_id = %self.client.installation_id(), group_id = hex::encode(&self.group_id), - current_epoch = openmls_group.epoch().as_u64(), msg_id = envelope.id, "client [{}] is about to process external envelope [{}]", self.client.inbox_id(), envelope.id ); - self.process_external_message(openmls_group, provider, message, envelope) + self.process_external_message(provider, message, envelope) .await } Err(err) => Err(GroupMessageProcessingError::Storage(err)), @@ -789,7 +782,6 @@ where &self, provider: &XmtpOpenMlsProvider, envelope: &GroupMessage, - openmls_group: &mut OpenMlsGroup, ) -> Result<(), GroupMessageProcessingError> { let msgv1 = match &envelope.version { Some(GroupMessageVersion::V1(value)) => value, @@ -831,7 +823,7 @@ where if !is_updated { return Err(ProcessIntentError::AlreadyProcessed(*cursor).into()); } - self.process_message(openmls_group, provider, msgv1, true).await?; + self.process_message(provider, msgv1, true).await?; Ok::<_, GroupMessageProcessingError>(()) }).await .inspect(|_| { @@ -859,14 +851,12 @@ where messages: Vec, provider: &XmtpOpenMlsProvider, ) -> Result<(), GroupError> { - let mut openmls_group = self.load_mls_group(provider)?; - let mut receive_errors: Vec = vec![]; for message in messages.into_iter() { let result = retry_async!( Retry::default(), (async { - self.consume_message(provider, &message, &mut openmls_group) + self.consume_message(provider, &message) .await }) ); @@ -1215,6 +1205,10 @@ where provider: &XmtpOpenMlsProvider, update_interval_ns: Option, ) -> Result<(), GroupError> { + // Check if we're still part of the group + if !self.is_active(provider)? { + return Ok(()); + } // determine how long of an interval in time to use before updating list let interval_ns = match update_interval_ns { Some(val) => val, diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index 360245fd1..b11385322 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -58,6 +58,7 @@ use self::{ group_permissions::PolicySet, validated_commit::CommitValidationError, }; +use std::future::Future; use std::{collections::HashSet, sync::Arc}; use xmtp_cryptography::signature::{sanitize_evm_addresses, AddressValidationError}; use xmtp_id::{InboxId, InboxIdRef}; @@ -80,6 +81,7 @@ use crate::{ MAX_PAST_EPOCHS, MUTABLE_METADATA_EXTENSION_ID, SEND_MESSAGE_UPDATE_INSTALLATIONS_INTERVAL_NS, }, + groups, hpke::{decrypt_welcome, HpkeError}, identity::{parse_credential, IdentityError}, identity_updates::{load_identity_updates, InstallationDiffError}, @@ -96,8 +98,10 @@ use crate::{ subscriptions::{LocalEventError, LocalEvents}, utils::{id::calculate_message_id, time::now_ns}, xmtp_openmls_provider::XmtpOpenMlsProvider, - Store, + GroupCommitLock, Store, MLS_COMMIT_LOCK, }; +use crate::hpke::HpkeError::StorageError; +use crate::storage::sql_key_store::SqlKeyStoreError; #[derive(Debug, Error)] pub enum GroupError { @@ -202,6 +206,8 @@ pub enum GroupError { IntentNotCommitted, #[error(transparent)] ProcessIntent(#[from] ProcessIntentError), + #[error("Failed to acquire lock for group operation")] + LockUnavailable, } impl RetryableError for GroupError { @@ -228,6 +234,7 @@ impl RetryableError for GroupError { Self::MessageHistory(err) => err.is_retryable(), Self::ProcessIntent(err) => err.is_retryable(), Self::LocalEvent(err) => err.is_retryable(), + Self::LockUnavailable => true, Self::SyncFailedToWait => true, Self::GroupNotFound | Self::GroupMetadata(_) @@ -343,6 +350,59 @@ impl MlsGroup { Ok(mls_group) } + #[tracing::instrument(level = "trace", skip_all)] + pub(crate) fn load_mls_group_with_lock( + &self, + provider: impl OpenMlsProvider, + operation: F, + ) -> Result + where + F: FnOnce(OpenMlsGroup) -> Result, + { + // Get the group ID for locking + let group_id = self.group_id.clone(); + + // Acquire the lock synchronously using blocking_lock + + let _lock = MLS_COMMIT_LOCK.get_lock_sync(group_id.clone())?; + // Load the MLS group + let mls_group = + OpenMlsGroup::load(provider.storage(), &GroupId::from_slice(&self.group_id)) + .map_err(|_| GroupError::GroupNotFound)? + .ok_or(GroupError::GroupNotFound)?; + + // Perform the operation with the MLS group + operation(mls_group) + } + + #[tracing::instrument(level = "trace", skip_all)] + pub(crate) async fn load_mls_group_with_lock_async( + &self, + provider: impl OpenMlsProvider, + operation: F, + ) -> Result + where + F: FnOnce(OpenMlsGroup) -> Fut + Send + 'static, + Fut: Future>, + E: From, + E: Into, + E: From + { + // Get the group ID for locking + let group_id = self.group_id.clone(); + + // Acquire the lock asynchronously + let _lock = MLS_COMMIT_LOCK.get_lock_async(group_id.clone()).await; + + // Load the MLS group + let mls_group = + OpenMlsGroup::load(provider.storage(), &GroupId::from_slice(&self.group_id))? + .ok_or(|e|GroupMessageProcessingError::Storage)?; + + // Perform the operation with the MLS group + operation(mls_group).await.map_err(Into::into) + } + // Create a new group and save it to the DB pub(crate) fn create_and_insert( client: Arc, @@ -1126,12 +1186,11 @@ impl MlsGroup { self.sync_until_intent_resolved(&provider, intent.id).await } - /// Checks if the the current user is active in the group. + /// Checks if the current user is active in the group. /// /// If the current user has been kicked out of the group, `is_active` will return `false` pub fn is_active(&self, provider: impl OpenMlsProvider) -> Result { - let mls_group = self.load_mls_group(provider)?; - Ok(mls_group.is_active()) + self.load_mls_group_with_lock(provider, |mls_group| Ok(mls_group.is_active())) } /// Get the `GroupMetadata` of the group. @@ -3700,9 +3759,8 @@ pub(crate) mod tests { panic!("wrong message format") }; let provider = client.mls_provider().unwrap(); - let mut openmls_group = group.load_mls_group(&provider).unwrap(); let process_result = group - .process_message(&mut openmls_group, &provider, &first_message, false) + .process_message(&provider, &first_message, false) .await; assert_err!( diff --git a/xmtp_mls/src/groups/subscriptions.rs b/xmtp_mls/src/groups/subscriptions.rs index fbaaa20d6..c011ccd22 100644 --- a/xmtp_mls/src/groups/subscriptions.rs +++ b/xmtp_mls/src/groups/subscriptions.rs @@ -43,31 +43,29 @@ impl MlsGroup { let process_result = retry_async!( Retry::default(), (async { - let client_id = &client_id; let msgv1 = &msgv1; self.context() .store() .transaction_async(provider, |provider| async move { - let mut openmls_group = self.load_mls_group(provider)?; - - // Attempt processing immediately, but fail if the message is not an Application Message - // Returning an error should roll back the DB tx - tracing::info!( + // let prov_ref = &provider; // Borrow provider instead of moving it + self.load_mls_group_with_lock_async(provider, |mut mls_group| async move { + // Attempt processing immediately, but fail if the message is not an Application Message + // Returning an error should roll back the DB tx + tracing::info!( inbox_id = self.client.inbox_id(), group_id = hex::encode(&self.group_id), - current_epoch = openmls_group.epoch().as_u64(), msg_id = msgv1.id, - "current epoch for [{}] in process_stream_entry() is Epoch: [{}]", + "current epoch for [{}] in process_stream_entry()", client_id, - openmls_group.epoch() ); - - self.process_message(&mut openmls_group, provider, msgv1, false) - .await - // NOTE: We want to make sure we retry an error in process_message - .map_err(SubscribeError::ReceiveGroup) + self.process_message(&mut mls_group, provider, msgv1, false) + .await + // NOTE: We want to make sure we retry an error in process_message + .map_err(SubscribeError::ReceiveGroup) + }).await }) .await + }) ); diff --git a/xmtp_mls/src/lib.rs b/xmtp_mls/src/lib.rs index fb21ba10b..6b82a8f6a 100644 --- a/xmtp_mls/src/lib.rs +++ b/xmtp_mls/src/lib.rs @@ -21,15 +21,76 @@ pub mod utils; pub mod verified_key_package_v2; mod xmtp_openmls_provider; +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; +use tokio::sync::{Semaphore, OwnedSemaphorePermit}; pub use client::{Client, Network}; use storage::{DuplicateItem, StorageError}; pub use xmtp_openmls_provider::XmtpOpenMlsProvider; +use std::sync::LazyLock; pub use xmtp_id::InboxOwner; pub use xmtp_proto::api_client::trait_impls::*; #[macro_use] extern crate tracing; +/// A manager for group-specific semaphores +#[derive(Debug)] +pub struct GroupCommitLock { + // Storage for group-specific semaphores + locks: Mutex, Arc>>, +} + +impl GroupCommitLock { + /// Create a new `GroupCommitLock` + pub fn new() -> Self { + Self { + locks: Mutex::new(HashMap::new()), + } + } + + /// Get or create a semaphore for a specific group and acquire it, returning a guard + pub async fn get_lock_async(&self, group_id: Vec) -> SemaphoreGuard { + let semaphore = { + let mut locks = self.locks.lock().unwrap(); + locks + .entry(group_id) + .or_insert_with(|| Arc::new(Semaphore::new(1))) + .clone() + }; + + let semaphore_clone = semaphore.clone(); + let permit = semaphore.acquire_owned().await.unwrap(); + SemaphoreGuard { _permit: permit, _semaphore: semaphore_clone } + } + + /// Get or create a semaphore for a specific group and acquire it synchronously + pub fn get_lock_sync(&self, group_id: Vec) -> Result { + let semaphore = { + let mut locks = self.locks.lock().unwrap(); + locks + .entry(group_id) + .or_insert_with(|| Arc::new(Semaphore::new(1))) + .clone() // Clone here to retain ownership for later use + }; + + // Synchronously acquire the permit + let permit = semaphore.clone().try_acquire_owned().map_err(|_| GroupError::LockUnavailable)?; + Ok(SemaphoreGuard { + _permit: permit, + _semaphore: semaphore, // semaphore is now valid because we cloned it earlier + }) + } +} + +/// A guard that releases the semaphore when dropped +pub struct SemaphoreGuard { + _permit: OwnedSemaphorePermit, + _semaphore: Arc, +} + +// Static instance of `GroupCommitLock` +pub static MLS_COMMIT_LOCK: LazyLock = LazyLock::new(GroupCommitLock::new); /// Global Marker trait for WebAssembly #[cfg(target_arch = "wasm32")] @@ -78,6 +139,7 @@ pub trait Delete { pub use stream_handles::{ spawn, AbortHandle, GenericStreamHandle, StreamHandle, StreamHandleError, }; +use crate::groups::GroupError; #[cfg(target_arch = "wasm32")] #[doc(hidden)] From 2b57721cebd4d5a85f1aed95fdf3c8592cbe4d75 Mon Sep 17 00:00:00 2001 From: Mojtaba Chenani Date: Mon, 9 Dec 2024 18:53:28 +0100 Subject: [PATCH 06/28] wip --- xmtp_mls/src/groups/members.rs | 7 +- xmtp_mls/src/groups/mls_sync.rs | 266 +++++++++++++-------------- xmtp_mls/src/groups/mod.rs | 134 ++++++++------ xmtp_mls/src/groups/subscriptions.rs | 1 - xmtp_mls/src/storage/errors.rs | 5 +- 5 files changed, 215 insertions(+), 198 deletions(-) diff --git a/xmtp_mls/src/groups/members.rs b/xmtp_mls/src/groups/members.rs index 5ca53a40a..730529fb3 100644 --- a/xmtp_mls/src/groups/members.rs +++ b/xmtp_mls/src/groups/members.rs @@ -40,9 +40,10 @@ where &self, provider: &XmtpOpenMlsProvider, ) -> Result, GroupError> { - let openmls_group = self.load_mls_group(provider)?; - // TODO: Replace with try_into from extensions - let group_membership = extract_group_membership(openmls_group.extensions())?; + let group_membership = self.load_mls_group_with_lock(provider, |mls_group| { + // Extract group membership from extensions + Ok(extract_group_membership(mls_group.extensions())?) + })?; let requests = group_membership .members .into_iter() diff --git a/xmtp_mls/src/groups/mls_sync.rs b/xmtp_mls/src/groups/mls_sync.rs index 3be8e70b6..bfad0d255 100644 --- a/xmtp_mls/src/groups/mls_sync.rs +++ b/xmtp_mls/src/groups/mls_sync.rs @@ -797,7 +797,6 @@ where let last_cursor = provider .conn_ref() .get_last_cursor_for_id(&self.group_id, message_entity_kind)?; - tracing::info!("### last cursor --> [{:?}]", last_cursor); let should_skip_message = last_cursor > msgv1.id as i64; if should_skip_message { tracing::info!( @@ -953,103 +952,104 @@ where &self, provider: &XmtpOpenMlsProvider, ) -> Result<(), GroupError> { - let mut openmls_group = self.load_mls_group(provider)?; + self.load_mls_group_with_lock_async(provider, |mut mls_group| async move { + let intents = provider.conn_ref().find_group_intents( + self.group_id.clone(), + Some(vec![IntentState::ToPublish]), + None, + )?; + + for intent in intents { + let result = retry_async!( + Retry::default(), + (async { + self.get_publish_intent_data(provider, &mut mls_group, &intent) + .await + }) + ); - let intents = provider.conn_ref().find_group_intents( - self.group_id.clone(), - Some(vec![IntentState::ToPublish]), - None, - )?; + match result { + Err(err) => { + tracing::error!(error = %err, "error getting publish intent data {:?}", err); + if (intent.publish_attempts + 1) as usize >= MAX_INTENT_PUBLISH_ATTEMPTS { + tracing::error!( + intent.id, + intent.kind = %intent.kind, + inbox_id = self.client.inbox_id(), + installation_id = %self.client.installation_id(),group_id = hex::encode(&self.group_id), + "intent {} has reached max publish attempts", intent.id); + // TODO: Eventually clean up errored attempts + provider + .conn_ref() + .set_group_intent_error_and_fail_msg(&intent)?; + } else { + provider + .conn_ref() + .increment_intent_publish_attempt_count(intent.id)?; + } - for intent in intents { - let result = retry_async!( - Retry::default(), - (async { - self.get_publish_intent_data(provider, &mut openmls_group, &intent) - .await - }) - ); + return Err(err); + } + Ok(Some(PublishIntentData { + payload_to_publish, + post_commit_action, + staged_commit, + })) => { + let payload_slice = payload_to_publish.as_slice(); + let has_staged_commit = staged_commit.is_some(); + provider.conn_ref().set_group_intent_published( + intent.id, + sha256(payload_slice), + post_commit_action, + staged_commit, + mls_group.epoch().as_u64() as i64, + )?; + tracing::debug!( + inbox_id = self.client.inbox_id(), + installation_id = %self.client.installation_id(), + intent.id, + intent.kind = %intent.kind, + group_id = hex::encode(&self.group_id), + "client [{}] set stored intent [{}] to state `published`", + self.client.inbox_id(), + intent.id + ); - match result { - Err(err) => { - tracing::error!(error = %err, "error getting publish intent data {:?}", err); - if (intent.publish_attempts + 1) as usize >= MAX_INTENT_PUBLISH_ATTEMPTS { - tracing::error!( + let messages = self.prepare_group_messages(vec![payload_slice])?;self.client + .api() + .send_group_messages(messages) + .await?; + + tracing::info!( intent.id, intent.kind = %intent.kind, inbox_id = self.client.inbox_id(), installation_id = %self.client.installation_id(), group_id = hex::encode(&self.group_id), - "intent {} has reached max publish attempts", intent.id); - // TODO: Eventually clean up errored attempts - provider - .conn_ref() - .set_group_intent_error_and_fail_msg(&intent)?; - } else { - provider - .conn_ref() - .increment_intent_publish_attempt_count(intent.id)?; + "[{}] published intent [{}] of type [{}]", + self.client.inbox_id(), + intent.id, + intent.kind + ); + if has_staged_commit { + tracing::info!("Commit sent. Stopping further publishes for this round"); + return Ok(()); + } } - - return Err(err); - } - Ok(Some(PublishIntentData { - payload_to_publish, - post_commit_action, - staged_commit, - })) => { - let payload_slice = payload_to_publish.as_slice(); - let has_staged_commit = staged_commit.is_some(); - provider.conn_ref().set_group_intent_published( - intent.id, - sha256(payload_slice), - post_commit_action, - staged_commit, - openmls_group.epoch().as_u64() as i64, - )?; - tracing::debug!( - inbox_id = self.client.inbox_id(), - installation_id = %self.client.installation_id(), - intent.id, - intent.kind = %intent.kind, - group_id = hex::encode(&self.group_id), - "client [{}] set stored intent [{}] to state `published`", - self.client.inbox_id(), - intent.id - ); - - let messages = self.prepare_group_messages(vec![payload_slice])?; - self.client.api().send_group_messages(messages).await?; - - tracing::info!( - intent.id, - intent.kind = %intent.kind, - inbox_id = self.client.inbox_id(), - installation_id = %self.client.installation_id(), - group_id = hex::encode(&self.group_id), - "[{}] published intent [{}] of type [{}]", - self.client.inbox_id(), - intent.id, - intent.kind - ); - if has_staged_commit { - tracing::info!("Commit sent. Stopping further publishes for this round"); - return Ok(()); + Ok(None) => { + tracing::info!( + inbox_id = self.client.inbox_id(), + installation_id = %self.client.installation_id(), + "Skipping intent because no publish data returned" + ); + let deleter: &dyn Delete = provider.conn_ref(); + deleter.delete(intent.id)?; } } - Ok(None) => { - tracing::info!( - inbox_id = self.client.inbox_id(), - installation_id = %self.client.installation_id(), - "Skipping intent because no publish data returned" - ); - let deleter: &dyn Delete = provider.conn_ref(); - deleter.delete(intent.id)?; - } } - } - Ok(()) + Ok(()) + }).await } // Takes a StoredGroupIntent and returns the payload and post commit data as a tuple @@ -1280,58 +1280,58 @@ where inbox_ids_to_add: &[InboxIdRef<'_>], inbox_ids_to_remove: &[InboxIdRef<'_>], ) -> Result { - let mls_group = self.load_mls_group(provider)?; - let existing_group_membership = extract_group_membership(mls_group.extensions())?; - - // TODO:nm prevent querying for updates on members who are being removed - let mut inbox_ids = existing_group_membership.inbox_ids(); - inbox_ids.extend_from_slice(inbox_ids_to_add); - let conn = provider.conn_ref(); - // Load any missing updates from the network - load_identity_updates(self.client.api(), conn, &inbox_ids).await?; - - let latest_sequence_id_map = conn.get_latest_sequence_id(&inbox_ids as &[&str])?; - - // Get a list of all inbox IDs that have increased sequence_id for the group - let changed_inbox_ids = - inbox_ids - .iter() - .try_fold(HashMap::new(), |mut updates, inbox_id| { - match ( - latest_sequence_id_map.get(inbox_id as &str), - existing_group_membership.get(inbox_id), - ) { - // This is an update. We have a new sequence ID and an existing one - (Some(latest_sequence_id), Some(current_sequence_id)) => { - let latest_sequence_id_u64 = *latest_sequence_id as u64; - if latest_sequence_id_u64.gt(current_sequence_id) { - updates.insert(inbox_id.to_string(), latest_sequence_id_u64); + self.load_mls_group_with_lock_async(provider, |mut mls_group| async move { + let existing_group_membership = extract_group_membership(mls_group.extensions())?; + // TODO:nm prevent querying for updates on members who are being removed + let mut inbox_ids = existing_group_membership.inbox_ids(); + inbox_ids.extend_from_slice(inbox_ids_to_add); + let conn = provider.conn_ref(); + // Load any missing updates from the network + load_identity_updates(self.client.api(), conn, &inbox_ids).await?; + + let latest_sequence_id_map = conn.get_latest_sequence_id(&inbox_ids as &[&str])?; + + // Get a list of all inbox IDs that have increased sequence_id for the group + let changed_inbox_ids = + inbox_ids + .iter() + .try_fold(HashMap::new(), |mut updates, inbox_id| { + match ( + latest_sequence_id_map.get(inbox_id as &str), + existing_group_membership.get(inbox_id), + ) { + // This is an update. We have a new sequence ID and an existing one + (Some(latest_sequence_id), Some(current_sequence_id)) => { + let latest_sequence_id_u64 = *latest_sequence_id as u64; + if latest_sequence_id_u64.gt(current_sequence_id) { + updates.insert(inbox_id.to_string(), latest_sequence_id_u64); + } + } + // This is for new additions to the group + (Some(latest_sequence_id), _) => { + // This is the case for net new members to the group + updates.insert(inbox_id.to_string(), *latest_sequence_id as u64); + } + (_, _) => { + tracing::warn!( + "Could not find existing sequence ID for inbox {}", + inbox_id + ); + return Err(GroupError::MissingSequenceId); } } - // This is for new additions to the group - (Some(latest_sequence_id), _) => { - // This is the case for net new members to the group - updates.insert(inbox_id.to_string(), *latest_sequence_id as u64); - } - (_, _) => { - tracing::warn!( - "Could not find existing sequence ID for inbox {}", - inbox_id - ); - return Err(GroupError::MissingSequenceId); - } - } - Ok(updates) - })?; + Ok(updates) + })?; - Ok(UpdateGroupMembershipIntentData::new( - changed_inbox_ids, - inbox_ids_to_remove - .iter() - .map(|s| s.to_string()) - .collect::>(), - )) + Ok(UpdateGroupMembershipIntentData::new( + changed_inbox_ids, + inbox_ids_to_remove + .iter() + .map(|s| s.to_string()) + .collect::>(), + )) + }).await } /** diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index b11385322..73f814f11 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -58,6 +58,7 @@ use self::{ group_permissions::PolicySet, validated_commit::CommitValidationError, }; +use futures::TryFutureExt; use std::future::Future; use std::{collections::HashSet, sync::Arc}; use xmtp_cryptography::signature::{sanitize_evm_addresses, AddressValidationError}; @@ -73,6 +74,8 @@ use xmtp_proto::xmtp::mls::{ }, }; +use crate::hpke::HpkeError::StorageError; +use crate::storage::sql_key_store::SqlKeyStoreError; use crate::{ api::WrappedApiError, client::{deserialize_welcome, ClientError, XmtpMlsLocalContext}, @@ -100,8 +103,6 @@ use crate::{ xmtp_openmls_provider::XmtpOpenMlsProvider, GroupCommitLock, Store, MLS_COMMIT_LOCK, }; -use crate::hpke::HpkeError::StorageError; -use crate::storage::sql_key_store::SqlKeyStoreError; #[derive(Debug, Error)] pub enum GroupError { @@ -376,17 +377,15 @@ impl MlsGroup { } #[tracing::instrument(level = "trace", skip_all)] - pub(crate) async fn load_mls_group_with_lock_async( + pub(crate) async fn load_mls_group_with_lock_async( &self, - provider: impl OpenMlsProvider, + provider: &XmtpOpenMlsProvider, operation: F, ) -> Result where - F: FnOnce(OpenMlsGroup) -> Fut + Send + 'static, - Fut: Future>, - E: From, - E: Into, - E: From + F: FnOnce(OpenMlsGroup) -> Fut + Send, + Fut: Future>, + E: From + From, { // Get the group ID for locking let group_id = self.group_id.clone(); @@ -396,8 +395,9 @@ impl MlsGroup { // Load the MLS group let mls_group = - OpenMlsGroup::load(provider.storage(), &GroupId::from_slice(&self.group_id))? - .ok_or(|e|GroupMessageProcessingError::Storage)?; + OpenMlsGroup::load(provider.storage(), &GroupId::from_slice(&self.group_id)) + .map_err(crate::StorageError::from)? + .ok_or(crate::StorageError::NotFound("Group Not Found".into()))?; // Perform the operation with the MLS group operation(mls_group).await.map_err(Into::into) @@ -1195,8 +1195,9 @@ impl MlsGroup { /// Get the `GroupMetadata` of the group. pub fn metadata(&self, provider: impl OpenMlsProvider) -> Result { - let mls_group = self.load_mls_group(provider)?; - Ok(extract_group_metadata(&mls_group)?) + self.load_mls_group_with_lock(provider, |mls_group| { + Ok(extract_group_metadata(&mls_group)?) + }) } /// Get the `GroupMutableMetadata` of the group. @@ -1204,16 +1205,17 @@ impl MlsGroup { &self, provider: impl OpenMlsProvider, ) -> Result { - let mls_group = &self.load_mls_group(provider)?; - - Ok(mls_group.try_into()?) + self.load_mls_group_with_lock(provider, |mls_group| { + Ok(GroupMutableMetadata::try_from(&mls_group)?) + }) } pub fn permissions(&self) -> Result { let provider = self.mls_provider()?; - let mls_group = self.load_mls_group(&provider)?; - Ok(extract_group_permissions(&mls_group)?) + self.load_mls_group_with_lock(&provider, |mls_group| { + Ok(extract_group_permissions(&mls_group)?) + }) } /// Used for testing that dm group validation works as expected. @@ -1937,15 +1939,19 @@ pub(crate) mod tests { // Check Amal's MLS group state. let amal_db = XmtpOpenMlsProvider::from(amal.context.store().conn().unwrap()); - let amal_mls_group = amal_group.load_mls_group(&amal_db).unwrap(); - let amal_members: Vec = amal_mls_group.members().collect(); - assert_eq!(amal_members.len(), 3); + let amal_members_len = amal_group.load_mls_group_with_lock(&amal_db, |amal_mls_group| { + Ok(amal_mls_group.members().count()) + }).unwrap(); + + assert_eq!(amal_members_len, 3); // Check Bola's MLS group state. let bola_db = XmtpOpenMlsProvider::from(bola.context.store().conn().unwrap()); - let bola_mls_group = bola_group.load_mls_group(&bola_db).unwrap(); - let bola_members: Vec = bola_mls_group.members().collect(); - assert_eq!(bola_members.len(), 3); + let bola_members_len = bola_group.load_mls_group_with_lock(&bola_db, |bola_mls_group| { + Ok(bola_mls_group.members().count()) + }).unwrap(); + + assert_eq!(bola_members_len, 3); let amal_uncommitted_intents = amal_db .conn_ref() @@ -2004,22 +2010,26 @@ pub(crate) mod tests { .unwrap(); let provider = alix.mls_provider().unwrap(); // Doctor the group membership - let mut mls_group = alix_group.load_mls_group(&provider).unwrap(); - let mut existing_extensions = mls_group.extensions().clone(); - let mut group_membership = GroupMembership::new(); - group_membership.add("deadbeef".to_string(), 1); - existing_extensions.add_or_replace(build_group_membership_extension(&group_membership)); - mls_group - .update_group_context_extensions( - &provider, - existing_extensions.clone(), - &alix.identity().installation_keys, - ) - .unwrap(); - mls_group.merge_pending_commit(&provider).unwrap(); + let mut mls_group = alix_group.load_mls_group_with_lock(&provider, |mut mls_group| { + let mut existing_extensions = mls_group.extensions().clone(); + let mut group_membership = GroupMembership::new(); + group_membership.add("deadbeef".to_string(), 1); + existing_extensions.add_or_replace(build_group_membership_extension(&group_membership)); + + mls_group + .update_group_context_extensions( + &provider, + existing_extensions.clone(), + &alix.identity().installation_keys, + ) + .unwrap(); + mls_group.merge_pending_commit(&provider).unwrap(); + + Ok(mls_group) // Return the updated group if necessary + }).unwrap(); - // Now add bo to the group force_add_member(&alix, &bo, &alix_group, &mut mls_group, &provider).await; + // Now add bo to the group // Bo should not be able to actually read this group bo.sync_welcomes(&bo.mls_provider().unwrap()).await.unwrap(); @@ -2143,9 +2153,11 @@ pub(crate) mod tests { assert_eq!(messages.len(), 2); let provider: XmtpOpenMlsProvider = client.context.store().conn().unwrap().into(); - let mls_group = group.load_mls_group(&provider).unwrap(); - let pending_commit = mls_group.pending_commit(); - assert!(pending_commit.is_none()); + let pending_commit_is_none = group.load_mls_group_with_lock(&provider, |mls_group| { + Ok(mls_group.pending_commit().is_none()) + }).unwrap(); + + assert!(pending_commit_is_none); group.send_message(b"hello").await.expect("send message"); @@ -2324,8 +2336,10 @@ pub(crate) mod tests { assert!(new_installations_were_added.is_ok()); group.sync().await.unwrap(); - let mls_group = group.load_mls_group(&provider).unwrap(); - let num_members = mls_group.members().collect::>().len(); + let num_members = group.load_mls_group_with_lock(&provider, |mls_group| { + Ok(mls_group.members().collect::>().len()) + }).unwrap(); + assert_eq!(num_members, 3); } @@ -3907,14 +3921,11 @@ pub(crate) mod tests { None, ) .unwrap(); - assert!(validate_dm_group( - &client, - &valid_dm_group - .load_mls_group(client.mls_provider().unwrap()) - .unwrap(), - added_by_inbox - ) - .is_ok()); + assert!(valid_dm_group + .load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group| { + validate_dm_group(&client, &mls_group, added_by_inbox) + }) + .is_ok()); // Test case 2: Invalid conversation type let invalid_protected_metadata = @@ -3929,10 +3940,11 @@ pub(crate) mod tests { ) .unwrap(); assert!(matches!( - validate_dm_group(&client, &invalid_type_group.load_mls_group(client.mls_provider().unwrap()).unwrap(), added_by_inbox), + invalid_type_group.load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group| { + validate_dm_group(&client, &mls_group, added_by_inbox) + }), Err(GroupError::Generic(msg)) if msg.contains("Invalid conversation type") )); - // Test case 3: Missing DmMembers // This case is not easily testable with the current structure, as DmMembers are set in the protected metadata @@ -3950,7 +3962,9 @@ pub(crate) mod tests { ) .unwrap(); assert!(matches!( - validate_dm_group(&client, &mismatched_dm_members_group.load_mls_group(client.mls_provider().unwrap()).unwrap(), added_by_inbox), + mismatched_dm_members_group.load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group| { + validate_dm_group(&client, &mls_group, added_by_inbox) + }), Err(GroupError::Generic(msg)) if msg.contains("DM members do not match expected inboxes") )); @@ -3970,7 +3984,9 @@ pub(crate) mod tests { ) .unwrap(); assert!(matches!( - validate_dm_group(&client, &non_empty_admin_list_group.load_mls_group(client.mls_provider().unwrap()).unwrap(), added_by_inbox), + non_empty_admin_list_group.load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group| { + validate_dm_group(&client, &mls_group, added_by_inbox) + }), Err(GroupError::Generic(msg)) if msg.contains("DM group must have empty admin and super admin lists") )); @@ -3989,11 +4005,9 @@ pub(crate) mod tests { ) .unwrap(); assert!(matches!( - validate_dm_group( - &client, - &invalid_permissions_group.load_mls_group(client.mls_provider().unwrap()).unwrap(), - added_by_inbox - ), + invalid_permissions_group.load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group| { + validate_dm_group(&client, &mls_group, added_by_inbox) + }), Err(GroupError::Generic(msg)) if msg.contains("Invalid permissions for DM group") )); } diff --git a/xmtp_mls/src/groups/subscriptions.rs b/xmtp_mls/src/groups/subscriptions.rs index c011ccd22..8034398bc 100644 --- a/xmtp_mls/src/groups/subscriptions.rs +++ b/xmtp_mls/src/groups/subscriptions.rs @@ -65,7 +65,6 @@ impl MlsGroup { }).await }) .await - }) ); diff --git a/xmtp_mls/src/storage/errors.rs b/xmtp_mls/src/storage/errors.rs index f109b0248..50c630df2 100644 --- a/xmtp_mls/src/storage/errors.rs +++ b/xmtp_mls/src/storage/errors.rs @@ -5,7 +5,7 @@ use thiserror::Error; use crate::{groups::intents::IntentError, retry::RetryableError, retryable}; -use super::sql_key_store; +use super::sql_key_store::{self, SqlKeyStoreError}; #[derive(Debug, Error)] pub enum StorageError { @@ -25,6 +25,7 @@ pub enum StorageError { Serialization(String), #[error("deserialization error")] Deserialization(String), + // TODO:insipx Make NotFound into an enum of possible items that may not be found #[error("{0} not found")] NotFound(String), #[error("lock")] @@ -43,6 +44,8 @@ pub enum StorageError { FromHex(#[from] hex::FromHexError), #[error(transparent)] Duplicate(DuplicateItem), + #[error(transparent)] + OpenMlsStorage(#[from] SqlKeyStoreError), } #[derive(Error, Debug)] From 5b352a72806940cd6a03bbbc02a6dcced4530ba3 Mon Sep 17 00:00:00 2001 From: Mojtaba Chenani Date: Mon, 9 Dec 2024 22:17:34 +0100 Subject: [PATCH 07/28] wip --- bindings_node/src/conversation.rs | 2 +- xmtp_mls/src/client.rs | 4 +- xmtp_mls/src/groups/intents.rs | 6 +-- xmtp_mls/src/groups/members.rs | 1 - xmtp_mls/src/groups/mls_sync.rs | 10 ++-- xmtp_mls/src/groups/mod.rs | 73 ++++++++++++---------------- xmtp_mls/src/groups/subscriptions.rs | 10 +++- 7 files changed, 48 insertions(+), 58 deletions(-) diff --git a/bindings_node/src/conversation.rs b/bindings_node/src/conversation.rs index ee52eb9d2..897662463 100644 --- a/bindings_node/src/conversation.rs +++ b/bindings_node/src/conversation.rs @@ -1,5 +1,5 @@ use std::{ops::Deref, sync::Arc}; -use futures::TryFutureExt; + use napi::{ bindgen_prelude::{Result, Uint8Array}, threadsafe_function::{ErrorStrategy, ThreadsafeFunction, ThreadsafeFunctionCallMode}, diff --git a/xmtp_mls/src/client.rs b/xmtp_mls/src/client.rs index 5f9fdc82c..98bf00b63 100644 --- a/xmtp_mls/src/client.rs +++ b/xmtp_mls/src/client.rs @@ -865,8 +865,8 @@ where async move { tracing::info!( inbox_id = self.inbox_id(), - "current epoch for [{}] in sync_all_groups()", - self.inbox_id(), + "[{}] syncing group", + self.inbox_id() ); tracing::info!( inbox_id = self.inbox_id(), diff --git a/xmtp_mls/src/groups/intents.rs b/xmtp_mls/src/groups/intents.rs index 7e4276cf0..d0c599862 100644 --- a/xmtp_mls/src/groups/intents.rs +++ b/xmtp_mls/src/groups/intents.rs @@ -727,7 +727,6 @@ impl TryFrom> for PostCommitAction { pub(crate) mod tests { #[cfg(target_arch = "wasm32")] wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_dedicated_worker); - use openmls::prelude::{MlsMessageBodyIn, MlsMessageIn, ProcessedMessageContent}; use tls_codec::Deserialize; use xmtp_cryptography::utils::generate_local_wallet; @@ -868,9 +867,8 @@ pub(crate) mod tests { let provider = group.client.mls_provider().unwrap(); let decrypted_message = match group .load_mls_group_with_lock(&provider, |mut mls_group| { - mls_group - .process_message(&provider, mls_message) - .map_err(|e| GroupError::Generic(e.to_string())) + Ok(mls_group + .process_message(&provider, mls_message).unwrap()) }) { Ok(message) => message, Err(err) => panic!("Error: {:?}", err), diff --git a/xmtp_mls/src/groups/members.rs b/xmtp_mls/src/groups/members.rs index 730529fb3..cfdf56e28 100644 --- a/xmtp_mls/src/groups/members.rs +++ b/xmtp_mls/src/groups/members.rs @@ -41,7 +41,6 @@ where provider: &XmtpOpenMlsProvider, ) -> Result, GroupError> { let group_membership = self.load_mls_group_with_lock(provider, |mls_group| { - // Extract group membership from extensions Ok(extract_group_membership(mls_group.extensions())?) })?; let requests = group_membership diff --git a/xmtp_mls/src/groups/mls_sync.rs b/xmtp_mls/src/groups/mls_sync.rs index bfad0d255..a20c40d1d 100644 --- a/xmtp_mls/src/groups/mls_sync.rs +++ b/xmtp_mls/src/groups/mls_sync.rs @@ -369,9 +369,8 @@ where if intent.state == IntentState::Committed { return Ok(IntentState::Committed); } - let group_epoch = mls_group.epoch(); - let message_epoch = message.epoch(); + let group_epoch = mls_group.epoch(); debug!( inbox_id = self.client.inbox_id(), installation_id = %self.client.installation_id(), @@ -856,7 +855,7 @@ where Retry::default(), (async { self.consume_message(provider, &message) - .await + .await }) ); if let Err(e) = result { @@ -1210,10 +1209,7 @@ where return Ok(()); } // determine how long of an interval in time to use before updating list - let interval_ns = match update_interval_ns { - Some(val) => val, - None => SYNC_UPDATE_INSTALLATIONS_INTERVAL_NS, - }; + let interval_ns = update_interval_ns.unwrap_or_else(|| SYNC_UPDATE_INSTALLATIONS_INTERVAL_NS); let now_ns = crate::utils::time::now_ns(); let last_ns = provider diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index 73f814f11..19fe34442 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -338,19 +338,6 @@ impl MlsGroup { } // Load the stored OpenMLS group from the OpenMLS provider's keystore - #[tracing::instrument(level = "trace", skip_all)] - pub(crate) fn load_mls_group( - &self, - provider: impl OpenMlsProvider, - ) -> Result { - let mls_group = - OpenMlsGroup::load(provider.storage(), &GroupId::from_slice(&self.group_id)) - .map_err(|_| GroupError::GroupNotFound)? - .ok_or(GroupError::GroupNotFound)?; - - Ok(mls_group) - } - #[tracing::instrument(level = "trace", skip_all)] pub(crate) fn load_mls_group_with_lock( &self, @@ -364,7 +351,6 @@ impl MlsGroup { let group_id = self.group_id.clone(); // Acquire the lock synchronously using blocking_lock - let _lock = MLS_COMMIT_LOCK.get_lock_sync(group_id.clone())?; // Load the MLS group let mls_group = @@ -376,6 +362,7 @@ impl MlsGroup { operation(mls_group) } + // Load the stored OpenMLS group from the OpenMLS provider's keystore #[tracing::instrument(level = "trace", skip_all)] pub(crate) async fn load_mls_group_with_lock_async( &self, @@ -1190,14 +1177,16 @@ impl MlsGroup { /// /// If the current user has been kicked out of the group, `is_active` will return `false` pub fn is_active(&self, provider: impl OpenMlsProvider) -> Result { - self.load_mls_group_with_lock(provider, |mls_group| Ok(mls_group.is_active())) + self.load_mls_group_with_lock(provider, |mls_group| + Ok(mls_group.is_active()) + ) } /// Get the `GroupMetadata` of the group. pub fn metadata(&self, provider: impl OpenMlsProvider) -> Result { - self.load_mls_group_with_lock(provider, |mls_group| { + self.load_mls_group_with_lock(provider, |mls_group| Ok(extract_group_metadata(&mls_group)?) - }) + ) } /// Get the `GroupMutableMetadata` of the group. @@ -1205,17 +1194,17 @@ impl MlsGroup { &self, provider: impl OpenMlsProvider, ) -> Result { - self.load_mls_group_with_lock(provider, |mls_group| { + self.load_mls_group_with_lock(provider, |mls_group| Ok(GroupMutableMetadata::try_from(&mls_group)?) - }) + ) } pub fn permissions(&self) -> Result { let provider = self.mls_provider()?; - self.load_mls_group_with_lock(&provider, |mls_group| { + self.load_mls_group_with_lock(&provider, |mls_group| Ok(extract_group_permissions(&mls_group)?) - }) + ) } /// Used for testing that dm group validation works as expected. @@ -1939,17 +1928,17 @@ pub(crate) mod tests { // Check Amal's MLS group state. let amal_db = XmtpOpenMlsProvider::from(amal.context.store().conn().unwrap()); - let amal_members_len = amal_group.load_mls_group_with_lock(&amal_db, |amal_mls_group| { - Ok(amal_mls_group.members().count()) - }).unwrap(); + let amal_members_len = amal_group.load_mls_group_with_lock(&amal_db, |mls_group| + Ok(mls_group.members().count()) + ).unwrap(); assert_eq!(amal_members_len, 3); // Check Bola's MLS group state. let bola_db = XmtpOpenMlsProvider::from(bola.context.store().conn().unwrap()); - let bola_members_len = bola_group.load_mls_group_with_lock(&bola_db, |bola_mls_group| { - Ok(bola_mls_group.members().count()) - }).unwrap(); + let bola_members_len = bola_group.load_mls_group_with_lock(&bola_db, |mls_group| + Ok(mls_group.members().count()) + ).unwrap(); assert_eq!(bola_members_len, 3); @@ -2028,8 +2017,8 @@ pub(crate) mod tests { Ok(mls_group) // Return the updated group if necessary }).unwrap(); - force_add_member(&alix, &bo, &alix_group, &mut mls_group, &provider).await; // Now add bo to the group + force_add_member(&alix, &bo, &alix_group, &mut mls_group, &provider).await; // Bo should not be able to actually read this group bo.sync_welcomes(&bo.mls_provider().unwrap()).await.unwrap(); @@ -2153,9 +2142,9 @@ pub(crate) mod tests { assert_eq!(messages.len(), 2); let provider: XmtpOpenMlsProvider = client.context.store().conn().unwrap().into(); - let pending_commit_is_none = group.load_mls_group_with_lock(&provider, |mls_group| { + let pending_commit_is_none = group.load_mls_group_with_lock(&provider, |mls_group| Ok(mls_group.pending_commit().is_none()) - }).unwrap(); + ).unwrap(); assert!(pending_commit_is_none); @@ -2336,9 +2325,9 @@ pub(crate) mod tests { assert!(new_installations_were_added.is_ok()); group.sync().await.unwrap(); - let num_members = group.load_mls_group_with_lock(&provider, |mls_group| { + let num_members = group.load_mls_group_with_lock(&provider, |mls_group| Ok(mls_group.members().collect::>().len()) - }).unwrap(); + ).unwrap(); assert_eq!(num_members, 3); } @@ -3922,9 +3911,9 @@ pub(crate) mod tests { ) .unwrap(); assert!(valid_dm_group - .load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group| { + .load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group| validate_dm_group(&client, &mls_group, added_by_inbox) - }) + ) .is_ok()); // Test case 2: Invalid conversation type @@ -3940,9 +3929,9 @@ pub(crate) mod tests { ) .unwrap(); assert!(matches!( - invalid_type_group.load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group| { + invalid_type_group.load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group| validate_dm_group(&client, &mls_group, added_by_inbox) - }), + ), Err(GroupError::Generic(msg)) if msg.contains("Invalid conversation type") )); // Test case 3: Missing DmMembers @@ -3962,9 +3951,9 @@ pub(crate) mod tests { ) .unwrap(); assert!(matches!( - mismatched_dm_members_group.load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group| { + mismatched_dm_members_group.load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group| validate_dm_group(&client, &mls_group, added_by_inbox) - }), + ), Err(GroupError::Generic(msg)) if msg.contains("DM members do not match expected inboxes") )); @@ -3984,9 +3973,9 @@ pub(crate) mod tests { ) .unwrap(); assert!(matches!( - non_empty_admin_list_group.load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group| { + non_empty_admin_list_group.load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group| validate_dm_group(&client, &mls_group, added_by_inbox) - }), + ), Err(GroupError::Generic(msg)) if msg.contains("DM group must have empty admin and super admin lists") )); @@ -4005,9 +3994,9 @@ pub(crate) mod tests { ) .unwrap(); assert!(matches!( - invalid_permissions_group.load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group| { + invalid_permissions_group.load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group| validate_dm_group(&client, &mls_group, added_by_inbox) - }), + ), Err(GroupError::Generic(msg)) if msg.contains("Invalid permissions for DM group") )); } diff --git a/xmtp_mls/src/groups/subscriptions.rs b/xmtp_mls/src/groups/subscriptions.rs index 8034398bc..fae5ea0d6 100644 --- a/xmtp_mls/src/groups/subscriptions.rs +++ b/xmtp_mls/src/groups/subscriptions.rs @@ -43,14 +43,22 @@ impl MlsGroup { let process_result = retry_async!( Retry::default(), (async { + let client_id = &client_id; let msgv1 = &msgv1; self.context() .store() .transaction_async(provider, |provider| async move { // let prov_ref = &provider; // Borrow provider instead of moving it - self.load_mls_group_with_lock_async(provider, |mut mls_group| async move { + self.load_mls_group_with_lock_async(provider, |mls_group| async move { // Attempt processing immediately, but fail if the message is not an Application Message // Returning an error should roll back the DB tx + tracing::info!( + inbox_id = self.client.inbox_id(), + group_id = hex::encode(&self.group_id), + msg_id = msgv1.id, + "current epoch for [{}] in process_stream_entry()", + client_id, + ); tracing::info!( inbox_id = self.client.inbox_id(), group_id = hex::encode(&self.group_id), From 250eabb833ba7bf1b0629826f4bd0372d8a6470d Mon Sep 17 00:00:00 2001 From: Mojtaba Chenani Date: Wed, 11 Dec 2024 15:47:22 +0100 Subject: [PATCH 08/28] fixed tests --- xmtp_mls/src/client.rs | 14 ++++-- xmtp_mls/src/groups/mls_sync.rs | 7 --- xmtp_mls/src/groups/mod.rs | 85 +++++++++++++++++---------------- xmtp_mls/src/subscriptions.rs | 4 +- 4 files changed, 58 insertions(+), 52 deletions(-) diff --git a/xmtp_mls/src/client.rs b/xmtp_mls/src/client.rs index 98bf00b63..d62820672 100644 --- a/xmtp_mls/src/client.rs +++ b/xmtp_mls/src/client.rs @@ -873,9 +873,17 @@ where "[{}] syncing group", self.inbox_id() ); - group.maybe_update_installations(provider, None).await?; - group.sync_with_conn(provider).await?; - active_group_count.fetch_add(1, Ordering::SeqCst); + let is_active = group + .load_mls_group_with_lock_async(provider_ref, |mls_group| async move { + Ok::(mls_group.is_active()) + }) + .await?; + if is_active { + group.maybe_update_installations(provider, None).await?; + + group.sync_with_conn(provider).await?; + active_group_count.fetch_add(1, Ordering::SeqCst); + } Ok::<(), GroupError>(()) } diff --git a/xmtp_mls/src/groups/mls_sync.rs b/xmtp_mls/src/groups/mls_sync.rs index a20c40d1d..c2d351145 100644 --- a/xmtp_mls/src/groups/mls_sync.rs +++ b/xmtp_mls/src/groups/mls_sync.rs @@ -198,9 +198,6 @@ where #[tracing::instrument(skip_all)] pub async fn sync_with_conn(&self, provider: &XmtpOpenMlsProvider) -> Result<(), GroupError> { // Check if we're still part of the group - if !self.is_active(provider)? { - return Ok(()); - } let _mutex = self.mutex.lock().await; let mut errors: Vec = vec![]; @@ -1204,10 +1201,6 @@ where provider: &XmtpOpenMlsProvider, update_interval_ns: Option, ) -> Result<(), GroupError> { - // Check if we're still part of the group - if !self.is_active(provider)? { - return Ok(()); - } // determine how long of an interval in time to use before updating list let interval_ns = update_interval_ns.unwrap_or_else(|| SYNC_UPDATE_INSTALLATIONS_INTERVAL_NS); diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index 19fe34442..1e5384f5c 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -1177,16 +1177,14 @@ impl MlsGroup { /// /// If the current user has been kicked out of the group, `is_active` will return `false` pub fn is_active(&self, provider: impl OpenMlsProvider) -> Result { - self.load_mls_group_with_lock(provider, |mls_group| - Ok(mls_group.is_active()) - ) + self.load_mls_group_with_lock(provider, |mls_group| Ok(mls_group.is_active())) } /// Get the `GroupMetadata` of the group. pub fn metadata(&self, provider: impl OpenMlsProvider) -> Result { - self.load_mls_group_with_lock(provider, |mls_group| + self.load_mls_group_with_lock(provider, |mls_group| { Ok(extract_group_metadata(&mls_group)?) - ) + }) } /// Get the `GroupMutableMetadata` of the group. @@ -1194,17 +1192,17 @@ impl MlsGroup { &self, provider: impl OpenMlsProvider, ) -> Result { - self.load_mls_group_with_lock(provider, |mls_group| + self.load_mls_group_with_lock(provider, |mls_group| { Ok(GroupMutableMetadata::try_from(&mls_group)?) - ) + }) } pub fn permissions(&self) -> Result { let provider = self.mls_provider()?; - self.load_mls_group_with_lock(&provider, |mls_group| + self.load_mls_group_with_lock(&provider, |mls_group| { Ok(extract_group_permissions(&mls_group)?) - ) + }) } /// Used for testing that dm group validation works as expected. @@ -1928,17 +1926,17 @@ pub(crate) mod tests { // Check Amal's MLS group state. let amal_db = XmtpOpenMlsProvider::from(amal.context.store().conn().unwrap()); - let amal_members_len = amal_group.load_mls_group_with_lock(&amal_db, |mls_group| - Ok(mls_group.members().count()) - ).unwrap(); + let amal_members_len = amal_group + .load_mls_group_with_lock(&amal_db, |mls_group| Ok(mls_group.members().count())) + .unwrap(); assert_eq!(amal_members_len, 3); // Check Bola's MLS group state. let bola_db = XmtpOpenMlsProvider::from(bola.context.store().conn().unwrap()); - let bola_members_len = bola_group.load_mls_group_with_lock(&bola_db, |mls_group| - Ok(mls_group.members().count()) - ).unwrap(); + let bola_members_len = bola_group + .load_mls_group_with_lock(&bola_db, |mls_group| Ok(mls_group.members().count())) + .unwrap(); assert_eq!(bola_members_len, 3); @@ -1999,23 +1997,26 @@ pub(crate) mod tests { .unwrap(); let provider = alix.mls_provider().unwrap(); // Doctor the group membership - let mut mls_group = alix_group.load_mls_group_with_lock(&provider, |mut mls_group| { - let mut existing_extensions = mls_group.extensions().clone(); - let mut group_membership = GroupMembership::new(); - group_membership.add("deadbeef".to_string(), 1); - existing_extensions.add_or_replace(build_group_membership_extension(&group_membership)); - - mls_group - .update_group_context_extensions( - &provider, - existing_extensions.clone(), - &alix.identity().installation_keys, - ) - .unwrap(); - mls_group.merge_pending_commit(&provider).unwrap(); - - Ok(mls_group) // Return the updated group if necessary - }).unwrap(); + let mut mls_group = alix_group + .load_mls_group_with_lock(&provider, |mut mls_group| { + let mut existing_extensions = mls_group.extensions().clone(); + let mut group_membership = GroupMembership::new(); + group_membership.add("deadbeef".to_string(), 1); + existing_extensions + .add_or_replace(build_group_membership_extension(&group_membership)); + + mls_group + .update_group_context_extensions( + &provider, + existing_extensions.clone(), + &alix.identity().installation_keys, + ) + .unwrap(); + mls_group.merge_pending_commit(&provider).unwrap(); + + Ok(mls_group) // Return the updated group if necessary + }) + .unwrap(); // Now add bo to the group force_add_member(&alix, &bo, &alix_group, &mut mls_group, &provider).await; @@ -2142,9 +2143,11 @@ pub(crate) mod tests { assert_eq!(messages.len(), 2); let provider: XmtpOpenMlsProvider = client.context.store().conn().unwrap().into(); - let pending_commit_is_none = group.load_mls_group_with_lock(&provider, |mls_group| - Ok(mls_group.pending_commit().is_none()) - ).unwrap(); + let pending_commit_is_none = group + .load_mls_group_with_lock(&provider, |mls_group| { + Ok(mls_group.pending_commit().is_none()) + }) + .unwrap(); assert!(pending_commit_is_none); @@ -2325,9 +2328,11 @@ pub(crate) mod tests { assert!(new_installations_were_added.is_ok()); group.sync().await.unwrap(); - let num_members = group.load_mls_group_with_lock(&provider, |mls_group| - Ok(mls_group.members().collect::>().len()) - ).unwrap(); + let num_members = group + .load_mls_group_with_lock(&provider, |mls_group| { + Ok(mls_group.members().collect::>().len()) + }) + .unwrap(); assert_eq!(num_members, 3); } @@ -3911,9 +3916,9 @@ pub(crate) mod tests { ) .unwrap(); assert!(valid_dm_group - .load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group| + .load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group| { validate_dm_group(&client, &mls_group, added_by_inbox) - ) + }) .is_ok()); // Test case 2: Invalid conversation type diff --git a/xmtp_mls/src/subscriptions.rs b/xmtp_mls/src/subscriptions.rs index 94b93c537..2b945145e 100644 --- a/xmtp_mls/src/subscriptions.rs +++ b/xmtp_mls/src/subscriptions.rs @@ -602,11 +602,11 @@ pub(crate) mod tests { .add_members_by_inbox_id(&[bob.inbox_id()]) .await .unwrap(); - let bob_group = bob + let bob_groups = bob .sync_welcomes(&bob.mls_provider().unwrap()) .await .unwrap(); - let bob_group = bob_group.first().unwrap(); + let bob_group = bob_groups.first().unwrap(); let notify = Delivery::new(None); let notify_ptr = notify.clone(); From 4adf5cfd384379edd2ad6b9336b28e30f6ce01fb Mon Sep 17 00:00:00 2001 From: Mojtaba Chenani Date: Wed, 11 Dec 2024 18:11:37 +0100 Subject: [PATCH 09/28] fix after rebase --- xmtp_mls/src/client.rs | 2 +- xmtp_mls/src/groups/mls_sync.rs | 2 +- xmtp_mls/src/groups/mod.rs | 10 +++------- xmtp_mls/src/groups/subscriptions.rs | 22 +++++----------------- 4 files changed, 10 insertions(+), 26 deletions(-) diff --git a/xmtp_mls/src/client.rs b/xmtp_mls/src/client.rs index d62820672..9f695de3c 100644 --- a/xmtp_mls/src/client.rs +++ b/xmtp_mls/src/client.rs @@ -874,7 +874,7 @@ where self.inbox_id() ); let is_active = group - .load_mls_group_with_lock_async(provider_ref, |mls_group| async move { + .load_mls_group_with_lock_async(provider, |mls_group| async move { Ok::(mls_group.is_active()) }) .await?; diff --git a/xmtp_mls/src/groups/mls_sync.rs b/xmtp_mls/src/groups/mls_sync.rs index c2d351145..d34d79a85 100644 --- a/xmtp_mls/src/groups/mls_sync.rs +++ b/xmtp_mls/src/groups/mls_sync.rs @@ -1269,7 +1269,7 @@ where inbox_ids_to_add: &[InboxIdRef<'_>], inbox_ids_to_remove: &[InboxIdRef<'_>], ) -> Result { - self.load_mls_group_with_lock_async(provider, |mut mls_group| async move { + self.load_mls_group_with_lock_async(provider, | mls_group| async move { let existing_group_membership = extract_group_membership(mls_group.extensions())?; // TODO:nm prevent querying for updates on members who are being removed let mut inbox_ids = existing_group_membership.inbox_ids(); diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index 1e5384f5c..69b7a7083 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -58,7 +58,6 @@ use self::{ group_permissions::PolicySet, validated_commit::CommitValidationError, }; -use futures::TryFutureExt; use std::future::Future; use std::{collections::HashSet, sync::Arc}; use xmtp_cryptography::signature::{sanitize_evm_addresses, AddressValidationError}; @@ -73,9 +72,6 @@ use xmtp_proto::xmtp::mls::{ PlaintextEnvelope, }, }; - -use crate::hpke::HpkeError::StorageError; -use crate::storage::sql_key_store::SqlKeyStoreError; use crate::{ api::WrappedApiError, client::{deserialize_welcome, ClientError, XmtpMlsLocalContext}, @@ -84,7 +80,6 @@ use crate::{ MAX_PAST_EPOCHS, MUTABLE_METADATA_EXTENSION_ID, SEND_MESSAGE_UPDATE_INSTALLATIONS_INTERVAL_NS, }, - groups, hpke::{decrypt_welcome, HpkeError}, identity::{parse_credential, IdentityError}, identity_updates::{load_identity_updates, InstallationDiffError}, @@ -96,13 +91,14 @@ use crate::{ group::{ConversationType, GroupMembershipState, StoredGroup}, group_intent::IntentKind, group_message::{DeliveryStatus, GroupMessageKind, MsgQueryArgs, StoredGroupMessage}, - sql_key_store, StorageError, + sql_key_store, }, subscriptions::{LocalEventError, LocalEvents}, utils::{id::calculate_message_id, time::now_ns}, xmtp_openmls_provider::XmtpOpenMlsProvider, - GroupCommitLock, Store, MLS_COMMIT_LOCK, + Store, MLS_COMMIT_LOCK, }; +use crate::storage::StorageError; #[derive(Debug, Error)] pub enum GroupError { diff --git a/xmtp_mls/src/groups/subscriptions.rs b/xmtp_mls/src/groups/subscriptions.rs index fae5ea0d6..b12987358 100644 --- a/xmtp_mls/src/groups/subscriptions.rs +++ b/xmtp_mls/src/groups/subscriptions.rs @@ -48,29 +48,17 @@ impl MlsGroup { self.context() .store() .transaction_async(provider, |provider| async move { - // let prov_ref = &provider; // Borrow provider instead of moving it - self.load_mls_group_with_lock_async(provider, |mls_group| async move { - // Attempt processing immediately, but fail if the message is not an Application Message - // Returning an error should roll back the DB tx - tracing::info!( + tracing::info!( inbox_id = self.client.inbox_id(), group_id = hex::encode(&self.group_id), msg_id = msgv1.id, "current epoch for [{}] in process_stream_entry()", client_id, ); - tracing::info!( - inbox_id = self.client.inbox_id(), - group_id = hex::encode(&self.group_id), - msg_id = msgv1.id, - "current epoch for [{}] in process_stream_entry()", - client_id, - ); - self.process_message(&mut mls_group, provider, msgv1, false) - .await - // NOTE: We want to make sure we retry an error in process_message - .map_err(SubscribeError::ReceiveGroup) - }).await + self.process_message(provider, msgv1, false) + .await + // NOTE: We want to make sure we retry an error in process_message + .map_err(SubscribeError::ReceiveGroup) }) .await }) From 1472473eea46de71290efe7f7a199532c04e3e26 Mon Sep 17 00:00:00 2001 From: Mojtaba Chenani Date: Wed, 11 Dec 2024 18:15:10 +0100 Subject: [PATCH 10/28] remove unneeded changes --- xmtp_mls/src/groups/intents.rs | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/xmtp_mls/src/groups/intents.rs b/xmtp_mls/src/groups/intents.rs index d0c599862..16a32dbda 100644 --- a/xmtp_mls/src/groups/intents.rs +++ b/xmtp_mls/src/groups/intents.rs @@ -865,14 +865,11 @@ pub(crate) mod tests { }; let provider = group.client.mls_provider().unwrap(); - let decrypted_message = match group + let decrypted_message = group .load_mls_group_with_lock(&provider, |mut mls_group| { Ok(mls_group .process_message(&provider, mls_message).unwrap()) - }) { - Ok(message) => message, - Err(err) => panic!("Error: {:?}", err), - }; + }).unwrap(); let staged_commit = match decrypted_message.into_content() { ProcessedMessageContent::StagedCommitMessage(staged_commit) => *staged_commit, From fbc057c2464ddf29e8f2ba33e8f5d57f1ee0a344 Mon Sep 17 00:00:00 2001 From: Mojtaba Chenani Date: Wed, 11 Dec 2024 19:19:18 +0100 Subject: [PATCH 11/28] fix clippy issues --- xmtp_mls/src/groups/intents.rs | 6 +-- xmtp_mls/src/groups/mls_sync.rs | 60 +++++++++++++++--------------- xmtp_mls/src/groups/mod.rs | 31 ++++++++-------- xmtp_mls/src/lib.rs | 66 +++++++++++++++++++++++---------- 4 files changed, 95 insertions(+), 68 deletions(-) diff --git a/xmtp_mls/src/groups/intents.rs b/xmtp_mls/src/groups/intents.rs index 16a32dbda..ab3ca869f 100644 --- a/xmtp_mls/src/groups/intents.rs +++ b/xmtp_mls/src/groups/intents.rs @@ -867,9 +867,9 @@ pub(crate) mod tests { let provider = group.client.mls_provider().unwrap(); let decrypted_message = group .load_mls_group_with_lock(&provider, |mut mls_group| { - Ok(mls_group - .process_message(&provider, mls_message).unwrap()) - }).unwrap(); + Ok(mls_group.process_message(&provider, mls_message).unwrap()) + }) + .unwrap(); let staged_commit = match decrypted_message.into_content() { ProcessedMessageContent::StagedCommitMessage(staged_commit) => *staged_commit, diff --git a/xmtp_mls/src/groups/mls_sync.rs b/xmtp_mls/src/groups/mls_sync.rs index d34d79a85..6464702bb 100644 --- a/xmtp_mls/src/groups/mls_sync.rs +++ b/xmtp_mls/src/groups/mls_sync.rs @@ -428,7 +428,7 @@ where &pending_commit, &mls_group, ) - .await; + .await; if let Err(err) = maybe_validated_commit { tracing::error!( @@ -453,7 +453,11 @@ where return Ok(IntentState::ToPublish); } else { // If no error committing the change, write a transcript message - self.save_transcript_message(conn, validated_commit, envelope_timestamp_ns)?; + self.save_transcript_message( + conn, + validated_commit, + envelope_timestamp_ns, + )?; } } IntentKind::SendMessage => { @@ -473,7 +477,8 @@ where }; Ok(IntentState::Committed) - }).await + }) + .await } #[tracing::instrument(level = "trace", skip_all)] @@ -734,26 +739,25 @@ where envelope.id, intent_id ); - match self - .process_own_message(intent, provider, message.into(), envelope) - .await? - { - IntentState::ToPublish => { - Ok(provider.conn_ref().set_group_intent_to_publish(intent_id)?) - } - IntentState::Committed => { - Ok(provider.conn_ref().set_group_intent_committed(intent_id)?) - } - IntentState::Published => { - tracing::error!("Unexpected behaviour: returned intent state published from process_own_message"); - Ok(()) - } - IntentState::Error => { - tracing::warn!("Intent [{}] moved to error status", intent_id); - Ok(provider.conn_ref().set_group_intent_error(intent_id)?) - } + match self + .process_own_message(intent, provider, message.into(), envelope) + .await? + { + IntentState::ToPublish => { + Ok(provider.conn_ref().set_group_intent_to_publish(intent_id)?) } - + IntentState::Committed => { + Ok(provider.conn_ref().set_group_intent_committed(intent_id)?) + } + IntentState::Published => { + tracing::error!("Unexpected behaviour: returned intent state published from process_own_message"); + Ok(()) + } + IntentState::Error => { + tracing::warn!("Intent [{}] moved to error status", intent_id); + Ok(provider.conn_ref().set_group_intent_error(intent_id)?) + } + } } // No matching intent found Ok(None) => { @@ -850,10 +854,7 @@ where for message in messages.into_iter() { let result = retry_async!( Retry::default(), - (async { - self.consume_message(provider, &message) - .await - }) + (async { self.consume_message(provider, &message).await }) ); if let Err(e) = result { let is_retryable = e.is_retryable(); @@ -1202,7 +1203,7 @@ where update_interval_ns: Option, ) -> Result<(), GroupError> { // determine how long of an interval in time to use before updating list - let interval_ns = update_interval_ns.unwrap_or_else(|| SYNC_UPDATE_INSTALLATIONS_INTERVAL_NS); + let interval_ns = update_interval_ns.unwrap_or(SYNC_UPDATE_INSTALLATIONS_INTERVAL_NS); let now_ns = crate::utils::time::now_ns(); let last_ns = provider @@ -1269,7 +1270,7 @@ where inbox_ids_to_add: &[InboxIdRef<'_>], inbox_ids_to_remove: &[InboxIdRef<'_>], ) -> Result { - self.load_mls_group_with_lock_async(provider, | mls_group| async move { + self.load_mls_group_with_lock_async(provider, |mls_group| async move { let existing_group_membership = extract_group_membership(mls_group.extensions())?; // TODO:nm prevent querying for updates on members who are being removed let mut inbox_ids = existing_group_membership.inbox_ids(); @@ -1320,7 +1321,8 @@ where .map(|s| s.to_string()) .collect::>(), )) - }).await + }) + .await } /** diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index 69b7a7083..427689cd8 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -58,20 +58,7 @@ use self::{ group_permissions::PolicySet, validated_commit::CommitValidationError, }; -use std::future::Future; -use std::{collections::HashSet, sync::Arc}; -use xmtp_cryptography::signature::{sanitize_evm_addresses, AddressValidationError}; -use xmtp_id::{InboxId, InboxIdRef}; -use xmtp_proto::xmtp::mls::{ - api::v1::{ - group_message::{Version as GroupMessageVersion, V1 as GroupMessageV1}, - GroupMessage, - }, - message_contents::{ - plaintext_envelope::{Content, V1}, - PlaintextEnvelope, - }, -}; +use crate::storage::StorageError; use crate::{ api::WrappedApiError, client::{deserialize_welcome, ClientError, XmtpMlsLocalContext}, @@ -98,7 +85,20 @@ use crate::{ xmtp_openmls_provider::XmtpOpenMlsProvider, Store, MLS_COMMIT_LOCK, }; -use crate::storage::StorageError; +use std::future::Future; +use std::{collections::HashSet, sync::Arc}; +use xmtp_cryptography::signature::{sanitize_evm_addresses, AddressValidationError}; +use xmtp_id::{InboxId, InboxIdRef}; +use xmtp_proto::xmtp::mls::{ + api::v1::{ + group_message::{Version as GroupMessageVersion, V1 as GroupMessageV1}, + GroupMessage, + }, + message_contents::{ + plaintext_envelope::{Content, V1}, + PlaintextEnvelope, + }, +}; #[derive(Debug, Error)] pub enum GroupError { @@ -1650,7 +1650,6 @@ pub(crate) mod tests { use diesel::connection::SimpleConnection; use futures::future::join_all; - use openmls::prelude::Member; use prost::Message; use std::sync::Arc; use xmtp_cryptography::utils::generate_local_wallet; diff --git a/xmtp_mls/src/lib.rs b/xmtp_mls/src/lib.rs index 6b82a8f6a..e7df53f6f 100644 --- a/xmtp_mls/src/lib.rs +++ b/xmtp_mls/src/lib.rs @@ -1,6 +1,8 @@ #![recursion_limit = "256"] #![warn(clippy::unwrap_used)] +#[macro_use] +extern crate tracing; pub mod api; pub mod builder; pub mod client; @@ -21,19 +23,17 @@ pub mod utils; pub mod verified_key_package_v2; mod xmtp_openmls_provider; +pub use client::{Client, Network}; use std::collections::HashMap; +use std::sync::LazyLock; use std::sync::{Arc, Mutex}; -use tokio::sync::{Semaphore, OwnedSemaphorePermit}; -pub use client::{Client, Network}; use storage::{DuplicateItem, StorageError}; +use tokio::sync::{OwnedSemaphorePermit, Semaphore}; pub use xmtp_openmls_provider::XmtpOpenMlsProvider; -use std::sync::LazyLock; pub use xmtp_id::InboxOwner; pub use xmtp_proto::api_client::trait_impls::*; -#[macro_use] -extern crate tracing; /// A manager for group-specific semaphores #[derive(Debug)] pub struct GroupCommitLock { @@ -41,6 +41,11 @@ pub struct GroupCommitLock { locks: Mutex, Arc>>, } +impl Default for GroupCommitLock { + fn default() -> Self { + Self::new() + } +} impl GroupCommitLock { /// Create a new `GroupCommitLock` pub fn new() -> Self { @@ -50,32 +55,53 @@ impl GroupCommitLock { } /// Get or create a semaphore for a specific group and acquire it, returning a guard - pub async fn get_lock_async(&self, group_id: Vec) -> SemaphoreGuard { + pub async fn get_lock_async(&self, group_id: Vec) -> Result { let semaphore = { - let mut locks = self.locks.lock().unwrap(); - locks - .entry(group_id) - .or_insert_with(|| Arc::new(Semaphore::new(1))) - .clone() + match self.locks.lock() { + Ok(mut locks) => locks + .entry(group_id) + .or_insert_with(|| Arc::new(Semaphore::new(1))) + .clone(), + Err(err) => { + eprintln!("Failed to lock the mutex: {}", err); + return Err(GroupError::LockUnavailable); + } + } }; let semaphore_clone = semaphore.clone(); - let permit = semaphore.acquire_owned().await.unwrap(); - SemaphoreGuard { _permit: permit, _semaphore: semaphore_clone } + let permit = match semaphore.acquire_owned().await { + Ok(permit) => permit, + Err(err) => { + eprintln!("Failed to acquire semaphore permit: {}", err); + return Err(GroupError::LockUnavailable); + } + }; Ok(SemaphoreGuard { + _permit: permit, + _semaphore: semaphore_clone, + }) } /// Get or create a semaphore for a specific group and acquire it synchronously pub fn get_lock_sync(&self, group_id: Vec) -> Result { let semaphore = { - let mut locks = self.locks.lock().unwrap(); - locks - .entry(group_id) - .or_insert_with(|| Arc::new(Semaphore::new(1))) - .clone() // Clone here to retain ownership for later use + match self.locks.lock() { + Ok(mut locks) => locks + .entry(group_id) + .or_insert_with(|| Arc::new(Semaphore::new(1))) + .clone(), + Err(err) => { + eprintln!("Failed to lock the mutex: {}", err); + return Err(GroupError::LockUnavailable); + } + } }; // Synchronously acquire the permit - let permit = semaphore.clone().try_acquire_owned().map_err(|_| GroupError::LockUnavailable)?; + let permit = semaphore + .clone() + .try_acquire_owned() + .map_err(|_| GroupError::LockUnavailable)?; Ok(SemaphoreGuard { _permit: permit, _semaphore: semaphore, // semaphore is now valid because we cloned it earlier @@ -136,10 +162,10 @@ pub trait Delete { fn delete(&self, key: Self::Key) -> Result; } +use crate::groups::GroupError; pub use stream_handles::{ spawn, AbortHandle, GenericStreamHandle, StreamHandle, StreamHandleError, }; -use crate::groups::GroupError; #[cfg(target_arch = "wasm32")] #[doc(hidden)] From dbd57e2b3c91c0764516493f9aaba3be30a53dc9 Mon Sep 17 00:00:00 2001 From: Mojtaba Chenani Date: Wed, 11 Dec 2024 19:23:19 +0100 Subject: [PATCH 12/28] fix fmt --- xmtp_mls/src/lib.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xmtp_mls/src/lib.rs b/xmtp_mls/src/lib.rs index e7df53f6f..bc6e29009 100644 --- a/xmtp_mls/src/lib.rs +++ b/xmtp_mls/src/lib.rs @@ -76,7 +76,8 @@ impl GroupCommitLock { eprintln!("Failed to acquire semaphore permit: {}", err); return Err(GroupError::LockUnavailable); } - }; Ok(SemaphoreGuard { + }; + Ok(SemaphoreGuard { _permit: permit, _semaphore: semaphore_clone, }) From 261f30023ddbc2ca1280b85b47d75947447fac69 Mon Sep 17 00:00:00 2001 From: Andrew Plaza Date: Thu, 12 Dec 2024 11:41:25 -0500 Subject: [PATCH 13/28] fix webassembly compile --- xmtp_mls/src/groups/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index b4fed0ded..0ded70adb 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -366,7 +366,7 @@ impl MlsGroup { operation: F, ) -> Result where - F: FnOnce(OpenMlsGroup) -> Fut + Send, + F: FnOnce(OpenMlsGroup) -> Fut, Fut: Future>, E: From + From, { From 6511f091a03e2c90a4ea86e7a7efe98684080523 Mon Sep 17 00:00:00 2001 From: Mojtaba Chenani Date: Thu, 12 Dec 2024 22:22:56 +0100 Subject: [PATCH 14/28] fix tests --- bindings_ffi/src/mls.rs | 14 +++- bindings_node/src/conversation.rs | 28 +++---- bindings_wasm/src/conversation.rs | 25 ++++--- common/src/test.rs | 4 +- examples/cli/serializable.rs | 5 +- xmtp_mls/src/groups/mod.rs | 118 ++++++++++++++++-------------- xmtp_mls/src/lib.rs | 50 ++++--------- xmtp_mls/src/subscriptions.rs | 2 +- 8 files changed, 124 insertions(+), 122 deletions(-) diff --git a/bindings_ffi/src/mls.rs b/bindings_ffi/src/mls.rs index a9ae8232b..492e42b78 100644 --- a/bindings_ffi/src/mls.rs +++ b/bindings_ffi/src/mls.rs @@ -1396,7 +1396,7 @@ impl FfiConversation { pub fn group_image_url_square(&self) -> Result { let provider = self.inner.mls_provider()?; - Ok(self.inner.group_image_url_square(provider)?) + Ok(self.inner.group_image_url_square(&provider)?) } pub async fn update_group_description( @@ -1412,7 +1412,7 @@ impl FfiConversation { pub fn group_description(&self) -> Result { let provider = self.inner.mls_provider()?; - Ok(self.inner.group_description(provider)?) + Ok(self.inner.group_description(&provider)?) } pub async fn update_group_pinned_frame_url( @@ -1546,7 +1546,10 @@ impl FfiConversation { pub fn group_metadata(&self) -> Result, GenericError> { let provider = self.inner.mls_provider()?; - let metadata = self.inner.metadata(provider)?; + // blocking is OK b/c not wasm + let metadata = tokio::task::block_in_place(|| { + futures::executor::block_on(self.inner.metadata(&provider)) + })?; Ok(Arc::new(FfiConversationMetadata { inner: Arc::new(metadata), })) @@ -1558,7 +1561,10 @@ impl FfiConversation { pub fn conversation_type(&self) -> Result { let provider = self.inner.mls_provider()?; - let conversation_type = self.inner.conversation_type(&provider)?; + // blocking OK b/c not wasm + let conversation_type = tokio::task::block_in_place(|| { + futures::executor::block_on(self.inner.conversation_type(&provider)) + })?; Ok(conversation_type.into()) } } diff --git a/bindings_node/src/conversation.rs b/bindings_node/src/conversation.rs index 897662463..897276889 100644 --- a/bindings_node/src/conversation.rs +++ b/bindings_node/src/conversation.rs @@ -167,9 +167,10 @@ impl Conversation { self.created_at_ns, ); let provider = group.mls_provider().map_err(ErrorWrapper::from)?; - let conversation_type = group - .conversation_type(&provider) - .map_err(ErrorWrapper::from)?; + let conversation_type = tokio::task::block_in_place(|| { + futures::executor::block_on(group.conversation_type(&provider)) + }) + .map_err(ErrorWrapper::from)?; let kind = match conversation_type { ConversationType::Group => None, ConversationType::Dm => Some(XmtpGroupMessageKind::Application), @@ -248,7 +249,7 @@ impl Conversation { ); let admin_list = group - .admin_list(group.mls_provider().map_err(ErrorWrapper::from)?) + .admin_list(&group.mls_provider().map_err(ErrorWrapper::from)?) .map_err(ErrorWrapper::from)?; Ok(admin_list) @@ -263,7 +264,7 @@ impl Conversation { ); let super_admin_list = group - .super_admin_list(group.mls_provider().map_err(ErrorWrapper::from)?) + .super_admin_list(&group.mls_provider().map_err(ErrorWrapper::from)?) .map_err(ErrorWrapper::from)?; Ok(super_admin_list) @@ -449,7 +450,7 @@ impl Conversation { ); let group_name = group - .group_name(group.mls_provider().map_err(ErrorWrapper::from)?) + .group_name(&group.mls_provider().map_err(ErrorWrapper::from)?) .map_err(ErrorWrapper::from)?; Ok(group_name) @@ -480,7 +481,7 @@ impl Conversation { ); let group_image_url_square = group - .group_image_url_square(group.mls_provider().map_err(ErrorWrapper::from)?) + .group_image_url_square(&group.mls_provider().map_err(ErrorWrapper::from)?) .map_err(ErrorWrapper::from)?; Ok(group_image_url_square) @@ -511,7 +512,7 @@ impl Conversation { ); let group_description = group - .group_description(group.mls_provider().map_err(ErrorWrapper::from)?) + .group_description(&group.mls_provider().map_err(ErrorWrapper::from)?) .map_err(ErrorWrapper::from)?; Ok(group_description) @@ -542,7 +543,7 @@ impl Conversation { ); let group_pinned_frame_url = group - .group_pinned_frame_url(group.mls_provider().map_err(ErrorWrapper::from)?) + .group_pinned_frame_url(&group.mls_provider().map_err(ErrorWrapper::from)?) .map_err(ErrorWrapper::from)?; Ok(group_pinned_frame_url) @@ -585,7 +586,7 @@ impl Conversation { Ok( group - .is_active(group.mls_provider().map_err(ErrorWrapper::from)?) + .is_active(&group.mls_provider().map_err(ErrorWrapper::from)?) .map_err(ErrorWrapper::from)?, ) } @@ -609,9 +610,10 @@ impl Conversation { self.created_at_ns, ); - let metadata = group - .metadata(group.mls_provider().map_err(ErrorWrapper::from)?) - .map_err(ErrorWrapper::from)?; + let metadata = tokio::task::block_in_place(|| { + futures::executor::block_on(group.metadata(&group.mls_provider()?)) + }) + .map_err(ErrorWrapper::from)?; Ok(GroupMetadata { inner: metadata }) } diff --git a/bindings_wasm/src/conversation.rs b/bindings_wasm/src/conversation.rs index 9c7a1e090..147c63b94 100644 --- a/bindings_wasm/src/conversation.rs +++ b/bindings_wasm/src/conversation.rs @@ -180,7 +180,10 @@ impl Conversation { } #[wasm_bindgen(js_name = findMessages)] - pub fn find_messages(&self, opts: Option) -> Result, JsError> { + pub async fn find_messages( + &self, + opts: Option, + ) -> Result, JsError> { let opts = opts.unwrap_or_default(); let group = self.to_mls_group(); let provider = group @@ -188,6 +191,7 @@ impl Conversation { .map_err(|e| JsError::new(&format!("{e}")))?; let conversation_type = group .conversation_type(&provider) + .await .map_err(|e| JsError::new(&format!("{e}")))?; let kind = match conversation_type { ConversationType::Group => None, @@ -238,7 +242,7 @@ impl Conversation { let group = self.to_mls_group(); let admin_list = group .admin_list( - group + &group .mls_provider() .map_err(|e| JsError::new(&format!("{e}")))?, ) @@ -252,7 +256,7 @@ impl Conversation { let group = self.to_mls_group(); let super_admin_list = group .super_admin_list( - group + &group .mls_provider() .map_err(|e| JsError::new(&format!("{e}")))?, ) @@ -398,7 +402,7 @@ impl Conversation { let group_name = group .group_name( - group + &group .mls_provider() .map_err(|e| JsError::new(&format!("{e}")))?, ) @@ -428,7 +432,7 @@ impl Conversation { let group_image_url_square = group .group_image_url_square( - group + &group .mls_provider() .map_err(|e| JsError::new(&format!("{e}")))?, ) @@ -455,7 +459,7 @@ impl Conversation { let group_description = group .group_description( - group + &group .mls_provider() .map_err(|e| JsError::new(&format!("{e}")))?, ) @@ -485,7 +489,7 @@ impl Conversation { let group_pinned_frame_url = group .group_pinned_frame_url( - group + &group .mls_provider() .map_err(|e| JsError::new(&format!("{e}")))?, ) @@ -505,7 +509,7 @@ impl Conversation { group .is_active( - group + &group .mls_provider() .map_err(|e| JsError::new(&format!("{e}")))?, ) @@ -522,14 +526,15 @@ impl Conversation { } #[wasm_bindgen(js_name = groupMetadata)] - pub fn group_metadata(&self) -> Result { + pub async fn group_metadata(&self) -> Result { let group = self.to_mls_group(); let metadata = group .metadata( - group + &group .mls_provider() .map_err(|e| JsError::new(&format!("{e}")))?, ) + .await .map_err(|e| JsError::new(&format!("{e}")))?; Ok(GroupMetadata { inner: metadata }) diff --git a/common/src/test.rs b/common/src/test.rs index c11c692cb..4cfb2442d 100644 --- a/common/src/test.rs +++ b/common/src/test.rs @@ -38,7 +38,7 @@ pub fn logger() { .from_env_lossy() }; - tracing_subscriber::registry() + let _ = tracing_subscriber::registry() // structured JSON logger only if STRUCTURED=true .with(is_structured.then(|| { tracing_subscriber::fmt::layer() @@ -61,7 +61,7 @@ pub fn logger() { }) .with_filter(filter()) })) - .init(); + .try_init(); }); } diff --git a/examples/cli/serializable.rs b/examples/cli/serializable.rs index c6ee793ce..545081638 100644 --- a/examples/cli/serializable.rs +++ b/examples/cli/serializable.rs @@ -31,11 +31,12 @@ impl SerializableGroup { let metadata = group .metadata( - group + &group .mls_provider() .expect("MLS Provider could not be created"), ) - .expect("could not load metadata"); + .await + .unwrap(); let permissions = group.permissions().expect("could not load permissions"); Self { diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index 6704246a8..6c2377314 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -206,8 +206,10 @@ pub enum GroupError { IntentNotCommitted, #[error(transparent)] ProcessIntent(#[from] ProcessIntentError), - #[error("Failed to acquire lock for group operation")] - LockUnavailable, + #[error(transparent)] + LockUnavailable(#[from] tokio::sync::AcquireError), + #[error(transparent)] + LockFailedToAcquire(#[from] tokio::sync::TryAcquireError), } impl RetryableError for GroupError { @@ -234,7 +236,8 @@ impl RetryableError for GroupError { Self::MessageHistory(err) => err.is_retryable(), Self::ProcessIntent(err) => err.is_retryable(), Self::LocalEvent(err) => err.is_retryable(), - Self::LockUnavailable => true, + Self::LockUnavailable(_) => true, + Self::LockFailedToAcquire(_) => true, Self::SyncFailedToWait => true, Self::GroupNotFound | Self::GroupMetadata(_) @@ -893,7 +896,7 @@ impl MlsGroup { /// to perform these updates. pub async fn update_group_name(&self, group_name: String) -> Result<(), GroupError> { let provider = self.client.mls_provider()?; - if self.metadata(&provider)?.conversation_type == ConversationType::Dm { + if self.metadata(&provider).await?.conversation_type == ConversationType::Dm { return Err(GroupError::DmGroupMetadataForbidden); } let intent_data: Vec = @@ -911,7 +914,7 @@ impl MlsGroup { metadata_field: Option, ) -> Result<(), GroupError> { let provider = self.client.mls_provider()?; - if self.metadata(&provider)?.conversation_type == ConversationType::Dm { + if self.metadata(&provider).await?.conversation_type == ConversationType::Dm { return Err(GroupError::DmGroupMetadataForbidden); } if permission_update_type == PermissionUpdateType::UpdateMetadata @@ -933,7 +936,7 @@ impl MlsGroup { } /// Retrieves the group name from the group's mutable metadata extension. - pub fn group_name(&self, provider: impl OpenMlsProvider) -> Result { + pub fn group_name(&self, provider: &XmtpOpenMlsProvider) -> Result { let mutable_metadata = self.mutable_metadata(provider)?; match mutable_metadata .attributes @@ -952,7 +955,7 @@ impl MlsGroup { group_description: String, ) -> Result<(), GroupError> { let provider = self.client.mls_provider()?; - if self.metadata(&provider)?.conversation_type == ConversationType::Dm { + if self.metadata(&provider).await?.conversation_type == ConversationType::Dm { return Err(GroupError::DmGroupMetadataForbidden); } let intent_data: Vec = @@ -962,7 +965,7 @@ impl MlsGroup { self.sync_until_intent_resolved(&provider, intent.id).await } - pub fn group_description(&self, provider: impl OpenMlsProvider) -> Result { + pub fn group_description(&self, provider: &XmtpOpenMlsProvider) -> Result { let mutable_metadata = self.mutable_metadata(provider)?; match mutable_metadata .attributes @@ -981,7 +984,7 @@ impl MlsGroup { group_image_url_square: String, ) -> Result<(), GroupError> { let provider = self.client.mls_provider()?; - if self.metadata(&provider)?.conversation_type == ConversationType::Dm { + if self.metadata(&provider).await?.conversation_type == ConversationType::Dm { return Err(GroupError::DmGroupMetadataForbidden); } let intent_data: Vec = @@ -995,7 +998,7 @@ impl MlsGroup { /// Retrieves the image URL (square) of the group from the group's mutable metadata extension. pub fn group_image_url_square( &self, - provider: impl OpenMlsProvider, + provider: &XmtpOpenMlsProvider, ) -> Result { let mutable_metadata = self.mutable_metadata(provider)?; match mutable_metadata @@ -1014,7 +1017,7 @@ impl MlsGroup { pinned_frame_url: String, ) -> Result<(), GroupError> { let provider = self.client.mls_provider()?; - if self.metadata(&provider)?.conversation_type == ConversationType::Dm { + if self.metadata(&provider).await?.conversation_type == ConversationType::Dm { return Err(GroupError::DmGroupMetadataForbidden); } let intent_data: Vec = @@ -1026,7 +1029,7 @@ impl MlsGroup { pub fn group_pinned_frame_url( &self, - provider: impl OpenMlsProvider, + provider: &XmtpOpenMlsProvider, ) -> Result { let mutable_metadata = self.mutable_metadata(provider)?; match mutable_metadata @@ -1041,7 +1044,7 @@ impl MlsGroup { } /// Retrieves the admin list of the group from the group's mutable metadata extension. - pub fn admin_list(&self, provider: impl OpenMlsProvider) -> Result, GroupError> { + pub fn admin_list(&self, provider: &XmtpOpenMlsProvider) -> Result, GroupError> { let mutable_metadata = self.mutable_metadata(provider)?; Ok(mutable_metadata.admin_list) } @@ -1049,7 +1052,7 @@ impl MlsGroup { /// Retrieves the super admin list of the group from the group's mutable metadata extension. pub fn super_admin_list( &self, - provider: impl OpenMlsProvider, + provider: &XmtpOpenMlsProvider, ) -> Result, GroupError> { let mutable_metadata = self.mutable_metadata(provider)?; Ok(mutable_metadata.super_admin_list) @@ -1059,7 +1062,7 @@ impl MlsGroup { pub fn is_admin( &self, inbox_id: String, - provider: impl OpenMlsProvider, + provider: &XmtpOpenMlsProvider, ) -> Result { let mutable_metadata = self.mutable_metadata(provider)?; Ok(mutable_metadata.admin_list.contains(&inbox_id)) @@ -1069,18 +1072,18 @@ impl MlsGroup { pub fn is_super_admin( &self, inbox_id: String, - provider: impl OpenMlsProvider, + provider: &XmtpOpenMlsProvider, ) -> Result { let mutable_metadata = self.mutable_metadata(provider)?; Ok(mutable_metadata.super_admin_list.contains(&inbox_id)) } /// Retrieves the conversation type of the group from the group's metadata extension. - pub fn conversation_type( + pub async fn conversation_type( &self, - provider: impl OpenMlsProvider, + provider: &XmtpOpenMlsProvider, ) -> Result { - let metadata = self.metadata(provider)?; + let metadata = self.metadata(provider).await?; Ok(metadata.conversation_type) } @@ -1091,7 +1094,7 @@ impl MlsGroup { inbox_id: String, ) -> Result<(), GroupError> { let provider = self.client.mls_provider()?; - if self.metadata(&provider)?.conversation_type == ConversationType::Dm { + if self.metadata(&provider).await?.conversation_type == ConversationType::Dm { return Err(GroupError::DmGroupMetadataForbidden); } let intent_action_type = match action_type { @@ -1175,21 +1178,25 @@ impl MlsGroup { /// Checks if the current user is active in the group. /// /// If the current user has been kicked out of the group, `is_active` will return `false` - pub fn is_active(&self, provider: impl OpenMlsProvider) -> Result { + pub fn is_active(&self, provider: &XmtpOpenMlsProvider) -> Result { self.load_mls_group_with_lock(provider, |mls_group| Ok(mls_group.is_active())) } /// Get the `GroupMetadata` of the group. - pub fn metadata(&self, provider: impl OpenMlsProvider) -> Result { - self.load_mls_group_with_lock(provider, |mls_group| { - Ok(extract_group_metadata(&mls_group)?) + pub async fn metadata( + &self, + provider: &XmtpOpenMlsProvider, + ) -> Result { + self.load_mls_group_with_lock_async(provider, |mls_group| { + futures::future::ready(extract_group_metadata(&mls_group).map_err(Into::into)) }) + .await } /// Get the `GroupMutableMetadata` of the group. pub fn mutable_metadata( &self, - provider: impl OpenMlsProvider, + provider: &XmtpOpenMlsProvider, ) -> Result { self.load_mls_group_with_lock(provider, |mls_group| { Ok(GroupMutableMetadata::try_from(&mls_group)?) @@ -1849,7 +1856,7 @@ pub(crate) mod tests { // Verify bola can see the group name let bola_group_name = bola_group - .group_name(bola_group.mls_provider().unwrap()) + .group_name(&bola_group.mls_provider().unwrap()) .unwrap(); assert_eq!(bola_group_name, ""); @@ -2222,7 +2229,7 @@ pub(crate) mod tests { let bola_group = receive_group_invite(&bola).await; bola_group.sync().await.unwrap(); assert!(!bola_group - .is_active(bola_group.mls_provider().unwrap()) + .is_active(&bola_group.mls_provider().unwrap()) .unwrap()) } @@ -2417,7 +2424,7 @@ pub(crate) mod tests { .unwrap(); let binding = amal_group - .mutable_metadata(amal_group.mls_provider().unwrap()) + .mutable_metadata(&amal_group.mls_provider().unwrap()) .expect("msg"); let amal_group_name: &String = binding .attributes @@ -2480,7 +2487,7 @@ pub(crate) mod tests { amal_group.sync().await.unwrap(); let group_mutable_metadata = amal_group - .mutable_metadata(amal_group.mls_provider().unwrap()) + .mutable_metadata(&amal_group.mls_provider().unwrap()) .unwrap(); assert!(group_mutable_metadata.attributes.len().eq(&4)); assert!(group_mutable_metadata @@ -2503,7 +2510,7 @@ pub(crate) mod tests { let bola_group = bola_groups.first().unwrap(); bola_group.sync().await.unwrap(); let group_mutable_metadata = bola_group - .mutable_metadata(bola_group.mls_provider().unwrap()) + .mutable_metadata(&bola_group.mls_provider().unwrap()) .unwrap(); assert!(group_mutable_metadata .attributes @@ -2522,7 +2529,7 @@ pub(crate) mod tests { // Verify amal group sees update amal_group.sync().await.unwrap(); let binding = amal_group - .mutable_metadata(amal_group.mls_provider().unwrap()) + .mutable_metadata(&amal_group.mls_provider().unwrap()) .expect("msg"); let amal_group_name: &String = binding .attributes @@ -2533,7 +2540,7 @@ pub(crate) mod tests { // Verify bola group sees update bola_group.sync().await.unwrap(); let binding = bola_group - .mutable_metadata(bola_group.mls_provider().unwrap()) + .mutable_metadata(&bola_group.mls_provider().unwrap()) .expect("msg"); let bola_group_name: &String = binding .attributes @@ -2550,7 +2557,7 @@ pub(crate) mod tests { // Verify bola group does not see an update bola_group.sync().await.unwrap(); let binding = bola_group - .mutable_metadata(bola_group.mls_provider().unwrap()) + .mutable_metadata(&bola_group.mls_provider().unwrap()) .expect("msg"); let bola_group_name: &String = binding .attributes @@ -2571,7 +2578,7 @@ pub(crate) mod tests { amal_group.sync().await.unwrap(); let group_mutable_metadata = amal_group - .mutable_metadata(amal_group.mls_provider().unwrap()) + .mutable_metadata(&amal_group.mls_provider().unwrap()) .unwrap(); assert!(group_mutable_metadata .attributes @@ -2588,7 +2595,7 @@ pub(crate) mod tests { // Verify amal group sees update amal_group.sync().await.unwrap(); let binding = amal_group - .mutable_metadata(amal_group.mls_provider().unwrap()) + .mutable_metadata(&amal_group.mls_provider().unwrap()) .expect("msg"); let amal_group_image_url: &String = binding .attributes @@ -2609,7 +2616,7 @@ pub(crate) mod tests { amal_group.sync().await.unwrap(); let group_mutable_metadata = amal_group - .mutable_metadata(amal_group.mls_provider().unwrap()) + .mutable_metadata(&amal_group.mls_provider().unwrap()) .unwrap(); assert!(group_mutable_metadata .attributes @@ -2626,7 +2633,7 @@ pub(crate) mod tests { // Verify amal group sees update amal_group.sync().await.unwrap(); let binding = amal_group - .mutable_metadata(amal_group.mls_provider().unwrap()) + .mutable_metadata(&amal_group.mls_provider().unwrap()) .expect("msg"); let amal_group_pinned_frame_url: &String = binding .attributes @@ -2649,7 +2656,7 @@ pub(crate) mod tests { amal_group.sync().await.unwrap(); let group_mutable_metadata = amal_group - .mutable_metadata(amal_group.mls_provider().unwrap()) + .mutable_metadata(&amal_group.mls_provider().unwrap()) .unwrap(); assert!(group_mutable_metadata .attributes @@ -2670,7 +2677,7 @@ pub(crate) mod tests { let bola_group = bola_groups.first().unwrap(); bola_group.sync().await.unwrap(); let group_mutable_metadata = bola_group - .mutable_metadata(bola_group.mls_provider().unwrap()) + .mutable_metadata(&bola_group.mls_provider().unwrap()) .unwrap(); assert!(group_mutable_metadata .attributes @@ -2687,7 +2694,7 @@ pub(crate) mod tests { // Verify amal group sees update amal_group.sync().await.unwrap(); let binding = amal_group - .mutable_metadata(amal_group.mls_provider().unwrap()) + .mutable_metadata(&amal_group.mls_provider().unwrap()) .unwrap(); let amal_group_name: &String = binding .attributes @@ -2698,7 +2705,7 @@ pub(crate) mod tests { // Verify bola group sees update bola_group.sync().await.unwrap(); let binding = bola_group - .mutable_metadata(bola_group.mls_provider().unwrap()) + .mutable_metadata(&bola_group.mls_provider().unwrap()) .expect("msg"); let bola_group_name: &String = binding .attributes @@ -2715,7 +2722,7 @@ pub(crate) mod tests { // Verify amal group sees an update amal_group.sync().await.unwrap(); let binding = amal_group - .mutable_metadata(amal_group.mls_provider().unwrap()) + .mutable_metadata(&amal_group.mls_provider().unwrap()) .expect("msg"); let amal_group_name: &String = binding .attributes @@ -2782,13 +2789,13 @@ pub(crate) mod tests { bola_group.sync().await.unwrap(); assert_eq!( bola_group - .admin_list(bola_group.mls_provider().unwrap()) + .admin_list(&bola_group.mls_provider().unwrap()) .unwrap() .len(), 1 ); assert!(bola_group - .admin_list(bola_group.mls_provider().unwrap()) + .admin_list(&bola_group.mls_provider().unwrap()) .unwrap() .contains(&bola.inbox_id().to_string())); @@ -2819,13 +2826,13 @@ pub(crate) mod tests { bola_group.sync().await.unwrap(); assert_eq!( bola_group - .admin_list(bola_group.mls_provider().unwrap()) + .admin_list(&bola_group.mls_provider().unwrap()) .unwrap() .len(), 0 ); assert!(!bola_group - .admin_list(bola_group.mls_provider().unwrap()) + .admin_list(&bola_group.mls_provider().unwrap()) .unwrap() .contains(&bola.inbox_id().to_string())); @@ -3103,13 +3110,14 @@ pub(crate) mod tests { amal_group.sync().await.unwrap(); let mutable_metadata = amal_group - .mutable_metadata(amal_group.mls_provider().unwrap()) + .mutable_metadata(&amal_group.mls_provider().unwrap()) .unwrap(); assert_eq!(mutable_metadata.super_admin_list.len(), 1); assert_eq!(mutable_metadata.super_admin_list[0], amal.inbox_id()); let protected_metadata: GroupMetadata = amal_group - .metadata(amal_group.mls_provider().unwrap()) + .metadata(&amal_group.mls_provider().unwrap()) + .await .unwrap(); assert_eq!( protected_metadata.conversation_type, @@ -3149,7 +3157,7 @@ pub(crate) mod tests { .unwrap(); amal_group.sync().await.unwrap(); let name = amal_group - .group_name(amal_group.mls_provider().unwrap()) + .group_name(&amal_group.mls_provider().unwrap()) .unwrap(); assert_eq!(name, "Name Update 1"); @@ -3170,7 +3178,7 @@ pub(crate) mod tests { amal_group.sync().await.unwrap(); bola_group.sync().await.unwrap(); let binding = amal_group - .mutable_metadata(amal_group.mls_provider().unwrap()) + .mutable_metadata(&amal_group.mls_provider().unwrap()) .expect("msg"); let amal_group_name: &String = binding .attributes @@ -3178,7 +3186,7 @@ pub(crate) mod tests { .unwrap(); assert_eq!(amal_group_name, "Name Update 2"); let binding = bola_group - .mutable_metadata(bola_group.mls_provider().unwrap()) + .mutable_metadata(&bola_group.mls_provider().unwrap()) .expect("msg"); let bola_group_name: &String = binding .attributes @@ -3389,16 +3397,16 @@ pub(crate) mod tests { amal_dm.sync().await.unwrap(); bola_dm.sync().await.unwrap(); let is_amal_admin = amal_dm - .is_admin(amal.inbox_id().to_string(), amal.mls_provider().unwrap()) + .is_admin(amal.inbox_id().to_string(), &amal.mls_provider().unwrap()) .unwrap(); let is_bola_admin = amal_dm - .is_admin(bola.inbox_id().to_string(), bola.mls_provider().unwrap()) + .is_admin(bola.inbox_id().to_string(), &bola.mls_provider().unwrap()) .unwrap(); let is_amal_super_admin = amal_dm - .is_super_admin(amal.inbox_id().to_string(), amal.mls_provider().unwrap()) + .is_super_admin(amal.inbox_id().to_string(), &amal.mls_provider().unwrap()) .unwrap(); let is_bola_super_admin = amal_dm - .is_super_admin(bola.inbox_id().to_string(), bola.mls_provider().unwrap()) + .is_super_admin(bola.inbox_id().to_string(), &bola.mls_provider().unwrap()) .unwrap(); assert!(!is_amal_admin); assert!(!is_bola_admin); diff --git a/xmtp_mls/src/lib.rs b/xmtp_mls/src/lib.rs index c2993fe0b..62b5057ae 100644 --- a/xmtp_mls/src/lib.rs +++ b/xmtp_mls/src/lib.rs @@ -20,9 +20,9 @@ pub mod verified_key_package_v2; mod xmtp_openmls_provider; pub use client::{Client, Network}; +use parking_lot::Mutex; use std::collections::HashMap; -use std::sync::LazyLock; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, LazyLock}; use storage::{DuplicateItem, StorageError}; use tokio::sync::{OwnedSemaphorePermit, Semaphore}; pub use xmtp_openmls_provider::XmtpOpenMlsProvider; @@ -53,52 +53,32 @@ impl GroupCommitLock { /// Get or create a semaphore for a specific group and acquire it, returning a guard pub async fn get_lock_async(&self, group_id: Vec) -> Result { let semaphore = { - match self.locks.lock() { - Ok(mut locks) => locks - .entry(group_id) - .or_insert_with(|| Arc::new(Semaphore::new(1))) - .clone(), - Err(err) => { - eprintln!("Failed to lock the mutex: {}", err); - return Err(GroupError::LockUnavailable); - } - } + let mut locks = self.locks.lock(); + locks + .entry(group_id) + .or_insert_with(|| Arc::new(Semaphore::new(1))) + .clone() }; - let semaphore_clone = semaphore.clone(); - let permit = match semaphore.acquire_owned().await { - Ok(permit) => permit, - Err(err) => { - eprintln!("Failed to acquire semaphore permit: {}", err); - return Err(GroupError::LockUnavailable); - } - }; + let permit = semaphore.clone().acquire_owned().await?; Ok(SemaphoreGuard { _permit: permit, - _semaphore: semaphore_clone, + _semaphore: semaphore, }) } /// Get or create a semaphore for a specific group and acquire it synchronously pub fn get_lock_sync(&self, group_id: Vec) -> Result { let semaphore = { - match self.locks.lock() { - Ok(mut locks) => locks - .entry(group_id) - .or_insert_with(|| Arc::new(Semaphore::new(1))) - .clone(), - Err(err) => { - eprintln!("Failed to lock the mutex: {}", err); - return Err(GroupError::LockUnavailable); - } - } + let mut locks = self.locks.lock(); + locks + .entry(group_id) + .or_insert_with(|| Arc::new(Semaphore::new(1))) + .clone() }; // Synchronously acquire the permit - let permit = semaphore - .clone() - .try_acquire_owned() - .map_err(|_| GroupError::LockUnavailable)?; + let permit = semaphore.clone().try_acquire_owned()?; Ok(SemaphoreGuard { _permit: permit, _semaphore: semaphore, // semaphore is now valid because we cloned it earlier diff --git a/xmtp_mls/src/subscriptions.rs b/xmtp_mls/src/subscriptions.rs index 5082bdce9..2efdd4a07 100644 --- a/xmtp_mls/src/subscriptions.rs +++ b/xmtp_mls/src/subscriptions.rs @@ -344,7 +344,7 @@ where } WelcomeOrGroup::Group(group) => group?, }; - let metadata = group.metadata(provider)?; + let metadata = group.metadata(&provider).await?; Ok((metadata, group)) } } From 7473a2462b98bd79afa469ea8940cc1c029bda69 Mon Sep 17 00:00:00 2001 From: Mojtaba Chenani Date: Thu, 12 Dec 2024 22:25:30 +0100 Subject: [PATCH 15/28] remove unneeded comments --- bindings_ffi/src/mls.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/bindings_ffi/src/mls.rs b/bindings_ffi/src/mls.rs index 492e42b78..c5c9b819c 100644 --- a/bindings_ffi/src/mls.rs +++ b/bindings_ffi/src/mls.rs @@ -1546,7 +1546,6 @@ impl FfiConversation { pub fn group_metadata(&self) -> Result, GenericError> { let provider = self.inner.mls_provider()?; - // blocking is OK b/c not wasm let metadata = tokio::task::block_in_place(|| { futures::executor::block_on(self.inner.metadata(&provider)) })?; @@ -1561,7 +1560,6 @@ impl FfiConversation { pub fn conversation_type(&self) -> Result { let provider = self.inner.mls_provider()?; - // blocking OK b/c not wasm let conversation_type = tokio::task::block_in_place(|| { futures::executor::block_on(self.inner.conversation_type(&provider)) })?; From 7ef4f3cc0f87b0deaf503aebe53d459b1139f5da Mon Sep 17 00:00:00 2001 From: Mojtaba Chenani Date: Fri, 13 Dec 2024 17:07:26 +0100 Subject: [PATCH 16/28] use mutex instead of semaphore --- bindings_ffi/src/mls.rs | 8 ++----- bindings_node/src/conversation.rs | 7 +++--- bindings_wasm/src/conversation.rs | 3 +-- examples/cli/serializable.rs | 3 +-- xmtp_mls/src/groups/mls_sync.rs | 1 - xmtp_mls/src/groups/mod.rs | 28 +++++++++++----------- xmtp_mls/src/lib.rs | 39 +++++-------------------------- xmtp_mls/src/subscriptions.rs | 2 +- 8 files changed, 27 insertions(+), 64 deletions(-) diff --git a/bindings_ffi/src/mls.rs b/bindings_ffi/src/mls.rs index c5c9b819c..65d4baed8 100644 --- a/bindings_ffi/src/mls.rs +++ b/bindings_ffi/src/mls.rs @@ -1546,9 +1546,7 @@ impl FfiConversation { pub fn group_metadata(&self) -> Result, GenericError> { let provider = self.inner.mls_provider()?; - let metadata = tokio::task::block_in_place(|| { - futures::executor::block_on(self.inner.metadata(&provider)) - })?; + let metadata = self.inner.metadata(&provider)?; Ok(Arc::new(FfiConversationMetadata { inner: Arc::new(metadata), })) @@ -1560,9 +1558,7 @@ impl FfiConversation { pub fn conversation_type(&self) -> Result { let provider = self.inner.mls_provider()?; - let conversation_type = tokio::task::block_in_place(|| { - futures::executor::block_on(self.inner.conversation_type(&provider)) - })?; + let conversation_type = self.inner.conversation_type(&provider)?; Ok(conversation_type.into()) } } diff --git a/bindings_node/src/conversation.rs b/bindings_node/src/conversation.rs index 897276889..9e67a6670 100644 --- a/bindings_node/src/conversation.rs +++ b/bindings_node/src/conversation.rs @@ -610,10 +610,9 @@ impl Conversation { self.created_at_ns, ); - let metadata = tokio::task::block_in_place(|| { - futures::executor::block_on(group.metadata(&group.mls_provider()?)) - }) - .map_err(ErrorWrapper::from)?; + let metadata = group + .metadata(&group.mls_provider().map_err(ErrorWrapper::from)?) + .map_err(ErrorWrapper::from)?; Ok(GroupMetadata { inner: metadata }) } diff --git a/bindings_wasm/src/conversation.rs b/bindings_wasm/src/conversation.rs index 147c63b94..ae446603e 100644 --- a/bindings_wasm/src/conversation.rs +++ b/bindings_wasm/src/conversation.rs @@ -526,7 +526,7 @@ impl Conversation { } #[wasm_bindgen(js_name = groupMetadata)] - pub async fn group_metadata(&self) -> Result { + pub fn group_metadata(&self) -> Result { let group = self.to_mls_group(); let metadata = group .metadata( @@ -534,7 +534,6 @@ impl Conversation { .mls_provider() .map_err(|e| JsError::new(&format!("{e}")))?, ) - .await .map_err(|e| JsError::new(&format!("{e}")))?; Ok(GroupMetadata { inner: metadata }) diff --git a/examples/cli/serializable.rs b/examples/cli/serializable.rs index 545081638..429876608 100644 --- a/examples/cli/serializable.rs +++ b/examples/cli/serializable.rs @@ -35,8 +35,7 @@ impl SerializableGroup { .mls_provider() .expect("MLS Provider could not be created"), ) - .await - .unwrap(); + .expect("could not load metadata"); let permissions = group.permissions().expect("could not load permissions"); Self { diff --git a/xmtp_mls/src/groups/mls_sync.rs b/xmtp_mls/src/groups/mls_sync.rs index 9c319be46..88736cbc4 100644 --- a/xmtp_mls/src/groups/mls_sync.rs +++ b/xmtp_mls/src/groups/mls_sync.rs @@ -198,7 +198,6 @@ where // TODO: Should probably be renamed to `sync_with_provider` #[tracing::instrument(skip_all)] pub async fn sync_with_conn(&self, provider: &XmtpOpenMlsProvider) -> Result<(), GroupError> { - // Check if we're still part of the group let _mutex = self.mutex.lock().await; let mut errors: Vec = vec![]; diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index 6c2377314..48b032b88 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -353,7 +353,7 @@ impl MlsGroup { let group_id = self.group_id.clone(); // Acquire the lock synchronously using blocking_lock - let _lock = MLS_COMMIT_LOCK.get_lock_sync(group_id.clone())?; + let _lock = MLS_COMMIT_LOCK.get_lock_sync(group_id.clone()); // Load the MLS group let mls_group = OpenMlsGroup::load(provider.storage(), &GroupId::from_slice(&self.group_id)) @@ -380,7 +380,7 @@ impl MlsGroup { let group_id = self.group_id.clone(); // Acquire the lock asynchronously - let _lock = MLS_COMMIT_LOCK.get_lock_async(group_id.clone()).await; + let _lock = MLS_COMMIT_LOCK.get_lock_sync(group_id.clone()).await; // Load the MLS group let mls_group = @@ -896,7 +896,7 @@ impl MlsGroup { /// to perform these updates. pub async fn update_group_name(&self, group_name: String) -> Result<(), GroupError> { let provider = self.client.mls_provider()?; - if self.metadata(&provider).await?.conversation_type == ConversationType::Dm { + if self.metadata(&provider)?.conversation_type == ConversationType::Dm { return Err(GroupError::DmGroupMetadataForbidden); } let intent_data: Vec = @@ -914,7 +914,7 @@ impl MlsGroup { metadata_field: Option, ) -> Result<(), GroupError> { let provider = self.client.mls_provider()?; - if self.metadata(&provider).await?.conversation_type == ConversationType::Dm { + if self.metadata(&provider)?.conversation_type == ConversationType::Dm { return Err(GroupError::DmGroupMetadataForbidden); } if permission_update_type == PermissionUpdateType::UpdateMetadata @@ -955,7 +955,7 @@ impl MlsGroup { group_description: String, ) -> Result<(), GroupError> { let provider = self.client.mls_provider()?; - if self.metadata(&provider).await?.conversation_type == ConversationType::Dm { + if self.metadata(&provider)?.conversation_type == ConversationType::Dm { return Err(GroupError::DmGroupMetadataForbidden); } let intent_data: Vec = @@ -984,7 +984,7 @@ impl MlsGroup { group_image_url_square: String, ) -> Result<(), GroupError> { let provider = self.client.mls_provider()?; - if self.metadata(&provider).await?.conversation_type == ConversationType::Dm { + if self.metadata(&provider)?.conversation_type == ConversationType::Dm { return Err(GroupError::DmGroupMetadataForbidden); } let intent_data: Vec = @@ -1017,7 +1017,7 @@ impl MlsGroup { pinned_frame_url: String, ) -> Result<(), GroupError> { let provider = self.client.mls_provider()?; - if self.metadata(&provider).await?.conversation_type == ConversationType::Dm { + if self.metadata(&provider)?.conversation_type == ConversationType::Dm { return Err(GroupError::DmGroupMetadataForbidden); } let intent_data: Vec = @@ -1079,11 +1079,11 @@ impl MlsGroup { } /// Retrieves the conversation type of the group from the group's metadata extension. - pub async fn conversation_type( + pub fn conversation_type( &self, provider: &XmtpOpenMlsProvider, ) -> Result { - let metadata = self.metadata(provider).await?; + let metadata = self.metadata(provider)?; Ok(metadata.conversation_type) } @@ -1094,7 +1094,7 @@ impl MlsGroup { inbox_id: String, ) -> Result<(), GroupError> { let provider = self.client.mls_provider()?; - if self.metadata(&provider).await?.conversation_type == ConversationType::Dm { + if self.metadata(&provider)?.conversation_type == ConversationType::Dm { return Err(GroupError::DmGroupMetadataForbidden); } let intent_action_type = match action_type { @@ -1183,14 +1183,13 @@ impl MlsGroup { } /// Get the `GroupMetadata` of the group. - pub async fn metadata( + pub fn metadata( &self, provider: &XmtpOpenMlsProvider, ) -> Result { - self.load_mls_group_with_lock_async(provider, |mls_group| { - futures::future::ready(extract_group_metadata(&mls_group).map_err(Into::into)) + self.load_mls_group_with_lock(provider, |mls_group| { + extract_group_metadata(&mls_group).map_err(Into::into) }) - .await } /// Get the `GroupMutableMetadata` of the group. @@ -3117,7 +3116,6 @@ pub(crate) mod tests { let protected_metadata: GroupMetadata = amal_group .metadata(&amal_group.mls_provider().unwrap()) - .await .unwrap(); assert_eq!( protected_metadata.conversation_type, diff --git a/xmtp_mls/src/lib.rs b/xmtp_mls/src/lib.rs index 62b5057ae..762ee85df 100644 --- a/xmtp_mls/src/lib.rs +++ b/xmtp_mls/src/lib.rs @@ -34,7 +34,7 @@ pub use xmtp_proto::api_client::trait_impls::*; #[derive(Debug)] pub struct GroupCommitLock { // Storage for group-specific semaphores - locks: Mutex, Arc>>, + locks: Mutex, Arc>>>, } impl Default for GroupCommitLock { @@ -51,45 +51,17 @@ impl GroupCommitLock { } /// Get or create a semaphore for a specific group and acquire it, returning a guard - pub async fn get_lock_async(&self, group_id: Vec) -> Result { - let semaphore = { + pub async fn get_lock_sync(&self, group_id: Vec) -> Result>, GroupError> { + let mutex = { let mut locks = self.locks.lock(); locks .entry(group_id) - .or_insert_with(|| Arc::new(Semaphore::new(1))) + .or_insert_with(|| Arc::new(Mutex::new(()))) .clone() }; - let permit = semaphore.clone().acquire_owned().await?; - Ok(SemaphoreGuard { - _permit: permit, - _semaphore: semaphore, - }) + Ok(mutex) } - - /// Get or create a semaphore for a specific group and acquire it synchronously - pub fn get_lock_sync(&self, group_id: Vec) -> Result { - let semaphore = { - let mut locks = self.locks.lock(); - locks - .entry(group_id) - .or_insert_with(|| Arc::new(Semaphore::new(1))) - .clone() - }; - - // Synchronously acquire the permit - let permit = semaphore.clone().try_acquire_owned()?; - Ok(SemaphoreGuard { - _permit: permit, - _semaphore: semaphore, // semaphore is now valid because we cloned it earlier - }) - } -} - -/// A guard that releases the semaphore when dropped -pub struct SemaphoreGuard { - _permit: OwnedSemaphorePermit, - _semaphore: Arc, } // Static instance of `GroupCommitLock` @@ -137,6 +109,7 @@ use crate::groups::GroupError; pub use stream_handles::{ spawn, AbortHandle, GenericStreamHandle, StreamHandle, StreamHandleError, }; +use crate::groups::GroupError::LockUnavailable; #[cfg(test)] pub(crate) mod tests { diff --git a/xmtp_mls/src/subscriptions.rs b/xmtp_mls/src/subscriptions.rs index 2efdd4a07..23179946b 100644 --- a/xmtp_mls/src/subscriptions.rs +++ b/xmtp_mls/src/subscriptions.rs @@ -344,7 +344,7 @@ where } WelcomeOrGroup::Group(group) => group?, }; - let metadata = group.metadata(&provider).await?; + let metadata = group.metadata(&provider)?; Ok((metadata, group)) } } From d288c6bbc1e689f45126d81e641585647e8acca3 Mon Sep 17 00:00:00 2001 From: Mojtaba Chenani Date: Fri, 13 Dec 2024 17:09:41 +0100 Subject: [PATCH 17/28] fix fmt --- bindings_node/src/conversation.rs | 11 +++++------ bindings_wasm/src/conversation.rs | 1 - xmtp_mls/src/groups/mod.rs | 5 +---- xmtp_mls/src/lib.rs | 2 +- 4 files changed, 7 insertions(+), 12 deletions(-) diff --git a/bindings_node/src/conversation.rs b/bindings_node/src/conversation.rs index 9e67a6670..20ffb2386 100644 --- a/bindings_node/src/conversation.rs +++ b/bindings_node/src/conversation.rs @@ -167,10 +167,9 @@ impl Conversation { self.created_at_ns, ); let provider = group.mls_provider().map_err(ErrorWrapper::from)?; - let conversation_type = tokio::task::block_in_place(|| { - futures::executor::block_on(group.conversation_type(&provider)) - }) - .map_err(ErrorWrapper::from)?; + let conversation_type = group + .conversation_type(&provider) + .map_err(ErrorWrapper::from)?; let kind = match conversation_type { ConversationType::Group => None, ConversationType::Dm => Some(XmtpGroupMessageKind::Application), @@ -611,8 +610,8 @@ impl Conversation { ); let metadata = group - .metadata(&group.mls_provider().map_err(ErrorWrapper::from)?) - .map_err(ErrorWrapper::from)?; + .metadata(&group.mls_provider().map_err(ErrorWrapper::from)?) + .map_err(ErrorWrapper::from)?; Ok(GroupMetadata { inner: metadata }) } diff --git a/bindings_wasm/src/conversation.rs b/bindings_wasm/src/conversation.rs index ae446603e..e7ded1a19 100644 --- a/bindings_wasm/src/conversation.rs +++ b/bindings_wasm/src/conversation.rs @@ -191,7 +191,6 @@ impl Conversation { .map_err(|e| JsError::new(&format!("{e}")))?; let conversation_type = group .conversation_type(&provider) - .await .map_err(|e| JsError::new(&format!("{e}")))?; let kind = match conversation_type { ConversationType::Group => None, diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index 48b032b88..bd6661514 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -1183,10 +1183,7 @@ impl MlsGroup { } /// Get the `GroupMetadata` of the group. - pub fn metadata( - &self, - provider: &XmtpOpenMlsProvider, - ) -> Result { + pub fn metadata(&self, provider: &XmtpOpenMlsProvider) -> Result { self.load_mls_group_with_lock(provider, |mls_group| { extract_group_metadata(&mls_group).map_err(Into::into) }) diff --git a/xmtp_mls/src/lib.rs b/xmtp_mls/src/lib.rs index 762ee85df..35133d7a0 100644 --- a/xmtp_mls/src/lib.rs +++ b/xmtp_mls/src/lib.rs @@ -106,10 +106,10 @@ pub trait Delete { } use crate::groups::GroupError; +use crate::groups::GroupError::LockUnavailable; pub use stream_handles::{ spawn, AbortHandle, GenericStreamHandle, StreamHandle, StreamHandleError, }; -use crate::groups::GroupError::LockUnavailable; #[cfg(test)] pub(crate) mod tests { From 822236dc23434266039ff62c2121305523947daa Mon Sep 17 00:00:00 2001 From: Mojtaba Chenani Date: Fri, 13 Dec 2024 17:29:34 +0100 Subject: [PATCH 18/28] fix after conflicts --- xmtp_mls/src/groups/mls_sync.rs | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/xmtp_mls/src/groups/mls_sync.rs b/xmtp_mls/src/groups/mls_sync.rs index 5cb1f9fbb..0889c05ca 100644 --- a/xmtp_mls/src/groups/mls_sync.rs +++ b/xmtp_mls/src/groups/mls_sync.rs @@ -615,13 +615,6 @@ where } .store_or_ignore(provider.conn_ref())?; - tracing::info!("Received a history reply."); - let _ = self.client.local_events().send(LocalEvents::SyncMessage( - SyncMessage::Reply { message_id }, - )); - } - .store_or_ignore(provider.conn_ref())?; - tracing::info!("Received a history reply."); let _ = self.client.local_events().send(LocalEvents::SyncMessage( SyncMessage::Reply { message_id }, @@ -645,6 +638,7 @@ where return Err(GroupMessageProcessingError::InvalidPayload); } } + } None => return Err(GroupMessageProcessingError::InvalidPayload), } } From ba0b09cfe635958a71c4bd15f126e289ae8fe473 Mon Sep 17 00:00:00 2001 From: Mojtaba Chenani Date: Fri, 13 Dec 2024 17:39:08 +0100 Subject: [PATCH 19/28] fix linter --- xmtp_mls/src/lib.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/xmtp_mls/src/lib.rs b/xmtp_mls/src/lib.rs index 35133d7a0..8a769a57b 100644 --- a/xmtp_mls/src/lib.rs +++ b/xmtp_mls/src/lib.rs @@ -24,7 +24,6 @@ use parking_lot::Mutex; use std::collections::HashMap; use std::sync::{Arc, LazyLock}; use storage::{DuplicateItem, StorageError}; -use tokio::sync::{OwnedSemaphorePermit, Semaphore}; pub use xmtp_openmls_provider::XmtpOpenMlsProvider; pub use xmtp_id::InboxOwner; @@ -106,7 +105,6 @@ pub trait Delete { } use crate::groups::GroupError; -use crate::groups::GroupError::LockUnavailable; pub use stream_handles::{ spawn, AbortHandle, GenericStreamHandle, StreamHandle, StreamHandleError, }; From a9bbb5d28c941e1460f4eb7a788f009dfbf56315 Mon Sep 17 00:00:00 2001 From: Mojtaba Chenani Date: Mon, 16 Dec 2024 19:19:45 +0100 Subject: [PATCH 20/28] revert to semaphore --- xmtp_mls/src/groups/mod.rs | 2 +- xmtp_mls/src/lib.rs | 44 ++++++++++++++++++++++++++++++++------ 2 files changed, 38 insertions(+), 8 deletions(-) diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index bd6661514..b5e238790 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -380,7 +380,7 @@ impl MlsGroup { let group_id = self.group_id.clone(); // Acquire the lock asynchronously - let _lock = MLS_COMMIT_LOCK.get_lock_sync(group_id.clone()).await; + let _lock = MLS_COMMIT_LOCK.get_lock_async(group_id.clone()).await; // Load the MLS group let mls_group = diff --git a/xmtp_mls/src/lib.rs b/xmtp_mls/src/lib.rs index 8a769a57b..0a401a4f1 100644 --- a/xmtp_mls/src/lib.rs +++ b/xmtp_mls/src/lib.rs @@ -20,10 +20,10 @@ pub mod verified_key_package_v2; mod xmtp_openmls_provider; pub use client::{Client, Network}; -use parking_lot::Mutex; use std::collections::HashMap; -use std::sync::{Arc, LazyLock}; +use std::sync::{Arc, LazyLock, Mutex}; use storage::{DuplicateItem, StorageError}; +use tokio::sync::{OwnedSemaphorePermit, Semaphore}; pub use xmtp_openmls_provider::XmtpOpenMlsProvider; pub use xmtp_id::InboxOwner; @@ -33,7 +33,7 @@ pub use xmtp_proto::api_client::trait_impls::*; #[derive(Debug)] pub struct GroupCommitLock { // Storage for group-specific semaphores - locks: Mutex, Arc>>>, + locks: Mutex, Arc>>, } impl Default for GroupCommitLock { @@ -50,17 +50,47 @@ impl GroupCommitLock { } /// Get or create a semaphore for a specific group and acquire it, returning a guard - pub async fn get_lock_sync(&self, group_id: Vec) -> Result>, GroupError> { - let mutex = { + pub async fn get_lock_async(&self, group_id: Vec) -> Result { + let semaphore = { let mut locks = self.locks.lock(); locks + .unwrap() .entry(group_id) - .or_insert_with(|| Arc::new(Mutex::new(()))) + .or_insert_with(|| Arc::new(Semaphore::new(1))) .clone() }; - Ok(mutex) + let permit = semaphore.clone().acquire_owned().await?; + Ok(SemaphoreGuard { + _permit: permit, + _semaphore: semaphore, + }) } + + /// Get or create a semaphore for a specific group and acquire it synchronously + pub fn get_lock_sync(&self, group_id: Vec) -> Result { + let semaphore = { + let locks = self.locks.lock(); + locks + .unwrap() + .entry(group_id) + .or_insert_with(|| Arc::new(Semaphore::new(1))) + .clone() + }; + + // Synchronously acquire the permit + let permit = semaphore.clone().try_acquire_owned()?; + Ok(SemaphoreGuard { + _permit: permit, + _semaphore: semaphore, // semaphore is now valid because we cloned it earlier + }) + } +} + +/// A guard that releases the semaphore when dropped +pub struct SemaphoreGuard { + _permit: OwnedSemaphorePermit, + _semaphore: Arc, } // Static instance of `GroupCommitLock` From 7189666adcb974d60493d1b564b8c9709ba01509 Mon Sep 17 00:00:00 2001 From: Mojtaba Chenani Date: Mon, 16 Dec 2024 21:30:03 +0100 Subject: [PATCH 21/28] fix clippy --- xmtp_mls/src/groups/device_sync.rs | 7 +---- xmtp_mls/src/groups/mod.rs | 14 ++++----- xmtp_mls/src/lib.rs | 48 ++++++++++++++++++++---------- 3 files changed, 41 insertions(+), 28 deletions(-) diff --git a/xmtp_mls/src/groups/device_sync.rs b/xmtp_mls/src/groups/device_sync.rs index 65f233256..9649f3b53 100644 --- a/xmtp_mls/src/groups/device_sync.rs +++ b/xmtp_mls/src/groups/device_sync.rs @@ -24,14 +24,9 @@ use futures::{Stream, StreamExt}; use preference_sync::UserPreferenceUpdate; use rand::{Rng, RngCore}; use serde::{Deserialize, Serialize}; -use std::future::Future; use std::pin::Pin; -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::Arc; use thiserror::Error; -use tokio::sync::{Notify, OnceCell}; -use tokio::time::error::Elapsed; -use tokio::time::timeout; +use tokio::sync::{OnceCell}; use tracing::{instrument, warn}; use xmtp_common::time::{now_ns, Duration}; use xmtp_common::{retry_async, Retry, RetryableError}; diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index b5e238790..676d6eb96 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -34,7 +34,7 @@ use openmls::{ use openmls_traits::OpenMlsProvider; use prost::Message; use thiserror::Error; -use tokio::sync::Mutex; +use tokio::sync::{Mutex}; use self::device_sync::DeviceSyncError; pub use self::group_permissions::PreconfiguredPolicies; @@ -206,10 +206,10 @@ pub enum GroupError { IntentNotCommitted, #[error(transparent)] ProcessIntent(#[from] ProcessIntentError), - #[error(transparent)] - LockUnavailable(#[from] tokio::sync::AcquireError), - #[error(transparent)] - LockFailedToAcquire(#[from] tokio::sync::TryAcquireError), + #[error("Failed to load lock")] + LockUnavailable, + #[error("Failed to acquire semaphore lock")] + LockFailedToAcquire, } impl RetryableError for GroupError { @@ -236,8 +236,8 @@ impl RetryableError for GroupError { Self::MessageHistory(err) => err.is_retryable(), Self::ProcessIntent(err) => err.is_retryable(), Self::LocalEvent(err) => err.is_retryable(), - Self::LockUnavailable(_) => true, - Self::LockFailedToAcquire(_) => true, + Self::LockUnavailable => true, + Self::LockFailedToAcquire => true, Self::SyncFailedToWait => true, Self::GroupNotFound | Self::GroupMetadata(_) diff --git a/xmtp_mls/src/lib.rs b/xmtp_mls/src/lib.rs index 0a401a4f1..164e79af6 100644 --- a/xmtp_mls/src/lib.rs +++ b/xmtp_mls/src/lib.rs @@ -52,34 +52,52 @@ impl GroupCommitLock { /// Get or create a semaphore for a specific group and acquire it, returning a guard pub async fn get_lock_async(&self, group_id: Vec) -> Result { let semaphore = { - let mut locks = self.locks.lock(); - locks - .unwrap() - .entry(group_id) - .or_insert_with(|| Arc::new(Semaphore::new(1))) - .clone() + match self.locks.lock() { + Ok(mut locks) => locks + .entry(group_id) + .or_insert_with(|| Arc::new(Semaphore::new(1))) + .clone(), + Err(err) => { + eprintln!("Failed to lock the mutex: {}", err); + return Err(GroupError::LockUnavailable); + } + } }; - let permit = semaphore.clone().acquire_owned().await?; + let semaphore_clone = semaphore.clone(); + let permit = match semaphore.acquire_owned().await { + Ok(permit) => permit, + Err(err) => { + eprintln!("Failed to acquire semaphore permit: {}", err); + return Err(GroupError::LockFailedToAcquire); + } + }; Ok(SemaphoreGuard { _permit: permit, - _semaphore: semaphore, + _semaphore: semaphore_clone, }) } /// Get or create a semaphore for a specific group and acquire it synchronously pub fn get_lock_sync(&self, group_id: Vec) -> Result { let semaphore = { - let locks = self.locks.lock(); - locks - .unwrap() - .entry(group_id) - .or_insert_with(|| Arc::new(Semaphore::new(1))) - .clone() + match self.locks.lock() { + Ok(mut locks) => locks + .entry(group_id) + .or_insert_with(|| Arc::new(Semaphore::new(1))) + .clone(), + Err(err) => { + eprintln!("Failed to lock the mutex: {}", err); + return Err(GroupError::LockUnavailable); + } + } }; // Synchronously acquire the permit - let permit = semaphore.clone().try_acquire_owned()?; + let permit = semaphore + .clone() + .try_acquire_owned() + .map_err(|_| GroupError::LockUnavailable)?; Ok(SemaphoreGuard { _permit: permit, _semaphore: semaphore, // semaphore is now valid because we cloned it earlier From e7e454c1633fc225520408b95d62a58db96eae11 Mon Sep 17 00:00:00 2001 From: Mojtaba Chenani Date: Mon, 16 Dec 2024 21:35:56 +0100 Subject: [PATCH 22/28] fix linter --- xmtp_mls/src/groups/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index 676d6eb96..b1fb6f92c 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -34,7 +34,7 @@ use openmls::{ use openmls_traits::OpenMlsProvider; use prost::Message; use thiserror::Error; -use tokio::sync::{Mutex}; +use tokio::sync::Mutex; use self::device_sync::DeviceSyncError; pub use self::group_permissions::PreconfiguredPolicies; From 494097c276073dd64c6a9f1032f143388a0ef63e Mon Sep 17 00:00:00 2001 From: Mojtaba Chenani Date: Mon, 16 Dec 2024 23:15:46 +0100 Subject: [PATCH 23/28] make group.metadata async --- bindings_ffi/src/mls.rs | 12 ++++++------ bindings_node/src/conversation.rs | 6 ++++-- bindings_wasm/src/conversation.rs | 4 +++- examples/cli/serializable.rs | 3 ++- xmtp_mls/src/groups/mod.rs | 27 ++++++++++++++++----------- xmtp_mls/src/subscriptions.rs | 2 +- 6 files changed, 32 insertions(+), 22 deletions(-) diff --git a/bindings_ffi/src/mls.rs b/bindings_ffi/src/mls.rs index 65d4baed8..ad9914533 100644 --- a/bindings_ffi/src/mls.rs +++ b/bindings_ffi/src/mls.rs @@ -1268,13 +1268,13 @@ impl FfiConversation { Ok(()) } - pub fn find_messages( + pub async fn find_messages( &self, opts: FfiListMessagesOptions, ) -> Result, GenericError> { let delivery_status = opts.delivery_status.map(|status| status.into()); let direction = opts.direction.map(|dir| dir.into()); - let kind = match self.conversation_type()? { + let kind = match self.conversation_type().await? { FfiConversationType::Group => None, FfiConversationType::Dm => Some(GroupMessageKind::Application), FfiConversationType::Sync => None, @@ -1544,9 +1544,9 @@ impl FfiConversation { self.inner.added_by_inbox_id().map_err(Into::into) } - pub fn group_metadata(&self) -> Result, GenericError> { + pub async fn group_metadata(&self) -> Result, GenericError> { let provider = self.inner.mls_provider()?; - let metadata = self.inner.metadata(&provider)?; + let metadata = self.inner.metadata(&provider).await?; Ok(Arc::new(FfiConversationMetadata { inner: Arc::new(metadata), })) @@ -1556,9 +1556,9 @@ impl FfiConversation { self.inner.dm_inbox_id().map_err(Into::into) } - pub fn conversation_type(&self) -> Result { + pub async fn conversation_type(&self) -> Result { let provider = self.inner.mls_provider()?; - let conversation_type = self.inner.conversation_type(&provider)?; + let conversation_type = self.inner.conversation_type(&provider).await?; Ok(conversation_type.into()) } } diff --git a/bindings_node/src/conversation.rs b/bindings_node/src/conversation.rs index e8a4066d0..b520af793 100644 --- a/bindings_node/src/conversation.rs +++ b/bindings_node/src/conversation.rs @@ -161,7 +161,7 @@ impl Conversation { } #[napi] - pub fn find_messages(&self, opts: Option) -> Result> { + pub async fn find_messages(&self, opts: Option) -> Result> { let opts = opts.unwrap_or_default(); let group = MlsGroup::new( self.inner_client.clone(), @@ -171,6 +171,7 @@ impl Conversation { let provider = group.mls_provider().map_err(ErrorWrapper::from)?; let conversation_type = group .conversation_type(&provider) + .await .map_err(ErrorWrapper::from)?; let kind = match conversation_type { ConversationType::Group => None, @@ -604,7 +605,7 @@ impl Conversation { } #[napi] - pub fn group_metadata(&self) -> Result { + pub async fn group_metadata(&self) -> Result { let group = MlsGroup::new( self.inner_client.clone(), self.group_id.clone(), @@ -613,6 +614,7 @@ impl Conversation { let metadata = group .metadata(&group.mls_provider().map_err(ErrorWrapper::from)?) + .await .map_err(ErrorWrapper::from)?; Ok(GroupMetadata { inner: metadata }) diff --git a/bindings_wasm/src/conversation.rs b/bindings_wasm/src/conversation.rs index d211d5cae..42ca2f06e 100644 --- a/bindings_wasm/src/conversation.rs +++ b/bindings_wasm/src/conversation.rs @@ -194,6 +194,7 @@ impl Conversation { .map_err(|e| JsError::new(&format!("{e}")))?; let conversation_type = group .conversation_type(&provider) + .await .map_err(|e| JsError::new(&format!("{e}")))?; let kind = match conversation_type { ConversationType::Group => None, @@ -528,7 +529,7 @@ impl Conversation { } #[wasm_bindgen(js_name = groupMetadata)] - pub fn group_metadata(&self) -> Result { + pub async fn group_metadata(&self) -> Result { let group = self.to_mls_group(); let metadata = group .metadata( @@ -536,6 +537,7 @@ impl Conversation { .mls_provider() .map_err(|e| JsError::new(&format!("{e}")))?, ) + .await .map_err(|e| JsError::new(&format!("{e}")))?; Ok(GroupMetadata { inner: metadata }) diff --git a/examples/cli/serializable.rs b/examples/cli/serializable.rs index 429876608..545081638 100644 --- a/examples/cli/serializable.rs +++ b/examples/cli/serializable.rs @@ -35,7 +35,8 @@ impl SerializableGroup { .mls_provider() .expect("MLS Provider could not be created"), ) - .expect("could not load metadata"); + .await + .unwrap(); let permissions = group.permissions().expect("could not load permissions"); Self { diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index b1fb6f92c..35662bc7d 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -896,7 +896,7 @@ impl MlsGroup { /// to perform these updates. pub async fn update_group_name(&self, group_name: String) -> Result<(), GroupError> { let provider = self.client.mls_provider()?; - if self.metadata(&provider)?.conversation_type == ConversationType::Dm { + if self.metadata(&provider).await?.conversation_type == ConversationType::Dm { return Err(GroupError::DmGroupMetadataForbidden); } let intent_data: Vec = @@ -914,7 +914,7 @@ impl MlsGroup { metadata_field: Option, ) -> Result<(), GroupError> { let provider = self.client.mls_provider()?; - if self.metadata(&provider)?.conversation_type == ConversationType::Dm { + if self.metadata(&provider).await?.conversation_type == ConversationType::Dm { return Err(GroupError::DmGroupMetadataForbidden); } if permission_update_type == PermissionUpdateType::UpdateMetadata @@ -955,7 +955,7 @@ impl MlsGroup { group_description: String, ) -> Result<(), GroupError> { let provider = self.client.mls_provider()?; - if self.metadata(&provider)?.conversation_type == ConversationType::Dm { + if self.metadata(&provider).await?.conversation_type == ConversationType::Dm { return Err(GroupError::DmGroupMetadataForbidden); } let intent_data: Vec = @@ -984,7 +984,7 @@ impl MlsGroup { group_image_url_square: String, ) -> Result<(), GroupError> { let provider = self.client.mls_provider()?; - if self.metadata(&provider)?.conversation_type == ConversationType::Dm { + if self.metadata(&provider).await?.conversation_type == ConversationType::Dm { return Err(GroupError::DmGroupMetadataForbidden); } let intent_data: Vec = @@ -1017,7 +1017,7 @@ impl MlsGroup { pinned_frame_url: String, ) -> Result<(), GroupError> { let provider = self.client.mls_provider()?; - if self.metadata(&provider)?.conversation_type == ConversationType::Dm { + if self.metadata(&provider).await?.conversation_type == ConversationType::Dm { return Err(GroupError::DmGroupMetadataForbidden); } let intent_data: Vec = @@ -1079,11 +1079,11 @@ impl MlsGroup { } /// Retrieves the conversation type of the group from the group's metadata extension. - pub fn conversation_type( + pub async fn conversation_type( &self, provider: &XmtpOpenMlsProvider, ) -> Result { - let metadata = self.metadata(provider)?; + let metadata = self.metadata(provider).await?; Ok(metadata.conversation_type) } @@ -1094,7 +1094,7 @@ impl MlsGroup { inbox_id: String, ) -> Result<(), GroupError> { let provider = self.client.mls_provider()?; - if self.metadata(&provider)?.conversation_type == ConversationType::Dm { + if self.metadata(&provider).await?.conversation_type == ConversationType::Dm { return Err(GroupError::DmGroupMetadataForbidden); } let intent_action_type = match action_type { @@ -1183,10 +1183,14 @@ impl MlsGroup { } /// Get the `GroupMetadata` of the group. - pub fn metadata(&self, provider: &XmtpOpenMlsProvider) -> Result { - self.load_mls_group_with_lock(provider, |mls_group| { - extract_group_metadata(&mls_group).map_err(Into::into) + pub async fn metadata( + &self, + provider: &XmtpOpenMlsProvider, + ) -> Result { + self.load_mls_group_with_lock_async(provider, |mls_group| { + futures::future::ready(extract_group_metadata(&mls_group).map_err(Into::into)) }) + .await } /// Get the `GroupMutableMetadata` of the group. @@ -3113,6 +3117,7 @@ pub(crate) mod tests { let protected_metadata: GroupMetadata = amal_group .metadata(&amal_group.mls_provider().unwrap()) + .await .unwrap(); assert_eq!( protected_metadata.conversation_type, diff --git a/xmtp_mls/src/subscriptions.rs b/xmtp_mls/src/subscriptions.rs index 23179946b..2efdd4a07 100644 --- a/xmtp_mls/src/subscriptions.rs +++ b/xmtp_mls/src/subscriptions.rs @@ -344,7 +344,7 @@ where } WelcomeOrGroup::Group(group) => group?, }; - let metadata = group.metadata(&provider)?; + let metadata = group.metadata(&provider).await?; Ok((metadata, group)) } } From 2011115810f9b0d2da12c1bbaa70f0cb4af9b8a1 Mon Sep 17 00:00:00 2001 From: Mojtaba Chenani Date: Mon, 16 Dec 2024 23:26:40 +0100 Subject: [PATCH 24/28] fix tests --- bindings_ffi/src/mls.rs | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/bindings_ffi/src/mls.rs b/bindings_ffi/src/mls.rs index ad9914533..f173c7a17 100644 --- a/bindings_ffi/src/mls.rs +++ b/bindings_ffi/src/mls.rs @@ -2605,9 +2605,11 @@ mod tests { let bo_messages1 = bo_group1 .find_messages(FfiListMessagesOptions::default()) + .await .unwrap(); let bo_messages5 = bo_group5 .find_messages(FfiListMessagesOptions::default()) + .await .unwrap(); assert_eq!(bo_messages1.len(), 0); assert_eq!(bo_messages5.len(), 0); @@ -2619,9 +2621,11 @@ mod tests { let bo_messages1 = bo_group1 .find_messages(FfiListMessagesOptions::default()) + .await .unwrap(); let bo_messages5 = bo_group5 .find_messages(FfiListMessagesOptions::default()) + .await .unwrap(); assert_eq!(bo_messages1.len(), 1); assert_eq!(bo_messages5.len(), 1); @@ -2740,11 +2744,13 @@ mod tests { alix_group.sync().await.unwrap(); let alix_messages = alix_group .find_messages(FfiListMessagesOptions::default()) + .await .unwrap(); bo_group.sync().await.unwrap(); let bo_messages = bo_group .find_messages(FfiListMessagesOptions::default()) + .await .unwrap(); assert_eq!(bo_messages.len(), 9); assert_eq!(alix_messages.len(), 10); @@ -2928,15 +2934,19 @@ mod tests { // Get the message count for all the clients let caro_messages = caro_group .find_messages(FfiListMessagesOptions::default()) + .await .unwrap(); let alix_messages = alix_group .find_messages(FfiListMessagesOptions::default()) + .await .unwrap(); let bo_messages = bo_group .find_messages(FfiListMessagesOptions::default()) + .await .unwrap(); let bo2_messages = bo2_group .find_messages(FfiListMessagesOptions::default()) + .await .unwrap(); assert_eq!(caro_messages.len(), 5); @@ -2992,9 +3002,11 @@ mod tests { let alix_messages = alix_group .find_messages(FfiListMessagesOptions::default()) + .await .unwrap(); let bo_messages = bo_group .find_messages(FfiListMessagesOptions::default()) + .await .unwrap(); let alix_can_see_bo_message = alix_messages @@ -3101,6 +3113,7 @@ mod tests { let bo_messages = bo_group .find_messages(FfiListMessagesOptions::default()) + .await .unwrap(); assert_eq!(bo_messages.len(), 0); @@ -3116,8 +3129,12 @@ mod tests { let bo_messages = bo_group .find_messages(FfiListMessagesOptions::default()) + .await .unwrap(); - assert!(bo_messages.first().unwrap().kind == FfiConversationMessageKind::MembershipChange); + assert_eq!( + bo_messages.first().unwrap().kind, + FfiConversationMessageKind::MembershipChange + ); assert_eq!(bo_messages.len(), 1); let bo_members = bo_group.list_members().await.unwrap(); @@ -3175,6 +3192,7 @@ mod tests { let bo_messages1 = bo_group .find_messages(FfiListMessagesOptions::default()) + .await .unwrap(); assert_eq!(bo_messages1.len(), first_msg_check); @@ -3187,6 +3205,7 @@ mod tests { let alix_messages = alix_group .find_messages(FfiListMessagesOptions::default()) + .await .unwrap(); assert_eq!(alix_messages.len(), second_msg_check); @@ -3196,6 +3215,7 @@ mod tests { let bo_messages2 = bo_group .find_messages(FfiListMessagesOptions::default()) + .await .unwrap(); assert_eq!(bo_messages2.len(), second_msg_check); assert_eq!(message_callbacks.message_count(), second_msg_check as u32); @@ -4418,15 +4438,19 @@ mod tests { // Get messages for both participants in both conversations let alix_dm_messages = alix_dm .find_messages(FfiListMessagesOptions::default()) + .await .unwrap(); let bo_dm_messages = bo_dm .find_messages(FfiListMessagesOptions::default()) + .await .unwrap(); let alix_group_messages = alix_group .find_messages(FfiListMessagesOptions::default()) + .await .unwrap(); let bo_group_messages = bo_group .find_messages(FfiListMessagesOptions::default()) + .await .unwrap(); // Verify DM messages @@ -4545,6 +4569,7 @@ mod tests { .await .unwrap()[0] .find_messages(FfiListMessagesOptions::default()) + .await .unwrap(); let bo_dm_messages = client_b .conversations() @@ -4552,6 +4577,7 @@ mod tests { .await .unwrap()[0] .find_messages(FfiListMessagesOptions::default()) + .await .unwrap(); assert_eq!(alix_dm_messages[0].content, "Hello in DM".as_bytes()); assert_eq!(bo_dm_messages[0].content, "Hello in DM".as_bytes()); From e02b70ba2b5e263a8b3efb226187e744a213627b Mon Sep 17 00:00:00 2001 From: Mojtaba Chenani Date: Tue, 17 Dec 2024 17:22:45 +0100 Subject: [PATCH 25/28] pull changes from https://github.com/xmtp/libxmtp/tree/insipx/troubleshoot-test --- bindings_ffi/src/mls.rs | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/bindings_ffi/src/mls.rs b/bindings_ffi/src/mls.rs index c3a33e36a..5893c24db 100644 --- a/bindings_ffi/src/mls.rs +++ b/bindings_ffi/src/mls.rs @@ -2091,6 +2091,9 @@ mod tests { .await .unwrap(); + let conn = client.inner_client.context().store().conn().unwrap(); + conn.register_triggers(); + register_client(&ffi_inbox_owner, &client).await; client } @@ -2572,6 +2575,8 @@ mod tests { async fn test_can_stream_group_messages_for_updates() { let alix = new_test_client().await; let bo = new_test_client().await; + let alix_provider = alix.inner_client.mls_provider().unwrap(); + let bo_provider = bo.inner_client.mls_provider().unwrap(); // Stream all group messages let message_callbacks = Arc::new(RustStreamCallback::default()); @@ -2604,14 +2609,21 @@ mod tests { .unwrap(); let bo_group = &bo_groups[0]; bo_group.sync().await.unwrap(); + + // alix published + processed group creation and name update + assert_eq!(alix_provider.conn_ref().intents_published(), 2); + assert_eq!(alix_provider.conn_ref().intents_deleted(), 2); + bo_group .update_group_name("Old Name2".to_string()) .await .unwrap(); message_callbacks.wait_for_delivery(None).await.unwrap(); + assert_eq!(bo_provider.conn_ref().intents_published(), 1); alix_group.send(b"Hello there".to_vec()).await.unwrap(); message_callbacks.wait_for_delivery(None).await.unwrap(); + assert_eq!(alix_provider.conn_ref().intents_published(), 3); let dm = bo .conversations() @@ -2619,6 +2631,7 @@ mod tests { .await .unwrap(); dm.send(b"Hello again".to_vec()).await.unwrap(); + assert_eq!(bo_provider.conn_ref().intents_published(), 3); message_callbacks.wait_for_delivery(None).await.unwrap(); // Uncomment the following lines to add more group name updates @@ -2627,6 +2640,7 @@ mod tests { .await .unwrap(); message_callbacks.wait_for_delivery(None).await.unwrap(); + assert_eq!(bo_provider.conn_ref().intents_published(), 4); assert_eq!(message_callbacks.message_count(), 6); From 9144b6588d688ddd2868307f8df1c70b17b2d6ce Mon Sep 17 00:00:00 2001 From: Mojtaba Chenani Date: Tue, 17 Dec 2024 17:55:30 +0100 Subject: [PATCH 26/28] fix tests --- bindings_ffi/src/mls.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/bindings_ffi/src/mls.rs b/bindings_ffi/src/mls.rs index 5893c24db..cc9624dde 100644 --- a/bindings_ffi/src/mls.rs +++ b/bindings_ffi/src/mls.rs @@ -2640,6 +2640,7 @@ mod tests { .await .unwrap(); message_callbacks.wait_for_delivery(None).await.unwrap(); + message_callbacks.wait_for_delivery(None).await.unwrap(); assert_eq!(bo_provider.conn_ref().intents_published(), 4); assert_eq!(message_callbacks.message_count(), 6); From 7e7410f565eb42892cad2f4a91b61a0cf0eb801b Mon Sep 17 00:00:00 2001 From: Ry Racherbaumer Date: Tue, 17 Dec 2024 12:30:19 -0600 Subject: [PATCH 27/28] Fix node bindings tests --- bindings_node/test/Conversations.test.ts | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/bindings_node/test/Conversations.test.ts b/bindings_node/test/Conversations.test.ts index ee40b431d..c6c6e2a1b 100644 --- a/bindings_node/test/Conversations.test.ts +++ b/bindings_node/test/Conversations.test.ts @@ -54,14 +54,14 @@ describe('Conversations', () => { updateGroupPinnedFrameUrlPolicy: 0, }) expect(group.addedByInboxId()).toBe(client1.inboxId()) - expect(group.findMessages().length).toBe(1) + expect((await group.findMessages()).length).toBe(1) const members = await group.listMembers() expect(members.length).toBe(2) const memberInboxIds = members.map((member) => member.inboxId) expect(memberInboxIds).toContain(client1.inboxId()) expect(memberInboxIds).toContain(client2.inboxId()) - expect(group.groupMetadata().conversationType()).toBe('group') - expect(group.groupMetadata().creatorInboxId()).toBe(client1.inboxId()) + expect((await group.groupMetadata()).conversationType()).toBe('group') + expect((await group.groupMetadata()).creatorInboxId()).toBe(client1.inboxId()) expect(group.consentState()).toBe(ConsentState.Allowed) @@ -198,14 +198,14 @@ describe('Conversations', () => { updateGroupPinnedFrameUrlPolicy: 0, }) expect(group.addedByInboxId()).toBe(client1.inboxId()) - expect(group.findMessages().length).toBe(0) + expect((await group.findMessages()).length).toBe(0) const members = await group.listMembers() expect(members.length).toBe(2) const memberInboxIds = members.map((member) => member.inboxId) expect(memberInboxIds).toContain(client1.inboxId()) expect(memberInboxIds).toContain(client2.inboxId()) - expect(group.groupMetadata().conversationType()).toBe('dm') - expect(group.groupMetadata().creatorInboxId()).toBe(client1.inboxId()) + expect((await group.groupMetadata()).conversationType()).toBe('dm') + expect((await group.groupMetadata()).creatorInboxId()).toBe(client1.inboxId()) expect(group.consentState()).toBe(ConsentState.Allowed) From 37ed3eccf59ea88aa2f4e1942f49146c3d10560e Mon Sep 17 00:00:00 2001 From: Mojtaba Chenani Date: Tue, 17 Dec 2024 19:37:53 +0100 Subject: [PATCH 28/28] fix fmt --- bindings_node/test/Conversations.test.ts | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/bindings_node/test/Conversations.test.ts b/bindings_node/test/Conversations.test.ts index c6c6e2a1b..c6123bc6d 100644 --- a/bindings_node/test/Conversations.test.ts +++ b/bindings_node/test/Conversations.test.ts @@ -61,7 +61,9 @@ describe('Conversations', () => { expect(memberInboxIds).toContain(client1.inboxId()) expect(memberInboxIds).toContain(client2.inboxId()) expect((await group.groupMetadata()).conversationType()).toBe('group') - expect((await group.groupMetadata()).creatorInboxId()).toBe(client1.inboxId()) + expect((await group.groupMetadata()).creatorInboxId()).toBe( + client1.inboxId() + ) expect(group.consentState()).toBe(ConsentState.Allowed) @@ -205,7 +207,9 @@ describe('Conversations', () => { expect(memberInboxIds).toContain(client1.inboxId()) expect(memberInboxIds).toContain(client2.inboxId()) expect((await group.groupMetadata()).conversationType()).toBe('dm') - expect((await group.groupMetadata()).creatorInboxId()).toBe(client1.inboxId()) + expect((await group.groupMetadata()).creatorInboxId()).toBe( + client1.inboxId() + ) expect(group.consentState()).toBe(ConsentState.Allowed)