Skip to content

Commit

Permalink
fix array_join diff
Browse files Browse the repository at this point in the history
  • Loading branch information
KevinyhZou committed Aug 19, 2024
1 parent b6c1902 commit a69a591
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ object CHExpressionUtil {
)

final val CH_BLACKLIST_SCALAR_FUNCTION: Map[String, FunctionValidator] = Map(
ARRAY_JOIN -> ArrayJoinValidator(),
SPLIT_PART -> DefaultValidator(),
TO_UNIX_TIMESTAMP -> UnixTimeStampValidator(),
UNIX_TIMESTAMP -> UnixTimeStampValidator(),
Expand Down
44 changes: 32 additions & 12 deletions cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <Functions/IFunction.h>
#include <Functions/FunctionFactory.h>
#include <DataTypes/DataTypeString.h>
#include <iostream>

using namespace DB;

Expand Down Expand Up @@ -54,23 +55,39 @@ class SparkFunctionArrayJoin : public IFunction
return makeNullable(data_type);
}

ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t) const override
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
{
if (arguments.size() != 2 && arguments.size() != 3)
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} must have 2 or 3 arguments", getName());

const auto * arg_null_col = checkAndGetColumn<ColumnNullable>(arguments[0].column.get());
const ColumnArray * array_col;
if (!arg_null_col)
array_col = checkAndGetColumn<ColumnArray>(arguments[0].column.get());
auto res_col = ColumnString::create();
auto null_col = ColumnUInt8::create(input_rows_count, 0);
PaddedPODArray<UInt8> & null_result = null_col->getData();
if (input_rows_count == 0)
return ColumnNullable::create(std::move(res_col), std::move(null_col));

const auto * arg_const_col = checkAndGetColumn<ColumnConst>(arguments[0].column.get());
const ColumnArray * array_col = nullptr;
const ColumnNullable * arg_null_col = nullptr;
if (arg_const_col)
{
if (arg_const_col->onlyNull())
{
null_result[0] = 1;
return ColumnNullable::create(std::move(res_col), std::move(null_col));
}
array_col = checkAndGetColumn<ColumnArray>(arg_const_col->getDataColumnPtr().get());
}
else
array_col = checkAndGetColumn<ColumnArray>(arg_null_col->getNestedColumnPtr().get());
{
arg_null_col = checkAndGetColumn<ColumnNullable>(arguments[0].column.get());
if (!arg_null_col)
array_col = checkAndGetColumn<ColumnArray>(arguments[0].column.get());
else
array_col = checkAndGetColumn<ColumnArray>(arg_null_col->getNestedColumnPtr().get());
}
if (!array_col)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Function {} 1st argument must be array type", getName());

auto res_col = ColumnString::create();
auto null_col = ColumnUInt8::create(array_col->size(), 0);
PaddedPODArray<UInt8> & null_result = null_col->getData();
std::pair<bool, StringRef> delim_p, null_replacement_p;
bool return_result = false;
auto checkAndGetConstString = [&](const ColumnPtr & col) -> std::pair<bool, StringRef>
Expand Down Expand Up @@ -145,7 +162,7 @@ class SparkFunctionArrayJoin : public IFunction
}
}
};
if (arg_null_col->isNullAt(i))
if (arg_null_col && arg_null_col->isNullAt(i))
{
setResultNull();
continue;
Expand All @@ -166,9 +183,9 @@ class SparkFunctionArrayJoin : public IFunction
continue;
}
}

size_t array_size = array_offsets[i] - current_offset;
size_t data_pos = array_pos == 0 ? 0 : string_offsets[array_pos - 1];
size_t last_not_null_pos = 0;
for (size_t j = 0; j < array_size; ++j)
{
if (array_nested_col && array_nested_col->isNullAt(j + array_pos))
Expand All @@ -179,11 +196,14 @@ class SparkFunctionArrayJoin : public IFunction
if (j != array_size - 1)
res += delim.toString();
}
else if (j == array_size - 1)
res = res.substr(0, last_not_null_pos);
}
else
{
const StringRef s(&string_data[data_pos], string_offsets[j + array_pos] - data_pos - 1);
res += s.toString();
last_not_null_pos = res.size();
if (j != array_size - 1)
res += delim.toString();
}
Expand Down

0 comments on commit a69a591

Please sign in to comment.