Skip to content

Commit

Permalink
Handle UNKNOWN type input in approx_distinct (#11367)
Browse files Browse the repository at this point in the history
Summary:
bypass-github-export-checks

Reviewed By: amitkdutta

Differential Revision: D65100134
  • Loading branch information
Yuhta authored and facebook-github-bot committed Oct 29, 2024
1 parent e67f11b commit 81d6842
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 13 deletions.
3 changes: 3 additions & 0 deletions velox/exec/DistinctAggregations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,9 @@ std::unique_ptr<DistinctAggregations> DistinctAggregations::create(
case TypeKind::ROW:
return std::make_unique<TypedDistinctAggregations<ComplexType>>(
aggregates, inputType, pool);
case TypeKind::UNKNOWN:
return std::make_unique<TypedDistinctAggregations<UnknownValue>>(
aggregates, inputType, pool);
default:
VELOX_UNREACHABLE("Unexpected type {}", type->toString());
}
Expand Down
62 changes: 62 additions & 0 deletions velox/exec/SetAccumulator.h
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,62 @@ struct ComplexTypeSetAccumulator {
}
};

class UnknownTypeSetAccumulator {
public:
UnknownTypeSetAccumulator(
const TypePtr& /*type*/,
HashStringAllocator* /*allocator*/) {}

void addValue(
const DecodedVector& decoded,
vector_size_t index,
HashStringAllocator* /*allocator*/) {
VELOX_DCHECK(decoded.isNullAt(index));
hasNull_ = true;
}

void addValues(
const ArrayVector& arrayVector,
vector_size_t index,
const DecodedVector& values,
HashStringAllocator* allocator) {
VELOX_DCHECK(!arrayVector.isNullAt(index));
const auto size = arrayVector.sizeAt(index);
const auto offset = arrayVector.offsetAt(index);
for (auto i = 0; i < size; ++i) {
addValue(values, offset + i, allocator);
}
}

void addNonNullValue(
const DecodedVector& /*decoded*/,
vector_size_t /*index*/,
HashStringAllocator* /*allocator*/) {}

void addNonNullValues(
const ArrayVector& /*arrayVector*/,
vector_size_t /*index*/,
const DecodedVector& /*values*/,
HashStringAllocator* /*/allocator*/) {}

size_t size() const {
return hasNull_ ? 1 : 0;
}

vector_size_t extractValues(BaseVector& values, vector_size_t offset) {
if (!hasNull_) {
return 0;
}
values.setNull(offset, true);
return 1;
}

void free(HashStringAllocator& /*allocator*/) {}

private:
bool hasNull_ = false;
};

template <typename T>
struct SetAccumulatorTypeTraits {
using AccumulatorType = SetAccumulator<T>;
Expand Down Expand Up @@ -369,6 +425,12 @@ template <>
struct SetAccumulatorTypeTraits<ComplexType> {
using AccumulatorType = ComplexTypeSetAccumulator;
};

template <>
struct SetAccumulatorTypeTraits<UnknownValue> {
using AccumulatorType = UnknownTypeSetAccumulator;
};

} // namespace detail

// A wrapper around SetAccumulator that overrides hash and equal_to functions to
Expand Down
32 changes: 19 additions & 13 deletions velox/functions/prestosql/aggregates/ApproxDistinctAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -453,19 +453,21 @@ exec::AggregateRegistrationResult registerApproxDistinct(
.argumentType("hyperloglog")
.build());
} else {
for (const auto& inputType :
{"boolean",
"tinyint",
"smallint",
"integer",
"bigint",
"hugeint",
"real",
"double",
"varchar",
"varbinary",
"timestamp",
"date"}) {
for (const auto& inputType : {
"boolean",
"tinyint",
"smallint",
"integer",
"bigint",
"hugeint",
"real",
"double",
"varchar",
"varbinary",
"timestamp",
"date",
"unknown",
}) {
signatures.push_back(exec::AggregateFunctionSignatureBuilder()
.returnType(returnType)
.intermediateType("varbinary")
Expand Down Expand Up @@ -505,6 +507,10 @@ exec::AggregateRegistrationResult registerApproxDistinct(
const TypePtr& resultType,
const core::QueryConfig& /*config*/)
-> std::unique_ptr<exec::Aggregate> {
if (argTypes[0]->isUnKnown()) {
return std::make_unique<ApproxDistinctAggregate<UnknownValue>>(
resultType, hllAsFinalResult, hllAsRawInput, defaultError);
}
return VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH(
createApproxDistinct,
argTypes[0]->kind(),
Expand Down
3 changes: 3 additions & 0 deletions velox/functions/prestosql/aggregates/SetAggregates.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,9 @@ void registerCountDistinctAggregate(
case TypeKind::ROW:
return std::make_unique<CountDistinctAggregate<ComplexType>>(
resultType, argTypes[0]);
case TypeKind::UNKNOWN:
return std::make_unique<CountDistinctAggregate<UnknownValue>>(
resultType, argTypes[0]);
default:
VELOX_UNREACHABLE(
"Unexpected type {}", mapTypeKindToName(typeKind));
Expand Down
39 changes: 39 additions & 0 deletions velox/functions/prestosql/aggregates/tests/ApproxDistinctTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -432,5 +432,44 @@ TEST_F(ApproxDistinctTest, toIntermediate) {
digests, {"c0"}, {"merge(a0)"}, {"c0", "cardinality(a0)"}, {input});
}

TEST_F(ApproxDistinctTest, unknownType) {
constexpr int kSize = 10;
auto input = makeRowVector({
makeFlatVector<int32_t>(kSize, [](auto i) { return i % 2; }),
makeAllNullFlatVector<UnknownValue>(kSize),
});
testAggregations(
{input},
{},
{"approx_distinct(c1)", "approx_distinct(c1, 0.023)"},
{makeRowVector(std::vector<VectorPtr>(2, makeConstant<int64_t>(0, 1)))});
testAggregations(
{input},
{},
{"approx_set(c1)", "approx_set(c1, 0.01625)"},
{"cardinality(a0)", "cardinality(a1)"},
{makeRowVector(
std::vector<VectorPtr>(2, makeNullConstant(TypeKind::BIGINT, 1)))});
testAggregations(
{input},
{"c0"},
{"approx_distinct(c1)", "approx_distinct(c1, 0.023)"},
{makeRowVector({
makeFlatVector<int32_t>({0, 1}),
makeFlatVector<int64_t>({0, 0}),
makeFlatVector<int64_t>({0, 0}),
})});
testAggregations(
{input},
{"c0"},
{"approx_set(c1)", "approx_set(c1, 0.01625)"},
{"c0", "cardinality(a0)", "cardinality(a1)"},
{makeRowVector({
makeFlatVector<int32_t>({0, 1}),
makeNullConstant(TypeKind::BIGINT, 2),
makeNullConstant(TypeKind::BIGINT, 2),
})});
}

} // namespace
} // namespace facebook::velox::aggregate::test
Original file line number Diff line number Diff line change
Expand Up @@ -373,5 +373,23 @@ TEST_F(CountAggregationTest, timestampWithTimeZone) {
testSingleAggregation({"c0"}, "c2", data, expected);
}

TEST_F(CountAggregationTest, unknownType) {
constexpr int kSize = 10;
auto input = makeRowVector({
makeFlatVector<int32_t>(kSize, [](auto i) { return i % 2; }),
makeAllNullFlatVector<UnknownValue>(kSize),
});
testGlobalAggregation(
"c1", input, makeRowVector({makeConstant<int64_t>(0, 1)}));
testSingleAggregation(
{"c0"},
"c1",
input,
makeRowVector({
makeFlatVector<int32_t>({0, 1}),
makeFlatVector<int64_t>({0, 0}),
}));
}

} // namespace
} // namespace facebook::velox::aggregate::test

0 comments on commit 81d6842

Please sign in to comment.