Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
mchenani committed Dec 9, 2024
1 parent 653f6c9 commit e5b2bf1
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 58 deletions.
2 changes: 1 addition & 1 deletion bindings_node/src/conversation.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::{ops::Deref, sync::Arc};
use futures::TryFutureExt;

use napi::{
bindgen_prelude::{Result, Uint8Array},
threadsafe_function::{ErrorStrategy, ThreadsafeFunction, ThreadsafeFunctionCallMode},
Expand Down
4 changes: 2 additions & 2 deletions xmtp_mls/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -845,8 +845,8 @@ where
async move {
tracing::info!(
inbox_id = self.inbox_id(),
"current epoch for [{}] in sync_all_groups()",
self.inbox_id(),
"[{}] syncing group",
self.inbox_id()
);
tracing::info!(
inbox_id = self.inbox_id(),
Expand Down
6 changes: 2 additions & 4 deletions xmtp_mls/src/groups/intents.rs
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,6 @@ impl TryFrom<Vec<u8>> for PostCommitAction {
pub(crate) mod tests {
#[cfg(target_arch = "wasm32")]
wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_dedicated_worker);

use openmls::prelude::{MlsMessageBodyIn, MlsMessageIn, ProcessedMessageContent};
use tls_codec::Deserialize;
use xmtp_cryptography::utils::generate_local_wallet;
Expand Down Expand Up @@ -866,9 +865,8 @@ pub(crate) mod tests {
let provider = group.client.mls_provider().unwrap();
let decrypted_message = match group
.load_mls_group_with_lock(&provider, |mut mls_group| {
mls_group
.process_message(&provider, mls_message)
.map_err(|e| GroupError::Generic(e.to_string()))
Ok(mls_group
.process_message(&provider, mls_message).unwrap())
}) {
Ok(message) => message,
Err(err) => panic!("Error: {:?}", err),
Expand Down
1 change: 0 additions & 1 deletion xmtp_mls/src/groups/members.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ where
provider: &XmtpOpenMlsProvider,
) -> Result<Vec<GroupMember>, GroupError> {
let group_membership = self.load_mls_group_with_lock(provider, |mls_group| {
// Extract group membership from extensions
Ok(extract_group_membership(mls_group.extensions())?)
})?;
let requests = group_membership
Expand Down
29 changes: 22 additions & 7 deletions xmtp_mls/src/groups/mls_sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -360,9 +360,8 @@ where
if intent.state == IntentState::Committed {
return Ok(IntentState::Committed);
}
let group_epoch = mls_group.epoch();

let message_epoch = message.epoch();
let group_epoch = mls_group.epoch();
debug!(
inbox_id = self.client.inbox_id(),
installation_id = hex::encode(self.client.installation_id()),
Expand Down Expand Up @@ -702,6 +701,14 @@ where
let intent = provider
.conn_ref()
.find_group_intent_by_payload_hash(sha256(envelope.data.as_slice()));
tracing::info!(
inbox_id = self.client.inbox_id(),
group_id = hex::encode(&self.group_id),
msg_id = envelope.id,
"Processing envelope with hash {:?}",
hex::encode(sha256(envelope.data.as_slice()))
);

match intent {
// Intent with the payload hash matches
Ok(Some(intent)) => {
Expand Down Expand Up @@ -729,6 +736,14 @@ where
}
// No matching intent found
Ok(None) => {
tracing::info!(
inbox_id = self.client.inbox_id(),
group_id = hex::encode(&self.group_id),
msg_id = envelope.id,
"client [{}] is about to process external envelope [{}]",
self.client.inbox_id(),
envelope.id
);
self.process_external_message(provider, message, envelope)
.await
}
Expand Down Expand Up @@ -792,7 +807,10 @@ where
for message in messages.into_iter() {
let result = retry_async!(
Retry::default(),
(async { self.consume_message(&message, provider.conn_ref()).await })
(async {
self.consume_message(&message, provider.conn_ref())
.await
})
);
if let Err(e) = result {
let is_retryable = e.is_retryable();
Expand Down Expand Up @@ -1139,10 +1157,7 @@ where
return Ok(());
}
// determine how long of an interval in time to use before updating list
let interval_ns = match update_interval_ns {
Some(val) => val,
None => SYNC_UPDATE_INSTALLATIONS_INTERVAL_NS,
};
let interval_ns = update_interval_ns.unwrap_or_else(|| SYNC_UPDATE_INSTALLATIONS_INTERVAL_NS);

let now_ns = crate::utils::time::now_ns();
let last_ns = provider
Expand Down
73 changes: 31 additions & 42 deletions xmtp_mls/src/groups/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,19 +332,6 @@ impl<ScopedClient: ScopedGroupClient> MlsGroup<ScopedClient> {
}

// Load the stored OpenMLS group from the OpenMLS provider's keystore
#[tracing::instrument(level = "trace", skip_all)]
pub(crate) fn load_mls_group(
&self,
provider: impl OpenMlsProvider,
) -> Result<OpenMlsGroup, GroupError> {
let mls_group =
OpenMlsGroup::load(provider.storage(), &GroupId::from_slice(&self.group_id))
.map_err(|_| GroupError::GroupNotFound)?
.ok_or(GroupError::GroupNotFound)?;

Ok(mls_group)
}

#[tracing::instrument(level = "trace", skip_all)]
pub(crate) fn load_mls_group_with_lock<F, R>(
&self,
Expand All @@ -358,7 +345,6 @@ impl<ScopedClient: ScopedGroupClient> MlsGroup<ScopedClient> {
let group_id = self.group_id.clone();

// Acquire the lock synchronously using blocking_lock

let _lock = MLS_COMMIT_LOCK.get_lock_sync(group_id.clone())?;
// Load the MLS group
let mls_group =
Expand All @@ -370,6 +356,7 @@ impl<ScopedClient: ScopedGroupClient> MlsGroup<ScopedClient> {
operation(mls_group)
}

// Load the stored OpenMLS group from the OpenMLS provider's keystore
#[tracing::instrument(level = "trace", skip_all)]
pub(crate) async fn load_mls_group_with_lock_async<F, E, R, Fut>(
&self,
Expand Down Expand Up @@ -1173,33 +1160,35 @@ impl<ScopedClient: ScopedGroupClient> MlsGroup<ScopedClient> {
///
/// If the current user has been kicked out of the group, `is_active` will return `false`
pub fn is_active(&self, provider: impl OpenMlsProvider) -> Result<bool, GroupError> {
self.load_mls_group_with_lock(provider, |mls_group| Ok(mls_group.is_active()))
self.load_mls_group_with_lock(provider, |mls_group|
Ok(mls_group.is_active())
)
}

/// Get the `GroupMetadata` of the group.
pub fn metadata(&self, provider: impl OpenMlsProvider) -> Result<GroupMetadata, GroupError> {
self.load_mls_group_with_lock(provider, |mls_group| {
self.load_mls_group_with_lock(provider, |mls_group|
Ok(extract_group_metadata(&mls_group)?)
})
)
}

/// Get the `GroupMutableMetadata` of the group.
pub fn mutable_metadata(
&self,
provider: impl OpenMlsProvider,
) -> Result<GroupMutableMetadata, GroupError> {
self.load_mls_group_with_lock(provider, |mls_group| {
self.load_mls_group_with_lock(provider, |mls_group|
Ok(GroupMutableMetadata::try_from(&mls_group)?)
})
)
}

pub fn permissions(&self) -> Result<GroupMutablePermissions, GroupError> {
let conn = self.context().store().conn()?;
let provider = XmtpOpenMlsProvider::new(conn);

self.load_mls_group_with_lock(&provider, |mls_group| {
self.load_mls_group_with_lock(&provider, |mls_group|
Ok(extract_group_permissions(&mls_group)?)
})
)
}

/// Used for testing that dm group validation works as expected.
Expand Down Expand Up @@ -1920,17 +1909,17 @@ pub(crate) mod tests {

// Check Amal's MLS group state.
let amal_db = XmtpOpenMlsProvider::from(amal.context.store().conn().unwrap());
let amal_members_len = amal_group.load_mls_group_with_lock(&amal_db, |amal_mls_group| {
Ok(amal_mls_group.members().count())
}).unwrap();
let amal_members_len = amal_group.load_mls_group_with_lock(&amal_db, |mls_group|
Ok(mls_group.members().count())
).unwrap();

assert_eq!(amal_members_len, 3);

// Check Bola's MLS group state.
let bola_db = XmtpOpenMlsProvider::from(bola.context.store().conn().unwrap());
let bola_members_len = bola_group.load_mls_group_with_lock(&bola_db, |bola_mls_group| {
Ok(bola_mls_group.members().count())
}).unwrap();
let bola_members_len = bola_group.load_mls_group_with_lock(&bola_db, |mls_group|
Ok(mls_group.members().count())
).unwrap();

assert_eq!(bola_members_len, 3);

Expand Down Expand Up @@ -2009,8 +1998,8 @@ pub(crate) mod tests {
Ok(mls_group) // Return the updated group if necessary
}).unwrap();

force_add_member(&alix, &bo, &alix_group, &mut mls_group, &provider).await;
// Now add bo to the group
force_add_member(&alix, &bo, &alix_group, &mut mls_group, &provider).await;

// Bo should not be able to actually read this group
bo.sync_welcomes(&bo.store().conn().unwrap()).await.unwrap();
Expand Down Expand Up @@ -2134,9 +2123,9 @@ pub(crate) mod tests {
assert_eq!(messages.len(), 2);

let provider: XmtpOpenMlsProvider = client.context.store().conn().unwrap().into();
let pending_commit_is_none = group.load_mls_group_with_lock(&provider, |mls_group| {
let pending_commit_is_none = group.load_mls_group_with_lock(&provider, |mls_group|
Ok(mls_group.pending_commit().is_none())
}).unwrap();
).unwrap();

assert!(pending_commit_is_none);

Expand Down Expand Up @@ -2317,9 +2306,9 @@ pub(crate) mod tests {
assert!(new_installations_were_added.is_ok());

group.sync().await.unwrap();
let num_members = group.load_mls_group_with_lock(&provider, |mls_group| {
let num_members = group.load_mls_group_with_lock(&provider, |mls_group|
Ok(mls_group.members().collect::<Vec<_>>().len())
}).unwrap();
).unwrap();

assert_eq!(num_members, 3);
}
Expand Down Expand Up @@ -3899,9 +3888,9 @@ pub(crate) mod tests {
)
.unwrap();
assert!(valid_dm_group
.load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group| {
.load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group|
validate_dm_group(&client, &mls_group, added_by_inbox)
})
)
.is_ok());

// Test case 2: Invalid conversation type
Expand All @@ -3917,9 +3906,9 @@ pub(crate) mod tests {
)
.unwrap();
assert!(matches!(
invalid_type_group.load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group| {
invalid_type_group.load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group|
validate_dm_group(&client, &mls_group, added_by_inbox)
}),
),
Err(GroupError::Generic(msg)) if msg.contains("Invalid conversation type")
));
// Test case 3: Missing DmMembers
Expand All @@ -3939,9 +3928,9 @@ pub(crate) mod tests {
)
.unwrap();
assert!(matches!(
mismatched_dm_members_group.load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group| {
mismatched_dm_members_group.load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group|
validate_dm_group(&client, &mls_group, added_by_inbox)
}),
),
Err(GroupError::Generic(msg)) if msg.contains("DM members do not match expected inboxes")
));

Expand All @@ -3961,9 +3950,9 @@ pub(crate) mod tests {
)
.unwrap();
assert!(matches!(
non_empty_admin_list_group.load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group| {
non_empty_admin_list_group.load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group|
validate_dm_group(&client, &mls_group, added_by_inbox)
}),
),
Err(GroupError::Generic(msg)) if msg.contains("DM group must have empty admin and super admin lists")
));

Expand All @@ -3982,9 +3971,9 @@ pub(crate) mod tests {
)
.unwrap();
assert!(matches!(
invalid_permissions_group.load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group| {
invalid_permissions_group.load_mls_group_with_lock(client.mls_provider().unwrap(), |mls_group|
validate_dm_group(&client, &mls_group, added_by_inbox)
}),
),
Err(GroupError::Generic(msg)) if msg.contains("Invalid permissions for DM group")
));
}
Expand Down
10 changes: 9 additions & 1 deletion xmtp_mls/src/groups/subscriptions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,24 @@ impl<ScopedClient: ScopedGroupClient> MlsGroup<ScopedClient> {
let process_result = retry_async!(
Retry::default(),
(async {
let client_id = &client_id;
let msgv1 = &msgv1;
self.context()
.store()
.transaction_async(|provider| async move {
let prov_ref = &provider; // Borrow provider instead of moving it
self.load_mls_group_with_lock_async(
prov_ref,
|mut mls_group| async move {
|mls_group| async move {
// Attempt processing immediately, but fail if the message is not an Application Message
// Returning an error should roll back the DB tx
tracing::info!(
inbox_id = self.client.inbox_id(),
group_id = hex::encode(&self.group_id),
msg_id = msgv1.id,
"current epoch for [{}] in process_stream_entry()",
client_id,
);
self.process_message(&prov_ref, msgv1, false)
.await
// NOTE: We want to make sure we retry an error in process_message
Expand Down

0 comments on commit e5b2bf1

Please sign in to comment.