From e499fd31c171c192ad1c92edadbcc7109da1e42f Mon Sep 17 00:00:00 2001 From: Andrew Plaza Date: Tue, 29 Oct 2024 15:20:52 -0400 Subject: [PATCH] create GroupQueryArgs for ergonomic paramters to find_groups --- xmtp_mls/src/client.rs | 37 ++---- xmtp_mls/src/groups/mod.rs | 40 +++--- xmtp_mls/src/storage/encrypted_store/group.rs | 123 +++++++++++++----- xmtp_mls/src/subscriptions.rs | 12 +- 4 files changed, 125 insertions(+), 87 deletions(-) diff --git a/xmtp_mls/src/client.rs b/xmtp_mls/src/client.rs index b133063b0..30155ab53 100644 --- a/xmtp_mls/src/client.rs +++ b/xmtp_mls/src/client.rs @@ -38,9 +38,8 @@ use xmtp_proto::xmtp::mls::api::v1::{ use crate::{ api::ApiClientWrapper, groups::{ - group_metadata::ConversationType, group_permissions::PolicySet, - validated_commit::CommitValidationError, GroupError, GroupMetadataOptions, IntentError, - MlsGroup, + group_permissions::PolicySet, validated_commit::CommitValidationError, GroupError, + GroupMetadataOptions, IntentError, MlsGroup, }, identity::{parse_credential, Identity, IdentityError}, identity_updates::{load_identity_updates, IdentityUpdateError}, @@ -48,6 +47,7 @@ use crate::{ mutex_registry::MutexRegistry, retry::Retry, retry_async, retryable, + storage::group::GroupQueryArgs, storage::{ consent_record::{ConsentState, ConsentType, StoredConsentRecord}, db_connection::DbConnection, @@ -217,16 +217,6 @@ impl From<&str> for ClientError { } } -#[derive(Debug, Default)] -pub struct FindGroupParams { - pub allowed_states: Option>, - pub created_after_ns: Option, - pub created_before_ns: Option, - pub limit: Option, - pub conversation_type: Option, - pub consent_state: Option, -} - /// Clients manage access to the network, identity, and data store pub struct Client> { pub(crate) api_client: ApiClientWrapper, @@ -682,18 +672,11 @@ where /// - created_after_ns: only return groups created after the given timestamp (in nanoseconds) /// - created_before_ns: only return groups created before the given timestamp (in nanoseconds) /// - limit: only return the first `limit` groups - pub fn find_groups(&self, params: FindGroupParams) -> Result>, ClientError> { + pub fn find_groups(&self, args: GroupQueryArgs) -> Result>, ClientError> { Ok(self .store() .conn()? - .find_groups( - params.allowed_states, - params.created_after_ns, - params.created_before_ns, - params.limit, - params.conversation_type, - params.consent_state, - )? + .find_groups(args)? .into_iter() .map(|stored_group| { MlsGroup::new(self.clone(), stored_group.id, stored_group.created_at_ns) @@ -981,12 +964,12 @@ pub(crate) mod tests { use crate::{ builder::ClientBuilder, - client::FindGroupParams, groups::GroupMetadataOptions, hpke::{decrypt_welcome, encrypt_welcome}, identity::serialize_key_package_hash_ref, storage::{ consent_record::{ConsentState, ConsentType, StoredConsentRecord}, + group::GroupQueryArgs, group_message::MsgQueryArgs, schema::identity_updates, }, @@ -1091,7 +1074,7 @@ pub(crate) mod tests { .create_group(None, GroupMetadataOptions::default()) .unwrap(); - let groups = client.find_groups(FindGroupParams::default()).unwrap(); + let groups = client.find_groups(GroupQueryArgs::default()).unwrap(); assert_eq!(groups.len(), 2); assert_eq!(groups[0].group_id, group_1.group_id); assert_eq!(groups[1].group_id, group_2.group_id); @@ -1166,7 +1149,7 @@ pub(crate) mod tests { let bob_received_groups = bo.sync_welcomes().await.unwrap(); assert_eq!(bob_received_groups.len(), 2); - let bo_groups = bo.find_groups(FindGroupParams::default()).unwrap(); + let bo_groups = bo.find_groups(GroupQueryArgs::default()).unwrap(); let bo_group1 = bo.group(alix_bo_group1.clone().group_id).unwrap(); let bo_messages1 = bo_group1.find_messages(&MsgQueryArgs::default()).unwrap(); assert_eq!(bo_messages1.len(), 0); @@ -1240,7 +1223,7 @@ pub(crate) mod tests { tracing::info!("Syncing bolas welcomes"); // See if Bola can see that they were added to the group bola.sync_welcomes().await.unwrap(); - let bola_groups = bola.find_groups(FindGroupParams::default()).unwrap(); + let bola_groups = bola.find_groups(GroupQueryArgs::default()).unwrap(); assert_eq!(bola_groups.len(), 1); let bola_group = bola_groups.first().unwrap(); tracing::info!("Syncing bolas messages"); @@ -1374,7 +1357,7 @@ pub(crate) mod tests { bo.sync_welcomes().await.unwrap(); // Bo should have two groups now - let bo_groups = bo.find_groups(FindGroupParams::default()).unwrap(); + let bo_groups = bo.find_groups(GroupQueryArgs::default()).unwrap(); assert_eq!(bo_groups.len(), 2); // Bo's original key should be deleted diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index 3f38628b0..5064fc688 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -1529,7 +1529,7 @@ pub(crate) mod tests { use crate::{ assert_err, builder::ClientBuilder, - client::{FindGroupParams, MessageProcessingError}, + client::MessageProcessingError, codecs::{group_updated::GroupUpdatedCodec, ContentCodec}, groups::{ build_dm_protected_metadata_extension, build_mutable_metadata_extension_default, @@ -1543,6 +1543,7 @@ pub(crate) mod tests { }, storage::{ consent_record::ConsentState, + group::GroupQueryArgs, group::Purpose, group_intent::{IntentKind, IntentState}, group_message::{GroupMessageKind, MsgQueryArgs, StoredGroupMessage}, @@ -1556,7 +1557,7 @@ pub(crate) mod tests { async fn receive_group_invite(client: &FullXmtpClient) -> MlsGroup { client.sync_welcomes().await.unwrap(); - let mut groups = client.find_groups(FindGroupParams::default()).unwrap(); + let mut groups = client.find_groups(GroupQueryArgs::default()).unwrap(); groups.remove(0) } @@ -1865,7 +1866,7 @@ pub(crate) mod tests { // Bo should not be able to actually read this group bo.sync_welcomes().await.unwrap(); - let groups = bo.find_groups(FindGroupParams::default()).unwrap(); + let groups = bo.find_groups(GroupQueryArgs::default()).unwrap(); assert_eq!(groups.len(), 0); assert_logged!("failed to create group from welcome", 1); }); @@ -1994,7 +1995,7 @@ pub(crate) mod tests { group.send_message(b"hello").await.expect("send message"); bola_client.sync_welcomes().await.unwrap(); - let bola_groups = bola_client.find_groups(FindGroupParams::default()).unwrap(); + let bola_groups = bola_client.find_groups(GroupQueryArgs::default()).unwrap(); let bola_group = bola_groups.first().unwrap(); bola_group.sync().await.unwrap(); let bola_messages = bola_group.find_messages(&MsgQueryArgs::default()).unwrap(); @@ -2355,7 +2356,7 @@ pub(crate) mod tests { .await .unwrap(); bola.sync_welcomes().await.unwrap(); - let bola_groups = bola.find_groups(FindGroupParams::default()).unwrap(); + let bola_groups = bola.find_groups(GroupQueryArgs::default()).unwrap(); assert_eq!(bola_groups.len(), 1); let bola_group = bola_groups.first().unwrap(); bola_group.sync().await.unwrap(); @@ -2523,7 +2524,7 @@ pub(crate) mod tests { .await .unwrap(); bola.sync_welcomes().await.unwrap(); - let bola_groups = bola.find_groups(FindGroupParams::default()).unwrap(); + let bola_groups = bola.find_groups(GroupQueryArgs::default()).unwrap(); assert_eq!(bola_groups.len(), 1); let bola_group = bola_groups.first().unwrap(); bola_group.sync().await.unwrap(); @@ -2603,7 +2604,7 @@ pub(crate) mod tests { .await .unwrap(); bola.sync_welcomes().await.unwrap(); - let bola_groups = bola.find_groups(FindGroupParams::default()).unwrap(); + let bola_groups = bola.find_groups(GroupQueryArgs::default()).unwrap(); assert_eq!(bola_groups.len(), 1); let bola_group = bola_groups.first().unwrap(); bola_group.sync().await.unwrap(); @@ -2619,7 +2620,7 @@ pub(crate) mod tests { // Verify that bola can not add caro because they are not an admin bola.sync_welcomes().await.unwrap(); - let bola_groups = bola.find_groups(FindGroupParams::default()).unwrap(); + let bola_groups = bola.find_groups(GroupQueryArgs::default()).unwrap(); assert_eq!(bola_groups.len(), 1); let bola_group: &MlsGroup<_> = bola_groups.first().unwrap(); bola_group.sync().await.unwrap(); @@ -2683,7 +2684,7 @@ pub(crate) mod tests { // Verify that bola can not add charlie because they are not an admin bola.sync_welcomes().await.unwrap(); - let bola_groups = bola.find_groups(FindGroupParams::default()).unwrap(); + let bola_groups = bola.find_groups(GroupQueryArgs::default()).unwrap(); assert_eq!(bola_groups.len(), 1); let bola_group: &MlsGroup<_> = bola_groups.first().unwrap(); bola_group.sync().await.unwrap(); @@ -2712,7 +2713,7 @@ pub(crate) mod tests { .await .unwrap(); bola.sync_welcomes().await.unwrap(); - let bola_groups = bola.find_groups(FindGroupParams::default()).unwrap(); + let bola_groups = bola.find_groups(GroupQueryArgs::default()).unwrap(); assert_eq!(bola_groups.len(), 1); let bola_group = bola_groups.first().unwrap(); bola_group.sync().await.unwrap(); @@ -2728,7 +2729,7 @@ pub(crate) mod tests { // Verify that bola can not add caro as an admin because they are not a super admin bola.sync_welcomes().await.unwrap(); - let bola_groups = bola.find_groups(FindGroupParams::default()).unwrap(); + let bola_groups = bola.find_groups(GroupQueryArgs::default()).unwrap(); assert_eq!(bola_groups.len(), 1); let bola_group: &MlsGroup<_> = bola_groups.first().unwrap(); bola_group.sync().await.unwrap(); @@ -2979,7 +2980,7 @@ pub(crate) mod tests { // Step 3: Verify that Bola can update the group name, and amal sees the update bola.sync_welcomes().await.unwrap(); - let bola_groups = bola.find_groups(FindGroupParams::default()).unwrap(); + let bola_groups = bola.find_groups(GroupQueryArgs::default()).unwrap(); let bola_group: &MlsGroup<_> = bola_groups.first().unwrap(); bola_group.sync().await.unwrap(); bola_group @@ -3045,7 +3046,7 @@ pub(crate) mod tests { // Step 3: Bola attemps to add Caro, but fails because group is admin only let caro = ClientBuilder::new_test_client(&generate_local_wallet()).await; bola.sync_welcomes().await.unwrap(); - let bola_groups = bola.find_groups(FindGroupParams::default()).unwrap(); + let bola_groups = bola.find_groups(GroupQueryArgs::default()).unwrap(); let bola_group: &MlsGroup<_> = bola_groups.first().unwrap(); bola_group.sync().await.unwrap(); let result = bola_group @@ -3204,12 +3205,7 @@ pub(crate) mod tests { // Bola can message amal let _ = bola.sync_welcomes().await; - let bola_groups = bola - .find_groups(FindGroupParams { - conversation_type: None, - ..FindGroupParams::default() - }) - .unwrap(); + let bola_groups = bola.find_groups(GroupQueryArgs::default()).unwrap(); let bola_dm: &MlsGroup<_> = bola_groups.first().unwrap(); bola_dm.send_message(b"test one").await.unwrap(); @@ -3556,7 +3552,7 @@ pub(crate) mod tests { .unwrap(); bola.sync_welcomes().await.unwrap(); - let bola_groups = bola.find_groups(FindGroupParams::default()).unwrap(); + let bola_groups = bola.find_groups(GroupQueryArgs::default()).unwrap(); let bola_group = bola_groups.first().unwrap(); // group consent state should default to unknown for users who did not create the group assert_eq!(bola_group.consent_state().unwrap(), ConsentState::Unknown); @@ -3575,7 +3571,7 @@ pub(crate) mod tests { .unwrap(); caro.sync_welcomes().await.unwrap(); - let caro_groups = caro.find_groups(FindGroupParams::default()).unwrap(); + let caro_groups = caro.find_groups(GroupQueryArgs::default()).unwrap(); let caro_group = caro_groups.first().unwrap(); caro_group @@ -3605,7 +3601,7 @@ pub(crate) mod tests { .await .unwrap(); bo.sync_welcomes().await.unwrap(); - let bo_groups = bo.find_groups(FindGroupParams::default()).unwrap(); + let bo_groups = bo.find_groups(GroupQueryArgs::default()).unwrap(); let bo_group = bo_groups.first().unwrap(); // Both members see the same amount of messages to start diff --git a/xmtp_mls/src/storage/encrypted_store/group.rs b/xmtp_mls/src/storage/encrypted_store/group.rs index feab3d719..f8102f4cc 100644 --- a/xmtp_mls/src/storage/encrypted_store/group.rs +++ b/xmtp_mls/src/storage/encrypted_store/group.rs @@ -1,6 +1,5 @@ //! The Group database table. Stored information surrounding group membership and ID's. -use diesel::expression::SqlLiteral; use diesel::query_dsl::QueryDsl; use diesel::{ backend::Backend, @@ -19,6 +18,7 @@ use super::{ }; use crate::{ groups::group_metadata::ConversationType, impl_fetch, impl_store, DuplicateItem, StorageError, + storage::schema::groups }; /// The Group ID type. @@ -118,37 +118,101 @@ impl StoredGroup { } } -pub struct FindGroupParams { - allowed_states: Option>, - created_after_ns: Option, - created_before_ns: Option, - limit: Option, - conversation_type: Option, +#[derive(Debug, Default)] +pub struct GroupQueryArgs { + pub allowed_states: Option>, + pub created_after_ns: Option, + pub created_before_ns: Option, + pub limit: Option, + pub conversation_type: Option, + pub consent_state: Option, } -use crate::storage::schema::consent_records; -use crate::storage::schema::groups; +impl AsRef for GroupQueryArgs { + fn as_ref(&self) -> &GroupQueryArgs { + self + } +} + +impl GroupQueryArgs { + pub fn allowed_states(self, allowed_states: Vec) -> Self { + self.maybe_allowed_states(Some(allowed_states)) + } + + pub fn maybe_allowed_states(mut self, allowed_states: Option>) -> Self { + self.allowed_states = allowed_states; + self + } + + pub fn created_after_ns(self, created_after_ns: i64) -> Self { + self.maybe_created_after_ns(Some(created_after_ns)) + } + + pub fn maybe_created_after_ns(mut self, created_after_ns: Option) -> Self { + self.created_after_ns = created_after_ns; + self + } + + pub fn created_before_ns(self, created_before_ns: i64) -> Self { + self.maybe_created_before_ns(Some(created_before_ns)) + } + + pub fn maybe_created_before_ns(mut self, created_before_ns: Option) -> Self { + self.created_before_ns = created_before_ns; + self + } + + pub fn limit(self, limit: i64) -> Self { + self.maybe_limit(Some(limit)) + } + + pub fn maybe_limit(mut self, limit: Option) -> Self { + self.limit = limit; + self + } + + pub fn conversation_type(self, conversation_type: ConversationType) -> Self { + self.maybe_conversation_type(Some(conversation_type)) + } + + pub fn maybe_conversation_type(mut self, conversation_type: Option) -> Self { + self.conversation_type = conversation_type; + self + } + + pub fn consent_state(self, consent_state: ConsentState) -> Self { + self.maybe_consent_state(Some(consent_state)) + } + + pub fn maybe_consent_state(mut self, consent_state: Option) -> Self { + self.consent_state = consent_state; + self + } +} impl DbConnection { /// Return regular [`Purpose::Conversation`] groups with additional optional filters - pub fn find_groups( + pub fn find_groups>( &self, - allowed_states: Option>, - created_after_ns: Option, - created_before_ns: Option, - limit: Option, - conversation_type: Option, - consent_state: Option, + args: A, ) -> Result, StorageError> { use crate::storage::schema::groups::dsl as groups_dsl; use crate::storage::schema::consent_records::dsl as consent_dsl; + let GroupQueryArgs { + allowed_states, + created_after_ns, + created_before_ns, + limit, + conversation_type, + consent_state + } = args.as_ref(); let mut query = groups_dsl::groups .order(groups_dsl::created_at_ns.asc()) .into_boxed(); if let Some(limit) = limit { - query = query.limit(limit); + query = query.limit(*limit); } if let Some(allowed_states) = allowed_states { @@ -552,18 +616,15 @@ pub(crate) mod tests { test_group_3.store(conn).unwrap(); let all_results = conn - .find_groups(None, None, None, None, Some(ConversationType::Group), None) + .find_groups(GroupQueryArgs::default().conversation_type(ConversationType::Group)) .unwrap(); assert_eq!(all_results.len(), 2); let pending_results = conn .find_groups( - Some(vec![GroupMembershipState::Pending]), - None, - None, - None, - Some(ConversationType::Group), - None, + GroupQueryArgs::default() + .allowed_states(vec![GroupMembershipState::Pending]) + .conversation_type(ConversationType::Group) ) .unwrap(); assert_eq!(pending_results[0].id, test_group_1.id); @@ -578,13 +639,11 @@ pub(crate) mod tests { let results_with_created_at_ns_after = conn .find_groups( - None, - Some(test_group_1.created_at_ns), - None, - Some(1), - Some(ConversationType::Group), - None, - ) + GroupQueryArgs::default() + .created_after_ns(test_group_1.created_at_ns) + .conversation_type(ConversationType::Group) + .limit(1) + ) .unwrap(); assert_eq!(results_with_created_at_ns_after.len(), 1); assert_eq!(results_with_created_at_ns_after[0].id, test_group_2.id); @@ -604,7 +663,7 @@ pub(crate) mod tests { // test only dms are returned let dm_results = conn - .find_groups(None, None, None, None, Some(ConversationType::Dm), None) + .find_groups(GroupQueryArgs::default().conversation_type(ConversationType::Dm)) .unwrap(); assert_eq!(dm_results.len(), 1); assert_eq!(dm_results[0].id, test_group_3.id); diff --git a/xmtp_mls/src/subscriptions.rs b/xmtp_mls/src/subscriptions.rs index 6998f11bc..aa4f43a0c 100644 --- a/xmtp_mls/src/subscriptions.rs +++ b/xmtp_mls/src/subscriptions.rs @@ -12,8 +12,9 @@ use crate::{ retry::Retry, retry::RetryableError, retry_async, retryable, - storage::StorageError, - storage::{group::StoredGroup, group_message::StoredGroupMessage}, + storage::{ + group::GroupQueryArgs, group::StoredGroup, group_message::StoredGroupMessage, StorageError, + }, Client, XmtpApi, }; @@ -272,7 +273,7 @@ where let mut group_id_to_info = self .store() .conn()? - .find_groups(None, None, None, None, conversation_type, None)? + .find_groups(GroupQueryArgs::default().maybe_conversation_type(conversation_type))? .into_iter() .map(Into::into) .collect::, MessagesStreamInfo>>(); @@ -384,9 +385,8 @@ pub(crate) mod tests { use crate::{ builder::ClientBuilder, - client::FindGroupParams, groups::{group_metadata::ConversationType, GroupMetadataOptions}, - storage::group_message::StoredGroupMessage, + storage::{group::GroupQueryArgs, group_message::StoredGroupMessage}, utils::test::{Delivery, FullXmtpClient, TestClient}, Client, StreamHandle, }; @@ -760,7 +760,7 @@ pub(crate) mod tests { // Verify syncing welcomes while streaming causes no issues alix.sync_welcomes().await.unwrap(); - let find_groups_results = alix.find_groups(FindGroupParams::default()).unwrap(); + let find_groups_results = alix.find_groups(GroupQueryArgs::default()).unwrap(); { let grps = groups.lock();