Skip to content

Commit

Permalink
impl
Browse files Browse the repository at this point in the history
  • Loading branch information
desmondcheongzx committed Feb 19, 2025
1 parent a42e55a commit a8c4fcc
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 30 deletions.
16 changes: 8 additions & 8 deletions src/daft-local-execution/src/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,8 @@ pub fn physical_plan_to_pipeline(
StatsState::Materialized(left_stats),
StatsState::Materialized(right_stats),
) => {
let left_size = left_stats.approx_stats.size_bytes;
let right_size = right_stats.approx_stats.size_bytes;
let left_size = left_stats.approx_stats.num_rows;
let right_size = right_stats.approx_stats.num_rows;
left_size <= right_size
}
// If stats are only available on the right side of the join, and the upper bound bytes on the
Expand All @@ -363,8 +363,8 @@ pub fn physical_plan_to_pipeline(
StatsState::Materialized(left_stats),
StatsState::Materialized(right_stats),
) => {
let left_size = left_stats.approx_stats.size_bytes;
let right_size = right_stats.approx_stats.size_bytes;
let left_size = left_stats.approx_stats.num_rows;
let right_size = right_stats.approx_stats.num_rows;
right_size as f64 >= left_size as f64 * 1.5
}
// If stats are only available on the left side of the join, and the upper bound bytes on the left
Expand All @@ -382,8 +382,8 @@ pub fn physical_plan_to_pipeline(
StatsState::Materialized(left_stats),
StatsState::Materialized(right_stats),
) => {
let left_size = left_stats.approx_stats.size_bytes;
let right_size = right_stats.approx_stats.size_bytes;
let left_size = left_stats.approx_stats.num_rows;
let right_size = right_stats.approx_stats.num_rows;
(right_size as f64 * 1.5) >= left_size as f64
}
// If stats are only available on the right side of the join, and the upper bound bytes on the
Expand All @@ -401,8 +401,8 @@ pub fn physical_plan_to_pipeline(
StatsState::Materialized(left_stats),
StatsState::Materialized(right_stats),
) => {
let left_size = left_stats.approx_stats.size_bytes;
let right_size = right_stats.approx_stats.size_bytes;
let left_size = left_stats.approx_stats.num_rows;
let right_size = right_stats.approx_stats.num_rows;
right_size as f64 > left_size as f64 * 1.5
}
// If stats are only available on the left side of the join, and the upper bound bytes on the left
Expand Down
6 changes: 1 addition & 5 deletions src/daft-logical-plan/src/builder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -796,11 +796,7 @@ impl LogicalPlanBuilder {
let unoptimized_plan = self.build();

let optimizer = OptimizerBuilder::default()
.when(
cfg.as_ref()
.map_or(false, |conf| conf.enable_join_reordering),
|builder| builder.reorder_joins(),
)
.reorder_joins()
.simplify_expressions()
.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ mod tests {
let order = $orderer.order(&graph);
assert!(JoinOrderTree::order_eq(&order, &$optimal_order));
// Check that the number of join conditions does not increase due to join edge inference.
assert_eq!(JoinOrderTree::num_join_conditions(&order), num_edges);
assert!(JoinOrderTree::num_join_conditions(&order) <= num_edges);
};
}

Expand Down Expand Up @@ -403,6 +403,138 @@ mod tests {
create_and_test_join_order!(nodes, edges, BruteForceJoinOrderer {}, optimal_order);
}

#[test]
fn test_brute_force_order_mock_tpch_sub_q9() {
let nodes = vec![
("nation", 25),
("supplier", 100_000),
("part", 100_000),
("partsupp", 8_000_000),
];
let name_to_id = node_to_id_map(nodes.clone());
let edges = vec![
JoinEdge {
node1: name_to_id["partsupp"],
node1_col_name: "ps_partkey".to_string(),
node2: name_to_id["part"],
node2_col_name: "p_partkey".to_string(),
total_domain: 2_000_000,
},
JoinEdge {
node1: name_to_id["partsupp"],
node1_col_name: "ps_suppkey".to_string(),
node2: name_to_id["supplier"],
node2_col_name: "s_suppkey".to_string(),
total_domain: 100_000,
},
JoinEdge {
node1: name_to_id["supplier"],
node1_col_name: "s_nationkey".to_string(),
node2: name_to_id["nation"],
node2_col_name: "n_nationkey".to_string(),
total_domain: 25,
},
];
let optimal_order = test_join(
test_join(
test_relation(name_to_id["nation"]),
test_relation(name_to_id["supplier"]),
),
test_join(
test_relation(name_to_id["part"]),
test_relation(name_to_id["partsupp"]),
),
);
create_and_test_join_order!(nodes, edges, BruteForceJoinOrderer {}, optimal_order);
}

#[test]
fn test_brute_force_order_mock_tpch_q9() {
let nodes = vec![
("nation", 22),
("orders", 1_350_000),
("lineitem", 4_374_885),
("supplier", 8_100),
("part", 18_000),
("partsupp", 648_000),
];
let name_to_id = node_to_id_map(nodes.clone());
let edges = vec![
JoinEdge {
node1: name_to_id["partsupp"],
node1_col_name: "ps_partkey".to_string(),
node2: name_to_id["part"],
node2_col_name: "p_partkey".to_string(),
total_domain: 200_000,
},
JoinEdge {
node1: name_to_id["partsupp"],
node1_col_name: "ps_partkey".to_string(),
node2: name_to_id["lineitem"],
node2_col_name: "l_partkey".to_string(),
total_domain: 200_000,
},
JoinEdge {
node1: name_to_id["partsupp"],
node1_col_name: "ps_suppkey".to_string(),
node2: name_to_id["lineitem"],
node2_col_name: "l_suppkey".to_string(),
total_domain: 10_000,
},
JoinEdge {
node1: name_to_id["partsupp"],
node1_col_name: "ps_suppkey".to_string(),
node2: name_to_id["supplier"],
node2_col_name: "s_suppkey".to_string(),
total_domain: 10_000,
},
JoinEdge {
node1: name_to_id["orders"],
node1_col_name: "o_orderkey".to_string(),
node2: name_to_id["lineitem"],
node2_col_name: "l_orderkey".to_string(),
total_domain: 1_500_000,
},
JoinEdge {
node1: name_to_id["lineitem"],
node1_col_name: "l_partkey".to_string(),
node2: name_to_id["part"],
node2_col_name: "p_partkey".to_string(),
total_domain: 200_000,
},
JoinEdge {
node1: name_to_id["lineitem"],
node1_col_name: "l_suppkey".to_string(),
node2: name_to_id["supplier"],
node2_col_name: "s_suppkey".to_string(),
total_domain: 10_000,
},
JoinEdge {
node1: name_to_id["supplier"],
node1_col_name: "s_nationkey".to_string(),
node2: name_to_id["nation"],
node2_col_name: "n_nationkey".to_string(),
total_domain: 25,
},
];
let optimal_order = test_join(
test_relation(name_to_id["orders"]),
test_join(
test_relation(name_to_id["lineitem"]),
test_join(
test_join(
test_relation(name_to_id["nation"]),
test_relation(name_to_id["supplier"]),
),
test_join(
test_relation(name_to_id["part"]),
test_relation(name_to_id["partsupp"]),
),
),
),
);
create_and_test_join_order!(nodes, edges, BruteForceJoinOrderer {}, optimal_order);
}
#[test]
fn test_brute_force_order_star_schema() {
let nodes = vec![
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ impl JoinOrderTree {
#[cfg(test)]
// Check if the join structure is the same, regardless of cardinality or join conditions.
pub(super) fn order_eq(this: &Self, other: &Self) -> bool {
match (this, other) {
let res = match (this, other) {
(JoinOrderTree::Relation(id1, _), JoinOrderTree::Relation(id2, _)) => id1 == id2,
(
JoinOrderTree::Join(left1, right1, _, _),
Expand All @@ -71,7 +71,8 @@ impl JoinOrderTree {
|| (Self::order_eq(left1, right2) && Self::order_eq(right1, left2))
}
_ => false,
}
};
res
}

#[cfg(test)]
Expand Down Expand Up @@ -144,7 +145,7 @@ pub(super) struct JoinAdjList {
// The total domain is the number of distinct values in the columns that are part of the equivalence set. For pk-fk joins,
// this would be the number of primary keys. In the absence of ndv statistics, we take the smallest table in the equivalence set,
// assume it's the primary key table, and use its cardinality as the total domain.
total_domains: Vec<usize>,
pub total_domains: Vec<usize>,
}

impl std::fmt::Display for JoinAdjList {
Expand Down Expand Up @@ -251,20 +252,21 @@ impl JoinAdjList {
}
}

// Helper function that estimates the total domain for a join between two relations.
fn get_estimated_total_domain(&self, left_plan: &LogicalPlanRef, right_plan: &LogicalPlanRef) -> usize {
let left_stats = left_plan.materialized_stats();
let right_stats = right_plan.materialized_stats();
// We multiple the number of rows by the reciprocal of the selectivity to get the original total domain.
let left_rows = left_stats.approx_stats.num_rows as f64 / left_stats.approx_stats.acc_selectivity.max(0.01);
let right_rows = right_stats.approx_stats.num_rows as f64 / right_stats.approx_stats.acc_selectivity.max(0.01);
left_rows.min(right_rows).max(1.0) as usize
}

pub(super) fn add_bidirectional_edge(&mut self, node1: JoinNode, node2: JoinNode) {
let node1_id = self.get_or_create_plan_id(&node1.plan);
let node2_id = self.get_or_create_plan_id(&node2.plan);
// Find the minimal total domain for the join columns, either from the current nodes or from the existing total domains.
let mut td = {
let node1_stats = node1.plan.materialized_stats();
let node2_stats = node2.plan.materialized_stats();
// We multiple the number of rows by the reciprocal of the selectivity to get the original total domain.
let node1_rows = node1_stats.approx_stats.num_rows as f64
/ node1_stats.approx_stats.acc_selectivity.max(0.01);
let node2_rows = node2_stats.approx_stats.num_rows as f64
/ node2_stats.approx_stats.acc_selectivity.max(0.01);
node1_rows.min(node2_rows).max(1.0) as usize
};
let mut td = self.get_estimated_total_domain(&node1.plan, &node2.plan);
if let Some(equivalence_set_id) = self
.equivalence_set_map
.get(&(node1_id, node1.relation_name.clone()))
Expand Down Expand Up @@ -365,26 +367,71 @@ impl JoinAdjList {
// Grab the minimum spanning tree of join conditions that connect the left and right trees, i.e. we take at most one join condition
// from each equivalence set of join conditions.
let mut conds = vec![];
let mut seen_equivalence_set_ids = HashSet::new();
let mut added_equivalence_set_id_for_td = HashSet::new();
let mut added_equivalence_set_id_for_conds = HashSet::new();
let mut double_counted_equivalence_set_ids = HashSet::new();
let mut td = 1;
for left_node in left.iter() {
if let Some(neighbors) = self.edges.get(&left_node) {
for right_node in right.iter() {
if let Some(edges) = neighbors.get(&right_node) {
for edge in edges {
// conds.extend(edges.iter().cloned());
// for edge in edges {
// let equivalence_set_id = self
// .equivalence_set_map
// .get(&(left_node, edge.left_on.clone()))
// .expect("Left join condition should be part of an equivalence set");
// if seen_equivalence_set_ids.insert(*equivalence_set_id) {
// if edges.len() == 1 {
// td *= self.total_domains[*equivalence_set_id];
// }
// conds.push(edge.clone());
// }
// }
if edges.len() == 1 {
let edge = edges[0].clone();
let equivalence_set_id = self
.equivalence_set_map
.get(&(left_node, edge.left_on.clone()))
.expect("Left join condition should be part of an equivalence set");
if seen_equivalence_set_ids.insert(*equivalence_set_id) {
if added_equivalence_set_id_for_td.insert(*equivalence_set_id) {
td *= self.total_domains[*equivalence_set_id];
}
if added_equivalence_set_id_for_conds.insert(*equivalence_set_id) {
conds.push(edge.clone());
}
}
if edges.len() > 1 {
let node1_plan = self
.id_to_plan
.get(&left_node)
.expect("left id not found in adj list");
let node2_plan = self
.id_to_plan
.get(&right_node)
.expect("right id not found in adj list");
td *= self.get_estimated_total_domain(node1_plan, node2_plan);
for edge in edges {
let equivalence_set_id = self
.equivalence_set_map
.get(&(left_node, edge.left_on.clone()))
.expect("Left join condition should be part of an equivalence set");
if added_equivalence_set_id_for_conds.insert(*equivalence_set_id) {
conds.push(edge.clone());
}
double_counted_equivalence_set_ids.insert(*equivalence_set_id);
}
}
}
}
}
}
for equivalence_set_id in double_counted_equivalence_set_ids {
if added_equivalence_set_id_for_td.contains(&equivalence_set_id) {
td /= self.total_domains[equivalence_set_id].max(1);
}
}
td = td.max(1);
(conds, td)
}
}
Expand Down

0 comments on commit a8c4fcc

Please sign in to comment.