Skip to content

Commit

Permalink
Add functions for parsing commits (#354)
Browse files Browse the repository at this point in the history
* Add functions for parsing commits

* Cleanup

* Remove useless clone

* Lint

* Use allow(dead_code) directive

* Remove commented out code

* Rename to account_address

* Update to latest protos

* Remove dead code
  • Loading branch information
neekolas authored Dec 11, 2023
1 parent bdbe0ab commit 9e4525a
Show file tree
Hide file tree
Showing 9 changed files with 435 additions and 169 deletions.
2 changes: 1 addition & 1 deletion examples/cli/cli-client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ async fn main() {
.members()
.unwrap()
.into_iter()
.map(|m| m.wallet_address)
.map(|m| m.account_address)
.collect::<Vec<String>>()
.join("\n"),
);
Expand Down
19 changes: 10 additions & 9 deletions xmtp_mls/src/codecs/membership_change.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::collections::HashMap;

use prost::Message;
use xmtp_proto::xmtp::mls::message_contents::{
ContentTypeId, EncodedContent, GroupMembershipChange,
ContentTypeId, EncodedContent, GroupMembershipChanges,
};

use super::{CodecError, ContentCodec};
Expand All @@ -14,7 +14,7 @@ impl GroupMembershipChangeCodec {
const TYPE_ID: &'static str = "group_membership_change";
}

impl ContentCodec<GroupMembershipChange> for GroupMembershipChangeCodec {
impl ContentCodec<GroupMembershipChanges> for GroupMembershipChangeCodec {
fn content_type() -> ContentTypeId {
ContentTypeId {
authority_id: GroupMembershipChangeCodec::AUTHORITY_ID.to_string(),
Expand All @@ -24,7 +24,7 @@ impl ContentCodec<GroupMembershipChange> for GroupMembershipChangeCodec {
}
}

fn encode(data: GroupMembershipChange) -> Result<EncodedContent, CodecError> {
fn encode(data: GroupMembershipChanges) -> Result<EncodedContent, CodecError> {
let mut buf = Vec::new();
data.encode(&mut buf)
.map_err(|e| CodecError::Encode(e.to_string()))?;
Expand All @@ -38,8 +38,8 @@ impl ContentCodec<GroupMembershipChange> for GroupMembershipChangeCodec {
})
}

fn decode(content: EncodedContent) -> Result<GroupMembershipChange, CodecError> {
let decoded = GroupMembershipChange::decode(content.content.as_slice())
fn decode(content: EncodedContent) -> Result<GroupMembershipChanges, CodecError> {
let decoded = GroupMembershipChanges::decode(content.content.as_slice())
.map_err(|e| CodecError::Decode(e.to_string()))?;

Ok(decoded)
Expand All @@ -48,19 +48,20 @@ impl ContentCodec<GroupMembershipChange> for GroupMembershipChangeCodec {

#[cfg(test)]
mod tests {
use xmtp_proto::xmtp::mls::message_contents::Member;
use xmtp_proto::xmtp::mls::message_contents::MembershipChange;

use crate::utils::test::{rand_string, rand_vec};

use super::*;

#[test]
fn test_encode_decode() {
let new_member = Member {
let new_member = MembershipChange {
installation_ids: vec![rand_vec()],
wallet_address: rand_string(),
account_address: rand_string(),
initiated_by_account_address: "".to_string(),
};
let data = GroupMembershipChange {
let data = GroupMembershipChanges {
members_added: vec![new_member.clone()],
members_removed: vec![],
installations_added: vec![],
Expand Down
12 changes: 6 additions & 6 deletions xmtp_mls/src/groups/members.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use super::{GroupError, MlsGroup};

#[derive(Debug, Clone)]
pub struct GroupMember {
pub wallet_address: String,
pub account_address: String,
pub installation_ids: Vec<Vec<u8>>,
}

Expand All @@ -33,11 +33,11 @@ where
})
.fold(
HashMap::new(),
|mut acc, (wallet_address, signature_key)| {
acc.entry(wallet_address.clone())
|mut acc, (account_address, signature_key)| {
acc.entry(account_address.clone())
.and_modify(|e| e.installation_ids.push(signature_key.clone()))
.or_insert(GroupMember {
wallet_address,
account_address,
installation_ids: vec![signature_key],
});
acc
Expand Down Expand Up @@ -79,10 +79,10 @@ mod tests {
assert_eq!(members.len(), 2);

for member in members {
if member.wallet_address.eq(&amal.account_address()) {
if member.account_address.eq(&amal.account_address()) {
assert_eq!(member.installation_ids.len(), 1);
}
if member.wallet_address.eq(&bola_a.account_address()) {
if member.account_address.eq(&bola_a.account_address()) {
assert_eq!(member.installation_ids.len(), 2);
}
}
Expand Down
212 changes: 212 additions & 0 deletions xmtp_mls/src/groups/membership_change.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
use std::collections::HashMap;

use openmls::{
group::{QueuedAddProposal, QueuedRemoveProposal},
prelude::{MlsGroup as OpenMlsGroup, StagedCommit},
};
use xmtp_proto::{
api_client::{XmtpApiClient, XmtpMlsClient},
xmtp::mls::message_contents::{GroupMembershipChanges, MembershipChange as MemberProto},
};

use crate::identity::Identity;

use super::{GroupError, MlsGroup};

// Take a QueuedAddProposal and extract the wallet address and installation_id
fn extract_identity_from_add(proposal: QueuedAddProposal) -> Option<(String, Vec<u8>)> {
let leaf_node = proposal.add_proposal().key_package().leaf_node();
let signature_key = leaf_node.signature_key().as_slice();
match Identity::get_validated_account_address(leaf_node.credential().identity(), signature_key)
{
Ok(account_address) => Some((account_address, signature_key.to_vec())),
Err(err) => {
log::warn!("error extracting identity {}", err);
None
}
}
}

// Take a QueuedRemoveProposal and extract the wallet address and installation_id
fn extract_identity_from_remove(
proposal: QueuedRemoveProposal,
group: &OpenMlsGroup,
) -> Option<(String, Vec<u8>)> {
let leaf_index = proposal.remove_proposal().removed();
let maybe_member = group.member_at(leaf_index);
if maybe_member.is_none() {
log::warn!("could not find removed member");
return None;
}
let member = maybe_member.expect("already checked");
let signature_key = member.signature_key.as_slice();
match Identity::get_validated_account_address(member.credential.identity(), signature_key) {
Ok(account_address) => Some((account_address, signature_key.to_vec())),
Err(err) => {
log::warn!("error extracting identity {}", err);
None
}
}
}

// Reducer function for merging members into a map, with all installation_ids collected per member
fn merge_members(
mut acc: HashMap<String, MemberProto>,
(account_address, signature_key): (String, Vec<u8>),
) -> HashMap<String, MemberProto> {
acc.entry(account_address.clone())
.and_modify(|entry| entry.installation_ids.push(signature_key.clone()))
.or_insert(MemberProto {
account_address,
installation_ids: vec![signature_key],
initiated_by_account_address: "".to_string(),
});
acc
}

// Get a tuple of (new_members, new_installations), each formatted as a Member object with all installation_ids grouped
fn get_new_members(
staged_commit: &StagedCommit,
existing_installation_ids: &HashMap<String, Vec<Vec<u8>>>,
) -> (Vec<MemberProto>, Vec<MemberProto>) {
let new_installations: HashMap<String, MemberProto> = staged_commit
.add_proposals()
.filter_map(extract_identity_from_add)
.fold(HashMap::new(), merge_members);

// Partition the list. If no existing member found, it is a new member. Otherwise it is just new installations
new_installations
.into_values()
.partition(|member| !existing_installation_ids.contains_key(&member.account_address))
}

// Get a tuple of (removed_members, removed_installations)
fn get_removed_members(
staged_commit: &StagedCommit,
existing_installation_ids: &HashMap<String, Vec<Vec<u8>>>,
openmls_group: &OpenMlsGroup,
) -> (Vec<MemberProto>, Vec<MemberProto>) {
let removed_installations: HashMap<String, MemberProto> = staged_commit
.remove_proposals()
.filter_map(|proposal| extract_identity_from_remove(proposal, openmls_group))
.fold(HashMap::new(), merge_members);

// Separate the fully removed members (where all installation ids were removed in the commit) from partial removals
removed_installations.into_values().partition(|member| {
match existing_installation_ids.get(&member.account_address) {
Some(entry) => entry.len() == member.installation_ids.len(),
None => true,
}
})
}

impl<'c, ApiClient> MlsGroup<'c, ApiClient>
where
ApiClient: XmtpApiClient + XmtpMlsClient,
{
#[allow(dead_code)]
pub(crate) fn build_group_membership_change(
&self,
staged_commit: &StagedCommit,
openmls_group: &OpenMlsGroup,
) -> Result<GroupMembershipChanges, GroupError> {
// Existing installation IDs keyed by wallet address
let existing_installation_ids: HashMap<String, Vec<Vec<u8>>> = self
.members()?
.into_iter()
.fold(HashMap::new(), |mut acc, curr| {
acc.insert(curr.account_address, curr.installation_ids);
acc
});

let (members_added, installations_added) =
get_new_members(staged_commit, &existing_installation_ids);

let (members_removed, installations_removed) =
get_removed_members(staged_commit, &existing_installation_ids, openmls_group);

Ok(GroupMembershipChanges {
members_added,
members_removed,
installations_added,
installations_removed,
})
}
}

#[cfg(test)]
mod tests {
use openmls::prelude_test::KeyPackage;
use xmtp_api_grpc::Client as GrpcClient;
use xmtp_cryptography::utils::generate_local_wallet;

use crate::{builder::ClientBuilder, Client};

fn get_key_package(client: &Client<GrpcClient>) -> KeyPackage {
client
.identity
.new_key_package(&client.mls_provider(&mut client.store.conn().unwrap()))
.unwrap()
}

#[tokio::test]
async fn test_membership_changes() {
let amal = ClientBuilder::new_test_client(generate_local_wallet().into()).await;
let bola = ClientBuilder::new_test_client(generate_local_wallet().into()).await;
let bola_key_package = get_key_package(&bola);

let amal_group = amal.create_group().unwrap();
let mut amal_conn = amal.store.conn().unwrap();
let amal_provider = amal.mls_provider(&mut amal_conn);
let mut mls_group = amal_group.load_mls_group(&amal_provider).unwrap();
// Create a pending commit to add bola to the group
mls_group
.add_members(
&amal_provider,
&amal.identity.installation_keys,
&[bola_key_package],
)
.unwrap();

let mut staged_commit = mls_group.pending_commit().unwrap();

let message = amal_group
.build_group_membership_change(staged_commit, &mls_group)
.unwrap();

assert_eq!(message.installations_added.len(), 0);
assert_eq!(message.members_added.len(), 1);
assert_eq!(
message.members_added[0].account_address,
bola.account_address()
);

// Merge the commit adding bola
mls_group.merge_pending_commit(&amal_provider).unwrap();
// Now we are going to remove bola

let bola_leaf_node = mls_group
.members()
.find(|m| {
m.signature_key
.eq(&bola.identity.installation_keys.public())
})
.unwrap()
.index;
mls_group
.remove_members(
&amal_provider,
&amal.identity.installation_keys,
&[bola_leaf_node],
)
.unwrap();

staged_commit = mls_group.pending_commit().unwrap();
let remove_message = amal_group
.build_group_membership_change(staged_commit, &mls_group)
.unwrap();

assert_eq!(remove_message.members_removed.len(), 1);
assert_eq!(remove_message.installations_removed.len(), 0);
}
}
7 changes: 5 additions & 2 deletions xmtp_mls/src/groups/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod intents;
mod members;
mod membership_change;
use intents::SendMessageIntentData;
#[cfg(not(test))]
use log::debug;
Expand Down Expand Up @@ -321,7 +322,9 @@ where
"[{}] received staged commit. Merging and clearing any pending commits",
self.client.account_address()
);
openmls_group.merge_staged_commit(provider, *staged_commit)?;

let sc = *staged_commit;
openmls_group.merge_staged_commit(provider, sc)?;
}
};

Expand Down Expand Up @@ -478,7 +481,7 @@ where
let installation_ids = self
.members()?
.into_iter()
.filter(|member| wallet_addresses.contains(&member.wallet_address))
.filter(|member| wallet_addresses.contains(&member.account_address))
.fold(vec![], |mut acc, member| {
acc.extend(member.installation_ids);
acc
Expand Down
Loading

0 comments on commit 9e4525a

Please sign in to comment.