Skip to content

Commit

Permalink
feat(substrait): add support for insert roundtrip in append mode
Browse files Browse the repository at this point in the history
  • Loading branch information
tokoko committed Oct 25, 2024
1 parent 813220d commit 7e6252b
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 21 deletions.
104 changes: 85 additions & 19 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ use datafusion::logical_expr::{
};
use substrait::proto::expression::subquery::set_predicate::PredicateOp;
use substrait::proto::expression_reference::ExprType;
use substrait::proto::write_rel::{self, WriteType};
use url::Url;

use crate::extensions::Extensions;
Expand All @@ -57,8 +58,9 @@ use datafusion::common::scalar::ScalarStructBuilder;
use datafusion::dataframe::DataFrame;
use datafusion::logical_expr::expr::InList;
use datafusion::logical_expr::{
col, expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning,
Repartition, Subquery, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
col, dml, expr, Cast, DmlStatement, Extension, GroupingSet, Like, LogicalPlanBuilder,
Partitioning, Repartition, Subquery, WindowFrameBound, WindowFrameUnits,
WindowFunctionDefinition,
};
use datafusion::prelude::JoinType;
use datafusion::sql::TableReference;
Expand Down Expand Up @@ -255,6 +257,27 @@ async fn except_rels(
Ok(rel)
}

fn from_substrait_names(names: &[String]) -> Result<TableReference> {
let table_reference = match names.len() {
0 => {
return plan_err!("No table name found in NamedTable");
}
1 => TableReference::Bare {
table: names[0].clone().into(),
},
2 => TableReference::Partial {
schema: names[0].clone().into(),
table: names[1].clone().into(),
},
_ => TableReference::Full {
catalog: names[0].clone().into(),
schema: names[1].clone().into(),
table: names[2].clone().into(),
},
};
Ok(table_reference)
}

/// Convert Substrait Plan to DataFusion LogicalPlan
pub async fn from_substrait_plan(
ctx: &SessionContext,
Expand Down Expand Up @@ -825,23 +848,7 @@ pub async fn from_substrait_rel(

match &read.as_ref().read_type {
Some(ReadType::NamedTable(nt)) => {
let table_reference = match nt.names.len() {
0 => {
return plan_err!("No table name found in NamedTable");
}
1 => TableReference::Bare {
table: nt.names[0].clone().into(),
},
2 => TableReference::Partial {
schema: nt.names[0].clone().into(),
table: nt.names[1].clone().into(),
},
_ => TableReference::Full {
catalog: nt.names[0].clone().into(),
schema: nt.names[1].clone().into(),
table: nt.names[2].clone().into(),
},
};
let table_reference = from_substrait_names(&nt.names)?;

let t = ctx.table(table_reference.clone()).await?;

Expand Down Expand Up @@ -1058,6 +1065,65 @@ pub async fn from_substrait_rel(
partitioning_scheme,
}))
}
Some(RelType::Write(write)) => {
let table_name = match &write.write_type {
Some(WriteType::NamedTable(now)) => from_substrait_names(&now.names)?,
Some(WriteType::ExtensionTable(_)) => {
return not_impl_err!("Unsupported WriteType: ExtensionTable");
}
_ => {
return plan_err!("No WriteType specified in WriteRel");
}
};

let table_schema = from_substrait_named_struct(
&write.table_schema.clone().unwrap(),
extensions,
)?;

let op = match write.op() {
write_rel::WriteOp::Insert => dml::WriteOp::Insert(dml::InsertOp::Append),
_ => {
return not_impl_err!("Unsupported WriteOp: {:?}", write.op());
}
};

let output_schema = match write.output() {
write_rel::OutputMode::Unspecified => match write.op() {
write_rel::WriteOp::Insert => {
let fields: Fields = vec![Field::new(
"count".to_string(),
DataType::UInt64,
false,
)]
.into();

DFSchema::try_from(Schema::new(fields))?
}
_ => {
return not_impl_err!("Unsupported WriteOp: {:?}", write.op());
}
},
write_rel::OutputMode::NoOutput => {
return not_impl_err!("Unsupported OutputMode: NoOutput");
}
write_rel::OutputMode::ModifiedRecords => {
return not_impl_err!("Unsupported OutputMode: ModifiedRecords");
}
};

let input =
from_substrait_rel(ctx, &write.input.clone().unwrap(), extensions)
.await?;

Ok(LogicalPlan::Dml(DmlStatement {
table_name,
table_schema: table_schema.into(),
op,
input: input.into(),
output_schema: output_schema.into(),
}))
}
_ => not_impl_err!("Unsupported RelType: {:?}", rel.rel_type),
}
}
Expand Down
41 changes: 39 additions & 2 deletions datafusion/substrait/src/logical_plan/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,17 @@
// under the License.

use datafusion::config::ConfigOptions;
use datafusion::logical_expr::dml::InsertOp;
use datafusion::optimizer::analyzer::expand_wildcard_rule::ExpandWildcardRule;
use datafusion::optimizer::AnalyzerRule;
use std::sync::Arc;
use substrait::proto::expression_reference::ExprType;
use substrait::proto::write_rel::{OutputMode, WriteOp, WriteType};

use arrow_buffer::ToByteSlice;
use datafusion::arrow::datatypes::{Field, IntervalUnit};
use datafusion::logical_expr::{
Distinct, FetchType, Like, Partitioning, SkipType, WindowFrameUnits,
dml, Distinct, FetchType, Like, Partitioning, SkipType, WindowFrameUnits,
};
use datafusion::{
arrow::datatypes::{DataType, TimeUnit},
Expand Down Expand Up @@ -67,7 +69,8 @@ use substrait::proto::read_rel::VirtualTable;
use substrait::proto::rel_common::EmitKind;
use substrait::proto::rel_common::EmitKind::Emit;
use substrait::proto::{
rel_common, ExchangeRel, ExpressionReference, ExtendedExpression, RelCommon,
rel_common, ExchangeRel, ExpressionReference, ExtendedExpression, NamedObjectWrite,
RelCommon, WriteRel,
};
use substrait::{
proto::{
Expand Down Expand Up @@ -584,6 +587,40 @@ pub fn to_substrait_rel(
rel_type: Some(RelType::Exchange(Box::new(exchange_rel))),
}))
}
LogicalPlan::Dml(dml) => {
let input = to_substrait_rel(dml.input.as_ref(), ctx, extensions)?;

let op = match dml.op {
dml::WriteOp::Insert(InsertOp::Append) => WriteOp::Insert,
dml::WriteOp::Delete => WriteOp::Delete,
dml::WriteOp::Update => WriteOp::Update,
dml::WriteOp::Ctas => WriteOp::Ctas,
dml::WriteOp::Insert(InsertOp::Overwrite) | dml::WriteOp::Insert(InsertOp::Replace) => {
return not_impl_err!(
"Substrait does not support InsertOp::Overwrite and InsertOp::Replace"
)
}
}.into();

let table_schema =
Some(to_substrait_named_struct(&dml.table_schema, extensions)?);

let write_rel = WriteRel {
common: None,
input: Some(input),
write_type: Some(WriteType::NamedTable(NamedObjectWrite {
names: dml.table_name.to_vec(),
advanced_extension: None,
})),
op,
table_schema,
output: OutputMode::Unspecified.into(),
};

Ok(Box::new(Rel {
rel_type: Some(RelType::Write(Box::new(write_rel))),
}))
}
LogicalPlan::Extension(extension_plan) => {
let extension_bytes = ctx
.state()
Expand Down
13 changes: 13 additions & 0 deletions datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1093,6 +1093,19 @@ async fn roundtrip_repartition_hash() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn roundtrip_insert() -> Result<()> {
assert_expected_plan_unoptimized(
"INSERT INTO data SELECT * FROM data",
"Dml: op=[Insert Into] table=[data]\
\n Projection: data.a, data.b, data.c, data.d, data.e, data.f\
\n Projection: data.a, data.b, data.c, data.d, data.e, data.f\
\n TableScan: data",
true,
)
.await
}

fn check_post_join_filters(rel: &Rel) -> Result<()> {
// search for target_rel and field value in proto
match &rel.rel_type {
Expand Down

0 comments on commit 7e6252b

Please sign in to comment.