Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft v2] Another Multi group by optimization #10976

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 1 addition & 41 deletions benchmarks/queries/clickbench/queries.sql
Original file line number Diff line number Diff line change
@@ -1,43 +1,3 @@
SELECT COUNT(*) FROM hits;
SELECT COUNT(*) FROM hits WHERE "AdvEngineID" <> 0;
SELECT SUM("AdvEngineID"), COUNT(*), AVG("ResolutionWidth") FROM hits;
SELECT AVG("UserID") FROM hits;
SELECT COUNT(DISTINCT "UserID") FROM hits;
SELECT COUNT(DISTINCT "SearchPhrase") FROM hits;
SELECT MIN("EventDate"::INT::DATE), MAX("EventDate"::INT::DATE) FROM hits;
SELECT "AdvEngineID", COUNT(*) FROM hits WHERE "AdvEngineID" <> 0 GROUP BY "AdvEngineID" ORDER BY COUNT(*) DESC;
SELECT "RegionID", COUNT(DISTINCT "UserID") AS u FROM hits GROUP BY "RegionID" ORDER BY u DESC LIMIT 10;
SELECT "RegionID", SUM("AdvEngineID"), COUNT(*) AS c, AVG("ResolutionWidth"), COUNT(DISTINCT "UserID") FROM hits GROUP BY "RegionID" ORDER BY c DESC LIMIT 10;
SELECT "MobilePhoneModel", COUNT(DISTINCT "UserID") AS u FROM hits WHERE "MobilePhoneModel" <> '' GROUP BY "MobilePhoneModel" ORDER BY u DESC LIMIT 10;
SELECT "MobilePhone", "MobilePhoneModel", COUNT(DISTINCT "UserID") AS u FROM hits WHERE "MobilePhoneModel" <> '' GROUP BY "MobilePhone", "MobilePhoneModel" ORDER BY u DESC LIMIT 10;
SELECT "SearchPhrase", COUNT(*) AS c FROM hits WHERE "SearchPhrase" <> '' GROUP BY "SearchPhrase" ORDER BY c DESC LIMIT 10;
SELECT "SearchPhrase", COUNT(DISTINCT "UserID") AS u FROM hits WHERE "SearchPhrase" <> '' GROUP BY "SearchPhrase" ORDER BY u DESC LIMIT 10;
SELECT "SearchEngineID", "SearchPhrase", COUNT(*) AS c FROM hits WHERE "SearchPhrase" <> '' GROUP BY "SearchEngineID", "SearchPhrase" ORDER BY c DESC LIMIT 10;
SELECT "UserID", COUNT(*) FROM hits GROUP BY "UserID" ORDER BY COUNT(*) DESC LIMIT 10;
SELECT "UserID", "SearchPhrase", COUNT(*) FROM hits GROUP BY "UserID", "SearchPhrase" ORDER BY COUNT(*) DESC LIMIT 10;
SELECT "UserID", "SearchPhrase", COUNT(*) FROM hits GROUP BY "UserID", "SearchPhrase" LIMIT 10;
SELECT "UserID", extract(minute FROM to_timestamp_seconds("EventTime")) AS m, "SearchPhrase", COUNT(*) FROM hits GROUP BY "UserID", m, "SearchPhrase" ORDER BY COUNT(*) DESC LIMIT 10;
SELECT "UserID" FROM hits WHERE "UserID" = 435090932899640449;
SELECT COUNT(*) FROM hits WHERE "URL" LIKE '%google%';
SELECT "SearchPhrase", MIN("URL"), COUNT(*) AS c FROM hits WHERE "URL" LIKE '%google%' AND "SearchPhrase" <> '' GROUP BY "SearchPhrase" ORDER BY c DESC LIMIT 10;
SELECT "SearchPhrase", MIN("URL"), MIN("Title"), COUNT(*) AS c, COUNT(DISTINCT "UserID") FROM hits WHERE "Title" LIKE '%Google%' AND "URL" NOT LIKE '%.google.%' AND "SearchPhrase" <> '' GROUP BY "SearchPhrase" ORDER BY c DESC LIMIT 10;
SELECT * FROM hits WHERE "URL" LIKE '%google%' ORDER BY to_timestamp_seconds("EventTime") LIMIT 10;
SELECT "SearchPhrase" FROM hits WHERE "SearchPhrase" <> '' ORDER BY to_timestamp_seconds("EventTime") LIMIT 10;
SELECT "SearchPhrase" FROM hits WHERE "SearchPhrase" <> '' ORDER BY "SearchPhrase" LIMIT 10;
SELECT "SearchPhrase" FROM hits WHERE "SearchPhrase" <> '' ORDER BY to_timestamp_seconds("EventTime"), "SearchPhrase" LIMIT 10;
SELECT "CounterID", AVG(length("URL")) AS l, COUNT(*) AS c FROM hits WHERE "URL" <> '' GROUP BY "CounterID" HAVING COUNT(*) > 100000 ORDER BY l DESC LIMIT 25;
SELECT REGEXP_REPLACE("Referer", '^https?://(?:www\.)?([^/]+)/.*$', '\1') AS k, AVG(length("Referer")) AS l, COUNT(*) AS c, MIN("Referer") FROM hits WHERE "Referer" <> '' GROUP BY k HAVING COUNT(*) > 100000 ORDER BY l DESC LIMIT 25;
SELECT SUM("ResolutionWidth"), SUM("ResolutionWidth" + 1), SUM("ResolutionWidth" + 2), SUM("ResolutionWidth" + 3), SUM("ResolutionWidth" + 4), SUM("ResolutionWidth" + 5), SUM("ResolutionWidth" + 6), SUM("ResolutionWidth" + 7), SUM("ResolutionWidth" + 8), SUM("ResolutionWidth" + 9), SUM("ResolutionWidth" + 10), SUM("ResolutionWidth" + 11), SUM("ResolutionWidth" + 12), SUM("ResolutionWidth" + 13), SUM("ResolutionWidth" + 14), SUM("ResolutionWidth" + 15), SUM("ResolutionWidth" + 16), SUM("ResolutionWidth" + 17), SUM("ResolutionWidth" + 18), SUM("ResolutionWidth" + 19), SUM("ResolutionWidth" + 20), SUM("ResolutionWidth" + 21), SUM("ResolutionWidth" + 22), SUM("ResolutionWidth" + 23), SUM("ResolutionWidth" + 24), SUM("ResolutionWidth" + 25), SUM("ResolutionWidth" + 26), SUM("ResolutionWidth" + 27), SUM("ResolutionWidth" + 28), SUM("ResolutionWidth" + 29), SUM("ResolutionWidth" + 30), SUM("ResolutionWidth" + 31), SUM("ResolutionWidth" + 32), SUM("ResolutionWidth" + 33), SUM("ResolutionWidth" + 34), SUM("ResolutionWidth" + 35), SUM("ResolutionWidth" + 36), SUM("ResolutionWidth" + 37), SUM("ResolutionWidth" + 38), SUM("ResolutionWidth" + 39), SUM("ResolutionWidth" + 40), SUM("ResolutionWidth" + 41), SUM("ResolutionWidth" + 42), SUM("ResolutionWidth" + 43), SUM("ResolutionWidth" + 44), SUM("ResolutionWidth" + 45), SUM("ResolutionWidth" + 46), SUM("ResolutionWidth" + 47), SUM("ResolutionWidth" + 48), SUM("ResolutionWidth" + 49), SUM("ResolutionWidth" + 50), SUM("ResolutionWidth" + 51), SUM("ResolutionWidth" + 52), SUM("ResolutionWidth" + 53), SUM("ResolutionWidth" + 54), SUM("ResolutionWidth" + 55), SUM("ResolutionWidth" + 56), SUM("ResolutionWidth" + 57), SUM("ResolutionWidth" + 58), SUM("ResolutionWidth" + 59), SUM("ResolutionWidth" + 60), SUM("ResolutionWidth" + 61), SUM("ResolutionWidth" + 62), SUM("ResolutionWidth" + 63), SUM("ResolutionWidth" + 64), SUM("ResolutionWidth" + 65), SUM("ResolutionWidth" + 66), SUM("ResolutionWidth" + 67), SUM("ResolutionWidth" + 68), SUM("ResolutionWidth" + 69), SUM("ResolutionWidth" + 70), SUM("ResolutionWidth" + 71), SUM("ResolutionWidth" + 72), SUM("ResolutionWidth" + 73), SUM("ResolutionWidth" + 74), SUM("ResolutionWidth" + 75), SUM("ResolutionWidth" + 76), SUM("ResolutionWidth" + 77), SUM("ResolutionWidth" + 78), SUM("ResolutionWidth" + 79), SUM("ResolutionWidth" + 80), SUM("ResolutionWidth" + 81), SUM("ResolutionWidth" + 82), SUM("ResolutionWidth" + 83), SUM("ResolutionWidth" + 84), SUM("ResolutionWidth" + 85), SUM("ResolutionWidth" + 86), SUM("ResolutionWidth" + 87), SUM("ResolutionWidth" + 88), SUM("ResolutionWidth" + 89) FROM hits;
SELECT "SearchEngineID", "ClientIP", COUNT(*) AS c, SUM("IsRefresh"), AVG("ResolutionWidth") FROM hits WHERE "SearchPhrase" <> '' GROUP BY "SearchEngineID", "ClientIP" ORDER BY c DESC LIMIT 10;
SELECT "WatchID", "ClientIP", COUNT(*) AS c, SUM("IsRefresh"), AVG("ResolutionWidth") FROM hits WHERE "SearchPhrase" <> '' GROUP BY "WatchID", "ClientIP" ORDER BY c DESC LIMIT 10;
SELECT "WatchID", "ClientIP", COUNT(*) AS c, SUM("IsRefresh"), AVG("ResolutionWidth") FROM hits GROUP BY "WatchID", "ClientIP" ORDER BY c DESC LIMIT 10;
SELECT "URL", COUNT(*) AS c FROM hits GROUP BY "URL" ORDER BY c DESC LIMIT 10;
SELECT 1, "URL", COUNT(*) AS c FROM hits GROUP BY 1, "URL" ORDER BY c DESC LIMIT 10;
SELECT "ClientIP", "ClientIP" - 1, "ClientIP" - 2, "ClientIP" - 3, COUNT(*) AS c FROM hits GROUP BY "ClientIP", "ClientIP" - 1, "ClientIP" - 2, "ClientIP" - 3 ORDER BY c DESC LIMIT 10;
SELECT "URL", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >= '2013-07-01' AND "EventDate"::INT::DATE <= '2013-07-31' AND "DontCountHits" = 0 AND "IsRefresh" = 0 AND "URL" <> '' GROUP BY "URL" ORDER BY PageViews DESC LIMIT 10;
SELECT "Title", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >= '2013-07-01' AND "EventDate"::INT::DATE <= '2013-07-31' AND "DontCountHits" = 0 AND "IsRefresh" = 0 AND "Title" <> '' GROUP BY "Title" ORDER BY PageViews DESC LIMIT 10;
SELECT "URL", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >= '2013-07-01' AND "EventDate"::INT::DATE <= '2013-07-31' AND "IsRefresh" = 0 AND "IsLink" <> 0 AND "IsDownload" = 0 GROUP BY "URL" ORDER BY PageViews DESC LIMIT 10 OFFSET 1000;
SELECT "TraficSourceID", "SearchEngineID", "AdvEngineID", CASE WHEN ("SearchEngineID" = 0 AND "AdvEngineID" = 0) THEN "Referer" ELSE '' END AS Src, "URL" AS Dst, COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >= '2013-07-01' AND "EventDate"::INT::DATE <= '2013-07-31' AND "IsRefresh" = 0 GROUP BY "TraficSourceID", "SearchEngineID", "AdvEngineID", Src, Dst ORDER BY PageViews DESC LIMIT 10 OFFSET 1000;
SELECT "URLHash", "EventDate"::INT::DATE, COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >= '2013-07-01' AND "EventDate"::INT::DATE <= '2013-07-31' AND "IsRefresh" = 0 AND "TraficSourceID" IN (-1, 6) AND "RefererHash" = 3594120000172545465 GROUP BY "URLHash", "EventDate"::INT::DATE ORDER BY PageViews DESC LIMIT 10 OFFSET 100;
SELECT "WindowClientWidth", "WindowClientHeight", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >= '2013-07-01' AND "EventDate"::INT::DATE <= '2013-07-31' AND "IsRefresh" = 0 AND "DontCountHits" = 0 AND "URLHash" = 2868770270353813622 GROUP BY "WindowClientWidth", "WindowClientHeight" ORDER BY PageViews DESC LIMIT 10 OFFSET 10000;
SELECT DATE_TRUNC('minute', to_timestamp_seconds("EventTime")) AS M, COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >= '2013-07-14' AND "EventDate"::INT::DATE <= '2013-07-15' AND "IsRefresh" = 0 AND "DontCountHits" = 0 GROUP BY DATE_TRUNC('minute', to_timestamp_seconds("EventTime")) ORDER BY DATE_TRUNC('minute', M) LIMIT 10 OFFSET 1000;
SELECT "UserID", concat("SearchPhrase", repeat('hello', 100)) as s, COUNT(*) FROM hits GROUP BY "UserID", s LIMIT 10;
96 changes: 90 additions & 6 deletions datafusion/physical-plan/src/aggregates/group_values/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,23 @@
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;

use crate::aggregates::group_values::GroupValues;
use ahash::RandomState;
use arrow::array::{AsArray, StringBuilder};
use arrow::compute::cast;
use arrow::datatypes::UInt64Type;
use arrow::record_batch::RecordBatch;
use arrow::row::{RowConverter, Rows, SortField};
use arrow_array::{Array, ArrayRef};
use arrow_array::{Array, ArrayRef, UInt64Array};
use arrow_schema::{DataType, SchemaRef};
use datafusion_common::hash_utils::create_hashes;
use datafusion_common::{DataFusionError, Result};
use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt};
use datafusion_expr::EmitTo;
use datafusion_physical_expr::binary_map::OutputType;
use datafusion_physical_expr_common::binary_map::ArrowBytesMap;
use hashbrown::raw::RawTable;

/// A [`GroupValues`] making use of [`Rows`]
Expand Down Expand Up @@ -67,6 +73,10 @@ pub struct GroupValuesRows {

/// Random state for creating hashes
random_state: RandomState,

// variable length column map
var_map: ArrowBytesMap<i32, usize>,
num_groups: usize,
}

impl GroupValuesRows {
Expand All @@ -75,7 +85,13 @@ impl GroupValuesRows {
schema
.fields()
.iter()
.map(|f| SortField::new(f.data_type().clone()))
.map(|f| {
if f.data_type() == &DataType::Utf8 {
SortField::new(DataType::UInt64)
} else {
SortField::new(f.data_type().clone())
}
})
.collect(),
)?;

Expand All @@ -94,16 +110,41 @@ impl GroupValuesRows {
hashes_buffer: Default::default(),
rows_buffer,
random_state: Default::default(),
var_map: ArrowBytesMap::new(OutputType::Utf8),
num_groups: 0,
})
}
}

impl GroupValues for GroupValuesRows {
fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec<usize>) -> Result<()> {
let cols_var = &cols[1];
groups.clear();
self.var_map.insert_if_new(
cols_var,
// called for each new group
|_value| {
// assign new group index on each insert
let group_idx = self.num_groups;
self.num_groups += 1;
group_idx
},
// called for each group
|group_idx| {
groups.push(group_idx);
},
);

let u64_vec: Vec<u64> = groups.iter().map(|&x| x as u64).collect();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be faster to write into u64_vec directly rather than write to groups and then copy over

let arr = Arc::new(UInt64Array::from(u64_vec)) as ArrayRef;

let mut cols = cols[0..1].to_vec();
cols.push(arr);

// Convert the group keys into the row format
let group_rows = &mut self.rows_buffer;
group_rows.clear();
self.row_converter.append(group_rows, cols)?;
self.row_converter.append(group_rows, &cols)?;
let n_rows = group_rows.num_rows();

let mut group_values = match self.group_values.take() {
Expand All @@ -118,7 +159,7 @@ impl GroupValues for GroupValuesRows {
let batch_hashes = &mut self.hashes_buffer;
batch_hashes.clear();
batch_hashes.resize(n_rows, 0);
create_hashes(cols, &self.random_state, batch_hashes)?;
create_hashes(&cols, &self.random_state, batch_hashes)?;

for (row, &hash) in batch_hashes.iter().enumerate() {
let entry = self.map.get_mut(hash, |(_hash, group_idx)| {
Expand Down Expand Up @@ -180,15 +221,39 @@ impl GroupValues for GroupValuesRows {
.take()
.expect("Can not emit from empty rows");

let map_contents = self.var_map.take().into_state();
let map_contents = map_contents.as_string::<i32>();

let mut output = match emit_to {
EmitTo::All => {
let output = self.row_converter.convert_rows(&group_values)?;
let mut output = self.row_converter.convert_rows(&group_values)?;

// Index Array: [0, 1, 1, 0]
// Data Array: ['a', 'c']
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code copies all the strings again, which likely takes significant time

One way to make this faster is likey to special case the "has no nulls" path (so we can avoid if let Some(.) check each row

The only other ways I can think to make this faster is to avoid copying the strings all together. The only way to do so I can figure are:

  1. Return a DictionaryArray as output
  2. Return a StringViewArray (similar to DictionayArray) [Epic] Implement support for StringView in DataFusion #10918

🤔

Maybe we can try to hack in the ability to have the HashAggregateExec return Dictionary(Int32, String) for these multi-column groups and see if we can show significant performance improvements

If so then we can figure out how to thread that ability through the engine.

// Result Array: ['a', 'c', 'c', 'a']

let mut string_build = StringBuilder::new();

let arr = output[1].as_primitive::<UInt64Type>();
for v in arr.iter() {
if let Some(index) = v {
let value = map_contents.value(index as usize);
string_build.append_value(value)
} else {
string_build.append_null();
}
}

let output_str = string_build.finish();

output[1] = Arc::new(output_str);

group_values.clear();
output
}
EmitTo::First(n) => {
let groups_rows = group_values.iter().take(n);
let output = self.row_converter.convert_rows(groups_rows)?;
let mut output = self.row_converter.convert_rows(groups_rows)?;
// Clear out first n group keys by copying them to a new Rows.
// TODO file some ticket in arrow-rs to make this more efficent?
let mut new_group_values = self.row_converter.empty_rows(0, 0);
Expand All @@ -209,6 +274,25 @@ impl GroupValues for GroupValuesRows {
}
}
}

let mut string_build = StringBuilder::new();

let arr = output[1].as_primitive::<UInt64Type>();
for v in arr.iter() {
if let Some(index) = v {
let value = map_contents.value(index as usize);
string_build.append_value(value)
} else {
string_build.append_null();
}
}

let output_str = string_build.finish();

output[1] = Arc::new(output_str);

// self.num_groups -= map_contents.len();
// output.push(map_contents);
output
}
};
Expand Down
29 changes: 29 additions & 0 deletions datafusion/sqllogictest/test_files/test1.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
statement ok
create table t(a int, b varchar, c int) as values (1, 'a', 3), (2, 'abcabcabcabcabc', 1), (1, 'c', 2), (1, 'a', 4);

query ITI
select a, concat(b, 'abcdabcdabcd') as b_v, count(*) from t group by a, b_v order by count(*) desc;
----
1 a 2
2 c 1
1 c 1

query TT
explain select a, b, count(*) from t group by a, b order by count(*) desc;
----
logical_plan
01)Sort: COUNT(*) DESC NULLS FIRST
02)--Aggregate: groupBy=[[t.a, t.b]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]]
03)----TableScan: t projection=[a, b]
physical_plan
01)SortPreservingMergeExec: [COUNT(*)@2 DESC]
02)--SortExec: expr=[COUNT(*)@2 DESC], preserve_partitioning=[true]
03)----AggregateExec: mode=FinalPartitioned, gby=[a@0 as a, b@1 as b], aggr=[COUNT(*)]
04)------CoalesceBatchesExec: target_batch_size=8192
05)--------RepartitionExec: partitioning=Hash([a@0, b@1], 4), input_partitions=4
06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
07)------------AggregateExec: mode=Partial, gby=[a@0 as a, b@1 as b], aggr=[COUNT(*)]
08)--------------MemoryExec: partitions=1, partition_sizes=[1]

statement ok
drop table t;
Loading