Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GLUTEN-3676][CH] Enable TPCH Deicmal Test #3677

Merged
merged 6 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,10 @@ class RangePartitionerBoundsGenerator[K: Ordering: ClassTag, V](
case _: DoubleType => node.put("value", row.getDouble(i))
case _: StringType => node.put("value", row.getString(i))
case _: DateType => node.put("value", row.getInt(i))
case d =>
case d: DecimalType =>
val decimal = row.getDecimal(i, d.precision, d.scale).toString()
node.put("value", decimal)
case _ =>
throw new IllegalArgumentException(
s"Unsupported data type ${ordering.dataType.toString}")
}
Expand Down Expand Up @@ -244,6 +247,7 @@ object RangePartitionerBoundsGenerator {
case _: DoubleType => true
case _: StringType => true
case _: DateType => true
case _: DecimalType => true
case _ => false
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,18 @@ class GlutenClickHouseDecimalSuite
}

private val decimalTable: String = "decimal_table"
private val decimalTPCHTables: Seq[DecimalType] = Seq.apply(DecimalType.apply(18, 8))
private val decimalTPCHTables: Seq[(DecimalType, Seq[Int])] = Seq.apply(
(DecimalType.apply(9, 4), Seq()),
// 1: ch decimal avg is float
(DecimalType.apply(18, 8), Seq(1)),
// 1: ch decimal avg is float, 3/10: all value is null and compare with limit
(DecimalType.apply(38, 19), Seq(1, 3, 10))
)

override protected val createNullableTables = true

override protected def createTPCHNullableTables(): Unit = {
decimalTPCHTables.foreach(createDecimalTables)
decimalTPCHTables.foreach(t => createDecimalTables(t._1))
}

private def createDecimalTables(dataType: DecimalType): Unit = {
Expand All @@ -85,7 +91,7 @@ class GlutenClickHouseDecimalSuite
.map(
tableName => {
val originTablePath = tablesPath + "/" + tableName
spark.read.parquet(originTablePath).createTempView(tableName + "_ori")
spark.read.parquet(originTablePath).createOrReplaceTempView(tableName + "_ori")

val sql = tableName match {
case "customer" =>
Expand Down Expand Up @@ -292,17 +298,42 @@ class GlutenClickHouseDecimalSuite
queriesResults: String = queriesResults,
compareResult: Boolean = true,
noFallBack: Boolean = true)(customCheck: DataFrame => Unit): Unit = {
decimalTPCHTables.foreach(
decimalType => {
spark.sql(s"use decimal_${decimalType.precision}_${decimalType.scale}")
compareTPCHQueryAgainstVanillaSpark(queryNum, tpchQueries, customCheck, noFallBack)
spark.sql(s"use default")
})
compareTPCHQueryAgainstVanillaSpark(
queryNum,
tpchQueries,
compareResult = compareResult,
customCheck = customCheck,
noFallBack = noFallBack)
}

test("TPCH Q20") {
runTPCHQuery(20)(_ => {})
}
Range
.inclusive(1, 22)
.foreach(
sql_num => {
decimalTPCHTables.foreach(
dt => {
val decimalType = dt._1
test(s"TPCH Decimal(${decimalType.precision},${decimalType.scale}) Q$sql_num") {
var noFallBack = true
var compareResult = true
if (sql_num == 16 || sql_num == 21) {
noFallBack = false
}

if (dt._2.contains(sql_num)) {
compareResult = false
}

spark.sql(s"use decimal_${decimalType.precision}_${decimalType.scale}")
runTPCHQuery(
sql_num,
tpchQueries,
compareResult = compareResult,
noFallBack = noFallBack) { _ => {} }
spark.sql(s"use default")
}
})
})

test("fix decimal precision overflow") {
val sql =
Expand Down
13 changes: 7 additions & 6 deletions cpp-ch/local-engine/Common/CHUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
#include <DataTypes/DataTypeTuple.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/NestedUtils.h>
#include <Functions/CastOverloadResolver.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionsConversion.h>
#include <Functions/registerFunctions.h>
Expand Down Expand Up @@ -390,7 +389,11 @@ const DB::ColumnWithTypeAndName * NestedColumnExtractHelper::findColumn(const DB
}

const DB::ActionsDAG::Node * ActionsDAGUtil::convertNodeType(
DB::ActionsDAGPtr & actions_dag, const DB::ActionsDAG::Node * node, const std::string & type_name, const std::string & result_name)
DB::ActionsDAGPtr & actions_dag,
const DB::ActionsDAG::Node * node,
const std::string & type_name,
const std::string & result_name,
CastType cast_type)
{
DB::ColumnWithTypeAndName type_name_col;
type_name_col.name = type_name;
Expand All @@ -399,11 +402,9 @@ const DB::ActionsDAG::Node * ActionsDAGUtil::convertNodeType(
const auto * right_arg = &actions_dag->addColumn(std::move(type_name_col));
const auto * left_arg = node;
DB::CastDiagnostic diagnostic = {node->result_name, node->result_name};
DB::FunctionOverloadResolverPtr func_builder_cast
= DB::createInternalCastOverloadResolver(DB::CastType::nonAccurate, std::move(diagnostic));

DB::ActionsDAG::NodeRawConstPtrs children = {left_arg, right_arg};
return &actions_dag->addFunction(func_builder_cast, std::move(children), result_name);
return &actions_dag->addFunction(
DB::createInternalCastOverloadResolver(cast_type, std::move(diagnostic)), std::move(children), result_name);
}

String QueryPipelineUtil::explainPipeline(DB::QueryPipeline & pipeline)
Expand Down
4 changes: 3 additions & 1 deletion cpp-ch/local-engine/Common/CHUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <Core/ColumnWithTypeAndName.h>
#include <Core/NamesAndTypes.h>
#include <DataTypes/Serializations/ISerialization.h>
#include <Functions/CastOverloadResolver.h>
#include <Interpreters/ActionsDAG.h>
#include <Interpreters/Context.h>
#include <Processors/Chunk.h>
Expand Down Expand Up @@ -99,7 +100,8 @@ class ActionsDAGUtil
DB::ActionsDAGPtr & actions_dag,
const DB::ActionsDAG::Node * node,
const std::string & type_name,
const std::string & result_name = "");
const std::string & result_name = "",
DB::CastType cast_type = DB::CastType::nonAccurate);
};

class QueryPipelineUtil
Expand Down
18 changes: 13 additions & 5 deletions cpp-ch/local-engine/Parser/AggregateFunctionParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,15 +151,23 @@ const DB::ActionsDAG::Node * AggregateFunctionParser::convertNodeTypeIfNeeded(
const auto & output_type = func_info.output_type;
if (!TypeParser::isTypeMatched(output_type, func_node->result_type))
{
auto ret_node = ActionsDAGUtil::convertNodeType(
func_node = ActionsDAGUtil::convertNodeType(
actions_dag, func_node, TypeParser::parseType(output_type)->getName(), func_node->result_name);
actions_dag->addOrReplaceInOutputs(*ret_node);
return ret_node;
actions_dag->addOrReplaceInOutputs(*func_node);
}
else

if (output_type.has_decimal())
{
return func_node;
String checkDecimalOverflowSparkOrNull = "checkDecimalOverflowSparkOrNull";
DB::ActionsDAG::NodeRawConstPtrs overflow_args
= {func_node,
plan_parser->addColumn(actions_dag, std::make_shared<DataTypeInt32>(), output_type.decimal().precision()),
plan_parser->addColumn(actions_dag, std::make_shared<DataTypeInt32>(), output_type.decimal().scale())};
func_node = toFunctionNode(actions_dag, checkDecimalOverflowSparkOrNull, func_node->result_name, overflow_args);
actions_dag->addOrReplaceInOutputs(*func_node);
}

return func_node;
}

AggregateFunctionParserFactory & AggregateFunctionParserFactory::instance()
Expand Down
Loading
Loading