Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add single parameter (array<array<T>>) support for array_intersect #11305

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 168 additions & 5 deletions velox/functions/prestosql/ArrayIntersectExcept.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,19 +138,146 @@ void generateSet(
DecodedVector* decodeArrayElements(
exec::LocalDecodedVector& arrayDecoder,
exec::LocalDecodedVector& elementsDecoder,
const SelectivityVector& rows) {
const SelectivityVector& rows,
SelectivityVector* elementRows) {
auto decodedVector = arrayDecoder.get();
auto baseArrayVector = arrayDecoder->base()->as<ArrayVector>();

// Decode and acquire array elements vector.
auto elementsVector = baseArrayVector->elements();
auto elementsSelectivityRows = toElementRows(
*elementRows = toElementRows(
elementsVector->size(), rows, baseArrayVector, decodedVector->indices());
elementsDecoder.get()->decode(*elementsVector, elementsSelectivityRows);
elementsDecoder.get()->decode(*elementsVector, *elementRows);
auto decodedElementsVector = elementsDecoder.get();
return decodedElementsVector;
}

DecodedVector* decodeArrayElements(
exec::LocalDecodedVector& arrayDecoder,
exec::LocalDecodedVector& elementsDecoder,
const SelectivityVector& rows) {
SelectivityVector elementRows;
return decodeArrayElements(arrayDecoder, elementsDecoder, rows, &elementRows);
}

template <typename T>
class ArraysIntersectSingleParam : public exec::VectorFunction {
public:
/// This class is used for array_intersect function with single parameter.
void apply(
const SelectivityVector& rows,
std::vector<VectorPtr>& args,
const TypePtr& outputType,
exec::EvalCtx& context,
VectorPtr& result) const override {
memory::MemoryPool* pool = context.pool();

exec::LocalDecodedVector outerArrayDecoder(context, *args[0], rows);
auto decodedOuterArray = outerArrayDecoder.get();
auto outerArray = decodedOuterArray->base()->as<ArrayVector>();

exec::LocalDecodedVector innerArrayDecoder(context);
SelectivityVector innerRows;
auto decodedInnerArray = decodeArrayElements(
outerArrayDecoder, innerArrayDecoder, rows, &innerRows);
auto innerArray = decodedInnerArray->base()->as<ArrayVector>();

exec::LocalDecodedVector elementDecoder(context);
SelectivityVector elementRows;
auto decodedInnerElement = decodeArrayElements(
innerArrayDecoder, elementDecoder, innerRows, &elementRows);

const auto elementCount =
countElements<ArrayVector>(innerRows, *decodedInnerArray);
const auto rowCount = args[0]->size();

// Allocate new vectors for indices, nulls, length and offsets.
BufferPtr newIndices = allocateIndices(elementCount, pool);
BufferPtr newElementNulls =
AlignedBuffer::allocate<bool>(elementCount, pool, bits::kNotNull);
BufferPtr newLengths = allocateSizes(rowCount, pool);
BufferPtr newOffsets = allocateOffsets(rowCount, pool);
BufferPtr newNulls = allocateNulls(rowCount, pool);

// Pointers and cursors to the raw data.
auto rawNewIndices = newIndices->asMutable<vector_size_t>();
auto rawNewElementNulls = newElementNulls->asMutable<uint64_t>();
auto rawNewOffsets = newOffsets->asMutable<vector_size_t>();
auto rawNewLengths = newLengths->asMutable<vector_size_t>();
auto rawNewNulls = newNulls->asMutable<uint64_t>();
auto indicesCursor = 0;

rows.applyToSelected([&](vector_size_t row) {
rawNewOffsets[row] = indicesCursor;
std::optional<vector_size_t> finalNullIndex;
SetWithNull<T> finalSet;

auto idx = decodedOuterArray->index(row);
auto offset = outerArray->offsetAt(idx);
auto size = outerArray->sizeAt(idx);
bool setInitialized = false;
for (auto i = offset; i < (offset + size); ++i) {
auto innerIdx = decodedInnerArray->index(i);
auto innerOffset = innerArray->offsetAt(innerIdx);
auto innerSize = innerArray->sizeAt(innerIdx);

// 1. prepare for next iteration
indicesCursor = rawNewOffsets[row];
SetWithNull<T> intermediateSet;
std::optional<vector_size_t> intermediateNullIndex;

// 2. Null array
if (decodedInnerArray->isNullAt(i)) {
bits::setNull(rawNewNulls, row, true);
finalNullIndex = intermediateNullIndex;
rawNewLengths[row] = indicesCursor - rawNewOffsets[row];
break;
}

// 3. Regular array
for (auto j = innerOffset; j < (innerOffset + innerSize); ++j) {
// null element
if (decodedInnerElement->isNullAt(j)) {
if ((!setInitialized || finalSet.hasNull) &&
!intermediateNullIndex.has_value()) {
intermediateSet.hasNull = true;
intermediateNullIndex = std::optional(indicesCursor++);
}
continue;
}
// regular element
if (!setInitialized || finalSet.count(decodedInnerElement, j)) {
auto success = intermediateSet.insert(decodedInnerElement, j);
if (success) {
rawNewIndices[indicesCursor++] = j;
}
}
}
setInitialized = true;
finalSet = intermediateSet;
finalNullIndex = intermediateNullIndex;
rawNewLengths[row] = indicesCursor - rawNewOffsets[row];
}

if (finalNullIndex.has_value()) {
bits::setNull(rawNewElementNulls, finalNullIndex.value(), true);
}
});

auto newElements = BaseVector::wrapInDictionary(
newElementNulls, newIndices, indicesCursor, innerArray->elements());
auto resultArray = std::make_shared<ArrayVector>(
pool,
outputType,
std::move(newNulls),
rowCount,
newOffsets,
newLengths,
newElements);
context.moveOrCopyResult(resultArray, rows, result);
}
};

// See documentation at https://prestodb.io/docs/current/functions/array.html
template <bool isIntersect, typename T>
class ArrayIntersectExceptFunction : public exec::VectorFunction {
Expand Down Expand Up @@ -414,7 +541,7 @@ class ArraysOverlapFunction : public exec::VectorFunction {
void validateMatchingArrayTypes(
const std::vector<exec::VectorFunctionArg>& inputArgs,
const std::string& name,
vector_size_t expectedArgCount) {
size_t expectedArgCount) {
VELOX_USER_CHECK_EQ(
inputArgs.size(),
expectedArgCount,
Expand Down Expand Up @@ -504,10 +631,30 @@ std::shared_ptr<exec::VectorFunction> createTypedArraysIntersectExcept(
}
}

template <TypeKind kind>
std::shared_ptr<exec::VectorFunction> createArraysIntersectSingleParam(
const TypePtr& elementType) {
if (elementType->providesCustomComparison()) {
return std::make_shared<ArraysIntersectSingleParam<WrappedVectorEntry>>();
} else {
using T = std::conditional_t<
TypeTraits<kind>::isPrimitiveType,
typename TypeTraits<kind>::NativeType,
WrappedVectorEntry>;
return std::make_shared<ArraysIntersectSingleParam<T>>();
}
}

std::shared_ptr<exec::VectorFunction> createArrayIntersect(
const std::string& name,
const std::vector<exec::VectorFunctionArg>& inputArgs,
const core::QueryConfig& /*config*/) {
if (inputArgs.size() == 1) {
auto elementType = inputArgs.front().type->childAt(0)->childAt(0);
return VELOX_DYNAMIC_TYPE_DISPATCH(
createArraysIntersectSingleParam, elementType->kind(), elementType);
}

validateMatchingArrayTypes(inputArgs, name, 2);
auto elementType = inputArgs.front().type->childAt(0);

Expand All @@ -534,6 +681,22 @@ std::shared_ptr<exec::VectorFunction> createArrayExcept(
elementType);
}

std::vector<std::shared_ptr<exec::FunctionSignature>>
arrayIntersectSignatures() {
return std::vector<std::shared_ptr<exec::FunctionSignature>>{
exec::FunctionSignatureBuilder()
.typeVariable("T")
.returnType("array(T)")
.argumentType("array(T)")
.argumentType("array(T)")
.build(),
exec::FunctionSignatureBuilder()
.typeVariable("T")
.returnType("array(T)")
.argumentType("array(array(T))")
.build()};
}

std::vector<std::shared_ptr<exec::FunctionSignature>> signatures(
const std::string& returnType) {
return std::vector<std::shared_ptr<exec::FunctionSignature>>{
Expand Down Expand Up @@ -600,7 +763,7 @@ VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(

VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(
udf_array_intersect,
signatures("array(T)"),
arrayIntersectSignatures(),
createArrayIntersect);

VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(
Expand Down
Loading
Loading