Skip to content

Commit

Permalink
support arrays_overlap
Browse files Browse the repository at this point in the history
  • Loading branch information
KevinyhZou committed Aug 15, 2024
1 parent f126f3c commit f123ff5
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,6 @@ object CHExpressionUtil {
UNIX_TIMESTAMP -> UnixTimeStampValidator(),
SEQUENCE -> SequenceValidator(),
GET_JSON_OBJECT -> GetJsonObjectValidator(),
ARRAYS_OVERLAP -> DefaultValidator(),
SPLIT -> StringSplitValidator(),
SUBSTRING_INDEX -> SubstringIndexValidator(),
LPAD -> StringLPadValidator(),
Expand Down
153 changes: 153 additions & 0 deletions cpp-ch/local-engine/Functions/SparkFunctionArraysOverlap.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <Columns/ColumnString.h>
#include <Columns/ColumnNullable.h>
#include <Functions/IFunction.h>
#include <Functions/FunctionFactory.h>
#include <DataTypes/DataTypeString.h>
#include <DataTypes/DataTypesNumber.h>
#include <iostream>

using namespace DB;

namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
}

namespace local_engine
{
class SparkFunctionArraysOverlap : public IFunction
{
public:
static constexpr auto name = "sparkArraysOverlap";
static FunctionPtr create(ContextPtr) { return std::make_shared<SparkFunctionArraysOverlap>(); }
SparkFunctionArraysOverlap() = default;
~SparkFunctionArraysOverlap() override = default;
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo &) const override { return true; }
size_t getNumberOfArguments() const override { return 2; }
String getName() const override { return name; }
bool useDefaultImplementationForNulls() const override { return false; }
bool useDefaultImplementationForConstants() const override { return false; }

DB::DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName &) const override
{
auto data_type = std::make_shared<DataTypeUInt8>();
return makeNullable(data_type);
}

ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
{
if (arguments.size() != 2)
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} must have 2 arguments", getName());

auto res = ColumnUInt8::create(input_rows_count, 0);
auto null_map = ColumnUInt8::create(input_rows_count, 0);
PaddedPODArray<UInt8> & res_data = res->getData();
PaddedPODArray<UInt8> & null_map_data = null_map->getData();
if (input_rows_count == 0)
return ColumnNullable::create(std::move(res), std::move(null_map));

const ColumnArray * array_col_1 = nullptr, * array_col_2 = nullptr;
const ColumnConst * const_col_1 = checkAndGetColumn<ColumnConst>(arguments[0].column.get());
const ColumnConst * const_col_2 = checkAndGetColumn<ColumnConst>(arguments[1].column.get());
if (const_col_1)
array_col_1 = checkAndGetColumn<ColumnArray>(const_col_1->getDataColumnPtr().get());
else
{
const auto * null_col_1 = checkAndGetColumn<ColumnNullable>(arguments[0].column.get());
array_col_1 = checkAndGetColumn<ColumnArray>(null_col_1->getNestedColumnPtr().get());
}
if (const_col_2)
array_col_2 = checkAndGetColumn<ColumnArray>(const_col_2->getDataColumnPtr().get());
else
{
const auto * null_col_2 = checkAndGetColumn<ColumnNullable>(arguments[1].column.get());
array_col_2 = checkAndGetColumn<ColumnArray>(null_col_2->getNestedColumnPtr().get());
}
if (!array_col_1 || !array_col_2)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Function {} 1st/2nd argument must be array type", getName());

const ColumnArray::Offsets & array_offsets_1 = array_col_1->getOffsets();
const ColumnArray::Offsets & array_offsets_2 = array_col_2->getOffsets();

size_t current_offset_1 = 0, current_offset_2 = 0;
for (size_t i = 0; i < array_col_1->size(); ++i)
{
if (arguments[0].column->isNullAt(i) || arguments[1].column->isNullAt(i))
{
null_map_data[i] = 1;
continue;
}
size_t array_size_1 = array_offsets_1[i] - current_offset_1;
size_t array_size_2 = array_offsets_2[i] - current_offset_2;
bool has_null_equals = false;
auto executeCompare = [&](const IColumn & col1, const IColumn & col2) -> void
{
for (size_t j = 0; j < array_size_1 && !res_data[i]; ++j)
{
for (size_t k = 0; k < array_size_2; ++k)
{
if (col1.compareAt(j, k, col2, -1) == 0)
{
if (!col1.isNullAt(j))
{
res_data[i] = 1;
break;
}
else
has_null_equals = true;
}
}
}
};
if (array_col_1->getData().getDataType() == array_col_2->getData().getDataType())
{
executeCompare(array_col_1->getData(), array_col_2->getData());
}
else if (array_col_1->getData().isNullable() || array_col_2->getData().isNullable())
{
if (array_col_1->getData().isNullable())
{
const ColumnNullable * array_null_col_1 = assert_cast<const ColumnNullable *>(&array_col_1->getData());
executeCompare(array_null_col_1->getNestedColumn(), array_col_2->getData());
}
if (array_col_2->getData().isNullable())
{
const ColumnNullable * array_null_col_2 = assert_cast<const ColumnNullable *>(&array_col_2->getData());
executeCompare(array_col_1->getData(), array_null_col_2->getNestedColumn());
}
}
if (!res_data[i] && has_null_equals)
null_map_data[i] = 1;
current_offset_1 = array_offsets_1[i];
current_offset_2 = array_offsets_2[i];
}
return ColumnNullable::create(std::move(res), std::move(null_map));
}
};

REGISTER_FUNCTION(SparkArraysOverlap)
{
factory.registerFunction<SparkFunctionArraysOverlap>();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Shuffle, shuffle, arrayShuffle);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Range, range, range);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Flatten, flatten, sparkArrayFlatten);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(ArrayJoin, array_join, sparkArrayJoin);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(ArraysOverlap, arrays_overlap, sparkArraysOverlap);

// map functions
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Map, map, map);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,6 @@ class ClickHouseTestSettings extends BackendTestSettings {
.exclude("MapEntries")
.exclude("Map Concat")
.exclude("MapFromEntries")
.exclude("ArraysOverlap")
.exclude("ArraysZip")
.exclude("Sequence of numbers")
.exclude("Sequence of timestamps")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,6 @@ class ClickHouseTestSettings extends BackendTestSettings {
.exclude("MapEntries")
.exclude("Map Concat")
.exclude("MapFromEntries")
.exclude("ArraysOverlap")
.exclude("ArraysZip")
.exclude("Sequence of numbers")
.exclude("Sequence of timestamps")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,6 @@ class ClickHouseTestSettings extends BackendTestSettings {
.exclude("MapEntries")
.exclude("Map Concat")
.exclude("MapFromEntries")
.exclude("ArraysOverlap")
.exclude("ArraysZip")
.exclude("Sequence of numbers")
.exclude("Sequence of timestamps")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,6 @@ class ClickHouseTestSettings extends BackendTestSettings {
.exclude("MapEntries")
.exclude("Map Concat")
.exclude("MapFromEntries")
.exclude("ArraysOverlap")
.exclude("ArraysZip")
.exclude("Sequence of numbers")
.exclude("Sequence of timestamps")
Expand Down

0 comments on commit f123ff5

Please sign in to comment.