Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

checkpointing final #48

Merged
merged 3 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,520 changes: 1,256 additions & 264 deletions Cargo.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,4 @@ base64 = "0.22.1"
chrono = { version = "0.4.38" }
itertools = "0.13"
pyo3 = { version = "0.22.2" }
slatedb = "0.2.0"
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,18 @@ Details about developing the python bindings can be found in [py-denormalized/RE
2. Start emitting some sample data: `cargo run --example emit_measurements`
3. Run a [simple streaming aggregation](./examples/examples/simple_aggregation.rs) on the data using denormalized: `cargo run --example simple_aggregation`

### Checkpointing

We use SlateDB for state backend. Initialize your Job Context to a path to local directory -

```
let ctx = Context::new()?
.with_slatedb_backend(String::from("/tmp/checkpoints/simple-agg-checkpoint-1"))
.await;
```

The job with automatically recover from state if a previous checkpoint exists.

## More examples

A more powerful example can be seen in our [Kafka ridesharing example](./docs/kafka_rideshare_example.md)
Expand Down
1 change: 1 addition & 0 deletions crates/common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ thiserror = "1.0.63"
pyo3 = { workspace = true, optional = true }
serde_json.workspace = true
apache-avro = "0.16.0"
slatedb = { workspace = true }
2 changes: 2 additions & 0 deletions crates/common/src/error/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ pub enum DenormalizedError {
AvroError(#[from] AvroError),
#[error("Json Error")]
Json(#[from] JsonError),
//#[error("SlateDB Error")]
//SlateDBError(#[from] SlateDBError),
#[error(transparent)]
Other(#[from] anyhow::Error),
}
4 changes: 4 additions & 0 deletions crates/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,7 @@ delegate = "0.12.0"
ahash = "0.8.11"
hashbrown = "0.14.5"
flatbuffers = "24.3.25"
crossbeam = "0.8.4"
slatedb = { workspace = true } # "0.2.0"
object_store = "0.11.0"
bytes = "1.7.2"
14 changes: 7 additions & 7 deletions crates/core/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use crate::datasource::kafka::TopicReader;
use crate::datastream::DataStream;
use crate::physical_optimizer::EnsureHashPartititionOnGroupByForStreamingAggregates;
use crate::query_planner::StreamingQueryPlanner;
use crate::state_backend::slatedb::initialize_global_slatedb;
use crate::utils::get_default_optimizer_rules;

use denormalized_common::error::{DenormalizedError, Result};
Expand Down Expand Up @@ -53,16 +54,10 @@ impl Context {

pub async fn from_topic(&self, topic: TopicReader) -> Result<DataStream, DenormalizedError> {
let topic_name = topic.0.topic.clone();

self.register_table(topic_name.clone(), Arc::new(topic))
.await?;

let df = self.session_conext.table(topic_name.as_str()).await?;

let ds = DataStream {
df: Arc::new(df),
context: Arc::new(self.clone()),
};
let ds = DataStream::new(Arc::new(df), Arc::new(self.clone()));
Ok(ds)
}

Expand All @@ -76,4 +71,9 @@ impl Context {

Ok(())
}

pub async fn with_slatedb_backend(self, path: String) -> Self {
let _ = initialize_global_slatedb(path.as_str()).await;
self
}
}
169 changes: 84 additions & 85 deletions crates/core/src/datasource/kafka/kafka_stream_read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,25 @@ use std::time::Duration;
use arrow::datatypes::TimestampMillisecondType;
use arrow_array::{Array, ArrayRef, PrimitiveArray, RecordBatch, StringArray, StructArray};
use arrow_schema::{DataType, Field, SchemaRef, TimeUnit};
use denormalized_orchestrator::channel_manager::{create_channel, get_sender};
use crossbeam::channel;
use denormalized_orchestrator::channel_manager::{create_channel, get_sender, take_receiver};
use denormalized_orchestrator::orchestrator::{self, OrchestrationMessage};
use futures::executor::block_on;
use log::{debug, error};
use serde::{Deserialize, Serialize};

use crate::config_extensions::denormalized_config::DenormalizedConfig;
use crate::physical_plan::stream_table::PartitionStreamExt;
use crate::physical_plan::utils::time::array_to_timestamp_array;
use crate::state_backend::rocksdb_backend::get_global_rocksdb;
use crate::state_backend::slatedb::get_global_slatedb;

use arrow::compute::{max, min};
use datafusion::execution::{SendableRecordBatchStream, TaskContext};
use datafusion::physical_plan::stream::RecordBatchReceiverStreamBuilder;
use datafusion::physical_plan::streaming::PartitionStream;

use rdkafka::consumer::{Consumer, StreamConsumer};
use rdkafka::{ClientConfig, Message, TopicPartitionList};
use rdkafka::{ClientConfig, Message, Offset, TopicPartitionList};

use super::KafkaReadConfig;

Expand All @@ -44,7 +46,7 @@ impl KafkaStreamRead {

#[derive(Debug, Serialize, Deserialize)]
struct BatchReadMetadata {
epoch: i32,
epoch: u128,
min_timestamp: Option<i64>,
max_timestamp: Option<i64>,
offsets_read: HashMap<i32, i64>,
Expand Down Expand Up @@ -81,120 +83,120 @@ impl PartitionStream for KafkaStreamRead {
}

fn execute(&self, ctx: Arc<TaskContext>) -> SendableRecordBatchStream {
let mut assigned_partitions = TopicPartitionList::new();

let config_options = ctx
let _config_options = ctx
.session_config()
.options()
.extensions
.get::<DenormalizedConfig>();

let should_checkpoint = config_options.map_or(false, |c| c.checkpoint);
let mut should_checkpoint = false; //config_options.map_or(false, |c| c.checkpoint);

let topic = self.config.topic.clone();
for partition in self.assigned_partitions.clone() {
assigned_partitions.add_partition(self.config.topic.as_str(), partition);
}
let node_id = self.exec_node_id.unwrap();
let partition_tag = self
.assigned_partitions
.iter()
.map(|&x| x.to_string())
.collect::<Vec<String>>()
.join("_");

let state_backend = if should_checkpoint {
Some(get_global_rocksdb().unwrap())
} else {
None
};
let channel_tag = format!("{}_{}", node_id, partition_tag);
let mut serialized_state: Option<Vec<u8>> = None;
let state_backend = get_global_slatedb().unwrap();

let mut starting_offsets: HashMap<i32, i64> = HashMap::new();
if orchestrator::SHOULD_CHECKPOINT {
create_channel(channel_tag.as_str(), 10);
debug!("checking for last checkpointed offsets");
serialized_state = block_on(state_backend.clone().get(channel_tag.as_bytes().to_vec()));
}

if let Some(serialized_state) = serialized_state {
let last_batch_metadata = BatchReadMetadata::from_bytes(&serialized_state).unwrap();
debug!(
"recovering from checkpointed offsets. epoch was {} max timestamp {:?}",
last_batch_metadata.epoch, last_batch_metadata.max_timestamp
);
starting_offsets = last_batch_metadata.offsets_read.clone();
}

let mut assigned_partitions = TopicPartitionList::new();

for partition in self.assigned_partitions.clone() {
assigned_partitions.add_partition(self.config.topic.as_str(), partition);
if starting_offsets.contains_key(&partition) {
let offset = starting_offsets.get(&partition).unwrap();
debug!("setting partition {} to offset {}", partition, offset);
let _ = assigned_partitions.set_partition_offset(
self.config.topic.as_str(),
partition,
Offset::from_raw(*offset),
);
}
}

let consumer: StreamConsumer = create_consumer(self.config.clone());

consumer
.assign(&assigned_partitions)
.expect("Partition assignment failed.");

let state_namespace = format!("kafka_source_{}", topic);

if let Some(backend) = &state_backend {
let _ = match backend.get_cf(&state_namespace) {
Ok(cf) => {
debug!("cf for this already exists");
Ok(cf)
}
Err(..) => {
let _ = backend.create_cf(&state_namespace);
backend.get_cf(&state_namespace)
}
};
}

let mut builder = RecordBatchReceiverStreamBuilder::new(self.config.schema.clone(), 1);
let tx = builder.tx();
let canonical_schema = self.config.schema.clone();
let timestamp_column: String = self.config.timestamp_column.clone();
let timestamp_unit = self.config.timestamp_unit.clone();
let batch_timeout = Duration::from_millis(100);
let mut channel_tag: String = String::from("");
if orchestrator::SHOULD_CHECKPOINT {
let node_id = self.exec_node_id.unwrap();
channel_tag = format!("{}_{}", node_id, partition_tag);
create_channel(channel_tag.as_str(), 10);
}
let batch_timeout: Duration = Duration::from_millis(100);
let mut decoder = self.config.build_decoder();

builder.spawn(async move {
let mut epoch = 0;
let mut receiver: Option<channel::Receiver<OrchestrationMessage>> = None;
if orchestrator::SHOULD_CHECKPOINT {
let orchestrator_sender = get_sender("orchestrator");
let msg = OrchestrationMessage::RegisterStream(channel_tag.clone());
let msg: OrchestrationMessage =
OrchestrationMessage::RegisterStream(channel_tag.clone());
orchestrator_sender.as_ref().unwrap().send(msg).unwrap();
receiver = take_receiver(channel_tag.as_str());
}

loop {
let mut last_offsets = HashMap::new();
if let Some(backend) = &state_backend {
if let Some(offsets) = backend
.get_state(&state_namespace, partition_tag.clone().into_bytes())
.unwrap()
{
let last_batch_metadata = BatchReadMetadata::from_bytes(&offsets).unwrap();
last_offsets = last_batch_metadata.offsets_read;
debug!(
"epoch is {} and last read offsets are {:?}",
epoch, last_offsets
);
} else {
debug!("epoch is {} and no prior offsets were found.", epoch);
//let mut checkpoint_barrier: Option<String> = None;
let mut _checkpoint_barrier: Option<i64> = None;

if orchestrator::SHOULD_CHECKPOINT {
let r = receiver.as_ref().unwrap();
for message in r.try_iter() {
debug!("received checkpoint barrier for {:?}", message);
if let OrchestrationMessage::CheckpointBarrier(epoch_ts) = message {
epoch = epoch_ts;
should_checkpoint = true;
}
}
}

for (partition, offset) in &last_offsets {
consumer
.seek(
&topic,
*partition,
rdkafka::Offset::Offset(*offset + 1),
Duration::from_secs(10),
)
.expect("Failed to seek to stored offset");
}

let mut offsets_read: HashMap<i32, i64> = HashMap::new();
let start_time = datafusion::common::instant::Instant::now();

while start_time.elapsed() < batch_timeout {
loop {
match tokio::time::timeout(batch_timeout, consumer.recv()).await {
Ok(Ok(m)) => {
let payload = m.payload().expect("Message payload is empty");
decoder.push_to_buffer(payload.to_owned());
offsets_read.insert(m.partition(), m.offset());
offsets_read
.entry(m.partition())
.and_modify(|existing_value| {
*existing_value = (*existing_value).max(m.offset())
})
.or_insert(m.offset());
break;
}
Ok(Err(err)) => {
error!("Error reading from Kafka {:?}", err);
// TODO: Implement a retry mechanism here
}
Err(_) => {
// Timeout reached
break;
error!("timeout reached.");
//break;
}
}
}
Expand All @@ -220,7 +222,6 @@ impl PartitionStream for KafkaStreamRead {

let max_timestamp: Option<_> = max::<TimestampMillisecondType>(ts_array);
let min_timestamp: Option<_> = min::<TimestampMillisecondType>(ts_array);
debug!("min: {:?}, max: {:?}", min_timestamp, max_timestamp);
let mut columns: Vec<Arc<dyn Array>> = record_batch.columns().to_vec();

let metadata_column = StructArray::from(vec![
Expand All @@ -245,23 +246,21 @@ impl PartitionStream for KafkaStreamRead {
match tx_result {
Ok(_) => {
if should_checkpoint {
let _ = state_backend.as_ref().map(|backend| {
backend.put_state(
&state_namespace,
partition_tag.clone().into_bytes(),
BatchReadMetadata {
epoch,
min_timestamp,
max_timestamp,
offsets_read,
}
.to_bytes()
.unwrap(),
)
});
debug!("about to checkpoint offsets");
let off = BatchReadMetadata {
epoch,
min_timestamp,
max_timestamp,
offsets_read,
};
let _ = state_backend
.as_ref()
.put(channel_tag.as_bytes().to_vec(), off.to_bytes().unwrap());
debug!("checkpointed offsets {:?}", off);
should_checkpoint = false;
}
}
Err(err) => error!("result err {:?}", err),
Err(err) => error!("result err {:?}. shutdown signal detected.", err),
}
epoch += 1;
}
Expand Down
Loading
Loading