From 5ee524ec70b1f10a078caca62954ce37b2dc3cc6 Mon Sep 17 00:00:00 2001 From: Filippo Rossi Date: Wed, 20 Nov 2024 18:11:18 +0100 Subject: [PATCH] feat(substrait): replace SessionContext with a trait (#13343) * feat(substrait): replace SessionContext with SessionState * feat(substrait): add logical plan context * chore(substrait): add apache header * docs: fix code in docs * docs(substrait): rename and document context * chore(substrait): context -> state * chore: fmt --- .../core/src/execution/session_state.rs | 4 +- datafusion/substrait/Cargo.toml | 1 + datafusion/substrait/src/lib.rs | 4 +- .../substrait/src/logical_plan/consumer.rs | 286 ++++++++++-------- datafusion/substrait/src/logical_plan/mod.rs | 1 + .../substrait/src/logical_plan/producer.rs | 196 ++++++------ .../substrait/src/logical_plan/state.rs | 63 ++++ datafusion/substrait/src/serializer.rs | 2 +- .../tests/cases/consumer_integration.rs | 2 +- .../substrait/tests/cases/emit_kind_tests.rs | 12 +- .../substrait/tests/cases/function_test.rs | 2 +- .../substrait/tests/cases/logical_plans.rs | 6 +- .../tests/cases/roundtrip_logical_plan.rs | 40 +-- datafusion/substrait/tests/cases/serialize.rs | 12 +- .../tests/cases/substrait_validations.rs | 10 +- 15 files changed, 379 insertions(+), 262 deletions(-) create mode 100644 datafusion/substrait/src/logical_plan/state.rs diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 9fc081dd5329..e99cf8222381 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -296,7 +296,9 @@ impl SessionState { .resolve(&catalog.default_catalog, &catalog.default_schema) } - pub(crate) fn schema_for_ref( + /// Retrieve the [`SchemaProvider`] for a specific [`TableReference`], if it + /// esists. + pub fn schema_for_ref( &self, table_ref: impl Into, ) -> datafusion_common::Result> { diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index 192fe26d6cef..61cdf3e91e3c 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -34,6 +34,7 @@ workspace = true [dependencies] arrow-buffer = { workspace = true } async-recursion = "1.0" +async-trait = { workspace = true } chrono = { workspace = true } datafusion = { workspace = true, default-features = true } itertools = { workspace = true } diff --git a/datafusion/substrait/src/lib.rs b/datafusion/substrait/src/lib.rs index a6f7c033f9d0..1389cac75b99 100644 --- a/datafusion/substrait/src/lib.rs +++ b/datafusion/substrait/src/lib.rs @@ -64,10 +64,10 @@ //! let plan = df.into_optimized_plan()?; //! //! // Convert the plan into a substrait (protobuf) Plan -//! let substrait_plan = logical_plan::producer::to_substrait_plan(&plan, &ctx)?; +//! let substrait_plan = logical_plan::producer::to_substrait_plan(&plan, &ctx.state())?; //! //! // Receive a substrait protobuf from somewhere, and turn it into a LogicalPlan -//! let logical_round_trip = logical_plan::consumer::from_substrait_plan(&ctx, &substrait_plan).await?; +//! let logical_round_trip = logical_plan::consumer::from_substrait_plan(&ctx.state(), &substrait_plan).await?; //! let logical_round_trip = ctx.state().optimize(&logical_round_trip)?; //! assert_eq!(format!("{:?}", plan), format!("{:?}", logical_round_trip)); //! # Ok(()) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 1cce228527ec..77e9eb81f546 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -26,7 +26,7 @@ use datafusion::common::{ not_impl_err, plan_datafusion_err, substrait_datafusion_err, substrait_err, DFSchema, DFSchemaRef, }; -use datafusion::execution::FunctionRegistry; +use datafusion::datasource::provider_as_source; use datafusion::logical_expr::expr::{Exists, InSubquery, Sort}; use datafusion::logical_expr::{ @@ -56,7 +56,6 @@ use crate::variation_const::{ use datafusion::arrow::array::{new_empty_array, AsArray}; use datafusion::arrow::temporal_conversions::NANOSECONDS; use datafusion::common::scalar::ScalarStructBuilder; -use datafusion::dataframe::DataFrame; use datafusion::logical_expr::builder::project; use datafusion::logical_expr::expr::InList; use datafusion::logical_expr::{ @@ -66,9 +65,7 @@ use datafusion::logical_expr::{ use datafusion::prelude::JoinType; use datafusion::sql::TableReference; use datafusion::{ - error::Result, - logical_expr::utils::split_conjunction, - prelude::{Column, SessionContext}, + error::Result, logical_expr::utils::split_conjunction, prelude::Column, scalar::ScalarValue, }; use std::collections::HashSet; @@ -102,6 +99,8 @@ use substrait::proto::{ }; use substrait::proto::{ExtendedExpression, FunctionArgument, SortField}; +use super::state::SubstraitPlanningState; + // Substrait PrecisionTimestampTz indicates that the timestamp is relative to UTC, which // is the same as the expectation for any non-empty timezone in DF, so any non-empty timezone // results in correct points on the timeline, and we pick UTC as a reasonable default. @@ -203,15 +202,15 @@ fn split_eq_and_noneq_join_predicate_with_nulls_equality( async fn union_rels( rels: &[Rel], - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, extensions: &Extensions, is_all: bool, ) -> Result { let mut union_builder = Ok(LogicalPlanBuilder::from( - from_substrait_rel(ctx, &rels[0], extensions).await?, + from_substrait_rel(state, &rels[0], extensions).await?, )); for input in &rels[1..] { - let rel_plan = from_substrait_rel(ctx, input, extensions).await?; + let rel_plan = from_substrait_rel(state, input, extensions).await?; union_builder = if is_all { union_builder?.union(rel_plan) @@ -224,16 +223,16 @@ async fn union_rels( async fn intersect_rels( rels: &[Rel], - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, extensions: &Extensions, is_all: bool, ) -> Result { - let mut rel = from_substrait_rel(ctx, &rels[0], extensions).await?; + let mut rel = from_substrait_rel(state, &rels[0], extensions).await?; for input in &rels[1..] { rel = LogicalPlanBuilder::intersect( rel, - from_substrait_rel(ctx, input, extensions).await?, + from_substrait_rel(state, input, extensions).await?, is_all, )? } @@ -243,16 +242,16 @@ async fn intersect_rels( async fn except_rels( rels: &[Rel], - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, extensions: &Extensions, is_all: bool, ) -> Result { - let mut rel = from_substrait_rel(ctx, &rels[0], extensions).await?; + let mut rel = from_substrait_rel(state, &rels[0], extensions).await?; for input in &rels[1..] { rel = LogicalPlanBuilder::except( rel, - from_substrait_rel(ctx, input, extensions).await?, + from_substrait_rel(state, input, extensions).await?, is_all, )? } @@ -262,7 +261,7 @@ async fn except_rels( /// Convert Substrait Plan to DataFusion LogicalPlan pub async fn from_substrait_plan( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, plan: &Plan, ) -> Result { // Register function extension @@ -277,10 +276,10 @@ pub async fn from_substrait_plan( match plan.relations[0].rel_type.as_ref() { Some(rt) => match rt { plan_rel::RelType::Rel(rel) => { - Ok(from_substrait_rel(ctx, rel, &extensions).await?) + Ok(from_substrait_rel(state, rel, &extensions).await?) }, plan_rel::RelType::Root(root) => { - let plan = from_substrait_rel(ctx, root.input.as_ref().unwrap(), &extensions).await?; + let plan = from_substrait_rel(state, root.input.as_ref().unwrap(), &extensions).await?; if root.names.is_empty() { // Backwards compatibility for plans missing names return Ok(plan); @@ -341,7 +340,7 @@ pub struct ExprContainer { /// between systems. This is often useful for scenarios like pushdown where filter /// expressions need to be sent to remote systems. pub async fn from_substrait_extended_expr( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, extended_expr: &ExtendedExpression, ) -> Result { // Register function extension @@ -370,7 +369,7 @@ pub async fn from_substrait_extended_expr( } }?; let expr = - from_substrait_rex(ctx, scalar_expr, &input_schema, &extensions).await?; + from_substrait_rex(state, scalar_expr, &input_schema, &extensions).await?; let (output_type, expected_nullability) = expr.data_type_and_nullable(&input_schema)?; let output_field = Field::new("", output_type, expected_nullability); @@ -561,7 +560,7 @@ fn make_renamed_schema( #[allow(deprecated)] #[async_recursion] pub async fn from_substrait_rel( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, rel: &Rel, extensions: &Extensions, ) -> Result { @@ -569,7 +568,7 @@ pub async fn from_substrait_rel( Some(RelType::Project(p)) => { if let Some(input) = p.input.as_ref() { let mut input = LogicalPlanBuilder::from( - from_substrait_rel(ctx, input, extensions).await?, + from_substrait_rel(state, input, extensions).await?, ); let original_schema = input.schema().clone(); @@ -587,9 +586,13 @@ pub async fn from_substrait_rel( let mut explicit_exprs: Vec = vec![]; for expr in &p.expressions { - let e = - from_substrait_rex(ctx, expr, input.clone().schema(), extensions) - .await?; + let e = from_substrait_rex( + state, + expr, + input.clone().schema(), + extensions, + ) + .await?; // if the expression is WindowFunction, wrap in a Window relation if let Expr::WindowFunction(_) = &e { // Adding the same expression here and in the project below @@ -617,11 +620,11 @@ pub async fn from_substrait_rel( Some(RelType::Filter(filter)) => { if let Some(input) = filter.input.as_ref() { let input = LogicalPlanBuilder::from( - from_substrait_rel(ctx, input, extensions).await?, + from_substrait_rel(state, input, extensions).await?, ); if let Some(condition) = filter.condition.as_ref() { let expr = - from_substrait_rex(ctx, condition, input.schema(), extensions) + from_substrait_rex(state, condition, input.schema(), extensions) .await?; input.filter(expr)?.build() } else { @@ -634,7 +637,7 @@ pub async fn from_substrait_rel( Some(RelType::Fetch(fetch)) => { if let Some(input) = fetch.input.as_ref() { let input = LogicalPlanBuilder::from( - from_substrait_rel(ctx, input, extensions).await?, + from_substrait_rel(state, input, extensions).await?, ); let offset = fetch.offset as usize; // -1 means that ALL records should be returned @@ -651,10 +654,10 @@ pub async fn from_substrait_rel( Some(RelType::Sort(sort)) => { if let Some(input) = sort.input.as_ref() { let input = LogicalPlanBuilder::from( - from_substrait_rel(ctx, input, extensions).await?, + from_substrait_rel(state, input, extensions).await?, ); let sorts = - from_substrait_sorts(ctx, &sort.sorts, input.schema(), extensions) + from_substrait_sorts(state, &sort.sorts, input.schema(), extensions) .await?; input.sort(sorts)?.build() } else { @@ -664,13 +667,13 @@ pub async fn from_substrait_rel( Some(RelType::Aggregate(agg)) => { if let Some(input) = agg.input.as_ref() { let input = LogicalPlanBuilder::from( - from_substrait_rel(ctx, input, extensions).await?, + from_substrait_rel(state, input, extensions).await?, ); let mut ref_group_exprs = vec![]; for e in &agg.grouping_expressions { let x = - from_substrait_rex(ctx, e, input.schema(), extensions).await?; + from_substrait_rex(state, e, input.schema(), extensions).await?; ref_group_exprs.push(x); } @@ -681,7 +684,7 @@ pub async fn from_substrait_rel( 1 => { group_exprs.extend_from_slice( &from_substrait_grouping( - ctx, + state, &agg.groupings[0], &ref_group_exprs, input.schema(), @@ -694,7 +697,7 @@ pub async fn from_substrait_rel( let mut grouping_sets = vec![]; for grouping in &agg.groupings { let grouping_set = from_substrait_grouping( - ctx, + state, grouping, &ref_group_exprs, input.schema(), @@ -716,7 +719,7 @@ pub async fn from_substrait_rel( for m in &agg.measures { let filter = match &m.filter { Some(fil) => Some(Box::new( - from_substrait_rex(ctx, fil, input.schema(), extensions) + from_substrait_rex(state, fil, input.schema(), extensions) .await?, )), None => None, @@ -739,7 +742,7 @@ pub async fn from_substrait_rel( let order_by = if !f.sorts.is_empty() { Some( from_substrait_sorts( - ctx, + state, &f.sorts, input.schema(), extensions, @@ -751,7 +754,7 @@ pub async fn from_substrait_rel( }; from_substrait_agg_func( - ctx, + state, f, input.schema(), extensions, @@ -780,10 +783,12 @@ pub async fn from_substrait_rel( } let left: LogicalPlanBuilder = LogicalPlanBuilder::from( - from_substrait_rel(ctx, join.left.as_ref().unwrap(), extensions).await?, + from_substrait_rel(state, join.left.as_ref().unwrap(), extensions) + .await?, ); let right = LogicalPlanBuilder::from( - from_substrait_rel(ctx, join.right.as_ref().unwrap(), extensions).await?, + from_substrait_rel(state, join.right.as_ref().unwrap(), extensions) + .await?, ); let (left, right) = requalify_sides_if_needed(left, right)?; @@ -796,7 +801,7 @@ pub async fn from_substrait_rel( // Otherwise, build join with only the filter, without join keys match &join.expression.as_ref() { Some(expr) => { - let on = from_substrait_rex(ctx, expr, &in_join_schema, extensions) + let on = from_substrait_rex(state, expr, &in_join_schema, extensions) .await?; // The join expression can contain both equal and non-equal ops. // As of datafusion 31.0.0, the equal and non equal join conditions are in separate fields. @@ -831,26 +836,44 @@ pub async fn from_substrait_rel( } Some(RelType::Cross(cross)) => { let left = LogicalPlanBuilder::from( - from_substrait_rel(ctx, cross.left.as_ref().unwrap(), extensions).await?, + from_substrait_rel(state, cross.left.as_ref().unwrap(), extensions) + .await?, ); let right = LogicalPlanBuilder::from( - from_substrait_rel(ctx, cross.right.as_ref().unwrap(), extensions) + from_substrait_rel(state, cross.right.as_ref().unwrap(), extensions) .await?, ); let (left, right) = requalify_sides_if_needed(left, right)?; left.cross_join(right.build()?)?.build() } Some(RelType::Read(read)) => { - fn read_with_schema( - df: DataFrame, + async fn read_with_schema( + state: &dyn SubstraitPlanningState, + table_ref: TableReference, schema: DFSchema, projection: &Option, ) -> Result { - ensure_schema_compatability(df.schema().to_owned(), schema.clone())?; + let schema = schema.replace_qualifier(table_ref.clone()); + + let plan = { + let provider = match state.table(&table_ref).await? { + Some(ref provider) => Arc::clone(provider), + _ => return plan_err!("No table named '{table_ref}'"), + }; + + LogicalPlanBuilder::scan( + table_ref, + provider_as_source(Arc::clone(&provider)), + None, + )? + .build()? + }; + + ensure_schema_compatability(plan.schema(), schema.clone())?; let schema = apply_masking(schema, projection)?; - apply_projection(df, schema) + apply_projection(plan, schema) } let named_struct = read.base_schema.as_ref().ok_or_else(|| { @@ -879,12 +902,13 @@ pub async fn from_substrait_rel( }, }; - let t = ctx.table(table_reference.clone()).await?; - - let substrait_schema = - substrait_schema.replace_qualifier(table_reference); - - read_with_schema(t, substrait_schema, &read.projection) + read_with_schema( + state, + table_reference, + substrait_schema, + &read.projection, + ) + .await } Some(ReadType::VirtualTable(vt)) => { if vt.values.is_empty() { @@ -960,12 +984,14 @@ pub async fn from_substrait_rel( let name = filename.unwrap(); // directly use unwrap here since we could determine it is a valid one let table_reference = TableReference::Bare { table: name.into() }; - let t = ctx.table(table_reference.clone()).await?; - - let substrait_schema = - substrait_schema.replace_qualifier(table_reference); - read_with_schema(t, substrait_schema, &read.projection) + read_with_schema( + state, + table_reference, + substrait_schema, + &read.projection, + ) + .await } _ => { not_impl_err!("Unsupported ReadType: {:?}", &read.as_ref().read_type) @@ -979,31 +1005,31 @@ pub async fn from_substrait_rel( } else { match set_op { set_rel::SetOp::UnionAll => { - union_rels(&set.inputs, ctx, extensions, true).await + union_rels(&set.inputs, state, extensions, true).await } set_rel::SetOp::UnionDistinct => { - union_rels(&set.inputs, ctx, extensions, false).await + union_rels(&set.inputs, state, extensions, false).await } set_rel::SetOp::IntersectionPrimary => { LogicalPlanBuilder::intersect( - from_substrait_rel(ctx, &set.inputs[0], extensions) + from_substrait_rel(state, &set.inputs[0], extensions) .await?, - union_rels(&set.inputs[1..], ctx, extensions, true) + union_rels(&set.inputs[1..], state, extensions, true) .await?, false, ) } set_rel::SetOp::IntersectionMultiset => { - intersect_rels(&set.inputs, ctx, extensions, false).await + intersect_rels(&set.inputs, state, extensions, false).await } set_rel::SetOp::IntersectionMultisetAll => { - intersect_rels(&set.inputs, ctx, extensions, true).await + intersect_rels(&set.inputs, state, extensions, true).await } set_rel::SetOp::MinusPrimary => { - except_rels(&set.inputs, ctx, extensions, false).await + except_rels(&set.inputs, state, extensions, false).await } set_rel::SetOp::MinusPrimaryAll => { - except_rels(&set.inputs, ctx, extensions, true).await + except_rels(&set.inputs, state, extensions, true).await } _ => not_impl_err!("Unsupported set operator: {set_op:?}"), } @@ -1015,8 +1041,7 @@ pub async fn from_substrait_rel( let Some(ext_detail) = &extension.detail else { return substrait_err!("Unexpected empty detail in ExtensionLeafRel"); }; - let plan = ctx - .state() + let plan = state .serializer_registry() .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?; Ok(LogicalPlan::Extension(Extension { node: plan })) @@ -1025,8 +1050,7 @@ pub async fn from_substrait_rel( let Some(ext_detail) = &extension.detail else { return substrait_err!("Unexpected empty detail in ExtensionSingleRel"); }; - let plan = ctx - .state() + let plan = state .serializer_registry() .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?; let Some(input_rel) = &extension.input else { @@ -1034,7 +1058,7 @@ pub async fn from_substrait_rel( "ExtensionSingleRel doesn't contains input rel. Try use ExtensionLeafRel instead" ); }; - let input_plan = from_substrait_rel(ctx, input_rel, extensions).await?; + let input_plan = from_substrait_rel(state, input_rel, extensions).await?; let plan = plan.with_exprs_and_inputs(plan.expressions(), vec![input_plan])?; Ok(LogicalPlan::Extension(Extension { node: plan })) @@ -1043,13 +1067,12 @@ pub async fn from_substrait_rel( let Some(ext_detail) = &extension.detail else { return substrait_err!("Unexpected empty detail in ExtensionSingleRel"); }; - let plan = ctx - .state() + let plan = state .serializer_registry() .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?; let mut inputs = Vec::with_capacity(extension.inputs.len()); for input in &extension.inputs { - let input_plan = from_substrait_rel(ctx, input, extensions).await?; + let input_plan = from_substrait_rel(state, input, extensions).await?; inputs.push(input_plan); } let plan = plan.with_exprs_and_inputs(plan.expressions(), inputs)?; @@ -1059,7 +1082,7 @@ pub async fn from_substrait_rel( let Some(input) = exchange.input.as_ref() else { return substrait_err!("Unexpected empty input in ExchangeRel"); }; - let input = Arc::new(from_substrait_rel(ctx, input, extensions).await?); + let input = Arc::new(from_substrait_rel(state, input, extensions).await?); let Some(exchange_kind) = &exchange.exchange_kind else { return substrait_err!("Unexpected empty input in ExchangeRel"); @@ -1237,7 +1260,7 @@ impl NameTracker { /// DataFusion schema may have MORE fields, but not the other way around. /// 2. All fields are compatible. See [`ensure_field_compatability`] for details fn ensure_schema_compatability( - table_schema: DFSchema, + table_schema: &DFSchema, substrait_schema: DFSchema, ) -> Result<()> { substrait_schema @@ -1253,16 +1276,19 @@ fn ensure_schema_compatability( /// This function returns a DataFrame with fields adjusted if necessary in the event that the /// Substrait schema is a subset of the DataFusion schema. -fn apply_projection(table: DataFrame, substrait_schema: DFSchema) -> Result { - let df_schema = table.schema().to_owned(); - - let t = table.into_unoptimized_plan(); +fn apply_projection( + plan: LogicalPlan, + substrait_schema: DFSchema, +) -> Result { + let df_schema = plan.schema(); if df_schema.logically_equivalent_names_and_types(&substrait_schema) { - return Ok(t); + return Ok(plan); } - match t { + let df_schema = df_schema.to_owned(); + + match plan { LogicalPlan::TableScan(mut scan) => { let column_indices: Vec = substrait_schema .strip_qualifiers() @@ -1389,7 +1415,7 @@ fn from_substrait_jointype(join_type: i32) -> Result { /// Convert Substrait Sorts to DataFusion Exprs pub async fn from_substrait_sorts( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, substrait_sorts: &Vec, input_schema: &DFSchema, extensions: &Extensions, @@ -1397,7 +1423,7 @@ pub async fn from_substrait_sorts( let mut sorts: Vec = vec![]; for s in substrait_sorts { let expr = - from_substrait_rex(ctx, s.expr.as_ref().unwrap(), input_schema, extensions) + from_substrait_rex(state, s.expr.as_ref().unwrap(), input_schema, extensions) .await?; let asc_nullfirst = match &s.sort_kind { Some(k) => match k { @@ -1439,14 +1465,15 @@ pub async fn from_substrait_sorts( /// Convert Substrait Expressions to DataFusion Exprs pub async fn from_substrait_rex_vec( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, exprs: &Vec, input_schema: &DFSchema, extensions: &Extensions, ) -> Result> { let mut expressions: Vec = vec![]; for expr in exprs { - let expression = from_substrait_rex(ctx, expr, input_schema, extensions).await?; + let expression = + from_substrait_rex(state, expr, input_schema, extensions).await?; expressions.push(expression); } Ok(expressions) @@ -1454,7 +1481,7 @@ pub async fn from_substrait_rex_vec( /// Convert Substrait FunctionArguments to DataFusion Exprs pub async fn from_substrait_func_args( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, arguments: &Vec, input_schema: &DFSchema, extensions: &Extensions, @@ -1463,7 +1490,7 @@ pub async fn from_substrait_func_args( for arg in arguments { let arg_expr = match &arg.arg_type { Some(ArgType::Value(e)) => { - from_substrait_rex(ctx, e, input_schema, extensions).await + from_substrait_rex(state, e, input_schema, extensions).await } _ => not_impl_err!("Function argument non-Value type not supported"), }; @@ -1474,7 +1501,7 @@ pub async fn from_substrait_func_args( /// Convert Substrait AggregateFunction to DataFusion Expr pub async fn from_substrait_agg_func( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, f: &AggregateFunction, input_schema: &DFSchema, extensions: &Extensions, @@ -1483,7 +1510,7 @@ pub async fn from_substrait_agg_func( distinct: bool, ) -> Result> { let args = - from_substrait_func_args(ctx, &f.arguments, input_schema, extensions).await?; + from_substrait_func_args(state, &f.arguments, input_schema, extensions).await?; let Some(function_name) = extensions.functions.get(&f.function_reference) else { return plan_err!( @@ -1494,7 +1521,7 @@ pub async fn from_substrait_agg_func( let function_name = substrait_fun_name(function_name); // try udaf first, then built-in aggr fn. - if let Ok(fun) = ctx.udaf(function_name) { + if let Ok(fun) = state.udaf(function_name) { // deal with situation that count(*) got no arguments let args = if fun.name() == "count" && args.is_empty() { vec![Expr::Literal(ScalarValue::Int64(Some(1)))] @@ -1517,7 +1544,7 @@ pub async fn from_substrait_agg_func( /// Convert Substrait Rex to DataFusion Expr #[async_recursion] pub async fn from_substrait_rex( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, e: &Expression, input_schema: &DFSchema, extensions: &Extensions, @@ -1528,11 +1555,11 @@ pub async fn from_substrait_rex( let substrait_list = s.options.as_ref(); Ok(Expr::InList(InList { expr: Box::new( - from_substrait_rex(ctx, substrait_expr, input_schema, extensions) + from_substrait_rex(state, substrait_expr, input_schema, extensions) .await?, ), list: from_substrait_rex_vec( - ctx, + state, substrait_list, input_schema, extensions, @@ -1555,7 +1582,7 @@ pub async fn from_substrait_rex( if if_expr.then.is_none() { expr = Some(Box::new( from_substrait_rex( - ctx, + state, if_expr.r#if.as_ref().unwrap(), input_schema, extensions, @@ -1568,7 +1595,7 @@ pub async fn from_substrait_rex( when_then_expr.push(( Box::new( from_substrait_rex( - ctx, + state, if_expr.r#if.as_ref().unwrap(), input_schema, extensions, @@ -1577,7 +1604,7 @@ pub async fn from_substrait_rex( ), Box::new( from_substrait_rex( - ctx, + state, if_expr.then.as_ref().unwrap(), input_schema, extensions, @@ -1589,7 +1616,7 @@ pub async fn from_substrait_rex( // Parse `else` let else_expr = match &if_then.r#else { Some(e) => Some(Box::new( - from_substrait_rex(ctx, e, input_schema, extensions).await?, + from_substrait_rex(state, e, input_schema, extensions).await?, )), None => None, }; @@ -1609,12 +1636,12 @@ pub async fn from_substrait_rex( let fn_name = substrait_fun_name(fn_name); let args = - from_substrait_func_args(ctx, &f.arguments, input_schema, extensions) + from_substrait_func_args(state, &f.arguments, input_schema, extensions) .await?; // try to first match the requested function into registered udfs, then built-in ops // and finally built-in expressions - if let Some(func) = ctx.state().scalar_functions().get(fn_name) { + if let Ok(func) = state.udf(fn_name) { Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf( func.to_owned(), args, @@ -1644,7 +1671,7 @@ pub async fn from_substrait_rex( Ok(combined_expr) } else if let Some(builder) = BuiltinExprBuilder::try_from_name(fn_name) { - builder.build(ctx, f, input_schema, extensions).await + builder.build(state, f, input_schema, extensions).await } else { not_impl_err!("Unsupported function name: {fn_name:?}") } @@ -1657,7 +1684,7 @@ pub async fn from_substrait_rex( Some(output_type) => Ok(Expr::Cast(Cast::new( Box::new( from_substrait_rex( - ctx, + state, cast.as_ref().input.as_ref().unwrap().as_ref(), input_schema, extensions, @@ -1679,9 +1706,9 @@ pub async fn from_substrait_rex( let fn_name = substrait_fun_name(fn_name); // check udwf first, then udaf, then built-in window and aggregate functions - let fun = if let Ok(udwf) = ctx.udwf(fn_name) { + let fun = if let Ok(udwf) = state.udwf(fn_name) { Ok(WindowFunctionDefinition::WindowUDF(udwf)) - } else if let Ok(udaf) = ctx.udaf(fn_name) { + } else if let Ok(udaf) = state.udaf(fn_name) { Ok(WindowFunctionDefinition::AggregateUDF(udaf)) } else { not_impl_err!( @@ -1692,7 +1719,7 @@ pub async fn from_substrait_rex( }?; let order_by = - from_substrait_sorts(ctx, &window.sorts, input_schema, extensions) + from_substrait_sorts(state, &window.sorts, input_schema, extensions) .await?; let bound_units = @@ -1715,14 +1742,14 @@ pub async fn from_substrait_rex( Ok(Expr::WindowFunction(expr::WindowFunction { fun, args: from_substrait_func_args( - ctx, + state, &window.arguments, input_schema, extensions, ) .await?, partition_by: from_substrait_rex_vec( - ctx, + state, &window.partitions, input_schema, extensions, @@ -1747,13 +1774,13 @@ pub async fn from_substrait_rex( let haystack_expr = &in_predicate.haystack; if let Some(haystack_expr) = haystack_expr { let haystack_expr = - from_substrait_rel(ctx, haystack_expr, extensions) + from_substrait_rel(state, haystack_expr, extensions) .await?; let outer_refs = haystack_expr.all_out_ref_exprs(); Ok(Expr::InSubquery(InSubquery { expr: Box::new( from_substrait_rex( - ctx, + state, needle_expr, input_schema, extensions, @@ -1773,7 +1800,7 @@ pub async fn from_substrait_rex( } SubqueryType::Scalar(query) => { let plan = from_substrait_rel( - ctx, + state, &(query.input.clone()).unwrap_or_default(), extensions, ) @@ -1790,7 +1817,7 @@ pub async fn from_substrait_rex( PredicateOp::Exists => { let relation = &predicate.tuples; let plan = from_substrait_rel( - ctx, + state, &relation.clone().unwrap_or_default(), extensions, ) @@ -2772,7 +2799,7 @@ fn from_substrait_null( #[allow(deprecated)] async fn from_substrait_grouping( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, grouping: &Grouping, expressions: &[Expr], input_schema: &DFSchemaRef, @@ -2781,7 +2808,7 @@ async fn from_substrait_grouping( let mut group_exprs = vec![]; if !grouping.grouping_expressions.is_empty() { for e in &grouping.grouping_expressions { - let expr = from_substrait_rex(ctx, e, input_schema, extensions).await?; + let expr = from_substrait_rex(state, e, input_schema, extensions).await?; group_exprs.push(expr); } return Ok(group_exprs); @@ -2834,23 +2861,29 @@ impl BuiltinExprBuilder { pub async fn build( self, - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, f: &ScalarFunction, input_schema: &DFSchema, extensions: &Extensions, ) -> Result { match self.expr_name.as_str() { "like" => { - Self::build_like_expr(ctx, false, f, input_schema, extensions).await + Self::build_like_expr(state, false, f, input_schema, extensions).await } "ilike" => { - Self::build_like_expr(ctx, true, f, input_schema, extensions).await + Self::build_like_expr(state, true, f, input_schema, extensions).await } "not" | "negative" | "negate" | "is_null" | "is_not_null" | "is_true" | "is_false" | "is_not_true" | "is_not_false" | "is_unknown" | "is_not_unknown" => { - Self::build_unary_expr(ctx, &self.expr_name, f, input_schema, extensions) - .await + Self::build_unary_expr( + state, + &self.expr_name, + f, + input_schema, + extensions, + ) + .await } _ => { not_impl_err!("Unsupported builtin expression: {}", self.expr_name) @@ -2859,7 +2892,7 @@ impl BuiltinExprBuilder { } async fn build_unary_expr( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, fn_name: &str, f: &ScalarFunction, input_schema: &DFSchema, @@ -2872,7 +2905,7 @@ impl BuiltinExprBuilder { return substrait_err!("Invalid arguments type for {fn_name} expr"); }; let arg = - from_substrait_rex(ctx, expr_substrait, input_schema, extensions).await?; + from_substrait_rex(state, expr_substrait, input_schema, extensions).await?; let arg = Box::new(arg); let expr = match fn_name { @@ -2893,7 +2926,7 @@ impl BuiltinExprBuilder { } async fn build_like_expr( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, case_insensitive: bool, f: &ScalarFunction, input_schema: &DFSchema, @@ -2908,12 +2941,13 @@ impl BuiltinExprBuilder { return substrait_err!("Invalid arguments type for `{fn_name}` expr"); }; let expr = - from_substrait_rex(ctx, expr_substrait, input_schema, extensions).await?; + from_substrait_rex(state, expr_substrait, input_schema, extensions).await?; let Some(ArgType::Value(pattern_substrait)) = &f.arguments[1].arg_type else { return substrait_err!("Invalid arguments type for `{fn_name}` expr"); }; let pattern = - from_substrait_rex(ctx, pattern_substrait, input_schema, extensions).await?; + from_substrait_rex(state, pattern_substrait, input_schema, extensions) + .await?; // Default case: escape character is Literal(Utf8(None)) let escape_char = if f.arguments.len() == 3 { @@ -2922,9 +2956,13 @@ impl BuiltinExprBuilder { return substrait_err!("Invalid arguments type for `{fn_name}` expr"); }; - let escape_char_expr = - from_substrait_rex(ctx, escape_char_substrait, input_schema, extensions) - .await?; + let escape_char_expr = from_substrait_rex( + state, + escape_char_substrait, + input_schema, + extensions, + ) + .await?; match escape_char_expr { Expr::Literal(ScalarValue::Utf8(escape_char_string)) => { diff --git a/datafusion/substrait/src/logical_plan/mod.rs b/datafusion/substrait/src/logical_plan/mod.rs index 6f8b8e493f52..9e2fa9fa49de 100644 --- a/datafusion/substrait/src/logical_plan/mod.rs +++ b/datafusion/substrait/src/logical_plan/mod.rs @@ -17,3 +17,4 @@ pub mod consumer; pub mod producer; +pub mod state; diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 4d864e4334ce..29019dfd74f3 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -29,7 +29,7 @@ use datafusion::{ arrow::datatypes::{DataType, TimeUnit}, error::{DataFusionError, Result}, logical_expr::{WindowFrame, WindowFrameBound}, - prelude::{JoinType, SessionContext}, + prelude::JoinType, scalar::ScalarValue, }; @@ -100,8 +100,13 @@ use substrait::{ version, }; +use super::state::SubstraitPlanningState; + /// Convert DataFusion LogicalPlan to Substrait Plan -pub fn to_substrait_plan(plan: &LogicalPlan, ctx: &SessionContext) -> Result> { +pub fn to_substrait_plan( + plan: &LogicalPlan, + state: &dyn SubstraitPlanningState, +) -> Result> { let mut extensions = Extensions::default(); // Parse relation nodes // Generate PlanRel(s) @@ -113,7 +118,7 @@ pub fn to_substrait_plan(plan: &LogicalPlan, ctx: &SessionContext) -> Result Result Result> { let mut extensions = Extensions::default(); @@ -152,7 +157,7 @@ pub fn to_substrait_extended_expr( .iter() .map(|(expr, field)| { let substrait_expr = to_substrait_rex( - ctx, + state, expr, schema, /*col_ref_offset=*/ 0, @@ -183,7 +188,7 @@ pub fn to_substrait_extended_expr( #[allow(deprecated)] pub fn to_substrait_rel( plan: &LogicalPlan, - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, extensions: &mut Extensions, ) -> Result> { match plan { @@ -284,7 +289,7 @@ pub fn to_substrait_rel( let expressions = p .expr .iter() - .map(|e| to_substrait_rex(ctx, e, p.input.schema(), 0, extensions)) + .map(|e| to_substrait_rex(state, e, p.input.schema(), 0, extensions)) .collect::>>()?; let emit_kind = create_project_remapping( @@ -300,16 +305,16 @@ pub fn to_substrait_rel( Ok(Box::new(Rel { rel_type: Some(RelType::Project(Box::new(ProjectRel { common: Some(common), - input: Some(to_substrait_rel(p.input.as_ref(), ctx, extensions)?), + input: Some(to_substrait_rel(p.input.as_ref(), state, extensions)?), expressions, advanced_extension: None, }))), })) } LogicalPlan::Filter(filter) => { - let input = to_substrait_rel(filter.input.as_ref(), ctx, extensions)?; + let input = to_substrait_rel(filter.input.as_ref(), state, extensions)?; let filter_expr = to_substrait_rex( - ctx, + state, &filter.predicate, filter.input.schema(), 0, @@ -325,7 +330,7 @@ pub fn to_substrait_rel( })) } LogicalPlan::Limit(limit) => { - let input = to_substrait_rel(limit.input.as_ref(), ctx, extensions)?; + let input = to_substrait_rel(limit.input.as_ref(), state, extensions)?; let FetchType::Literal(fetch) = limit.get_fetch_type()? else { return not_impl_err!("Non-literal limit fetch"); }; @@ -344,11 +349,11 @@ pub fn to_substrait_rel( })) } LogicalPlan::Sort(sort) => { - let input = to_substrait_rel(sort.input.as_ref(), ctx, extensions)?; + let input = to_substrait_rel(sort.input.as_ref(), state, extensions)?; let sort_fields = sort .expr .iter() - .map(|e| substrait_sort_field(ctx, e, sort.input.schema(), extensions)) + .map(|e| substrait_sort_field(state, e, sort.input.schema(), extensions)) .collect::>>()?; Ok(Box::new(Rel { rel_type: Some(RelType::Sort(Box::new(SortRel { @@ -360,9 +365,9 @@ pub fn to_substrait_rel( })) } LogicalPlan::Aggregate(agg) => { - let input = to_substrait_rel(agg.input.as_ref(), ctx, extensions)?; + let input = to_substrait_rel(agg.input.as_ref(), state, extensions)?; let (grouping_expressions, groupings) = to_substrait_groupings( - ctx, + state, &agg.group_expr, agg.input.schema(), extensions, @@ -370,7 +375,9 @@ pub fn to_substrait_rel( let measures = agg .aggr_expr .iter() - .map(|e| to_substrait_agg_measure(ctx, e, agg.input.schema(), extensions)) + .map(|e| { + to_substrait_agg_measure(state, e, agg.input.schema(), extensions) + }) .collect::>>()?; Ok(Box::new(Rel { @@ -386,7 +393,7 @@ pub fn to_substrait_rel( } LogicalPlan::Distinct(Distinct::All(plan)) => { // Use Substrait's AggregateRel with empty measures to represent `select distinct` - let input = to_substrait_rel(plan.as_ref(), ctx, extensions)?; + let input = to_substrait_rel(plan.as_ref(), state, extensions)?; // Get grouping keys from the input relation's number of output fields let grouping = (0..plan.schema().fields().len()) .map(substrait_field_ref) @@ -407,8 +414,8 @@ pub fn to_substrait_rel( })) } LogicalPlan::Join(join) => { - let left = to_substrait_rel(join.left.as_ref(), ctx, extensions)?; - let right = to_substrait_rel(join.right.as_ref(), ctx, extensions)?; + let left = to_substrait_rel(join.left.as_ref(), state, extensions)?; + let right = to_substrait_rel(join.right.as_ref(), state, extensions)?; let join_type = to_substrait_jointype(join.join_type); // we only support basic joins so return an error for anything not yet supported match join.join_constraint { @@ -421,7 +428,7 @@ pub fn to_substrait_rel( let in_join_schema = join.left.schema().join(join.right.schema())?; let join_filter = match &join.filter { Some(filter) => Some(to_substrait_rex( - ctx, + state, filter, &Arc::new(in_join_schema), 0, @@ -438,7 +445,7 @@ pub fn to_substrait_rel( Operator::Eq }; let join_on = to_substrait_join_expr( - ctx, + state, &join.on, eq_op, join.left.schema(), @@ -479,13 +486,13 @@ pub fn to_substrait_rel( LogicalPlan::SubqueryAlias(alias) => { // Do nothing if encounters SubqueryAlias // since there is no corresponding relation type in Substrait - to_substrait_rel(alias.input.as_ref(), ctx, extensions) + to_substrait_rel(alias.input.as_ref(), state, extensions) } LogicalPlan::Union(union) => { let input_rels = union .inputs .iter() - .map(|input| to_substrait_rel(input.as_ref(), ctx, extensions)) + .map(|input| to_substrait_rel(input.as_ref(), state, extensions)) .collect::>>()? .into_iter() .map(|ptr| *ptr) @@ -500,7 +507,7 @@ pub fn to_substrait_rel( })) } LogicalPlan::Window(window) => { - let input = to_substrait_rel(window.input.as_ref(), ctx, extensions)?; + let input = to_substrait_rel(window.input.as_ref(), state, extensions)?; // create a field reference for each input field let mut expressions = (0..window.input.schema().fields().len()) @@ -510,7 +517,7 @@ pub fn to_substrait_rel( // process and add each window function expression for expr in &window.window_expr { expressions.push(to_substrait_rex( - ctx, + state, expr, window.input.schema(), 0, @@ -539,7 +546,7 @@ pub fn to_substrait_rel( })) } LogicalPlan::Repartition(repartition) => { - let input = to_substrait_rel(repartition.input.as_ref(), ctx, extensions)?; + let input = to_substrait_rel(repartition.input.as_ref(), state, extensions)?; let partition_count = match repartition.partitioning_scheme { Partitioning::RoundRobinBatch(num) => num, Partitioning::Hash(_, num) => num, @@ -585,8 +592,7 @@ pub fn to_substrait_rel( })) } LogicalPlan::Extension(extension_plan) => { - let extension_bytes = ctx - .state() + let extension_bytes = state .serializer_registry() .serialize_logical_plan(extension_plan.node.as_ref())?; let detail = ProtoAny { @@ -597,7 +603,7 @@ pub fn to_substrait_rel( .node .inputs() .into_iter() - .map(|plan| to_substrait_rel(plan, ctx, extensions)) + .map(|plan| to_substrait_rel(plan, state, extensions)) .collect::>>()?; let rel_type = match inputs_rel.len() { 0 => RelType::ExtensionLeaf(ExtensionLeafRel { @@ -687,7 +693,7 @@ fn to_substrait_named_struct(schema: &DFSchemaRef) -> Result { } fn to_substrait_join_expr( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, join_conditions: &Vec<(Expr, Expr)>, eq_op: Operator, left_schema: &DFSchemaRef, @@ -698,10 +704,10 @@ fn to_substrait_join_expr( let mut exprs: Vec = vec![]; for (left, right) in join_conditions { // Parse left - let l = to_substrait_rex(ctx, left, left_schema, 0, extensions)?; + let l = to_substrait_rex(state, left, left_schema, 0, extensions)?; // Parse right let r = to_substrait_rex( - ctx, + state, right, right_schema, left_schema.fields().len(), // offset to return the correct index @@ -770,7 +776,7 @@ pub fn operator_to_name(op: Operator) -> &'static str { #[allow(deprecated)] pub fn parse_flat_grouping_exprs( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, exprs: &[Expr], schema: &DFSchemaRef, extensions: &mut Extensions, @@ -780,7 +786,7 @@ pub fn parse_flat_grouping_exprs( let mut grouping_expressions = vec![]; for e in exprs { - let rex = to_substrait_rex(ctx, e, schema, 0, extensions)?; + let rex = to_substrait_rex(state, e, schema, 0, extensions)?; grouping_expressions.push(rex.clone()); ref_group_exprs.push(rex); expression_references.push((ref_group_exprs.len() - 1) as u32); @@ -792,7 +798,7 @@ pub fn parse_flat_grouping_exprs( } pub fn to_substrait_groupings( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, exprs: &[Expr], schema: &DFSchemaRef, extensions: &mut Extensions, @@ -808,7 +814,7 @@ pub fn to_substrait_groupings( .iter() .map(|set| { parse_flat_grouping_exprs( - ctx, + state, set, schema, extensions, @@ -826,7 +832,7 @@ pub fn to_substrait_groupings( .rev() .map(|set| { parse_flat_grouping_exprs( - ctx, + state, set, schema, extensions, @@ -837,7 +843,7 @@ pub fn to_substrait_groupings( } }, _ => Ok(vec![parse_flat_grouping_exprs( - ctx, + state, exprs, schema, extensions, @@ -845,7 +851,7 @@ pub fn to_substrait_groupings( )?]), }, _ => Ok(vec![parse_flat_grouping_exprs( - ctx, + state, exprs, schema, extensions, @@ -857,7 +863,7 @@ pub fn to_substrait_groupings( #[allow(deprecated)] pub fn to_substrait_agg_measure( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, expr: &Expr, schema: &DFSchemaRef, extensions: &mut Extensions, @@ -865,13 +871,13 @@ pub fn to_substrait_agg_measure( match expr { Expr::AggregateFunction(expr::AggregateFunction { func, args, distinct, filter, order_by, null_treatment: _, }) => { let sorts = if let Some(order_by) = order_by { - order_by.iter().map(|expr| to_substrait_sort_field(ctx, expr, schema, extensions)).collect::>>()? + order_by.iter().map(|expr| to_substrait_sort_field(state, expr, schema, extensions)).collect::>>()? } else { vec![] }; let mut arguments: Vec = vec![]; for arg in args { - arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(ctx, arg, schema, 0, extensions)?)) }); + arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(state, arg, schema, 0, extensions)?)) }); } let function_anchor = extensions.register_function(func.name().to_string()); Ok(Measure { @@ -889,14 +895,14 @@ pub fn to_substrait_agg_measure( options: vec![], }), filter: match filter { - Some(f) => Some(to_substrait_rex(ctx, f, schema, 0, extensions)?), + Some(f) => Some(to_substrait_rex(state, f, schema, 0, extensions)?), None => None } }) } Expr::Alias(Alias{expr,..})=> { - to_substrait_agg_measure(ctx, expr, schema, extensions) + to_substrait_agg_measure(state, expr, schema, extensions) } _ => internal_err!( "Expression must be compatible with aggregation. Unsupported expression: {:?}. ExpressionType: {:?}", @@ -908,7 +914,7 @@ pub fn to_substrait_agg_measure( /// Converts sort expression to corresponding substrait `SortField` fn to_substrait_sort_field( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, sort: &Sort, schema: &DFSchemaRef, extensions: &mut Extensions, @@ -920,7 +926,7 @@ fn to_substrait_sort_field( (false, false) => SortDirection::DescNullsLast, }; Ok(SortField { - expr: Some(to_substrait_rex(ctx, &sort.expr, schema, 0, extensions)?), + expr: Some(to_substrait_rex(state, &sort.expr, schema, 0, extensions)?), sort_kind: Some(SortKind::Direction(sort_kind.into())), }) } @@ -977,7 +983,7 @@ pub fn make_binary_op_scalar_func( /// * `extensions` - Substrait extension info. Contains registered function information #[allow(deprecated)] pub fn to_substrait_rex( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, expr: &Expr, schema: &DFSchemaRef, col_ref_offset: usize, @@ -991,10 +997,10 @@ pub fn to_substrait_rex( }) => { let substrait_list = list .iter() - .map(|x| to_substrait_rex(ctx, x, schema, col_ref_offset, extensions)) + .map(|x| to_substrait_rex(state, x, schema, col_ref_offset, extensions)) .collect::>>()?; let substrait_expr = - to_substrait_rex(ctx, expr, schema, col_ref_offset, extensions)?; + to_substrait_rex(state, expr, schema, col_ref_offset, extensions)?; let substrait_or_list = Expression { rex_type: Some(RexType::SingularOrList(Box::new(SingularOrList { @@ -1026,7 +1032,7 @@ pub fn to_substrait_rex( for arg in &fun.args { arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex( - ctx, + state, arg, schema, col_ref_offset, @@ -1055,11 +1061,11 @@ pub fn to_substrait_rex( if *negated { // `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr) let substrait_expr = - to_substrait_rex(ctx, expr, schema, col_ref_offset, extensions)?; + to_substrait_rex(state, expr, schema, col_ref_offset, extensions)?; let substrait_low = - to_substrait_rex(ctx, low, schema, col_ref_offset, extensions)?; + to_substrait_rex(state, low, schema, col_ref_offset, extensions)?; let substrait_high = - to_substrait_rex(ctx, high, schema, col_ref_offset, extensions)?; + to_substrait_rex(state, high, schema, col_ref_offset, extensions)?; let l_expr = make_binary_op_scalar_func( &substrait_expr, @@ -1083,11 +1089,11 @@ pub fn to_substrait_rex( } else { // `expr BETWEEN low AND high` can be translated into (low <= expr AND expr <= high) let substrait_expr = - to_substrait_rex(ctx, expr, schema, col_ref_offset, extensions)?; + to_substrait_rex(state, expr, schema, col_ref_offset, extensions)?; let substrait_low = - to_substrait_rex(ctx, low, schema, col_ref_offset, extensions)?; + to_substrait_rex(state, low, schema, col_ref_offset, extensions)?; let substrait_high = - to_substrait_rex(ctx, high, schema, col_ref_offset, extensions)?; + to_substrait_rex(state, high, schema, col_ref_offset, extensions)?; let l_expr = make_binary_op_scalar_func( &substrait_low, @@ -1115,8 +1121,8 @@ pub fn to_substrait_rex( substrait_field_ref(index + col_ref_offset) } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - let l = to_substrait_rex(ctx, left, schema, col_ref_offset, extensions)?; - let r = to_substrait_rex(ctx, right, schema, col_ref_offset, extensions)?; + let l = to_substrait_rex(state, left, schema, col_ref_offset, extensions)?; + let r = to_substrait_rex(state, right, schema, col_ref_offset, extensions)?; Ok(make_binary_op_scalar_func(&l, &r, *op, extensions)) } @@ -1131,7 +1137,7 @@ pub fn to_substrait_rex( // Base expression exists ifs.push(IfClause { r#if: Some(to_substrait_rex( - ctx, + state, e, schema, col_ref_offset, @@ -1144,14 +1150,14 @@ pub fn to_substrait_rex( for (r#if, then) in when_then_expr { ifs.push(IfClause { r#if: Some(to_substrait_rex( - ctx, + state, r#if, schema, col_ref_offset, extensions, )?), then: Some(to_substrait_rex( - ctx, + state, then, schema, col_ref_offset, @@ -1163,7 +1169,7 @@ pub fn to_substrait_rex( // Parse outer `else` let r#else: Option> = match else_expr { Some(e) => Some(Box::new(to_substrait_rex( - ctx, + state, e, schema, col_ref_offset, @@ -1182,7 +1188,7 @@ pub fn to_substrait_rex( substrait::proto::expression::Cast { r#type: Some(to_substrait_type(data_type, true)?), input: Some(Box::new(to_substrait_rex( - ctx, + state, expr, schema, col_ref_offset, @@ -1195,7 +1201,7 @@ pub fn to_substrait_rex( } Expr::Literal(value) => to_substrait_literal_expr(value, extensions), Expr::Alias(Alias { expr, .. }) => { - to_substrait_rex(ctx, expr, schema, col_ref_offset, extensions) + to_substrait_rex(state, expr, schema, col_ref_offset, extensions) } Expr::WindowFunction(WindowFunction { fun, @@ -1212,7 +1218,7 @@ pub fn to_substrait_rex( for arg in args { arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex( - ctx, + state, arg, schema, col_ref_offset, @@ -1223,12 +1229,12 @@ pub fn to_substrait_rex( // partition by expressions let partition_by = partition_by .iter() - .map(|e| to_substrait_rex(ctx, e, schema, col_ref_offset, extensions)) + .map(|e| to_substrait_rex(state, e, schema, col_ref_offset, extensions)) .collect::>>()?; // order by expressions let order_by = order_by .iter() - .map(|e| substrait_sort_field(ctx, e, schema, extensions)) + .map(|e| substrait_sort_field(state, e, schema, extensions)) .collect::>>()?; // window frame let bounds = to_substrait_bounds(window_frame)?; @@ -1249,7 +1255,7 @@ pub fn to_substrait_rex( escape_char, case_insensitive, }) => make_substrait_like_expr( - ctx, + state, *case_insensitive, *negated, expr, @@ -1265,10 +1271,10 @@ pub fn to_substrait_rex( negated, }) => { let substrait_expr = - to_substrait_rex(ctx, expr, schema, col_ref_offset, extensions)?; + to_substrait_rex(state, expr, schema, col_ref_offset, extensions)?; let subquery_plan = - to_substrait_rel(subquery.subquery.as_ref(), ctx, extensions)?; + to_substrait_rel(subquery.subquery.as_ref(), state, extensions)?; let substrait_subquery = Expression { rex_type: Some(RexType::Subquery(Box::new(Subquery { @@ -1301,7 +1307,7 @@ pub fn to_substrait_rex( } } Expr::Not(arg) => to_substrait_unary_scalar_fn( - ctx, + state, "not", arg, schema, @@ -1309,7 +1315,7 @@ pub fn to_substrait_rex( extensions, ), Expr::IsNull(arg) => to_substrait_unary_scalar_fn( - ctx, + state, "is_null", arg, schema, @@ -1317,7 +1323,7 @@ pub fn to_substrait_rex( extensions, ), Expr::IsNotNull(arg) => to_substrait_unary_scalar_fn( - ctx, + state, "is_not_null", arg, schema, @@ -1325,7 +1331,7 @@ pub fn to_substrait_rex( extensions, ), Expr::IsTrue(arg) => to_substrait_unary_scalar_fn( - ctx, + state, "is_true", arg, schema, @@ -1333,7 +1339,7 @@ pub fn to_substrait_rex( extensions, ), Expr::IsFalse(arg) => to_substrait_unary_scalar_fn( - ctx, + state, "is_false", arg, schema, @@ -1341,7 +1347,7 @@ pub fn to_substrait_rex( extensions, ), Expr::IsUnknown(arg) => to_substrait_unary_scalar_fn( - ctx, + state, "is_unknown", arg, schema, @@ -1349,7 +1355,7 @@ pub fn to_substrait_rex( extensions, ), Expr::IsNotTrue(arg) => to_substrait_unary_scalar_fn( - ctx, + state, "is_not_true", arg, schema, @@ -1357,7 +1363,7 @@ pub fn to_substrait_rex( extensions, ), Expr::IsNotFalse(arg) => to_substrait_unary_scalar_fn( - ctx, + state, "is_not_false", arg, schema, @@ -1365,7 +1371,7 @@ pub fn to_substrait_rex( extensions, ), Expr::IsNotUnknown(arg) => to_substrait_unary_scalar_fn( - ctx, + state, "is_not_unknown", arg, schema, @@ -1373,7 +1379,7 @@ pub fn to_substrait_rex( extensions, ), Expr::Negative(arg) => to_substrait_unary_scalar_fn( - ctx, + state, "negate", arg, schema, @@ -1674,7 +1680,7 @@ fn make_substrait_window_function( #[allow(deprecated)] #[allow(clippy::too_many_arguments)] fn make_substrait_like_expr( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, ignore_case: bool, negated: bool, expr: &Expr, @@ -1689,8 +1695,8 @@ fn make_substrait_like_expr( } else { extensions.register_function("like".to_string()) }; - let expr = to_substrait_rex(ctx, expr, schema, col_ref_offset, extensions)?; - let pattern = to_substrait_rex(ctx, pattern, schema, col_ref_offset, extensions)?; + let expr = to_substrait_rex(state, expr, schema, col_ref_offset, extensions)?; + let pattern = to_substrait_rex(state, pattern, schema, col_ref_offset, extensions)?; let escape_char = to_substrait_literal_expr( &ScalarValue::Utf8(escape_char.map(|c| c.to_string())), extensions, @@ -2088,7 +2094,7 @@ fn to_substrait_literal_expr( /// Util to generate substrait [RexType::ScalarFunction] with one argument fn to_substrait_unary_scalar_fn( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, fn_name: &str, arg: &Expr, schema: &DFSchemaRef, @@ -2096,7 +2102,8 @@ fn to_substrait_unary_scalar_fn( extensions: &mut Extensions, ) -> Result { let function_anchor = extensions.register_function(fn_name.to_string()); - let substrait_expr = to_substrait_rex(ctx, arg, schema, col_ref_offset, extensions)?; + let substrait_expr = + to_substrait_rex(state, arg, schema, col_ref_offset, extensions)?; Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { @@ -2137,7 +2144,7 @@ fn try_to_substrait_field_reference( } fn substrait_sort_field( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, sort: &Sort, schema: &DFSchemaRef, extensions: &mut Extensions, @@ -2147,7 +2154,7 @@ fn substrait_sort_field( asc, nulls_first, } = sort; - let e = to_substrait_rex(ctx, expr, schema, 0, extensions)?; + let e = to_substrait_rex(state, expr, schema, 0, extensions)?; let d = match (asc, nulls_first) { (true, true) => SortDirection::AscNullsFirst, (true, false) => SortDirection::AscNullsLast, @@ -2190,6 +2197,7 @@ mod test { use datafusion::arrow::datatypes::{Field, Fields, Schema}; use datafusion::common::scalar::ScalarStructBuilder; use datafusion::common::DFSchema; + use datafusion::execution::SessionStateBuilder; #[test] fn round_trip_literals() -> Result<()> { @@ -2433,15 +2441,15 @@ mod test { #[tokio::test] async fn extended_expressions() -> Result<()> { - let ctx = SessionContext::new(); + let state = SessionStateBuilder::default().build(); // One expression, empty input schema let expr = Expr::Literal(ScalarValue::Int32(Some(42))); let field = Field::new("out", DataType::Int32, false); let empty_schema = DFSchemaRef::new(DFSchema::empty()); let substrait = - to_substrait_extended_expr(&[(&expr, &field)], &empty_schema, &ctx)?; - let roundtrip_expr = from_substrait_extended_expr(&ctx, &substrait).await?; + to_substrait_extended_expr(&[(&expr, &field)], &empty_schema, &state)?; + let roundtrip_expr = from_substrait_extended_expr(&state, &substrait).await?; assert_eq!(roundtrip_expr.input_schema, empty_schema); assert_eq!(roundtrip_expr.exprs.len(), 1); @@ -2463,9 +2471,9 @@ mod test { let substrait = to_substrait_extended_expr( &[(&expr1, &out1), (&expr2, &out2)], &input_schema, - &ctx, + &state, )?; - let roundtrip_expr = from_substrait_extended_expr(&ctx, &substrait).await?; + let roundtrip_expr = from_substrait_extended_expr(&state, &substrait).await?; assert_eq!(roundtrip_expr.input_schema, input_schema); assert_eq!(roundtrip_expr.exprs.len(), 2); @@ -2485,14 +2493,14 @@ mod test { #[tokio::test] async fn invalid_extended_expression() { - let ctx = SessionContext::new(); + let state = SessionStateBuilder::default().build(); // Not ok if input schema is missing field referenced by expr let expr = Expr::Column("missing".into()); let field = Field::new("out", DataType::Int32, false); let empty_schema = DFSchemaRef::new(DFSchema::empty()); - let err = to_substrait_extended_expr(&[(&expr, &field)], &empty_schema, &ctx); + let err = to_substrait_extended_expr(&[(&expr, &field)], &empty_schema, &state); assert!(matches!(err, Err(DataFusionError::SchemaError(_, _)))); } diff --git a/datafusion/substrait/src/logical_plan/state.rs b/datafusion/substrait/src/logical_plan/state.rs new file mode 100644 index 000000000000..0bd749c1105d --- /dev/null +++ b/datafusion/substrait/src/logical_plan/state.rs @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use async_trait::async_trait; +use datafusion::{ + catalog::TableProvider, + error::{DataFusionError, Result}, + execution::{registry::SerializerRegistry, FunctionRegistry, SessionState}, + sql::TableReference, +}; + +/// This trait provides the context needed to transform a substrait plan into a +/// [`datafusion::logical_expr::LogicalPlan`] (via [`super::consumer::from_substrait_plan`]) +/// and back again into a substrait plan (via [`super::producer::to_substrait_plan`]). +/// +/// The context is declared as a trait to decouple the substrait plan encoder / +/// decoder from the [`SessionState`], potentially allowing users to define +/// their own slimmer context just for serializing and deserializing substrait. +/// +/// [`SessionState`] implements this trait. +#[async_trait] +pub trait SubstraitPlanningState: Sync + Send + FunctionRegistry { + /// Return [SerializerRegistry] for extensions + fn serializer_registry(&self) -> &Arc; + + async fn table( + &self, + reference: &TableReference, + ) -> Result>>; +} + +#[async_trait] +impl SubstraitPlanningState for SessionState { + fn serializer_registry(&self) -> &Arc { + self.serializer_registry() + } + + async fn table( + &self, + reference: &TableReference, + ) -> Result>, DataFusionError> { + let table = reference.table().to_string(); + let schema = self.schema_for_ref(reference.clone())?; + let table_provider = schema.table(&table).await?; + Ok(table_provider) + } +} diff --git a/datafusion/substrait/src/serializer.rs b/datafusion/substrait/src/serializer.rs index 6b81e33dfc37..4278671777fd 100644 --- a/datafusion/substrait/src/serializer.rs +++ b/datafusion/substrait/src/serializer.rs @@ -38,7 +38,7 @@ pub async fn serialize(sql: &str, ctx: &SessionContext, path: &str) -> Result<() pub async fn serialize_bytes(sql: &str, ctx: &SessionContext) -> Result> { let df = ctx.sql(sql).await?; let plan = df.into_optimized_plan()?; - let proto = producer::to_substrait_plan(&plan, ctx)?; + let proto = producer::to_substrait_plan(&plan, &ctx.state())?; let mut protobuf_out = Vec::::new(); proto.encode(&mut protobuf_out).map_err(|e| { diff --git a/datafusion/substrait/tests/cases/consumer_integration.rs b/datafusion/substrait/tests/cases/consumer_integration.rs index bc38ef82977f..219f656bb471 100644 --- a/datafusion/substrait/tests/cases/consumer_integration.rs +++ b/datafusion/substrait/tests/cases/consumer_integration.rs @@ -41,7 +41,7 @@ mod tests { .expect("failed to parse json"); let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto)?; - let plan = from_substrait_plan(&ctx, &proto).await?; + let plan = from_substrait_plan(&ctx.state(), &proto).await?; Ok(format!("{}", plan)) } diff --git a/datafusion/substrait/tests/cases/emit_kind_tests.rs b/datafusion/substrait/tests/cases/emit_kind_tests.rs index ac66177ed796..08537d0d110f 100644 --- a/datafusion/substrait/tests/cases/emit_kind_tests.rs +++ b/datafusion/substrait/tests/cases/emit_kind_tests.rs @@ -33,7 +33,7 @@ mod tests { "tests/testdata/test_plans/emit_kind/direct_on_project.substrait.json", ); let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?; - let plan = from_substrait_plan(&ctx, &proto_plan).await?; + let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; let plan_str = format!("{}", plan); @@ -51,7 +51,7 @@ mod tests { "tests/testdata/test_plans/emit_kind/emit_on_filter.substrait.json", ); let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?; - let plan = from_substrait_plan(&ctx, &proto_plan).await?; + let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; let plan_str = format!("{}", plan); @@ -91,8 +91,8 @@ mod tests { \n TableScan: data" ); - let proto = to_substrait_plan(&plan, &ctx)?; - let plan2 = from_substrait_plan(&ctx, &proto).await?; + let proto = to_substrait_plan(&plan, &ctx.state())?; + let plan2 = from_substrait_plan(&ctx.state(), &proto).await?; // note how the Projections are not flattened assert_eq!( format!("{}", plan2), @@ -115,8 +115,8 @@ mod tests { \n TableScan: data" ); - let proto = to_substrait_plan(&plan, &ctx)?; - let plan2 = from_substrait_plan(&ctx, &proto).await?; + let proto = to_substrait_plan(&plan, &ctx.state())?; + let plan2 = from_substrait_plan(&ctx.state(), &proto).await?; let plan1str = format!("{plan}"); let plan2str = format!("{plan2}"); diff --git a/datafusion/substrait/tests/cases/function_test.rs b/datafusion/substrait/tests/cases/function_test.rs index b136b0af19c2..043808456176 100644 --- a/datafusion/substrait/tests/cases/function_test.rs +++ b/datafusion/substrait/tests/cases/function_test.rs @@ -29,7 +29,7 @@ mod tests { async fn contains_function_test() -> Result<()> { let proto_plan = read_json("tests/testdata/contains_plan.substrait.json"); let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?; - let plan = from_substrait_plan(&ctx, &proto_plan).await?; + let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; let plan_str = format!("{}", plan); diff --git a/datafusion/substrait/tests/cases/logical_plans.rs b/datafusion/substrait/tests/cases/logical_plans.rs index f4e34af35d78..65f404bbda55 100644 --- a/datafusion/substrait/tests/cases/logical_plans.rs +++ b/datafusion/substrait/tests/cases/logical_plans.rs @@ -38,7 +38,7 @@ mod tests { let proto_plan = read_json("tests/testdata/test_plans/select_not_bool.substrait.json"); let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?; - let plan = from_substrait_plan(&ctx, &proto_plan).await?; + let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; assert_eq!( format!("{}", plan), @@ -63,7 +63,7 @@ mod tests { let proto_plan = read_json("tests/testdata/test_plans/select_window.substrait.json"); let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?; - let plan = from_substrait_plan(&ctx, &proto_plan).await?; + let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; assert_eq!( format!("{}", plan), @@ -82,7 +82,7 @@ mod tests { let proto_plan = read_json("tests/testdata/test_plans/non_nullable_lists.substrait.json"); let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?; - let plan = from_substrait_plan(&ctx, &proto_plan).await?; + let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; assert_eq!(format!("{}", &plan), "Values: (List([1, 2]))"); diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index d4e2d48885ae..d03ab5182028 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -979,8 +979,8 @@ async fn extension_logical_plan() -> Result<()> { }), }); - let proto = to_substrait_plan(&ext_plan, &ctx)?; - let plan2 = from_substrait_plan(&ctx, &proto).await?; + let proto = to_substrait_plan(&ext_plan, &ctx.state())?; + let plan2 = from_substrait_plan(&ctx.state(), &proto).await?; let plan1str = format!("{ext_plan}"); let plan2str = format!("{plan2}"); @@ -1081,8 +1081,8 @@ async fn roundtrip_repartition_roundrobin() -> Result<()> { partitioning_scheme: Partitioning::RoundRobinBatch(8), }); - let proto = to_substrait_plan(&plan, &ctx)?; - let plan2 = from_substrait_plan(&ctx, &proto).await?; + let proto = to_substrait_plan(&plan, &ctx.state())?; + let plan2 = from_substrait_plan(&ctx.state(), &proto).await?; let plan2 = ctx.state().optimize(&plan2)?; assert_eq!(format!("{plan}"), format!("{plan2}")); @@ -1098,8 +1098,8 @@ async fn roundtrip_repartition_hash() -> Result<()> { partitioning_scheme: Partitioning::Hash(vec![col("data.a")], 8), }); - let proto = to_substrait_plan(&plan, &ctx)?; - let plan2 = from_substrait_plan(&ctx, &proto).await?; + let proto = to_substrait_plan(&plan, &ctx.state())?; + let plan2 = from_substrait_plan(&ctx.state(), &proto).await?; let plan2 = ctx.state().optimize(&plan2)?; assert_eq!(format!("{plan}"), format!("{plan2}")); @@ -1199,8 +1199,8 @@ async fn assert_expected_plan_unoptimized( let ctx = create_context().await?; let df = ctx.sql(sql).await?; let plan = df.into_unoptimized_plan(); - let proto = to_substrait_plan(&plan, &ctx)?; - let plan2 = from_substrait_plan(&ctx, &proto).await?; + let proto = to_substrait_plan(&plan, &ctx.state())?; + let plan2 = from_substrait_plan(&ctx.state(), &proto).await?; println!("{plan}"); println!("{plan2}"); @@ -1225,8 +1225,8 @@ async fn assert_expected_plan( let ctx = create_context().await?; let df = ctx.sql(sql).await?; let plan = df.into_optimized_plan()?; - let proto = to_substrait_plan(&plan, &ctx)?; - let plan2 = from_substrait_plan(&ctx, &proto).await?; + let proto = to_substrait_plan(&plan, &ctx.state())?; + let plan2 = from_substrait_plan(&ctx.state(), &proto).await?; let plan2 = ctx.state().optimize(&plan2)?; println!("{plan}"); @@ -1250,7 +1250,7 @@ async fn assert_expected_plan_substrait( ) -> Result<()> { let ctx = create_context().await?; - let plan = from_substrait_plan(&ctx, &substrait_plan).await?; + let plan = from_substrait_plan(&ctx.state(), &substrait_plan).await?; let plan = ctx.state().optimize(&plan)?; @@ -1265,7 +1265,7 @@ async fn assert_substrait_sql(substrait_plan: Plan, sql: &str) -> Result<()> { let expected = ctx.sql(sql).await?.into_optimized_plan()?; - let plan = from_substrait_plan(&ctx, &substrait_plan).await?; + let plan = from_substrait_plan(&ctx.state(), &substrait_plan).await?; let plan = ctx.state().optimize(&plan)?; @@ -1280,8 +1280,8 @@ async fn roundtrip_fill_na(sql: &str) -> Result<()> { let ctx = create_context().await?; let df = ctx.sql(sql).await?; let plan = df.into_optimized_plan()?; - let proto = to_substrait_plan(&plan, &ctx)?; - let plan2 = from_substrait_plan(&ctx, &proto).await?; + let proto = to_substrait_plan(&plan, &ctx.state())?; + let plan2 = from_substrait_plan(&ctx.state(), &proto).await?; let plan2 = ctx.state().optimize(&plan2)?; // Format plan string and replace all None's with 0 @@ -1301,12 +1301,12 @@ async fn test_alias(sql_with_alias: &str, sql_no_alias: &str) -> Result<()> { let ctx = create_context().await?; let df_a = ctx.sql(sql_with_alias).await?; - let proto_a = to_substrait_plan(&df_a.into_optimized_plan()?, &ctx)?; - let plan_with_alias = from_substrait_plan(&ctx, &proto_a).await?; + let proto_a = to_substrait_plan(&df_a.into_optimized_plan()?, &ctx.state())?; + let plan_with_alias = from_substrait_plan(&ctx.state(), &proto_a).await?; let df = ctx.sql(sql_no_alias).await?; - let proto = to_substrait_plan(&df.into_optimized_plan()?, &ctx)?; - let plan = from_substrait_plan(&ctx, &proto).await?; + let proto = to_substrait_plan(&df.into_optimized_plan()?, &ctx.state())?; + let plan = from_substrait_plan(&ctx.state(), &proto).await?; println!("{plan_with_alias}"); println!("{plan}"); @@ -1323,8 +1323,8 @@ async fn roundtrip_logical_plan_with_ctx( plan: LogicalPlan, ctx: SessionContext, ) -> Result> { - let proto = to_substrait_plan(&plan, &ctx)?; - let plan2 = from_substrait_plan(&ctx, &proto).await?; + let proto = to_substrait_plan(&plan, &ctx.state())?; + let plan2 = from_substrait_plan(&ctx.state(), &proto).await?; let plan2 = ctx.state().optimize(&plan2)?; println!("{plan}"); diff --git a/datafusion/substrait/tests/cases/serialize.rs b/datafusion/substrait/tests/cases/serialize.rs index 54d55d1b6f10..e28c63312788 100644 --- a/datafusion/substrait/tests/cases/serialize.rs +++ b/datafusion/substrait/tests/cases/serialize.rs @@ -45,7 +45,7 @@ mod tests { // Read substrait plan from file let proto = serializer::deserialize(path).await?; // Check plan equality - let plan = from_substrait_plan(&ctx, &proto).await?; + let plan = from_substrait_plan(&ctx.state(), &proto).await?; let plan_str_ref = format!("{plan_ref}"); let plan_str = format!("{plan}"); assert_eq!(plan_str_ref, plan_str); @@ -60,7 +60,7 @@ mod tests { let ctx = create_context().await?; let table = provider_as_source(ctx.table_provider("data").await?); let table_scan = LogicalPlanBuilder::scan("data", table, None)?.build()?; - let convert_result = to_substrait_plan(&table_scan, &ctx); + let convert_result = to_substrait_plan(&table_scan, &ctx.state()); assert!(convert_result.is_ok()); Ok(()) @@ -78,7 +78,9 @@ mod tests { \n TableScan: data projection=[a, b]", ); - let plan = to_substrait_plan(&datafusion_plan, &ctx)?.as_ref().clone(); + let plan = to_substrait_plan(&datafusion_plan, &ctx.state())? + .as_ref() + .clone(); let relation = plan.relations.first().unwrap().rel_type.as_ref(); let root_rel = match relation { @@ -121,7 +123,9 @@ mod tests { \n TableScan: data projection=[a, b, c]", ); - let plan = to_substrait_plan(&datafusion_plan, &ctx)?.as_ref().clone(); + let plan = to_substrait_plan(&datafusion_plan, &ctx.state())? + .as_ref() + .clone(); let relation = plan.relations.first().unwrap().rel_type.as_ref(); let root_rel = match relation { diff --git a/datafusion/substrait/tests/cases/substrait_validations.rs b/datafusion/substrait/tests/cases/substrait_validations.rs index 5ae586afe56f..c77bf1489f4e 100644 --- a/datafusion/substrait/tests/cases/substrait_validations.rs +++ b/datafusion/substrait/tests/cases/substrait_validations.rs @@ -65,7 +65,7 @@ mod tests { vec![("a", DataType::Int32, false), ("b", DataType::Int32, true)]; let ctx = generate_context_with_table("DATA", df_schema)?; - let plan = from_substrait_plan(&ctx, &proto_plan).await?; + let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; assert_eq!( format!("{}", plan), @@ -86,7 +86,7 @@ mod tests { ("c", DataType::Int32, false), ]; let ctx = generate_context_with_table("DATA", df_schema)?; - let plan = from_substrait_plan(&ctx, &proto_plan).await?; + let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; assert_eq!( format!("{}", plan), @@ -109,7 +109,7 @@ mod tests { ("b", DataType::Int32, false), ]; let ctx = generate_context_with_table("DATA", df_schema)?; - let plan = from_substrait_plan(&ctx, &proto_plan).await?; + let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; assert_eq!( format!("{}", plan), @@ -128,7 +128,7 @@ mod tests { vec![("a", DataType::Int32, false), ("c", DataType::Int32, true)]; let ctx = generate_context_with_table("DATA", df_schema)?; - let res = from_substrait_plan(&ctx, &proto_plan).await; + let res = from_substrait_plan(&ctx.state(), &proto_plan).await; assert!(res.is_err()); Ok(()) } @@ -140,7 +140,7 @@ mod tests { let ctx = generate_context_with_table("DATA", vec![("a", DataType::Date32, true)])?; - let res = from_substrait_plan(&ctx, &proto_plan).await; + let res = from_substrait_plan(&ctx.state(), &proto_plan).await; assert!(res.is_err()); Ok(()) }