Skip to content

Commit

Permalink
[GLUTEN-6561][CH] Fix incompatiable type exception throw in capture f…
Browse files Browse the repository at this point in the history
…unction while processing array literal with `transform` (#6601)

* fix style

* fix issue #6561

* add uts

* add uts

* fix uts

* fix style

* ignore some checks when spark 3.3
  • Loading branch information
taiyang-li authored Jul 30, 2024
1 parent e8dd172 commit dba4439
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -903,33 +903,17 @@ class GlutenClickHouseNativeWriteTableSuite
| ) partitioned by (day string)
| stored as $format""".stripMargin

// FIXME:
// Spark analyzer(>=3.4) will resolve map type to
// map_from_arrays(transform(map_keys(map('t1','a','t2','b')), v->v),
// transform(map_values(map('t1','a','t2','b')), v->v))
// which cause core dump. see https://github.com/apache/incubator-gluten/issues/6561
// for details.
val insert_sql =
if (isSparkVersionLE("3.3")) {
s"""insert overwrite $table_name partition (day)
|select id as a,
| str_to_map(concat('t1:','a','&t2:','b'),'&',':'),
| struct('1', null) as c,
| '2024-01-08' as day
|from range(10)""".stripMargin
} else {
s"""insert overwrite $table_name partition (day)
|select id as a,
| map('t1', 'a', 't2', 'b'),
| struct('1', null) as c,
| '2024-01-08' as day
|from range(10)""".stripMargin
}
s"""insert overwrite $table_name partition (day)
|select id as a,
| str_to_map(concat('t1:','a','&t2:','b'),'&',':'),
| struct('1', null) as c,
| '2024-01-08' as day
|from range(10)""".stripMargin
(table_name, create_sql, insert_sql)
},
(table_name, _) =>
if (isSparkVersionGE("3.4")) {
// FIXME: Don't Know Why Failed
compareResultsAgainstVanillaSpark(
s"select * from $table_name",
compareResult = true,
Expand Down
15 changes: 14 additions & 1 deletion cpp-ch/local-engine/Common/CHUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,19 @@ const DB::ActionsDAG::Node * ActionsDAGUtil::convertNodeType(
DB::createInternalCastOverloadResolver(cast_type, std::move(diagnostic)), std::move(children), result_name);
}

const DB::ActionsDAG::Node * ActionsDAGUtil::convertNodeTypeIfNeeded(
DB::ActionsDAGPtr & actions_dag,
const DB::ActionsDAG::Node * node,
const DB::DataTypePtr & dst_type,
const std::string & result_name,
CastType cast_type)
{
if (node->result_type->equals(*dst_type))
return node;

return convertNodeType(actions_dag, node, dst_type->getName(), result_name, cast_type);
}

String QueryPipelineUtil::explainPipeline(DB::QueryPipeline & pipeline)
{
DB::WriteBufferFromOwnString buf;
Expand Down Expand Up @@ -844,7 +857,7 @@ void BackendInitializerUtil::initContexts(DB::Context::ConfigurationPtr config)
size_t index_uncompressed_cache_size = config->getUInt64("index_uncompressed_cache_size", DEFAULT_INDEX_UNCOMPRESSED_CACHE_MAX_SIZE);
double index_uncompressed_cache_size_ratio = config->getDouble("index_uncompressed_cache_size_ratio", DEFAULT_INDEX_UNCOMPRESSED_CACHE_SIZE_RATIO);
global_context->setIndexUncompressedCache(index_uncompressed_cache_policy, index_uncompressed_cache_size, index_uncompressed_cache_size_ratio);

String index_mark_cache_policy = config->getString("index_mark_cache_policy", DEFAULT_INDEX_MARK_CACHE_POLICY);
size_t index_mark_cache_size = config->getUInt64("index_mark_cache_size", DEFAULT_INDEX_MARK_CACHE_MAX_SIZE);
double index_mark_cache_size_ratio = config->getDouble("index_mark_cache_size_ratio", DEFAULT_INDEX_MARK_CACHE_SIZE_RATIO);
Expand Down
7 changes: 7 additions & 0 deletions cpp-ch/local-engine/Common/CHUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,13 @@ class ActionsDAGUtil
const std::string & type_name,
const std::string & result_name = "",
DB::CastType cast_type = DB::CastType::nonAccurate);

static const DB::ActionsDAG::Node * convertNodeTypeIfNeeded(
DB::ActionsDAGPtr & actions_dag,
const DB::ActionsDAG::Node * node,
const DB::DataTypePtr & dst_type,
const std::string & result_name = "",
DB::CastType cast_type = DB::CastType::nonAccurate);
};

class QueryPipelineUtil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ REGISTER_COMMON_SCALAR_FUNCTION_PARSER(ToUnixTimestamp, to_unix_timestamp, parse
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Position, positive, identity);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Negative, negative, negate);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Pmod, pmod, pmod);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(abs, abs, abs);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Abs, abs, abs);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Ceil, ceil, ceil);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Round, round, roundHalfUp);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Bround, bround, roundBankers);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,17 @@
* limitations under the License.
*/

#include <Parser/FunctionParser.h>
#include <Common/Exception.h>
#include <Poco/Logger.h>
#include <Common/logger_useful.h>
#include <Common/CHUtil.h>
#include <Core/Types.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeFunction.h>
#include <DataTypes/DataTypeNullable.h>
#include <Core/Types.h>
#include <Parser/FunctionParser.h>
#include <Parser/TypeParser.h>
#include <Parser/scalar_function_parser/lambdaFunction.h>
#include <Poco/Logger.h>
#include <Common/CHUtil.h>
#include <Common/Exception.h>
#include <Common/logger_useful.h>

namespace DB::ErrorCodes
{
Expand Down Expand Up @@ -90,7 +91,16 @@ class ArrayTransform : public FunctionParser
assert(parsed_args.size() == 2);
if (lambda_args.size() == 1)
{
return toFunctionNode(actions_dag, ch_func_name, {parsed_args[1], parsed_args[0]});
/// Convert Array(T) to Array(U) if needed, Array(T) is the type of the first argument of transform.
/// U is the argument type of lambda function. In some cases Array(T) is not equal to Array(U).
/// e.g. in the second query of https://github.com/apache/incubator-gluten/issues/6561, T is String, and U is Nullable(String)
/// The difference of both types will result in runtime exceptions in function capture.
const auto & src_array_type = parsed_args[0]->result_type;
DataTypePtr dst_array_type = std::make_shared<DataTypeArray>(lambda_args.front().type);
if (isNullableOrLowCardinalityNullable(src_array_type))
dst_array_type = std::make_shared<DataTypeNullable>(dst_array_type);
const auto * dst_array_arg = ActionsDAGUtil::convertNodeTypeIfNeeded(actions_dag, parsed_args[0], dst_array_type);
return toFunctionNode(actions_dag, ch_func_name, {parsed_args[1], dst_array_arg});
}

/// transform with index argument.
Expand Down

0 comments on commit dba4439

Please sign in to comment.