Skip to content

Commit

Permalink
Merge branch 'main' of github.com:xmtp/libxmtp into insipx/357/key-pa…
Browse files Browse the repository at this point in the history
…ckage-to-wallet-address
  • Loading branch information
insipx committed Dec 11, 2023
2 parents cea2db4 + 9e4525a commit 8b9a82d
Show file tree
Hide file tree
Showing 5 changed files with 256 additions and 9 deletions.
2 changes: 1 addition & 1 deletion examples/cli/cli-client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ async fn main() {
.members()
.unwrap()
.into_iter()
.map(|m| m.wallet_address)
.map(|m| m.account_address)
.collect::<Vec<String>>()
.join("\n"),
);
Expand Down
34 changes: 33 additions & 1 deletion xmtp_mls/src/codecs/membership_change.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,36 @@ impl ContentCodec<GroupMembershipChanges> for GroupMembershipChangeCodec {
}

#[cfg(test)]
mod tests {}

mod tests {
use xmtp_proto::xmtp::mls::message_contents::MembershipChange;

use crate::utils::test::{rand_string, rand_vec};

use super::*;

#[test]
fn test_encode_decode() {
let new_member = MembershipChange {
installation_ids: vec![rand_vec()],
account_address: rand_string(),
initiated_by_account_address: "".to_string(),
};
let data = GroupMembershipChanges {
members_added: vec![new_member.clone()],
members_removed: vec![],
installations_added: vec![],
installations_removed: vec![],
};

let encoded = GroupMembershipChangeCodec::encode(data).unwrap();
assert_eq!(
encoded.clone().r#type.unwrap().type_id,
"group_membership_change"
);
assert!(!encoded.content.is_empty());

let decoded = GroupMembershipChangeCodec::decode(encoded).unwrap();
assert_eq!(decoded.members_added[0], new_member);
}
}
12 changes: 6 additions & 6 deletions xmtp_mls/src/groups/members.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use super::{GroupError, MlsGroup};

#[derive(Debug, Clone)]
pub struct GroupMember {
pub wallet_address: String,
pub account_address: String,
pub installation_ids: Vec<Vec<u8>>,
}

Expand All @@ -33,11 +33,11 @@ where
})
.fold(
HashMap::new(),
|mut acc, (wallet_address, signature_key)| {
acc.entry(wallet_address.clone())
|mut acc, (account_address, signature_key)| {
acc.entry(account_address.clone())
.and_modify(|e| e.installation_ids.push(signature_key.clone()))
.or_insert(GroupMember {
wallet_address,
account_address,
installation_ids: vec![signature_key],
});
acc
Expand Down Expand Up @@ -79,10 +79,10 @@ mod tests {
assert_eq!(members.len(), 2);

for member in members {
if member.wallet_address.eq(&amal.account_address()) {
if member.account_address.eq(&amal.account_address()) {
assert_eq!(member.installation_ids.len(), 1);
}
if member.wallet_address.eq(&bola_a.account_address()) {
if member.account_address.eq(&bola_a.account_address()) {
assert_eq!(member.installation_ids.len(), 2);
}
}
Expand Down
212 changes: 212 additions & 0 deletions xmtp_mls/src/groups/membership_change.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
use std::collections::HashMap;

use openmls::{
group::{QueuedAddProposal, QueuedRemoveProposal},
prelude::{MlsGroup as OpenMlsGroup, StagedCommit},
};
use xmtp_proto::{
api_client::{XmtpApiClient, XmtpMlsClient},
xmtp::mls::message_contents::{GroupMembershipChanges, MembershipChange as MemberProto},
};

use crate::identity::Identity;

use super::{GroupError, MlsGroup};

// Take a QueuedAddProposal and extract the wallet address and installation_id
fn extract_identity_from_add(proposal: QueuedAddProposal) -> Option<(String, Vec<u8>)> {
let leaf_node = proposal.add_proposal().key_package().leaf_node();
let signature_key = leaf_node.signature_key().as_slice();
match Identity::get_validated_account_address(leaf_node.credential().identity(), signature_key)
{
Ok(account_address) => Some((account_address, signature_key.to_vec())),
Err(err) => {
log::warn!("error extracting identity {}", err);
None
}
}
}

// Take a QueuedRemoveProposal and extract the wallet address and installation_id
fn extract_identity_from_remove(
proposal: QueuedRemoveProposal,
group: &OpenMlsGroup,
) -> Option<(String, Vec<u8>)> {
let leaf_index = proposal.remove_proposal().removed();
let maybe_member = group.member_at(leaf_index);
if maybe_member.is_none() {
log::warn!("could not find removed member");
return None;
}
let member = maybe_member.expect("already checked");
let signature_key = member.signature_key.as_slice();
match Identity::get_validated_account_address(member.credential.identity(), signature_key) {
Ok(account_address) => Some((account_address, signature_key.to_vec())),
Err(err) => {
log::warn!("error extracting identity {}", err);
None
}
}
}

// Reducer function for merging members into a map, with all installation_ids collected per member
fn merge_members(
mut acc: HashMap<String, MemberProto>,
(account_address, signature_key): (String, Vec<u8>),
) -> HashMap<String, MemberProto> {
acc.entry(account_address.clone())
.and_modify(|entry| entry.installation_ids.push(signature_key.clone()))
.or_insert(MemberProto {
account_address,
installation_ids: vec![signature_key],
initiated_by_account_address: "".to_string(),
});
acc
}

// Get a tuple of (new_members, new_installations), each formatted as a Member object with all installation_ids grouped
fn get_new_members(
staged_commit: &StagedCommit,
existing_installation_ids: &HashMap<String, Vec<Vec<u8>>>,
) -> (Vec<MemberProto>, Vec<MemberProto>) {
let new_installations: HashMap<String, MemberProto> = staged_commit
.add_proposals()
.filter_map(extract_identity_from_add)
.fold(HashMap::new(), merge_members);

// Partition the list. If no existing member found, it is a new member. Otherwise it is just new installations
new_installations
.into_values()
.partition(|member| !existing_installation_ids.contains_key(&member.account_address))
}

// Get a tuple of (removed_members, removed_installations)
fn get_removed_members(
staged_commit: &StagedCommit,
existing_installation_ids: &HashMap<String, Vec<Vec<u8>>>,
openmls_group: &OpenMlsGroup,
) -> (Vec<MemberProto>, Vec<MemberProto>) {
let removed_installations: HashMap<String, MemberProto> = staged_commit
.remove_proposals()
.filter_map(|proposal| extract_identity_from_remove(proposal, openmls_group))
.fold(HashMap::new(), merge_members);

// Separate the fully removed members (where all installation ids were removed in the commit) from partial removals
removed_installations.into_values().partition(|member| {
match existing_installation_ids.get(&member.account_address) {
Some(entry) => entry.len() == member.installation_ids.len(),
None => true,
}
})
}

impl<'c, ApiClient> MlsGroup<'c, ApiClient>
where
ApiClient: XmtpApiClient + XmtpMlsClient,
{
#[allow(dead_code)]
pub(crate) fn build_group_membership_change(
&self,
staged_commit: &StagedCommit,
openmls_group: &OpenMlsGroup,
) -> Result<GroupMembershipChanges, GroupError> {
// Existing installation IDs keyed by wallet address
let existing_installation_ids: HashMap<String, Vec<Vec<u8>>> = self
.members()?
.into_iter()
.fold(HashMap::new(), |mut acc, curr| {
acc.insert(curr.account_address, curr.installation_ids);
acc
});

let (members_added, installations_added) =
get_new_members(staged_commit, &existing_installation_ids);

let (members_removed, installations_removed) =
get_removed_members(staged_commit, &existing_installation_ids, openmls_group);

Ok(GroupMembershipChanges {
members_added,
members_removed,
installations_added,
installations_removed,
})
}
}

#[cfg(test)]
mod tests {
use openmls::prelude_test::KeyPackage;
use xmtp_api_grpc::Client as GrpcClient;
use xmtp_cryptography::utils::generate_local_wallet;

use crate::{builder::ClientBuilder, Client};

fn get_key_package(client: &Client<GrpcClient>) -> KeyPackage {
client
.identity
.new_key_package(&client.mls_provider(&mut client.store.conn().unwrap()))
.unwrap()
}

#[tokio::test]
async fn test_membership_changes() {
let amal = ClientBuilder::new_test_client(generate_local_wallet().into()).await;
let bola = ClientBuilder::new_test_client(generate_local_wallet().into()).await;
let bola_key_package = get_key_package(&bola);

let amal_group = amal.create_group().unwrap();
let mut amal_conn = amal.store.conn().unwrap();
let amal_provider = amal.mls_provider(&mut amal_conn);
let mut mls_group = amal_group.load_mls_group(&amal_provider).unwrap();
// Create a pending commit to add bola to the group
mls_group
.add_members(
&amal_provider,
&amal.identity.installation_keys,
&[bola_key_package],
)
.unwrap();

let mut staged_commit = mls_group.pending_commit().unwrap();

let message = amal_group
.build_group_membership_change(staged_commit, &mls_group)
.unwrap();

assert_eq!(message.installations_added.len(), 0);
assert_eq!(message.members_added.len(), 1);
assert_eq!(
message.members_added[0].account_address,
bola.account_address()
);

// Merge the commit adding bola
mls_group.merge_pending_commit(&amal_provider).unwrap();
// Now we are going to remove bola

let bola_leaf_node = mls_group
.members()
.find(|m| {
m.signature_key
.eq(&bola.identity.installation_keys.public())
})
.unwrap()
.index;
mls_group
.remove_members(
&amal_provider,
&amal.identity.installation_keys,
&[bola_leaf_node],
)
.unwrap();

staged_commit = mls_group.pending_commit().unwrap();
let remove_message = amal_group
.build_group_membership_change(staged_commit, &mls_group)
.unwrap();

assert_eq!(remove_message.members_removed.len(), 1);
assert_eq!(remove_message.installations_removed.len(), 0);
}
}
5 changes: 4 additions & 1 deletion xmtp_mls/src/groups/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod intents;
mod members;
mod membership_change;
use intents::SendMessageIntentData;
use log::debug;
use openmls::{
Expand Down Expand Up @@ -323,7 +324,9 @@ where
"[{}] received staged commit. Merging and clearing any pending commits",
self.client.account_address()
);
openmls_group.merge_staged_commit(provider, *staged_commit)?;

let sc = *staged_commit;
openmls_group.merge_staged_commit(provider, sc)?;
}
};

Expand Down

0 comments on commit 8b9a82d

Please sign in to comment.