Skip to content

Commit

Permalink
Add documentation for prepared parameters + make it eaiser to use
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Oct 10, 2023
1 parent b6f87ed commit f79bbe4
Show file tree
Hide file tree
Showing 6 changed files with 175 additions and 70 deletions.
37 changes: 36 additions & 1 deletion datafusion/core/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1218,7 +1218,42 @@ impl DataFrame {
Ok(DataFrame::new(self.session_state, project_plan))
}

/// Convert a prepare logical plan into its inner logical plan with all params replaced with their corresponding values
/// Replace all parameters in logical plan with the specified
/// values, in preparation for execution.
///
/// # Example
///
/// ```
/// use datafusion::prelude::*;
/// # use datafusion::{error::Result, assert_batches_eq};
/// # #[tokio::main]
/// # async fn main() -> Result<()> {
/// # use datafusion_common::ScalarValue;
/// let mut ctx = SessionContext::new();
/// # ctx.register_csv("example", "tests/data/example.csv", CsvReadOptions::new()).await?;
/// let results = ctx
/// .sql("SELECT a FROM example WHERE b = $1")
/// .await?
/// // replace $1 with value 2
/// .with_param_values(vec![
/// // value at index 0 --> $1
/// ScalarValue::from(2i64)
/// ])?
/// .collect()
/// .await?;
/// assert_batches_eq!(
/// &[
/// "+---+",
/// "| a |",
/// "+---+",
/// "| 1 |",
/// "+---+",
/// ],
/// &results
/// );
/// # Ok(())
/// # }
/// ```
pub fn with_param_values(self, param_values: Vec<ScalarValue>) -> Result<Self> {
let plan = self.plan.with_param_values(param_values)?;
Ok(Self::new(self.session_state, plan))
Expand Down
58 changes: 54 additions & 4 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

//! Expr module contains core type definition for `Expr`.
use crate::aggregate_function;
use crate::built_in_function;
use crate::expr_fn::binary_expr;
use crate::logical_plan::Subquery;
Expand All @@ -26,8 +25,10 @@ use crate::utils::{expr_to_columns, find_out_reference_exprs};
use crate::window_frame;
use crate::window_function;
use crate::Operator;
use crate::{aggregate_function, ExprSchemable};
use arrow::datatypes::DataType;
use datafusion_common::internal_err;
use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_common::{internal_err, DFSchema};
use datafusion_common::{plan_err, Column, DataFusionError, Result, ScalarValue};
use std::collections::HashSet;
use std::fmt;
Expand Down Expand Up @@ -599,10 +600,13 @@ impl InSubquery {
}
}

/// Placeholder
/// Placeholder, representing bind parameter values such as `$1`.
///
/// The type of these parameters is inferred using [`Expr::infer_placeholder_types`]
/// or can be specified directly using `PREPARE` statements.
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct Placeholder {
/// The identifier of the parameter (e.g, $1 or $foo)
/// The identifier of the parameter, including the leading `$` (e.g, `"$1"` or `"$foo'`)
pub id: String,
/// The type the parameter will be filled in with
pub data_type: Option<DataType>,
Expand Down Expand Up @@ -1030,6 +1034,52 @@ impl Expr {
pub fn contains_outer(&self) -> bool {
!find_out_reference_exprs(self).is_empty()
}

/// Recursively find all [`Expr::Placeholder`] expressions, and
/// to infer their [`DataType`] from the context of their use.
///
/// For example, gicen an expression like `<int32> = $0` will infer `$0` to
/// have type `int32`.
pub fn infer_placeholder_types(self, schema: &DFSchema) -> Result<Expr> {
self.transform(&|mut expr| {
// Default to assuming the arguments are the same type
if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) = &mut expr {
rewrite_placeholder(left.as_mut(), right.as_ref(), schema)?;
rewrite_placeholder(right.as_mut(), left.as_ref(), schema)?;
};
if let Expr::Between(Between {
expr,
negated: _,
low,
high,
}) = &mut expr
{
rewrite_placeholder(low.as_mut(), expr.as_ref(), schema)?;
rewrite_placeholder(high.as_mut(), expr.as_ref(), schema)?;
}
Ok(Transformed::Yes(expr))
})
}
}

// modifies expr if it is a placeholder with datatype of right
fn rewrite_placeholder(expr: &mut Expr, other: &Expr, schema: &DFSchema) -> Result<()> {
if let Expr::Placeholder(Placeholder { id: _, data_type }) = expr {
if data_type.is_none() {
let other_dt = other.get_type(schema);
match other_dt {
Err(e) => {
Err(e.context(format!(
"Can not find type of {other} needed to infer type of {expr}"
)))?;
}
Ok(dt) => {
*data_type = Some(dt);
}
}
};
}
Ok(())
}

#[macro_export]
Expand Down
20 changes: 19 additions & 1 deletion datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
use crate::expr::{
AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList, InSubquery,
ScalarFunction, TryCast,
Placeholder, ScalarFunction, TryCast,
};
use crate::function::PartitionEvaluatorFactory;
use crate::WindowUDF;
Expand Down Expand Up @@ -80,6 +80,24 @@ pub fn ident(name: impl Into<String>) -> Expr {
Expr::Column(Column::from_name(name))
}

/// Create placeholder value that will be filled in (such as `$1`)
///
/// Note the parameter type can be inferred using [`Expr::infer_placeholder_types`]
///
/// # Example
///
/// ```rust
/// # use datafusion_expr::{placeholder};
/// let p = placeholder("$0"); // $0, refers to parameter 1
/// assert_eq!(p.to_string(), "$0")
/// ```
pub fn placeholder(id: impl Into<String>) -> Expr {
Expr::Placeholder(Placeholder {
id: id.into(),
data_type: None,
})
}

/// Return a new expression `left <op> right`
pub fn binary_expr(left: Expr, op: Operator, right: Expr) -> Expr {
Expr::BinaryExpr(BinaryExpr::new(Box::new(left), op, Box::new(right)))
Expand Down
69 changes: 51 additions & 18 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -928,8 +928,40 @@ impl LogicalPlan {
}
}
}
/// Convert a prepared [`LogicalPlan`] into its inner logical plan
/// with all params replaced with their corresponding values
/// Replaces placeholder param values (like `$1`, `$2`) in [`LogicalPlan`]
/// with the specified `param_values`.
///
/// [`LogicalPlan::Prepare`] are
/// converted to their inner logical plan for execution.
///
/// # Example
/// ```
/// # use arrow::datatypes::{Field, Schema, DataType};
/// use datafusion_common::ScalarValue;
/// # use datafusion_expr::{lit, col, LogicalPlanBuilder, logical_plan::table_scan, placeholder};
/// # let schema = Schema::new(vec![
/// # Field::new("id", DataType::Int32, false),
/// # ]);
/// // Build SELECT * FROM t1 WHRERE id = $1
/// let plan = table_scan(Some("t1"), &schema, None).unwrap()
/// .filter(col("id").eq(placeholder("$1"))).unwrap()
/// .build().unwrap();
///
/// assert_eq!("Filter: t1.id = $1\
/// \n TableScan: t1",
/// plan.display_indent().to_string()
/// );
///
/// // Fill in the parameter $1 with a literal 3
/// let plan = plan.with_param_values(vec![
/// ScalarValue::from(3i32) // value at index 0 --> $1
/// ]).unwrap();
///
/// assert_eq!("Filter: t1.id = Int32(3)\
/// \n TableScan: t1",
/// plan.display_indent().to_string()
/// );
/// ```
pub fn with_param_values(
self,
param_values: Vec<ScalarValue>,
Expand Down Expand Up @@ -961,7 +993,7 @@ impl LogicalPlan {
let input_plan = prepare_lp.input;
input_plan.replace_params_with_values(&param_values)
}
_ => Ok(self),
_ => self.replace_params_with_values(&param_values),
}
}

Expand Down Expand Up @@ -1060,7 +1092,7 @@ impl LogicalPlan {
}

impl LogicalPlan {
/// applies collect to any subqueries in the plan
/// applies `op` to any subqueries in the plan
pub(crate) fn apply_subqueries<F>(&self, op: &mut F) -> datafusion_common::Result<()>
where
F: FnMut(&Self) -> datafusion_common::Result<VisitRecursion>,
Expand Down Expand Up @@ -1112,17 +1144,22 @@ impl LogicalPlan {
Ok(())
}

/// Return a logical plan with all placeholders/params (e.g $1 $2,
/// ...) replaced with corresponding values provided in the
/// params_values
/// Return a `LogicalPlan` with all placeholders (e.g $1 $2,
/// ...) replaced with corresponding values provided in
/// `params_values`
///
/// See [`Self::with_param_values`] for examples and usage
pub fn replace_params_with_values(
&self,
param_values: &[ScalarValue],
) -> Result<LogicalPlan> {
let new_exprs = self
.expressions()
.into_iter()
.map(|e| Self::replace_placeholders_with_values(e, param_values))
.map(|e| {
let e = e.infer_placeholder_types(self.schema())?;
Self::replace_placeholders_with_values(e, param_values)
})
.collect::<Result<Vec<_>>>()?;

let new_inputs_with_values = self
Expand Down Expand Up @@ -1219,7 +1256,9 @@ impl LogicalPlan {
// Various implementations for printing out LogicalPlans
impl LogicalPlan {
/// Return a `format`able structure that produces a single line
/// per node. For example:
/// per node.
///
/// # Example
///
/// ```text
/// Projection: employee.id
Expand Down Expand Up @@ -2321,7 +2360,7 @@ pub struct Unnest {
mod tests {
use super::*;
use crate::logical_plan::table_scan;
use crate::{col, exists, in_subquery, lit};
use crate::{col, exists, in_subquery, lit, placeholder};
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::tree_node::TreeNodeVisitor;
use datafusion_common::{not_impl_err, DFSchema, TableReference};
Expand Down Expand Up @@ -2767,10 +2806,7 @@ digraph {

let plan = table_scan(TableReference::none(), &schema, None)
.unwrap()
.filter(col("id").eq(Expr::Placeholder(Placeholder::new(
"".into(),
Some(DataType::Int32),
))))
.filter(col("id").eq(placeholder("")))
.unwrap()
.build()
.unwrap();
Expand All @@ -2783,10 +2819,7 @@ digraph {

let plan = table_scan(TableReference::none(), &schema, None)
.unwrap()
.filter(col("id").eq(Expr::Placeholder(Placeholder::new(
"$0".into(),
Some(DataType::Int32),
))))
.filter(col("id").eq(placeholder("$0")))
.unwrap()
.build()
.unwrap();
Expand Down
48 changes: 2 additions & 46 deletions datafusion/sql/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,12 @@ mod value;

use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
use arrow_schema::DataType;
use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_common::{
internal_err, not_impl_err, plan_err, Column, DFSchema, DataFusionError, Result,
ScalarValue,
};
use datafusion_expr::expr::InList;
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::expr::{InList, Placeholder};
use datafusion_expr::{
col, expr, lit, AggregateFunction, Between, BinaryExpr, BuiltinScalarFunction, Cast,
Expr, ExprSchemable, GetFieldAccess, GetIndexedField, Like, Operator, TryCast,
Expand Down Expand Up @@ -122,7 +121,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
let mut expr = self.sql_expr_to_logical_expr(sql, schema, planner_context)?;
expr = self.rewrite_partial_qualifier(expr, schema);
self.validate_schema_satisfies_exprs(schema, &[expr.clone()])?;
let expr = infer_placeholder_types(expr, schema)?;
let expr = expr.infer_placeholder_types(schema)?;
Ok(expr)
}

Expand Down Expand Up @@ -712,49 +711,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}
}

// modifies expr if it is a placeholder with datatype of right
fn rewrite_placeholder(expr: &mut Expr, other: &Expr, schema: &DFSchema) -> Result<()> {
if let Expr::Placeholder(Placeholder { id: _, data_type }) = expr {
if data_type.is_none() {
let other_dt = other.get_type(schema);
match other_dt {
Err(e) => {
Err(e.context(format!(
"Can not find type of {other} needed to infer type of {expr}"
)))?;
}
Ok(dt) => {
*data_type = Some(dt);
}
}
};
}
Ok(())
}

/// Find all [`Expr::Placeholder`] tokens in a logical plan, and try
/// to infer their [`DataType`] from the context of their use.
fn infer_placeholder_types(expr: Expr, schema: &DFSchema) -> Result<Expr> {
expr.transform(&|mut expr| {
// Default to assuming the arguments are the same type
if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) = &mut expr {
rewrite_placeholder(left.as_mut(), right.as_ref(), schema)?;
rewrite_placeholder(right.as_mut(), left.as_ref(), schema)?;
};
if let Expr::Between(Between {
expr,
negated: _,
low,
high,
}) = &mut expr
{
rewrite_placeholder(low.as_mut(), expr.as_ref(), schema)?;
rewrite_placeholder(high.as_mut(), expr.as_ref(), schema)?;
}
Ok(Transformed::Yes(expr))
})
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
13 changes: 13 additions & 0 deletions datafusion/sql/tests/sql_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3684,6 +3684,19 @@ fn test_prepare_statement_should_infer_types() {
assert_eq!(actual_types, expected_types);
}

#[test]
fn test_non_prepare_statement_should_infer_types() {
// Non prepared statements (like SELECT) should also have their parameter types inferred
let sql = "SELECT 1 + $1";
let plan = logical_plan(sql).unwrap();
let actual_types = plan.get_parameter_types().unwrap();
let expected_types = HashMap::from([
// constant 1 is inferred to be int64
("$1".to_string(), Some(DataType::Int64)),
]);
assert_eq!(actual_types, expected_types);
}

#[test]
#[should_panic(
expected = "value: SQL(ParserError(\"Expected [NOT] NULL or TRUE|FALSE or [NOT] DISTINCT FROM after IS, found: $1\""
Expand Down

0 comments on commit f79bbe4

Please sign in to comment.