From ccbcb3bed42b36033e228f8f1de87ba9f260f1c7 Mon Sep 17 00:00:00 2001 From: Andrew Plaza Date: Sun, 22 Dec 2024 18:56:20 -0500 Subject: [PATCH] progress from friday --- xmtp_mls/src/client.rs | 4 +- xmtp_mls/src/subscriptions/mod.rs | 15 +- .../src/subscriptions/stream_conversations.rs | 185 +++++++++++++----- 3 files changed, 150 insertions(+), 54 deletions(-) diff --git a/xmtp_mls/src/client.rs b/xmtp_mls/src/client.rs index 69be71c2e..e4187ed94 100644 --- a/xmtp_mls/src/client.rs +++ b/xmtp_mls/src/client.rs @@ -559,9 +559,7 @@ where { Some(id) => id, None => { - return Err(ClientError::Storage(StorageError::NotFound( - NotFound::InboxIdForAddress(account_address), - ))); + return Err(NotFound::InboxIdForAddress(account_address).into()); } }; diff --git a/xmtp_mls/src/subscriptions/mod.rs b/xmtp_mls/src/subscriptions/mod.rs index e9b8dcad0..c0f6d1307 100644 --- a/xmtp_mls/src/subscriptions/mod.rs +++ b/xmtp_mls/src/subscriptions/mod.rs @@ -11,7 +11,7 @@ use xmtp_id::scw_verifier::SmartContractSignatureVerifier; use xmtp_proto::{api_client::XmtpMlsStreams, xmtp::mls::api::v1::WelcomeMessage}; // mod stream_all; -// mod stream_conversations; +mod stream_conversations; use crate::{ client::{extract_welcome_message, ClientError}, @@ -24,7 +24,7 @@ use crate::{ consent_record::StoredConsentRecord, group::{ConversationType, GroupQueryArgs, StoredGroup}, group_message::StoredGroupMessage, - StorageError, + NotFound, StorageError, }, Client, XmtpApi, XmtpOpenMlsProvider, }; @@ -223,6 +223,13 @@ impl From for (Vec, MessagesStreamInfo) { } } +// TODO: REMOVE BEFORE MERGING +// TODO: REMOVE BEFORE MERGING +// TODO: REMOVE BEFORE MERGING +pub(self) mod temp { + pub(super) type Result = std::result::Result; +} + #[derive(thiserror::Error, Debug)] pub enum SubscribeError { #[error("failed to start new messages stream {0}")] @@ -231,6 +238,9 @@ pub enum SubscribeError { Client(#[from] ClientError), #[error(transparent)] Group(#[from] GroupError), + #[error(transparent)] + NotFound(#[from] NotFound), + // TODO: Add this to `NotFound` #[error("group message expected in database but is missing")] GroupMessageNotFound, #[error("processing group message in stream: {0}")] @@ -258,6 +268,7 @@ impl RetryableError for SubscribeError { Storage(e) => retryable!(e), Api(e) => retryable!(e), Decode(_) => false, + NotFound(e) => retryable!(e), } } } diff --git a/xmtp_mls/src/subscriptions/stream_conversations.rs b/xmtp_mls/src/subscriptions/stream_conversations.rs index 3f3f3abb5..3bc47894c 100644 --- a/xmtp_mls/src/subscriptions/stream_conversations.rs +++ b/xmtp_mls/src/subscriptions/stream_conversations.rs @@ -1,26 +1,28 @@ -use std::{collections::HashSet, marker::PhantomData, sync::Arc, task::Poll}; +use std::{ + collections::HashSet, future::Future, marker::PhantomData, pin::Pin, + sync::Arc, task::Poll, +}; -use futures::{prelude::stream::Select, Stream}; +use crate::{ + groups::{scoped_client::ScopedGroupClient, MlsGroup}, + storage::{group::ConversationType, DbConnection, NotFound}, + Client, XmtpOpenMlsProvider, +}; +use futures::{future::FutureExt, prelude::stream::Select, Stream}; use pin_project_lite::pin_project; use tokio_stream::wrappers::BroadcastStream; use xmtp_common::{retry_async, Retry}; use xmtp_id::scw_verifier::SmartContractSignatureVerifier; use xmtp_proto::{ api_client::{trait_impls::XmtpApi, XmtpMlsStreams}, - xmtp::mls::api::v1::WelcomeMessage, -}; - -use crate::{ - groups::{scoped_client::ScopedGroupClient, MlsGroup}, - storage::{group::ConversationType, DbConnection}, - Client, XmtpOpenMlsProvider, + xmtp::mls::api::v1::{welcome_message::V1 as WelcomeMessageV1, WelcomeMessage}, }; -use super::{LocalEvents, SubscribeError}; +use super::{temp::Result, LocalEvents, SubscribeError}; enum WelcomeOrGroup { - Group(Result, SubscribeError>), - Welcome(Result), + Group(Result>), + Welcome(Result), } pin_project! { @@ -85,7 +87,7 @@ impl SubscriptionStream { impl Stream for SubscriptionStream where - S: Stream>, + S: Stream>, { type Item = WelcomeOrGroup; @@ -97,7 +99,10 @@ where let this = self.project(); match this.inner.poll_next(cx) { - Ready(Some(welcome)) => Ready(Some(WelcomeOrGroup::Welcome(welcome))), + Ready(Some(welcome)) => { + let welcome = welcome.map_err(SubscribeError::from); + Ready(Some(WelcomeOrGroup::Welcome(welcome))) + } Pending => Pending, Ready(None) => Ready(None), } @@ -109,7 +114,27 @@ pin_project! { client: &'a C, #[pin] inner: Subscription, conversation_type: Option, - known_welcome_ids: HashSet + known_welcome_ids: HashSet, + #[pin] state: ProcessState<'a, C>, + } +} + +pin_project! { + #[project = ProcessProject] + enum ProcessState<'a, C> { + /// State where we are waiting on the next Message from the network + Waiting, + /// State where we are waiting on an IO/Network future to finish processing the current message + /// before moving on to the next one + Processing { + #[pin] future: Pin, Option) >> + 'a >> + } + } +} + +impl<'a, C> Default for ProcessState<'a, C> { + fn default() -> Self { + ProcessState::Waiting } } @@ -129,7 +154,7 @@ where client: &'a Client, conversation_type: Option, conn: &DbConnection, - ) -> Result { + ) -> Result { let installation_key = client.installation_public_key(); let id_cursor = 0; tracing::info!( @@ -154,6 +179,7 @@ where inner: stream, known_welcome_ids, conversation_type, + state: ProcessState::Waiting, }) } } @@ -161,27 +187,58 @@ where impl<'a, C, Subscription> Stream for StreamConversations<'a, C, Subscription> where C: ScopedGroupClient + Clone, - Subscription: Stream, SubscribeError>>, + Subscription: Stream>> + 'a, { - type Item = Result, SubscribeError>; + type Item = Result>; fn poll_next( - self: std::pin::Pin<&mut Self>, + mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { use std::task::Poll::*; - let this = self.project(); + use ProcessState::*; + let mut this = self.as_mut().project(); - match this.inner.poll_next(cx) { - Ready(Some(msg)) => { - todo!() - } - // stream ended - Ready(None) => Ready(None), - Pending => { - cx.waker().wake_by_ref(); - Pending + match this.state.as_mut().project() { + ProcessProject::Waiting => { + match this.inner.poll_next(cx) { + Ready(Some(item)) => { + let future = + // need to clone client into Arc<> here b/c: + // otherwise the `'1` ref for `Pin<&mut Self>` in arg to `poll_next` needs to + // live as long as `'a` ref for `Client`. + // This is because we're boxing this future (i.e `Box`). + // There maybe a way to avoid it, but we need to `Box<>` the type + // b/c there's no way to get the anonymous future type on the stack generated by an + // `async fn`. If we can somehow store `impl Trait` on a struct (or + // something similar), we could avoid the `Clone` + `Arc`ing. + Self::process_new_item(this.known_welcome_ids.clone(), Arc::new(this.client.clone()), item); + + this.state.set(ProcessState::Processing { + future: future.boxed(), + }); + Pending + } + // stream ended + Ready(None) => Ready(None), + Pending => { + cx.waker().wake_by_ref(); + Pending + } + } } + /// We're processing a message we received + ProcessProject::Processing { future } => match future.poll(cx) { + Ready(Ok((group, welcome_id))) => { + if let Some(id) = welcome_id { + this.known_welcome_ids.insert(id); + } + this.state.set(ProcessState::Waiting); + Ready(Some(Ok(group))) + } + Ready(Err(e)) => Ready(Some(Err(e))), + Pending => Pending, + }, } } } @@ -190,17 +247,42 @@ impl<'a, C, Subscription> StreamConversations<'a, C, Subscription> where C: ScopedGroupClient + Clone, { - async fn process_streamed_welcome( - &mut self, - client: C, + async fn process_new_item( + known_welcome_ids: HashSet, + client: Arc, + item: Result>, + ) -> Result<(MlsGroup, Option)> { + use WelcomeOrGroup::*; + let provider = client.context().mls_provider()?; + match item? { + Welcome(w) => Self::on_welcome(&known_welcome_ids, client, &provider, w?).await, + Group(g) => { + todo!() + } + } + } + + // process a new welcome, returning the new welcome ID + async fn on_welcome( + known_welcome_ids: &HashSet, + client: Arc, provider: &XmtpOpenMlsProvider, welcome: WelcomeMessage, - ) -> Result, SubscribeError> { - let welcome_v1 = crate::client::extract_welcome_message(welcome)?; - if self.known_welcome_ids.contains(&(welcome_v1.id as i64)) { + ) -> Result<(MlsGroup, Option)> { + let WelcomeMessageV1 { + id, + ref created_ns, + ref installation_key, + ref data, + ref hpke_public_key, + } = crate::client::extract_welcome_message(welcome)?; + let id = id as i64; + + if known_welcome_ids.contains(&(id)) { let conn = provider.conn_ref(); - self.known_welcome_ids.insert(welcome_v1.id as i64); - let group = conn.find_group_by_welcome_id(welcome_v1.id as i64)?; + let group = conn + .find_group_by_welcome_id(id)? + .ok_or(NotFound::GroupByWelcome(id))?; tracing::info!( inbox_id = client.inbox_id(), group_id = hex::encode(&group.id), @@ -208,34 +290,39 @@ where "Loading existing group for welcome_id: {:?}", group.welcome_id ); - return Ok(MlsGroup::new(client.clone(), group.id, group.created_at_ns)); + return Ok(( + MlsGroup::new(Arc::unwrap_or_clone(client), group.id, group.created_at_ns), + Some(id), + )); } - let creation_result = retry_async!( + let c = &client; + let mls_group = retry_async!( Retry::default(), (async { tracing::info!( - installation_id = &welcome_v1.id, + installation_id = hex::encode(installation_key), + welcome_id = &id, "Trying to process streamed welcome" ); - let welcome_v1 = &welcome_v1; - client - .context + + (*client) + .context() .store() .transaction_async(provider, |provider| async move { MlsGroup::create_from_encrypted_welcome( - Arc::new(client.clone()), + Arc::clone(c), provider, - welcome_v1.hpke_public_key.as_slice(), - &welcome_v1.data, - welcome_v1.id as i64, + hpke_public_key.as_slice(), + data, + id, ) .await }) .await }) - ); + )?; - Ok(creation_result?) + Ok((mls_group, Some(id))) } }