From 60255240fd26d176f64241403776377670b30adb Mon Sep 17 00:00:00 2001 From: NathanosDev Date: Thu, 13 Jun 2024 18:01:12 +0200 Subject: [PATCH] fix(ic-http-gateway): impl body trait correctly --- Cargo.lock | 21 ++++ Cargo.toml | 4 + examples/http-gateway/rust/Cargo.toml | 19 +++ examples/http-gateway/rust/src/main.rs | 116 ++++++++++++++++++ packages/ic-http-gateway/Cargo.toml | 1 + .../src/client/http_gateway_client.rs | 2 + .../ic-http-gateway/src/protocol/handler.rs | 18 ++- .../src/response/http_gateway_response.rs | 91 +------------- .../src/response/response_handler.rs | 11 +- .../ic-http-gateway/tests/custom_assets.rs | 22 ++-- 10 files changed, 203 insertions(+), 102 deletions(-) create mode 100644 examples/http-gateway/rust/Cargo.toml create mode 100644 examples/http-gateway/rust/src/main.rs diff --git a/Cargo.lock b/Cargo.lock index 05fe1ab..faf1024 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -927,12 +927,31 @@ dependencies = [ "serde_cbor", ] +[[package]] +name = "http_gateway_rust" +version = "0.0.0" +dependencies = [ + "http-body-util", + "hyper", + "hyper-util", + "ic-agent", + "ic-http-gateway", + "pocket-ic", + "tokio", +] + [[package]] name = "httparse" version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d897f394bad6a705d5f4104762e116a75639e470d80901eed05a860a95cb1904" +[[package]] +name = "httpdate" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" + [[package]] name = "hyper" version = "1.3.1" @@ -946,6 +965,7 @@ dependencies = [ "http 1.1.0", "http-body", "httparse", + "httpdate", "itoa", "pin-project-lite", "smallvec", @@ -1145,6 +1165,7 @@ dependencies = [ "futures", "http 1.1.0", "http-body", + "http-body-util", "ic-agent", "ic-http-certification", "ic-response-verification", diff --git a/Cargo.toml b/Cargo.toml index e351acc..17fd875 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,6 +3,7 @@ resolver = "2" members = [ "examples/http-gateway/canister/src/custom_assets", + "examples/http-gateway/rust", "packages/ic-http-gateway", ] @@ -28,12 +29,15 @@ thiserror = "1" futures = "0.3" http = "1" http-body = "1" +http-body-util = "0.1" bytes = "1" base64 = "0.22" lazy_static = "1" serde = "1" serde_cbor = "0.11" tokio = { version = "1", features = ["full"] } +hyper = { version = "1", features = ["full"] } +hyper-util = "0.1" ic-cdk = "0.13" ic-cdk-macros = "0.13" diff --git a/examples/http-gateway/rust/Cargo.toml b/examples/http-gateway/rust/Cargo.toml new file mode 100644 index 0000000..645b07a --- /dev/null +++ b/examples/http-gateway/rust/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "http_gateway_rust" +version.workspace = true +authors.workspace = true +edition.workspace = true +repository.workspace = true +homepage.workspace = true +license.workspace = true + +[dependencies] +tokio.workspace = true +hyper.workspace = true +hyper-util.workspace = true +http-body-util.workspace = true + +ic-http-gateway.workspace = true +ic-agent.workspace = true + +pocket-ic.workspace = true diff --git a/examples/http-gateway/rust/src/main.rs b/examples/http-gateway/rust/src/main.rs new file mode 100644 index 0000000..8589d62 --- /dev/null +++ b/examples/http-gateway/rust/src/main.rs @@ -0,0 +1,116 @@ +use http_body_util::BodyExt; +use hyper::{body::Incoming, server::conn::http2, service::service_fn, Request, Response}; +use hyper_util::rt::TokioIo; +use ic_agent::Agent; +use ic_http_gateway::{HttpGatewayClient, HttpGatewayRequestArgs, HttpGatewayResponseBody}; +use pocket_ic::PocketIcBuilder; +use std::{convert::Infallible, net::SocketAddr, path::PathBuf, sync::Arc}; +use tokio::{fs::File, io::AsyncReadExt, net::TcpListener, task}; + +pub async fn load_custom_assets_wasm() -> Vec { + load_wasm("http_gateway_canister_custom_assets").await +} + +async fn load_wasm(canister: &str) -> Vec { + let file_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("../../../.dfx/local/canisters") + .join(canister) + .join(format!("{}.wasm.gz", canister)); + + load_file(file_path).await +} + +async fn load_file(file_path: PathBuf) -> Vec { + let mut file = File::open(&file_path).await.unwrap(); + + let mut buffer = Vec::new(); + file.read_to_end(&mut buffer).await.unwrap(); + + buffer +} + +fn main() { + let rt = tokio::runtime::Runtime::new().unwrap(); + let wasm_bytes = rt.block_on(async { load_custom_assets_wasm().await }); + + let pic = PocketIcBuilder::new() + .with_nns_subnet() + .with_application_subnet() + .build(); + + let canister_id = pic.create_canister(); + pic.add_cycles(canister_id, 2_000_000_000_000); + pic.install_canister(canister_id, wasm_bytes, vec![], None); + + let url = pic.auto_progress(); + + let agent = Agent::builder().with_url(url).build().unwrap(); + rt.block_on(async { + agent.fetch_root_key().await.unwrap(); + }); + + let http_gateway = HttpGatewayClient::builder() + .with_agent(agent) + .build() + .unwrap(); + + rt.block_on(async { + let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); + let listener = TcpListener::bind(addr).await.unwrap(); + + println!("Listening on: {}", addr); + + loop { + let (stream, _) = listener.accept().await.unwrap(); + let io = TokioIo::new(stream); + + let http_gateway_clone = Arc::new(http_gateway.clone()); + + let service = service_fn(move |req: Request| { + let http_gateway_clone = Arc::clone(&http_gateway_clone); + + async move { + let canister_request = Request::builder().uri(req.uri()).method(req.method()); + let collected_req = req.collect().await.unwrap().to_bytes().to_vec(); + let canister_request = canister_request.body(collected_req).unwrap(); + + let gateway_response = http_gateway_clone + .request(HttpGatewayRequestArgs { + canister_id, + canister_request, + }) + .send() + .await; + + Ok::, Infallible>( + gateway_response.canister_response, + ) + } + }); + + let local = task::LocalSet::new(); + local + .run_until(async move { + if let Err(err) = http2::Builder::new(LocalExec) + .serve_connection(io, service) + .await + { + eprintln!("Error serving connection: {:?}", err); + } + }) + .await; + } + }); +} + +#[derive(Clone, Copy, Debug)] +struct LocalExec; + +impl hyper::rt::Executor for LocalExec +where + F: std::future::Future + 'static, +{ + fn execute(&self, fut: F) { + tokio::task::spawn_local(fut); + } +} diff --git a/packages/ic-http-gateway/Cargo.toml b/packages/ic-http-gateway/Cargo.toml index f1a474b..0a43812 100644 --- a/packages/ic-http-gateway/Cargo.toml +++ b/packages/ic-http-gateway/Cargo.toml @@ -24,6 +24,7 @@ thiserror.workspace = true futures.workspace = true http.workspace = true http-body.workspace = true +http-body-util.workspace = true bytes.workspace = true ic-agent.workspace = true diff --git a/packages/ic-http-gateway/src/client/http_gateway_client.rs b/packages/ic-http-gateway/src/client/http_gateway_client.rs index 405e9a4..71a4b75 100644 --- a/packages/ic-http-gateway/src/client/http_gateway_client.rs +++ b/packages/ic-http-gateway/src/client/http_gateway_client.rs @@ -4,10 +4,12 @@ use crate::{ }; use ic_agent::Agent; +#[derive(Clone)] pub struct HttpGatewayClientArgs { pub agent: Agent, } +#[derive(Clone)] pub struct HttpGatewayClient { agent: Agent, } diff --git a/packages/ic-http-gateway/src/protocol/handler.rs b/packages/ic-http-gateway/src/protocol/handler.rs index 98b4545..8602d64 100644 --- a/packages/ic-http-gateway/src/protocol/handler.rs +++ b/packages/ic-http-gateway/src/protocol/handler.rs @@ -6,6 +6,7 @@ use crate::{ }; use candid::Principal; use http::{Response, StatusCode}; +use http_body_util::{BodyExt, Either, Full}; use ic_agent::{ agent::{RejectCode, RejectResponse}, Agent, AgentError, @@ -18,7 +19,9 @@ use ic_utils::{ }; fn create_err_response(status_code: StatusCode, msg: &str) -> CanisterResponse { - let mut response = Response::new(HttpGatewayResponseBody::Bytes(msg.as_bytes().to_vec())); + let mut response = Response::new(HttpGatewayResponseBody::Right(Full::from( + msg.as_bytes().to_vec(), + ))); *response.status_mut() = status_code; response @@ -72,7 +75,7 @@ pub async fn process_request( metadata: HttpGatewayResponseMetadata { upgraded_to_update_call: false, response_verification_version: None, - internal_error: Some(e.into()), + internal_error: Some(e), }, } } @@ -176,7 +179,10 @@ pub async fn process_request( // strategy. Performing verification for those requests would required to join all the chunks // and this could cause memory issues and possibly create DOS attack vectors. match &response_body { - HttpGatewayResponseBody::Bytes(body) => { + Either::Right(body) => { + // this unwrap should never panic because `Either::Right` will always have a full body + let body = body.clone().collect().await.unwrap().to_bytes().to_vec(); + let validation_result = validate( agent, &canister_id, @@ -188,7 +194,7 @@ pub async fn process_request( .iter() .map(|HeaderField(k, v)| (k.to_string(), v.to_string())) .collect(), - body: body.to_owned(), + body, upgrade: None, }, allow_skip_verification, @@ -327,7 +333,7 @@ fn handle_agent_error(error: &AgentError) -> CanisterResponse { reject_code: RejectCode::DestinationInvalid, reject_message, .. - }) => create_err_response(StatusCode::NOT_FOUND, &reject_message), + }) => create_err_response(StatusCode::NOT_FOUND, reject_message), // If the result is a Replica error, returns the 500 code and message. There is no information // leak here because a user could use `dfx` to get the same reply. @@ -343,7 +349,7 @@ fn handle_agent_error(error: &AgentError) -> CanisterResponse { reject_code: RejectCode::DestinationInvalid, reject_message, .. - }) => create_err_response(StatusCode::NOT_FOUND, &reject_message), + }) => create_err_response(StatusCode::NOT_FOUND, reject_message), // If the result is a Replica error, returns the 500 code and message. There is no information // leak here because a user could use `dfx` to get the same reply. diff --git a/packages/ic-http-gateway/src/response/http_gateway_response.rs b/packages/ic-http-gateway/src/response/http_gateway_response.rs index a809381..8459cf6 100644 --- a/packages/ic-http-gateway/src/response/http_gateway_response.rs +++ b/packages/ic-http-gateway/src/response/http_gateway_response.rs @@ -1,20 +1,16 @@ use bytes::Bytes; -use futures::Stream; +use futures::stream::BoxStream; use http::Response; -use http_body::{Body, Frame, SizeHint}; +use http_body::Frame; +use http_body_util::{Either, Full, StreamBody}; use ic_agent::AgentError; -use std::{ - fmt::{Debug, Formatter}, - pin::Pin, - task::{Context, Poll}, -}; +use std::fmt::Debug; use crate::HttpGatewayError; pub type CanisterResponse = Response; /// A response from the HTTP gateway. -#[derive(Debug)] pub struct HttpGatewayResponse { /// The certified response, excluding uncertified headers. /// If response verification v1 is used, the original, uncertified headers are returned. @@ -39,84 +35,9 @@ pub struct HttpGatewayResponseMetadata { pub internal_error: Option, } -/// The body of an HTTP gateway response. -#[derive(Debug)] -pub enum HttpGatewayResponseBody { - /// A byte array representing the response body. - Bytes(Vec), +pub type HttpGatewayResponseBody = Either>; - /// A stream of response body chunks. - Stream(ResponseBodyStream), -} - -impl Body for HttpGatewayResponseBody { - type Data = Bytes; - type Error = AgentError; - - fn poll_frame( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>>> { - match self.get_mut() { - HttpGatewayResponseBody::Bytes(bytes) => { - Poll::Ready(Some(Ok(Frame::data(Bytes::from(bytes.clone()))))) - } - HttpGatewayResponseBody::Stream(stream) => Stream::poll_next(Pin::new(stream), cx), - } - } - - fn is_end_stream(&self) -> bool { - match self { - HttpGatewayResponseBody::Bytes(_) => true, - HttpGatewayResponseBody::Stream(_) => false, - } - } - - fn size_hint(&self) -> SizeHint { - match self { - HttpGatewayResponseBody::Bytes(bytes) => SizeHint::with_exact(bytes.len() as u64), - HttpGatewayResponseBody::Stream(stream) => { - let (lower, upper) = stream.size_hint(); - - let mut size_hint = SizeHint::new(); - size_hint.set_lower(lower as u64); - - if let Some(upper) = upper { - size_hint.set_upper(upper as u64); - } - - size_hint - } - } - } -} +pub type ResponseBodyStream = StreamBody>; /// An item in a response body stream. pub type ResponseBodyStreamItem = Result, AgentError>; - -/// A stream of response body chunks. -pub struct ResponseBodyStream { - inner: Pin + 'static>>, -} - -impl ResponseBodyStream { - pub fn new(stream: impl Stream + 'static) -> Self { - Self { - inner: Box::pin(stream), - } - } -} - -impl Debug for ResponseBodyStream { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.debug_struct("ResponseBodyStream").finish() - } -} - -impl Stream for ResponseBodyStream { - type Item = ResponseBodyStreamItem; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.inner.as_mut().poll_next(cx) - } -} diff --git a/packages/ic-http-gateway/src/response/response_handler.rs b/packages/ic-http-gateway/src/response/response_handler.rs index a6f8711..04f0ecc 100644 --- a/packages/ic-http-gateway/src/response/response_handler.rs +++ b/packages/ic-http-gateway/src/response/response_handler.rs @@ -2,6 +2,7 @@ use crate::{HttpGatewayResponseBody, ResponseBodyStream}; use bytes::Bytes; use futures::{stream, Stream, StreamExt, TryStreamExt}; use http_body::Frame; +use http_body_util::Full; use ic_agent::{Agent, AgentError}; use ic_utils::{ call::SyncCall, @@ -29,7 +30,9 @@ pub async fn get_body_and_streaming_body( // if we already have the full body, we can return it early let Some(StreamingStrategy::Callback(callback_strategy)) = response.streaming_strategy.clone() else { - return Ok(HttpGatewayResponseBody::Bytes(response.body.clone())); + return Ok(HttpGatewayResponseBody::Right(Full::from( + response.body.clone(), + ))); }; let (streamed_body, token) = create_stream( @@ -64,13 +67,13 @@ pub async fn get_body_and_streaming_body( streamed_body, ); - return Ok(HttpGatewayResponseBody::Stream(body_stream)); + return Ok(HttpGatewayResponseBody::Left(body_stream)); }; // if we no longer have a token at this point, // we were able to collect the response within the allow certified callback limit, // return this collected response as a standard response body so it will be verified - Ok(HttpGatewayResponseBody::Bytes(streamed_body)) + Ok(HttpGatewayResponseBody::Right(Full::from(streamed_body))) } fn create_body_stream( @@ -88,7 +91,7 @@ fn create_body_stream( .map(|x| async move { x }) .buffered(STREAM_CALLBACK_BUFFER); - ResponseBodyStream::new(body_stream) + ResponseBodyStream::new(Box::pin(body_stream)) } fn create_stream( diff --git a/packages/ic-http-gateway/tests/custom_assets.rs b/packages/ic-http-gateway/tests/custom_assets.rs index ecaaa11..0eccfa5 100644 --- a/packages/ic-http-gateway/tests/custom_assets.rs +++ b/packages/ic-http-gateway/tests/custom_assets.rs @@ -1,8 +1,7 @@ use http::Request; +use http_body_util::BodyExt; use ic_agent::Agent; -use ic_http_gateway::{ - HttpGatewayClient, HttpGatewayRequestArgs, HttpGatewayResponseBody, HttpGatewayResponseMetadata, -}; +use ic_http_gateway::{HttpGatewayClient, HttpGatewayRequestArgs, HttpGatewayResponseMetadata}; use pocket_ic::PocketIcBuilder; mod utils; @@ -68,10 +67,19 @@ fn test_custom_assets_index_html() { ("content-type", "text/html"), ] ); - matches!( - response.canister_response.body(), - HttpGatewayResponseBody::Bytes(body) if body == index_html - ); + + rt.block_on(async { + let body = response + .canister_response + .into_body() + .collect() + .await + .unwrap() + .to_bytes() + .to_vec(); + + assert_eq!(body, index_html); + }); assert_response_metadata( response.metadata,