Skip to content

Commit

Permalink
add white list for agg opt
Browse files Browse the repository at this point in the history
  • Loading branch information
liuneng1994 committed Oct 31, 2023
1 parent 6503c27 commit 6ebb3f2
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -288,4 +288,43 @@ class GlutenClickHouseTPCHColumnarShuffleParquetAQESuite
}
}
}

test("collect_set") {
val sql =
"""
|select a, b from (
|select n_regionkey as a, collect_set(if(n_regionkey=0, n_name, null)) as set from nation group by n_regionkey)
|lateral view explode(set) as b
|order by a, b
|""".stripMargin
runQueryAndCompare(sql)(checkOperatorMatch[CHHashAggregateExecTransformer])
}

test("test 'aggregate function collect_list'") {
val df = runQueryAndCompare(
"select l_orderkey,from_unixtime(l_orderkey, 'yyyy-MM-dd HH:mm:ss') " +
"from lineitem order by l_orderkey desc limit 10"
)(checkOperatorMatch[ProjectExecTransformer])
checkLengthAndPlan(df, 10)
}

test("test max string") {
withSQLConf(("spark.gluten.sql.columnar.force.hashagg", "true")) {
val sql =
"""
|SELECT
| l_returnflag,
| l_linestatus,
| max(l_comment)
|FROM
| lineitem
|WHERE
| l_shipdate <= date'1998-09-02' - interval 1 day
|GROUP BY
| l_returnflag,
| l_linestatus
|""".stripMargin
runQueryAndCompare(sql, noFallBack = false) { df => }
}
}
}
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Builder/SerializedPlanBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ std::shared_ptr<substrait::Type> SerializedPlanBuilder::buildType(const DB::Data
res->mutable_i32()->set_nullability(type_nullability);
else if (which.isInt64())
res->mutable_i64()->set_nullability(type_nullability);
else if (which.isString() || which.isAggregateFunction() || which.isFixedString())
else if (which.isStringOrFixedString() || which.isAggregateFunction())
res->mutable_binary()->set_nullability(type_nullability); /// Spark Binary type is more similiar to CH String type
else if (which.isFloat32())
res->mutable_fp32()->set_nullability(type_nullability);
Expand Down
19 changes: 17 additions & 2 deletions cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <Common/Arena.h>

#include <Columns/ColumnAggregateFunction.h>
#include <Columns/ColumnFixedString.h>
#include <DataTypes/DataTypeAggregateFunction.h>
#include <DataTypes/DataTypeFixedString.h>

Expand All @@ -26,20 +27,34 @@ using namespace DB;

namespace local_engine
{

bool isFixedSizeStateAggregateFunction(const String& name)
{
// TODO max(String) should exclude, but fallback now
static const std::set<String> function_set = {"min", "max", "sum", "count", "avg"};
return function_set.contains(name);
}

DB::ColumnWithTypeAndName convertAggregateStateToFixedString(DB::ColumnWithTypeAndName col)
{
if (!isAggregateFunction(col.type))
{
return col;
}
const auto *aggregate_col = checkAndGetColumn<ColumnAggregateFunction>(*col.column);
// only support known fixed size aggregate function
if (!isFixedSizeStateAggregateFunction(aggregate_col->getAggregateFunction()->getName()))
{
return col;
}
size_t state_size = aggregate_col->getAggregateFunction()->sizeOfData();
auto res_type = std::make_shared<DataTypeFixedString>(state_size);
auto res_col = res_type->createColumn();
res_col->reserve(aggregate_col->size());
PaddedPODArray<UInt8> & column_chars_t = assert_cast<ColumnFixedString &>(*res_col).getChars();
column_chars_t.reserve(aggregate_col->size() * state_size);
for (const auto & item : aggregate_col->getData())
{
res_col->insertData(item, state_size);
column_chars_t.insert_assume_reserved(item, item + state_size);
}
return DB::ColumnWithTypeAndName(std::move(res_col), res_type, col.name);
}
Expand Down

0 comments on commit 6ebb3f2

Please sign in to comment.