Skip to content

Commit

Permalink
[GLUTEN-3676][CH] Enable TPCH Deicmal Test (#3677)
Browse files Browse the repository at this point in the history
* [CH] Enable tpch deicmal sql

* fix ut

* add decimal import

* FFix cast or null

* fix ci error

* fix ci error
  • Loading branch information
loneylee authored Dec 13, 2023
1 parent 0f5d9e1 commit a36c961
Show file tree
Hide file tree
Showing 12 changed files with 158 additions and 93 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,10 @@ class RangePartitionerBoundsGenerator[K: Ordering: ClassTag, V](
case _: DoubleType => node.put("value", row.getDouble(i))
case _: StringType => node.put("value", row.getString(i))
case _: DateType => node.put("value", row.getInt(i))
case d =>
case d: DecimalType =>
val decimal = row.getDecimal(i, d.precision, d.scale).toString()
node.put("value", decimal)
case _ =>
throw new IllegalArgumentException(
s"Unsupported data type ${ordering.dataType.toString}")
}
Expand Down Expand Up @@ -244,6 +247,7 @@ object RangePartitionerBoundsGenerator {
case _: DoubleType => true
case _: StringType => true
case _: DateType => true
case _: DecimalType => true
case _ => false
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,18 @@ class GlutenClickHouseDecimalSuite
}

private val decimalTable: String = "decimal_table"
private val decimalTPCHTables: Seq[DecimalType] = Seq.apply(DecimalType.apply(18, 8))
private val decimalTPCHTables: Seq[(DecimalType, Seq[Int])] = Seq.apply(
(DecimalType.apply(9, 4), Seq()),
// 1: ch decimal avg is float
(DecimalType.apply(18, 8), Seq(1)),
// 1: ch decimal avg is float, 3/10: all value is null and compare with limit
(DecimalType.apply(38, 19), Seq(1, 3, 10))
)

override protected val createNullableTables = true

override protected def createTPCHNullableTables(): Unit = {
decimalTPCHTables.foreach(createDecimalTables)
decimalTPCHTables.foreach(t => createDecimalTables(t._1))
}

private def createDecimalTables(dataType: DecimalType): Unit = {
Expand All @@ -85,7 +91,7 @@ class GlutenClickHouseDecimalSuite
.map(
tableName => {
val originTablePath = tablesPath + "/" + tableName
spark.read.parquet(originTablePath).createTempView(tableName + "_ori")
spark.read.parquet(originTablePath).createOrReplaceTempView(tableName + "_ori")

val sql = tableName match {
case "customer" =>
Expand Down Expand Up @@ -292,17 +298,42 @@ class GlutenClickHouseDecimalSuite
queriesResults: String = queriesResults,
compareResult: Boolean = true,
noFallBack: Boolean = true)(customCheck: DataFrame => Unit): Unit = {
decimalTPCHTables.foreach(
decimalType => {
spark.sql(s"use decimal_${decimalType.precision}_${decimalType.scale}")
compareTPCHQueryAgainstVanillaSpark(queryNum, tpchQueries, customCheck, noFallBack)
spark.sql(s"use default")
})
compareTPCHQueryAgainstVanillaSpark(
queryNum,
tpchQueries,
compareResult = compareResult,
customCheck = customCheck,
noFallBack = noFallBack)
}

test("TPCH Q20") {
runTPCHQuery(20)(_ => {})
}
Range
.inclusive(1, 22)
.foreach(
sql_num => {
decimalTPCHTables.foreach(
dt => {
val decimalType = dt._1
test(s"TPCH Decimal(${decimalType.precision},${decimalType.scale}) Q$sql_num") {
var noFallBack = true
var compareResult = true
if (sql_num == 16 || sql_num == 21) {
noFallBack = false
}

if (dt._2.contains(sql_num)) {
compareResult = false
}

spark.sql(s"use decimal_${decimalType.precision}_${decimalType.scale}")
runTPCHQuery(
sql_num,
tpchQueries,
compareResult = compareResult,
noFallBack = noFallBack) { _ => {} }
spark.sql(s"use default")
}
})
})

test("fix decimal precision overflow") {
val sql =
Expand Down
13 changes: 7 additions & 6 deletions cpp-ch/local-engine/Common/CHUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
#include <DataTypes/DataTypeTuple.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/NestedUtils.h>
#include <Functions/CastOverloadResolver.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionsConversion.h>
#include <Functions/registerFunctions.h>
Expand Down Expand Up @@ -390,7 +389,11 @@ const DB::ColumnWithTypeAndName * NestedColumnExtractHelper::findColumn(const DB
}

const DB::ActionsDAG::Node * ActionsDAGUtil::convertNodeType(
DB::ActionsDAGPtr & actions_dag, const DB::ActionsDAG::Node * node, const std::string & type_name, const std::string & result_name)
DB::ActionsDAGPtr & actions_dag,
const DB::ActionsDAG::Node * node,
const std::string & type_name,
const std::string & result_name,
CastType cast_type)
{
DB::ColumnWithTypeAndName type_name_col;
type_name_col.name = type_name;
Expand All @@ -399,11 +402,9 @@ const DB::ActionsDAG::Node * ActionsDAGUtil::convertNodeType(
const auto * right_arg = &actions_dag->addColumn(std::move(type_name_col));
const auto * left_arg = node;
DB::CastDiagnostic diagnostic = {node->result_name, node->result_name};
DB::FunctionOverloadResolverPtr func_builder_cast
= DB::createInternalCastOverloadResolver(DB::CastType::nonAccurate, std::move(diagnostic));

DB::ActionsDAG::NodeRawConstPtrs children = {left_arg, right_arg};
return &actions_dag->addFunction(func_builder_cast, std::move(children), result_name);
return &actions_dag->addFunction(
DB::createInternalCastOverloadResolver(cast_type, std::move(diagnostic)), std::move(children), result_name);
}

String QueryPipelineUtil::explainPipeline(DB::QueryPipeline & pipeline)
Expand Down
4 changes: 3 additions & 1 deletion cpp-ch/local-engine/Common/CHUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <Core/ColumnWithTypeAndName.h>
#include <Core/NamesAndTypes.h>
#include <DataTypes/Serializations/ISerialization.h>
#include <Functions/CastOverloadResolver.h>
#include <Interpreters/ActionsDAG.h>
#include <Interpreters/Context.h>
#include <Processors/Chunk.h>
Expand Down Expand Up @@ -99,7 +100,8 @@ class ActionsDAGUtil
DB::ActionsDAGPtr & actions_dag,
const DB::ActionsDAG::Node * node,
const std::string & type_name,
const std::string & result_name = "");
const std::string & result_name = "",
DB::CastType cast_type = DB::CastType::nonAccurate);
};

class QueryPipelineUtil
Expand Down
18 changes: 13 additions & 5 deletions cpp-ch/local-engine/Parser/AggregateFunctionParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,15 +151,23 @@ const DB::ActionsDAG::Node * AggregateFunctionParser::convertNodeTypeIfNeeded(
const auto & output_type = func_info.output_type;
if (!TypeParser::isTypeMatched(output_type, func_node->result_type))
{
auto ret_node = ActionsDAGUtil::convertNodeType(
func_node = ActionsDAGUtil::convertNodeType(
actions_dag, func_node, TypeParser::parseType(output_type)->getName(), func_node->result_name);
actions_dag->addOrReplaceInOutputs(*ret_node);
return ret_node;
actions_dag->addOrReplaceInOutputs(*func_node);
}
else

if (output_type.has_decimal())
{
return func_node;
String checkDecimalOverflowSparkOrNull = "checkDecimalOverflowSparkOrNull";
DB::ActionsDAG::NodeRawConstPtrs overflow_args
= {func_node,
plan_parser->addColumn(actions_dag, std::make_shared<DataTypeInt32>(), output_type.decimal().precision()),
plan_parser->addColumn(actions_dag, std::make_shared<DataTypeInt32>(), output_type.decimal().scale())};
func_node = toFunctionNode(actions_dag, checkDecimalOverflowSparkOrNull, func_node->result_name, overflow_args);
actions_dag->addOrReplaceInOutputs(*func_node);
}

return func_node;
}

AggregateFunctionParserFactory & AggregateFunctionParserFactory::instance()
Expand Down
Loading

0 comments on commit a36c961

Please sign in to comment.