diff --git a/Cargo.lock b/Cargo.lock index 40e2f8c7..900918e8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6370,6 +6370,7 @@ dependencies = [ "chrono", "cron", "derive_builder", + "eventsource-stream", "flow-component", "flow-graph", "flow-graph-interpreter", diff --git a/crates/components/wick-http-client/src/component.rs b/crates/components/wick-http-client/src/component.rs index 30690fac..da0e4496 100644 --- a/crates/components/wick-http-client/src/component.rs +++ b/crates/components/wick-http-client/src/component.rs @@ -16,7 +16,7 @@ use wick_config::config::components::{ HttpClientOperationDefinition, OperationConfig, }; -use wick_config::config::{Codec, HttpMethod, LiquidJsonConfig, Metadata, UrlResource}; +use wick_config::config::{Codec, HttpEvent, HttpMethod, LiquidJsonConfig, Metadata, UrlResource}; use wick_config::{ConfigValidation, Resolver}; use wick_interface_types::{ComponentSignature, OperationSignatures}; use wick_packet::{Base64Bytes, FluxChannel, Invocation, Observer, Packet, PacketSender, PacketStream, RuntimeConfig}; @@ -355,18 +355,6 @@ async fn handle( Ok(()) } -#[derive(Debug, serde::Serialize, serde::Deserialize, Clone)] -struct WickEvent { - /// The event name if given - event: String, - /// The event data - data: String, - /// The event id if given - id: String, - /// Retry duration if given - retry: Option, -} - fn output_task( span: Span, codec: Codec, @@ -380,12 +368,7 @@ fn output_task( while let Some(event) = stream.next().await { match event { Ok(event) => { - let wick_event = WickEvent { - event: event.event, - data: event.data, - id: event.id, - retry: event.retry, - }; + let wick_event = HttpEvent::new(Some(event.event), event.data, Some(event.id), event.retry); span.in_scope(|| debug!("{} {}", format!("{:?}", wick_event), "http:client:response_body")); let _ = tx.send(Packet::encode("body", wick_event)); } @@ -704,12 +687,10 @@ mod test { let (app_config, component_config) = get_config(); let comp = get_component(&app_config, component_config); - // Simulate an event stream let event_stream = "data: {\"id\":\"1\",\"object\":\"event1\"}\n\n\ data: {\"id\":\"2\",\"object\":\"event2\"}\n\n"; let packets = packet_stream!(("input", event_stream)); - // Replace "event_stream_op" with the actual operation id for event stream let invocation = Invocation::test( "test_event_stream", Entity::local("event_stream_op"), @@ -726,8 +707,10 @@ mod test { let packets = stream.into_iter().collect::, _>>()?; for packet in packets { if packet.port() == "body" { - let response: WickEvent = packet.decode().unwrap(); - assert!(response.id == "1" && response.event == "event1" || response.id == "2" && response.event == "event2"); + let response: HttpEvent = packet.decode().unwrap(); + let response_id = response.get_id().as_ref().unwrap(); + let response_event = response.get_event().as_ref().unwrap(); + assert!(response_id == "1" && response_event == "event1" || response_id == "2" && response_event == "event2"); } else { let response: HttpResponse = packet.decode().unwrap(); assert_eq!(response.version, HttpVersion::Http11); diff --git a/crates/wick/wick-config/src/config/common.rs b/crates/wick/wick-config/src/config/common.rs index 2f9d9083..c1e0bb53 100644 --- a/crates/wick/wick-config/src/config/common.rs +++ b/crates/wick/wick-config/src/config/common.rs @@ -29,7 +29,7 @@ pub use self::error_behavior::ErrorBehavior; pub use self::exposed_resources::{ExposedVolume, ExposedVolumeBuilder}; pub use self::glob::Glob; pub use self::host_definition::{HostConfig, HostConfigBuilder, HttpConfig, HttpConfigBuilder}; -pub use self::http::{Codec, HttpMethod}; +pub use self::http::{Codec, HttpEvent, HttpMethod}; pub use self::import_definition::ImportDefinition; pub use self::interface::InterfaceDefinition; pub use self::liquid_json_config::LiquidJsonConfig; diff --git a/crates/wick/wick-config/src/config/common/http.rs b/crates/wick/wick-config/src/config/common/http.rs index 1ee415f2..66951828 100644 --- a/crates/wick/wick-config/src/config/common/http.rs +++ b/crates/wick/wick-config/src/config/common/http.rs @@ -1,3 +1,5 @@ +use std::fmt::Write; + #[derive(Debug, Clone, Copy, PartialEq, serde::Serialize)] /// Supported HTTP methods #[serde(rename_all = "kebab-case")] @@ -30,6 +32,76 @@ impl Default for Codec { } } +#[derive(Debug, serde::Serialize, serde::Deserialize, Clone, PartialEq)] +pub struct HttpEvent { + /// The event name if given + event: Option, + /// The event data + data: String, + /// The event id if given + id: Option, + /// Retry duration if given + retry: Option, +} + +impl HttpEvent { + #[must_use] + pub const fn new( + event: Option, + data: String, + id: Option, + retry: Option, + ) -> Self { + Self { event, data, id, retry } + } + + #[must_use] + pub const fn get_event(&self) -> &Option { + &self.event + } + #[must_use] + pub const fn get_data(&self) -> &String { + &self.data + } + #[must_use] + pub const fn get_id(&self) -> &Option { + &self.id + } + #[must_use] + pub const fn get_retry(&self) -> &Option { + &self.retry + } + + #[must_use] + pub fn to_sse_string(&self) -> String { + let mut sse_string = String::new(); + + if let Some(ref event) = self.event { + writeln!(sse_string, "event: {}", event).unwrap(); + } + + // Splitting data by newline to ensure each line is prefixed with "data: " + for line in self.data.split('\n') { + writeln!(sse_string, "data: {}", line).unwrap(); + } + + if let Some(ref id) = self.id { + writeln!(sse_string, "id: {}", id).unwrap(); + } + + if let Some(ref retry) = self.retry { + // Converting retry duration to milliseconds + let millis = retry.as_millis(); + writeln!(sse_string, "retry: {}", millis).unwrap(); + } + + // Adding the required empty line to separate events + sse_string.push_str("\n"); + + sse_string + } +} + impl std::fmt::Display for Codec { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { diff --git a/crates/wick/wick-runtime/Cargo.toml b/crates/wick/wick-runtime/Cargo.toml index 1b6b2477..9794f02f 100644 --- a/crates/wick/wick-runtime/Cargo.toml +++ b/crates/wick/wick-runtime/Cargo.toml @@ -47,6 +47,7 @@ hyper-staticfile = { workspace = true } hyper-reverse-proxy = { workspace = true } url = { workspace = true } bytes = { workspace = true } +eventsource-stream = { workspace = true } openapiv3 = { workspace = true } percent-encoding = { workspace = true } liquid = { workspace = true } diff --git a/crates/wick/wick-runtime/src/triggers/http/component_utils.rs b/crates/wick/wick-runtime/src/triggers/http/component_utils.rs index 0ab59d0e..5b1edc06 100644 --- a/crates/wick/wick-runtime/src/triggers/http/component_utils.rs +++ b/crates/wick/wick-runtime/src/triggers/http/component_utils.rs @@ -5,10 +5,12 @@ use hyper::http::response::Builder; use hyper::http::{HeaderName, HeaderValue}; use hyper::{Body, Response, StatusCode}; use serde_json::{Map, Value}; +use tokio::sync::mpsc::unbounded_channel; +use tokio::sync::oneshot; use tokio_stream::StreamExt; use tracing::Span; use uuid::Uuid; -use wick_config::config::Codec; +use wick_config::config::{Codec, HttpEvent}; use wick_interface_http::types as wick_http; use wick_packet::{ packets, @@ -132,21 +134,12 @@ pub(super) async fn handle_response_middleware( } } -pub(super) async fn respond( +async fn stream_response( codec: Codec, - stream: Result, -) -> Result, HttpError> { - if let Err(e) = stream { - return Ok( - Builder::new() - .status(StatusCode::INTERNAL_SERVER_ERROR) - .body(Body::from(e.to_string())) - .unwrap(), - ); - } - let mut stream = stream.unwrap(); - let mut builder = Response::builder(); - let mut body = bytes::BytesMut::new(); + mut stream: PacketStream, + mut oneshot_channel: Option>, + tx_channel: tokio::sync::mpsc::UnboundedSender, HttpError>>, +) -> Result<(), HttpError> { while let Some(packet) = stream.next().await { match packet { Ok(p) => { @@ -159,8 +152,11 @@ pub(super) async fn respond( } let response: wick_interface_http::types::HttpResponse = p .decode() - .map_err(|e| HttpError::Deserialize("response".to_owned(), e.to_string()))?; - builder = convert_response(builder, response)?; + .map_err(|e| HttpError::Deserialize("response".to_owned(), e.to_string())) + .unwrap(); + let mut builder = Response::builder(); + builder = convert_response(builder, response).unwrap(); + let _ = oneshot_channel.take().unwrap().send(builder); } else if p.port() == "body" { if let PacketPayload::Err(e) = p.payload() { return Err(HttpError::OutputStream(p.port().to_owned(), e.msg().to_owned())); @@ -168,22 +164,103 @@ pub(super) async fn respond( if !p.has_data() { continue; } - if codec == Codec::Json { - let response: Value = p.decode().map_err(|e| HttpError::Codec(codec, e.to_string()))?; - let as_str = response.to_string(); - let bytes = as_str.as_bytes(); - body.extend_from_slice(bytes); - } else { - let response: Base64Bytes = p.decode().map_err(|e| HttpError::Bytes(e.to_string()))?; - body.extend_from_slice(&response); - } + let response: Value = p + .decode_value() + .map_err(|e| HttpError::Codec(codec, e.to_string())) + .unwrap(); + let http_event: HttpEvent = serde_json::from_value(response).unwrap(); + let as_str = http_event.to_sse_string(); + let bytes = as_str.as_bytes(); + let _ = tx_channel.send(Ok::<_, HttpError>(bytes.to_vec())); } } Err(e) => return Err(HttpError::OperationError(e.to_string())), } } - builder = reset_header(builder, CONTENT_LENGTH, body.len()); - Ok(builder.body(body.freeze().into()).unwrap()) + Ok(()) +} + +pub(super) async fn respond( + codec: Codec, + stream: Result, +) -> Result, HttpError> { + if let Err(e) = stream { + return Ok( + Builder::new() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(Body::from(e.to_string())) + .unwrap(), + ); + } + let mut stream = stream.unwrap(); + if codec == Codec::EventStream { + let (tx, rx) = unbounded_channel(); + let (tx_one, rx_one) = oneshot::channel(); + let tx_one = Some(tx_one); + + tokio::spawn(async move { + let _ = stream_response(codec, stream, tx_one, tx).await; + }); + + let response = rx_one.await; + response.map_or_else( + |_| { + Ok( + Builder::new() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(Body::from("data: No response received\n\ndata: [DONE]")) + .unwrap(), + ) + }, + |response| { + let mut builder = response; + builder = reset_header(builder, CONTENT_LENGTH, 0); + builder = builder.header("content-type", "text/event-stream"); + let body = Body::wrap_stream(tokio_stream::wrappers::UnboundedReceiverStream::new(rx)); + Ok(builder.body(body).unwrap()) + }, + ) + } else { + let mut body = bytes::BytesMut::new(); + let mut builder = Response::builder(); + while let Some(packet) = stream.next().await { + match packet { + Ok(p) => { + if p.port() == "response" { + if let PacketPayload::Err(e) = p.payload() { + return Err(HttpError::OutputStream(p.port().to_owned(), e.msg().to_owned())); + } + if p.is_done() { + continue; + } + let response: wick_interface_http::types::HttpResponse = p + .decode() + .map_err(|e| HttpError::Deserialize("response".to_owned(), e.to_string()))?; + builder = convert_response(builder, response)?; + } else if p.port() == "body" { + if let PacketPayload::Err(e) = p.payload() { + return Err(HttpError::OutputStream(p.port().to_owned(), e.msg().to_owned())); + } + if !p.has_data() { + continue; + } + if codec == Codec::Json { + let response: Value = p.decode().map_err(|e| HttpError::Codec(codec, e.to_string()))?; + let as_str: String = response.to_string(); + let bytes = as_str.as_bytes(); + body.extend_from_slice(bytes); + } else { + let response: Base64Bytes = p.decode().map_err(|e| HttpError::Bytes(e.to_string()))?; + body.extend_from_slice(&response); + } + } + } + Err(e) => return Err(HttpError::OperationError(e.to_string())), + } + } + builder = reset_header(builder, CONTENT_LENGTH, body.len()); + Ok(builder.body(body.freeze().into()).unwrap()) + } } fn reset_header(mut builder: Builder, header: HeaderName, value: impl Into) -> Builder {