Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Avro decoder #40

Merged
merged 1 commit into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion crates/common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ default = ["python"]

[dependencies]
anyhow = "1.0.86"
datafusion = { workspace = true }
datafusion = { workspace = true, features = ["avro"] }
arrow = { workspace = true }
thiserror = "1.0.63"
pyo3 = { workspace = true, optional = true }
serde_json.workspace = true
apache-avro = "0.16.0"
3 changes: 3 additions & 0 deletions crates/common/src/error/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::result;
use thiserror::Error;

use apache_avro::Error as AvroError;
use arrow::error::ArrowError;
use datafusion::error::DataFusionError;
use serde_json::Error as JsonError;
Expand All @@ -22,6 +23,8 @@ pub enum DenormalizedError {
KafkaConfig(String),
#[error("Arrow Error")]
Arrow(#[from] ArrowError),
#[error("Avro Error")]
AvroError(#[from] AvroError),
#[error("Json Error")]
Json(#[from] JsonError),
#[error(transparent)]
Expand Down
2 changes: 1 addition & 1 deletion crates/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ edition = { workspace = true }
denormalized-common = { workspace = true }
denormalized-orchestrator = { workspace = true }

datafusion = { workspace = true }
datafusion = { workspace = true, features = ["avro"] }

arrow = { workspace = true }
arrow-schema = { workspace = true }
Expand Down
21 changes: 21 additions & 0 deletions crates/core/src/datasource/kafka/kafka_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,15 @@ use std::collections::HashMap;
use std::str::FromStr;
use std::{sync::Arc, time::Duration};

use apache_avro::Schema as AvroSchema;
use arrow_schema::{DataType, Field, Fields, Schema, SchemaRef, TimeUnit};

use datafusion::logical_expr::SortExpr;

use crate::formats::decoders::avro::AvroDecoder;
use crate::formats::decoders::json::JsonDecoder;
use crate::formats::decoders::utils::to_arrow_schema;
use crate::formats::decoders::Decoder;
use crate::formats::StreamEncoding;
use crate::physical_plan::utils::time::TimestampUnit;
use crate::utils::arrow_helpers::infer_arrow_schema_from_json_value;
Expand Down Expand Up @@ -53,6 +58,13 @@ impl KafkaReadConfig {
let consumer: StreamConsumer = client_config.create().expect("Consumer creation failed");
Ok(consumer)
}

pub fn build_decoder(&self) -> Box<dyn Decoder> {
match self.encoding {
StreamEncoding::Avro => Box::new(AvroDecoder::new(self.original_schema.clone())),
StreamEncoding::Json => Box::new(JsonDecoder::new(self.original_schema.clone())),
}
}
}

#[derive(Debug)]
Expand Down Expand Up @@ -146,6 +158,15 @@ impl KafkaTopicBuilder {
Ok(self)
}

pub fn infer_schema_from_avro(&mut self, avro_schema_str: &str) -> Result<&mut Self> {
self.infer_schema = false;
let avro_schema: AvroSchema =
AvroSchema::parse_str(avro_schema_str).expect("Invalid schema!");
let arrow_schema = to_arrow_schema(&avro_schema)?;
self.schema = Some(Arc::new(arrow_schema));
Ok(self)
}

pub fn with_timestamp(
&mut self,
timestamp_column: String,
Expand Down
10 changes: 4 additions & 6 deletions crates/core/src/datasource/kafka/kafka_stream_read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ use log::{debug, error};
use serde::{Deserialize, Serialize};

use crate::config_extensions::denormalized_config::DenormalizedConfig;
use crate::formats::decoders::json::JsonDecoder;
use crate::formats::decoders::Decoder;
use crate::physical_plan::stream_table::PartitionStreamExt;
use crate::physical_plan::utils::time::array_to_timestamp_array;
use crate::state_backend::rocksdb_backend::get_global_rocksdb;
Expand Down Expand Up @@ -133,7 +131,6 @@ impl PartitionStream for KafkaStreamRead {
let mut builder = RecordBatchReceiverStreamBuilder::new(self.config.schema.clone(), 1);
let tx = builder.tx();
let canonical_schema = self.config.schema.clone();
let arrow_schema = self.config.original_schema.clone();
let timestamp_column: String = self.config.timestamp_column.clone();
let timestamp_unit = self.config.timestamp_unit.clone();
let batch_timeout = Duration::from_millis(100);
Expand All @@ -143,14 +140,15 @@ impl PartitionStream for KafkaStreamRead {
channel_tag = format!("{}_{}", node_id, partition_tag);
create_channel(channel_tag.as_str(), 10);
}
let mut decoder = self.config.build_decoder();

builder.spawn(async move {
let mut epoch = 0;
if orchestrator::SHOULD_CHECKPOINT {
let orchestrator_sender = get_sender("orchestrator");
let msg = OrchestrationMessage::RegisterStream(channel_tag.clone());
orchestrator_sender.as_ref().unwrap().send(msg).unwrap();
}
let mut json_decoder: JsonDecoder = JsonDecoder::new(arrow_schema.clone());
loop {
let mut last_offsets = HashMap::new();
if let Some(backend) = &state_backend {
Expand Down Expand Up @@ -192,7 +190,7 @@ impl PartitionStream for KafkaStreamRead {
{
Ok(Ok(m)) => {
let payload = m.payload().expect("Message payload is empty");
json_decoder.push_to_buffer(payload.to_owned());
decoder.push_to_buffer(payload.to_owned());
offsets_read.insert(m.partition(), m.offset());
}
Ok(Err(err)) => {
Expand All @@ -207,7 +205,7 @@ impl PartitionStream for KafkaStreamRead {
}

if !offsets_read.is_empty() {
let record_batch = json_decoder.to_record_batch().unwrap();
let record_batch = decoder.to_record_batch().unwrap();
let ts_column = record_batch
.column_by_name(timestamp_column.as_str())
.map(|ts_col| {
Expand Down
159 changes: 159 additions & 0 deletions crates/core/src/formats/decoders/avro.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
use std::{io::Cursor, sync::Arc};

use arrow_array::RecordBatch;
use arrow_schema::Schema;
use datafusion::datasource::avro_to_arrow::ReaderBuilder;
use denormalized_common::DenormalizedError;

use super::Decoder;

#[derive(Clone)]
pub struct AvroDecoder {
schema: Arc<Schema>,
cache: Vec<Vec<u8>>,
size: usize,
}

impl Decoder for AvroDecoder {
fn push_to_buffer(&mut self, bytes: Vec<u8>) {
self.cache.push(bytes);
self.size += 1;
}

fn to_record_batch(&mut self) -> Result<arrow_array::RecordBatch, DenormalizedError> {
if self.size == 0 {
return Ok(RecordBatch::new_empty(self.schema.clone()));
}
let all_bytes: Vec<u8> = self.cache.iter().flatten().cloned().collect();
// Create a cursor from the concatenated bytes
let cursor = Cursor::new(all_bytes);

// Build the reader
let mut reader = ReaderBuilder::new()
.with_batch_size(self.size)
.with_schema(self.schema.clone())
.build(cursor)?;

// Read the batch
match reader.next() {
Some(Ok(batch)) => Ok(batch),
Some(Err(e)) => Err(DenormalizedError::Arrow(e)),
None => Ok(RecordBatch::new_empty(self.schema.clone())),
}
}
}

impl AvroDecoder {
pub fn new(schema: Arc<Schema>) -> Self {
AvroDecoder {
schema,
cache: Vec::new(),
size: 0,
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use apache_avro::{types::Record, Schema as AvroSchema, Writer};
use arrow_array::{Int32Array, StringArray};
use arrow_schema::{DataType, Field};

fn create_test_schema() -> Arc<Schema> {
Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, false),
]))
}

fn create_avro_data(records: Vec<(i32, &str)>) -> Vec<u8> {
let avro_schema = AvroSchema::parse_str(
r#"
{
"type": "record",
"name": "test",
"fields": [
{"name": "id", "type": "int"},
{"name": "name", "type": "string"}
]
}
"#,
)
.unwrap();

let mut writer = Writer::new(&avro_schema, Vec::new());

for (id, name) in records {
let mut record: Record<'_> = Record::new(writer.schema()).unwrap();
record.put("id", id);
record.put("name", name);
writer.append(record).unwrap();
}

writer.into_inner().unwrap()
}

#[test]
fn test_push_to_buffer() {
let schema = create_test_schema();
let mut decoder = AvroDecoder::new(schema);

decoder.push_to_buffer(vec![1, 2, 3]);
decoder.push_to_buffer(vec![4, 5, 6]);

assert_eq!(decoder.size, 2);
assert_eq!(decoder.cache, vec![vec![1, 2, 3], vec![4, 5, 6]]);
}

#[test]
fn test_empty_record_batch() {
let schema = create_test_schema();
let mut decoder = AvroDecoder::new(schema.clone());

let result = decoder.to_record_batch().unwrap();

assert_eq!(result.schema(), schema);
assert_eq!(result.num_rows(), 0);
}

#[test]
fn test_record_batch_with_data() {
let schema = create_test_schema();
let mut decoder = AvroDecoder::new(schema.clone());

let avro_data = create_avro_data(vec![(1, "Alice")]);
decoder.push_to_buffer(avro_data);

let result = decoder.to_record_batch().unwrap();

assert_eq!(result.schema(), schema);
assert_eq!(result.num_rows(), 1);

let id_array = result
.column(0)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
let name_array = result
.column(1)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();

assert_eq!(id_array.value(0), 1);
assert_eq!(name_array.value(0), "Alice");
}

#[test]
fn test_invalid_avro_data() {
let schema = create_test_schema();
let mut decoder = AvroDecoder::new(schema);

decoder.push_to_buffer(vec![1, 2, 3]);

let result = decoder.to_record_batch();

assert!(matches!(result, Err(DenormalizedError::DataFusion(_))));
}
}
1 change: 1 addition & 0 deletions crates/core/src/formats/decoders/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::utils::arrow_helpers::json_records_to_arrow_record_batch;

use super::Decoder;

#[derive(Clone)]
pub struct JsonDecoder {
schema: Arc<Schema>,
cache: Vec<Vec<u8>>,
Expand Down
4 changes: 3 additions & 1 deletion crates/core/src/formats/decoders/mod.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use arrow_array::RecordBatch;
use denormalized_common::DenormalizedError;

pub trait Decoder {
pub trait Decoder: Send + Sync {
fn push_to_buffer(&mut self, bytes: Vec<u8>);

fn to_record_batch(&mut self) -> Result<RecordBatch, DenormalizedError>;
}

pub mod avro;
pub mod json;
pub mod utils;
Loading
Loading