Skip to content

Commit

Permalink
Add single parameter (array<array<T>>) support for array_intersect (f…
Browse files Browse the repository at this point in the history
…acebookincubator#11305)

Summary: Pull Request resolved: facebookincubator#11305

Reviewed By: kevinwilfong

Differential Revision: D64698174

Pulled By: kewang1024

fbshipit-source-id: 6a9cefb24565c004fad956d8ba265172723fb29c
  • Loading branch information
kewang1024 authored and athmaja-n committed Jan 10, 2025
1 parent 1338994 commit 536ffe2
Show file tree
Hide file tree
Showing 2 changed files with 403 additions and 5 deletions.
173 changes: 168 additions & 5 deletions velox/functions/prestosql/ArrayIntersectExcept.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,19 +138,146 @@ void generateSet(
DecodedVector* decodeArrayElements(
exec::LocalDecodedVector& arrayDecoder,
exec::LocalDecodedVector& elementsDecoder,
const SelectivityVector& rows) {
const SelectivityVector& rows,
SelectivityVector* elementRows) {
auto decodedVector = arrayDecoder.get();
auto baseArrayVector = arrayDecoder->base()->as<ArrayVector>();

// Decode and acquire array elements vector.
auto elementsVector = baseArrayVector->elements();
auto elementsSelectivityRows = toElementRows(
*elementRows = toElementRows(
elementsVector->size(), rows, baseArrayVector, decodedVector->indices());
elementsDecoder.get()->decode(*elementsVector, elementsSelectivityRows);
elementsDecoder.get()->decode(*elementsVector, *elementRows);
auto decodedElementsVector = elementsDecoder.get();
return decodedElementsVector;
}

DecodedVector* decodeArrayElements(
exec::LocalDecodedVector& arrayDecoder,
exec::LocalDecodedVector& elementsDecoder,
const SelectivityVector& rows) {
SelectivityVector elementRows;
return decodeArrayElements(arrayDecoder, elementsDecoder, rows, &elementRows);
}

template <typename T>
class ArraysIntersectSingleParam : public exec::VectorFunction {
public:
/// This class is used for array_intersect function with single parameter.
void apply(
const SelectivityVector& rows,
std::vector<VectorPtr>& args,
const TypePtr& outputType,
exec::EvalCtx& context,
VectorPtr& result) const override {
memory::MemoryPool* pool = context.pool();

exec::LocalDecodedVector outerArrayDecoder(context, *args[0], rows);
auto decodedOuterArray = outerArrayDecoder.get();
auto outerArray = decodedOuterArray->base()->as<ArrayVector>();

exec::LocalDecodedVector innerArrayDecoder(context);
SelectivityVector innerRows;
auto decodedInnerArray = decodeArrayElements(
outerArrayDecoder, innerArrayDecoder, rows, &innerRows);
auto innerArray = decodedInnerArray->base()->as<ArrayVector>();

exec::LocalDecodedVector elementDecoder(context);
SelectivityVector elementRows;
auto decodedInnerElement = decodeArrayElements(
innerArrayDecoder, elementDecoder, innerRows, &elementRows);

const auto elementCount =
countElements<ArrayVector>(innerRows, *decodedInnerArray);
const auto rowCount = args[0]->size();

// Allocate new vectors for indices, nulls, length and offsets.
BufferPtr newIndices = allocateIndices(elementCount, pool);
BufferPtr newElementNulls =
AlignedBuffer::allocate<bool>(elementCount, pool, bits::kNotNull);
BufferPtr newLengths = allocateSizes(rowCount, pool);
BufferPtr newOffsets = allocateOffsets(rowCount, pool);
BufferPtr newNulls = allocateNulls(rowCount, pool);

// Pointers and cursors to the raw data.
auto rawNewIndices = newIndices->asMutable<vector_size_t>();
auto rawNewElementNulls = newElementNulls->asMutable<uint64_t>();
auto rawNewOffsets = newOffsets->asMutable<vector_size_t>();
auto rawNewLengths = newLengths->asMutable<vector_size_t>();
auto rawNewNulls = newNulls->asMutable<uint64_t>();
auto indicesCursor = 0;

rows.applyToSelected([&](vector_size_t row) {
rawNewOffsets[row] = indicesCursor;
std::optional<vector_size_t> finalNullIndex;
SetWithNull<T> finalSet;

auto idx = decodedOuterArray->index(row);
auto offset = outerArray->offsetAt(idx);
auto size = outerArray->sizeAt(idx);
bool setInitialized = false;
for (auto i = offset; i < (offset + size); ++i) {
auto innerIdx = decodedInnerArray->index(i);
auto innerOffset = innerArray->offsetAt(innerIdx);
auto innerSize = innerArray->sizeAt(innerIdx);

// 1. prepare for next iteration
indicesCursor = rawNewOffsets[row];
SetWithNull<T> intermediateSet;
std::optional<vector_size_t> intermediateNullIndex;

// 2. Null array
if (decodedInnerArray->isNullAt(i)) {
bits::setNull(rawNewNulls, row, true);
finalNullIndex = intermediateNullIndex;
rawNewLengths[row] = indicesCursor - rawNewOffsets[row];
break;
}

// 3. Regular array
for (auto j = innerOffset; j < (innerOffset + innerSize); ++j) {
// null element
if (decodedInnerElement->isNullAt(j)) {
if ((!setInitialized || finalSet.hasNull) &&
!intermediateNullIndex.has_value()) {
intermediateSet.hasNull = true;
intermediateNullIndex = std::optional(indicesCursor++);
}
continue;
}
// regular element
if (!setInitialized || finalSet.count(decodedInnerElement, j)) {
auto success = intermediateSet.insert(decodedInnerElement, j);
if (success) {
rawNewIndices[indicesCursor++] = j;
}
}
}
setInitialized = true;
finalSet = intermediateSet;
finalNullIndex = intermediateNullIndex;
rawNewLengths[row] = indicesCursor - rawNewOffsets[row];
}

if (finalNullIndex.has_value()) {
bits::setNull(rawNewElementNulls, finalNullIndex.value(), true);
}
});

auto newElements = BaseVector::wrapInDictionary(
newElementNulls, newIndices, indicesCursor, innerArray->elements());
auto resultArray = std::make_shared<ArrayVector>(
pool,
outputType,
std::move(newNulls),
rowCount,
newOffsets,
newLengths,
newElements);
context.moveOrCopyResult(resultArray, rows, result);
}
};

// See documentation at https://prestodb.io/docs/current/functions/array.html
template <bool isIntersect, typename T>
class ArrayIntersectExceptFunction : public exec::VectorFunction {
Expand Down Expand Up @@ -414,7 +541,7 @@ class ArraysOverlapFunction : public exec::VectorFunction {
void validateMatchingArrayTypes(
const std::vector<exec::VectorFunctionArg>& inputArgs,
const std::string& name,
vector_size_t expectedArgCount) {
size_t expectedArgCount) {
VELOX_USER_CHECK_EQ(
inputArgs.size(),
expectedArgCount,
Expand Down Expand Up @@ -504,10 +631,30 @@ std::shared_ptr<exec::VectorFunction> createTypedArraysIntersectExcept(
}
}

template <TypeKind kind>
std::shared_ptr<exec::VectorFunction> createArraysIntersectSingleParam(
const TypePtr& elementType) {
if (elementType->providesCustomComparison()) {
return std::make_shared<ArraysIntersectSingleParam<WrappedVectorEntry>>();
} else {
using T = std::conditional_t<
TypeTraits<kind>::isPrimitiveType,
typename TypeTraits<kind>::NativeType,
WrappedVectorEntry>;
return std::make_shared<ArraysIntersectSingleParam<T>>();
}
}

std::shared_ptr<exec::VectorFunction> createArrayIntersect(
const std::string& name,
const std::vector<exec::VectorFunctionArg>& inputArgs,
const core::QueryConfig& /*config*/) {
if (inputArgs.size() == 1) {
auto elementType = inputArgs.front().type->childAt(0)->childAt(0);
return VELOX_DYNAMIC_TYPE_DISPATCH(
createArraysIntersectSingleParam, elementType->kind(), elementType);
}

validateMatchingArrayTypes(inputArgs, name, 2);
auto elementType = inputArgs.front().type->childAt(0);

Expand All @@ -534,6 +681,22 @@ std::shared_ptr<exec::VectorFunction> createArrayExcept(
elementType);
}

std::vector<std::shared_ptr<exec::FunctionSignature>>
arrayIntersectSignatures() {
return std::vector<std::shared_ptr<exec::FunctionSignature>>{
exec::FunctionSignatureBuilder()
.typeVariable("T")
.returnType("array(T)")
.argumentType("array(T)")
.argumentType("array(T)")
.build(),
exec::FunctionSignatureBuilder()
.typeVariable("T")
.returnType("array(T)")
.argumentType("array(array(T))")
.build()};
}

std::vector<std::shared_ptr<exec::FunctionSignature>> signatures(
const std::string& returnType) {
return std::vector<std::shared_ptr<exec::FunctionSignature>>{
Expand Down Expand Up @@ -600,7 +763,7 @@ VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(

VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(
udf_array_intersect,
signatures("array(T)"),
arrayIntersectSignatures(),
createArrayIntersect);

VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(
Expand Down
Loading

0 comments on commit 536ffe2

Please sign in to comment.