From a69a5918081e2d0e785e0a620c2422b202767a85 Mon Sep 17 00:00:00 2001 From: kevinyhzou Date: Mon, 19 Aug 2024 10:38:31 +0800 Subject: [PATCH] fix array_join diff --- .../gluten/utils/CHExpressionUtil.scala | 1 - .../Functions/SparkFunctionArrayJoin.cpp | 44 ++++++++++++++----- 2 files changed, 32 insertions(+), 13 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala index 55d70904f2e1..2b1b7ba2f084 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala @@ -181,7 +181,6 @@ object CHExpressionUtil { ) final val CH_BLACKLIST_SCALAR_FUNCTION: Map[String, FunctionValidator] = Map( - ARRAY_JOIN -> ArrayJoinValidator(), SPLIT_PART -> DefaultValidator(), TO_UNIX_TIMESTAMP -> UnixTimeStampValidator(), UNIX_TIMESTAMP -> UnixTimeStampValidator(), diff --git a/cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp b/cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp index ed99c0904272..0ac694cf81a9 100644 --- a/cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp +++ b/cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp @@ -21,6 +21,7 @@ #include #include #include +#include using namespace DB; @@ -54,23 +55,39 @@ class SparkFunctionArrayJoin : public IFunction return makeNullable(data_type); } - ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t) const override + ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override { if (arguments.size() != 2 && arguments.size() != 3) throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} must have 2 or 3 arguments", getName()); - - const auto * arg_null_col = checkAndGetColumn(arguments[0].column.get()); - const ColumnArray * array_col; - if (!arg_null_col) - array_col = checkAndGetColumn(arguments[0].column.get()); + auto res_col = ColumnString::create(); + auto null_col = ColumnUInt8::create(input_rows_count, 0); + PaddedPODArray & null_result = null_col->getData(); + if (input_rows_count == 0) + return ColumnNullable::create(std::move(res_col), std::move(null_col)); + + const auto * arg_const_col = checkAndGetColumn(arguments[0].column.get()); + const ColumnArray * array_col = nullptr; + const ColumnNullable * arg_null_col = nullptr; + if (arg_const_col) + { + if (arg_const_col->onlyNull()) + { + null_result[0] = 1; + return ColumnNullable::create(std::move(res_col), std::move(null_col)); + } + array_col = checkAndGetColumn(arg_const_col->getDataColumnPtr().get()); + } else - array_col = checkAndGetColumn(arg_null_col->getNestedColumnPtr().get()); + { + arg_null_col = checkAndGetColumn(arguments[0].column.get()); + if (!arg_null_col) + array_col = checkAndGetColumn(arguments[0].column.get()); + else + array_col = checkAndGetColumn(arg_null_col->getNestedColumnPtr().get()); + } if (!array_col) throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Function {} 1st argument must be array type", getName()); - auto res_col = ColumnString::create(); - auto null_col = ColumnUInt8::create(array_col->size(), 0); - PaddedPODArray & null_result = null_col->getData(); std::pair delim_p, null_replacement_p; bool return_result = false; auto checkAndGetConstString = [&](const ColumnPtr & col) -> std::pair @@ -145,7 +162,7 @@ class SparkFunctionArrayJoin : public IFunction } } }; - if (arg_null_col->isNullAt(i)) + if (arg_null_col && arg_null_col->isNullAt(i)) { setResultNull(); continue; @@ -166,9 +183,9 @@ class SparkFunctionArrayJoin : public IFunction continue; } } - size_t array_size = array_offsets[i] - current_offset; size_t data_pos = array_pos == 0 ? 0 : string_offsets[array_pos - 1]; + size_t last_not_null_pos = 0; for (size_t j = 0; j < array_size; ++j) { if (array_nested_col && array_nested_col->isNullAt(j + array_pos)) @@ -179,11 +196,14 @@ class SparkFunctionArrayJoin : public IFunction if (j != array_size - 1) res += delim.toString(); } + else if (j == array_size - 1) + res = res.substr(0, last_not_null_pos); } else { const StringRef s(&string_data[data_pos], string_offsets[j + array_pos] - data_pos - 1); res += s.toString(); + last_not_null_pos = res.size(); if (j != array_size - 1) res += delim.toString(); }