Skip to content

Commit

Permalink
feat(wasm): unblock streams in the browser
Browse files Browse the repository at this point in the history
  • Loading branch information
insipx committed Jan 6, 2025
1 parent 7fe4b38 commit fb4e7dd
Show file tree
Hide file tree
Showing 14 changed files with 714 additions and 227 deletions.
4 changes: 3 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 11 additions & 12 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ ctor = "0.2"
ed25519 = "2.2.3"
ed25519-dalek = { version = "2.1.1", features = ["zeroize"] }
ethers = { version = "2.0", default-features = false }
futures = "0.3.30"
futures-core = "0.3.30"
futures = { version = "0.3.30", default-features = false }
getrandom = { version = "0.2", default-features = false }
hex = "0.4.3"
hkdf = "0.12.3"
Expand All @@ -62,16 +61,7 @@ tls_codec = "0.4.1"
tokio = { version = "1.35.1", default-features = false }
uuid = "1.10"
vergen-git2 = "1.0.2"
wasm-timer = "0.2"
web-time = "1.1"
# Changing this version and rustls may potentially break the android build. Use Caution.
# Test with Android and Swift first.
# Its probably preferable to one day use https://github.com/rustls/rustls-platform-verifier
# Until then, always test agains iOS/Android after updating these dependencies & making a PR
# Related Issues:
# - https://github.com/seanmonstar/reqwest/issues/2159
# - https://github.com/hyperium/tonic/pull/1974
# - https://github.com/rustls/rustls-platform-verifier/issues/58
bincode = "1.3"
console_error_panic_hook = "0.1"
const_format = "0.2"
Expand All @@ -88,6 +78,14 @@ openssl = { version = "0.10", features = ["vendored"] }
openssl-sys = { version = "0.9", features = ["vendored"] }
parking_lot = "0.12.3"
sqlite-web = "0.0.1"
# Changing this version and rustls may potentially break the android build. Use Caution.
# Test with Android and Swift first.
# Its probably preferable to one day use https://github.com/rustls/rustls-platform-verifier
# Until then, always test agains iOS/Android after updating these dependencies & making a PR
# Related Issues:
# - https://github.com/seanmonstar/reqwest/issues/2159
# - https://github.com/hyperium/tonic/pull/1974
# - https://github.com/rustls/rustls-platform-verifier/issues/58
tonic = { version = "0.12", default-features = false }
tracing = { version = "0.1", features = ["log"] }
tracing-subscriber = { version = "0.3", default-features = false }
Expand All @@ -102,7 +100,8 @@ criterion = { version = "0.5", features = [
"html_reports",
"async_tokio",
]}
once_cell = "1.2"
once_cell = "1.2"
pin-project-lite = "0.2"

# Internal Crate Dependencies
xmtp_api_grpc = { path = "xmtp_api_grpc" }
Expand Down
4 changes: 4 additions & 0 deletions common/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ pub fn rand_u64() -> u64 {
crypto_utils::rng().gen()
}

pub fn rand_i64() -> i64 {
crypto_utils::rng().gen()
}

#[cfg(not(target_arch = "wasm32"))]
pub fn tmp_path() -> String {
let db_name = crate::rand_string::<24>();
Expand Down
2 changes: 1 addition & 1 deletion xmtp_api_grpc/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ version.workspace = true
async-stream.workspace = true
async-trait = "0.1"
base64.workspace = true
futures.workspace = true
futures = { workspace = true, features = ["alloc"] }
hex.workspace = true
prost = { workspace = true, features = ["prost-derive"] }
tokio = { workspace = true, features = ["macros", "time"] }
Expand Down
5 changes: 3 additions & 2 deletions xmtp_api_http/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,17 @@ license.workspace = true
crate-type = ["cdylib", "rlib"]

[dependencies]
async-stream.workspace = true
futures = { workspace = true }
tracing.workspace = true
reqwest = { version = "0.12.5", features = ["json", "stream"] }
serde = { workspace = true }
serde_json = { workspace = true }
thiserror = "2.0"
thiserror.workspace = true
tokio = { workspace = true, features = ["sync", "rt", "macros"] }
xmtp_proto = { path = "../xmtp_proto", features = ["proto_full"] }
async-trait = "0.1"
bytes = "1.9"
pin-project-lite = "0.2.15"

[dev-dependencies]
xmtp_proto = { path = "../xmtp_proto", features = ["test-utils"] }
Expand Down
231 changes: 231 additions & 0 deletions xmtp_api_http/src/http_stream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
//! Streams that work with HTTP POST requests
use crate::util::GrpcResponse;
use futures::{
stream::{self, Stream, StreamExt},
Future,
};
use reqwest::Response;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::Deserializer;
use std::{marker::PhantomData, pin::Pin, task::Poll};
use xmtp_proto::{Error, ErrorKind};

#[derive(Deserialize, Serialize, Debug)]
pub(crate) struct SubscriptionItem<T> {
pub result: T,
}

#[cfg(target_arch = "wasm32")]
pub type BytesStream = stream::LocalBoxStream<'static, Result<bytes::Bytes, reqwest::Error>>;

// #[cfg(not(target_arch = "wasm32"))]
// pub type BytesStream = Pin<Box<dyn Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Send>>;

#[cfg(not(target_arch = "wasm32"))]
pub type BytesStream = stream::BoxStream<'static, Result<bytes::Bytes, reqwest::Error>>;

pin_project_lite::pin_project! {
#[project = PostStreamProject]
enum HttpPostStream<F, R> {
NotStarted{#[pin] fut: F},
// `Reqwest::bytes_stream` returns `impl Stream` rather than a type generic,
// so we can't use a type generic here
// this makes wasm a bit tricky.
Started {
#[pin] http: BytesStream,
remaining: Vec<u8>,
_marker: PhantomData<R>,
},
}
}

impl<F, R> Stream for HttpPostStream<F, R>
where
F: Future<Output = Result<Response, reqwest::Error>>,
for<'de> R: Send + Deserialize<'de>,
{
type Item = Result<R, Error>;

fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
use std::task::Poll::*;
match self.as_mut().project() {
PostStreamProject::NotStarted { fut } => match fut.poll(cx) {
Ready(response) => {
let s = response.unwrap().bytes_stream();
self.set(Self::started(s));
self.as_mut().poll_next(cx)
}
Pending => {
cx.waker().wake_by_ref();
Pending
}
},
PostStreamProject::Started {
ref mut http,
ref mut remaining,
..
} => {
let mut pinned = std::pin::pin!(http);
let next = pinned.as_mut().poll_next(cx);
Self::on_bytes(next, remaining, cx)
}
}
}
}

impl<F, R> HttpPostStream<F, R>
where
R: Send,
{
#[cfg(not(target_arch = "wasm32"))]
fn started(
http: impl Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Send + 'static,
) -> Self {
Self::Started {
http: http.boxed(),
remaining: Vec::new(),
_marker: PhantomData,
}
}

#[cfg(target_arch = "wasm32")]
fn started(http: impl Stream<Item = Result<bytes::Bytes, reqwest::Error>> + 'static) -> Self {
Self::Started {
http: http.boxed_local(),
remaining: Vec::new(),
_marker: PhantomData,
}
}
}

impl<F, R> HttpPostStream<F, R>
where
F: Future<Output = Result<Response, reqwest::Error>>,
for<'de> R: Deserialize<'de> + DeserializeOwned + Send,
{
fn new(request: F) -> Self {
Self::NotStarted { fut: request }
}

fn on_bytes(
p: Poll<Option<Result<bytes::Bytes, reqwest::Error>>>,
remaining: &mut Vec<u8>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<<Self as Stream>::Item>> {
use futures::task::Poll::*;
match p {
Ready(Some(bytes)) => {
let bytes = bytes.map_err(|e| {
Error::new(ErrorKind::SubscriptionUpdateError).with(e.to_string())
})?;
let bytes = &[remaining.as_ref(), bytes.as_ref()].concat();
let de = Deserializer::from_slice(bytes);
let mut stream = de.into_iter::<GrpcResponse<R>>();
'messages: loop {
tracing::debug!("Waiting on next response ...");
let response = stream.next();
let res = match response {
Some(Ok(GrpcResponse::Ok(response))) => Ok(response),
Some(Ok(GrpcResponse::SubscriptionItem(item))) => Ok(item.result),
Some(Ok(GrpcResponse::Err(e))) => {
Err(Error::new(ErrorKind::MlsError).with(e.message))
}
Some(Err(e)) => {
if e.is_eof() {
*remaining = (&**bytes)[stream.byte_offset()..].to_vec();
return Pending;
} else {
Err(Error::new(ErrorKind::MlsError).with(e.to_string()))
}
}
Some(Ok(GrpcResponse::Empty {})) => continue 'messages,
None => return Ready(None),
};
return Ready(Some(res));
}
}
Ready(None) => Ready(None),
Pending => {
cx.waker().wake_by_ref();
Pending
}
}
}
}

#[cfg(not(target_arch = "wasm32"))]
impl<F, R> HttpPostStream<F, R>
where
F: Future<Output = Result<Response, reqwest::Error>> + Unpin,
for<'de> R: Deserialize<'de> + DeserializeOwned + Send,
{
/// Establish the initial HTTP Stream connection
fn establish(&mut self) -> () {
// we need to poll the future once to progress the future state &
// establish the initial POST request.
// It should always be pending
let noop_waker = futures::task::noop_waker();
let mut cx = std::task::Context::from_waker(&noop_waker);
// let mut this = Pin::new(self);
let mut this = Pin::new(self);
let _ = this.poll_next_unpin(&mut cx);
}
}

#[cfg(target_arch = "wasm32")]
impl<F, R> HttpPostStream<F, R>
where
F: Future<Output = Result<Response, reqwest::Error>>,
for<'de> R: Deserialize<'de> + DeserializeOwned + Send,
{
fn establish(&mut self) -> () {
// we need to poll the future once to progress the future state &
// establish the initial POST request.
// It should always be pending
let noop_waker = futures::task::noop_waker();
let mut cx = std::task::Context::from_waker(&noop_waker);
let mut this = unsafe { Pin::new_unchecked(self) };
let _ = this.poll_next_unpin(&mut cx);
}
}

#[cfg(target_arch = "wasm32")]
pub fn create_grpc_stream<T: Serialize + Send + 'static, R: DeserializeOwned + Send + 'static>(
request: T,
endpoint: String,
http_client: reqwest::Client,
) -> stream::LocalBoxStream<'static, Result<R, Error>> {
create_grpc_stream_inner(request, endpoint, http_client).boxed_local()
}

#[cfg(not(target_arch = "wasm32"))]
pub fn create_grpc_stream<T, R>(
request: T,
endpoint: String,
http_client: reqwest::Client,
) -> stream::BoxStream<'static, Result<R, Error>>
where
T: Serialize + 'static,
R: DeserializeOwned + Send + 'static,
{
create_grpc_stream_inner(request, endpoint, http_client).boxed()
}

fn create_grpc_stream_inner<T, R>(
request: T,
endpoint: String,
http_client: reqwest::Client,
) -> impl Stream<Item = Result<R, Error>>
where
T: Serialize + 'static,
R: DeserializeOwned + Send + 'static,
{
let request = http_client.post(endpoint).json(&request).send();
let mut http = HttpPostStream::new(request);
http.establish();
http
}
4 changes: 3 additions & 1 deletion xmtp_api_http/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
#![warn(clippy::unwrap_used)]

pub mod constants;
mod http_stream;
mod util;

use futures::stream;
use http_stream::create_grpc_stream;
use reqwest::header;
use util::{create_grpc_stream, handle_error};
use util::handle_error;
use xmtp_proto::api_client::{ClientWithMetadata, XmtpIdentityClient};
use xmtp_proto::xmtp::identity::api::v1::{
GetIdentityUpdatesRequest as GetIdentityUpdatesV2Request,
Expand Down
Loading

0 comments on commit fb4e7dd

Please sign in to comment.