From bda540d71e5e5b9abaa2c423371ec117b78956d2 Mon Sep 17 00:00:00 2001 From: Nicholas Molnar <65710+neekolas@users.noreply.github.com> Date: Fri, 23 Aug 2024 15:53:54 -0700 Subject: [PATCH] Simplify aborting in streams (#989) ## Summary - When you receive a commit from a stream, we need to abort processing the message and sync instead. Previously we did that by decrypting the message and looking at the content to see what kind of message it was. - There is a better way. There was a private method on the `PrivateMessageIn` that I have now made public that will let us see what kind of message it is _before_ decrypting. - This should make streaming more performant and reliable, since we don't have to decrypt the message twice ## Related https://github.com/xmtp/openmls/pull/36 --- Cargo.lock | 12 ++--- Cargo.toml | 8 +-- bindings_ffi/Cargo.lock | 12 ++--- bindings_node/Cargo.lock | 12 ++--- xmtp_api_http/src/util.rs | 3 +- xmtp_mls/src/groups/mod.rs | 53 +++++++++++++++++-- xmtp_mls/src/groups/subscriptions.rs | 14 ++++- xmtp_mls/src/groups/sync.rs | 22 +++----- xmtp_mls/src/lib.rs | 2 +- .../storage/encrypted_store/refresh_state.rs | 8 +-- 10 files changed, 100 insertions(+), 46 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 23664eb4c..6e16dda64 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3247,7 +3247,7 @@ dependencies = [ [[package]] name = "openmls" version = "0.6.0-pre.2" -source = "git+https://github.com/xmtp/openmls?rev=cf42738018d093434c955a1b50a9de34cc12b8c5#cf42738018d093434c955a1b50a9de34cc12b8c5" +source = "git+https://github.com/xmtp/openmls?rev=87e7e257d8eb15d6662b104518becfc75ef6db76#87e7e257d8eb15d6662b104518becfc75ef6db76" dependencies = [ "backtrace", "itertools 0.10.5", @@ -3270,7 +3270,7 @@ dependencies = [ [[package]] name = "openmls_basic_credential" version = "0.3.0-pre.1" -source = "git+https://github.com/xmtp/openmls?rev=cf42738018d093434c955a1b50a9de34cc12b8c5#cf42738018d093434c955a1b50a9de34cc12b8c5" +source = "git+https://github.com/xmtp/openmls?rev=87e7e257d8eb15d6662b104518becfc75ef6db76#87e7e257d8eb15d6662b104518becfc75ef6db76" dependencies = [ "ed25519-dalek", "openmls_traits", @@ -3283,7 +3283,7 @@ dependencies = [ [[package]] name = "openmls_memory_storage" version = "0.3.0-pre.2" -source = "git+https://github.com/xmtp/openmls?rev=cf42738018d093434c955a1b50a9de34cc12b8c5#cf42738018d093434c955a1b50a9de34cc12b8c5" +source = "git+https://github.com/xmtp/openmls?rev=87e7e257d8eb15d6662b104518becfc75ef6db76#87e7e257d8eb15d6662b104518becfc75ef6db76" dependencies = [ "hex", "log", @@ -3296,7 +3296,7 @@ dependencies = [ [[package]] name = "openmls_rust_crypto" version = "0.3.0-pre.1" -source = "git+https://github.com/xmtp/openmls?rev=cf42738018d093434c955a1b50a9de34cc12b8c5#cf42738018d093434c955a1b50a9de34cc12b8c5" +source = "git+https://github.com/xmtp/openmls?rev=87e7e257d8eb15d6662b104518becfc75ef6db76#87e7e257d8eb15d6662b104518becfc75ef6db76" dependencies = [ "aes-gcm", "chacha20poly1305", @@ -3320,7 +3320,7 @@ dependencies = [ [[package]] name = "openmls_test" version = "0.1.0-pre.1" -source = "git+https://github.com/xmtp/openmls?rev=cf42738018d093434c955a1b50a9de34cc12b8c5#cf42738018d093434c955a1b50a9de34cc12b8c5" +source = "git+https://github.com/xmtp/openmls?rev=87e7e257d8eb15d6662b104518becfc75ef6db76#87e7e257d8eb15d6662b104518becfc75ef6db76" dependencies = [ "ansi_term", "openmls_rust_crypto", @@ -3335,7 +3335,7 @@ dependencies = [ [[package]] name = "openmls_traits" version = "0.3.0-pre.2" -source = "git+https://github.com/xmtp/openmls?rev=cf42738018d093434c955a1b50a9de34cc12b8c5#cf42738018d093434c955a1b50a9de34cc12b8c5" +source = "git+https://github.com/xmtp/openmls?rev=87e7e257d8eb15d6662b104518becfc75ef6db76#87e7e257d8eb15d6662b104518becfc75ef6db76" dependencies = [ "serde", "tls_codec", diff --git a/Cargo.toml b/Cargo.toml index fba795ca2..4048b2f7b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,10 +38,10 @@ futures = "0.3.30" futures-core = "0.3.30" hex = "0.4.3" log = { version = "0.4" } -openmls = { git = "https://github.com/xmtp/openmls", rev = "cf42738018d093434c955a1b50a9de34cc12b8c5", default-features = false } -openmls_basic_credential = { git = "https://github.com/xmtp/openmls", rev = "cf42738018d093434c955a1b50a9de34cc12b8c5" } -openmls_rust_crypto = { git = "https://github.com/xmtp/openmls", rev = "cf42738018d093434c955a1b50a9de34cc12b8c5" } -openmls_traits = { git = "https://github.com/xmtp/openmls", rev = "cf42738018d093434c955a1b50a9de34cc12b8c5" } +openmls = { git = "https://github.com/xmtp/openmls", rev = "87e7e257d8eb15d6662b104518becfc75ef6db76", default-features = false } +openmls_basic_credential = { git = "https://github.com/xmtp/openmls", rev = "87e7e257d8eb15d6662b104518becfc75ef6db76" } +openmls_rust_crypto = { git = "https://github.com/xmtp/openmls", rev = "87e7e257d8eb15d6662b104518becfc75ef6db76" } +openmls_traits = { git = "https://github.com/xmtp/openmls", rev = "87e7e257d8eb15d6662b104518becfc75ef6db76" } pbjson = "0.6.0" pbjson-types = "0.6.0" prost = "^0.12" diff --git a/bindings_ffi/Cargo.lock b/bindings_ffi/Cargo.lock index 7dc79d249..297c6ef45 100644 --- a/bindings_ffi/Cargo.lock +++ b/bindings_ffi/Cargo.lock @@ -2842,7 +2842,7 @@ dependencies = [ [[package]] name = "openmls" version = "0.6.0-pre.2" -source = "git+https://github.com/xmtp/openmls?rev=cf42738018d093434c955a1b50a9de34cc12b8c5#cf42738018d093434c955a1b50a9de34cc12b8c5" +source = "git+https://github.com/xmtp/openmls?rev=87e7e257d8eb15d6662b104518becfc75ef6db76#87e7e257d8eb15d6662b104518becfc75ef6db76" dependencies = [ "backtrace", "itertools 0.10.5", @@ -2865,7 +2865,7 @@ dependencies = [ [[package]] name = "openmls_basic_credential" version = "0.3.0-pre.1" -source = "git+https://github.com/xmtp/openmls?rev=cf42738018d093434c955a1b50a9de34cc12b8c5#cf42738018d093434c955a1b50a9de34cc12b8c5" +source = "git+https://github.com/xmtp/openmls?rev=87e7e257d8eb15d6662b104518becfc75ef6db76#87e7e257d8eb15d6662b104518becfc75ef6db76" dependencies = [ "ed25519-dalek", "openmls_traits", @@ -2878,7 +2878,7 @@ dependencies = [ [[package]] name = "openmls_memory_storage" version = "0.3.0-pre.2" -source = "git+https://github.com/xmtp/openmls?rev=cf42738018d093434c955a1b50a9de34cc12b8c5#cf42738018d093434c955a1b50a9de34cc12b8c5" +source = "git+https://github.com/xmtp/openmls?rev=87e7e257d8eb15d6662b104518becfc75ef6db76#87e7e257d8eb15d6662b104518becfc75ef6db76" dependencies = [ "hex", "log", @@ -2891,7 +2891,7 @@ dependencies = [ [[package]] name = "openmls_rust_crypto" version = "0.3.0-pre.1" -source = "git+https://github.com/xmtp/openmls?rev=cf42738018d093434c955a1b50a9de34cc12b8c5#cf42738018d093434c955a1b50a9de34cc12b8c5" +source = "git+https://github.com/xmtp/openmls?rev=87e7e257d8eb15d6662b104518becfc75ef6db76#87e7e257d8eb15d6662b104518becfc75ef6db76" dependencies = [ "aes-gcm", "chacha20poly1305", @@ -2915,7 +2915,7 @@ dependencies = [ [[package]] name = "openmls_test" version = "0.1.0-pre.1" -source = "git+https://github.com/xmtp/openmls?rev=cf42738018d093434c955a1b50a9de34cc12b8c5#cf42738018d093434c955a1b50a9de34cc12b8c5" +source = "git+https://github.com/xmtp/openmls?rev=87e7e257d8eb15d6662b104518becfc75ef6db76#87e7e257d8eb15d6662b104518becfc75ef6db76" dependencies = [ "ansi_term", "openmls_rust_crypto", @@ -2930,7 +2930,7 @@ dependencies = [ [[package]] name = "openmls_traits" version = "0.3.0-pre.2" -source = "git+https://github.com/xmtp/openmls?rev=cf42738018d093434c955a1b50a9de34cc12b8c5#cf42738018d093434c955a1b50a9de34cc12b8c5" +source = "git+https://github.com/xmtp/openmls?rev=87e7e257d8eb15d6662b104518becfc75ef6db76#87e7e257d8eb15d6662b104518becfc75ef6db76" dependencies = [ "serde", "tls_codec", diff --git a/bindings_node/Cargo.lock b/bindings_node/Cargo.lock index 9e2f2f37f..9a80ddd94 100644 --- a/bindings_node/Cargo.lock +++ b/bindings_node/Cargo.lock @@ -2676,7 +2676,7 @@ dependencies = [ [[package]] name = "openmls" version = "0.6.0-pre.2" -source = "git+https://github.com/xmtp/openmls?rev=cf42738018d093434c955a1b50a9de34cc12b8c5#cf42738018d093434c955a1b50a9de34cc12b8c5" +source = "git+https://github.com/xmtp/openmls?rev=87e7e257d8eb15d6662b104518becfc75ef6db76#87e7e257d8eb15d6662b104518becfc75ef6db76" dependencies = [ "backtrace", "itertools 0.10.5", @@ -2699,7 +2699,7 @@ dependencies = [ [[package]] name = "openmls_basic_credential" version = "0.3.0-pre.1" -source = "git+https://github.com/xmtp/openmls?rev=cf42738018d093434c955a1b50a9de34cc12b8c5#cf42738018d093434c955a1b50a9de34cc12b8c5" +source = "git+https://github.com/xmtp/openmls?rev=87e7e257d8eb15d6662b104518becfc75ef6db76#87e7e257d8eb15d6662b104518becfc75ef6db76" dependencies = [ "ed25519-dalek", "openmls_traits", @@ -2712,7 +2712,7 @@ dependencies = [ [[package]] name = "openmls_memory_storage" version = "0.3.0-pre.2" -source = "git+https://github.com/xmtp/openmls?rev=cf42738018d093434c955a1b50a9de34cc12b8c5#cf42738018d093434c955a1b50a9de34cc12b8c5" +source = "git+https://github.com/xmtp/openmls?rev=87e7e257d8eb15d6662b104518becfc75ef6db76#87e7e257d8eb15d6662b104518becfc75ef6db76" dependencies = [ "hex", "log", @@ -2725,7 +2725,7 @@ dependencies = [ [[package]] name = "openmls_rust_crypto" version = "0.3.0-pre.1" -source = "git+https://github.com/xmtp/openmls?rev=cf42738018d093434c955a1b50a9de34cc12b8c5#cf42738018d093434c955a1b50a9de34cc12b8c5" +source = "git+https://github.com/xmtp/openmls?rev=87e7e257d8eb15d6662b104518becfc75ef6db76#87e7e257d8eb15d6662b104518becfc75ef6db76" dependencies = [ "aes-gcm", "chacha20poly1305", @@ -2749,7 +2749,7 @@ dependencies = [ [[package]] name = "openmls_test" version = "0.1.0-pre.1" -source = "git+https://github.com/xmtp/openmls?rev=cf42738018d093434c955a1b50a9de34cc12b8c5#cf42738018d093434c955a1b50a9de34cc12b8c5" +source = "git+https://github.com/xmtp/openmls?rev=87e7e257d8eb15d6662b104518becfc75ef6db76#87e7e257d8eb15d6662b104518becfc75ef6db76" dependencies = [ "ansi_term", "openmls_rust_crypto", @@ -2764,7 +2764,7 @@ dependencies = [ [[package]] name = "openmls_traits" version = "0.3.0-pre.2" -source = "git+https://github.com/xmtp/openmls?rev=cf42738018d093434c955a1b50a9de34cc12b8c5#cf42738018d093434c955a1b50a9de34cc12b8c5" +source = "git+https://github.com/xmtp/openmls?rev=87e7e257d8eb15d6662b104518becfc75ef6db76#87e7e257d8eb15d6662b104518becfc75ef6db76" dependencies = [ "serde", "tls_codec", diff --git a/xmtp_api_http/src/util.rs b/xmtp_api_http/src/util.rs index f50fd7372..d3a4b00ca 100644 --- a/xmtp_api_http/src/util.rs +++ b/xmtp_api_http/src/util.rs @@ -48,8 +48,9 @@ pub async fn create_grpc_stream< endpoint: String, http_client: reqwest::Client, ) -> Result>, Error> { + log::info!("About to spawn stream"); let stream = async_stream::stream! { - log::debug!("Spawning grpc http stream"); + log::info!("Spawning grpc http stream"); let request = http_client .post(endpoint) .json(&request) diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index 2a4119cd6..09db449c5 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -1266,8 +1266,9 @@ mod tests { use xmtp_proto::xmtp::mls::message_contents::EncodedContent; use crate::{ - assert_logged, + assert_err, assert_logged, builder::ClientBuilder, + client::MessageProcessingError, codecs::{group_updated::GroupUpdatedCodec, ContentCodec}, groups::{ build_group_membership_extension, @@ -3088,11 +3089,11 @@ mod tests { alix1_group .publish_intents(&alix1_provider, &alix1) .await - .expect_err("Expected an error that publish was canceled"); + .expect("Expect publish to be OK"); alix1_group .publish_intents(&alix1_provider, &alix1) .await - .expect_err("Expected an error that publish was canceled"); + .expect("Expected publish to be OK"); // Now I am going to sync twice alix1_group @@ -3153,4 +3154,50 @@ mod tests { .iter() .any(|m| m.decrypted_message_bytes == "hi from alix1".as_bytes())); } + + #[tokio::test(flavor = "multi_thread", worker_threads = 5)] + async fn respect_allow_epoch_increment() { + let wallet = generate_local_wallet(); + let client = ClientBuilder::new_test_client(&wallet).await; + + let group = client + .create_group(None, GroupMetadataOptions::default()) + .unwrap(); + + let _client_2 = ClientBuilder::new_test_client(&wallet).await; + + // Sync the group to get the message adding client_2 published to the network + group.sync(&client).await.unwrap(); + + // Retrieve the envelope for the commit from the network + let messages = client + .api_client + .query_group_messages(group.group_id.clone(), None) + .await + .unwrap(); + + let first_envelope = messages.first().unwrap(); + + let Some(xmtp_proto::xmtp::mls::api::v1::group_message::Version::V1(first_message)) = + first_envelope.clone().version + else { + panic!("wrong message format") + }; + let provider = client.mls_provider().unwrap(); + let mut openmls_group = group.load_mls_group(&provider).unwrap(); + let process_result = group + .process_message( + &client, + &mut openmls_group, + &provider, + &first_message, + false, + ) + .await; + + assert_err!( + process_result, + MessageProcessingError::EpochIncrementNotAllowed + ); + } } diff --git a/xmtp_mls/src/groups/subscriptions.rs b/xmtp_mls/src/groups/subscriptions.rs index 14c907645..a7f0cb8a8 100644 --- a/xmtp_mls/src/groups/subscriptions.rs +++ b/xmtp_mls/src/groups/subscriptions.rs @@ -66,8 +66,16 @@ impl MlsGroup { ); if let Some(GroupError::ReceiveError(_)) = process_result.as_ref().err() { - self.sync_with_conn(&client.mls_provider()?, &client) - .await?; + // Swallow errors here, since another process may have successfully saved the message + // to the DB + match self.sync_with_conn(&client.mls_provider()?, &client).await { + Ok(_) => { + log::debug!("Sync triggered by streamed message successful") + } + Err(err) => { + log::warn!("Sync triggered by streamed message failed: {}", err); + } + }; } else if process_result.is_err() { log::error!("Process stream entry {:?}", process_result.err()); } @@ -309,6 +317,8 @@ mod tests { }); // just to make sure stream is started let _ = start_rx.await; + // Adding in a sleep, since the HTTP API client may acknowledge requests before they are ready + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; amal_group .add_members_by_inbox_id(&amal, vec![bola.inbox_id()]) diff --git a/xmtp_mls/src/groups/sync.rs b/xmtp_mls/src/groups/sync.rs index d44435c20..e817ef3eb 100644 --- a/xmtp_mls/src/groups/sync.rs +++ b/xmtp_mls/src/groups/sync.rs @@ -46,7 +46,7 @@ use log::debug; use openmls::{ credentials::BasicCredential, extensions::Extensions, - framing::ProtocolMessage, + framing::{ContentType, ProtocolMessage}, group::{GroupEpoch, StagedCommit}, prelude::{ tls_codec::{Deserialize, Serialize}, @@ -266,7 +266,6 @@ impl MlsGroup { provider: &XmtpOpenMlsProvider, message: ProtocolMessage, envelope_timestamp_ns: u64, - allow_epoch_increment: bool, ) -> Result { if intent.state == IntentState::Committed { return Ok(IntentState::Committed); @@ -290,9 +289,6 @@ impl MlsGroup { | IntentKind::UpdateAdminList | IntentKind::MetadataUpdate | IntentKind::UpdatePermission => { - if !allow_epoch_increment { - return Err(MessageProcessingError::EpochIncrementNotAllowed); - } if let Some(published_in_epoch) = intent.published_in_epoch { let published_in_epoch_u64 = published_in_epoch as u64; let group_epoch_u64 = group_epoch.as_u64(); @@ -376,7 +372,6 @@ impl MlsGroup { provider: &XmtpOpenMlsProvider, message: PrivateMessageIn, envelope_timestamp_ns: u64, - allow_epoch_increment: bool, ) -> Result<(), MessageProcessingError> { let decrypted_message = openmls_group.process_message(provider, message)?; let (sender_inbox_id, sender_installation_id) = @@ -548,9 +543,6 @@ impl MlsGroup { // intentionally left blank. } ProcessedMessageContent::StagedCommitMessage(staged_commit) => { - if !allow_epoch_increment { - return Err(MessageProcessingError::EpochIncrementNotAllowed); - } log::info!( "[{}] received staged commit. Merging and clearing any pending commits", self.context.inbox_id() @@ -600,6 +592,10 @@ impl MlsGroup { )), }?; + if !allow_epoch_increment && message.content_type() == ContentType::Commit { + return Err(MessageProcessingError::EpochIncrementNotAllowed); + } + let intent = provider .conn_ref() .find_group_intent_by_payload_hash(sha256(envelope.data.as_slice())); @@ -622,7 +618,6 @@ impl MlsGroup { provider, message.into(), envelope.created_ns, - allow_epoch_increment, ) .await? { @@ -654,7 +649,6 @@ impl MlsGroup { provider, message, envelope.created_ns, - allow_epoch_increment, ) .await } @@ -875,8 +869,8 @@ impl MlsGroup { intent.kind ); if has_staged_commit { - log::info!("Canceling all further publishes, since a commit was found"); - return Err(GroupError::PublishCancelled); + log::info!("Commit sent. Stopping further publishes for this round"); + return Ok(()); } } Ok(None) => { @@ -1011,7 +1005,7 @@ impl MlsGroup { } } - #[tracing::instrument(level = "trace", skip(conn, client))] + #[tracing::instrument(level = "trace", skip_all)] pub(crate) async fn post_commit( &self, conn: &DbConnection, diff --git a/xmtp_mls/src/lib.rs b/xmtp_mls/src/lib.rs index a80ac59dd..0257ddedf 100644 --- a/xmtp_mls/src/lib.rs +++ b/xmtp_mls/src/lib.rs @@ -123,7 +123,7 @@ mod tests { #[macro_export] macro_rules! assert_err { ( $x:expr , $y:pat $(,)? ) => { - assert!(matches!($x, Err($y))); + assert!(matches!($x, Err($y))) }; ( $x:expr, $y:pat $(,)?, $($msg:tt)+) => {{ diff --git a/xmtp_mls/src/storage/encrypted_store/refresh_state.rs b/xmtp_mls/src/storage/encrypted_store/refresh_state.rs index 1c9e03478..a72d92e32 100644 --- a/xmtp_mls/src/storage/encrypted_store/refresh_state.rs +++ b/xmtp_mls/src/storage/encrypted_store/refresh_state.rs @@ -9,7 +9,7 @@ use diesel::{ }; use super::{db_connection::DbConnection, schema::refresh_state}; -use crate::{impl_store, storage::StorageError, Store}; +use crate::{impl_store, impl_store_or_ignore, storage::StorageError, StoreOrIgnore}; #[repr(i32)] #[derive(Debug, Clone, Copy, PartialEq, Eq, AsExpression, Hash, FromSqlRow)] @@ -52,6 +52,7 @@ pub struct RefreshState { } impl_store!(RefreshState, refresh_state); +impl_store_or_ignore!(RefreshState, refresh_state); impl DbConnection { pub fn get_refresh_state>>( @@ -69,6 +70,7 @@ impl DbConnection { Ok(res) } + pub fn get_last_cursor_for_id>>( &self, id: IdType, @@ -83,7 +85,7 @@ impl DbConnection { entity_kind, cursor: 0, }; - new_state.store(self)?; + new_state.store_or_ignore(self)?; Ok(0) } } @@ -119,7 +121,7 @@ impl DbConnection { #[cfg(test)] pub(crate) mod tests { use super::*; - use crate::storage::encrypted_store::tests::with_connection; + use crate::{storage::encrypted_store::tests::with_connection, Store}; #[test] fn get_cursor_with_no_existing_state() {