Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(wasm): unblock streams in the browser #1444

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 11 additions & 12 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ ctor = "0.2"
ed25519 = "2.2.3"
ed25519-dalek = { version = "2.1.1", features = ["zeroize"] }
ethers = { version = "2.0", default-features = false }
futures = "0.3.30"
futures-core = "0.3.30"
futures = { version = "0.3.30", default-features = false }
getrandom = { version = "0.2", default-features = false }
hex = "0.4.3"
hkdf = "0.12.3"
Expand All @@ -61,16 +60,7 @@ thiserror = "2.0"
tls_codec = "0.4.1"
tokio = { version = "1.35.1", default-features = false }
uuid = "1.10"
wasm-timer = "0.2"
web-time = "1.1"
# Changing this version and rustls may potentially break the android build. Use Caution.
# Test with Android and Swift first.
# Its probably preferable to one day use https://github.com/rustls/rustls-platform-verifier
# Until then, always test agains iOS/Android after updating these dependencies & making a PR
# Related Issues:
# - https://github.com/seanmonstar/reqwest/issues/2159
# - https://github.com/hyperium/tonic/pull/1974
# - https://github.com/rustls/rustls-platform-verifier/issues/58
bincode = "1.3"
console_error_panic_hook = "0.1"
const_format = "0.2"
Expand All @@ -87,6 +77,14 @@ openssl = { version = "0.10", features = ["vendored"] }
openssl-sys = { version = "0.9", features = ["vendored"] }
parking_lot = "0.12.3"
sqlite-web = "0.0.1"
# Changing this version and rustls may potentially break the android build. Use Caution.
# Test with Android and Swift first.
# Its probably preferable to one day use https://github.com/rustls/rustls-platform-verifier
# Until then, always test agains iOS/Android after updating these dependencies & making a PR
# Related Issues:
# - https://github.com/seanmonstar/reqwest/issues/2159
# - https://github.com/hyperium/tonic/pull/1974
# - https://github.com/rustls/rustls-platform-verifier/issues/58
tonic = { version = "0.12", default-features = false }
tracing = { version = "0.1", features = ["log"] }
tracing-subscriber = { version = "0.3", default-features = false }
Expand All @@ -101,7 +99,8 @@ criterion = { version = "0.5", features = [
"html_reports",
"async_tokio",
]}
once_cell = "1.2"
once_cell = "1.2"
pin-project-lite = "0.2"

# Internal Crate Dependencies
xmtp_api_grpc = { path = "xmtp_api_grpc" }
Expand Down
4 changes: 4 additions & 0 deletions common/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ pub fn rand_u64() -> u64 {
crypto_utils::rng().gen()
}

pub fn rand_i64() -> i64 {
crypto_utils::rng().gen()
}

#[cfg(not(target_arch = "wasm32"))]
pub fn tmp_path() -> String {
let db_name = crate::rand_string::<24>();
Expand Down
2 changes: 1 addition & 1 deletion xmtp_api_grpc/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ version.workspace = true
async-stream.workspace = true
async-trait = "0.1"
base64.workspace = true
futures.workspace = true
futures = { workspace = true, features = ["alloc"] }
hex.workspace = true
prost = { workspace = true, features = ["prost-derive"] }
tokio = { workspace = true, features = ["macros", "time"] }
Expand Down
5 changes: 3 additions & 2 deletions xmtp_api_http/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,17 @@ license.workspace = true
crate-type = ["cdylib", "rlib"]

[dependencies]
async-stream.workspace = true
futures = { workspace = true }
tracing.workspace = true
reqwest = { version = "0.12.5", features = ["json", "stream"] }
serde = { workspace = true }
serde_json = { workspace = true }
thiserror = "2.0"
thiserror.workspace = true
tokio = { workspace = true, features = ["sync", "rt", "macros"] }
xmtp_proto = { path = "../xmtp_proto", features = ["proto_full"] }
async-trait = "0.1"
bytes = "1.9"
pin-project-lite = "0.2.15"

[dev-dependencies]
xmtp_proto = { path = "../xmtp_proto", features = ["test-utils"] }
Expand Down
231 changes: 231 additions & 0 deletions xmtp_api_http/src/http_stream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
//! Streams that work with HTTP POST requests

use crate::util::GrpcResponse;
use futures::{
stream::{self, Stream, StreamExt},
Future,
};
use reqwest::Response;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::Deserializer;
use std::{marker::PhantomData, pin::Pin, task::Poll};
use xmtp_proto::{Error, ErrorKind};

#[derive(Deserialize, Serialize, Debug)]
pub(crate) struct SubscriptionItem<T> {
pub result: T,
}

#[cfg(target_arch = "wasm32")]
pub type BytesStream = stream::LocalBoxStream<'static, Result<bytes::Bytes, reqwest::Error>>;

// #[cfg(not(target_arch = "wasm32"))]
// pub type BytesStream = Pin<Box<dyn Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Send>>;

#[cfg(not(target_arch = "wasm32"))]
pub type BytesStream = stream::BoxStream<'static, Result<bytes::Bytes, reqwest::Error>>;

pin_project_lite::pin_project! {
#[project = PostStreamProject]
enum HttpPostStream<F, R> {
NotStarted{#[pin] fut: F},
// `Reqwest::bytes_stream` returns `impl Stream` rather than a type generic,
// so we can't use a type generic here
// this makes wasm a bit tricky.
Started {
#[pin] http: BytesStream,
remaining: Vec<u8>,
_marker: PhantomData<R>,
},
}
}

impl<F, R> Stream for HttpPostStream<F, R>
where
F: Future<Output = Result<Response, reqwest::Error>>,
for<'de> R: Send + Deserialize<'de>,
{
type Item = Result<R, Error>;

fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
use std::task::Poll::*;
match self.as_mut().project() {
PostStreamProject::NotStarted { fut } => match fut.poll(cx) {
Ready(response) => {
let s = response.unwrap().bytes_stream();
self.set(Self::started(s));
self.as_mut().poll_next(cx)
}
Pending => {
cx.waker().wake_by_ref();
Pending
}
},
PostStreamProject::Started {
ref mut http,
ref mut remaining,
..
} => {
let mut pinned = std::pin::pin!(http);
let next = pinned.as_mut().poll_next(cx);
Self::on_bytes(next, remaining, cx)
}
}
}
}

impl<F, R> HttpPostStream<F, R>
where
R: Send,
{
#[cfg(not(target_arch = "wasm32"))]
fn started(
http: impl Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Send + 'static,
) -> Self {
Self::Started {
http: http.boxed(),
remaining: Vec::new(),
_marker: PhantomData,
}
}

#[cfg(target_arch = "wasm32")]
fn started(http: impl Stream<Item = Result<bytes::Bytes, reqwest::Error>> + 'static) -> Self {
Self::Started {
http: http.boxed_local(),
remaining: Vec::new(),
_marker: PhantomData,
}
}
}

impl<F, R> HttpPostStream<F, R>
where
F: Future<Output = Result<Response, reqwest::Error>>,
for<'de> R: Deserialize<'de> + DeserializeOwned + Send,
{
fn new(request: F) -> Self {
Self::NotStarted { fut: request }
}

fn on_bytes(
p: Poll<Option<Result<bytes::Bytes, reqwest::Error>>>,
remaining: &mut Vec<u8>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<<Self as Stream>::Item>> {
use futures::task::Poll::*;
match p {
Ready(Some(bytes)) => {
let bytes = bytes.map_err(|e| {
Error::new(ErrorKind::SubscriptionUpdateError).with(e.to_string())
})?;
let bytes = &[remaining.as_ref(), bytes.as_ref()].concat();
let de = Deserializer::from_slice(bytes);
let mut stream = de.into_iter::<GrpcResponse<R>>();
'messages: loop {
tracing::debug!("Waiting on next response ...");
let response = stream.next();
let res = match response {
Some(Ok(GrpcResponse::Ok(response))) => Ok(response),
Some(Ok(GrpcResponse::SubscriptionItem(item))) => Ok(item.result),
Some(Ok(GrpcResponse::Err(e))) => {
Err(Error::new(ErrorKind::MlsError).with(e.message))
}
Some(Err(e)) => {
if e.is_eof() {
*remaining = (&**bytes)[stream.byte_offset()..].to_vec();
return Pending;
} else {
Err(Error::new(ErrorKind::MlsError).with(e.to_string()))
}
}
Some(Ok(GrpcResponse::Empty {})) => continue 'messages,
None => return Ready(None),
};
return Ready(Some(res));
}
}
Ready(None) => Ready(None),
Pending => {
cx.waker().wake_by_ref();
Pending
}
}
}
}

#[cfg(not(target_arch = "wasm32"))]
impl<F, R> HttpPostStream<F, R>
where
F: Future<Output = Result<Response, reqwest::Error>> + Unpin,
for<'de> R: Deserialize<'de> + DeserializeOwned + Send,
{
/// Establish the initial HTTP Stream connection
fn establish(&mut self) -> () {
// we need to poll the future once to progress the future state &
// establish the initial POST request.
// It should always be pending
let noop_waker = futures::task::noop_waker();
let mut cx = std::task::Context::from_waker(&noop_waker);
// let mut this = Pin::new(self);
let mut this = Pin::new(self);
let _ = this.poll_next_unpin(&mut cx);
}
}

#[cfg(target_arch = "wasm32")]
impl<F, R> HttpPostStream<F, R>
where
F: Future<Output = Result<Response, reqwest::Error>>,
for<'de> R: Deserialize<'de> + DeserializeOwned + Send,
{
fn establish(&mut self) -> () {
// we need to poll the future once to progress the future state &
// establish the initial POST request.
// It should always be pending
let noop_waker = futures::task::noop_waker();
let mut cx = std::task::Context::from_waker(&noop_waker);
let mut this = unsafe { Pin::new_unchecked(self) };
let _ = this.poll_next_unpin(&mut cx);
}
}

#[cfg(target_arch = "wasm32")]
pub fn create_grpc_stream<T: Serialize + Send + 'static, R: DeserializeOwned + Send + 'static>(
request: T,
endpoint: String,
http_client: reqwest::Client,
) -> stream::LocalBoxStream<'static, Result<R, Error>> {
create_grpc_stream_inner(request, endpoint, http_client).boxed_local()
}

#[cfg(not(target_arch = "wasm32"))]
pub fn create_grpc_stream<T, R>(
request: T,
endpoint: String,
http_client: reqwest::Client,
) -> stream::BoxStream<'static, Result<R, Error>>
where
T: Serialize + 'static,
R: DeserializeOwned + Send + 'static,
{
create_grpc_stream_inner(request, endpoint, http_client).boxed()
}

fn create_grpc_stream_inner<T, R>(
request: T,
endpoint: String,
http_client: reqwest::Client,
) -> impl Stream<Item = Result<R, Error>>
where
T: Serialize + 'static,
R: DeserializeOwned + Send + 'static,
{
let request = http_client.post(endpoint).json(&request).send();
let mut http = HttpPostStream::new(request);
http.establish();
http
}
4 changes: 3 additions & 1 deletion xmtp_api_http/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
#![warn(clippy::unwrap_used)]

pub mod constants;
mod http_stream;
mod util;

use futures::stream;
use http_stream::create_grpc_stream;
use reqwest::header;
use util::{create_grpc_stream, handle_error};
use util::handle_error;
use xmtp_proto::api_client::{ClientWithMetadata, XmtpIdentityClient};
use xmtp_proto::xmtp::identity::api::v1::{
GetIdentityUpdatesRequest as GetIdentityUpdatesV2Request,
Expand Down
Loading