diff --git a/Cargo.lock b/Cargo.lock index 04a6f822f..4c88bbeed 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -580,9 +580,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.8.0" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ac0150caa2ae65ca5bd83f25c7de183dea78d4d366469f148435e2acfbad0da" +checksum = "325918d6fe32f23b19878fe4b34794ae41fc19ddbe53b10571a4874d44ffd39b" dependencies = [ "serde", ] @@ -7232,6 +7232,7 @@ version = "0.1.0" dependencies = [ "async-stream", "async-trait", + "bytes", "futures", "reqwest 0.12.9", "serde", diff --git a/xmtp_api_http/Cargo.toml b/xmtp_api_http/Cargo.toml index b26a414a9..dae2a490c 100644 --- a/xmtp_api_http/Cargo.toml +++ b/xmtp_api_http/Cargo.toml @@ -18,6 +18,7 @@ thiserror = "2.0" tokio = { workspace = true, features = ["sync", "rt", "macros"] } xmtp_proto = { path = "../xmtp_proto", features = ["proto_full"] } async-trait = "0.1" +bytes = "1.9" [dev-dependencies] xmtp_proto = { path = "../xmtp_proto", features = ["test-utils"] } diff --git a/xmtp_api_http/src/http_stream.rs b/xmtp_api_http/src/http_stream.rs new file mode 100644 index 000000000..0a5f83014 --- /dev/null +++ b/xmtp_api_http/src/http_stream.rs @@ -0,0 +1,129 @@ +//! Streams that work with HTTP POST requests + +use crate::util::GrpcResponse; +use futures::{ + stream::{self, Stream, StreamExt}, + Future, +}; +use reqwest::Response; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use serde_json::Deserializer; +use std::pin::Pin; +use xmtp_proto::{Error, ErrorKind}; + +#[derive(Deserialize, Serialize, Debug)] +pub(crate) struct SubscriptionItem { + pub result: T, +} + +enum HttpPostStream +where + F: Future>, +{ + NotStarted(F), + // NotStarted(Box>>), + Started(Pin> + Unpin + Send>>), +} + +impl Stream for HttpPostStream +where + F: Future> + Unpin, +{ + type Item = Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + use futures::task::Poll::*; + use HttpPostStream::*; + match self.as_mut().get_mut() { + NotStarted(ref mut f) => { + tracing::info!("Polling"); + let f = std::pin::pin!(f); + match f.poll(cx) { + Ready(response) => { + let s = response.unwrap().bytes_stream(); + self.set(Self::Started(Box::pin(s.boxed()))); + self.poll_next(cx) + } + Pending => { + // cx.waker().wake_by_ref(); + Pending + } + } + } + Started(s) => s.poll_next_unpin(cx), + } + } +} + +#[cfg(target_arch = "wasm32")] +pub fn create_grpc_stream< + T: Serialize + Send + 'static, + R: DeserializeOwned + Send + std::fmt::Debug + 'static, +>( + request: T, + endpoint: String, + http_client: reqwest::Client, +) -> stream::LocalBoxStream<'static, Result> { + create_grpc_stream_inner(request, endpoint, http_client).boxed_local() +} + +#[cfg(not(target_arch = "wasm32"))] +pub fn create_grpc_stream< + T: Serialize + Send + 'static, + R: DeserializeOwned + Send + std::fmt::Debug + 'static, +>( + request: T, + endpoint: String, + http_client: reqwest::Client, +) -> stream::BoxStream<'static, Result> { + create_grpc_stream_inner(request, endpoint, http_client).boxed() +} + +pub fn create_grpc_stream_inner< + T: Serialize + Send + 'static, + R: DeserializeOwned + Send + std::fmt::Debug + 'static, +>( + request: T, + endpoint: String, + http_client: reqwest::Client, +) -> impl Stream> { + let request = http_client.post(endpoint).json(&request).send(); + let http_stream = HttpPostStream::NotStarted(request); + + async_stream::stream! { + tracing::info!("spawning grpc http stream"); + let mut remaining = vec![]; + for await bytes in http_stream { + let bytes = bytes + .map_err(|e| Error::new(ErrorKind::SubscriptionUpdateError).with(e.to_string()))?; + let bytes = &[remaining.as_ref(), bytes.as_ref()].concat(); + let de = Deserializer::from_slice(bytes); + let mut stream = de.into_iter::>(); + 'messages: loop { + tracing::debug!("Waiting on next response ..."); + let response = stream.next(); + let res = match response { + Some(Ok(GrpcResponse::Ok(response))) => Ok(response), + Some(Ok(GrpcResponse::SubscriptionItem(item))) => Ok(item.result), + Some(Ok(GrpcResponse::Err(e))) => { + Err(Error::new(ErrorKind::MlsError).with(e.message)) + } + Some(Err(e)) => { + if e.is_eof() { + remaining = (&**bytes)[stream.byte_offset()..].to_vec(); + break 'messages; + } else { + Err(Error::new(ErrorKind::MlsError).with(e.to_string())) + } + } + Some(Ok(GrpcResponse::Empty {})) => continue 'messages, + None => break 'messages, + }; + yield res; + } + } + } +} diff --git a/xmtp_api_http/src/lib.rs b/xmtp_api_http/src/lib.rs index 80489fb3c..8a3f972c4 100755 --- a/xmtp_api_http/src/lib.rs +++ b/xmtp_api_http/src/lib.rs @@ -1,11 +1,13 @@ #![warn(clippy::unwrap_used)] pub mod constants; +mod http_stream; mod util; use futures::stream; +use http_stream::create_grpc_stream; use reqwest::header; -use util::{create_grpc_stream, handle_error}; +use util::handle_error; use xmtp_proto::api_client::{ClientWithMetadata, XmtpIdentityClient}; use xmtp_proto::xmtp::identity::api::v1::{ GetIdentityUpdatesRequest as GetIdentityUpdatesV2Request, diff --git a/xmtp_api_http/src/util.rs b/xmtp_api_http/src/util.rs index ab151c9f4..34c878c4a 100644 --- a/xmtp_api_http/src/util.rs +++ b/xmtp_api_http/src/util.rs @@ -1,9 +1,5 @@ -use futures::{ - stream::{self, StreamExt}, - Stream, -}; +use crate::http_stream::SubscriptionItem; use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use serde_json::Deserializer; use std::io::Read; use xmtp_proto::{Error, ErrorKind}; @@ -23,11 +19,6 @@ pub(crate) struct ErrorResponse { details: Vec, } -#[derive(Deserialize, Serialize, Debug)] -pub(crate) struct SubscriptionItem { - pub result: T, -} - /// handle JSON response from gRPC, returning either /// the expected deserialized response object or a gRPC [`Error`] pub fn handle_error(reader: R) -> Result @@ -43,80 +34,6 @@ where } } -#[cfg(target_arch = "wasm32")] -pub fn create_grpc_stream< - T: Serialize + Send + 'static, - R: DeserializeOwned + Send + std::fmt::Debug + 'static, ->( - request: T, - endpoint: String, - http_client: reqwest::Client, -) -> stream::LocalBoxStream<'static, Result> { - create_grpc_stream_inner(request, endpoint, http_client).boxed_local() -} - -#[cfg(not(target_arch = "wasm32"))] -pub fn create_grpc_stream< - T: Serialize + Send + 'static, - R: DeserializeOwned + Send + std::fmt::Debug + 'static, ->( - request: T, - endpoint: String, - http_client: reqwest::Client, -) -> stream::BoxStream<'static, Result> { - create_grpc_stream_inner(request, endpoint, http_client).boxed() -} - -pub fn create_grpc_stream_inner< - T: Serialize + Send + 'static, - R: DeserializeOwned + Send + std::fmt::Debug + 'static, ->( - request: T, - endpoint: String, - http_client: reqwest::Client, -) -> impl Stream> { - async_stream::stream! { - tracing::info!("Spawning grpc http stream"); - let request = http_client - .post(endpoint) - .json(&request) - .send() - .await - .map_err(|e| Error::new(ErrorKind::MlsError).with(e))?; - tracing::debug!("Got Request, getting byte stream"); - let mut remaining = vec![]; - for await bytes in request.bytes_stream() { - let bytes = bytes - .map_err(|e| Error::new(ErrorKind::SubscriptionUpdateError).with(e.to_string()))?; - let bytes = &[remaining.as_ref(), bytes.as_ref()].concat(); - let de = Deserializer::from_slice(bytes); - let mut stream = de.into_iter::>(); - 'messages: loop { - tracing::debug!("Waiting on next response ..."); - let response = stream.next(); - let res = match response { - Some(Ok(GrpcResponse::Ok(response))) => Ok(response), - Some(Ok(GrpcResponse::SubscriptionItem(item))) => Ok(item.result), - Some(Ok(GrpcResponse::Err(e))) => { - Err(Error::new(ErrorKind::MlsError).with(e.message)) - } - Some(Err(e)) => { - if e.is_eof() { - remaining = (&**bytes)[stream.byte_offset()..].to_vec(); - break 'messages; - } else { - Err(Error::new(ErrorKind::MlsError).with(e.to_string())) - } - } - Some(Ok(GrpcResponse::Empty {})) => continue 'messages, - None => break 'messages, - }; - yield res; - } - } - } -} - #[cfg(feature = "test-utils")] #[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)] #[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))] diff --git a/xmtp_mls/src/subscriptions.rs b/xmtp_mls/src/subscriptions.rs index bac20877d..bd3c583f1 100644 --- a/xmtp_mls/src/subscriptions.rs +++ b/xmtp_mls/src/subscriptions.rs @@ -556,6 +556,7 @@ pub(crate) mod tests { let alice_group = alice .create_group(None, GroupMetadataOptions::default()) .unwrap(); + tracing::info!("Group Id = [{}]", hex::encode(&alice_group.group_id)); alice_group .add_members_by_inbox_id(&[bob.inbox_id()])