diff --git a/internal/planner/average.go b/internal/planner/average.go index c5274b5b6f..76bbfc107d 100644 --- a/internal/planner/average.go +++ b/internal/planner/average.go @@ -28,6 +28,8 @@ type averageNode struct { virtualFieldIndex int execInfo averageExecInfo + + aggregateFilter *mapper.Filter } type averageExecInfo struct { @@ -37,6 +39,7 @@ type averageExecInfo struct { func (p *Planner) Average( field *mapper.Aggregate, + filter *mapper.Filter, ) (*averageNode, error) { var sumField *mapper.Aggregate var countField *mapper.Aggregate @@ -57,6 +60,7 @@ func (p *Planner) Average( countFieldIndex: countField.Index, virtualFieldIndex: field.Index, docMapper: docMapper{field.DocumentMapping}, + aggregateFilter: filter, }, nil } @@ -102,7 +106,7 @@ func (n *averageNode) Next() (bool, error) { return false, client.NewErrUnhandledType("sum", sumProp) } - return true, nil + return mapper.RunFilter(n.currentValue, n.aggregateFilter) } func (n *averageNode) SetPlan(p planNode) { n.plan = p } diff --git a/internal/planner/count.go b/internal/planner/count.go index b71fcab1e5..1b58109749 100644 --- a/internal/planner/count.go +++ b/internal/planner/count.go @@ -35,6 +35,7 @@ type countNode struct { virtualFieldIndex int aggregateMapping []mapper.AggregateTarget + aggregateFilter *mapper.Filter execInfo countExecInfo } @@ -44,11 +45,12 @@ type countExecInfo struct { iterations uint64 } -func (p *Planner) Count(field *mapper.Aggregate, host *mapper.Select) (*countNode, error) { +func (p *Planner) Count(field *mapper.Aggregate, host *mapper.Select, filter *mapper.Filter) (*countNode, error) { return &countNode{ p: p, virtualFieldIndex: field.Index, aggregateMapping: field.AggregateTargets, + aggregateFilter: filter, docMapper: docMapper{field.DocumentMapping}, }, nil } @@ -181,7 +183,7 @@ func (n *countNode) Next() (bool, error) { } n.currentValue.Fields[n.virtualFieldIndex] = count - return true, nil + return mapper.RunFilter(n.currentValue, n.aggregateFilter) } // countDocs counts the number of documents in a slice, skipping over hidden items diff --git a/internal/planner/max.go b/internal/planner/max.go index c3eb6b488e..530e60e25e 100644 --- a/internal/planner/max.go +++ b/internal/planner/max.go @@ -33,6 +33,7 @@ type maxNode struct { // that contains the result of the aggregate. virtualFieldIndex int aggregateMapping []mapper.AggregateTarget + aggregateFilter *mapper.Filter execInfo maxExecInfo } @@ -45,11 +46,13 @@ type maxExecInfo struct { func (p *Planner) Max( field *mapper.Aggregate, parent *mapper.Select, + filter *mapper.Filter, ) (*maxNode, error) { return &maxNode{ p: p, parent: parent, aggregateMapping: field.AggregateTargets, + aggregateFilter: filter, virtualFieldIndex: field.Index, docMapper: docMapper{field.DocumentMapping}, }, nil @@ -252,5 +255,5 @@ func (n *maxNode) Next() (bool, error) { res, _ := max.Int64() n.currentValue.Fields[n.virtualFieldIndex] = res } - return true, nil + return mapper.RunFilter(n.currentValue, n.aggregateFilter) } diff --git a/internal/planner/min.go b/internal/planner/min.go index 99278785bc..be70a8ccb9 100644 --- a/internal/planner/min.go +++ b/internal/planner/min.go @@ -33,6 +33,7 @@ type minNode struct { // that contains the result of the aggregate. virtualFieldIndex int aggregateMapping []mapper.AggregateTarget + aggregateFilter *mapper.Filter execInfo minExecInfo } @@ -45,11 +46,13 @@ type minExecInfo struct { func (p *Planner) Min( field *mapper.Aggregate, parent *mapper.Select, + filter *mapper.Filter, ) (*minNode, error) { return &minNode{ p: p, parent: parent, aggregateMapping: field.AggregateTargets, + aggregateFilter: filter, virtualFieldIndex: field.Index, docMapper: docMapper{field.DocumentMapping}, }, nil @@ -252,5 +255,5 @@ func (n *minNode) Next() (bool, error) { res, _ := min.Int64() n.currentValue.Fields[n.virtualFieldIndex] = res } - return true, nil + return mapper.RunFilter(n.currentValue, n.aggregateFilter) } diff --git a/internal/planner/select.go b/internal/planner/select.go index d0e816cfb9..d3bcbb910d 100644 --- a/internal/planner/select.go +++ b/internal/planner/select.go @@ -19,6 +19,7 @@ import ( "github.com/sourcenetwork/defradb/internal/core" "github.com/sourcenetwork/defradb/internal/db/base" "github.com/sourcenetwork/defradb/internal/keys" + "github.com/sourcenetwork/defradb/internal/planner/filter" "github.com/sourcenetwork/defradb/internal/planner/mapper" ) @@ -344,18 +345,21 @@ func (n *selectNode) initFields(selectReq *mapper.Select) ([]aggregateNode, erro case *mapper.Aggregate: var plan aggregateNode var aggregateError error + var aggregateFilter *mapper.Filter + // extract aggregate filters from the select + selectReq.Filter, aggregateFilter = filter.SplitByFields(selectReq.Filter, f.Field) switch f.Name { case request.CountFieldName: - plan, aggregateError = n.planner.Count(f, selectReq) + plan, aggregateError = n.planner.Count(f, selectReq, aggregateFilter) case request.SumFieldName: - plan, aggregateError = n.planner.Sum(f, selectReq) + plan, aggregateError = n.planner.Sum(f, selectReq, aggregateFilter) case request.AverageFieldName: - plan, aggregateError = n.planner.Average(f) + plan, aggregateError = n.planner.Average(f, aggregateFilter) case request.MaxFieldName: - plan, aggregateError = n.planner.Max(f, selectReq) + plan, aggregateError = n.planner.Max(f, selectReq, aggregateFilter) case request.MinFieldName: - plan, aggregateError = n.planner.Min(f, selectReq) + plan, aggregateError = n.planner.Min(f, selectReq, aggregateFilter) } if aggregateError != nil { diff --git a/internal/planner/sum.go b/internal/planner/sum.go index c790cba60d..a77e56da3d 100644 --- a/internal/planner/sum.go +++ b/internal/planner/sum.go @@ -30,6 +30,7 @@ type sumNode struct { isFloat bool virtualFieldIndex int aggregateMapping []mapper.AggregateTarget + aggregateFilter *mapper.Filter execInfo sumExecInfo } @@ -42,6 +43,7 @@ type sumExecInfo struct { func (p *Planner) Sum( field *mapper.Aggregate, parent *mapper.Select, + filter *mapper.Filter, ) (*sumNode, error) { isFloat := false for _, target := range field.AggregateTargets { @@ -60,6 +62,7 @@ func (p *Planner) Sum( p: p, isFloat: isFloat, aggregateMapping: field.AggregateTargets, + aggregateFilter: filter, virtualFieldIndex: field.Index, docMapper: docMapper{field.DocumentMapping}, }, nil @@ -310,8 +313,7 @@ func (n *sumNode) Next() (bool, error) { typedSum = int64(sum) } n.currentValue.Fields[n.virtualFieldIndex] = typedSum - - return true, nil + return mapper.RunFilter(n.currentValue, n.aggregateFilter) } func (n *sumNode) SetPlan(p planNode) { n.plan = p } diff --git a/internal/planner/top.go b/internal/planner/top.go index 6224b6d62d..658dc66dd8 100644 --- a/internal/planner/top.go +++ b/internal/planner/top.go @@ -199,15 +199,15 @@ func (p *Planner) Top(m *mapper.Select) (*topLevelNode, error) { var err error switch field.GetName() { case request.CountFieldName: - child, err = p.Count(f, m) + child, err = p.Count(f, m, nil) case request.SumFieldName: - child, err = p.Sum(f, m) + child, err = p.Sum(f, m, nil) case request.AverageFieldName: - child, err = p.Average(f) + child, err = p.Average(f, nil) case request.MaxFieldName: - child, err = p.Max(f, m) + child, err = p.Max(f, m, nil) case request.MinFieldName: - child, err = p.Min(f, m) + child, err = p.Min(f, m, nil) } if err != nil { return nil, err diff --git a/tests/integration/query/one_to_many/with_count_test.go b/tests/integration/query/one_to_many/with_count_test.go index 77d4e754f3..77905ed748 100644 --- a/tests/integration/query/one_to_many/with_count_test.go +++ b/tests/integration/query/one_to_many/with_count_test.go @@ -119,11 +119,9 @@ func TestQueryOneToManyWithCount(t *testing.T) { } } -// This test documents the behavior of aggregate alias targeting which is not yet implemented. -// https://github.com/sourcenetwork/defradb/issues/3195 -func TestQueryOneToMany_WithCountAliasFilter_ShouldFilterAll(t *testing.T) { +func TestQueryOneToMany_WithCountAliasFilter_ShouldMatchAll(t *testing.T) { test := testUtils.TestCase{ - Description: "One-to-many relation query from many side with count", + Description: "One-to-many relation query from many side with count alias", Actions: []any{ testUtils.CreateDoc{ CollectionID: 1, @@ -173,7 +171,16 @@ func TestQueryOneToMany_WithCountAliasFilter_ShouldFilterAll(t *testing.T) { } }`, Results: map[string]any{ - "Author": []map[string]any{}, + "Author": []map[string]any{ + { + "name": "Cornelia Funke", + "publishedCount": 1, + }, + { + "name": "John Grisham", + "publishedCount": 2, + }, + }, }, }, }, diff --git a/tests/integration/query/simple/with_group_aggregate_alias_filter_test.go b/tests/integration/query/simple/with_group_aggregate_alias_filter_test.go new file mode 100644 index 0000000000..037a187cad --- /dev/null +++ b/tests/integration/query/simple/with_group_aggregate_alias_filter_test.go @@ -0,0 +1,303 @@ +// Copyright 2024 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package simple + +import ( + "testing" + + testUtils "github.com/sourcenetwork/defradb/tests/integration" +) + +func TestQuerySimple_WithGroupAverageAliasFilter_FiltersResults(t *testing.T) { + test := testUtils.TestCase{ + Description: "Simple query with group average alias filter", + Actions: []any{ + testUtils.SchemaUpdate{ + Schema: `type Users { + Name: String + Score: Int + }`, + }, + testUtils.CreateDoc{ + Doc: `{ + "Name": "Bob", + "Score": 10 + }`, + }, + testUtils.CreateDoc{ + Doc: `{ + "Name": "Bob", + "Score": 20 + }`, + }, + testUtils.CreateDoc{ + Doc: `{ + "Name": "Alice", + "Score": 40 + }`, + }, + testUtils.CreateDoc{ + Doc: `{ + "Name": "Alice", + "Score": 0 + }`, + }, + testUtils.Request{ + Request: `query { + Users(groupBy: [Name], filter: {_alias: {averageScore: {_eq: 20}}}) { + Name + averageScore: _avg(_group: {field: Score}) + } + }`, + Results: map[string]any{ + "Users": []map[string]any{ + { + "Name": "Alice", + "averageScore": float64(20), + }, + }, + }, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} + +func TestQuerySimple_WithGroupSumAliasFilter_FiltersResults(t *testing.T) { + test := testUtils.TestCase{ + Description: "Simple query with group sum alias filter", + Actions: []any{ + testUtils.SchemaUpdate{ + Schema: `type Users { + Name: String + Score: Int + }`, + }, + testUtils.CreateDoc{ + Doc: `{ + "Name": "Bob", + "Score": 10 + }`, + }, + testUtils.CreateDoc{ + Doc: `{ + "Name": "Bob", + "Score": 20 + }`, + }, + testUtils.CreateDoc{ + Doc: `{ + "Name": "Alice", + "Score": 40 + }`, + }, + testUtils.CreateDoc{ + Doc: `{ + "Name": "Alice", + "Score": 0 + }`, + }, + testUtils.Request{ + Request: `query { + Users(groupBy: [Name], filter: {_alias: {totalScore: {_eq: 40}}}) { + Name + totalScore: _sum(_group: {field: Score}) + } + }`, + Results: map[string]any{ + "Users": []map[string]any{ + { + "Name": "Alice", + "totalScore": float64(40), + }, + }, + }, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} + +func TestQuerySimple_WithGroupMinAliasFilter_FiltersResults(t *testing.T) { + test := testUtils.TestCase{ + Description: "Simple query with group min alias filter", + Actions: []any{ + testUtils.SchemaUpdate{ + Schema: `type Users { + Name: String + Score: Int + }`, + }, + testUtils.CreateDoc{ + Doc: `{ + "Name": "Bob", + "Score": 10 + }`, + }, + testUtils.CreateDoc{ + Doc: `{ + "Name": "Bob", + "Score": 20 + }`, + }, + testUtils.CreateDoc{ + Doc: `{ + "Name": "Alice", + "Score": 40 + }`, + }, + testUtils.CreateDoc{ + Doc: `{ + "Name": "Alice", + "Score": 0 + }`, + }, + testUtils.Request{ + Request: `query { + Users(groupBy: [Name], filter: {_alias: {minScore: {_eq: 0}}}) { + Name + minScore: _min(_group: {field: Score}) + } + }`, + Results: map[string]any{ + "Users": []map[string]any{ + { + "Name": "Alice", + "minScore": int64(0), + }, + }, + }, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} + +func TestQuerySimple_WithGroupMaxAliasFilter_FiltersResults(t *testing.T) { + test := testUtils.TestCase{ + Description: "Simple query with group max alias filter", + Actions: []any{ + testUtils.SchemaUpdate{ + Schema: `type Users { + Name: String + Score: Int + }`, + }, + testUtils.CreateDoc{ + Doc: `{ + "Name": "Bob", + "Score": 10 + }`, + }, + testUtils.CreateDoc{ + Doc: `{ + "Name": "Bob", + "Score": 20 + }`, + }, + testUtils.CreateDoc{ + Doc: `{ + "Name": "Alice", + "Score": 40 + }`, + }, + testUtils.CreateDoc{ + Doc: `{ + "Name": "Alice", + "Score": 0 + }`, + }, + testUtils.Request{ + Request: `query { + Users(groupBy: [Name], filter: {_alias: {maxScore: {_eq: 40}}}) { + Name + maxScore: _max(_group: {field: Score}) + } + }`, + Results: map[string]any{ + "Users": []map[string]any{ + { + "Name": "Alice", + "maxScore": int64(40), + }, + }, + }, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} + +func TestQuerySimple_WithGroupCountAliasFilter_FiltersResults(t *testing.T) { + test := testUtils.TestCase{ + Description: "Simple query with group count alias filter", + Actions: []any{ + testUtils.SchemaUpdate{ + Schema: `type Users { + Name: String + Score: Int + }`, + }, + testUtils.CreateDoc{ + Doc: `{ + "Name": "Bob", + "Score": 10 + }`, + }, + testUtils.CreateDoc{ + Doc: `{ + "Name": "Bob", + "Score": 20 + }`, + }, + testUtils.CreateDoc{ + Doc: `{ + "Name": "Alice", + "Score": 40 + }`, + }, + testUtils.CreateDoc{ + Doc: `{ + "Name": "Alice", + "Score": 0 + }`, + }, + testUtils.CreateDoc{ + Doc: `{ + "Name": "Alice", + "Score": 5 + }`, + }, + testUtils.Request{ + Request: `query { + Users(groupBy: [Name], filter: {_alias: {scores: {_eq: 3}}}) { + Name + scores: _count(_group: {}) + } + }`, + Results: map[string]any{ + "Users": []map[string]any{ + { + "Name": "Alice", + "scores": int64(3), + }, + }, + }, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +}