Skip to content

Commit

Permalink
Enable spilling for partial aggregation (7558)
Browse files Browse the repository at this point in the history
  • Loading branch information
JkSelf authored and zhztheplayer committed Dec 29, 2023
1 parent 22132d9 commit d6ed32e
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 17 deletions.
12 changes: 10 additions & 2 deletions velox/core/PlanNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,16 @@ bool AggregationNode::canSpill(const QueryConfig& queryConfig) const {
}
// TODO: add spilling for pre-grouped aggregation later:
// https://github.com/facebookincubator/velox/issues/3264
return (isFinal() || isSingle()) && preGroupedKeys().empty() &&
queryConfig.aggregationSpillEnabled();
if ((isFinal() || isSingle()) && queryConfig.aggregationSpillEnabled()) {
return preGroupedKeys().empty();
}

if ((isIntermediate() || isPartial()) &&
queryConfig.partialAggregationSpillEnabled()) {
return preGroupedKeys().empty();
}

return false;
}

void AggregationNode::addDetails(std::stringstream& stream) const {
Expand Down
8 changes: 8 additions & 0 deletions velox/core/PlanNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,14 @@ class AggregationNode : public PlanNode {
return step_ == Step::kSingle;
}

bool isIntermediate() const {
return step_ == Step::kIntermediate;
}

bool isPartial() const {
return step_ == Step::kPartial;
}

folly::dynamic serialize() const override;

static PlanNodePtr create(const folly::dynamic& obj, void* context);
Expand Down
13 changes: 12 additions & 1 deletion velox/core/QueryConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,11 @@ class QueryConfig {
static constexpr const char* kAggregationSpillEnabled =
"aggregation_spill_enabled";

/// Partial aggregation spilling flag, only applies if "spill_enabled" flag is
/// set.
static constexpr const char* kPartialAggregationSpillEnabled =
"partial_aggregation_spill_enabled";

/// Join spilling flag, only applies if "spill_enabled" flag is set.
static constexpr const char* kJoinSpillEnabled = "join_spill_enabled";

Expand Down Expand Up @@ -542,11 +547,17 @@ class QueryConfig {
}

/// Returns 'is aggregation spilling enabled' flag. Must also check the
/// spillEnabled()!g
/// spillEnabled()!
bool aggregationSpillEnabled() const {
return get<bool>(kAggregationSpillEnabled, true);
}

/// Returns 'is partial aggregation spilling enabled' flag. Must also check
/// the spillEnabled()!
bool partialAggregationSpillEnabled() const {
return get<bool>(kPartialAggregationSpillEnabled, false);
}

/// Returns 'is join spilling enabled' flag. Must also check the
/// spillEnabled()!
bool joinSpillEnabled() const {
Expand Down
17 changes: 13 additions & 4 deletions velox/exec/GroupingSet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,7 @@ bool GroupingSet::getOutput(
}

if (hasSpilled()) {
spill();
return getOutputWithSpill(maxOutputRows, maxOutputBytes, result);
}
VELOX_CHECK(!isDistinct());
Expand Down Expand Up @@ -803,7 +804,7 @@ const HashLookup& GroupingSet::hashLookup() const {
void GroupingSet::ensureInputFits(const RowVectorPtr& input) {
// Spilling is considered if this is a final or single aggregation and
// spillPath is set.
if (isPartial_ || spillConfig_ == nullptr) {
if (spillConfig_ == nullptr) {
return;
}

Expand Down Expand Up @@ -888,7 +889,7 @@ void GroupingSet::ensureOutputFits() {
// to reserve memory for the output as we can't reclaim much memory from this
// operator itself. The output processing can reclaim memory from the other
// operator or query through memory arbitration.
if (isPartial_ || spillConfig_ == nullptr || hasSpilled()) {
if (spillConfig_ == nullptr || hasSpilled()) {
return;
}

Expand Down Expand Up @@ -938,7 +939,6 @@ void GroupingSet::spill() {
if (table_ == nullptr || table_->numDistinct() == 0) {
return;
}

if (!hasSpilled()) {
auto rows = table_->rows();
VELOX_DCHECK(pool_.trackUsage());
Expand Down Expand Up @@ -1013,7 +1013,16 @@ bool GroupingSet::getOutputWithSpill(
if (merge_ == nullptr) {
return false;
}
return mergeNext(maxOutputRows, maxOutputBytes, result);
bool hasData = mergeNext(maxOutputRows, maxOutputBytes, result);
if (!hasData) {
// If spill has been finalized, reset merge stream and spiller. This would
// help partial aggregation replay the spilling procedure once needed again.
merge_ = nullptr;
mergeRows_ = nullptr;
mergeArgs_.clear();
spiller_ = nullptr;
}
return hasData;
}

bool GroupingSet::mergeNext(
Expand Down
44 changes: 44 additions & 0 deletions velox/exec/tests/SharedArbitratorTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1175,6 +1175,50 @@ TEST_F(SharedArbitrationTest, reclaimFromDistinctAggregation) {
waitForAllTasksToBeDeleted();
}

TEST_F(SharedArbitrationTest, reclaimFromPartialAggregation) {
const uint64_t maxQueryCapacity = 20L << 20;
std::vector<RowVectorPtr> vectors = newVectors(1024, maxQueryCapacity * 2);
createDuckDbTable(vectors);
const auto spillDirectory = exec::test::TempDirectoryPath::create();
core::PlanNodeId partialAggNodeId;
core::PlanNodeId finalAggNodeId;
std::shared_ptr<core::QueryCtx> queryCtx = newQueryCtx(maxQueryCapacity);
auto task =
AssertQueryBuilder(duckDbQueryRunner_)
.spillDirectory(spillDirectory->path)
.config(core::QueryConfig::kSpillEnabled, "true")
.config(core::QueryConfig::kPartialAggregationSpillEnabled, "true")
.config(core::QueryConfig::kAggregationSpillEnabled, "true")
.config(
core::QueryConfig::kMaxPartialAggregationMemory,
std::to_string(1LL << 30)) // disable flush
.config(
core::QueryConfig::kMaxExtendedPartialAggregationMemory,
std::to_string(1LL << 30)) // disable flush
.config(
core::QueryConfig::kAbandonPartialAggregationMinPct,
"200") // avoid abandoning
.config(
core::QueryConfig::kAbandonPartialAggregationMinRows,
std::to_string(1LL << 30)) // avoid abandoning
.queryCtx(queryCtx)
.plan(PlanBuilder()
.values(vectors)
.partialAggregation({"c0"}, {"count(1)"})
.capturePlanNodeId(partialAggNodeId)
.finalAggregation()
.capturePlanNodeId(finalAggNodeId)
.planNode())
.assertResults("SELECT c0, count(1) FROM tmp GROUP BY c0");
auto taskStats = exec::toPlanStats(task->taskStats());
auto& partialStats = taskStats.at(partialAggNodeId);
auto& finalStats = taskStats.at(finalAggNodeId);
ASSERT_GT(partialStats.spilledBytes, 0);
ASSERT_GT(finalStats.spilledBytes, 0);
task.reset();
waitForAllTasksToBeDeleted();
}

DEBUG_ONLY_TEST_F(SharedArbitrationTest, reclaimFromAggregationOnNoMoreInput) {
const int numVectors = 32;
std::vector<RowVectorPtr> vectors;
Expand Down
33 changes: 23 additions & 10 deletions velox/functions/prestosql/aggregates/ApproxPercentileAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,19 @@ class ApproxPercentileAggregate : public exec::Aggregate {
DecodedVector decodedDigest_;

private:
bool isConstantVector(const VectorPtr& vec) {
if (vec->isConstantEncoding()) {
return true;
}
VELOX_USER_CHECK(vec->size() > 0);
for (vector_size_t i = 1; i < vec->size(); ++i) {
if (!vec->equalValueAt(vec.get(), i, 0)) {
return false;
}
}
return true;
}

template <bool kSingleGroup, bool checkIntermediateInputs>
void addIntermediateImpl(
std::conditional_t<kSingleGroup, char*, char**> group,
Expand All @@ -650,7 +663,8 @@ class ApproxPercentileAggregate : public exec::Aggregate {
if constexpr (checkIntermediateInputs) {
VELOX_USER_CHECK(rowVec);
for (int i = kPercentiles; i <= kAccuracy; ++i) {
VELOX_USER_CHECK(rowVec->childAt(i)->isConstantEncoding());
VELOX_USER_CHECK(isConstantVector(
rowVec->childAt(i))); // spilling flats constant encoding
}
for (int i = kK; i <= kMaxValue; ++i) {
VELOX_USER_CHECK(rowVec->childAt(i)->isFlatEncoding());
Expand All @@ -677,10 +691,9 @@ class ApproxPercentileAggregate : public exec::Aggregate {
}

DecodedVector percentiles(*rowVec->childAt(kPercentiles), *baseRows);
auto percentileIsArray =
rowVec->childAt(kPercentilesIsArray)->asUnchecked<SimpleVector<bool>>();
auto accuracy =
rowVec->childAt(kAccuracy)->asUnchecked<SimpleVector<double>>();
DecodedVector percentileIsArray(
*rowVec->childAt(kPercentilesIsArray), *baseRows);
DecodedVector accuracy(*rowVec->childAt(kAccuracy), *baseRows);
auto k = rowVec->childAt(kK)->asUnchecked<SimpleVector<int32_t>>();
auto n = rowVec->childAt(kN)->asUnchecked<SimpleVector<int64_t>>();
auto minValue = rowVec->childAt(kMinValue)->asUnchecked<SimpleVector<T>>();
Expand Down Expand Up @@ -710,7 +723,7 @@ class ApproxPercentileAggregate : public exec::Aggregate {
return;
}
int i = decoded.index(row);
if (percentileIsArray->isNullAt(i)) {
if (percentileIsArray.isNullAt(i)) {
return;
}
if (!accumulator) {
Expand All @@ -720,19 +733,19 @@ class ApproxPercentileAggregate : public exec::Aggregate {
percentilesBase->elements()->asFlatVector<double>();
if constexpr (checkIntermediateInputs) {
VELOX_USER_CHECK(percentileBaseElements);
VELOX_USER_CHECK(!percentilesBase->isNullAt(indexInBaseVector));
VELOX_USER_CHECK(!percentiles.isNullAt(indexInBaseVector));
}

bool isArray = percentileIsArray->valueAt(i);
bool isArray = percentileIsArray.valueAt<bool>(i);
const double* data;
vector_size_t len;
std::vector<bool> isNull;
extractPercentiles(
percentilesBase, indexInBaseVector, data, len, isNull);
checkSetPercentile(isArray, data, len, isNull);

if (!accuracy->isNullAt(i)) {
checkSetAccuracy(accuracy->valueAt(i));
if (!accuracy.isNullAt(i)) {
checkSetAccuracy(accuracy.valueAt<double>(i));
}
}
if constexpr (kSingleGroup) {
Expand Down

0 comments on commit d6ed32e

Please sign in to comment.