diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs index f6490801126f..730638495811 100644 --- a/datafusion/core/src/dataframe.rs +++ b/datafusion/core/src/dataframe.rs @@ -1218,7 +1218,42 @@ impl DataFrame { Ok(DataFrame::new(self.session_state, project_plan)) } - /// Convert a prepare logical plan into its inner logical plan with all params replaced with their corresponding values + /// Replace all parameters in logical plan with the specified + /// values, in preparation for execution. + /// + /// # Example + /// + /// ``` + /// use datafusion::prelude::*; + /// # use datafusion::{error::Result, assert_batches_eq}; + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// # use datafusion_common::ScalarValue; + /// let mut ctx = SessionContext::new(); + /// # ctx.register_csv("example", "tests/data/example.csv", CsvReadOptions::new()).await?; + /// let results = ctx + /// .sql("SELECT a FROM example WHERE b = $1") + /// .await? + /// // replace $1 with value 2 + /// .with_param_values(vec![ + /// // value at index 0 --> $1 + /// ScalarValue::from(2i64) + /// ])? + /// .collect() + /// .await?; + /// assert_batches_eq!( + /// &[ + /// "+---+", + /// "| a |", + /// "+---+", + /// "| 1 |", + /// "+---+", + /// ], + /// &results + /// ); + /// # Ok(()) + /// # } + /// ``` pub fn with_param_values(self, param_values: Vec) -> Result { let plan = self.plan.with_param_values(param_values)?; Ok(Self::new(self.session_state, plan)) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 3e4e3068977c..f688f7371e9f 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -17,7 +17,6 @@ //! Expr module contains core type definition for `Expr`. -use crate::aggregate_function; use crate::built_in_function; use crate::expr_fn::binary_expr; use crate::logical_plan::Subquery; @@ -26,8 +25,10 @@ use crate::utils::{expr_to_columns, find_out_reference_exprs}; use crate::window_frame; use crate::window_function; use crate::Operator; +use crate::{aggregate_function, ExprSchemable}; use arrow::datatypes::DataType; -use datafusion_common::internal_err; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::{internal_err, DFSchema}; use datafusion_common::{plan_err, Column, DataFusionError, Result, ScalarValue}; use std::collections::HashSet; use std::fmt; @@ -599,10 +600,13 @@ impl InSubquery { } } -/// Placeholder +/// Placeholder, representing bind parameter values such as `$1`. +/// +/// The type of these parameters is inferred using [`Expr::infer_placeholder_types`] +/// or can be specified directly using `PREPARE` statements. #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct Placeholder { - /// The identifier of the parameter (e.g, $1 or $foo) + /// The identifier of the parameter, including the leading `$` (e.g, `"$1"` or `"$foo'`) pub id: String, /// The type the parameter will be filled in with pub data_type: Option, @@ -1030,6 +1034,52 @@ impl Expr { pub fn contains_outer(&self) -> bool { !find_out_reference_exprs(self).is_empty() } + + /// Recursively find all [`Expr::Placeholder`] expressions, and + /// to infer their [`DataType`] from the context of their use. + /// + /// For example, gicen an expression like ` = $0` will infer `$0` to + /// have type `int32`. + pub fn infer_placeholder_types(self, schema: &DFSchema) -> Result { + self.transform(&|mut expr| { + // Default to assuming the arguments are the same type + if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) = &mut expr { + rewrite_placeholder(left.as_mut(), right.as_ref(), schema)?; + rewrite_placeholder(right.as_mut(), left.as_ref(), schema)?; + }; + if let Expr::Between(Between { + expr, + negated: _, + low, + high, + }) = &mut expr + { + rewrite_placeholder(low.as_mut(), expr.as_ref(), schema)?; + rewrite_placeholder(high.as_mut(), expr.as_ref(), schema)?; + } + Ok(Transformed::Yes(expr)) + }) + } +} + +// modifies expr if it is a placeholder with datatype of right +fn rewrite_placeholder(expr: &mut Expr, other: &Expr, schema: &DFSchema) -> Result<()> { + if let Expr::Placeholder(Placeholder { id: _, data_type }) = expr { + if data_type.is_none() { + let other_dt = other.get_type(schema); + match other_dt { + Err(e) => { + Err(e.context(format!( + "Can not find type of {other} needed to infer type of {expr}" + )))?; + } + Ok(dt) => { + *data_type = Some(dt); + } + } + }; + } + Ok(()) } #[macro_export] diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 711dc123a4a4..79a43c2353db 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -19,7 +19,7 @@ use crate::expr::{ AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList, InSubquery, - ScalarFunction, TryCast, + Placeholder, ScalarFunction, TryCast, }; use crate::function::PartitionEvaluatorFactory; use crate::WindowUDF; @@ -80,6 +80,24 @@ pub fn ident(name: impl Into) -> Expr { Expr::Column(Column::from_name(name)) } +/// Create placeholder value that will be filled in (such as `$1`) +/// +/// Note the parameter type can be inferred using [`Expr::infer_placeholder_types`] +/// +/// # Example +/// +/// ```rust +/// # use datafusion_expr::{placeholder}; +/// let p = placeholder("$0"); // $0, refers to parameter 1 +/// assert_eq!(p.to_string(), "$0") +/// ``` +pub fn placeholder(id: impl Into) -> Expr { + Expr::Placeholder(Placeholder { + id: id.into(), + data_type: None, + }) +} + /// Return a new expression `left right` pub fn binary_expr(left: Expr, op: Operator, right: Expr) -> Expr { Expr::BinaryExpr(BinaryExpr::new(Box::new(left), op, Box::new(right))) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index dfc83f9eec76..0bb359a9e91b 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -928,8 +928,40 @@ impl LogicalPlan { } } } - /// Convert a prepared [`LogicalPlan`] into its inner logical plan - /// with all params replaced with their corresponding values + /// Replaces placeholder param values (like `$1`, `$2`) in [`LogicalPlan`] + /// with the specified `param_values`. + /// + /// [`LogicalPlan::Prepare`] are + /// converted to their inner logical plan for execution. + /// + /// # Example + /// ``` + /// # use arrow::datatypes::{Field, Schema, DataType}; + /// use datafusion_common::ScalarValue; + /// # use datafusion_expr::{lit, col, LogicalPlanBuilder, logical_plan::table_scan, placeholder}; + /// # let schema = Schema::new(vec![ + /// # Field::new("id", DataType::Int32, false), + /// # ]); + /// // Build SELECT * FROM t1 WHRERE id = $1 + /// let plan = table_scan(Some("t1"), &schema, None).unwrap() + /// .filter(col("id").eq(placeholder("$1"))).unwrap() + /// .build().unwrap(); + /// + /// assert_eq!("Filter: t1.id = $1\ + /// \n TableScan: t1", + /// plan.display_indent().to_string() + /// ); + /// + /// // Fill in the parameter $1 with a literal 3 + /// let plan = plan.with_param_values(vec![ + /// ScalarValue::from(3i32) // value at index 0 --> $1 + /// ]).unwrap(); + /// + /// assert_eq!("Filter: t1.id = Int32(3)\ + /// \n TableScan: t1", + /// plan.display_indent().to_string() + /// ); + /// ``` pub fn with_param_values( self, param_values: Vec, @@ -961,7 +993,7 @@ impl LogicalPlan { let input_plan = prepare_lp.input; input_plan.replace_params_with_values(¶m_values) } - _ => Ok(self), + _ => self.replace_params_with_values(¶m_values), } } @@ -1060,7 +1092,7 @@ impl LogicalPlan { } impl LogicalPlan { - /// applies collect to any subqueries in the plan + /// applies `op` to any subqueries in the plan pub(crate) fn apply_subqueries(&self, op: &mut F) -> datafusion_common::Result<()> where F: FnMut(&Self) -> datafusion_common::Result, @@ -1112,9 +1144,11 @@ impl LogicalPlan { Ok(()) } - /// Return a logical plan with all placeholders/params (e.g $1 $2, - /// ...) replaced with corresponding values provided in the - /// params_values + /// Return a `LogicalPlan` with all placeholders (e.g $1 $2, + /// ...) replaced with corresponding values provided in + /// `params_values` + /// + /// See [`Self::with_param_values`] for examples and usage pub fn replace_params_with_values( &self, param_values: &[ScalarValue], @@ -1122,7 +1156,10 @@ impl LogicalPlan { let new_exprs = self .expressions() .into_iter() - .map(|e| Self::replace_placeholders_with_values(e, param_values)) + .map(|e| { + let e = e.infer_placeholder_types(self.schema())?; + Self::replace_placeholders_with_values(e, param_values) + }) .collect::>>()?; let new_inputs_with_values = self @@ -1219,7 +1256,9 @@ impl LogicalPlan { // Various implementations for printing out LogicalPlans impl LogicalPlan { /// Return a `format`able structure that produces a single line - /// per node. For example: + /// per node. + /// + /// # Example /// /// ```text /// Projection: employee.id @@ -2321,7 +2360,7 @@ pub struct Unnest { mod tests { use super::*; use crate::logical_plan::table_scan; - use crate::{col, exists, in_subquery, lit}; + use crate::{col, exists, in_subquery, lit, placeholder}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::tree_node::TreeNodeVisitor; use datafusion_common::{not_impl_err, DFSchema, TableReference}; @@ -2767,10 +2806,7 @@ digraph { let plan = table_scan(TableReference::none(), &schema, None) .unwrap() - .filter(col("id").eq(Expr::Placeholder(Placeholder::new( - "".into(), - Some(DataType::Int32), - )))) + .filter(col("id").eq(placeholder(""))) .unwrap() .build() .unwrap(); @@ -2783,10 +2819,7 @@ digraph { let plan = table_scan(TableReference::none(), &schema, None) .unwrap() - .filter(col("id").eq(Expr::Placeholder(Placeholder::new( - "$0".into(), - Some(DataType::Int32), - )))) + .filter(col("id").eq(placeholder("$0"))) .unwrap() .build() .unwrap(); diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index cb34b6ca36e8..a90a0f121f26 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -29,13 +29,12 @@ mod value; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use arrow_schema::DataType; -use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{ internal_err, not_impl_err, plan_err, Column, DFSchema, DataFusionError, Result, ScalarValue, }; +use datafusion_expr::expr::InList; use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::expr::{InList, Placeholder}; use datafusion_expr::{ col, expr, lit, AggregateFunction, Between, BinaryExpr, BuiltinScalarFunction, Cast, Expr, ExprSchemable, GetFieldAccess, GetIndexedField, Like, Operator, TryCast, @@ -122,7 +121,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let mut expr = self.sql_expr_to_logical_expr(sql, schema, planner_context)?; expr = self.rewrite_partial_qualifier(expr, schema); self.validate_schema_satisfies_exprs(schema, &[expr.clone()])?; - let expr = infer_placeholder_types(expr, schema)?; + let expr = expr.infer_placeholder_types(schema)?; Ok(expr) } @@ -712,49 +711,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } -// modifies expr if it is a placeholder with datatype of right -fn rewrite_placeholder(expr: &mut Expr, other: &Expr, schema: &DFSchema) -> Result<()> { - if let Expr::Placeholder(Placeholder { id: _, data_type }) = expr { - if data_type.is_none() { - let other_dt = other.get_type(schema); - match other_dt { - Err(e) => { - Err(e.context(format!( - "Can not find type of {other} needed to infer type of {expr}" - )))?; - } - Ok(dt) => { - *data_type = Some(dt); - } - } - }; - } - Ok(()) -} - -/// Find all [`Expr::Placeholder`] tokens in a logical plan, and try -/// to infer their [`DataType`] from the context of their use. -fn infer_placeholder_types(expr: Expr, schema: &DFSchema) -> Result { - expr.transform(&|mut expr| { - // Default to assuming the arguments are the same type - if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) = &mut expr { - rewrite_placeholder(left.as_mut(), right.as_ref(), schema)?; - rewrite_placeholder(right.as_mut(), left.as_ref(), schema)?; - }; - if let Expr::Between(Between { - expr, - negated: _, - low, - high, - }) = &mut expr - { - rewrite_placeholder(low.as_mut(), expr.as_ref(), schema)?; - rewrite_placeholder(high.as_mut(), expr.as_ref(), schema)?; - } - Ok(Transformed::Yes(expr)) - }) -} - #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index b1de5a12bcd0..a39384502ff2 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -3684,6 +3684,19 @@ fn test_prepare_statement_should_infer_types() { assert_eq!(actual_types, expected_types); } +#[test] +fn test_non_prepare_statement_should_infer_types() { + // Non prepared statements (like SELECT) should also have their parameter types inferred + let sql = "SELECT 1 + $1"; + let plan = logical_plan(sql).unwrap(); + let actual_types = plan.get_parameter_types().unwrap(); + let expected_types = HashMap::from([ + // constant 1 is inferred to be int64 + ("$1".to_string(), Some(DataType::Int64)), + ]); + assert_eq!(actual_types, expected_types); +} + #[test] #[should_panic( expected = "value: SQL(ParserError(\"Expected [NOT] NULL or TRUE|FALSE or [NOT] DISTINCT FROM after IS, found: $1\""