From dcb9db48dba5ab8553d502af4da0a7312441d241 Mon Sep 17 00:00:00 2001 From: Victor Barua Date: Thu, 26 Dec 2024 15:08:49 -0800 Subject: [PATCH 01/13] feat(substrait): modular substrait producer --- .../substrait/src/logical_plan/producer.rs | 2294 ++++++++++------- 1 file changed, 1305 insertions(+), 989 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index b73d246e1989..b905ee6c35df 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -22,7 +22,11 @@ use std::sync::Arc; use substrait::proto::expression_reference::ExprType; use datafusion::arrow::datatypes::{Field, IntervalUnit}; -use datafusion::logical_expr::{Distinct, Like, Partitioning, TryCast, WindowFrameUnits}; +use datafusion::logical_expr::{ + Aggregate, Distinct, EmptyRelation, Extension, Filter, Join, Like, Limit, + Partitioning, Projection, Repartition, Sort, SortExpr, SubqueryAlias, TableScan, + TryCast, Union, Values, Window, WindowFrameUnits, +}; use datafusion::{ arrow::datatypes::{DataType, TimeUnit}, error::{DataFusionError, Result}, @@ -43,11 +47,11 @@ use datafusion::arrow::array::{Array, GenericListArray, OffsetSizeTrait}; use datafusion::arrow::temporal_conversions::NANOSECONDS; use datafusion::common::{ exec_err, internal_err, not_impl_err, plan_err, substrait_datafusion_err, - substrait_err, DFSchema, DFSchemaRef, ToDFSchema, + substrait_err, Column, DFSchema, DFSchemaRef, ToDFSchema, }; -#[allow(unused_imports)] +use datafusion::execution::SessionState; use datafusion::logical_expr::expr::{ - Alias, BinaryExpr, Case, Cast, GroupingSet, InList, InSubquery, Sort, WindowFunction, + Alias, BinaryExpr, Case, Cast, GroupingSet, InList, InSubquery, WindowFunction, }; use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator}; use datafusion::prelude::Expr; @@ -63,6 +67,7 @@ use substrait::proto::expression::literal::{ }; use substrait::proto::expression::subquery::InPredicate; use substrait::proto::expression::window_function::BoundsType; +use substrait::proto::expression::ScalarFunction; use substrait::proto::read_rel::VirtualTable; use substrait::proto::rel_common::EmitKind; use substrait::proto::rel_common::EmitKind::Emit; @@ -84,8 +89,7 @@ use substrait::{ window_function::bound::Kind as BoundKind, window_function::Bound, FieldReference, IfThen, Literal, MaskExpression, ReferenceSegment, RexType, - ScalarFunction, SingularOrList, Subquery, - WindowFunction as SubstraitWindowFunction, + SingularOrList, WindowFunction as SubstraitWindowFunction, }, function_argument::ArgType, join_rel, plan_rel, r#type, @@ -101,14 +105,279 @@ use substrait::{ version, }; -use super::state::SubstraitPlanningState; +pub trait SubstraitProducer: Send + Sync + Sized { + fn get_extensions(self) -> Extensions; + + fn register_function(&mut self, signature: String) -> u32; + + // Logical Plans + fn consume_plan(&mut self, plan: &LogicalPlan) -> Result> { + to_substrait_rel(self, plan) + } + + fn consume_projection(&mut self, plan: &Projection) -> Result> { + from_projection(self, plan) + } + + fn consume_filter(&mut self, plan: &Filter) -> Result> { + from_filter(self, plan) + } + + fn consume_window(&mut self, plan: &Window) -> Result> { + from_window(self, plan) + } + + fn consume_aggregate(&mut self, plan: &Aggregate) -> Result> { + from_aggregate(self, plan) + } + + fn consume_sort(&mut self, plan: &Sort) -> Result> { + from_sort(self, plan) + } + + fn consume_join(&mut self, plan: &Join) -> Result> { + from_join(self, plan) + } + + fn consume_repartition(&mut self, plan: &Repartition) -> Result> { + from_repartition(self, plan) + } + + fn consume_union(&mut self, plan: &Union) -> Result> { + from_union(self, plan) + } + + fn consume_table_scan(&mut self, plan: &TableScan) -> Result> { + from_table_scan(self, plan) + } + + fn consume_empty_relation(&mut self, plan: &EmptyRelation) -> Result> { + from_empty_relation(plan) + } + + fn consume_subquery_alias(&mut self, plan: &SubqueryAlias) -> Result> { + from_subquery_alias(self, plan) + } + + fn consume_limit(&mut self, plan: &Limit) -> Result> { + from_limit(self, plan) + } + + fn consume_values(&mut self, plan: &Values) -> Result> { + from_values(self, plan) + } + + fn consume_distinct(&mut self, plan: &Distinct) -> Result> { + from_distinct(self, plan) + } + + fn consume_extension(&mut self, _plan: &Extension) -> Result> { + substrait_err!("Specify handling for LogicalPlan::Extension by implementing the SubstraitProducer trait") + } + + // Expressions + fn consume_expr( + &mut self, + expr: &Expr, + schema: &DFSchemaRef, + col_ref_offset: usize, + ) -> Result { + to_substrait_rex(self, expr, schema, col_ref_offset) + } + + fn consume_alias( + &mut self, + alias: &Alias, + schema: &DFSchemaRef, + col_ref_offset: usize, + ) -> Result { + from_alias(self, alias, schema, col_ref_offset) + } + + fn consume_column( + &mut self, + column: &Column, + schema: &DFSchemaRef, + col_ref_offset: usize, + ) -> Result { + from_column(column, schema, col_ref_offset) + } + + fn consume_literal(&mut self, value: &ScalarValue) -> Result { + from_literal(self, value) + } + + fn consume_binary_expr( + &mut self, + expr: &BinaryExpr, + schema: &DFSchemaRef, + col_ref_offset: usize, + ) -> Result { + from_binary_expr(self, expr, schema, col_ref_offset) + } + + fn consume_like( + &mut self, + like: &Like, + schema: &DFSchemaRef, + col_ref_offset: usize, + ) -> Result { + from_like(self, like, schema, col_ref_offset) + } + + /// Handles: Not, IsNotNull, IsNull, IsTrue, IsFalse, IsUnknown, IsNotTrue, IsNotFalse, IsNotUnknown, Negative + fn consume_unary_expr( + &mut self, + expr: &Expr, + schema: &DFSchemaRef, + col_ref_offset: usize, + ) -> Result { + from_unary_expr(self, expr, schema, col_ref_offset) + } + + fn consume_between( + &mut self, + between: &Between, + schema: &DFSchemaRef, + col_ref_offset: usize, + ) -> Result { + from_between(self, between, schema, col_ref_offset) + } + + fn consume_case( + &mut self, + case: &Case, + schema: &DFSchemaRef, + col_ref_offset: usize, + ) -> Result { + from_case(self, case, schema, col_ref_offset) + } + + fn consume_cast( + &mut self, + cast: &Cast, + schema: &DFSchemaRef, + col_ref_offset: usize, + ) -> Result { + from_cast(self, cast, schema, col_ref_offset) + } + + fn consume_try_cast( + &mut self, + cast: &TryCast, + schema: &DFSchemaRef, + col_ref_offset: usize, + ) -> Result { + from_try_cast(self, cast, schema, col_ref_offset) + } + + fn consume_scalar_function( + &mut self, + scalar_fn: &expr::ScalarFunction, + schema: &DFSchemaRef, + col_ref_offset: usize, + ) -> Result { + from_scalar_function(self, scalar_fn, schema, col_ref_offset) + } + + fn consume_agg_function( + &mut self, + agg_fn: &expr::AggregateFunction, + schema: &DFSchemaRef, + ) -> Result { + from_aggregate_function(self, agg_fn, schema) + } + + fn consume_window_function( + &mut self, + window_fn: &WindowFunction, + schema: &DFSchemaRef, + col_ref_offset: usize, + ) -> Result { + from_window_function(self, window_fn, schema, col_ref_offset) + } + + fn consume_in_list( + &mut self, + in_list: &InList, + schema: &DFSchemaRef, + col_ref_offset: usize, + ) -> Result { + from_in_list(self, in_list, schema, col_ref_offset) + } + + fn consume_in_subquery( + &mut self, + in_subquery: &InSubquery, + schema: &DFSchemaRef, + col_ref_offset: usize, + ) -> Result { + from_in_subquery(self, in_subquery, schema, col_ref_offset) + } +} + +struct DefaultSubstraitProducer<'a> { + extensions: Extensions, + state: &'a SessionState, +} + +impl<'a> DefaultSubstraitProducer<'a> { + pub fn new(state: &'a SessionState) -> Self { + DefaultSubstraitProducer { + extensions: Extensions::default(), + state, + } + } +} + +impl SubstraitProducer for DefaultSubstraitProducer<'_> { + fn get_extensions(self) -> Extensions { + self.extensions + } + + fn register_function(&mut self, fn_name: String) -> u32 { + self.extensions.register_function(fn_name) + } + + fn consume_extension(&mut self, plan: &Extension) -> Result> { + let extension_bytes = self + .state + .serializer_registry() + .serialize_logical_plan(plan.node.as_ref())?; + let detail = ProtoAny { + type_url: plan.node.name().to_string(), + value: extension_bytes.into(), + }; + let mut inputs_rel = plan + .node + .inputs() + .into_iter() + .map(|plan| self.consume_plan(plan)) + .collect::>>()?; + let rel_type = match inputs_rel.len() { + 0 => RelType::ExtensionLeaf(ExtensionLeafRel { + common: None, + detail: Some(detail), + }), + 1 => RelType::ExtensionSingle(Box::new(ExtensionSingleRel { + common: None, + detail: Some(detail), + input: Some(inputs_rel.pop().unwrap()), + })), + _ => RelType::ExtensionMulti(ExtensionMultiRel { + common: None, + detail: Some(detail), + inputs: inputs_rel.into_iter().map(|r| *r).collect(), + }), + }; + Ok(Box::new(Rel { + rel_type: Some(rel_type), + })) + } +} /// Convert DataFusion LogicalPlan to Substrait Plan -pub fn to_substrait_plan( - plan: &LogicalPlan, - state: &dyn SubstraitPlanningState, -) -> Result> { - let mut extensions = Extensions::default(); +pub fn to_substrait_plan(plan: &LogicalPlan, state: &SessionState) -> Result> { // Parse relation nodes // Generate PlanRel(s) // Note: Only 1 relation tree is currently supported @@ -117,14 +386,16 @@ pub fn to_substrait_plan( let plan = Arc::new(ExpandWildcardRule::new()) .analyze(plan.clone(), &ConfigOptions::default())?; + let mut producer: DefaultSubstraitProducer = DefaultSubstraitProducer::new(state); let plan_rels = vec![PlanRel { rel_type: Some(plan_rel::RelType::Root(RelRoot { - input: Some(*to_substrait_rel(&plan, state, &mut extensions)?), + input: Some(*producer.consume_plan(&plan)?), names: to_substrait_named_struct(plan.schema())?.names, })), }]; // Return parsed plan + let extensions = producer.get_extensions(); Ok(Box::new(Plan { version: Some(version::version_with_producer("datafusion")), extension_uris: vec![], @@ -150,20 +421,14 @@ pub fn to_substrait_plan( pub fn to_substrait_extended_expr( exprs: &[(&Expr, &Field)], schema: &DFSchemaRef, - state: &dyn SubstraitPlanningState, + state: &SessionState, ) -> Result> { - let mut extensions = Extensions::default(); - + let mut producer = DefaultSubstraitProducer::new(state); let substrait_exprs = exprs .iter() .map(|(expr, field)| { - let substrait_expr = to_substrait_rex( - state, - expr, - schema, - /*col_ref_offset=*/ 0, - &mut extensions, - )?; + let substrait_expr = + producer.consume_expr(expr, schema, /*col_ref_offset=*/ 0)?; let mut output_names = Vec::new(); flatten_names(field, false, &mut output_names)?; Ok(ExpressionReference { @@ -174,6 +439,7 @@ pub fn to_substrait_extended_expr( .collect::>>()?; let substrait_schema = to_substrait_named_struct(schema)?; + let extensions = producer.get_extensions(); Ok(Box::new(ExtendedExpression { advanced_extensions: None, expected_type_urls: vec![], @@ -185,257 +451,291 @@ pub fn to_substrait_extended_expr( })) } -/// Convert DataFusion LogicalPlan to Substrait Rel -#[allow(deprecated)] pub fn to_substrait_rel( + producer: &mut impl SubstraitProducer, plan: &LogicalPlan, - state: &dyn SubstraitPlanningState, - extensions: &mut Extensions, ) -> Result> { match plan { - LogicalPlan::TableScan(scan) => { - let projection = scan.projection.as_ref().map(|p| { - p.iter() - .map(|i| StructItem { - field: *i as i32, - child: None, - }) - .collect() - }); + LogicalPlan::Projection(plan) => producer.consume_projection(plan), + LogicalPlan::Filter(plan) => producer.consume_filter(plan), + LogicalPlan::Window(plan) => producer.consume_window(plan), + LogicalPlan::Aggregate(plan) => producer.consume_aggregate(plan), + LogicalPlan::Sort(plan) => producer.consume_sort(plan), + LogicalPlan::Join(plan) => producer.consume_join(plan), + LogicalPlan::Repartition(plan) => producer.consume_repartition(plan), + LogicalPlan::Union(plan) => producer.consume_union(plan), + LogicalPlan::TableScan(plan) => producer.consume_table_scan(plan), + LogicalPlan::EmptyRelation(plan) => producer.consume_empty_relation(plan), + LogicalPlan::SubqueryAlias(plan) => producer.consume_subquery_alias(plan), + LogicalPlan::Limit(plan) => producer.consume_limit(plan), + LogicalPlan::Values(plan) => producer.consume_values(plan), + LogicalPlan::Distinct(plan) => producer.consume_distinct(plan), + LogicalPlan::Extension(plan) => producer.consume_extension(plan), + _ => not_impl_err!("Unsupported plan type: {plan:?}")?, + } +} - let projection = projection.map(|struct_items| MaskExpression { - select: Some(StructSelect { struct_items }), - maintain_singular_struct: false, - }); +pub fn from_table_scan( + _producer: &mut impl SubstraitProducer, + scan: &TableScan, +) -> Result> { + let projection = scan.projection.as_ref().map(|p| { + p.iter() + .map(|i| StructItem { + field: *i as i32, + child: None, + }) + .collect() + }); + + let projection = projection.map(|struct_items| MaskExpression { + select: Some(StructSelect { struct_items }), + maintain_singular_struct: false, + }); + + let table_schema = scan.source.schema().to_dfschema_ref()?; + let base_schema = to_substrait_named_struct(&table_schema)?; + + Ok(Box::new(Rel { + rel_type: Some(RelType::Read(Box::new(ReadRel { + common: None, + base_schema: Some(base_schema), + filter: None, + best_effort_filter: None, + projection, + advanced_extension: None, + read_type: Some(ReadType::NamedTable(NamedTable { + names: scan.table_name.to_vec(), + advanced_extension: None, + })), + }))), + })) +} - let table_schema = scan.source.schema().to_dfschema_ref()?; - let base_schema = to_substrait_named_struct(&table_schema)?; +pub fn from_empty_relation(e: &EmptyRelation) -> Result> { + if e.produce_one_row { + return not_impl_err!("Producing a row from empty relation is unsupported"); + } + #[allow(deprecated)] + Ok(Box::new(Rel { + rel_type: Some(RelType::Read(Box::new(ReadRel { + common: None, + base_schema: Some(to_substrait_named_struct(&e.schema)?), + filter: None, + best_effort_filter: None, + projection: None, + advanced_extension: None, + read_type: Some(ReadType::VirtualTable(VirtualTable { + values: vec![], + expressions: vec![], + })), + }))), + })) +} - Ok(Box::new(Rel { - rel_type: Some(RelType::Read(Box::new(ReadRel { - common: None, - base_schema: Some(base_schema), - filter: None, - best_effort_filter: None, - projection, - advanced_extension: None, - read_type: Some(ReadType::NamedTable(NamedTable { - names: scan.table_name.to_vec(), - advanced_extension: None, - })), - }))), - })) - } - LogicalPlan::EmptyRelation(e) => { - if e.produce_one_row { - return not_impl_err!( - "Producing a row from empty relation is unsupported" - ); - } - Ok(Box::new(Rel { - rel_type: Some(RelType::Read(Box::new(ReadRel { - common: None, - base_schema: Some(to_substrait_named_struct(&e.schema)?), - filter: None, - best_effort_filter: None, - projection: None, - advanced_extension: None, - read_type: Some(ReadType::VirtualTable(VirtualTable { - values: vec![], - expressions: vec![], - })), - }))), - })) - } - LogicalPlan::Values(v) => { - let values = v - .values +pub fn from_values( + producer: &mut impl SubstraitProducer, + v: &Values, +) -> Result> { + let values = v + .values + .iter() + .map(|row| { + let fields = row .iter() - .map(|row| { - let fields = row - .iter() - .map(|v| match v { - Expr::Literal(sv) => to_substrait_literal(sv, extensions), - Expr::Alias(alias) => match alias.expr.as_ref() { - // The schema gives us the names, so we can skip aliases - Expr::Literal(sv) => to_substrait_literal(sv, extensions), - _ => Err(substrait_datafusion_err!( + .map(|v| match v { + Expr::Literal(sv) => to_substrait_literal(producer, sv), + Expr::Alias(alias) => match alias.expr.as_ref() { + // The schema gives us the names, so we can skip aliases + Expr::Literal(sv) => to_substrait_literal(producer, sv), + _ => Err(substrait_datafusion_err!( "Only literal types can be aliased in Virtual Tables, got: {}", alias.expr.variant_name() )), - }, - _ => Err(substrait_datafusion_err!( + }, + _ => Err(substrait_datafusion_err!( "Only literal types and aliases are supported in Virtual Tables, got: {}", v.variant_name() )), - }) - .collect::>()?; - Ok(Struct { fields }) }) .collect::>()?; - Ok(Box::new(Rel { - rel_type: Some(RelType::Read(Box::new(ReadRel { - common: None, - base_schema: Some(to_substrait_named_struct(&v.schema)?), - filter: None, - best_effort_filter: None, - projection: None, - advanced_extension: None, - read_type: Some(ReadType::VirtualTable(VirtualTable { - values, - expressions: vec![], - })), - }))), - })) - } - LogicalPlan::Projection(p) => { - let expressions = p - .expr - .iter() - .map(|e| to_substrait_rex(state, e, p.input.schema(), 0, extensions)) - .collect::>>()?; + Ok(Struct { fields }) + }) + .collect::>()?; + #[allow(deprecated)] + Ok(Box::new(Rel { + rel_type: Some(RelType::Read(Box::new(ReadRel { + common: None, + base_schema: Some(to_substrait_named_struct(&v.schema)?), + filter: None, + best_effort_filter: None, + projection: None, + advanced_extension: None, + read_type: Some(ReadType::VirtualTable(VirtualTable { + values, + expressions: vec![], + })), + }))), + })) +} - let emit_kind = create_project_remapping( - expressions.len(), - p.input.as_ref().schema().fields().len(), - ); - let common = RelCommon { - emit_kind: Some(emit_kind), - hint: None, - advanced_extension: None, - }; +pub fn from_projection( + producer: &mut impl SubstraitProducer, + p: &Projection, +) -> Result> { + let expressions = p + .expr + .iter() + .map(|e| producer.consume_expr(e, p.input.schema(), 0)) + .collect::>>()?; - Ok(Box::new(Rel { - rel_type: Some(RelType::Project(Box::new(ProjectRel { - common: Some(common), - input: Some(to_substrait_rel(p.input.as_ref(), state, extensions)?), - expressions, - advanced_extension: None, - }))), - })) - } - LogicalPlan::Filter(filter) => { - let input = to_substrait_rel(filter.input.as_ref(), state, extensions)?; - let filter_expr = to_substrait_rex( - state, - &filter.predicate, - filter.input.schema(), - 0, - extensions, - )?; - Ok(Box::new(Rel { - rel_type: Some(RelType::Filter(Box::new(FilterRel { - common: None, - input: Some(input), - condition: Some(Box::new(filter_expr)), - advanced_extension: None, - }))), - })) - } - LogicalPlan::Limit(limit) => { - let input = to_substrait_rel(limit.input.as_ref(), state, extensions)?; - let empty_schema = Arc::new(DFSchema::empty()); - let offset_mode = limit - .skip - .as_ref() - .map(|expr| { - to_substrait_rex(state, expr.as_ref(), &empty_schema, 0, extensions) - }) - .transpose()? - .map(Box::new) - .map(fetch_rel::OffsetMode::OffsetExpr); - let count_mode = limit - .fetch - .as_ref() - .map(|expr| { - to_substrait_rex(state, expr.as_ref(), &empty_schema, 0, extensions) - }) - .transpose()? - .map(Box::new) - .map(fetch_rel::CountMode::CountExpr); + let emit_kind = create_project_remapping( + expressions.len(), + p.input.as_ref().schema().fields().len(), + ); + let common = RelCommon { + emit_kind: Some(emit_kind), + hint: None, + advanced_extension: None, + }; + + Ok(Box::new(Rel { + rel_type: Some(RelType::Project(Box::new(ProjectRel { + common: Some(common), + input: Some(producer.consume_plan(p.input.as_ref())?), + expressions, + advanced_extension: None, + }))), + })) +} + +pub fn from_filter( + producer: &mut impl SubstraitProducer, + filter: &Filter, +) -> Result> { + let input = producer.consume_plan(filter.input.as_ref())?; + let filter_expr = + producer.consume_expr(&filter.predicate, filter.input.schema(), 0)?; + Ok(Box::new(Rel { + rel_type: Some(RelType::Filter(Box::new(FilterRel { + common: None, + input: Some(input), + condition: Some(Box::new(filter_expr)), + advanced_extension: None, + }))), + })) +} + +pub fn from_limit( + producer: &mut impl SubstraitProducer, + limit: &Limit, +) -> Result> { + let input = producer.consume_plan(limit.input.as_ref())?; + let empty_schema = Arc::new(DFSchema::empty()); + let offset_mode = limit + .skip + .as_ref() + .map(|expr| producer.consume_expr(expr.as_ref(), &empty_schema, 0)) + .transpose()? + .map(Box::new) + .map(fetch_rel::OffsetMode::OffsetExpr); + let count_mode = limit + .fetch + .as_ref() + .map(|expr| producer.consume_expr(expr.as_ref(), &empty_schema, 0)) + .transpose()? + .map(Box::new) + .map(fetch_rel::CountMode::CountExpr); + Ok(Box::new(Rel { + rel_type: Some(RelType::Fetch(Box::new(FetchRel { + common: None, + input: Some(input), + offset_mode, + count_mode, + advanced_extension: None, + }))), + })) +} + +pub fn from_sort(producer: &mut impl SubstraitProducer, sort: &Sort) -> Result> { + let Sort { expr, input, fetch } = sort; + let sort_fields = expr + .iter() + .map(|e| substrait_sort_field(producer, e, input.schema())) + .collect::>>()?; + + let input = producer.consume_plan(input.as_ref())?; + + let sort_rel = Box::new(Rel { + rel_type: Some(RelType::Sort(Box::new(SortRel { + common: None, + input: Some(input), + sorts: sort_fields, + advanced_extension: None, + }))), + }); + + match fetch { + Some(amount) => { + let count_mode = + Some(fetch_rel::CountMode::CountExpr(Box::new(Expression { + rex_type: Some(RexType::Literal(Literal { + nullable: false, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + literal_type: Some(LiteralType::I64(*amount as i64)), + })), + }))); Ok(Box::new(Rel { rel_type: Some(RelType::Fetch(Box::new(FetchRel { common: None, - input: Some(input), - offset_mode, + input: Some(sort_rel), + offset_mode: None, count_mode, advanced_extension: None, }))), })) } - LogicalPlan::Sort(datafusion::logical_expr::Sort { expr, input, fetch }) => { - let sort_fields = expr - .iter() - .map(|e| substrait_sort_field(state, e, input.schema(), extensions)) - .collect::>>()?; + None => Ok(sort_rel), + } +} - let input = to_substrait_rel(input.as_ref(), state, extensions)?; +pub fn from_aggregate( + producer: &mut impl SubstraitProducer, + agg: &Aggregate, +) -> Result> { + let input = producer.consume_plan(agg.input.as_ref())?; + let (grouping_expressions, groupings) = + to_substrait_groupings(producer, &agg.group_expr, agg.input.schema())?; + let measures = agg + .aggr_expr + .iter() + .map(|e| to_substrait_agg_measure(producer, e, agg.input.schema())) + .collect::>>()?; - let sort_rel = Box::new(Rel { - rel_type: Some(RelType::Sort(Box::new(SortRel { - common: None, - input: Some(input), - sorts: sort_fields, - advanced_extension: None, - }))), - }); - - match fetch { - Some(amount) => { - let count_mode = - Some(fetch_rel::CountMode::CountExpr(Box::new(Expression { - rex_type: Some(RexType::Literal(Literal { - nullable: false, - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - literal_type: Some(LiteralType::I64(*amount as i64)), - })), - }))); - Ok(Box::new(Rel { - rel_type: Some(RelType::Fetch(Box::new(FetchRel { - common: None, - input: Some(sort_rel), - offset_mode: None, - count_mode, - advanced_extension: None, - }))), - })) - } - None => Ok(sort_rel), - } - } - LogicalPlan::Aggregate(agg) => { - let input = to_substrait_rel(agg.input.as_ref(), state, extensions)?; - let (grouping_expressions, groupings) = to_substrait_groupings( - state, - &agg.group_expr, - agg.input.schema(), - extensions, - )?; - let measures = agg - .aggr_expr - .iter() - .map(|e| { - to_substrait_agg_measure(state, e, agg.input.schema(), extensions) - }) - .collect::>>()?; + Ok(Box::new(Rel { + rel_type: Some(RelType::Aggregate(Box::new(AggregateRel { + common: None, + input: Some(input), + grouping_expressions, + groupings, + measures, + advanced_extension: None, + }))), + })) +} - Ok(Box::new(Rel { - rel_type: Some(RelType::Aggregate(Box::new(AggregateRel { - common: None, - input: Some(input), - grouping_expressions, - groupings, - measures, - advanced_extension: None, - }))), - })) - } - LogicalPlan::Distinct(Distinct::All(plan)) => { +pub fn from_distinct( + producer: &mut impl SubstraitProducer, + distinct: &Distinct, +) -> Result> { + match distinct { + Distinct::All(plan) => { // Use Substrait's AggregateRel with empty measures to represent `select distinct` - let input = to_substrait_rel(plan.as_ref(), state, extensions)?; + let input = producer.consume_plan(plan.as_ref())?; // Get grouping keys from the input relation's number of output fields let grouping = (0..plan.schema().fields().len()) .map(substrait_field_ref) .collect::>>()?; + #[allow(deprecated)] Ok(Box::new(Rel { rel_type: Some(RelType::Aggregate(Box::new(AggregateRel { common: None, @@ -450,220 +750,186 @@ pub fn to_substrait_rel( }))), })) } - LogicalPlan::Join(join) => { - let left = to_substrait_rel(join.left.as_ref(), state, extensions)?; - let right = to_substrait_rel(join.right.as_ref(), state, extensions)?; - let join_type = to_substrait_jointype(join.join_type); - // we only support basic joins so return an error for anything not yet supported - match join.join_constraint { - JoinConstraint::On => {} - JoinConstraint::Using => { - return not_impl_err!("join constraint: `using`") - } - } - // parse filter if exists - let in_join_schema = join.left.schema().join(join.right.schema())?; - let join_filter = match &join.filter { - Some(filter) => Some(to_substrait_rex( - state, - filter, - &Arc::new(in_join_schema), - 0, - extensions, - )?), - None => None, - }; + Distinct::On(_) => not_impl_err!("Cannot convert Distinct::On"), + } +} - // map the left and right columns to binary expressions in the form `l = r` - // build a single expression for the ON condition, such as `l.a = r.a AND l.b = r.b` - let eq_op = if join.null_equals_null { - Operator::IsNotDistinctFrom - } else { - Operator::Eq - }; - let join_on = to_substrait_join_expr( - state, - &join.on, - eq_op, - join.left.schema(), - join.right.schema(), - extensions, - )?; - - // create conjunction between `join_on` and `join_filter` to embed all join conditions, - // whether equal or non-equal in a single expression - let join_expr = match &join_on { - Some(on_expr) => match &join_filter { - Some(filter) => Some(Box::new(make_binary_op_scalar_func( - on_expr, - filter, - Operator::And, - extensions, - ))), - None => join_on.map(Box::new), // the join expression will only contain `join_on` if filter doesn't exist - }, - None => match &join_filter { - Some(_) => join_filter.map(Box::new), // the join expression will only contain `join_filter` if the `on` condition doesn't exist - None => None, - }, - }; +pub fn from_join(producer: &mut impl SubstraitProducer, join: &Join) -> Result> { + let left = producer.consume_plan(join.left.as_ref())?; + let right = producer.consume_plan(join.right.as_ref())?; + let join_type = to_substrait_jointype(join.join_type); + // we only support basic joins so return an error for anything not yet supported + match join.join_constraint { + JoinConstraint::On => {} + JoinConstraint::Using => return not_impl_err!("join constraint: `using`"), + } + // parse filter if exists + let in_join_schema = join.left.schema().join(join.right.schema())?; + let join_filter = match &join.filter { + Some(filter) => Some(to_substrait_rex( + producer, + filter, + &Arc::new(in_join_schema), + 0, + )?), + None => None, + }; - Ok(Box::new(Rel { - rel_type: Some(RelType::Join(Box::new(JoinRel { - common: None, - left: Some(left), - right: Some(right), - r#type: join_type as i32, - expression: join_expr, - post_join_filter: None, - advanced_extension: None, - }))), - })) - } - LogicalPlan::SubqueryAlias(alias) => { - // Do nothing if encounters SubqueryAlias - // since there is no corresponding relation type in Substrait - to_substrait_rel(alias.input.as_ref(), state, extensions) - } - LogicalPlan::Union(union) => { - let input_rels = union - .inputs - .iter() - .map(|input| to_substrait_rel(input.as_ref(), state, extensions)) - .collect::>>()? - .into_iter() - .map(|ptr| *ptr) - .collect(); - Ok(Box::new(Rel { - rel_type: Some(RelType::Set(SetRel { - common: None, - inputs: input_rels, - op: set_rel::SetOp::UnionAll as i32, // UNION DISTINCT gets translated to AGGREGATION + UNION ALL - advanced_extension: None, - })), - })) - } - LogicalPlan::Window(window) => { - let input = to_substrait_rel(window.input.as_ref(), state, extensions)?; + // map the left and right columns to binary expressions in the form `l = r` + // build a single expression for the ON condition, such as `l.a = r.a AND l.b = r.b` + let eq_op = if join.null_equals_null { + Operator::IsNotDistinctFrom + } else { + Operator::Eq + }; + let join_on = to_substrait_join_expr( + producer, + &join.on, + eq_op, + join.left.schema(), + join.right.schema(), + )?; - // create a field reference for each input field - let mut expressions = (0..window.input.schema().fields().len()) - .map(substrait_field_ref) - .collect::>>()?; + // create conjunction between `join_on` and `join_filter` to embed all join conditions, + // whether equal or non-equal in a single expression + let join_expr = match &join_on { + Some(on_expr) => match &join_filter { + Some(filter) => Some(Box::new(make_binary_op_scalar_func( + producer, + on_expr, + filter, + Operator::And, + ))), + None => join_on.map(Box::new), // the join expression will only contain `join_on` if filter doesn't exist + }, + None => match &join_filter { + Some(_) => join_filter.map(Box::new), // the join expression will only contain `join_filter` if the `on` condition doesn't exist + None => None, + }, + }; - // process and add each window function expression - for expr in &window.window_expr { - expressions.push(to_substrait_rex( - state, - expr, - window.input.schema(), - 0, - extensions, - )?); - } + Ok(Box::new(Rel { + rel_type: Some(RelType::Join(Box::new(JoinRel { + common: None, + left: Some(left), + right: Some(right), + r#type: join_type as i32, + expression: join_expr, + post_join_filter: None, + advanced_extension: None, + }))), + })) +} - let emit_kind = create_project_remapping( - expressions.len(), - window.input.schema().fields().len(), - ); - let common = RelCommon { - emit_kind: Some(emit_kind), - hint: None, - advanced_extension: None, - }; - let project_rel = Box::new(ProjectRel { - common: Some(common), - input: Some(input), - expressions, - advanced_extension: None, - }); +pub fn from_subquery_alias( + producer: &mut impl SubstraitProducer, + alias: &SubqueryAlias, +) -> Result> { + // Do nothing if encounters SubqueryAlias + // since there is no corresponding relation type in Substrait + producer.consume_plan(alias.input.as_ref()) +} - Ok(Box::new(Rel { - rel_type: Some(RelType::Project(project_rel)), - })) +pub fn from_union( + producer: &mut impl SubstraitProducer, + union: &Union, +) -> Result> { + let input_rels = union + .inputs + .iter() + .map(|input| producer.consume_plan(input.as_ref())) + .collect::>>()? + .into_iter() + .map(|ptr| *ptr) + .collect(); + Ok(Box::new(Rel { + rel_type: Some(RelType::Set(SetRel { + common: None, + inputs: input_rels, + op: set_rel::SetOp::UnionAll as i32, // UNION DISTINCT gets translated to AGGREGATION + UNION ALL + advanced_extension: None, + })), + })) +} + +pub fn from_window( + producer: &mut impl SubstraitProducer, + window: &Window, +) -> Result> { + let input = producer.consume_plan(window.input.as_ref())?; + + // create a field reference for each input field + let mut expressions = (0..window.input.schema().fields().len()) + .map(substrait_field_ref) + .collect::>>()?; + + // process and add each window function expression + for expr in &window.window_expr { + expressions.push(producer.consume_expr(expr, window.input.schema(), 0)?); + } + + let emit_kind = + create_project_remapping(expressions.len(), window.input.schema().fields().len()); + let common = RelCommon { + emit_kind: Some(emit_kind), + hint: None, + advanced_extension: None, + }; + let project_rel = Box::new(ProjectRel { + common: Some(common), + input: Some(input), + expressions, + advanced_extension: None, + }); + + Ok(Box::new(Rel { + rel_type: Some(RelType::Project(project_rel)), + })) +} + +pub fn from_repartition( + producer: &mut impl SubstraitProducer, + repartition: &Repartition, +) -> Result> { + let input = producer.consume_plan(repartition.input.as_ref())?; + let partition_count = match repartition.partitioning_scheme { + Partitioning::RoundRobinBatch(num) => num, + Partitioning::Hash(_, num) => num, + Partitioning::DistributeBy(_) => { + return not_impl_err!( + "Physical plan does not support DistributeBy partitioning" + ) } - LogicalPlan::Repartition(repartition) => { - let input = to_substrait_rel(repartition.input.as_ref(), state, extensions)?; - let partition_count = match repartition.partitioning_scheme { - Partitioning::RoundRobinBatch(num) => num, - Partitioning::Hash(_, num) => num, - Partitioning::DistributeBy(_) => { - return not_impl_err!( - "Physical plan does not support DistributeBy partitioning" - ) - } - }; - // ref: https://substrait.io/relations/physical_relations/#exchange-types - let exchange_kind = match &repartition.partitioning_scheme { - Partitioning::RoundRobinBatch(_) => { - ExchangeKind::RoundRobin(RoundRobin::default()) - } - Partitioning::Hash(exprs, _) => { - let fields = exprs - .iter() - .map(|e| { - try_to_substrait_field_reference( - e, - repartition.input.schema(), - ) - }) - .collect::>>()?; - ExchangeKind::ScatterByFields(ScatterFields { fields }) - } - Partitioning::DistributeBy(_) => { - return not_impl_err!( - "Physical plan does not support DistributeBy partitioning" - ) - } - }; - let exchange_rel = ExchangeRel { - common: None, - input: Some(input), - exchange_kind: Some(exchange_kind), - advanced_extension: None, - partition_count: partition_count as i32, - targets: vec![], - }; - Ok(Box::new(Rel { - rel_type: Some(RelType::Exchange(Box::new(exchange_rel))), - })) + }; + // ref: https://substrait.io/relations/physical_relations/#exchange-types + let exchange_kind = match &repartition.partitioning_scheme { + Partitioning::RoundRobinBatch(_) => { + ExchangeKind::RoundRobin(RoundRobin::default()) } - LogicalPlan::Extension(extension_plan) => { - let extension_bytes = state - .serializer_registry() - .serialize_logical_plan(extension_plan.node.as_ref())?; - let detail = ProtoAny { - type_url: extension_plan.node.name().to_string(), - value: extension_bytes.into(), - }; - let mut inputs_rel = extension_plan - .node - .inputs() - .into_iter() - .map(|plan| to_substrait_rel(plan, state, extensions)) + Partitioning::Hash(exprs, _) => { + let fields = exprs + .iter() + .map(|e| try_to_substrait_field_reference(e, repartition.input.schema())) .collect::>>()?; - let rel_type = match inputs_rel.len() { - 0 => RelType::ExtensionLeaf(ExtensionLeafRel { - common: None, - detail: Some(detail), - }), - 1 => RelType::ExtensionSingle(Box::new(ExtensionSingleRel { - common: None, - detail: Some(detail), - input: Some(inputs_rel.pop().unwrap()), - })), - _ => RelType::ExtensionMulti(ExtensionMultiRel { - common: None, - detail: Some(detail), - inputs: inputs_rel.into_iter().map(|r| *r).collect(), - }), - }; - Ok(Box::new(Rel { - rel_type: Some(rel_type), - })) + ExchangeKind::ScatterByFields(ScatterFields { fields }) } - _ => not_impl_err!("Unsupported operator: {plan}"), - } + Partitioning::DistributeBy(_) => { + return not_impl_err!( + "Physical plan does not support DistributeBy partitioning" + ) + } + }; + let exchange_rel = ExchangeRel { + common: None, + input: Some(input), + exchange_kind: Some(exchange_kind), + advanced_extension: None, + partition_count: partition_count as i32, + targets: vec![], + }; + Ok(Box::new(Rel { + rel_type: Some(RelType::Exchange(Box::new(exchange_rel))), + })) } /// By default, a Substrait Project outputs all input fields followed by all expressions. @@ -730,32 +996,30 @@ fn to_substrait_named_struct(schema: &DFSchemaRef) -> Result { } fn to_substrait_join_expr( - state: &dyn SubstraitPlanningState, + producer: &mut impl SubstraitProducer, join_conditions: &Vec<(Expr, Expr)>, eq_op: Operator, left_schema: &DFSchemaRef, right_schema: &DFSchemaRef, - extensions: &mut Extensions, ) -> Result> { // Only support AND conjunction for each binary expression in join conditions let mut exprs: Vec = vec![]; for (left, right) in join_conditions { // Parse left - let l = to_substrait_rex(state, left, left_schema, 0, extensions)?; + let l = producer.consume_expr(left, left_schema, 0)?; // Parse right let r = to_substrait_rex( - state, + producer, right, right_schema, left_schema.fields().len(), // offset to return the correct index - extensions, )?; // AND with existing expression - exprs.push(make_binary_op_scalar_func(&l, &r, eq_op, extensions)); + exprs.push(make_binary_op_scalar_func(producer, &l, &r, eq_op)); } let join_expr: Option = exprs.into_iter().reduce(|acc: Expression, e: Expression| { - make_binary_op_scalar_func(&acc, &e, Operator::And, extensions) + make_binary_op_scalar_func(producer, &acc, &e, Operator::And) }); Ok(join_expr) } @@ -811,23 +1075,22 @@ pub fn operator_to_name(op: Operator) -> &'static str { } } -#[allow(deprecated)] pub fn parse_flat_grouping_exprs( - state: &dyn SubstraitPlanningState, + producer: &mut impl SubstraitProducer, exprs: &[Expr], schema: &DFSchemaRef, - extensions: &mut Extensions, ref_group_exprs: &mut Vec, ) -> Result { let mut expression_references = vec![]; let mut grouping_expressions = vec![]; for e in exprs { - let rex = to_substrait_rex(state, e, schema, 0, extensions)?; + let rex = producer.consume_expr(e, schema, 0)?; grouping_expressions.push(rex.clone()); ref_group_exprs.push(rex); expression_references.push((ref_group_exprs.len() - 1) as u32); } + #[allow(deprecated)] Ok(Grouping { grouping_expressions, expression_references, @@ -835,10 +1098,9 @@ pub fn parse_flat_grouping_exprs( } pub fn to_substrait_groupings( - state: &dyn SubstraitPlanningState, + producer: &mut impl SubstraitProducer, exprs: &[Expr], schema: &DFSchemaRef, - extensions: &mut Extensions, ) -> Result<(Vec, Vec)> { let mut ref_group_exprs = vec![]; let groupings = match exprs.len() { @@ -851,10 +1113,9 @@ pub fn to_substrait_groupings( .iter() .map(|set| { parse_flat_grouping_exprs( - state, + producer, set, schema, - extensions, &mut ref_group_exprs, ) }) @@ -869,10 +1130,9 @@ pub fn to_substrait_groupings( .rev() .map(|set| { parse_flat_grouping_exprs( - state, + producer, set, schema, - extensions, &mut ref_group_exprs, ) }) @@ -880,66 +1140,81 @@ pub fn to_substrait_groupings( } }, _ => Ok(vec![parse_flat_grouping_exprs( - state, + producer, exprs, schema, - extensions, &mut ref_group_exprs, )?]), }, _ => Ok(vec![parse_flat_grouping_exprs( - state, + producer, exprs, schema, - extensions, &mut ref_group_exprs, )?]), }?; Ok((ref_group_exprs, groupings)) } -#[allow(deprecated)] +pub fn from_aggregate_function( + producer: &mut impl SubstraitProducer, + agg_fn: &expr::AggregateFunction, + schema: &DFSchemaRef, +) -> Result { + let expr::AggregateFunction { + func, + args, + distinct, + filter, + order_by, + null_treatment: _null_treatment, + } = agg_fn; + let sorts = if let Some(order_by) = order_by { + order_by + .iter() + .map(|expr| to_substrait_sort_field(producer, expr, schema)) + .collect::>>()? + } else { + vec![] + }; + let mut arguments: Vec = vec![]; + for arg in args { + arguments.push(FunctionArgument { + arg_type: Some(ArgType::Value(producer.consume_expr(arg, schema, 0)?)), + }); + } + let function_anchor = producer.register_function(func.name().to_string()); + #[allow(deprecated)] + Ok(Measure { + measure: Some(AggregateFunction { + function_reference: function_anchor, + arguments, + sorts, + output_type: None, + invocation: match distinct { + true => AggregationInvocation::Distinct as i32, + false => AggregationInvocation::All as i32, + }, + phase: AggregationPhase::Unspecified as i32, + args: vec![], + options: vec![], + }), + filter: match filter { + Some(f) => Some(producer.consume_expr(f, schema, 0)?), + None => None, + }, + }) +} + pub fn to_substrait_agg_measure( - state: &dyn SubstraitPlanningState, + producer: &mut impl SubstraitProducer, expr: &Expr, schema: &DFSchemaRef, - extensions: &mut Extensions, ) -> Result { match expr { - Expr::AggregateFunction(expr::AggregateFunction { func, args, distinct, filter, order_by, null_treatment: _, }) => { - let sorts = if let Some(order_by) = order_by { - order_by.iter().map(|expr| to_substrait_sort_field(state, expr, schema, extensions)).collect::>>()? - } else { - vec![] - }; - let mut arguments: Vec = vec![]; - for arg in args { - arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(state, arg, schema, 0, extensions)?)) }); - } - let function_anchor = extensions.register_function(func.name().to_string()); - Ok(Measure { - measure: Some(AggregateFunction { - function_reference: function_anchor, - arguments, - sorts, - output_type: None, - invocation: match distinct { - true => AggregationInvocation::Distinct as i32, - false => AggregationInvocation::All as i32, - }, - phase: AggregationPhase::Unspecified as i32, - args: vec![], - options: vec![], - }), - filter: match filter { - Some(f) => Some(to_substrait_rex(state, f, schema, 0, extensions)?), - None => None - } - }) - - } - Expr::Alias(Alias{expr,..})=> { - to_substrait_agg_measure(state, expr, schema, extensions) + Expr::AggregateFunction(agg_fn) => from_aggregate_function(producer, agg_fn, schema), + Expr::Alias(Alias{expr,..}) => { + to_substrait_agg_measure(producer, expr, schema) } _ => internal_err!( "Expression must be compatible with aggregation. Unsupported expression: {:?}. ExpressionType: {:?}", @@ -951,10 +1226,9 @@ pub fn to_substrait_agg_measure( /// Converts sort expression to corresponding substrait `SortField` fn to_substrait_sort_field( - state: &dyn SubstraitPlanningState, - sort: &Sort, + producer: &mut impl SubstraitProducer, + sort: &expr::Sort, schema: &DFSchemaRef, - extensions: &mut Extensions, ) -> Result { let sort_kind = match (sort.asc, sort.nulls_first) { (true, true) => SortDirection::AscNullsFirst, @@ -963,485 +1237,529 @@ fn to_substrait_sort_field( (false, false) => SortDirection::DescNullsLast, }; Ok(SortField { - expr: Some(to_substrait_rex(state, &sort.expr, schema, 0, extensions)?), + expr: Some(producer.consume_expr(&sort.expr, schema, 0)?), sort_kind: Some(SortKind::Direction(sort_kind.into())), }) } -/// Return Substrait scalar function with two arguments -#[allow(deprecated)] -pub fn make_binary_op_scalar_func( - lhs: &Expression, - rhs: &Expression, - op: Operator, - extensions: &mut Extensions, -) -> Expression { - let function_anchor = extensions.register_function(operator_to_name(op).to_string()); - Expression { +/// Return Substrait scalar function with two arguments +pub fn make_binary_op_scalar_func( + producer: &mut impl SubstraitProducer, + lhs: &Expression, + rhs: &Expression, + op: Operator, +) -> Expression { + let function_anchor = producer.register_function(operator_to_name(op).to_string()); + #[allow(deprecated)] + Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments: vec![ + FunctionArgument { + arg_type: Some(ArgType::Value(lhs.clone())), + }, + FunctionArgument { + arg_type: Some(ArgType::Value(rhs.clone())), + }, + ], + output_type: None, + args: vec![], + options: vec![], + })), + } +} + +/// Convert DataFusion Expr to Substrait Rex +/// +/// # Arguments +/// +/// * `expr` - DataFusion expression to be parse into a Substrait expression +/// * `schema` - DataFusion input schema for looking up field qualifiers +/// * `col_ref_offset` - Offset for calculating Substrait field reference indices. +/// This should only be set by caller with more than one input relations i.e. Join. +/// Substrait expects one set of indices when joining two relations. +/// Let's say `left` and `right` have `m` and `n` columns, respectively. The `right` +/// relation will have column indices from `0` to `n-1`, however, Substrait will expect +/// the `right` indices to be offset by the `left`. This means Substrait will expect to +/// evaluate the join condition expression on indices [0 .. n-1, n .. n+m-1]. For example: +/// ```SELECT * +/// FROM t1 +/// JOIN t2 +/// ON t1.c1 = t2.c0;``` +/// where t1 consists of columns [c0, c1, c2], and t2 = columns [c0, c1] +/// the join condition should become +/// `col_ref(1) = col_ref(3 + 0)` +/// , where `3` is the number of `left` columns (`col_ref_offset`) and `0` is the index +/// of the join key column from `right` +/// * `extensions` - Substrait extension info. Contains registered function information +pub fn to_substrait_rex( + producer: &mut impl SubstraitProducer, + expr: &Expr, + schema: &DFSchemaRef, + col_ref_offset: usize, +) -> Result { + match expr { + Expr::Alias(expr) => producer.consume_alias(expr, schema, col_ref_offset), + Expr::Column(expr) => producer.consume_column(expr, schema, col_ref_offset), + Expr::Literal(expr) => producer.consume_literal(expr), + Expr::BinaryExpr(expr) => { + producer.consume_binary_expr(expr, schema, col_ref_offset) + } + Expr::Like(expr) => producer.consume_like(expr, schema, col_ref_offset), + Expr::SimilarTo(_) => not_impl_err!("SimilarTo is not supported"), + Expr::Not(_) => producer.consume_unary_expr(expr, schema, col_ref_offset), + Expr::IsNotNull(_) => producer.consume_unary_expr(expr, schema, col_ref_offset), + Expr::IsNull(_) => producer.consume_unary_expr(expr, schema, col_ref_offset), + Expr::IsTrue(_) => producer.consume_unary_expr(expr, schema, col_ref_offset), + Expr::IsFalse(_) => producer.consume_unary_expr(expr, schema, col_ref_offset), + Expr::IsUnknown(_) => producer.consume_unary_expr(expr, schema, col_ref_offset), + Expr::IsNotTrue(_) => producer.consume_unary_expr(expr, schema, col_ref_offset), + Expr::IsNotFalse(_) => producer.consume_unary_expr(expr, schema, col_ref_offset), + Expr::IsNotUnknown(_) => { + producer.consume_unary_expr(expr, schema, col_ref_offset) + } + Expr::Negative(_) => producer.consume_unary_expr(expr, schema, col_ref_offset), + Expr::Between(expr) => producer.consume_between(expr, schema, col_ref_offset), + Expr::Case(expr) => producer.consume_case(expr, schema, col_ref_offset), + Expr::Cast(expr) => producer.consume_cast(expr, schema, col_ref_offset), + Expr::TryCast(expr) => producer.consume_try_cast(expr, schema, col_ref_offset), + Expr::ScalarFunction(expr) => { + producer.consume_scalar_function(expr, schema, col_ref_offset) + } + Expr::AggregateFunction(_) => { + internal_err!( + "AggregateFunction should only be encountered as part of a LogicalPlan::Aggregate" + ) + } + Expr::WindowFunction(expr) => { + producer.consume_window_function(expr, schema, col_ref_offset) + } + Expr::InList(expr) => producer.consume_in_list(expr, schema, col_ref_offset), + Expr::InSubquery(expr) => { + producer.consume_in_subquery(expr, schema, col_ref_offset) + } + _ => not_impl_err!("Cannot convert {expr:?} to Substrait"), + } +} + +pub fn from_in_list( + producer: &mut impl SubstraitProducer, + in_list: &InList, + schema: &DFSchemaRef, + col_ref_offset: usize, +) -> Result { + let InList { + expr, + list, + negated, + } = in_list; + let substrait_list = list + .iter() + .map(|x| producer.consume_expr(x, schema, col_ref_offset)) + .collect::>>()?; + let substrait_expr = producer.consume_expr(expr, schema, col_ref_offset)?; + + let substrait_or_list = Expression { + rex_type: Some(RexType::SingularOrList(Box::new(SingularOrList { + value: Some(Box::new(substrait_expr)), + options: substrait_list, + }))), + }; + + if *negated { + let function_anchor = producer.register_function("not".to_string()); + + #[allow(deprecated)] + Ok(Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments: vec![FunctionArgument { + arg_type: Some(ArgType::Value(substrait_or_list)), + }], + output_type: None, + args: vec![], + options: vec![], + })), + }) + } else { + Ok(substrait_or_list) + } +} + +pub fn from_scalar_function( + producer: &mut impl SubstraitProducer, + fun: &expr::ScalarFunction, + schema: &DFSchemaRef, + col_ref_offset: usize, +) -> Result { + let mut arguments: Vec = vec![]; + for arg in &fun.args { + arguments.push(FunctionArgument { + arg_type: Some(ArgType::Value(to_substrait_rex( + producer, + arg, + schema, + col_ref_offset, + )?)), + }); + } + + let function_anchor = producer.register_function(fun.name().to_string()); + #[allow(deprecated)] + Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { function_reference: function_anchor, - arguments: vec![ - FunctionArgument { - arg_type: Some(ArgType::Value(lhs.clone())), - }, - FunctionArgument { - arg_type: Some(ArgType::Value(rhs.clone())), - }, - ], + arguments, output_type: None, - args: vec![], options: vec![], + args: vec![], })), - } + }) } -/// Convert DataFusion Expr to Substrait Rex -/// -/// # Arguments -/// -/// * `expr` - DataFusion expression to be parse into a Substrait expression -/// * `schema` - DataFusion input schema for looking up field qualifiers -/// * `col_ref_offset` - Offset for calculating Substrait field reference indices. -/// This should only be set by caller with more than one input relations i.e. Join. -/// Substrait expects one set of indices when joining two relations. -/// Let's say `left` and `right` have `m` and `n` columns, respectively. The `right` -/// relation will have column indices from `0` to `n-1`, however, Substrait will expect -/// the `right` indices to be offset by the `left`. This means Substrait will expect to -/// evaluate the join condition expression on indices [0 .. n-1, n .. n+m-1]. For example: -/// ```SELECT * -/// FROM t1 -/// JOIN t2 -/// ON t1.c1 = t2.c0;``` -/// where t1 consists of columns [c0, c1, c2], and t2 = columns [c0, c1] -/// the join condition should become -/// `col_ref(1) = col_ref(3 + 0)` -/// , where `3` is the number of `left` columns (`col_ref_offset`) and `0` is the index -/// of the join key column from `right` -/// * `extensions` - Substrait extension info. Contains registered function information -#[allow(deprecated)] -pub fn to_substrait_rex( - state: &dyn SubstraitPlanningState, - expr: &Expr, +pub fn from_between( + producer: &mut impl SubstraitProducer, + between: &Between, schema: &DFSchemaRef, col_ref_offset: usize, - extensions: &mut Extensions, ) -> Result { - match expr { - Expr::InList(InList { - expr, - list, - negated, - }) => { - let substrait_list = list - .iter() - .map(|x| to_substrait_rex(state, x, schema, col_ref_offset, extensions)) - .collect::>>()?; - let substrait_expr = - to_substrait_rex(state, expr, schema, col_ref_offset, extensions)?; + let Between { + expr, + negated, + low, + high, + } = between; + if *negated { + // `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr) + let substrait_expr = + producer.consume_expr(expr.as_ref(), schema, col_ref_offset)?; + let substrait_low = + producer.consume_expr(low.as_ref(), schema, col_ref_offset)?; + let substrait_high = + producer.consume_expr(high.as_ref(), schema, col_ref_offset)?; + + let l_expr = make_binary_op_scalar_func( + producer, + &substrait_expr, + &substrait_low, + Operator::Lt, + ); + let r_expr = make_binary_op_scalar_func( + producer, + &substrait_high, + &substrait_expr, + Operator::Lt, + ); - let substrait_or_list = Expression { - rex_type: Some(RexType::SingularOrList(Box::new(SingularOrList { - value: Some(Box::new(substrait_expr)), - options: substrait_list, - }))), - }; + Ok(make_binary_op_scalar_func( + producer, + &l_expr, + &r_expr, + Operator::Or, + )) + } else { + // `expr BETWEEN low AND high` can be translated into (low <= expr AND expr <= high) + let substrait_expr = + producer.consume_expr(expr.as_ref(), schema, col_ref_offset)?; + let substrait_low = + producer.consume_expr(low.as_ref(), schema, col_ref_offset)?; + let substrait_high = + producer.consume_expr(high.as_ref(), schema, col_ref_offset)?; + + let l_expr = make_binary_op_scalar_func( + producer, + &substrait_low, + &substrait_expr, + Operator::LtEq, + ); + let r_expr = make_binary_op_scalar_func( + producer, + &substrait_expr, + &substrait_high, + Operator::LtEq, + ); - if *negated { - let function_anchor = extensions.register_function("not".to_string()); - - Ok(Expression { - rex_type: Some(RexType::ScalarFunction(ScalarFunction { - function_reference: function_anchor, - arguments: vec![FunctionArgument { - arg_type: Some(ArgType::Value(substrait_or_list)), - }], - output_type: None, - args: vec![], - options: vec![], - })), - }) - } else { - Ok(substrait_or_list) - } - } - Expr::ScalarFunction(fun) => { - let mut arguments: Vec = vec![]; - for arg in &fun.args { - arguments.push(FunctionArgument { - arg_type: Some(ArgType::Value(to_substrait_rex( - state, - arg, - schema, - col_ref_offset, - extensions, - )?)), - }); - } + Ok(make_binary_op_scalar_func( + producer, + &l_expr, + &r_expr, + Operator::And, + )) + } +} +pub fn from_column( + col: &Column, + schema: &DFSchemaRef, + col_ref_offset: usize, +) -> Result { + let index = schema.index_of_column(col)?; + substrait_field_ref(index + col_ref_offset) +} - let function_anchor = extensions.register_function(fun.name().to_string()); - Ok(Expression { - rex_type: Some(RexType::ScalarFunction(ScalarFunction { - function_reference: function_anchor, - arguments, - output_type: None, - args: vec![], - options: vec![], - })), - }) - } - Expr::Between(Between { - expr, - negated, - low, - high, - }) => { - if *negated { - // `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr) - let substrait_expr = - to_substrait_rex(state, expr, schema, col_ref_offset, extensions)?; - let substrait_low = - to_substrait_rex(state, low, schema, col_ref_offset, extensions)?; - let substrait_high = - to_substrait_rex(state, high, schema, col_ref_offset, extensions)?; - - let l_expr = make_binary_op_scalar_func( - &substrait_expr, - &substrait_low, - Operator::Lt, - extensions, - ); - let r_expr = make_binary_op_scalar_func( - &substrait_high, - &substrait_expr, - Operator::Lt, - extensions, - ); +pub fn from_binary_expr( + producer: &mut impl SubstraitProducer, + expr: &BinaryExpr, + schema: &DFSchemaRef, + col_ref_offset: usize, +) -> Result { + let BinaryExpr { left, op, right } = expr; + let l = producer.consume_expr(left, schema, col_ref_offset)?; + let r = producer.consume_expr(right, schema, col_ref_offset)?; + Ok(make_binary_op_scalar_func(producer, &l, &r, *op)) +} +pub fn from_case( + producer: &mut impl SubstraitProducer, + case: &Case, + schema: &DFSchemaRef, + col_ref_offset: usize, +) -> Result { + let Case { + expr, + when_then_expr, + else_expr, + } = case; + let mut ifs: Vec = vec![]; + // Parse base + if let Some(e) = expr { + // Base expression exists + ifs.push(IfClause { + r#if: Some(producer.consume_expr(e, schema, col_ref_offset)?), + then: None, + }); + } + // Parse `when`s + for (r#if, then) in when_then_expr { + ifs.push(IfClause { + r#if: Some(producer.consume_expr(r#if, schema, col_ref_offset)?), + then: Some(producer.consume_expr(then, schema, col_ref_offset)?), + }); + } - Ok(make_binary_op_scalar_func( - &l_expr, - &r_expr, - Operator::Or, - extensions, - )) - } else { - // `expr BETWEEN low AND high` can be translated into (low <= expr AND expr <= high) - let substrait_expr = - to_substrait_rex(state, expr, schema, col_ref_offset, extensions)?; - let substrait_low = - to_substrait_rex(state, low, schema, col_ref_offset, extensions)?; - let substrait_high = - to_substrait_rex(state, high, schema, col_ref_offset, extensions)?; - - let l_expr = make_binary_op_scalar_func( - &substrait_low, - &substrait_expr, - Operator::LtEq, - extensions, - ); - let r_expr = make_binary_op_scalar_func( - &substrait_expr, - &substrait_high, - Operator::LtEq, - extensions, - ); + // Parse outer `else` + let r#else: Option> = match else_expr { + Some(e) => Some(Box::new(to_substrait_rex( + producer, + e, + schema, + col_ref_offset, + )?)), + None => None, + }; - Ok(make_binary_op_scalar_func( - &l_expr, - &r_expr, - Operator::And, - extensions, - )) - } - } - Expr::Column(col) => { - let index = schema.index_of_column(col)?; - substrait_field_ref(index + col_ref_offset) - } - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - let l = to_substrait_rex(state, left, schema, col_ref_offset, extensions)?; - let r = to_substrait_rex(state, right, schema, col_ref_offset, extensions)?; + Ok(Expression { + rex_type: Some(RexType::IfThen(Box::new(IfThen { ifs, r#else }))), + }) +} - Ok(make_binary_op_scalar_func(&l, &r, *op, extensions)) - } - Expr::Case(Case { - expr, - when_then_expr, - else_expr, - }) => { - let mut ifs: Vec = vec![]; - // Parse base - if let Some(e) = expr { - // Base expression exists - ifs.push(IfClause { - r#if: Some(to_substrait_rex( - state, - e, - schema, - col_ref_offset, - extensions, - )?), - then: None, - }); - } - // Parse `when`s - for (r#if, then) in when_then_expr { - ifs.push(IfClause { - r#if: Some(to_substrait_rex( - state, - r#if, - schema, - col_ref_offset, - extensions, - )?), - then: Some(to_substrait_rex( - state, - then, - schema, - col_ref_offset, - extensions, - )?), - }); - } +pub fn from_cast( + producer: &mut impl SubstraitProducer, + cast: &Cast, + schema: &DFSchemaRef, + col_ref_offset: usize, +) -> Result { + let Cast { expr, data_type } = cast; + Ok(Expression { + rex_type: Some(RexType::Cast(Box::new( + substrait::proto::expression::Cast { + r#type: Some(to_substrait_type(data_type, true)?), + input: Some(Box::new(to_substrait_rex( + producer, + expr, + schema, + col_ref_offset, + )?)), + failure_behavior: FailureBehavior::ThrowException.into(), + }, + ))), + }) +} - // Parse outer `else` - let r#else: Option> = match else_expr { - Some(e) => Some(Box::new(to_substrait_rex( - state, - e, +pub fn from_try_cast( + producer: &mut impl SubstraitProducer, + cast: &TryCast, + schema: &DFSchemaRef, + col_ref_offset: usize, +) -> Result { + let TryCast { expr, data_type } = cast; + Ok(Expression { + rex_type: Some(RexType::Cast(Box::new( + substrait::proto::expression::Cast { + r#type: Some(to_substrait_type(data_type, true)?), + input: Some(Box::new(to_substrait_rex( + producer, + expr, schema, col_ref_offset, - extensions, )?)), - None => None, - }; + failure_behavior: FailureBehavior::ReturnNull.into(), + }, + ))), + }) +} - Ok(Expression { - rex_type: Some(RexType::IfThen(Box::new(IfThen { ifs, r#else }))), - }) - } - Expr::Cast(Cast { expr, data_type }) => Ok(Expression { - rex_type: Some(RexType::Cast(Box::new( - substrait::proto::expression::Cast { - r#type: Some(to_substrait_type(data_type, true)?), - input: Some(Box::new(to_substrait_rex( - state, - expr, - schema, - col_ref_offset, - extensions, - )?)), - failure_behavior: FailureBehavior::ThrowException.into(), - }, - ))), - }), - Expr::TryCast(TryCast { expr, data_type }) => Ok(Expression { - rex_type: Some(RexType::Cast(Box::new( - substrait::proto::expression::Cast { - r#type: Some(to_substrait_type(data_type, true)?), - input: Some(Box::new(to_substrait_rex( - state, - expr, - schema, - col_ref_offset, - extensions, - )?)), - failure_behavior: FailureBehavior::ReturnNull.into(), - }, - ))), - }), - Expr::Literal(value) => to_substrait_literal_expr(value, extensions), - Expr::Alias(Alias { expr, .. }) => { - to_substrait_rex(state, expr, schema, col_ref_offset, extensions) - } - Expr::WindowFunction(WindowFunction { - fun, - args, - partition_by, - order_by, - window_frame, - null_treatment: _, - }) => { - // function reference - let function_anchor = extensions.register_function(fun.to_string()); - // arguments - let mut arguments: Vec = vec![]; - for arg in args { - arguments.push(FunctionArgument { - arg_type: Some(ArgType::Value(to_substrait_rex( - state, - arg, - schema, - col_ref_offset, - extensions, - )?)), - }); - } - // partition by expressions - let partition_by = partition_by - .iter() - .map(|e| to_substrait_rex(state, e, schema, col_ref_offset, extensions)) - .collect::>>()?; - // order by expressions - let order_by = order_by - .iter() - .map(|e| substrait_sort_field(state, e, schema, extensions)) - .collect::>>()?; - // window frame - let bounds = to_substrait_bounds(window_frame)?; - let bound_type = to_substrait_bound_type(window_frame)?; - Ok(make_substrait_window_function( - function_anchor, - arguments, - partition_by, - order_by, - bounds, - bound_type, - )) - } - Expr::Like(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive, - }) => make_substrait_like_expr( - state, - *case_insensitive, - *negated, - expr, - pattern, - *escape_char, - schema, - col_ref_offset, - extensions, - ), - Expr::InSubquery(InSubquery { - expr, - subquery, - negated, - }) => { - let substrait_expr = - to_substrait_rex(state, expr, schema, col_ref_offset, extensions)?; - - let subquery_plan = - to_substrait_rel(subquery.subquery.as_ref(), state, extensions)?; - - let substrait_subquery = Expression { - rex_type: Some(RexType::Subquery(Box::new(Subquery { - subquery_type: Some( - substrait::proto::expression::subquery::SubqueryType::InPredicate( - Box::new(InPredicate { - needles: (vec![substrait_expr]), - haystack: Some(subquery_plan), - }), - ), +pub fn from_literal( + producer: &mut impl SubstraitProducer, + value: &ScalarValue, +) -> Result { + to_substrait_literal_expr(producer, value) +} + +pub fn from_alias( + producer: &mut impl SubstraitProducer, + alias: &Alias, + schema: &DFSchemaRef, + col_ref_offset: usize, +) -> Result { + producer.consume_expr(alias.expr.as_ref(), schema, col_ref_offset) +} + +pub fn from_window_function( + producer: &mut impl SubstraitProducer, + window_fn: &WindowFunction, + schema: &DFSchemaRef, + col_ref_offset: usize, +) -> Result { + let WindowFunction { + fun, + args, + partition_by, + order_by, + window_frame, + null_treatment: _, + } = window_fn; + // function reference + let function_anchor = producer.register_function(fun.to_string()); + // arguments + let mut arguments: Vec = vec![]; + for arg in args { + arguments.push(FunctionArgument { + arg_type: Some(ArgType::Value(to_substrait_rex( + producer, + arg, + schema, + col_ref_offset, + )?)), + }); + } + // partition by expressions + let partition_by = partition_by + .iter() + .map(|e| producer.consume_expr(e, schema, col_ref_offset)) + .collect::>>()?; + // order by expressions + let order_by = order_by + .iter() + .map(|e| substrait_sort_field(producer, e, schema)) + .collect::>>()?; + // window frame + let bounds = to_substrait_bounds(window_frame)?; + let bound_type = to_substrait_bound_type(window_frame)?; + Ok(make_substrait_window_function( + function_anchor, + arguments, + partition_by, + order_by, + bounds, + bound_type, + )) +} + +pub fn from_like( + producer: &mut impl SubstraitProducer, + like: &Like, + schema: &DFSchemaRef, + col_ref_offset: usize, +) -> Result { + let Like { + negated, + expr, + pattern, + escape_char, + case_insensitive, + } = like; + make_substrait_like_expr( + producer, + *case_insensitive, + *negated, + expr, + pattern, + *escape_char, + schema, + col_ref_offset, + ) +} + +pub fn from_in_subquery( + producer: &mut impl SubstraitProducer, + subquery: &InSubquery, + schema: &DFSchemaRef, + col_ref_offset: usize, +) -> Result { + let InSubquery { + expr, + subquery, + negated, + } = subquery; + let substrait_expr = producer.consume_expr(expr, schema, col_ref_offset)?; + + let subquery_plan = producer.consume_plan(subquery.subquery.as_ref())?; + + let substrait_subquery = Expression { + rex_type: Some(RexType::Subquery(Box::new( + substrait::proto::expression::Subquery { + subquery_type: Some( + substrait::proto::expression::subquery::SubqueryType::InPredicate( + Box::new(InPredicate { + needles: (vec![substrait_expr]), + haystack: Some(subquery_plan), + }), ), - }))), - }; - if *negated { - let function_anchor = extensions.register_function("not".to_string()); - - Ok(Expression { - rex_type: Some(RexType::ScalarFunction(ScalarFunction { - function_reference: function_anchor, - arguments: vec![FunctionArgument { - arg_type: Some(ArgType::Value(substrait_subquery)), - }], - output_type: None, - args: vec![], - options: vec![], - })), - }) - } else { - Ok(substrait_subquery) - } - } - Expr::Not(arg) => to_substrait_unary_scalar_fn( - state, - "not", - arg, - schema, - col_ref_offset, - extensions, - ), - Expr::IsNull(arg) => to_substrait_unary_scalar_fn( - state, - "is_null", - arg, - schema, - col_ref_offset, - extensions, - ), - Expr::IsNotNull(arg) => to_substrait_unary_scalar_fn( - state, - "is_not_null", - arg, - schema, - col_ref_offset, - extensions, - ), - Expr::IsTrue(arg) => to_substrait_unary_scalar_fn( - state, - "is_true", - arg, - schema, - col_ref_offset, - extensions, - ), - Expr::IsFalse(arg) => to_substrait_unary_scalar_fn( - state, - "is_false", - arg, - schema, - col_ref_offset, - extensions, - ), - Expr::IsUnknown(arg) => to_substrait_unary_scalar_fn( - state, - "is_unknown", - arg, - schema, - col_ref_offset, - extensions, - ), - Expr::IsNotTrue(arg) => to_substrait_unary_scalar_fn( - state, - "is_not_true", - arg, - schema, - col_ref_offset, - extensions, - ), - Expr::IsNotFalse(arg) => to_substrait_unary_scalar_fn( - state, - "is_not_false", - arg, - schema, - col_ref_offset, - extensions, - ), - Expr::IsNotUnknown(arg) => to_substrait_unary_scalar_fn( - state, - "is_not_unknown", - arg, - schema, - col_ref_offset, - extensions, - ), - Expr::Negative(arg) => to_substrait_unary_scalar_fn( - state, - "negate", - arg, - schema, - col_ref_offset, - extensions, - ), - _ => { - not_impl_err!("Unsupported expression: {expr:?}") - } + ), + }, + ))), + }; + if *negated { + let function_anchor = producer.register_function("not".to_string()); + + #[allow(deprecated)] + Ok(Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments: vec![FunctionArgument { + arg_type: Some(ArgType::Value(substrait_subquery)), + }], + output_type: None, + args: vec![], + options: vec![], + })), + }) + } else { + Ok(substrait_subquery) } } +pub fn from_unary_expr( + producer: &mut impl SubstraitProducer, + expr: &Expr, + schema: &DFSchemaRef, + col_ref_offset: usize, +) -> Result { + let (fn_name, arg) = match expr { + Expr::Not(arg) => ("not", arg), + Expr::IsNull(arg) => ("is_null", arg), + Expr::IsNotNull(arg) => ("is_not_null", arg), + Expr::IsTrue(arg) => ("is_true", arg), + Expr::IsFalse(arg) => ("is_false", arg), + Expr::IsUnknown(arg) => ("is_unknown", arg), + Expr::IsNotTrue(arg) => ("is_not_true", arg), + Expr::IsNotFalse(arg) => ("is_not_false", arg), + Expr::IsNotUnknown(arg) => ("is_not_unknown", arg), + Expr::Negative(arg) => ("negate", arg), + expr => not_impl_err!("Unsupported expression: {expr:?}")?, + }; + to_substrait_unary_scalar_fn(producer, fn_name, arg, schema, col_ref_offset) +} + fn to_substrait_type(dt: &DataType, nullable: bool) -> Result { let nullability = if nullable { r#type::Nullability::Nullable as i32 @@ -1700,7 +2018,6 @@ fn to_substrait_type(dt: &DataType, nullable: bool) -> Result, @@ -1709,6 +2026,7 @@ fn make_substrait_window_function( bounds: (Bound, Bound), bounds_type: BoundsType, ) -> Expression { + #[allow(deprecated)] Expression { rex_type: Some(RexType::WindowFunction(SubstraitWindowFunction { function_reference, @@ -1727,10 +2045,9 @@ fn make_substrait_window_function( } } -#[allow(deprecated)] #[allow(clippy::too_many_arguments)] fn make_substrait_like_expr( - state: &dyn SubstraitPlanningState, + producer: &mut impl SubstraitProducer, ignore_case: bool, negated: bool, expr: &Expr, @@ -1738,18 +2055,18 @@ fn make_substrait_like_expr( escape_char: Option, schema: &DFSchemaRef, col_ref_offset: usize, - extensions: &mut Extensions, ) -> Result { + // let mut extensions = producer.get_extensions(); let function_anchor = if ignore_case { - extensions.register_function("ilike".to_string()) + producer.register_function("ilike".to_string()) } else { - extensions.register_function("like".to_string()) + producer.register_function("like".to_string()) }; - let expr = to_substrait_rex(state, expr, schema, col_ref_offset, extensions)?; - let pattern = to_substrait_rex(state, pattern, schema, col_ref_offset, extensions)?; + let expr = producer.consume_expr(expr, schema, col_ref_offset)?; + let pattern = producer.consume_expr(pattern, schema, col_ref_offset)?; let escape_char = to_substrait_literal_expr( + producer, &ScalarValue::Utf8(escape_char.map(|c| c.to_string())), - extensions, )?; let arguments = vec![ FunctionArgument { @@ -1763,6 +2080,7 @@ fn make_substrait_like_expr( }, ]; + #[allow(deprecated)] let substrait_like = Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { function_reference: function_anchor, @@ -1774,8 +2092,9 @@ fn make_substrait_like_expr( }; if negated { - let function_anchor = extensions.register_function("not".to_string()); + let function_anchor = producer.register_function("not".to_string()); + #[allow(deprecated)] Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { function_reference: function_anchor, @@ -1847,8 +2166,8 @@ fn to_substrait_bounds(window_frame: &WindowFrame) -> Result<(Bound, Bound)> { } fn to_substrait_literal( + producer: &mut impl SubstraitProducer, value: &ScalarValue, - extensions: &mut Extensions, ) -> Result { if value.is_null() { return Ok(Literal { @@ -2026,11 +2345,11 @@ fn to_substrait_literal( DECIMAL_128_TYPE_VARIATION_REF, ), ScalarValue::List(l) => ( - convert_array_to_literal_list(l, extensions)?, + convert_array_to_literal_list(producer, l)?, DEFAULT_CONTAINER_TYPE_VARIATION_REF, ), ScalarValue::LargeList(l) => ( - convert_array_to_literal_list(l, extensions)?, + convert_array_to_literal_list(producer, l)?, LARGE_CONTAINER_TYPE_VARIATION_REF, ), ScalarValue::Map(m) => { @@ -2047,16 +2366,16 @@ fn to_substrait_literal( let keys = (0..m.keys().len()) .map(|i| { to_substrait_literal( + producer, &ScalarValue::try_from_array(&m.keys(), i)?, - extensions, ) }) .collect::>>()?; let values = (0..m.values().len()) .map(|i| { to_substrait_literal( + producer, &ScalarValue::try_from_array(&m.values(), i)?, - extensions, ) }) .collect::>>()?; @@ -2082,8 +2401,8 @@ fn to_substrait_literal( .iter() .map(|col| { to_substrait_literal( + producer, &ScalarValue::try_from_array(col, 0)?, - extensions, ) }) .collect::>>()?, @@ -2104,8 +2423,8 @@ fn to_substrait_literal( } fn convert_array_to_literal_list( + producer: &mut impl SubstraitProducer, array: &GenericListArray, - extensions: &mut Extensions, ) -> Result { assert_eq!(array.len(), 1); let nested_array = array.value(0); @@ -2113,8 +2432,8 @@ fn convert_array_to_literal_list( let values = (0..nested_array.len()) .map(|i| { to_substrait_literal( + producer, &ScalarValue::try_from_array(&nested_array, i)?, - extensions, ) }) .collect::>>()?; @@ -2133,10 +2452,10 @@ fn convert_array_to_literal_list( } fn to_substrait_literal_expr( + producer: &mut impl SubstraitProducer, value: &ScalarValue, - extensions: &mut Extensions, ) -> Result { - let literal = to_substrait_literal(value, extensions)?; + let literal = to_substrait_literal(producer, value)?; Ok(Expression { rex_type: Some(RexType::Literal(literal)), }) @@ -2144,16 +2463,14 @@ fn to_substrait_literal_expr( /// Util to generate substrait [RexType::ScalarFunction] with one argument fn to_substrait_unary_scalar_fn( - state: &dyn SubstraitPlanningState, + producer: &mut impl SubstraitProducer, fn_name: &str, arg: &Expr, schema: &DFSchemaRef, col_ref_offset: usize, - extensions: &mut Extensions, ) -> Result { - let function_anchor = extensions.register_function(fn_name.to_string()); - let substrait_expr = - to_substrait_rex(state, arg, schema, col_ref_offset, extensions)?; + let function_anchor = producer.register_function(fn_name.to_string()); + let substrait_expr = producer.consume_expr(arg, schema, col_ref_offset)?; Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { @@ -2194,17 +2511,16 @@ fn try_to_substrait_field_reference( } fn substrait_sort_field( - state: &dyn SubstraitPlanningState, - sort: &Sort, + producer: &mut impl SubstraitProducer, + sort: &SortExpr, schema: &DFSchemaRef, - extensions: &mut Extensions, ) -> Result { - let Sort { + let SortExpr { expr, asc, nulls_first, } = sort; - let e = to_substrait_rex(state, expr, schema, 0, extensions)?; + let e = producer.consume_expr(expr, schema, 0)?; let d = match (asc, nulls_first) { (true, true) => SortDirection::AscNullsFirst, (true, false) => SortDirection::AscNullsLast, @@ -2380,9 +2696,9 @@ mod test { fn round_trip_literal(scalar: ScalarValue) -> Result<()> { println!("Checking round trip of {scalar:?}"); - - let mut extensions = Extensions::default(); - let substrait_literal = to_substrait_literal(&scalar, &mut extensions)?; + let state = SessionContext::default().state(); + let mut producer = DefaultSubstraitProducer::new(&state); + let substrait_literal = to_substrait_literal(&mut producer, &scalar)?; let roundtrip_scalar = from_substrait_literal_without_names(&test_consumer(), &substrait_literal)?; assert_eq!(scalar, roundtrip_scalar); From 1183d457dac14f542aef6042f259c2cd7dc3b4b8 Mon Sep 17 00:00:00 2001 From: Victor Barua Date: Thu, 26 Dec 2024 19:35:27 -0800 Subject: [PATCH 02/13] refactor(substrait): simplify col_ref_offset handling in producer --- .../substrait/src/logical_plan/producer.rs | 346 ++++++++---------- 1 file changed, 150 insertions(+), 196 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index b905ee6c35df..04e2c67ad716 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -110,6 +110,53 @@ pub trait SubstraitProducer: Send + Sync + Sized { fn register_function(&mut self, signature: String) -> u32; + /// Offset for calculating Substrait field reference indices. + /// + /// See [SubstraitProducer::set_col_offset] for more details. + fn get_col_offset(&self) -> usize; + + /// Sets the offset for calculating Substrait field reference indices. + /// + /// This is only needed when handling relations with more than 1 input relation, and only when + /// converting column references contained in fields of the relation, which are indexed relative + /// to the *output schema* of the relation. + /// + /// An example of this is the [JoinRel] which takes 2 inputs and has 2 fields: + /// * [expression](JoinRel:: expression) + /// * [post_join_filter](JoinRel:: post_join_filter) + /// + /// These two fields may contain column references. DataFusion references these columns by name, + /// whereas Substrait uses indices. + /// + /// The output schema of the JoinRel consists of all columns of the left input, followed by all + /// columns of the right input. If `JoinRel::left` has `m` columns, and `JoinRel::right` has `n` + /// columns, then + /// * The `left` input will have column indices from `0` to `m-1` + /// * The `right` input will have column indices from `0` to `n-1` + /// * The [JoinRel] output has column indices from `0` to `m + n - 1` + /// + /// Putting it all together. Given a query + /// ```sql + /// SELECT * + /// FROM t1 + /// JOIN t2 + /// ON t1.l1 = t2.r1; + /// ``` + /// with the following tables: + /// * t1: (l0, l1, l2) + /// * t2: (r0, r1) + /// + /// the output schema is + /// ```text + /// 0, 1, 2, 3, 4 : Output Schema Index + /// (l0, l1, l2, r0, r1) + /// ``` + /// and as such the join condition becomes + /// ```col_ref(1) = col_ref(3 + 1)``` + /// + /// This function can be used to set the offset used when computing the ... + fn set_col_offset(&mut self, offset: usize); + // Logical Plans fn consume_plan(&mut self, plan: &LogicalPlan) -> Result> { to_substrait_rel(self, plan) @@ -176,31 +223,24 @@ pub trait SubstraitProducer: Send + Sync + Sized { } // Expressions - fn consume_expr( - &mut self, - expr: &Expr, - schema: &DFSchemaRef, - col_ref_offset: usize, - ) -> Result { - to_substrait_rex(self, expr, schema, col_ref_offset) + fn consume_expr(&mut self, expr: &Expr, schema: &DFSchemaRef) -> Result { + to_substrait_rex(self, expr, schema) } fn consume_alias( &mut self, alias: &Alias, schema: &DFSchemaRef, - col_ref_offset: usize, ) -> Result { - from_alias(self, alias, schema, col_ref_offset) + from_alias(self, alias, schema) } fn consume_column( &mut self, column: &Column, schema: &DFSchemaRef, - col_ref_offset: usize, ) -> Result { - from_column(column, schema, col_ref_offset) + from_column(self, column, schema) } fn consume_literal(&mut self, value: &ScalarValue) -> Result { @@ -211,18 +251,12 @@ pub trait SubstraitProducer: Send + Sync + Sized { &mut self, expr: &BinaryExpr, schema: &DFSchemaRef, - col_ref_offset: usize, ) -> Result { - from_binary_expr(self, expr, schema, col_ref_offset) + from_binary_expr(self, expr, schema) } - fn consume_like( - &mut self, - like: &Like, - schema: &DFSchemaRef, - col_ref_offset: usize, - ) -> Result { - from_like(self, like, schema, col_ref_offset) + fn consume_like(&mut self, like: &Like, schema: &DFSchemaRef) -> Result { + from_like(self, like, schema) } /// Handles: Not, IsNotNull, IsNull, IsTrue, IsFalse, IsUnknown, IsNotTrue, IsNotFalse, IsNotUnknown, Negative @@ -230,54 +264,40 @@ pub trait SubstraitProducer: Send + Sync + Sized { &mut self, expr: &Expr, schema: &DFSchemaRef, - col_ref_offset: usize, ) -> Result { - from_unary_expr(self, expr, schema, col_ref_offset) + from_unary_expr(self, expr, schema) } fn consume_between( &mut self, between: &Between, schema: &DFSchemaRef, - col_ref_offset: usize, ) -> Result { - from_between(self, between, schema, col_ref_offset) + from_between(self, between, schema) } - fn consume_case( - &mut self, - case: &Case, - schema: &DFSchemaRef, - col_ref_offset: usize, - ) -> Result { - from_case(self, case, schema, col_ref_offset) + fn consume_case(&mut self, case: &Case, schema: &DFSchemaRef) -> Result { + from_case(self, case, schema) } - fn consume_cast( - &mut self, - cast: &Cast, - schema: &DFSchemaRef, - col_ref_offset: usize, - ) -> Result { - from_cast(self, cast, schema, col_ref_offset) + fn consume_cast(&mut self, cast: &Cast, schema: &DFSchemaRef) -> Result { + from_cast(self, cast, schema) } fn consume_try_cast( &mut self, cast: &TryCast, schema: &DFSchemaRef, - col_ref_offset: usize, ) -> Result { - from_try_cast(self, cast, schema, col_ref_offset) + from_try_cast(self, cast, schema) } fn consume_scalar_function( &mut self, scalar_fn: &expr::ScalarFunction, schema: &DFSchemaRef, - col_ref_offset: usize, ) -> Result { - from_scalar_function(self, scalar_fn, schema, col_ref_offset) + from_scalar_function(self, scalar_fn, schema) } fn consume_agg_function( @@ -292,33 +312,31 @@ pub trait SubstraitProducer: Send + Sync + Sized { &mut self, window_fn: &WindowFunction, schema: &DFSchemaRef, - col_ref_offset: usize, ) -> Result { - from_window_function(self, window_fn, schema, col_ref_offset) + from_window_function(self, window_fn, schema) } fn consume_in_list( &mut self, in_list: &InList, schema: &DFSchemaRef, - col_ref_offset: usize, ) -> Result { - from_in_list(self, in_list, schema, col_ref_offset) + from_in_list(self, in_list, schema) } fn consume_in_subquery( &mut self, in_subquery: &InSubquery, schema: &DFSchemaRef, - col_ref_offset: usize, ) -> Result { - from_in_subquery(self, in_subquery, schema, col_ref_offset) + from_in_subquery(self, in_subquery, schema) } } struct DefaultSubstraitProducer<'a> { extensions: Extensions, state: &'a SessionState, + col_offset: usize, } impl<'a> DefaultSubstraitProducer<'a> { @@ -326,6 +344,7 @@ impl<'a> DefaultSubstraitProducer<'a> { DefaultSubstraitProducer { extensions: Extensions::default(), state, + col_offset: 0, } } } @@ -339,6 +358,14 @@ impl SubstraitProducer for DefaultSubstraitProducer<'_> { self.extensions.register_function(fn_name) } + fn get_col_offset(&self) -> usize { + self.col_offset + } + + fn set_col_offset(&mut self, offset: usize) { + self.col_offset = offset + } + fn consume_extension(&mut self, plan: &Extension) -> Result> { let extension_bytes = self .state @@ -427,8 +454,7 @@ pub fn to_substrait_extended_expr( let substrait_exprs = exprs .iter() .map(|(expr, field)| { - let substrait_expr = - producer.consume_expr(expr, schema, /*col_ref_offset=*/ 0)?; + let substrait_expr = producer.consume_expr(expr, schema)?; let mut output_names = Vec::new(); flatten_names(field, false, &mut output_names)?; Ok(ExpressionReference { @@ -584,7 +610,7 @@ pub fn from_projection( let expressions = p .expr .iter() - .map(|e| producer.consume_expr(e, p.input.schema(), 0)) + .map(|e| producer.consume_expr(e, p.input.schema())) .collect::>>()?; let emit_kind = create_project_remapping( @@ -612,8 +638,7 @@ pub fn from_filter( filter: &Filter, ) -> Result> { let input = producer.consume_plan(filter.input.as_ref())?; - let filter_expr = - producer.consume_expr(&filter.predicate, filter.input.schema(), 0)?; + let filter_expr = producer.consume_expr(&filter.predicate, filter.input.schema())?; Ok(Box::new(Rel { rel_type: Some(RelType::Filter(Box::new(FilterRel { common: None, @@ -633,14 +658,14 @@ pub fn from_limit( let offset_mode = limit .skip .as_ref() - .map(|expr| producer.consume_expr(expr.as_ref(), &empty_schema, 0)) + .map(|expr| producer.consume_expr(expr.as_ref(), &empty_schema)) .transpose()? .map(Box::new) .map(fetch_rel::OffsetMode::OffsetExpr); let count_mode = limit .fetch .as_ref() - .map(|expr| producer.consume_expr(expr.as_ref(), &empty_schema, 0)) + .map(|expr| producer.consume_expr(expr.as_ref(), &empty_schema)) .transpose()? .map(Box::new) .map(fetch_rel::CountMode::CountExpr); @@ -770,7 +795,6 @@ pub fn from_join(producer: &mut impl SubstraitProducer, join: &Join) -> Result None, }; @@ -865,7 +889,7 @@ pub fn from_window( // process and add each window function expression for expr in &window.window_expr { - expressions.push(producer.consume_expr(expr, window.input.schema(), 0)?); + expressions.push(producer.consume_expr(expr, window.input.schema())?); } let emit_kind = @@ -1004,19 +1028,22 @@ fn to_substrait_join_expr( ) -> Result> { // Only support AND conjunction for each binary expression in join conditions let mut exprs: Vec = vec![]; + + // store current column offset + let current_offset = producer.get_col_offset(); for (left, right) in join_conditions { - // Parse left - let l = producer.consume_expr(left, left_schema, 0)?; - // Parse right - let r = to_substrait_rex( - producer, - right, - right_schema, - left_schema.fields().len(), // offset to return the correct index - )?; + // column references to the left input start at 0 in the JoinRel schema + producer.set_col_offset(0); + let l = producer.consume_expr(left, left_schema)?; + // column references to the right input start after all fields in the left schema + producer.set_col_offset(left_schema.fields().len()); + let r = to_substrait_rex(producer, right, right_schema)?; // AND with existing expression exprs.push(make_binary_op_scalar_func(producer, &l, &r, eq_op)); } + // restore column offset + producer.set_col_offset(current_offset); + let join_expr: Option = exprs.into_iter().reduce(|acc: Expression, e: Expression| { make_binary_op_scalar_func(producer, &acc, &e, Operator::And) @@ -1085,7 +1112,7 @@ pub fn parse_flat_grouping_exprs( let mut grouping_expressions = vec![]; for e in exprs { - let rex = producer.consume_expr(e, schema, 0)?; + let rex = producer.consume_expr(e, schema)?; grouping_expressions.push(rex.clone()); ref_group_exprs.push(rex); expression_references.push((ref_group_exprs.len() - 1) as u32); @@ -1180,7 +1207,7 @@ pub fn from_aggregate_function( let mut arguments: Vec = vec![]; for arg in args { arguments.push(FunctionArgument { - arg_type: Some(ArgType::Value(producer.consume_expr(arg, schema, 0)?)), + arg_type: Some(ArgType::Value(producer.consume_expr(arg, schema)?)), }); } let function_anchor = producer.register_function(func.name().to_string()); @@ -1200,7 +1227,7 @@ pub fn from_aggregate_function( options: vec![], }), filter: match filter { - Some(f) => Some(producer.consume_expr(f, schema, 0)?), + Some(f) => Some(producer.consume_expr(f, schema)?), None => None, }, }) @@ -1237,7 +1264,7 @@ fn to_substrait_sort_field( (false, false) => SortDirection::DescNullsLast, }; Ok(SortField { - expr: Some(producer.consume_expr(&sort.expr, schema, 0)?), + expr: Some(producer.consume_expr(&sort.expr, schema)?), sort_kind: Some(SortKind::Direction(sort_kind.into())), }) } @@ -1273,71 +1300,43 @@ pub fn make_binary_op_scalar_func( /// /// # Arguments /// -/// * `expr` - DataFusion expression to be parse into a Substrait expression -/// * `schema` - DataFusion input schema for looking up field qualifiers -/// * `col_ref_offset` - Offset for calculating Substrait field reference indices. -/// This should only be set by caller with more than one input relations i.e. Join. -/// Substrait expects one set of indices when joining two relations. -/// Let's say `left` and `right` have `m` and `n` columns, respectively. The `right` -/// relation will have column indices from `0` to `n-1`, however, Substrait will expect -/// the `right` indices to be offset by the `left`. This means Substrait will expect to -/// evaluate the join condition expression on indices [0 .. n-1, n .. n+m-1]. For example: -/// ```SELECT * -/// FROM t1 -/// JOIN t2 -/// ON t1.c1 = t2.c0;``` -/// where t1 consists of columns [c0, c1, c2], and t2 = columns [c0, c1] -/// the join condition should become -/// `col_ref(1) = col_ref(3 + 0)` -/// , where `3` is the number of `left` columns (`col_ref_offset`) and `0` is the index -/// of the join key column from `right` -/// * `extensions` - Substrait extension info. Contains registered function information +/// * `expr` - DataFusion expression to convert into a Substrait expression +/// * `schema` - DataFusion input schema for looking up columns pub fn to_substrait_rex( producer: &mut impl SubstraitProducer, expr: &Expr, schema: &DFSchemaRef, - col_ref_offset: usize, ) -> Result { match expr { - Expr::Alias(expr) => producer.consume_alias(expr, schema, col_ref_offset), - Expr::Column(expr) => producer.consume_column(expr, schema, col_ref_offset), + Expr::Alias(expr) => producer.consume_alias(expr, schema), + Expr::Column(expr) => producer.consume_column(expr, schema), Expr::Literal(expr) => producer.consume_literal(expr), - Expr::BinaryExpr(expr) => { - producer.consume_binary_expr(expr, schema, col_ref_offset) - } - Expr::Like(expr) => producer.consume_like(expr, schema, col_ref_offset), + Expr::BinaryExpr(expr) => producer.consume_binary_expr(expr, schema), + Expr::Like(expr) => producer.consume_like(expr, schema), Expr::SimilarTo(_) => not_impl_err!("SimilarTo is not supported"), - Expr::Not(_) => producer.consume_unary_expr(expr, schema, col_ref_offset), - Expr::IsNotNull(_) => producer.consume_unary_expr(expr, schema, col_ref_offset), - Expr::IsNull(_) => producer.consume_unary_expr(expr, schema, col_ref_offset), - Expr::IsTrue(_) => producer.consume_unary_expr(expr, schema, col_ref_offset), - Expr::IsFalse(_) => producer.consume_unary_expr(expr, schema, col_ref_offset), - Expr::IsUnknown(_) => producer.consume_unary_expr(expr, schema, col_ref_offset), - Expr::IsNotTrue(_) => producer.consume_unary_expr(expr, schema, col_ref_offset), - Expr::IsNotFalse(_) => producer.consume_unary_expr(expr, schema, col_ref_offset), - Expr::IsNotUnknown(_) => { - producer.consume_unary_expr(expr, schema, col_ref_offset) - } - Expr::Negative(_) => producer.consume_unary_expr(expr, schema, col_ref_offset), - Expr::Between(expr) => producer.consume_between(expr, schema, col_ref_offset), - Expr::Case(expr) => producer.consume_case(expr, schema, col_ref_offset), - Expr::Cast(expr) => producer.consume_cast(expr, schema, col_ref_offset), - Expr::TryCast(expr) => producer.consume_try_cast(expr, schema, col_ref_offset), - Expr::ScalarFunction(expr) => { - producer.consume_scalar_function(expr, schema, col_ref_offset) - } + Expr::Not(_) => producer.consume_unary_expr(expr, schema), + Expr::IsNotNull(_) => producer.consume_unary_expr(expr, schema), + Expr::IsNull(_) => producer.consume_unary_expr(expr, schema), + Expr::IsTrue(_) => producer.consume_unary_expr(expr, schema), + Expr::IsFalse(_) => producer.consume_unary_expr(expr, schema), + Expr::IsUnknown(_) => producer.consume_unary_expr(expr, schema), + Expr::IsNotTrue(_) => producer.consume_unary_expr(expr, schema), + Expr::IsNotFalse(_) => producer.consume_unary_expr(expr, schema), + Expr::IsNotUnknown(_) => producer.consume_unary_expr(expr, schema), + Expr::Negative(_) => producer.consume_unary_expr(expr, schema), + Expr::Between(expr) => producer.consume_between(expr, schema), + Expr::Case(expr) => producer.consume_case(expr, schema), + Expr::Cast(expr) => producer.consume_cast(expr, schema), + Expr::TryCast(expr) => producer.consume_try_cast(expr, schema), + Expr::ScalarFunction(expr) => producer.consume_scalar_function(expr, schema), Expr::AggregateFunction(_) => { internal_err!( "AggregateFunction should only be encountered as part of a LogicalPlan::Aggregate" ) } - Expr::WindowFunction(expr) => { - producer.consume_window_function(expr, schema, col_ref_offset) - } - Expr::InList(expr) => producer.consume_in_list(expr, schema, col_ref_offset), - Expr::InSubquery(expr) => { - producer.consume_in_subquery(expr, schema, col_ref_offset) - } + Expr::WindowFunction(expr) => producer.consume_window_function(expr, schema), + Expr::InList(expr) => producer.consume_in_list(expr, schema), + Expr::InSubquery(expr) => producer.consume_in_subquery(expr, schema), _ => not_impl_err!("Cannot convert {expr:?} to Substrait"), } } @@ -1346,7 +1345,6 @@ pub fn from_in_list( producer: &mut impl SubstraitProducer, in_list: &InList, schema: &DFSchemaRef, - col_ref_offset: usize, ) -> Result { let InList { expr, @@ -1355,9 +1353,9 @@ pub fn from_in_list( } = in_list; let substrait_list = list .iter() - .map(|x| producer.consume_expr(x, schema, col_ref_offset)) + .map(|x| producer.consume_expr(x, schema)) .collect::>>()?; - let substrait_expr = producer.consume_expr(expr, schema, col_ref_offset)?; + let substrait_expr = producer.consume_expr(expr, schema)?; let substrait_or_list = Expression { rex_type: Some(RexType::SingularOrList(Box::new(SingularOrList { @@ -1390,17 +1388,11 @@ pub fn from_scalar_function( producer: &mut impl SubstraitProducer, fun: &expr::ScalarFunction, schema: &DFSchemaRef, - col_ref_offset: usize, ) -> Result { let mut arguments: Vec = vec![]; for arg in &fun.args { arguments.push(FunctionArgument { - arg_type: Some(ArgType::Value(to_substrait_rex( - producer, - arg, - schema, - col_ref_offset, - )?)), + arg_type: Some(ArgType::Value(to_substrait_rex(producer, arg, schema)?)), }); } @@ -1421,7 +1413,6 @@ pub fn from_between( producer: &mut impl SubstraitProducer, between: &Between, schema: &DFSchemaRef, - col_ref_offset: usize, ) -> Result { let Between { expr, @@ -1431,12 +1422,9 @@ pub fn from_between( } = between; if *negated { // `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr) - let substrait_expr = - producer.consume_expr(expr.as_ref(), schema, col_ref_offset)?; - let substrait_low = - producer.consume_expr(low.as_ref(), schema, col_ref_offset)?; - let substrait_high = - producer.consume_expr(high.as_ref(), schema, col_ref_offset)?; + let substrait_expr = producer.consume_expr(expr.as_ref(), schema)?; + let substrait_low = producer.consume_expr(low.as_ref(), schema)?; + let substrait_high = producer.consume_expr(high.as_ref(), schema)?; let l_expr = make_binary_op_scalar_func( producer, @@ -1459,12 +1447,9 @@ pub fn from_between( )) } else { // `expr BETWEEN low AND high` can be translated into (low <= expr AND expr <= high) - let substrait_expr = - producer.consume_expr(expr.as_ref(), schema, col_ref_offset)?; - let substrait_low = - producer.consume_expr(low.as_ref(), schema, col_ref_offset)?; - let substrait_high = - producer.consume_expr(high.as_ref(), schema, col_ref_offset)?; + let substrait_expr = producer.consume_expr(expr.as_ref(), schema)?; + let substrait_low = producer.consume_expr(low.as_ref(), schema)?; + let substrait_high = producer.consume_expr(high.as_ref(), schema)?; let l_expr = make_binary_op_scalar_func( producer, @@ -1488,30 +1473,29 @@ pub fn from_between( } } pub fn from_column( + producer: &impl SubstraitProducer, col: &Column, schema: &DFSchemaRef, - col_ref_offset: usize, ) -> Result { let index = schema.index_of_column(col)?; - substrait_field_ref(index + col_ref_offset) + let col_offset = producer.get_col_offset(); + substrait_field_ref(index + col_offset) } pub fn from_binary_expr( producer: &mut impl SubstraitProducer, expr: &BinaryExpr, schema: &DFSchemaRef, - col_ref_offset: usize, ) -> Result { let BinaryExpr { left, op, right } = expr; - let l = producer.consume_expr(left, schema, col_ref_offset)?; - let r = producer.consume_expr(right, schema, col_ref_offset)?; + let l = producer.consume_expr(left, schema)?; + let r = producer.consume_expr(right, schema)?; Ok(make_binary_op_scalar_func(producer, &l, &r, *op)) } pub fn from_case( producer: &mut impl SubstraitProducer, case: &Case, schema: &DFSchemaRef, - col_ref_offset: usize, ) -> Result { let Case { expr, @@ -1523,26 +1507,21 @@ pub fn from_case( if let Some(e) = expr { // Base expression exists ifs.push(IfClause { - r#if: Some(producer.consume_expr(e, schema, col_ref_offset)?), + r#if: Some(producer.consume_expr(e, schema)?), then: None, }); } // Parse `when`s for (r#if, then) in when_then_expr { ifs.push(IfClause { - r#if: Some(producer.consume_expr(r#if, schema, col_ref_offset)?), - then: Some(producer.consume_expr(then, schema, col_ref_offset)?), + r#if: Some(producer.consume_expr(r#if, schema)?), + then: Some(producer.consume_expr(then, schema)?), }); } // Parse outer `else` let r#else: Option> = match else_expr { - Some(e) => Some(Box::new(to_substrait_rex( - producer, - e, - schema, - col_ref_offset, - )?)), + Some(e) => Some(Box::new(to_substrait_rex(producer, e, schema)?)), None => None, }; @@ -1555,19 +1534,13 @@ pub fn from_cast( producer: &mut impl SubstraitProducer, cast: &Cast, schema: &DFSchemaRef, - col_ref_offset: usize, ) -> Result { let Cast { expr, data_type } = cast; Ok(Expression { rex_type: Some(RexType::Cast(Box::new( substrait::proto::expression::Cast { r#type: Some(to_substrait_type(data_type, true)?), - input: Some(Box::new(to_substrait_rex( - producer, - expr, - schema, - col_ref_offset, - )?)), + input: Some(Box::new(to_substrait_rex(producer, expr, schema)?)), failure_behavior: FailureBehavior::ThrowException.into(), }, ))), @@ -1578,19 +1551,13 @@ pub fn from_try_cast( producer: &mut impl SubstraitProducer, cast: &TryCast, schema: &DFSchemaRef, - col_ref_offset: usize, ) -> Result { let TryCast { expr, data_type } = cast; Ok(Expression { rex_type: Some(RexType::Cast(Box::new( substrait::proto::expression::Cast { r#type: Some(to_substrait_type(data_type, true)?), - input: Some(Box::new(to_substrait_rex( - producer, - expr, - schema, - col_ref_offset, - )?)), + input: Some(Box::new(to_substrait_rex(producer, expr, schema)?)), failure_behavior: FailureBehavior::ReturnNull.into(), }, ))), @@ -1608,16 +1575,14 @@ pub fn from_alias( producer: &mut impl SubstraitProducer, alias: &Alias, schema: &DFSchemaRef, - col_ref_offset: usize, ) -> Result { - producer.consume_expr(alias.expr.as_ref(), schema, col_ref_offset) + producer.consume_expr(alias.expr.as_ref(), schema) } pub fn from_window_function( producer: &mut impl SubstraitProducer, window_fn: &WindowFunction, schema: &DFSchemaRef, - col_ref_offset: usize, ) -> Result { let WindowFunction { fun, @@ -1633,18 +1598,13 @@ pub fn from_window_function( let mut arguments: Vec = vec![]; for arg in args { arguments.push(FunctionArgument { - arg_type: Some(ArgType::Value(to_substrait_rex( - producer, - arg, - schema, - col_ref_offset, - )?)), + arg_type: Some(ArgType::Value(to_substrait_rex(producer, arg, schema)?)), }); } // partition by expressions let partition_by = partition_by .iter() - .map(|e| producer.consume_expr(e, schema, col_ref_offset)) + .map(|e| producer.consume_expr(e, schema)) .collect::>>()?; // order by expressions let order_by = order_by @@ -1668,7 +1628,6 @@ pub fn from_like( producer: &mut impl SubstraitProducer, like: &Like, schema: &DFSchemaRef, - col_ref_offset: usize, ) -> Result { let Like { negated, @@ -1685,7 +1644,6 @@ pub fn from_like( pattern, *escape_char, schema, - col_ref_offset, ) } @@ -1693,14 +1651,13 @@ pub fn from_in_subquery( producer: &mut impl SubstraitProducer, subquery: &InSubquery, schema: &DFSchemaRef, - col_ref_offset: usize, ) -> Result { let InSubquery { expr, subquery, negated, } = subquery; - let substrait_expr = producer.consume_expr(expr, schema, col_ref_offset)?; + let substrait_expr = producer.consume_expr(expr, schema)?; let subquery_plan = producer.consume_plan(subquery.subquery.as_ref())?; @@ -1742,7 +1699,6 @@ pub fn from_unary_expr( producer: &mut impl SubstraitProducer, expr: &Expr, schema: &DFSchemaRef, - col_ref_offset: usize, ) -> Result { let (fn_name, arg) = match expr { Expr::Not(arg) => ("not", arg), @@ -1757,7 +1713,7 @@ pub fn from_unary_expr( Expr::Negative(arg) => ("negate", arg), expr => not_impl_err!("Unsupported expression: {expr:?}")?, }; - to_substrait_unary_scalar_fn(producer, fn_name, arg, schema, col_ref_offset) + to_substrait_unary_scalar_fn(producer, fn_name, arg, schema) } fn to_substrait_type(dt: &DataType, nullable: bool) -> Result { @@ -2054,7 +2010,6 @@ fn make_substrait_like_expr( pattern: &Expr, escape_char: Option, schema: &DFSchemaRef, - col_ref_offset: usize, ) -> Result { // let mut extensions = producer.get_extensions(); let function_anchor = if ignore_case { @@ -2062,8 +2017,8 @@ fn make_substrait_like_expr( } else { producer.register_function("like".to_string()) }; - let expr = producer.consume_expr(expr, schema, col_ref_offset)?; - let pattern = producer.consume_expr(pattern, schema, col_ref_offset)?; + let expr = producer.consume_expr(expr, schema)?; + let pattern = producer.consume_expr(pattern, schema)?; let escape_char = to_substrait_literal_expr( producer, &ScalarValue::Utf8(escape_char.map(|c| c.to_string())), @@ -2467,10 +2422,9 @@ fn to_substrait_unary_scalar_fn( fn_name: &str, arg: &Expr, schema: &DFSchemaRef, - col_ref_offset: usize, ) -> Result { let function_anchor = producer.register_function(fn_name.to_string()); - let substrait_expr = producer.consume_expr(arg, schema, col_ref_offset)?; + let substrait_expr = producer.consume_expr(arg, schema)?; Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { @@ -2520,7 +2474,7 @@ fn substrait_sort_field( asc, nulls_first, } = sort; - let e = producer.consume_expr(expr, schema, 0)?; + let e = producer.consume_expr(expr, schema)?; let d = match (asc, nulls_first) { (true, true) => SortDirection::AscNullsFirst, (true, false) => SortDirection::AscNullsLast, From 8a14160454c7bd3a063edc95303a8f9827c7d5b1 Mon Sep 17 00:00:00 2001 From: Victor Barua Date: Fri, 27 Dec 2024 09:39:33 -0800 Subject: [PATCH 03/13] refactor(substrait): remove column offset tracking from producer --- .../substrait/src/logical_plan/producer.rs | 97 ++----------------- .../tests/cases/roundtrip_logical_plan.rs | 15 +++ 2 files changed, 25 insertions(+), 87 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 04e2c67ad716..0f4a062e2b1b 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -110,53 +110,6 @@ pub trait SubstraitProducer: Send + Sync + Sized { fn register_function(&mut self, signature: String) -> u32; - /// Offset for calculating Substrait field reference indices. - /// - /// See [SubstraitProducer::set_col_offset] for more details. - fn get_col_offset(&self) -> usize; - - /// Sets the offset for calculating Substrait field reference indices. - /// - /// This is only needed when handling relations with more than 1 input relation, and only when - /// converting column references contained in fields of the relation, which are indexed relative - /// to the *output schema* of the relation. - /// - /// An example of this is the [JoinRel] which takes 2 inputs and has 2 fields: - /// * [expression](JoinRel:: expression) - /// * [post_join_filter](JoinRel:: post_join_filter) - /// - /// These two fields may contain column references. DataFusion references these columns by name, - /// whereas Substrait uses indices. - /// - /// The output schema of the JoinRel consists of all columns of the left input, followed by all - /// columns of the right input. If `JoinRel::left` has `m` columns, and `JoinRel::right` has `n` - /// columns, then - /// * The `left` input will have column indices from `0` to `m-1` - /// * The `right` input will have column indices from `0` to `n-1` - /// * The [JoinRel] output has column indices from `0` to `m + n - 1` - /// - /// Putting it all together. Given a query - /// ```sql - /// SELECT * - /// FROM t1 - /// JOIN t2 - /// ON t1.l1 = t2.r1; - /// ``` - /// with the following tables: - /// * t1: (l0, l1, l2) - /// * t2: (r0, r1) - /// - /// the output schema is - /// ```text - /// 0, 1, 2, 3, 4 : Output Schema Index - /// (l0, l1, l2, r0, r1) - /// ``` - /// and as such the join condition becomes - /// ```col_ref(1) = col_ref(3 + 1)``` - /// - /// This function can be used to set the offset used when computing the ... - fn set_col_offset(&mut self, offset: usize); - // Logical Plans fn consume_plan(&mut self, plan: &LogicalPlan) -> Result> { to_substrait_rel(self, plan) @@ -336,7 +289,6 @@ pub trait SubstraitProducer: Send + Sync + Sized { struct DefaultSubstraitProducer<'a> { extensions: Extensions, state: &'a SessionState, - col_offset: usize, } impl<'a> DefaultSubstraitProducer<'a> { @@ -344,7 +296,6 @@ impl<'a> DefaultSubstraitProducer<'a> { DefaultSubstraitProducer { extensions: Extensions::default(), state, - col_offset: 0, } } } @@ -358,14 +309,6 @@ impl SubstraitProducer for DefaultSubstraitProducer<'_> { self.extensions.register_function(fn_name) } - fn get_col_offset(&self) -> usize { - self.col_offset - } - - fn set_col_offset(&mut self, offset: usize) { - self.col_offset = offset - } - fn consume_extension(&mut self, plan: &Extension) -> Result> { let extension_bytes = self .state @@ -788,14 +731,11 @@ pub fn from_join(producer: &mut impl SubstraitProducer, join: &Join) -> Result {} JoinConstraint::Using => return not_impl_err!("join constraint: `using`"), } - // parse filter if exists - let in_join_schema = join.left.schema().join(join.right.schema())?; + let in_join_schema = Arc::new(join.left.schema().join(join.right.schema())?); + + // convert filter if present let join_filter = match &join.filter { - Some(filter) => Some(to_substrait_rex( - producer, - filter, - &Arc::new(in_join_schema), - )?), + Some(filter) => Some(to_substrait_rex(producer, filter, &in_join_schema)?), None => None, }; @@ -806,13 +746,7 @@ pub fn from_join(producer: &mut impl SubstraitProducer, join: &Join) -> Result, eq_op: Operator, - left_schema: &DFSchemaRef, - right_schema: &DFSchemaRef, + join_schema: &DFSchemaRef, ) -> Result> { // Only support AND conjunction for each binary expression in join conditions let mut exprs: Vec = vec![]; - - // store current column offset - let current_offset = producer.get_col_offset(); for (left, right) in join_conditions { - // column references to the left input start at 0 in the JoinRel schema - producer.set_col_offset(0); - let l = producer.consume_expr(left, left_schema)?; - // column references to the right input start after all fields in the left schema - producer.set_col_offset(left_schema.fields().len()); - let r = to_substrait_rex(producer, right, right_schema)?; + let l = producer.consume_expr(left, join_schema)?; + let r = producer.consume_expr(right, join_schema)?; // AND with existing expression exprs.push(make_binary_op_scalar_func(producer, &l, &r, eq_op)); } - // restore column offset - producer.set_col_offset(current_offset); let join_expr: Option = exprs.into_iter().reduce(|acc: Expression, e: Expression| { @@ -1473,13 +1397,12 @@ pub fn from_between( } } pub fn from_column( - producer: &impl SubstraitProducer, + _producer: &impl SubstraitProducer, col: &Column, schema: &DFSchemaRef, ) -> Result { let index = schema.index_of_column(col)?; - let col_offset = producer.get_col_offset(); - substrait_field_ref(index + col_offset) + substrait_field_ref(index) } pub fn from_binary_expr( diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 383fe44be507..772bf2e7ad8e 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -571,6 +571,21 @@ async fn roundtrip_self_implicit_cross_join() -> Result<()> { roundtrip("SELECT left.a left_a, left.b, right.a right_a, right.c FROM data AS left, data AS right").await } +#[tokio::test] +async fn self_join_introduces_aliases() -> Result<()> { + assert_expected_plan( + "SELECT d1.b, d2.c FROM data d1 JOIN data d2 ON d1.b = d2.b", + "Projection: left.b, right.c\ + \n Inner Join: left.b = right.b\ + \n SubqueryAlias: left\ + \n TableScan: data projection=[b]\ + \n SubqueryAlias: right\ + \n TableScan: data projection=[b, c]", + false, + ) + .await +} + #[tokio::test] async fn roundtrip_arithmetic_ops() -> Result<()> { roundtrip("SELECT a - a FROM data").await?; From 4a464cb5bea7d52712ca36ff48985cf81a145d71 Mon Sep 17 00:00:00 2001 From: Victor Barua Date: Fri, 27 Dec 2024 11:01:59 -0800 Subject: [PATCH 04/13] docs(substrait): document SubstraitProducer --- .../substrait/src/logical_plan/consumer.rs | 3 + .../substrait/src/logical_plan/producer.rs | 105 ++++++++++++++++-- .../tests/cases/roundtrip_logical_plan.rs | 2 +- 3 files changed, 97 insertions(+), 13 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 515553152659..d82237298436 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -114,6 +114,9 @@ use substrait::proto::{ /// This trait is used to consume Substrait plans, converting them into DataFusion Logical Plans. /// It can be implemented by users to allow for custom handling of relations, expressions, etc. /// +/// Combined with the [crate::logical_plan::producer::SubstraitProducer] this allows for fully +/// customizable Substrait serde. +/// /// # Example Usage /// /// ``` diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 0f4a062e2b1b..cf879ad65629 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -105,12 +105,89 @@ use substrait::{ version, }; +/// This trait is used to produce Substrait plans, converting them from DataFusion Logical Plans. +/// It can be implemented by users to allow for custom handling of relations, expressions, etc. +/// +/// Combined with the [crate::logical_plan::consumer::SubstraitConsumer] this allows for fully +/// customizable Substrait serde. +/// +/// # Example Usage +/// +/// ``` +/// # use std::sync::Arc; +/// # use substrait::proto::{Expression, Rel}; +/// # use substrait::proto::rel::RelType; +/// # use datafusion::common::DFSchemaRef; +/// # use datafusion::error::Result; +/// # use datafusion::execution::SessionState; +/// # use datafusion::logical_expr::{Between, Extension, Projection}; +/// # use datafusion_substrait::extensions::Extensions; +/// # use datafusion_substrait::logical_plan::producer::{from_projection, SubstraitProducer}; +/// +/// struct CustomSubstraitProducer { +/// extensions: Extensions, +/// state: Arc, +/// } +/// +/// impl SubstraitProducer for CustomSubstraitProducer { +/// +/// fn register_function(&mut self, signature: String) -> u32 { +/// self.extensions.register_function(signature) +/// } +/// +/// fn get_extensions(self) -> Extensions { +/// self.extensions +/// } +/// +/// // You can set additional metadata on the Rels you produce +/// fn consume_projection(&mut self, plan: &Projection) -> Result> { +/// let mut rel = from_projection(self, plan)?; +/// match rel.rel_type { +/// Some(RelType::Project(mut project)) => { +/// let mut project = project.clone(); +/// // set common metadata or advanced extension +/// project.common = None; +/// project.advanced_extension = None; +/// Ok(Box::new(Rel { +/// rel_type: Some(RelType::Project(project)), +/// })) +/// } +/// rel_type => Ok(Box::new(Rel { rel_type })), +/// } +/// } +/// +/// // You can tweak how you convert expressions for your target system +/// fn consume_between(&mut self, between: &Between, schema: &DFSchemaRef) -> Result { +/// // add your own encoding for Between +/// todo!() +/// } +/// +/// // You can fully control how you convert UserDefinedLogicalNodes into Substrait +/// fn consume_extension(&mut self, _plan: &Extension) -> Result> { +/// // implement your own serializer into Substrait +/// todo!() +/// } +/// } +/// ``` pub trait SubstraitProducer: Send + Sync + Sized { + /// Within a Substrait plan, functions are referenced using function anchors that are stored at + /// the top level of the [Plan] within + /// [ExtensionFunction](substrait::proto::extensions::simple_extension_declaration::ExtensionFunction) + /// messages. + /// + /// When given a function signature, this method should return the existing anchor for it if + /// there is one. Otherwise, it should generate a new anchor. + fn register_function(&mut self, signature: String) -> u32; + + /// Consume the producer to generate the [Extensions] for the Substrait plan based on the + /// functions that have been registered fn get_extensions(self) -> Extensions; - fn register_function(&mut self, signature: String) -> u32; + // Logical Plan Methods + // There is one method per LogicalPlan to allow for easy overriding of producer behaviour. + // These methods have default implementations calling the common handler code, to allow for users + // to re-use common handling logic. - // Logical Plans fn consume_plan(&mut self, plan: &LogicalPlan) -> Result> { to_substrait_rel(self, plan) } @@ -175,7 +252,11 @@ pub trait SubstraitProducer: Send + Sync + Sized { substrait_err!("Specify handling for LogicalPlan::Extension by implementing the SubstraitProducer trait") } - // Expressions + // Expression Methods + // There is one method per DataFusion Expr to allow for easy overriding of producer behaviour + // These methods have default implementations calling the common handler code, to allow for users + // to re-use common handling logic. + fn consume_expr(&mut self, expr: &Expr, schema: &DFSchemaRef) -> Result { to_substrait_rex(self, expr, schema) } @@ -212,7 +293,7 @@ pub trait SubstraitProducer: Send + Sync + Sized { from_like(self, like, schema) } - /// Handles: Not, IsNotNull, IsNull, IsTrue, IsFalse, IsUnknown, IsNotTrue, IsNotFalse, IsNotUnknown, Negative + /// For handling Not, IsNotNull, IsNull, IsTrue, IsFalse, IsUnknown, IsNotTrue, IsNotFalse, IsNotUnknown, Negative fn consume_unary_expr( &mut self, expr: &Expr, @@ -253,7 +334,7 @@ pub trait SubstraitProducer: Send + Sync + Sized { from_scalar_function(self, scalar_fn, schema) } - fn consume_agg_function( + fn consume_aggregate_function( &mut self, agg_fn: &expr::AggregateFunction, schema: &DFSchemaRef, @@ -301,14 +382,14 @@ impl<'a> DefaultSubstraitProducer<'a> { } impl SubstraitProducer for DefaultSubstraitProducer<'_> { - fn get_extensions(self) -> Extensions { - self.extensions - } - fn register_function(&mut self, fn_name: String) -> u32 { self.extensions.register_function(fn_name) } + fn get_extensions(self) -> Extensions { + self.extensions + } + fn consume_extension(&mut self, plan: &Extension) -> Result> { let extension_bytes = self .state @@ -1164,7 +1245,7 @@ pub fn to_substrait_agg_measure( ) -> Result { match expr { Expr::AggregateFunction(agg_fn) => from_aggregate_function(producer, agg_fn, schema), - Expr::Alias(Alias{expr,..}) => { + Expr::Alias(Alias { expr, .. }) => { to_substrait_agg_measure(producer, expr, schema) } _ => internal_err!( @@ -2631,7 +2712,7 @@ mod test { ], false, ) - .into(), + .into(), false, ))?; @@ -2640,7 +2721,7 @@ mod test { Field::new("c0", DataType::Int32, true), Field::new("c1", DataType::Utf8, true), ] - .into(), + .into(), ))?; round_trip_type(DataType::Interval(IntervalUnit::YearMonth))?; diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 772bf2e7ad8e..7045729493b1 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -583,7 +583,7 @@ async fn self_join_introduces_aliases() -> Result<()> { \n TableScan: data projection=[b, c]", false, ) - .await + .await } #[tokio::test] From 22bcc94837951a477845fadf97d7f92969c47f3e Mon Sep 17 00:00:00 2001 From: Victor Barua Date: Fri, 27 Dec 2024 12:21:08 -0800 Subject: [PATCH 05/13] refactor: minor cleanup --- datafusion/substrait/src/logical_plan/producer.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index cf879ad65629..ae020aeb8100 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -274,7 +274,7 @@ pub trait SubstraitProducer: Send + Sync + Sized { column: &Column, schema: &DFSchemaRef, ) -> Result { - from_column(self, column, schema) + from_column(column, schema) } fn consume_literal(&mut self, value: &ScalarValue) -> Result { @@ -1304,7 +1304,7 @@ pub fn make_binary_op_scalar_func( /// Convert DataFusion Expr to Substrait Rex /// /// # Arguments -/// +/// * `producer` - SubstraitProducer implementation which the handles the actual conversion /// * `expr` - DataFusion expression to convert into a Substrait expression /// * `schema` - DataFusion input schema for looking up columns pub fn to_substrait_rex( @@ -1478,7 +1478,6 @@ pub fn from_between( } } pub fn from_column( - _producer: &impl SubstraitProducer, col: &Column, schema: &DFSchemaRef, ) -> Result { @@ -2015,7 +2014,6 @@ fn make_substrait_like_expr( escape_char: Option, schema: &DFSchemaRef, ) -> Result { - // let mut extensions = producer.get_extensions(); let function_anchor = if ignore_case { producer.register_function("ilike".to_string()) } else { From 6839d333270c71bd222fa9a04466d5d9412914f4 Mon Sep 17 00:00:00 2001 From: Victor Barua Date: Fri, 27 Dec 2024 12:26:45 -0800 Subject: [PATCH 06/13] feature: remove unused SubstraitPlanningState BREAKING CHANGE: SubstraitPlanningState is no longer available --- datafusion/substrait/src/logical_plan/mod.rs | 1 - .../substrait/src/logical_plan/state.rs | 63 ------------------- 2 files changed, 64 deletions(-) delete mode 100644 datafusion/substrait/src/logical_plan/state.rs diff --git a/datafusion/substrait/src/logical_plan/mod.rs b/datafusion/substrait/src/logical_plan/mod.rs index 9e2fa9fa49de..6f8b8e493f52 100644 --- a/datafusion/substrait/src/logical_plan/mod.rs +++ b/datafusion/substrait/src/logical_plan/mod.rs @@ -17,4 +17,3 @@ pub mod consumer; pub mod producer; -pub mod state; diff --git a/datafusion/substrait/src/logical_plan/state.rs b/datafusion/substrait/src/logical_plan/state.rs deleted file mode 100644 index 0bd749c1105d..000000000000 --- a/datafusion/substrait/src/logical_plan/state.rs +++ /dev/null @@ -1,63 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::sync::Arc; - -use async_trait::async_trait; -use datafusion::{ - catalog::TableProvider, - error::{DataFusionError, Result}, - execution::{registry::SerializerRegistry, FunctionRegistry, SessionState}, - sql::TableReference, -}; - -/// This trait provides the context needed to transform a substrait plan into a -/// [`datafusion::logical_expr::LogicalPlan`] (via [`super::consumer::from_substrait_plan`]) -/// and back again into a substrait plan (via [`super::producer::to_substrait_plan`]). -/// -/// The context is declared as a trait to decouple the substrait plan encoder / -/// decoder from the [`SessionState`], potentially allowing users to define -/// their own slimmer context just for serializing and deserializing substrait. -/// -/// [`SessionState`] implements this trait. -#[async_trait] -pub trait SubstraitPlanningState: Sync + Send + FunctionRegistry { - /// Return [SerializerRegistry] for extensions - fn serializer_registry(&self) -> &Arc; - - async fn table( - &self, - reference: &TableReference, - ) -> Result>>; -} - -#[async_trait] -impl SubstraitPlanningState for SessionState { - fn serializer_registry(&self) -> &Arc { - self.serializer_registry() - } - - async fn table( - &self, - reference: &TableReference, - ) -> Result>, DataFusionError> { - let table = reference.table().to_string(); - let schema = self.schema_for_ref(reference.clone())?; - let table_provider = schema.table(&table).await?; - Ok(table_provider) - } -} From 1f547dc990f8ea082d2f931e77dd20aa3b80c5c4 Mon Sep 17 00:00:00 2001 From: Victor Barua Date: Fri, 27 Dec 2024 12:29:34 -0800 Subject: [PATCH 07/13] refactor: cargo fmt --- datafusion/substrait/src/logical_plan/producer.rs | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index ae020aeb8100..fb46cb66bb99 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -1477,10 +1477,7 @@ pub fn from_between( )) } } -pub fn from_column( - col: &Column, - schema: &DFSchemaRef, -) -> Result { +pub fn from_column(col: &Column, schema: &DFSchemaRef) -> Result { let index = schema.index_of_column(col)?; substrait_field_ref(index) } @@ -2710,7 +2707,7 @@ mod test { ], false, ) - .into(), + .into(), false, ))?; @@ -2719,7 +2716,7 @@ mod test { Field::new("c0", DataType::Int32, true), Field::new("c1", DataType::Utf8, true), ] - .into(), + .into(), ))?; round_trip_type(DataType::Interval(IntervalUnit::YearMonth))?; From d3400f489a44815816341d4667d61dd61d62d7af Mon Sep 17 00:00:00 2001 From: Victor Barua Date: Mon, 30 Dec 2024 09:24:55 -0800 Subject: [PATCH 08/13] refactor(substrait): consume_ -> handle_ --- .../substrait/src/logical_plan/producer.rs | 242 +++++++++--------- 1 file changed, 121 insertions(+), 121 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index fb46cb66bb99..f62ef030f2e1 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -140,7 +140,7 @@ use substrait::{ /// } /// /// // You can set additional metadata on the Rels you produce -/// fn consume_projection(&mut self, plan: &Projection) -> Result> { +/// fn handle_projection(&mut self, plan: &Projection) -> Result> { /// let mut rel = from_projection(self, plan)?; /// match rel.rel_type { /// Some(RelType::Project(mut project)) => { @@ -157,13 +157,13 @@ use substrait::{ /// } /// /// // You can tweak how you convert expressions for your target system -/// fn consume_between(&mut self, between: &Between, schema: &DFSchemaRef) -> Result { +/// fn handle_between(&mut self, between: &Between, schema: &DFSchemaRef) -> Result { /// // add your own encoding for Between /// todo!() /// } /// /// // You can fully control how you convert UserDefinedLogicalNodes into Substrait -/// fn consume_extension(&mut self, _plan: &Extension) -> Result> { +/// fn handle_extension(&mut self, _plan: &Extension) -> Result> { /// // implement your own serializer into Substrait /// todo!() /// } @@ -188,67 +188,67 @@ pub trait SubstraitProducer: Send + Sync + Sized { // These methods have default implementations calling the common handler code, to allow for users // to re-use common handling logic. - fn consume_plan(&mut self, plan: &LogicalPlan) -> Result> { + fn handle_plan(&mut self, plan: &LogicalPlan) -> Result> { to_substrait_rel(self, plan) } - fn consume_projection(&mut self, plan: &Projection) -> Result> { + fn handle_projection(&mut self, plan: &Projection) -> Result> { from_projection(self, plan) } - fn consume_filter(&mut self, plan: &Filter) -> Result> { + fn handle_filter(&mut self, plan: &Filter) -> Result> { from_filter(self, plan) } - fn consume_window(&mut self, plan: &Window) -> Result> { + fn handle_window(&mut self, plan: &Window) -> Result> { from_window(self, plan) } - fn consume_aggregate(&mut self, plan: &Aggregate) -> Result> { + fn handle_aggregate(&mut self, plan: &Aggregate) -> Result> { from_aggregate(self, plan) } - fn consume_sort(&mut self, plan: &Sort) -> Result> { + fn handle_sort(&mut self, plan: &Sort) -> Result> { from_sort(self, plan) } - fn consume_join(&mut self, plan: &Join) -> Result> { + fn handle_join(&mut self, plan: &Join) -> Result> { from_join(self, plan) } - fn consume_repartition(&mut self, plan: &Repartition) -> Result> { + fn handle_repartition(&mut self, plan: &Repartition) -> Result> { from_repartition(self, plan) } - fn consume_union(&mut self, plan: &Union) -> Result> { + fn handle_union(&mut self, plan: &Union) -> Result> { from_union(self, plan) } - fn consume_table_scan(&mut self, plan: &TableScan) -> Result> { + fn handle_table_scan(&mut self, plan: &TableScan) -> Result> { from_table_scan(self, plan) } - fn consume_empty_relation(&mut self, plan: &EmptyRelation) -> Result> { + fn handle_empty_relation(&mut self, plan: &EmptyRelation) -> Result> { from_empty_relation(plan) } - fn consume_subquery_alias(&mut self, plan: &SubqueryAlias) -> Result> { + fn handle_subquery_alias(&mut self, plan: &SubqueryAlias) -> Result> { from_subquery_alias(self, plan) } - fn consume_limit(&mut self, plan: &Limit) -> Result> { + fn handle_limit(&mut self, plan: &Limit) -> Result> { from_limit(self, plan) } - fn consume_values(&mut self, plan: &Values) -> Result> { + fn handle_values(&mut self, plan: &Values) -> Result> { from_values(self, plan) } - fn consume_distinct(&mut self, plan: &Distinct) -> Result> { + fn handle_distinct(&mut self, plan: &Distinct) -> Result> { from_distinct(self, plan) } - fn consume_extension(&mut self, _plan: &Extension) -> Result> { + fn handle_extension(&mut self, _plan: &Extension) -> Result> { substrait_err!("Specify handling for LogicalPlan::Extension by implementing the SubstraitProducer trait") } @@ -257,11 +257,11 @@ pub trait SubstraitProducer: Send + Sync + Sized { // These methods have default implementations calling the common handler code, to allow for users // to re-use common handling logic. - fn consume_expr(&mut self, expr: &Expr, schema: &DFSchemaRef) -> Result { + fn handle_expr(&mut self, expr: &Expr, schema: &DFSchemaRef) -> Result { to_substrait_rex(self, expr, schema) } - fn consume_alias( + fn handle_alias( &mut self, alias: &Alias, schema: &DFSchemaRef, @@ -269,7 +269,7 @@ pub trait SubstraitProducer: Send + Sync + Sized { from_alias(self, alias, schema) } - fn consume_column( + fn handle_column( &mut self, column: &Column, schema: &DFSchemaRef, @@ -277,11 +277,11 @@ pub trait SubstraitProducer: Send + Sync + Sized { from_column(column, schema) } - fn consume_literal(&mut self, value: &ScalarValue) -> Result { + fn handle_literal(&mut self, value: &ScalarValue) -> Result { from_literal(self, value) } - fn consume_binary_expr( + fn handle_binary_expr( &mut self, expr: &BinaryExpr, schema: &DFSchemaRef, @@ -289,12 +289,12 @@ pub trait SubstraitProducer: Send + Sync + Sized { from_binary_expr(self, expr, schema) } - fn consume_like(&mut self, like: &Like, schema: &DFSchemaRef) -> Result { + fn handle_like(&mut self, like: &Like, schema: &DFSchemaRef) -> Result { from_like(self, like, schema) } /// For handling Not, IsNotNull, IsNull, IsTrue, IsFalse, IsUnknown, IsNotTrue, IsNotFalse, IsNotUnknown, Negative - fn consume_unary_expr( + fn handle_unary_expr( &mut self, expr: &Expr, schema: &DFSchemaRef, @@ -302,7 +302,7 @@ pub trait SubstraitProducer: Send + Sync + Sized { from_unary_expr(self, expr, schema) } - fn consume_between( + fn handle_between( &mut self, between: &Between, schema: &DFSchemaRef, @@ -310,15 +310,15 @@ pub trait SubstraitProducer: Send + Sync + Sized { from_between(self, between, schema) } - fn consume_case(&mut self, case: &Case, schema: &DFSchemaRef) -> Result { + fn handle_case(&mut self, case: &Case, schema: &DFSchemaRef) -> Result { from_case(self, case, schema) } - fn consume_cast(&mut self, cast: &Cast, schema: &DFSchemaRef) -> Result { + fn handle_cast(&mut self, cast: &Cast, schema: &DFSchemaRef) -> Result { from_cast(self, cast, schema) } - fn consume_try_cast( + fn handle_try_cast( &mut self, cast: &TryCast, schema: &DFSchemaRef, @@ -326,7 +326,7 @@ pub trait SubstraitProducer: Send + Sync + Sized { from_try_cast(self, cast, schema) } - fn consume_scalar_function( + fn handle_scalar_function( &mut self, scalar_fn: &expr::ScalarFunction, schema: &DFSchemaRef, @@ -334,7 +334,7 @@ pub trait SubstraitProducer: Send + Sync + Sized { from_scalar_function(self, scalar_fn, schema) } - fn consume_aggregate_function( + fn handle_aggregate_function( &mut self, agg_fn: &expr::AggregateFunction, schema: &DFSchemaRef, @@ -342,7 +342,7 @@ pub trait SubstraitProducer: Send + Sync + Sized { from_aggregate_function(self, agg_fn, schema) } - fn consume_window_function( + fn handle_window_function( &mut self, window_fn: &WindowFunction, schema: &DFSchemaRef, @@ -350,7 +350,7 @@ pub trait SubstraitProducer: Send + Sync + Sized { from_window_function(self, window_fn, schema) } - fn consume_in_list( + fn handle_in_list( &mut self, in_list: &InList, schema: &DFSchemaRef, @@ -358,7 +358,7 @@ pub trait SubstraitProducer: Send + Sync + Sized { from_in_list(self, in_list, schema) } - fn consume_in_subquery( + fn handle_in_subquery( &mut self, in_subquery: &InSubquery, schema: &DFSchemaRef, @@ -390,7 +390,7 @@ impl SubstraitProducer for DefaultSubstraitProducer<'_> { self.extensions } - fn consume_extension(&mut self, plan: &Extension) -> Result> { + fn handle_extension(&mut self, plan: &Extension) -> Result> { let extension_bytes = self .state .serializer_registry() @@ -403,7 +403,7 @@ impl SubstraitProducer for DefaultSubstraitProducer<'_> { .node .inputs() .into_iter() - .map(|plan| self.consume_plan(plan)) + .map(|plan| self.handle_plan(plan)) .collect::>>()?; let rel_type = match inputs_rel.len() { 0 => RelType::ExtensionLeaf(ExtensionLeafRel { @@ -440,7 +440,7 @@ pub fn to_substrait_plan(plan: &LogicalPlan, state: &SessionState) -> Result Result> { match plan { - LogicalPlan::Projection(plan) => producer.consume_projection(plan), - LogicalPlan::Filter(plan) => producer.consume_filter(plan), - LogicalPlan::Window(plan) => producer.consume_window(plan), - LogicalPlan::Aggregate(plan) => producer.consume_aggregate(plan), - LogicalPlan::Sort(plan) => producer.consume_sort(plan), - LogicalPlan::Join(plan) => producer.consume_join(plan), - LogicalPlan::Repartition(plan) => producer.consume_repartition(plan), - LogicalPlan::Union(plan) => producer.consume_union(plan), - LogicalPlan::TableScan(plan) => producer.consume_table_scan(plan), - LogicalPlan::EmptyRelation(plan) => producer.consume_empty_relation(plan), - LogicalPlan::SubqueryAlias(plan) => producer.consume_subquery_alias(plan), - LogicalPlan::Limit(plan) => producer.consume_limit(plan), - LogicalPlan::Values(plan) => producer.consume_values(plan), - LogicalPlan::Distinct(plan) => producer.consume_distinct(plan), - LogicalPlan::Extension(plan) => producer.consume_extension(plan), + LogicalPlan::Projection(plan) => producer.handle_projection(plan), + LogicalPlan::Filter(plan) => producer.handle_filter(plan), + LogicalPlan::Window(plan) => producer.handle_window(plan), + LogicalPlan::Aggregate(plan) => producer.handle_aggregate(plan), + LogicalPlan::Sort(plan) => producer.handle_sort(plan), + LogicalPlan::Join(plan) => producer.handle_join(plan), + LogicalPlan::Repartition(plan) => producer.handle_repartition(plan), + LogicalPlan::Union(plan) => producer.handle_union(plan), + LogicalPlan::TableScan(plan) => producer.handle_table_scan(plan), + LogicalPlan::EmptyRelation(plan) => producer.handle_empty_relation(plan), + LogicalPlan::SubqueryAlias(plan) => producer.handle_subquery_alias(plan), + LogicalPlan::Limit(plan) => producer.handle_limit(plan), + LogicalPlan::Values(plan) => producer.handle_values(plan), + LogicalPlan::Distinct(plan) => producer.handle_distinct(plan), + LogicalPlan::Extension(plan) => producer.handle_extension(plan), _ => not_impl_err!("Unsupported plan type: {plan:?}")?, } } @@ -634,7 +634,7 @@ pub fn from_projection( let expressions = p .expr .iter() - .map(|e| producer.consume_expr(e, p.input.schema())) + .map(|e| producer.handle_expr(e, p.input.schema())) .collect::>>()?; let emit_kind = create_project_remapping( @@ -650,7 +650,7 @@ pub fn from_projection( Ok(Box::new(Rel { rel_type: Some(RelType::Project(Box::new(ProjectRel { common: Some(common), - input: Some(producer.consume_plan(p.input.as_ref())?), + input: Some(producer.handle_plan(p.input.as_ref())?), expressions, advanced_extension: None, }))), @@ -661,8 +661,8 @@ pub fn from_filter( producer: &mut impl SubstraitProducer, filter: &Filter, ) -> Result> { - let input = producer.consume_plan(filter.input.as_ref())?; - let filter_expr = producer.consume_expr(&filter.predicate, filter.input.schema())?; + let input = producer.handle_plan(filter.input.as_ref())?; + let filter_expr = producer.handle_expr(&filter.predicate, filter.input.schema())?; Ok(Box::new(Rel { rel_type: Some(RelType::Filter(Box::new(FilterRel { common: None, @@ -677,19 +677,19 @@ pub fn from_limit( producer: &mut impl SubstraitProducer, limit: &Limit, ) -> Result> { - let input = producer.consume_plan(limit.input.as_ref())?; + let input = producer.handle_plan(limit.input.as_ref())?; let empty_schema = Arc::new(DFSchema::empty()); let offset_mode = limit .skip .as_ref() - .map(|expr| producer.consume_expr(expr.as_ref(), &empty_schema)) + .map(|expr| producer.handle_expr(expr.as_ref(), &empty_schema)) .transpose()? .map(Box::new) .map(fetch_rel::OffsetMode::OffsetExpr); let count_mode = limit .fetch .as_ref() - .map(|expr| producer.consume_expr(expr.as_ref(), &empty_schema)) + .map(|expr| producer.handle_expr(expr.as_ref(), &empty_schema)) .transpose()? .map(Box::new) .map(fetch_rel::CountMode::CountExpr); @@ -711,7 +711,7 @@ pub fn from_sort(producer: &mut impl SubstraitProducer, sort: &Sort) -> Result>>()?; - let input = producer.consume_plan(input.as_ref())?; + let input = producer.handle_plan(input.as_ref())?; let sort_rel = Box::new(Rel { rel_type: Some(RelType::Sort(Box::new(SortRel { @@ -750,7 +750,7 @@ pub fn from_aggregate( producer: &mut impl SubstraitProducer, agg: &Aggregate, ) -> Result> { - let input = producer.consume_plan(agg.input.as_ref())?; + let input = producer.handle_plan(agg.input.as_ref())?; let (grouping_expressions, groupings) = to_substrait_groupings(producer, &agg.group_expr, agg.input.schema())?; let measures = agg @@ -778,7 +778,7 @@ pub fn from_distinct( match distinct { Distinct::All(plan) => { // Use Substrait's AggregateRel with empty measures to represent `select distinct` - let input = producer.consume_plan(plan.as_ref())?; + let input = producer.handle_plan(plan.as_ref())?; // Get grouping keys from the input relation's number of output fields let grouping = (0..plan.schema().fields().len()) .map(substrait_field_ref) @@ -804,8 +804,8 @@ pub fn from_distinct( } pub fn from_join(producer: &mut impl SubstraitProducer, join: &Join) -> Result> { - let left = producer.consume_plan(join.left.as_ref())?; - let right = producer.consume_plan(join.right.as_ref())?; + let left = producer.handle_plan(join.left.as_ref())?; + let right = producer.handle_plan(join.right.as_ref())?; let join_type = to_substrait_jointype(join.join_type); // we only support basic joins so return an error for anything not yet supported match join.join_constraint { @@ -866,7 +866,7 @@ pub fn from_subquery_alias( ) -> Result> { // Do nothing if encounters SubqueryAlias // since there is no corresponding relation type in Substrait - producer.consume_plan(alias.input.as_ref()) + producer.handle_plan(alias.input.as_ref()) } pub fn from_union( @@ -876,7 +876,7 @@ pub fn from_union( let input_rels = union .inputs .iter() - .map(|input| producer.consume_plan(input.as_ref())) + .map(|input| producer.handle_plan(input.as_ref())) .collect::>>()? .into_iter() .map(|ptr| *ptr) @@ -895,7 +895,7 @@ pub fn from_window( producer: &mut impl SubstraitProducer, window: &Window, ) -> Result> { - let input = producer.consume_plan(window.input.as_ref())?; + let input = producer.handle_plan(window.input.as_ref())?; // create a field reference for each input field let mut expressions = (0..window.input.schema().fields().len()) @@ -904,7 +904,7 @@ pub fn from_window( // process and add each window function expression for expr in &window.window_expr { - expressions.push(producer.consume_expr(expr, window.input.schema())?); + expressions.push(producer.handle_expr(expr, window.input.schema())?); } let emit_kind = @@ -930,7 +930,7 @@ pub fn from_repartition( producer: &mut impl SubstraitProducer, repartition: &Repartition, ) -> Result> { - let input = producer.consume_plan(repartition.input.as_ref())?; + let input = producer.handle_plan(repartition.input.as_ref())?; let partition_count = match repartition.partitioning_scheme { Partitioning::RoundRobinBatch(num) => num, Partitioning::Hash(_, num) => num, @@ -1043,8 +1043,8 @@ fn to_substrait_join_expr( // Only support AND conjunction for each binary expression in join conditions let mut exprs: Vec = vec![]; for (left, right) in join_conditions { - let l = producer.consume_expr(left, join_schema)?; - let r = producer.consume_expr(right, join_schema)?; + let l = producer.handle_expr(left, join_schema)?; + let r = producer.handle_expr(right, join_schema)?; // AND with existing expression exprs.push(make_binary_op_scalar_func(producer, &l, &r, eq_op)); } @@ -1117,7 +1117,7 @@ pub fn parse_flat_grouping_exprs( let mut grouping_expressions = vec![]; for e in exprs { - let rex = producer.consume_expr(e, schema)?; + let rex = producer.handle_expr(e, schema)?; grouping_expressions.push(rex.clone()); ref_group_exprs.push(rex); expression_references.push((ref_group_exprs.len() - 1) as u32); @@ -1212,7 +1212,7 @@ pub fn from_aggregate_function( let mut arguments: Vec = vec![]; for arg in args { arguments.push(FunctionArgument { - arg_type: Some(ArgType::Value(producer.consume_expr(arg, schema)?)), + arg_type: Some(ArgType::Value(producer.handle_expr(arg, schema)?)), }); } let function_anchor = producer.register_function(func.name().to_string()); @@ -1232,7 +1232,7 @@ pub fn from_aggregate_function( options: vec![], }), filter: match filter { - Some(f) => Some(producer.consume_expr(f, schema)?), + Some(f) => Some(producer.handle_expr(f, schema)?), None => None, }, }) @@ -1269,7 +1269,7 @@ fn to_substrait_sort_field( (false, false) => SortDirection::DescNullsLast, }; Ok(SortField { - expr: Some(producer.consume_expr(&sort.expr, schema)?), + expr: Some(producer.handle_expr(&sort.expr, schema)?), sort_kind: Some(SortKind::Direction(sort_kind.into())), }) } @@ -1313,35 +1313,35 @@ pub fn to_substrait_rex( schema: &DFSchemaRef, ) -> Result { match expr { - Expr::Alias(expr) => producer.consume_alias(expr, schema), - Expr::Column(expr) => producer.consume_column(expr, schema), - Expr::Literal(expr) => producer.consume_literal(expr), - Expr::BinaryExpr(expr) => producer.consume_binary_expr(expr, schema), - Expr::Like(expr) => producer.consume_like(expr, schema), + Expr::Alias(expr) => producer.handle_alias(expr, schema), + Expr::Column(expr) => producer.handle_column(expr, schema), + Expr::Literal(expr) => producer.handle_literal(expr), + Expr::BinaryExpr(expr) => producer.handle_binary_expr(expr, schema), + Expr::Like(expr) => producer.handle_like(expr, schema), Expr::SimilarTo(_) => not_impl_err!("SimilarTo is not supported"), - Expr::Not(_) => producer.consume_unary_expr(expr, schema), - Expr::IsNotNull(_) => producer.consume_unary_expr(expr, schema), - Expr::IsNull(_) => producer.consume_unary_expr(expr, schema), - Expr::IsTrue(_) => producer.consume_unary_expr(expr, schema), - Expr::IsFalse(_) => producer.consume_unary_expr(expr, schema), - Expr::IsUnknown(_) => producer.consume_unary_expr(expr, schema), - Expr::IsNotTrue(_) => producer.consume_unary_expr(expr, schema), - Expr::IsNotFalse(_) => producer.consume_unary_expr(expr, schema), - Expr::IsNotUnknown(_) => producer.consume_unary_expr(expr, schema), - Expr::Negative(_) => producer.consume_unary_expr(expr, schema), - Expr::Between(expr) => producer.consume_between(expr, schema), - Expr::Case(expr) => producer.consume_case(expr, schema), - Expr::Cast(expr) => producer.consume_cast(expr, schema), - Expr::TryCast(expr) => producer.consume_try_cast(expr, schema), - Expr::ScalarFunction(expr) => producer.consume_scalar_function(expr, schema), + Expr::Not(_) => producer.handle_unary_expr(expr, schema), + Expr::IsNotNull(_) => producer.handle_unary_expr(expr, schema), + Expr::IsNull(_) => producer.handle_unary_expr(expr, schema), + Expr::IsTrue(_) => producer.handle_unary_expr(expr, schema), + Expr::IsFalse(_) => producer.handle_unary_expr(expr, schema), + Expr::IsUnknown(_) => producer.handle_unary_expr(expr, schema), + Expr::IsNotTrue(_) => producer.handle_unary_expr(expr, schema), + Expr::IsNotFalse(_) => producer.handle_unary_expr(expr, schema), + Expr::IsNotUnknown(_) => producer.handle_unary_expr(expr, schema), + Expr::Negative(_) => producer.handle_unary_expr(expr, schema), + Expr::Between(expr) => producer.handle_between(expr, schema), + Expr::Case(expr) => producer.handle_case(expr, schema), + Expr::Cast(expr) => producer.handle_cast(expr, schema), + Expr::TryCast(expr) => producer.handle_try_cast(expr, schema), + Expr::ScalarFunction(expr) => producer.handle_scalar_function(expr, schema), Expr::AggregateFunction(_) => { internal_err!( "AggregateFunction should only be encountered as part of a LogicalPlan::Aggregate" ) } - Expr::WindowFunction(expr) => producer.consume_window_function(expr, schema), - Expr::InList(expr) => producer.consume_in_list(expr, schema), - Expr::InSubquery(expr) => producer.consume_in_subquery(expr, schema), + Expr::WindowFunction(expr) => producer.handle_window_function(expr, schema), + Expr::InList(expr) => producer.handle_in_list(expr, schema), + Expr::InSubquery(expr) => producer.handle_in_subquery(expr, schema), _ => not_impl_err!("Cannot convert {expr:?} to Substrait"), } } @@ -1358,9 +1358,9 @@ pub fn from_in_list( } = in_list; let substrait_list = list .iter() - .map(|x| producer.consume_expr(x, schema)) + .map(|x| producer.handle_expr(x, schema)) .collect::>>()?; - let substrait_expr = producer.consume_expr(expr, schema)?; + let substrait_expr = producer.handle_expr(expr, schema)?; let substrait_or_list = Expression { rex_type: Some(RexType::SingularOrList(Box::new(SingularOrList { @@ -1427,9 +1427,9 @@ pub fn from_between( } = between; if *negated { // `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr) - let substrait_expr = producer.consume_expr(expr.as_ref(), schema)?; - let substrait_low = producer.consume_expr(low.as_ref(), schema)?; - let substrait_high = producer.consume_expr(high.as_ref(), schema)?; + let substrait_expr = producer.handle_expr(expr.as_ref(), schema)?; + let substrait_low = producer.handle_expr(low.as_ref(), schema)?; + let substrait_high = producer.handle_expr(high.as_ref(), schema)?; let l_expr = make_binary_op_scalar_func( producer, @@ -1452,9 +1452,9 @@ pub fn from_between( )) } else { // `expr BETWEEN low AND high` can be translated into (low <= expr AND expr <= high) - let substrait_expr = producer.consume_expr(expr.as_ref(), schema)?; - let substrait_low = producer.consume_expr(low.as_ref(), schema)?; - let substrait_high = producer.consume_expr(high.as_ref(), schema)?; + let substrait_expr = producer.handle_expr(expr.as_ref(), schema)?; + let substrait_low = producer.handle_expr(low.as_ref(), schema)?; + let substrait_high = producer.handle_expr(high.as_ref(), schema)?; let l_expr = make_binary_op_scalar_func( producer, @@ -1488,8 +1488,8 @@ pub fn from_binary_expr( schema: &DFSchemaRef, ) -> Result { let BinaryExpr { left, op, right } = expr; - let l = producer.consume_expr(left, schema)?; - let r = producer.consume_expr(right, schema)?; + let l = producer.handle_expr(left, schema)?; + let r = producer.handle_expr(right, schema)?; Ok(make_binary_op_scalar_func(producer, &l, &r, *op)) } pub fn from_case( @@ -1507,15 +1507,15 @@ pub fn from_case( if let Some(e) = expr { // Base expression exists ifs.push(IfClause { - r#if: Some(producer.consume_expr(e, schema)?), + r#if: Some(producer.handle_expr(e, schema)?), then: None, }); } // Parse `when`s for (r#if, then) in when_then_expr { ifs.push(IfClause { - r#if: Some(producer.consume_expr(r#if, schema)?), - then: Some(producer.consume_expr(then, schema)?), + r#if: Some(producer.handle_expr(r#if, schema)?), + then: Some(producer.handle_expr(then, schema)?), }); } @@ -1576,7 +1576,7 @@ pub fn from_alias( alias: &Alias, schema: &DFSchemaRef, ) -> Result { - producer.consume_expr(alias.expr.as_ref(), schema) + producer.handle_expr(alias.expr.as_ref(), schema) } pub fn from_window_function( @@ -1604,7 +1604,7 @@ pub fn from_window_function( // partition by expressions let partition_by = partition_by .iter() - .map(|e| producer.consume_expr(e, schema)) + .map(|e| producer.handle_expr(e, schema)) .collect::>>()?; // order by expressions let order_by = order_by @@ -1657,9 +1657,9 @@ pub fn from_in_subquery( subquery, negated, } = subquery; - let substrait_expr = producer.consume_expr(expr, schema)?; + let substrait_expr = producer.handle_expr(expr, schema)?; - let subquery_plan = producer.consume_plan(subquery.subquery.as_ref())?; + let subquery_plan = producer.handle_plan(subquery.subquery.as_ref())?; let substrait_subquery = Expression { rex_type: Some(RexType::Subquery(Box::new( @@ -2016,8 +2016,8 @@ fn make_substrait_like_expr( } else { producer.register_function("like".to_string()) }; - let expr = producer.consume_expr(expr, schema)?; - let pattern = producer.consume_expr(pattern, schema)?; + let expr = producer.handle_expr(expr, schema)?; + let pattern = producer.handle_expr(pattern, schema)?; let escape_char = to_substrait_literal_expr( producer, &ScalarValue::Utf8(escape_char.map(|c| c.to_string())), @@ -2423,7 +2423,7 @@ fn to_substrait_unary_scalar_fn( schema: &DFSchemaRef, ) -> Result { let function_anchor = producer.register_function(fn_name.to_string()); - let substrait_expr = producer.consume_expr(arg, schema)?; + let substrait_expr = producer.handle_expr(arg, schema)?; Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { @@ -2473,7 +2473,7 @@ fn substrait_sort_field( asc, nulls_first, } = sort; - let e = producer.consume_expr(expr, schema)?; + let e = producer.handle_expr(expr, schema)?; let d = match (asc, nulls_first) { (true, true) => SortDirection::AscNullsFirst, (true, false) => SortDirection::AscNullsLast, From d962dc3c952c1574a497ec13941182e6fa96af85 Mon Sep 17 00:00:00 2001 From: Victor Barua Date: Mon, 30 Dec 2024 09:58:30 -0800 Subject: [PATCH 09/13] refactor(substrait): expand match blocks --- .../substrait/src/logical_plan/producer.rs | 24 +++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index f62ef030f2e1..4d6a6b762681 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -516,12 +516,21 @@ pub fn to_substrait_rel( LogicalPlan::Union(plan) => producer.handle_union(plan), LogicalPlan::TableScan(plan) => producer.handle_table_scan(plan), LogicalPlan::EmptyRelation(plan) => producer.handle_empty_relation(plan), + LogicalPlan::Subquery(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, LogicalPlan::SubqueryAlias(plan) => producer.handle_subquery_alias(plan), LogicalPlan::Limit(plan) => producer.handle_limit(plan), + LogicalPlan::Statement(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, LogicalPlan::Values(plan) => producer.handle_values(plan), - LogicalPlan::Distinct(plan) => producer.handle_distinct(plan), + LogicalPlan::Explain(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, + LogicalPlan::Analyze(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, LogicalPlan::Extension(plan) => producer.handle_extension(plan), - _ => not_impl_err!("Unsupported plan type: {plan:?}")?, + LogicalPlan::Distinct(plan) => producer.handle_distinct(plan), + LogicalPlan::Dml(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, + LogicalPlan::Ddl(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, + LogicalPlan::Copy(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, + LogicalPlan::DescribeTable(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, + LogicalPlan::Unnest(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, + LogicalPlan::RecursiveQuery(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, } } @@ -1315,10 +1324,11 @@ pub fn to_substrait_rex( match expr { Expr::Alias(expr) => producer.handle_alias(expr, schema), Expr::Column(expr) => producer.handle_column(expr, schema), + Expr::ScalarVariable(_, _) => not_impl_err!("Cannot convert {expr:?} to Substrait"), Expr::Literal(expr) => producer.handle_literal(expr), Expr::BinaryExpr(expr) => producer.handle_binary_expr(expr, schema), Expr::Like(expr) => producer.handle_like(expr, schema), - Expr::SimilarTo(_) => not_impl_err!("SimilarTo is not supported"), + Expr::SimilarTo(_) => not_impl_err!("Cannot convert {expr:?} to Substrait"), Expr::Not(_) => producer.handle_unary_expr(expr, schema), Expr::IsNotNull(_) => producer.handle_unary_expr(expr, schema), Expr::IsNull(_) => producer.handle_unary_expr(expr, schema), @@ -1341,8 +1351,14 @@ pub fn to_substrait_rex( } Expr::WindowFunction(expr) => producer.handle_window_function(expr, schema), Expr::InList(expr) => producer.handle_in_list(expr, schema), + Expr::Exists(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), Expr::InSubquery(expr) => producer.handle_in_subquery(expr, schema), - _ => not_impl_err!("Cannot convert {expr:?} to Substrait"), + Expr::ScalarSubquery(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), + Expr::Wildcard { .. } => not_impl_err!("Cannot convert {expr:?} to Substrait"), + Expr::GroupingSet(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), + Expr::Placeholder(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), + Expr::OuterReferenceColumn(_, _) => not_impl_err!("Cannot convert {expr:?} to Substrait"), + Expr::Unnest(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), } } From aa9e6f377f4cec11ad3c2cb510e810eb5293e898 Mon Sep 17 00:00:00 2001 From: Victor Barua Date: Mon, 30 Dec 2024 10:04:14 -0800 Subject: [PATCH 10/13] refactor: DefaultSubstraitProducer only needs serializer_registry --- datafusion/substrait/src/logical_plan/producer.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 4d6a6b762681..053a4f7b0af2 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -104,6 +104,7 @@ use substrait::{ }, version, }; +use datafusion::execution::registry::SerializerRegistry; /// This trait is used to produce Substrait plans, converting them from DataFusion Logical Plans. /// It can be implemented by users to allow for custom handling of relations, expressions, etc. @@ -369,14 +370,14 @@ pub trait SubstraitProducer: Send + Sync + Sized { struct DefaultSubstraitProducer<'a> { extensions: Extensions, - state: &'a SessionState, + serializer_registry: &'a dyn SerializerRegistry, } impl<'a> DefaultSubstraitProducer<'a> { pub fn new(state: &'a SessionState) -> Self { DefaultSubstraitProducer { extensions: Extensions::default(), - state, + serializer_registry: state.serializer_registry().as_ref(), } } } @@ -392,8 +393,7 @@ impl SubstraitProducer for DefaultSubstraitProducer<'_> { fn handle_extension(&mut self, plan: &Extension) -> Result> { let extension_bytes = self - .state - .serializer_registry() + .serializer_registry .serialize_logical_plan(plan.node.as_ref())?; let detail = ProtoAny { type_url: plan.node.name().to_string(), From af9c8a5c125615f782e9514aa65d58303281de36 Mon Sep 17 00:00:00 2001 From: Victor Barua Date: Mon, 30 Dec 2024 10:09:14 -0800 Subject: [PATCH 11/13] refactor: remove unnecessary warning suppression --- datafusion/substrait/src/logical_plan/producer.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 053a4f7b0af2..c72007a16837 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -2017,7 +2017,6 @@ fn make_substrait_window_function( } } -#[allow(clippy::too_many_arguments)] fn make_substrait_like_expr( producer: &mut impl SubstraitProducer, ignore_case: bool, From cf762a29633431dd9742ecf5998331ab6b37c80e Mon Sep 17 00:00:00 2001 From: Victor Barua Date: Mon, 30 Dec 2024 10:11:39 -0800 Subject: [PATCH 12/13] fix(substrait): route expr conversion through handle_expr --- datafusion/substrait/src/logical_plan/producer.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index c72007a16837..3e1bea587508 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -825,7 +825,7 @@ pub fn from_join(producer: &mut impl SubstraitProducer, join: &Join) -> Result Some(to_substrait_rex(producer, filter, &in_join_schema)?), + Some(filter) => Some(producer.handle_expr(filter, &in_join_schema)?), None => None, }; @@ -1413,7 +1413,7 @@ pub fn from_scalar_function( let mut arguments: Vec = vec![]; for arg in &fun.args { arguments.push(FunctionArgument { - arg_type: Some(ArgType::Value(to_substrait_rex(producer, arg, schema)?)), + arg_type: Some(ArgType::Value(producer.handle_expr(arg, schema)?)), }); } @@ -1537,7 +1537,7 @@ pub fn from_case( // Parse outer `else` let r#else: Option> = match else_expr { - Some(e) => Some(Box::new(to_substrait_rex(producer, e, schema)?)), + Some(e) => Some(Box::new(producer.handle_expr(e, schema)?)), None => None, }; @@ -1556,7 +1556,7 @@ pub fn from_cast( rex_type: Some(RexType::Cast(Box::new( substrait::proto::expression::Cast { r#type: Some(to_substrait_type(data_type, true)?), - input: Some(Box::new(to_substrait_rex(producer, expr, schema)?)), + input: Some(Box::new(producer.handle_expr(expr, schema)?)), failure_behavior: FailureBehavior::ThrowException.into(), }, ))), @@ -1573,7 +1573,7 @@ pub fn from_try_cast( rex_type: Some(RexType::Cast(Box::new( substrait::proto::expression::Cast { r#type: Some(to_substrait_type(data_type, true)?), - input: Some(Box::new(to_substrait_rex(producer, expr, schema)?)), + input: Some(Box::new(producer.handle_expr(expr, schema)?)), failure_behavior: FailureBehavior::ReturnNull.into(), }, ))), @@ -1614,7 +1614,7 @@ pub fn from_window_function( let mut arguments: Vec = vec![]; for arg in args { arguments.push(FunctionArgument { - arg_type: Some(ArgType::Value(to_substrait_rex(producer, arg, schema)?)), + arg_type: Some(ArgType::Value(producer.handle_expr(arg, schema)?)), }); } // partition by expressions From 85106f3e9fcdb2fc5ce6966b93112a0b1adf2e48 Mon Sep 17 00:00:00 2001 From: Victor Barua Date: Mon, 30 Dec 2024 10:12:27 -0800 Subject: [PATCH 13/13] cargo fmt --- .../substrait/src/logical_plan/producer.rs | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 3e1bea587508..e501ddf5c698 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -49,6 +49,7 @@ use datafusion::common::{ exec_err, internal_err, not_impl_err, plan_err, substrait_datafusion_err, substrait_err, Column, DFSchema, DFSchemaRef, ToDFSchema, }; +use datafusion::execution::registry::SerializerRegistry; use datafusion::execution::SessionState; use datafusion::logical_expr::expr::{ Alias, BinaryExpr, Case, Cast, GroupingSet, InList, InSubquery, WindowFunction, @@ -104,7 +105,6 @@ use substrait::{ }, version, }; -use datafusion::execution::registry::SerializerRegistry; /// This trait is used to produce Substrait plans, converting them from DataFusion Logical Plans. /// It can be implemented by users to allow for custom handling of relations, expressions, etc. @@ -528,9 +528,13 @@ pub fn to_substrait_rel( LogicalPlan::Dml(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, LogicalPlan::Ddl(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, LogicalPlan::Copy(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, - LogicalPlan::DescribeTable(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, + LogicalPlan::DescribeTable(plan) => { + not_impl_err!("Unsupported plan type: {plan:?}")? + } LogicalPlan::Unnest(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, - LogicalPlan::RecursiveQuery(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, + LogicalPlan::RecursiveQuery(plan) => { + not_impl_err!("Unsupported plan type: {plan:?}")? + } } } @@ -1324,7 +1328,9 @@ pub fn to_substrait_rex( match expr { Expr::Alias(expr) => producer.handle_alias(expr, schema), Expr::Column(expr) => producer.handle_column(expr, schema), - Expr::ScalarVariable(_, _) => not_impl_err!("Cannot convert {expr:?} to Substrait"), + Expr::ScalarVariable(_, _) => { + not_impl_err!("Cannot convert {expr:?} to Substrait") + } Expr::Literal(expr) => producer.handle_literal(expr), Expr::BinaryExpr(expr) => producer.handle_binary_expr(expr, schema), Expr::Like(expr) => producer.handle_like(expr, schema), @@ -1353,11 +1359,15 @@ pub fn to_substrait_rex( Expr::InList(expr) => producer.handle_in_list(expr, schema), Expr::Exists(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), Expr::InSubquery(expr) => producer.handle_in_subquery(expr, schema), - Expr::ScalarSubquery(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), + Expr::ScalarSubquery(expr) => { + not_impl_err!("Cannot convert {expr:?} to Substrait") + } Expr::Wildcard { .. } => not_impl_err!("Cannot convert {expr:?} to Substrait"), Expr::GroupingSet(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), Expr::Placeholder(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), - Expr::OuterReferenceColumn(_, _) => not_impl_err!("Cannot convert {expr:?} to Substrait"), + Expr::OuterReferenceColumn(_, _) => { + not_impl_err!("Cannot convert {expr:?} to Substrait") + } Expr::Unnest(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), } }