diff --git a/crates/interfaces/wick-interface-http/component.yaml b/crates/interfaces/wick-interface-http/component.yaml index 032c4f75..705b5e36 100644 --- a/crates/interfaces/wick-interface-http/component.yaml +++ b/crates/interfaces/wick-interface-http/component.yaml @@ -3,7 +3,7 @@ name: http kind: wick/types@v1 metadata: - version: 0.4.0 + version: 0.5.0 package: registry: host: registry.candle.dev diff --git a/crates/wick/wick-trigger-http/src/http/component_utils.rs b/crates/wick/wick-trigger-http/src/http/component_utils.rs index 4edea58d..831b9890 100644 --- a/crates/wick/wick-trigger-http/src/http/component_utils.rs +++ b/crates/wick/wick-trigger-http/src/http/component_utils.rs @@ -1,12 +1,13 @@ use std::collections::HashMap; +use std::sync::Arc; use futures::stream::StreamExt; use hyper::header::{CONTENT_LENGTH, CONTENT_TYPE}; use hyper::http::response::Builder; use hyper::http::{HeaderName, HeaderValue}; use hyper::{Body, Response, StatusCode}; +use parking_lot::Mutex; use serde_json::{Map, Value}; -use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver}; use tokio::sync::oneshot; use tracing::Span; use uuid::Uuid; @@ -15,6 +16,7 @@ use wick_interface_http::types::{self as wick_http}; use wick_packet::{ packets, Base64Bytes, + BoxStream, Entity, InherentData, Invocation, @@ -158,12 +160,11 @@ pub(super) async fn respond( let stream = stream.unwrap(); let builder = Response::builder(); - let (handle, response, mut body_stream) = split_stream(stream); + let (response, mut body_stream) = split_stream(stream); let response = match response.await { Ok(response) => response?, Err(e) => { - handle.abort(); return Ok( Builder::new() .status(StatusCode::INTERNAL_SERVER_ERROR) @@ -180,43 +181,31 @@ pub(super) async fn respond( .map_or(false, |v| v == "text/event-stream"); let res = if event_stream { - let (tx, rx) = unbounded_channel(); - let _output_handle = tokio::spawn(async move { - while let Some(p) = body_stream.recv().await { - if !p.has_data() { - continue; - } - match codec { - Codec::Json => { - let chunk = p - .decode::() - .map_err(|e| HttpError::Bytes(e.to_string())) - .map(|v| to_sse_string_bytes(&v)); - let _ = tx.send(chunk); - } - Codec::Raw => { - let chunk = p - .decode::() - .map_err(|e| HttpError::Bytes(e.to_string())) - .map(Into::into); - let _ = tx.send(chunk); - } - Codec::Text => { - let chunk = p - .decode::() - .map_err(|e| HttpError::Utf8Text(e.to_string())) - .map(Into::into); - let _ = tx.send(chunk); - } - Codec::FormData => unreachable!("FormData is not supported as a decoder for HTTP responses"), - } + let body_stream = body_stream.filter_map(move |p| async move { + if !p.has_data() { + return None; } + Some(match codec { + Codec::Json => p + .decode::() + .map_err(|e| HttpError::Bytes(e.to_string())) + .map(|v| to_sse_string_bytes(&v)), + Codec::Raw => p + .decode::() + .map_err(|e| HttpError::Bytes(e.to_string())) + .map(Into::into), + Codec::Text => p + .decode::() + .map_err(|e| HttpError::Utf8Text(e.to_string())) + .map(Into::into), + Codec::FormData => unreachable!("FormData is not supported as a decoder for HTTP responses"), + }) }); - let body = Body::wrap_stream(tokio_stream::wrappers::UnboundedReceiverStream::new(rx)); + let body = Body::wrap_stream(body_stream); builder.body(body).unwrap() } else { let mut body = bytes::BytesMut::new(); - while let Some(p) = body_stream.recv().await { + while let Some(p) = body_stream.next().await { if let PacketPayload::Err(e) = p.payload() { return Err(HttpError::OutputStream(p.port().to_owned(), e.msg().to_owned())); } @@ -241,52 +230,60 @@ pub(super) async fn respond( } fn split_stream( - mut stream: PacketStream, + stream: PacketStream, ) -> ( - tokio::task::JoinHandle<()>, oneshot::Receiver>, - UnboundedReceiver, + BoxStream, ) { - let (body_tx, body_rx) = unbounded_channel(); let (res_tx, res_rx) = oneshot::channel(); - let mut res_tx = Some(res_tx); - let handle = tokio::spawn(async move { - while let Some(packet) = stream.next().await { - match packet { + let res_tx = Arc::new(Mutex::new(Some(res_tx))); + + let body = stream.filter_map(move |p| { + let res_tx = Arc::clone(&res_tx); + async move { + match p { Ok(p) => { if p.port() == "response" { if p.is_done() { - continue; + return None; } - let Some(sender) = res_tx.take() else { + let Some(sender) = res_tx.lock().take() else { // we only respect the first packet to the response port. - continue; + return None; }; if let PacketPayload::Err(e) = p.payload() { let _ = sender.send(Err(HttpError::OutputStream(p.port().to_owned(), e.msg().to_owned()))); - break; + return None; } let response: Result = p .decode() .map_err(|e| HttpError::Deserialize("response".to_owned(), e.to_string())); let _ = sender.send(response); + return None; } else if p.port() == "body" { - let _ = body_tx.send(p); + return Some(p); } + if let Some(sender) = res_tx.lock().take() { + if let PacketPayload::Err(e) = p.payload { + error!(error=%e,"http:stream:error"); + let _ = sender.send(Err(HttpError::OperationError(e.to_string()))); + } + }; + None } Err(e) => { - if let Some(sender) = res_tx.take() { + if let Some(sender) = res_tx.lock().take() { let _ = sender.send(Err(HttpError::OperationError(e.to_string()))); - } - warn!(?e, "http:stream:error"); - break; + }; + error!(error=%e,"http:stream:error"); + None } } } }); - (handle, res_rx, body_rx) + (res_rx, body.boxed()) } fn to_sse_string_bytes(event: &wick_http::HttpEvent) -> Vec {