Skip to content

Commit

Permalink
opt: do not remap columns under a grouping operator
Browse files Browse the repository at this point in the history
The `TryRemapOuterCols` rules attempt to replace an outer column reference
within a relational expression by a non-outer column. The transformation is
valid because the rule only fires when a parent expression filters all rows
for which the outer and non-outer columns are not equal. For example:
```
SELECT * FROM xy INNER JOIN (SELECT a + x FROM ab) ON x = a;
```
In the above query, the reference to `x` in the subquery could be replaced
with `x`, and the rows for which `a = x` does not hold would be later
removed by the join.

However, the rule also performs this transformation for GroupBy and
DistinctOn, and in particular, the replacement column could be one of the
`ConstAgg` columns. This could cause incorrect results because input rows
that satisfy the equality and rows that don't could be aggregated together
(note that de-duplication is a special form of aggregation).

Note that the transformation is correct when the replacement column is one
of the grouping columns; this is because the grouping then separates rows
that satisfy the equality from those that don't. The groups that don't
satisfy the equality are then filtered later in the query, preserving the
correct result. This also applies to the distinct variants of set operators
(Union, Except, Intersect), since they de-duplicate across all columns.

This patch fixes the bug by restricting the rule to only make column
replacements under a GroupBy or DistinctOn for columns that are part of
the grouping column set.

Fixes cockroachdb#130001

Release note (bug fix): Fixed a bug introduced in v23.1 that can cause
incorrect results when:
1. The query contains a correlated subquery.
2. The correlated subquery has a GroupBy or DistinctOn operator with an
   outer-column reference in its input.
3. The correlated subquery is in the input of a Select or Join operator
4. The Select or Join has a filter that sets the outer-column reference from
   (2) equal to a non-outer column in the input of the grouping operator.
5. The grouping column set does not include the replacement column, and
   functionally determines the replacement column.
  • Loading branch information
DrewKimball committed Sep 19, 2024
1 parent f09d877 commit f38e791
Show file tree
Hide file tree
Showing 5 changed files with 213 additions and 59 deletions.
53 changes: 53 additions & 0 deletions pkg/sql/logictest/testdata/logic_test/subquery
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,10 @@ SELECT (SELECT 2, 2) IN (SELECT x+1 FROM xy)
query error pgcode 42601 subquery has too many columns
SELECT (SELECT 2, 2) IN (SELECT x+1, y+1, x+y FROM xy)

subtest end

subtest regression_100561

# Regression test for #100561.
statement ok
CREATE TABLE t100561a (a INT);
Expand All @@ -987,3 +991,52 @@ SELECT * FROM (
WHERE tmp.x = bc.b + 1;
----
NULL 2 1 NULL

subtest end

subtest regression_130001

# Regression test for #100561.

# Adding a redundant filter to the LEFT JOIN should not change the result.
query TTTI rowsort
WITH a (colA) AS (
VALUES ('row-1'), ('row-2')
),
b (colB) AS (
VALUES ('row-1'), ('row-2')
)
SELECT a.colA, l.colB, l.colB_agg, l.count
FROM a
LEFT JOIN LATERAL (
SELECT colB, array_agg(colB) AS colB_agg, count(*) AS count
FROM b
WHERE colB = a.colA
GROUP BY colB
) l ON true;
----
row-1 row-1 {row-1} 1
row-2 row-2 {row-2} 1

query TTTI rowsort
WITH a (colA) AS (
VALUES ('row-1'), ('row-2')
),
b (colB) AS (
VALUES ('row-1'), ('row-2')
)
SELECT a.colA, l.colB, l.colB_agg, l.count
FROM a
LEFT JOIN LATERAL (
SELECT colB, array_agg(colB) AS colB_agg, count(*) AS count
FROM b
WHERE colB = a.colA
GROUP BY colB
) l ON true
-- redundant filter
AND l.colB = a.colA;
----
row-1 row-1 {row-1} 1
row-2 row-2 {row-2} 1

subtest end
23 changes: 20 additions & 3 deletions pkg/sql/opt/norm/decorrelate_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -1419,6 +1419,10 @@ func (c *CustomFuncs) tryRemapOuterCols(
// cycles. Any other rules that reuse this logic should reconsider the
// simplification made in getSubstituteColsSetOp.
//
// NOTE: care must be taken for operators that may aggregate or "group" rows.
// If rows for which the outer-column equality holds are grouped together with
// those for which it does not, the result set will be incorrect (see #130001).
//
// getSubstituteColsRelExpr copies substituteCols before performing any
// modifications, so the original ColSet is not mutated.
func (c *CustomFuncs) getSubstituteColsRelExpr(
Expand Down Expand Up @@ -1454,16 +1458,29 @@ func (c *CustomFuncs) getSubstituteColsRelExpr(
*memo.SemiJoinApplyExpr, *memo.AntiJoinExpr, *memo.AntiJoinApplyExpr:
// [PushSelectIntoJoinLeft]
// [PushSelectCondLeftIntoJoinLeftAndRight]
// NOTE: These join variants do perform "grouping" operations, but only on
// the right input, for which we do not push down the equality.
substituteCols = getSubstituteColsLeftSemiAntiJoin(t, substituteCols)
case *memo.GroupByExpr, *memo.DistinctOnExpr:
// [PushSelectIntoGroupBy]
// Filters must refer only to grouping and ConstAgg columns.
// Filters must refer only to grouping columns. This ensures that the rows
// that satisfy the outer-column equality are grouped separately from those
// that do not. The rows that do not satisfy the equality will therefore not
// affect the values of the rows that do, and they will be filtered out
// later, ensuring that the transformation does not change the result set.
// See also #130001.
//
// NOTE: this is more restrictive than PushSelectIntoGroupBy, which also
// allows references to ConstAgg columns.
private := t.Private().(*memo.GroupingPrivate)
aggs := t.Child(1).(*memo.AggregationsExpr)
substituteCols.IntersectionWith(c.GroupingAndConstCols(private, *aggs))
substituteCols.IntersectionWith(private.GroupingCols)
case *memo.UnionExpr, *memo.UnionAllExpr, *memo.IntersectExpr,
*memo.IntersectAllExpr, *memo.ExceptExpr, *memo.ExceptAllExpr:
// [PushFilterIntoSetOp]
// NOTE: the distinct variants (Union, Intersect, Except) de-duplicate
// across all columns, so the requirement that filters only reference
// grouping columns is always satisfied. See the comment for DistinctOn
// above.
substituteCols = getSubstituteColsSetOp(t, substituteCols)
default:
// Filter push-down through this expression is not supported.
Expand Down
116 changes: 95 additions & 21 deletions pkg/sql/opt/norm/testdata/rules/decorrelate
Original file line number Diff line number Diff line change
Expand Up @@ -6940,27 +6940,6 @@ inner-join (hash)
└── filters
└── a = xy.x

# Case with a GroupBy. Equality references a ConstAgg column.
norm expect=TryRemapJoinOuterColsRight format=hide-all
SELECT * FROM xy INNER JOIN LATERAL (SELECT v, corr(u, x) FROM uv GROUP BY u, v) ON v = x
----
project
└── inner-join (hash)
├── scan xy
├── group-by (hash)
│ ├── project
│ │ ├── scan uv
│ │ └── projections
│ │ └── v
│ └── aggregations
│ ├── corr
│ │ ├── u
│ │ └── x
│ └── const-agg
│ └── v
└── filters
└── v = xy.x

# Case with a DistinctOn.
norm expect=TryRemapJoinOuterColsRight format=hide-all
SELECT * FROM xy INNER JOIN LATERAL (SELECT DISTINCT ON (a) * FROM (SELECT *, b+x FROM ab)) ON a = x
Expand Down Expand Up @@ -7202,6 +7181,101 @@ project
└── projections
└── a:5 + x:1 [as="?column?":10, outer=(1,5), immutable]

# Regression test for #130001. Do not remap columns under a GroupBy or
# DistinctOn unless the grouping columns include the equivalent non-outer
# column.
#
# Case with a GroupBy.
norm expect-not=TryRemapJoinOuterColsRight
SELECT * FROM xy INNER JOIN LATERAL (SELECT v, corr(u, x) FROM uv GROUP BY u, v) ON v = x
----
project
├── columns: x:1!null y:2 v:6!null corr:10
├── fd: (1)-->(2), (1)==(6), (6)==(1)
└── group-by (hash)
├── columns: xy.x:1!null y:2 u:5!null v:6!null corr:10
├── grouping columns: u:5!null
├── key: (5)
├── fd: (1)-->(2), (5)-->(1,2,6,10), (1)==(6), (6)==(1)
├── project
│ ├── columns: x:9!null xy.x:1!null y:2 u:5!null v:6!null
│ ├── key: (5)
│ ├── fd: (1)-->(2), (5)-->(6), (1)==(6,9), (6)==(1,9), (9)==(1,6)
│ ├── inner-join (hash)
│ │ ├── columns: xy.x:1!null y:2 u:5!null v:6!null
│ │ ├── multiplicity: left-rows(zero-or-more), right-rows(zero-or-one)
│ │ ├── key: (5)
│ │ ├── fd: (1)-->(2), (5)-->(6), (1)==(6), (6)==(1)
│ │ ├── scan xy
│ │ │ ├── columns: xy.x:1!null y:2
│ │ │ ├── key: (1)
│ │ │ └── fd: (1)-->(2)
│ │ ├── scan uv
│ │ │ ├── columns: u:5!null v:6
│ │ │ ├── key: (5)
│ │ │ └── fd: (5)-->(6)
│ │ └── filters
│ │ └── v:6 = xy.x:1 [outer=(1,6), constraints=(/1: (/NULL - ]; /6: (/NULL - ]), fd=(1)==(6), (6)==(1)]
│ └── projections
│ └── xy.x:1 [as=x:9, outer=(1)]
└── aggregations
├── corr [as=corr:10, outer=(5,9)]
│ ├── u:5
│ └── x:9
├── const-agg [as=v:6, outer=(6)]
│ └── v:6
├── const-agg [as=y:2, outer=(2)]
│ └── y:2
└── const-agg [as=xy.x:1, outer=(1)]
└── xy.x:1

# Case with a DistinctOn.
norm expect-not=TryRemapJoinOuterColsRight
SELECT * FROM xy INNER JOIN LATERAL (SELECT DISTINCT ON (a, foo) * FROM (SELECT 1 AS foo, *, b+x FROM ab)) ON foo = x
----
distinct-on
├── columns: x:1!null y:2 foo:10!null a:5 b:6 "?column?":11
├── grouping columns: a:5
├── immutable
├── key: (5)
├── fd: ()-->(1,2,10), (6)-->(11), (5)-->(1,2,6,10,11)
├── project
│ ├── columns: foo:10!null "?column?":11 x:1!null y:2 a:5 b:6
│ ├── immutable
│ ├── fd: ()-->(1,2,10), (6)-->(11)
│ ├── inner-join (cross)
│ │ ├── columns: x:1!null y:2 a:5 b:6
│ │ ├── multiplicity: left-rows(zero-or-more), right-rows(zero-or-one)
│ │ ├── fd: ()-->(1,2)
│ │ ├── select
│ │ │ ├── columns: x:1!null y:2
│ │ │ ├── cardinality: [0 - 1]
│ │ │ ├── key: ()
│ │ │ ├── fd: ()-->(1,2)
│ │ │ ├── scan xy
│ │ │ │ ├── columns: x:1!null y:2
│ │ │ │ ├── key: (1)
│ │ │ │ └── fd: (1)-->(2)
│ │ │ └── filters
│ │ │ └── x:1 = 1 [outer=(1), constraints=(/1: [/1 - /1]; tight), fd=()-->(1)]
│ │ ├── scan ab
│ │ │ └── columns: a:5 b:6
│ │ └── filters (true)
│ └── projections
│ ├── 1 [as=foo:10]
│ └── b:6 + x:1 [as="?column?":11, outer=(1,6), immutable]
└── aggregations
├── first-agg [as=b:6, outer=(6)]
│ └── b:6
├── first-agg [as="?column?":11, outer=(11)]
│ └── "?column?":11
├── const-agg [as=foo:10, outer=(10)]
│ └── foo:10
├── const-agg [as=y:2, outer=(2)]
│ └── y:2
└── const-agg [as=x:1, outer=(1)]
└── x:1

# --------------------------------------------------
# TryRemapJoinOuterColsLeft
# --------------------------------------------------
Expand Down
78 changes: 44 additions & 34 deletions pkg/sql/opt/xform/testdata/external/hibernate
Original file line number Diff line number Diff line change
Expand Up @@ -1226,30 +1226,30 @@ where
project
├── columns: id1_2_:1!null address2_2_:2 createdo3_2_:3 name4_2_:4 nickname5_2_:5 version6_2_:6!null
├── fd: (1)-->(2-6)
└── inner-join (lookup person [as=person0_])
├── columns: person0_.id:1!null address:2 createdon:3 name:4 nickname:5 version:6!null phones2_.id:9!null phones2_.person_id:12!null phones2_.order_id:13!null max:23!null
├── key columns: [12] = [1]
├── lookup columns are key
└── select
├── columns: person0_.id:1!null address:2 createdon:3 name:4 nickname:5 version:6!null phones2_.id:9!null phones2_.order_id:13!null max:23!null
├── key: (9)
├── fd: (1)-->(2-6), (9)-->(12,13,23), (13)==(23), (23)==(13), (1)==(12), (12)==(1)
├── select
│ ├── columns: phones2_.id:9!null phones2_.person_id:12!null phones2_.order_id:13!null max:23!null
├── fd: (1)-->(2-6), (9)-->(1-6,13,23), (13)==(23), (23)==(13)
├── group-by (hash)
│ ├── columns: person0_.id:1!null address:2 createdon:3 name:4 nickname:5 version:6!null phones2_.id:9!null phones2_.order_id:13 max:23!null
│ ├── grouping columns: phones2_.id:9!null
│ ├── key: (9)
│ ├── fd: (9)-->(12,13,23), (13)==(23), (23)==(13)
│ ├── group-by (hash)
│ │ ├── columns: phones2_.id:9!null phones2_.person_id:12!null phones2_.order_id:13 max:23!null
│ │ ├── grouping columns: phones2_.id:9!null
│ │ ├── key: (9)
│ │ ├── fd: (9)-->(12,13,23)
│ │ ├── inner-join (hash)
│ │ │ ├── columns: phones2_.id:9!null phones2_.phone_type:11!null phones2_.person_id:12!null phones2_.order_id:13 phones1_.person_id:19!null phones1_.order_id:20!null
│ │ │ ├── fd: ()-->(11), (9)-->(12,13), (12)==(19), (19)==(12)
│ │ │ ├── select
│ │ │ │ ├── columns: phones1_.person_id:19 phones1_.order_id:20!null
│ │ │ │ ├── scan phone [as=phones1_]
│ │ │ │ │ └── columns: phones1_.person_id:19 phones1_.order_id:20
│ │ │ │ └── filters
│ │ │ │ └── phones1_.order_id:20 IS NOT NULL [outer=(20), constraints=(/20: (/NULL - ]; tight)]
│ ├── fd: (1)-->(2-6), (9)-->(1-6,13,23)
│ ├── inner-join (hash)
│ │ ├── columns: person0_.id:1!null address:2 createdon:3 name:4 nickname:5 version:6!null phones2_.id:9!null phones2_.phone_type:11!null phones2_.person_id:12!null phones2_.order_id:13 phones1_.person_id:19!null phones1_.order_id:20!null
│ │ ├── fd: ()-->(11), (1)-->(2-6), (9)-->(12,13), (1)==(12,19), (12)==(1,19), (19)==(1,12)
│ │ ├── select
│ │ │ ├── columns: phones1_.person_id:19 phones1_.order_id:20!null
│ │ │ ├── scan phone [as=phones1_]
│ │ │ │ └── columns: phones1_.person_id:19 phones1_.order_id:20
│ │ │ └── filters
│ │ │ └── phones1_.order_id:20 IS NOT NULL [outer=(20), constraints=(/20: (/NULL - ]; tight)]
│ │ ├── inner-join (lookup person [as=person0_])
│ │ │ ├── columns: person0_.id:1!null address:2 createdon:3 name:4 nickname:5 version:6!null phones2_.id:9!null phones2_.phone_type:11!null phones2_.person_id:12!null phones2_.order_id:13
│ │ │ ├── key columns: [12] = [1]
│ │ │ ├── lookup columns are key
│ │ │ ├── key: (9)
│ │ │ ├── fd: ()-->(11), (1)-->(2-6), (9)-->(12,13), (1)==(12), (12)==(1)
│ │ │ ├── select
│ │ │ │ ├── columns: phones2_.id:9!null phones2_.phone_type:11!null phones2_.person_id:12 phones2_.order_id:13
│ │ │ │ ├── key: (9)
Expand All @@ -1260,18 +1260,28 @@ project
│ │ │ │ │ └── fd: (9)-->(11-13)
│ │ │ │ └── filters
│ │ │ │ └── phones2_.phone_type:11 = 'LAND_LINE' [outer=(11), constraints=(/11: [/'LAND_LINE' - /'LAND_LINE']; tight), fd=()-->(11)]
│ │ │ └── filters
│ │ │ └── phones2_.person_id:12 = phones1_.person_id:19 [outer=(12,19), constraints=(/12: (/NULL - ]; /19: (/NULL - ]), fd=(12)==(19), (19)==(12)]
│ │ └── aggregations
│ │ ├── max [as=max:23, outer=(20)]
│ │ │ └── phones1_.order_id:20
│ │ ├── const-agg [as=phones2_.person_id:12, outer=(12)]
│ │ │ └── phones2_.person_id:12
│ │ └── const-agg [as=phones2_.order_id:13, outer=(13)]
│ │ └── phones2_.order_id:13
│ └── filters
│ └── phones2_.order_id:13 = max:23 [outer=(13,23), constraints=(/13: (/NULL - ]; /23: (/NULL - ]), fd=(13)==(23), (23)==(13)]
└── filters (true)
│ │ │ └── filters (true)
│ │ └── filters
│ │ └── person0_.id:1 = phones1_.person_id:19 [outer=(1,19), constraints=(/1: (/NULL - ]; /19: (/NULL - ]), fd=(1)==(19), (19)==(1)]
│ └── aggregations
│ ├── max [as=max:23, outer=(20)]
│ │ └── phones1_.order_id:20
│ ├── const-agg [as=phones2_.order_id:13, outer=(13)]
│ │ └── phones2_.order_id:13
│ ├── const-agg [as=address:2, outer=(2)]
│ │ └── address:2
│ ├── const-agg [as=createdon:3, outer=(3)]
│ │ └── createdon:3
│ ├── const-agg [as=name:4, outer=(4)]
│ │ └── name:4
│ ├── const-agg [as=nickname:5, outer=(5)]
│ │ └── nickname:5
│ ├── const-agg [as=version:6, outer=(6)]
│ │ └── version:6
│ └── const-agg [as=person0_.id:1, outer=(1)]
│ └── person0_.id:1
└── filters
└── phones2_.order_id:13 = max:23 [outer=(13,23), constraints=(/13: (/NULL - ]; /23: (/NULL - ]), fd=(13)==(23), (23)==(13)]

opt
select
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/opt/xform/testdata/rules/disjunction_in_join
Original file line number Diff line number Diff line change
Expand Up @@ -2743,7 +2743,7 @@ project
│ │ ├── (a1:1 = c1:7) OR (c1:7 = c2:8) [outer=(1,7,8), constraints=(/7: (/NULL - ])]
│ │ └── a2:2 = c2:8 [outer=(2,8), constraints=(/2: (/NULL - ]; /8: (/NULL - ]), fd=(2)==(8), (8)==(2)]
│ └── filters
│ └── (a1:1 = b1:14) OR (a1:1 = c2:8) [outer=(1,8,14), constraints=(/1: (/NULL - ])]
│ └── (a1:1 = b1:14) OR (a1:1 = a2:2) [outer=(1,2,14), constraints=(/1: (/NULL - ])]
└── aggregations
├── const-agg [as=c1:7, outer=(7)]
│ └── c1:7
Expand Down

0 comments on commit f38e791

Please sign in to comment.