diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index a6b8734d7..2bce885dc 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -8,11 +8,12 @@ pub mod validated_commit; use intents::SendMessageIntentData; use openmls::{ + credentials::BasicCredential, extensions::{Extension, Extensions, Metadata}, group::{MlsGroupCreateConfig, MlsGroupJoinConfig}, prelude::{ - CredentialWithKey, CryptoConfig, Error as TlsCodecError, GroupId, MlsGroup as OpenMlsGroup, - StagedWelcome, Welcome as MlsWelcome, WireFormatPolicy, + BasicCredentialError, CredentialWithKey, CryptoConfig, Error as TlsCodecError, GroupId, + MlsGroup as OpenMlsGroup, StagedWelcome, Welcome as MlsWelcome, WireFormatPolicy, }, }; use openmls_traits::OpenMlsProvider; @@ -118,6 +119,8 @@ pub enum GroupError { Identity(#[from] IdentityError), #[error("serialization error: {0}")] EncodeError(#[from] prost::EncodeError), + #[error("Credential error")] + CredentialError(#[from] BasicCredentialError), } impl RetryableError for GroupError { @@ -139,6 +142,7 @@ impl RetryableError for GroupError { pub struct MlsGroup<'c, ApiClient> { pub group_id: Vec, pub created_at_ns: i64, + pub host_id: Option>, client: &'c Client, } @@ -148,6 +152,7 @@ impl<'c, ApiClient> Clone for MlsGroup<'c, ApiClient> { client: self.client, group_id: self.group_id.clone(), created_at_ns: self.created_at_ns, + host_id: self.host_id.clone(), } } } @@ -162,6 +167,17 @@ where client, group_id, created_at_ns, + host_id: None, + } + } + + // Creates a new group instance with the Welcome host_id. Does not validate that the group exists in the DB + pub fn new_with_host_id(client: &'c Client, group_id: Vec, created_at_ns: i64, host_id: Option>) -> Self { + Self { + client, + group_id, + created_at_ns, + host_id, } } @@ -214,6 +230,7 @@ where client: &'c Client, provider: &XmtpOpenMlsProvider, welcome: MlsWelcome, + host_id: Option>, ) -> Result { let mls_welcome = StagedWelcome::new_from_welcome(provider, &build_group_join_config(), welcome, None)?; @@ -225,10 +242,11 @@ where let to_store = StoredGroup::new(group_id, now_ns(), GroupMembershipState::Pending); let stored_group = provider.conn().insert_or_ignore_group(to_store)?; - Ok(Self::new( + Ok(Self::new_with_host_id( client, stored_group.id, stored_group.created_at_ns, + host_id, )) } @@ -259,9 +277,10 @@ where .welcome_sender() .expect("couldn't determine the sender of welcome"); - println!("{:?}", welcome_sender.credential()); + let host_credential = BasicCredential::try_from(welcome_sender.credential())?; + let host_id = host_credential.identity().to_vec(); - Self::create_from_welcome(client, provider, welcome) + Self::create_from_welcome(client, provider, welcome, Some(host_id)) } fn add_idempotency_key(encoded_msg: &[u8], idempotency_key: &str) -> PlaintextEnvelope { @@ -506,7 +525,10 @@ fn build_group_join_config() -> MlsGroupJoinConfig { #[cfg(test)] mod tests { - use openmls::prelude::Member; + use openmls::{ + credentials::BasicCredential, + prelude::{Credential, Member} + }; use prost::Message; use xmtp_api_grpc::grpc_api_helper::Client as GrpcClient; use xmtp_cryptography::utils::generate_local_wallet; @@ -982,6 +1004,7 @@ mod tests { #[tokio::test] async fn test_staged_welcome() { + // Create Clients let amal = ClientBuilder::new_test_client(&generate_local_wallet()).await; let bola = ClientBuilder::new_test_client(&generate_local_wallet()).await; @@ -993,7 +1016,21 @@ mod tests { .unwrap(); // Get bola's version of the same group - let bola_groups = bola.sync_welcomes().await.unwrap(); + let bola_groups = bola + .sync_welcomes() + .await + .unwrap(); let bola_group = bola_groups.first().unwrap(); + + // Check Bola's group for the welcome host_id + let host_id: Vec = bola_group.host_id.clone().unwrap(); + let host_basic_credential = BasicCredential::new(host_id).unwrap(); + let host_credential = Credential::from(host_basic_credential); + + // Verify the welcome host_credential is equal to Amal's + assert_eq!(amal.identity + .credential() + .unwrap(), + host_credential); } }