Skip to content

Commit

Permalink
Store Welcome Sender’s Identity on Group
Browse files Browse the repository at this point in the history
  • Loading branch information
zombieobject committed Apr 5, 2024
1 parent 8e1a9f2 commit 043f9f9
Showing 1 changed file with 44 additions and 7 deletions.
51 changes: 44 additions & 7 deletions xmtp_mls/src/groups/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@ pub mod validated_commit;

use intents::SendMessageIntentData;
use openmls::{
credentials::BasicCredential,
extensions::{Extension, Extensions, Metadata},
group::{MlsGroupCreateConfig, MlsGroupJoinConfig},
prelude::{
CredentialWithKey, CryptoConfig, Error as TlsCodecError, GroupId, MlsGroup as OpenMlsGroup,
StagedWelcome, Welcome as MlsWelcome, WireFormatPolicy,
BasicCredentialError, CredentialWithKey, CryptoConfig, Error as TlsCodecError, GroupId,
MlsGroup as OpenMlsGroup, StagedWelcome, Welcome as MlsWelcome, WireFormatPolicy,
},
};
use openmls_traits::OpenMlsProvider;
Expand Down Expand Up @@ -118,6 +119,8 @@ pub enum GroupError {
Identity(#[from] IdentityError),
#[error("serialization error: {0}")]
EncodeError(#[from] prost::EncodeError),
#[error("Credential error")]
CredentialError(#[from] BasicCredentialError),
}

impl RetryableError for GroupError {
Expand All @@ -139,6 +142,7 @@ impl RetryableError for GroupError {
pub struct MlsGroup<'c, ApiClient> {
pub group_id: Vec<u8>,
pub created_at_ns: i64,
pub host_id: Option<Vec<u8>>,
client: &'c Client<ApiClient>,
}

Expand All @@ -148,6 +152,7 @@ impl<'c, ApiClient> Clone for MlsGroup<'c, ApiClient> {
client: self.client,
group_id: self.group_id.clone(),
created_at_ns: self.created_at_ns,
host_id: self.host_id.clone(),
}
}
}
Expand All @@ -162,6 +167,17 @@ where
client,
group_id,
created_at_ns,
host_id: None,
}
}

// Creates a new group instance with the Welcome host_id. Does not validate that the group exists in the DB
pub fn new_with_host_id(client: &'c Client<ApiClient>, group_id: Vec<u8>, created_at_ns: i64, host_id: Option<Vec<u8>>) -> Self {
Self {
client,
group_id,
created_at_ns,
host_id,
}
}

Expand Down Expand Up @@ -214,6 +230,7 @@ where
client: &'c Client<ApiClient>,
provider: &XmtpOpenMlsProvider,
welcome: MlsWelcome,
host_id: Option<Vec<u8>>,
) -> Result<Self, GroupError> {
let mls_welcome =
StagedWelcome::new_from_welcome(provider, &build_group_join_config(), welcome, None)?;
Expand All @@ -225,10 +242,11 @@ where
let to_store = StoredGroup::new(group_id, now_ns(), GroupMembershipState::Pending);
let stored_group = provider.conn().insert_or_ignore_group(to_store)?;

Ok(Self::new(
Ok(Self::new_with_host_id(
client,
stored_group.id,
stored_group.created_at_ns,
host_id,
))
}

Expand Down Expand Up @@ -259,9 +277,10 @@ where
.welcome_sender()
.expect("couldn't determine the sender of welcome");

println!("{:?}", welcome_sender.credential());
let host_credential = BasicCredential::try_from(welcome_sender.credential())?;
let host_id = host_credential.identity().to_vec();

Self::create_from_welcome(client, provider, welcome)
Self::create_from_welcome(client, provider, welcome, Some(host_id))
}

fn add_idempotency_key(encoded_msg: &[u8], idempotency_key: &str) -> PlaintextEnvelope {
Expand Down Expand Up @@ -506,7 +525,10 @@ fn build_group_join_config() -> MlsGroupJoinConfig {

#[cfg(test)]
mod tests {
use openmls::prelude::Member;
use openmls::{
credentials::BasicCredential,
prelude::{Credential, Member}
};
use prost::Message;
use xmtp_api_grpc::grpc_api_helper::Client as GrpcClient;
use xmtp_cryptography::utils::generate_local_wallet;
Expand Down Expand Up @@ -982,6 +1004,7 @@ mod tests {

#[tokio::test]
async fn test_staged_welcome() {
// Create Clients
let amal = ClientBuilder::new_test_client(&generate_local_wallet()).await;
let bola = ClientBuilder::new_test_client(&generate_local_wallet()).await;

Expand All @@ -993,7 +1016,21 @@ mod tests {
.unwrap();

// Get bola's version of the same group
let bola_groups = bola.sync_welcomes().await.unwrap();
let bola_groups = bola
.sync_welcomes()
.await
.unwrap();
let bola_group = bola_groups.first().unwrap();

// Check Bola's group for the welcome host_id
let host_id: Vec<u8> = bola_group.host_id.clone().unwrap();
let host_basic_credential = BasicCredential::new(host_id).unwrap();
let host_credential = Credential::from(host_basic_credential);

// Verify the welcome host_credential is equal to Amal's
assert_eq!(amal.identity
.credential()
.unwrap(),
host_credential);
}
}

0 comments on commit 043f9f9

Please sign in to comment.