Skip to content

Commit

Permalink
Graceful shutdown handling (#69)
Browse files Browse the repository at this point in the history
  • Loading branch information
ameyc authored Dec 17, 2024
1 parent d86a124 commit b856f51
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 87 deletions.
2 changes: 2 additions & 0 deletions crates/common/src/error/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ pub enum DenormalizedError {
// #[allow(clippy::disallowed_types)]
#[error("DataFusion error")]
DataFusion(#[from] DataFusionError),
#[error("Shutdown")]
Shutdown(),
#[error("RocksDB error: {0}")]
RocksDB(String),
#[error("Kafka error")]
Expand Down
161 changes: 94 additions & 67 deletions crates/core/src/datastream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,6 @@ impl DataStream {
let (session_state, plan) = self.df.as_ref().clone().into_parts();
let physical_plan = self.df.as_ref().clone().create_physical_plan().await?;
let node_id = physical_plan.node_id();
debug!("topline node id = {:?}", node_id);
let displayable_plan = DisplayableExecutionPlan::new(physical_plan.as_ref());

println!("{}", displayable_plan.indent(true));
Expand All @@ -243,62 +242,49 @@ impl DataStream {
})
}

/// execute the stream and print the results to stdout.
/// Mainly used for development and debugging
pub async fn print_stream(mut self) -> Result<()> {
async fn with_orchestrator<F, Fut, T>(&mut self, stream_fn: F) -> Result<T>
where
F: FnOnce(watch::Receiver<bool>) -> Fut,
Fut: std::future::Future<Output = Result<T>>,
{
self.start_shutdown_listener();

let mut maybe_orchestrator_handle = None;

let config = self.context.session_context.copied_config();
let config_options = config.options().extensions.get::<DenormalizedConfig>();

let should_checkpoint = config_options.map_or(false, |c| c.checkpoint);

let mut maybe_orchestrator_handle = None;

// Start orchestrator if checkpointing is enabled
if should_checkpoint {
let mut orchestrator = Orchestrator::default();
let cloned_shutdown_rx = self.shutdown_rx.clone();
let orchestrator_handle =
SpawnedTask::spawn_blocking(move || orchestrator.run(10, cloned_shutdown_rx));

maybe_orchestrator_handle = Some(orchestrator_handle)
maybe_orchestrator_handle = Some(orchestrator_handle);
}

let mut stream: SendableRecordBatchStream =
self.df.as_ref().clone().execute_stream().await?;

// Stream loop with shutdown check
loop {
tokio::select! {
// Check if shutdown signal has changed
_ = self.shutdown_rx.changed() => {
info!("Graceful shutdown initiated, exiting stream loop...");

break;
}
// Handle the next batch from the DataFusion stream
next_batch = stream.next() => {
match next_batch.transpose() {
Ok(Some(batch)) => {
println!(
"{}",
datafusion::common::arrow::util::pretty::pretty_format_batches(&[batch])
.unwrap()
);
}
Ok(None) => {
info!("No more RecordBatch in stream");
break; // End of stream
}
Err(err) => {
log::error!("Error reading stream: {:?}", err);
return Err(err.into());
}
}
}
// Run the stream processing function

let mut shutdown_rx = self.shutdown_rx.clone();

let result = tokio::select! {
res = stream_fn(shutdown_rx.clone()) => {
// `stream_fn` completed first
res
},
_ = shutdown_rx.changed() => {
// Shutdown signal received first
log::info!("Shutdown signal received while the pipeline was running, cancelling...");
// return early or handle cancellation gracefully
// For example, you might return Ok(()) or some cancellation error:
return Err(denormalized_common::DenormalizedError::Shutdown());
}
}
};

//let result = stream_fn(self.shutdown_rx.clone()).await;

// Cleanup
log::info!("Stream processing stopped. Cleaning up...");

if should_checkpoint {
Expand All @@ -309,43 +295,84 @@ impl DataStream {
}
}

// Join the orchestrator handle if it exists, ensuring it is joined and awaited
// Join orchestrator if it was started
if let Some(orchestrator_handle) = maybe_orchestrator_handle {
log::info!("Waiting for orchestrator task to complete...");
match orchestrator_handle.join_unwind().await {
Ok(_) => log::info!("Orchestrator task completed successfully."),
Err(e) => log::error!("Error joining orchestrator task: {:?}", e),
}
}
Ok(())
}

/// execute the stream and write the results to a give kafka topic
pub async fn sink_kafka(self, bootstrap_servers: String, topic: String) -> Result<()> {
let processed_schema = Arc::new(datafusion::common::arrow::datatypes::Schema::from(
self.df.schema(),
));

let sink_topic = KafkaTopicBuilder::new(bootstrap_servers.clone())
.with_timestamp(String::from("occurred_at_ms"), TimestampUnit::Int64Millis)
.with_encoding("json")?
.with_topic(topic.clone())
.with_schema(processed_schema)
.build_writer(ConnectionOpts::new())
.await?;

self.context
.register_table(topic.clone(), Arc::new(sink_topic))
.await?;
result
}

self.df
.as_ref()
.clone()
.write_table(topic.as_str(), DataFrameWriteOptions::default())
/// execute the stream and print the results to stdout.
/// Mainly used for development and debugging
pub async fn print_stream(self) -> Result<()> {
self.clone()
.with_orchestrator(|_shutdown_rx| async move {
let mut stream: SendableRecordBatchStream =
self.df.as_ref().clone().execute_stream().await?;

loop {
match stream.next().await.transpose() {
Ok(Some(batch)) => {
if batch.num_rows() > 0 {
println!(
"{}",
datafusion::common::arrow::util::pretty::pretty_format_batches(
&[batch]
)
.unwrap()
);
}
}
Ok(None) => {
info!("No more RecordBatches in stream");
break; // End of stream
}
Err(err) => {
log::error!("Error reading stream: {:?}", err);
return Err(err.into());
}
}
}
Ok(())
})
.await?;

Ok(())
}

pub async fn sink_kafka(self, bootstrap_servers: String, topic: String) -> Result<()> {
self.clone()
.with_orchestrator(|_shutdown_rx| async move {
let processed_schema = Arc::new(
datafusion::common::arrow::datatypes::Schema::from(self.df.schema()),
);

let sink_topic = KafkaTopicBuilder::new(bootstrap_servers.clone())
.with_timestamp(String::from("occurred_at_ms"), TimestampUnit::Int64Millis)
.with_encoding("json")?
.with_topic(topic.clone())
.with_schema(processed_schema)
.build_writer(ConnectionOpts::new())
.await?;

self.context
.register_table(topic.clone(), Arc::new(sink_topic))
.await?;

self.df
.as_ref()
.clone()
.write_table(topic.as_str(), DataFrameWriteOptions::default())
.await?;

Ok(())
})
.await
}
}

/// Trait that allows both DataStream and DataFrame objects to be joined to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -632,18 +632,18 @@ impl GroupedAggWindowFrame {
&mut self,
state: &CheckpointedGroupedWindowFrame,
) -> Result<(), DataFusionError> {
let _ = self
.accumulators
self.accumulators
.iter_mut()
.zip(state.accumulators.iter())
.map(|(acc, checkpointed_acc)| {
.for_each(|(acc, checkpointed_acc)| {
let group_indices = (0..checkpointed_acc.num_groups).collect::<Vec<usize>>();
acc.merge_batch(
&checkpointed_acc.states.arrays,
&group_indices,
None,
checkpointed_acc.num_groups,
)
.unwrap();
});
Ok(())
}
Expand Down
2 changes: 1 addition & 1 deletion crates/core/src/utils/serialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ mod tests {
use arrow_schema::{Field, Fields};
use datafusion::{
functions_aggregate::average::AvgAccumulator, logical_expr::Accumulator,
scalar::ScalarValue,
physical_expr::GroupsAccumulatorAdapter, scalar::ScalarValue,
};
use std::sync::Arc;

Expand Down
33 changes: 17 additions & 16 deletions examples/examples/simple_aggregation.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use std::time::Duration;

use datafusion::functions_aggregate::count::count;
use datafusion::functions_aggregate::expr_fn::{max, min};
use datafusion::logical_expr::{col, lit};
use datafusion::functions_aggregate::expr_fn::{avg, max, min};
use datafusion::logical_expr::col;

use denormalized::datasource::kafka::{ConnectionOpts, KafkaTopicBuilder};
use denormalized::prelude::*;
Expand All @@ -17,27 +17,29 @@ async fn main() -> Result<()> {
.filter_level(log::LevelFilter::Debug)
.init();

let bootstrap_servers = String::from("localhost:9092");
let bootstrap_servers = String::from("localhost:19092");

let config = Context::default_config().set_bool("denormalized_config.checkpoint", false);

let mut topic_builder = KafkaTopicBuilder::new(bootstrap_servers);
let mut topic_builder = KafkaTopicBuilder::new(bootstrap_servers.clone());

// Connect to source topic
let source_topic = topic_builder
.with_topic(String::from("temperature"))
.infer_schema_from_json(get_sample_json().as_str())?
.with_encoding("json")?
.with_timestamp(String::from("occurred_at_ms"), TimestampUnit::Int64Millis)
//.with_timestamp(String::from("occurred_at_ms"), TimestampUnit::Int64Millis)
.build_reader(ConnectionOpts::from([
("auto.offset.reset".to_string(), "latest".to_string()),
("group.id".to_string(), "sample_pipeline".to_string()),
]))
.await?;

let _ctx = Context::with_config(config)?
//.with_slatedb_backend(String::from("/tmp/checkpoints/simple-agg-checkpoint-1"))
//.await
let ds = Context::with_config(config)?
// .with_slatedb_backend(String::from(
// "/tmp/checkpoints/simple-aggregation-example",
// ))
// .await
.from_topic(source_topic)
.await?
.window(
Expand All @@ -46,14 +48,13 @@ async fn main() -> Result<()> {
count(col("reading")).alias("count"),
min(col("reading")).alias("min"),
max(col("reading")).alias("max"),
//avg(col("reading")).alias("average"),
avg(col("reading")).alias("average"),
],
Duration::from_millis(1_000), // aggregate every 1 second
None, // None means tumbling window
)?
.filter(col("max").gt(lit(113)))?
.print_stream() // Print out the results
.await?;

Duration::from_millis(5_000), // Window length
None, // Slide duration. None defaults to a tumbling window.
)?;
ds.print_stream().await?;
//ds.sink_kafka(bootstrap_servers, String::from("checkpointed-output"))
// .await?;
Ok(())
}

0 comments on commit b856f51

Please sign in to comment.