Skip to content

Commit

Permalink
support flatten
Browse files Browse the repository at this point in the history
  • Loading branch information
liuneng1994 committed Jun 24, 2024
1 parent e0fcfe5 commit fb018ca
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,6 @@ object CHExpressionUtil {
UNIX_MICROS -> DefaultValidator(),
TIMESTAMP_MILLIS -> DefaultValidator(),
TIMESTAMP_MICROS -> DefaultValidator(),
FLATTEN -> DefaultValidator(),
STACK -> DefaultValidator()
)
}
3 changes: 2 additions & 1 deletion cpp-ch/clickhouse.version
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
CH_ORG=Kyligence
CH_BRANCH=rebase_ch/20240621
CH_COMMIT=acf666c1c4f
CH_COMMIT=c811cbb985f

166 changes: 166 additions & 0 deletions cpp-ch/local-engine/Functions/SparkArrayFlatten.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <Functions/IFunction.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <DataTypes/DataTypeArray.h>
#include <Columns/ColumnArray.h>
#include <Columns/ColumnNullable.h>


namespace DB
{

namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int ILLEGAL_COLUMN;
}

/// arrayFlatten([[1, 2, 3], [4, 5]]) = [1, 2, 3, 4, 5] - flatten array.
class SparkArrayFlatten : public IFunction
{
public:
static constexpr auto name = "sparkArrayFlatten";

static FunctionPtr create(ContextPtr) { return std::make_shared<SparkArrayFlatten>(); }

size_t getNumberOfArguments() const override { return 1; }
bool useDefaultImplementationForConstants() const override { return true; }
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; }

DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!isArray(arguments[0]))
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument of function {}, expected Array",
arguments[0]->getName(), getName());

DataTypePtr nested_type = arguments[0];
std::cerr << "flatten input type: " << nested_type->getName() << std::endl;
// bool nest_nullable_array = checkAndGetDataType<DataTypeArray>(removeNullable(arguments[0]).get())->getNestedType()->isNullable();
nested_type = checkAndGetDataType<DataTypeArray>(removeNullable(nested_type).get())->getNestedType();
std::cerr << "fallten result type: " << nested_type->getName() << std::endl;
return nested_type;
}

ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
{
/** We create an array column with array elements as the most deep elements of nested arrays,
* and construct offsets by selecting elements of most deep offsets by values of ancestor offsets.
*
Example 1:
Source column: Array(Array(UInt8)):
Row 1: [[1, 2, 3], [4, 5]], Row 2: [[6], [7, 8]]
data: [1, 2, 3], [4, 5], [6], [7, 8]
offsets: 2, 4
data.data: 1 2 3 4 5 6 7 8
data.offsets: 3 5 6 8
Result column: Array(UInt8):
Row 1: [1, 2, 3, 4, 5], Row 2: [6, 7, 8]
data: 1 2 3 4 5 6 7 8
offsets: 5 8
Result offsets are selected from the most deep (data.offsets) by previous deep (offsets) (and values are decremented by one):
3 5 6 8
^ ^
Example 2:
Source column: Array(Array(Array(UInt8))):
Row 1: [[], [[1], [], [2, 3]]], Row 2: [[[4]]]
most deep data: 1 2 3 4
offsets1: 2 3
offsets2: 0 3 4
- ^ ^ - select by prev offsets
offsets3: 1 1 3 4
- ^ ^ - select by prev offsets
result offsets: 3, 4
result: Row 1: [1, 2, 3], Row2: [4]
*/

const ColumnArray * src_col = checkAndGetColumn<ColumnArray>(arguments[0].column.get());

if (!src_col)
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal column {} in argument of function 'arrayFlatten'",
arguments[0].column->getName());

const IColumn::Offsets & src_offsets = src_col->getOffsets();

ColumnArray::ColumnOffsets::MutablePtr result_offsets_column;
const IColumn::Offsets * prev_offsets = &src_offsets;
const IColumn * prev_data = &src_col->getData();
bool nullable = prev_data->isNullable();
// when array has null element, return null
if (nullable)
{
const ColumnNullable * nullable_column = checkAndGetColumn<ColumnNullable>(prev_data);
prev_data = nullable_column->getNestedColumnPtr().get();
for (size_t i = 0; i < nullable_column->size(); i++)
{
std::cerr << i <<" : " << nullable_column->isNullAt(i) << ": " << nullable_column->getNestedColumnPtr()->getName() << std::endl;
if (nullable_column->isNullAt(i))
{
auto res= nullable_column->cloneEmpty();
res->insertDefault();
return res;
}
}
}
if (isNothing(prev_data->getDataType()))
return prev_data->cloneResized(input_rows_count);
// only flatten one dimension
if (const ColumnArray * next_col = checkAndGetColumn<ColumnArray>(prev_data))
{
result_offsets_column = ColumnArray::ColumnOffsets::create(input_rows_count);

IColumn::Offsets & result_offsets = result_offsets_column->getData();

const IColumn::Offsets * next_offsets = &next_col->getOffsets();

for (size_t i = 0; i < input_rows_count; ++i)
result_offsets[i] = (*next_offsets)[(*prev_offsets)[i] - 1]; /// -1 array subscript is Ok, see PaddedPODArray
prev_data = &next_col->getData();
}


auto res = ColumnArray::create(
prev_data->getPtr(),
result_offsets_column ? std::move(result_offsets_column) : src_col->getOffsetsPtr());
if (nullable)
return makeNullable(res);
return res;
}

private:
String getName() const override
{
return name;
}
};


REGISTER_FUNCTION(SparkArrayFlatten)
{
factory.registerFunction<SparkArrayFlatten>();
}

}
1 change: 1 addition & 0 deletions cpp-ch/local-engine/Parser/SerializedPlanParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ static const std::map<std::string, std::string> SCALAR_FUNCTIONS
{"array", "array"},
{"shuffle", "arrayShuffle"},
{"range", "range"}, /// dummy mapping
{"flatten", "sparkArrayFlatten"},

// map functions
{"map", "map"},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,6 @@ class ClickHouseTestSettings extends BackendTestSettings {
.exclude("Sequence with default step")
.exclude("Reverse")
.exclude("elementAt")
.exclude("Flatten")
.exclude("ArrayRepeat")
.exclude("Array remove")
.exclude("Array Distinct")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,6 @@ class ClickHouseTestSettings extends BackendTestSettings {
.exclude("Sequence with default step")
.exclude("Reverse")
.exclude("elementAt")
.exclude("Flatten")
.exclude("ArrayRepeat")
.exclude("Array remove")
.exclude("Array Distinct")
Expand Down

0 comments on commit fb018ca

Please sign in to comment.