diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala index cf45c1118f13..e9bee84396f8 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala @@ -209,7 +209,6 @@ object CHExpressionUtil { UNIX_MICROS -> DefaultValidator(), TIMESTAMP_MILLIS -> DefaultValidator(), TIMESTAMP_MICROS -> DefaultValidator(), - FLATTEN -> DefaultValidator(), STACK -> DefaultValidator() ) } diff --git a/cpp-ch/clickhouse.version b/cpp-ch/clickhouse.version index 4a3088e54309..54d0a74c5bb4 100644 --- a/cpp-ch/clickhouse.version +++ b/cpp-ch/clickhouse.version @@ -1,3 +1,4 @@ CH_ORG=Kyligence CH_BRANCH=rebase_ch/20240621 -CH_COMMIT=acf666c1c4f +CH_COMMIT=c811cbb985f + diff --git a/cpp-ch/local-engine/Functions/SparkArrayFlatten.cpp b/cpp-ch/local-engine/Functions/SparkArrayFlatten.cpp new file mode 100644 index 000000000000..d39bca5ea104 --- /dev/null +++ b/cpp-ch/local-engine/Functions/SparkArrayFlatten.cpp @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int ILLEGAL_TYPE_OF_ARGUMENT; + extern const int ILLEGAL_COLUMN; +} + +/// arrayFlatten([[1, 2, 3], [4, 5]]) = [1, 2, 3, 4, 5] - flatten array. +class SparkArrayFlatten : public IFunction +{ +public: + static constexpr auto name = "sparkArrayFlatten"; + + static FunctionPtr create(ContextPtr) { return std::make_shared(); } + + size_t getNumberOfArguments() const override { return 1; } + bool useDefaultImplementationForConstants() const override { return true; } + bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; } + + DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override + { + if (!isArray(arguments[0])) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument of function {}, expected Array", + arguments[0]->getName(), getName()); + + DataTypePtr nested_type = arguments[0]; + nested_type = checkAndGetDataType(removeNullable(nested_type).get())->getNestedType(); + return nested_type; + } + + ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override + { + /** We create an array column with array elements as the most deep elements of nested arrays, + * and construct offsets by selecting elements of most deep offsets by values of ancestor offsets. + * +Example 1: + +Source column: Array(Array(UInt8)): +Row 1: [[1, 2, 3], [4, 5]], Row 2: [[6], [7, 8]] +data: [1, 2, 3], [4, 5], [6], [7, 8] +offsets: 2, 4 +data.data: 1 2 3 4 5 6 7 8 +data.offsets: 3 5 6 8 + +Result column: Array(UInt8): +Row 1: [1, 2, 3, 4, 5], Row 2: [6, 7, 8] +data: 1 2 3 4 5 6 7 8 +offsets: 5 8 + +Result offsets are selected from the most deep (data.offsets) by previous deep (offsets) (and values are decremented by one): +3 5 6 8 + ^ ^ + +Example 2: + +Source column: Array(Array(Array(UInt8))): +Row 1: [[], [[1], [], [2, 3]]], Row 2: [[[4]]] + +most deep data: 1 2 3 4 + +offsets1: 2 3 +offsets2: 0 3 4 +- ^ ^ - select by prev offsets +offsets3: 1 1 3 4 +- ^ ^ - select by prev offsets + +result offsets: 3, 4 +result: Row 1: [1, 2, 3], Row2: [4] + */ + + const ColumnArray * src_col = checkAndGetColumn(arguments[0].column.get()); + + if (!src_col) + throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal column {} in argument of function 'arrayFlatten'", + arguments[0].column->getName()); + + const IColumn::Offsets & src_offsets = src_col->getOffsets(); + + ColumnArray::ColumnOffsets::MutablePtr result_offsets_column; + const IColumn::Offsets * prev_offsets = &src_offsets; + const IColumn * prev_data = &src_col->getData(); + bool nullable = prev_data->isNullable(); + // when array has null element, return null + if (nullable) + { + const ColumnNullable * nullable_column = checkAndGetColumn(prev_data); + prev_data = nullable_column->getNestedColumnPtr().get(); + for (size_t i = 0; i < nullable_column->size(); i++) + { + if (nullable_column->isNullAt(i)) + { + auto res= nullable_column->cloneEmpty(); + res->insertManyDefaults(input_rows_count); + return res; + } + } + } + if (isNothing(prev_data->getDataType())) + return prev_data->cloneResized(input_rows_count); + // only flatten one dimension + if (const ColumnArray * next_col = checkAndGetColumn(prev_data)) + { + result_offsets_column = ColumnArray::ColumnOffsets::create(input_rows_count); + + IColumn::Offsets & result_offsets = result_offsets_column->getData(); + + const IColumn::Offsets * next_offsets = &next_col->getOffsets(); + + for (size_t i = 0; i < input_rows_count; ++i) + result_offsets[i] = (*next_offsets)[(*prev_offsets)[i] - 1]; /// -1 array subscript is Ok, see PaddedPODArray + prev_data = &next_col->getData(); + } + + auto res = ColumnArray::create( + prev_data->getPtr(), + result_offsets_column ? std::move(result_offsets_column) : src_col->getOffsetsPtr()); + if (nullable) + return makeNullable(res); + return res; + } + +private: + String getName() const override + { + return name; + } +}; + +REGISTER_FUNCTION(SparkArrayFlatten) +{ + factory.registerFunction(); +} + +} diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.h b/cpp-ch/local-engine/Parser/SerializedPlanParser.h index 71cdca58a6ce..a120c8fca01c 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.h +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.h @@ -180,6 +180,7 @@ static const std::map SCALAR_FUNCTIONS {"array", "array"}, {"shuffle", "arrayShuffle"}, {"range", "range"}, /// dummy mapping + {"flatten", "sparkArrayFlatten"}, // map functions {"map", "map"}, diff --git a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index 8572ef54d5c8..1626716805cb 100644 --- a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -172,6 +172,7 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("shuffle function - array for primitive type not containing null") .exclude("shuffle function - array for primitive type containing null") .exclude("shuffle function - array for non-primitive type") + .exclude("flatten function") enableSuite[GlutenDataFrameHintSuite] enableSuite[GlutenDataFrameImplicitsSuite] enableSuite[GlutenDataFrameJoinSuite].exclude( @@ -674,7 +675,6 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("Sequence with default step") .exclude("Reverse") .exclude("elementAt") - .exclude("Flatten") .exclude("ArrayRepeat") .exclude("Array remove") .exclude("Array Distinct") diff --git a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/GlutenDataFrameFunctionsSuite.scala b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/GlutenDataFrameFunctionsSuite.scala index 2b0b40790a76..e64f760ab55f 100644 --- a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/GlutenDataFrameFunctionsSuite.scala +++ b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/GlutenDataFrameFunctionsSuite.scala @@ -49,4 +49,86 @@ class GlutenDataFrameFunctionsSuite extends DataFrameFunctionsSuite with GlutenS false ) } + + testGluten("flatten function") { + // Test cases with a primitive type + val intDF = Seq( + (Seq(Seq(1, 2, 3), Seq(4, 5), Seq(6))), + (Seq(Seq(1, 2))), + (Seq(Seq(1), Seq.empty)), + (Seq(Seq.empty, Seq(1))) + ).toDF("i") + + val intDFResult = Seq(Row(Seq(1, 2, 3, 4, 5, 6)), Row(Seq(1, 2)), Row(Seq(1)), Row(Seq(1))) + + def testInt(): Unit = { + checkAnswer(intDF.select(flatten($"i")), intDFResult) + checkAnswer(intDF.selectExpr("flatten(i)"), intDFResult) + } + + // Test with local relation, the Project will be evaluated without codegen + testInt() + // Test with cached relation, the Project will be evaluated with codegen + intDF.cache() + testInt() + + // Test cases with non-primitive types + val strDF = Seq( + (Seq(Seq("a", "b"), Seq("c"), Seq("d", "e", "f"))), + (Seq(Seq("a", "b"))), + (Seq(Seq("a", null), Seq(null, "b"), Seq(null, null))), + (Seq(Seq("a"), Seq.empty)), + (Seq(Seq.empty, Seq("a"))) + ).toDF("s") + + val strDFResult = Seq( + Row(Seq("a", "b", "c", "d", "e", "f")), + Row(Seq("a", "b")), + Row(Seq("a", null, null, "b", null, null)), + Row(Seq("a")), + Row(Seq("a"))) + + def testString(): Unit = { + checkAnswer(strDF.select(flatten($"s")), strDFResult) + checkAnswer(strDF.selectExpr("flatten(s)"), strDFResult) + } + + // Test with local relation, the Project will be evaluated without codegen + testString() + // Test with cached relation, the Project will be evaluated with codegen + strDF.cache() + testString() + + val arrDF = Seq((1, "a", Seq(1, 2, 3))).toDF("i", "s", "arr") + + def testArray(): Unit = { + checkAnswer( + arrDF.selectExpr("flatten(array(arr, array(null, 5), array(6, null)))"), + Seq(Row(Seq(1, 2, 3, null, 5, 6, null)))) + checkAnswer( + arrDF.selectExpr("flatten(array(array(arr, arr), array(arr)))"), + Seq(Row(Seq(Seq(1, 2, 3), Seq(1, 2, 3), Seq(1, 2, 3))))) + } + + // Test with local relation, the Project will be evaluated without codegen + testArray() + // Test with cached relation, the Project will be evaluated with codegen + arrDF.cache() + testArray() + + // Error test cases + val oneRowDF = Seq((1, "a", Seq(1, 2, 3))).toDF("i", "s", "arr") + intercept[AnalysisException] { + oneRowDF.select(flatten($"arr")) + } + intercept[AnalysisException] { + oneRowDF.select(flatten($"i")) + } + intercept[AnalysisException] { + oneRowDF.select(flatten($"s")) + } + intercept[AnalysisException] { + oneRowDF.selectExpr("flatten(null)") + } + } } diff --git a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index 50e7929e4619..3147c7c3dbf3 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -190,6 +190,7 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("shuffle function - array for primitive type not containing null") .exclude("shuffle function - array for primitive type containing null") .exclude("shuffle function - array for non-primitive type") + .exclude("flatten function") enableSuite[GlutenDataFrameHintSuite] enableSuite[GlutenDataFrameImplicitsSuite] enableSuite[GlutenDataFrameJoinSuite].exclude( @@ -714,7 +715,6 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("Sequence with default step") .exclude("Reverse") .exclude("elementAt") - .exclude("Flatten") .exclude("ArrayRepeat") .exclude("Array remove") .exclude("Array Distinct") diff --git a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenDataFrameFunctionsSuite.scala b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenDataFrameFunctionsSuite.scala index 2b0b40790a76..e64f760ab55f 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenDataFrameFunctionsSuite.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenDataFrameFunctionsSuite.scala @@ -49,4 +49,86 @@ class GlutenDataFrameFunctionsSuite extends DataFrameFunctionsSuite with GlutenS false ) } + + testGluten("flatten function") { + // Test cases with a primitive type + val intDF = Seq( + (Seq(Seq(1, 2, 3), Seq(4, 5), Seq(6))), + (Seq(Seq(1, 2))), + (Seq(Seq(1), Seq.empty)), + (Seq(Seq.empty, Seq(1))) + ).toDF("i") + + val intDFResult = Seq(Row(Seq(1, 2, 3, 4, 5, 6)), Row(Seq(1, 2)), Row(Seq(1)), Row(Seq(1))) + + def testInt(): Unit = { + checkAnswer(intDF.select(flatten($"i")), intDFResult) + checkAnswer(intDF.selectExpr("flatten(i)"), intDFResult) + } + + // Test with local relation, the Project will be evaluated without codegen + testInt() + // Test with cached relation, the Project will be evaluated with codegen + intDF.cache() + testInt() + + // Test cases with non-primitive types + val strDF = Seq( + (Seq(Seq("a", "b"), Seq("c"), Seq("d", "e", "f"))), + (Seq(Seq("a", "b"))), + (Seq(Seq("a", null), Seq(null, "b"), Seq(null, null))), + (Seq(Seq("a"), Seq.empty)), + (Seq(Seq.empty, Seq("a"))) + ).toDF("s") + + val strDFResult = Seq( + Row(Seq("a", "b", "c", "d", "e", "f")), + Row(Seq("a", "b")), + Row(Seq("a", null, null, "b", null, null)), + Row(Seq("a")), + Row(Seq("a"))) + + def testString(): Unit = { + checkAnswer(strDF.select(flatten($"s")), strDFResult) + checkAnswer(strDF.selectExpr("flatten(s)"), strDFResult) + } + + // Test with local relation, the Project will be evaluated without codegen + testString() + // Test with cached relation, the Project will be evaluated with codegen + strDF.cache() + testString() + + val arrDF = Seq((1, "a", Seq(1, 2, 3))).toDF("i", "s", "arr") + + def testArray(): Unit = { + checkAnswer( + arrDF.selectExpr("flatten(array(arr, array(null, 5), array(6, null)))"), + Seq(Row(Seq(1, 2, 3, null, 5, 6, null)))) + checkAnswer( + arrDF.selectExpr("flatten(array(array(arr, arr), array(arr)))"), + Seq(Row(Seq(Seq(1, 2, 3), Seq(1, 2, 3), Seq(1, 2, 3))))) + } + + // Test with local relation, the Project will be evaluated without codegen + testArray() + // Test with cached relation, the Project will be evaluated with codegen + arrDF.cache() + testArray() + + // Error test cases + val oneRowDF = Seq((1, "a", Seq(1, 2, 3))).toDF("i", "s", "arr") + intercept[AnalysisException] { + oneRowDF.select(flatten($"arr")) + } + intercept[AnalysisException] { + oneRowDF.select(flatten($"i")) + } + intercept[AnalysisException] { + oneRowDF.select(flatten($"s")) + } + intercept[AnalysisException] { + oneRowDF.selectExpr("flatten(null)") + } + } }