From 81d6842920a502cc3b78c470d2620e26326a4419 Mon Sep 17 00:00:00 2001 From: Jimmy Lu Date: Tue, 29 Oct 2024 09:01:14 -0700 Subject: [PATCH] Handle UNKNOWN type input in approx_distinct (#11367) Summary: bypass-github-export-checks Reviewed By: amitkdutta Differential Revision: D65100134 --- velox/exec/DistinctAggregations.cpp | 3 + velox/exec/SetAccumulator.h | 62 +++++++++++++++++++ .../aggregates/ApproxDistinctAggregate.cpp | 32 ++++++---- .../prestosql/aggregates/SetAggregates.cpp | 3 + .../aggregates/tests/ApproxDistinctTest.cpp | 39 ++++++++++++ .../aggregates/tests/CountAggregationTest.cpp | 18 ++++++ 6 files changed, 144 insertions(+), 13 deletions(-) diff --git a/velox/exec/DistinctAggregations.cpp b/velox/exec/DistinctAggregations.cpp index 4167b4ef2d99d..f6f9db5a24547 100644 --- a/velox/exec/DistinctAggregations.cpp +++ b/velox/exec/DistinctAggregations.cpp @@ -281,6 +281,9 @@ std::unique_ptr DistinctAggregations::create( case TypeKind::ROW: return std::make_unique>( aggregates, inputType, pool); + case TypeKind::UNKNOWN: + return std::make_unique>( + aggregates, inputType, pool); default: VELOX_UNREACHABLE("Unexpected type {}", type->toString()); } diff --git a/velox/exec/SetAccumulator.h b/velox/exec/SetAccumulator.h index 9f0c05b951923..3783ff5d843ba 100644 --- a/velox/exec/SetAccumulator.h +++ b/velox/exec/SetAccumulator.h @@ -339,6 +339,62 @@ struct ComplexTypeSetAccumulator { } }; +class UnknownTypeSetAccumulator { + public: + UnknownTypeSetAccumulator( + const TypePtr& /*type*/, + HashStringAllocator* /*allocator*/) {} + + void addValue( + const DecodedVector& decoded, + vector_size_t index, + HashStringAllocator* /*allocator*/) { + VELOX_DCHECK(decoded.isNullAt(index)); + hasNull_ = true; + } + + void addValues( + const ArrayVector& arrayVector, + vector_size_t index, + const DecodedVector& values, + HashStringAllocator* allocator) { + VELOX_DCHECK(!arrayVector.isNullAt(index)); + const auto size = arrayVector.sizeAt(index); + const auto offset = arrayVector.offsetAt(index); + for (auto i = 0; i < size; ++i) { + addValue(values, offset + i, allocator); + } + } + + void addNonNullValue( + const DecodedVector& /*decoded*/, + vector_size_t /*index*/, + HashStringAllocator* /*allocator*/) {} + + void addNonNullValues( + const ArrayVector& /*arrayVector*/, + vector_size_t /*index*/, + const DecodedVector& /*values*/, + HashStringAllocator* /*/allocator*/) {} + + size_t size() const { + return hasNull_ ? 1 : 0; + } + + vector_size_t extractValues(BaseVector& values, vector_size_t offset) { + if (!hasNull_) { + return 0; + } + values.setNull(offset, true); + return 1; + } + + void free(HashStringAllocator& /*allocator*/) {} + + private: + bool hasNull_ = false; +}; + template struct SetAccumulatorTypeTraits { using AccumulatorType = SetAccumulator; @@ -369,6 +425,12 @@ template <> struct SetAccumulatorTypeTraits { using AccumulatorType = ComplexTypeSetAccumulator; }; + +template <> +struct SetAccumulatorTypeTraits { + using AccumulatorType = UnknownTypeSetAccumulator; +}; + } // namespace detail // A wrapper around SetAccumulator that overrides hash and equal_to functions to diff --git a/velox/functions/prestosql/aggregates/ApproxDistinctAggregate.cpp b/velox/functions/prestosql/aggregates/ApproxDistinctAggregate.cpp index 52ddbf628eaf5..6728a552c3463 100644 --- a/velox/functions/prestosql/aggregates/ApproxDistinctAggregate.cpp +++ b/velox/functions/prestosql/aggregates/ApproxDistinctAggregate.cpp @@ -453,19 +453,21 @@ exec::AggregateRegistrationResult registerApproxDistinct( .argumentType("hyperloglog") .build()); } else { - for (const auto& inputType : - {"boolean", - "tinyint", - "smallint", - "integer", - "bigint", - "hugeint", - "real", - "double", - "varchar", - "varbinary", - "timestamp", - "date"}) { + for (const auto& inputType : { + "boolean", + "tinyint", + "smallint", + "integer", + "bigint", + "hugeint", + "real", + "double", + "varchar", + "varbinary", + "timestamp", + "date", + "unknown", + }) { signatures.push_back(exec::AggregateFunctionSignatureBuilder() .returnType(returnType) .intermediateType("varbinary") @@ -505,6 +507,10 @@ exec::AggregateRegistrationResult registerApproxDistinct( const TypePtr& resultType, const core::QueryConfig& /*config*/) -> std::unique_ptr { + if (argTypes[0]->isUnKnown()) { + return std::make_unique>( + resultType, hllAsFinalResult, hllAsRawInput, defaultError); + } return VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH( createApproxDistinct, argTypes[0]->kind(), diff --git a/velox/functions/prestosql/aggregates/SetAggregates.cpp b/velox/functions/prestosql/aggregates/SetAggregates.cpp index 4e2ff4c4715cb..09220bd0eeefc 100644 --- a/velox/functions/prestosql/aggregates/SetAggregates.cpp +++ b/velox/functions/prestosql/aggregates/SetAggregates.cpp @@ -419,6 +419,9 @@ void registerCountDistinctAggregate( case TypeKind::ROW: return std::make_unique>( resultType, argTypes[0]); + case TypeKind::UNKNOWN: + return std::make_unique>( + resultType, argTypes[0]); default: VELOX_UNREACHABLE( "Unexpected type {}", mapTypeKindToName(typeKind)); diff --git a/velox/functions/prestosql/aggregates/tests/ApproxDistinctTest.cpp b/velox/functions/prestosql/aggregates/tests/ApproxDistinctTest.cpp index 47c3ac0f8cae4..cb5d3d9bfb783 100644 --- a/velox/functions/prestosql/aggregates/tests/ApproxDistinctTest.cpp +++ b/velox/functions/prestosql/aggregates/tests/ApproxDistinctTest.cpp @@ -432,5 +432,44 @@ TEST_F(ApproxDistinctTest, toIntermediate) { digests, {"c0"}, {"merge(a0)"}, {"c0", "cardinality(a0)"}, {input}); } +TEST_F(ApproxDistinctTest, unknownType) { + constexpr int kSize = 10; + auto input = makeRowVector({ + makeFlatVector(kSize, [](auto i) { return i % 2; }), + makeAllNullFlatVector(kSize), + }); + testAggregations( + {input}, + {}, + {"approx_distinct(c1)", "approx_distinct(c1, 0.023)"}, + {makeRowVector(std::vector(2, makeConstant(0, 1)))}); + testAggregations( + {input}, + {}, + {"approx_set(c1)", "approx_set(c1, 0.01625)"}, + {"cardinality(a0)", "cardinality(a1)"}, + {makeRowVector( + std::vector(2, makeNullConstant(TypeKind::BIGINT, 1)))}); + testAggregations( + {input}, + {"c0"}, + {"approx_distinct(c1)", "approx_distinct(c1, 0.023)"}, + {makeRowVector({ + makeFlatVector({0, 1}), + makeFlatVector({0, 0}), + makeFlatVector({0, 0}), + })}); + testAggregations( + {input}, + {"c0"}, + {"approx_set(c1)", "approx_set(c1, 0.01625)"}, + {"c0", "cardinality(a0)", "cardinality(a1)"}, + {makeRowVector({ + makeFlatVector({0, 1}), + makeNullConstant(TypeKind::BIGINT, 2), + makeNullConstant(TypeKind::BIGINT, 2), + })}); +} + } // namespace } // namespace facebook::velox::aggregate::test diff --git a/velox/functions/prestosql/aggregates/tests/CountAggregationTest.cpp b/velox/functions/prestosql/aggregates/tests/CountAggregationTest.cpp index 0209b21770aa2..1b4ca57329cf1 100644 --- a/velox/functions/prestosql/aggregates/tests/CountAggregationTest.cpp +++ b/velox/functions/prestosql/aggregates/tests/CountAggregationTest.cpp @@ -373,5 +373,23 @@ TEST_F(CountAggregationTest, timestampWithTimeZone) { testSingleAggregation({"c0"}, "c2", data, expected); } +TEST_F(CountAggregationTest, unknownType) { + constexpr int kSize = 10; + auto input = makeRowVector({ + makeFlatVector(kSize, [](auto i) { return i % 2; }), + makeAllNullFlatVector(kSize), + }); + testGlobalAggregation( + "c1", input, makeRowVector({makeConstant(0, 1)})); + testSingleAggregation( + {"c0"}, + "c1", + input, + makeRowVector({ + makeFlatVector({0, 1}), + makeFlatVector({0, 0}), + })); +} + } // namespace } // namespace facebook::velox::aggregate::test