diff --git a/core/src/execution/datafusion/expressions/mod.rs b/core/src/execution/datafusion/expressions/mod.rs index 084fef2df..05230b4c2 100644 --- a/core/src/execution/datafusion/expressions/mod.rs +++ b/core/src/execution/datafusion/expressions/mod.rs @@ -36,5 +36,6 @@ pub mod strings; pub mod subquery; pub mod sum_decimal; pub mod temporal; +pub mod unbound; mod utils; pub mod variance; diff --git a/core/src/execution/datafusion/expressions/unbound.rs b/core/src/execution/datafusion/expressions/unbound.rs new file mode 100644 index 000000000..5387b1012 --- /dev/null +++ b/core/src/execution/datafusion/expressions/unbound.rs @@ -0,0 +1,110 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::execution::datafusion::expressions::utils::down_cast_any_ref; +use arrow_array::RecordBatch; +use arrow_schema::{DataType, Schema}; +use datafusion::physical_plan::ColumnarValue; +use datafusion_common::{internal_err, Result}; +use datafusion_physical_expr::PhysicalExpr; +use std::{ + any::Any, + hash::{Hash, Hasher}, + sync::Arc, +}; + +/// This is similar to `UnKnownColumn` in DataFusion, but it has data type. +/// This is only used when the column is not bound to a schema, for example, the +/// inputs to aggregation functions in final aggregation. In the case, we cannot +/// bind the aggregation functions to the input schema which is grouping columns +/// and aggregate buffer attributes in Spark (DataFusion has different design). +/// But when creating certain aggregation functions, we need to know its input +/// data types. As `UnKnownColumn` doesn't have data type, we implement this +/// `UnboundColumn` to carry the data type. +#[derive(Debug, Hash, PartialEq, Eq, Clone)] +pub struct UnboundColumn { + name: String, + datatype: DataType, +} + +impl UnboundColumn { + /// Create a new unbound column expression + pub fn new(name: &str, datatype: DataType) -> Self { + Self { + name: name.to_owned(), + datatype, + } + } + + /// Get the column name + pub fn name(&self) -> &str { + &self.name + } +} + +impl std::fmt::Display for UnboundColumn { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}, datatype: {}", self.name, self.datatype) + } +} + +impl PhysicalExpr for UnboundColumn { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn std::any::Any { + self + } + + /// Get the data type of this expression, given the schema of the input + fn data_type(&self, _input_schema: &Schema) -> Result { + Ok(self.datatype.clone()) + } + + /// Decide whether this expression is nullable, given the schema of the input + fn nullable(&self, _input_schema: &Schema) -> Result { + Ok(true) + } + + /// Evaluate the expression + fn evaluate(&self, _batch: &RecordBatch) -> Result { + internal_err!("UnboundColumn::evaluate() should not be called") + } + + fn children(&self) -> Vec> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> Result> { + Ok(self) + } + + fn dyn_hash(&self, state: &mut dyn Hasher) { + let mut s = state; + self.hash(&mut s); + } +} + +impl PartialEq for UnboundColumn { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| self == x) + .unwrap_or(false) + } +} diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index a5bcf5654..7af5f6838 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -33,7 +33,7 @@ use datafusion::{ expressions::{ in_list, BinaryExpr, BitAnd, BitOr, BitXor, CaseExpr, CastExpr, Column, Count, FirstValue, InListExpr, IsNotNullExpr, IsNullExpr, LastValue, - Literal as DataFusionLiteral, Max, Min, NotExpr, Sum, UnKnownColumn, + Literal as DataFusionLiteral, Max, Min, NotExpr, Sum, }, AggregateExpr, PhysicalExpr, PhysicalSortExpr, ScalarFunctionExpr, }, @@ -78,6 +78,7 @@ use crate::{ subquery::Subquery, sum_decimal::SumDecimal, temporal::{DateTruncExec, HourExec, MinuteExec, SecondExec, TimestampTruncExec}, + unbound::UnboundColumn, variance::Variance, NormalizeNaNAndZero, }, @@ -239,7 +240,13 @@ impl PhysicalPlanner { let field = input_schema.field(idx); Ok(Arc::new(Column::new(field.name().as_str(), idx))) } - ExprStruct::Unbound(unbound) => Ok(Arc::new(UnKnownColumn::new(unbound.name.as_str()))), + ExprStruct::Unbound(unbound) => { + let data_type = to_arrow_datatype(unbound.datatype.as_ref().unwrap()); + Ok(Arc::new(UnboundColumn::new( + unbound.name.as_str(), + data_type, + ))) + } ExprStruct::IsNotNull(is_notnull) => { let child = self.create_expr(is_notnull.child.as_ref().unwrap(), input_schema)?; Ok(Arc::new(IsNotNullExpr::new(child)))