Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GLUTEN-6387][CH] support percentile function #6396

Merged
merged 6 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,15 @@ case class CHHashAggregateExecTransformer(
approxPercentile.percentageExpression.dataType,
approxPercentile.percentageExpression.nullable)
(makeStructType(fields), attr.nullable)
case percentile: Percentile =>
var fields = Seq[(DataType, Boolean)]()
// Use percentile.nullable as the nullable of the struct type
// to make sure it returns null when input is empty
fields = fields :+ (percentile.child.dataType, percentile.nullable)
fields = fields :+ (
percentile.percentageExpression.dataType,
percentile.percentageExpression.nullable)
(makeStructType(fields), attr.nullable)
case _ =>
(makeStructTypeSingleOne(attr.dataType, attr.nullable), attr.nullable)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,6 @@ case class FormatStringValidator() extends FunctionValidator {
}

object CHExpressionUtil {

final val CH_AGGREGATE_FUNC_BLACKLIST: Map[String, FunctionValidator] = Map(
MAX_BY -> DefaultValidator(),
MIN_BY -> DefaultValidator()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2599,6 +2599,19 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr
runQueryAndCompare(sql2)({ _ => })
}

test("aggregate function percentile") {
// single percentage
val sql1 = "select l_linenumber % 10, percentile(l_extendedprice, 0.5) " +
"from lineitem group by l_linenumber % 10"
runQueryAndCompare(sql1)({ _ => })

// multiple percentages
val sql2 =
"select l_linenumber % 10, percentile(l_extendedprice, array(0.1, 0.2, 0.3)) " +
"from lineitem group by l_linenumber % 10"
runQueryAndCompare(sql2)({ _ => })
}

test("GLUTEN-5096: Bug fix regexp_extract diff") {
val tbl_create_sql = "create table test_tbl_5096(id bigint, data string) using parquet"
val tbl_insert_sql = "insert into test_tbl_5096 values(1, 'abc'), (2, 'abc\n')"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.gluten.utils._

import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.expressions.{Alias, CumeDist, DenseRank, Descending, Expression, Lag, Lead, NamedExpression, NthValue, NTile, PercentRank, RangeFrame, Rank, RowNumber, SortOrder, SpecialFrameBoundary, SpecifiedWindowFrame}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, ApproximatePercentile}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, ApproximatePercentile, Percentile}
import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.execution.{ColumnarCachedBatchSerializer, SparkPlan}
Expand Down Expand Up @@ -371,7 +371,8 @@ object VeloxBackendSettings extends BackendSettingsApi {
case _: RowNumber | _: Rank | _: CumeDist | _: DenseRank | _: PercentRank |
_: NthValue | _: NTile | _: Lag | _: Lead =>
case aggrExpr: AggregateExpression
if !aggrExpr.aggregateFunction.isInstanceOf[ApproximatePercentile] =>
if !aggrExpr.aggregateFunction.isInstanceOf[ApproximatePercentile]
&& !aggrExpr.aggregateFunction.isInstanceOf[Percentile] =>
case _ =>
allSupported = false
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <AggregateFunctions/Combinators/AggregateFunctionCombinatorFactory.h>
#include <AggregateFunctions/AggregateFunctionPartialMerge.h>
#include <AggregateFunctions/Combinators/AggregateFunctionCombinatorFactory.h>
#include <DataTypes/DataTypeAggregateFunction.h>


Expand All @@ -25,88 +25,86 @@ namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
}

namespace local_engine
{
namespace
{
class AggregateFunctionCombinatorPartialMerge final : public IAggregateFunctionCombinator
{
public:
String getName() const override { return "PartialMerge"; }

DataTypes transformArguments(const DataTypes & arguments) const override
{
if (arguments.size() != 1)
throw Exception(
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Incorrect number of arguments for aggregate function with {} suffix",
getName());

const DataTypePtr & argument = arguments[0];
class AggregateFunctionCombinatorPartialMerge final : public IAggregateFunctionCombinator
{
public:
String getName() const override { return "PartialMerge"; }

const DataTypeAggregateFunction * function = typeid_cast<const DataTypeAggregateFunction *>(argument.get());
if (!function)
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Illegal type {} of argument for aggregate function with {} suffix must be AggregateFunction(...)",
argument->getName(),
getName());
DataTypes transformArguments(const DataTypes & arguments) const override
{
if (arguments.size() != 1)
throw Exception(
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Incorrect number of arguments for aggregate function with {} suffix",
getName());

const DataTypePtr & argument = arguments[0];

const DataTypeAggregateFunction * function = typeid_cast<const DataTypeAggregateFunction *>(argument.get());
if (!function)
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Illegal type {} of argument for aggregate function with {} suffix must be AggregateFunction(...)",
argument->getName(),
getName());

const DataTypeAggregateFunction * function2
= typeid_cast<const DataTypeAggregateFunction *>(function->getArgumentsDataTypes()[0].get());
if (function2)
return transformArguments(function->getArgumentsDataTypes());
return function->getArgumentsDataTypes();
}

AggregateFunctionPtr transformAggregateFunction(
const AggregateFunctionPtr & nested_function,
const AggregateFunctionProperties &,
const DataTypes & arguments,
const Array & params) const override
{
DataTypePtr & argument = const_cast<DataTypePtr &>(arguments[0]);

const DataTypeAggregateFunction * function2
= typeid_cast<const DataTypeAggregateFunction *>(function->getArgumentsDataTypes()[0].get());
if (function2)
{
return transformArguments(function->getArgumentsDataTypes());
}
return function->getArgumentsDataTypes();
}
const DataTypeAggregateFunction * function = typeid_cast<const DataTypeAggregateFunction *>(argument.get());
if (!function)
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Illegal type {} of argument for aggregate function with {} suffix must be AggregateFunction(...)",
argument->getName(),
getName());

AggregateFunctionPtr transformAggregateFunction(
const AggregateFunctionPtr & nested_function,
const AggregateFunctionProperties &,
const DataTypes & arguments,
const Array & params) const override
while (nested_function->getName() != function->getFunctionName())
{
DataTypePtr & argument = const_cast<DataTypePtr &>(arguments[0]);

const DataTypeAggregateFunction * function = typeid_cast<const DataTypeAggregateFunction *>(argument.get());
argument = function->getArgumentsDataTypes()[0];
function = typeid_cast<const DataTypeAggregateFunction *>(function->getArgumentsDataTypes()[0].get());
if (!function)
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Illegal type {} of argument for aggregate function with {} suffix must be AggregateFunction(...)",
argument->getName(),
getName());

while (nested_function->getName() != function->getFunctionName())
{
argument = function->getArgumentsDataTypes()[0];
function = typeid_cast<const DataTypeAggregateFunction *>(function->getArgumentsDataTypes()[0].get());
if (!function)
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Illegal type {} of argument for aggregate function with {} suffix must be AggregateFunction(...)",
argument->getName(),
getName());
}

if (nested_function->getName() != function->getFunctionName())
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Illegal type {} of argument for aggregate function with {} suffix, because it corresponds to different aggregate "
"function: {} instead of {}",
argument->getName(),
getName(),
function->getFunctionName(),
nested_function->getName());

return std::make_shared<AggregateFunctionPartialMerge>(nested_function, argument, params);
}
};

if (nested_function->getName() != function->getFunctionName())
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Illegal type {} of argument for aggregate function with {} suffix, because it corresponds to different aggregate "
"function: {} instead of {}",
argument->getName(),
getName(),
function->getFunctionName(),
nested_function->getName());

return std::make_shared<AggregateFunctionPartialMerge>(nested_function, argument, params);
}
};

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
#pragma once

#include <AggregateFunctions/IAggregateFunction_fwd.h>
#include <AggregateFunctions/IAggregateFunction.h>
#include <Columns/ColumnAggregateFunction.h>
#include <DataTypes/DataTypeAggregateFunction.h>
#include <Common/assert_cast.h>
Expand All @@ -41,8 +41,6 @@ struct Settings;
* this class is copied from AggregateFunctionMerge with little enhancement.
* we use this PartialMerge for both spark PartialMerge and Final
*/


class AggregateFunctionPartialMerge final : public IAggregateFunctionHelper<AggregateFunctionPartialMerge>
{
private:
Expand Down
53 changes: 26 additions & 27 deletions cpp-ch/local-engine/Parser/AggregateFunctionParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
* limitations under the License.
*/
#include "AggregateFunctionParser.h"
#include <type_traits>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <DataTypes/DataTypeAggregateFunction.h>
#include <DataTypes/DataTypeTuple.h>
Expand Down Expand Up @@ -105,6 +104,7 @@ AggregateFunctionParser::parseFunctionArguments(const CommonFunctionInfo & func_

collected_args.push_back(arg_node);
}

if (func_info.has_filter)
{
// With `If` combinator, the function take one more argument which refers to the condition.
Expand All @@ -115,47 +115,46 @@ AggregateFunctionParser::parseFunctionArguments(const CommonFunctionInfo & func_
}

std::pair<String, DB::DataTypes> AggregateFunctionParser::tryApplyCHCombinator(
const CommonFunctionInfo & func_info, const String & ch_func_name, const DB::DataTypes & arg_column_types) const
const CommonFunctionInfo & func_info, const String & ch_func_name, const DB::DataTypes & argument_types) const
{
auto get_aggregate_function = [](const String & name, const DB::DataTypes & arg_types) -> DB::AggregateFunctionPtr
auto get_aggregate_function
= [](const String & name, const DB::DataTypes & argument_types, const DB::Array & parameters) -> DB::AggregateFunctionPtr
{
DB::AggregateFunctionProperties properties;
auto func = RelParser::getAggregateFunction(name, arg_types, properties);
auto func = RelParser::getAggregateFunction(name, argument_types, properties, parameters);
if (!func)
{
throw Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unknown aggregate function {}", name);
}

return func;
};

String combinator_function_name = ch_func_name;
DB::DataTypes combinator_arg_column_types = arg_column_types;
DB::DataTypes combinator_argument_types = argument_types;

if (func_info.phase != substrait::AggregationPhase::AGGREGATION_PHASE_INITIAL_TO_INTERMEDIATE
&& func_info.phase != substrait::AggregationPhase::AGGREGATION_PHASE_INITIAL_TO_RESULT)
{
if (arg_column_types.size() != 1)
{
if (argument_types.size() != 1)
throw Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Only support one argument aggregate function in phase {}", func_info.phase);
}

// Add a check here for safty.
if (func_info.has_filter)
{
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Unspport apply filter in phase {}", func_info.phase);
}
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Apply filter in phase {} not supported", func_info.phase);

const auto * agg_function_data = DB::checkAndGetDataType<DB::DataTypeAggregateFunction>(arg_column_types[0].get());
if (!agg_function_data)
const auto * aggr_func_type = DB::checkAndGetDataType<DB::DataTypeAggregateFunction>(argument_types[0].get());
if (!aggr_func_type)
{
// FIXME. This is should be fixed. It's the case that count(distinct(xxx)) with other aggregate functions.
// Gluten breaks the rule that intermediate result should have a special format name here.
LOG_INFO(logger, "Intermediate aggregate function data is expected in phase {} for {}", func_info.phase, ch_func_name);
auto arg_type = DB::removeNullable(arg_column_types[0]);

auto arg_type = DB::removeNullable(argument_types[0]);
if (auto * tupe_type = typeid_cast<const DB::DataTypeTuple *>(arg_type.get()))
{
combinator_arg_column_types = tupe_type->getElements();
}
auto agg_function = get_aggregate_function(ch_func_name, arg_column_types);
combinator_argument_types = tupe_type->getElements();

auto agg_function = get_aggregate_function(ch_func_name, argument_types, aggr_func_type->getParameters());
auto agg_intermediate_result_type = agg_function->getStateType();
combinator_arg_column_types = {agg_intermediate_result_type};
combinator_argument_types = {agg_intermediate_result_type};
}
else
{
Expand All @@ -167,12 +166,12 @@ std::pair<String, DB::DataTypes> AggregateFunctionParser::tryApplyCHCombinator(
// count(a),count(b), count(1), count(distinct(a)), count(distinct(b))
// from values (1, null), (2,2) as data(a,b)
// with `first_value` enable
if (endsWith(agg_function_data->getFunction()->getName(), "If") && ch_func_name != agg_function_data->getFunction()->getName())
if (endsWith(aggr_func_type->getFunction()->getName(), "If") && ch_func_name != aggr_func_type->getFunction()->getName())
{
auto original_args_types = agg_function_data->getArgumentsDataTypes();
combinator_arg_column_types = DataTypes(original_args_types.begin(), std::prev(original_args_types.end()));
auto agg_function = get_aggregate_function(ch_func_name, combinator_arg_column_types);
combinator_arg_column_types = {agg_function->getStateType()};
auto original_args_types = aggr_func_type->getArgumentsDataTypes();
combinator_argument_types = DataTypes(original_args_types.begin(), std::prev(original_args_types.end()));
auto agg_function = get_aggregate_function(ch_func_name, combinator_argument_types, aggr_func_type->getParameters());
combinator_argument_types = {agg_function->getStateType()};
}
}
combinator_function_name += "PartialMerge";
Expand All @@ -182,7 +181,7 @@ std::pair<String, DB::DataTypes> AggregateFunctionParser::tryApplyCHCombinator(
// Apply `If` aggregate function combinator on the original aggregate function.
combinator_function_name += "If";
}
return {combinator_function_name, combinator_arg_column_types};
return {combinator_function_name, combinator_argument_types};
}

const DB::ActionsDAG::Node * AggregateFunctionParser::convertNodeTypeIfNeeded(
Expand Down
6 changes: 3 additions & 3 deletions cpp-ch/local-engine/Parser/AggregateFunctionParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class AggregateFunctionParser

/// In most cases, arguments size and types are enough to determine the CH function implementation.
/// It is only be used in TypeParser::buildBlockFromNamedStruct
/// Users are allowed to modify arg types to make it fit for ggregateFunctionFactory::instance().get(...) in TypeParser::buildBlockFromNamedStruct
/// Users are allowed to modify arg types to make it fit for AggregateFunctionFactory::instance().get(...) in TypeParser::buildBlockFromNamedStruct
virtual String getCHFunctionName(DB::DataTypes & args) const = 0;

/// Do some preprojections for the function arguments, and return the necessary arguments for the CH function.
Expand All @@ -114,8 +114,8 @@ class AggregateFunctionParser

/// Parameters are only used in aggregate functions at present. e.g. percentiles(0.5)(x).
/// 0.5 is the parameter of percentiles function.
virtual DB::Array
parseFunctionParameters(const CommonFunctionInfo & /*func_info*/, DB::ActionsDAG::NodeRawConstPtrs & /*arg_nodes*/) const
virtual DB::Array parseFunctionParameters(
const CommonFunctionInfo & /*func_info*/, DB::ActionsDAG::NodeRawConstPtrs & /*arg_nodes*/, DB::ActionsDAG & /*actions_dag*/) const
{
return DB::Array();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ void AggregateRelParser::addPreProjection()
{
auto arg_nodes = agg_info.function_parser->parseFunctionArguments(agg_info.parser_func_info, projection_action);
// This may remove elements from arg_nodes, because some of them are converted to CH func parameters.
agg_info.params = agg_info.function_parser->parseFunctionParameters(agg_info.parser_func_info, arg_nodes);
agg_info.params = agg_info.function_parser->parseFunctionParameters(agg_info.parser_func_info, arg_nodes, projection_action);
for (auto & arg_node : arg_nodes)
{
agg_info.arg_column_names.emplace_back(arg_node->result_name);
Expand Down
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Parser/RelParsers/WindowRelParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ void WindowRelParser::tryAddProjectionBeforeWindow()
{
auto arg_nodes = win_info.function_parser->parseFunctionArguments(win_info.parser_func_info, actions_dag);
// This may remove elements from arg_nodes, because some of them are converted to CH func parameters.
win_info.params = win_info.function_parser->parseFunctionParameters(win_info.parser_func_info, arg_nodes);
win_info.params = win_info.function_parser->parseFunctionParameters(win_info.parser_func_info, arg_nodes, actions_dag);
for (auto & arg_node : arg_nodes)
{
win_info.arg_column_names.emplace_back(arg_node->result_name);
Expand Down
Loading
Loading