From 2fde46c901c19a95262b9b94389b64ccbf1e65f7 Mon Sep 17 00:00:00 2001 From: Sean Smith Date: Mon, 16 Dec 2024 19:03:10 -0600 Subject: [PATCH 1/2] feat: Natural join --- .../logical/binder/bind_query/bind_from.rs | 50 ++++++++++- crates/rayexec_parser/src/ast/from.rs | 36 +++++++- slt/standard/join/natural_join.slt | 85 +++++++++++++++++++ 3 files changed, 166 insertions(+), 5 deletions(-) create mode 100644 slt/standard/join/natural_join.slt diff --git a/crates/rayexec_execution/src/logical/binder/bind_query/bind_from.rs b/crates/rayexec_execution/src/logical/binder/bind_query/bind_from.rs index b17a78a8d..87d07f0aa 100644 --- a/crates/rayexec_execution/src/logical/binder/bind_query/bind_from.rs +++ b/crates/rayexec_execution/src/logical/binder/bind_query/bind_from.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::sync::Arc; use rayexec_bullet::datatype::DataType; @@ -550,7 +550,41 @@ impl<'a> FromBinder<'a> { (Vec::new(), using_cols) } ast::JoinCondition::Natural => { - not_implemented!("NATURAL join") + // Get tables refs from the left. + // + // We want to prune these tables out from the right. Tables are + // implicitly in scope on the right for lateral references. + let left_tables: HashSet<_> = bind_context + .iter_tables_in_scope(left_idx)? + .map(|table| table.reference) + .collect(); + + // Get columns from the left. + let left_cols: HashSet<_> = bind_context + .iter_tables_in_scope(left_idx)? + .flat_map(|table| table.column_names.iter()) + .collect(); + + // Get columns from the right, skipping columns from tables that + // would generate a lateral reference. + let right_cols = bind_context + .iter_tables_in_scope(right_idx)? + .filter(|table| !left_tables.contains(&table.reference)) + .flat_map(|table| table.column_names.iter()); + + let mut common = Vec::new(); + + // Now collect the columns that are common in both. + // + // Manually iterate over using a hash set intersection to keep + // the order of columns consistent. + for right_col in right_cols { + if left_cols.contains(right_col) { + common.push(right_col.clone()); + } + } + + (Vec::new(), common) } ast::JoinCondition::None => (Vec::new(), Vec::new()), }; @@ -612,8 +646,16 @@ impl<'a> FromBinder<'a> { }, }; - // Add USING column to _current_ scope. - bind_context.append_using_column(self.current, using_column)?; + // Add USING column to _current_ scope if we don't already have an + // equivalent column in our using set. + let already_using = bind_context + .get_using_columns(self.current)? + .iter() + .any(|c| c.column == using_column.column); + + if !already_using { + bind_context.append_using_column(self.current, using_column)?; + } // Generate additional equality condition. // TODO: Probably make this a method on the expr binder. Easy to miss the cast. diff --git a/crates/rayexec_parser/src/ast/from.rs b/crates/rayexec_parser/src/ast/from.rs index 90d98f146..ec273879a 100644 --- a/crates/rayexec_parser/src/ast/from.rs +++ b/crates/rayexec_parser/src/ast/from.rs @@ -60,6 +60,9 @@ impl AstParseable for FromNode { }), } } else { + // Optional NATURAL prefixing the join type. + let natural = parser.parse_keyword(Keyword::NATURAL); + let kw = match parser.peek() { Some(tok) => match tok.keyword() { Some(kw) => kw, @@ -166,7 +169,13 @@ impl AstParseable for FromNode { parser.parse_parenthesized_comma_separated(Ident::parse)?, ) } - _ => JoinCondition::None, + _ => { + if natural { + JoinCondition::Natural + } else { + JoinCondition::None + } + } }; node = FromNode { @@ -688,4 +697,29 @@ mod tests { }; assert_eq!(expected, node, "left:\n{expected:#?}\nright:\n{node:#?}"); } + + #[test] + fn natural_inner_join_lateral() { + let node: FromNode<_> = parse_ast("t1 NATURAL INNER JOIN t2").unwrap(); + let expected = FromNode { + alias: None, + body: FromNodeBody::Join(FromJoin { + left: Box::new(FromNode { + alias: None, + body: FromNodeBody::BaseTable(FromBaseTable { + reference: ObjectReference::from_strings(["t1"]), + }), + }), + right: Box::new(FromNode { + alias: None, + body: FromNodeBody::BaseTable(FromBaseTable { + reference: ObjectReference::from_strings(["t2"]), + }), + }), + join_type: JoinType::Inner, + join_condition: JoinCondition::Natural, + }), + }; + assert_eq!(expected, node, "left:\n{expected:#?}\nright:\n{node:#?}"); + } } diff --git a/slt/standard/join/natural_join.slt b/slt/standard/join/natural_join.slt new file mode 100644 index 000000000..51ec6aa62 --- /dev/null +++ b/slt/standard/join/natural_join.slt @@ -0,0 +1,85 @@ +# Natural join tests + +statement ok +CREATE TEMP TABLE t1 (num INT, name TEXT); + +statement ok +CREATE TEMP TABLE t2 (num INT, value TEXT); + +statement ok +INSERT INTO t1 VALUES (1, 'a'), (2, 'b'), (3, 'c'); + +statement ok +INSERT INTO t2 VALUES (1, 'xxx'), (3, 'yyy'), (5, 'zzz'); + +query ITT +SELECT * FROM t1 NATURAL INNER JOIN t2 ORDER BY num; +---- +1 a xxx +3 c yyy + +# Order by qualified name. +query ITT +SELECT * FROM t1 NATURAL INNER JOIN t2 ORDER BY t1.num; +---- +1 a xxx +3 c yyy + +query ITT +SELECT * FROM t1 NATURAL INNER JOIN t2 ORDER BY t2.num; +---- +1 a xxx +3 c yyy + +statement ok +CREATE TEMP TABLE t3 (extra TEXT, num INT); + +statement ok +INSERT INTO t3 VALUES ('cat', 3), ('dog', 4), ('goose', 5); + +query ITTT +SELECT * FROM t1 NATURAL INNER JOIN t2 NATURAL INNER JOIN t3; +---- +3 c yyy cat + +query IITT +SELECT t1.num, * FROM t1 NATURAL INNER JOIN t2 NATURAL INNER JOIN t3; +---- +3 3 c yyy cat + +query I +SELECT t1.num FROM t1 NATURAL INNER JOIN t2 NATURAL INNER JOIN t3; +---- +3 + +query ITTT +SELECT * FROM t1 NATURAL INNER JOIN t2 NATURAL LEFT JOIN t3 ORDER BY 1; +---- +1 a xxx NULL +3 c yyy cat + +query ITTT +SELECT * FROM t1 NATURAL RIGHT JOIN t2 NATURAL LEFT JOIN t3 ORDER BY value; +---- +1 a xxx NULL +3 c yyy cat +5 NULL zzz goose + +query ITTT +SELECT * FROM t1 NATURAL RIGHT JOIN t2 NATURAL LEFT JOIN t3 ORDER BY t3.num, t1.num; +---- +3 c yyy cat +5 NULL zzz goose +1 a xxx NULL + +# USING columns from subqueries +query ITT +SELECT * FROM t1 NATURAL JOIN (SELECT * FROM t2) ORDER BY 1; +---- +1 a xxx +3 c yyy + +query IT +SELECT * FROM t1 NATURAL JOIN (SELECT 3) s(num) ORDER BY 1; +---- +3 c From e5e7a928e8315bb56a540b0c62be397efa491f2c Mon Sep 17 00:00:00 2001 From: Sean Smith Date: Mon, 16 Dec 2024 19:07:25 -0600 Subject: [PATCH 2/2] fixup! feat: Natural join --- slt/standard/join/natural_join.slt | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/slt/standard/join/natural_join.slt b/slt/standard/join/natural_join.slt index 51ec6aa62..68a3f12ed 100644 --- a/slt/standard/join/natural_join.slt +++ b/slt/standard/join/natural_join.slt @@ -79,7 +79,17 @@ SELECT * FROM t1 NATURAL JOIN (SELECT * FROM t2) ORDER BY 1; 1 a xxx 3 c yyy +query IT +SELECT * FROM t1 NATURAL JOIN (SELECT 3) s(num); +---- +3 c + query IT SELECT * FROM t1 NATURAL JOIN (SELECT 3) s(num) ORDER BY 1; ---- 3 c + +query IT +SELECT * FROM t1 NATURAL JOIN (SELECT 3 AS num); +---- +3 c