diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala index cf45c1118f139..e9bee84396f83 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala @@ -209,7 +209,6 @@ object CHExpressionUtil { UNIX_MICROS -> DefaultValidator(), TIMESTAMP_MILLIS -> DefaultValidator(), TIMESTAMP_MICROS -> DefaultValidator(), - FLATTEN -> DefaultValidator(), STACK -> DefaultValidator() ) } diff --git a/cpp-ch/clickhouse.version b/cpp-ch/clickhouse.version index 4a3088e54309b..54d0a74c5bb40 100644 --- a/cpp-ch/clickhouse.version +++ b/cpp-ch/clickhouse.version @@ -1,3 +1,4 @@ CH_ORG=Kyligence CH_BRANCH=rebase_ch/20240621 -CH_COMMIT=acf666c1c4f +CH_COMMIT=c811cbb985f + diff --git a/cpp-ch/local-engine/Functions/SparkArrayFlatten.cpp b/cpp-ch/local-engine/Functions/SparkArrayFlatten.cpp new file mode 100644 index 0000000000000..dc720976dff3f --- /dev/null +++ b/cpp-ch/local-engine/Functions/SparkArrayFlatten.cpp @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int ILLEGAL_TYPE_OF_ARGUMENT; + extern const int ILLEGAL_COLUMN; +} + +/// arrayFlatten([[1, 2, 3], [4, 5]]) = [1, 2, 3, 4, 5] - flatten array. +class SparkArrayFlatten : public IFunction +{ +public: + static constexpr auto name = "sparkArrayFlatten"; + + static FunctionPtr create(ContextPtr) { return std::make_shared(); } + + size_t getNumberOfArguments() const override { return 1; } + bool useDefaultImplementationForConstants() const override { return true; } + bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; } + + DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override + { + if (!isArray(arguments[0])) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument of function {}, expected Array", + arguments[0]->getName(), getName()); + + DataTypePtr nested_type = arguments[0]; + std::cerr << "flatten input type: " << nested_type->getName() << std::endl; + // bool nest_nullable_array = checkAndGetDataType(removeNullable(arguments[0]).get())->getNestedType()->isNullable(); + nested_type = checkAndGetDataType(removeNullable(nested_type).get())->getNestedType(); + std::cerr << "fallten result type: " << nested_type->getName() << std::endl; + return nested_type; + } + + ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override + { + /** We create an array column with array elements as the most deep elements of nested arrays, + * and construct offsets by selecting elements of most deep offsets by values of ancestor offsets. + * +Example 1: + +Source column: Array(Array(UInt8)): +Row 1: [[1, 2, 3], [4, 5]], Row 2: [[6], [7, 8]] +data: [1, 2, 3], [4, 5], [6], [7, 8] +offsets: 2, 4 +data.data: 1 2 3 4 5 6 7 8 +data.offsets: 3 5 6 8 + +Result column: Array(UInt8): +Row 1: [1, 2, 3, 4, 5], Row 2: [6, 7, 8] +data: 1 2 3 4 5 6 7 8 +offsets: 5 8 + +Result offsets are selected from the most deep (data.offsets) by previous deep (offsets) (and values are decremented by one): +3 5 6 8 + ^ ^ + +Example 2: + +Source column: Array(Array(Array(UInt8))): +Row 1: [[], [[1], [], [2, 3]]], Row 2: [[[4]]] + +most deep data: 1 2 3 4 + +offsets1: 2 3 +offsets2: 0 3 4 +- ^ ^ - select by prev offsets +offsets3: 1 1 3 4 +- ^ ^ - select by prev offsets + +result offsets: 3, 4 +result: Row 1: [1, 2, 3], Row2: [4] + */ + + const ColumnArray * src_col = checkAndGetColumn(arguments[0].column.get()); + + if (!src_col) + throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal column {} in argument of function 'arrayFlatten'", + arguments[0].column->getName()); + + const IColumn::Offsets & src_offsets = src_col->getOffsets(); + + ColumnArray::ColumnOffsets::MutablePtr result_offsets_column; + const IColumn::Offsets * prev_offsets = &src_offsets; + const IColumn * prev_data = &src_col->getData(); + bool nullable = prev_data->isNullable(); + // when array has null element, return null + if (nullable) + { + const ColumnNullable * nullable_column = checkAndGetColumn(prev_data); + prev_data = nullable_column->getNestedColumnPtr().get(); + for (size_t i = 0; i < nullable_column->size(); i++) + { + std::cerr << i <<" : " << nullable_column->isNullAt(i) << ": " << nullable_column->getNestedColumnPtr()->getName() << std::endl; + if (nullable_column->isNullAt(i)) + { + auto res= nullable_column->cloneEmpty(); + res->insertDefault(); + return res; + } + } + } + if (isNothing(prev_data->getDataType())) + return prev_data->cloneResized(input_rows_count); + // only flatten one dimension + if (const ColumnArray * next_col = checkAndGetColumn(prev_data)) + { + result_offsets_column = ColumnArray::ColumnOffsets::create(input_rows_count); + + IColumn::Offsets & result_offsets = result_offsets_column->getData(); + + const IColumn::Offsets * next_offsets = &next_col->getOffsets(); + + for (size_t i = 0; i < input_rows_count; ++i) + result_offsets[i] = (*next_offsets)[(*prev_offsets)[i] - 1]; /// -1 array subscript is Ok, see PaddedPODArray + prev_data = &next_col->getData(); + } + + + auto res = ColumnArray::create( + prev_data->getPtr(), + result_offsets_column ? std::move(result_offsets_column) : src_col->getOffsetsPtr()); + if (nullable) + return makeNullable(res); + return res; + } + +private: + String getName() const override + { + return name; + } +}; + + +REGISTER_FUNCTION(SparkArrayFlatten) +{ + factory.registerFunction(); +} + +} diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.h b/cpp-ch/local-engine/Parser/SerializedPlanParser.h index 71cdca58a6ce1..a120c8fca01c7 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.h +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.h @@ -180,6 +180,7 @@ static const std::map SCALAR_FUNCTIONS {"array", "array"}, {"shuffle", "arrayShuffle"}, {"range", "range"}, /// dummy mapping + {"flatten", "sparkArrayFlatten"}, // map functions {"map", "map"}, diff --git a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index 8572ef54d5c8a..f4fe29854aec3 100644 --- a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -674,7 +674,6 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("Sequence with default step") .exclude("Reverse") .exclude("elementAt") - .exclude("Flatten") .exclude("ArrayRepeat") .exclude("Array remove") .exclude("Array Distinct") diff --git a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index 50e7929e46199..b8a3f35994564 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -714,7 +714,6 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("Sequence with default step") .exclude("Reverse") .exclude("elementAt") - .exclude("Flatten") .exclude("ArrayRepeat") .exclude("Array remove") .exclude("Array Distinct")