Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify aborting in streams #989

Merged
merged 1 commit into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ futures = "0.3.30"
futures-core = "0.3.30"
hex = "0.4.3"
log = { version = "0.4" }
openmls = { git = "https://github.com/xmtp/openmls", rev = "cf42738018d093434c955a1b50a9de34cc12b8c5", default-features = false }
openmls_basic_credential = { git = "https://github.com/xmtp/openmls", rev = "cf42738018d093434c955a1b50a9de34cc12b8c5" }
openmls_rust_crypto = { git = "https://github.com/xmtp/openmls", rev = "cf42738018d093434c955a1b50a9de34cc12b8c5" }
openmls_traits = { git = "https://github.com/xmtp/openmls", rev = "cf42738018d093434c955a1b50a9de34cc12b8c5" }
openmls = { git = "https://github.com/xmtp/openmls", rev = "87e7e257d8eb15d6662b104518becfc75ef6db76", default-features = false }
openmls_basic_credential = { git = "https://github.com/xmtp/openmls", rev = "87e7e257d8eb15d6662b104518becfc75ef6db76" }
openmls_rust_crypto = { git = "https://github.com/xmtp/openmls", rev = "87e7e257d8eb15d6662b104518becfc75ef6db76" }
openmls_traits = { git = "https://github.com/xmtp/openmls", rev = "87e7e257d8eb15d6662b104518becfc75ef6db76" }
pbjson = "0.6.0"
pbjson-types = "0.6.0"
prost = "^0.12"
Expand Down
12 changes: 6 additions & 6 deletions bindings_ffi/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 6 additions & 6 deletions bindings_node/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion xmtp_api_http/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ pub async fn create_grpc_stream<
endpoint: String,
http_client: reqwest::Client,
) -> Result<BoxStream<'static, Result<R, Error>>, Error> {
log::info!("About to spawn stream");
let stream = async_stream::stream! {
log::debug!("Spawning grpc http stream");
log::info!("Spawning grpc http stream");
let request = http_client
.post(endpoint)
.json(&request)
Expand Down
53 changes: 50 additions & 3 deletions xmtp_mls/src/groups/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1266,8 +1266,9 @@ mod tests {
use xmtp_proto::xmtp::mls::message_contents::EncodedContent;

use crate::{
assert_logged,
assert_err, assert_logged,
builder::ClientBuilder,
client::MessageProcessingError,
codecs::{group_updated::GroupUpdatedCodec, ContentCodec},
groups::{
build_group_membership_extension,
Expand Down Expand Up @@ -3088,11 +3089,11 @@ mod tests {
alix1_group
.publish_intents(&alix1_provider, &alix1)
.await
.expect_err("Expected an error that publish was canceled");
.expect("Expect publish to be OK");
alix1_group
.publish_intents(&alix1_provider, &alix1)
.await
.expect_err("Expected an error that publish was canceled");
.expect("Expected publish to be OK");

// Now I am going to sync twice
alix1_group
Expand Down Expand Up @@ -3153,4 +3154,50 @@ mod tests {
.iter()
.any(|m| m.decrypted_message_bytes == "hi from alix1".as_bytes()));
}

#[tokio::test(flavor = "multi_thread", worker_threads = 5)]
async fn respect_allow_epoch_increment() {
let wallet = generate_local_wallet();
let client = ClientBuilder::new_test_client(&wallet).await;

let group = client
.create_group(None, GroupMetadataOptions::default())
.unwrap();

let _client_2 = ClientBuilder::new_test_client(&wallet).await;

// Sync the group to get the message adding client_2 published to the network
group.sync(&client).await.unwrap();

// Retrieve the envelope for the commit from the network
let messages = client
.api_client
.query_group_messages(group.group_id.clone(), None)
.await
.unwrap();

let first_envelope = messages.first().unwrap();

let Some(xmtp_proto::xmtp::mls::api::v1::group_message::Version::V1(first_message)) =
first_envelope.clone().version
else {
panic!("wrong message format")
};
let provider = client.mls_provider().unwrap();
let mut openmls_group = group.load_mls_group(&provider).unwrap();
let process_result = group
.process_message(
&client,
&mut openmls_group,
&provider,
&first_message,
false,
)
.await;

assert_err!(
process_result,
MessageProcessingError::EpochIncrementNotAllowed
);
}
}
14 changes: 12 additions & 2 deletions xmtp_mls/src/groups/subscriptions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,16 @@ impl MlsGroup {
);

if let Some(GroupError::ReceiveError(_)) = process_result.as_ref().err() {
self.sync_with_conn(&client.mls_provider()?, &client)
.await?;
// Swallow errors here, since another process may have successfully saved the message
// to the DB
match self.sync_with_conn(&client.mls_provider()?, &client).await {
Ok(_) => {
log::debug!("Sync triggered by streamed message successful")
}
Err(err) => {
log::warn!("Sync triggered by streamed message failed: {}", err);
}
};
} else if process_result.is_err() {
log::error!("Process stream entry {:?}", process_result.err());
}
Expand Down Expand Up @@ -309,6 +317,8 @@ mod tests {
});
// just to make sure stream is started
let _ = start_rx.await;
// Adding in a sleep, since the HTTP API client may acknowledge requests before they are ready
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;

amal_group
.add_members_by_inbox_id(&amal, vec![bola.inbox_id()])
Expand Down
22 changes: 8 additions & 14 deletions xmtp_mls/src/groups/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ use log::debug;
use openmls::{
credentials::BasicCredential,
extensions::Extensions,
framing::ProtocolMessage,
framing::{ContentType, ProtocolMessage},
group::{GroupEpoch, StagedCommit},
prelude::{
tls_codec::{Deserialize, Serialize},
Expand Down Expand Up @@ -266,7 +266,6 @@ impl MlsGroup {
provider: &XmtpOpenMlsProvider,
message: ProtocolMessage,
envelope_timestamp_ns: u64,
allow_epoch_increment: bool,
) -> Result<IntentState, MessageProcessingError> {
if intent.state == IntentState::Committed {
return Ok(IntentState::Committed);
Expand All @@ -290,9 +289,6 @@ impl MlsGroup {
| IntentKind::UpdateAdminList
| IntentKind::MetadataUpdate
| IntentKind::UpdatePermission => {
if !allow_epoch_increment {
return Err(MessageProcessingError::EpochIncrementNotAllowed);
}
if let Some(published_in_epoch) = intent.published_in_epoch {
let published_in_epoch_u64 = published_in_epoch as u64;
let group_epoch_u64 = group_epoch.as_u64();
Expand Down Expand Up @@ -376,7 +372,6 @@ impl MlsGroup {
provider: &XmtpOpenMlsProvider,
message: PrivateMessageIn,
envelope_timestamp_ns: u64,
allow_epoch_increment: bool,
) -> Result<(), MessageProcessingError> {
let decrypted_message = openmls_group.process_message(provider, message)?;
let (sender_inbox_id, sender_installation_id) =
Expand Down Expand Up @@ -548,9 +543,6 @@ impl MlsGroup {
// intentionally left blank.
}
ProcessedMessageContent::StagedCommitMessage(staged_commit) => {
if !allow_epoch_increment {
return Err(MessageProcessingError::EpochIncrementNotAllowed);
}
log::info!(
"[{}] received staged commit. Merging and clearing any pending commits",
self.context.inbox_id()
Expand Down Expand Up @@ -600,6 +592,10 @@ impl MlsGroup {
)),
}?;

if !allow_epoch_increment && message.content_type() == ContentType::Commit {
return Err(MessageProcessingError::EpochIncrementNotAllowed);
}

let intent = provider
.conn_ref()
.find_group_intent_by_payload_hash(sha256(envelope.data.as_slice()));
Expand All @@ -622,7 +618,6 @@ impl MlsGroup {
provider,
message.into(),
envelope.created_ns,
allow_epoch_increment,
)
.await?
{
Expand Down Expand Up @@ -654,7 +649,6 @@ impl MlsGroup {
provider,
message,
envelope.created_ns,
allow_epoch_increment,
)
.await
}
Expand Down Expand Up @@ -875,8 +869,8 @@ impl MlsGroup {
intent.kind
);
if has_staged_commit {
log::info!("Canceling all further publishes, since a commit was found");
return Err(GroupError::PublishCancelled);
log::info!("Commit sent. Stopping further publishes for this round");
return Ok(());
}
}
Ok(None) => {
Expand Down Expand Up @@ -1011,7 +1005,7 @@ impl MlsGroup {
}
}

#[tracing::instrument(level = "trace", skip(conn, client))]
#[tracing::instrument(level = "trace", skip_all)]
pub(crate) async fn post_commit<ApiClient>(
&self,
conn: &DbConnection,
Expand Down
2 changes: 1 addition & 1 deletion xmtp_mls/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ mod tests {
#[macro_export]
macro_rules! assert_err {
( $x:expr , $y:pat $(,)? ) => {
assert!(matches!($x, Err($y)));
assert!(matches!($x, Err($y)))
};

( $x:expr, $y:pat $(,)?, $($msg:tt)+) => {{
Expand Down
Loading