diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 515553152659a..d822372984362 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -114,6 +114,9 @@ use substrait::proto::{ /// This trait is used to consume Substrait plans, converting them into DataFusion Logical Plans. /// It can be implemented by users to allow for custom handling of relations, expressions, etc. /// +/// Combined with the [crate::logical_plan::producer::SubstraitProducer] this allows for fully +/// customizable Substrait serde. +/// /// # Example Usage /// /// ``` diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 0f4a062e2b1b8..007e04040126b 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -105,11 +105,84 @@ use substrait::{ version, }; +/// This trait is used to produce Substrait plans, converting them from DataFusion Logical Plans. +/// It can be implemented by users to allow for custom handling of relations, expressions, etc. +/// +/// Combined with the [crate::logical_plan::consumer::SubstraitConsumer] this allows for fully +/// customizable Substrait serde. +/// +/// # Example Usage +/// +/// ``` +/// # use std::sync::Arc; +/// # use substrait::proto::{Expression, Rel}; +/// # use substrait::proto::rel::RelType; +/// # use datafusion::common::DFSchemaRef; +/// # use datafusion::error::Result; +/// # use datafusion::execution::SessionState; +/// # use datafusion::logical_expr::{Between, Extension, Projection}; +/// # use datafusion_substrait::extensions::Extensions; +/// # use datafusion_substrait::logical_plan::producer::{from_projection, SubstraitProducer}; +/// +/// struct CustomSubstraitProducer { +/// extensions: Extensions, +/// state: Arc, +/// } +/// +/// impl SubstraitProducer for CustomSubstraitProducer { +/// +/// fn register_function(&mut self, signature: String) -> u32 { +/// self.extensions.register_function(signature) +/// } +/// +/// fn get_extensions(self) -> Extensions { +/// self.extensions +/// } +/// +/// // You can set additional metadata on the Rels you produce +/// fn consume_projection(&mut self, plan: &Projection) -> Result> { +/// let mut rel = from_projection(self, plan)?; +/// match rel.rel_type { +/// Some(RelType::Project(mut project)) => { +/// let mut project = project.clone(); +/// // set common metadata or advanced extension +/// project.common = None; +/// project.advanced_extension = None; +/// Ok(Box::new(Rel { +/// rel_type: Some(RelType::Project(project)), +/// })) +/// } +/// rel_type => Ok(Box::new(Rel { rel_type })), +/// } +/// } +/// +/// // You can tweak how you convert expressions for your target system +/// fn consume_between(&mut self, between: &Between, schema: &DFSchemaRef) -> Result { +/// // add your own encoding for Between +/// todo!() +/// } +/// +/// // You can fully control how you convert UserDefinedLogicalNode into Substrait +/// fn consume_extension(&mut self, _plan: &Extension) -> Result> { +/// // implement your own serializer into Substrait +/// todo!() +/// } +/// } +/// ``` pub trait SubstraitProducer: Send + Sync + Sized { - fn get_extensions(self) -> Extensions; - + /// Within a Substrait plan, functions are referenced using function anchors that are stored at + /// the top level of the [Plan] within + /// [ExtensionFunction](substrait::proto::extensions::simple_extension_declaration::ExtensionFunction) + /// messages. + /// + /// When given a function signature, this method should return the existing anchor for it if + /// there is one. Otherwise, it should generate a new anchor. fn register_function(&mut self, signature: String) -> u32; + /// Consume the producer to generate the [Extensions] for the Substrait plan based on the + /// functions that have been registered + fn get_extensions(self) -> Extensions; + // Logical Plans fn consume_plan(&mut self, plan: &LogicalPlan) -> Result> { to_substrait_rel(self, plan) @@ -301,14 +374,14 @@ impl<'a> DefaultSubstraitProducer<'a> { } impl SubstraitProducer for DefaultSubstraitProducer<'_> { - fn get_extensions(self) -> Extensions { - self.extensions - } - fn register_function(&mut self, fn_name: String) -> u32 { self.extensions.register_function(fn_name) } + fn get_extensions(self) -> Extensions { + self.extensions + } + fn consume_extension(&mut self, plan: &Extension) -> Result> { let extension_bytes = self .state @@ -1164,7 +1237,7 @@ pub fn to_substrait_agg_measure( ) -> Result { match expr { Expr::AggregateFunction(agg_fn) => from_aggregate_function(producer, agg_fn, schema), - Expr::Alias(Alias{expr,..}) => { + Expr::Alias(Alias { expr, .. }) => { to_substrait_agg_measure(producer, expr, schema) } _ => internal_err!( @@ -2631,7 +2704,7 @@ mod test { ], false, ) - .into(), + .into(), false, ))?; @@ -2640,7 +2713,7 @@ mod test { Field::new("c0", DataType::Int32, true), Field::new("c1", DataType::Utf8, true), ] - .into(), + .into(), ))?; round_trip_type(DataType::Interval(IntervalUnit::YearMonth))?; diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 772bf2e7ad8e6..7045729493b11 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -583,7 +583,7 @@ async fn self_join_introduces_aliases() -> Result<()> { \n TableScan: data projection=[b, c]", false, ) - .await + .await } #[tokio::test]