Skip to content

Commit

Permalink
feat(wasm): unblock streams in the browser
Browse files Browse the repository at this point in the history
  • Loading branch information
insipx committed Dec 4, 2024
1 parent cf10919 commit e4af7cb
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 87 deletions.
5 changes: 3 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions xmtp_api_http/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ thiserror = "2.0"
tokio = { workspace = true, features = ["sync", "rt", "macros"] }
xmtp_proto = { path = "../xmtp_proto", features = ["proto_full"] }
async-trait = "0.1"
bytes = "1.9"

[dev-dependencies]
xmtp_proto = { path = "../xmtp_proto", features = ["test-utils"] }
Expand Down
129 changes: 129 additions & 0 deletions xmtp_api_http/src/http_stream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
//! Streams that work with HTTP POST requests
use crate::util::GrpcResponse;
use futures::{
stream::{self, Stream, StreamExt},
Future,
};
use reqwest::Response;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::Deserializer;
use std::pin::Pin;
use xmtp_proto::{Error, ErrorKind};

#[derive(Deserialize, Serialize, Debug)]
pub(crate) struct SubscriptionItem<T> {
pub result: T,
}

enum HttpPostStream<F>
where
F: Future<Output = Result<Response, reqwest::Error>>,
{
NotStarted(F),
// NotStarted(Box<dyn Future<Output = Result<Response, Error>>>),
Started(Pin<Box<dyn Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Unpin + Send>>),
}

impl<F> Stream for HttpPostStream<F>
where
F: Future<Output = Result<Response, reqwest::Error>> + Unpin,
{
type Item = Result<bytes::Bytes, reqwest::Error>;

fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
use futures::task::Poll::*;
use HttpPostStream::*;
match self.as_mut().get_mut() {
NotStarted(ref mut f) => {
tracing::info!("Polling");
let f = std::pin::pin!(f);
match f.poll(cx) {
Ready(response) => {
let s = response.unwrap().bytes_stream();
self.set(Self::Started(Box::pin(s.boxed())));
self.poll_next(cx)
}
Pending => {
// cx.waker().wake_by_ref();
Pending
}
}
}
Started(s) => s.poll_next_unpin(cx),
}
}
}

#[cfg(target_arch = "wasm32")]
pub fn create_grpc_stream<
T: Serialize + Send + 'static,
R: DeserializeOwned + Send + std::fmt::Debug + 'static,
>(
request: T,
endpoint: String,
http_client: reqwest::Client,
) -> stream::LocalBoxStream<'static, Result<R, Error>> {
create_grpc_stream_inner(request, endpoint, http_client).boxed_local()
}

#[cfg(not(target_arch = "wasm32"))]
pub fn create_grpc_stream<
T: Serialize + Send + 'static,
R: DeserializeOwned + Send + std::fmt::Debug + 'static,
>(
request: T,
endpoint: String,
http_client: reqwest::Client,
) -> stream::BoxStream<'static, Result<R, Error>> {
create_grpc_stream_inner(request, endpoint, http_client).boxed()
}

pub fn create_grpc_stream_inner<
T: Serialize + Send + 'static,
R: DeserializeOwned + Send + std::fmt::Debug + 'static,
>(
request: T,
endpoint: String,
http_client: reqwest::Client,
) -> impl Stream<Item = Result<R, Error>> {
let request = http_client.post(endpoint).json(&request).send();
let http_stream = HttpPostStream::NotStarted(request);

async_stream::stream! {
tracing::info!("spawning grpc http stream");
let mut remaining = vec![];
for await bytes in http_stream {
let bytes = bytes
.map_err(|e| Error::new(ErrorKind::SubscriptionUpdateError).with(e.to_string()))?;
let bytes = &[remaining.as_ref(), bytes.as_ref()].concat();
let de = Deserializer::from_slice(bytes);
let mut stream = de.into_iter::<GrpcResponse<R>>();
'messages: loop {
tracing::debug!("Waiting on next response ...");
let response = stream.next();
let res = match response {
Some(Ok(GrpcResponse::Ok(response))) => Ok(response),
Some(Ok(GrpcResponse::SubscriptionItem(item))) => Ok(item.result),
Some(Ok(GrpcResponse::Err(e))) => {
Err(Error::new(ErrorKind::MlsError).with(e.message))
}
Some(Err(e)) => {
if e.is_eof() {
remaining = (&**bytes)[stream.byte_offset()..].to_vec();
break 'messages;
} else {
Err(Error::new(ErrorKind::MlsError).with(e.to_string()))
}
}
Some(Ok(GrpcResponse::Empty {})) => continue 'messages,
None => break 'messages,
};
yield res;
}
}
}
}
4 changes: 3 additions & 1 deletion xmtp_api_http/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
#![warn(clippy::unwrap_used)]

pub mod constants;
mod http_stream;
mod util;

use futures::stream;
use http_stream::create_grpc_stream;
use reqwest::header;
use util::{create_grpc_stream, handle_error};
use util::handle_error;
use xmtp_proto::api_client::{ClientWithMetadata, XmtpIdentityClient};
use xmtp_proto::xmtp::identity::api::v1::{
GetIdentityUpdatesRequest as GetIdentityUpdatesV2Request,
Expand Down
85 changes: 1 addition & 84 deletions xmtp_api_http/src/util.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
use futures::{
stream::{self, StreamExt},
Stream,
};
use crate::http_stream::SubscriptionItem;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::Deserializer;
use std::io::Read;
use xmtp_proto::{Error, ErrorKind};

Expand All @@ -23,11 +19,6 @@ pub(crate) struct ErrorResponse {
details: Vec<String>,
}

#[derive(Deserialize, Serialize, Debug)]
pub(crate) struct SubscriptionItem<T> {
pub result: T,
}

/// handle JSON response from gRPC, returning either
/// the expected deserialized response object or a gRPC [`Error`]
pub fn handle_error<R: Read, T>(reader: R) -> Result<T, Error>
Expand All @@ -43,80 +34,6 @@ where
}
}

#[cfg(target_arch = "wasm32")]
pub fn create_grpc_stream<
T: Serialize + Send + 'static,
R: DeserializeOwned + Send + std::fmt::Debug + 'static,
>(
request: T,
endpoint: String,
http_client: reqwest::Client,
) -> stream::LocalBoxStream<'static, Result<R, Error>> {
create_grpc_stream_inner(request, endpoint, http_client).boxed_local()
}

#[cfg(not(target_arch = "wasm32"))]
pub fn create_grpc_stream<
T: Serialize + Send + 'static,
R: DeserializeOwned + Send + std::fmt::Debug + 'static,
>(
request: T,
endpoint: String,
http_client: reqwest::Client,
) -> stream::BoxStream<'static, Result<R, Error>> {
create_grpc_stream_inner(request, endpoint, http_client).boxed()
}

pub fn create_grpc_stream_inner<
T: Serialize + Send + 'static,
R: DeserializeOwned + Send + std::fmt::Debug + 'static,
>(
request: T,
endpoint: String,
http_client: reqwest::Client,
) -> impl Stream<Item = Result<R, Error>> {
async_stream::stream! {
tracing::info!("Spawning grpc http stream");
let request = http_client
.post(endpoint)
.json(&request)
.send()
.await
.map_err(|e| Error::new(ErrorKind::MlsError).with(e))?;
tracing::debug!("Got Request, getting byte stream");
let mut remaining = vec![];
for await bytes in request.bytes_stream() {
let bytes = bytes
.map_err(|e| Error::new(ErrorKind::SubscriptionUpdateError).with(e.to_string()))?;
let bytes = &[remaining.as_ref(), bytes.as_ref()].concat();
let de = Deserializer::from_slice(bytes);
let mut stream = de.into_iter::<GrpcResponse<R>>();
'messages: loop {
tracing::debug!("Waiting on next response ...");
let response = stream.next();
let res = match response {
Some(Ok(GrpcResponse::Ok(response))) => Ok(response),
Some(Ok(GrpcResponse::SubscriptionItem(item))) => Ok(item.result),
Some(Ok(GrpcResponse::Err(e))) => {
Err(Error::new(ErrorKind::MlsError).with(e.message))
}
Some(Err(e)) => {
if e.is_eof() {
remaining = (&**bytes)[stream.byte_offset()..].to_vec();
break 'messages;
} else {
Err(Error::new(ErrorKind::MlsError).with(e.to_string()))
}
}
Some(Ok(GrpcResponse::Empty {})) => continue 'messages,
None => break 'messages,
};
yield res;
}
}
}
}

#[cfg(feature = "test-utils")]
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
Expand Down
1 change: 1 addition & 0 deletions xmtp_mls/src/subscriptions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,7 @@ pub(crate) mod tests {
let alice_group = alice
.create_group(None, GroupMetadataOptions::default())
.unwrap();
tracing::info!("Group Id = [{}]", hex::encode(&alice_group.group_id));

alice_group
.add_members_by_inbox_id(&[bob.inbox_id()])
Expand Down

0 comments on commit e4af7cb

Please sign in to comment.