Skip to content

Commit

Permalink
Simplify windows builtin functions return type (apache#8920)
Browse files Browse the repository at this point in the history
* Simplify windows builtin functions

* add field comments
  • Loading branch information
comphead authored Jan 22, 2024
1 parent c0a69a7 commit 2b218be
Show file tree
Hide file tree
Showing 12 changed files with 186 additions and 63 deletions.
25 changes: 11 additions & 14 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ use datafusion_expr::expr::{
};
use datafusion_expr::expr_rewriter::unnormalize_cols;
use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary;
use datafusion_expr::utils::exprlist_to_fields;
use datafusion_expr::{
DescribeTable, DmlStatement, RecursiveQuery, ScalarFunctionDefinition,
StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp,
Expand Down Expand Up @@ -719,14 +720,16 @@ impl DefaultPhysicalPlanner {
}

let logical_input_schema = input.schema();
let physical_input_schema = input_exec.schema();
// Extend the schema to include window expression fields as builtin window functions derives its datatype from incoming schema
let mut window_fields = logical_input_schema.fields().clone();
window_fields.extend_from_slice(&exprlist_to_fields(window_expr.iter(), input)?);
let extended_schema = &DFSchema::new_with_metadata(window_fields, HashMap::new())?;
let window_expr = window_expr
.iter()
.map(|e| {
create_window_expr(
e,
logical_input_schema,
&physical_input_schema,
extended_schema,
session_state.execution_props(),
)
})
Expand Down Expand Up @@ -1529,7 +1532,7 @@ fn get_physical_expr_pair(
/// queries like:
/// OVER (ORDER BY a RANGES BETWEEN 3 PRECEDING AND 5 PRECEDING)
/// OVER (ORDER BY a RANGES BETWEEN INTERVAL '3 DAY' PRECEDING AND '5 DAY' PRECEDING) are rejected
pub fn is_window_valid(window_frame: &WindowFrame) -> bool {
pub fn is_window_frame_bound_valid(window_frame: &WindowFrame) -> bool {
match (&window_frame.start_bound, &window_frame.end_bound) {
(WindowFrameBound::Following(_), WindowFrameBound::Preceding(_))
| (WindowFrameBound::Following(_), WindowFrameBound::CurrentRow)
Expand All @@ -1549,10 +1552,10 @@ pub fn create_window_expr_with_name(
e: &Expr,
name: impl Into<String>,
logical_input_schema: &DFSchema,
physical_input_schema: &Schema,
execution_props: &ExecutionProps,
) -> Result<Arc<dyn WindowExpr>> {
let name = name.into();
let physical_input_schema: &Schema = &logical_input_schema.into();
match e {
Expr::WindowFunction(WindowFunction {
fun,
Expand All @@ -1575,7 +1578,8 @@ pub fn create_window_expr_with_name(
create_physical_sort_expr(e, logical_input_schema, execution_props)
})
.collect::<Result<Vec<_>>>()?;
if !is_window_valid(window_frame) {

if !is_window_frame_bound_valid(window_frame) {
return plan_err!(
"Invalid window frame: start bound ({}) cannot be larger than end bound ({})",
window_frame.start_bound, window_frame.end_bound
Expand All @@ -1601,21 +1605,14 @@ pub fn create_window_expr_with_name(
pub fn create_window_expr(
e: &Expr,
logical_input_schema: &DFSchema,
physical_input_schema: &Schema,
execution_props: &ExecutionProps,
) -> Result<Arc<dyn WindowExpr>> {
// unpack aliased logical expressions, e.g. "sum(col) over () as total"
let (name, e) = match e {
Expr::Alias(Alias { expr, name, .. }) => (name.clone(), expr.as_ref()),
_ => (e.display_name()?, e),
};
create_window_expr_with_name(
e,
name,
logical_input_schema,
physical_input_schema,
execution_props,
)
create_window_expr_with_name(e, name, logical_input_schema, execution_props)
}

type AggregateExprWithOptionalArgs = (
Expand Down
41 changes: 38 additions & 3 deletions datafusion/core/tests/fuzz_cases/window_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use arrow::compute::{concat_batches, SortOptions};
use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
use arrow::util::pretty::pretty_format_batches;
use arrow_schema::{Field, Schema};
use datafusion::physical_plan::memory::MemoryExec;
use datafusion::physical_plan::sorts::sort::SortExec;
use datafusion::physical_plan::windows::{
Expand All @@ -37,6 +38,7 @@ use datafusion_expr::{
};
use datafusion_physical_expr::expressions::{cast, col, lit};
use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
use itertools::Itertools;
use test_utils::add_empty_batches;

use hashbrown::HashMap;
Expand Down Expand Up @@ -482,7 +484,6 @@ async fn run_window_test(
let session_config = SessionConfig::new().with_batch_size(50);
let ctx = SessionContext::new_with_config(session_config);
let (window_fn, args, fn_name) = get_random_function(&schema, &mut rng, is_linear);

let window_frame = get_random_window_frame(&mut rng, is_linear);
let mut orderby_exprs = vec![];
for column in &orderby_columns {
Expand Down Expand Up @@ -532,6 +533,40 @@ async fn run_window_test(
if is_linear {
exec1 = Arc::new(SortExec::new(sort_keys.clone(), exec1)) as _;
}

// The schema needs to be enriched before the `create_window_expr`
// The reason for this is window expressions datatypes are derived from the schema
// The datafusion code enriches the schema on physical planner and this test copies the same behavior manually
// Also bunch of functions dont require input arguments thus just send an empty vec for such functions
let data_types = if [
"row_number",
"rank",
"dense_rank",
"percent_rank",
"ntile",
"cume_dist",
]
.contains(&fn_name.as_str())
{
vec![]
} else {
args.iter()
.map(|e| e.clone().as_ref().data_type(&schema))
.collect::<Result<Vec<_>>>()?
};
let window_expr_return_type = window_fn.return_type(&data_types)?;
let mut window_fields = schema
.fields()
.iter()
.map(|f| f.as_ref().clone())
.collect_vec();
window_fields.extend_from_slice(&[Field::new(
&fn_name,
window_expr_return_type,
true,
)]);
let extended_schema = Arc::new(Schema::new(window_fields));

let usual_window_exec = Arc::new(
WindowAggExec::try_new(
vec![create_window_expr(
Expand All @@ -541,7 +576,7 @@ async fn run_window_test(
&partitionby_exprs,
&orderby_exprs,
Arc::new(window_frame.clone()),
schema.as_ref(),
&extended_schema,
)
.unwrap()],
exec1,
Expand All @@ -563,7 +598,7 @@ async fn run_window_test(
&partitionby_exprs,
&orderby_exprs,
Arc::new(window_frame.clone()),
schema.as_ref(),
extended_schema.as_ref(),
)
.unwrap()],
exec2,
Expand Down
4 changes: 2 additions & 2 deletions datafusion/expr/src/built_in_window_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,11 @@ impl BuiltInWindowFunction {
match self {
BuiltInWindowFunction::RowNumber
| BuiltInWindowFunction::Rank
| BuiltInWindowFunction::DenseRank => Ok(DataType::UInt64),
| BuiltInWindowFunction::DenseRank
| BuiltInWindowFunction::Ntile => Ok(DataType::UInt64),
BuiltInWindowFunction::PercentRank | BuiltInWindowFunction::CumeDist => {
Ok(DataType::Float64)
}
BuiltInWindowFunction::Ntile => Ok(DataType::UInt64),
BuiltInWindowFunction::Lag
| BuiltInWindowFunction::Lead
| BuiltInWindowFunction::FirstValue
Expand Down
14 changes: 9 additions & 5 deletions datafusion/physical-expr/src/window/cume_dist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,16 @@ use std::sync::Arc;
#[derive(Debug)]
pub struct CumeDist {
name: String,
/// Output data type
data_type: DataType,
}

/// Create a cume_dist window function
pub fn cume_dist(name: String) -> CumeDist {
CumeDist { name }
pub fn cume_dist(name: String, data_type: &DataType) -> CumeDist {
CumeDist {
name,
data_type: data_type.clone(),
}
}

impl BuiltInWindowFunctionExpr for CumeDist {
Expand All @@ -49,8 +54,7 @@ impl BuiltInWindowFunctionExpr for CumeDist {

fn field(&self) -> Result<Field> {
let nullable = false;
let data_type = DataType::Float64;
Ok(Field::new(self.name(), data_type, nullable))
Ok(Field::new(self.name(), self.data_type.clone(), nullable))
}

fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
Expand Down Expand Up @@ -119,7 +123,7 @@ mod tests {
#[test]
#[allow(clippy::single_range_in_vec_init)]
fn test_cume_dist() -> Result<()> {
let r = cume_dist("arr".into());
let r = cume_dist("arr".into(), &DataType::Float64);

let expected = vec![0.0; 0];
test_i32_result(&r, 0, vec![], expected)?;
Expand Down
1 change: 1 addition & 0 deletions datafusion/physical-expr/src/window/lead_lag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ use std::sync::Arc;
#[derive(Debug)]
pub struct WindowShift {
name: String,
/// Output data type
data_type: DataType,
shift_offset: i64,
expr: Arc<dyn PhysicalExpr>,
Expand Down
1 change: 1 addition & 0 deletions datafusion/physical-expr/src/window/nth_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ use datafusion_expr::PartitionEvaluator;
pub struct NthValue {
name: String,
expr: Arc<dyn PhysicalExpr>,
/// Output data type
data_type: DataType,
kind: NthValueKind,
}
Expand Down
13 changes: 9 additions & 4 deletions datafusion/physical-expr/src/window/ntile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,17 @@ use std::sync::Arc;
pub struct Ntile {
name: String,
n: u64,
/// Output data type
data_type: DataType,
}

impl Ntile {
pub fn new(name: String, n: u64) -> Self {
Self { name, n }
pub fn new(name: String, n: u64, data_type: &DataType) -> Self {
Self {
name,
n,
data_type: data_type.clone(),
}
}

pub fn get_n(&self) -> u64 {
Expand All @@ -54,8 +60,7 @@ impl BuiltInWindowFunctionExpr for Ntile {

fn field(&self) -> Result<Field> {
let nullable = false;
let data_type = DataType::UInt64;
Ok(Field::new(self.name(), data_type, nullable))
Ok(Field::new(self.name(), self.data_type.clone(), nullable))
}

fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
Expand Down
23 changes: 12 additions & 11 deletions datafusion/physical-expr/src/window/rank.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ use std::sync::Arc;
pub struct Rank {
name: String,
rank_type: RankType,
/// Output data type
data_type: DataType,
}

impl Rank {
Expand All @@ -58,26 +60,29 @@ pub enum RankType {
}

/// Create a rank window function
pub fn rank(name: String) -> Rank {
pub fn rank(name: String, data_type: &DataType) -> Rank {
Rank {
name,
rank_type: RankType::Basic,
data_type: data_type.clone(),
}
}

/// Create a dense rank window function
pub fn dense_rank(name: String) -> Rank {
pub fn dense_rank(name: String, data_type: &DataType) -> Rank {
Rank {
name,
rank_type: RankType::Dense,
data_type: data_type.clone(),
}
}

/// Create a percent rank window function
pub fn percent_rank(name: String) -> Rank {
pub fn percent_rank(name: String, data_type: &DataType) -> Rank {
Rank {
name,
rank_type: RankType::Percent,
data_type: data_type.clone(),
}
}

Expand All @@ -89,11 +94,7 @@ impl BuiltInWindowFunctionExpr for Rank {

fn field(&self) -> Result<Field> {
let nullable = false;
let data_type = match self.rank_type {
RankType::Basic | RankType::Dense => DataType::UInt64,
RankType::Percent => DataType::Float64,
};
Ok(Field::new(self.name(), data_type, nullable))
Ok(Field::new(self.name(), self.data_type.clone(), nullable))
}

fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
Expand Down Expand Up @@ -268,15 +269,15 @@ mod tests {

#[test]
fn test_dense_rank() -> Result<()> {
let r = dense_rank("arr".into());
let r = dense_rank("arr".into(), &DataType::UInt64);
test_without_rank(&r, vec![1; 8])?;
test_with_rank(&r, vec![1, 1, 2, 3, 3, 3, 4, 5])?;
Ok(())
}

#[test]
fn test_rank() -> Result<()> {
let r = rank("arr".into());
let r = rank("arr".into(), &DataType::UInt64);
test_without_rank(&r, vec![1; 8])?;
test_with_rank(&r, vec![1, 1, 3, 4, 4, 4, 7, 8])?;
Ok(())
Expand All @@ -285,7 +286,7 @@ mod tests {
#[test]
#[allow(clippy::single_range_in_vec_init)]
fn test_percent_rank() -> Result<()> {
let r = percent_rank("arr".into());
let r = percent_rank("arr".into(), &DataType::Float64);

// empty case
let expected = vec![0.0; 0];
Expand Down
16 changes: 10 additions & 6 deletions datafusion/physical-expr/src/window/row_number.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,17 @@ use std::sync::Arc;
#[derive(Debug)]
pub struct RowNumber {
name: String,
/// Output data type
data_type: DataType,
}

impl RowNumber {
/// Create a new ROW_NUMBER function
pub fn new(name: impl Into<String>) -> Self {
Self { name: name.into() }
pub fn new(name: impl Into<String>, data_type: &DataType) -> Self {
Self {
name: name.into(),
data_type: data_type.clone(),
}
}
}

Expand All @@ -53,8 +58,7 @@ impl BuiltInWindowFunctionExpr for RowNumber {

fn field(&self) -> Result<Field> {
let nullable = false;
let data_type = DataType::UInt64;
Ok(Field::new(self.name(), data_type, nullable))
Ok(Field::new(self.name(), self.data_type.clone(), nullable))
}

fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
Expand Down Expand Up @@ -127,7 +131,7 @@ mod tests {
]));
let schema = Schema::new(vec![Field::new("arr", DataType::Boolean, true)]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![arr])?;
let row_number = RowNumber::new("row_number".to_owned());
let row_number = RowNumber::new("row_number".to_owned(), &DataType::UInt64);
let values = row_number.evaluate_args(&batch)?;
let result = row_number
.create_evaluator()?
Expand All @@ -145,7 +149,7 @@ mod tests {
]));
let schema = Schema::new(vec![Field::new("arr", DataType::Boolean, false)]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![arr])?;
let row_number = RowNumber::new("row_number".to_owned());
let row_number = RowNumber::new("row_number".to_owned(), &DataType::UInt64);
let values = row_number.evaluate_args(&batch)?;
let result = row_number
.create_evaluator()?
Expand Down
Loading

0 comments on commit 2b218be

Please sign in to comment.