diff --git a/velox/functions/prestosql/aggregates/ApproxDistinctAggregate.cpp b/velox/functions/prestosql/aggregates/ApproxDistinctAggregate.cpp index 52ddbf628eaf5..6728a552c3463 100644 --- a/velox/functions/prestosql/aggregates/ApproxDistinctAggregate.cpp +++ b/velox/functions/prestosql/aggregates/ApproxDistinctAggregate.cpp @@ -453,19 +453,21 @@ exec::AggregateRegistrationResult registerApproxDistinct( .argumentType("hyperloglog") .build()); } else { - for (const auto& inputType : - {"boolean", - "tinyint", - "smallint", - "integer", - "bigint", - "hugeint", - "real", - "double", - "varchar", - "varbinary", - "timestamp", - "date"}) { + for (const auto& inputType : { + "boolean", + "tinyint", + "smallint", + "integer", + "bigint", + "hugeint", + "real", + "double", + "varchar", + "varbinary", + "timestamp", + "date", + "unknown", + }) { signatures.push_back(exec::AggregateFunctionSignatureBuilder() .returnType(returnType) .intermediateType("varbinary") @@ -505,6 +507,10 @@ exec::AggregateRegistrationResult registerApproxDistinct( const TypePtr& resultType, const core::QueryConfig& /*config*/) -> std::unique_ptr { + if (argTypes[0]->isUnKnown()) { + return std::make_unique>( + resultType, hllAsFinalResult, hllAsRawInput, defaultError); + } return VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH( createApproxDistinct, argTypes[0]->kind(), diff --git a/velox/functions/prestosql/aggregates/tests/ApproxDistinctTest.cpp b/velox/functions/prestosql/aggregates/tests/ApproxDistinctTest.cpp index 47c3ac0f8cae4..cb5d3d9bfb783 100644 --- a/velox/functions/prestosql/aggregates/tests/ApproxDistinctTest.cpp +++ b/velox/functions/prestosql/aggregates/tests/ApproxDistinctTest.cpp @@ -432,5 +432,44 @@ TEST_F(ApproxDistinctTest, toIntermediate) { digests, {"c0"}, {"merge(a0)"}, {"c0", "cardinality(a0)"}, {input}); } +TEST_F(ApproxDistinctTest, unknownType) { + constexpr int kSize = 10; + auto input = makeRowVector({ + makeFlatVector(kSize, [](auto i) { return i % 2; }), + makeAllNullFlatVector(kSize), + }); + testAggregations( + {input}, + {}, + {"approx_distinct(c1)", "approx_distinct(c1, 0.023)"}, + {makeRowVector(std::vector(2, makeConstant(0, 1)))}); + testAggregations( + {input}, + {}, + {"approx_set(c1)", "approx_set(c1, 0.01625)"}, + {"cardinality(a0)", "cardinality(a1)"}, + {makeRowVector( + std::vector(2, makeNullConstant(TypeKind::BIGINT, 1)))}); + testAggregations( + {input}, + {"c0"}, + {"approx_distinct(c1)", "approx_distinct(c1, 0.023)"}, + {makeRowVector({ + makeFlatVector({0, 1}), + makeFlatVector({0, 0}), + makeFlatVector({0, 0}), + })}); + testAggregations( + {input}, + {"c0"}, + {"approx_set(c1)", "approx_set(c1, 0.01625)"}, + {"c0", "cardinality(a0)", "cardinality(a1)"}, + {makeRowVector({ + makeFlatVector({0, 1}), + makeNullConstant(TypeKind::BIGINT, 2), + makeNullConstant(TypeKind::BIGINT, 2), + })}); +} + } // namespace } // namespace facebook::velox::aggregate::test