Skip to content

Commit

Permalink
feat(ic-http-gateway): add compatibility with http-body crate
Browse files Browse the repository at this point in the history
remove unnecessary lifetime parameter, make unsafe_allow_skip_verification a setter, extract path and query on behalf of consumer
  • Loading branch information
nathanosdev committed May 29, 2024
1 parent b5367a4 commit da864b6
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 27 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ codegen-units = 1
thiserror = "1"
futures = "0.3"
http = "1"
http-body = "1"
bytes = "1"
base64 = "0.22"
lazy_static = "1"
serde = "1"
Expand Down
2 changes: 2 additions & 0 deletions packages/ic-http-gateway/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ homepage.workspace = true
thiserror.workspace = true
futures.workspace = true
http.workspace = true
http-body.workspace = true
bytes.workspace = true

ic-agent.workspace = true
ic-utils.workspace = true
Expand Down
13 changes: 10 additions & 3 deletions packages/ic-http-gateway/src/protocol/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,16 @@ use ic_utils::{
};

fn convert_request(request: CanisterRequest) -> HttpGatewayResult<HttpRequest> {
let uri = request.uri();
let mut url = uri.path().to_string();
if let Some(query) = uri.query() {
url.push('?');
url.push_str(query);
}

Ok(HttpRequest {
method: request.method().to_string(),
url: request.uri().to_string(),
url,
headers: request
.headers()
.into_iter()
Expand All @@ -46,7 +53,7 @@ pub async fn process_request(
request: CanisterRequest,
canister_id: Principal,
allow_skip_verification: bool,
) -> HttpGatewayResult<HttpGatewayResponse<'_>> {
) -> HttpGatewayResult<HttpGatewayResponse> {
let http_request = convert_request(request)?;

let canister = HttpRequestCanister::create(agent, canister_id);
Expand Down Expand Up @@ -246,7 +253,7 @@ pub async fn process_request(
})
}

fn handle_agent_error<'a>(error: AgentError) -> HttpGatewayResult<CanisterResponse<'a>> {
fn handle_agent_error(error: AgentError) -> HttpGatewayResult<CanisterResponse> {
match error {
// Turn all `DestinationInvalid`s into 404
AgentError::CertifiedReject(RejectResponse {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,16 @@ impl<'a> HttpGatewayRequestBuilder<'a> {
}
}

pub fn unsafe_allow_skip_verification(mut self) -> Self {
self.allow_skip_verification = true;
pub fn unsafe_set_allow_skip_verification(
&mut self,
allow_skip_verification: bool,
) -> &mut Self {
self.allow_skip_verification = allow_skip_verification;

self
}

pub async fn send(self) -> HttpGatewayResult<HttpGatewayResponse<'a>> {
pub async fn send(self) -> HttpGatewayResult<HttpGatewayResponse> {
process_request(
self.args.agent,
self.args.request_args.canister_request,
Expand Down
68 changes: 56 additions & 12 deletions packages/ic-http-gateway/src/response/http_gateway_response.rs
Original file line number Diff line number Diff line change
@@ -1,61 +1,105 @@
use bytes::Bytes;
use futures::Stream;
use http::Response;
use http_body::{Body, Frame, SizeHint};
use ic_agent::AgentError;
use std::{
fmt::{Debug, Formatter},
pin::Pin,
task::{Context, Poll},
};

pub type CanisterResponse<'a> = Response<HttpGatewayResponseBody<'a>>;
pub type CanisterResponse = Response<HttpGatewayResponseBody>;

/// A response from the HTTP gateway.
#[derive(Debug)]
pub struct HttpGatewayResponse<'a> {
pub struct HttpGatewayResponse {
/// The certified response, excluding uncertified headers.
/// If response verification v1 is used, the original, uncertified headers are returned.
pub canister_response: CanisterResponse<'a>,
pub canister_response: CanisterResponse,

/// Additional metadata regarding the response.
pub metadata: HttpGatewayResponseMetadata,
}

/// The body of an HTTP gateway response.
#[derive(Debug)]
pub enum HttpGatewayResponseBody<'a> {
pub enum HttpGatewayResponseBody {
/// A byte array representing the response body.
Bytes(Vec<u8>),

/// A stream of response body chunks.
Stream(ResponseBodyStream<'a>),
Stream(ResponseBodyStream),
}

impl Body for HttpGatewayResponseBody {
type Data = Bytes;
type Error = AgentError;

fn poll_frame(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
match self.get_mut() {
HttpGatewayResponseBody::Bytes(bytes) => {
Poll::Ready(Some(Ok(Frame::data(Bytes::from(bytes.clone())))))
}
HttpGatewayResponseBody::Stream(stream) => Stream::poll_next(Pin::new(stream), cx),
}
}

fn is_end_stream(&self) -> bool {
match self {
HttpGatewayResponseBody::Bytes(_) => true,
HttpGatewayResponseBody::Stream(_) => false,
}
}

fn size_hint(&self) -> SizeHint {
match self {
HttpGatewayResponseBody::Bytes(bytes) => SizeHint::with_exact(bytes.len() as u64),
HttpGatewayResponseBody::Stream(stream) => {
let (lower, upper) = stream.size_hint();

let mut size_hint = SizeHint::new();
size_hint.set_lower(lower as u64);

if let Some(upper) = upper {
size_hint.set_upper(upper as u64);
}

size_hint
}
}
}
}

/// An item in a response body stream.
pub type ResponseBodyStreamItem = Result<Vec<u8>, AgentError>;
pub type ResponseBodyStreamItem = Result<Frame<Bytes>, AgentError>;

/// A stream of response body chunks.
pub struct ResponseBodyStream<'a> {
inner: Pin<Box<dyn Stream<Item = ResponseBodyStreamItem> + 'a>>,
pub struct ResponseBodyStream {
inner: Pin<Box<dyn Stream<Item = ResponseBodyStreamItem> + 'static>>,
}

// Trait bound added for cloning.
impl<'a> ResponseBodyStream<'a> {
pub fn new(stream: impl Stream<Item = ResponseBodyStreamItem> + 'a) -> Self {
impl ResponseBodyStream {
pub fn new(stream: impl Stream<Item = ResponseBodyStreamItem> + 'static) -> Self {
Self {
inner: Box::pin(stream),
}
}
}

// Debug implementation remains the same
impl<'a> Debug for ResponseBodyStream<'a> {
impl Debug for ResponseBodyStream {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ResponseBodyStream").finish()
}
}

// Stream implementation remains the same
impl<'a> Stream for ResponseBodyStream<'a> {
impl Stream for ResponseBodyStream {
type Item = ResponseBodyStreamItem;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Expand Down
20 changes: 11 additions & 9 deletions packages/ic-http-gateway/src/response/response_handler.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use crate::{HttpGatewayResponseBody, ResponseBodyStream};
use bytes::Bytes;
use futures::{stream, Stream, StreamExt, TryStreamExt};
use http_body::Frame;
use ic_agent::{Agent, AgentError};
use ic_utils::{
call::SyncCall,
Expand All @@ -20,10 +22,10 @@ static STREAM_CALLBACK_BUFFER: usize = 2;

pub type AgentResponseAny = AgentResponse<Token, HttpRequestStreamingCallbackAny>;

pub async fn get_body_and_streaming_body<'a, 'b>(
agent: &'a Agent,
response: &'b AgentResponseAny,
) -> Result<HttpGatewayResponseBody<'a>, AgentError> {
pub async fn get_body_and_streaming_body(
agent: &Agent,
response: &AgentResponseAny,
) -> Result<HttpGatewayResponseBody, AgentError> {
// if we already have the full body, we can return it early
let Some(StreamingStrategy::Callback(callback_strategy)) = response.streaming_strategy.clone()
else {
Expand Down Expand Up @@ -71,16 +73,16 @@ pub async fn get_body_and_streaming_body<'a, 'b>(
Ok(HttpGatewayResponseBody::Bytes(streamed_body))
}

fn create_body_stream<'a>(
fn create_body_stream(
agent: Agent,
callback: HttpRequestStreamingCallbackAny,
token: Option<Token>,
initial_body: Vec<u8>,
) -> ResponseBodyStream<'a> {
let chunks_stream =
create_stream(agent, callback, token).map(|chunk| chunk.map(|(body, _)| body));
) -> ResponseBodyStream {
let chunks_stream = create_stream(agent, callback, token)
.map(|chunk| chunk.map(|(body, _)| Frame::data(Bytes::from(body))));

let body_stream = stream::once(async move { Ok(initial_body) })
let body_stream = stream::once(async move { Ok(Frame::data(Bytes::from(initial_body))) })
.chain(chunks_stream)
.take(MAX_HTTP_REQUEST_STREAM_CALLBACK_CALL_COUNT)
.map(|x| async move { x })
Expand Down

0 comments on commit da864b6

Please sign in to comment.