diff --git a/bindings_ffi/src/mls.rs b/bindings_ffi/src/mls.rs index 9fd6399ab..527efc096 100644 --- a/bindings_ffi/src/mls.rs +++ b/bindings_ffi/src/mls.rs @@ -216,11 +216,11 @@ impl FfiConversations { if !account_addresses.is_empty() { convo.add_members(account_addresses).await?; } - let out = Arc::new(FfiGroup { inner_client: self.inner_client.clone(), group_id: convo.group_id, created_at_ns: convo.created_at_ns, + added_by_address: convo.added_by_address, }); Ok(out) @@ -232,13 +232,12 @@ impl FfiConversations { ) -> Result, GenericError> { let inner = self.inner_client.as_ref(); let group = inner.process_streamed_welcome_message(envelope_bytes)?; - let out = Arc::new(FfiGroup { inner_client: self.inner_client.clone(), group_id: group.group_id, created_at_ns: group.created_at_ns, + added_by_address: group.added_by_address, }); - Ok(out) } @@ -266,6 +265,7 @@ impl FfiConversations { inner_client: self.inner_client.clone(), group_id: group.group_id, created_at_ns: group.created_at_ns, + added_by_address: group.added_by_address, }) }) .collect(); @@ -285,6 +285,7 @@ impl FfiConversations { inner_client: client.clone(), group_id: convo.group_id, created_at_ns: convo.created_at_ns, + added_by_address: convo.added_by_address, })) }, || {}, // on_close_callback @@ -318,6 +319,7 @@ pub struct FfiGroup { inner_client: Arc, group_id: Vec, created_at_ns: i64, + added_by_address: Option, } #[derive(uniffi::Record)] @@ -340,6 +342,7 @@ impl FfiGroup { self.inner_client.as_ref(), self.group_id.clone(), self.created_at_ns, + self.added_by_address.clone(), ); group.send_message(content_bytes.as_slice()).await?; @@ -352,6 +355,7 @@ impl FfiGroup { self.inner_client.as_ref(), self.group_id.clone(), self.created_at_ns, + self.added_by_address.clone(), ); group.sync().await?; @@ -367,6 +371,7 @@ impl FfiGroup { self.inner_client.as_ref(), self.group_id.clone(), self.created_at_ns, + self.added_by_address.clone(), ); let messages: Vec = group @@ -392,6 +397,7 @@ impl FfiGroup { self.inner_client.as_ref(), self.group_id.clone(), self.created_at_ns, + self.added_by_address.clone(), ); let message = group.process_streamed_group_message(envelope_bytes).await?; let ffi_message = message.into(); @@ -404,6 +410,7 @@ impl FfiGroup { self.inner_client.as_ref(), self.group_id.clone(), self.created_at_ns, + self.added_by_address.clone(), ); let members: Vec = group @@ -425,6 +432,7 @@ impl FfiGroup { self.inner_client.as_ref(), self.group_id.clone(), self.created_at_ns, + self.added_by_address.clone(), ); group.add_members(account_addresses).await?; @@ -437,6 +445,7 @@ impl FfiGroup { self.inner_client.as_ref(), self.group_id.clone(), self.created_at_ns, + self.added_by_address.clone(), ); group.remove_members(account_addresses).await?; @@ -472,6 +481,7 @@ impl FfiGroup { self.inner_client.as_ref(), self.group_id.clone(), self.created_at_ns, + self.added_by_address.clone(), ); Ok(group.is_active()?) @@ -482,6 +492,7 @@ impl FfiGroup { self.inner_client.as_ref(), self.group_id.clone(), self.created_at_ns, + self.added_by_address.clone(), ); let metadata = group.metadata()?; @@ -1108,4 +1119,45 @@ mod tests { tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; assert!(stream_closer.is_closed()); } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn test_group_who_added_me() { + // Create Clients + let amal = new_test_client().await; + let bola = new_test_client().await; + + // Amal creates a group and adds Bola to the group + amal.conversations() + .create_group(vec![bola.account_address()], None) + .await + .unwrap(); + + // Bola syncs groups - this will decrypt the Welcome, identify who added Bola + // and then store that value on the group and insert into the database + let bola_conversations = bola.conversations(); + let _ = bola_conversations.sync().await; + + // Bola gets the group id. This will be needed to fetch the group from + // the database. + let bola_groups = bola_conversations + .list(crate::FfiListConversationsOptions { + created_after_ns: None, + created_before_ns: None, + limit: None, + }) + .await + .unwrap(); + + let bola_group = bola_groups.first().unwrap(); + + // Check Bola's group for the added_by_address of the inviter + let added_by_address = bola_group.added_by_address.clone().unwrap(); + + // // Verify the welcome host_credential is equal to Amal's + assert_eq!( + amal.account_address(), + added_by_address, + "The Inviter and added_by_address do not match!" + ); + } } diff --git a/xmtp_mls/migrations/2024-04-08-180113_group_added_by_address/down.sql b/xmtp_mls/migrations/2024-04-08-180113_group_added_by_address/down.sql new file mode 100644 index 000000000..79eca70df --- /dev/null +++ b/xmtp_mls/migrations/2024-04-08-180113_group_added_by_address/down.sql @@ -0,0 +1,9 @@ +-- As SQLite does not support ALTER, we play this game of move, repopulate, drop. Here we recreate without the 'added_by_address' column. +BEGIN TRANSACTION; +CREATE TEMPORARY TABLE backup_group(id BLOB PRIMARY KEY NOT NULL, created_at_ns BIGINT NOT NULL, membership_state INT NOT NULL, installations_last_checked BIGINT NOT NULL, purpose INT NOT NULL DEFAULT 1); +INSERT INTO backup_group SELECT id, created_at_ns, membership_state, installations_last_checked, pupose FROM groups; +DROP TABLE groups; +CREATE TABLE groups(id BLOB PRIMARY KEY NOT NULL, created_at_ns BIGINT NOT NULL, membership_state INT NOT NULL, installations_last_checked BIGINT NOT NULL, purpose INT NOT NULL DEFAULT 1); +INSERT INTO groups SELECT id, created_at_ns, membership_state, installations_last_checked, purpose FROM backup_group; +DROP TABLE backup_group; +COMMIT; diff --git a/xmtp_mls/migrations/2024-04-08-180113_group_added_by_address/up.sql b/xmtp_mls/migrations/2024-04-08-180113_group_added_by_address/up.sql new file mode 100644 index 000000000..9d32617d6 --- /dev/null +++ b/xmtp_mls/migrations/2024-04-08-180113_group_added_by_address/up.sql @@ -0,0 +1,2 @@ +ALTER TABLE groups +ADD COLUMN added_by_address TEXT diff --git a/xmtp_mls/src/client.rs b/xmtp_mls/src/client.rs index 6972476e0..91f367374 100644 --- a/xmtp_mls/src/client.rs +++ b/xmtp_mls/src/client.rs @@ -206,8 +206,13 @@ where ) -> Result, ClientError> { log::info!("creating group"); - let group = MlsGroup::create_and_insert(self, GroupMembershipState::Allowed, permissions) - .map_err(|e| ClientError::Generic(format!("group create error {}", e)))?; + let group = MlsGroup::create_and_insert( + self, + GroupMembershipState::Allowed, + permissions, + Some(self.account_address()), + ) + .map_err(|e| ClientError::Generic(format!("group create error {}", e)))?; Ok(group) } @@ -218,7 +223,12 @@ where let conn = &mut self.store.conn()?; let stored_group: Option = conn.fetch(&group_id)?; match stored_group { - Some(group) => Ok(MlsGroup::new(self, group.id, group.created_at_ns)), + Some(group) => Ok(MlsGroup::new( + self, + group.id, + group.created_at_ns, + group.added_by_address, + )), None => Err(ClientError::Generic("group not found".to_string())), } } @@ -242,7 +252,14 @@ where .conn()? .find_groups(allowed_states, created_after_ns, created_before_ns, limit)? .into_iter() - .map(|stored_group| MlsGroup::new(self, stored_group.id, stored_group.created_at_ns)) + .map(|stored_group| { + MlsGroup::new( + self, + stored_group.id, + stored_group.created_at_ns, + stored_group.added_by_address, + ) + }) .collect()) } diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index adcdb5e90..04eaca535 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -9,11 +9,13 @@ pub mod validated_commit; use intents::SendMessageIntentData; use openmls::{ + credentials::BasicCredential, + error::LibraryError, extensions::{Extension, Extensions, Metadata}, group::{MlsGroupCreateConfig, MlsGroupJoinConfig}, prelude::{ - CredentialWithKey, CryptoConfig, Error as TlsCodecError, GroupId, MlsGroup as OpenMlsGroup, - StagedWelcome, Welcome as MlsWelcome, WireFormatPolicy, + BasicCredentialError, CredentialWithKey, CryptoConfig, Error as TlsCodecError, GroupId, + MlsGroup as OpenMlsGroup, StagedWelcome, Welcome as MlsWelcome, WireFormatPolicy, }, }; use openmls_traits::OpenMlsProvider; @@ -119,6 +121,10 @@ pub enum GroupError { Identity(#[from] IdentityError), #[error("serialization error: {0}")] EncodeError(#[from] prost::EncodeError), + #[error("Credential error")] + CredentialError(#[from] BasicCredentialError), + #[error("LeafNode error")] + LeafNodeError(#[from] LibraryError), } impl RetryableError for GroupError { @@ -140,6 +146,7 @@ impl RetryableError for GroupError { pub struct MlsGroup<'c, ApiClient> { pub group_id: Vec, pub created_at_ns: i64, + pub added_by_address: Option, client: &'c Client, } @@ -149,6 +156,7 @@ impl<'c, ApiClient> Clone for MlsGroup<'c, ApiClient> { client: self.client, group_id: self.group_id.clone(), created_at_ns: self.created_at_ns, + added_by_address: self.added_by_address.clone(), } } } @@ -158,11 +166,17 @@ where ApiClient: XmtpMlsClient, { // Creates a new group instance. Does not validate that the group exists in the DB - pub fn new(client: &'c Client, group_id: Vec, created_at_ns: i64) -> Self { + pub fn new( + client: &'c Client, + group_id: Vec, + created_at_ns: i64, + added_by_address: Option, + ) -> Self { Self { client, group_id, created_at_ns, + added_by_address, } } @@ -180,6 +194,7 @@ where client: &'c Client, membership_state: GroupMembershipState, permissions: Option, + added_by_address: Option, ) -> Result { let conn = client.store.conn()?; let provider = XmtpOpenMlsProvider::new(&conn); @@ -201,9 +216,20 @@ where mls_group.save(provider.key_store())?; let group_id = mls_group.group_id().to_vec(); - let stored_group = StoredGroup::new(group_id.clone(), now_ns(), membership_state); + let stored_group = StoredGroup::new( + group_id.clone(), + now_ns(), + membership_state, + added_by_address.clone(), + ); + stored_group.store(provider.conn())?; - Ok(Self::new(client, group_id, stored_group.created_at_ns)) + Ok(Self::new( + client, + group_id, + stored_group.created_at_ns, + added_by_address, + )) } // Create a group from a decrypted and decoded welcome message @@ -212,6 +238,7 @@ where client: &'c Client, provider: &XmtpOpenMlsProvider, welcome: MlsWelcome, + added_by_address: Option, ) -> Result { let mls_welcome = StagedWelcome::new_from_welcome(provider, &build_group_join_config(), welcome, None)?; @@ -220,13 +247,19 @@ where mls_group.save(provider.key_store())?; let group_id = mls_group.group_id().to_vec(); - let to_store = StoredGroup::new(group_id, now_ns(), GroupMembershipState::Pending); + let to_store = StoredGroup::new( + group_id, + now_ns(), + GroupMembershipState::Pending, + added_by_address.clone(), + ); let stored_group = provider.conn().insert_or_ignore_group(to_store)?; Ok(Self::new( client, stored_group.id, stored_group.created_at_ns, + added_by_address, )) } @@ -241,7 +274,18 @@ where let welcome = deserialize_welcome(&welcome_bytes)?; - Self::create_from_welcome(client, provider, welcome) + let join_config = build_group_join_config(); + let staged_welcome = + StagedWelcome::new_from_welcome(provider, &join_config, welcome.clone(), None)?; + + let added_by_node = staged_welcome.welcome_sender()?; + + let added_by_credential = BasicCredential::try_from(added_by_node.credential())?; + let pub_key_bytes = added_by_node.signature_key().as_slice(); + let account_address = + Identity::get_validated_account_address(added_by_credential.identity(), pub_key_bytes)?; + + Self::create_from_welcome(client, provider, welcome, Some(account_address)) } fn into_envelope(encoded_msg: &[u8], idempotency_key: &str) -> PlaintextEnvelope { @@ -960,4 +1004,42 @@ mod tests { .await .is_err(),); } + + #[tokio::test] + async fn test_staged_welcome() { + // Create Clients + let amal = ClientBuilder::new_test_client(&generate_local_wallet()).await; + let bola = ClientBuilder::new_test_client(&generate_local_wallet()).await; + + // Amal creates a group + let amal_group = amal.create_group(None).unwrap(); + + // Amal adds Bola to the group + amal_group + .add_members_by_installation_id(vec![bola.installation_public_key()]) + .await + .unwrap(); + + // Bola syncs groups - this will decrypt the Welcome, identify who added Bola + // and then store that value on the group and insert into the database + let bola_groups = bola.sync_welcomes().await.unwrap(); + + // Bola gets the group id. This will be needed to fetch the group from + // the database. + let bola_group = bola_groups.first().unwrap(); + let bola_group_id = bola_group.group_id.clone(); + + // Bola fetches group from the database + let bola_fetched_group = bola.group(bola_group_id).unwrap(); + + // Check Bola's group for the added_by_address of the inviter + let added_by_address = bola_fetched_group.added_by_address.clone().unwrap(); + + // Verify the welcome host_credential is equal to Amal's + assert_eq!( + amal.account_address(), + added_by_address, + "The Inviter and added_by_address do not match!" + ); + } } diff --git a/xmtp_mls/src/storage/encrypted_store/group.rs b/xmtp_mls/src/storage/encrypted_store/group.rs index fae5f7a03..5a0a318dc 100644 --- a/xmtp_mls/src/storage/encrypted_store/group.rs +++ b/xmtp_mls/src/storage/encrypted_store/group.rs @@ -34,6 +34,8 @@ pub struct StoredGroup { pub installations_last_checked: i64, /// Enum, [`Purpose`] signifies the group purpose which extends to who can access it. pub purpose: Purpose, + /// String representing the wallet address of the who added the user to a group. + pub added_by_address: Option, } impl_fetch!(StoredGroup, groups, Vec); @@ -41,13 +43,19 @@ impl_store!(StoredGroup, groups); impl StoredGroup { /// Create a new [`Purpose::Conversation`] group. This is the default type of group. - pub fn new(id: ID, created_at_ns: i64, membership_state: GroupMembershipState) -> Self { + pub fn new( + id: ID, + created_at_ns: i64, + membership_state: GroupMembershipState, + added_by_address: Option, + ) -> Self { Self { id, created_at_ns, membership_state, installations_last_checked: 0, purpose: Purpose::Conversation, + added_by_address, } } @@ -63,6 +71,7 @@ impl StoredGroup { membership_state, installations_last_checked: 0, purpose: Purpose::Sync, + added_by_address: None, } } } @@ -249,7 +258,7 @@ pub(crate) mod tests { let id = rand_vec(); let created_at_ns = now_ns(); let membership_state = state.unwrap_or(GroupMembershipState::Allowed); - StoredGroup::new(id, created_at_ns, membership_state) + StoredGroup::new(id, created_at_ns, membership_state, None) } #[test] diff --git a/xmtp_mls/src/storage/encrypted_store/group_intent.rs b/xmtp_mls/src/storage/encrypted_store/group_intent.rs index a7bf59eed..ac0bca57d 100644 --- a/xmtp_mls/src/storage/encrypted_store/group_intent.rs +++ b/xmtp_mls/src/storage/encrypted_store/group_intent.rs @@ -306,7 +306,7 @@ mod tests { }; fn insert_group(conn: &DbConnection, group_id: Vec) { - let group = StoredGroup::new(group_id, 100, GroupMembershipState::Allowed); + let group = StoredGroup::new(group_id, 100, GroupMembershipState::Allowed, None); group.store(conn).unwrap(); } diff --git a/xmtp_mls/src/storage/encrypted_store/schema.rs b/xmtp_mls/src/storage/encrypted_store/schema.rs index fddd657c9..e259ab2a4 100644 --- a/xmtp_mls/src/storage/encrypted_store/schema.rs +++ b/xmtp_mls/src/storage/encrypted_store/schema.rs @@ -33,6 +33,7 @@ diesel::table! { membership_state -> Integer, installations_last_checked -> BigInt, purpose -> Integer, + added_by_address -> Nullable, } } diff --git a/xmtp_mls/src/subscriptions.rs b/xmtp_mls/src/subscriptions.rs index 885927bf4..6b178e26b 100644 --- a/xmtp_mls/src/subscriptions.rs +++ b/xmtp_mls/src/subscriptions.rs @@ -134,7 +134,7 @@ where ), )?; // TODO update cursor - MlsGroup::new(self, group_id, stream_info.convo_created_at_ns) + MlsGroup::new(self, group_id, stream_info.convo_created_at_ns, None) .process_stream_entry(envelope) .await }