Skip to content

Commit

Permalink
progress from friday
Browse files Browse the repository at this point in the history
  • Loading branch information
insipx committed Jan 6, 2025
1 parent fb4e7dd commit ccbcb3b
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 54 deletions.
4 changes: 1 addition & 3 deletions xmtp_mls/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -559,9 +559,7 @@ where
{
Some(id) => id,
None => {
return Err(ClientError::Storage(StorageError::NotFound(
NotFound::InboxIdForAddress(account_address),
)));
return Err(NotFound::InboxIdForAddress(account_address).into());
}
};

Expand Down
15 changes: 13 additions & 2 deletions xmtp_mls/src/subscriptions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use xmtp_id::scw_verifier::SmartContractSignatureVerifier;
use xmtp_proto::{api_client::XmtpMlsStreams, xmtp::mls::api::v1::WelcomeMessage};

// mod stream_all;
// mod stream_conversations;
mod stream_conversations;

use crate::{
client::{extract_welcome_message, ClientError},
Expand All @@ -24,7 +24,7 @@ use crate::{
consent_record::StoredConsentRecord,
group::{ConversationType, GroupQueryArgs, StoredGroup},
group_message::StoredGroupMessage,
StorageError,
NotFound, StorageError,
},
Client, XmtpApi, XmtpOpenMlsProvider,
};
Expand Down Expand Up @@ -223,6 +223,13 @@ impl From<StoredGroup> for (Vec<u8>, MessagesStreamInfo) {
}
}

// TODO: REMOVE BEFORE MERGING
// TODO: REMOVE BEFORE MERGING
// TODO: REMOVE BEFORE MERGING
pub(self) mod temp {
pub(super) type Result<T> = std::result::Result<T, super::SubscribeError>;
}

#[derive(thiserror::Error, Debug)]
pub enum SubscribeError {
#[error("failed to start new messages stream {0}")]
Expand All @@ -231,6 +238,9 @@ pub enum SubscribeError {
Client(#[from] ClientError),
#[error(transparent)]
Group(#[from] GroupError),
#[error(transparent)]
NotFound(#[from] NotFound),
// TODO: Add this to `NotFound`
#[error("group message expected in database but is missing")]
GroupMessageNotFound,
#[error("processing group message in stream: {0}")]
Expand Down Expand Up @@ -258,6 +268,7 @@ impl RetryableError for SubscribeError {
Storage(e) => retryable!(e),
Api(e) => retryable!(e),
Decode(_) => false,
NotFound(e) => retryable!(e),
}
}
}
Expand Down
185 changes: 136 additions & 49 deletions xmtp_mls/src/subscriptions/stream_conversations.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,28 @@
use std::{collections::HashSet, marker::PhantomData, sync::Arc, task::Poll};
use std::{
collections::HashSet, future::Future, marker::PhantomData, pin::Pin,
sync::Arc, task::Poll,
};

use futures::{prelude::stream::Select, Stream};
use crate::{
groups::{scoped_client::ScopedGroupClient, MlsGroup},
storage::{group::ConversationType, DbConnection, NotFound},
Client, XmtpOpenMlsProvider,
};
use futures::{future::FutureExt, prelude::stream::Select, Stream};
use pin_project_lite::pin_project;
use tokio_stream::wrappers::BroadcastStream;
use xmtp_common::{retry_async, Retry};
use xmtp_id::scw_verifier::SmartContractSignatureVerifier;
use xmtp_proto::{
api_client::{trait_impls::XmtpApi, XmtpMlsStreams},
xmtp::mls::api::v1::WelcomeMessage,
};

use crate::{
groups::{scoped_client::ScopedGroupClient, MlsGroup},
storage::{group::ConversationType, DbConnection},
Client, XmtpOpenMlsProvider,
xmtp::mls::api::v1::{welcome_message::V1 as WelcomeMessageV1, WelcomeMessage},
};

use super::{LocalEvents, SubscribeError};
use super::{temp::Result, LocalEvents, SubscribeError};

enum WelcomeOrGroup<C> {
Group(Result<MlsGroup<C>, SubscribeError>),
Welcome(Result<WelcomeMessage, xmtp_proto::Error>),
Group(Result<MlsGroup<C>>),
Welcome(Result<WelcomeMessage>),
}

pin_project! {
Expand Down Expand Up @@ -85,7 +87,7 @@ impl<S, C> SubscriptionStream<S, C> {

impl<S, C> Stream for SubscriptionStream<S, C>
where
S: Stream<Item = Result<WelcomeMessage, xmtp_proto::Error>>,
S: Stream<Item = std::result::Result<WelcomeMessage, xmtp_proto::Error>>,
{
type Item = WelcomeOrGroup<C>;

Expand All @@ -97,7 +99,10 @@ where
let this = self.project();

match this.inner.poll_next(cx) {
Ready(Some(welcome)) => Ready(Some(WelcomeOrGroup::Welcome(welcome))),
Ready(Some(welcome)) => {
let welcome = welcome.map_err(SubscribeError::from);
Ready(Some(WelcomeOrGroup::Welcome(welcome)))
}
Pending => Pending,
Ready(None) => Ready(None),
}
Expand All @@ -109,7 +114,27 @@ pin_project! {
client: &'a C,
#[pin] inner: Subscription,
conversation_type: Option<ConversationType>,
known_welcome_ids: HashSet<i64>
known_welcome_ids: HashSet<i64>,
#[pin] state: ProcessState<'a, C>,
}
}

pin_project! {
#[project = ProcessProject]
enum ProcessState<'a, C> {
/// State where we are waiting on the next Message from the network
Waiting,
/// State where we are waiting on an IO/Network future to finish processing the current message
/// before moving on to the next one
Processing {
#[pin] future: Pin<Box<dyn Future<Output = Result< (MlsGroup<C>, Option<i64>) >> + 'a >>
}
}
}

impl<'a, C> Default for ProcessState<'a, C> {
fn default() -> Self {
ProcessState::Waiting
}
}

Expand All @@ -129,7 +154,7 @@ where
client: &'a Client<A, V>,
conversation_type: Option<ConversationType>,
conn: &DbConnection,
) -> Result<Self, SubscribeError> {
) -> Result<Self> {
let installation_key = client.installation_public_key();
let id_cursor = 0;
tracing::info!(
Expand All @@ -154,34 +179,66 @@ where
inner: stream,
known_welcome_ids,
conversation_type,
state: ProcessState::Waiting,
})
}
}

impl<'a, C, Subscription> Stream for StreamConversations<'a, C, Subscription>
where
C: ScopedGroupClient + Clone,
Subscription: Stream<Item = Result<WelcomeOrGroup<C>, SubscribeError>>,
Subscription: Stream<Item = Result<WelcomeOrGroup<C>>> + 'a,
{
type Item = Result<MlsGroup<C>, SubscribeError>;
type Item = Result<MlsGroup<C>>;

fn poll_next(
self: std::pin::Pin<&mut Self>,
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
use std::task::Poll::*;
let this = self.project();
use ProcessState::*;
let mut this = self.as_mut().project();

match this.inner.poll_next(cx) {
Ready(Some(msg)) => {
todo!()
}
// stream ended
Ready(None) => Ready(None),
Pending => {
cx.waker().wake_by_ref();
Pending
match this.state.as_mut().project() {
ProcessProject::Waiting => {
match this.inner.poll_next(cx) {
Ready(Some(item)) => {
let future =
// need to clone client into Arc<> here b/c:
// otherwise the `'1` ref for `Pin<&mut Self>` in arg to `poll_next` needs to
// live as long as `'a` ref for `Client`.
// This is because we're boxing this future (i.e `Box<dyn Future + 'a>`).
// There maybe a way to avoid it, but we need to `Box<>` the type
// b/c there's no way to get the anonymous future type on the stack generated by an
// `async fn`. If we can somehow store `impl Trait` on a struct (or
// something similar), we could avoid the `Clone` + `Arc`ing.
Self::process_new_item(this.known_welcome_ids.clone(), Arc::new(this.client.clone()), item);

this.state.set(ProcessState::Processing {
future: future.boxed(),
});
Pending
}
// stream ended
Ready(None) => Ready(None),
Pending => {
cx.waker().wake_by_ref();
Pending
}
}
}
/// We're processing a message we received
ProcessProject::Processing { future } => match future.poll(cx) {
Ready(Ok((group, welcome_id))) => {
if let Some(id) = welcome_id {
this.known_welcome_ids.insert(id);
}
this.state.set(ProcessState::Waiting);
Ready(Some(Ok(group)))
}
Ready(Err(e)) => Ready(Some(Err(e))),
Pending => Pending,
},
}
}
}
Expand All @@ -190,52 +247,82 @@ impl<'a, C, Subscription> StreamConversations<'a, C, Subscription>
where
C: ScopedGroupClient + Clone,
{
async fn process_streamed_welcome(
&mut self,
client: C,
async fn process_new_item(
known_welcome_ids: HashSet<i64>,
client: Arc<C>,
item: Result<WelcomeOrGroup<C>>,
) -> Result<(MlsGroup<C>, Option<i64>)> {
use WelcomeOrGroup::*;
let provider = client.context().mls_provider()?;
match item? {
Welcome(w) => Self::on_welcome(&known_welcome_ids, client, &provider, w?).await,
Group(g) => {
todo!()
}
}
}

// process a new welcome, returning the new welcome ID
async fn on_welcome(
known_welcome_ids: &HashSet<i64>,
client: Arc<C>,
provider: &XmtpOpenMlsProvider,
welcome: WelcomeMessage,
) -> Result<MlsGroup<C>, SubscribeError> {
let welcome_v1 = crate::client::extract_welcome_message(welcome)?;
if self.known_welcome_ids.contains(&(welcome_v1.id as i64)) {
) -> Result<(MlsGroup<C>, Option<i64>)> {
let WelcomeMessageV1 {
id,
ref created_ns,
ref installation_key,
ref data,
ref hpke_public_key,
} = crate::client::extract_welcome_message(welcome)?;
let id = id as i64;

if known_welcome_ids.contains(&(id)) {
let conn = provider.conn_ref();
self.known_welcome_ids.insert(welcome_v1.id as i64);
let group = conn.find_group_by_welcome_id(welcome_v1.id as i64)?;
let group = conn
.find_group_by_welcome_id(id)?
.ok_or(NotFound::GroupByWelcome(id))?;
tracing::info!(
inbox_id = client.inbox_id(),
group_id = hex::encode(&group.id),
welcome_id = ?group.welcome_id,
"Loading existing group for welcome_id: {:?}",
group.welcome_id
);
return Ok(MlsGroup::new(client.clone(), group.id, group.created_at_ns));
return Ok((
MlsGroup::new(Arc::unwrap_or_clone(client), group.id, group.created_at_ns),
Some(id),
));
}

let creation_result = retry_async!(
let c = &client;
let mls_group = retry_async!(
Retry::default(),
(async {
tracing::info!(
installation_id = &welcome_v1.id,
installation_id = hex::encode(installation_key),
welcome_id = &id,
"Trying to process streamed welcome"
);
let welcome_v1 = &welcome_v1;
client
.context

(*client)
.context()
.store()
.transaction_async(provider, |provider| async move {
MlsGroup::create_from_encrypted_welcome(
Arc::new(client.clone()),
Arc::clone(c),
provider,
welcome_v1.hpke_public_key.as_slice(),
&welcome_v1.data,
welcome_v1.id as i64,
hpke_public_key.as_slice(),
data,
id,
)
.await
})
.await
})
);
)?;

Ok(creation_result?)
Ok((mls_group, Some(id)))
}
}

0 comments on commit ccbcb3b

Please sign in to comment.