From c92a3af36b9fb3710c73e58efe4e26325349dce8 Mon Sep 17 00:00:00 2001 From: Lan Lou Date: Thu, 14 Nov 2024 10:11:25 -0500 Subject: [PATCH 01/51] Partial migration of filter --- optd-cost-model/src/common/nodes.rs | 27 +- .../src/common/predicates/attr_ref_pred.rs | 45 +++ .../src/common/predicates/cast_pred.rs | 44 +++ .../src/common/predicates/constant_pred.rs | 189 +++++++++++ .../src/common/predicates/data_type_pred.rs | 40 +++ optd-cost-model/src/common/predicates/mod.rs | 3 + optd-cost-model/src/cost/filter.rs | 306 ++++++++++++++++++ optd-cost-model/src/cost_model.rs | 4 +- optd-cost-model/src/lib.rs | 1 + 9 files changed, 656 insertions(+), 3 deletions(-) create mode 100644 optd-cost-model/src/common/predicates/attr_ref_pred.rs create mode 100644 optd-cost-model/src/common/predicates/cast_pred.rs create mode 100644 optd-cost-model/src/common/predicates/data_type_pred.rs diff --git a/optd-cost-model/src/common/nodes.rs b/optd-cost-model/src/common/nodes.rs index 38e2500..0bbcca1 100644 --- a/optd-cost-model/src/common/nodes.rs +++ b/optd-cost-model/src/common/nodes.rs @@ -77,7 +77,7 @@ pub struct PredicateNode { /// A generic predicate node type pub typ: PredicateType, /// Child predicate nodes, always materialized - pub children: Vec, + pub children: Vec, /// Data associated with the predicate, if any pub data: Option, } @@ -94,3 +94,28 @@ impl std::fmt::Display for PredicateNode { write!(f, ")") } } + +impl PredicateNode { + pub fn child(&self, idx: usize) -> ArcPredicateNode { + self.children[idx].clone() + } + + pub fn unwrap_data(&self) -> Value { + self.data.clone().unwrap() + } +} +pub trait ReprPredicateNode: 'static + Clone { + fn into_pred_node(self) -> ArcPredicateNode; + + fn from_pred_node(pred_node: ArcPredicateNode) -> Option; +} + +impl ReprPredicateNode for ArcPredicateNode { + fn into_pred_node(self) -> ArcPredicateNode { + self + } + + fn from_pred_node(pred_node: ArcPredicateNode) -> Option { + Some(pred_node) + } +} diff --git a/optd-cost-model/src/common/predicates/attr_ref_pred.rs b/optd-cost-model/src/common/predicates/attr_ref_pred.rs new file mode 100644 index 0000000..34e1901 --- /dev/null +++ b/optd-cost-model/src/common/predicates/attr_ref_pred.rs @@ -0,0 +1,45 @@ +use crate::common::{ + nodes::{ArcPredicateNode, PredicateNode, PredicateType, ReprPredicateNode}, + values::Value, +}; + +#[derive(Clone, Debug)] +pub struct AttributeRefPred(pub ArcPredicateNode); + +impl AttributeRefPred { + /// Creates a new `ColumnRef` expression. + pub fn new(column_idx: usize) -> AttributeRefPred { + // this conversion is always safe since usize is at most u64 + let u64_column_idx = column_idx as u64; + AttributeRefPred( + PredicateNode { + typ: PredicateType::AttributeRef, + children: vec![], + data: Some(Value::UInt64(u64_column_idx)), + } + .into(), + ) + } + + fn get_data_usize(&self) -> usize { + self.0.data.as_ref().unwrap().as_u64() as usize + } + + /// Gets the column index. + pub fn index(&self) -> usize { + self.get_data_usize() + } +} + +impl ReprPredicateNode for AttributeRefPred { + fn into_pred_node(self) -> ArcPredicateNode { + self.0 + } + + fn from_pred_node(pred_node: ArcPredicateNode) -> Option { + if pred_node.typ != PredicateType::AttributeRef { + return None; + } + Some(Self(pred_node)) + } +} diff --git a/optd-cost-model/src/common/predicates/cast_pred.rs b/optd-cost-model/src/common/predicates/cast_pred.rs new file mode 100644 index 0000000..eaafca9 --- /dev/null +++ b/optd-cost-model/src/common/predicates/cast_pred.rs @@ -0,0 +1,44 @@ +use arrow_schema::DataType; + +use crate::common::nodes::{ArcPredicateNode, PredicateNode, PredicateType, ReprPredicateNode}; + +use super::data_type_pred::DataTypePred; + +#[derive(Clone, Debug)] +pub struct CastPred(pub ArcPredicateNode); + +impl CastPred { + pub fn new(child: ArcPredicateNode, cast_to: DataType) -> Self { + CastPred( + PredicateNode { + typ: PredicateType::Cast, + children: vec![child, DataTypePred::new(cast_to).into_pred_node()], + data: None, + } + .into(), + ) + } + + pub fn child(&self) -> ArcPredicateNode { + self.0.child(0) + } + + pub fn cast_to(&self) -> DataType { + DataTypePred::from_pred_node(self.0.child(1)) + .unwrap() + .data_type() + } +} + +impl ReprPredicateNode for CastPred { + fn into_pred_node(self) -> ArcPredicateNode { + self.0 + } + + fn from_pred_node(pred_node: ArcPredicateNode) -> Option { + if !matches!(pred_node.typ, PredicateType::Cast) { + return None; + } + Some(Self(pred_node)) + } +} diff --git a/optd-cost-model/src/common/predicates/constant_pred.rs b/optd-cost-model/src/common/predicates/constant_pred.rs index 7923ae4..2fa06ae 100644 --- a/optd-cost-model/src/common/predicates/constant_pred.rs +++ b/optd-cost-model/src/common/predicates/constant_pred.rs @@ -1,5 +1,13 @@ +use std::sync::Arc; + +use arrow_schema::{DataType, IntervalUnit}; use serde::{Deserialize, Serialize}; +use crate::common::{ + nodes::{ArcPredicateNode, PredicateNode, PredicateType, ReprPredicateNode}, + values::{SerializableOrderedF64, Value}, +}; + /// TODO: documentation #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug, Serialize, Deserialize)] pub enum ConstantType { @@ -19,3 +27,184 @@ pub enum ConstantType { Decimal, Binary, } + +impl ConstantType { + pub fn get_data_type_from_value(value: &Value) -> Self { + match value { + Value::Bool(_) => ConstantType::Bool, + Value::String(_) => ConstantType::Utf8String, + Value::UInt8(_) => ConstantType::UInt8, + Value::UInt16(_) => ConstantType::UInt16, + Value::UInt32(_) => ConstantType::UInt32, + Value::UInt64(_) => ConstantType::UInt64, + Value::Int8(_) => ConstantType::Int8, + Value::Int16(_) => ConstantType::Int16, + Value::Int32(_) => ConstantType::Int32, + Value::Int64(_) => ConstantType::Int64, + Value::Float(_) => ConstantType::Float64, + Value::Date32(_) => ConstantType::Date, + _ => unimplemented!("get_data_type_from_value() not implemented for value {value}"), + } + } + + // TODO: current DataType and ConstantType are not 1 to 1 mapping + // optd schema stores constantType from data type in catalog.get + // for decimal128, the precision is lost + pub fn from_data_type(data_type: DataType) -> Self { + match data_type { + DataType::Binary => ConstantType::Binary, + DataType::Boolean => ConstantType::Bool, + DataType::UInt8 => ConstantType::UInt8, + DataType::UInt16 => ConstantType::UInt16, + DataType::UInt32 => ConstantType::UInt32, + DataType::UInt64 => ConstantType::UInt64, + DataType::Int8 => ConstantType::Int8, + DataType::Int16 => ConstantType::Int16, + DataType::Int32 => ConstantType::Int32, + DataType::Int64 => ConstantType::Int64, + DataType::Float64 => ConstantType::Float64, + DataType::Date32 => ConstantType::Date, + DataType::Interval(IntervalUnit::MonthDayNano) => ConstantType::IntervalMonthDateNano, + DataType::Utf8 => ConstantType::Utf8String, + DataType::Decimal128(_, _) => ConstantType::Decimal, + _ => unimplemented!("no conversion to ConstantType for DataType {data_type}"), + } + } + + pub fn into_data_type(&self) -> DataType { + match self { + ConstantType::Binary => DataType::Binary, + ConstantType::Bool => DataType::Boolean, + ConstantType::UInt8 => DataType::UInt8, + ConstantType::UInt16 => DataType::UInt16, + ConstantType::UInt32 => DataType::UInt32, + ConstantType::UInt64 => DataType::UInt64, + ConstantType::Int8 => DataType::Int8, + ConstantType::Int16 => DataType::Int16, + ConstantType::Int32 => DataType::Int32, + ConstantType::Int64 => DataType::Int64, + ConstantType::Float64 => DataType::Float64, + ConstantType::Date => DataType::Date32, + ConstantType::IntervalMonthDateNano => DataType::Interval(IntervalUnit::MonthDayNano), + ConstantType::Decimal => DataType::Float64, + ConstantType::Utf8String => DataType::Utf8, + } + } +} + +#[derive(Clone, Debug)] +pub struct ConstantPred(pub ArcPredicateNode); + +impl ConstantPred { + pub fn new(value: Value) -> Self { + let typ = ConstantType::get_data_type_from_value(&value); + Self::new_with_type(value, typ) + } + + pub fn new_with_type(value: Value, typ: ConstantType) -> Self { + ConstantPred( + PredicateNode { + typ: PredicateType::Constant(typ), + children: vec![], + data: Some(value), + } + .into(), + ) + } + + pub fn bool(value: bool) -> Self { + Self::new_with_type(Value::Bool(value), ConstantType::Bool) + } + + pub fn string(value: impl AsRef) -> Self { + Self::new_with_type( + Value::String(value.as_ref().into()), + ConstantType::Utf8String, + ) + } + + pub fn uint8(value: u8) -> Self { + Self::new_with_type(Value::UInt8(value), ConstantType::UInt8) + } + + pub fn uint16(value: u16) -> Self { + Self::new_with_type(Value::UInt16(value), ConstantType::UInt16) + } + + pub fn uint32(value: u32) -> Self { + Self::new_with_type(Value::UInt32(value), ConstantType::UInt32) + } + + pub fn uint64(value: u64) -> Self { + Self::new_with_type(Value::UInt64(value), ConstantType::UInt64) + } + + pub fn int8(value: i8) -> Self { + Self::new_with_type(Value::Int8(value), ConstantType::Int8) + } + + pub fn int16(value: i16) -> Self { + Self::new_with_type(Value::Int16(value), ConstantType::Int16) + } + + pub fn int32(value: i32) -> Self { + Self::new_with_type(Value::Int32(value), ConstantType::Int32) + } + + pub fn int64(value: i64) -> Self { + Self::new_with_type(Value::Int64(value), ConstantType::Int64) + } + + pub fn interval_month_day_nano(value: i128) -> Self { + Self::new_with_type(Value::Int128(value), ConstantType::IntervalMonthDateNano) + } + + pub fn float64(value: f64) -> Self { + Self::new_with_type( + Value::Float(SerializableOrderedF64(value.into())), + ConstantType::Float64, + ) + } + + pub fn date(value: i64) -> Self { + Self::new_with_type(Value::Int64(value), ConstantType::Date) + } + + pub fn decimal(value: f64) -> Self { + Self::new_with_type( + Value::Float(SerializableOrderedF64(value.into())), + ConstantType::Decimal, + ) + } + + pub fn serialized(value: Arc<[u8]>) -> Self { + Self::new_with_type(Value::Serialized(value), ConstantType::Binary) + } + + /// Gets the constant value. + pub fn value(&self) -> Value { + self.0.data.clone().unwrap() + } + + pub fn constant_type(&self) -> ConstantType { + if let PredicateType::Constant(typ) = self.0.typ { + typ + } else { + panic!("not a constant") + } + } +} + +impl ReprPredicateNode for ConstantPred { + fn into_pred_node(self) -> ArcPredicateNode { + self.0 + } + + fn from_pred_node(rel_node: ArcPredicateNode) -> Option { + if let PredicateType::Constant(_) = rel_node.typ { + Some(Self(rel_node)) + } else { + None + } + } +} diff --git a/optd-cost-model/src/common/predicates/data_type_pred.rs b/optd-cost-model/src/common/predicates/data_type_pred.rs new file mode 100644 index 0000000..fe29336 --- /dev/null +++ b/optd-cost-model/src/common/predicates/data_type_pred.rs @@ -0,0 +1,40 @@ +use arrow_schema::DataType; + +use crate::common::nodes::{ArcPredicateNode, PredicateNode, PredicateType, ReprPredicateNode}; + +#[derive(Clone, Debug)] +pub struct DataTypePred(pub ArcPredicateNode); + +impl DataTypePred { + pub fn new(typ: DataType) -> Self { + DataTypePred( + PredicateNode { + typ: PredicateType::DataType(typ), + children: vec![], + data: None, + } + .into(), + ) + } + + pub fn data_type(&self) -> DataType { + if let PredicateType::DataType(ref data_type) = self.0.typ { + data_type.clone() + } else { + panic!("not a data type") + } + } +} + +impl ReprPredicateNode for DataTypePred { + fn into_pred_node(self) -> ArcPredicateNode { + self.0 + } + + fn from_pred_node(pred_node: ArcPredicateNode) -> Option { + if !matches!(pred_node.typ, PredicateType::DataType(_)) { + return None; + } + Some(Self(pred_node)) + } +} diff --git a/optd-cost-model/src/common/predicates/mod.rs b/optd-cost-model/src/common/predicates/mod.rs index 87e6e94..d733198 100644 --- a/optd-cost-model/src/common/predicates/mod.rs +++ b/optd-cost-model/src/common/predicates/mod.rs @@ -1,5 +1,8 @@ +pub mod attr_ref_pred; pub mod bin_op_pred; +pub mod cast_pred; pub mod constant_pred; +pub mod data_type_pred; pub mod func_pred; pub mod log_op_pred; pub mod sort_order_pred; diff --git a/optd-cost-model/src/cost/filter.rs b/optd-cost-model/src/cost/filter.rs index 8b13789..056b0c1 100644 --- a/optd-cost-model/src/cost/filter.rs +++ b/optd-cost-model/src/cost/filter.rs @@ -1 +1,307 @@ +#![allow(unused_variables)] +use optd_persistent::CostModelStorageLayer; +use crate::{ + common::{ + nodes::{ArcPredicateNode, PredicateType, ReprPredicateNode}, + predicates::{ + attr_ref_pred::AttributeRefPred, + bin_op_pred::BinOpType, + cast_pred::CastPred, + constant_pred::{ConstantPred, ConstantType}, + un_op_pred::UnOpType, + }, + values::Value, + }, + cost_model::CostModelImpl, + CostModelResult, EstimatedStatistic, +}; + +// A placeholder for unimplemented!() for codepaths which are accessed by plannertest +const UNIMPLEMENTED_SEL: f64 = 0.01; +// Default statistics. All are from selfuncs.h in Postgres unless specified otherwise +// Default selectivity estimate for equalities such as "A = b" +const DEFAULT_EQ_SEL: f64 = 0.005; +// Default selectivity estimate for inequalities such as "A < b" +const DEFAULT_INEQ_SEL: f64 = 0.3333333333333333; + +impl CostModelImpl { + pub fn get_filter_row_cnt( + &self, + child_row_cnt: EstimatedStatistic, + table_id: i32, + cond: ArcPredicateNode, + ) -> CostModelResult { + let selectivity = { self.get_filter_selectivity(cond, table_id)? }; + Ok( + EstimatedStatistic((child_row_cnt.0 as f64 * selectivity) as u64) + .max(EstimatedStatistic(1)), + ) + } + + pub fn get_filter_selectivity( + &self, + expr_tree: ArcPredicateNode, + table_id: i32, + ) -> CostModelResult { + match &expr_tree.typ { + PredicateType::Constant(_) => Ok(Self::get_constant_selectivity(expr_tree)), + PredicateType::AttributeRef => unimplemented!("check bool type or else panic"), + PredicateType::UnOp(un_op_typ) => { + assert!(expr_tree.children.len() == 1); + let child = expr_tree.child(0); + match un_op_typ { + // not doesn't care about nulls so there's no complex logic. it just reverses + // the selectivity for instance, != _will not_ include nulls + // but "NOT ==" _will_ include nulls + UnOpType::Not => Ok(1.0 - self.get_filter_selectivity(child, table_id)?), + UnOpType::Neg => panic!( + "the selectivity of operations that return numerical values is undefined" + ), + } + } + PredicateType::BinOp(bin_op_typ) => { + assert!(expr_tree.children.len() == 2); + let left_child = expr_tree.child(0); + let right_child = expr_tree.child(1); + + if bin_op_typ.is_comparison() { + self.get_comp_op_selectivity(*bin_op_typ, left_child, right_child, table_id) + } else if bin_op_typ.is_numerical() { + panic!( + "the selectivity of operations that return numerical values is undefined" + ) + } else { + unreachable!("all BinOpTypes should be true for at least one is_*() function") + } + } + _ => unimplemented!("check bool type or else panic"), + } + } + + fn get_constant_selectivity(const_node: ArcPredicateNode) -> f64 { + if let PredicateType::Constant(const_typ) = const_node.typ { + if matches!(const_typ, ConstantType::Bool) { + let value = const_node + .as_ref() + .data + .as_ref() + .expect("constants should have data"); + if let Value::Bool(bool_value) = value { + if *bool_value { + 1.0 + } else { + 0.0 + } + } else { + unreachable!( + "if the typ is ConstantType::Bool, the value should be a Value::Bool" + ) + } + } else { + panic!("selectivity is not defined on constants which are not bools") + } + } else { + panic!("get_constant_selectivity must be called on a constant") + } + } + + /// Comparison operators are the base case for recursion in get_filter_selectivity() + fn get_comp_op_selectivity( + &self, + comp_bin_op_typ: BinOpType, + left: ArcPredicateNode, + right: ArcPredicateNode, + table_id: i32, + ) -> CostModelResult { + assert!(comp_bin_op_typ.is_comparison()); + + // I intentionally performed moves on left and right. This way, we don't accidentally use + // them after this block + let (col_ref_exprs, values, non_col_ref_exprs, is_left_col_ref) = + self.get_semantic_nodes(left, right, table_id)?; + + // Handle the different cases of semantic nodes. + if col_ref_exprs.is_empty() { + Ok(UNIMPLEMENTED_SEL) + } else if col_ref_exprs.len() == 1 { + let col_ref_expr = col_ref_exprs + .first() + .expect("we just checked that col_ref_exprs.len() == 1"); + let col_ref_idx = col_ref_expr.index(); + + todo!() + } else if col_ref_exprs.len() == 2 { + Ok(Self::get_default_comparison_op_selectivity(comp_bin_op_typ)) + } else { + unreachable!("we could have at most pushed left and right into col_ref_exprs") + } + } + + /// Convert the left and right child nodes of some operation to what they semantically are. + /// This is convenient to avoid repeating the same logic just with "left" and "right" swapped. + /// The last return value is true when the input node (left) is a ColumnRefPred. + #[allow(clippy::type_complexity)] + fn get_semantic_nodes( + &self, + left: ArcPredicateNode, + right: ArcPredicateNode, + table_id: i32, + ) -> CostModelResult<( + Vec, + Vec, + Vec, + bool, + )> { + let mut col_ref_exprs = vec![]; + let mut values = vec![]; + let mut non_col_ref_exprs = vec![]; + let is_left_col_ref; + + // Recursively unwrap casts as much as we can. + let mut uncasted_left = left; + let mut uncasted_right = right; + loop { + // println!("loop {}, uncasted_left={:?}, uncasted_right={:?}", Local::now(), + // uncasted_left, uncasted_right); + if uncasted_left.as_ref().typ == PredicateType::Cast + && uncasted_right.as_ref().typ == PredicateType::Cast + { + let left_cast_expr = CastPred::from_pred_node(uncasted_left) + .expect("we already checked that the type is Cast"); + let right_cast_expr = CastPred::from_pred_node(uncasted_right) + .expect("we already checked that the type is Cast"); + assert!(left_cast_expr.cast_to() == right_cast_expr.cast_to()); + uncasted_left = left_cast_expr.child().into_pred_node(); + uncasted_right = right_cast_expr.child().into_pred_node(); + } else if uncasted_left.as_ref().typ == PredicateType::Cast + || uncasted_right.as_ref().typ == PredicateType::Cast + { + let is_left_cast = uncasted_left.as_ref().typ == PredicateType::Cast; + let (mut cast_node, mut non_cast_node) = if is_left_cast { + (uncasted_left, uncasted_right) + } else { + (uncasted_right, uncasted_left) + }; + + let cast_expr = CastPred::from_pred_node(cast_node) + .expect("we already checked that the type is Cast"); + let cast_expr_child = cast_expr.child().into_pred_node(); + let cast_expr_cast_to = cast_expr.cast_to(); + + let should_break = match cast_expr_child.typ { + PredicateType::Constant(_) => { + cast_node = ConstantPred::new( + ConstantPred::from_pred_node(cast_expr_child) + .expect("we already checked that the type is Constant") + .value() + .convert_to_type(cast_expr_cast_to), + ) + .into_pred_node(); + false + } + PredicateType::AttributeRef => { + let col_ref_expr = AttributeRefPred::from_pred_node(cast_expr_child) + .expect("we already checked that the type is ColumnRef"); + let col_ref_idx = col_ref_expr.index(); + cast_node = col_ref_expr.into_pred_node(); + // The "invert" cast is to invert the cast so that we're casting the + // non_cast_node to the column's original type. + // TODO(migration): double check + let invert_cast_data_type = &(self + .storage_manager + .get_attribute_info(table_id, col_ref_idx as i32)? + .typ + .into_data_type()); + + match non_cast_node.typ { + PredicateType::AttributeRef => { + // In general, there's no way to remove the Cast here. We can't move + // the Cast to the other ColumnRef + // because that would lead to an infinite loop. Thus, we just leave + // the cast where it is and break. + true + } + _ => { + non_cast_node = + CastPred::new(non_cast_node, invert_cast_data_type.clone()) + .into_pred_node(); + false + } + } + } + _ => todo!(), + }; + + (uncasted_left, uncasted_right) = if is_left_cast { + (cast_node, non_cast_node) + } else { + (non_cast_node, cast_node) + }; + + if should_break { + break; + } + } else { + break; + } + } + + // Sort nodes into col_ref_exprs, values, and non_col_ref_exprs + match uncasted_left.as_ref().typ { + PredicateType::AttributeRef => { + is_left_col_ref = true; + col_ref_exprs.push( + AttributeRefPred::from_pred_node(uncasted_left) + .expect("we already checked that the type is ColumnRef"), + ); + } + PredicateType::Constant(_) => { + is_left_col_ref = false; + values.push( + ConstantPred::from_pred_node(uncasted_left) + .expect("we already checked that the type is Constant") + .value(), + ) + } + _ => { + is_left_col_ref = false; + non_col_ref_exprs.push(uncasted_left); + } + } + match uncasted_right.as_ref().typ { + PredicateType::AttributeRef => { + col_ref_exprs.push( + AttributeRefPred::from_pred_node(uncasted_right) + .expect("we already checked that the type is ColumnRef"), + ); + } + PredicateType::Constant(_) => values.push( + ConstantPred::from_pred_node(uncasted_right) + .expect("we already checked that the type is Constant") + .value(), + ), + _ => { + non_col_ref_exprs.push(uncasted_right); + } + } + + assert!(col_ref_exprs.len() + values.len() + non_col_ref_exprs.len() == 2); + Ok((col_ref_exprs, values, non_col_ref_exprs, is_left_col_ref)) + } + + /// The default selectivity of a comparison expression + /// Used when one side of the comparison is a column while the other side is something too + /// complex/impossible to evaluate (subquery, UDF, another column, we have no stats, etc.) + fn get_default_comparison_op_selectivity(comp_bin_op_typ: BinOpType) -> f64 { + assert!(comp_bin_op_typ.is_comparison()); + match comp_bin_op_typ { + BinOpType::Eq => DEFAULT_EQ_SEL, + BinOpType::Neq => 1.0 - DEFAULT_EQ_SEL, + BinOpType::Lt | BinOpType::Leq | BinOpType::Gt | BinOpType::Geq => DEFAULT_INEQ_SEL, + _ => unreachable!( + "all comparison BinOpTypes were enumerated. this should be unreachable" + ), + } + } +} diff --git a/optd-cost-model/src/cost_model.rs b/optd-cost-model/src/cost_model.rs index c0b0677..0b1760e 100644 --- a/optd-cost-model/src/cost_model.rs +++ b/optd-cost-model/src/cost_model.rs @@ -18,8 +18,8 @@ use crate::{ /// TODO: documentation pub struct CostModelImpl { - storage_manager: CostModelStorageManager, - default_catalog_source: CatalogSource, + pub storage_manager: CostModelStorageManager, + pub default_catalog_source: CatalogSource, } impl CostModelImpl { diff --git a/optd-cost-model/src/lib.rs b/optd-cost-model/src/lib.rs index a635b66..e18098f 100644 --- a/optd-cost-model/src/lib.rs +++ b/optd-cost-model/src/lib.rs @@ -31,6 +31,7 @@ pub struct Cost(pub Vec); /// Estimated statistic calculated by the cost model. /// It is the estimated output row count of the targeted expression. +#[derive(Eq, Ord, PartialEq, PartialOrd)] pub struct EstimatedStatistic(pub u64); pub type CostModelResult = Result; From 7b0158c6c53f76e954b334dcf6580aa42bb66a28 Mon Sep 17 00:00:00 2001 From: Lan Lou Date: Thu, 14 Nov 2024 10:28:22 -0500 Subject: [PATCH 02/51] Change col to attr --- .../src/common/predicates/attr_ref_pred.rs | 6 +- optd-cost-model/src/cost/filter.rs | 59 ++++++++++--------- 2 files changed, 33 insertions(+), 32 deletions(-) diff --git a/optd-cost-model/src/common/predicates/attr_ref_pred.rs b/optd-cost-model/src/common/predicates/attr_ref_pred.rs index 34e1901..b3b7814 100644 --- a/optd-cost-model/src/common/predicates/attr_ref_pred.rs +++ b/optd-cost-model/src/common/predicates/attr_ref_pred.rs @@ -8,14 +8,14 @@ pub struct AttributeRefPred(pub ArcPredicateNode); impl AttributeRefPred { /// Creates a new `ColumnRef` expression. - pub fn new(column_idx: usize) -> AttributeRefPred { + pub fn new(attribute_idx: usize) -> AttributeRefPred { // this conversion is always safe since usize is at most u64 - let u64_column_idx = column_idx as u64; + let u64_attribute_idx = attribute_idx as u64; AttributeRefPred( PredicateNode { typ: PredicateType::AttributeRef, children: vec![], - data: Some(Value::UInt64(u64_column_idx)), + data: Some(Value::UInt64(u64_attribute_idx)), } .into(), ) diff --git a/optd-cost-model/src/cost/filter.rs b/optd-cost-model/src/cost/filter.rs index 056b0c1..619b1a9 100644 --- a/optd-cost-model/src/cost/filter.rs +++ b/optd-cost-model/src/cost/filter.rs @@ -11,6 +11,7 @@ use crate::{ constant_pred::{ConstantPred, ConstantType}, un_op_pred::UnOpType, }, + types::TableId, values::Value, }, cost_model::CostModelImpl, @@ -29,7 +30,7 @@ impl CostModelImpl { pub fn get_filter_row_cnt( &self, child_row_cnt: EstimatedStatistic, - table_id: i32, + table_id: TableId, cond: ArcPredicateNode, ) -> CostModelResult { let selectivity = { self.get_filter_selectivity(cond, table_id)? }; @@ -42,7 +43,7 @@ impl CostModelImpl { pub fn get_filter_selectivity( &self, expr_tree: ArcPredicateNode, - table_id: i32, + table_id: TableId, ) -> CostModelResult { match &expr_tree.typ { PredicateType::Constant(_) => Ok(Self::get_constant_selectivity(expr_tree)), @@ -112,29 +113,29 @@ impl CostModelImpl { comp_bin_op_typ: BinOpType, left: ArcPredicateNode, right: ArcPredicateNode, - table_id: i32, + table_id: TableId, ) -> CostModelResult { assert!(comp_bin_op_typ.is_comparison()); // I intentionally performed moves on left and right. This way, we don't accidentally use // them after this block - let (col_ref_exprs, values, non_col_ref_exprs, is_left_col_ref) = + let (attr_ref_exprs, values, non_attr_ref_exprs, is_left_attr_ref) = self.get_semantic_nodes(left, right, table_id)?; // Handle the different cases of semantic nodes. - if col_ref_exprs.is_empty() { + if attr_ref_exprs.is_empty() { Ok(UNIMPLEMENTED_SEL) - } else if col_ref_exprs.len() == 1 { - let col_ref_expr = col_ref_exprs + } else if attr_ref_exprs.len() == 1 { + let attr_ref_expr = attr_ref_exprs .first() - .expect("we just checked that col_ref_exprs.len() == 1"); - let col_ref_idx = col_ref_expr.index(); + .expect("we just checked that attr_ref_exprs.len() == 1"); + let attr_ref_idx = attr_ref_expr.index(); todo!() - } else if col_ref_exprs.len() == 2 { + } else if attr_ref_exprs.len() == 2 { Ok(Self::get_default_comparison_op_selectivity(comp_bin_op_typ)) } else { - unreachable!("we could have at most pushed left and right into col_ref_exprs") + unreachable!("we could have at most pushed left and right into attr_ref_exprs") } } @@ -146,17 +147,17 @@ impl CostModelImpl { &self, left: ArcPredicateNode, right: ArcPredicateNode, - table_id: i32, + table_id: TableId, ) -> CostModelResult<( Vec, Vec, Vec, bool, )> { - let mut col_ref_exprs = vec![]; + let mut attr_ref_exprs = vec![]; let mut values = vec![]; - let mut non_col_ref_exprs = vec![]; - let is_left_col_ref; + let mut non_attr_ref_exprs = vec![]; + let is_left_attr_ref; // Recursively unwrap casts as much as we can. let mut uncasted_left = left; @@ -201,16 +202,16 @@ impl CostModelImpl { false } PredicateType::AttributeRef => { - let col_ref_expr = AttributeRefPred::from_pred_node(cast_expr_child) + let attr_ref_expr = AttributeRefPred::from_pred_node(cast_expr_child) .expect("we already checked that the type is ColumnRef"); - let col_ref_idx = col_ref_expr.index(); - cast_node = col_ref_expr.into_pred_node(); + let attr_ref_idx = attr_ref_expr.index(); + cast_node = attr_ref_expr.into_pred_node(); // The "invert" cast is to invert the cast so that we're casting the // non_cast_node to the column's original type. // TODO(migration): double check let invert_cast_data_type = &(self .storage_manager - .get_attribute_info(table_id, col_ref_idx as i32)? + .get_attribute_info(table_id, attr_ref_idx as i32)? .typ .into_data_type()); @@ -247,17 +248,17 @@ impl CostModelImpl { } } - // Sort nodes into col_ref_exprs, values, and non_col_ref_exprs + // Sort nodes into attr_ref_exprs, values, and non_attr_ref_exprs match uncasted_left.as_ref().typ { PredicateType::AttributeRef => { - is_left_col_ref = true; - col_ref_exprs.push( + is_left_attr_ref = true; + attr_ref_exprs.push( AttributeRefPred::from_pred_node(uncasted_left) .expect("we already checked that the type is ColumnRef"), ); } PredicateType::Constant(_) => { - is_left_col_ref = false; + is_left_attr_ref = false; values.push( ConstantPred::from_pred_node(uncasted_left) .expect("we already checked that the type is Constant") @@ -265,13 +266,13 @@ impl CostModelImpl { ) } _ => { - is_left_col_ref = false; - non_col_ref_exprs.push(uncasted_left); + is_left_attr_ref = false; + non_attr_ref_exprs.push(uncasted_left); } } match uncasted_right.as_ref().typ { PredicateType::AttributeRef => { - col_ref_exprs.push( + attr_ref_exprs.push( AttributeRefPred::from_pred_node(uncasted_right) .expect("we already checked that the type is ColumnRef"), ); @@ -282,12 +283,12 @@ impl CostModelImpl { .value(), ), _ => { - non_col_ref_exprs.push(uncasted_right); + non_attr_ref_exprs.push(uncasted_right); } } - assert!(col_ref_exprs.len() + values.len() + non_col_ref_exprs.len() == 2); - Ok((col_ref_exprs, values, non_col_ref_exprs, is_left_col_ref)) + assert!(attr_ref_exprs.len() + values.len() + non_attr_ref_exprs.len() == 2); + Ok((attr_ref_exprs, values, non_attr_ref_exprs, is_left_attr_ref)) } /// The default selectivity of a comparison expression From 81f8d50c57beaff3c57e490c06b04d986eb3e718 Mon Sep 17 00:00:00 2001 From: Yuanxin Cao Date: Thu, 14 Nov 2024 10:33:09 -0500 Subject: [PATCH 03/51] implement cost computation for limit --- optd-cost-model/src/cost/limit.rs | 29 +++++++++++++++++++++++++++++ optd-cost-model/src/cost/mod.rs | 3 +++ 2 files changed, 32 insertions(+) create mode 100644 optd-cost-model/src/cost/limit.rs diff --git a/optd-cost-model/src/cost/limit.rs b/optd-cost-model/src/cost/limit.rs new file mode 100644 index 0000000..ce3e08e --- /dev/null +++ b/optd-cost-model/src/cost/limit.rs @@ -0,0 +1,29 @@ +use optd_persistent::CostModelStorageLayer; + +use crate::{ + common::{ + nodes::{ArcPredicateNode, ReprPredicateNode}, + predicates::constant_pred::ConstantPred, + }, + cost_model::CostModelImpl, + CostModelResult, EstimatedStatistic, +}; + +impl CostModelImpl { + pub(crate) fn get_limit_row_cnt( + &self, + child_row_cnt: EstimatedStatistic, + fetch_expr: ArcPredicateNode, + ) -> CostModelResult { + let fetch = ConstantPred::from_pred_node(fetch_expr) + .unwrap() + .value() + .as_u64(); + // u64::MAX represents None + if fetch == u64::MAX { + Ok(child_row_cnt) + } else { + Ok(EstimatedStatistic(child_row_cnt.0.min(fetch))) + } + } +} diff --git a/optd-cost-model/src/cost/mod.rs b/optd-cost-model/src/cost/mod.rs index 795ed3e..c98d7d7 100644 --- a/optd-cost-model/src/cost/mod.rs +++ b/optd-cost-model/src/cost/mod.rs @@ -1,3 +1,6 @@ +#![allow(unused)] + pub mod agg; pub mod filter; pub mod join; +pub mod limit; From 2a5740e9236a6ccc11c8aedea04a34a92f3915ce Mon Sep 17 00:00:00 2001 From: Yuanxin Cao Date: Thu, 14 Nov 2024 11:35:34 -0500 Subject: [PATCH 04/51] add author --- optd-cost-model/Cargo.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/optd-cost-model/Cargo.toml b/optd-cost-model/Cargo.toml index 1d41af7..4161749 100644 --- a/optd-cost-model/Cargo.toml +++ b/optd-cost-model/Cargo.toml @@ -2,6 +2,7 @@ name = "optd-cost-model" version = "0.1.0" edition = "2021" +authors = ["Yuanxin Cao", "Lan Lou", "Kunle Li"] [dependencies] optd-persistent = { path = "../optd-persistent", version = "0.1" } From f7f6857815a531d29d812c947ab91df8f81c4f2e Mon Sep 17 00:00:00 2001 From: Yuanxin Cao Date: Thu, 14 Nov 2024 12:12:47 -0500 Subject: [PATCH 05/51] introduce ColumnCombValueStats --- optd-cost-model/Cargo.lock | 32 +++++++++++------------ optd-cost-model/src/cost/mod.rs | 1 + optd-cost-model/src/cost/stats.rs | 43 +++++++++++++++++++++++++++++++ 3 files changed, 60 insertions(+), 16 deletions(-) create mode 100644 optd-cost-model/src/cost/stats.rs diff --git a/optd-cost-model/Cargo.lock b/optd-cost-model/Cargo.lock index a38097d..9464489 100644 --- a/optd-cost-model/Cargo.lock +++ b/optd-cost-model/Cargo.lock @@ -480,9 +480,9 @@ dependencies = [ [[package]] name = "borsh" -version = "1.5.2" +version = "1.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f5327f6c99920069d1fe374aa743be1af0031dea9f250852cdf1ae6a0861ee24" +checksum = "2506947f73ad44e344215ccd6403ac2ae18cd8e046e581a441bf8d199f257f03" dependencies = [ "borsh-derive", "cfg_aliases", @@ -490,9 +490,9 @@ dependencies = [ [[package]] name = "borsh-derive" -version = "1.5.2" +version = "1.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "10aedd8f1a81a8aafbfde924b0e3061cd6fedd6f6bbcfc6a76e6fd426d7bfe26" +checksum = "c2593a3b8b938bd68373196c9832f516be11fa487ef4ae745eb282e6a56a7244" dependencies = [ "once_cell", "proc-macro-crate", @@ -543,9 +543,9 @@ checksum = "9ac0150caa2ae65ca5bd83f25c7de183dea78d4d366469f148435e2acfbad0da" [[package]] name = "cc" -version = "1.2.0" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1aeb932158bd710538c73702db6945cb68a8fb08c519e6e12706b94263b36db8" +checksum = "fd9de9f2205d5ef3fd67e685b0df337994ddd4495e2a28d185500d0e1edfea47" dependencies = [ "shlex", ] @@ -601,9 +601,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.20" +version = "4.5.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b97f376d85a664d5837dbae44bf546e6477a679ff6610010f17276f686d867e8" +checksum = "fb3b4b9e5a7c7514dfa52869339ee98b3156b0bfb4e8a77c4ff4babb64b1604f" dependencies = [ "clap_builder", "clap_derive", @@ -611,9 +611,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.20" +version = "4.5.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19bc80abd44e4bed93ca373a0704ccbd1b710dc5749406201bb018272808dc54" +checksum = "b17a95aa67cc7b5ebd32aa5370189aa0d79069ef1c64ce893bd30fb24bff20ec" dependencies = [ "anstream", "anstyle", @@ -635,9 +635,9 @@ dependencies = [ [[package]] name = "clap_lex" -version = "0.7.2" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1462739cb27611015575c0c11df5df7601141071f07518d56fcc1be504cbec97" +checksum = "afb84c814227b90d6895e01398aee0d8033c00e7466aca416fb6a8e0eb19d8a7" [[package]] name = "colorchoice" @@ -647,9 +647,9 @@ checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" [[package]] name = "comfy-table" -version = "7.1.1" +version = "7.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b34115915337defe99b2aff5c2ce6771e5fbc4079f4b506301f5cf394c8452f7" +checksum = "24f165e7b643266ea80cb858aed492ad9280e3e05ce24d4a99d7d7b889b6a4d9" dependencies = [ "strum 0.26.3", "strum_macros 0.26.4", @@ -3336,9 +3336,9 @@ checksum = "e70f2a8b45122e719eb623c01822704c4e0907e7e426a05927e1a1cfff5b75d0" [[package]] name = "unicode-width" -version = "0.1.14" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" +checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd" [[package]] name = "unicode_categories" diff --git a/optd-cost-model/src/cost/mod.rs b/optd-cost-model/src/cost/mod.rs index c98d7d7..8fa1bb8 100644 --- a/optd-cost-model/src/cost/mod.rs +++ b/optd-cost-model/src/cost/mod.rs @@ -4,3 +4,4 @@ pub mod agg; pub mod filter; pub mod join; pub mod limit; +pub mod stats; diff --git a/optd-cost-model/src/cost/stats.rs b/optd-cost-model/src/cost/stats.rs new file mode 100644 index 0000000..ee8e322 --- /dev/null +++ b/optd-cost-model/src/cost/stats.rs @@ -0,0 +1,43 @@ +use serde::{Deserialize, Serialize}; + +use crate::common::values::Value; + +pub type ColumnCombValue = Vec>; + +/// Ideally, MostCommonValues would have trait bounds for Serialize and Deserialize. However, I have +/// not figured out how to both have Deserialize as a trait bound and utilize the Deserialize +/// macro, because the Deserialize trait involves lifetimes. +pub trait MostCommonValues: 'static + Send + Sync { + // it is true that we could just expose freq_over_pred() and use that for freq() and + // total_freq() however, freq() and total_freq() each have potential optimizations (freq() + // is O(1) instead of O(n) and total_freq() can be cached) + // additionally, it makes sense to return an Option for freq() instead of just 0 if value + // doesn't exist thus, I expose three different functions + fn freq(&self, value: &ColumnCombValue) -> Option; + fn total_freq(&self) -> f64; + fn freq_over_pred(&self, pred: Box bool>) -> f64; + + // returns the # of entries (i.e. value + freq) in the most common values structure + fn cnt(&self) -> usize; +} + +/// A more general interface meant to perform the task of a histogram. +/// +/// This more general interface is still compatible with histograms but allows +/// more powerful statistics like TDigest. +/// +/// Ideally, Distribution would have trait bounds for Serialize and Deserialize. +/// However, I have not figured out how to both have Deserialize as a trait bound +/// and utilize the Deserialize macro, because the Deserialize trait involves lifetimes. +pub trait Distribution: 'static + Send + Sync { + // Give the probability of a random value sampled from the distribution being <= `value` + fn cdf(&self, value: &Value) -> f64; +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct ColumnCombValueStats { + pub mcvs: M, // Does NOT contain full nulls. + pub distr: Option, // Does NOT contain mcvs; optional. + pub ndistinct: u64, // Does NOT contain full nulls. + pub null_frac: f64, // % of full nulls. +} From be430ac72079b10ce70f81cc191757f61948fb91 Mon Sep 17 00:00:00 2001 From: Yuanxin Cao Date: Thu, 14 Nov 2024 13:35:33 -0500 Subject: [PATCH 06/51] refactor AttributeCombValueStats and introduce statistic-related data structures --- optd-cost-model/src/cost/mod.rs | 1 - optd-cost-model/src/cost/stats.rs | 43 ------------------------------- 2 files changed, 44 deletions(-) delete mode 100644 optd-cost-model/src/cost/stats.rs diff --git a/optd-cost-model/src/cost/mod.rs b/optd-cost-model/src/cost/mod.rs index 8fa1bb8..c98d7d7 100644 --- a/optd-cost-model/src/cost/mod.rs +++ b/optd-cost-model/src/cost/mod.rs @@ -4,4 +4,3 @@ pub mod agg; pub mod filter; pub mod join; pub mod limit; -pub mod stats; diff --git a/optd-cost-model/src/cost/stats.rs b/optd-cost-model/src/cost/stats.rs deleted file mode 100644 index ee8e322..0000000 --- a/optd-cost-model/src/cost/stats.rs +++ /dev/null @@ -1,43 +0,0 @@ -use serde::{Deserialize, Serialize}; - -use crate::common::values::Value; - -pub type ColumnCombValue = Vec>; - -/// Ideally, MostCommonValues would have trait bounds for Serialize and Deserialize. However, I have -/// not figured out how to both have Deserialize as a trait bound and utilize the Deserialize -/// macro, because the Deserialize trait involves lifetimes. -pub trait MostCommonValues: 'static + Send + Sync { - // it is true that we could just expose freq_over_pred() and use that for freq() and - // total_freq() however, freq() and total_freq() each have potential optimizations (freq() - // is O(1) instead of O(n) and total_freq() can be cached) - // additionally, it makes sense to return an Option for freq() instead of just 0 if value - // doesn't exist thus, I expose three different functions - fn freq(&self, value: &ColumnCombValue) -> Option; - fn total_freq(&self) -> f64; - fn freq_over_pred(&self, pred: Box bool>) -> f64; - - // returns the # of entries (i.e. value + freq) in the most common values structure - fn cnt(&self) -> usize; -} - -/// A more general interface meant to perform the task of a histogram. -/// -/// This more general interface is still compatible with histograms but allows -/// more powerful statistics like TDigest. -/// -/// Ideally, Distribution would have trait bounds for Serialize and Deserialize. -/// However, I have not figured out how to both have Deserialize as a trait bound -/// and utilize the Deserialize macro, because the Deserialize trait involves lifetimes. -pub trait Distribution: 'static + Send + Sync { - // Give the probability of a random value sampled from the distribution being <= `value` - fn cdf(&self, value: &Value) -> f64; -} - -#[derive(Serialize, Deserialize, Debug)] -pub struct ColumnCombValueStats { - pub mcvs: M, // Does NOT contain full nulls. - pub distr: Option, // Does NOT contain mcvs; optional. - pub ndistinct: u64, // Does NOT contain full nulls. - pub null_frac: f64, // % of full nulls. -} From 089cfef78fb9d0cdc09f7c775879c774d7c2645f Mon Sep 17 00:00:00 2001 From: Lan Lou Date: Thu, 14 Nov 2024 12:25:33 -0500 Subject: [PATCH 07/51] Change col to attr in filter --- optd-cost-model/src/cost/filter.rs | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/optd-cost-model/src/cost/filter.rs b/optd-cost-model/src/cost/filter.rs index 619b1a9..3fefa18 100644 --- a/optd-cost-model/src/cost/filter.rs +++ b/optd-cost-model/src/cost/filter.rs @@ -141,7 +141,7 @@ impl CostModelImpl { /// Convert the left and right child nodes of some operation to what they semantically are. /// This is convenient to avoid repeating the same logic just with "left" and "right" swapped. - /// The last return value is true when the input node (left) is a ColumnRefPred. + /// The last return value is true when the input node (left) is a AttributeRefPred. #[allow(clippy::type_complexity)] fn get_semantic_nodes( &self, @@ -203,11 +203,11 @@ impl CostModelImpl { } PredicateType::AttributeRef => { let attr_ref_expr = AttributeRefPred::from_pred_node(cast_expr_child) - .expect("we already checked that the type is ColumnRef"); + .expect("we already checked that the type is AttributeRef"); let attr_ref_idx = attr_ref_expr.index(); cast_node = attr_ref_expr.into_pred_node(); // The "invert" cast is to invert the cast so that we're casting the - // non_cast_node to the column's original type. + // non_cast_node to the attribute's original type. // TODO(migration): double check let invert_cast_data_type = &(self .storage_manager @@ -218,7 +218,7 @@ impl CostModelImpl { match non_cast_node.typ { PredicateType::AttributeRef => { // In general, there's no way to remove the Cast here. We can't move - // the Cast to the other ColumnRef + // the Cast to the other AttributeRef // because that would lead to an infinite loop. Thus, we just leave // the cast where it is and break. true @@ -254,7 +254,7 @@ impl CostModelImpl { is_left_attr_ref = true; attr_ref_exprs.push( AttributeRefPred::from_pred_node(uncasted_left) - .expect("we already checked that the type is ColumnRef"), + .expect("we already checked that the type is AttributeRef"), ); } PredicateType::Constant(_) => { @@ -274,7 +274,7 @@ impl CostModelImpl { PredicateType::AttributeRef => { attr_ref_exprs.push( AttributeRefPred::from_pred_node(uncasted_right) - .expect("we already checked that the type is ColumnRef"), + .expect("we already checked that the type is AttributeRef"), ); } PredicateType::Constant(_) => values.push( @@ -292,8 +292,8 @@ impl CostModelImpl { } /// The default selectivity of a comparison expression - /// Used when one side of the comparison is a column while the other side is something too - /// complex/impossible to evaluate (subquery, UDF, another column, we have no stats, etc.) + /// Used when one side of the comparison is a attribute while the other side is something too + /// complex/impossible to evaluate (subquery, UDF, another attribute, we have no stats, etc.) fn get_default_comparison_op_selectivity(comp_bin_op_typ: BinOpType) -> f64 { assert!(comp_bin_op_typ.is_comparison()); match comp_bin_op_typ { From 69607f1434cfb551f5861c0535bf058b06d8b510 Mon Sep 17 00:00:00 2001 From: Lan Lou Date: Thu, 14 Nov 2024 13:42:04 -0500 Subject: [PATCH 08/51] Complete partial implementation of filter --- .../src/common/predicates/in_list_pred.rs | 48 +++ .../src/common/predicates/like_pred.rs | 66 ++++ .../src/common/predicates/list_pred.rs | 47 +++ optd-cost-model/src/common/predicates/mod.rs | 3 + optd-cost-model/src/cost/filter.rs | 341 +++++++++++++++++- 5 files changed, 502 insertions(+), 3 deletions(-) create mode 100644 optd-cost-model/src/common/predicates/in_list_pred.rs create mode 100644 optd-cost-model/src/common/predicates/like_pred.rs create mode 100644 optd-cost-model/src/common/predicates/list_pred.rs diff --git a/optd-cost-model/src/common/predicates/in_list_pred.rs b/optd-cost-model/src/common/predicates/in_list_pred.rs new file mode 100644 index 0000000..8d3b511 --- /dev/null +++ b/optd-cost-model/src/common/predicates/in_list_pred.rs @@ -0,0 +1,48 @@ +use crate::common::{ + nodes::{ArcPredicateNode, PredicateNode, PredicateType, ReprPredicateNode}, + values::Value, +}; + +use super::list_pred::ListPred; + +#[derive(Clone, Debug)] +pub struct InListPred(pub ArcPredicateNode); + +impl InListPred { + pub fn new(child: ArcPredicateNode, list: ListPred, negated: bool) -> Self { + InListPred( + PredicateNode { + typ: PredicateType::InList, + children: vec![child, list.into_pred_node()], + data: Some(Value::Bool(negated)), + } + .into(), + ) + } + + pub fn child(&self) -> ArcPredicateNode { + self.0.child(0) + } + + pub fn list(&self) -> ListPred { + ListPred::from_pred_node(self.0.child(1)).unwrap() + } + + /// `true` for `NOT IN`. + pub fn negated(&self) -> bool { + self.0.data.as_ref().unwrap().as_bool() + } +} + +impl ReprPredicateNode for InListPred { + fn into_pred_node(self) -> ArcPredicateNode { + self.0 + } + + fn from_pred_node(pred_node: ArcPredicateNode) -> Option { + if !matches!(pred_node.typ, PredicateType::InList) { + return None; + } + Some(Self(pred_node)) + } +} diff --git a/optd-cost-model/src/common/predicates/like_pred.rs b/optd-cost-model/src/common/predicates/like_pred.rs new file mode 100644 index 0000000..bf9fe31 --- /dev/null +++ b/optd-cost-model/src/common/predicates/like_pred.rs @@ -0,0 +1,66 @@ +use std::sync::Arc; + +use crate::common::{ + nodes::{ArcPredicateNode, PredicateNode, PredicateType, ReprPredicateNode}, + values::Value, +}; + +#[derive(Clone, Debug)] +pub struct LikePred(pub ArcPredicateNode); + +impl LikePred { + pub fn new( + negated: bool, + case_insensitive: bool, + child: ArcPredicateNode, + pattern: ArcPredicateNode, + ) -> Self { + // TODO: support multiple values in data. + let negated = if negated { 1 } else { 0 }; + let case_insensitive = if case_insensitive { 1 } else { 0 }; + LikePred( + PredicateNode { + typ: PredicateType::Like, + children: vec![child.into_pred_node(), pattern.into_pred_node()], + data: Some(Value::Serialized(Arc::new([negated, case_insensitive]))), + } + .into(), + ) + } + + pub fn child(&self) -> ArcPredicateNode { + self.0.child(0) + } + + pub fn pattern(&self) -> ArcPredicateNode { + self.0.child(1) + } + + /// `true` for `NOT LIKE`. + pub fn negated(&self) -> bool { + match self.0.data.as_ref().unwrap() { + Value::Serialized(data) => data[0] != 0, + _ => panic!("not a serialized value"), + } + } + + pub fn case_insensitive(&self) -> bool { + match self.0.data.as_ref().unwrap() { + Value::Serialized(data) => data[1] != 0, + _ => panic!("not a serialized value"), + } + } +} + +impl ReprPredicateNode for LikePred { + fn into_pred_node(self) -> ArcPredicateNode { + self.0 + } + + fn from_pred_node(pred_node: ArcPredicateNode) -> Option { + if !matches!(pred_node.typ, PredicateType::Like) { + return None; + } + Some(Self(pred_node)) + } +} diff --git a/optd-cost-model/src/common/predicates/list_pred.rs b/optd-cost-model/src/common/predicates/list_pred.rs new file mode 100644 index 0000000..972598d --- /dev/null +++ b/optd-cost-model/src/common/predicates/list_pred.rs @@ -0,0 +1,47 @@ +use crate::common::nodes::{ArcPredicateNode, PredicateNode, PredicateType, ReprPredicateNode}; + +#[derive(Clone, Debug)] +pub struct ListPred(pub ArcPredicateNode); + +impl ListPred { + pub fn new(preds: Vec) -> Self { + ListPred( + PredicateNode { + typ: PredicateType::List, + children: preds, + data: None, + } + .into(), + ) + } + + /// Gets number of expressions in the list + pub fn len(&self) -> usize { + self.0.children.len() + } + + pub fn is_empty(&self) -> bool { + self.0.children.is_empty() + } + + pub fn child(&self, idx: usize) -> ArcPredicateNode { + self.0.child(idx) + } + + pub fn to_vec(&self) -> Vec { + self.0.children.clone() + } +} + +impl ReprPredicateNode for ListPred { + fn into_pred_node(self) -> ArcPredicateNode { + self.0 + } + + fn from_pred_node(pred_node: ArcPredicateNode) -> Option { + if pred_node.typ != PredicateType::List { + return None; + } + Some(Self(pred_node)) + } +} diff --git a/optd-cost-model/src/common/predicates/mod.rs b/optd-cost-model/src/common/predicates/mod.rs index d733198..877b78a 100644 --- a/optd-cost-model/src/common/predicates/mod.rs +++ b/optd-cost-model/src/common/predicates/mod.rs @@ -4,6 +4,9 @@ pub mod cast_pred; pub mod constant_pred; pub mod data_type_pred; pub mod func_pred; +pub mod in_list_pred; +pub mod like_pred; +pub mod list_pred; pub mod log_op_pred; pub mod sort_order_pred; pub mod un_op_pred; diff --git a/optd-cost-model/src/cost/filter.rs b/optd-cost-model/src/cost/filter.rs index 3fefa18..85724d4 100644 --- a/optd-cost-model/src/cost/filter.rs +++ b/optd-cost-model/src/cost/filter.rs @@ -1,5 +1,7 @@ #![allow(unused_variables)] -use optd_persistent::CostModelStorageLayer; +use std::ops::Bound; + +use optd_persistent::{cost_model::interface::Cost, CostModelStorageLayer}; use crate::{ common::{ @@ -9,6 +11,9 @@ use crate::{ bin_op_pred::BinOpType, cast_pred::CastPred, constant_pred::{ConstantPred, ConstantType}, + in_list_pred::InListPred, + like_pred::LikePred, + log_op_pred::LogOpType, un_op_pred::UnOpType, }, types::TableId, @@ -25,6 +30,13 @@ const UNIMPLEMENTED_SEL: f64 = 0.01; const DEFAULT_EQ_SEL: f64 = 0.005; // Default selectivity estimate for inequalities such as "A < b" const DEFAULT_INEQ_SEL: f64 = 0.3333333333333333; +// Used for estimating pattern selectivity character-by-character. These numbers +// are not used on their own. Depending on the characters in the pattern, the +// selectivity is multiplied by these factors. +// +// See `FULL_WILDCARD_SEL` and `FIXED_CHAR_SEL` in Postgres. +const FULL_WILDCARD_SEL_FACTOR: f64 = 5.0; +const FIXED_CHAR_SEL_FACTOR: f64 = 0.2; impl CostModelImpl { pub fn get_filter_row_cnt( @@ -76,7 +88,29 @@ impl CostModelImpl { unreachable!("all BinOpTypes should be true for at least one is_*() function") } } - _ => unimplemented!("check bool type or else panic"), + PredicateType::LogOp(log_op_typ) => { + self.get_log_op_selectivity(*log_op_typ, &expr_tree.children, table_id) + } + PredicateType::Func(_) => unimplemented!("check bool type or else panic"), + PredicateType::SortOrder(_) => { + panic!("the selectivity of sort order expressions is undefined") + } + PredicateType::Between => Ok(UNIMPLEMENTED_SEL), + PredicateType::Cast => unimplemented!("check bool type or else panic"), + PredicateType::Like => { + let like_expr = LikePred::from_pred_node(expr_tree).unwrap(); + self.get_like_selectivity(&like_expr) + } + PredicateType::DataType(_) => { + panic!("the selectivity of a data type is not defined") + } + PredicateType::InList => { + let in_list_expr = InListPred::from_pred_node(expr_tree).unwrap(); + self.get_in_list_selectivity(&in_list_expr, table_id) + } + _ => unreachable!( + "all expression DfPredType were enumerated. this should be unreachable" + ), } } @@ -107,6 +141,27 @@ impl CostModelImpl { } } + fn get_log_op_selectivity( + &self, + log_op_typ: LogOpType, + children: &[ArcPredicateNode], + table_id: TableId, + ) -> CostModelResult { + match log_op_typ { + LogOpType::And => children.iter().try_fold(1.0, |acc, child| { + let selectivity = self.get_filter_selectivity(child.clone(), table_id)?; + Ok(acc * selectivity) + }), + LogOpType::Or => { + let product = children.iter().try_fold(1.0, |acc, child| { + let selectivity = self.get_filter_selectivity(child.clone(), table_id)?; + Ok(acc * (1.0 - selectivity)) + })?; + Ok(1.0 - product) + } + } + } + /// Comparison operators are the base case for recursion in get_filter_selectivity() fn get_comp_op_selectivity( &self, @@ -131,7 +186,61 @@ impl CostModelImpl { .expect("we just checked that attr_ref_exprs.len() == 1"); let attr_ref_idx = attr_ref_expr.index(); - todo!() + // TODO: Consider attribute is a derived attribute + if values.len() == 1 { + let value = values + .first() + .expect("we just checked that values.len() == 1"); + match comp_bin_op_typ { + BinOpType::Eq => { + self.get_attribute_equality_selectivity(table_id, attr_ref_idx, value, true) + } + BinOpType::Neq => self.get_attribute_equality_selectivity( + table_id, + attr_ref_idx, + value, + false, + ), + BinOpType::Lt | BinOpType::Leq | BinOpType::Gt | BinOpType::Geq => { + let start = match (comp_bin_op_typ, is_left_attr_ref) { + (BinOpType::Lt, true) | (BinOpType::Geq, false) => Bound::Unbounded, + (BinOpType::Leq, true) | (BinOpType::Gt, false) => Bound::Unbounded, + (BinOpType::Gt, true) | (BinOpType::Leq, false) => Bound::Excluded(value), + (BinOpType::Geq, true) | (BinOpType::Lt, false) => Bound::Included(value), + _ => unreachable!("all comparison BinOpTypes were enumerated. this should be unreachable"), + }; + let end = match (comp_bin_op_typ, is_left_attr_ref) { + (BinOpType::Lt, true) | (BinOpType::Geq, false) => Bound::Excluded(value), + (BinOpType::Leq, true) | (BinOpType::Gt, false) => Bound::Included(value), + (BinOpType::Gt, true) | (BinOpType::Leq, false) => Bound::Unbounded, + (BinOpType::Geq, true) | (BinOpType::Lt, false) => Bound::Unbounded, + _ => unreachable!("all comparison BinOpTypes were enumerated. this should be unreachable"), + }; + self.get_attribute_range_selectivity(table_id, attr_ref_idx, start, end) + } + _ => unreachable!( + "all comparison BinOpTypes were enumerated. this should be unreachable" + ), + } + } else { + let non_attr_ref_expr = non_attr_ref_exprs.first().expect( + "non_attr_ref_exprs should have a value since attr_ref_exprs.len() == 1", + ); + + match non_attr_ref_expr.as_ref().typ { + PredicateType::BinOp(_) => { + Ok(Self::get_default_comparison_op_selectivity(comp_bin_op_typ)) + } + PredicateType::Cast => Ok(UNIMPLEMENTED_SEL), + PredicateType::Constant(_) => { + unreachable!("we should have handled this in the values.len() == 1 branch") + } + _ => unimplemented!( + "unhandled case of comparing a attribute ref node to {}", + non_attr_ref_expr.as_ref().typ + ), + } + } } else if attr_ref_exprs.len() == 2 { Ok(Self::get_default_comparison_op_selectivity(comp_bin_op_typ)) } else { @@ -139,6 +248,232 @@ impl CostModelImpl { } } + /// Get the selectivity of an expression of the form "attribute equals value" (or "value equals + /// attribute") Will handle the case of statistics missing + /// Equality predicates are handled entirely differently from range predicates so this is its + /// own function + /// Also, get_attribute_equality_selectivity is a subroutine when computing range + /// selectivity, which is another reason for separating these into two functions + /// is_eq means whether it's == or != + fn get_attribute_equality_selectivity( + &self, + table_id: TableId, + attr_base_index: usize, + value: &Value, + is_eq: bool, + ) -> CostModelResult { + // TODO: The attribute could be a derived attribute + todo!() + // let ret_sel = if let Some(attribute_stats) = + // self.get_attribute_comb_stats(table_id, &[attr_base_index]) + // { + // let eq_freq = if let Some(freq) = attribute_stats.mcvs.freq(&vec![Some(value.clone())]) { + // freq + // } else { + // let non_mcv_freq = 1.0 - attribute_stats.mcvs.total_freq(); + // // always safe because usize is at least as large as i32 + // let ndistinct_as_usize = attribute_stats.ndistinct as usize; + // let non_mcv_cnt = ndistinct_as_usize - attribute_stats.mcvs.cnt(); + // if non_mcv_cnt == 0 { + // return 0.0; + // } + // // note that nulls are not included in ndistinct so we don't need to do non_mcv_cnt + // // - 1 if null_frac > 0 + // (non_mcv_freq - attribute_stats.null_frac) / (non_mcv_cnt as f64) + // }; + // if is_eq { + // eq_freq + // } else { + // 1.0 - eq_freq - attribute_stats.null_frac + // } + // } else { + // #[allow(clippy::collapsible_else_if)] + // if is_eq { + // DEFAULT_EQ_SEL + // } else { + // 1.0 - DEFAULT_EQ_SEL + // } + // }; + // assert!( + // (0.0..=1.0).contains(&ret_sel), + // "ret_sel ({}) should be in [0, 1]", + // ret_sel + // ); + // ret_sel + } + + /// Get the selectivity of an expression of the form "attribute =/> value" (or "value + /// =/> attribute"). Computes selectivity based off of statistics. + /// Range predicates are handled entirely differently from equality predicates so this is its + /// own function. If it is unable to find the statistics, it returns DEFAULT_INEQ_SEL. + /// The selectivity is computed as quantile of the right bound minus quantile of the left bound. + fn get_attribute_range_selectivity( + &self, + table_id: TableId, + attr_base_index: usize, + start: Bound<&Value>, + end: Bound<&Value>, + ) -> CostModelResult { + // TODO: Consider attribute is a derived attribute + todo!() + // if let Some(attribute_stats) = self.get_attribute_comb_stats(table, &[attr_idx]) { + // // Left and right quantile contain both Distribution and MCVs. + // let left_quantile = match start { + // Bound::Unbounded => 0.0, + // Bound::Included(value) => { + // self.get_attribute_lt_value_freq(attribute_stats, table, attr_idx, value) + // } + // Bound::Excluded(value) => Self::get_attribute_leq_value_freq(attribute_stats, value), + // }; + // let right_quantile = match end { + // Bound::Unbounded => 1.0, + // Bound::Included(value) => Self::get_attribute_leq_value_freq(attribute_stats, value), + // Bound::Excluded(value) => { + // self.get_attribute_lt_value_freq(attribute_stats, table, attr_idx, value) + // } + // }; + // assert!( + // left_quantile <= right_quantile, + // "left_quantile ({}) should be <= right_quantile ({})", + // left_quantile, + // right_quantile + // ); + // right_quantile - left_quantile + // } else { + // DEFAULT_INEQ_SEL + // } + } + + /// Compute the selectivity of a (NOT) LIKE expression. + /// + /// The logic is somewhat similar to Postgres but different. Postgres first estimates the + /// histogram part of the population and then add up data for any MCV values. If the + /// histogram is large enough, it just uses the number of matches in the histogram, + /// otherwise it estimates the fixed prefix and remainder of pattern separately and + /// combine them. + /// + /// Our approach is simpler and less selective. Firstly, we don't use histogram. The selectivity + /// is composed of MCV frequency and non-MCV selectivity. MCV frequency is computed by + /// adding up frequencies of MCVs that match the pattern. Non-MCV selectivity is computed + /// in the same way that Postgres computes selectivity for the wildcard part of the pattern. + fn get_like_selectivity(&self, like_expr: &LikePred) -> CostModelResult { + let child = like_expr.child(); + + // Check child is a attribute ref. + if !matches!(child.typ, PredicateType::AttributeRef) { + return Ok(UNIMPLEMENTED_SEL); + } + + // Check pattern is a constant. + let pattern = like_expr.pattern(); + if !matches!(pattern.typ, PredicateType::Constant(_)) { + return Ok(UNIMPLEMENTED_SEL); + } + + let attr_ref_idx = AttributeRefPred::from_pred_node(child).unwrap().index(); + + // TODO: Consider attribute is a derived attribute + let pattern = ConstantPred::from_pred_node(pattern) + .expect("we already checked pattern is a constant") + .value() + .as_str(); + + // Compute the selectivity exculuding MCVs. + // See Postgres `like_selectivity`. + let non_mcv_sel = pattern + .chars() + .fold(1.0, |acc, c| { + if c == '%' { + acc * FULL_WILDCARD_SEL_FACTOR + } else { + acc * FIXED_CHAR_SEL_FACTOR + } + }) + .min(1.0); + todo!() + + // // Compute the selectivity in MCVs. + // let attribute_stats = self.get_attribute_comb_stats(table, &[*attr_idx]); + // let (mcv_freq, null_frac) = if let Some(attribute_stats) = attribute_stats { + // let pred = Box::new(move |val: &AttributeCombValue| { + // let string = StringArray::from(vec![val[0].as_ref().unwrap().as_str().as_ref()]); + // let pattern = StringArray::from(vec![pattern.as_ref()]); + // like(&string, &pattern).unwrap().value(0) + // }); + // ( + // attribute_stats.mcvs.freq_over_pred(pred), + // attribute_stats.null_frac, + // ) + // } else { + // (0.0, 0.0) + // }; + + // let result = non_mcv_sel + mcv_freq; + + // if like_expr.negated() { + // 1.0 - result - null_frac + // } else { + // result + // } + // // Postgres clamps the result after histogram and before MCV. See Postgres + // // `patternsel_common`. + // .clamp(0.0001, 0.9999) + } + + /// Only support colA in (val1, val2, val3) where colA is a attribute ref and + /// val1, val2, val3 are constants. + pub fn get_in_list_selectivity( + &self, + expr: &InListPred, + table_id: TableId, + ) -> CostModelResult { + let child = expr.child(); + + // Check child is a attribute ref. + if !matches!(child.typ, PredicateType::AttributeRef) { + return Ok(UNIMPLEMENTED_SEL); + } + + // Check all expressions in the list are constants. + let list_exprs = expr.list().to_vec(); + if list_exprs + .iter() + .any(|expr| !matches!(expr.typ, PredicateType::Constant(_))) + { + return Ok(UNIMPLEMENTED_SEL); + } + + // Convert child and const expressions to concrete types. + let attr_ref_idx = AttributeRefPred::from_pred_node(child).unwrap().index(); + let list_exprs = list_exprs + .into_iter() + .map(|expr| { + ConstantPred::from_pred_node(expr) + .expect("we already checked all list elements are constants") + }) + .collect::>(); + let negated = expr.negated(); + + // TODO: Consider attribute is a derived attribute + let in_sel = list_exprs + .iter() + .try_fold(0.0, |acc, expr| { + let selectivity = self.get_attribute_equality_selectivity( + table_id, + attr_ref_idx, + &expr.value(), + /* is_equality */ true, + )?; + Ok(acc + selectivity) + })? + .min(1.0); + if negated { + Ok(1.0 - in_sel) + } else { + Ok(in_sel) + } + } + /// Convert the left and right child nodes of some operation to what they semantically are. /// This is convenient to avoid repeating the same logic just with "left" and "right" swapped. /// The last return value is true when the input node (left) is a AttributeRefPred. From 6518e0034e1f5c8773e083affac3c11aac26654e Mon Sep 17 00:00:00 2001 From: Lan Lou Date: Thu, 14 Nov 2024 18:29:34 -0500 Subject: [PATCH 09/51] Add get_attribute_comb_stats --- optd-cost-model/src/common/nodes.rs | 3 + optd-cost-model/src/cost/filter.rs | 119 +++++++++++++--------------- optd-cost-model/src/cost_model.rs | 16 +++- 3 files changed, 74 insertions(+), 64 deletions(-) diff --git a/optd-cost-model/src/common/nodes.rs b/optd-cost-model/src/common/nodes.rs index 0bbcca1..b7a3728 100644 --- a/optd-cost-model/src/common/nodes.rs +++ b/optd-cost-model/src/common/nodes.rs @@ -79,6 +79,9 @@ pub struct PredicateNode { /// Child predicate nodes, always materialized pub children: Vec, /// Data associated with the predicate, if any + /// TODO: If it is PredicateType::AttributeRef, then + /// the data is attribute index. But we need more information + /// to represent this attribute in case it is a derived attribute. pub data: Option, } diff --git a/optd-cost-model/src/cost/filter.rs b/optd-cost-model/src/cost/filter.rs index 85724d4..39e16c8 100644 --- a/optd-cost-model/src/cost/filter.rs +++ b/optd-cost-model/src/cost/filter.rs @@ -263,43 +263,36 @@ impl CostModelImpl { is_eq: bool, ) -> CostModelResult { // TODO: The attribute could be a derived attribute - todo!() - // let ret_sel = if let Some(attribute_stats) = - // self.get_attribute_comb_stats(table_id, &[attr_base_index]) - // { - // let eq_freq = if let Some(freq) = attribute_stats.mcvs.freq(&vec![Some(value.clone())]) { - // freq - // } else { - // let non_mcv_freq = 1.0 - attribute_stats.mcvs.total_freq(); - // // always safe because usize is at least as large as i32 - // let ndistinct_as_usize = attribute_stats.ndistinct as usize; - // let non_mcv_cnt = ndistinct_as_usize - attribute_stats.mcvs.cnt(); - // if non_mcv_cnt == 0 { - // return 0.0; - // } - // // note that nulls are not included in ndistinct so we don't need to do non_mcv_cnt - // // - 1 if null_frac > 0 - // (non_mcv_freq - attribute_stats.null_frac) / (non_mcv_cnt as f64) - // }; - // if is_eq { - // eq_freq - // } else { - // 1.0 - eq_freq - attribute_stats.null_frac - // } - // } else { - // #[allow(clippy::collapsible_else_if)] - // if is_eq { - // DEFAULT_EQ_SEL - // } else { - // 1.0 - DEFAULT_EQ_SEL - // } - // }; - // assert!( - // (0.0..=1.0).contains(&ret_sel), - // "ret_sel ({}) should be in [0, 1]", - // ret_sel - // ); - // ret_sel + let ret_sel = { + let attribute_stats = self.get_attribute_comb_stats(table_id, &[attr_base_index])?; + let eq_freq = if let Some(freq) = attribute_stats.mcvs.freq(&vec![Some(value.clone())]) + { + freq + } else { + let non_mcv_freq = 1.0 - attribute_stats.mcvs.total_freq(); + // always safe because usize is at least as large as i32 + let ndistinct_as_usize = attribute_stats.ndistinct as usize; + let non_mcv_cnt = ndistinct_as_usize - attribute_stats.mcvs.cnt(); + if non_mcv_cnt == 0 { + return Ok(0.0); + } + // note that nulls are not included in ndistinct so we don't need to do non_mcv_cnt + // - 1 if null_frac > 0 + (non_mcv_freq - attribute_stats.null_frac) / (non_mcv_cnt as f64) + }; + if is_eq { + eq_freq + } else { + 1.0 - eq_freq - attribute_stats.null_frac + } + }; + + assert!( + (0.0..=1.0).contains(&ret_sel), + "ret_sel ({}) should be in [0, 1]", + ret_sel + ); + Ok(ret_sel) } /// Get the selectivity of an expression of the form "attribute =/> value" (or "value @@ -315,33 +308,33 @@ impl CostModelImpl { end: Bound<&Value>, ) -> CostModelResult { // TODO: Consider attribute is a derived attribute + let attribute_stats = self.get_attribute_comb_stats(table_id, &[attr_base_index])?; todo!() - // if let Some(attribute_stats) = self.get_attribute_comb_stats(table, &[attr_idx]) { - // // Left and right quantile contain both Distribution and MCVs. - // let left_quantile = match start { - // Bound::Unbounded => 0.0, - // Bound::Included(value) => { - // self.get_attribute_lt_value_freq(attribute_stats, table, attr_idx, value) - // } - // Bound::Excluded(value) => Self::get_attribute_leq_value_freq(attribute_stats, value), - // }; - // let right_quantile = match end { - // Bound::Unbounded => 1.0, - // Bound::Included(value) => Self::get_attribute_leq_value_freq(attribute_stats, value), - // Bound::Excluded(value) => { - // self.get_attribute_lt_value_freq(attribute_stats, table, attr_idx, value) - // } - // }; - // assert!( - // left_quantile <= right_quantile, - // "left_quantile ({}) should be <= right_quantile ({})", - // left_quantile, - // right_quantile - // ); - // right_quantile - left_quantile - // } else { - // DEFAULT_INEQ_SEL - // } + // let left_quantile = match start { + // Bound::Unbounded => 0.0, + // Bound::Included(value) => { + // self.get_attribute_lt_value_freq(attribute_stats, table, attr_idx, value) + // } + // Bound::Excluded(value) => { + // Self::get_attribute_leq_value_freq(attribute_stats, value) + // } + // }; + // let right_quantile = match end { + // Bound::Unbounded => 1.0, + // Bound::Included(value) => { + // Self::get_attribute_leq_value_freq(attribute_stats, value) + // } + // Bound::Excluded(value) => { + // self.get_attribute_lt_value_freq(attribute_stats, table, attr_idx, value) + // } + // }; + // assert!( + // left_quantile <= right_quantile, + // "left_quantile ({}) should be <= right_quantile ({})", + // left_quantile, + // right_quantile + // ); + // right_quantile - left_quantile } /// Compute the selectivity of a (NOT) LIKE expression. diff --git a/optd-cost-model/src/cost_model.rs b/optd-cost-model/src/cost_model.rs index 0b1760e..c5c6530 100644 --- a/optd-cost-model/src/cost_model.rs +++ b/optd-cost-model/src/cost_model.rs @@ -12,6 +12,7 @@ use crate::{ nodes::{ArcPredicateNode, PhysicalNodeType}, types::{AttrId, EpochId, ExprId, TableId}, }, + stats::AttributeCombValueStats, storage::CostModelStorageManager, ComputeCostContext, Cost, CostModel, CostModelResult, EstimatedStatistic, StatValue, }; @@ -67,7 +68,6 @@ impl CostModel for CostM fn get_table_statistic_for_analysis( &self, - // TODO: i32 should be changed to TableId. table_id: TableId, stat_type: StatType, epoch_id: Option, @@ -92,3 +92,17 @@ impl CostModel for CostM todo!() } } + +impl CostModelImpl { + /// TODO: documentation + /// TODO: if we have memory cache, + /// we should add the reference. (&AttributeCombValueStats) + pub(crate) fn get_attribute_comb_stats( + &self, + table_id: TableId, + attr_comb: &[usize], + ) -> CostModelResult { + self.storage_manager + .get_attributes_comb_statistics(table_id, attr_comb) + } +} From 59a8889590b5fb0156855aa104f20326ffa93130 Mon Sep 17 00:00:00 2001 From: Lan Lou Date: Thu, 14 Nov 2024 18:52:45 -0500 Subject: [PATCH 10/51] Finish first draft version of filter functionality --- optd-cost-model/Cargo.lock | 643 ++++++++++++++++++++++++++++- optd-cost-model/Cargo.toml | 1 + optd-cost-model/src/cost/filter.rs | 163 +++++--- 3 files changed, 747 insertions(+), 60 deletions(-) diff --git a/optd-cost-model/Cargo.lock b/optd-cost-model/Cargo.lock index 9464489..bf0b367 100644 --- a/optd-cost-model/Cargo.lock +++ b/optd-cost-model/Cargo.lock @@ -57,6 +57,21 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "250f629c0161ad8107cf89319e990051fae62832fd343083bea452d93e2205fd" +[[package]] +name = "alloc-no-stdlib" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc7bb162ec39d46ab1ca8c77bf72e890535becd1751bb45f64c597edb4c8c6b3" + +[[package]] +name = "alloc-stdlib" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94fb8275041c72129eb51b7d0322c29b8387a0386127718b096429201a5d6ece" +dependencies = [ + "alloc-no-stdlib", +] + [[package]] name = "allocator-api2" version = "0.2.20" @@ -127,6 +142,12 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "arrayref" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb" + [[package]] name = "arrayvec" version = "0.7.6" @@ -353,6 +374,24 @@ dependencies = [ "regex-syntax 0.7.5", ] +[[package]] +name = "async-compression" +version = "0.4.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0cb8f1d480b0ea3783ab015936d2a55c87e219676f0c0b7dec61494043f21857" +dependencies = [ + "bzip2", + "flate2", + "futures-core", + "futures-io", + "memchr", + "pin-project-lite", + "tokio", + "xz2", + "zstd 0.13.2", + "zstd-safe 7.2.1", +] + [[package]] name = "async-stream" version = "0.3.6" @@ -416,6 +455,12 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "base64" +version = "0.21.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" + [[package]] name = "base64" version = "0.22.1" @@ -469,6 +514,28 @@ dependencies = [ "wyz", ] +[[package]] +name = "blake2" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe" +dependencies = [ + "digest", +] + +[[package]] +name = "blake3" +version = "1.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d82033247fd8e890df8f740e407ad4d038debb9eb1f40533fffb32e7d17dc6f7" +dependencies = [ + "arrayref", + "arrayvec", + "cc", + "cfg-if", + "constant_time_eq", +] + [[package]] name = "block-buffer" version = "0.10.4" @@ -501,6 +568,27 @@ dependencies = [ "syn 2.0.87", ] +[[package]] +name = "brotli" +version = "3.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d640d25bc63c50fb1f0b545ffd80207d2e10a4c965530809b40ba3386825c391" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", + "brotli-decompressor", +] + +[[package]] +name = "brotli-decompressor" +version = "2.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e2e4afe60d7dd600fdd3de8d0f08c2b7ec039712e3b6137ff98b7004e82de4f" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", +] + [[package]] name = "bumpalo" version = "3.16.0" @@ -541,12 +629,35 @@ version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ac0150caa2ae65ca5bd83f25c7de183dea78d4d366469f148435e2acfbad0da" +[[package]] +name = "bzip2" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bdb116a6ef3f6c3698828873ad02c3014b3c85cadb88496095628e3ef1e347f8" +dependencies = [ + "bzip2-sys", + "libc", +] + +[[package]] +name = "bzip2-sys" +version = "0.1.11+1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "736a955f3fa7875102d57c82b8cac37ec45224a07fd32d58f9f7a186b6cd4cdc" +dependencies = [ + "cc", + "libc", + "pkg-config", +] + [[package]] name = "cc" version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fd9de9f2205d5ef3fd67e685b0df337994ddd4495e2a28d185500d0e1edfea47" dependencies = [ + "jobserver", + "libc", "shlex", ] @@ -691,6 +802,12 @@ dependencies = [ "tiny-keccak", ] +[[package]] +name = "constant_time_eq" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6" + [[package]] name = "core-foundation-sys" version = "0.8.7" @@ -721,6 +838,15 @@ version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "19d374276b40fb8bbdee95aef7c7fa6b5316ec764510eb64b8dd0e2ed0d7e7f5" +[[package]] +name = "crc32fast" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a97769d94ddab943e4510d138150169a2758b5ef3eb191a9ee688de3e23ef7b3" +dependencies = [ + "cfg-if", +] + [[package]] name = "crossbeam" version = "0.8.4" @@ -849,6 +975,67 @@ dependencies = [ "syn 2.0.87", ] +[[package]] +name = "dashmap" +version = "5.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" +dependencies = [ + "cfg-if", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + +[[package]] +name = "datafusion" +version = "32.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7014432223f4d721cb9786cd88bb89e7464e0ba984d4a7f49db7787f5f268674" +dependencies = [ + "ahash 0.8.11", + "arrow", + "arrow-array", + "arrow-schema 47.0.0", + "async-compression", + "async-trait", + "bytes", + "bzip2", + "chrono", + "dashmap", + "datafusion-common", + "datafusion-execution", + "datafusion-expr", + "datafusion-optimizer", + "datafusion-physical-expr", + "datafusion-physical-plan", + "datafusion-sql", + "flate2", + "futures", + "glob", + "half", + "hashbrown 0.14.5", + "indexmap 2.6.0", + "itertools 0.11.0", + "log", + "num_cpus", + "object_store", + "parking_lot", + "parquet", + "percent-encoding", + "pin-project-lite", + "rand", + "sqlparser", + "tempfile", + "tokio", + "tokio-util", + "url", + "uuid", + "xz2", + "zstd 0.12.4", +] + [[package]] name = "datafusion-common" version = "32.0.0" @@ -863,9 +1050,32 @@ dependencies = [ "chrono", "half", "num_cpus", + "object_store", + "parquet", "sqlparser", ] +[[package]] +name = "datafusion-execution" +version = "32.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "780b73b2407050e53f51a9781868593f694102c59e622de9a8aafc0343c4f237" +dependencies = [ + "arrow", + "chrono", + "dashmap", + "datafusion-common", + "datafusion-expr", + "futures", + "hashbrown 0.14.5", + "log", + "object_store", + "parking_lot", + "rand", + "tempfile", + "url", +] + [[package]] name = "datafusion-expr" version = "32.0.0" @@ -881,6 +1091,103 @@ dependencies = [ "strum_macros 0.25.3", ] +[[package]] +name = "datafusion-optimizer" +version = "32.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f2904a432f795484fd45e29ded4537152adb60f636c05691db34fcd94c92c96" +dependencies = [ + "arrow", + "async-trait", + "chrono", + "datafusion-common", + "datafusion-expr", + "datafusion-physical-expr", + "hashbrown 0.14.5", + "itertools 0.11.0", + "log", + "regex-syntax 0.7.5", +] + +[[package]] +name = "datafusion-physical-expr" +version = "32.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57b4968e9a998dc0476c4db7a82f280e2026b25f464e4aa0c3bb9807ee63ddfd" +dependencies = [ + "ahash 0.8.11", + "arrow", + "arrow-array", + "arrow-buffer", + "arrow-schema 47.0.0", + "base64 0.21.7", + "blake2", + "blake3", + "chrono", + "datafusion-common", + "datafusion-expr", + "half", + "hashbrown 0.14.5", + "hex", + "indexmap 2.6.0", + "itertools 0.11.0", + "libc", + "log", + "md-5", + "paste", + "petgraph", + "rand", + "regex", + "sha2", + "unicode-segmentation", + "uuid", +] + +[[package]] +name = "datafusion-physical-plan" +version = "32.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "efd0d1fe54e37a47a2d58a1232c22786f2c28ad35805fdcd08f0253a8b0aaa90" +dependencies = [ + "ahash 0.8.11", + "arrow", + "arrow-array", + "arrow-buffer", + "arrow-schema 47.0.0", + "async-trait", + "chrono", + "datafusion-common", + "datafusion-execution", + "datafusion-expr", + "datafusion-physical-expr", + "futures", + "half", + "hashbrown 0.14.5", + "indexmap 2.6.0", + "itertools 0.11.0", + "log", + "once_cell", + "parking_lot", + "pin-project-lite", + "rand", + "tokio", + "uuid", +] + +[[package]] +name = "datafusion-sql" +version = "32.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b568d44c87ead99604d704f942e257c8a236ee1bbf890ee3e034ad659dcb2c21" +dependencies = [ + "arrow", + "arrow-schema 47.0.0", + "datafusion-common", + "datafusion-expr", + "log", + "sqlparser", +] + [[package]] name = "der" version = "0.7.9" @@ -925,6 +1232,12 @@ dependencies = [ "syn 2.0.87", ] +[[package]] +name = "doc-comment" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" + [[package]] name = "dotenvy" version = "0.15.7" @@ -984,6 +1297,12 @@ version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "486f806e73c5707928240ddc295403b1b93c96a02038563881c4a2fd84b81ac4" +[[package]] +name = "fixedbitset" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" + [[package]] name = "flatbuffers" version = "23.5.26" @@ -994,6 +1313,16 @@ dependencies = [ "rustc_version", ] +[[package]] +name = "flate2" +version = "1.0.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c936bfdafb507ebbf50b8074c54fa31c5be9a1e7e5f467dd659697041407d07c" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + [[package]] name = "flume" version = "0.11.1" @@ -1034,6 +1363,7 @@ checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" dependencies = [ "futures-channel", "futures-core", + "futures-executor", "futures-io", "futures-sink", "futures-task", @@ -1084,6 +1414,17 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" +[[package]] +name = "futures-macro" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + [[package]] name = "futures-sink" version = "0.3.31" @@ -1105,6 +1446,7 @@ dependencies = [ "futures-channel", "futures-core", "futures-io", + "futures-macro", "futures-sink", "futures-task", "memchr", @@ -1242,6 +1584,12 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "humantime" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" + [[package]] name = "iana-time-zone" version = "0.1.61" @@ -1443,12 +1791,27 @@ dependencies = [ "syn 2.0.87", ] +[[package]] +name = "integer-encoding" +version = "3.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8bb03732005da905c88227371639bf1ad885cc712789c011c31c5fb3ab3ccf02" + [[package]] name = "is_terminal_polyfill" version = "1.70.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" +[[package]] +name = "itertools" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57" +dependencies = [ + "either", +] + [[package]] name = "itertools" version = "0.12.1" @@ -1473,6 +1836,15 @@ version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" +[[package]] +name = "jobserver" +version = "0.1.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0" +dependencies = [ + "libc", +] + [[package]] name = "js-sys" version = "0.3.72" @@ -1606,6 +1978,36 @@ version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +[[package]] +name = "lz4" +version = "1.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4d1febb2b4a79ddd1980eede06a8f7902197960aa0383ffcfdd62fe723036725" +dependencies = [ + "lz4-sys", +] + +[[package]] +name = "lz4-sys" +version = "1.11.1+lz4-1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6bd8c0d6c6ed0cd30b3652886bb8711dc4bb01d637a68105a3d5158039b418e6" +dependencies = [ + "cc", + "libc", +] + +[[package]] +name = "lzma-sys" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5fda04ab3764e6cde78b9974eec4f779acaba7c4e84b36eca3cf77c581b85d27" +dependencies = [ + "cc", + "libc", + "pkg-config", +] + [[package]] name = "matchers" version = "0.1.0" @@ -1784,6 +2186,27 @@ dependencies = [ "memchr", ] +[[package]] +name = "object_store" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f930c88a43b1c3f6e776dfe495b4afab89882dbc81530c632db2ed65451ebcb4" +dependencies = [ + "async-trait", + "bytes", + "chrono", + "futures", + "humantime", + "itertools 0.11.0", + "parking_lot", + "percent-encoding", + "snafu", + "tokio", + "tracing", + "url", + "walkdir", +] + [[package]] name = "once_cell" version = "1.20.2" @@ -1797,6 +2220,7 @@ dependencies = [ "arrow-schema 53.2.0", "chrono", "crossbeam", + "datafusion", "datafusion-expr", "itertools 0.13.0", "optd-persistent", @@ -1821,6 +2245,15 @@ dependencies = [ "trait-variant", ] +[[package]] +name = "ordered-float" +version = "2.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68f19d67e5a2795c94e73e0bb1cc1a7edeb2e28efd39e2e1c9b7a40c1108b11c" +dependencies = [ + "num-traits", +] + [[package]] name = "ordered-float" version = "3.9.2" @@ -1893,6 +2326,40 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "parquet" +version = "47.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0463cc3b256d5f50408c49a4be3a16674f4c8ceef60941709620a062b1f6bf4d" +dependencies = [ + "ahash 0.8.11", + "arrow-array", + "arrow-buffer", + "arrow-cast", + "arrow-data", + "arrow-ipc", + "arrow-schema 47.0.0", + "arrow-select", + "base64 0.21.7", + "brotli", + "bytes", + "chrono", + "flate2", + "futures", + "hashbrown 0.14.5", + "lz4", + "num", + "num-bigint", + "object_store", + "paste", + "seq-macro", + "snap", + "thrift", + "tokio", + "twox-hash", + "zstd 0.12.4", +] + [[package]] name = "parse-zoneinfo" version = "0.3.1" @@ -1923,6 +2390,16 @@ version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +[[package]] +name = "petgraph" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" +dependencies = [ + "fixedbitset", + "indexmap 2.6.0", +] + [[package]] name = "phf" version = "0.11.2" @@ -2361,6 +2838,15 @@ version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + [[package]] name = "scopeguard" version = "1.2.0" @@ -2538,6 +3024,12 @@ version = "1.0.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b" +[[package]] +name = "seq-macro" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" + [[package]] name = "serde" version = "1.0.215" @@ -2588,7 +3080,7 @@ version = "3.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e28bdad6db2b8340e449f7108f020b3b092e8583a9e3fb82713e1d4e71fe817" dependencies = [ - "base64", + "base64 0.22.1", "chrono", "hex", "indexmap 1.9.3", @@ -2689,6 +3181,34 @@ dependencies = [ "serde", ] +[[package]] +name = "snafu" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4de37ad025c587a29e8f3f5605c00f70b98715ef90b9061a815b9e59e9042d6" +dependencies = [ + "doc-comment", + "snafu-derive", +] + +[[package]] +name = "snafu-derive" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "990079665f075b699031e9c08fd3ab99be5029b96f3b78dc0709e8f77e4efebf" +dependencies = [ + "heck 0.4.1", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "snap" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b6b67fb9a61334225b5b790716f609cd58395f895b3fe8b328786812a40bc3b" + [[package]] name = "socket2" version = "0.5.7" @@ -2855,7 +3375,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "64bb4714269afa44aef2755150a0fc19d756fb580a67db8885608cf02f47d06a" dependencies = [ "atoi", - "base64", + "base64 0.22.1", "bigdecimal", "bitflags 2.6.0", "byteorder", @@ -2902,7 +3422,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6fa91a732d854c5d7726349bb4bb879bb9478993ceb764247660aee25f67c2f8" dependencies = [ "atoi", - "base64", + "base64 0.22.1", "bigdecimal", "bitflags 2.6.0", "byteorder", @@ -3123,6 +3643,17 @@ dependencies = [ "once_cell", ] +[[package]] +name = "thrift" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e54bc85fc7faa8bc175c4bab5b92ba8d9a3ce893d0e9f42cc455c8ab16a9e09" +dependencies = [ + "byteorder", + "integer-encoding", + "ordered-float 2.10.1", +] + [[package]] name = "time" version = "0.3.36" @@ -3198,6 +3729,7 @@ dependencies = [ "bytes", "libc", "mio", + "parking_lot", "pin-project-lite", "socket2", "tokio-macros", @@ -3226,6 +3758,19 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-util" +version = "0.7.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61e7c3654c13bcd040d4a03abee2c75b1d14a37b423cf5a813ceae1cc903ec6a" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", +] + [[package]] name = "toml_datetime" version = "0.6.8" @@ -3301,6 +3846,16 @@ dependencies = [ "syn 2.0.87", ] +[[package]] +name = "twox-hash" +version = "1.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97fee6b57c6a41524a810daee9286c02d7752c4253064d0b05472833a438f675" +dependencies = [ + "cfg-if", + "static_assertions", +] + [[package]] name = "typenum" version = "1.17.0" @@ -3334,6 +3889,12 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e70f2a8b45122e719eb623c01822704c4e0907e7e426a05927e1a1cfff5b75d0" +[[package]] +name = "unicode-segmentation" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" + [[package]] name = "unicode-width" version = "0.2.0" @@ -3387,6 +3948,7 @@ version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8c5f0a0af699448548ad1a2fbf920fb4bee257eae39953ba95cb84891a0446a" dependencies = [ + "getrandom", "serde", ] @@ -3402,6 +3964,16 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" @@ -3488,6 +4060,15 @@ dependencies = [ "wasite", ] +[[package]] +name = "winapi-util" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" +dependencies = [ + "windows-sys 0.59.0", +] + [[package]] name = "windows-core" version = "0.52.0" @@ -3675,6 +4256,15 @@ dependencies = [ "tap", ] +[[package]] +name = "xz2" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "388c44dc09d76f1536602ead6d325eb532f5c122f17782bd57fb47baeeb767e2" +dependencies = [ + "lzma-sys", +] + [[package]] name = "yansi" version = "1.0.1" @@ -3774,3 +4364,50 @@ dependencies = [ "quote", "syn 2.0.87", ] + +[[package]] +name = "zstd" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a27595e173641171fc74a1232b7b1c7a7cb6e18222c11e9dfb9888fa424c53c" +dependencies = [ + "zstd-safe 6.0.6", +] + +[[package]] +name = "zstd" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcf2b778a664581e31e389454a7072dab1647606d44f7feea22cd5abb9c9f3f9" +dependencies = [ + "zstd-safe 7.2.1", +] + +[[package]] +name = "zstd-safe" +version = "6.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee98ffd0b48ee95e6c5168188e44a54550b1564d9d530ee21d5f0eaed1069581" +dependencies = [ + "libc", + "zstd-sys", +] + +[[package]] +name = "zstd-safe" +version = "7.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54a3ab4db68cea366acc5c897c7b4d4d1b8994a9cd6e6f841f8964566a419059" +dependencies = [ + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.13+zstd.1.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38ff0f21cfee8f97d94cef41359e0c89aa6113028ab0291aa8ca0038995a95aa" +dependencies = [ + "cc", + "pkg-config", +] diff --git a/optd-cost-model/Cargo.toml b/optd-cost-model/Cargo.toml index 4161749..d667fe4 100644 --- a/optd-cost-model/Cargo.toml +++ b/optd-cost-model/Cargo.toml @@ -11,6 +11,7 @@ serde_json = "1.0" serde_with = { version = "3.7.0", features = ["json"] } arrow-schema = "53.2.0" datafusion-expr = "32.0.0" +datafusion = "32.0.0" ordered-float = "4.0" chrono = "0.4" itertools = "0.13" diff --git a/optd-cost-model/src/cost/filter.rs b/optd-cost-model/src/cost/filter.rs index 39e16c8..b71dae4 100644 --- a/optd-cost-model/src/cost/filter.rs +++ b/optd-cost-model/src/cost/filter.rs @@ -1,6 +1,8 @@ #![allow(unused_variables)] use std::ops::Bound; +use datafusion::arrow::array::StringArray; +use datafusion::arrow::compute::like; use optd_persistent::{cost_model::interface::Cost, CostModelStorageLayer}; use crate::{ @@ -20,6 +22,7 @@ use crate::{ values::Value, }, cost_model::CostModelImpl, + stats::{AttributeCombValue, AttributeCombValueStats}, CostModelResult, EstimatedStatistic, }; @@ -39,6 +42,8 @@ const FULL_WILDCARD_SEL_FACTOR: f64 = 5.0; const FIXED_CHAR_SEL_FACTOR: f64 = 0.2; impl CostModelImpl { + // TODO: is it a good design to pass table_id here? I think it needs to be refactored. + // Consider to remove table_id. pub fn get_filter_row_cnt( &self, child_row_cnt: EstimatedStatistic, @@ -99,7 +104,7 @@ impl CostModelImpl { PredicateType::Cast => unimplemented!("check bool type or else panic"), PredicateType::Like => { let like_expr = LikePred::from_pred_node(expr_tree).unwrap(); - self.get_like_selectivity(&like_expr) + self.get_like_selectivity(&like_expr, table_id) } PredicateType::DataType(_) => { panic!("the selectivity of a data type is not defined") @@ -295,6 +300,48 @@ impl CostModelImpl { Ok(ret_sel) } + /// Compute the frequency of values in a attribute less than or equal to the given value. + fn get_attribute_leq_value_freq( + per_attribute_stats: &AttributeCombValueStats, + value: &Value, + ) -> f64 { + // because distr does not include the values in MCVs, we need to compute the CDFs there as + // well because nulls return false in any comparison, they are never included when + // computing range selectivity + let distr_leq_freq = per_attribute_stats.distr.as_ref().unwrap().cdf(value); + let value = value.clone(); + let pred = Box::new(move |val: &AttributeCombValue| *val[0].as_ref().unwrap() <= value); + let mcvs_leq_freq = per_attribute_stats.mcvs.freq_over_pred(pred); + let ret_freq = distr_leq_freq + mcvs_leq_freq; + assert!( + (0.0..=1.0).contains(&ret_freq), + "ret_freq ({}) should be in [0, 1]", + ret_freq + ); + ret_freq + } + + /// Compute the frequency of values in a attribute less than the given value. + fn get_attribute_lt_value_freq( + &self, + attribute_stats: &AttributeCombValueStats, + table_id: TableId, + attr_base_index: usize, + value: &Value, + ) -> CostModelResult { + // depending on whether value is in mcvs or not, we use different logic to turn total_lt_cdf + // into total_leq_cdf this logic just so happens to be the exact same logic as + // get_attribute_equality_selectivity implements + let ret_freq = Self::get_attribute_leq_value_freq(attribute_stats, value) + - self.get_attribute_equality_selectivity(table_id, attr_base_index, value, true)?; + assert!( + (0.0..=1.0).contains(&ret_freq), + "ret_freq ({}) should be in [0, 1]", + ret_freq + ); + Ok(ret_freq) + } + /// Get the selectivity of an expression of the form "attribute =/> value" (or "value /// =/> attribute"). Computes selectivity based off of statistics. /// Range predicates are handled entirely differently from equality predicates so this is its @@ -309,32 +356,33 @@ impl CostModelImpl { ) -> CostModelResult { // TODO: Consider attribute is a derived attribute let attribute_stats = self.get_attribute_comb_stats(table_id, &[attr_base_index])?; - todo!() - // let left_quantile = match start { - // Bound::Unbounded => 0.0, - // Bound::Included(value) => { - // self.get_attribute_lt_value_freq(attribute_stats, table, attr_idx, value) - // } - // Bound::Excluded(value) => { - // Self::get_attribute_leq_value_freq(attribute_stats, value) - // } - // }; - // let right_quantile = match end { - // Bound::Unbounded => 1.0, - // Bound::Included(value) => { - // Self::get_attribute_leq_value_freq(attribute_stats, value) - // } - // Bound::Excluded(value) => { - // self.get_attribute_lt_value_freq(attribute_stats, table, attr_idx, value) - // } - // }; - // assert!( - // left_quantile <= right_quantile, - // "left_quantile ({}) should be <= right_quantile ({})", - // left_quantile, - // right_quantile - // ); - // right_quantile - left_quantile + let left_quantile = match start { + Bound::Unbounded => 0.0, + Bound::Included(value) => self.get_attribute_lt_value_freq( + &attribute_stats, + table_id, + attr_base_index, + value, + )?, + Bound::Excluded(value) => Self::get_attribute_leq_value_freq(&attribute_stats, value), + }; + let right_quantile = match end { + Bound::Unbounded => 1.0, + Bound::Included(value) => Self::get_attribute_leq_value_freq(&attribute_stats, value), + Bound::Excluded(value) => self.get_attribute_lt_value_freq( + &attribute_stats, + table_id, + attr_base_index, + value, + )?, + }; + assert!( + left_quantile <= right_quantile, + "left_quantile ({}) should be <= right_quantile ({})", + left_quantile, + right_quantile + ); + Ok(right_quantile - left_quantile) } /// Compute the selectivity of a (NOT) LIKE expression. @@ -349,7 +397,11 @@ impl CostModelImpl { /// is composed of MCV frequency and non-MCV selectivity. MCV frequency is computed by /// adding up frequencies of MCVs that match the pattern. Non-MCV selectivity is computed /// in the same way that Postgres computes selectivity for the wildcard part of the pattern. - fn get_like_selectivity(&self, like_expr: &LikePred) -> CostModelResult { + fn get_like_selectivity( + &self, + like_expr: &LikePred, + table_id: TableId, + ) -> CostModelResult { let child = like_expr.child(); // Check child is a attribute ref. @@ -383,37 +435,34 @@ impl CostModelImpl { } }) .min(1.0); - todo!() - - // // Compute the selectivity in MCVs. - // let attribute_stats = self.get_attribute_comb_stats(table, &[*attr_idx]); - // let (mcv_freq, null_frac) = if let Some(attribute_stats) = attribute_stats { - // let pred = Box::new(move |val: &AttributeCombValue| { - // let string = StringArray::from(vec![val[0].as_ref().unwrap().as_str().as_ref()]); - // let pattern = StringArray::from(vec![pattern.as_ref()]); - // like(&string, &pattern).unwrap().value(0) - // }); - // ( - // attribute_stats.mcvs.freq_over_pred(pred), - // attribute_stats.null_frac, - // ) - // } else { - // (0.0, 0.0) - // }; - - // let result = non_mcv_sel + mcv_freq; - - // if like_expr.negated() { - // 1.0 - result - null_frac - // } else { - // result - // } - // // Postgres clamps the result after histogram and before MCV. See Postgres - // // `patternsel_common`. - // .clamp(0.0001, 0.9999) + + // Compute the selectivity in MCVs. + let attribute_stats = self.get_attribute_comb_stats(table_id, &[attr_ref_idx])?; + let (mcv_freq, null_frac) = { + let pred = Box::new(move |val: &AttributeCombValue| { + let string = StringArray::from(vec![val[0].as_ref().unwrap().as_str().as_ref()]); + let pattern = StringArray::from(vec![pattern.as_ref()]); + like(&string, &pattern).unwrap().value(0) + }); + ( + attribute_stats.mcvs.freq_over_pred(pred), + attribute_stats.null_frac, + ) + }; + + let result = non_mcv_sel + mcv_freq; + + Ok(if like_expr.negated() { + 1.0 - result - null_frac + } else { + result + } + // Postgres clamps the result after histogram and before MCV. See Postgres + // `patternsel_common`. + .clamp(0.0001, 0.9999)) } - /// Only support colA in (val1, val2, val3) where colA is a attribute ref and + /// Only support attrA in (val1, val2, val3) where attrA is a attribute ref and /// val1, val2, val3 are constants. pub fn get_in_list_selectivity( &self, From 5070a78d04c63cb39ddd20b2641041c0ae8c2c0c Mon Sep 17 00:00:00 2001 From: Lan Lou Date: Thu, 14 Nov 2024 19:23:36 -0500 Subject: [PATCH 11/51] Add comment for the guideline of re-designing PredicateNode --- optd-cost-model/src/common/nodes.rs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/optd-cost-model/src/common/nodes.rs b/optd-cost-model/src/common/nodes.rs index b7a3728..aba1faf 100644 --- a/optd-cost-model/src/common/nodes.rs +++ b/optd-cost-model/src/common/nodes.rs @@ -82,6 +82,13 @@ pub struct PredicateNode { /// TODO: If it is PredicateType::AttributeRef, then /// the data is attribute index. But we need more information /// to represent this attribute in case it is a derived attribute. + /// 1. We can use Vec, but the disadvantage is that the optimizer + /// may need to do some memory copy. (However, if we want to provide a + /// general API, the memory copy is unavoidable. It applies for both 1 + /// and 2 designs). And Vec lacks readability. + /// 2. Also we can use enum, but if Rust uses something like `union` to + /// implement enum, then if some members are large, it will waste memory space, + /// also causing unnecessary memory copy. But enum provides better readability. pub data: Option, } From 740ab113e9e41ffee0fb19afa00cdd39db49b18f Mon Sep 17 00:00:00 2001 From: Yuanxin Cao Date: Fri, 15 Nov 2024 09:26:52 -0500 Subject: [PATCH 12/51] introduce IdPred and make AttributeRefPred store table id and attr index --- optd-cost-model/src/common/nodes.rs | 11 +---- .../src/common/predicates/attr_ref_pred.rs | 35 +++++++++------- .../src/common/predicates/cast_pred.rs | 5 +++ .../src/common/predicates/id_pred.rs | 40 +++++++++++++++++++ optd-cost-model/src/common/predicates/mod.rs | 1 + optd-cost-model/src/cost/filter.rs | 12 ++++-- 6 files changed, 75 insertions(+), 29 deletions(-) create mode 100644 optd-cost-model/src/common/predicates/id_pred.rs diff --git a/optd-cost-model/src/common/nodes.rs b/optd-cost-model/src/common/nodes.rs index aba1faf..87506d0 100644 --- a/optd-cost-model/src/common/nodes.rs +++ b/optd-cost-model/src/common/nodes.rs @@ -51,6 +51,7 @@ pub enum PredicateType { Constant(ConstantType), AttributeRef, ExternAttributeRef, + Id, UnOp(UnOpType), BinOp(BinOpType), LogOp(LogOpType), @@ -79,16 +80,6 @@ pub struct PredicateNode { /// Child predicate nodes, always materialized pub children: Vec, /// Data associated with the predicate, if any - /// TODO: If it is PredicateType::AttributeRef, then - /// the data is attribute index. But we need more information - /// to represent this attribute in case it is a derived attribute. - /// 1. We can use Vec, but the disadvantage is that the optimizer - /// may need to do some memory copy. (However, if we want to provide a - /// general API, the memory copy is unavoidable. It applies for both 1 - /// and 2 designs). And Vec lacks readability. - /// 2. Also we can use enum, but if Rust uses something like `union` to - /// implement enum, then if some members are large, it will waste memory space, - /// also causing unnecessary memory copy. But enum provides better readability. pub data: Option, } diff --git a/optd-cost-model/src/common/predicates/attr_ref_pred.rs b/optd-cost-model/src/common/predicates/attr_ref_pred.rs index b3b7814..bded38b 100644 --- a/optd-cost-model/src/common/predicates/attr_ref_pred.rs +++ b/optd-cost-model/src/common/predicates/attr_ref_pred.rs @@ -1,33 +1,38 @@ -use crate::common::{ - nodes::{ArcPredicateNode, PredicateNode, PredicateType, ReprPredicateNode}, - values::Value, -}; +use crate::common::nodes::{ArcPredicateNode, PredicateNode, PredicateType, ReprPredicateNode}; +use super::id_pred::IdPred; + +/// [`AttributeRefPred`] represents a reference to a column in a relation. +/// +/// An [`AttributeRefPred`] has two children: +/// 1. The table id, represented by an [`IdPred`]. +/// 2. The index of the column, represented by an [`IdPred`]. #[derive(Clone, Debug)] pub struct AttributeRefPred(pub ArcPredicateNode); impl AttributeRefPred { - /// Creates a new `ColumnRef` expression. - pub fn new(attribute_idx: usize) -> AttributeRefPred { - // this conversion is always safe since usize is at most u64 - let u64_attribute_idx = attribute_idx as u64; + pub fn new(table_id: usize, attribute_idx: usize) -> AttributeRefPred { AttributeRefPred( PredicateNode { typ: PredicateType::AttributeRef, - children: vec![], - data: Some(Value::UInt64(u64_attribute_idx)), + children: vec![ + IdPred::new(table_id).into_pred_node(), + IdPred::new(attribute_idx).into_pred_node(), + ], + data: None, } .into(), ) } - fn get_data_usize(&self) -> usize { - self.0.data.as_ref().unwrap().as_u64() as usize + /// Gets the table id. + pub fn table_id(&self) -> usize { + self.0.child(0).data.as_ref().unwrap().as_u64() as usize } - /// Gets the column index. - pub fn index(&self) -> usize { - self.get_data_usize() + /// Gets the attribute index. + pub fn attr_index(&self) -> usize { + self.0.child(1).data.as_ref().unwrap().as_u64() as usize } } diff --git a/optd-cost-model/src/common/predicates/cast_pred.rs b/optd-cost-model/src/common/predicates/cast_pred.rs index eaafca9..2e1ef54 100644 --- a/optd-cost-model/src/common/predicates/cast_pred.rs +++ b/optd-cost-model/src/common/predicates/cast_pred.rs @@ -4,6 +4,11 @@ use crate::common::nodes::{ArcPredicateNode, PredicateNode, PredicateType, ReprP use super::data_type_pred::DataTypePred; +/// [`CastPred`] casts a column from one data type to another. +/// +/// A [`CastPred`] has two children: +/// 1. The original data to cast +/// 2. The target data type to cast to #[derive(Clone, Debug)] pub struct CastPred(pub ArcPredicateNode); diff --git a/optd-cost-model/src/common/predicates/id_pred.rs b/optd-cost-model/src/common/predicates/id_pred.rs new file mode 100644 index 0000000..8a19d20 --- /dev/null +++ b/optd-cost-model/src/common/predicates/id_pred.rs @@ -0,0 +1,40 @@ +use crate::common::{ + nodes::{ArcPredicateNode, PredicateNode, PredicateType, ReprPredicateNode}, + values::Value, +}; + +/// [`IdPred`] holds an id or an index, e.g. table id. +/// +/// The data is of uint64 type, because an id or an index can always be +/// represented by uint64. +#[derive(Clone, Debug)] +pub struct IdPred(pub ArcPredicateNode); + +impl IdPred { + pub fn new(id: usize) -> IdPred { + // This conversion is always safe since usize is at most u64. + let u64_id = id as u64; + IdPred( + PredicateNode { + typ: PredicateType::Id, + children: vec![], + data: Some(Value::UInt64(u64_id)), + } + .into(), + ) + } +} + +impl ReprPredicateNode for IdPred { + fn into_pred_node(self) -> ArcPredicateNode { + self.0 + } + + fn from_pred_node(pred_node: ArcPredicateNode) -> Option { + if let PredicateType::Id = pred_node.typ { + Some(Self(pred_node)) + } else { + None + } + } +} diff --git a/optd-cost-model/src/common/predicates/mod.rs b/optd-cost-model/src/common/predicates/mod.rs index 877b78a..65d6ad0 100644 --- a/optd-cost-model/src/common/predicates/mod.rs +++ b/optd-cost-model/src/common/predicates/mod.rs @@ -4,6 +4,7 @@ pub mod cast_pred; pub mod constant_pred; pub mod data_type_pred; pub mod func_pred; +pub mod id_pred; pub mod in_list_pred; pub mod like_pred; pub mod list_pred; diff --git a/optd-cost-model/src/cost/filter.rs b/optd-cost-model/src/cost/filter.rs index b71dae4..36469b6 100644 --- a/optd-cost-model/src/cost/filter.rs +++ b/optd-cost-model/src/cost/filter.rs @@ -189,7 +189,7 @@ impl CostModelImpl { let attr_ref_expr = attr_ref_exprs .first() .expect("we just checked that attr_ref_exprs.len() == 1"); - let attr_ref_idx = attr_ref_expr.index(); + let attr_ref_idx = attr_ref_expr.attr_index(); // TODO: Consider attribute is a derived attribute if values.len() == 1 { @@ -415,7 +415,9 @@ impl CostModelImpl { return Ok(UNIMPLEMENTED_SEL); } - let attr_ref_idx = AttributeRefPred::from_pred_node(child).unwrap().index(); + let attr_ref_idx = AttributeRefPred::from_pred_node(child) + .unwrap() + .attr_index(); // TODO: Consider attribute is a derived attribute let pattern = ConstantPred::from_pred_node(pattern) @@ -486,7 +488,9 @@ impl CostModelImpl { } // Convert child and const expressions to concrete types. - let attr_ref_idx = AttributeRefPred::from_pred_node(child).unwrap().index(); + let attr_ref_idx = AttributeRefPred::from_pred_node(child) + .unwrap() + .attr_index(); let list_exprs = list_exprs .into_iter() .map(|expr| { @@ -581,7 +585,7 @@ impl CostModelImpl { PredicateType::AttributeRef => { let attr_ref_expr = AttributeRefPred::from_pred_node(cast_expr_child) .expect("we already checked that the type is AttributeRef"); - let attr_ref_idx = attr_ref_expr.index(); + let attr_ref_idx = attr_ref_expr.attr_index(); cast_node = attr_ref_expr.into_pred_node(); // The "invert" cast is to invert the cast so that we're casting the // non_cast_node to the attribute's original type. From 85cd0d174c6cbcb7e4ee8adf2d21c0407ba7bd43 Mon Sep 17 00:00:00 2001 From: Yuanxin Cao Date: Fri, 15 Nov 2024 09:41:54 -0500 Subject: [PATCH 13/51] add get method for id pred and add comments --- optd-cost-model/src/common/predicates/attr_ref_pred.rs | 4 ++++ optd-cost-model/src/common/predicates/id_pred.rs | 5 +++++ 2 files changed, 9 insertions(+) diff --git a/optd-cost-model/src/common/predicates/attr_ref_pred.rs b/optd-cost-model/src/common/predicates/attr_ref_pred.rs index bded38b..02b969e 100644 --- a/optd-cost-model/src/common/predicates/attr_ref_pred.rs +++ b/optd-cost-model/src/common/predicates/attr_ref_pred.rs @@ -7,6 +7,10 @@ use super::id_pred::IdPred; /// An [`AttributeRefPred`] has two children: /// 1. The table id, represented by an [`IdPred`]. /// 2. The index of the column, represented by an [`IdPred`]. +/// +/// Currently, [`AttributeRefPred`] only holds base table attributes, i.e. attributes +/// that already exist in the table. More complex structures may be introduced in the +/// future to represent derived attributes (e.g. t.v1 + t.v2). #[derive(Clone, Debug)] pub struct AttributeRefPred(pub ArcPredicateNode); diff --git a/optd-cost-model/src/common/predicates/id_pred.rs b/optd-cost-model/src/common/predicates/id_pred.rs index 8a19d20..962e526 100644 --- a/optd-cost-model/src/common/predicates/id_pred.rs +++ b/optd-cost-model/src/common/predicates/id_pred.rs @@ -23,6 +23,11 @@ impl IdPred { .into(), ) } + + /// Gets the id stored in the predicate. + pub fn id(&self) -> usize { + self.0.data.clone().unwrap().as_u64() as usize + } } impl ReprPredicateNode for IdPred { From 7775b889f8fcfd68a3245a722069f4e16d8d533f Mon Sep 17 00:00:00 2001 From: Yuanxin Cao Date: Fri, 15 Nov 2024 09:48:16 -0500 Subject: [PATCH 14/51] add check for derived column in AttributeRefPred --- optd-cost-model/src/common/predicates/attr_ref_pred.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/optd-cost-model/src/common/predicates/attr_ref_pred.rs b/optd-cost-model/src/common/predicates/attr_ref_pred.rs index 02b969e..f36c07f 100644 --- a/optd-cost-model/src/common/predicates/attr_ref_pred.rs +++ b/optd-cost-model/src/common/predicates/attr_ref_pred.rs @@ -38,6 +38,12 @@ impl AttributeRefPred { pub fn attr_index(&self) -> usize { self.0.child(1).data.as_ref().unwrap().as_u64() as usize } + + /// Checks whether the attribute is a derived attribute. Currently, this will always return + /// false, since derived attribute is not yet supported. + pub fn is_derived(&self) -> bool { + false + } } impl ReprPredicateNode for AttributeRefPred { From b60c6324ac5f753bee23c9d6e154e229117398a8 Mon Sep 17 00:00:00 2001 From: Yuanxin Cao Date: Fri, 15 Nov 2024 10:03:17 -0500 Subject: [PATCH 15/51] make get_attributes_comb_statistics return Option --- Cargo.lock | 643 ++++++++++++++++++++++++++++- optd-cost-model/src/cost/filter.rs | 15 +- optd-cost-model/src/cost_model.rs | 2 +- optd-cost-model/src/storage.rs | 4 - 4 files changed, 653 insertions(+), 11 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9464489..bf0b367 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -57,6 +57,21 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "250f629c0161ad8107cf89319e990051fae62832fd343083bea452d93e2205fd" +[[package]] +name = "alloc-no-stdlib" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc7bb162ec39d46ab1ca8c77bf72e890535becd1751bb45f64c597edb4c8c6b3" + +[[package]] +name = "alloc-stdlib" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94fb8275041c72129eb51b7d0322c29b8387a0386127718b096429201a5d6ece" +dependencies = [ + "alloc-no-stdlib", +] + [[package]] name = "allocator-api2" version = "0.2.20" @@ -127,6 +142,12 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "arrayref" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb" + [[package]] name = "arrayvec" version = "0.7.6" @@ -353,6 +374,24 @@ dependencies = [ "regex-syntax 0.7.5", ] +[[package]] +name = "async-compression" +version = "0.4.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0cb8f1d480b0ea3783ab015936d2a55c87e219676f0c0b7dec61494043f21857" +dependencies = [ + "bzip2", + "flate2", + "futures-core", + "futures-io", + "memchr", + "pin-project-lite", + "tokio", + "xz2", + "zstd 0.13.2", + "zstd-safe 7.2.1", +] + [[package]] name = "async-stream" version = "0.3.6" @@ -416,6 +455,12 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "base64" +version = "0.21.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" + [[package]] name = "base64" version = "0.22.1" @@ -469,6 +514,28 @@ dependencies = [ "wyz", ] +[[package]] +name = "blake2" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe" +dependencies = [ + "digest", +] + +[[package]] +name = "blake3" +version = "1.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d82033247fd8e890df8f740e407ad4d038debb9eb1f40533fffb32e7d17dc6f7" +dependencies = [ + "arrayref", + "arrayvec", + "cc", + "cfg-if", + "constant_time_eq", +] + [[package]] name = "block-buffer" version = "0.10.4" @@ -501,6 +568,27 @@ dependencies = [ "syn 2.0.87", ] +[[package]] +name = "brotli" +version = "3.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d640d25bc63c50fb1f0b545ffd80207d2e10a4c965530809b40ba3386825c391" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", + "brotli-decompressor", +] + +[[package]] +name = "brotli-decompressor" +version = "2.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e2e4afe60d7dd600fdd3de8d0f08c2b7ec039712e3b6137ff98b7004e82de4f" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", +] + [[package]] name = "bumpalo" version = "3.16.0" @@ -541,12 +629,35 @@ version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ac0150caa2ae65ca5bd83f25c7de183dea78d4d366469f148435e2acfbad0da" +[[package]] +name = "bzip2" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bdb116a6ef3f6c3698828873ad02c3014b3c85cadb88496095628e3ef1e347f8" +dependencies = [ + "bzip2-sys", + "libc", +] + +[[package]] +name = "bzip2-sys" +version = "0.1.11+1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "736a955f3fa7875102d57c82b8cac37ec45224a07fd32d58f9f7a186b6cd4cdc" +dependencies = [ + "cc", + "libc", + "pkg-config", +] + [[package]] name = "cc" version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fd9de9f2205d5ef3fd67e685b0df337994ddd4495e2a28d185500d0e1edfea47" dependencies = [ + "jobserver", + "libc", "shlex", ] @@ -691,6 +802,12 @@ dependencies = [ "tiny-keccak", ] +[[package]] +name = "constant_time_eq" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6" + [[package]] name = "core-foundation-sys" version = "0.8.7" @@ -721,6 +838,15 @@ version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "19d374276b40fb8bbdee95aef7c7fa6b5316ec764510eb64b8dd0e2ed0d7e7f5" +[[package]] +name = "crc32fast" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a97769d94ddab943e4510d138150169a2758b5ef3eb191a9ee688de3e23ef7b3" +dependencies = [ + "cfg-if", +] + [[package]] name = "crossbeam" version = "0.8.4" @@ -849,6 +975,67 @@ dependencies = [ "syn 2.0.87", ] +[[package]] +name = "dashmap" +version = "5.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" +dependencies = [ + "cfg-if", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + +[[package]] +name = "datafusion" +version = "32.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7014432223f4d721cb9786cd88bb89e7464e0ba984d4a7f49db7787f5f268674" +dependencies = [ + "ahash 0.8.11", + "arrow", + "arrow-array", + "arrow-schema 47.0.0", + "async-compression", + "async-trait", + "bytes", + "bzip2", + "chrono", + "dashmap", + "datafusion-common", + "datafusion-execution", + "datafusion-expr", + "datafusion-optimizer", + "datafusion-physical-expr", + "datafusion-physical-plan", + "datafusion-sql", + "flate2", + "futures", + "glob", + "half", + "hashbrown 0.14.5", + "indexmap 2.6.0", + "itertools 0.11.0", + "log", + "num_cpus", + "object_store", + "parking_lot", + "parquet", + "percent-encoding", + "pin-project-lite", + "rand", + "sqlparser", + "tempfile", + "tokio", + "tokio-util", + "url", + "uuid", + "xz2", + "zstd 0.12.4", +] + [[package]] name = "datafusion-common" version = "32.0.0" @@ -863,9 +1050,32 @@ dependencies = [ "chrono", "half", "num_cpus", + "object_store", + "parquet", "sqlparser", ] +[[package]] +name = "datafusion-execution" +version = "32.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "780b73b2407050e53f51a9781868593f694102c59e622de9a8aafc0343c4f237" +dependencies = [ + "arrow", + "chrono", + "dashmap", + "datafusion-common", + "datafusion-expr", + "futures", + "hashbrown 0.14.5", + "log", + "object_store", + "parking_lot", + "rand", + "tempfile", + "url", +] + [[package]] name = "datafusion-expr" version = "32.0.0" @@ -881,6 +1091,103 @@ dependencies = [ "strum_macros 0.25.3", ] +[[package]] +name = "datafusion-optimizer" +version = "32.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f2904a432f795484fd45e29ded4537152adb60f636c05691db34fcd94c92c96" +dependencies = [ + "arrow", + "async-trait", + "chrono", + "datafusion-common", + "datafusion-expr", + "datafusion-physical-expr", + "hashbrown 0.14.5", + "itertools 0.11.0", + "log", + "regex-syntax 0.7.5", +] + +[[package]] +name = "datafusion-physical-expr" +version = "32.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57b4968e9a998dc0476c4db7a82f280e2026b25f464e4aa0c3bb9807ee63ddfd" +dependencies = [ + "ahash 0.8.11", + "arrow", + "arrow-array", + "arrow-buffer", + "arrow-schema 47.0.0", + "base64 0.21.7", + "blake2", + "blake3", + "chrono", + "datafusion-common", + "datafusion-expr", + "half", + "hashbrown 0.14.5", + "hex", + "indexmap 2.6.0", + "itertools 0.11.0", + "libc", + "log", + "md-5", + "paste", + "petgraph", + "rand", + "regex", + "sha2", + "unicode-segmentation", + "uuid", +] + +[[package]] +name = "datafusion-physical-plan" +version = "32.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "efd0d1fe54e37a47a2d58a1232c22786f2c28ad35805fdcd08f0253a8b0aaa90" +dependencies = [ + "ahash 0.8.11", + "arrow", + "arrow-array", + "arrow-buffer", + "arrow-schema 47.0.0", + "async-trait", + "chrono", + "datafusion-common", + "datafusion-execution", + "datafusion-expr", + "datafusion-physical-expr", + "futures", + "half", + "hashbrown 0.14.5", + "indexmap 2.6.0", + "itertools 0.11.0", + "log", + "once_cell", + "parking_lot", + "pin-project-lite", + "rand", + "tokio", + "uuid", +] + +[[package]] +name = "datafusion-sql" +version = "32.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b568d44c87ead99604d704f942e257c8a236ee1bbf890ee3e034ad659dcb2c21" +dependencies = [ + "arrow", + "arrow-schema 47.0.0", + "datafusion-common", + "datafusion-expr", + "log", + "sqlparser", +] + [[package]] name = "der" version = "0.7.9" @@ -925,6 +1232,12 @@ dependencies = [ "syn 2.0.87", ] +[[package]] +name = "doc-comment" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" + [[package]] name = "dotenvy" version = "0.15.7" @@ -984,6 +1297,12 @@ version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "486f806e73c5707928240ddc295403b1b93c96a02038563881c4a2fd84b81ac4" +[[package]] +name = "fixedbitset" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" + [[package]] name = "flatbuffers" version = "23.5.26" @@ -994,6 +1313,16 @@ dependencies = [ "rustc_version", ] +[[package]] +name = "flate2" +version = "1.0.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c936bfdafb507ebbf50b8074c54fa31c5be9a1e7e5f467dd659697041407d07c" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + [[package]] name = "flume" version = "0.11.1" @@ -1034,6 +1363,7 @@ checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" dependencies = [ "futures-channel", "futures-core", + "futures-executor", "futures-io", "futures-sink", "futures-task", @@ -1084,6 +1414,17 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" +[[package]] +name = "futures-macro" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + [[package]] name = "futures-sink" version = "0.3.31" @@ -1105,6 +1446,7 @@ dependencies = [ "futures-channel", "futures-core", "futures-io", + "futures-macro", "futures-sink", "futures-task", "memchr", @@ -1242,6 +1584,12 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "humantime" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" + [[package]] name = "iana-time-zone" version = "0.1.61" @@ -1443,12 +1791,27 @@ dependencies = [ "syn 2.0.87", ] +[[package]] +name = "integer-encoding" +version = "3.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8bb03732005da905c88227371639bf1ad885cc712789c011c31c5fb3ab3ccf02" + [[package]] name = "is_terminal_polyfill" version = "1.70.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" +[[package]] +name = "itertools" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57" +dependencies = [ + "either", +] + [[package]] name = "itertools" version = "0.12.1" @@ -1473,6 +1836,15 @@ version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" +[[package]] +name = "jobserver" +version = "0.1.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0" +dependencies = [ + "libc", +] + [[package]] name = "js-sys" version = "0.3.72" @@ -1606,6 +1978,36 @@ version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +[[package]] +name = "lz4" +version = "1.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4d1febb2b4a79ddd1980eede06a8f7902197960aa0383ffcfdd62fe723036725" +dependencies = [ + "lz4-sys", +] + +[[package]] +name = "lz4-sys" +version = "1.11.1+lz4-1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6bd8c0d6c6ed0cd30b3652886bb8711dc4bb01d637a68105a3d5158039b418e6" +dependencies = [ + "cc", + "libc", +] + +[[package]] +name = "lzma-sys" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5fda04ab3764e6cde78b9974eec4f779acaba7c4e84b36eca3cf77c581b85d27" +dependencies = [ + "cc", + "libc", + "pkg-config", +] + [[package]] name = "matchers" version = "0.1.0" @@ -1784,6 +2186,27 @@ dependencies = [ "memchr", ] +[[package]] +name = "object_store" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f930c88a43b1c3f6e776dfe495b4afab89882dbc81530c632db2ed65451ebcb4" +dependencies = [ + "async-trait", + "bytes", + "chrono", + "futures", + "humantime", + "itertools 0.11.0", + "parking_lot", + "percent-encoding", + "snafu", + "tokio", + "tracing", + "url", + "walkdir", +] + [[package]] name = "once_cell" version = "1.20.2" @@ -1797,6 +2220,7 @@ dependencies = [ "arrow-schema 53.2.0", "chrono", "crossbeam", + "datafusion", "datafusion-expr", "itertools 0.13.0", "optd-persistent", @@ -1821,6 +2245,15 @@ dependencies = [ "trait-variant", ] +[[package]] +name = "ordered-float" +version = "2.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68f19d67e5a2795c94e73e0bb1cc1a7edeb2e28efd39e2e1c9b7a40c1108b11c" +dependencies = [ + "num-traits", +] + [[package]] name = "ordered-float" version = "3.9.2" @@ -1893,6 +2326,40 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "parquet" +version = "47.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0463cc3b256d5f50408c49a4be3a16674f4c8ceef60941709620a062b1f6bf4d" +dependencies = [ + "ahash 0.8.11", + "arrow-array", + "arrow-buffer", + "arrow-cast", + "arrow-data", + "arrow-ipc", + "arrow-schema 47.0.0", + "arrow-select", + "base64 0.21.7", + "brotli", + "bytes", + "chrono", + "flate2", + "futures", + "hashbrown 0.14.5", + "lz4", + "num", + "num-bigint", + "object_store", + "paste", + "seq-macro", + "snap", + "thrift", + "tokio", + "twox-hash", + "zstd 0.12.4", +] + [[package]] name = "parse-zoneinfo" version = "0.3.1" @@ -1923,6 +2390,16 @@ version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +[[package]] +name = "petgraph" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" +dependencies = [ + "fixedbitset", + "indexmap 2.6.0", +] + [[package]] name = "phf" version = "0.11.2" @@ -2361,6 +2838,15 @@ version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + [[package]] name = "scopeguard" version = "1.2.0" @@ -2538,6 +3024,12 @@ version = "1.0.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b" +[[package]] +name = "seq-macro" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" + [[package]] name = "serde" version = "1.0.215" @@ -2588,7 +3080,7 @@ version = "3.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e28bdad6db2b8340e449f7108f020b3b092e8583a9e3fb82713e1d4e71fe817" dependencies = [ - "base64", + "base64 0.22.1", "chrono", "hex", "indexmap 1.9.3", @@ -2689,6 +3181,34 @@ dependencies = [ "serde", ] +[[package]] +name = "snafu" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4de37ad025c587a29e8f3f5605c00f70b98715ef90b9061a815b9e59e9042d6" +dependencies = [ + "doc-comment", + "snafu-derive", +] + +[[package]] +name = "snafu-derive" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "990079665f075b699031e9c08fd3ab99be5029b96f3b78dc0709e8f77e4efebf" +dependencies = [ + "heck 0.4.1", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "snap" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b6b67fb9a61334225b5b790716f609cd58395f895b3fe8b328786812a40bc3b" + [[package]] name = "socket2" version = "0.5.7" @@ -2855,7 +3375,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "64bb4714269afa44aef2755150a0fc19d756fb580a67db8885608cf02f47d06a" dependencies = [ "atoi", - "base64", + "base64 0.22.1", "bigdecimal", "bitflags 2.6.0", "byteorder", @@ -2902,7 +3422,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6fa91a732d854c5d7726349bb4bb879bb9478993ceb764247660aee25f67c2f8" dependencies = [ "atoi", - "base64", + "base64 0.22.1", "bigdecimal", "bitflags 2.6.0", "byteorder", @@ -3123,6 +3643,17 @@ dependencies = [ "once_cell", ] +[[package]] +name = "thrift" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e54bc85fc7faa8bc175c4bab5b92ba8d9a3ce893d0e9f42cc455c8ab16a9e09" +dependencies = [ + "byteorder", + "integer-encoding", + "ordered-float 2.10.1", +] + [[package]] name = "time" version = "0.3.36" @@ -3198,6 +3729,7 @@ dependencies = [ "bytes", "libc", "mio", + "parking_lot", "pin-project-lite", "socket2", "tokio-macros", @@ -3226,6 +3758,19 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-util" +version = "0.7.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61e7c3654c13bcd040d4a03abee2c75b1d14a37b423cf5a813ceae1cc903ec6a" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", +] + [[package]] name = "toml_datetime" version = "0.6.8" @@ -3301,6 +3846,16 @@ dependencies = [ "syn 2.0.87", ] +[[package]] +name = "twox-hash" +version = "1.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97fee6b57c6a41524a810daee9286c02d7752c4253064d0b05472833a438f675" +dependencies = [ + "cfg-if", + "static_assertions", +] + [[package]] name = "typenum" version = "1.17.0" @@ -3334,6 +3889,12 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e70f2a8b45122e719eb623c01822704c4e0907e7e426a05927e1a1cfff5b75d0" +[[package]] +name = "unicode-segmentation" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" + [[package]] name = "unicode-width" version = "0.2.0" @@ -3387,6 +3948,7 @@ version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8c5f0a0af699448548ad1a2fbf920fb4bee257eae39953ba95cb84891a0446a" dependencies = [ + "getrandom", "serde", ] @@ -3402,6 +3964,16 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" @@ -3488,6 +4060,15 @@ dependencies = [ "wasite", ] +[[package]] +name = "winapi-util" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" +dependencies = [ + "windows-sys 0.59.0", +] + [[package]] name = "windows-core" version = "0.52.0" @@ -3675,6 +4256,15 @@ dependencies = [ "tap", ] +[[package]] +name = "xz2" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "388c44dc09d76f1536602ead6d325eb532f5c122f17782bd57fb47baeeb767e2" +dependencies = [ + "lzma-sys", +] + [[package]] name = "yansi" version = "1.0.1" @@ -3774,3 +4364,50 @@ dependencies = [ "quote", "syn 2.0.87", ] + +[[package]] +name = "zstd" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a27595e173641171fc74a1232b7b1c7a7cb6e18222c11e9dfb9888fa424c53c" +dependencies = [ + "zstd-safe 6.0.6", +] + +[[package]] +name = "zstd" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcf2b778a664581e31e389454a7072dab1647606d44f7feea22cd5abb9c9f3f9" +dependencies = [ + "zstd-safe 7.2.1", +] + +[[package]] +name = "zstd-safe" +version = "6.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee98ffd0b48ee95e6c5168188e44a54550b1564d9d530ee21d5f0eaed1069581" +dependencies = [ + "libc", + "zstd-sys", +] + +[[package]] +name = "zstd-safe" +version = "7.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54a3ab4db68cea366acc5c897c7b4d4d1b8994a9cd6e6f841f8964566a419059" +dependencies = [ + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.13+zstd.1.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38ff0f21cfee8f97d94cef41359e0c89aa6113028ab0291aa8ca0038995a95aa" +dependencies = [ + "cc", + "pkg-config", +] diff --git a/optd-cost-model/src/cost/filter.rs b/optd-cost-model/src/cost/filter.rs index 36469b6..370507f 100644 --- a/optd-cost-model/src/cost/filter.rs +++ b/optd-cost-model/src/cost/filter.rs @@ -269,7 +269,10 @@ impl CostModelImpl { ) -> CostModelResult { // TODO: The attribute could be a derived attribute let ret_sel = { - let attribute_stats = self.get_attribute_comb_stats(table_id, &[attr_base_index])?; + // TODO: Handle the case where `attribute_stats` is None. + let attribute_stats = self + .get_attribute_comb_stats(table_id, &[attr_base_index])? + .unwrap(); let eq_freq = if let Some(freq) = attribute_stats.mcvs.freq(&vec![Some(value.clone())]) { freq @@ -355,7 +358,10 @@ impl CostModelImpl { end: Bound<&Value>, ) -> CostModelResult { // TODO: Consider attribute is a derived attribute - let attribute_stats = self.get_attribute_comb_stats(table_id, &[attr_base_index])?; + // TODO: Handle the case where `attribute_stats` is None. + let attribute_stats = self + .get_attribute_comb_stats(table_id, &[attr_base_index])? + .unwrap(); let left_quantile = match start { Bound::Unbounded => 0.0, Bound::Included(value) => self.get_attribute_lt_value_freq( @@ -439,7 +445,10 @@ impl CostModelImpl { .min(1.0); // Compute the selectivity in MCVs. - let attribute_stats = self.get_attribute_comb_stats(table_id, &[attr_ref_idx])?; + // TODO: Handle the case where `attribute_stats` is None. + let attribute_stats = self + .get_attribute_comb_stats(table_id, &[attr_ref_idx])? + .unwrap(); let (mcv_freq, null_frac) = { let pred = Box::new(move |val: &AttributeCombValue| { let string = StringArray::from(vec![val[0].as_ref().unwrap().as_str().as_ref()]); diff --git a/optd-cost-model/src/cost_model.rs b/optd-cost-model/src/cost_model.rs index c5c6530..ebb45c2 100644 --- a/optd-cost-model/src/cost_model.rs +++ b/optd-cost-model/src/cost_model.rs @@ -101,7 +101,7 @@ impl CostModelImpl { &self, table_id: TableId, attr_comb: &[usize], - ) -> CostModelResult { + ) -> CostModelResult> { self.storage_manager .get_attributes_comb_statistics(table_id, attr_comb) } diff --git a/optd-cost-model/src/storage.rs b/optd-cost-model/src/storage.rs index 5538618..56b05e7 100644 --- a/optd-cost-model/src/storage.rs +++ b/optd-cost-model/src/storage.rs @@ -53,7 +53,6 @@ impl CostModelStorageManager { pub async fn get_attributes_comb_statistics( &self, table_id: TableId, - attr_base_indices: &[i32], ) -> CostModelResult> { let dist: Option = self .backend_manager @@ -126,6 +125,3 @@ impl CostModelStorageManager { mcvs, ndistinct, null_frac, dist, ))) } -} - -// TODO: add some tests, especially cover the error cases. From 3646ecafdfabfc689c948dab17d6ba2a60ba7e73 Mon Sep 17 00:00:00 2001 From: Yuanxin Cao Date: Fri, 15 Nov 2024 10:16:43 -0500 Subject: [PATCH 16/51] implement agg cost computation --- .../src/common/predicates/attr_ref_pred.rs | 2 + optd-cost-model/src/cost/agg.rs | 61 +++++++++++++++++++ optd-cost-model/src/lib.rs | 2 +- optd-cost-model/src/storage.rs | 2 + 4 files changed, 66 insertions(+), 1 deletion(-) diff --git a/optd-cost-model/src/common/predicates/attr_ref_pred.rs b/optd-cost-model/src/common/predicates/attr_ref_pred.rs index f36c07f..52db73d 100644 --- a/optd-cost-model/src/common/predicates/attr_ref_pred.rs +++ b/optd-cost-model/src/common/predicates/attr_ref_pred.rs @@ -11,6 +11,8 @@ use super::id_pred::IdPred; /// Currently, [`AttributeRefPred`] only holds base table attributes, i.e. attributes /// that already exist in the table. More complex structures may be introduced in the /// future to represent derived attributes (e.g. t.v1 + t.v2). +/// +/// TODO: Support derived column in `AttributeRefPred`. #[derive(Clone, Debug)] pub struct AttributeRefPred(pub ArcPredicateNode); diff --git a/optd-cost-model/src/cost/agg.rs b/optd-cost-model/src/cost/agg.rs index e69de29..34b259a 100644 --- a/optd-cost-model/src/cost/agg.rs +++ b/optd-cost-model/src/cost/agg.rs @@ -0,0 +1,61 @@ +use optd_persistent::CostModelStorageLayer; + +use crate::{ + common::{ + nodes::{ArcPredicateNode, PredicateType, ReprPredicateNode}, + predicates::{attr_ref_pred::AttributeRefPred, list_pred::ListPred}, + types::TableId, + }, + cost_model::CostModelImpl, + stats::DEFAULT_NUM_DISTINCT, + CostModelError, CostModelResult, EstimatedStatistic, +}; + +impl CostModelImpl { + pub fn get_agg_row_cnt( + &self, + group_by: ArcPredicateNode, + ) -> CostModelResult { + let group_by = ListPred::from_pred_node(group_by).unwrap(); + if group_by.is_empty() { + Ok(EstimatedStatistic(1)) + } else { + // Multiply the n-distinct of all the group by columns. + // TODO: improve with multi-dimensional n-distinct + let row_cnt = group_by.0.children.iter().try_fold(1, |acc, node| { + match node.typ { + PredicateType::AttributeRef => { + let attr_ref = + AttributeRefPred::from_pred_node(node.clone()).ok_or_else(|| { + CostModelError::InvalidPredicate( + "Expected AttributeRef predicate".to_string(), + ) + })?; + if attr_ref.is_derived() { + Ok(acc * DEFAULT_NUM_DISTINCT) + } else { + let table_id = attr_ref.table_id(); + let attr_idx = attr_ref.attr_index(); + let stats_option = + self.get_attribute_comb_stats(TableId(table_id), &vec![attr_idx])?; + + let ndistinct = match stats_option { + Some(stats) => stats.ndistinct, + None => { + // The column type is not supported or stats are missing. + DEFAULT_NUM_DISTINCT + } + }; + Ok(acc * ndistinct) + } + } + _ => { + // TODO: Consider the case where `GROUP BY 1`. + panic!("GROUP BY must have attribute ref predicate") + } + } + })?; + Ok(EstimatedStatistic(row_cnt)) + } + } +} diff --git a/optd-cost-model/src/lib.rs b/optd-cost-model/src/lib.rs index e18098f..a2afcb8 100644 --- a/optd-cost-model/src/lib.rs +++ b/optd-cost-model/src/lib.rs @@ -46,9 +46,9 @@ pub enum SemanticError { #[derive(Debug)] pub enum CostModelError { - // TODO: Add more error types ORMError(BackendError), SemanticError(SemanticError), + InvalidPredicate(String), } impl From for CostModelError { diff --git a/optd-cost-model/src/storage.rs b/optd-cost-model/src/storage.rs index 56b05e7..1ee5d0e 100644 --- a/optd-cost-model/src/storage.rs +++ b/optd-cost-model/src/storage.rs @@ -53,6 +53,7 @@ impl CostModelStorageManager { pub async fn get_attributes_comb_statistics( &self, table_id: TableId, + attr_base_indices: &[i32], ) -> CostModelResult> { let dist: Option = self .backend_manager @@ -125,3 +126,4 @@ impl CostModelStorageManager { mcvs, ndistinct, null_frac, dist, ))) } +} From db555ff31f92543d67a5e8e3b447794408917f53 Mon Sep 17 00:00:00 2001 From: Yuanxin Cao Date: Fri, 15 Nov 2024 10:18:07 -0500 Subject: [PATCH 17/51] move filter-related constants to stats crate --- optd-cost-model/src/cost/filter.rs | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/optd-cost-model/src/cost/filter.rs b/optd-cost-model/src/cost/filter.rs index 370507f..90272fa 100644 --- a/optd-cost-model/src/cost/filter.rs +++ b/optd-cost-model/src/cost/filter.rs @@ -22,25 +22,13 @@ use crate::{ values::Value, }, cost_model::CostModelImpl, - stats::{AttributeCombValue, AttributeCombValueStats}, + stats::{ + AttributeCombValue, AttributeCombValueStats, DEFAULT_EQ_SEL, DEFAULT_INEQ_SEL, + FIXED_CHAR_SEL_FACTOR, FULL_WILDCARD_SEL_FACTOR, UNIMPLEMENTED_SEL, + }, CostModelResult, EstimatedStatistic, }; -// A placeholder for unimplemented!() for codepaths which are accessed by plannertest -const UNIMPLEMENTED_SEL: f64 = 0.01; -// Default statistics. All are from selfuncs.h in Postgres unless specified otherwise -// Default selectivity estimate for equalities such as "A = b" -const DEFAULT_EQ_SEL: f64 = 0.005; -// Default selectivity estimate for inequalities such as "A < b" -const DEFAULT_INEQ_SEL: f64 = 0.3333333333333333; -// Used for estimating pattern selectivity character-by-character. These numbers -// are not used on their own. Depending on the characters in the pattern, the -// selectivity is multiplied by these factors. -// -// See `FULL_WILDCARD_SEL` and `FIXED_CHAR_SEL` in Postgres. -const FULL_WILDCARD_SEL_FACTOR: f64 = 5.0; -const FIXED_CHAR_SEL_FACTOR: f64 = 0.2; - impl CostModelImpl { // TODO: is it a good design to pass table_id here? I think it needs to be refactored. // Consider to remove table_id. From 64f4a10f9a92071e7318276532754a2d8f406118 Mon Sep 17 00:00:00 2001 From: Yuanxin Cao Date: Fri, 15 Nov 2024 10:18:26 -0500 Subject: [PATCH 18/51] fix clippy --- optd-cost-model/src/cost/agg.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optd-cost-model/src/cost/agg.rs b/optd-cost-model/src/cost/agg.rs index 34b259a..3e5a525 100644 --- a/optd-cost-model/src/cost/agg.rs +++ b/optd-cost-model/src/cost/agg.rs @@ -37,7 +37,7 @@ impl CostModelImpl { let table_id = attr_ref.table_id(); let attr_idx = attr_ref.attr_index(); let stats_option = - self.get_attribute_comb_stats(TableId(table_id), &vec![attr_idx])?; + self.get_attribute_comb_stats(TableId(table_id), &[attr_idx])?; let ndistinct = match stats_option { Some(stats) => stats.ndistinct, From cafd01cea96d33c53f5ff6696fc4ba27f10b08a7 Mon Sep 17 00:00:00 2001 From: Lan Lou Date: Fri, 15 Nov 2024 11:44:25 -0500 Subject: [PATCH 19/51] Resolve the optional comb stats, remove table id in filter --- optd-cost-model/src/common/nodes.rs | 1 + .../src/common/predicates/attr_ref_pred.rs | 13 +- optd-cost-model/src/cost/agg.rs | 2 +- optd-cost-model/src/cost/filter.rs | 218 +++++++++--------- 4 files changed, 123 insertions(+), 111 deletions(-) diff --git a/optd-cost-model/src/common/nodes.rs b/optd-cost-model/src/common/nodes.rs index 87506d0..8bfdabb 100644 --- a/optd-cost-model/src/common/nodes.rs +++ b/optd-cost-model/src/common/nodes.rs @@ -51,6 +51,7 @@ pub enum PredicateType { Constant(ConstantType), AttributeRef, ExternAttributeRef, + // TODO(lanlou): Id -> Id(IdType) Id, UnOp(UnOpType), BinOp(BinOpType), diff --git a/optd-cost-model/src/common/predicates/attr_ref_pred.rs b/optd-cost-model/src/common/predicates/attr_ref_pred.rs index 52db73d..8a670a5 100644 --- a/optd-cost-model/src/common/predicates/attr_ref_pred.rs +++ b/optd-cost-model/src/common/predicates/attr_ref_pred.rs @@ -1,4 +1,7 @@ -use crate::common::nodes::{ArcPredicateNode, PredicateNode, PredicateType, ReprPredicateNode}; +use crate::common::{ + nodes::{ArcPredicateNode, PredicateNode, PredicateType, ReprPredicateNode}, + types::TableId, +}; use super::id_pred::IdPred; @@ -8,11 +11,14 @@ use super::id_pred::IdPred; /// 1. The table id, represented by an [`IdPred`]. /// 2. The index of the column, represented by an [`IdPred`]. /// +/// **TODO**: Now we assume any IdPred is as same as the ones in the ORM layer. +/// /// Currently, [`AttributeRefPred`] only holds base table attributes, i.e. attributes /// that already exist in the table. More complex structures may be introduced in the /// future to represent derived attributes (e.g. t.v1 + t.v2). /// /// TODO: Support derived column in `AttributeRefPred`. +/// Proposal: Data field can store the column type (base or derived). #[derive(Clone, Debug)] pub struct AttributeRefPred(pub ArcPredicateNode); @@ -32,11 +38,12 @@ impl AttributeRefPred { } /// Gets the table id. - pub fn table_id(&self) -> usize { - self.0.child(0).data.as_ref().unwrap().as_u64() as usize + pub fn table_id(&self) -> TableId { + TableId(self.0.child(0).data.as_ref().unwrap().as_u64() as usize) } /// Gets the attribute index. + /// Note: The attribute index is the **base** index, which is table specific. pub fn attr_index(&self) -> usize { self.0.child(1).data.as_ref().unwrap().as_u64() as usize } diff --git a/optd-cost-model/src/cost/agg.rs b/optd-cost-model/src/cost/agg.rs index 3e5a525..a3de5aa 100644 --- a/optd-cost-model/src/cost/agg.rs +++ b/optd-cost-model/src/cost/agg.rs @@ -37,7 +37,7 @@ impl CostModelImpl { let table_id = attr_ref.table_id(); let attr_idx = attr_ref.attr_index(); let stats_option = - self.get_attribute_comb_stats(TableId(table_id), &[attr_idx])?; + self.get_attribute_comb_stats(table_id, &[attr_idx])?; let ndistinct = match stats_option { Some(stats) => stats.ndistinct, diff --git a/optd-cost-model/src/cost/filter.rs b/optd-cost-model/src/cost/filter.rs index 90272fa..432261f 100644 --- a/optd-cost-model/src/cost/filter.rs +++ b/optd-cost-model/src/cost/filter.rs @@ -22,11 +22,14 @@ use crate::{ values::Value, }, cost_model::CostModelImpl, + // TODO: If we return the default value, consider tell the upper level that we cannot + // compute the selectivity. stats::{ AttributeCombValue, AttributeCombValueStats, DEFAULT_EQ_SEL, DEFAULT_INEQ_SEL, FIXED_CHAR_SEL_FACTOR, FULL_WILDCARD_SEL_FACTOR, UNIMPLEMENTED_SEL, }, - CostModelResult, EstimatedStatistic, + CostModelResult, + EstimatedStatistic, }; impl CostModelImpl { @@ -35,21 +38,16 @@ impl CostModelImpl { pub fn get_filter_row_cnt( &self, child_row_cnt: EstimatedStatistic, - table_id: TableId, cond: ArcPredicateNode, ) -> CostModelResult { - let selectivity = { self.get_filter_selectivity(cond, table_id)? }; + let selectivity = { self.get_filter_selectivity(cond)? }; Ok( EstimatedStatistic((child_row_cnt.0 as f64 * selectivity) as u64) .max(EstimatedStatistic(1)), ) } - pub fn get_filter_selectivity( - &self, - expr_tree: ArcPredicateNode, - table_id: TableId, - ) -> CostModelResult { + pub fn get_filter_selectivity(&self, expr_tree: ArcPredicateNode) -> CostModelResult { match &expr_tree.typ { PredicateType::Constant(_) => Ok(Self::get_constant_selectivity(expr_tree)), PredicateType::AttributeRef => unimplemented!("check bool type or else panic"), @@ -60,7 +58,7 @@ impl CostModelImpl { // not doesn't care about nulls so there's no complex logic. it just reverses // the selectivity for instance, != _will not_ include nulls // but "NOT ==" _will_ include nulls - UnOpType::Not => Ok(1.0 - self.get_filter_selectivity(child, table_id)?), + UnOpType::Not => Ok(1.0 - self.get_filter_selectivity(child)?), UnOpType::Neg => panic!( "the selectivity of operations that return numerical values is undefined" ), @@ -72,7 +70,7 @@ impl CostModelImpl { let right_child = expr_tree.child(1); if bin_op_typ.is_comparison() { - self.get_comp_op_selectivity(*bin_op_typ, left_child, right_child, table_id) + self.get_comp_op_selectivity(*bin_op_typ, left_child, right_child) } else if bin_op_typ.is_numerical() { panic!( "the selectivity of operations that return numerical values is undefined" @@ -82,7 +80,7 @@ impl CostModelImpl { } } PredicateType::LogOp(log_op_typ) => { - self.get_log_op_selectivity(*log_op_typ, &expr_tree.children, table_id) + self.get_log_op_selectivity(*log_op_typ, &expr_tree.children) } PredicateType::Func(_) => unimplemented!("check bool type or else panic"), PredicateType::SortOrder(_) => { @@ -92,14 +90,14 @@ impl CostModelImpl { PredicateType::Cast => unimplemented!("check bool type or else panic"), PredicateType::Like => { let like_expr = LikePred::from_pred_node(expr_tree).unwrap(); - self.get_like_selectivity(&like_expr, table_id) + self.get_like_selectivity(&like_expr) } PredicateType::DataType(_) => { panic!("the selectivity of a data type is not defined") } PredicateType::InList => { let in_list_expr = InListPred::from_pred_node(expr_tree).unwrap(); - self.get_in_list_selectivity(&in_list_expr, table_id) + self.get_in_list_selectivity(&in_list_expr) } _ => unreachable!( "all expression DfPredType were enumerated. this should be unreachable" @@ -138,16 +136,15 @@ impl CostModelImpl { &self, log_op_typ: LogOpType, children: &[ArcPredicateNode], - table_id: TableId, ) -> CostModelResult { match log_op_typ { LogOpType::And => children.iter().try_fold(1.0, |acc, child| { - let selectivity = self.get_filter_selectivity(child.clone(), table_id)?; + let selectivity = self.get_filter_selectivity(child.clone())?; Ok(acc * selectivity) }), LogOpType::Or => { let product = children.iter().try_fold(1.0, |acc, child| { - let selectivity = self.get_filter_selectivity(child.clone(), table_id)?; + let selectivity = self.get_filter_selectivity(child.clone())?; Ok(acc * (1.0 - selectivity)) })?; Ok(1.0 - product) @@ -161,14 +158,13 @@ impl CostModelImpl { comp_bin_op_typ: BinOpType, left: ArcPredicateNode, right: ArcPredicateNode, - table_id: TableId, ) -> CostModelResult { assert!(comp_bin_op_typ.is_comparison()); // I intentionally performed moves on left and right. This way, we don't accidentally use // them after this block let (attr_ref_exprs, values, non_attr_ref_exprs, is_left_attr_ref) = - self.get_semantic_nodes(left, right, table_id)?; + self.get_semantic_nodes(left, right)?; // Handle the different cases of semantic nodes. if attr_ref_exprs.is_empty() { @@ -178,6 +174,7 @@ impl CostModelImpl { .first() .expect("we just checked that attr_ref_exprs.len() == 1"); let attr_ref_idx = attr_ref_expr.attr_index(); + let table_id = attr_ref_expr.table_id(); // TODO: Consider attribute is a derived attribute if values.len() == 1 { @@ -257,29 +254,36 @@ impl CostModelImpl { ) -> CostModelResult { // TODO: The attribute could be a derived attribute let ret_sel = { - // TODO: Handle the case where `attribute_stats` is None. - let attribute_stats = self - .get_attribute_comb_stats(table_id, &[attr_base_index])? - .unwrap(); - let eq_freq = if let Some(freq) = attribute_stats.mcvs.freq(&vec![Some(value.clone())]) + if let Some(attribute_stats) = + self.get_attribute_comb_stats(table_id, &[attr_base_index])? { - freq - } else { - let non_mcv_freq = 1.0 - attribute_stats.mcvs.total_freq(); - // always safe because usize is at least as large as i32 - let ndistinct_as_usize = attribute_stats.ndistinct as usize; - let non_mcv_cnt = ndistinct_as_usize - attribute_stats.mcvs.cnt(); - if non_mcv_cnt == 0 { - return Ok(0.0); + let eq_freq = + if let Some(freq) = attribute_stats.mcvs.freq(&vec![Some(value.clone())]) { + freq + } else { + let non_mcv_freq = 1.0 - attribute_stats.mcvs.total_freq(); + // always safe because usize is at least as large as i32 + let ndistinct_as_usize = attribute_stats.ndistinct as usize; + let non_mcv_cnt = ndistinct_as_usize - attribute_stats.mcvs.cnt(); + if non_mcv_cnt == 0 { + return Ok(0.0); + } + // note that nulls are not included in ndistinct so we don't need to do non_mcv_cnt + // - 1 if null_frac > 0 + (non_mcv_freq - attribute_stats.null_frac) / (non_mcv_cnt as f64) + }; + if is_eq { + eq_freq + } else { + 1.0 - eq_freq - attribute_stats.null_frac } - // note that nulls are not included in ndistinct so we don't need to do non_mcv_cnt - // - 1 if null_frac > 0 - (non_mcv_freq - attribute_stats.null_frac) / (non_mcv_cnt as f64) - }; - if is_eq { - eq_freq } else { - 1.0 - eq_freq - attribute_stats.null_frac + #[allow(clippy::collapsible_else_if)] + if is_eq { + DEFAULT_EQ_SEL + } else { + 1.0 - DEFAULT_EQ_SEL + } } }; @@ -346,37 +350,43 @@ impl CostModelImpl { end: Bound<&Value>, ) -> CostModelResult { // TODO: Consider attribute is a derived attribute - // TODO: Handle the case where `attribute_stats` is None. - let attribute_stats = self - .get_attribute_comb_stats(table_id, &[attr_base_index])? - .unwrap(); - let left_quantile = match start { - Bound::Unbounded => 0.0, - Bound::Included(value) => self.get_attribute_lt_value_freq( - &attribute_stats, - table_id, - attr_base_index, - value, - )?, - Bound::Excluded(value) => Self::get_attribute_leq_value_freq(&attribute_stats, value), - }; - let right_quantile = match end { - Bound::Unbounded => 1.0, - Bound::Included(value) => Self::get_attribute_leq_value_freq(&attribute_stats, value), - Bound::Excluded(value) => self.get_attribute_lt_value_freq( - &attribute_stats, - table_id, - attr_base_index, - value, - )?, - }; - assert!( - left_quantile <= right_quantile, - "left_quantile ({}) should be <= right_quantile ({})", - left_quantile, - right_quantile - ); - Ok(right_quantile - left_quantile) + if let Some(attribute_stats) = + self.get_attribute_comb_stats(table_id, &[attr_base_index])? + { + let left_quantile = match start { + Bound::Unbounded => 0.0, + Bound::Included(value) => self.get_attribute_lt_value_freq( + &attribute_stats, + table_id, + attr_base_index, + value, + )?, + Bound::Excluded(value) => { + Self::get_attribute_leq_value_freq(&attribute_stats, value) + } + }; + let right_quantile = match end { + Bound::Unbounded => 1.0, + Bound::Included(value) => { + Self::get_attribute_leq_value_freq(&attribute_stats, value) + } + Bound::Excluded(value) => self.get_attribute_lt_value_freq( + &attribute_stats, + table_id, + attr_base_index, + value, + )?, + }; + assert!( + left_quantile <= right_quantile, + "left_quantile ({}) should be <= right_quantile ({})", + left_quantile, + right_quantile + ); + Ok(right_quantile - left_quantile) + } else { + Ok(DEFAULT_INEQ_SEL) + } } /// Compute the selectivity of a (NOT) LIKE expression. @@ -391,11 +401,7 @@ impl CostModelImpl { /// is composed of MCV frequency and non-MCV selectivity. MCV frequency is computed by /// adding up frequencies of MCVs that match the pattern. Non-MCV selectivity is computed /// in the same way that Postgres computes selectivity for the wildcard part of the pattern. - fn get_like_selectivity( - &self, - like_expr: &LikePred, - table_id: TableId, - ) -> CostModelResult { + fn get_like_selectivity(&self, like_expr: &LikePred) -> CostModelResult { let child = like_expr.child(); // Check child is a attribute ref. @@ -409,9 +415,9 @@ impl CostModelImpl { return Ok(UNIMPLEMENTED_SEL); } - let attr_ref_idx = AttributeRefPred::from_pred_node(child) - .unwrap() - .attr_index(); + let attr_ref_pred = AttributeRefPred::from_pred_node(child).unwrap(); + let attr_ref_idx = attr_ref_pred.attr_index(); + let table_id = attr_ref_pred.table_id(); // TODO: Consider attribute is a derived attribute let pattern = ConstantPred::from_pred_node(pattern) @@ -434,40 +440,38 @@ impl CostModelImpl { // Compute the selectivity in MCVs. // TODO: Handle the case where `attribute_stats` is None. - let attribute_stats = self - .get_attribute_comb_stats(table_id, &[attr_ref_idx])? - .unwrap(); - let (mcv_freq, null_frac) = { - let pred = Box::new(move |val: &AttributeCombValue| { - let string = StringArray::from(vec![val[0].as_ref().unwrap().as_str().as_ref()]); - let pattern = StringArray::from(vec![pattern.as_ref()]); - like(&string, &pattern).unwrap().value(0) - }); - ( - attribute_stats.mcvs.freq_over_pred(pred), - attribute_stats.null_frac, - ) - }; + if let Some(attribute_stats) = self.get_attribute_comb_stats(table_id, &[attr_ref_idx])? { + let (mcv_freq, null_frac) = { + let pred = Box::new(move |val: &AttributeCombValue| { + let string = + StringArray::from(vec![val[0].as_ref().unwrap().as_str().as_ref()]); + let pattern = StringArray::from(vec![pattern.as_ref()]); + like(&string, &pattern).unwrap().value(0) + }); + ( + attribute_stats.mcvs.freq_over_pred(pred), + attribute_stats.null_frac, + ) + }; - let result = non_mcv_sel + mcv_freq; + let result = non_mcv_sel + mcv_freq; - Ok(if like_expr.negated() { - 1.0 - result - null_frac + Ok(if like_expr.negated() { + 1.0 - result - null_frac + } else { + result + } + // Postgres clamps the result after histogram and before MCV. See Postgres + // `patternsel_common`. + .clamp(0.0001, 0.9999)) } else { - result + Ok(UNIMPLEMENTED_SEL) } - // Postgres clamps the result after histogram and before MCV. See Postgres - // `patternsel_common`. - .clamp(0.0001, 0.9999)) } /// Only support attrA in (val1, val2, val3) where attrA is a attribute ref and /// val1, val2, val3 are constants. - pub fn get_in_list_selectivity( - &self, - expr: &InListPred, - table_id: TableId, - ) -> CostModelResult { + pub fn get_in_list_selectivity(&self, expr: &InListPred) -> CostModelResult { let child = expr.child(); // Check child is a attribute ref. @@ -485,9 +489,9 @@ impl CostModelImpl { } // Convert child and const expressions to concrete types. - let attr_ref_idx = AttributeRefPred::from_pred_node(child) - .unwrap() - .attr_index(); + let attr_ref_pred = AttributeRefPred::from_pred_node(child).unwrap(); + let attr_ref_idx = attr_ref_pred.attr_index(); + let table_id = attr_ref_pred.table_id(); let list_exprs = list_exprs .into_iter() .map(|expr| { @@ -525,7 +529,6 @@ impl CostModelImpl { &self, left: ArcPredicateNode, right: ArcPredicateNode, - table_id: TableId, ) -> CostModelResult<( Vec, Vec, @@ -583,6 +586,7 @@ impl CostModelImpl { let attr_ref_expr = AttributeRefPred::from_pred_node(cast_expr_child) .expect("we already checked that the type is AttributeRef"); let attr_ref_idx = attr_ref_expr.attr_index(); + let table_id = attr_ref_expr.table_id(); cast_node = attr_ref_expr.into_pred_node(); // The "invert" cast is to invert the cast so that we're casting the // non_cast_node to the attribute's original type. From 5c5a40f68e30ea7970dbde4ff11866f50244a160 Mon Sep 17 00:00:00 2001 From: Lan Lou Date: Fri, 15 Nov 2024 12:51:05 -0500 Subject: [PATCH 20/51] Refactor filter implementation --- optd-cost-model/src/cost/filter.rs | 690 ------------------ optd-cost-model/src/cost/filter/attribute.rs | 165 +++++ optd-cost-model/src/cost/filter/comp_op.rs | 274 +++++++ optd-cost-model/src/cost/filter/constant.rs | 39 + optd-cost-model/src/cost/filter/controller.rs | 85 +++ optd-cost-model/src/cost/filter/in_list.rs | 67 ++ optd-cost-model/src/cost/filter/like.rs | 98 +++ optd-cost-model/src/cost/filter/log_op.rs | 29 + optd-cost-model/src/cost/filter/mod.rs | 7 + 9 files changed, 764 insertions(+), 690 deletions(-) delete mode 100644 optd-cost-model/src/cost/filter.rs create mode 100644 optd-cost-model/src/cost/filter/attribute.rs create mode 100644 optd-cost-model/src/cost/filter/comp_op.rs create mode 100644 optd-cost-model/src/cost/filter/constant.rs create mode 100644 optd-cost-model/src/cost/filter/controller.rs create mode 100644 optd-cost-model/src/cost/filter/in_list.rs create mode 100644 optd-cost-model/src/cost/filter/like.rs create mode 100644 optd-cost-model/src/cost/filter/log_op.rs create mode 100644 optd-cost-model/src/cost/filter/mod.rs diff --git a/optd-cost-model/src/cost/filter.rs b/optd-cost-model/src/cost/filter.rs deleted file mode 100644 index 432261f..0000000 --- a/optd-cost-model/src/cost/filter.rs +++ /dev/null @@ -1,690 +0,0 @@ -#![allow(unused_variables)] -use std::ops::Bound; - -use datafusion::arrow::array::StringArray; -use datafusion::arrow::compute::like; -use optd_persistent::{cost_model::interface::Cost, CostModelStorageLayer}; - -use crate::{ - common::{ - nodes::{ArcPredicateNode, PredicateType, ReprPredicateNode}, - predicates::{ - attr_ref_pred::AttributeRefPred, - bin_op_pred::BinOpType, - cast_pred::CastPred, - constant_pred::{ConstantPred, ConstantType}, - in_list_pred::InListPred, - like_pred::LikePred, - log_op_pred::LogOpType, - un_op_pred::UnOpType, - }, - types::TableId, - values::Value, - }, - cost_model::CostModelImpl, - // TODO: If we return the default value, consider tell the upper level that we cannot - // compute the selectivity. - stats::{ - AttributeCombValue, AttributeCombValueStats, DEFAULT_EQ_SEL, DEFAULT_INEQ_SEL, - FIXED_CHAR_SEL_FACTOR, FULL_WILDCARD_SEL_FACTOR, UNIMPLEMENTED_SEL, - }, - CostModelResult, - EstimatedStatistic, -}; - -impl CostModelImpl { - // TODO: is it a good design to pass table_id here? I think it needs to be refactored. - // Consider to remove table_id. - pub fn get_filter_row_cnt( - &self, - child_row_cnt: EstimatedStatistic, - cond: ArcPredicateNode, - ) -> CostModelResult { - let selectivity = { self.get_filter_selectivity(cond)? }; - Ok( - EstimatedStatistic((child_row_cnt.0 as f64 * selectivity) as u64) - .max(EstimatedStatistic(1)), - ) - } - - pub fn get_filter_selectivity(&self, expr_tree: ArcPredicateNode) -> CostModelResult { - match &expr_tree.typ { - PredicateType::Constant(_) => Ok(Self::get_constant_selectivity(expr_tree)), - PredicateType::AttributeRef => unimplemented!("check bool type or else panic"), - PredicateType::UnOp(un_op_typ) => { - assert!(expr_tree.children.len() == 1); - let child = expr_tree.child(0); - match un_op_typ { - // not doesn't care about nulls so there's no complex logic. it just reverses - // the selectivity for instance, != _will not_ include nulls - // but "NOT ==" _will_ include nulls - UnOpType::Not => Ok(1.0 - self.get_filter_selectivity(child)?), - UnOpType::Neg => panic!( - "the selectivity of operations that return numerical values is undefined" - ), - } - } - PredicateType::BinOp(bin_op_typ) => { - assert!(expr_tree.children.len() == 2); - let left_child = expr_tree.child(0); - let right_child = expr_tree.child(1); - - if bin_op_typ.is_comparison() { - self.get_comp_op_selectivity(*bin_op_typ, left_child, right_child) - } else if bin_op_typ.is_numerical() { - panic!( - "the selectivity of operations that return numerical values is undefined" - ) - } else { - unreachable!("all BinOpTypes should be true for at least one is_*() function") - } - } - PredicateType::LogOp(log_op_typ) => { - self.get_log_op_selectivity(*log_op_typ, &expr_tree.children) - } - PredicateType::Func(_) => unimplemented!("check bool type or else panic"), - PredicateType::SortOrder(_) => { - panic!("the selectivity of sort order expressions is undefined") - } - PredicateType::Between => Ok(UNIMPLEMENTED_SEL), - PredicateType::Cast => unimplemented!("check bool type or else panic"), - PredicateType::Like => { - let like_expr = LikePred::from_pred_node(expr_tree).unwrap(); - self.get_like_selectivity(&like_expr) - } - PredicateType::DataType(_) => { - panic!("the selectivity of a data type is not defined") - } - PredicateType::InList => { - let in_list_expr = InListPred::from_pred_node(expr_tree).unwrap(); - self.get_in_list_selectivity(&in_list_expr) - } - _ => unreachable!( - "all expression DfPredType were enumerated. this should be unreachable" - ), - } - } - - fn get_constant_selectivity(const_node: ArcPredicateNode) -> f64 { - if let PredicateType::Constant(const_typ) = const_node.typ { - if matches!(const_typ, ConstantType::Bool) { - let value = const_node - .as_ref() - .data - .as_ref() - .expect("constants should have data"); - if let Value::Bool(bool_value) = value { - if *bool_value { - 1.0 - } else { - 0.0 - } - } else { - unreachable!( - "if the typ is ConstantType::Bool, the value should be a Value::Bool" - ) - } - } else { - panic!("selectivity is not defined on constants which are not bools") - } - } else { - panic!("get_constant_selectivity must be called on a constant") - } - } - - fn get_log_op_selectivity( - &self, - log_op_typ: LogOpType, - children: &[ArcPredicateNode], - ) -> CostModelResult { - match log_op_typ { - LogOpType::And => children.iter().try_fold(1.0, |acc, child| { - let selectivity = self.get_filter_selectivity(child.clone())?; - Ok(acc * selectivity) - }), - LogOpType::Or => { - let product = children.iter().try_fold(1.0, |acc, child| { - let selectivity = self.get_filter_selectivity(child.clone())?; - Ok(acc * (1.0 - selectivity)) - })?; - Ok(1.0 - product) - } - } - } - - /// Comparison operators are the base case for recursion in get_filter_selectivity() - fn get_comp_op_selectivity( - &self, - comp_bin_op_typ: BinOpType, - left: ArcPredicateNode, - right: ArcPredicateNode, - ) -> CostModelResult { - assert!(comp_bin_op_typ.is_comparison()); - - // I intentionally performed moves on left and right. This way, we don't accidentally use - // them after this block - let (attr_ref_exprs, values, non_attr_ref_exprs, is_left_attr_ref) = - self.get_semantic_nodes(left, right)?; - - // Handle the different cases of semantic nodes. - if attr_ref_exprs.is_empty() { - Ok(UNIMPLEMENTED_SEL) - } else if attr_ref_exprs.len() == 1 { - let attr_ref_expr = attr_ref_exprs - .first() - .expect("we just checked that attr_ref_exprs.len() == 1"); - let attr_ref_idx = attr_ref_expr.attr_index(); - let table_id = attr_ref_expr.table_id(); - - // TODO: Consider attribute is a derived attribute - if values.len() == 1 { - let value = values - .first() - .expect("we just checked that values.len() == 1"); - match comp_bin_op_typ { - BinOpType::Eq => { - self.get_attribute_equality_selectivity(table_id, attr_ref_idx, value, true) - } - BinOpType::Neq => self.get_attribute_equality_selectivity( - table_id, - attr_ref_idx, - value, - false, - ), - BinOpType::Lt | BinOpType::Leq | BinOpType::Gt | BinOpType::Geq => { - let start = match (comp_bin_op_typ, is_left_attr_ref) { - (BinOpType::Lt, true) | (BinOpType::Geq, false) => Bound::Unbounded, - (BinOpType::Leq, true) | (BinOpType::Gt, false) => Bound::Unbounded, - (BinOpType::Gt, true) | (BinOpType::Leq, false) => Bound::Excluded(value), - (BinOpType::Geq, true) | (BinOpType::Lt, false) => Bound::Included(value), - _ => unreachable!("all comparison BinOpTypes were enumerated. this should be unreachable"), - }; - let end = match (comp_bin_op_typ, is_left_attr_ref) { - (BinOpType::Lt, true) | (BinOpType::Geq, false) => Bound::Excluded(value), - (BinOpType::Leq, true) | (BinOpType::Gt, false) => Bound::Included(value), - (BinOpType::Gt, true) | (BinOpType::Leq, false) => Bound::Unbounded, - (BinOpType::Geq, true) | (BinOpType::Lt, false) => Bound::Unbounded, - _ => unreachable!("all comparison BinOpTypes were enumerated. this should be unreachable"), - }; - self.get_attribute_range_selectivity(table_id, attr_ref_idx, start, end) - } - _ => unreachable!( - "all comparison BinOpTypes were enumerated. this should be unreachable" - ), - } - } else { - let non_attr_ref_expr = non_attr_ref_exprs.first().expect( - "non_attr_ref_exprs should have a value since attr_ref_exprs.len() == 1", - ); - - match non_attr_ref_expr.as_ref().typ { - PredicateType::BinOp(_) => { - Ok(Self::get_default_comparison_op_selectivity(comp_bin_op_typ)) - } - PredicateType::Cast => Ok(UNIMPLEMENTED_SEL), - PredicateType::Constant(_) => { - unreachable!("we should have handled this in the values.len() == 1 branch") - } - _ => unimplemented!( - "unhandled case of comparing a attribute ref node to {}", - non_attr_ref_expr.as_ref().typ - ), - } - } - } else if attr_ref_exprs.len() == 2 { - Ok(Self::get_default_comparison_op_selectivity(comp_bin_op_typ)) - } else { - unreachable!("we could have at most pushed left and right into attr_ref_exprs") - } - } - - /// Get the selectivity of an expression of the form "attribute equals value" (or "value equals - /// attribute") Will handle the case of statistics missing - /// Equality predicates are handled entirely differently from range predicates so this is its - /// own function - /// Also, get_attribute_equality_selectivity is a subroutine when computing range - /// selectivity, which is another reason for separating these into two functions - /// is_eq means whether it's == or != - fn get_attribute_equality_selectivity( - &self, - table_id: TableId, - attr_base_index: usize, - value: &Value, - is_eq: bool, - ) -> CostModelResult { - // TODO: The attribute could be a derived attribute - let ret_sel = { - if let Some(attribute_stats) = - self.get_attribute_comb_stats(table_id, &[attr_base_index])? - { - let eq_freq = - if let Some(freq) = attribute_stats.mcvs.freq(&vec![Some(value.clone())]) { - freq - } else { - let non_mcv_freq = 1.0 - attribute_stats.mcvs.total_freq(); - // always safe because usize is at least as large as i32 - let ndistinct_as_usize = attribute_stats.ndistinct as usize; - let non_mcv_cnt = ndistinct_as_usize - attribute_stats.mcvs.cnt(); - if non_mcv_cnt == 0 { - return Ok(0.0); - } - // note that nulls are not included in ndistinct so we don't need to do non_mcv_cnt - // - 1 if null_frac > 0 - (non_mcv_freq - attribute_stats.null_frac) / (non_mcv_cnt as f64) - }; - if is_eq { - eq_freq - } else { - 1.0 - eq_freq - attribute_stats.null_frac - } - } else { - #[allow(clippy::collapsible_else_if)] - if is_eq { - DEFAULT_EQ_SEL - } else { - 1.0 - DEFAULT_EQ_SEL - } - } - }; - - assert!( - (0.0..=1.0).contains(&ret_sel), - "ret_sel ({}) should be in [0, 1]", - ret_sel - ); - Ok(ret_sel) - } - - /// Compute the frequency of values in a attribute less than or equal to the given value. - fn get_attribute_leq_value_freq( - per_attribute_stats: &AttributeCombValueStats, - value: &Value, - ) -> f64 { - // because distr does not include the values in MCVs, we need to compute the CDFs there as - // well because nulls return false in any comparison, they are never included when - // computing range selectivity - let distr_leq_freq = per_attribute_stats.distr.as_ref().unwrap().cdf(value); - let value = value.clone(); - let pred = Box::new(move |val: &AttributeCombValue| *val[0].as_ref().unwrap() <= value); - let mcvs_leq_freq = per_attribute_stats.mcvs.freq_over_pred(pred); - let ret_freq = distr_leq_freq + mcvs_leq_freq; - assert!( - (0.0..=1.0).contains(&ret_freq), - "ret_freq ({}) should be in [0, 1]", - ret_freq - ); - ret_freq - } - - /// Compute the frequency of values in a attribute less than the given value. - fn get_attribute_lt_value_freq( - &self, - attribute_stats: &AttributeCombValueStats, - table_id: TableId, - attr_base_index: usize, - value: &Value, - ) -> CostModelResult { - // depending on whether value is in mcvs or not, we use different logic to turn total_lt_cdf - // into total_leq_cdf this logic just so happens to be the exact same logic as - // get_attribute_equality_selectivity implements - let ret_freq = Self::get_attribute_leq_value_freq(attribute_stats, value) - - self.get_attribute_equality_selectivity(table_id, attr_base_index, value, true)?; - assert!( - (0.0..=1.0).contains(&ret_freq), - "ret_freq ({}) should be in [0, 1]", - ret_freq - ); - Ok(ret_freq) - } - - /// Get the selectivity of an expression of the form "attribute =/> value" (or "value - /// =/> attribute"). Computes selectivity based off of statistics. - /// Range predicates are handled entirely differently from equality predicates so this is its - /// own function. If it is unable to find the statistics, it returns DEFAULT_INEQ_SEL. - /// The selectivity is computed as quantile of the right bound minus quantile of the left bound. - fn get_attribute_range_selectivity( - &self, - table_id: TableId, - attr_base_index: usize, - start: Bound<&Value>, - end: Bound<&Value>, - ) -> CostModelResult { - // TODO: Consider attribute is a derived attribute - if let Some(attribute_stats) = - self.get_attribute_comb_stats(table_id, &[attr_base_index])? - { - let left_quantile = match start { - Bound::Unbounded => 0.0, - Bound::Included(value) => self.get_attribute_lt_value_freq( - &attribute_stats, - table_id, - attr_base_index, - value, - )?, - Bound::Excluded(value) => { - Self::get_attribute_leq_value_freq(&attribute_stats, value) - } - }; - let right_quantile = match end { - Bound::Unbounded => 1.0, - Bound::Included(value) => { - Self::get_attribute_leq_value_freq(&attribute_stats, value) - } - Bound::Excluded(value) => self.get_attribute_lt_value_freq( - &attribute_stats, - table_id, - attr_base_index, - value, - )?, - }; - assert!( - left_quantile <= right_quantile, - "left_quantile ({}) should be <= right_quantile ({})", - left_quantile, - right_quantile - ); - Ok(right_quantile - left_quantile) - } else { - Ok(DEFAULT_INEQ_SEL) - } - } - - /// Compute the selectivity of a (NOT) LIKE expression. - /// - /// The logic is somewhat similar to Postgres but different. Postgres first estimates the - /// histogram part of the population and then add up data for any MCV values. If the - /// histogram is large enough, it just uses the number of matches in the histogram, - /// otherwise it estimates the fixed prefix and remainder of pattern separately and - /// combine them. - /// - /// Our approach is simpler and less selective. Firstly, we don't use histogram. The selectivity - /// is composed of MCV frequency and non-MCV selectivity. MCV frequency is computed by - /// adding up frequencies of MCVs that match the pattern. Non-MCV selectivity is computed - /// in the same way that Postgres computes selectivity for the wildcard part of the pattern. - fn get_like_selectivity(&self, like_expr: &LikePred) -> CostModelResult { - let child = like_expr.child(); - - // Check child is a attribute ref. - if !matches!(child.typ, PredicateType::AttributeRef) { - return Ok(UNIMPLEMENTED_SEL); - } - - // Check pattern is a constant. - let pattern = like_expr.pattern(); - if !matches!(pattern.typ, PredicateType::Constant(_)) { - return Ok(UNIMPLEMENTED_SEL); - } - - let attr_ref_pred = AttributeRefPred::from_pred_node(child).unwrap(); - let attr_ref_idx = attr_ref_pred.attr_index(); - let table_id = attr_ref_pred.table_id(); - - // TODO: Consider attribute is a derived attribute - let pattern = ConstantPred::from_pred_node(pattern) - .expect("we already checked pattern is a constant") - .value() - .as_str(); - - // Compute the selectivity exculuding MCVs. - // See Postgres `like_selectivity`. - let non_mcv_sel = pattern - .chars() - .fold(1.0, |acc, c| { - if c == '%' { - acc * FULL_WILDCARD_SEL_FACTOR - } else { - acc * FIXED_CHAR_SEL_FACTOR - } - }) - .min(1.0); - - // Compute the selectivity in MCVs. - // TODO: Handle the case where `attribute_stats` is None. - if let Some(attribute_stats) = self.get_attribute_comb_stats(table_id, &[attr_ref_idx])? { - let (mcv_freq, null_frac) = { - let pred = Box::new(move |val: &AttributeCombValue| { - let string = - StringArray::from(vec![val[0].as_ref().unwrap().as_str().as_ref()]); - let pattern = StringArray::from(vec![pattern.as_ref()]); - like(&string, &pattern).unwrap().value(0) - }); - ( - attribute_stats.mcvs.freq_over_pred(pred), - attribute_stats.null_frac, - ) - }; - - let result = non_mcv_sel + mcv_freq; - - Ok(if like_expr.negated() { - 1.0 - result - null_frac - } else { - result - } - // Postgres clamps the result after histogram and before MCV. See Postgres - // `patternsel_common`. - .clamp(0.0001, 0.9999)) - } else { - Ok(UNIMPLEMENTED_SEL) - } - } - - /// Only support attrA in (val1, val2, val3) where attrA is a attribute ref and - /// val1, val2, val3 are constants. - pub fn get_in_list_selectivity(&self, expr: &InListPred) -> CostModelResult { - let child = expr.child(); - - // Check child is a attribute ref. - if !matches!(child.typ, PredicateType::AttributeRef) { - return Ok(UNIMPLEMENTED_SEL); - } - - // Check all expressions in the list are constants. - let list_exprs = expr.list().to_vec(); - if list_exprs - .iter() - .any(|expr| !matches!(expr.typ, PredicateType::Constant(_))) - { - return Ok(UNIMPLEMENTED_SEL); - } - - // Convert child and const expressions to concrete types. - let attr_ref_pred = AttributeRefPred::from_pred_node(child).unwrap(); - let attr_ref_idx = attr_ref_pred.attr_index(); - let table_id = attr_ref_pred.table_id(); - let list_exprs = list_exprs - .into_iter() - .map(|expr| { - ConstantPred::from_pred_node(expr) - .expect("we already checked all list elements are constants") - }) - .collect::>(); - let negated = expr.negated(); - - // TODO: Consider attribute is a derived attribute - let in_sel = list_exprs - .iter() - .try_fold(0.0, |acc, expr| { - let selectivity = self.get_attribute_equality_selectivity( - table_id, - attr_ref_idx, - &expr.value(), - /* is_equality */ true, - )?; - Ok(acc + selectivity) - })? - .min(1.0); - if negated { - Ok(1.0 - in_sel) - } else { - Ok(in_sel) - } - } - - /// Convert the left and right child nodes of some operation to what they semantically are. - /// This is convenient to avoid repeating the same logic just with "left" and "right" swapped. - /// The last return value is true when the input node (left) is a AttributeRefPred. - #[allow(clippy::type_complexity)] - fn get_semantic_nodes( - &self, - left: ArcPredicateNode, - right: ArcPredicateNode, - ) -> CostModelResult<( - Vec, - Vec, - Vec, - bool, - )> { - let mut attr_ref_exprs = vec![]; - let mut values = vec![]; - let mut non_attr_ref_exprs = vec![]; - let is_left_attr_ref; - - // Recursively unwrap casts as much as we can. - let mut uncasted_left = left; - let mut uncasted_right = right; - loop { - // println!("loop {}, uncasted_left={:?}, uncasted_right={:?}", Local::now(), - // uncasted_left, uncasted_right); - if uncasted_left.as_ref().typ == PredicateType::Cast - && uncasted_right.as_ref().typ == PredicateType::Cast - { - let left_cast_expr = CastPred::from_pred_node(uncasted_left) - .expect("we already checked that the type is Cast"); - let right_cast_expr = CastPred::from_pred_node(uncasted_right) - .expect("we already checked that the type is Cast"); - assert!(left_cast_expr.cast_to() == right_cast_expr.cast_to()); - uncasted_left = left_cast_expr.child().into_pred_node(); - uncasted_right = right_cast_expr.child().into_pred_node(); - } else if uncasted_left.as_ref().typ == PredicateType::Cast - || uncasted_right.as_ref().typ == PredicateType::Cast - { - let is_left_cast = uncasted_left.as_ref().typ == PredicateType::Cast; - let (mut cast_node, mut non_cast_node) = if is_left_cast { - (uncasted_left, uncasted_right) - } else { - (uncasted_right, uncasted_left) - }; - - let cast_expr = CastPred::from_pred_node(cast_node) - .expect("we already checked that the type is Cast"); - let cast_expr_child = cast_expr.child().into_pred_node(); - let cast_expr_cast_to = cast_expr.cast_to(); - - let should_break = match cast_expr_child.typ { - PredicateType::Constant(_) => { - cast_node = ConstantPred::new( - ConstantPred::from_pred_node(cast_expr_child) - .expect("we already checked that the type is Constant") - .value() - .convert_to_type(cast_expr_cast_to), - ) - .into_pred_node(); - false - } - PredicateType::AttributeRef => { - let attr_ref_expr = AttributeRefPred::from_pred_node(cast_expr_child) - .expect("we already checked that the type is AttributeRef"); - let attr_ref_idx = attr_ref_expr.attr_index(); - let table_id = attr_ref_expr.table_id(); - cast_node = attr_ref_expr.into_pred_node(); - // The "invert" cast is to invert the cast so that we're casting the - // non_cast_node to the attribute's original type. - // TODO(migration): double check - let invert_cast_data_type = &(self - .storage_manager - .get_attribute_info(table_id, attr_ref_idx as i32)? - .typ - .into_data_type()); - - match non_cast_node.typ { - PredicateType::AttributeRef => { - // In general, there's no way to remove the Cast here. We can't move - // the Cast to the other AttributeRef - // because that would lead to an infinite loop. Thus, we just leave - // the cast where it is and break. - true - } - _ => { - non_cast_node = - CastPred::new(non_cast_node, invert_cast_data_type.clone()) - .into_pred_node(); - false - } - } - } - _ => todo!(), - }; - - (uncasted_left, uncasted_right) = if is_left_cast { - (cast_node, non_cast_node) - } else { - (non_cast_node, cast_node) - }; - - if should_break { - break; - } - } else { - break; - } - } - - // Sort nodes into attr_ref_exprs, values, and non_attr_ref_exprs - match uncasted_left.as_ref().typ { - PredicateType::AttributeRef => { - is_left_attr_ref = true; - attr_ref_exprs.push( - AttributeRefPred::from_pred_node(uncasted_left) - .expect("we already checked that the type is AttributeRef"), - ); - } - PredicateType::Constant(_) => { - is_left_attr_ref = false; - values.push( - ConstantPred::from_pred_node(uncasted_left) - .expect("we already checked that the type is Constant") - .value(), - ) - } - _ => { - is_left_attr_ref = false; - non_attr_ref_exprs.push(uncasted_left); - } - } - match uncasted_right.as_ref().typ { - PredicateType::AttributeRef => { - attr_ref_exprs.push( - AttributeRefPred::from_pred_node(uncasted_right) - .expect("we already checked that the type is AttributeRef"), - ); - } - PredicateType::Constant(_) => values.push( - ConstantPred::from_pred_node(uncasted_right) - .expect("we already checked that the type is Constant") - .value(), - ), - _ => { - non_attr_ref_exprs.push(uncasted_right); - } - } - - assert!(attr_ref_exprs.len() + values.len() + non_attr_ref_exprs.len() == 2); - Ok((attr_ref_exprs, values, non_attr_ref_exprs, is_left_attr_ref)) - } - - /// The default selectivity of a comparison expression - /// Used when one side of the comparison is a attribute while the other side is something too - /// complex/impossible to evaluate (subquery, UDF, another attribute, we have no stats, etc.) - fn get_default_comparison_op_selectivity(comp_bin_op_typ: BinOpType) -> f64 { - assert!(comp_bin_op_typ.is_comparison()); - match comp_bin_op_typ { - BinOpType::Eq => DEFAULT_EQ_SEL, - BinOpType::Neq => 1.0 - DEFAULT_EQ_SEL, - BinOpType::Lt | BinOpType::Leq | BinOpType::Gt | BinOpType::Geq => DEFAULT_INEQ_SEL, - _ => unreachable!( - "all comparison BinOpTypes were enumerated. this should be unreachable" - ), - } - } -} diff --git a/optd-cost-model/src/cost/filter/attribute.rs b/optd-cost-model/src/cost/filter/attribute.rs new file mode 100644 index 0000000..b72802e --- /dev/null +++ b/optd-cost-model/src/cost/filter/attribute.rs @@ -0,0 +1,165 @@ +use std::ops::Bound; + +use optd_persistent::CostModelStorageLayer; + +use crate::{ + common::{types::TableId, values::Value}, + cost_model::CostModelImpl, + // TODO: If we return the default value, consider tell the upper level that we cannot + // compute the selectivity. + stats::{AttributeCombValue, AttributeCombValueStats, DEFAULT_EQ_SEL, DEFAULT_INEQ_SEL}, + CostModelResult, +}; + +impl CostModelImpl { + /// Get the selectivity of an expression of the form "attribute equals value" (or "value equals + /// attribute") Will handle the case of statistics missing + /// Equality predicates are handled entirely differently from range predicates so this is its + /// own function + /// Also, get_attribute_equality_selectivity is a subroutine when computing range + /// selectivity, which is another reason for separating these into two functions + /// is_eq means whether it's == or != + pub(crate) fn get_attribute_equality_selectivity( + &self, + table_id: TableId, + attr_base_index: usize, + value: &Value, + is_eq: bool, + ) -> CostModelResult { + // TODO: The attribute could be a derived attribute + let ret_sel = { + if let Some(attribute_stats) = + self.get_attribute_comb_stats(table_id, &[attr_base_index])? + { + let eq_freq = + if let Some(freq) = attribute_stats.mcvs.freq(&vec![Some(value.clone())]) { + freq + } else { + let non_mcv_freq = 1.0 - attribute_stats.mcvs.total_freq(); + // always safe because usize is at least as large as i32 + let ndistinct_as_usize = attribute_stats.ndistinct as usize; + let non_mcv_cnt = ndistinct_as_usize - attribute_stats.mcvs.cnt(); + if non_mcv_cnt == 0 { + return Ok(0.0); + } + // note that nulls are not included in ndistinct so we don't need to do non_mcv_cnt + // - 1 if null_frac > 0 + (non_mcv_freq - attribute_stats.null_frac) / (non_mcv_cnt as f64) + }; + if is_eq { + eq_freq + } else { + 1.0 - eq_freq - attribute_stats.null_frac + } + } else { + #[allow(clippy::collapsible_else_if)] + if is_eq { + DEFAULT_EQ_SEL + } else { + 1.0 - DEFAULT_EQ_SEL + } + } + }; + + assert!( + (0.0..=1.0).contains(&ret_sel), + "ret_sel ({}) should be in [0, 1]", + ret_sel + ); + Ok(ret_sel) + } + + /// Compute the frequency of values in a attribute less than or equal to the given value. + fn get_attribute_leq_value_freq( + per_attribute_stats: &AttributeCombValueStats, + value: &Value, + ) -> f64 { + // because distr does not include the values in MCVs, we need to compute the CDFs there as + // well because nulls return false in any comparison, they are never included when + // computing range selectivity + let distr_leq_freq = per_attribute_stats.distr.as_ref().unwrap().cdf(value); + let value = value.clone(); + let pred = Box::new(move |val: &AttributeCombValue| *val[0].as_ref().unwrap() <= value); + let mcvs_leq_freq = per_attribute_stats.mcvs.freq_over_pred(pred); + let ret_freq = distr_leq_freq + mcvs_leq_freq; + assert!( + (0.0..=1.0).contains(&ret_freq), + "ret_freq ({}) should be in [0, 1]", + ret_freq + ); + ret_freq + } + + /// Compute the frequency of values in a attribute less than the given value. + fn get_attribute_lt_value_freq( + &self, + attribute_stats: &AttributeCombValueStats, + table_id: TableId, + attr_base_index: usize, + value: &Value, + ) -> CostModelResult { + // depending on whether value is in mcvs or not, we use different logic to turn total_lt_cdf + // into total_leq_cdf this logic just so happens to be the exact same logic as + // get_attribute_equality_selectivity implements + let ret_freq = Self::get_attribute_leq_value_freq(attribute_stats, value) + - self.get_attribute_equality_selectivity(table_id, attr_base_index, value, true)?; + assert!( + (0.0..=1.0).contains(&ret_freq), + "ret_freq ({}) should be in [0, 1]", + ret_freq + ); + Ok(ret_freq) + } + + /// Get the selectivity of an expression of the form "attribute =/> value" (or "value + /// =/> attribute"). Computes selectivity based off of statistics. + /// Range predicates are handled entirely differently from equality predicates so this is its + /// own function. If it is unable to find the statistics, it returns DEFAULT_INEQ_SEL. + /// The selectivity is computed as quantile of the right bound minus quantile of the left bound. + pub(crate) fn get_attribute_range_selectivity( + &self, + table_id: TableId, + attr_base_index: usize, + start: Bound<&Value>, + end: Bound<&Value>, + ) -> CostModelResult { + // TODO: Consider attribute is a derived attribute + if let Some(attribute_stats) = + self.get_attribute_comb_stats(table_id, &[attr_base_index])? + { + let left_quantile = match start { + Bound::Unbounded => 0.0, + Bound::Included(value) => self.get_attribute_lt_value_freq( + &attribute_stats, + table_id, + attr_base_index, + value, + )?, + Bound::Excluded(value) => { + Self::get_attribute_leq_value_freq(&attribute_stats, value) + } + }; + let right_quantile = match end { + Bound::Unbounded => 1.0, + Bound::Included(value) => { + Self::get_attribute_leq_value_freq(&attribute_stats, value) + } + Bound::Excluded(value) => self.get_attribute_lt_value_freq( + &attribute_stats, + table_id, + attr_base_index, + value, + )?, + }; + assert!( + left_quantile <= right_quantile, + "left_quantile ({}) should be <= right_quantile ({})", + left_quantile, + right_quantile + ); + Ok(right_quantile - left_quantile) + } else { + Ok(DEFAULT_INEQ_SEL) + } + } +} diff --git a/optd-cost-model/src/cost/filter/comp_op.rs b/optd-cost-model/src/cost/filter/comp_op.rs new file mode 100644 index 0000000..7bde869 --- /dev/null +++ b/optd-cost-model/src/cost/filter/comp_op.rs @@ -0,0 +1,274 @@ +use std::ops::Bound; + +use optd_persistent::CostModelStorageLayer; + +use crate::{ + common::{ + nodes::{ArcPredicateNode, PredicateType, ReprPredicateNode}, + predicates::{ + attr_ref_pred::AttributeRefPred, bin_op_pred::BinOpType, cast_pred::CastPred, + constant_pred::ConstantPred, + }, + values::Value, + }, + cost_model::CostModelImpl, + // TODO: If we return the default value, consider tell the upper level that we cannot + // compute the selectivity. + stats::{DEFAULT_EQ_SEL, DEFAULT_INEQ_SEL, UNIMPLEMENTED_SEL}, + CostModelResult, +}; + +impl CostModelImpl { + /// Comparison operators are the base case for recursion in get_filter_selectivity() + pub(crate) fn get_comp_op_selectivity( + &self, + comp_bin_op_typ: BinOpType, + left: ArcPredicateNode, + right: ArcPredicateNode, + ) -> CostModelResult { + assert!(comp_bin_op_typ.is_comparison()); + + // I intentionally performed moves on left and right. This way, we don't accidentally use + // them after this block + let (attr_ref_exprs, values, non_attr_ref_exprs, is_left_attr_ref) = + self.get_semantic_nodes(left, right)?; + + // Handle the different cases of semantic nodes. + if attr_ref_exprs.is_empty() { + Ok(UNIMPLEMENTED_SEL) + } else if attr_ref_exprs.len() == 1 { + let attr_ref_expr = attr_ref_exprs + .first() + .expect("we just checked that attr_ref_exprs.len() == 1"); + let attr_ref_idx = attr_ref_expr.attr_index(); + let table_id = attr_ref_expr.table_id(); + + // TODO: Consider attribute is a derived attribute + if values.len() == 1 { + let value = values + .first() + .expect("we just checked that values.len() == 1"); + match comp_bin_op_typ { + BinOpType::Eq => { + self.get_attribute_equality_selectivity(table_id, attr_ref_idx, value, true) + } + BinOpType::Neq => self.get_attribute_equality_selectivity( + table_id, + attr_ref_idx, + value, + false, + ), + BinOpType::Lt | BinOpType::Leq | BinOpType::Gt | BinOpType::Geq => { + let start = match (comp_bin_op_typ, is_left_attr_ref) { + (BinOpType::Lt, true) | (BinOpType::Geq, false) => Bound::Unbounded, + (BinOpType::Leq, true) | (BinOpType::Gt, false) => Bound::Unbounded, + (BinOpType::Gt, true) | (BinOpType::Leq, false) => Bound::Excluded(value), + (BinOpType::Geq, true) | (BinOpType::Lt, false) => Bound::Included(value), + _ => unreachable!("all comparison BinOpTypes were enumerated. this should be unreachable"), + }; + let end = match (comp_bin_op_typ, is_left_attr_ref) { + (BinOpType::Lt, true) | (BinOpType::Geq, false) => Bound::Excluded(value), + (BinOpType::Leq, true) | (BinOpType::Gt, false) => Bound::Included(value), + (BinOpType::Gt, true) | (BinOpType::Leq, false) => Bound::Unbounded, + (BinOpType::Geq, true) | (BinOpType::Lt, false) => Bound::Unbounded, + _ => unreachable!("all comparison BinOpTypes were enumerated. this should be unreachable"), + }; + self.get_attribute_range_selectivity(table_id, attr_ref_idx, start, end) + } + _ => unreachable!( + "all comparison BinOpTypes were enumerated. this should be unreachable" + ), + } + } else { + let non_attr_ref_expr = non_attr_ref_exprs.first().expect( + "non_attr_ref_exprs should have a value since attr_ref_exprs.len() == 1", + ); + + match non_attr_ref_expr.as_ref().typ { + PredicateType::BinOp(_) => { + Ok(Self::get_default_comparison_op_selectivity(comp_bin_op_typ)) + } + PredicateType::Cast => Ok(UNIMPLEMENTED_SEL), + PredicateType::Constant(_) => { + unreachable!("we should have handled this in the values.len() == 1 branch") + } + _ => unimplemented!( + "unhandled case of comparing a attribute ref node to {}", + non_attr_ref_expr.as_ref().typ + ), + } + } + } else if attr_ref_exprs.len() == 2 { + Ok(Self::get_default_comparison_op_selectivity(comp_bin_op_typ)) + } else { + unreachable!("we could have at most pushed left and right into attr_ref_exprs") + } + } + + /// Convert the left and right child nodes of some operation to what they semantically are. + /// This is convenient to avoid repeating the same logic just with "left" and "right" swapped. + /// The last return value is true when the input node (left) is a AttributeRefPred. + #[allow(clippy::type_complexity)] + fn get_semantic_nodes( + &self, + left: ArcPredicateNode, + right: ArcPredicateNode, + ) -> CostModelResult<( + Vec, + Vec, + Vec, + bool, + )> { + let mut attr_ref_exprs = vec![]; + let mut values = vec![]; + let mut non_attr_ref_exprs = vec![]; + let is_left_attr_ref; + + // Recursively unwrap casts as much as we can. + let mut uncasted_left = left; + let mut uncasted_right = right; + loop { + // println!("loop {}, uncasted_left={:?}, uncasted_right={:?}", Local::now(), + // uncasted_left, uncasted_right); + if uncasted_left.as_ref().typ == PredicateType::Cast + && uncasted_right.as_ref().typ == PredicateType::Cast + { + let left_cast_expr = CastPred::from_pred_node(uncasted_left) + .expect("we already checked that the type is Cast"); + let right_cast_expr = CastPred::from_pred_node(uncasted_right) + .expect("we already checked that the type is Cast"); + assert!(left_cast_expr.cast_to() == right_cast_expr.cast_to()); + uncasted_left = left_cast_expr.child().into_pred_node(); + uncasted_right = right_cast_expr.child().into_pred_node(); + } else if uncasted_left.as_ref().typ == PredicateType::Cast + || uncasted_right.as_ref().typ == PredicateType::Cast + { + let is_left_cast = uncasted_left.as_ref().typ == PredicateType::Cast; + let (mut cast_node, mut non_cast_node) = if is_left_cast { + (uncasted_left, uncasted_right) + } else { + (uncasted_right, uncasted_left) + }; + + let cast_expr = CastPred::from_pred_node(cast_node) + .expect("we already checked that the type is Cast"); + let cast_expr_child = cast_expr.child().into_pred_node(); + let cast_expr_cast_to = cast_expr.cast_to(); + + let should_break = match cast_expr_child.typ { + PredicateType::Constant(_) => { + cast_node = ConstantPred::new( + ConstantPred::from_pred_node(cast_expr_child) + .expect("we already checked that the type is Constant") + .value() + .convert_to_type(cast_expr_cast_to), + ) + .into_pred_node(); + false + } + PredicateType::AttributeRef => { + let attr_ref_expr = AttributeRefPred::from_pred_node(cast_expr_child) + .expect("we already checked that the type is AttributeRef"); + let attr_ref_idx = attr_ref_expr.attr_index(); + let table_id = attr_ref_expr.table_id(); + cast_node = attr_ref_expr.into_pred_node(); + // The "invert" cast is to invert the cast so that we're casting the + // non_cast_node to the attribute's original type. + // TODO(migration): double check + let invert_cast_data_type = &(self + .storage_manager + .get_attribute_info(table_id, attr_ref_idx as i32)? + .typ + .into_data_type()); + + match non_cast_node.typ { + PredicateType::AttributeRef => { + // In general, there's no way to remove the Cast here. We can't move + // the Cast to the other AttributeRef + // because that would lead to an infinite loop. Thus, we just leave + // the cast where it is and break. + true + } + _ => { + non_cast_node = + CastPred::new(non_cast_node, invert_cast_data_type.clone()) + .into_pred_node(); + false + } + } + } + _ => todo!(), + }; + + (uncasted_left, uncasted_right) = if is_left_cast { + (cast_node, non_cast_node) + } else { + (non_cast_node, cast_node) + }; + + if should_break { + break; + } + } else { + break; + } + } + + // Sort nodes into attr_ref_exprs, values, and non_attr_ref_exprs + match uncasted_left.as_ref().typ { + PredicateType::AttributeRef => { + is_left_attr_ref = true; + attr_ref_exprs.push( + AttributeRefPred::from_pred_node(uncasted_left) + .expect("we already checked that the type is AttributeRef"), + ); + } + PredicateType::Constant(_) => { + is_left_attr_ref = false; + values.push( + ConstantPred::from_pred_node(uncasted_left) + .expect("we already checked that the type is Constant") + .value(), + ) + } + _ => { + is_left_attr_ref = false; + non_attr_ref_exprs.push(uncasted_left); + } + } + match uncasted_right.as_ref().typ { + PredicateType::AttributeRef => { + attr_ref_exprs.push( + AttributeRefPred::from_pred_node(uncasted_right) + .expect("we already checked that the type is AttributeRef"), + ); + } + PredicateType::Constant(_) => values.push( + ConstantPred::from_pred_node(uncasted_right) + .expect("we already checked that the type is Constant") + .value(), + ), + _ => { + non_attr_ref_exprs.push(uncasted_right); + } + } + + assert!(attr_ref_exprs.len() + values.len() + non_attr_ref_exprs.len() == 2); + Ok((attr_ref_exprs, values, non_attr_ref_exprs, is_left_attr_ref)) + } + + /// The default selectivity of a comparison expression + /// Used when one side of the comparison is a attribute while the other side is something too + /// complex/impossible to evaluate (subquery, UDF, another attribute, we have no stats, etc.) + fn get_default_comparison_op_selectivity(comp_bin_op_typ: BinOpType) -> f64 { + assert!(comp_bin_op_typ.is_comparison()); + match comp_bin_op_typ { + BinOpType::Eq => DEFAULT_EQ_SEL, + BinOpType::Neq => 1.0 - DEFAULT_EQ_SEL, + BinOpType::Lt | BinOpType::Leq | BinOpType::Gt | BinOpType::Geq => DEFAULT_INEQ_SEL, + _ => unreachable!( + "all comparison BinOpTypes were enumerated. this should be unreachable" + ), + } + } +} diff --git a/optd-cost-model/src/cost/filter/constant.rs b/optd-cost-model/src/cost/filter/constant.rs new file mode 100644 index 0000000..b9c2cc8 --- /dev/null +++ b/optd-cost-model/src/cost/filter/constant.rs @@ -0,0 +1,39 @@ +use optd_persistent::CostModelStorageLayer; + +use crate::{ + common::{ + nodes::{ArcPredicateNode, PredicateType}, + predicates::constant_pred::ConstantType, + values::Value, + }, + cost_model::CostModelImpl, +}; + +impl CostModelImpl { + pub(crate) fn get_constant_selectivity(const_node: ArcPredicateNode) -> f64 { + if let PredicateType::Constant(const_typ) = const_node.typ { + if matches!(const_typ, ConstantType::Bool) { + let value = const_node + .as_ref() + .data + .as_ref() + .expect("constants should have data"); + if let Value::Bool(bool_value) = value { + if *bool_value { + 1.0 + } else { + 0.0 + } + } else { + unreachable!( + "if the typ is ConstantType::Bool, the value should be a Value::Bool" + ) + } + } else { + panic!("selectivity is not defined on constants which are not bools") + } + } else { + panic!("get_constant_selectivity must be called on a constant") + } + } +} diff --git a/optd-cost-model/src/cost/filter/controller.rs b/optd-cost-model/src/cost/filter/controller.rs new file mode 100644 index 0000000..c0ce8c9 --- /dev/null +++ b/optd-cost-model/src/cost/filter/controller.rs @@ -0,0 +1,85 @@ +use optd_persistent::CostModelStorageLayer; + +use crate::{ + common::{ + nodes::{ArcPredicateNode, PredicateType, ReprPredicateNode}, + predicates::{in_list_pred::InListPred, like_pred::LikePred, un_op_pred::UnOpType}, + }, + cost_model::CostModelImpl, + stats::UNIMPLEMENTED_SEL, + CostModelResult, EstimatedStatistic, +}; + +impl CostModelImpl { + // TODO: is it a good design to pass table_id here? I think it needs to be refactored. + // Consider to remove table_id. + pub fn get_filter_row_cnt( + &self, + child_row_cnt: EstimatedStatistic, + cond: ArcPredicateNode, + ) -> CostModelResult { + let selectivity = { self.get_filter_selectivity(cond)? }; + Ok( + EstimatedStatistic((child_row_cnt.0 as f64 * selectivity) as u64) + .max(EstimatedStatistic(1)), + ) + } + + pub fn get_filter_selectivity(&self, expr_tree: ArcPredicateNode) -> CostModelResult { + match &expr_tree.typ { + PredicateType::Constant(_) => Ok(Self::get_constant_selectivity(expr_tree)), + PredicateType::AttributeRef => unimplemented!("check bool type or else panic"), + PredicateType::UnOp(un_op_typ) => { + assert!(expr_tree.children.len() == 1); + let child = expr_tree.child(0); + match un_op_typ { + // not doesn't care about nulls so there's no complex logic. it just reverses + // the selectivity for instance, != _will not_ include nulls + // but "NOT ==" _will_ include nulls + UnOpType::Not => Ok(1.0 - self.get_filter_selectivity(child)?), + UnOpType::Neg => panic!( + "the selectivity of operations that return numerical values is undefined" + ), + } + } + PredicateType::BinOp(bin_op_typ) => { + assert!(expr_tree.children.len() == 2); + let left_child = expr_tree.child(0); + let right_child = expr_tree.child(1); + + if bin_op_typ.is_comparison() { + self.get_comp_op_selectivity(*bin_op_typ, left_child, right_child) + } else if bin_op_typ.is_numerical() { + panic!( + "the selectivity of operations that return numerical values is undefined" + ) + } else { + unreachable!("all BinOpTypes should be true for at least one is_*() function") + } + } + PredicateType::LogOp(log_op_typ) => { + self.get_log_op_selectivity(*log_op_typ, &expr_tree.children) + } + PredicateType::Func(_) => unimplemented!("check bool type or else panic"), + PredicateType::SortOrder(_) => { + panic!("the selectivity of sort order expressions is undefined") + } + PredicateType::Between => Ok(UNIMPLEMENTED_SEL), + PredicateType::Cast => unimplemented!("check bool type or else panic"), + PredicateType::Like => { + let like_expr = LikePred::from_pred_node(expr_tree).unwrap(); + self.get_like_selectivity(&like_expr) + } + PredicateType::DataType(_) => { + panic!("the selectivity of a data type is not defined") + } + PredicateType::InList => { + let in_list_expr = InListPred::from_pred_node(expr_tree).unwrap(); + self.get_in_list_selectivity(&in_list_expr) + } + _ => unreachable!( + "all expression DfPredType were enumerated. this should be unreachable" + ), + } + } +} diff --git a/optd-cost-model/src/cost/filter/in_list.rs b/optd-cost-model/src/cost/filter/in_list.rs new file mode 100644 index 0000000..cc2a570 --- /dev/null +++ b/optd-cost-model/src/cost/filter/in_list.rs @@ -0,0 +1,67 @@ +use optd_persistent::CostModelStorageLayer; + +use crate::{ + common::{ + nodes::{PredicateType, ReprPredicateNode}, + predicates::{ + attr_ref_pred::AttributeRefPred, constant_pred::ConstantPred, in_list_pred::InListPred, + }, + }, + cost_model::CostModelImpl, + stats::UNIMPLEMENTED_SEL, + CostModelResult, +}; + +impl CostModelImpl { + /// Only support attrA in (val1, val2, val3) where attrA is a attribute ref and + /// val1, val2, val3 are constants. + pub(crate) fn get_in_list_selectivity(&self, expr: &InListPred) -> CostModelResult { + let child = expr.child(); + + // Check child is a attribute ref. + if !matches!(child.typ, PredicateType::AttributeRef) { + return Ok(UNIMPLEMENTED_SEL); + } + + // Check all expressions in the list are constants. + let list_exprs = expr.list().to_vec(); + if list_exprs + .iter() + .any(|expr| !matches!(expr.typ, PredicateType::Constant(_))) + { + return Ok(UNIMPLEMENTED_SEL); + } + + // Convert child and const expressions to concrete types. + let attr_ref_pred = AttributeRefPred::from_pred_node(child).unwrap(); + let attr_ref_idx = attr_ref_pred.attr_index(); + let table_id = attr_ref_pred.table_id(); + let list_exprs = list_exprs + .into_iter() + .map(|expr| { + ConstantPred::from_pred_node(expr) + .expect("we already checked all list elements are constants") + }) + .collect::>(); + let negated = expr.negated(); + + // TODO: Consider attribute is a derived attribute + let in_sel = list_exprs + .iter() + .try_fold(0.0, |acc, expr| { + let selectivity = self.get_attribute_equality_selectivity( + table_id, + attr_ref_idx, + &expr.value(), + /* is_equality */ true, + )?; + Ok(acc + selectivity) + })? + .min(1.0); + if negated { + Ok(1.0 - in_sel) + } else { + Ok(in_sel) + } + } +} diff --git a/optd-cost-model/src/cost/filter/like.rs b/optd-cost-model/src/cost/filter/like.rs new file mode 100644 index 0000000..f8a1ab4 --- /dev/null +++ b/optd-cost-model/src/cost/filter/like.rs @@ -0,0 +1,98 @@ +use datafusion::arrow::{array::StringArray, compute::like}; +use optd_persistent::CostModelStorageLayer; + +use crate::{ + common::{ + nodes::{PredicateType, ReprPredicateNode}, + predicates::{ + attr_ref_pred::AttributeRefPred, constant_pred::ConstantPred, like_pred::LikePred, + }, + }, + cost_model::CostModelImpl, + stats::{ + AttributeCombValue, FIXED_CHAR_SEL_FACTOR, FULL_WILDCARD_SEL_FACTOR, UNIMPLEMENTED_SEL, + }, + CostModelResult, +}; + +impl CostModelImpl { + /// Compute the selectivity of a (NOT) LIKE expression. + /// + /// The logic is somewhat similar to Postgres but different. Postgres first estimates the + /// histogram part of the population and then add up data for any MCV values. If the + /// histogram is large enough, it just uses the number of matches in the histogram, + /// otherwise it estimates the fixed prefix and remainder of pattern separately and + /// combine them. + /// + /// Our approach is simpler and less selective. Firstly, we don't use histogram. The selectivity + /// is composed of MCV frequency and non-MCV selectivity. MCV frequency is computed by + /// adding up frequencies of MCVs that match the pattern. Non-MCV selectivity is computed + /// in the same way that Postgres computes selectivity for the wildcard part of the pattern. + pub(crate) fn get_like_selectivity(&self, like_expr: &LikePred) -> CostModelResult { + let child = like_expr.child(); + + // Check child is a attribute ref. + if !matches!(child.typ, PredicateType::AttributeRef) { + return Ok(UNIMPLEMENTED_SEL); + } + + // Check pattern is a constant. + let pattern = like_expr.pattern(); + if !matches!(pattern.typ, PredicateType::Constant(_)) { + return Ok(UNIMPLEMENTED_SEL); + } + + let attr_ref_pred = AttributeRefPred::from_pred_node(child).unwrap(); + let attr_ref_idx = attr_ref_pred.attr_index(); + let table_id = attr_ref_pred.table_id(); + + // TODO: Consider attribute is a derived attribute + let pattern = ConstantPred::from_pred_node(pattern) + .expect("we already checked pattern is a constant") + .value() + .as_str(); + + // Compute the selectivity exculuding MCVs. + // See Postgres `like_selectivity`. + let non_mcv_sel = pattern + .chars() + .fold(1.0, |acc, c| { + if c == '%' { + acc * FULL_WILDCARD_SEL_FACTOR + } else { + acc * FIXED_CHAR_SEL_FACTOR + } + }) + .min(1.0); + + // Compute the selectivity in MCVs. + // TODO: Handle the case where `attribute_stats` is None. + if let Some(attribute_stats) = self.get_attribute_comb_stats(table_id, &[attr_ref_idx])? { + let (mcv_freq, null_frac) = { + let pred = Box::new(move |val: &AttributeCombValue| { + let string = + StringArray::from(vec![val[0].as_ref().unwrap().as_str().as_ref()]); + let pattern = StringArray::from(vec![pattern.as_ref()]); + like(&string, &pattern).unwrap().value(0) + }); + ( + attribute_stats.mcvs.freq_over_pred(pred), + attribute_stats.null_frac, + ) + }; + + let result = non_mcv_sel + mcv_freq; + + Ok(if like_expr.negated() { + 1.0 - result - null_frac + } else { + result + } + // Postgres clamps the result after histogram and before MCV. See Postgres + // `patternsel_common`. + .clamp(0.0001, 0.9999)) + } else { + Ok(UNIMPLEMENTED_SEL) + } + } +} diff --git a/optd-cost-model/src/cost/filter/log_op.rs b/optd-cost-model/src/cost/filter/log_op.rs new file mode 100644 index 0000000..46e6a21 --- /dev/null +++ b/optd-cost-model/src/cost/filter/log_op.rs @@ -0,0 +1,29 @@ +use optd_persistent::CostModelStorageLayer; + +use crate::{ + common::{nodes::ArcPredicateNode, predicates::log_op_pred::LogOpType}, + cost_model::CostModelImpl, + CostModelResult, +}; + +impl CostModelImpl { + pub(crate) fn get_log_op_selectivity( + &self, + log_op_typ: LogOpType, + children: &[ArcPredicateNode], + ) -> CostModelResult { + match log_op_typ { + LogOpType::And => children.iter().try_fold(1.0, |acc, child| { + let selectivity = self.get_filter_selectivity(child.clone())?; + Ok(acc * selectivity) + }), + LogOpType::Or => { + let product = children.iter().try_fold(1.0, |acc, child| { + let selectivity = self.get_filter_selectivity(child.clone())?; + Ok(acc * (1.0 - selectivity)) + })?; + Ok(1.0 - product) + } + } + } +} diff --git a/optd-cost-model/src/cost/filter/mod.rs b/optd-cost-model/src/cost/filter/mod.rs new file mode 100644 index 0000000..bf1d5ab --- /dev/null +++ b/optd-cost-model/src/cost/filter/mod.rs @@ -0,0 +1,7 @@ +pub mod attribute; +pub mod comp_op; +pub mod constant; +pub mod controller; +pub mod in_list; +pub mod like; +pub mod log_op; From dd6598a40237d13ab73ff426da0f05d3465fb48b Mon Sep 17 00:00:00 2001 From: Lan Lou Date: Sat, 16 Nov 2024 12:50:13 -0500 Subject: [PATCH 21/51] Resolve conflict with main --- Cargo.lock | 22 ++++ .../src/common/predicates/constant_pred.rs | 10 ++ optd-cost-model/src/cost/agg.rs | 20 ++-- optd-cost-model/src/cost/filter/attribute.rs | 50 ++++---- optd-cost-model/src/cost/filter/comp_op.rs | 42 ++++--- optd-cost-model/src/cost/filter/controller.rs | 113 +++++++++--------- optd-cost-model/src/cost/filter/in_list.rs | 19 +-- optd-cost-model/src/cost/filter/like.rs | 7 +- optd-cost-model/src/cost/filter/log_op.rs | 25 ++-- optd-cost-model/src/cost_model.rs | 3 +- optd-cost-model/src/lib.rs | 9 +- optd-cost-model/src/storage.rs | 36 ++++-- optd-persistent/Cargo.toml | 1 + optd-persistent/src/cost_model/interface.rs | 7 +- optd-persistent/src/cost_model/orm.rs | 32 +++-- 15 files changed, 250 insertions(+), 146 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index bf0b367..d47ecc6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2177,6 +2177,27 @@ dependencies = [ "libc", ] +[[package]] +name = "num_enum" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e613fc340b2220f734a8595782c551f1250e969d87d3be1ae0579e8d4065179" +dependencies = [ + "num_enum_derive", +] + +[[package]] +name = "num_enum_derive" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af1844ef2428cc3e1cb900be36181049ef3d3193c63e43026cfe202983b27a56" +dependencies = [ + "proc-macro-crate", + "proc-macro2", + "quote", + "syn 2.0.87", +] + [[package]] name = "object" version = "0.36.5" @@ -2237,6 +2258,7 @@ version = "0.1.0" dependencies = [ "async-stream", "async-trait", + "num_enum", "sea-orm", "sea-orm-migration", "serde_json", diff --git a/optd-cost-model/src/common/predicates/constant_pred.rs b/optd-cost-model/src/common/predicates/constant_pred.rs index 2fa06ae..61285f7 100644 --- a/optd-cost-model/src/common/predicates/constant_pred.rs +++ b/optd-cost-model/src/common/predicates/constant_pred.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use arrow_schema::{DataType, IntervalUnit}; +use optd_persistent::cost_model::interface::AttrType; use serde::{Deserialize, Serialize}; use crate::common::{ @@ -90,6 +91,15 @@ impl ConstantType { ConstantType::Utf8String => DataType::Utf8, } } + + pub fn from_persistent_attr_type(attr_type: AttrType) -> Self { + match attr_type { + AttrType::Integer => ConstantType::Int32, + AttrType::Float => ConstantType::Float64, + AttrType::Varchar => ConstantType::Utf8String, + AttrType::Boolean => ConstantType::Bool, + } + } } #[derive(Clone, Debug)] diff --git a/optd-cost-model/src/cost/agg.rs b/optd-cost-model/src/cost/agg.rs index a3de5aa..ff82fec 100644 --- a/optd-cost-model/src/cost/agg.rs +++ b/optd-cost-model/src/cost/agg.rs @@ -8,11 +8,11 @@ use crate::{ }, cost_model::CostModelImpl, stats::DEFAULT_NUM_DISTINCT, - CostModelError, CostModelResult, EstimatedStatistic, + CostModelError, CostModelResult, EstimatedStatistic, SemanticError, }; impl CostModelImpl { - pub fn get_agg_row_cnt( + pub async fn get_agg_row_cnt( &self, group_by: ArcPredicateNode, ) -> CostModelResult { @@ -22,22 +22,24 @@ impl CostModelImpl { } else { // Multiply the n-distinct of all the group by columns. // TODO: improve with multi-dimensional n-distinct - let row_cnt = group_by.0.children.iter().try_fold(1, |acc, node| { + let mut row_cnt = 1; + + for node in &group_by.0.children { match node.typ { PredicateType::AttributeRef => { let attr_ref = AttributeRefPred::from_pred_node(node.clone()).ok_or_else(|| { - CostModelError::InvalidPredicate( + SemanticError::InvalidPredicate( "Expected AttributeRef predicate".to_string(), ) })?; if attr_ref.is_derived() { - Ok(acc * DEFAULT_NUM_DISTINCT) + row_cnt *= DEFAULT_NUM_DISTINCT; } else { let table_id = attr_ref.table_id(); let attr_idx = attr_ref.attr_index(); let stats_option = - self.get_attribute_comb_stats(table_id, &[attr_idx])?; + self.get_attribute_comb_stats(table_id, &[attr_idx]).await?; let ndistinct = match stats_option { Some(stats) => stats.ndistinct, @@ -46,15 +48,15 @@ impl CostModelImpl { DEFAULT_NUM_DISTINCT } }; - Ok(acc * ndistinct) + row_cnt *= ndistinct; } } _ => { // TODO: Consider the case where `GROUP BY 1`. - panic!("GROUP BY must have attribute ref predicate") + panic!("GROUP BY must have attribute ref predicate"); } } - })?; + } Ok(EstimatedStatistic(row_cnt)) } } diff --git a/optd-cost-model/src/cost/filter/attribute.rs b/optd-cost-model/src/cost/filter/attribute.rs index b72802e..c5ad90c 100644 --- a/optd-cost-model/src/cost/filter/attribute.rs +++ b/optd-cost-model/src/cost/filter/attribute.rs @@ -19,7 +19,7 @@ impl CostModelImpl { /// Also, get_attribute_equality_selectivity is a subroutine when computing range /// selectivity, which is another reason for separating these into two functions /// is_eq means whether it's == or != - pub(crate) fn get_attribute_equality_selectivity( + pub(crate) async fn get_attribute_equality_selectivity( &self, table_id: TableId, attr_base_index: usize, @@ -28,8 +28,9 @@ impl CostModelImpl { ) -> CostModelResult { // TODO: The attribute could be a derived attribute let ret_sel = { - if let Some(attribute_stats) = - self.get_attribute_comb_stats(table_id, &[attr_base_index])? + if let Some(attribute_stats) = self + .get_attribute_comb_stats(table_id, &[attr_base_index]) + .await? { let eq_freq = if let Some(freq) = attribute_stats.mcvs.freq(&vec![Some(value.clone())]) { @@ -91,7 +92,7 @@ impl CostModelImpl { } /// Compute the frequency of values in a attribute less than the given value. - fn get_attribute_lt_value_freq( + async fn get_attribute_lt_value_freq( &self, attribute_stats: &AttributeCombValueStats, table_id: TableId, @@ -102,7 +103,9 @@ impl CostModelImpl { // into total_leq_cdf this logic just so happens to be the exact same logic as // get_attribute_equality_selectivity implements let ret_freq = Self::get_attribute_leq_value_freq(attribute_stats, value) - - self.get_attribute_equality_selectivity(table_id, attr_base_index, value, true)?; + - self + .get_attribute_equality_selectivity(table_id, attr_base_index, value, true) + .await?; assert!( (0.0..=1.0).contains(&ret_freq), "ret_freq ({}) should be in [0, 1]", @@ -116,7 +119,7 @@ impl CostModelImpl { /// Range predicates are handled entirely differently from equality predicates so this is its /// own function. If it is unable to find the statistics, it returns DEFAULT_INEQ_SEL. /// The selectivity is computed as quantile of the right bound minus quantile of the left bound. - pub(crate) fn get_attribute_range_selectivity( + pub(crate) async fn get_attribute_range_selectivity( &self, table_id: TableId, attr_base_index: usize, @@ -124,17 +127,21 @@ impl CostModelImpl { end: Bound<&Value>, ) -> CostModelResult { // TODO: Consider attribute is a derived attribute - if let Some(attribute_stats) = - self.get_attribute_comb_stats(table_id, &[attr_base_index])? + if let Some(attribute_stats) = self + .get_attribute_comb_stats(table_id, &[attr_base_index]) + .await? { let left_quantile = match start { Bound::Unbounded => 0.0, - Bound::Included(value) => self.get_attribute_lt_value_freq( - &attribute_stats, - table_id, - attr_base_index, - value, - )?, + Bound::Included(value) => { + self.get_attribute_lt_value_freq( + &attribute_stats, + table_id, + attr_base_index, + value, + ) + .await? + } Bound::Excluded(value) => { Self::get_attribute_leq_value_freq(&attribute_stats, value) } @@ -144,12 +151,15 @@ impl CostModelImpl { Bound::Included(value) => { Self::get_attribute_leq_value_freq(&attribute_stats, value) } - Bound::Excluded(value) => self.get_attribute_lt_value_freq( - &attribute_stats, - table_id, - attr_base_index, - value, - )?, + Bound::Excluded(value) => { + self.get_attribute_lt_value_freq( + &attribute_stats, + table_id, + attr_base_index, + value, + ) + .await? + } }; assert!( left_quantile <= right_quantile, diff --git a/optd-cost-model/src/cost/filter/comp_op.rs b/optd-cost-model/src/cost/filter/comp_op.rs index 7bde869..0a2092e 100644 --- a/optd-cost-model/src/cost/filter/comp_op.rs +++ b/optd-cost-model/src/cost/filter/comp_op.rs @@ -16,11 +16,12 @@ use crate::{ // compute the selectivity. stats::{DEFAULT_EQ_SEL, DEFAULT_INEQ_SEL, UNIMPLEMENTED_SEL}, CostModelResult, + SemanticError, }; impl CostModelImpl { /// Comparison operators are the base case for recursion in get_filter_selectivity() - pub(crate) fn get_comp_op_selectivity( + pub(crate) async fn get_comp_op_selectivity( &self, comp_bin_op_typ: BinOpType, left: ArcPredicateNode, @@ -30,8 +31,11 @@ impl CostModelImpl { // I intentionally performed moves on left and right. This way, we don't accidentally use // them after this block - let (attr_ref_exprs, values, non_attr_ref_exprs, is_left_attr_ref) = - self.get_semantic_nodes(left, right)?; + let semantic_res = self.get_semantic_nodes(left, right).await; + if semantic_res.is_err() { + return Ok(Self::get_default_comparison_op_selectivity(comp_bin_op_typ)); + } + let (attr_ref_exprs, values, non_attr_ref_exprs, is_left_attr_ref) = semantic_res.unwrap(); // Handle the different cases of semantic nodes. if attr_ref_exprs.is_empty() { @@ -51,13 +55,17 @@ impl CostModelImpl { match comp_bin_op_typ { BinOpType::Eq => { self.get_attribute_equality_selectivity(table_id, attr_ref_idx, value, true) + .await + } + BinOpType::Neq => { + self.get_attribute_equality_selectivity( + table_id, + attr_ref_idx, + value, + false, + ) + .await } - BinOpType::Neq => self.get_attribute_equality_selectivity( - table_id, - attr_ref_idx, - value, - false, - ), BinOpType::Lt | BinOpType::Leq | BinOpType::Gt | BinOpType::Geq => { let start = match (comp_bin_op_typ, is_left_attr_ref) { (BinOpType::Lt, true) | (BinOpType::Geq, false) => Bound::Unbounded, @@ -74,6 +82,7 @@ impl CostModelImpl { _ => unreachable!("all comparison BinOpTypes were enumerated. this should be unreachable"), }; self.get_attribute_range_selectivity(table_id, attr_ref_idx, start, end) + .await } _ => unreachable!( "all comparison BinOpTypes were enumerated. this should be unreachable" @@ -109,7 +118,7 @@ impl CostModelImpl { /// This is convenient to avoid repeating the same logic just with "left" and "right" swapped. /// The last return value is true when the input node (left) is a AttributeRefPred. #[allow(clippy::type_complexity)] - fn get_semantic_nodes( + async fn get_semantic_nodes( &self, left: ArcPredicateNode, right: ArcPredicateNode, @@ -175,11 +184,16 @@ impl CostModelImpl { // The "invert" cast is to invert the cast so that we're casting the // non_cast_node to the attribute's original type. // TODO(migration): double check - let invert_cast_data_type = &(self + // TODO: Consider attribute info is None. + let attribute_info = self .storage_manager - .get_attribute_info(table_id, attr_ref_idx as i32)? - .typ - .into_data_type()); + .get_attribute_info(table_id, attr_ref_idx as i32) + .await? + .ok_or({ + SemanticError::AttributeNotFound(table_id, attr_ref_idx as i32) + })?; + + let invert_cast_data_type = &attribute_info.typ.into_data_type(); match non_cast_node.typ { PredicateType::AttributeRef => { diff --git a/optd-cost-model/src/cost/filter/controller.rs b/optd-cost-model/src/cost/filter/controller.rs index c0ce8c9..5369fe5 100644 --- a/optd-cost-model/src/cost/filter/controller.rs +++ b/optd-cost-model/src/cost/filter/controller.rs @@ -13,73 +13,78 @@ use crate::{ impl CostModelImpl { // TODO: is it a good design to pass table_id here? I think it needs to be refactored. // Consider to remove table_id. - pub fn get_filter_row_cnt( + pub async fn get_filter_row_cnt( &self, child_row_cnt: EstimatedStatistic, cond: ArcPredicateNode, ) -> CostModelResult { - let selectivity = { self.get_filter_selectivity(cond)? }; + let selectivity = { self.get_filter_selectivity(cond).await? }; Ok( EstimatedStatistic((child_row_cnt.0 as f64 * selectivity) as u64) .max(EstimatedStatistic(1)), ) } - pub fn get_filter_selectivity(&self, expr_tree: ArcPredicateNode) -> CostModelResult { - match &expr_tree.typ { - PredicateType::Constant(_) => Ok(Self::get_constant_selectivity(expr_tree)), - PredicateType::AttributeRef => unimplemented!("check bool type or else panic"), - PredicateType::UnOp(un_op_typ) => { - assert!(expr_tree.children.len() == 1); - let child = expr_tree.child(0); - match un_op_typ { - // not doesn't care about nulls so there's no complex logic. it just reverses - // the selectivity for instance, != _will not_ include nulls - // but "NOT ==" _will_ include nulls - UnOpType::Not => Ok(1.0 - self.get_filter_selectivity(child)?), - UnOpType::Neg => panic!( - "the selectivity of operations that return numerical values is undefined" - ), + pub async fn get_filter_selectivity( + &self, + expr_tree: ArcPredicateNode, + ) -> CostModelResult { + Box::pin(async move { + match &expr_tree.typ { + PredicateType::Constant(_) => Ok(Self::get_constant_selectivity(expr_tree)), + PredicateType::AttributeRef => unimplemented!("check bool type or else panic"), + PredicateType::UnOp(un_op_typ) => { + assert!(expr_tree.children.len() == 1); + let child = expr_tree.child(0); + match un_op_typ { + // not doesn't care about nulls so there's no complex logic. it just reverses + // the selectivity for instance, != _will not_ include nulls + // but "NOT ==" _will_ include nulls + UnOpType::Not => Ok(1.0 - self.get_filter_selectivity(child).await?), + UnOpType::Neg => panic!( + "the selectivity of operations that return numerical values is undefined" + ), + } } - } - PredicateType::BinOp(bin_op_typ) => { - assert!(expr_tree.children.len() == 2); - let left_child = expr_tree.child(0); - let right_child = expr_tree.child(1); + PredicateType::BinOp(bin_op_typ) => { + assert!(expr_tree.children.len() == 2); + let left_child = expr_tree.child(0); + let right_child = expr_tree.child(1); - if bin_op_typ.is_comparison() { - self.get_comp_op_selectivity(*bin_op_typ, left_child, right_child) - } else if bin_op_typ.is_numerical() { - panic!( - "the selectivity of operations that return numerical values is undefined" - ) - } else { - unreachable!("all BinOpTypes should be true for at least one is_*() function") + if bin_op_typ.is_comparison() { + self.get_comp_op_selectivity(*bin_op_typ, left_child, right_child).await + } else if bin_op_typ.is_numerical() { + panic!( + "the selectivity of operations that return numerical values is undefined" + ) + } else { + unreachable!("all BinOpTypes should be true for at least one is_*() function") + } } + PredicateType::LogOp(log_op_typ) => { + self.get_log_op_selectivity(*log_op_typ, &expr_tree.children).await + } + PredicateType::Func(_) => unimplemented!("check bool type or else panic"), + PredicateType::SortOrder(_) => { + panic!("the selectivity of sort order expressions is undefined") + } + PredicateType::Between => Ok(UNIMPLEMENTED_SEL), + PredicateType::Cast => unimplemented!("check bool type or else panic"), + PredicateType::Like => { + let like_expr = LikePred::from_pred_node(expr_tree).unwrap(); + self.get_like_selectivity(&like_expr).await + } + PredicateType::DataType(_) => { + panic!("the selectivity of a data type is not defined") + } + PredicateType::InList => { + let in_list_expr = InListPred::from_pred_node(expr_tree).unwrap(); + self.get_in_list_selectivity(&in_list_expr).await + } + _ => unreachable!( + "all expression DfPredType were enumerated. this should be unreachable" + ), } - PredicateType::LogOp(log_op_typ) => { - self.get_log_op_selectivity(*log_op_typ, &expr_tree.children) - } - PredicateType::Func(_) => unimplemented!("check bool type or else panic"), - PredicateType::SortOrder(_) => { - panic!("the selectivity of sort order expressions is undefined") - } - PredicateType::Between => Ok(UNIMPLEMENTED_SEL), - PredicateType::Cast => unimplemented!("check bool type or else panic"), - PredicateType::Like => { - let like_expr = LikePred::from_pred_node(expr_tree).unwrap(); - self.get_like_selectivity(&like_expr) - } - PredicateType::DataType(_) => { - panic!("the selectivity of a data type is not defined") - } - PredicateType::InList => { - let in_list_expr = InListPred::from_pred_node(expr_tree).unwrap(); - self.get_in_list_selectivity(&in_list_expr) - } - _ => unreachable!( - "all expression DfPredType were enumerated. this should be unreachable" - ), - } + }).await } } diff --git a/optd-cost-model/src/cost/filter/in_list.rs b/optd-cost-model/src/cost/filter/in_list.rs index cc2a570..f1eed06 100644 --- a/optd-cost-model/src/cost/filter/in_list.rs +++ b/optd-cost-model/src/cost/filter/in_list.rs @@ -15,7 +15,7 @@ use crate::{ impl CostModelImpl { /// Only support attrA in (val1, val2, val3) where attrA is a attribute ref and /// val1, val2, val3 are constants. - pub(crate) fn get_in_list_selectivity(&self, expr: &InListPred) -> CostModelResult { + pub(crate) async fn get_in_list_selectivity(&self, expr: &InListPred) -> CostModelResult { let child = expr.child(); // Check child is a attribute ref. @@ -46,18 +46,19 @@ impl CostModelImpl { let negated = expr.negated(); // TODO: Consider attribute is a derived attribute - let in_sel = list_exprs - .iter() - .try_fold(0.0, |acc, expr| { - let selectivity = self.get_attribute_equality_selectivity( + let mut in_sel = 0.0; + for expr in &list_exprs { + let selectivity = self + .get_attribute_equality_selectivity( table_id, attr_ref_idx, &expr.value(), /* is_equality */ true, - )?; - Ok(acc + selectivity) - })? - .min(1.0); + ) + .await?; + in_sel += selectivity; + } + in_sel = in_sel.min(1.0); if negated { Ok(1.0 - in_sel) } else { diff --git a/optd-cost-model/src/cost/filter/like.rs b/optd-cost-model/src/cost/filter/like.rs index f8a1ab4..fe9214b 100644 --- a/optd-cost-model/src/cost/filter/like.rs +++ b/optd-cost-model/src/cost/filter/like.rs @@ -28,7 +28,7 @@ impl CostModelImpl { /// is composed of MCV frequency and non-MCV selectivity. MCV frequency is computed by /// adding up frequencies of MCVs that match the pattern. Non-MCV selectivity is computed /// in the same way that Postgres computes selectivity for the wildcard part of the pattern. - pub(crate) fn get_like_selectivity(&self, like_expr: &LikePred) -> CostModelResult { + pub(crate) async fn get_like_selectivity(&self, like_expr: &LikePred) -> CostModelResult { let child = like_expr.child(); // Check child is a attribute ref. @@ -67,7 +67,10 @@ impl CostModelImpl { // Compute the selectivity in MCVs. // TODO: Handle the case where `attribute_stats` is None. - if let Some(attribute_stats) = self.get_attribute_comb_stats(table_id, &[attr_ref_idx])? { + if let Some(attribute_stats) = self + .get_attribute_comb_stats(table_id, &[attr_ref_idx]) + .await? + { let (mcv_freq, null_frac) = { let pred = Box::new(move |val: &AttributeCombValue| { let string = diff --git a/optd-cost-model/src/cost/filter/log_op.rs b/optd-cost-model/src/cost/filter/log_op.rs index 46e6a21..63e7cd1 100644 --- a/optd-cost-model/src/cost/filter/log_op.rs +++ b/optd-cost-model/src/cost/filter/log_op.rs @@ -7,22 +7,27 @@ use crate::{ }; impl CostModelImpl { - pub(crate) fn get_log_op_selectivity( + pub(crate) async fn get_log_op_selectivity( &self, log_op_typ: LogOpType, children: &[ArcPredicateNode], ) -> CostModelResult { match log_op_typ { - LogOpType::And => children.iter().try_fold(1.0, |acc, child| { - let selectivity = self.get_filter_selectivity(child.clone())?; - Ok(acc * selectivity) - }), + LogOpType::And => { + let mut and_sel = 1.0; + for child in children { + let selectivity = self.get_filter_selectivity(child.clone()).await?; + and_sel *= selectivity; + } + Ok(and_sel) + } LogOpType::Or => { - let product = children.iter().try_fold(1.0, |acc, child| { - let selectivity = self.get_filter_selectivity(child.clone())?; - Ok(acc * (1.0 - selectivity)) - })?; - Ok(1.0 - product) + let mut or_sel_neg = 1.0; + for child in children { + let selectivity = self.get_filter_selectivity(child.clone()).await?; + or_sel_neg *= (1.0 - selectivity); + } + Ok(1.0 - or_sel_neg) } } } diff --git a/optd-cost-model/src/cost_model.rs b/optd-cost-model/src/cost_model.rs index ebb45c2..1942558 100644 --- a/optd-cost-model/src/cost_model.rs +++ b/optd-cost-model/src/cost_model.rs @@ -97,12 +97,13 @@ impl CostModelImpl { /// TODO: documentation /// TODO: if we have memory cache, /// we should add the reference. (&AttributeCombValueStats) - pub(crate) fn get_attribute_comb_stats( + pub(crate) async fn get_attribute_comb_stats( &self, table_id: TableId, attr_comb: &[usize], ) -> CostModelResult> { self.storage_manager .get_attributes_comb_statistics(table_id, attr_comb) + .await } } diff --git a/optd-cost-model/src/lib.rs b/optd-cost-model/src/lib.rs index a2afcb8..d4a24d2 100644 --- a/optd-cost-model/src/lib.rs +++ b/optd-cost-model/src/lib.rs @@ -42,13 +42,14 @@ pub enum SemanticError { UnknownStatisticType, VersionedStatisticNotFound, AttributeNotFound(TableId, i32), // (table_id, attribute_base_index) + // FIXME: not sure if this should be put here + InvalidPredicate(String), } #[derive(Debug)] pub enum CostModelError { ORMError(BackendError), SemanticError(SemanticError), - InvalidPredicate(String), } impl From for CostModelError { @@ -57,6 +58,12 @@ impl From for CostModelError { } } +impl From for CostModelError { + fn from(err: SemanticError) -> Self { + CostModelError::SemanticError(err) + } +} + pub trait CostModel: 'static + Send + Sync { /// TODO: documentation fn compute_operation_cost( diff --git a/optd-cost-model/src/storage.rs b/optd-cost-model/src/storage.rs index 1ee5d0e..78b1b85 100644 --- a/optd-cost-model/src/storage.rs +++ b/optd-cost-model/src/storage.rs @@ -1,17 +1,22 @@ #![allow(unused_variables)] use std::sync::Arc; -use optd_persistent::{ - cost_model::interface::{Attr, StatType}, - CostModelStorageLayer, -}; +use optd_persistent::{cost_model::interface::StatType, CostModelStorageLayer}; +use serde::{Deserialize, Serialize}; use crate::{ - common::types::TableId, + common::{predicates::constant_pred::ConstantType, types::TableId}, stats::{counter::Counter, AttributeCombValueStats, Distribution, MostCommonValues}, CostModelResult, }; +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Attribute { + pub name: String, + pub typ: ConstantType, + pub nullable: bool, +} + /// TODO: documentation pub struct CostModelStorageManager { pub backend_manager: Arc, @@ -31,11 +36,16 @@ impl CostModelStorageManager { &self, table_id: TableId, attr_base_index: i32, - ) -> CostModelResult> { + ) -> CostModelResult> { Ok(self .backend_manager .get_attribute(table_id.into(), attr_base_index) - .await?) + .await? + .map(|attr| Attribute { + name: attr.name, + typ: ConstantType::from_persistent_attr_type(attr.attr_type), + nullable: attr.nullable, + })) } /// Gets the latest statistics for a given table. @@ -53,13 +63,13 @@ impl CostModelStorageManager { pub async fn get_attributes_comb_statistics( &self, table_id: TableId, - attr_base_indices: &[i32], + attr_base_indices: &[usize], ) -> CostModelResult> { let dist: Option = self .backend_manager .get_stats_for_attr_indices_based( table_id.into(), - attr_base_indices.to_vec(), + attr_base_indices.iter().map(|&x| x as i32).collect(), StatType::Distribution, None, ) @@ -70,7 +80,7 @@ impl CostModelStorageManager { .backend_manager .get_stats_for_attr_indices_based( table_id.into(), - attr_base_indices.to_vec(), + attr_base_indices.iter().map(|&x| x as i32).collect(), StatType::MostCommonValues, None, ) @@ -82,7 +92,7 @@ impl CostModelStorageManager { .backend_manager .get_stats_for_attr_indices_based( table_id.into(), - attr_base_indices.to_vec(), + attr_base_indices.iter().map(|&x| x as i32).collect(), StatType::Cardinality, None, ) @@ -94,7 +104,7 @@ impl CostModelStorageManager { .backend_manager .get_stats_for_attr_indices_based( table_id.into(), - attr_base_indices.to_vec(), + attr_base_indices.iter().map(|&x| x as i32).collect(), StatType::TableRowCount, None, ) @@ -105,7 +115,7 @@ impl CostModelStorageManager { .backend_manager .get_stats_for_attr_indices_based( table_id.into(), - attr_base_indices.to_vec(), + attr_base_indices.iter().map(|&x| x as i32).collect(), StatType::NonNullCount, None, ) diff --git a/optd-persistent/Cargo.toml b/optd-persistent/Cargo.toml index e9b9905..5d03a14 100644 --- a/optd-persistent/Cargo.toml +++ b/optd-persistent/Cargo.toml @@ -21,3 +21,4 @@ trait-variant = "0.1.2" async-trait = "0.1.43" async-stream = "0.3.1" strum = "0.26.1" +num_enum = "0.7.3" diff --git a/optd-persistent/src/cost_model/interface.rs b/optd-persistent/src/cost_model/interface.rs index a03087f..598598d 100644 --- a/optd-persistent/src/cost_model/interface.rs +++ b/optd-persistent/src/cost_model/interface.rs @@ -4,6 +4,7 @@ use crate::entities::cascades_group; use crate::entities::logical_expression; use crate::entities::physical_expression; use crate::StorageResult; +use num_enum::{IntoPrimitive, TryFromPrimitive}; use sea_orm::prelude::Json; use sea_orm::*; use sea_orm_migration::prelude::*; @@ -24,8 +25,10 @@ pub enum CatalogSource { } /// TODO: documentation +#[repr(i32)] +#[derive(Copy, Clone, Debug, PartialEq, IntoPrimitive, TryFromPrimitive)] pub enum AttrType { - Integer, + Integer = 1, Float, Varchar, Boolean, @@ -96,7 +99,7 @@ pub struct Attr { pub table_id: i32, pub name: String, pub compression_method: String, - pub attr_type: i32, + pub attr_type: AttrType, pub base_index: i32, pub nullable: bool, } diff --git a/optd-persistent/src/cost_model/orm.rs b/optd-persistent/src/cost_model/orm.rs index d172c14..65d6035 100644 --- a/optd-persistent/src/cost_model/orm.rs +++ b/optd-persistent/src/cost_model/orm.rs @@ -14,7 +14,8 @@ use serde_json::json; use super::catalog::mock_catalog::{self, MockCatalog}; use super::interface::{ - Attr, AttrId, CatalogSource, EpochId, EpochOption, ExprId, Stat, StatId, StatType, TableId, + Attr, AttrId, AttrType, CatalogSource, EpochId, EpochOption, ExprId, Stat, StatId, StatType, + TableId, }; impl BackendManager { @@ -543,19 +544,28 @@ impl CostModelStorageLayer for BackendManager { table_id: TableId, attribute_base_index: i32, ) -> StorageResult> { - Ok(Attribute::find() + let attr_res = Attribute::find() .filter(attribute::Column::TableId.eq(table_id)) .filter(attribute::Column::BaseAttributeNumber.eq(attribute_base_index)) .one(&self.db) - .await? - .map(|attr| Attr { - table_id, - name: attr.name, - compression_method: attr.compression_method, - attr_type: attr.variant_tag, - base_index: attribute_base_index, - nullable: !attr.is_not_null, - })) + .await?; + match attr_res { + Some(attr) => match AttrType::try_from(attr.variant_tag) { + Ok(attr_type) => Ok(Some(Attr { + table_id: attr.table_id, + name: attr.name, + compression_method: attr.compression_method, + attr_type, + base_index: attr.base_attribute_number, + nullable: attr.is_not_null, + })), + Err(_) => Err(BackendError::BackendError(format!( + "Failed to convert variant tag {} to AttrType", + attr.variant_tag + ))), + }, + None => Ok(None), + } } } From 03b6ec3b3a5be521b3528291c28fbbd3fa0466ad Mon Sep 17 00:00:00 2001 From: Lan Lou Date: Sat, 16 Nov 2024 13:47:22 -0500 Subject: [PATCH 22/51] Refactor cost model storage --- Cargo.lock | 8 + optd-cost-model/Cargo.toml | 2 + optd-cost-model/src/common/types.rs | 3 + optd-cost-model/src/cost/agg.rs | 5 +- optd-cost-model/src/cost/filter/attribute.rs | 7 +- optd-cost-model/src/cost/filter/comp_op.rs | 10 +- optd-cost-model/src/cost/filter/constant.rs | 5 +- optd-cost-model/src/cost/filter/controller.rs | 5 +- optd-cost-model/src/cost/filter/in_list.rs | 5 +- optd-cost-model/src/cost/filter/like.rs | 4 +- optd-cost-model/src/cost/filter/log_op.rs | 5 +- optd-cost-model/src/cost/limit.rs | 5 +- optd-cost-model/src/cost_model.rs | 226 +++++++++++++++++- optd-cost-model/src/storage/mock.rs | 0 optd-cost-model/src/storage/mod.rs | 21 ++ .../src/{storage.rs => storage/persistent.rs} | 14 +- 16 files changed, 280 insertions(+), 45 deletions(-) create mode 100644 optd-cost-model/src/storage/mock.rs create mode 100644 optd-cost-model/src/storage/mod.rs rename optd-cost-model/src/{storage.rs => storage/persistent.rs} (92%) diff --git a/Cargo.lock b/Cargo.lock index d47ecc6..8381193 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -374,6 +374,12 @@ dependencies = [ "regex-syntax 0.7.5", ] +[[package]] +name = "assert_approx_eq" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c07dab4369547dbe5114677b33fbbf724971019f3818172d59a97a61c774ffd" + [[package]] name = "async-compression" version = "0.4.17" @@ -2239,6 +2245,7 @@ name = "optd-cost-model" version = "0.1.0" dependencies = [ "arrow-schema 53.2.0", + "assert_approx_eq", "chrono", "crossbeam", "datafusion", @@ -2250,6 +2257,7 @@ dependencies = [ "serde", "serde_json", "serde_with", + "trait-variant", ] [[package]] diff --git a/optd-cost-model/Cargo.toml b/optd-cost-model/Cargo.toml index d667fe4..c20c062 100644 --- a/optd-cost-model/Cargo.toml +++ b/optd-cost-model/Cargo.toml @@ -15,6 +15,8 @@ datafusion = "32.0.0" ordered-float = "4.0" chrono = "0.4" itertools = "0.13" +assert_approx_eq = "1.1.0" +trait-variant = "0.1.2" [dev-dependencies] crossbeam = "0.8" diff --git a/optd-cost-model/src/common/types.rs b/optd-cost-model/src/common/types.rs index 1e92355..e8aaf7b 100644 --- a/optd-cost-model/src/common/types.rs +++ b/optd-cost-model/src/common/types.rs @@ -1,5 +1,8 @@ use std::fmt::Display; +/// TODO: Implement from and to methods for the following types to enable conversion +/// to and from their persistent counterparts. + /// TODO: documentation #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug, Default, Hash)] pub struct GroupId(pub usize); diff --git a/optd-cost-model/src/cost/agg.rs b/optd-cost-model/src/cost/agg.rs index ff82fec..f7e0034 100644 --- a/optd-cost-model/src/cost/agg.rs +++ b/optd-cost-model/src/cost/agg.rs @@ -1,5 +1,3 @@ -use optd_persistent::CostModelStorageLayer; - use crate::{ common::{ nodes::{ArcPredicateNode, PredicateType, ReprPredicateNode}, @@ -8,10 +6,11 @@ use crate::{ }, cost_model::CostModelImpl, stats::DEFAULT_NUM_DISTINCT, + storage::CostModelStorageManager, CostModelError, CostModelResult, EstimatedStatistic, SemanticError, }; -impl CostModelImpl { +impl CostModelImpl { pub async fn get_agg_row_cnt( &self, group_by: ArcPredicateNode, diff --git a/optd-cost-model/src/cost/filter/attribute.rs b/optd-cost-model/src/cost/filter/attribute.rs index c5ad90c..7eb77ce 100644 --- a/optd-cost-model/src/cost/filter/attribute.rs +++ b/optd-cost-model/src/cost/filter/attribute.rs @@ -1,17 +1,14 @@ use std::ops::Bound; -use optd_persistent::CostModelStorageLayer; - use crate::{ common::{types::TableId, values::Value}, cost_model::CostModelImpl, - // TODO: If we return the default value, consider tell the upper level that we cannot - // compute the selectivity. stats::{AttributeCombValue, AttributeCombValueStats, DEFAULT_EQ_SEL, DEFAULT_INEQ_SEL}, + storage::CostModelStorageManager, CostModelResult, }; -impl CostModelImpl { +impl CostModelImpl { /// Get the selectivity of an expression of the form "attribute equals value" (or "value equals /// attribute") Will handle the case of statistics missing /// Equality predicates are handled entirely differently from range predicates so this is its diff --git a/optd-cost-model/src/cost/filter/comp_op.rs b/optd-cost-model/src/cost/filter/comp_op.rs index 0a2092e..b8c2d99 100644 --- a/optd-cost-model/src/cost/filter/comp_op.rs +++ b/optd-cost-model/src/cost/filter/comp_op.rs @@ -1,7 +1,5 @@ use std::ops::Bound; -use optd_persistent::CostModelStorageLayer; - use crate::{ common::{ nodes::{ArcPredicateNode, PredicateType, ReprPredicateNode}, @@ -12,14 +10,12 @@ use crate::{ values::Value, }, cost_model::CostModelImpl, - // TODO: If we return the default value, consider tell the upper level that we cannot - // compute the selectivity. stats::{DEFAULT_EQ_SEL, DEFAULT_INEQ_SEL, UNIMPLEMENTED_SEL}, - CostModelResult, - SemanticError, + storage::CostModelStorageManager, + CostModelResult, SemanticError, }; -impl CostModelImpl { +impl CostModelImpl { /// Comparison operators are the base case for recursion in get_filter_selectivity() pub(crate) async fn get_comp_op_selectivity( &self, diff --git a/optd-cost-model/src/cost/filter/constant.rs b/optd-cost-model/src/cost/filter/constant.rs index b9c2cc8..e131bde 100644 --- a/optd-cost-model/src/cost/filter/constant.rs +++ b/optd-cost-model/src/cost/filter/constant.rs @@ -1,5 +1,3 @@ -use optd_persistent::CostModelStorageLayer; - use crate::{ common::{ nodes::{ArcPredicateNode, PredicateType}, @@ -7,9 +5,10 @@ use crate::{ values::Value, }, cost_model::CostModelImpl, + storage::CostModelStorageManager, }; -impl CostModelImpl { +impl CostModelImpl { pub(crate) fn get_constant_selectivity(const_node: ArcPredicateNode) -> f64 { if let PredicateType::Constant(const_typ) = const_node.typ { if matches!(const_typ, ConstantType::Bool) { diff --git a/optd-cost-model/src/cost/filter/controller.rs b/optd-cost-model/src/cost/filter/controller.rs index 5369fe5..35301c1 100644 --- a/optd-cost-model/src/cost/filter/controller.rs +++ b/optd-cost-model/src/cost/filter/controller.rs @@ -1,5 +1,3 @@ -use optd_persistent::CostModelStorageLayer; - use crate::{ common::{ nodes::{ArcPredicateNode, PredicateType, ReprPredicateNode}, @@ -7,10 +5,11 @@ use crate::{ }, cost_model::CostModelImpl, stats::UNIMPLEMENTED_SEL, + storage::CostModelStorageManager, CostModelResult, EstimatedStatistic, }; -impl CostModelImpl { +impl CostModelImpl { // TODO: is it a good design to pass table_id here? I think it needs to be refactored. // Consider to remove table_id. pub async fn get_filter_row_cnt( diff --git a/optd-cost-model/src/cost/filter/in_list.rs b/optd-cost-model/src/cost/filter/in_list.rs index f1eed06..d27b6f1 100644 --- a/optd-cost-model/src/cost/filter/in_list.rs +++ b/optd-cost-model/src/cost/filter/in_list.rs @@ -1,5 +1,3 @@ -use optd_persistent::CostModelStorageLayer; - use crate::{ common::{ nodes::{PredicateType, ReprPredicateNode}, @@ -9,10 +7,11 @@ use crate::{ }, cost_model::CostModelImpl, stats::UNIMPLEMENTED_SEL, + storage::CostModelStorageManager, CostModelResult, }; -impl CostModelImpl { +impl CostModelImpl { /// Only support attrA in (val1, val2, val3) where attrA is a attribute ref and /// val1, val2, val3 are constants. pub(crate) async fn get_in_list_selectivity(&self, expr: &InListPred) -> CostModelResult { diff --git a/optd-cost-model/src/cost/filter/like.rs b/optd-cost-model/src/cost/filter/like.rs index fe9214b..04517f2 100644 --- a/optd-cost-model/src/cost/filter/like.rs +++ b/optd-cost-model/src/cost/filter/like.rs @@ -1,5 +1,4 @@ use datafusion::arrow::{array::StringArray, compute::like}; -use optd_persistent::CostModelStorageLayer; use crate::{ common::{ @@ -12,10 +11,11 @@ use crate::{ stats::{ AttributeCombValue, FIXED_CHAR_SEL_FACTOR, FULL_WILDCARD_SEL_FACTOR, UNIMPLEMENTED_SEL, }, + storage::CostModelStorageManager, CostModelResult, }; -impl CostModelImpl { +impl CostModelImpl { /// Compute the selectivity of a (NOT) LIKE expression. /// /// The logic is somewhat similar to Postgres but different. Postgres first estimates the diff --git a/optd-cost-model/src/cost/filter/log_op.rs b/optd-cost-model/src/cost/filter/log_op.rs index 63e7cd1..66bab10 100644 --- a/optd-cost-model/src/cost/filter/log_op.rs +++ b/optd-cost-model/src/cost/filter/log_op.rs @@ -1,12 +1,11 @@ -use optd_persistent::CostModelStorageLayer; - use crate::{ common::{nodes::ArcPredicateNode, predicates::log_op_pred::LogOpType}, cost_model::CostModelImpl, + storage::CostModelStorageManager, CostModelResult, }; -impl CostModelImpl { +impl CostModelImpl { pub(crate) async fn get_log_op_selectivity( &self, log_op_typ: LogOpType, diff --git a/optd-cost-model/src/cost/limit.rs b/optd-cost-model/src/cost/limit.rs index ce3e08e..38e7550 100644 --- a/optd-cost-model/src/cost/limit.rs +++ b/optd-cost-model/src/cost/limit.rs @@ -1,15 +1,14 @@ -use optd_persistent::CostModelStorageLayer; - use crate::{ common::{ nodes::{ArcPredicateNode, ReprPredicateNode}, predicates::constant_pred::ConstantPred, }, cost_model::CostModelImpl, + storage::CostModelStorageManager, CostModelResult, EstimatedStatistic, }; -impl CostModelImpl { +impl CostModelImpl { pub(crate) fn get_limit_row_cnt( &self, child_row_cnt: EstimatedStatistic, diff --git a/optd-cost-model/src/cost_model.rs b/optd-cost-model/src/cost_model.rs index 1942558..c0b38e1 100644 --- a/optd-cost-model/src/cost_model.rs +++ b/optd-cost-model/src/cost_model.rs @@ -18,17 +18,14 @@ use crate::{ }; /// TODO: documentation -pub struct CostModelImpl { - pub storage_manager: CostModelStorageManager, +pub struct CostModelImpl { + pub storage_manager: S, pub default_catalog_source: CatalogSource, } -impl CostModelImpl { +impl CostModelImpl { /// TODO: documentation - pub fn new( - storage_manager: CostModelStorageManager, - default_catalog_source: CatalogSource, - ) -> Self { + pub fn new(storage_manager: S, default_catalog_source: CatalogSource) -> Self { Self { storage_manager, default_catalog_source, @@ -36,7 +33,9 @@ impl CostModelImpl { } } -impl CostModel for CostModelImpl { +impl CostModel + for CostModelImpl +{ fn compute_operation_cost( &self, node: &PhysicalNodeType, @@ -93,7 +92,7 @@ impl CostModel for CostM } } -impl CostModelImpl { +impl CostModelImpl { /// TODO: documentation /// TODO: if we have memory cache, /// we should add the reference. (&AttributeCombValueStats) @@ -107,3 +106,212 @@ impl CostModelImpl { .await } } + +// /// I thought about using the system's own parser and planner to generate these expression trees, +// /// but this is not currently feasible because it would create a cyclic dependency between +// /// optd-datafusion-bridge and optd-datafusion-repr +// #[cfg(test)] +// mod tests { +// use std::collections::HashMap; + +// use arrow_schema::DataType; +// use itertools::Itertools; +// use optd_persistent::BackendManager; +// use serde::{Deserialize, Serialize}; + +// use super::*; +// pub type TestPerColumnStats = AttributeCombValueStats; +// pub type TestOptCostModel = CostModelImpl; + +// pub const TABLE1_NAME: &str = "table1"; +// pub const TABLE2_NAME: &str = "table2"; +// pub const TABLE3_NAME: &str = "table3"; +// pub const TABLE4_NAME: &str = "table4"; + +// // one column is sufficient for all filter selectivity tests +// pub fn create_one_column_cost_model(per_column_stats: TestPerColumnStats) -> TestOptCostModel { +// AdvStats::new( +// vec![( +// String::from(TABLE1_NAME), +// TableStats::new(100, vec![(vec![0], per_column_stats)].into_iter().collect()), +// )] +// .into_iter() +// .collect(), +// ) +// } + +// /// Create a cost model with two columns, one for each table. Each column has 100 values. +// pub fn create_two_table_cost_model( +// tbl1_per_column_stats: TestPerColumnStats, +// tbl2_per_column_stats: TestPerColumnStats, +// ) -> TestOptCostModel { +// create_two_table_cost_model_custom_row_cnts( +// tbl1_per_column_stats, +// tbl2_per_column_stats, +// 100, +// 100, +// ) +// } + +// /// Create a cost model with three columns, one for each table. Each column has 100 values. +// pub fn create_three_table_cost_model( +// tbl1_per_column_stats: TestPerColumnStats, +// tbl2_per_column_stats: TestPerColumnStats, +// tbl3_per_column_stats: TestPerColumnStats, +// ) -> TestOptCostModel { +// AdvStats::new( +// vec![ +// ( +// String::from(TABLE1_NAME), +// TableStats::new( +// 100, +// vec![(vec![0], tbl1_per_column_stats)].into_iter().collect(), +// ), +// ), +// ( +// String::from(TABLE2_NAME), +// TableStats::new( +// 100, +// vec![(vec![0], tbl2_per_column_stats)].into_iter().collect(), +// ), +// ), +// ( +// String::from(TABLE3_NAME), +// TableStats::new( +// 100, +// vec![(vec![0], tbl3_per_column_stats)].into_iter().collect(), +// ), +// ), +// ] +// .into_iter() +// .collect(), +// ) +// } + +// /// Create a cost model with three columns, one for each table. Each column has 100 values. +// pub fn create_four_table_cost_model( +// tbl1_per_column_stats: TestPerColumnStats, +// tbl2_per_column_stats: TestPerColumnStats, +// tbl3_per_column_stats: TestPerColumnStats, +// tbl4_per_column_stats: TestPerColumnStats, +// ) -> TestOptCostModel { +// AdvStats::new( +// vec![ +// ( +// String::from(TABLE1_NAME), +// TableStats::new( +// 100, +// vec![(vec![0], tbl1_per_column_stats)].into_iter().collect(), +// ), +// ), +// ( +// String::from(TABLE2_NAME), +// TableStats::new( +// 100, +// vec![(vec![0], tbl2_per_column_stats)].into_iter().collect(), +// ), +// ), +// ( +// String::from(TABLE3_NAME), +// TableStats::new( +// 100, +// vec![(vec![0], tbl3_per_column_stats)].into_iter().collect(), +// ), +// ), +// ( +// String::from(TABLE4_NAME), +// TableStats::new( +// 100, +// vec![(vec![0], tbl4_per_column_stats)].into_iter().collect(), +// ), +// ), +// ] +// .into_iter() +// .collect(), +// ) +// } + +// /// We need custom row counts because some join algorithms rely on the row cnt +// pub fn create_two_table_cost_model_custom_row_cnts( +// tbl1_per_column_stats: TestPerColumnStats, +// tbl2_per_column_stats: TestPerColumnStats, +// tbl1_row_cnt: usize, +// tbl2_row_cnt: usize, +// ) -> TestOptCostModel { +// AdvStats::new( +// vec![ +// ( +// String::from(TABLE1_NAME), +// TableStats::new( +// tbl1_row_cnt, +// vec![(vec![0], tbl1_per_column_stats)].into_iter().collect(), +// ), +// ), +// ( +// String::from(TABLE2_NAME), +// TableStats::new( +// tbl2_row_cnt, +// vec![(vec![0], tbl2_per_column_stats)].into_iter().collect(), +// ), +// ), +// ] +// .into_iter() +// .collect(), +// ) +// } + +// pub fn col_ref(idx: u64) -> ArcDfPredNode { +// // this conversion is always safe because idx was originally a usize +// let idx_as_usize = idx as usize; +// ColumnRefPred::new(idx_as_usize).into_pred_node() +// } + +// pub fn cnst(value: Value) -> ArcDfPredNode { +// ConstantPred::new(value).into_pred_node() +// } + +// pub fn cast(child: ArcDfPredNode, cast_type: DataType) -> ArcDfPredNode { +// CastPred::new(child, cast_type).into_pred_node() +// } + +// pub fn bin_op(op_type: BinOpType, left: ArcDfPredNode, right: ArcDfPredNode) -> ArcDfPredNode { +// BinOpPred::new(left, right, op_type).into_pred_node() +// } + +// pub fn log_op(op_type: LogOpType, children: Vec) -> ArcDfPredNode { +// LogOpPred::new(op_type, children).into_pred_node() +// } + +// pub fn un_op(op_type: UnOpType, child: ArcDfPredNode) -> ArcDfPredNode { +// UnOpPred::new(child, op_type).into_pred_node() +// } + +// pub fn in_list(col_ref_idx: u64, list: Vec, negated: bool) -> InListPred { +// InListPred::new( +// col_ref(col_ref_idx), +// ListPred::new(list.into_iter().map(cnst).collect_vec()), +// negated, +// ) +// } + +// pub fn like(col_ref_idx: u64, pattern: &str, negated: bool) -> LikePred { +// LikePred::new( +// negated, +// false, +// col_ref(col_ref_idx), +// cnst(Value::String(pattern.into())), +// ) +// } + +// /// The reason this isn't an associated function of PerColumnStats is because that would require +// /// adding an empty() function to the trait definitions of MostCommonValues and Distribution, +// /// which I wanted to avoid +// pub(crate) fn get_empty_per_col_stats() -> TestPerColumnStats { +// TestPerColumnStats::new( +// TestMostCommonValues::empty(), +// 0, +// 0.0, +// Some(TestDistribution::empty()), +// ) +// } +// } diff --git a/optd-cost-model/src/storage/mock.rs b/optd-cost-model/src/storage/mock.rs new file mode 100644 index 0000000..e69de29 diff --git a/optd-cost-model/src/storage/mod.rs b/optd-cost-model/src/storage/mod.rs new file mode 100644 index 0000000..cf7baf5 --- /dev/null +++ b/optd-cost-model/src/storage/mod.rs @@ -0,0 +1,21 @@ +use persistent::Attribute; + +use crate::{common::types::TableId, stats::AttributeCombValueStats, CostModelResult}; + +pub mod mock; +pub mod persistent; + +#[trait_variant::make(Send)] +pub trait CostModelStorageManager { + async fn get_attribute_info( + &self, + table_id: TableId, + attr_base_index: i32, + ) -> CostModelResult>; + + async fn get_attributes_comb_statistics( + &self, + table_id: TableId, + attr_base_indices: &[usize], + ) -> CostModelResult>; +} diff --git a/optd-cost-model/src/storage.rs b/optd-cost-model/src/storage/persistent.rs similarity index 92% rename from optd-cost-model/src/storage.rs rename to optd-cost-model/src/storage/persistent.rs index 78b1b85..c789e7f 100644 --- a/optd-cost-model/src/storage.rs +++ b/optd-cost-model/src/storage/persistent.rs @@ -10,6 +10,8 @@ use crate::{ CostModelResult, }; +use super::CostModelStorageManager; + #[derive(Clone, Debug, Serialize, Deserialize)] pub struct Attribute { pub name: String, @@ -18,21 +20,25 @@ pub struct Attribute { } /// TODO: documentation -pub struct CostModelStorageManager { +pub struct CostModelStorageManagerImpl { pub backend_manager: Arc, // TODO: in-memory cache } -impl CostModelStorageManager { +impl CostModelStorageManagerImpl { pub fn new(backend_manager: Arc) -> Self { Self { backend_manager } } +} +impl CostModelStorageManager + for CostModelStorageManagerImpl +{ /// Gets the attribute information for a given table and attribute base index. /// /// TODO: if we have memory cache, /// we should add the reference. (&Attr) - pub async fn get_attribute_info( + async fn get_attribute_info( &self, table_id: TableId, attr_base_index: i32, @@ -60,7 +66,7 @@ impl CostModelStorageManager { /// /// TODO: Shall we pass in an epoch here to make sure that the statistics are from the same /// epoch? - pub async fn get_attributes_comb_statistics( + async fn get_attributes_comb_statistics( &self, table_id: TableId, attr_base_indices: &[usize], From a3b80888e8630037bd305e02fc4ddbd29f829035 Mon Sep 17 00:00:00 2001 From: Lan Lou Date: Sat, 16 Nov 2024 13:51:49 -0500 Subject: [PATCH 23/51] Move storage attribute to mod --- optd-cost-model/src/storage/mod.rs | 15 +++++++++++++-- optd-cost-model/src/storage/persistent.rs | 10 +--------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/optd-cost-model/src/storage/mod.rs b/optd-cost-model/src/storage/mod.rs index cf7baf5..78c75cd 100644 --- a/optd-cost-model/src/storage/mod.rs +++ b/optd-cost-model/src/storage/mod.rs @@ -1,10 +1,21 @@ -use persistent::Attribute; +use serde::{Deserialize, Serialize}; -use crate::{common::types::TableId, stats::AttributeCombValueStats, CostModelResult}; +use crate::{ + common::{predicates::constant_pred::ConstantType, types::TableId}, + stats::AttributeCombValueStats, + CostModelResult, +}; pub mod mock; pub mod persistent; +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Attribute { + pub name: String, + pub typ: ConstantType, + pub nullable: bool, +} + #[trait_variant::make(Send)] pub trait CostModelStorageManager { async fn get_attribute_info( diff --git a/optd-cost-model/src/storage/persistent.rs b/optd-cost-model/src/storage/persistent.rs index c789e7f..53f37da 100644 --- a/optd-cost-model/src/storage/persistent.rs +++ b/optd-cost-model/src/storage/persistent.rs @@ -2,7 +2,6 @@ use std::sync::Arc; use optd_persistent::{cost_model::interface::StatType, CostModelStorageLayer}; -use serde::{Deserialize, Serialize}; use crate::{ common::{predicates::constant_pred::ConstantType, types::TableId}, @@ -10,14 +9,7 @@ use crate::{ CostModelResult, }; -use super::CostModelStorageManager; - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct Attribute { - pub name: String, - pub typ: ConstantType, - pub nullable: bool, -} +use super::{Attribute, CostModelStorageManager}; /// TODO: documentation pub struct CostModelStorageManagerImpl { From c07b9fce0ce0763093576709060d8e582438a214 Mon Sep 17 00:00:00 2001 From: Lan Lou Date: Sat, 16 Nov 2024 14:38:40 -0500 Subject: [PATCH 24/51] Add initial test framework in cost_model.rs --- .../src/common/predicates/attr_ref_pred.rs | 11 +- .../src/common/predicates/bin_op_pred.rs | 47 ++ .../src/common/predicates/log_op_pred.rs | 71 +++ .../src/common/predicates/un_op_pred.rs | 43 ++ optd-cost-model/src/cost_model.rs | 408 ++++++++++-------- optd-cost-model/src/storage/mock.rs | 62 +++ 6 files changed, 448 insertions(+), 194 deletions(-) diff --git a/optd-cost-model/src/common/predicates/attr_ref_pred.rs b/optd-cost-model/src/common/predicates/attr_ref_pred.rs index 8a670a5..cdc7440 100644 --- a/optd-cost-model/src/common/predicates/attr_ref_pred.rs +++ b/optd-cost-model/src/common/predicates/attr_ref_pred.rs @@ -9,7 +9,12 @@ use super::id_pred::IdPred; /// /// An [`AttributeRefPred`] has two children: /// 1. The table id, represented by an [`IdPred`]. -/// 2. The index of the column, represented by an [`IdPred`]. +/// 2. The index of the attribute, represented by an [`IdPred`]. +/// +/// Although it may be strange at first glance (table id and attribute base index +/// aren't children of the attribute reference), but considering the attribute reference +/// can be represented as table_id.attr_base_index, and it enables the cost model to +/// obtain the information in a simple way without refactoring `data` field. /// /// **TODO**: Now we assume any IdPred is as same as the ones in the ORM layer. /// @@ -23,12 +28,12 @@ use super::id_pred::IdPred; pub struct AttributeRefPred(pub ArcPredicateNode); impl AttributeRefPred { - pub fn new(table_id: usize, attribute_idx: usize) -> AttributeRefPred { + pub fn new(table_id: TableId, attribute_idx: usize) -> AttributeRefPred { AttributeRefPred( PredicateNode { typ: PredicateType::AttributeRef, children: vec![ - IdPred::new(table_id).into_pred_node(), + IdPred::new(table_id.0).into_pred_node(), IdPred::new(attribute_idx).into_pred_node(), ], data: None, diff --git a/optd-cost-model/src/common/predicates/bin_op_pred.rs b/optd-cost-model/src/common/predicates/bin_op_pred.rs index 196d987..5c48688 100644 --- a/optd-cost-model/src/common/predicates/bin_op_pred.rs +++ b/optd-cost-model/src/common/predicates/bin_op_pred.rs @@ -1,3 +1,5 @@ +use crate::common::nodes::{ArcPredicateNode, PredicateNode, PredicateType, ReprPredicateNode}; + /// TODO: documentation #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)] pub enum BinOpType { @@ -38,3 +40,48 @@ impl BinOpType { ) } } + +#[derive(Clone, Debug)] +pub struct BinOpPred(pub ArcPredicateNode); + +impl BinOpPred { + pub fn new(left: ArcPredicateNode, right: ArcPredicateNode, op_type: BinOpType) -> Self { + BinOpPred( + PredicateNode { + typ: PredicateType::BinOp(op_type), + children: vec![left, right], + data: None, + } + .into(), + ) + } + + pub fn left_child(&self) -> ArcPredicateNode { + self.0.child(0) + } + + pub fn right_child(&self) -> ArcPredicateNode { + self.0.child(1) + } + + pub fn op_type(&self) -> BinOpType { + if let PredicateType::BinOp(op_type) = self.0.typ { + op_type + } else { + panic!("not a bin op") + } + } +} + +impl ReprPredicateNode for BinOpPred { + fn into_pred_node(self) -> ArcPredicateNode { + self.0 + } + + fn from_pred_node(pred_node: ArcPredicateNode) -> Option { + if !matches!(pred_node.typ, PredicateType::BinOp(_)) { + return None; + } + Some(Self(pred_node)) + } +} diff --git a/optd-cost-model/src/common/predicates/log_op_pred.rs b/optd-cost-model/src/common/predicates/log_op_pred.rs index 88c5746..1899cb1 100644 --- a/optd-cost-model/src/common/predicates/log_op_pred.rs +++ b/optd-cost-model/src/common/predicates/log_op_pred.rs @@ -1,5 +1,9 @@ use std::fmt::Display; +use crate::common::nodes::{ArcPredicateNode, PredicateNode, PredicateType, ReprPredicateNode}; + +use super::list_pred::ListPred; + /// TODO: documentation #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)] pub enum LogOpType { @@ -12,3 +16,70 @@ impl Display for LogOpType { write!(f, "{:?}", self) } } + +#[derive(Clone, Debug)] +pub struct LogOpPred(pub ArcPredicateNode); + +impl LogOpPred { + pub fn new(op_type: LogOpType, preds: Vec) -> Self { + LogOpPred( + PredicateNode { + typ: PredicateType::LogOp(op_type), + children: preds, + data: None, + } + .into(), + ) + } + + /// flatten_nested_logical is a helper function to flatten nested logical operators with same op + /// type eg. (a AND (b AND c)) => ExprList([a, b, c]) + /// (a OR (b OR c)) => ExprList([a, b, c]) + /// It assume the children of the input expr_list are already flattened + /// and can only be used in bottom up manner + pub fn new_flattened_nested_logical(op: LogOpType, expr_list: ListPred) -> Self { + // Since we assume that we are building the children bottom up, + // there is no need to call flatten_nested_logical recursively + let mut new_expr_list = Vec::new(); + for child in expr_list.to_vec() { + if let PredicateType::LogOp(child_op) = child.typ { + if child_op == op { + let child_log_op_expr = LogOpPred::from_pred_node(child).unwrap(); + new_expr_list.extend(child_log_op_expr.children().to_vec()); + continue; + } + } + new_expr_list.push(child.clone()); + } + LogOpPred::new(op, new_expr_list) + } + + pub fn children(&self) -> Vec { + self.0.children.clone() + } + + pub fn child(&self, idx: usize) -> ArcPredicateNode { + self.0.child(idx) + } + + pub fn op_type(&self) -> LogOpType { + if let PredicateType::LogOp(op_type) = self.0.typ { + op_type + } else { + panic!("not a log op") + } + } +} + +impl ReprPredicateNode for LogOpPred { + fn into_pred_node(self) -> ArcPredicateNode { + self.0 + } + + fn from_pred_node(pred_node: ArcPredicateNode) -> Option { + if !matches!(pred_node.typ, PredicateType::LogOp(_)) { + return None; + } + Some(Self(pred_node)) + } +} diff --git a/optd-cost-model/src/common/predicates/un_op_pred.rs b/optd-cost-model/src/common/predicates/un_op_pred.rs index d33158f..a3fc270 100644 --- a/optd-cost-model/src/common/predicates/un_op_pred.rs +++ b/optd-cost-model/src/common/predicates/un_op_pred.rs @@ -1,5 +1,7 @@ use std::fmt::Display; +use crate::common::nodes::{ArcPredicateNode, PredicateNode, PredicateType, ReprPredicateNode}; + /// TODO: documentation #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)] pub enum UnOpType { @@ -12,3 +14,44 @@ impl Display for UnOpType { write!(f, "{:?}", self) } } + +#[derive(Clone, Debug)] +pub struct UnOpPred(pub ArcPredicateNode); + +impl UnOpPred { + pub fn new(child: ArcPredicateNode, op_type: UnOpType) -> Self { + UnOpPred( + PredicateNode { + typ: PredicateType::UnOp(op_type), + children: vec![child], + data: None, + } + .into(), + ) + } + + pub fn child(&self) -> ArcPredicateNode { + self.0.child(0) + } + + pub fn op_type(&self) -> UnOpType { + if let PredicateType::UnOp(op_type) = self.0.typ { + op_type + } else { + panic!("not a un op") + } + } +} + +impl ReprPredicateNode for UnOpPred { + fn into_pred_node(self) -> ArcPredicateNode { + self.0 + } + + fn from_pred_node(pred_node: ArcPredicateNode) -> Option { + if !matches!(pred_node.typ, PredicateType::UnOp(_)) { + return None; + } + Some(Self(pred_node)) + } +} diff --git a/optd-cost-model/src/cost_model.rs b/optd-cost-model/src/cost_model.rs index c0b38e1..4700f17 100644 --- a/optd-cost-model/src/cost_model.rs +++ b/optd-cost-model/src/cost_model.rs @@ -107,211 +107,237 @@ impl CostModelImpl { } } -// /// I thought about using the system's own parser and planner to generate these expression trees, -// /// but this is not currently feasible because it would create a cyclic dependency between -// /// optd-datafusion-bridge and optd-datafusion-repr -// #[cfg(test)] -// mod tests { -// use std::collections::HashMap; +/// I thought about using the system's own parser and planner to generate these expression trees, +/// but this is not currently feasible because it would create a cyclic dependency between +/// optd-datafusion-bridge and optd-datafusion-repr +#[cfg(test)] +mod tests { + use std::collections::HashMap; -// use arrow_schema::DataType; -// use itertools::Itertools; -// use optd_persistent::BackendManager; -// use serde::{Deserialize, Serialize}; + use arrow_schema::DataType; + use itertools::Itertools; + use optd_persistent::cost_model::interface::CatalogSource; + use serde::{Deserialize, Serialize}; -// use super::*; -// pub type TestPerColumnStats = AttributeCombValueStats; -// pub type TestOptCostModel = CostModelImpl; + use crate::{ + common::{ + nodes::ReprPredicateNode, + predicates::{ + attr_ref_pred::AttributeRefPred, + bin_op_pred::{BinOpPred, BinOpType}, + cast_pred::CastPred, + constant_pred::ConstantPred, + in_list_pred::InListPred, + like_pred::LikePred, + list_pred::ListPred, + log_op_pred::{LogOpPred, LogOpType}, + un_op_pred::{UnOpPred, UnOpType}, + }, + values::Value, + }, + stats::{ + counter::Counter, tdigest::TDigest, AttributeCombValueStats, Distribution, + MostCommonValues, + }, + storage::mock::{CostModelStorageMockManagerImpl, TableStats}, + }; -// pub const TABLE1_NAME: &str = "table1"; -// pub const TABLE2_NAME: &str = "table2"; -// pub const TABLE3_NAME: &str = "table3"; -// pub const TABLE4_NAME: &str = "table4"; + use super::*; -// // one column is sufficient for all filter selectivity tests -// pub fn create_one_column_cost_model(per_column_stats: TestPerColumnStats) -> TestOptCostModel { -// AdvStats::new( -// vec![( -// String::from(TABLE1_NAME), -// TableStats::new(100, vec![(vec![0], per_column_stats)].into_iter().collect()), -// )] -// .into_iter() -// .collect(), -// ) -// } + pub type TestPerColumnStats = AttributeCombValueStats; + // TODO: add tests for non-mock storage manager + pub type TestOptCostModelMock = CostModelImpl; -// /// Create a cost model with two columns, one for each table. Each column has 100 values. -// pub fn create_two_table_cost_model( -// tbl1_per_column_stats: TestPerColumnStats, -// tbl2_per_column_stats: TestPerColumnStats, -// ) -> TestOptCostModel { -// create_two_table_cost_model_custom_row_cnts( -// tbl1_per_column_stats, -// tbl2_per_column_stats, -// 100, -// 100, -// ) -// } + pub const TABLE1_NAME: &str = "table1"; + pub const TABLE2_NAME: &str = "table2"; + pub const TABLE3_NAME: &str = "table3"; + pub const TABLE4_NAME: &str = "table4"; -// /// Create a cost model with three columns, one for each table. Each column has 100 values. -// pub fn create_three_table_cost_model( -// tbl1_per_column_stats: TestPerColumnStats, -// tbl2_per_column_stats: TestPerColumnStats, -// tbl3_per_column_stats: TestPerColumnStats, -// ) -> TestOptCostModel { -// AdvStats::new( -// vec![ -// ( -// String::from(TABLE1_NAME), -// TableStats::new( -// 100, -// vec![(vec![0], tbl1_per_column_stats)].into_iter().collect(), -// ), -// ), -// ( -// String::from(TABLE2_NAME), -// TableStats::new( -// 100, -// vec![(vec![0], tbl2_per_column_stats)].into_iter().collect(), -// ), -// ), -// ( -// String::from(TABLE3_NAME), -// TableStats::new( -// 100, -// vec![(vec![0], tbl3_per_column_stats)].into_iter().collect(), -// ), -// ), -// ] -// .into_iter() -// .collect(), -// ) -// } + // one column is sufficient for all filter selectivity tests + pub fn create_one_column_cost_model_mock_storage( + per_column_stats: TestPerColumnStats, + ) -> TestOptCostModelMock { + let storage_manager = CostModelStorageMockManagerImpl::new( + vec![( + String::from(TABLE1_NAME), + TableStats::new(100, vec![(vec![0], per_column_stats)].into_iter().collect()), + )] + .into_iter() + .collect(), + ); + CostModelImpl::new(storage_manager, CatalogSource::Mock) + } + + /// Create a cost model with two columns, one for each table. Each column has 100 values. + pub fn create_two_table_cost_model_mock_storage( + tbl1_per_column_stats: TestPerColumnStats, + tbl2_per_column_stats: TestPerColumnStats, + ) -> TestOptCostModelMock { + create_two_table_cost_model_custom_row_cnts_mock_storage( + tbl1_per_column_stats, + tbl2_per_column_stats, + 100, + 100, + ) + } -// /// Create a cost model with three columns, one for each table. Each column has 100 values. -// pub fn create_four_table_cost_model( -// tbl1_per_column_stats: TestPerColumnStats, -// tbl2_per_column_stats: TestPerColumnStats, -// tbl3_per_column_stats: TestPerColumnStats, -// tbl4_per_column_stats: TestPerColumnStats, -// ) -> TestOptCostModel { -// AdvStats::new( -// vec![ -// ( -// String::from(TABLE1_NAME), -// TableStats::new( -// 100, -// vec![(vec![0], tbl1_per_column_stats)].into_iter().collect(), -// ), -// ), -// ( -// String::from(TABLE2_NAME), -// TableStats::new( -// 100, -// vec![(vec![0], tbl2_per_column_stats)].into_iter().collect(), -// ), -// ), -// ( -// String::from(TABLE3_NAME), -// TableStats::new( -// 100, -// vec![(vec![0], tbl3_per_column_stats)].into_iter().collect(), -// ), -// ), -// ( -// String::from(TABLE4_NAME), -// TableStats::new( -// 100, -// vec![(vec![0], tbl4_per_column_stats)].into_iter().collect(), -// ), -// ), -// ] -// .into_iter() -// .collect(), -// ) -// } + /// Create a cost model with three columns, one for each table. Each column has 100 values. + pub fn create_three_table_cost_model( + tbl1_per_column_stats: TestPerColumnStats, + tbl2_per_column_stats: TestPerColumnStats, + tbl3_per_column_stats: TestPerColumnStats, + ) -> TestOptCostModelMock { + let storage_manager = CostModelStorageMockManagerImpl::new( + vec![ + ( + String::from(TABLE1_NAME), + TableStats::new( + 100, + vec![(vec![0], tbl1_per_column_stats)].into_iter().collect(), + ), + ), + ( + String::from(TABLE2_NAME), + TableStats::new( + 100, + vec![(vec![0], tbl2_per_column_stats)].into_iter().collect(), + ), + ), + ( + String::from(TABLE3_NAME), + TableStats::new( + 100, + vec![(vec![0], tbl3_per_column_stats)].into_iter().collect(), + ), + ), + ] + .into_iter() + .collect(), + ); + CostModelImpl::new(storage_manager, CatalogSource::Mock) + } -// /// We need custom row counts because some join algorithms rely on the row cnt -// pub fn create_two_table_cost_model_custom_row_cnts( -// tbl1_per_column_stats: TestPerColumnStats, -// tbl2_per_column_stats: TestPerColumnStats, -// tbl1_row_cnt: usize, -// tbl2_row_cnt: usize, -// ) -> TestOptCostModel { -// AdvStats::new( -// vec![ -// ( -// String::from(TABLE1_NAME), -// TableStats::new( -// tbl1_row_cnt, -// vec![(vec![0], tbl1_per_column_stats)].into_iter().collect(), -// ), -// ), -// ( -// String::from(TABLE2_NAME), -// TableStats::new( -// tbl2_row_cnt, -// vec![(vec![0], tbl2_per_column_stats)].into_iter().collect(), -// ), -// ), -// ] -// .into_iter() -// .collect(), -// ) -// } + /// Create a cost model with three columns, one for each table. Each column has 100 values. + pub fn create_four_table_cost_model_mock_storage( + tbl1_per_column_stats: TestPerColumnStats, + tbl2_per_column_stats: TestPerColumnStats, + tbl3_per_column_stats: TestPerColumnStats, + tbl4_per_column_stats: TestPerColumnStats, + ) -> TestOptCostModelMock { + let storage_manager = CostModelStorageMockManagerImpl::new( + vec![ + ( + String::from(TABLE1_NAME), + TableStats::new( + 100, + vec![(vec![0], tbl1_per_column_stats)].into_iter().collect(), + ), + ), + ( + String::from(TABLE2_NAME), + TableStats::new( + 100, + vec![(vec![0], tbl2_per_column_stats)].into_iter().collect(), + ), + ), + ( + String::from(TABLE3_NAME), + TableStats::new( + 100, + vec![(vec![0], tbl3_per_column_stats)].into_iter().collect(), + ), + ), + ( + String::from(TABLE4_NAME), + TableStats::new( + 100, + vec![(vec![0], tbl4_per_column_stats)].into_iter().collect(), + ), + ), + ] + .into_iter() + .collect(), + ); + CostModelImpl::new(storage_manager, CatalogSource::Mock) + } -// pub fn col_ref(idx: u64) -> ArcDfPredNode { -// // this conversion is always safe because idx was originally a usize -// let idx_as_usize = idx as usize; -// ColumnRefPred::new(idx_as_usize).into_pred_node() -// } + /// We need custom row counts because some join algorithms rely on the row cnt + pub fn create_two_table_cost_model_custom_row_cnts_mock_storage( + tbl1_per_column_stats: TestPerColumnStats, + tbl2_per_column_stats: TestPerColumnStats, + tbl1_row_cnt: usize, + tbl2_row_cnt: usize, + ) -> TestOptCostModelMock { + let storage_manager = CostModelStorageMockManagerImpl::new( + vec![ + ( + String::from(TABLE1_NAME), + TableStats::new( + tbl1_row_cnt, + vec![(vec![0], tbl1_per_column_stats)].into_iter().collect(), + ), + ), + ( + String::from(TABLE2_NAME), + TableStats::new( + tbl2_row_cnt, + vec![(vec![0], tbl2_per_column_stats)].into_iter().collect(), + ), + ), + ] + .into_iter() + .collect(), + ); + CostModelImpl::new(storage_manager, CatalogSource::Mock) + } -// pub fn cnst(value: Value) -> ArcDfPredNode { -// ConstantPred::new(value).into_pred_node() -// } + pub fn attr_ref(table_id: TableId, attr_base_index: usize) -> ArcPredicateNode { + AttributeRefPred::new(table_id, attr_base_index).into_pred_node() + } -// pub fn cast(child: ArcDfPredNode, cast_type: DataType) -> ArcDfPredNode { -// CastPred::new(child, cast_type).into_pred_node() -// } + pub fn cnst(value: Value) -> ArcPredicateNode { + ConstantPred::new(value).into_pred_node() + } -// pub fn bin_op(op_type: BinOpType, left: ArcDfPredNode, right: ArcDfPredNode) -> ArcDfPredNode { -// BinOpPred::new(left, right, op_type).into_pred_node() -// } + pub fn cast(child: ArcPredicateNode, cast_type: DataType) -> ArcPredicateNode { + CastPred::new(child, cast_type).into_pred_node() + } -// pub fn log_op(op_type: LogOpType, children: Vec) -> ArcDfPredNode { -// LogOpPred::new(op_type, children).into_pred_node() -// } + pub fn bin_op( + op_type: BinOpType, + left: ArcPredicateNode, + right: ArcPredicateNode, + ) -> ArcPredicateNode { + BinOpPred::new(left, right, op_type).into_pred_node() + } -// pub fn un_op(op_type: UnOpType, child: ArcDfPredNode) -> ArcDfPredNode { -// UnOpPred::new(child, op_type).into_pred_node() -// } + pub fn log_op(op_type: LogOpType, children: Vec) -> ArcPredicateNode { + LogOpPred::new(op_type, children).into_pred_node() + } -// pub fn in_list(col_ref_idx: u64, list: Vec, negated: bool) -> InListPred { -// InListPred::new( -// col_ref(col_ref_idx), -// ListPred::new(list.into_iter().map(cnst).collect_vec()), -// negated, -// ) -// } + pub fn un_op(op_type: UnOpType, child: ArcPredicateNode) -> ArcPredicateNode { + UnOpPred::new(child, op_type).into_pred_node() + } -// pub fn like(col_ref_idx: u64, pattern: &str, negated: bool) -> LikePred { -// LikePred::new( -// negated, -// false, -// col_ref(col_ref_idx), -// cnst(Value::String(pattern.into())), -// ) -// } + pub fn in_list( + table_id: TableId, + attr_ref_idx: usize, + list: Vec, + negated: bool, + ) -> InListPred { + InListPred::new( + attr_ref(table_id, attr_ref_idx), + ListPred::new(list.into_iter().map(cnst).collect_vec()), + negated, + ) + } -// /// The reason this isn't an associated function of PerColumnStats is because that would require -// /// adding an empty() function to the trait definitions of MostCommonValues and Distribution, -// /// which I wanted to avoid -// pub(crate) fn get_empty_per_col_stats() -> TestPerColumnStats { -// TestPerColumnStats::new( -// TestMostCommonValues::empty(), -// 0, -// 0.0, -// Some(TestDistribution::empty()), -// ) -// } -// } + pub fn like(table_id: TableId, attr_ref_idx: usize, pattern: &str, negated: bool) -> LikePred { + LikePred::new( + negated, + false, + attr_ref(table_id, attr_ref_idx), + cnst(Value::String(pattern.into())), + ) + } +} diff --git a/optd-cost-model/src/storage/mock.rs b/optd-cost-model/src/storage/mock.rs index e69de29..8a1f962 100644 --- a/optd-cost-model/src/storage/mock.rs +++ b/optd-cost-model/src/storage/mock.rs @@ -0,0 +1,62 @@ +#![allow(unused_variables, dead_code)] +use std::collections::HashMap; + +use serde::{Deserialize, Serialize}; + +use crate::{common::types::TableId, stats::AttributeCombValueStats, CostModelResult}; + +use super::{Attribute, CostModelStorageManager}; + +pub type AttrsIdx = Vec; + +#[serde_with::serde_as] +#[derive(Serialize, Deserialize, Debug)] +pub struct TableStats { + pub row_cnt: usize, + #[serde_as(as = "HashMap")] + pub column_comb_stats: HashMap, +} + +impl TableStats { + pub fn new( + row_cnt: usize, + column_comb_stats: HashMap, + ) -> Self { + Self { + row_cnt, + column_comb_stats, + } + } +} + +pub type BaseTableStats = HashMap; + +pub struct CostModelStorageMockManagerImpl { + pub(crate) per_table_stats_map: BaseTableStats, +} + +impl CostModelStorageMockManagerImpl { + pub fn new(per_table_stats_map: BaseTableStats) -> Self { + Self { + per_table_stats_map, + } + } +} + +impl CostModelStorageManager for CostModelStorageMockManagerImpl { + async fn get_attribute_info( + &self, + table_id: TableId, + attr_base_index: i32, + ) -> CostModelResult> { + todo!() + } + + async fn get_attributes_comb_statistics( + &self, + table_id: TableId, + attr_base_indices: &[usize], + ) -> CostModelResult> { + todo!() + } +} From 86f6fc2f89c230405587d9bb766215baa4fbdd46 Mon Sep 17 00:00:00 2001 From: Lan Lou Date: Sat, 16 Nov 2024 14:47:37 -0500 Subject: [PATCH 25/51] Fix typo in initial test framework --- optd-cost-model/src/cost_model.rs | 88 ++++++++++++++++++++----------- 1 file changed, 58 insertions(+), 30 deletions(-) diff --git a/optd-cost-model/src/cost_model.rs b/optd-cost-model/src/cost_model.rs index 4700f17..fe3df42 100644 --- a/optd-cost-model/src/cost_model.rs +++ b/optd-cost-model/src/cost_model.rs @@ -144,7 +144,7 @@ mod tests { use super::*; - pub type TestPerColumnStats = AttributeCombValueStats; + pub type TestPerAttributeStats = AttributeCombValueStats; // TODO: add tests for non-mock storage manager pub type TestOptCostModelMock = CostModelImpl; @@ -153,14 +153,17 @@ mod tests { pub const TABLE3_NAME: &str = "table3"; pub const TABLE4_NAME: &str = "table4"; - // one column is sufficient for all filter selectivity tests - pub fn create_one_column_cost_model_mock_storage( - per_column_stats: TestPerColumnStats, + // one attribute is sufficient for all filter selectivity tests + pub fn create_one_attribute_cost_model_mock_storage( + per_attribute_stats: TestPerAttributeStats, ) -> TestOptCostModelMock { let storage_manager = CostModelStorageMockManagerImpl::new( vec![( String::from(TABLE1_NAME), - TableStats::new(100, vec![(vec![0], per_column_stats)].into_iter().collect()), + TableStats::new( + 100, + vec![(vec![0], per_attribute_stats)].into_iter().collect(), + ), )] .into_iter() .collect(), @@ -168,24 +171,24 @@ mod tests { CostModelImpl::new(storage_manager, CatalogSource::Mock) } - /// Create a cost model with two columns, one for each table. Each column has 100 values. + /// Create a cost model with two attributes, one for each table. Each attribute has 100 values. pub fn create_two_table_cost_model_mock_storage( - tbl1_per_column_stats: TestPerColumnStats, - tbl2_per_column_stats: TestPerColumnStats, + tbl1_per_attribute_stats: TestPerAttributeStats, + tbl2_per_attribute_stats: TestPerAttributeStats, ) -> TestOptCostModelMock { create_two_table_cost_model_custom_row_cnts_mock_storage( - tbl1_per_column_stats, - tbl2_per_column_stats, + tbl1_per_attribute_stats, + tbl2_per_attribute_stats, 100, 100, ) } - /// Create a cost model with three columns, one for each table. Each column has 100 values. + /// Create a cost model with three attributes, one for each table. Each attribute has 100 values. pub fn create_three_table_cost_model( - tbl1_per_column_stats: TestPerColumnStats, - tbl2_per_column_stats: TestPerColumnStats, - tbl3_per_column_stats: TestPerColumnStats, + tbl1_per_attribute_stats: TestPerAttributeStats, + tbl2_per_attribute_stats: TestPerAttributeStats, + tbl3_per_attribute_stats: TestPerAttributeStats, ) -> TestOptCostModelMock { let storage_manager = CostModelStorageMockManagerImpl::new( vec![ @@ -193,21 +196,27 @@ mod tests { String::from(TABLE1_NAME), TableStats::new( 100, - vec![(vec![0], tbl1_per_column_stats)].into_iter().collect(), + vec![(vec![0], tbl1_per_attribute_stats)] + .into_iter() + .collect(), ), ), ( String::from(TABLE2_NAME), TableStats::new( 100, - vec![(vec![0], tbl2_per_column_stats)].into_iter().collect(), + vec![(vec![0], tbl2_per_attribute_stats)] + .into_iter() + .collect(), ), ), ( String::from(TABLE3_NAME), TableStats::new( 100, - vec![(vec![0], tbl3_per_column_stats)].into_iter().collect(), + vec![(vec![0], tbl3_per_attribute_stats)] + .into_iter() + .collect(), ), ), ] @@ -217,12 +226,12 @@ mod tests { CostModelImpl::new(storage_manager, CatalogSource::Mock) } - /// Create a cost model with three columns, one for each table. Each column has 100 values. + /// Create a cost model with three attributes, one for each table. Each attribute has 100 values. pub fn create_four_table_cost_model_mock_storage( - tbl1_per_column_stats: TestPerColumnStats, - tbl2_per_column_stats: TestPerColumnStats, - tbl3_per_column_stats: TestPerColumnStats, - tbl4_per_column_stats: TestPerColumnStats, + tbl1_per_attribute_stats: TestPerAttributeStats, + tbl2_per_attribute_stats: TestPerAttributeStats, + tbl3_per_attribute_stats: TestPerAttributeStats, + tbl4_per_attribute_stats: TestPerAttributeStats, ) -> TestOptCostModelMock { let storage_manager = CostModelStorageMockManagerImpl::new( vec![ @@ -230,28 +239,36 @@ mod tests { String::from(TABLE1_NAME), TableStats::new( 100, - vec![(vec![0], tbl1_per_column_stats)].into_iter().collect(), + vec![(vec![0], tbl1_per_attribute_stats)] + .into_iter() + .collect(), ), ), ( String::from(TABLE2_NAME), TableStats::new( 100, - vec![(vec![0], tbl2_per_column_stats)].into_iter().collect(), + vec![(vec![0], tbl2_per_attribute_stats)] + .into_iter() + .collect(), ), ), ( String::from(TABLE3_NAME), TableStats::new( 100, - vec![(vec![0], tbl3_per_column_stats)].into_iter().collect(), + vec![(vec![0], tbl3_per_attribute_stats)] + .into_iter() + .collect(), ), ), ( String::from(TABLE4_NAME), TableStats::new( 100, - vec![(vec![0], tbl4_per_column_stats)].into_iter().collect(), + vec![(vec![0], tbl4_per_attribute_stats)] + .into_iter() + .collect(), ), ), ] @@ -263,8 +280,8 @@ mod tests { /// We need custom row counts because some join algorithms rely on the row cnt pub fn create_two_table_cost_model_custom_row_cnts_mock_storage( - tbl1_per_column_stats: TestPerColumnStats, - tbl2_per_column_stats: TestPerColumnStats, + tbl1_per_attribute_stats: TestPerAttributeStats, + tbl2_per_attribute_stats: TestPerAttributeStats, tbl1_row_cnt: usize, tbl2_row_cnt: usize, ) -> TestOptCostModelMock { @@ -274,14 +291,18 @@ mod tests { String::from(TABLE1_NAME), TableStats::new( tbl1_row_cnt, - vec![(vec![0], tbl1_per_column_stats)].into_iter().collect(), + vec![(vec![0], tbl1_per_attribute_stats)] + .into_iter() + .collect(), ), ), ( String::from(TABLE2_NAME), TableStats::new( tbl2_row_cnt, - vec![(vec![0], tbl2_per_column_stats)].into_iter().collect(), + vec![(vec![0], tbl2_per_attribute_stats)] + .into_iter() + .collect(), ), ), ] @@ -340,4 +361,11 @@ mod tests { cnst(Value::String(pattern.into())), ) } + + /// The reason this isn't an associated function of PerAttributeStats is because that would require + /// adding an empty type to the enum definitions of MostCommonValues and Distribution, + /// which I wanted to avoid + pub(crate) fn get_empty_per_col_stats() -> TestPerAttributeStats { + TestPerAttributeStats::new(MostCommonValues::Counter(Counter::default()), 0, 0.0, None) + } } From 2c1f09b8c479edb55f0535328a6d5fb370a2e254 Mon Sep 17 00:00:00 2001 From: Lan Lou Date: Sat, 16 Nov 2024 20:03:18 -0500 Subject: [PATCH 26/51] Modify initial test framework --- Cargo.lock | 1 + optd-cost-model/Cargo.toml | 1 + optd-cost-model/src/cost/filter/controller.rs | 579 ++++++++++++++++++ optd-cost-model/src/cost_model.rs | 180 +----- optd-cost-model/src/stats/counter.rs | 7 + optd-cost-model/src/storage/mock.rs | 2 +- 6 files changed, 609 insertions(+), 161 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8381193..f6c0033 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2257,6 +2257,7 @@ dependencies = [ "serde", "serde_json", "serde_with", + "tokio", "trait-variant", ] diff --git a/optd-cost-model/Cargo.toml b/optd-cost-model/Cargo.toml index c20c062..4ede352 100644 --- a/optd-cost-model/Cargo.toml +++ b/optd-cost-model/Cargo.toml @@ -17,6 +17,7 @@ chrono = "0.4" itertools = "0.13" assert_approx_eq = "1.1.0" trait-variant = "0.1.2" +tokio = { version = "1.0.1", features = ["macros", "rt-multi-thread"] } [dev-dependencies] crossbeam = "0.8" diff --git a/optd-cost-model/src/cost/filter/controller.rs b/optd-cost-model/src/cost/filter/controller.rs index 35301c1..90735a7 100644 --- a/optd-cost-model/src/cost/filter/controller.rs +++ b/optd-cost-model/src/cost/filter/controller.rs @@ -87,3 +87,582 @@ impl CostModelImpl { }).await } } + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use crate::{ + common::{predicates::bin_op_pred::BinOpType, types::TableId, values::Value}, + cost_model::tests::*, + stats::{counter::Counter, MostCommonValues}, + }; + use arrow_schema::DataType; + + #[tokio::test] + async fn test_const() { + let cost_model = create_cost_model_mock_storage( + vec![TableId(0)], + vec![get_empty_per_attr_stats()], + vec![None], + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(cnst(Value::Bool(true))) + .await + .unwrap(), + 1.0 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(cnst(Value::Bool(false))) + .await + .unwrap(), + 0.0 + ); + } + + // #[tokio::test] + // async fn test_attrref_eq_constint_in_mcv() { + // let mut mcvs_counts = HashMap::new(); + // mcvs_counts.insert(vec![Some(Value::Int32(1))], 3); + // let mcvs_total_count = 10; + // let per_attribute_stats = TestPerAttributeStats::new( + // MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + // 0, + // 0.0, + // None, + // ); + // let table_id = TableId(0); + // let cost_model = + // create_cost_model_mock_storage(vec![table_id], vec![per_attribute_stats], vec![None]); + + // let expr_tree = bin_op(BinOpType::Eq, attr_ref(table_id, 0), cnst(Value::Int32(1))); + // let expr_tree_rev = bin_op(BinOpType::Eq, cnst(Value::Int32(1)), attr_ref(table_id, 0)); + // assert_approx_eq::assert_approx_eq!( + // cost_model.get_filter_selectivity(expr_tree).await.unwrap(), + // 0.3 + // ); + // assert_approx_eq::assert_approx_eq!( + // cost_model + // .get_filter_selectivity(expr_tree_rev) + // .await + // .unwrap(), + // 0.3 + // ); + // } + + // #[test] + // fn test_attrref_eq_constint_not_in_mcv() { + // let cost_model = create_one_attribute_cost_model(TestPerAttributeStats::new( + // TestMostCommonValues::new(vec![(Value::Int32(1), 0.2), (Value::Int32(3), 0.44)]), + // 5, + // 0.0, + // Some(TestDistribution::empty()), + // )); + // let expr_tree = bin_op(BinOpType::Eq, attr_ref(0), cnst(Value::Int32(2))); + // let expr_tree_rev = bin_op(BinOpType::Eq, cnst(Value::Int32(2)), attr_ref(0)); + // let schema = Schema::new(vec![]); + // let attribute_refs = vec![AttributeRef::base_table_attribute_ref( + // String::from(TABLE1_NAME), + // 0, + // )]; + // assert_approx_eq::assert_approx_eq!( + // cost_model.get_filter_selectivity(expr_tree, &schema, &attribute_refs), + // 0.12 + // ); + // assert_approx_eq::assert_approx_eq!( + // cost_model.get_filter_selectivity(expr_tree_rev, &schema, &attribute_refs), + // 0.12 + // ); + // } + + // /// I only have one test for NEQ since I'll assume that it uses the same underlying logic as EQ + // #[test] + // fn test_attrref_neq_constint_in_mcv() { + // let cost_model = create_one_attribute_cost_model(TestPerAttributeStats::new( + // TestMostCommonValues::new(vec![(Value::Int32(1), 0.3)]), + // 0, + // 0.0, + // Some(TestDistribution::empty()), + // )); + // let expr_tree = bin_op(BinOpType::Neq, attr_ref(0), cnst(Value::Int32(1))); + // let expr_tree_rev = bin_op(BinOpType::Neq, cnst(Value::Int32(1)), attr_ref(0)); + // let schema = Schema::new(vec![]); + // let attribute_refs = vec![AttributeRef::base_table_attribute_ref( + // String::from(TABLE1_NAME), + // 0, + // )]; + // assert_approx_eq::assert_approx_eq!( + // cost_model.get_filter_selectivity(expr_tree, &schema, &attribute_refs), + // 1.0 - 0.3 + // ); + // assert_approx_eq::assert_approx_eq!( + // cost_model.get_filter_selectivity(expr_tree_rev, &schema, &attribute_refs), + // 1.0 - 0.3 + // ); + // } + + // #[test] + // fn test_attrref_leq_constint_no_mcvs_in_range() { + // let cost_model = create_one_attribute_cost_model(TestPerAttributeStats::new( + // TestMostCommonValues::empty(), + // 10, + // 0.0, + // Some(TestDistribution::new(vec![(Value::Int32(15), 0.7)])), + // )); + // let expr_tree = bin_op(BinOpType::Leq, attr_ref(0), cnst(Value::Int32(15))); + // let expr_tree_rev = bin_op(BinOpType::Gt, cnst(Value::Int32(15)), attr_ref(0)); + // let schema = Schema::new(vec![]); + // let attribute_refs = vec![AttributeRef::base_table_attribute_ref( + // String::from(TABLE1_NAME), + // 0, + // )]; + // assert_approx_eq::assert_approx_eq!( + // cost_model.get_filter_selectivity(expr_tree, &schema, &attribute_refs), + // 0.7 + // ); + // assert_approx_eq::assert_approx_eq!( + // cost_model.get_filter_selectivity(expr_tree_rev, &schema, &attribute_refs), + // 0.7 + // ); + // } + + // #[test] + // fn test_attrref_leq_constint_with_mcvs_in_range_not_at_border() { + // let cost_model = create_one_attribute_cost_model(TestPerAttributeStats::new( + // TestMostCommonValues { + // mcvs: vec![ + // (vec![Some(Value::Int32(6))], 0.05), + // (vec![Some(Value::Int32(10))], 0.1), + // (vec![Some(Value::Int32(17))], 0.08), + // (vec![Some(Value::Int32(25))], 0.07), + // ] + // .into_iter() + // .attrlect(), + // }, + // 10, + // 0.0, + // Some(TestDistribution::new(vec![(Value::Int32(15), 0.7)])), + // )); + // let expr_tree = bin_op(BinOpType::Leq, attr_ref(0), cnst(Value::Int32(15))); + // let expr_tree_rev = bin_op(BinOpType::Gt, cnst(Value::Int32(15)), attr_ref(0)); + // let schema = Schema::new(vec![]); + // let attribute_refs = vec![AttributeRef::base_table_attribute_ref( + // String::from(TABLE1_NAME), + // 0, + // )]; + // assert_approx_eq::assert_approx_eq!( + // cost_model.get_filter_selectivity(expr_tree, &schema, &attribute_refs), + // 0.85 + // ); + // assert_approx_eq::assert_approx_eq!( + // cost_model.get_filter_selectivity(expr_tree_rev, &schema, &attribute_refs), + // 0.85 + // ); + // } + + // #[test] + // fn test_attrref_leq_constint_with_mcv_at_border() { + // let cost_model = create_one_attribute_cost_model(TestPerAttributeStats::new( + // TestMostCommonValues::new(vec![ + // (Value::Int32(6), 0.05), + // (Value::Int32(10), 0.1), + // (Value::Int32(15), 0.08), + // (Value::Int32(25), 0.07), + // ]), + // 10, + // 0.0, + // Some(TestDistribution::new(vec![(Value::Int32(15), 0.7)])), + // )); + // let expr_tree = bin_op(BinOpType::Leq, attr_ref(0), cnst(Value::Int32(15))); + // let expr_tree_rev = bin_op(BinOpType::Gt, cnst(Value::Int32(15)), attr_ref(0)); + // let schema = Schema::new(vec![]); + // let attribute_refs = vec![AttributeRef::base_table_attribute_ref( + // String::from(TABLE1_NAME), + // 0, + // )]; + // assert_approx_eq::assert_approx_eq!( + // cost_model.get_filter_selectivity(expr_tree, &schema, &attribute_refs), + // 0.93 + // ); + // assert_approx_eq::assert_approx_eq!( + // cost_model.get_filter_selectivity(expr_tree_rev, &schema, &attribute_refs), + // 0.93 + // ); + // } + + // #[test] + // fn test_attrref_lt_constint_no_mcvs_in_range() { + // let cost_model = create_one_attribute_cost_model(TestPerAttributeStats::new( + // TestMostCommonValues::empty(), + // 10, + // 0.0, + // Some(TestDistribution::new(vec![(Value::Int32(15), 0.7)])), + // )); + // let expr_tree = bin_op(BinOpType::Lt, attr_ref(0), cnst(Value::Int32(15))); + // let expr_tree_rev = bin_op(BinOpType::Geq, cnst(Value::Int32(15)), attr_ref(0)); + // let schema = Schema::new(vec![]); + // let attribute_refs = vec![AttributeRef::base_table_attribute_ref( + // String::from(TABLE1_NAME), + // 0, + // )]; + // assert_approx_eq::assert_approx_eq!( + // cost_model.get_filter_selectivity(expr_tree, &schema, &attribute_refs), + // 0.6 + // ); + // assert_approx_eq::assert_approx_eq!( + // cost_model.get_filter_selectivity(expr_tree_rev, &schema, &attribute_refs), + // 0.6 + // ); + // } + + // #[test] + // fn test_attrref_lt_constint_with_mcvs_in_range_not_at_border() { + // let cost_model = create_one_attribute_cost_model(TestPerAttributeStats::new( + // TestMostCommonValues { + // mcvs: vec![ + // (vec![Some(Value::Int32(6))], 0.05), + // (vec![Some(Value::Int32(10))], 0.1), + // (vec![Some(Value::Int32(17))], 0.08), + // (vec![Some(Value::Int32(25))], 0.07), + // ] + // .into_iter() + // .attrlect(), + // }, + // 11, /* there are 4 MCVs which together add up to 0.3. With 11 total ndistinct, each + // * remaining value has freq 0.1 */ + // 0.0, + // Some(TestDistribution::new(vec![(Value::Int32(15), 0.7)])), + // )); + // let expr_tree = bin_op(BinOpType::Lt, attr_ref(0), cnst(Value::Int32(15))); + // let expr_tree_rev = bin_op(BinOpType::Geq, cnst(Value::Int32(15)), attr_ref(0)); + // let schema = Schema::new(vec![]); + // let attribute_refs = vec![AttributeRef::base_table_attribute_ref( + // String::from(TABLE1_NAME), + // 0, + // )]; + // assert_approx_eq::assert_approx_eq!( + // cost_model.get_filter_selectivity(expr_tree, &schema, &attribute_refs), + // 0.75 + // ); + // assert_approx_eq::assert_approx_eq!( + // cost_model.get_filter_selectivity(expr_tree_rev, &schema, &attribute_refs), + // 0.75 + // ); + // } + + // #[test] + // fn test_attrref_lt_constint_with_mcv_at_border() { + // let cost_model = create_one_attribute_cost_model(TestPerAttributeStats::new( + // TestMostCommonValues { + // mcvs: vec![ + // (vec![Some(Value::Int32(6))], 0.05), + // (vec![Some(Value::Int32(10))], 0.1), + // (vec![Some(Value::Int32(15))], 0.08), + // (vec![Some(Value::Int32(25))], 0.07), + // ] + // .into_iter() + // .attrlect(), + // }, + // 11, /* there are 4 MCVs which together add up to 0.3. With 11 total ndistinct, each + // * remaining value has freq 0.1 */ + // 0.0, + // Some(TestDistribution::new(vec![(Value::Int32(15), 0.7)])), + // )); + // let expr_tree = bin_op(BinOpType::Lt, attr_ref(0), cnst(Value::Int32(15))); + // let expr_tree_rev = bin_op(BinOpType::Geq, cnst(Value::Int32(15)), attr_ref(0)); + // let schema = Schema::new(vec![]); + // let attribute_refs = vec![AttributeRef::base_table_attribute_ref( + // String::from(TABLE1_NAME), + // 0, + // )]; + // assert_approx_eq::assert_approx_eq!( + // cost_model.get_filter_selectivity(expr_tree, &schema, &attribute_refs), + // 0.85 + // ); + // assert_approx_eq::assert_approx_eq!( + // cost_model.get_filter_selectivity(expr_tree_rev, &schema, &attribute_refs), + // 0.85 + // ); + // } + + // /// I have fewer tests for GT since I'll assume that it uses the same underlying logic as LEQ + // /// The only interesting thing to test is that if there are nulls, those aren't included in GT + // #[test] + // fn test_attrref_gt_constint() { + // let cost_model = create_one_attribute_cost_model(TestPerAttributeStats::new( + // TestMostCommonValues::empty(), + // 10, + // 0.0, + // Some(TestDistribution::new(vec![(Value::Int32(15), 0.7)])), + // )); + // let expr_tree = bin_op(BinOpType::Gt, attr_ref(0), cnst(Value::Int32(15))); + // let expr_tree_rev = bin_op(BinOpType::Leq, cnst(Value::Int32(15)), attr_ref(0)); + // let schema = Schema::new(vec![]); + // let attribute_refs = vec![AttributeRef::base_table_attribute_ref( + // String::from(TABLE1_NAME), + // 0, + // )]; + // assert_approx_eq::assert_approx_eq!( + // cost_model.get_filter_selectivity(expr_tree, &schema, &attribute_refs), + // 1.0 - 0.7 + // ); + // assert_approx_eq::assert_approx_eq!( + // cost_model.get_filter_selectivity(expr_tree_rev, &schema, &attribute_refs), + // 1.0 - 0.7 + // ); + // } + + // #[test] + // fn test_attrref_geq_constint() { + // let cost_model = create_one_attribute_cost_model(TestPerAttributeStats::new( + // TestMostCommonValues::empty(), + // 10, + // 0.0, + // Some(TestDistribution::new(vec![(Value::Int32(15), 0.7)])), + // )); + // let expr_tree = bin_op(BinOpType::Geq, attr_ref(0), cnst(Value::Int32(15))); + // let expr_tree_rev = bin_op(BinOpType::Lt, cnst(Value::Int32(15)), attr_ref(0)); + // let schema = Schema::new(vec![]); + // let attribute_refs = vec![AttributeRef::base_table_attribute_ref( + // String::from(TABLE1_NAME), + // 0, + // )]; + // assert_approx_eq::assert_approx_eq!( + // cost_model.get_filter_selectivity(expr_tree, &schema, &attribute_refs), + // 1.0 - 0.6 + // ); + // assert_approx_eq::assert_approx_eq!( + // cost_model.get_filter_selectivity(expr_tree_rev, &schema, &attribute_refs), + // 1.0 - 0.6 + // ); + // } + + // #[test] + // fn test_and() { + // let cost_model = create_one_attribute_cost_model(TestPerAttributeStats::new( + // TestMostCommonValues { + // mcvs: vec![ + // (vec![Some(Value::Int32(1))], 0.3), + // (vec![Some(Value::Int32(5))], 0.5), + // (vec![Some(Value::Int32(8))], 0.2), + // ] + // .into_iter() + // .attrlect(), + // }, + // 0, + // 0.0, + // Some(TestDistribution::empty()), + // )); + // let eq1 = bin_op(BinOpType::Eq, attr_ref(0), cnst(Value::Int32(1))); + // let eq5 = bin_op(BinOpType::Eq, attr_ref(0), cnst(Value::Int32(5))); + // let eq8 = bin_op(BinOpType::Eq, attr_ref(0), cnst(Value::Int32(8))); + // let expr_tree = log_op(LogOpType::And, vec![eq1.clone(), eq5.clone(), eq8.clone()]); + // let expr_tree_shift1 = log_op(LogOpType::And, vec![eq5.clone(), eq8.clone(), eq1.clone()]); + // let expr_tree_shift2 = log_op(LogOpType::And, vec![eq8.clone(), eq1.clone(), eq5.clone()]); + // let schema = Schema::new(vec![]); + // let attribute_refs = vec![AttributeRef::base_table_attribute_ref( + // String::from(TABLE1_NAME), + // 0, + // )]; + // assert_approx_eq::assert_approx_eq!( + // cost_model.get_filter_selectivity(expr_tree, &schema, &attribute_refs), + // 0.03 + // ); + // assert_approx_eq::assert_approx_eq!( + // cost_model.get_filter_selectivity(expr_tree_shift1, &schema, &attribute_refs), + // 0.03 + // ); + // assert_approx_eq::assert_approx_eq!( + // cost_model.get_filter_selectivity(expr_tree_shift2, &schema, &attribute_refs), + // 0.03 + // ); + // } + + // #[test] + // fn test_or() { + // let cost_model = create_one_attribute_cost_model(TestPerAttributeStats::new( + // TestMostCommonValues { + // mcvs: vec![ + // (vec![Some(Value::Int32(1))], 0.3), + // (vec![Some(Value::Int32(5))], 0.5), + // (vec![Some(Value::Int32(8))], 0.2), + // ] + // .into_iter() + // .attrlect(), + // }, + // 0, + // 0.0, + // Some(TestDistribution::empty()), + // )); + // let eq1 = bin_op(BinOpType::Eq, attr_ref(0), cnst(Value::Int32(1))); + // let eq5 = bin_op(BinOpType::Eq, attr_ref(0), cnst(Value::Int32(5))); + // let eq8 = bin_op(BinOpType::Eq, attr_ref(0), cnst(Value::Int32(8))); + // let expr_tree = log_op(LogOpType::Or, vec![eq1.clone(), eq5.clone(), eq8.clone()]); + // let expr_tree_shift1 = log_op(LogOpType::Or, vec![eq5.clone(), eq8.clone(), eq1.clone()]); + // let expr_tree_shift2 = log_op(LogOpType::Or, vec![eq8.clone(), eq1.clone(), eq5.clone()]); + // let schema = Schema::new(vec![]); + // let attribute_refs = vec![AttributeRef::base_table_attribute_ref( + // String::from(TABLE1_NAME), + // 0, + // )]; + // assert_approx_eq::assert_approx_eq!( + // cost_model.get_filter_selectivity(expr_tree, &schema, &attribute_refs), + // 0.72 + // ); + // assert_approx_eq::assert_approx_eq!( + // cost_model.get_filter_selectivity(expr_tree_shift1, &schema, &attribute_refs), + // 0.72 + // ); + // assert_approx_eq::assert_approx_eq!( + // cost_model.get_filter_selectivity(expr_tree_shift2, &schema, &attribute_refs), + // 0.72 + // ); + // } + + // #[test] + // fn test_not() { + // let cost_model = create_one_attribute_cost_model(TestPerAttributeStats::new( + // TestMostCommonValues::new(vec![(Value::Int32(1), 0.3)]), + // 0, + // 0.0, + // Some(TestDistribution::empty()), + // )); + // let expr_tree = un_op( + // UnOpType::Not, + // bin_op(BinOpType::Eq, attr_ref(0), cnst(Value::Int32(1))), + // ); + // let schema = Schema::new(vec![]); + // let attribute_refs = vec![AttributeRef::base_table_attribute_ref( + // String::from(TABLE1_NAME), + // 0, + // )]; + // assert_approx_eq::assert_approx_eq!( + // cost_model.get_filter_selectivity(expr_tree, &schema, &attribute_refs), + // 0.7 + // ); + // } + + // // I didn't test any non-unique cases with filter. The non-unique tests without filter should + // // cover that + + // #[test] + // fn test_attrref_eq_cast_value() { + // let cost_model = create_one_attribute_cost_model(TestPerAttributeStats::new( + // TestMostCommonValues::new(vec![(Value::Int32(1), 0.3)]), + // 0, + // 0.1, + // Some(TestDistribution::empty()), + // )); + // let expr_tree = bin_op( + // BinOpType::Eq, + // attr_ref(0), + // cast(cnst(Value::Int64(1)), DataType::Int32), + // ); + // let expr_tree_rev = bin_op( + // BinOpType::Eq, + // cast(cnst(Value::Int64(1)), DataType::Int32), + // attr_ref(0), + // ); + // let schema = Schema::new(vec![]); + // let attribute_refs = vec![AttributeRef::base_table_attribute_ref( + // String::from(TABLE1_NAME), + // 0, + // )]; + // assert_approx_eq::assert_approx_eq!( + // cost_model.get_filter_selectivity(expr_tree, &schema, &attribute_refs), + // 0.3 + // ); + // assert_approx_eq::assert_approx_eq!( + // cost_model.get_filter_selectivity(expr_tree_rev, &schema, &attribute_refs), + // 0.3 + // ); + // } + + // #[test] + // fn test_cast_attrref_eq_value() { + // let cost_model = create_one_attribute_cost_model(TestPerAttributeStats::new( + // TestMostCommonValues::new(vec![(Value::Int32(1), 0.3)]), + // 0, + // 0.1, + // Some(TestDistribution::empty()), + // )); + // let expr_tree = bin_op( + // BinOpType::Eq, + // cast(attr_ref(0), DataType::Int64), + // cnst(Value::Int64(1)), + // ); + // let expr_tree_rev = bin_op( + // BinOpType::Eq, + // cnst(Value::Int64(1)), + // cast(attr_ref(0), DataType::Int64), + // ); + // let schema = Schema::new(vec![Field { + // name: String::from(""), + // typ: ConstantType::Int32, + // nullable: false, + // }]); + // let attribute_refs = vec![AttributeRef::base_table_attribute_ref( + // String::from(TABLE1_NAME), + // 0, + // )]; + // assert_approx_eq::assert_approx_eq!( + // cost_model.get_filter_selectivity(expr_tree, &schema, &attribute_refs), + // 0.3 + // ); + // assert_approx_eq::assert_approx_eq!( + // cost_model.get_filter_selectivity(expr_tree_rev, &schema, &attribute_refs), + // 0.3 + // ); + // } + + // /// In this case, we should leave the Cast as is. + // /// + // /// Note that the test only checks the selectivity and thus doesn't explicitly test that the + // /// Cast is indeed left as is. However, if get_filter_selectivity() doesn't crash, that's a + // /// pretty good signal that the Cast was left as is. + // #[test] + // fn test_cast_attrref_eq_attrref() { + // let cost_model = create_one_attribute_cost_model(TestPerAttributeStats::new( + // TestMostCommonValues::new(vec![]), + // 0, + // 0.0, + // Some(TestDistribution::empty()), + // )); + // let expr_tree = bin_op( + // BinOpType::Eq, + // cast(attr_ref(0), DataType::Int64), + // attr_ref(1), + // ); + // let expr_tree_rev = bin_op( + // BinOpType::Eq, + // attr_ref(1), + // cast(attr_ref(0), DataType::Int64), + // ); + // let schema = Schema::new(vec![ + // Field { + // name: String::from(""), + // typ: ConstantType::Int32, + // nullable: false, + // }, + // Field { + // name: String::from(""), + // typ: ConstantType::Int64, + // nullable: false, + // }, + // ]); + // let attribute_refs = vec![AttributeRef::base_table_attribute_ref( + // String::from(TABLE1_NAME), + // 0, + // )]; + // assert_approx_eq::assert_approx_eq!( + // cost_model.get_filter_selectivity(expr_tree, &schema, &attribute_refs), + // DEFAULT_EQ_SEL + // ); + // assert_approx_eq::assert_approx_eq!( + // cost_model.get_filter_selectivity(expr_tree_rev, &schema, &attribute_refs), + // DEFAULT_EQ_SEL + // ); + // } +} diff --git a/optd-cost-model/src/cost_model.rs b/optd-cost-model/src/cost_model.rs index fe3df42..03bcd60 100644 --- a/optd-cost-model/src/cost_model.rs +++ b/optd-cost-model/src/cost_model.rs @@ -111,7 +111,7 @@ impl CostModelImpl { /// but this is not currently feasible because it would create a cyclic dependency between /// optd-datafusion-bridge and optd-datafusion-repr #[cfg(test)] -mod tests { +pub mod tests { use std::collections::HashMap; use arrow_schema::DataType; @@ -148,166 +148,26 @@ mod tests { // TODO: add tests for non-mock storage manager pub type TestOptCostModelMock = CostModelImpl; - pub const TABLE1_NAME: &str = "table1"; - pub const TABLE2_NAME: &str = "table2"; - pub const TABLE3_NAME: &str = "table3"; - pub const TABLE4_NAME: &str = "table4"; - - // one attribute is sufficient for all filter selectivity tests - pub fn create_one_attribute_cost_model_mock_storage( - per_attribute_stats: TestPerAttributeStats, - ) -> TestOptCostModelMock { - let storage_manager = CostModelStorageMockManagerImpl::new( - vec![( - String::from(TABLE1_NAME), - TableStats::new( - 100, - vec![(vec![0], per_attribute_stats)].into_iter().collect(), - ), - )] - .into_iter() - .collect(), - ); - CostModelImpl::new(storage_manager, CatalogSource::Mock) - } - - /// Create a cost model with two attributes, one for each table. Each attribute has 100 values. - pub fn create_two_table_cost_model_mock_storage( - tbl1_per_attribute_stats: TestPerAttributeStats, - tbl2_per_attribute_stats: TestPerAttributeStats, - ) -> TestOptCostModelMock { - create_two_table_cost_model_custom_row_cnts_mock_storage( - tbl1_per_attribute_stats, - tbl2_per_attribute_stats, - 100, - 100, - ) - } - - /// Create a cost model with three attributes, one for each table. Each attribute has 100 values. - pub fn create_three_table_cost_model( - tbl1_per_attribute_stats: TestPerAttributeStats, - tbl2_per_attribute_stats: TestPerAttributeStats, - tbl3_per_attribute_stats: TestPerAttributeStats, - ) -> TestOptCostModelMock { - let storage_manager = CostModelStorageMockManagerImpl::new( - vec![ - ( - String::from(TABLE1_NAME), - TableStats::new( - 100, - vec![(vec![0], tbl1_per_attribute_stats)] - .into_iter() - .collect(), - ), - ), - ( - String::from(TABLE2_NAME), - TableStats::new( - 100, - vec![(vec![0], tbl2_per_attribute_stats)] - .into_iter() - .collect(), - ), - ), - ( - String::from(TABLE3_NAME), - TableStats::new( - 100, - vec![(vec![0], tbl3_per_attribute_stats)] - .into_iter() - .collect(), - ), - ), - ] - .into_iter() - .collect(), - ); - CostModelImpl::new(storage_manager, CatalogSource::Mock) - } - - /// Create a cost model with three attributes, one for each table. Each attribute has 100 values. - pub fn create_four_table_cost_model_mock_storage( - tbl1_per_attribute_stats: TestPerAttributeStats, - tbl2_per_attribute_stats: TestPerAttributeStats, - tbl3_per_attribute_stats: TestPerAttributeStats, - tbl4_per_attribute_stats: TestPerAttributeStats, - ) -> TestOptCostModelMock { - let storage_manager = CostModelStorageMockManagerImpl::new( - vec![ - ( - String::from(TABLE1_NAME), - TableStats::new( - 100, - vec![(vec![0], tbl1_per_attribute_stats)] - .into_iter() - .collect(), - ), - ), - ( - String::from(TABLE2_NAME), - TableStats::new( - 100, - vec![(vec![0], tbl2_per_attribute_stats)] - .into_iter() - .collect(), - ), - ), - ( - String::from(TABLE3_NAME), - TableStats::new( - 100, - vec![(vec![0], tbl3_per_attribute_stats)] - .into_iter() - .collect(), - ), - ), - ( - String::from(TABLE4_NAME), - TableStats::new( - 100, - vec![(vec![0], tbl4_per_attribute_stats)] - .into_iter() - .collect(), - ), - ), - ] - .into_iter() - .collect(), - ); - CostModelImpl::new(storage_manager, CatalogSource::Mock) - } - - /// We need custom row counts because some join algorithms rely on the row cnt - pub fn create_two_table_cost_model_custom_row_cnts_mock_storage( - tbl1_per_attribute_stats: TestPerAttributeStats, - tbl2_per_attribute_stats: TestPerAttributeStats, - tbl1_row_cnt: usize, - tbl2_row_cnt: usize, + pub fn create_cost_model_mock_storage( + table_id: Vec, + per_attribute_stats: Vec, + row_counts: Vec>, ) -> TestOptCostModelMock { let storage_manager = CostModelStorageMockManagerImpl::new( - vec![ - ( - String::from(TABLE1_NAME), - TableStats::new( - tbl1_row_cnt, - vec![(vec![0], tbl1_per_attribute_stats)] - .into_iter() - .collect(), - ), - ), - ( - String::from(TABLE2_NAME), - TableStats::new( - tbl2_row_cnt, - vec![(vec![0], tbl2_per_attribute_stats)] - .into_iter() - .collect(), - ), - ), - ] - .into_iter() - .collect(), + table_id + .into_iter() + .zip(per_attribute_stats) + .zip(row_counts) + .map(|((table_id, per_attr_stats), row_count)| { + ( + table_id, + TableStats::new( + row_count.unwrap_or(100), + vec![(vec![0], per_attr_stats)].into_iter().collect(), + ), + ) + }) + .collect(), ); CostModelImpl::new(storage_manager, CatalogSource::Mock) } @@ -365,7 +225,7 @@ mod tests { /// The reason this isn't an associated function of PerAttributeStats is because that would require /// adding an empty type to the enum definitions of MostCommonValues and Distribution, /// which I wanted to avoid - pub(crate) fn get_empty_per_col_stats() -> TestPerAttributeStats { + pub(crate) fn get_empty_per_attr_stats() -> TestPerAttributeStats { TestPerAttributeStats::new(MostCommonValues::Counter(Counter::default()), 0, 0.0, None) } } diff --git a/optd-cost-model/src/stats/counter.rs b/optd-cost-model/src/stats/counter.rs index 65a2d63..ddffb7a 100644 --- a/optd-cost-model/src/stats/counter.rs +++ b/optd-cost-model/src/stats/counter.rs @@ -32,6 +32,13 @@ where } } + pub fn new_from_existing(counts: HashMap, total_count: i32) -> Self { + Counter:: { + counts, + total_count, + } + } + // Inserts an element in the Counter if it is being tracked. fn insert_element(&mut self, elem: T, occ: i32) { if let Some(frequency) = self.counts.get_mut(&elem) { diff --git a/optd-cost-model/src/storage/mock.rs b/optd-cost-model/src/storage/mock.rs index 8a1f962..3f8a8bf 100644 --- a/optd-cost-model/src/storage/mock.rs +++ b/optd-cost-model/src/storage/mock.rs @@ -29,7 +29,7 @@ impl TableStats { } } -pub type BaseTableStats = HashMap; +pub type BaseTableStats = HashMap; pub struct CostModelStorageMockManagerImpl { pub(crate) per_table_stats_map: BaseTableStats, From ebab829605bad07782079a61cc09acdfea5be147 Mon Sep 17 00:00:00 2001 From: Lan Lou Date: Sat, 16 Nov 2024 21:28:29 -0500 Subject: [PATCH 27/51] Finish most tests for filter --- optd-cost-model/src/cost/filter/controller.rs | 1297 ++++++++++------- optd-cost-model/src/cost_model.rs | 10 +- optd-cost-model/src/stats/mod.rs | 30 +- .../src/stats/{ => utilities}/counter.rs | 3 +- optd-cost-model/src/stats/utilities/mod.rs | 3 + .../src/stats/utilities/simple_map.rs | 20 + .../src/stats/{ => utilities}/tdigest.rs | 4 +- optd-cost-model/src/storage/mock.rs | 26 +- optd-cost-model/src/storage/persistent.rs | 2 +- 9 files changed, 829 insertions(+), 566 deletions(-) rename optd-cost-model/src/stats/{ => utilities}/counter.rs (97%) create mode 100644 optd-cost-model/src/stats/utilities/mod.rs create mode 100644 optd-cost-model/src/stats/utilities/simple_map.rs rename optd-cost-model/src/stats/{ => utilities}/tdigest.rs (99%) diff --git a/optd-cost-model/src/cost/filter/controller.rs b/optd-cost-model/src/cost/filter/controller.rs index 90735a7..39462ae 100644 --- a/optd-cost-model/src/cost/filter/controller.rs +++ b/optd-cost-model/src/cost/filter/controller.rs @@ -93,9 +93,20 @@ mod tests { use std::collections::HashMap; use crate::{ - common::{predicates::bin_op_pred::BinOpType, types::TableId, values::Value}, + common::{ + predicates::{ + bin_op_pred::BinOpType, constant_pred::ConstantType, log_op_pred::LogOpType, + un_op_pred::UnOpType, + }, + types::TableId, + values::Value, + }, cost_model::tests::*, - stats::{counter::Counter, MostCommonValues}, + stats::{ + utilities::{counter::Counter, simple_map::SimpleMap}, + Distribution, MostCommonValues, DEFAULT_EQ_SEL, + }, + storage::Attribute, }; use arrow_schema::DataType; @@ -105,6 +116,7 @@ mod tests { vec![TableId(0)], vec![get_empty_per_attr_stats()], vec![None], + HashMap::new(), ); assert_approx_eq::assert_approx_eq!( cost_model @@ -122,547 +134,742 @@ mod tests { ); } - // #[tokio::test] - // async fn test_attrref_eq_constint_in_mcv() { - // let mut mcvs_counts = HashMap::new(); - // mcvs_counts.insert(vec![Some(Value::Int32(1))], 3); - // let mcvs_total_count = 10; - // let per_attribute_stats = TestPerAttributeStats::new( - // MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), - // 0, - // 0.0, - // None, - // ); - // let table_id = TableId(0); - // let cost_model = - // create_cost_model_mock_storage(vec![table_id], vec![per_attribute_stats], vec![None]); - - // let expr_tree = bin_op(BinOpType::Eq, attr_ref(table_id, 0), cnst(Value::Int32(1))); - // let expr_tree_rev = bin_op(BinOpType::Eq, cnst(Value::Int32(1)), attr_ref(table_id, 0)); - // assert_approx_eq::assert_approx_eq!( - // cost_model.get_filter_selectivity(expr_tree).await.unwrap(), - // 0.3 - // ); - // assert_approx_eq::assert_approx_eq!( - // cost_model - // .get_filter_selectivity(expr_tree_rev) - // .await - // .unwrap(), - // 0.3 - // ); - // } - - // #[test] - // fn test_attrref_eq_constint_not_in_mcv() { - // let cost_model = create_one_attribute_cost_model(TestPerAttributeStats::new( - // TestMostCommonValues::new(vec![(Value::Int32(1), 0.2), (Value::Int32(3), 0.44)]), - // 5, - // 0.0, - // Some(TestDistribution::empty()), - // )); - // let expr_tree = bin_op(BinOpType::Eq, attr_ref(0), cnst(Value::Int32(2))); - // let expr_tree_rev = bin_op(BinOpType::Eq, cnst(Value::Int32(2)), attr_ref(0)); - // let schema = Schema::new(vec![]); - // let attribute_refs = vec![AttributeRef::base_table_attribute_ref( - // String::from(TABLE1_NAME), - // 0, - // )]; - // assert_approx_eq::assert_approx_eq!( - // cost_model.get_filter_selectivity(expr_tree, &schema, &attribute_refs), - // 0.12 - // ); - // assert_approx_eq::assert_approx_eq!( - // cost_model.get_filter_selectivity(expr_tree_rev, &schema, &attribute_refs), - // 0.12 - // ); - // } - - // /// I only have one test for NEQ since I'll assume that it uses the same underlying logic as EQ - // #[test] - // fn test_attrref_neq_constint_in_mcv() { - // let cost_model = create_one_attribute_cost_model(TestPerAttributeStats::new( - // TestMostCommonValues::new(vec![(Value::Int32(1), 0.3)]), - // 0, - // 0.0, - // Some(TestDistribution::empty()), - // )); - // let expr_tree = bin_op(BinOpType::Neq, attr_ref(0), cnst(Value::Int32(1))); - // let expr_tree_rev = bin_op(BinOpType::Neq, cnst(Value::Int32(1)), attr_ref(0)); - // let schema = Schema::new(vec![]); - // let attribute_refs = vec![AttributeRef::base_table_attribute_ref( - // String::from(TABLE1_NAME), - // 0, - // )]; - // assert_approx_eq::assert_approx_eq!( - // cost_model.get_filter_selectivity(expr_tree, &schema, &attribute_refs), - // 1.0 - 0.3 - // ); - // assert_approx_eq::assert_approx_eq!( - // cost_model.get_filter_selectivity(expr_tree_rev, &schema, &attribute_refs), - // 1.0 - 0.3 - // ); - // } - - // #[test] - // fn test_attrref_leq_constint_no_mcvs_in_range() { - // let cost_model = create_one_attribute_cost_model(TestPerAttributeStats::new( - // TestMostCommonValues::empty(), - // 10, - // 0.0, - // Some(TestDistribution::new(vec![(Value::Int32(15), 0.7)])), - // )); - // let expr_tree = bin_op(BinOpType::Leq, attr_ref(0), cnst(Value::Int32(15))); - // let expr_tree_rev = bin_op(BinOpType::Gt, cnst(Value::Int32(15)), attr_ref(0)); - // let schema = Schema::new(vec![]); - // let attribute_refs = vec![AttributeRef::base_table_attribute_ref( - // String::from(TABLE1_NAME), - // 0, - // )]; - // assert_approx_eq::assert_approx_eq!( - // cost_model.get_filter_selectivity(expr_tree, &schema, &attribute_refs), - // 0.7 - // ); - // assert_approx_eq::assert_approx_eq!( - // cost_model.get_filter_selectivity(expr_tree_rev, &schema, &attribute_refs), - // 0.7 - // ); - // } - - // #[test] - // fn test_attrref_leq_constint_with_mcvs_in_range_not_at_border() { - // let cost_model = create_one_attribute_cost_model(TestPerAttributeStats::new( - // TestMostCommonValues { - // mcvs: vec![ - // (vec![Some(Value::Int32(6))], 0.05), - // (vec![Some(Value::Int32(10))], 0.1), - // (vec![Some(Value::Int32(17))], 0.08), - // (vec![Some(Value::Int32(25))], 0.07), - // ] - // .into_iter() - // .attrlect(), - // }, - // 10, - // 0.0, - // Some(TestDistribution::new(vec![(Value::Int32(15), 0.7)])), - // )); - // let expr_tree = bin_op(BinOpType::Leq, attr_ref(0), cnst(Value::Int32(15))); - // let expr_tree_rev = bin_op(BinOpType::Gt, cnst(Value::Int32(15)), attr_ref(0)); - // let schema = Schema::new(vec![]); - // let attribute_refs = vec![AttributeRef::base_table_attribute_ref( - // String::from(TABLE1_NAME), - // 0, - // )]; - // assert_approx_eq::assert_approx_eq!( - // cost_model.get_filter_selectivity(expr_tree, &schema, &attribute_refs), - // 0.85 - // ); - // assert_approx_eq::assert_approx_eq!( - // cost_model.get_filter_selectivity(expr_tree_rev, &schema, &attribute_refs), - // 0.85 - // ); - // } - - // #[test] - // fn test_attrref_leq_constint_with_mcv_at_border() { - // let cost_model = create_one_attribute_cost_model(TestPerAttributeStats::new( - // TestMostCommonValues::new(vec![ - // (Value::Int32(6), 0.05), - // (Value::Int32(10), 0.1), - // (Value::Int32(15), 0.08), - // (Value::Int32(25), 0.07), - // ]), - // 10, - // 0.0, - // Some(TestDistribution::new(vec![(Value::Int32(15), 0.7)])), - // )); - // let expr_tree = bin_op(BinOpType::Leq, attr_ref(0), cnst(Value::Int32(15))); - // let expr_tree_rev = bin_op(BinOpType::Gt, cnst(Value::Int32(15)), attr_ref(0)); - // let schema = Schema::new(vec![]); - // let attribute_refs = vec![AttributeRef::base_table_attribute_ref( - // String::from(TABLE1_NAME), - // 0, - // )]; - // assert_approx_eq::assert_approx_eq!( - // cost_model.get_filter_selectivity(expr_tree, &schema, &attribute_refs), - // 0.93 - // ); - // assert_approx_eq::assert_approx_eq!( - // cost_model.get_filter_selectivity(expr_tree_rev, &schema, &attribute_refs), - // 0.93 - // ); - // } - - // #[test] - // fn test_attrref_lt_constint_no_mcvs_in_range() { - // let cost_model = create_one_attribute_cost_model(TestPerAttributeStats::new( - // TestMostCommonValues::empty(), - // 10, - // 0.0, - // Some(TestDistribution::new(vec![(Value::Int32(15), 0.7)])), - // )); - // let expr_tree = bin_op(BinOpType::Lt, attr_ref(0), cnst(Value::Int32(15))); - // let expr_tree_rev = bin_op(BinOpType::Geq, cnst(Value::Int32(15)), attr_ref(0)); - // let schema = Schema::new(vec![]); - // let attribute_refs = vec![AttributeRef::base_table_attribute_ref( - // String::from(TABLE1_NAME), - // 0, - // )]; - // assert_approx_eq::assert_approx_eq!( - // cost_model.get_filter_selectivity(expr_tree, &schema, &attribute_refs), - // 0.6 - // ); - // assert_approx_eq::assert_approx_eq!( - // cost_model.get_filter_selectivity(expr_tree_rev, &schema, &attribute_refs), - // 0.6 - // ); - // } - - // #[test] - // fn test_attrref_lt_constint_with_mcvs_in_range_not_at_border() { - // let cost_model = create_one_attribute_cost_model(TestPerAttributeStats::new( - // TestMostCommonValues { - // mcvs: vec![ - // (vec![Some(Value::Int32(6))], 0.05), - // (vec![Some(Value::Int32(10))], 0.1), - // (vec![Some(Value::Int32(17))], 0.08), - // (vec![Some(Value::Int32(25))], 0.07), - // ] - // .into_iter() - // .attrlect(), - // }, - // 11, /* there are 4 MCVs which together add up to 0.3. With 11 total ndistinct, each - // * remaining value has freq 0.1 */ - // 0.0, - // Some(TestDistribution::new(vec![(Value::Int32(15), 0.7)])), - // )); - // let expr_tree = bin_op(BinOpType::Lt, attr_ref(0), cnst(Value::Int32(15))); - // let expr_tree_rev = bin_op(BinOpType::Geq, cnst(Value::Int32(15)), attr_ref(0)); - // let schema = Schema::new(vec![]); - // let attribute_refs = vec![AttributeRef::base_table_attribute_ref( - // String::from(TABLE1_NAME), - // 0, - // )]; - // assert_approx_eq::assert_approx_eq!( - // cost_model.get_filter_selectivity(expr_tree, &schema, &attribute_refs), - // 0.75 - // ); - // assert_approx_eq::assert_approx_eq!( - // cost_model.get_filter_selectivity(expr_tree_rev, &schema, &attribute_refs), - // 0.75 - // ); - // } - - // #[test] - // fn test_attrref_lt_constint_with_mcv_at_border() { - // let cost_model = create_one_attribute_cost_model(TestPerAttributeStats::new( - // TestMostCommonValues { - // mcvs: vec![ - // (vec![Some(Value::Int32(6))], 0.05), - // (vec![Some(Value::Int32(10))], 0.1), - // (vec![Some(Value::Int32(15))], 0.08), - // (vec![Some(Value::Int32(25))], 0.07), - // ] - // .into_iter() - // .attrlect(), - // }, - // 11, /* there are 4 MCVs which together add up to 0.3. With 11 total ndistinct, each - // * remaining value has freq 0.1 */ - // 0.0, - // Some(TestDistribution::new(vec![(Value::Int32(15), 0.7)])), - // )); - // let expr_tree = bin_op(BinOpType::Lt, attr_ref(0), cnst(Value::Int32(15))); - // let expr_tree_rev = bin_op(BinOpType::Geq, cnst(Value::Int32(15)), attr_ref(0)); - // let schema = Schema::new(vec![]); - // let attribute_refs = vec![AttributeRef::base_table_attribute_ref( - // String::from(TABLE1_NAME), - // 0, - // )]; - // assert_approx_eq::assert_approx_eq!( - // cost_model.get_filter_selectivity(expr_tree, &schema, &attribute_refs), - // 0.85 - // ); - // assert_approx_eq::assert_approx_eq!( - // cost_model.get_filter_selectivity(expr_tree_rev, &schema, &attribute_refs), - // 0.85 - // ); - // } - - // /// I have fewer tests for GT since I'll assume that it uses the same underlying logic as LEQ - // /// The only interesting thing to test is that if there are nulls, those aren't included in GT - // #[test] - // fn test_attrref_gt_constint() { - // let cost_model = create_one_attribute_cost_model(TestPerAttributeStats::new( - // TestMostCommonValues::empty(), - // 10, - // 0.0, - // Some(TestDistribution::new(vec![(Value::Int32(15), 0.7)])), - // )); - // let expr_tree = bin_op(BinOpType::Gt, attr_ref(0), cnst(Value::Int32(15))); - // let expr_tree_rev = bin_op(BinOpType::Leq, cnst(Value::Int32(15)), attr_ref(0)); - // let schema = Schema::new(vec![]); - // let attribute_refs = vec![AttributeRef::base_table_attribute_ref( - // String::from(TABLE1_NAME), - // 0, - // )]; - // assert_approx_eq::assert_approx_eq!( - // cost_model.get_filter_selectivity(expr_tree, &schema, &attribute_refs), - // 1.0 - 0.7 - // ); - // assert_approx_eq::assert_approx_eq!( - // cost_model.get_filter_selectivity(expr_tree_rev, &schema, &attribute_refs), - // 1.0 - 0.7 - // ); - // } - - // #[test] - // fn test_attrref_geq_constint() { - // let cost_model = create_one_attribute_cost_model(TestPerAttributeStats::new( - // TestMostCommonValues::empty(), - // 10, - // 0.0, - // Some(TestDistribution::new(vec![(Value::Int32(15), 0.7)])), - // )); - // let expr_tree = bin_op(BinOpType::Geq, attr_ref(0), cnst(Value::Int32(15))); - // let expr_tree_rev = bin_op(BinOpType::Lt, cnst(Value::Int32(15)), attr_ref(0)); - // let schema = Schema::new(vec![]); - // let attribute_refs = vec![AttributeRef::base_table_attribute_ref( - // String::from(TABLE1_NAME), - // 0, - // )]; - // assert_approx_eq::assert_approx_eq!( - // cost_model.get_filter_selectivity(expr_tree, &schema, &attribute_refs), - // 1.0 - 0.6 - // ); - // assert_approx_eq::assert_approx_eq!( - // cost_model.get_filter_selectivity(expr_tree_rev, &schema, &attribute_refs), - // 1.0 - 0.6 - // ); - // } - - // #[test] - // fn test_and() { - // let cost_model = create_one_attribute_cost_model(TestPerAttributeStats::new( - // TestMostCommonValues { - // mcvs: vec![ - // (vec![Some(Value::Int32(1))], 0.3), - // (vec![Some(Value::Int32(5))], 0.5), - // (vec![Some(Value::Int32(8))], 0.2), - // ] - // .into_iter() - // .attrlect(), - // }, - // 0, - // 0.0, - // Some(TestDistribution::empty()), - // )); - // let eq1 = bin_op(BinOpType::Eq, attr_ref(0), cnst(Value::Int32(1))); - // let eq5 = bin_op(BinOpType::Eq, attr_ref(0), cnst(Value::Int32(5))); - // let eq8 = bin_op(BinOpType::Eq, attr_ref(0), cnst(Value::Int32(8))); - // let expr_tree = log_op(LogOpType::And, vec![eq1.clone(), eq5.clone(), eq8.clone()]); - // let expr_tree_shift1 = log_op(LogOpType::And, vec![eq5.clone(), eq8.clone(), eq1.clone()]); - // let expr_tree_shift2 = log_op(LogOpType::And, vec![eq8.clone(), eq1.clone(), eq5.clone()]); - // let schema = Schema::new(vec![]); - // let attribute_refs = vec![AttributeRef::base_table_attribute_ref( - // String::from(TABLE1_NAME), - // 0, - // )]; - // assert_approx_eq::assert_approx_eq!( - // cost_model.get_filter_selectivity(expr_tree, &schema, &attribute_refs), - // 0.03 - // ); - // assert_approx_eq::assert_approx_eq!( - // cost_model.get_filter_selectivity(expr_tree_shift1, &schema, &attribute_refs), - // 0.03 - // ); - // assert_approx_eq::assert_approx_eq!( - // cost_model.get_filter_selectivity(expr_tree_shift2, &schema, &attribute_refs), - // 0.03 - // ); - // } - - // #[test] - // fn test_or() { - // let cost_model = create_one_attribute_cost_model(TestPerAttributeStats::new( - // TestMostCommonValues { - // mcvs: vec![ - // (vec![Some(Value::Int32(1))], 0.3), - // (vec![Some(Value::Int32(5))], 0.5), - // (vec![Some(Value::Int32(8))], 0.2), - // ] - // .into_iter() - // .attrlect(), - // }, - // 0, - // 0.0, - // Some(TestDistribution::empty()), - // )); - // let eq1 = bin_op(BinOpType::Eq, attr_ref(0), cnst(Value::Int32(1))); - // let eq5 = bin_op(BinOpType::Eq, attr_ref(0), cnst(Value::Int32(5))); - // let eq8 = bin_op(BinOpType::Eq, attr_ref(0), cnst(Value::Int32(8))); - // let expr_tree = log_op(LogOpType::Or, vec![eq1.clone(), eq5.clone(), eq8.clone()]); - // let expr_tree_shift1 = log_op(LogOpType::Or, vec![eq5.clone(), eq8.clone(), eq1.clone()]); - // let expr_tree_shift2 = log_op(LogOpType::Or, vec![eq8.clone(), eq1.clone(), eq5.clone()]); - // let schema = Schema::new(vec![]); - // let attribute_refs = vec![AttributeRef::base_table_attribute_ref( - // String::from(TABLE1_NAME), - // 0, - // )]; - // assert_approx_eq::assert_approx_eq!( - // cost_model.get_filter_selectivity(expr_tree, &schema, &attribute_refs), - // 0.72 - // ); - // assert_approx_eq::assert_approx_eq!( - // cost_model.get_filter_selectivity(expr_tree_shift1, &schema, &attribute_refs), - // 0.72 - // ); - // assert_approx_eq::assert_approx_eq!( - // cost_model.get_filter_selectivity(expr_tree_shift2, &schema, &attribute_refs), - // 0.72 - // ); - // } - - // #[test] - // fn test_not() { - // let cost_model = create_one_attribute_cost_model(TestPerAttributeStats::new( - // TestMostCommonValues::new(vec![(Value::Int32(1), 0.3)]), - // 0, - // 0.0, - // Some(TestDistribution::empty()), - // )); - // let expr_tree = un_op( - // UnOpType::Not, - // bin_op(BinOpType::Eq, attr_ref(0), cnst(Value::Int32(1))), - // ); - // let schema = Schema::new(vec![]); - // let attribute_refs = vec![AttributeRef::base_table_attribute_ref( - // String::from(TABLE1_NAME), - // 0, - // )]; - // assert_approx_eq::assert_approx_eq!( - // cost_model.get_filter_selectivity(expr_tree, &schema, &attribute_refs), - // 0.7 - // ); - // } - - // // I didn't test any non-unique cases with filter. The non-unique tests without filter should - // // cover that - - // #[test] - // fn test_attrref_eq_cast_value() { - // let cost_model = create_one_attribute_cost_model(TestPerAttributeStats::new( - // TestMostCommonValues::new(vec![(Value::Int32(1), 0.3)]), - // 0, - // 0.1, - // Some(TestDistribution::empty()), - // )); - // let expr_tree = bin_op( - // BinOpType::Eq, - // attr_ref(0), - // cast(cnst(Value::Int64(1)), DataType::Int32), - // ); - // let expr_tree_rev = bin_op( - // BinOpType::Eq, - // cast(cnst(Value::Int64(1)), DataType::Int32), - // attr_ref(0), - // ); - // let schema = Schema::new(vec![]); - // let attribute_refs = vec![AttributeRef::base_table_attribute_ref( - // String::from(TABLE1_NAME), - // 0, - // )]; - // assert_approx_eq::assert_approx_eq!( - // cost_model.get_filter_selectivity(expr_tree, &schema, &attribute_refs), - // 0.3 - // ); - // assert_approx_eq::assert_approx_eq!( - // cost_model.get_filter_selectivity(expr_tree_rev, &schema, &attribute_refs), - // 0.3 - // ); - // } - - // #[test] - // fn test_cast_attrref_eq_value() { - // let cost_model = create_one_attribute_cost_model(TestPerAttributeStats::new( - // TestMostCommonValues::new(vec![(Value::Int32(1), 0.3)]), - // 0, - // 0.1, - // Some(TestDistribution::empty()), - // )); - // let expr_tree = bin_op( - // BinOpType::Eq, - // cast(attr_ref(0), DataType::Int64), - // cnst(Value::Int64(1)), - // ); - // let expr_tree_rev = bin_op( - // BinOpType::Eq, - // cnst(Value::Int64(1)), - // cast(attr_ref(0), DataType::Int64), - // ); - // let schema = Schema::new(vec![Field { - // name: String::from(""), - // typ: ConstantType::Int32, - // nullable: false, - // }]); - // let attribute_refs = vec![AttributeRef::base_table_attribute_ref( - // String::from(TABLE1_NAME), - // 0, - // )]; - // assert_approx_eq::assert_approx_eq!( - // cost_model.get_filter_selectivity(expr_tree, &schema, &attribute_refs), - // 0.3 - // ); - // assert_approx_eq::assert_approx_eq!( - // cost_model.get_filter_selectivity(expr_tree_rev, &schema, &attribute_refs), - // 0.3 - // ); - // } - - // /// In this case, we should leave the Cast as is. - // /// - // /// Note that the test only checks the selectivity and thus doesn't explicitly test that the - // /// Cast is indeed left as is. However, if get_filter_selectivity() doesn't crash, that's a - // /// pretty good signal that the Cast was left as is. - // #[test] - // fn test_cast_attrref_eq_attrref() { - // let cost_model = create_one_attribute_cost_model(TestPerAttributeStats::new( - // TestMostCommonValues::new(vec![]), - // 0, - // 0.0, - // Some(TestDistribution::empty()), - // )); - // let expr_tree = bin_op( - // BinOpType::Eq, - // cast(attr_ref(0), DataType::Int64), - // attr_ref(1), - // ); - // let expr_tree_rev = bin_op( - // BinOpType::Eq, - // attr_ref(1), - // cast(attr_ref(0), DataType::Int64), - // ); - // let schema = Schema::new(vec![ - // Field { - // name: String::from(""), - // typ: ConstantType::Int32, - // nullable: false, - // }, - // Field { - // name: String::from(""), - // typ: ConstantType::Int64, - // nullable: false, - // }, - // ]); - // let attribute_refs = vec![AttributeRef::base_table_attribute_ref( - // String::from(TABLE1_NAME), - // 0, - // )]; - // assert_approx_eq::assert_approx_eq!( - // cost_model.get_filter_selectivity(expr_tree, &schema, &attribute_refs), - // DEFAULT_EQ_SEL - // ); - // assert_approx_eq::assert_approx_eq!( - // cost_model.get_filter_selectivity(expr_tree_rev, &schema, &attribute_refs), - // DEFAULT_EQ_SEL - // ); - // } + #[tokio::test] + async fn test_attr_ref_eq_constint_in_mcv() { + let mut mcvs_counts = HashMap::new(); + mcvs_counts.insert(vec![Some(Value::Int32(1))], 3); + let mcvs_total_count = 10; + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + 0, + 0.0, + None, + ); + let table_id = TableId(0); + let cost_model = create_cost_model_mock_storage( + vec![table_id], + vec![per_attribute_stats], + vec![None], + HashMap::new(), + ); + + let expr_tree = bin_op(BinOpType::Eq, attr_ref(table_id, 0), cnst(Value::Int32(1))); + let expr_tree_rev = bin_op(BinOpType::Eq, cnst(Value::Int32(1)), attr_ref(table_id, 0)); + assert_approx_eq::assert_approx_eq!( + cost_model.get_filter_selectivity(expr_tree).await.unwrap(), + 0.3 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(expr_tree_rev) + .await + .unwrap(), + 0.3 + ); + } + + #[tokio::test] + async fn test_attr_ref_eq_constint_not_in_mcv() { + let mut mcvs_counts = HashMap::new(); + mcvs_counts.insert(vec![Some(Value::Int32(1))], 20); + mcvs_counts.insert(vec![Some(Value::Int32(3))], 44); + let mcvs_total_count = 100; + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + 5, + 0.0, + None, + ); + let table_id = TableId(0); + let cost_model = create_cost_model_mock_storage( + vec![table_id], + vec![per_attribute_stats], + vec![None], + HashMap::new(), + ); + + let expr_tree = bin_op(BinOpType::Eq, attr_ref(table_id, 0), cnst(Value::Int32(2))); + let expr_tree_rev = bin_op(BinOpType::Eq, cnst(Value::Int32(2)), attr_ref(table_id, 0)); + assert_approx_eq::assert_approx_eq!( + cost_model.get_filter_selectivity(expr_tree).await.unwrap(), + 0.12 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(expr_tree_rev) + .await + .unwrap(), + 0.12 + ); + } + + /// I only have one test for NEQ since I'll assume that it uses the same underlying logic as EQ + #[tokio::test] + async fn test_attr_ref_neq_constint_in_mcv() { + let mut mcvs_counts = HashMap::new(); + mcvs_counts.insert(vec![Some(Value::Int32(1))], 3); + let mcvs_total_count = 10; + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + 0, + 0.0, + None, + ); + let table_id = TableId(0); + let cost_model = create_cost_model_mock_storage( + vec![table_id], + vec![per_attribute_stats], + vec![None], + HashMap::new(), + ); + + let expr_tree = bin_op(BinOpType::Neq, attr_ref(table_id, 0), cnst(Value::Int32(1))); + let expr_tree_rev = bin_op(BinOpType::Neq, cnst(Value::Int32(1)), attr_ref(table_id, 0)); + assert_approx_eq::assert_approx_eq!( + cost_model.get_filter_selectivity(expr_tree).await.unwrap(), + 1.0 - 0.3 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(expr_tree_rev) + .await + .unwrap(), + 1.0 - 0.3 + ); + } + + #[tokio::test] + async fn test_attr_ref_leq_constint_no_mcvs_in_range() { + let mut mcvs_counts = HashMap::new(); + let mcvs_total_count = 10; + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + 10, + 0.0, + Some(Distribution::SimpleDistribution(SimpleMap::new(vec![( + Value::Int32(15), + 0.7, + )]))), + ); + let table_id = TableId(0); + let cost_model = create_cost_model_mock_storage( + vec![table_id], + vec![per_attribute_stats], + vec![None], + HashMap::new(), + ); + + let expr_tree = bin_op( + BinOpType::Leq, + attr_ref(table_id, 0), + cnst(Value::Int32(15)), + ); + let expr_tree_rev = bin_op(BinOpType::Gt, cnst(Value::Int32(15)), attr_ref(table_id, 0)); + assert_approx_eq::assert_approx_eq!( + cost_model.get_filter_selectivity(expr_tree).await.unwrap(), + 0.7 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(expr_tree_rev) + .await + .unwrap(), + 0.7 + ); + } + + #[tokio::test] + async fn test_attr_ref_leq_constint_with_mcvs_in_range_not_at_border() { + let mut mcvs_counts = HashMap::new(); + mcvs_counts.insert(vec![Some(Value::Int32(6))], 5); + mcvs_counts.insert(vec![Some(Value::Int32(10))], 10); + mcvs_counts.insert(vec![Some(Value::Int32(17))], 8); + mcvs_counts.insert(vec![Some(Value::Int32(25))], 7); + let mcvs_total_count = 100; + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + 10, + 0.0, + Some(Distribution::SimpleDistribution(SimpleMap::new(vec![( + Value::Int32(15), + 0.7, + )]))), + ); + let table_id = TableId(0); + let cost_model = create_cost_model_mock_storage( + vec![table_id], + vec![per_attribute_stats], + vec![None], + HashMap::new(), + ); + + let expr_tree = bin_op( + BinOpType::Leq, + attr_ref(table_id, 0), + cnst(Value::Int32(15)), + ); + let expr_tree_rev = bin_op(BinOpType::Gt, cnst(Value::Int32(15)), attr_ref(table_id, 0)); + assert_approx_eq::assert_approx_eq!( + cost_model.get_filter_selectivity(expr_tree).await.unwrap(), + 0.85 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(expr_tree_rev) + .await + .unwrap(), + 0.85 + ); + } + + #[tokio::test] + async fn test_attr_ref_leq_constint_with_mcv_at_border() { + let mut mcvs_counts = HashMap::new(); + mcvs_counts.insert(vec![Some(Value::Int32(6))], 5); + mcvs_counts.insert(vec![Some(Value::Int32(10))], 10); + mcvs_counts.insert(vec![Some(Value::Int32(15))], 8); + mcvs_counts.insert(vec![Some(Value::Int32(25))], 7); + let mcvs_total_count = 100; + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + 10, + 0.0, + Some(Distribution::SimpleDistribution(SimpleMap::new(vec![( + Value::Int32(15), + 0.7, + )]))), + ); + let table_id = TableId(0); + let cost_model = create_cost_model_mock_storage( + vec![table_id], + vec![per_attribute_stats], + vec![None], + HashMap::new(), + ); + + let expr_tree = bin_op( + BinOpType::Leq, + attr_ref(table_id, 0), + cnst(Value::Int32(15)), + ); + let expr_tree_rev = bin_op(BinOpType::Gt, cnst(Value::Int32(15)), attr_ref(table_id, 0)); + assert_approx_eq::assert_approx_eq!( + cost_model.get_filter_selectivity(expr_tree).await.unwrap(), + 0.93 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(expr_tree_rev) + .await + .unwrap(), + 0.93 + ); + } + + #[tokio::test] + async fn test_attr_ref_lt_constint_no_mcvs_in_range() { + let mut mcvs_counts = HashMap::new(); + let mcvs_total_count = 10; + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + 10, + 0.0, + Some(Distribution::SimpleDistribution(SimpleMap::new(vec![( + Value::Int32(15), + 0.7, + )]))), + ); + let table_id = TableId(0); + let cost_model = create_cost_model_mock_storage( + vec![table_id], + vec![per_attribute_stats], + vec![None], + HashMap::new(), + ); + + let expr_tree = bin_op(BinOpType::Lt, attr_ref(table_id, 0), cnst(Value::Int32(15))); + let expr_tree_rev = bin_op( + BinOpType::Geq, + cnst(Value::Int32(15)), + attr_ref(table_id, 0), + ); + assert_approx_eq::assert_approx_eq!( + cost_model.get_filter_selectivity(expr_tree).await.unwrap(), + 0.6 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(expr_tree_rev) + .await + .unwrap(), + 0.6 + ); + } + + #[tokio::test] + async fn test_attr_ef_lt_constint_with_mcvs_in_range_not_at_border() { + let mut mcvs_counts = HashMap::new(); + mcvs_counts.insert(vec![Some(Value::Int32(6))], 5); + mcvs_counts.insert(vec![Some(Value::Int32(10))], 10); + mcvs_counts.insert(vec![Some(Value::Int32(17))], 8); + mcvs_counts.insert(vec![Some(Value::Int32(25))], 7); + let mcvs_total_count = 100; + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + 11, /* there are 4 MCVs which together add up to 0.3. With 11 total ndistinct, each + * remaining value has freq 0.1 */ + 0.0, + Some(Distribution::SimpleDistribution(SimpleMap::new(vec![( + Value::Int32(15), + 0.7, + )]))), + ); + let table_id = TableId(0); + let cost_model = create_cost_model_mock_storage( + vec![table_id], + vec![per_attribute_stats], + vec![None], + HashMap::new(), + ); + + let expr_tree = bin_op(BinOpType::Lt, attr_ref(table_id, 0), cnst(Value::Int32(15))); + let expr_tree_rev = bin_op( + BinOpType::Geq, + cnst(Value::Int32(15)), + attr_ref(table_id, 0), + ); + assert_approx_eq::assert_approx_eq!( + cost_model.get_filter_selectivity(expr_tree).await.unwrap(), + 0.75 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(expr_tree_rev) + .await + .unwrap(), + 0.75 + ); + } + + #[tokio::test] + async fn test_attr_ref_lt_constint_with_mcv_at_border() { + let mut mcvs_counts = HashMap::new(); + mcvs_counts.insert(vec![Some(Value::Int32(6))], 5); + mcvs_counts.insert(vec![Some(Value::Int32(10))], 10); + mcvs_counts.insert(vec![Some(Value::Int32(15))], 8); + mcvs_counts.insert(vec![Some(Value::Int32(25))], 7); + let mcvs_total_count = 100; + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + 11, /* there are 4 MCVs which together add up to 0.3. With 11 total ndistinct, each + * remaining value has freq 0.1 */ + 0.0, + Some(Distribution::SimpleDistribution(SimpleMap::new(vec![( + Value::Int32(15), + 0.7, + )]))), + ); + let table_id = TableId(0); + let cost_model = create_cost_model_mock_storage( + vec![table_id], + vec![per_attribute_stats], + vec![None], + HashMap::new(), + ); + + let expr_tree = bin_op(BinOpType::Lt, attr_ref(table_id, 0), cnst(Value::Int32(15))); + let expr_tree_rev = bin_op( + BinOpType::Geq, + cnst(Value::Int32(15)), + attr_ref(table_id, 0), + ); + assert_approx_eq::assert_approx_eq!( + cost_model.get_filter_selectivity(expr_tree).await.unwrap(), + 0.85 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(expr_tree_rev) + .await + .unwrap(), + 0.85 + ); + } + + /// I have fewer tests for GT since I'll assume that it uses the same underlying logic as LEQ + /// The only interesting thing to test is that if there are nulls, those aren't included in GT + #[tokio::test] + async fn test_attr_ref_gt_constint() { + let mut mcvs_counts = HashMap::new(); + let mcvs_total_count = 100; + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + 10, + 0.0, + Some(Distribution::SimpleDistribution(SimpleMap::new(vec![( + Value::Int32(15), + 0.7, + )]))), + ); + let table_id = TableId(0); + let cost_model = create_cost_model_mock_storage( + vec![table_id], + vec![per_attribute_stats], + vec![None], + HashMap::new(), + ); + + let expr_tree = bin_op(BinOpType::Gt, attr_ref(table_id, 0), cnst(Value::Int32(15))); + let expr_tree_rev = bin_op( + BinOpType::Leq, + cnst(Value::Int32(15)), + attr_ref(table_id, 0), + ); + assert_approx_eq::assert_approx_eq!( + cost_model.get_filter_selectivity(expr_tree).await.unwrap(), + 1.0 - 0.7 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(expr_tree_rev) + .await + .unwrap(), + 1.0 - 0.7 + ); + } + + #[tokio::test] + async fn test_attr_ref_geq_constint() { + let mut mcvs_counts = HashMap::new(); + let mcvs_total_count = 100; + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + 10, + 0.0, + Some(Distribution::SimpleDistribution(SimpleMap::new(vec![( + Value::Int32(15), + 0.7, + )]))), + ); + let table_id = TableId(0); + let cost_model = create_cost_model_mock_storage( + vec![table_id], + vec![per_attribute_stats], + vec![None], + HashMap::new(), + ); + + let expr_tree = bin_op( + BinOpType::Geq, + attr_ref(table_id, 0), + cnst(Value::Int32(15)), + ); + let expr_tree_rev = bin_op(BinOpType::Lt, cnst(Value::Int32(15)), attr_ref(table_id, 0)); + + assert_approx_eq::assert_approx_eq!( + cost_model.get_filter_selectivity(expr_tree).await.unwrap(), + 1.0 - 0.6 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(expr_tree_rev) + .await + .unwrap(), + 1.0 - 0.6 + ); + } + + #[tokio::test] + async fn test_and() { + let mut mcvs_counts = HashMap::new(); + mcvs_counts.insert(vec![Some(Value::Int32(1))], 3); + mcvs_counts.insert(vec![Some(Value::Int32(5))], 5); + mcvs_counts.insert(vec![Some(Value::Int32(8))], 2); + let mcvs_total_count = 10; + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + 0, + 0.0, + None, + ); + let table_id = TableId(0); + let cost_model = create_cost_model_mock_storage( + vec![table_id], + vec![per_attribute_stats], + vec![None], + HashMap::new(), + ); + + let eq1 = bin_op(BinOpType::Eq, attr_ref(table_id, 0), cnst(Value::Int32(1))); + let eq5 = bin_op(BinOpType::Eq, attr_ref(table_id, 0), cnst(Value::Int32(5))); + let eq8 = bin_op(BinOpType::Eq, attr_ref(table_id, 0), cnst(Value::Int32(8))); + let expr_tree = log_op(LogOpType::And, vec![eq1.clone(), eq5.clone(), eq8.clone()]); + let expr_tree_shift1 = log_op(LogOpType::And, vec![eq5.clone(), eq8.clone(), eq1.clone()]); + let expr_tree_shift2 = log_op(LogOpType::And, vec![eq8.clone(), eq1.clone(), eq5.clone()]); + + assert_approx_eq::assert_approx_eq!( + cost_model.get_filter_selectivity(expr_tree).await.unwrap(), + 0.03 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(expr_tree_shift1) + .await + .unwrap(), + 0.03 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(expr_tree_shift2) + .await + .unwrap(), + 0.03 + ); + } + + #[tokio::test] + async fn test_or() { + let mut mcvs_counts = HashMap::new(); + mcvs_counts.insert(vec![Some(Value::Int32(1))], 3); + mcvs_counts.insert(vec![Some(Value::Int32(5))], 5); + mcvs_counts.insert(vec![Some(Value::Int32(8))], 2); + let mcvs_total_count = 10; + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + 0, + 0.0, + None, + ); + let table_id = TableId(0); + let cost_model = create_cost_model_mock_storage( + vec![table_id], + vec![per_attribute_stats], + vec![None], + HashMap::new(), + ); + + let eq1 = bin_op(BinOpType::Eq, attr_ref(table_id, 0), cnst(Value::Int32(1))); + let eq5 = bin_op(BinOpType::Eq, attr_ref(table_id, 0), cnst(Value::Int32(5))); + let eq8 = bin_op(BinOpType::Eq, attr_ref(table_id, 0), cnst(Value::Int32(8))); + let expr_tree = log_op(LogOpType::Or, vec![eq1.clone(), eq5.clone(), eq8.clone()]); + let expr_tree_shift1 = log_op(LogOpType::Or, vec![eq5.clone(), eq8.clone(), eq1.clone()]); + let expr_tree_shift2 = log_op(LogOpType::Or, vec![eq8.clone(), eq1.clone(), eq5.clone()]); + + assert_approx_eq::assert_approx_eq!( + cost_model.get_filter_selectivity(expr_tree).await.unwrap(), + 0.72 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(expr_tree_shift1) + .await + .unwrap(), + 0.72 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(expr_tree_shift2) + .await + .unwrap(), + 0.72 + ); + } + + #[tokio::test] + async fn test_not() { + let mut mcvs_counts = HashMap::new(); + mcvs_counts.insert(vec![Some(Value::Int32(1))], 3); + let mcvs_total_count = 10; + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + 0, + 0.0, + None, + ); + let table_id = TableId(0); + let cost_model = create_cost_model_mock_storage( + vec![table_id], + vec![per_attribute_stats], + vec![None], + HashMap::new(), + ); + + let expr_tree = un_op( + UnOpType::Not, + bin_op(BinOpType::Eq, attr_ref(table_id, 0), cnst(Value::Int32(1))), + ); + + assert_approx_eq::assert_approx_eq!( + cost_model.get_filter_selectivity(expr_tree).await.unwrap(), + 0.7 + ); + } + + // I didn't test any non-unique cases with filter. The non-unique tests without filter should + // cover that + + #[tokio::test] + async fn test_attr_ref_eq_cast_value() { + let mut mcvs_counts = HashMap::new(); + mcvs_counts.insert(vec![Some(Value::Int32(1))], 3); + let mcvs_total_count = 10; + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + 0, + 0.0, + None, + ); + let table_id = TableId(0); + let cost_model = create_cost_model_mock_storage( + vec![table_id], + vec![per_attribute_stats], + vec![None], + HashMap::new(), + ); + + let expr_tree = bin_op( + BinOpType::Eq, + attr_ref(table_id, 0), + cast(cnst(Value::Int64(1)), DataType::Int32), + ); + let expr_tree_rev = bin_op( + BinOpType::Eq, + cast(cnst(Value::Int64(1)), DataType::Int32), + attr_ref(table_id, 0), + ); + + assert_approx_eq::assert_approx_eq!( + cost_model.get_filter_selectivity(expr_tree).await.unwrap(), + 0.3 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(expr_tree_rev) + .await + .unwrap(), + 0.3 + ); + } + + #[tokio::test] + async fn test_cast_attr_ref_eq_value() { + let mut mcvs_counts = HashMap::new(); + mcvs_counts.insert(vec![Some(Value::Int32(1))], 3); + let mcvs_total_count = 10; + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + 0, + 0.1, + None, + ); + let table_id = TableId(0); + let attr_infos = HashMap::from([( + table_id, + HashMap::from([( + 0, + Attribute { + name: String::from("attr1"), + typ: ConstantType::Int32, + nullable: false, + }, + )]), + )]); + let cost_model = create_cost_model_mock_storage( + vec![table_id], + vec![per_attribute_stats], + vec![None], + attr_infos, + ); + + let expr_tree = bin_op( + BinOpType::Eq, + cast(attr_ref(table_id, 0), DataType::Int64), + cnst(Value::Int64(1)), + ); + let expr_tree_rev = bin_op( + BinOpType::Eq, + cnst(Value::Int64(1)), + cast(attr_ref(table_id, 0), DataType::Int64), + ); + + assert_approx_eq::assert_approx_eq!( + cost_model.get_filter_selectivity(expr_tree).await.unwrap(), + 0.3 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(expr_tree_rev) + .await + .unwrap(), + 0.3 + ); + } + + /// In this case, we should leave the Cast as is. + /// + /// Note that the test only checks the selectivity and thus doesn't explicitly test that the + /// Cast is indeed left as is. However, if get_filter_selectivity() doesn't crash, that's a + /// pretty good signal that the Cast was left as is. + #[tokio::test] + async fn test_cast_attr_ref_eq_attr_ref() { + let mut mcvs_counts = HashMap::new(); + let mcvs_total_count = 10; + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + 0, + 0.0, + None, + ); + let table_id = TableId(0); + let attr_infos = HashMap::from([( + table_id, + HashMap::from([ + ( + 0, + Attribute { + name: String::from("attr1"), + typ: ConstantType::Int32, + nullable: false, + }, + ), + ( + 1, + Attribute { + name: String::from("attr2"), + typ: ConstantType::Int64, + nullable: false, + }, + ), + ]), + )]); + let cost_model = create_cost_model_mock_storage( + vec![table_id], + vec![per_attribute_stats], + vec![None], + attr_infos, + ); + + let expr_tree = bin_op( + BinOpType::Eq, + cast(attr_ref(table_id, 0), DataType::Int64), + attr_ref(table_id, 1), + ); + let expr_tree_rev = bin_op( + BinOpType::Eq, + attr_ref(table_id, 1), + cast(attr_ref(table_id, 0), DataType::Int64), + ); + + assert_approx_eq::assert_approx_eq!( + cost_model.get_filter_selectivity(expr_tree).await.unwrap(), + DEFAULT_EQ_SEL + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(expr_tree_rev) + .await + .unwrap(), + DEFAULT_EQ_SEL + ); + } } diff --git a/optd-cost-model/src/cost_model.rs b/optd-cost-model/src/cost_model.rs index 03bcd60..ebb6391 100644 --- a/optd-cost-model/src/cost_model.rs +++ b/optd-cost-model/src/cost_model.rs @@ -136,10 +136,9 @@ pub mod tests { values::Value, }, stats::{ - counter::Counter, tdigest::TDigest, AttributeCombValueStats, Distribution, - MostCommonValues, + utilities::counter::Counter, AttributeCombValueStats, Distribution, MostCommonValues, }, - storage::mock::{CostModelStorageMockManagerImpl, TableStats}, + storage::mock::{BaseTableAttrInfo, CostModelStorageMockManagerImpl, TableStats}, }; use super::*; @@ -152,6 +151,7 @@ pub mod tests { table_id: Vec, per_attribute_stats: Vec, row_counts: Vec>, + per_table_attr_infos: BaseTableAttrInfo, ) -> TestOptCostModelMock { let storage_manager = CostModelStorageMockManagerImpl::new( table_id @@ -168,6 +168,7 @@ pub mod tests { ) }) .collect(), + per_table_attr_infos, ); CostModelImpl::new(storage_manager, CatalogSource::Mock) } @@ -222,9 +223,6 @@ pub mod tests { ) } - /// The reason this isn't an associated function of PerAttributeStats is because that would require - /// adding an empty type to the enum definitions of MostCommonValues and Distribution, - /// which I wanted to avoid pub(crate) fn get_empty_per_attr_stats() -> TestPerAttributeStats { TestPerAttributeStats::new(MostCommonValues::Counter(Counter::default()), 0, 0.0, None) } diff --git a/optd-cost-model/src/stats/mod.rs b/optd-cost-model/src/stats/mod.rs index 0b1396a..5440ea1 100644 --- a/optd-cost-model/src/stats/mod.rs +++ b/optd-cost-model/src/stats/mod.rs @@ -1,12 +1,15 @@ #![allow(unused)] mod arith_encoder; -pub mod counter; -pub mod tdigest; +pub mod utilities; use crate::common::values::Value; -use counter::Counter; use serde::{Deserialize, Serialize}; +use utilities::counter::Counter; +use utilities::{ + simple_map::{self, SimpleMap}, + tdigest::TDigest, +}; // Default n-distinct estimate for derived columns or columns lacking statistics pub const DEFAULT_NUM_DISTINCT: u64 = 200; @@ -27,7 +30,8 @@ pub const FIXED_CHAR_SEL_FACTOR: f64 = 0.2; pub type AttributeCombValue = Vec>; -#[derive(Serialize, Deserialize, Debug)] +// TODO: remove the clone, see the comment in the [`AttributeCombValueStats`] +#[derive(Serialize, Deserialize, Debug, Clone)] #[serde(tag = "type")] pub enum MostCommonValues { Counter(Counter), @@ -71,10 +75,12 @@ impl MostCommonValues { } } -#[derive(Serialize, Deserialize, Debug)] +// TODO: remove the clone, see the comment in the [`AttributeCombValueStats`] +#[derive(Serialize, Deserialize, Debug, Clone)] #[serde(tag = "type")] pub enum Distribution { - TDigest(tdigest::TDigest), + TDigest(TDigest), + SimpleDistribution(SimpleMap), // Add more types here... } @@ -89,11 +95,21 @@ impl Distribution { tdigest.centroids.len() as f64 * tdigest.cdf(value) / nb_rows as f64 } } + Distribution::SimpleDistribution(simple_distribution) => { + *simple_distribution.m.get(value).unwrap_or(&0.0) + } } } } -#[derive(Serialize, Deserialize, Debug)] +// TODO: Remove the clone. Now I have to add this because +// persistent.rs doesn't have a memory cache, so we have to +// return AttributeCombValueStats rather than &AttributeCombValueStats. +// But this poses a problem for mock.rs when testing, since mock storage +// only has memory hash map, so we need to return a clone of AttributeCombValueStats. +// Later, if memory cache is added, we should change this to return a reference. +// **and** remove the clone. +#[derive(Serialize, Deserialize, Debug, Clone)] pub struct AttributeCombValueStats { pub mcvs: MostCommonValues, // Does NOT contain full nulls. pub distr: Option, // Does NOT contain mcvs; optional. diff --git a/optd-cost-model/src/stats/counter.rs b/optd-cost-model/src/stats/utilities/counter.rs similarity index 97% rename from optd-cost-model/src/stats/counter.rs rename to optd-cost-model/src/stats/utilities/counter.rs index ddffb7a..368700c 100644 --- a/optd-cost-model/src/stats/counter.rs +++ b/optd-cost-model/src/stats/utilities/counter.rs @@ -5,8 +5,9 @@ use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; /// The Counter structure to track exact frequencies of fixed elements. +/// TODO: remove the clone, see the comment in the [`AttributeCombValueStats`] #[serde_with::serde_as] -#[derive(Default, Serialize, Deserialize, Debug)] +#[derive(Default, Serialize, Deserialize, Debug, Clone)] pub struct Counter { #[serde_as(as = "HashMap")] counts: HashMap, // The exact counts of an element T. diff --git a/optd-cost-model/src/stats/utilities/mod.rs b/optd-cost-model/src/stats/utilities/mod.rs new file mode 100644 index 0000000..0a7903b --- /dev/null +++ b/optd-cost-model/src/stats/utilities/mod.rs @@ -0,0 +1,3 @@ +pub mod counter; +pub mod simple_map; +pub mod tdigest; diff --git a/optd-cost-model/src/stats/utilities/simple_map.rs b/optd-cost-model/src/stats/utilities/simple_map.rs new file mode 100644 index 0000000..5503b2f --- /dev/null +++ b/optd-cost-model/src/stats/utilities/simple_map.rs @@ -0,0 +1,20 @@ +use std::collections::HashMap; + +use serde::{Deserialize, Serialize}; + +use crate::common::values::Value; + +/// TODO: documentation +/// Now it is mainly for testing purposes. +#[derive(Clone, Serialize, Deserialize, Debug)] +pub struct SimpleMap { + pub(crate) m: HashMap, +} + +impl SimpleMap { + pub fn new(v: Vec<(Value, f64)>) -> Self { + Self { + m: v.into_iter().collect(), + } + } +} diff --git a/optd-cost-model/src/stats/tdigest.rs b/optd-cost-model/src/stats/utilities/tdigest.rs similarity index 99% rename from optd-cost-model/src/stats/tdigest.rs rename to optd-cost-model/src/stats/utilities/tdigest.rs index 83dc9b5..96a2269 100644 --- a/optd-cost-model/src/stats/tdigest.rs +++ b/optd-cost-model/src/stats/utilities/tdigest.rs @@ -15,9 +15,7 @@ use std::marker::PhantomData; use itertools::Itertools; use serde::{Deserialize, Serialize}; -use crate::common::values::Value; - -use super::arith_encoder; +use crate::{common::values::Value, stats::arith_encoder}; pub const DEFAULT_COMPRESSION: f64 = 200.0; diff --git a/optd-cost-model/src/storage/mock.rs b/optd-cost-model/src/storage/mock.rs index 3f8a8bf..ed75a59 100644 --- a/optd-cost-model/src/storage/mock.rs +++ b/optd-cost-model/src/storage/mock.rs @@ -30,15 +30,21 @@ impl TableStats { } pub type BaseTableStats = HashMap; +pub type BaseTableAttrInfo = HashMap>; pub struct CostModelStorageMockManagerImpl { pub(crate) per_table_stats_map: BaseTableStats, + pub(crate) per_table_attr_infos_map: BaseTableAttrInfo, } impl CostModelStorageMockManagerImpl { - pub fn new(per_table_stats_map: BaseTableStats) -> Self { + pub fn new( + per_table_stats_map: BaseTableStats, + per_table_attr_infos_map: BaseTableAttrInfo, + ) -> Self { Self { per_table_stats_map, + per_table_attr_infos_map, } } } @@ -49,7 +55,14 @@ impl CostModelStorageManager for CostModelStorageMockManagerImpl { table_id: TableId, attr_base_index: i32, ) -> CostModelResult> { - todo!() + let table_attr_infos = self.per_table_attr_infos_map.get(&table_id); + match table_attr_infos { + None => Ok(None), + Some(table_attr_infos) => match table_attr_infos.get(&attr_base_index) { + None => Ok(None), + Some(attr) => Ok(Some(attr.clone())), + }, + } } async fn get_attributes_comb_statistics( @@ -57,6 +70,13 @@ impl CostModelStorageManager for CostModelStorageMockManagerImpl { table_id: TableId, attr_base_indices: &[usize], ) -> CostModelResult> { - todo!() + let table_stats = self.per_table_stats_map.get(&table_id); + match table_stats { + None => Ok(None), + Some(table_stats) => match table_stats.column_comb_stats.get(attr_base_indices) { + None => Ok(None), + Some(stats) => Ok(Some(stats.clone())), + }, + } } } diff --git a/optd-cost-model/src/storage/persistent.rs b/optd-cost-model/src/storage/persistent.rs index 53f37da..72ae430 100644 --- a/optd-cost-model/src/storage/persistent.rs +++ b/optd-cost-model/src/storage/persistent.rs @@ -5,7 +5,7 @@ use optd_persistent::{cost_model::interface::StatType, CostModelStorageLayer}; use crate::{ common::{predicates::constant_pred::ConstantType, types::TableId}, - stats::{counter::Counter, AttributeCombValueStats, Distribution, MostCommonValues}, + stats::{utilities::counter::Counter, AttributeCombValueStats, Distribution, MostCommonValues}, CostModelResult, }; From a8f92c36e5cd95ec96ba1fd825ee5ce215862323 Mon Sep 17 00:00:00 2001 From: Lan Lou Date: Sat, 16 Nov 2024 21:37:16 -0500 Subject: [PATCH 28/51] Finish all tests for filter --- optd-cost-model/src/cost/filter/in_list.rs | 85 +++++++++++++++++++ optd-cost-model/src/cost/filter/like.rs | 94 ++++++++++++++++++++++ 2 files changed, 179 insertions(+) diff --git a/optd-cost-model/src/cost/filter/in_list.rs b/optd-cost-model/src/cost/filter/in_list.rs index d27b6f1..16080de 100644 --- a/optd-cost-model/src/cost/filter/in_list.rs +++ b/optd-cost-model/src/cost/filter/in_list.rs @@ -65,3 +65,88 @@ impl CostModelImpl { } } } + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use crate::{ + common::{types::TableId, values::Value}, + cost_model::tests::*, + stats::{utilities::counter::Counter, MostCommonValues}, + }; + + #[tokio::test] + async fn test_in_list() { + let mut mcvs_counts = HashMap::new(); + mcvs_counts.insert(vec![Some(Value::Int32(1))], 8); + mcvs_counts.insert(vec![Some(Value::Int32(2))], 2); + let mcvs_total_count = 10; + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + 2, + 0.0, + None, + ); + let table_id = TableId(0); + let cost_model = create_cost_model_mock_storage( + vec![table_id], + vec![per_attribute_stats], + vec![None], + HashMap::new(), + ); + + assert_approx_eq::assert_approx_eq!( + cost_model + .get_in_list_selectivity(&in_list(table_id, 0, vec![Value::Int32(1)], false)) + .await + .unwrap(), + 0.8 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_in_list_selectivity(&in_list( + table_id, + 0, + vec![Value::Int32(1), Value::Int32(2)], + false + )) + .await + .unwrap(), + 1.0 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_in_list_selectivity(&in_list(table_id, 0, vec![Value::Int32(3)], false)) + .await + .unwrap(), + 0.0 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_in_list_selectivity(&in_list(table_id, 0, vec![Value::Int32(1)], true)) + .await + .unwrap(), + 0.2 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_in_list_selectivity(&in_list( + table_id, + 0, + vec![Value::Int32(1), Value::Int32(2)], + true + )) + .await + .unwrap(), + 0.0 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_in_list_selectivity(&in_list(table_id, 0, vec![Value::Int32(3)], true)) + .await + .unwrap(), + 1.0 + ); + } +} diff --git a/optd-cost-model/src/cost/filter/like.rs b/optd-cost-model/src/cost/filter/like.rs index 04517f2..f49ca18 100644 --- a/optd-cost-model/src/cost/filter/like.rs +++ b/optd-cost-model/src/cost/filter/like.rs @@ -99,3 +99,97 @@ impl CostModelImpl { } } } + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use crate::{ + common::{types::TableId, values::Value}, + cost_model::tests::*, + stats::{ + utilities::counter::Counter, MostCommonValues, FIXED_CHAR_SEL_FACTOR, + FULL_WILDCARD_SEL_FACTOR, + }, + }; + + #[tokio::test] + async fn test_like_no_nulls() { + let mut mcvs_counts = HashMap::new(); + mcvs_counts.insert(vec![Some(Value::String("abcd".into()))], 1); + mcvs_counts.insert(vec![Some(Value::String("abc".into()))], 1); + let mcvs_total_count = 10; + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + 2, + 0.0, + None, + ); + let table_id = TableId(0); + let cost_model = create_cost_model_mock_storage( + vec![table_id], + vec![per_attribute_stats], + vec![None], + HashMap::new(), + ); + + assert_approx_eq::assert_approx_eq!( + cost_model + .get_like_selectivity(&like(table_id, 0, "%abcd%", false)) + .await + .unwrap(), + 0.1 + FULL_WILDCARD_SEL_FACTOR.powi(2) * FIXED_CHAR_SEL_FACTOR.powi(4) + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_like_selectivity(&like(table_id, 0, "%abc%", false)) + .await + .unwrap(), + 0.1 + 0.1 + FULL_WILDCARD_SEL_FACTOR.powi(2) * FIXED_CHAR_SEL_FACTOR.powi(3) + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_like_selectivity(&like(table_id, 0, "%abc%", true)) + .await + .unwrap(), + 1.0 - (0.1 + 0.1 + FULL_WILDCARD_SEL_FACTOR.powi(2) * FIXED_CHAR_SEL_FACTOR.powi(3)) + ); + } + + #[tokio::test] + async fn test_like_with_nulls() { + let null_frac = 0.5; + let mut mcvs_counts = HashMap::new(); + mcvs_counts.insert(vec![Some(Value::String("abcd".into()))], 1); + let mcvs_total_count = 10; + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + 2, + null_frac, + None, + ); + let table_id = TableId(0); + let cost_model = create_cost_model_mock_storage( + vec![table_id], + vec![per_attribute_stats], + vec![None], + HashMap::new(), + ); + + assert_approx_eq::assert_approx_eq!( + cost_model + .get_like_selectivity(&like(table_id, 0, "%abcd%", false)) + .await + .unwrap(), + 0.1 + FULL_WILDCARD_SEL_FACTOR.powi(2) * FIXED_CHAR_SEL_FACTOR.powi(4) + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_like_selectivity(&like(table_id, 0, "%abcd%", true)) + .await + .unwrap(), + 1.0 - (0.1 + FULL_WILDCARD_SEL_FACTOR.powi(2) * FIXED_CHAR_SEL_FACTOR.powi(4)) + - null_frac + ); + } +} From 2c9240f87bac45781f0276044d2a8078666d7dab Mon Sep 17 00:00:00 2001 From: Lan Lou Date: Sat, 16 Nov 2024 22:11:19 -0500 Subject: [PATCH 29/51] Add important tricky todo --- optd-cost-model/src/cost/filter/comp_op.rs | 2 +- optd-cost-model/src/storage/persistent.rs | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/optd-cost-model/src/cost/filter/comp_op.rs b/optd-cost-model/src/cost/filter/comp_op.rs index b8c2d99..4ce2b3c 100644 --- a/optd-cost-model/src/cost/filter/comp_op.rs +++ b/optd-cost-model/src/cost/filter/comp_op.rs @@ -179,8 +179,8 @@ impl CostModelImpl { cast_node = attr_ref_expr.into_pred_node(); // The "invert" cast is to invert the cast so that we're casting the // non_cast_node to the attribute's original type. - // TODO(migration): double check // TODO: Consider attribute info is None. + // **TODO**: What if this attribute is a derived attribute? let attribute_info = self .storage_manager .get_attribute_info(table_id, attr_ref_idx as i32) diff --git a/optd-cost-model/src/storage/persistent.rs b/optd-cost-model/src/storage/persistent.rs index 72ae430..b574fe0 100644 --- a/optd-cost-model/src/storage/persistent.rs +++ b/optd-cost-model/src/storage/persistent.rs @@ -30,6 +30,8 @@ impl CostModelStorageManager /// /// TODO: if we have memory cache, /// we should add the reference. (&Attr) + /// TODO(IMPORTANT): what if table is a derived (temporary) table? And what if + /// the attribute is a derived attribute? async fn get_attribute_info( &self, table_id: TableId, @@ -58,6 +60,9 @@ impl CostModelStorageManager /// /// TODO: Shall we pass in an epoch here to make sure that the statistics are from the same /// epoch? + /// + /// TODO(IMPORTANT): what if table is a derived (temporary) table? And what if + /// the attribute is a derived attribute? async fn get_attributes_comb_statistics( &self, table_id: TableId, From d6e18257a3442f209734e5e6ec5245d5a27a4246 Mon Sep 17 00:00:00 2001 From: Lan Lou Date: Sat, 16 Nov 2024 23:19:39 -0500 Subject: [PATCH 30/51] Improve filter tests --- optd-cost-model/src/cost/filter/controller.rs | 143 ++++++++---------- optd-cost-model/src/cost/filter/in_list.rs | 14 +- optd-cost-model/src/cost/filter/like.rs | 13 +- optd-cost-model/src/stats/mod.rs | 12 +- .../src/stats/utilities/simple_map.rs | 9 +- 5 files changed, 93 insertions(+), 98 deletions(-) diff --git a/optd-cost-model/src/cost/filter/controller.rs b/optd-cost-model/src/cost/filter/controller.rs index 39462ae..c10ea1d 100644 --- a/optd-cost-model/src/cost/filter/controller.rs +++ b/optd-cost-model/src/cost/filter/controller.rs @@ -136,11 +136,11 @@ mod tests { #[tokio::test] async fn test_attr_ref_eq_constint_in_mcv() { - let mut mcvs_counts = HashMap::new(); - mcvs_counts.insert(vec![Some(Value::Int32(1))], 3); - let mcvs_total_count = 10; let per_attribute_stats = TestPerAttributeStats::new( - MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![( + vec![Some(Value::Int32(1))], + 0.3, + )])), 0, 0.0, None, @@ -170,12 +170,11 @@ mod tests { #[tokio::test] async fn test_attr_ref_eq_constint_not_in_mcv() { - let mut mcvs_counts = HashMap::new(); - mcvs_counts.insert(vec![Some(Value::Int32(1))], 20); - mcvs_counts.insert(vec![Some(Value::Int32(3))], 44); - let mcvs_total_count = 100; let per_attribute_stats = TestPerAttributeStats::new( - MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![ + (vec![Some(Value::Int32(1))], 0.2), + (vec![Some(Value::Int32(3))], 0.44), + ])), 5, 0.0, None, @@ -206,11 +205,11 @@ mod tests { /// I only have one test for NEQ since I'll assume that it uses the same underlying logic as EQ #[tokio::test] async fn test_attr_ref_neq_constint_in_mcv() { - let mut mcvs_counts = HashMap::new(); - mcvs_counts.insert(vec![Some(Value::Int32(1))], 3); - let mcvs_total_count = 10; let per_attribute_stats = TestPerAttributeStats::new( - MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![( + vec![Some(Value::Int32(1))], + 0.3, + )])), 0, 0.0, None, @@ -240,10 +239,8 @@ mod tests { #[tokio::test] async fn test_attr_ref_leq_constint_no_mcvs_in_range() { - let mut mcvs_counts = HashMap::new(); - let mcvs_total_count = 10; let per_attribute_stats = TestPerAttributeStats::new( - MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![])), 10, 0.0, Some(Distribution::SimpleDistribution(SimpleMap::new(vec![( @@ -280,14 +277,13 @@ mod tests { #[tokio::test] async fn test_attr_ref_leq_constint_with_mcvs_in_range_not_at_border() { - let mut mcvs_counts = HashMap::new(); - mcvs_counts.insert(vec![Some(Value::Int32(6))], 5); - mcvs_counts.insert(vec![Some(Value::Int32(10))], 10); - mcvs_counts.insert(vec![Some(Value::Int32(17))], 8); - mcvs_counts.insert(vec![Some(Value::Int32(25))], 7); - let mcvs_total_count = 100; let per_attribute_stats = TestPerAttributeStats::new( - MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![ + (vec![Some(Value::Int32(6))], 0.05), + (vec![Some(Value::Int32(10))], 0.1), + (vec![Some(Value::Int32(17))], 0.08), + (vec![Some(Value::Int32(25))], 0.07), + ])), 10, 0.0, Some(Distribution::SimpleDistribution(SimpleMap::new(vec![( @@ -324,14 +320,13 @@ mod tests { #[tokio::test] async fn test_attr_ref_leq_constint_with_mcv_at_border() { - let mut mcvs_counts = HashMap::new(); - mcvs_counts.insert(vec![Some(Value::Int32(6))], 5); - mcvs_counts.insert(vec![Some(Value::Int32(10))], 10); - mcvs_counts.insert(vec![Some(Value::Int32(15))], 8); - mcvs_counts.insert(vec![Some(Value::Int32(25))], 7); - let mcvs_total_count = 100; let per_attribute_stats = TestPerAttributeStats::new( - MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![ + (vec![Some(Value::Int32(6))], 0.05), + (vec![Some(Value::Int32(10))], 0.1), + (vec![Some(Value::Int32(15))], 0.08), + (vec![Some(Value::Int32(25))], 0.07), + ])), 10, 0.0, Some(Distribution::SimpleDistribution(SimpleMap::new(vec![( @@ -368,10 +363,8 @@ mod tests { #[tokio::test] async fn test_attr_ref_lt_constint_no_mcvs_in_range() { - let mut mcvs_counts = HashMap::new(); - let mcvs_total_count = 10; let per_attribute_stats = TestPerAttributeStats::new( - MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![])), 10, 0.0, Some(Distribution::SimpleDistribution(SimpleMap::new(vec![( @@ -408,14 +401,13 @@ mod tests { #[tokio::test] async fn test_attr_ef_lt_constint_with_mcvs_in_range_not_at_border() { - let mut mcvs_counts = HashMap::new(); - mcvs_counts.insert(vec![Some(Value::Int32(6))], 5); - mcvs_counts.insert(vec![Some(Value::Int32(10))], 10); - mcvs_counts.insert(vec![Some(Value::Int32(17))], 8); - mcvs_counts.insert(vec![Some(Value::Int32(25))], 7); - let mcvs_total_count = 100; let per_attribute_stats = TestPerAttributeStats::new( - MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![ + (vec![Some(Value::Int32(6))], 0.05), + (vec![Some(Value::Int32(10))], 0.1), + (vec![Some(Value::Int32(17))], 0.08), + (vec![Some(Value::Int32(25))], 0.07), + ])), 11, /* there are 4 MCVs which together add up to 0.3. With 11 total ndistinct, each * remaining value has freq 0.1 */ 0.0, @@ -453,14 +445,13 @@ mod tests { #[tokio::test] async fn test_attr_ref_lt_constint_with_mcv_at_border() { - let mut mcvs_counts = HashMap::new(); - mcvs_counts.insert(vec![Some(Value::Int32(6))], 5); - mcvs_counts.insert(vec![Some(Value::Int32(10))], 10); - mcvs_counts.insert(vec![Some(Value::Int32(15))], 8); - mcvs_counts.insert(vec![Some(Value::Int32(25))], 7); - let mcvs_total_count = 100; let per_attribute_stats = TestPerAttributeStats::new( - MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![ + (vec![Some(Value::Int32(6))], 0.05), + (vec![Some(Value::Int32(10))], 0.1), + (vec![Some(Value::Int32(15))], 0.08), + (vec![Some(Value::Int32(25))], 0.07), + ])), 11, /* there are 4 MCVs which together add up to 0.3. With 11 total ndistinct, each * remaining value has freq 0.1 */ 0.0, @@ -500,10 +491,8 @@ mod tests { /// The only interesting thing to test is that if there are nulls, those aren't included in GT #[tokio::test] async fn test_attr_ref_gt_constint() { - let mut mcvs_counts = HashMap::new(); - let mcvs_total_count = 100; let per_attribute_stats = TestPerAttributeStats::new( - MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![])), 10, 0.0, Some(Distribution::SimpleDistribution(SimpleMap::new(vec![( @@ -540,10 +529,8 @@ mod tests { #[tokio::test] async fn test_attr_ref_geq_constint() { - let mut mcvs_counts = HashMap::new(); - let mcvs_total_count = 100; let per_attribute_stats = TestPerAttributeStats::new( - MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![])), 10, 0.0, Some(Distribution::SimpleDistribution(SimpleMap::new(vec![( @@ -581,13 +568,12 @@ mod tests { #[tokio::test] async fn test_and() { - let mut mcvs_counts = HashMap::new(); - mcvs_counts.insert(vec![Some(Value::Int32(1))], 3); - mcvs_counts.insert(vec![Some(Value::Int32(5))], 5); - mcvs_counts.insert(vec![Some(Value::Int32(8))], 2); - let mcvs_total_count = 10; let per_attribute_stats = TestPerAttributeStats::new( - MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![ + (vec![Some(Value::Int32(1))], 0.3), + (vec![Some(Value::Int32(5))], 0.5), + (vec![Some(Value::Int32(8))], 0.2), + ])), 0, 0.0, None, @@ -629,13 +615,12 @@ mod tests { #[tokio::test] async fn test_or() { - let mut mcvs_counts = HashMap::new(); - mcvs_counts.insert(vec![Some(Value::Int32(1))], 3); - mcvs_counts.insert(vec![Some(Value::Int32(5))], 5); - mcvs_counts.insert(vec![Some(Value::Int32(8))], 2); - let mcvs_total_count = 10; let per_attribute_stats = TestPerAttributeStats::new( - MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![ + (vec![Some(Value::Int32(1))], 0.3), + (vec![Some(Value::Int32(5))], 0.5), + (vec![Some(Value::Int32(8))], 0.2), + ])), 0, 0.0, None, @@ -677,11 +662,11 @@ mod tests { #[tokio::test] async fn test_not() { - let mut mcvs_counts = HashMap::new(); - mcvs_counts.insert(vec![Some(Value::Int32(1))], 3); - let mcvs_total_count = 10; let per_attribute_stats = TestPerAttributeStats::new( - MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![( + vec![Some(Value::Int32(1))], + 0.3, + )])), 0, 0.0, None, @@ -710,11 +695,11 @@ mod tests { #[tokio::test] async fn test_attr_ref_eq_cast_value() { - let mut mcvs_counts = HashMap::new(); - mcvs_counts.insert(vec![Some(Value::Int32(1))], 3); - let mcvs_total_count = 10; let per_attribute_stats = TestPerAttributeStats::new( - MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![( + vec![Some(Value::Int32(1))], + 0.3, + )])), 0, 0.0, None, @@ -753,11 +738,11 @@ mod tests { #[tokio::test] async fn test_cast_attr_ref_eq_value() { - let mut mcvs_counts = HashMap::new(); - mcvs_counts.insert(vec![Some(Value::Int32(1))], 3); - let mcvs_total_count = 10; let per_attribute_stats = TestPerAttributeStats::new( - MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![( + vec![Some(Value::Int32(1))], + 0.3, + )])), 0, 0.1, None, @@ -812,10 +797,8 @@ mod tests { /// pretty good signal that the Cast was left as is. #[tokio::test] async fn test_cast_attr_ref_eq_attr_ref() { - let mut mcvs_counts = HashMap::new(); - let mcvs_total_count = 10; let per_attribute_stats = TestPerAttributeStats::new( - MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![])), 0, 0.0, None, diff --git a/optd-cost-model/src/cost/filter/in_list.rs b/optd-cost-model/src/cost/filter/in_list.rs index 16080de..2363d4a 100644 --- a/optd-cost-model/src/cost/filter/in_list.rs +++ b/optd-cost-model/src/cost/filter/in_list.rs @@ -73,17 +73,19 @@ mod tests { use crate::{ common::{types::TableId, values::Value}, cost_model::tests::*, - stats::{utilities::counter::Counter, MostCommonValues}, + stats::{ + utilities::{counter::Counter, simple_map::SimpleMap}, + MostCommonValues, + }, }; #[tokio::test] async fn test_in_list() { - let mut mcvs_counts = HashMap::new(); - mcvs_counts.insert(vec![Some(Value::Int32(1))], 8); - mcvs_counts.insert(vec![Some(Value::Int32(2))], 2); - let mcvs_total_count = 10; let per_attribute_stats = TestPerAttributeStats::new( - MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![ + (vec![Some(Value::Int32(1))], 0.8), + (vec![Some(Value::Int32(2))], 0.2), + ])), 2, 0.0, None, diff --git a/optd-cost-model/src/cost/filter/like.rs b/optd-cost-model/src/cost/filter/like.rs index f49ca18..03da4d1 100644 --- a/optd-cost-model/src/cost/filter/like.rs +++ b/optd-cost-model/src/cost/filter/like.rs @@ -108,19 +108,18 @@ mod tests { common::{types::TableId, values::Value}, cost_model::tests::*, stats::{ - utilities::counter::Counter, MostCommonValues, FIXED_CHAR_SEL_FACTOR, - FULL_WILDCARD_SEL_FACTOR, + utilities::{counter::Counter, simple_map::SimpleMap}, + MostCommonValues, FIXED_CHAR_SEL_FACTOR, FULL_WILDCARD_SEL_FACTOR, }, }; #[tokio::test] async fn test_like_no_nulls() { - let mut mcvs_counts = HashMap::new(); - mcvs_counts.insert(vec![Some(Value::String("abcd".into()))], 1); - mcvs_counts.insert(vec![Some(Value::String("abc".into()))], 1); - let mcvs_total_count = 10; let per_attribute_stats = TestPerAttributeStats::new( - MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![ + (vec![Some(Value::String("abcd".into()))], 0.1), + (vec![Some(Value::String("abc".into()))], 0.1), + ])), 2, 0.0, None, diff --git a/optd-cost-model/src/stats/mod.rs b/optd-cost-model/src/stats/mod.rs index 5440ea1..a20cdf2 100644 --- a/optd-cost-model/src/stats/mod.rs +++ b/optd-cost-model/src/stats/mod.rs @@ -35,6 +35,7 @@ pub type AttributeCombValue = Vec>; #[serde(tag = "type")] pub enum MostCommonValues { Counter(Counter), + SimpleFrequency(SimpleMap), // Add more types here... } @@ -47,12 +48,14 @@ impl MostCommonValues { pub fn freq(&self, value: &AttributeCombValue) -> Option { match self { MostCommonValues::Counter(counter) => counter.frequencies().get(value).copied(), + MostCommonValues::SimpleFrequency(simple_map) => simple_map.m.get(value).copied(), } } pub fn total_freq(&self) -> f64 { match self { MostCommonValues::Counter(counter) => counter.frequencies().values().sum(), + MostCommonValues::SimpleFrequency(simple_map) => simple_map.m.values().sum(), } } @@ -64,6 +67,12 @@ impl MostCommonValues { .filter(|(val, _)| pred(val)) .map(|(_, freq)| freq) .sum(), + MostCommonValues::SimpleFrequency(simple_map) => simple_map + .m + .iter() + .filter(|(val, _)| pred(val)) + .map(|(_, freq)| freq) + .sum(), } } @@ -71,6 +80,7 @@ impl MostCommonValues { pub fn cnt(&self) -> usize { match self { MostCommonValues::Counter(counter) => counter.frequencies().len(), + MostCommonValues::SimpleFrequency(simple_map) => simple_map.m.len(), } } } @@ -80,7 +90,7 @@ impl MostCommonValues { #[serde(tag = "type")] pub enum Distribution { TDigest(TDigest), - SimpleDistribution(SimpleMap), + SimpleDistribution(SimpleMap), // Add more types here... } diff --git a/optd-cost-model/src/stats/utilities/simple_map.rs b/optd-cost-model/src/stats/utilities/simple_map.rs index 5503b2f..f685fe6 100644 --- a/optd-cost-model/src/stats/utilities/simple_map.rs +++ b/optd-cost-model/src/stats/utilities/simple_map.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::hash::Hash; use serde::{Deserialize, Serialize}; @@ -7,12 +8,12 @@ use crate::common::values::Value; /// TODO: documentation /// Now it is mainly for testing purposes. #[derive(Clone, Serialize, Deserialize, Debug)] -pub struct SimpleMap { - pub(crate) m: HashMap, +pub struct SimpleMap { + pub(crate) m: HashMap, } -impl SimpleMap { - pub fn new(v: Vec<(Value, f64)>) -> Self { +impl SimpleMap { + pub fn new(v: Vec<(K, f64)>) -> Self { Self { m: v.into_iter().collect(), } From 082f0be54c7111f88a3732548a10b63eb541e2d1 Mon Sep 17 00:00:00 2001 From: Yuanxin Cao Date: Sun, 17 Nov 2024 09:27:51 -0500 Subject: [PATCH 31/51] refine test infra --- .../src/common/predicates/attr_ref_pred.rs | 4 +- optd-cost-model/src/cost/agg.rs | 1 + optd-cost-model/src/cost/filter/attribute.rs | 6 +- optd-cost-model/src/cost/filter/controller.rs | 90 +++++++++---------- optd-cost-model/src/cost/filter/in_list.rs | 4 +- optd-cost-model/src/cost/filter/like.rs | 8 +- optd-cost-model/src/cost_model.rs | 21 +++-- optd-cost-model/src/lib.rs | 2 +- optd-cost-model/src/stats/mod.rs | 2 +- .../src/stats/utilities/simple_map.rs | 2 +- optd-cost-model/src/storage/mock.rs | 8 +- optd-cost-model/src/storage/mod.rs | 2 +- optd-cost-model/src/storage/persistent.rs | 6 +- 13 files changed, 85 insertions(+), 71 deletions(-) diff --git a/optd-cost-model/src/common/predicates/attr_ref_pred.rs b/optd-cost-model/src/common/predicates/attr_ref_pred.rs index cdc7440..e589cdf 100644 --- a/optd-cost-model/src/common/predicates/attr_ref_pred.rs +++ b/optd-cost-model/src/common/predicates/attr_ref_pred.rs @@ -49,8 +49,8 @@ impl AttributeRefPred { /// Gets the attribute index. /// Note: The attribute index is the **base** index, which is table specific. - pub fn attr_index(&self) -> usize { - self.0.child(1).data.as_ref().unwrap().as_u64() as usize + pub fn attr_index(&self) -> u64 { + self.0.child(1).data.as_ref().unwrap().as_u64() } /// Checks whether the attribute is a derived attribute. Currently, this will always return diff --git a/optd-cost-model/src/cost/agg.rs b/optd-cost-model/src/cost/agg.rs index f7e0034..858fe1b 100644 --- a/optd-cost-model/src/cost/agg.rs +++ b/optd-cost-model/src/cost/agg.rs @@ -37,6 +37,7 @@ impl CostModelImpl { } else { let table_id = attr_ref.table_id(); let attr_idx = attr_ref.attr_index(); + // TODO: Only query ndistinct instead of all kinds of stats. let stats_option = self.get_attribute_comb_stats(table_id, &[attr_idx]).await?; diff --git a/optd-cost-model/src/cost/filter/attribute.rs b/optd-cost-model/src/cost/filter/attribute.rs index 7eb77ce..e39d7b5 100644 --- a/optd-cost-model/src/cost/filter/attribute.rs +++ b/optd-cost-model/src/cost/filter/attribute.rs @@ -19,7 +19,7 @@ impl CostModelImpl { pub(crate) async fn get_attribute_equality_selectivity( &self, table_id: TableId, - attr_base_index: usize, + attr_base_index: u64, value: &Value, is_eq: bool, ) -> CostModelResult { @@ -93,7 +93,7 @@ impl CostModelImpl { &self, attribute_stats: &AttributeCombValueStats, table_id: TableId, - attr_base_index: usize, + attr_base_index: u64, value: &Value, ) -> CostModelResult { // depending on whether value is in mcvs or not, we use different logic to turn total_lt_cdf @@ -119,7 +119,7 @@ impl CostModelImpl { pub(crate) async fn get_attribute_range_selectivity( &self, table_id: TableId, - attr_base_index: usize, + attr_base_index: u64, start: Bound<&Value>, end: Bound<&Value>, ) -> CostModelResult { diff --git a/optd-cost-model/src/cost/filter/controller.rs b/optd-cost-model/src/cost/filter/controller.rs index c10ea1d..b319a48 100644 --- a/optd-cost-model/src/cost/filter/controller.rs +++ b/optd-cost-model/src/cost/filter/controller.rs @@ -114,7 +114,7 @@ mod tests { async fn test_const() { let cost_model = create_cost_model_mock_storage( vec![TableId(0)], - vec![get_empty_per_attr_stats()], + vec![HashMap::from([(0, empty_per_attr_stats())])], vec![None], HashMap::new(), ); @@ -141,14 +141,14 @@ mod tests { vec![Some(Value::Int32(1))], 0.3, )])), + None, 0, 0.0, - None, ); let table_id = TableId(0); let cost_model = create_cost_model_mock_storage( vec![table_id], - vec![per_attribute_stats], + vec![HashMap::from([(0, per_attribute_stats)])], vec![None], HashMap::new(), ); @@ -175,14 +175,14 @@ mod tests { (vec![Some(Value::Int32(1))], 0.2), (vec![Some(Value::Int32(3))], 0.44), ])), + None, 5, 0.0, - None, ); let table_id = TableId(0); let cost_model = create_cost_model_mock_storage( vec![table_id], - vec![per_attribute_stats], + vec![HashMap::from([(0, per_attribute_stats)])], vec![None], HashMap::new(), ); @@ -210,14 +210,14 @@ mod tests { vec![Some(Value::Int32(1))], 0.3, )])), + None, 0, 0.0, - None, ); let table_id = TableId(0); let cost_model = create_cost_model_mock_storage( vec![table_id], - vec![per_attribute_stats], + vec![HashMap::from([(0, per_attribute_stats)])], vec![None], HashMap::new(), ); @@ -241,17 +241,17 @@ mod tests { async fn test_attr_ref_leq_constint_no_mcvs_in_range() { let per_attribute_stats = TestPerAttributeStats::new( MostCommonValues::SimpleFrequency(SimpleMap::new(vec![])), - 10, - 0.0, Some(Distribution::SimpleDistribution(SimpleMap::new(vec![( Value::Int32(15), 0.7, )]))), + 10, + 0.0, ); let table_id = TableId(0); let cost_model = create_cost_model_mock_storage( vec![table_id], - vec![per_attribute_stats], + vec![HashMap::from([(0, per_attribute_stats)])], vec![None], HashMap::new(), ); @@ -284,17 +284,17 @@ mod tests { (vec![Some(Value::Int32(17))], 0.08), (vec![Some(Value::Int32(25))], 0.07), ])), - 10, - 0.0, Some(Distribution::SimpleDistribution(SimpleMap::new(vec![( Value::Int32(15), 0.7, )]))), + 10, + 0.0, ); let table_id = TableId(0); let cost_model = create_cost_model_mock_storage( vec![table_id], - vec![per_attribute_stats], + vec![HashMap::from([(0, per_attribute_stats)])], vec![None], HashMap::new(), ); @@ -327,17 +327,17 @@ mod tests { (vec![Some(Value::Int32(15))], 0.08), (vec![Some(Value::Int32(25))], 0.07), ])), - 10, - 0.0, Some(Distribution::SimpleDistribution(SimpleMap::new(vec![( Value::Int32(15), 0.7, )]))), + 10, + 0.0, ); let table_id = TableId(0); let cost_model = create_cost_model_mock_storage( vec![table_id], - vec![per_attribute_stats], + vec![HashMap::from([(0, per_attribute_stats)])], vec![None], HashMap::new(), ); @@ -365,17 +365,17 @@ mod tests { async fn test_attr_ref_lt_constint_no_mcvs_in_range() { let per_attribute_stats = TestPerAttributeStats::new( MostCommonValues::SimpleFrequency(SimpleMap::new(vec![])), - 10, - 0.0, Some(Distribution::SimpleDistribution(SimpleMap::new(vec![( Value::Int32(15), 0.7, )]))), + 10, + 0.0, ); let table_id = TableId(0); let cost_model = create_cost_model_mock_storage( vec![table_id], - vec![per_attribute_stats], + vec![HashMap::from([(0, per_attribute_stats)])], vec![None], HashMap::new(), ); @@ -408,18 +408,18 @@ mod tests { (vec![Some(Value::Int32(17))], 0.08), (vec![Some(Value::Int32(25))], 0.07), ])), - 11, /* there are 4 MCVs which together add up to 0.3. With 11 total ndistinct, each - * remaining value has freq 0.1 */ - 0.0, Some(Distribution::SimpleDistribution(SimpleMap::new(vec![( Value::Int32(15), 0.7, )]))), + 11, /* there are 4 MCVs which together add up to 0.3. With 11 total ndistinct, each + * remaining value has freq 0.1 */ + 0.0, ); let table_id = TableId(0); let cost_model = create_cost_model_mock_storage( vec![table_id], - vec![per_attribute_stats], + vec![HashMap::from([(0, per_attribute_stats)])], vec![None], HashMap::new(), ); @@ -452,18 +452,18 @@ mod tests { (vec![Some(Value::Int32(15))], 0.08), (vec![Some(Value::Int32(25))], 0.07), ])), - 11, /* there are 4 MCVs which together add up to 0.3. With 11 total ndistinct, each - * remaining value has freq 0.1 */ - 0.0, Some(Distribution::SimpleDistribution(SimpleMap::new(vec![( Value::Int32(15), 0.7, )]))), + 11, /* there are 4 MCVs which together add up to 0.3. With 11 total ndistinct, each + * remaining value has freq 0.1 */ + 0.0, ); let table_id = TableId(0); let cost_model = create_cost_model_mock_storage( vec![table_id], - vec![per_attribute_stats], + vec![HashMap::from([(0, per_attribute_stats)])], vec![None], HashMap::new(), ); @@ -493,17 +493,17 @@ mod tests { async fn test_attr_ref_gt_constint() { let per_attribute_stats = TestPerAttributeStats::new( MostCommonValues::SimpleFrequency(SimpleMap::new(vec![])), - 10, - 0.0, Some(Distribution::SimpleDistribution(SimpleMap::new(vec![( Value::Int32(15), 0.7, )]))), + 10, + 0.0, ); let table_id = TableId(0); let cost_model = create_cost_model_mock_storage( vec![table_id], - vec![per_attribute_stats], + vec![HashMap::from([(0, per_attribute_stats)])], vec![None], HashMap::new(), ); @@ -531,17 +531,17 @@ mod tests { async fn test_attr_ref_geq_constint() { let per_attribute_stats = TestPerAttributeStats::new( MostCommonValues::SimpleFrequency(SimpleMap::new(vec![])), - 10, - 0.0, Some(Distribution::SimpleDistribution(SimpleMap::new(vec![( Value::Int32(15), 0.7, )]))), + 10, + 0.0, ); let table_id = TableId(0); let cost_model = create_cost_model_mock_storage( vec![table_id], - vec![per_attribute_stats], + vec![HashMap::from([(0, per_attribute_stats)])], vec![None], HashMap::new(), ); @@ -574,14 +574,14 @@ mod tests { (vec![Some(Value::Int32(5))], 0.5), (vec![Some(Value::Int32(8))], 0.2), ])), + None, 0, 0.0, - None, ); let table_id = TableId(0); let cost_model = create_cost_model_mock_storage( vec![table_id], - vec![per_attribute_stats], + vec![HashMap::from([(0, per_attribute_stats)])], vec![None], HashMap::new(), ); @@ -621,14 +621,14 @@ mod tests { (vec![Some(Value::Int32(5))], 0.5), (vec![Some(Value::Int32(8))], 0.2), ])), + None, 0, 0.0, - None, ); let table_id = TableId(0); let cost_model = create_cost_model_mock_storage( vec![table_id], - vec![per_attribute_stats], + vec![HashMap::from([(0, per_attribute_stats)])], vec![None], HashMap::new(), ); @@ -667,14 +667,14 @@ mod tests { vec![Some(Value::Int32(1))], 0.3, )])), + None, 0, 0.0, - None, ); let table_id = TableId(0); let cost_model = create_cost_model_mock_storage( vec![table_id], - vec![per_attribute_stats], + vec![HashMap::from([(0, per_attribute_stats)])], vec![None], HashMap::new(), ); @@ -700,14 +700,14 @@ mod tests { vec![Some(Value::Int32(1))], 0.3, )])), + None, 0, 0.0, - None, ); let table_id = TableId(0); let cost_model = create_cost_model_mock_storage( vec![table_id], - vec![per_attribute_stats], + vec![HashMap::from([(0, per_attribute_stats)])], vec![None], HashMap::new(), ); @@ -743,9 +743,9 @@ mod tests { vec![Some(Value::Int32(1))], 0.3, )])), + None, 0, 0.1, - None, ); let table_id = TableId(0); let attr_infos = HashMap::from([( @@ -761,7 +761,7 @@ mod tests { )]); let cost_model = create_cost_model_mock_storage( vec![table_id], - vec![per_attribute_stats], + vec![HashMap::from([(0, per_attribute_stats)])], vec![None], attr_infos, ); @@ -799,9 +799,9 @@ mod tests { async fn test_cast_attr_ref_eq_attr_ref() { let per_attribute_stats = TestPerAttributeStats::new( MostCommonValues::SimpleFrequency(SimpleMap::new(vec![])), + None, 0, 0.0, - None, ); let table_id = TableId(0); let attr_infos = HashMap::from([( @@ -827,7 +827,7 @@ mod tests { )]); let cost_model = create_cost_model_mock_storage( vec![table_id], - vec![per_attribute_stats], + vec![HashMap::from([(0, per_attribute_stats)])], vec![None], attr_infos, ); diff --git a/optd-cost-model/src/cost/filter/in_list.rs b/optd-cost-model/src/cost/filter/in_list.rs index 2363d4a..2c79ec2 100644 --- a/optd-cost-model/src/cost/filter/in_list.rs +++ b/optd-cost-model/src/cost/filter/in_list.rs @@ -86,14 +86,14 @@ mod tests { (vec![Some(Value::Int32(1))], 0.8), (vec![Some(Value::Int32(2))], 0.2), ])), + None, 2, 0.0, - None, ); let table_id = TableId(0); let cost_model = create_cost_model_mock_storage( vec![table_id], - vec![per_attribute_stats], + vec![HashMap::from([(0, per_attribute_stats)])], vec![None], HashMap::new(), ); diff --git a/optd-cost-model/src/cost/filter/like.rs b/optd-cost-model/src/cost/filter/like.rs index 03da4d1..92a519a 100644 --- a/optd-cost-model/src/cost/filter/like.rs +++ b/optd-cost-model/src/cost/filter/like.rs @@ -120,14 +120,14 @@ mod tests { (vec![Some(Value::String("abcd".into()))], 0.1), (vec![Some(Value::String("abc".into()))], 0.1), ])), + None, 2, 0.0, - None, ); let table_id = TableId(0); let cost_model = create_cost_model_mock_storage( vec![table_id], - vec![per_attribute_stats], + vec![HashMap::from([(0, per_attribute_stats)])], vec![None], HashMap::new(), ); @@ -163,14 +163,14 @@ mod tests { let mcvs_total_count = 10; let per_attribute_stats = TestPerAttributeStats::new( MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + None, 2, null_frac, - None, ); let table_id = TableId(0); let cost_model = create_cost_model_mock_storage( vec![table_id], - vec![per_attribute_stats], + vec![HashMap::from([(0, per_attribute_stats)])], vec![None], HashMap::new(), ); diff --git a/optd-cost-model/src/cost_model.rs b/optd-cost-model/src/cost_model.rs index ebb6391..1943519 100644 --- a/optd-cost-model/src/cost_model.rs +++ b/optd-cost-model/src/cost_model.rs @@ -99,7 +99,7 @@ impl CostModelImpl { pub(crate) async fn get_attribute_comb_stats( &self, table_id: TableId, - attr_comb: &[usize], + attr_comb: &[u64], ) -> CostModelResult> { self.storage_manager .get_attributes_comb_statistics(table_id, attr_comb) @@ -149,7 +149,7 @@ pub mod tests { pub fn create_cost_model_mock_storage( table_id: Vec, - per_attribute_stats: Vec, + per_attribute_stats: Vec>, row_counts: Vec>, per_table_attr_infos: BaseTableAttrInfo, ) -> TestOptCostModelMock { @@ -163,7 +163,10 @@ pub mod tests { table_id, TableStats::new( row_count.unwrap_or(100), - vec![(vec![0], per_attr_stats)].into_iter().collect(), + per_attr_stats + .into_iter() + .map(|(attr_idx, stats)| (vec![attr_idx], stats)) + .collect(), ), ) }) @@ -201,6 +204,14 @@ pub mod tests { UnOpPred::new(child, op_type).into_pred_node() } + pub fn empty_list() -> ArcPredicateNode { + ListPred::new(vec![]).into_pred_node() + } + + pub fn list(children: Vec) -> ArcPredicateNode { + ListPred::new(children).into_pred_node() + } + pub fn in_list( table_id: TableId, attr_ref_idx: usize, @@ -223,7 +234,7 @@ pub mod tests { ) } - pub(crate) fn get_empty_per_attr_stats() -> TestPerAttributeStats { - TestPerAttributeStats::new(MostCommonValues::Counter(Counter::default()), 0, 0.0, None) + pub(crate) fn empty_per_attr_stats() -> TestPerAttributeStats { + TestPerAttributeStats::new(MostCommonValues::Counter(Counter::default()), None, 0, 0.0) } } diff --git a/optd-cost-model/src/lib.rs b/optd-cost-model/src/lib.rs index d4a24d2..a9b977a 100644 --- a/optd-cost-model/src/lib.rs +++ b/optd-cost-model/src/lib.rs @@ -31,7 +31,7 @@ pub struct Cost(pub Vec); /// Estimated statistic calculated by the cost model. /// It is the estimated output row count of the targeted expression. -#[derive(Eq, Ord, PartialEq, PartialOrd)] +#[derive(Eq, Ord, PartialEq, PartialOrd, Debug)] pub struct EstimatedStatistic(pub u64); pub type CostModelResult = Result; diff --git a/optd-cost-model/src/stats/mod.rs b/optd-cost-model/src/stats/mod.rs index a20cdf2..0fcc4c2 100644 --- a/optd-cost-model/src/stats/mod.rs +++ b/optd-cost-model/src/stats/mod.rs @@ -130,9 +130,9 @@ pub struct AttributeCombValueStats { impl AttributeCombValueStats { pub fn new( mcvs: MostCommonValues, + distr: Option, ndistinct: u64, null_frac: f64, - distr: Option, ) -> Self { Self { mcvs, diff --git a/optd-cost-model/src/stats/utilities/simple_map.rs b/optd-cost-model/src/stats/utilities/simple_map.rs index f685fe6..d04439e 100644 --- a/optd-cost-model/src/stats/utilities/simple_map.rs +++ b/optd-cost-model/src/stats/utilities/simple_map.rs @@ -7,7 +7,7 @@ use crate::common::values::Value; /// TODO: documentation /// Now it is mainly for testing purposes. -#[derive(Clone, Serialize, Deserialize, Debug)] +#[derive(Clone, Serialize, Deserialize, Debug, Default)] pub struct SimpleMap { pub(crate) m: HashMap, } diff --git a/optd-cost-model/src/storage/mock.rs b/optd-cost-model/src/storage/mock.rs index ed75a59..aba140b 100644 --- a/optd-cost-model/src/storage/mock.rs +++ b/optd-cost-model/src/storage/mock.rs @@ -7,20 +7,20 @@ use crate::{common::types::TableId, stats::AttributeCombValueStats, CostModelRes use super::{Attribute, CostModelStorageManager}; -pub type AttrsIdx = Vec; +pub type AttrIndices = Vec; #[serde_with::serde_as] #[derive(Serialize, Deserialize, Debug)] pub struct TableStats { pub row_cnt: usize, #[serde_as(as = "HashMap")] - pub column_comb_stats: HashMap, + pub column_comb_stats: HashMap, } impl TableStats { pub fn new( row_cnt: usize, - column_comb_stats: HashMap, + column_comb_stats: HashMap, ) -> Self { Self { row_cnt, @@ -68,7 +68,7 @@ impl CostModelStorageManager for CostModelStorageMockManagerImpl { async fn get_attributes_comb_statistics( &self, table_id: TableId, - attr_base_indices: &[usize], + attr_base_indices: &[u64], ) -> CostModelResult> { let table_stats = self.per_table_stats_map.get(&table_id); match table_stats { diff --git a/optd-cost-model/src/storage/mod.rs b/optd-cost-model/src/storage/mod.rs index 78c75cd..fcc5141 100644 --- a/optd-cost-model/src/storage/mod.rs +++ b/optd-cost-model/src/storage/mod.rs @@ -27,6 +27,6 @@ pub trait CostModelStorageManager { async fn get_attributes_comb_statistics( &self, table_id: TableId, - attr_base_indices: &[usize], + attr_base_indices: &[u64], ) -> CostModelResult>; } diff --git a/optd-cost-model/src/storage/persistent.rs b/optd-cost-model/src/storage/persistent.rs index b574fe0..49c4ff4 100644 --- a/optd-cost-model/src/storage/persistent.rs +++ b/optd-cost-model/src/storage/persistent.rs @@ -66,7 +66,7 @@ impl CostModelStorageManager async fn get_attributes_comb_statistics( &self, table_id: TableId, - attr_base_indices: &[usize], + attr_base_indices: &[u64], ) -> CostModelResult> { let dist: Option = self .backend_manager @@ -136,7 +136,9 @@ impl CostModelStorageManager }; Ok(Some(AttributeCombValueStats::new( - mcvs, ndistinct, null_frac, dist, + mcvs, dist, ndistinct, null_frac, ))) } + + // TODO: Support querying for a specific type of statistics. } From e183f02f6d55e39abf4a9bd035507ff0b1cf68fe Mon Sep 17 00:00:00 2001 From: Yuanxin Cao Date: Sun, 17 Nov 2024 09:43:27 -0500 Subject: [PATCH 32/51] add test for cost model agg --- .../src/common/predicates/attr_ref_pred.rs | 4 +- .../src/common/predicates/id_pred.rs | 6 +- optd-cost-model/src/cost/agg.rs | 163 ++++++++++++++++++ optd-cost-model/src/cost/filter/controller.rs | 10 +- optd-cost-model/src/cost_model.rs | 6 +- optd-cost-model/src/storage/mock.rs | 4 +- 6 files changed, 177 insertions(+), 16 deletions(-) diff --git a/optd-cost-model/src/common/predicates/attr_ref_pred.rs b/optd-cost-model/src/common/predicates/attr_ref_pred.rs index e589cdf..99638f2 100644 --- a/optd-cost-model/src/common/predicates/attr_ref_pred.rs +++ b/optd-cost-model/src/common/predicates/attr_ref_pred.rs @@ -28,12 +28,12 @@ use super::id_pred::IdPred; pub struct AttributeRefPred(pub ArcPredicateNode); impl AttributeRefPred { - pub fn new(table_id: TableId, attribute_idx: usize) -> AttributeRefPred { + pub fn new(table_id: TableId, attribute_idx: u64) -> AttributeRefPred { AttributeRefPred( PredicateNode { typ: PredicateType::AttributeRef, children: vec![ - IdPred::new(table_id.0).into_pred_node(), + IdPred::new(table_id.0 as u64).into_pred_node(), IdPred::new(attribute_idx).into_pred_node(), ], data: None, diff --git a/optd-cost-model/src/common/predicates/id_pred.rs b/optd-cost-model/src/common/predicates/id_pred.rs index 962e526..e502a48 100644 --- a/optd-cost-model/src/common/predicates/id_pred.rs +++ b/optd-cost-model/src/common/predicates/id_pred.rs @@ -11,14 +11,12 @@ use crate::common::{ pub struct IdPred(pub ArcPredicateNode); impl IdPred { - pub fn new(id: usize) -> IdPred { - // This conversion is always safe since usize is at most u64. - let u64_id = id as u64; + pub fn new(id: u64) -> IdPred { IdPred( PredicateNode { typ: PredicateType::Id, children: vec![], - data: Some(Value::UInt64(u64_id)), + data: Some(Value::UInt64(id)), } .into(), ) diff --git a/optd-cost-model/src/cost/agg.rs b/optd-cost-model/src/cost/agg.rs index 858fe1b..2372b08 100644 --- a/optd-cost-model/src/cost/agg.rs +++ b/optd-cost-model/src/cost/agg.rs @@ -61,3 +61,166 @@ impl CostModelImpl { } } } + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use crate::{ + common::{predicates::constant_pred::ConstantType, types::TableId, values::Value}, + cost_model::tests::{ + attr_ref, cnst, create_cost_model_mock_storage, empty_list, empty_per_attr_stats, list, + TestPerAttributeStats, + }, + stats::{utilities::simple_map::SimpleMap, MostCommonValues, DEFAULT_NUM_DISTINCT}, + storage::Attribute, + EstimatedStatistic, + }; + + #[tokio::test] + async fn test_agg_no_stats() { + let table_id = TableId(0); + let attr_infos = HashMap::from([( + table_id, + HashMap::from([ + ( + 0, + Attribute { + name: String::from("attr1"), + typ: ConstantType::Int32, + nullable: false, + }, + ), + ( + 1, + Attribute { + name: String::from("attr2"), + typ: ConstantType::Int64, + nullable: false, + }, + ), + ]), + )]); + let cost_model = + create_cost_model_mock_storage(vec![table_id], vec![], vec![None], attr_infos); + + // Group by empty list should return 1. + let group_bys = empty_list(); + assert_eq!( + cost_model.get_agg_row_cnt(group_bys).await.unwrap(), + EstimatedStatistic(1) + ); + + // Group by single column should return the default value since there are no stats. + let group_bys = list(vec![attr_ref(table_id, 0)]); + assert_eq!( + cost_model.get_agg_row_cnt(group_bys).await.unwrap(), + EstimatedStatistic(DEFAULT_NUM_DISTINCT) + ); + + // Group by two columns should return the default value squared since there are no stats. + let group_bys = list(vec![attr_ref(table_id, 0), attr_ref(table_id, 1)]); + assert_eq!( + cost_model.get_agg_row_cnt(group_bys).await.unwrap(), + EstimatedStatistic(DEFAULT_NUM_DISTINCT * DEFAULT_NUM_DISTINCT) + ); + } + + #[tokio::test] + async fn test_agg_with_stats() { + let table_id = TableId(0); + let attr1_base_idx = 0; + let attr2_base_idx = 1; + let attr3_base_idx = 2; + let attr_infos = HashMap::from([( + table_id, + HashMap::from([ + ( + attr1_base_idx, + Attribute { + name: String::from("attr1"), + typ: ConstantType::Int32, + nullable: false, + }, + ), + ( + attr2_base_idx, + Attribute { + name: String::from("attr2"), + typ: ConstantType::Int64, + nullable: false, + }, + ), + ( + attr3_base_idx, + Attribute { + name: String::from("attr3"), + typ: ConstantType::Int64, + nullable: false, + }, + ), + ]), + )]); + + let attr1_ndistinct = 12; + let attr2_ndistinct = 645; + let attr1_stats = TestPerAttributeStats::new( + MostCommonValues::SimpleFrequency(SimpleMap::default()), + None, + attr1_ndistinct, + 0.0, + ); + let attr2_stats = TestPerAttributeStats::new( + MostCommonValues::SimpleFrequency(SimpleMap::default()), + None, + attr2_ndistinct, + 0.0, + ); + + let cost_model = create_cost_model_mock_storage( + vec![table_id], + vec![HashMap::from([ + (attr1_base_idx, attr1_stats), + (attr2_base_idx, attr2_stats), + ])], + vec![None], + attr_infos, + ); + + // Group by empty list should return 1. + let group_bys = empty_list(); + assert_eq!( + cost_model.get_agg_row_cnt(group_bys).await.unwrap(), + EstimatedStatistic(1) + ); + + // Group by single column should return the n-distinct of the column. + let group_bys = list(vec![attr_ref(table_id, attr1_base_idx)]); + assert_eq!( + cost_model.get_agg_row_cnt(group_bys).await.unwrap(), + EstimatedStatistic(attr1_ndistinct) + ); + + // Group by two columns should return the product of the n-distinct of the columns. + let group_bys = list(vec![ + attr_ref(table_id, attr1_base_idx), + attr_ref(table_id, attr2_base_idx), + ]); + assert_eq!( + cost_model.get_agg_row_cnt(group_bys).await.unwrap(), + EstimatedStatistic(attr1_ndistinct * attr2_ndistinct) + ); + + // Group by multiple columns should return the product of the n-distinct of the columns. If one of the columns + // does not have stats, it should use the default value instead. + let group_bys = list(vec![ + attr_ref(table_id, attr1_base_idx), + attr_ref(table_id, attr2_base_idx), + attr_ref(table_id, attr3_base_idx), + ]); + assert_eq!( + cost_model.get_agg_row_cnt(group_bys).await.unwrap(), + EstimatedStatistic(attr1_ndistinct * attr2_ndistinct * DEFAULT_NUM_DISTINCT) + ); + } +} diff --git a/optd-cost-model/src/cost/filter/controller.rs b/optd-cost-model/src/cost/filter/controller.rs index b319a48..dcfd039 100644 --- a/optd-cost-model/src/cost/filter/controller.rs +++ b/optd-cost-model/src/cost/filter/controller.rs @@ -240,7 +240,7 @@ mod tests { #[tokio::test] async fn test_attr_ref_leq_constint_no_mcvs_in_range() { let per_attribute_stats = TestPerAttributeStats::new( - MostCommonValues::SimpleFrequency(SimpleMap::new(vec![])), + MostCommonValues::SimpleFrequency(SimpleMap::default()), Some(Distribution::SimpleDistribution(SimpleMap::new(vec![( Value::Int32(15), 0.7, @@ -364,7 +364,7 @@ mod tests { #[tokio::test] async fn test_attr_ref_lt_constint_no_mcvs_in_range() { let per_attribute_stats = TestPerAttributeStats::new( - MostCommonValues::SimpleFrequency(SimpleMap::new(vec![])), + MostCommonValues::SimpleFrequency(SimpleMap::default()), Some(Distribution::SimpleDistribution(SimpleMap::new(vec![( Value::Int32(15), 0.7, @@ -492,7 +492,7 @@ mod tests { #[tokio::test] async fn test_attr_ref_gt_constint() { let per_attribute_stats = TestPerAttributeStats::new( - MostCommonValues::SimpleFrequency(SimpleMap::new(vec![])), + MostCommonValues::SimpleFrequency(SimpleMap::default()), Some(Distribution::SimpleDistribution(SimpleMap::new(vec![( Value::Int32(15), 0.7, @@ -530,7 +530,7 @@ mod tests { #[tokio::test] async fn test_attr_ref_geq_constint() { let per_attribute_stats = TestPerAttributeStats::new( - MostCommonValues::SimpleFrequency(SimpleMap::new(vec![])), + MostCommonValues::SimpleFrequency(SimpleMap::default()), Some(Distribution::SimpleDistribution(SimpleMap::new(vec![( Value::Int32(15), 0.7, @@ -798,7 +798,7 @@ mod tests { #[tokio::test] async fn test_cast_attr_ref_eq_attr_ref() { let per_attribute_stats = TestPerAttributeStats::new( - MostCommonValues::SimpleFrequency(SimpleMap::new(vec![])), + MostCommonValues::SimpleFrequency(SimpleMap::default()), None, 0, 0.0, diff --git a/optd-cost-model/src/cost_model.rs b/optd-cost-model/src/cost_model.rs index 1943519..32ea404 100644 --- a/optd-cost-model/src/cost_model.rs +++ b/optd-cost-model/src/cost_model.rs @@ -176,7 +176,7 @@ pub mod tests { CostModelImpl::new(storage_manager, CatalogSource::Mock) } - pub fn attr_ref(table_id: TableId, attr_base_index: usize) -> ArcPredicateNode { + pub fn attr_ref(table_id: TableId, attr_base_index: u64) -> ArcPredicateNode { AttributeRefPred::new(table_id, attr_base_index).into_pred_node() } @@ -214,7 +214,7 @@ pub mod tests { pub fn in_list( table_id: TableId, - attr_ref_idx: usize, + attr_ref_idx: u64, list: Vec, negated: bool, ) -> InListPred { @@ -225,7 +225,7 @@ pub mod tests { ) } - pub fn like(table_id: TableId, attr_ref_idx: usize, pattern: &str, negated: bool) -> LikePred { + pub fn like(table_id: TableId, attr_ref_idx: u64, pattern: &str, negated: bool) -> LikePred { LikePred::new( negated, false, diff --git a/optd-cost-model/src/storage/mock.rs b/optd-cost-model/src/storage/mock.rs index aba140b..91a2265 100644 --- a/optd-cost-model/src/storage/mock.rs +++ b/optd-cost-model/src/storage/mock.rs @@ -30,7 +30,7 @@ impl TableStats { } pub type BaseTableStats = HashMap; -pub type BaseTableAttrInfo = HashMap>; +pub type BaseTableAttrInfo = HashMap>; // (table_id, (attr_base_index, attr)) pub struct CostModelStorageMockManagerImpl { pub(crate) per_table_stats_map: BaseTableStats, @@ -58,7 +58,7 @@ impl CostModelStorageManager for CostModelStorageMockManagerImpl { let table_attr_infos = self.per_table_attr_infos_map.get(&table_id); match table_attr_infos { None => Ok(None), - Some(table_attr_infos) => match table_attr_infos.get(&attr_base_index) { + Some(table_attr_infos) => match table_attr_infos.get(&(attr_base_index as u64)) { None => Ok(None), Some(attr) => Ok(Some(attr.clone())), }, From 0059141e1121affb179026d67e41dd93e6f7dc5e Mon Sep 17 00:00:00 2001 From: Yuanxin Cao Date: Sun, 17 Nov 2024 09:49:10 -0500 Subject: [PATCH 33/51] make all data types u64 instead of usize --- optd-cost-model/src/common/predicates/attr_ref_pred.rs | 4 ++-- optd-cost-model/src/common/predicates/id_pred.rs | 4 ++-- optd-cost-model/src/common/types.rs | 10 +++++----- optd-cost-model/src/cost/filter/comp_op.rs | 6 ++---- optd-cost-model/src/cost_model.rs | 2 +- optd-cost-model/src/lib.rs | 2 +- optd-cost-model/src/storage/mock.rs | 8 ++++---- optd-cost-model/src/storage/mod.rs | 2 +- optd-cost-model/src/storage/persistent.rs | 4 ++-- optd-persistent/src/cost_model/interface.rs | 5 +++-- optd-persistent/src/cost_model/orm.rs | 8 ++++---- 11 files changed, 27 insertions(+), 28 deletions(-) diff --git a/optd-cost-model/src/common/predicates/attr_ref_pred.rs b/optd-cost-model/src/common/predicates/attr_ref_pred.rs index 99638f2..9f63ad7 100644 --- a/optd-cost-model/src/common/predicates/attr_ref_pred.rs +++ b/optd-cost-model/src/common/predicates/attr_ref_pred.rs @@ -33,7 +33,7 @@ impl AttributeRefPred { PredicateNode { typ: PredicateType::AttributeRef, children: vec![ - IdPred::new(table_id.0 as u64).into_pred_node(), + IdPred::new(table_id.0).into_pred_node(), IdPred::new(attribute_idx).into_pred_node(), ], data: None, @@ -44,7 +44,7 @@ impl AttributeRefPred { /// Gets the table id. pub fn table_id(&self) -> TableId { - TableId(self.0.child(0).data.as_ref().unwrap().as_u64() as usize) + TableId(self.0.child(0).data.as_ref().unwrap().as_u64()) } /// Gets the attribute index. diff --git a/optd-cost-model/src/common/predicates/id_pred.rs b/optd-cost-model/src/common/predicates/id_pred.rs index e502a48..13f557f 100644 --- a/optd-cost-model/src/common/predicates/id_pred.rs +++ b/optd-cost-model/src/common/predicates/id_pred.rs @@ -23,8 +23,8 @@ impl IdPred { } /// Gets the id stored in the predicate. - pub fn id(&self) -> usize { - self.0.data.clone().unwrap().as_u64() as usize + pub fn id(&self) -> u64 { + self.0.data.clone().unwrap().as_u64() } } diff --git a/optd-cost-model/src/common/types.rs b/optd-cost-model/src/common/types.rs index e8aaf7b..fecd143 100644 --- a/optd-cost-model/src/common/types.rs +++ b/optd-cost-model/src/common/types.rs @@ -5,23 +5,23 @@ use std::fmt::Display; /// TODO: documentation #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug, Default, Hash)] -pub struct GroupId(pub usize); +pub struct GroupId(pub u64); /// TODO: documentation #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug, Default, Hash)] -pub struct ExprId(pub usize); +pub struct ExprId(pub u64); /// TODO: documentation #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug, Default, Hash)] -pub struct TableId(pub usize); +pub struct TableId(pub u64); /// TODO: documentation #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug, Default, Hash)] -pub struct AttrId(pub usize); +pub struct AttrId(pub u64); /// TODO: documentation #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug, Default, Hash)] -pub struct EpochId(pub usize); +pub struct EpochId(pub u64); impl Display for GroupId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { diff --git a/optd-cost-model/src/cost/filter/comp_op.rs b/optd-cost-model/src/cost/filter/comp_op.rs index 4ce2b3c..2e463bd 100644 --- a/optd-cost-model/src/cost/filter/comp_op.rs +++ b/optd-cost-model/src/cost/filter/comp_op.rs @@ -183,11 +183,9 @@ impl CostModelImpl { // **TODO**: What if this attribute is a derived attribute? let attribute_info = self .storage_manager - .get_attribute_info(table_id, attr_ref_idx as i32) + .get_attribute_info(table_id, attr_ref_idx) .await? - .ok_or({ - SemanticError::AttributeNotFound(table_id, attr_ref_idx as i32) - })?; + .ok_or({ SemanticError::AttributeNotFound(table_id, attr_ref_idx) })?; let invert_cast_data_type = &attribute_info.typ.into_data_type(); diff --git a/optd-cost-model/src/cost_model.rs b/optd-cost-model/src/cost_model.rs index 32ea404..992d4bc 100644 --- a/optd-cost-model/src/cost_model.rs +++ b/optd-cost-model/src/cost_model.rs @@ -150,7 +150,7 @@ pub mod tests { pub fn create_cost_model_mock_storage( table_id: Vec, per_attribute_stats: Vec>, - row_counts: Vec>, + row_counts: Vec>, per_table_attr_infos: BaseTableAttrInfo, ) -> TestOptCostModelMock { let storage_manager = CostModelStorageMockManagerImpl::new( diff --git a/optd-cost-model/src/lib.rs b/optd-cost-model/src/lib.rs index a9b977a..6f786e8 100644 --- a/optd-cost-model/src/lib.rs +++ b/optd-cost-model/src/lib.rs @@ -41,7 +41,7 @@ pub enum SemanticError { // TODO: Add more error types UnknownStatisticType, VersionedStatisticNotFound, - AttributeNotFound(TableId, i32), // (table_id, attribute_base_index) + AttributeNotFound(TableId, u64), // (table_id, attribute_base_index) // FIXME: not sure if this should be put here InvalidPredicate(String), } diff --git a/optd-cost-model/src/storage/mock.rs b/optd-cost-model/src/storage/mock.rs index 91a2265..1f369f9 100644 --- a/optd-cost-model/src/storage/mock.rs +++ b/optd-cost-model/src/storage/mock.rs @@ -12,14 +12,14 @@ pub type AttrIndices = Vec; #[serde_with::serde_as] #[derive(Serialize, Deserialize, Debug)] pub struct TableStats { - pub row_cnt: usize, + pub row_cnt: u64, #[serde_as(as = "HashMap")] pub column_comb_stats: HashMap, } impl TableStats { pub fn new( - row_cnt: usize, + row_cnt: u64, column_comb_stats: HashMap, ) -> Self { Self { @@ -53,12 +53,12 @@ impl CostModelStorageManager for CostModelStorageMockManagerImpl { async fn get_attribute_info( &self, table_id: TableId, - attr_base_index: i32, + attr_base_index: u64, ) -> CostModelResult> { let table_attr_infos = self.per_table_attr_infos_map.get(&table_id); match table_attr_infos { None => Ok(None), - Some(table_attr_infos) => match table_attr_infos.get(&(attr_base_index as u64)) { + Some(table_attr_infos) => match table_attr_infos.get(&attr_base_index) { None => Ok(None), Some(attr) => Ok(Some(attr.clone())), }, diff --git a/optd-cost-model/src/storage/mod.rs b/optd-cost-model/src/storage/mod.rs index fcc5141..3231107 100644 --- a/optd-cost-model/src/storage/mod.rs +++ b/optd-cost-model/src/storage/mod.rs @@ -21,7 +21,7 @@ pub trait CostModelStorageManager { async fn get_attribute_info( &self, table_id: TableId, - attr_base_index: i32, + attr_base_index: u64, ) -> CostModelResult>; async fn get_attributes_comb_statistics( diff --git a/optd-cost-model/src/storage/persistent.rs b/optd-cost-model/src/storage/persistent.rs index 49c4ff4..42312ff 100644 --- a/optd-cost-model/src/storage/persistent.rs +++ b/optd-cost-model/src/storage/persistent.rs @@ -35,11 +35,11 @@ impl CostModelStorageManager async fn get_attribute_info( &self, table_id: TableId, - attr_base_index: i32, + attr_base_index: u64, ) -> CostModelResult> { Ok(self .backend_manager - .get_attribute(table_id.into(), attr_base_index) + .get_attribute(table_id.into(), attr_base_index as i32) .await? .map(|attr| Attribute { name: attr.name, diff --git a/optd-persistent/src/cost_model/interface.rs b/optd-persistent/src/cost_model/interface.rs index 598598d..ee767d7 100644 --- a/optd-persistent/src/cost_model/interface.rs +++ b/optd-persistent/src/cost_model/interface.rs @@ -17,6 +17,7 @@ pub type AttrId = i32; pub type ExprId = i32; pub type EpochId = i32; pub type StatId = i32; +pub type AttrIndex = i32; /// TODO: documentation pub enum CatalogSource { @@ -152,7 +153,7 @@ pub trait CostModelStorageLayer { async fn get_stats_for_attr_indices_based( &self, table_id: TableId, - attr_base_indices: Vec, + attr_base_indices: Vec, stat_type: StatType, epoch_id: Option, ) -> StorageResult>; @@ -168,6 +169,6 @@ pub trait CostModelStorageLayer { async fn get_attribute( &self, table_id: TableId, - attribute_base_index: i32, + attribute_base_index: AttrIndex, ) -> StorageResult>; } diff --git a/optd-persistent/src/cost_model/orm.rs b/optd-persistent/src/cost_model/orm.rs index 65d6035..b503b7e 100644 --- a/optd-persistent/src/cost_model/orm.rs +++ b/optd-persistent/src/cost_model/orm.rs @@ -14,8 +14,8 @@ use serde_json::json; use super::catalog::mock_catalog::{self, MockCatalog}; use super::interface::{ - Attr, AttrId, AttrType, CatalogSource, EpochId, EpochOption, ExprId, Stat, StatId, StatType, - TableId, + Attr, AttrId, AttrIndex, AttrType, CatalogSource, EpochId, EpochOption, ExprId, Stat, StatId, + StatType, TableId, }; impl BackendManager { @@ -432,7 +432,7 @@ impl CostModelStorageLayer for BackendManager { async fn get_stats_for_attr_indices_based( &self, table_id: TableId, - attr_base_indices: Vec, + attr_base_indices: Vec, stat_type: StatType, epoch_id: Option, ) -> StorageResult> { @@ -542,7 +542,7 @@ impl CostModelStorageLayer for BackendManager { async fn get_attribute( &self, table_id: TableId, - attribute_base_index: i32, + attribute_base_index: AttrIndex, ) -> StorageResult> { let attr_res = Attribute::find() .filter(attribute::Column::TableId.eq(table_id)) From ec0afa6c55006b8dd32325908d33113e986a9bc2 Mon Sep 17 00:00:00 2001 From: Yuanxin Cao Date: Mon, 18 Nov 2024 11:57:35 -0500 Subject: [PATCH 34/51] copy paste join cardinality calculation --- optd-cost-model/src/cost/join.rs | 1422 ++++++++++++++++++++++++++++++ 1 file changed, 1422 insertions(+) diff --git a/optd-cost-model/src/cost/join.rs b/optd-cost-model/src/cost/join.rs index 8b13789..5aa8fb6 100644 --- a/optd-cost-model/src/cost/join.rs +++ b/optd-cost-model/src/cost/join.rs @@ -1 +1,1423 @@ +// Copyright (c) 2023-2024 CMU Database Group +// +// Use of this source code is governed by an MIT-style license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. +use std::collections::HashSet; + +use itertools::Itertools; +use optd_datafusion_repr::plan_nodes::{ + ArcDfPredNode, BinOpType, ColumnRefPred, DfPredType, DfReprPredNode, JoinType, ListPred, + LogOpPred, LogOpType, +}; +use optd_datafusion_repr::properties::column_ref::{ + BaseTableColumnRef, BaseTableColumnRefs, ColumnRef, EqBaseTableColumnSets, EqPredicate, + GroupColumnRefs, SemanticCorrelation, +}; +use optd_datafusion_repr::properties::schema::Schema; +use serde::de::DeserializeOwned; +use serde::Serialize; + +use super::AdvStats; +use crate::adv_stats::stats::{Distribution, MostCommonValues}; +use crate::adv_stats::DEFAULT_NUM_DISTINCT; + +impl< + M: MostCommonValues + Serialize + DeserializeOwned, + D: Distribution + Serialize + DeserializeOwned, + > AdvStats +{ + #[allow(clippy::too_many_arguments)] + pub(crate) fn get_nlj_row_cnt( + &self, + join_typ: JoinType, + left_row_cnt: f64, + right_row_cnt: f64, + output_schema: Schema, + output_column_refs: GroupColumnRefs, + join_cond: ArcDfPredNode, + left_column_refs: GroupColumnRefs, + right_column_refs: GroupColumnRefs, + ) -> f64 { + let selectivity = { + let input_correlation = self.get_input_correlation(left_column_refs, right_column_refs); + self.get_join_selectivity_from_expr_tree( + join_typ, + join_cond, + &output_schema, + output_column_refs.base_table_column_refs(), + input_correlation, + left_row_cnt, + right_row_cnt, + ) + }; + (left_row_cnt * right_row_cnt * selectivity).max(1.0) + } + + #[allow(clippy::too_many_arguments)] + pub(crate) fn get_hash_join_row_cnt( + &self, + join_typ: JoinType, + left_row_cnt: f64, + right_row_cnt: f64, + left_keys: ListPred, + right_keys: ListPred, + output_schema: Schema, + output_column_refs: GroupColumnRefs, + left_column_refs: GroupColumnRefs, + right_column_refs: GroupColumnRefs, + ) -> f64 { + let selectivity = { + let schema = output_schema; + let column_refs = output_column_refs; + let column_refs = column_refs.base_table_column_refs(); + let left_col_cnt = left_column_refs.base_table_column_refs().len(); + // there may be more than one expression tree in a group. + // see comment in DfPredType::PhysicalFilter(_) for more information + let input_correlation = self.get_input_correlation(left_column_refs, right_column_refs); + self.get_join_selectivity_from_keys( + join_typ, + left_keys, + right_keys, + &schema, + column_refs, + input_correlation, + left_row_cnt, + right_row_cnt, + left_col_cnt, + ) + }; + (left_row_cnt * right_row_cnt * selectivity).max(1.0) + } + + fn get_input_correlation( + &self, + left_prop: GroupColumnRefs, + right_prop: GroupColumnRefs, + ) -> Option { + SemanticCorrelation::merge( + left_prop.output_correlation().cloned(), + right_prop.output_correlation().cloned(), + ) + } + + /// A wrapper to convert the join keys to the format expected by get_join_selectivity_core() + #[allow(clippy::too_many_arguments)] + fn get_join_selectivity_from_keys( + &self, + join_typ: JoinType, + left_keys: ListPred, + right_keys: ListPred, + schema: &Schema, + column_refs: &BaseTableColumnRefs, + input_correlation: Option, + left_row_cnt: f64, + right_row_cnt: f64, + left_col_cnt: usize, + ) -> f64 { + assert!(left_keys.len() == right_keys.len()); + // I assume that the keys are already in the right order + // s.t. the ith key of left_keys corresponds with the ith key of right_keys + let on_col_ref_pairs = left_keys + .to_vec() + .into_iter() + .zip(right_keys.to_vec()) + .map(|(left_key, right_key)| { + ( + ColumnRefPred::from_pred_node(left_key).expect("keys should be ColumnRefPreds"), + ColumnRefPred::from_pred_node(right_key) + .expect("keys should be ColumnRefPreds"), + ) + }) + .collect_vec(); + self.get_join_selectivity_core( + join_typ, + on_col_ref_pairs, + None, + schema, + column_refs, + input_correlation, + left_row_cnt, + right_row_cnt, + left_col_cnt, + ) + } + + /// The core logic of join selectivity which assumes we've already separated the expression + /// into the on conditions and the filters. + /// + /// Hash join and NLJ reference right table columns differently, hence the + /// `right_col_ref_offset` parameter. + /// + /// For hash join, the right table columns indices are with respect to the right table, + /// which means #0 is the first column of the right table. + /// + /// For NLJ, the right table columns indices are with respect to the output of the join. + /// For example, if the left table has 3 columns, the first column of the right table + /// is #3 instead of #0. + #[allow(clippy::too_many_arguments)] + fn get_join_selectivity_core( + &self, + join_typ: JoinType, + on_col_ref_pairs: Vec<(ColumnRefPred, ColumnRefPred)>, + filter_expr_tree: Option, + schema: &Schema, + column_refs: &BaseTableColumnRefs, + input_correlation: Option, + left_row_cnt: f64, + right_row_cnt: f64, + right_col_ref_offset: usize, + ) -> f64 { + let join_on_selectivity = self.get_join_on_selectivity( + &on_col_ref_pairs, + column_refs, + input_correlation, + right_col_ref_offset, + ); + // Currently, there is no difference in how we handle a join filter and a select filter, + // so we use the same function. + // + // One difference (that we *don't* care about right now) is that join filters can contain + // expressions from multiple different tables. Currently, this doesn't affect the + // get_filter_selectivity() function, but this may change in the future. + let join_filter_selectivity = match filter_expr_tree { + Some(filter_expr_tree) => { + self.get_filter_selectivity(filter_expr_tree, schema, column_refs) + } + None => 1.0, + }; + let inner_join_selectivity = join_on_selectivity * join_filter_selectivity; + match join_typ { + JoinType::Inner => inner_join_selectivity, + JoinType::LeftOuter => f64::max(inner_join_selectivity, 1.0 / right_row_cnt), + JoinType::RightOuter => f64::max(inner_join_selectivity, 1.0 / left_row_cnt), + JoinType::Cross => { + assert!( + on_col_ref_pairs.is_empty(), + "Cross joins should not have on columns" + ); + join_filter_selectivity + } + _ => unimplemented!("join_typ={} is not implemented", join_typ), + } + } + + /// The expr_tree input must be a "mixed expression tree", just like with + /// `get_filter_selectivity`. + /// + /// This is a "wrapper" to separate the equality conditions from the filter conditions before + /// calling the "main" `get_join_selectivity_core` function. + #[allow(clippy::too_many_arguments)] + fn get_join_selectivity_from_expr_tree( + &self, + join_typ: JoinType, + expr_tree: ArcDfPredNode, + schema: &Schema, + column_refs: &BaseTableColumnRefs, + input_correlation: Option, + left_row_cnt: f64, + right_row_cnt: f64, + ) -> f64 { + if expr_tree.typ == DfPredType::LogOp(LogOpType::And) { + let mut on_col_ref_pairs = vec![]; + let mut filter_expr_trees = vec![]; + for child_expr_tree in &expr_tree.children { + if let Some(on_col_ref_pair) = + Self::get_on_col_ref_pair(child_expr_tree.clone(), column_refs) + { + on_col_ref_pairs.push(on_col_ref_pair) + } else { + let child_expr = child_expr_tree.clone(); + filter_expr_trees.push(child_expr); + } + } + assert!(on_col_ref_pairs.len() + filter_expr_trees.len() == expr_tree.children.len()); + let filter_expr_tree = if filter_expr_trees.is_empty() { + None + } else { + Some(LogOpPred::new(LogOpType::And, filter_expr_trees).into_pred_node()) + }; + self.get_join_selectivity_core( + join_typ, + on_col_ref_pairs, + filter_expr_tree, + schema, + column_refs, + input_correlation, + left_row_cnt, + right_row_cnt, + 0, + ) + } else { + #[allow(clippy::collapsible_else_if)] + if let Some(on_col_ref_pair) = Self::get_on_col_ref_pair(expr_tree.clone(), column_refs) + { + self.get_join_selectivity_core( + join_typ, + vec![on_col_ref_pair], + None, + schema, + column_refs, + input_correlation, + left_row_cnt, + right_row_cnt, + 0, + ) + } else { + self.get_join_selectivity_core( + join_typ, + vec![], + Some(expr_tree), + schema, + column_refs, + input_correlation, + left_row_cnt, + right_row_cnt, + 0, + ) + } + } + } + + /// Check if an expr_tree is a join condition, returning the join on col ref pair if it is. + /// The reason the check and the info are in the same function is because their code is almost + /// identical. It only picks out equality conditions between two column refs on different + /// tables + fn get_on_col_ref_pair( + expr_tree: ArcDfPredNode, + column_refs: &BaseTableColumnRefs, + ) -> Option<(ColumnRefPred, ColumnRefPred)> { + // 1. Check that it's equality + if expr_tree.typ == DfPredType::BinOp(BinOpType::Eq) { + let left_child = expr_tree.child(0); + let right_child = expr_tree.child(1); + // 2. Check that both sides are column refs + if left_child.typ == DfPredType::ColumnRef && right_child.typ == DfPredType::ColumnRef { + // 3. Check that both sides don't belong to the same table (if we don't know, that + // means they don't belong) + let left_col_ref_expr = ColumnRefPred::from_pred_node(left_child) + .expect("we already checked that the type is ColumnRef"); + let right_col_ref_expr = ColumnRefPred::from_pred_node(right_child) + .expect("we already checked that the type is ColumnRef"); + let left_col_ref = &column_refs[left_col_ref_expr.index()]; + let right_col_ref = &column_refs[right_col_ref_expr.index()]; + let is_same_table = if let ( + ColumnRef::BaseTableColumnRef(BaseTableColumnRef { + table: left_table, .. + }), + ColumnRef::BaseTableColumnRef(BaseTableColumnRef { + table: right_table, .. + }), + ) = (left_col_ref, right_col_ref) + { + left_table == right_table + } else { + false + }; + if !is_same_table { + Some((left_col_ref_expr, right_col_ref_expr)) + } else { + None + } + } else { + None + } + } else { + None + } + } + + /// Get the selectivity of one column eq predicate, e.g. colA = colB. + fn get_join_selectivity_from_on_col_ref_pair( + &self, + left: &ColumnRef, + right: &ColumnRef, + ) -> f64 { + // the formula for each pair is min(1 / ndistinct1, 1 / ndistinct2) + // (see https://postgrespro.com/blog/pgsql/5969618) + let ndistincts = vec![left, right].into_iter().map(|col_ref| { + match self.get_single_column_stats_from_col_ref(col_ref) { + Some(per_col_stats) => per_col_stats.ndistinct, + None => DEFAULT_NUM_DISTINCT, + } + }); + // using reduce(f64::min) is the idiomatic workaround to min() because + // f64 does not implement Ord due to NaN + let selectivity = ndistincts.map(|ndistinct| 1.0 / ndistinct as f64).reduce(f64::min).expect("reduce() only returns None if the iterator is empty, which is impossible since col_ref_exprs.len() == 2"); + assert!( + !selectivity.is_nan(), + "it should be impossible for selectivity to be NaN since n-distinct is never 0" + ); + selectivity + } + + /// Given a set of N columns involved in a multi-equality, find the total selectivity + /// of the multi-equality. + /// + /// This is a generalization of get_join_selectivity_from_on_col_ref_pair(). + fn get_join_selectivity_from_most_selective_columns( + &self, + base_col_refs: HashSet, + ) -> f64 { + assert!(base_col_refs.len() > 1); + let num_base_col_refs = base_col_refs.len(); + base_col_refs + .into_iter() + .map(|base_col_ref| { + match self.get_column_comb_stats(&base_col_ref.table, &[base_col_ref.col_idx]) { + Some(per_col_stats) => per_col_stats.ndistinct, + None => DEFAULT_NUM_DISTINCT, + } + }) + .map(|ndistinct| 1.0 / ndistinct as f64) + .sorted_by(|a, b| { + a.partial_cmp(b) + .expect("No floats should be NaN since n-distinct is never 0") + }) + .take(num_base_col_refs - 1) + .product() + } + + /// A predicate set defines a "multi-equality graph", which is an unweighted undirected graph. + /// The nodes are columns while edges are predicates. The old graph is defined by + /// `past_eq_columns` while the `predicate` is the new addition to this graph. This + /// unweighted undirected graph consists of a number of connected components, where each + /// connected component represents columns that are set to be equal to each other. Single + /// nodes not connected to anything are considered standalone connected components. + /// + /// The selectivity of each connected component of N nodes is equal to the product of + /// 1/ndistinct of the N-1 nodes with the highest ndistinct values. You can see this if you + /// imagine that all columns being joined are unique columns and that they follow the + /// inclusion principle (every element of the smaller tables is present in the larger + /// tables). When these assumptions are not true, the selectivity may not be completely + /// accurate. However, it is still fairly accurate. + /// + /// However, we cannot simply add `predicate` to the multi-equality graph and compute the + /// selectivity of the entire connected component, because this would be "double counting" a + /// lot of nodes. The join(s) before this join would already have a selectivity value. Thus, + /// we compute the selectivity of the join(s) before this join (the first block of the + /// function) and then the selectivity of the connected component after this join. The + /// quotient is the "adjustment" factor. + /// + /// NOTE: This function modifies `past_eq_columns` by adding `predicate` to it. + fn get_join_selectivity_adjustment_when_adding_to_multi_equality_graph( + &self, + predicate: &EqPredicate, + past_eq_columns: &mut EqBaseTableColumnSets, + ) -> f64 { + if predicate.left == predicate.right { + // self-join, TODO: is this correct? + return 1.0; + } + // To find the adjustment, we need to know the selectivity of the graph before `predicate` + // is added. + // + // There are two cases: (1) adding `predicate` does not change the # of connected + // components, and (2) adding `predicate` reduces the # of connected by 1. Note that + // columns not involved in any predicates are considered a part of the graph and are + // a connected component on their own. + let children_pred_sel = { + if past_eq_columns.is_eq(&predicate.left, &predicate.right) { + self.get_join_selectivity_from_most_selective_columns( + past_eq_columns.find_cols_for_eq_column_set(&predicate.left), + ) + } else { + let left_sel = if past_eq_columns.contains(&predicate.left) { + self.get_join_selectivity_from_most_selective_columns( + past_eq_columns.find_cols_for_eq_column_set(&predicate.left), + ) + } else { + 1.0 + }; + let right_sel = if past_eq_columns.contains(&predicate.right) { + self.get_join_selectivity_from_most_selective_columns( + past_eq_columns.find_cols_for_eq_column_set(&predicate.right), + ) + } else { + 1.0 + }; + left_sel * right_sel + } + }; + + // Add predicate to past_eq_columns and compute the selectivity of the connected component + // it creates. + past_eq_columns.add_predicate(predicate.clone()); + let new_pred_sel = { + let cols = past_eq_columns.find_cols_for_eq_column_set(&predicate.left); + self.get_join_selectivity_from_most_selective_columns(cols) + }; + + // Compute the adjustment factor. + new_pred_sel / children_pred_sel + } + + /// Get the selectivity of the on conditions. + /// + /// Note that the selectivity of the on conditions does not depend on join type. + /// Join type is accounted for separately in get_join_selectivity_core(). + /// + /// We also check if each predicate is correlated with any of the previous predicates. + /// + /// More specifically, we are checking if the predicate can be expressed with other existing + /// predicates. E.g. if we have a predicate like A = B and B = C is equivalent to A = C. + // + /// However, we don't just throw away A = C, because we want to pick the most selective + /// predicates. For details on how we do this, see + /// `get_join_selectivity_from_redundant_predicates`. + fn get_join_on_selectivity( + &self, + on_col_ref_pairs: &[(ColumnRefPred, ColumnRefPred)], + column_refs: &BaseTableColumnRefs, + input_correlation: Option, + right_col_ref_offset: usize, + ) -> f64 { + let mut past_eq_columns = input_correlation + .map(|c| EqBaseTableColumnSets::try_from(c).unwrap()) + .unwrap_or_default(); + + // multiply the selectivities of all individual conditions together + on_col_ref_pairs + .iter() + .map(|on_col_ref_pair| { + let left_col_ref = &column_refs[on_col_ref_pair.0.index()]; + let right_col_ref = &column_refs[on_col_ref_pair.1.index() + right_col_ref_offset]; + + if let (ColumnRef::BaseTableColumnRef(left), ColumnRef::BaseTableColumnRef(right)) = + (left_col_ref, right_col_ref) + { + let predicate = EqPredicate::new(left.clone(), right.clone()); + return self + .get_join_selectivity_adjustment_when_adding_to_multi_equality_graph( + &predicate, + &mut past_eq_columns, + ); + } + + self.get_join_selectivity_from_on_col_ref_pair(left_col_ref, right_col_ref) + }) + .product() + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashSet; + + use optd_core::nodes::Value; + use optd_datafusion_repr::plan_nodes::{ArcDfPredNode, BinOpType, JoinType, LogOpType}; + use optd_datafusion_repr::properties::column_ref::{ + BaseTableColumnRef, BaseTableColumnRefs, ColumnRef, EqBaseTableColumnSets, EqPredicate, + SemanticCorrelation, + }; + use optd_datafusion_repr::properties::schema::Schema; + + use crate::adv_stats::tests::*; + use crate::adv_stats::DEFAULT_EQ_SEL; + + /// A wrapper around get_join_selectivity_from_expr_tree that extracts the + /// table row counts from the cost model. + fn test_get_join_selectivity( + cost_model: &TestOptCostModel, + reverse_tables: bool, + join_typ: JoinType, + expr_tree: ArcDfPredNode, + schema: &Schema, + column_refs: &BaseTableColumnRefs, + input_correlation: Option, + ) -> f64 { + let table1_row_cnt = cost_model.per_table_stats_map[TABLE1_NAME].row_cnt as f64; + let table2_row_cnt = cost_model.per_table_stats_map[TABLE2_NAME].row_cnt as f64; + if !reverse_tables { + cost_model.get_join_selectivity_from_expr_tree( + join_typ, + expr_tree, + schema, + column_refs, + input_correlation, + table1_row_cnt, + table2_row_cnt, + ) + } else { + cost_model.get_join_selectivity_from_expr_tree( + join_typ, + expr_tree, + schema, + column_refs, + input_correlation, + table2_row_cnt, + table1_row_cnt, + ) + } + } + + #[test] + fn test_inner_const() { + let cost_model = create_one_column_cost_model(get_empty_per_col_stats()); + assert_approx_eq::assert_approx_eq!( + cost_model.get_join_selectivity_from_expr_tree( + JoinType::Inner, + cnst(Value::Bool(true)), + &Schema::new(vec![]), + &vec![], + None, + f64::NAN, + f64::NAN + ), + 1.0 + ); + assert_approx_eq::assert_approx_eq!( + cost_model.get_join_selectivity_from_expr_tree( + JoinType::Inner, + cnst(Value::Bool(false)), + &Schema::new(vec![]), + &vec![], + None, + f64::NAN, + f64::NAN + ), + 0.0 + ); + } + + #[test] + fn test_inner_oncond() { + let cost_model = create_two_table_cost_model( + TestPerColumnStats::new( + TestMostCommonValues::empty(), + 5, + 0.0, + Some(TestDistribution::empty()), + ), + TestPerColumnStats::new( + TestMostCommonValues::empty(), + 4, + 0.0, + Some(TestDistribution::empty()), + ), + ); + let expr_tree = bin_op(BinOpType::Eq, col_ref(0), col_ref(1)); + let expr_tree_rev = bin_op(BinOpType::Eq, col_ref(1), col_ref(0)); + let schema = Schema::new(vec![]); + let column_refs = vec![ + ColumnRef::base_table_column_ref(String::from(TABLE1_NAME), 0), + ColumnRef::base_table_column_ref(String::from(TABLE2_NAME), 0), + ]; + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree, + &schema, + &column_refs, + None, + ), + 0.2 + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree_rev, + &schema, + &column_refs, + None, + ), + 0.2 + ); + } + + #[test] + fn test_inner_and_of_onconds() { + let cost_model = create_two_table_cost_model( + TestPerColumnStats::new( + TestMostCommonValues::empty(), + 5, + 0.0, + Some(TestDistribution::empty()), + ), + TestPerColumnStats::new( + TestMostCommonValues::empty(), + 4, + 0.0, + Some(TestDistribution::empty()), + ), + ); + let eq0and1 = bin_op(BinOpType::Eq, col_ref(0), col_ref(1)); + let eq1and0 = bin_op(BinOpType::Eq, col_ref(1), col_ref(0)); + let expr_tree = log_op(LogOpType::And, vec![eq0and1.clone(), eq1and0.clone()]); + let expr_tree_rev = log_op(LogOpType::And, vec![eq1and0.clone(), eq0and1.clone()]); + let schema = Schema::new(vec![]); + let column_refs = vec![ + ColumnRef::base_table_column_ref(String::from(TABLE1_NAME), 0), + ColumnRef::base_table_column_ref(String::from(TABLE2_NAME), 0), + ]; + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree, + &schema, + &column_refs, + None, + ), + 0.2 + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree_rev, + &schema, + &column_refs, + None + ), + 0.2 + ); + } + + #[test] + fn test_inner_and_of_oncond_and_filter() { + let cost_model = create_two_table_cost_model( + TestPerColumnStats::new( + TestMostCommonValues::empty(), + 5, + 0.0, + Some(TestDistribution::empty()), + ), + TestPerColumnStats::new( + TestMostCommonValues::empty(), + 4, + 0.0, + Some(TestDistribution::empty()), + ), + ); + let eq0and1 = bin_op(BinOpType::Eq, col_ref(0), col_ref(1)); + let eq100 = bin_op(BinOpType::Eq, col_ref(1), cnst(Value::Int32(100))); + let expr_tree = log_op(LogOpType::And, vec![eq0and1.clone(), eq100.clone()]); + let expr_tree_rev = log_op(LogOpType::And, vec![eq100.clone(), eq0and1.clone()]); + let schema = Schema::new(vec![]); + let column_refs = vec![ + ColumnRef::base_table_column_ref(String::from(TABLE1_NAME), 0), + ColumnRef::base_table_column_ref(String::from(TABLE2_NAME), 0), + ]; + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree, + &schema, + &column_refs, + None + ), + 0.05 + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree_rev, + &schema, + &column_refs, + None + ), + 0.05 + ); + } + + #[test] + fn test_inner_and_of_filters() { + let cost_model = create_two_table_cost_model( + TestPerColumnStats::new( + TestMostCommonValues::empty(), + 5, + 0.0, + Some(TestDistribution::empty()), + ), + TestPerColumnStats::new( + TestMostCommonValues::empty(), + 4, + 0.0, + Some(TestDistribution::empty()), + ), + ); + let neq12 = bin_op(BinOpType::Neq, col_ref(0), cnst(Value::Int32(12))); + let eq100 = bin_op(BinOpType::Eq, col_ref(1), cnst(Value::Int32(100))); + let expr_tree = log_op(LogOpType::And, vec![neq12.clone(), eq100.clone()]); + let expr_tree_rev = log_op(LogOpType::And, vec![eq100.clone(), neq12.clone()]); + let schema = Schema::new(vec![]); + let column_refs = vec![ + ColumnRef::base_table_column_ref(String::from(TABLE1_NAME), 0), + ColumnRef::base_table_column_ref(String::from(TABLE2_NAME), 0), + ]; + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree, + &schema, + &column_refs, + None, + ), + 0.2 + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree_rev, + &schema, + &column_refs, + None + ), + 0.2 + ); + } + + #[test] + fn test_inner_colref_eq_colref_same_table_is_not_oncond() { + let cost_model = create_two_table_cost_model( + TestPerColumnStats::new( + TestMostCommonValues::empty(), + 5, + 0.0, + Some(TestDistribution::empty()), + ), + TestPerColumnStats::new( + TestMostCommonValues::empty(), + 4, + 0.0, + Some(TestDistribution::empty()), + ), + ); + let expr_tree = bin_op(BinOpType::Eq, col_ref(0), col_ref(0)); + let schema = Schema::new(vec![]); + let column_refs = vec![ + ColumnRef::base_table_column_ref(String::from(TABLE1_NAME), 0), + ColumnRef::base_table_column_ref(String::from(TABLE2_NAME), 0), + ]; + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree, + &schema, + &column_refs, + None + ), + DEFAULT_EQ_SEL + ); + } + + // We don't test joinsel or with oncond because if there is an oncond (on condition), the + // top-level operator must be an AND + + /// I made this helper function to avoid copying all eight lines over and over + fn assert_outer_selectivities( + cost_model: &TestOptCostModel, + expr_tree: ArcDfPredNode, + expr_tree_rev: ArcDfPredNode, + schema: &Schema, + column_refs: &BaseTableColumnRefs, + expected_table1_outer_sel: f64, + expected_table2_outer_sel: f64, + ) { + // all table 1 outer combinations + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + cost_model, + false, + JoinType::LeftOuter, + expr_tree.clone(), + schema, + column_refs, + None + ), + expected_table1_outer_sel + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + cost_model, + false, + JoinType::LeftOuter, + expr_tree_rev.clone(), + schema, + column_refs, + None + ), + expected_table1_outer_sel + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + cost_model, + true, + JoinType::RightOuter, + expr_tree.clone(), + schema, + column_refs, + None + ), + expected_table1_outer_sel + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + cost_model, + true, + JoinType::RightOuter, + expr_tree_rev.clone(), + schema, + column_refs, + None + ), + expected_table1_outer_sel + ); + // all table 2 outer combinations + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + cost_model, + true, + JoinType::LeftOuter, + expr_tree.clone(), + schema, + column_refs, + None + ), + expected_table2_outer_sel + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + cost_model, + true, + JoinType::LeftOuter, + expr_tree_rev.clone(), + schema, + column_refs, + None + ), + expected_table2_outer_sel + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + cost_model, + false, + JoinType::RightOuter, + expr_tree.clone(), + schema, + column_refs, + None + ), + expected_table2_outer_sel + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + cost_model, + false, + JoinType::RightOuter, + expr_tree_rev.clone(), + schema, + column_refs, + None + ), + expected_table2_outer_sel + ); + } + + /// Unique oncond means an oncondition on columns which are unique in both tables + /// There's only one case if both columns are unique and have different row counts: the inner + /// will be < 1 / row count of one table and = 1 / row count of another + #[test] + fn test_outer_unique_oncond() { + let cost_model = create_two_table_cost_model_custom_row_cnts( + TestPerColumnStats::new( + TestMostCommonValues::empty(), + 5, + 0.0, + Some(TestDistribution::empty()), + ), + TestPerColumnStats::new( + TestMostCommonValues::empty(), + 4, + 0.0, + Some(TestDistribution::empty()), + ), + 5, + 4, + ); + // the left/right of the join refers to the tables, not the order of columns in the + // predicate + let expr_tree = bin_op(BinOpType::Eq, col_ref(0), col_ref(1)); + let expr_tree_rev = bin_op(BinOpType::Eq, col_ref(1), col_ref(0)); + let schema = Schema::new(vec![]); + let column_refs = vec![ + ColumnRef::base_table_column_ref(String::from(TABLE1_NAME), 0), + ColumnRef::base_table_column_ref(String::from(TABLE2_NAME), 0), + ]; + // sanity check the expected inner sel + let expected_inner_sel = 0.2; + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree.clone(), + &schema, + &column_refs, + None + ), + expected_inner_sel + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree_rev.clone(), + &schema, + &column_refs, + None + ), + expected_inner_sel + ); + // check the outer sels + assert_outer_selectivities( + &cost_model, + expr_tree, + expr_tree_rev, + &schema, + &column_refs, + 0.25, + 0.2, + ); + } + + /// Non-unique oncond means the column is not unique in either table + /// Inner always >= row count means that the inner join result is >= 1 / the row count of both + /// tables + #[test] + fn test_outer_nonunique_oncond_inner_always_geq_rowcnt() { + let cost_model = create_two_table_cost_model_custom_row_cnts( + TestPerColumnStats::new( + TestMostCommonValues::empty(), + 5, + 0.0, + Some(TestDistribution::empty()), + ), + TestPerColumnStats::new( + TestMostCommonValues::empty(), + 4, + 0.0, + Some(TestDistribution::empty()), + ), + 10, + 8, + ); + // the left/right of the join refers to the tables, not the order of columns in the + // predicate + let expr_tree = bin_op(BinOpType::Eq, col_ref(0), col_ref(1)); + let expr_tree_rev = bin_op(BinOpType::Eq, col_ref(1), col_ref(0)); + let schema = Schema::new(vec![]); + let column_refs = vec![ + ColumnRef::base_table_column_ref(String::from(TABLE1_NAME), 0), + ColumnRef::base_table_column_ref(String::from(TABLE2_NAME), 0), + ]; + // sanity check the expected inner sel + let expected_inner_sel = 0.2; + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree.clone(), + &schema, + &column_refs, + None + ), + expected_inner_sel + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree_rev.clone(), + &schema, + &column_refs, + None + ), + expected_inner_sel + ); + // check the outer sels + assert_outer_selectivities( + &cost_model, + expr_tree, + expr_tree_rev, + &schema, + &column_refs, + 0.2, + 0.2, + ); + } + + /// Non-unique oncond means the column is not unique in either table + /// Inner sometimes < row count means that the inner join result < 1 / the row count of exactly + /// one table. Note that without a join filter, it's impossible to be less than the row + /// count of both tables + #[test] + fn test_outer_nonunique_oncond_inner_sometimes_lt_rowcnt() { + let cost_model = create_two_table_cost_model_custom_row_cnts( + TestPerColumnStats::new( + TestMostCommonValues::empty(), + 10, + 0.0, + Some(TestDistribution::empty()), + ), + TestPerColumnStats::new( + TestMostCommonValues::empty(), + 2, + 0.0, + Some(TestDistribution::empty()), + ), + 20, + 4, + ); + // the left/right of the join refers to the tables, not the order of columns in the + // predicate + let expr_tree = bin_op(BinOpType::Eq, col_ref(0), col_ref(1)); + let expr_tree_rev = bin_op(BinOpType::Eq, col_ref(1), col_ref(0)); + let schema = Schema::new(vec![]); + let column_refs = vec![ + ColumnRef::base_table_column_ref(String::from(TABLE1_NAME), 0), + ColumnRef::base_table_column_ref(String::from(TABLE2_NAME), 0), + ]; + // sanity check the expected inner sel + let expected_inner_sel = 0.1; + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree.clone(), + &schema, + &column_refs, + None + ), + expected_inner_sel + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree_rev.clone(), + &schema, + &column_refs, + None + ), + expected_inner_sel + ); + // check the outer sels + assert_outer_selectivities( + &cost_model, + expr_tree, + expr_tree_rev, + &schema, + &column_refs, + 0.25, + 0.1, + ); + } + + /// Unique oncond means an oncondition on columns which are unique in both tables + /// Filter means we're adding a join filter + /// There's only one case if both columns are unique and there's a filter: + /// the inner will be < 1 / row count of both tables + #[test] + fn test_outer_unique_oncond_filter() { + let cost_model = create_two_table_cost_model_custom_row_cnts( + TestPerColumnStats::new( + TestMostCommonValues::empty(), + 50, + 0.0, + Some(TestDistribution::new(vec![(Value::Int32(128), 0.4)])), + ), + TestPerColumnStats::new( + TestMostCommonValues::empty(), + 4, + 0.0, + Some(TestDistribution::empty()), + ), + 50, + 4, + ); + // the left/right of the join refers to the tables, not the order of columns in the + // predicate + let eq0and1 = bin_op(BinOpType::Eq, col_ref(0), col_ref(1)); + let eq1and0 = bin_op(BinOpType::Eq, col_ref(1), col_ref(0)); + let filter = bin_op(BinOpType::Leq, col_ref(0), cnst(Value::Int32(128))); + let expr_tree = log_op(LogOpType::And, vec![eq0and1, filter.clone()]); + // inner rev means its the inner expr (the eq op) whose children are being reversed, as + // opposed to the and op + let expr_tree_inner_rev = log_op(LogOpType::And, vec![eq1and0, filter.clone()]); + let schema = Schema::new(vec![]); + let column_refs = vec![ + ColumnRef::base_table_column_ref(String::from(TABLE1_NAME), 0), + ColumnRef::base_table_column_ref(String::from(TABLE2_NAME), 0), + ]; + // sanity check the expected inner sel + let expected_inner_sel = 0.008; + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree.clone(), + &schema, + &column_refs, + None + ), + expected_inner_sel + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree_inner_rev.clone(), + &schema, + &column_refs, + None + ), + expected_inner_sel + ); + // check the outer sels + assert_outer_selectivities( + &cost_model, + expr_tree, + expr_tree_inner_rev, + &schema, + &column_refs, + 0.25, + 0.02, + ); + } + + /// Test all possible permutations of three-table joins. + /// A three-table join consists of at least two joins. `join1_on_cond` is the condition of the + /// first join. There can only be one condition because only two tables are involved at + /// the time of the first join. + #[test_case::test_case(&[(0, 1)])] + #[test_case::test_case(&[(0, 2)])] + #[test_case::test_case(&[(1, 2)])] + #[test_case::test_case(&[(0, 1), (0, 2)])] + #[test_case::test_case(&[(0, 1), (1, 2)])] + #[test_case::test_case(&[(0, 2), (1, 2)])] + #[test_case::test_case(&[(0, 1), (0, 2), (1, 2)])] + fn test_three_table_join_for_initial_join_on_conds(initial_join_on_conds: &[(usize, usize)]) { + assert!( + !initial_join_on_conds.is_empty(), + "initial_join_on_conds should be non-empty" + ); + assert_eq!( + initial_join_on_conds.len(), + initial_join_on_conds.iter().collect::>().len(), + "initial_join_on_conds shouldn't contain duplicates" + ); + let cost_model = create_three_table_cost_model( + TestPerColumnStats::new( + TestMostCommonValues::empty(), + 2, + 0.0, + Some(TestDistribution::empty()), + ), + TestPerColumnStats::new( + TestMostCommonValues::empty(), + 3, + 0.0, + Some(TestDistribution::empty()), + ), + TestPerColumnStats::new( + TestMostCommonValues::empty(), + 4, + 0.0, + Some(TestDistribution::empty()), + ), + ); + let col_base_refs = vec![ + BaseTableColumnRef { + table: String::from(TABLE1_NAME), + col_idx: 0, + }, + BaseTableColumnRef { + table: String::from(TABLE2_NAME), + col_idx: 0, + }, + BaseTableColumnRef { + table: String::from(TABLE3_NAME), + col_idx: 0, + }, + ]; + let col_refs: BaseTableColumnRefs = col_base_refs + .clone() + .into_iter() + .map(|col_base_ref| col_base_ref.into()) + .collect(); + + let mut eq_columns = EqBaseTableColumnSets::new(); + for initial_join_on_cond in initial_join_on_conds { + eq_columns.add_predicate(EqPredicate::new( + col_base_refs[initial_join_on_cond.0].clone(), + col_base_refs[initial_join_on_cond.1].clone(), + )); + } + let initial_selectivity = { + if initial_join_on_conds.len() == 1 { + let initial_join_on_cond = initial_join_on_conds.first().unwrap(); + if initial_join_on_cond == &(0, 1) { + 1.0 / 3.0 + } else if initial_join_on_cond == &(0, 2) || initial_join_on_cond == &(1, 2) { + 1.0 / 4.0 + } else { + panic!(); + } + } else { + 1.0 / 12.0 + } + }; + let semantic_correlation = SemanticCorrelation::new(eq_columns); + let schema = Schema::new(vec![]); + let column_refs = col_refs; + let input_correlation = Some(semantic_correlation); + + // Try all join conditions of the final join which would lead to all three tables being + // joined. + let eq0and1 = bin_op(BinOpType::Eq, col_ref(0), col_ref(1)); + let eq0and2 = bin_op(BinOpType::Eq, col_ref(0), col_ref(2)); + let eq1and2 = bin_op(BinOpType::Eq, col_ref(1), col_ref(2)); + let and_01_02 = log_op(LogOpType::And, vec![eq0and1.clone(), eq0and2.clone()]); + let and_01_12 = log_op(LogOpType::And, vec![eq0and1.clone(), eq1and2.clone()]); + let and_02_12 = log_op(LogOpType::And, vec![eq0and2.clone(), eq1and2.clone()]); + let and_01_02_12 = log_op( + LogOpType::And, + vec![eq0and1.clone(), eq0and2.clone(), eq1and2.clone()], + ); + let mut join2_expr_trees = vec![and_01_02, and_01_12, and_02_12, and_01_02_12]; + if initial_join_on_conds.len() == 1 { + let initial_join_on_cond = initial_join_on_conds.first().unwrap(); + if initial_join_on_cond == &(0, 1) { + join2_expr_trees.push(eq0and2); + join2_expr_trees.push(eq1and2); + } else if initial_join_on_cond == &(0, 2) { + join2_expr_trees.push(eq0and1); + join2_expr_trees.push(eq1and2); + } else if initial_join_on_cond == &(1, 2) { + join2_expr_trees.push(eq0and1); + join2_expr_trees.push(eq0and2); + } else { + panic!(); + } + } + for expr_tree in join2_expr_trees { + let overall_selectivity = initial_selectivity + * test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree.clone(), + &schema, + &column_refs, + input_correlation.clone(), + ); + assert_approx_eq::assert_approx_eq!(overall_selectivity, 1.0 / 12.0); + } + } + + #[test] + fn test_join_which_connects_two_components_together() { + let cost_model = create_four_table_cost_model( + TestPerColumnStats::new( + TestMostCommonValues::empty(), + 2, + 0.0, + Some(TestDistribution::empty()), + ), + TestPerColumnStats::new( + TestMostCommonValues::empty(), + 3, + 0.0, + Some(TestDistribution::empty()), + ), + TestPerColumnStats::new( + TestMostCommonValues::empty(), + 4, + 0.0, + Some(TestDistribution::empty()), + ), + TestPerColumnStats::new( + TestMostCommonValues::empty(), + 5, + 0.0, + Some(TestDistribution::empty()), + ), + ); + let col_base_refs = vec![ + BaseTableColumnRef { + table: String::from(TABLE1_NAME), + col_idx: 0, + }, + BaseTableColumnRef { + table: String::from(TABLE2_NAME), + col_idx: 0, + }, + BaseTableColumnRef { + table: String::from(TABLE3_NAME), + col_idx: 0, + }, + BaseTableColumnRef { + table: String::from(TABLE4_NAME), + col_idx: 0, + }, + ]; + let col_refs: BaseTableColumnRefs = col_base_refs + .clone() + .into_iter() + .map(|col_base_ref| col_base_ref.into()) + .collect(); + + let mut eq_columns = EqBaseTableColumnSets::new(); + eq_columns.add_predicate(EqPredicate::new( + col_base_refs[0].clone(), + col_base_refs[1].clone(), + )); + eq_columns.add_predicate(EqPredicate::new( + col_base_refs[2].clone(), + col_base_refs[3].clone(), + )); + let initial_selectivity = 1.0 / (3.0 * 5.0); + let semantic_correlation = SemanticCorrelation::new(eq_columns); + let schema = Schema::new(vec![]); + let column_refs = col_refs; + let input_correlation = Some(semantic_correlation); + + let eq1and2 = bin_op(BinOpType::Eq, col_ref(1), col_ref(2)); + let overall_selectivity = initial_selectivity + * test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + eq1and2.clone(), + &schema, + &column_refs, + input_correlation, + ); + assert_approx_eq::assert_approx_eq!(overall_selectivity, 1.0 / (3.0 * 4.0 * 5.0)); + } +} From 6d508434284a84154482d8055636610ac192b3b3 Mon Sep 17 00:00:00 2001 From: Yuanxin Cao Date: Mon, 18 Nov 2024 11:57:57 -0500 Subject: [PATCH 35/51] make join compile --- optd-cost-model/src/common/nodes.rs | 11 +- .../src/common/predicates/attr_ref_pred.rs | 14 +- .../src/common/properties/attr_ref.rs | 2 +- optd-cost-model/src/cost/agg.rs | 6 +- optd-cost-model/src/cost/filter/comp_op.rs | 23 +- optd-cost-model/src/cost/filter/controller.rs | 2 +- optd-cost-model/src/cost/filter/in_list.rs | 6 +- optd-cost-model/src/cost/filter/like.rs | 6 +- optd-cost-model/src/cost/join.rs | 1248 +++-------------- optd-cost-model/src/cost_model.rs | 4 +- 10 files changed, 218 insertions(+), 1104 deletions(-) diff --git a/optd-cost-model/src/common/nodes.rs b/optd-cost-model/src/common/nodes.rs index 8bfdabb..8ad98e4 100644 --- a/optd-cost-model/src/common/nodes.rs +++ b/optd-cost-model/src/common/nodes.rs @@ -1,4 +1,5 @@ -use std::sync::Arc; +use core::fmt; +use std::{fmt::Display, sync::Arc}; use arrow_schema::DataType; @@ -24,6 +25,12 @@ pub enum JoinType { RightAnti, } +impl Display for JoinType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:?}", self) + } +} + /// TODO: documentation #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum PhysicalNodeType { @@ -49,7 +56,7 @@ impl std::fmt::Display for PhysicalNodeType { pub enum PredicateType { List, Constant(ConstantType), - AttributeRef, + AttrRef, ExternAttributeRef, // TODO(lanlou): Id -> Id(IdType) Id, diff --git a/optd-cost-model/src/common/predicates/attr_ref_pred.rs b/optd-cost-model/src/common/predicates/attr_ref_pred.rs index 9f63ad7..9afe6a0 100644 --- a/optd-cost-model/src/common/predicates/attr_ref_pred.rs +++ b/optd-cost-model/src/common/predicates/attr_ref_pred.rs @@ -25,13 +25,13 @@ use super::id_pred::IdPred; /// TODO: Support derived column in `AttributeRefPred`. /// Proposal: Data field can store the column type (base or derived). #[derive(Clone, Debug)] -pub struct AttributeRefPred(pub ArcPredicateNode); +pub struct AttrRefPred(pub ArcPredicateNode); -impl AttributeRefPred { - pub fn new(table_id: TableId, attribute_idx: u64) -> AttributeRefPred { - AttributeRefPred( +impl AttrRefPred { + pub fn new(table_id: TableId, attribute_idx: u64) -> AttrRefPred { + AttrRefPred( PredicateNode { - typ: PredicateType::AttributeRef, + typ: PredicateType::AttrRef, children: vec![ IdPred::new(table_id.0).into_pred_node(), IdPred::new(attribute_idx).into_pred_node(), @@ -60,13 +60,13 @@ impl AttributeRefPred { } } -impl ReprPredicateNode for AttributeRefPred { +impl ReprPredicateNode for AttrRefPred { fn into_pred_node(self) -> ArcPredicateNode { self.0 } fn from_pred_node(pred_node: ArcPredicateNode) -> Option { - if pred_node.typ != PredicateType::AttributeRef { + if pred_node.typ != PredicateType::AttrRef { return None; } Some(Self(pred_node)) diff --git a/optd-cost-model/src/common/properties/attr_ref.rs b/optd-cost-model/src/common/properties/attr_ref.rs index eb10fbb..fea3270 100644 --- a/optd-cost-model/src/common/properties/attr_ref.rs +++ b/optd-cost-model/src/common/properties/attr_ref.rs @@ -176,7 +176,7 @@ impl GroupAttrRefs { } } - pub fn base_table_attribute_refs(&self) -> &AttrRefs { + pub fn base_table_attr_refs(&self) -> &AttrRefs { &self.attribute_refs } diff --git a/optd-cost-model/src/cost/agg.rs b/optd-cost-model/src/cost/agg.rs index e2b94b1..f288ebb 100644 --- a/optd-cost-model/src/cost/agg.rs +++ b/optd-cost-model/src/cost/agg.rs @@ -1,7 +1,7 @@ use crate::{ common::{ nodes::{ArcPredicateNode, PredicateType, ReprPredicateNode}, - predicates::{attr_ref_pred::AttributeRefPred, list_pred::ListPred}, + predicates::{attr_ref_pred::AttrRefPred, list_pred::ListPred}, types::TableId, }, cost_model::CostModelImpl, @@ -25,9 +25,9 @@ impl CostModelImpl { for node in &group_by.0.children { match node.typ { - PredicateType::AttributeRef => { + PredicateType::AttrRef => { let attr_ref = - AttributeRefPred::from_pred_node(node.clone()).ok_or_else(|| { + AttrRefPred::from_pred_node(node.clone()).ok_or_else(|| { SemanticError::InvalidPredicate( "Expected AttributeRef predicate".to_string(), ) diff --git a/optd-cost-model/src/cost/filter/comp_op.rs b/optd-cost-model/src/cost/filter/comp_op.rs index 2e463bd..9712c82 100644 --- a/optd-cost-model/src/cost/filter/comp_op.rs +++ b/optd-cost-model/src/cost/filter/comp_op.rs @@ -4,7 +4,7 @@ use crate::{ common::{ nodes::{ArcPredicateNode, PredicateType, ReprPredicateNode}, predicates::{ - attr_ref_pred::AttributeRefPred, bin_op_pred::BinOpType, cast_pred::CastPred, + attr_ref_pred::AttrRefPred, bin_op_pred::BinOpType, cast_pred::CastPred, constant_pred::ConstantPred, }, values::Value, @@ -118,12 +118,7 @@ impl CostModelImpl { &self, left: ArcPredicateNode, right: ArcPredicateNode, - ) -> CostModelResult<( - Vec, - Vec, - Vec, - bool, - )> { + ) -> CostModelResult<(Vec, Vec, Vec, bool)> { let mut attr_ref_exprs = vec![]; let mut values = vec![]; let mut non_attr_ref_exprs = vec![]; @@ -171,8 +166,8 @@ impl CostModelImpl { .into_pred_node(); false } - PredicateType::AttributeRef => { - let attr_ref_expr = AttributeRefPred::from_pred_node(cast_expr_child) + PredicateType::AttrRef => { + let attr_ref_expr = AttrRefPred::from_pred_node(cast_expr_child) .expect("we already checked that the type is AttributeRef"); let attr_ref_idx = attr_ref_expr.attr_index(); let table_id = attr_ref_expr.table_id(); @@ -190,7 +185,7 @@ impl CostModelImpl { let invert_cast_data_type = &attribute_info.typ.into_data_type(); match non_cast_node.typ { - PredicateType::AttributeRef => { + PredicateType::AttrRef => { // In general, there's no way to remove the Cast here. We can't move // the Cast to the other AttributeRef // because that would lead to an infinite loop. Thus, we just leave @@ -224,10 +219,10 @@ impl CostModelImpl { // Sort nodes into attr_ref_exprs, values, and non_attr_ref_exprs match uncasted_left.as_ref().typ { - PredicateType::AttributeRef => { + PredicateType::AttrRef => { is_left_attr_ref = true; attr_ref_exprs.push( - AttributeRefPred::from_pred_node(uncasted_left) + AttrRefPred::from_pred_node(uncasted_left) .expect("we already checked that the type is AttributeRef"), ); } @@ -245,9 +240,9 @@ impl CostModelImpl { } } match uncasted_right.as_ref().typ { - PredicateType::AttributeRef => { + PredicateType::AttrRef => { attr_ref_exprs.push( - AttributeRefPred::from_pred_node(uncasted_right) + AttrRefPred::from_pred_node(uncasted_right) .expect("we already checked that the type is AttributeRef"), ); } diff --git a/optd-cost-model/src/cost/filter/controller.rs b/optd-cost-model/src/cost/filter/controller.rs index 80ec2ee..3f6ef21 100644 --- a/optd-cost-model/src/cost/filter/controller.rs +++ b/optd-cost-model/src/cost/filter/controller.rs @@ -31,7 +31,7 @@ impl CostModelImpl { Box::pin(async move { match &expr_tree.typ { PredicateType::Constant(_) => Ok(Self::get_constant_selectivity(expr_tree)), - PredicateType::AttributeRef => unimplemented!("check bool type or else panic"), + PredicateType::AttrRef => unimplemented!("check bool type or else panic"), PredicateType::UnOp(un_op_typ) => { assert!(expr_tree.children.len() == 1); let child = expr_tree.child(0); diff --git a/optd-cost-model/src/cost/filter/in_list.rs b/optd-cost-model/src/cost/filter/in_list.rs index 2c79ec2..8c11bcd 100644 --- a/optd-cost-model/src/cost/filter/in_list.rs +++ b/optd-cost-model/src/cost/filter/in_list.rs @@ -2,7 +2,7 @@ use crate::{ common::{ nodes::{PredicateType, ReprPredicateNode}, predicates::{ - attr_ref_pred::AttributeRefPred, constant_pred::ConstantPred, in_list_pred::InListPred, + attr_ref_pred::AttrRefPred, constant_pred::ConstantPred, in_list_pred::InListPred, }, }, cost_model::CostModelImpl, @@ -18,7 +18,7 @@ impl CostModelImpl { let child = expr.child(); // Check child is a attribute ref. - if !matches!(child.typ, PredicateType::AttributeRef) { + if !matches!(child.typ, PredicateType::AttrRef) { return Ok(UNIMPLEMENTED_SEL); } @@ -32,7 +32,7 @@ impl CostModelImpl { } // Convert child and const expressions to concrete types. - let attr_ref_pred = AttributeRefPred::from_pred_node(child).unwrap(); + let attr_ref_pred = AttrRefPred::from_pred_node(child).unwrap(); let attr_ref_idx = attr_ref_pred.attr_index(); let table_id = attr_ref_pred.table_id(); let list_exprs = list_exprs diff --git a/optd-cost-model/src/cost/filter/like.rs b/optd-cost-model/src/cost/filter/like.rs index 92a519a..997e389 100644 --- a/optd-cost-model/src/cost/filter/like.rs +++ b/optd-cost-model/src/cost/filter/like.rs @@ -4,7 +4,7 @@ use crate::{ common::{ nodes::{PredicateType, ReprPredicateNode}, predicates::{ - attr_ref_pred::AttributeRefPred, constant_pred::ConstantPred, like_pred::LikePred, + attr_ref_pred::AttrRefPred, constant_pred::ConstantPred, like_pred::LikePred, }, }, cost_model::CostModelImpl, @@ -32,7 +32,7 @@ impl CostModelImpl { let child = like_expr.child(); // Check child is a attribute ref. - if !matches!(child.typ, PredicateType::AttributeRef) { + if !matches!(child.typ, PredicateType::AttrRef) { return Ok(UNIMPLEMENTED_SEL); } @@ -42,7 +42,7 @@ impl CostModelImpl { return Ok(UNIMPLEMENTED_SEL); } - let attr_ref_pred = AttributeRefPred::from_pred_node(child).unwrap(); + let attr_ref_pred = AttrRefPred::from_pred_node(child).unwrap(); let attr_ref_idx = attr_ref_pred.attr_index(); let table_id = attr_ref_pred.table_id(); diff --git a/optd-cost-model/src/cost/join.rs b/optd-cost-model/src/cost/join.rs index 5aa8fb6..e313603 100644 --- a/optd-cost-model/src/cost/join.rs +++ b/optd-cost-model/src/cost/join.rs @@ -1,61 +1,61 @@ -// Copyright (c) 2023-2024 CMU Database Group -// -// Use of this source code is governed by an MIT-style license that can be found in the LICENSE file or at -// https://opensource.org/licenses/MIT. - use std::collections::HashSet; use itertools::Itertools; -use optd_datafusion_repr::plan_nodes::{ - ArcDfPredNode, BinOpType, ColumnRefPred, DfPredType, DfReprPredNode, JoinType, ListPred, - LogOpPred, LogOpType, -}; -use optd_datafusion_repr::properties::column_ref::{ - BaseTableColumnRef, BaseTableColumnRefs, ColumnRef, EqBaseTableColumnSets, EqPredicate, - GroupColumnRefs, SemanticCorrelation, -}; -use optd_datafusion_repr::properties::schema::Schema; -use serde::de::DeserializeOwned; -use serde::Serialize; -use super::AdvStats; -use crate::adv_stats::stats::{Distribution, MostCommonValues}; -use crate::adv_stats::DEFAULT_NUM_DISTINCT; +use crate::{ + common::{ + nodes::{ArcPredicateNode, JoinType, PredicateType, ReprPredicateNode}, + predicates::{ + attr_ref_pred::AttrRefPred, + bin_op_pred::BinOpType, + list_pred::ListPred, + log_op_pred::{LogOpPred, LogOpType}, + }, + properties::{ + attr_ref::{ + self, AttrRef, AttrRefs, BaseTableAttrRef, EqPredicate, GroupAttrRefs, + SemanticCorrelation, + }, + schema::Schema, + }, + }, + cost_model::CostModelImpl, + stats::DEFAULT_NUM_DISTINCT, + storage::CostModelStorageManager, + CostModelResult, +}; -impl< - M: MostCommonValues + Serialize + DeserializeOwned, - D: Distribution + Serialize + DeserializeOwned, - > AdvStats -{ +impl CostModelImpl { #[allow(clippy::too_many_arguments)] - pub(crate) fn get_nlj_row_cnt( + pub async fn get_nlj_row_cnt( &self, join_typ: JoinType, left_row_cnt: f64, right_row_cnt: f64, output_schema: Schema, - output_column_refs: GroupColumnRefs, - join_cond: ArcDfPredNode, - left_column_refs: GroupColumnRefs, - right_column_refs: GroupColumnRefs, - ) -> f64 { + output_column_refs: GroupAttrRefs, + join_cond: ArcPredicateNode, + left_column_refs: GroupAttrRefs, + right_column_refs: GroupAttrRefs, + ) -> CostModelResult { let selectivity = { let input_correlation = self.get_input_correlation(left_column_refs, right_column_refs); self.get_join_selectivity_from_expr_tree( join_typ, join_cond, &output_schema, - output_column_refs.base_table_column_refs(), + output_column_refs.base_table_attr_refs(), input_correlation, left_row_cnt, right_row_cnt, ) + .await? }; - (left_row_cnt * right_row_cnt * selectivity).max(1.0) + Ok((left_row_cnt * right_row_cnt * selectivity).max(1.0)) } #[allow(clippy::too_many_arguments)] - pub(crate) fn get_hash_join_row_cnt( + pub async fn get_hash_join_row_cnt( &self, join_typ: JoinType, left_row_cnt: f64, @@ -63,17 +63,17 @@ impl< left_keys: ListPred, right_keys: ListPred, output_schema: Schema, - output_column_refs: GroupColumnRefs, - left_column_refs: GroupColumnRefs, - right_column_refs: GroupColumnRefs, - ) -> f64 { + output_column_refs: GroupAttrRefs, + left_column_refs: GroupAttrRefs, + right_column_refs: GroupAttrRefs, + ) -> CostModelResult { let selectivity = { let schema = output_schema; let column_refs = output_column_refs; - let column_refs = column_refs.base_table_column_refs(); - let left_col_cnt = left_column_refs.base_table_column_refs().len(); + let column_refs = column_refs.base_table_attr_refs(); + let left_col_cnt = left_column_refs.base_table_attr_refs().len(); // there may be more than one expression tree in a group. - // see comment in DfPredType::PhysicalFilter(_) for more information + // see comment in PredicateType::PhysicalFilter(_) for more information let input_correlation = self.get_input_correlation(left_column_refs, right_column_refs); self.get_join_selectivity_from_keys( join_typ, @@ -86,14 +86,15 @@ impl< right_row_cnt, left_col_cnt, ) + .await? }; - (left_row_cnt * right_row_cnt * selectivity).max(1.0) + Ok((left_row_cnt * right_row_cnt * selectivity).max(1.0)) } fn get_input_correlation( &self, - left_prop: GroupColumnRefs, - right_prop: GroupColumnRefs, + left_prop: GroupAttrRefs, + right_prop: GroupAttrRefs, ) -> Option { SemanticCorrelation::merge( left_prop.output_correlation().cloned(), @@ -103,18 +104,18 @@ impl< /// A wrapper to convert the join keys to the format expected by get_join_selectivity_core() #[allow(clippy::too_many_arguments)] - fn get_join_selectivity_from_keys( + async fn get_join_selectivity_from_keys( &self, join_typ: JoinType, left_keys: ListPred, right_keys: ListPred, schema: &Schema, - column_refs: &BaseTableColumnRefs, + column_refs: &AttrRefs, input_correlation: Option, left_row_cnt: f64, right_row_cnt: f64, left_col_cnt: usize, - ) -> f64 { + ) -> CostModelResult { assert!(left_keys.len() == right_keys.len()); // I assume that the keys are already in the right order // s.t. the ith key of left_keys corresponds with the ith key of right_keys @@ -124,9 +125,8 @@ impl< .zip(right_keys.to_vec()) .map(|(left_key, right_key)| { ( - ColumnRefPred::from_pred_node(left_key).expect("keys should be ColumnRefPreds"), - ColumnRefPred::from_pred_node(right_key) - .expect("keys should be ColumnRefPreds"), + AttrRefPred::from_pred_node(left_key).expect("keys should be AttrRefPreds"), + AttrRefPred::from_pred_node(right_key).expect("keys should be AttrRefPreds"), ) }) .collect_vec(); @@ -141,6 +141,7 @@ impl< right_row_cnt, left_col_cnt, ) + .await } /// The core logic of join selectivity which assumes we've already separated the expression @@ -156,24 +157,26 @@ impl< /// For example, if the left table has 3 columns, the first column of the right table /// is #3 instead of #0. #[allow(clippy::too_many_arguments)] - fn get_join_selectivity_core( + async fn get_join_selectivity_core( &self, join_typ: JoinType, - on_col_ref_pairs: Vec<(ColumnRefPred, ColumnRefPred)>, - filter_expr_tree: Option, + on_col_ref_pairs: Vec<(AttrRefPred, AttrRefPred)>, + filter_expr_tree: Option, schema: &Schema, - column_refs: &BaseTableColumnRefs, + column_refs: &AttrRefs, input_correlation: Option, left_row_cnt: f64, right_row_cnt: f64, right_col_ref_offset: usize, - ) -> f64 { - let join_on_selectivity = self.get_join_on_selectivity( - &on_col_ref_pairs, - column_refs, - input_correlation, - right_col_ref_offset, - ); + ) -> CostModelResult { + let join_on_selectivity = self + .get_join_on_selectivity( + &on_col_ref_pairs, + column_refs, + input_correlation, + right_col_ref_offset, + ) + .await?; // Currently, there is no difference in how we handle a join filter and a select filter, // so we use the same function. // @@ -182,12 +185,14 @@ impl< // get_filter_selectivity() function, but this may change in the future. let join_filter_selectivity = match filter_expr_tree { Some(filter_expr_tree) => { - self.get_filter_selectivity(filter_expr_tree, schema, column_refs) + // FIXME: Pass in group id or schema & attr_refs + self.get_filter_selectivity(filter_expr_tree).await? } None => 1.0, }; let inner_join_selectivity = join_on_selectivity * join_filter_selectivity; - match join_typ { + + Ok(match join_typ { JoinType::Inner => inner_join_selectivity, JoinType::LeftOuter => f64::max(inner_join_selectivity, 1.0 / right_row_cnt), JoinType::RightOuter => f64::max(inner_join_selectivity, 1.0 / left_row_cnt), @@ -199,7 +204,7 @@ impl< join_filter_selectivity } _ => unimplemented!("join_typ={} is not implemented", join_typ), - } + }) } /// The expr_tree input must be a "mixed expression tree", just like with @@ -208,17 +213,17 @@ impl< /// This is a "wrapper" to separate the equality conditions from the filter conditions before /// calling the "main" `get_join_selectivity_core` function. #[allow(clippy::too_many_arguments)] - fn get_join_selectivity_from_expr_tree( + async fn get_join_selectivity_from_expr_tree( &self, join_typ: JoinType, - expr_tree: ArcDfPredNode, + expr_tree: ArcPredicateNode, schema: &Schema, - column_refs: &BaseTableColumnRefs, + column_refs: &AttrRefs, input_correlation: Option, left_row_cnt: f64, right_row_cnt: f64, - ) -> f64 { - if expr_tree.typ == DfPredType::LogOp(LogOpType::And) { + ) -> CostModelResult { + if expr_tree.typ == PredicateType::LogOp(LogOpType::And) { let mut on_col_ref_pairs = vec![]; let mut filter_expr_trees = vec![]; for child_expr_tree in &expr_tree.children { @@ -248,6 +253,7 @@ impl< right_row_cnt, 0, ) + .await } else { #[allow(clippy::collapsible_else_if)] if let Some(on_col_ref_pair) = Self::get_on_col_ref_pair(expr_tree.clone(), column_refs) @@ -263,6 +269,7 @@ impl< right_row_cnt, 0, ) + .await } else { self.get_join_selectivity_core( join_typ, @@ -275,6 +282,7 @@ impl< right_row_cnt, 0, ) + .await } } } @@ -284,33 +292,36 @@ impl< /// identical. It only picks out equality conditions between two column refs on different /// tables fn get_on_col_ref_pair( - expr_tree: ArcDfPredNode, - column_refs: &BaseTableColumnRefs, - ) -> Option<(ColumnRefPred, ColumnRefPred)> { + expr_tree: ArcPredicateNode, + column_refs: &AttrRefs, + ) -> Option<(AttrRefPred, AttrRefPred)> { // 1. Check that it's equality - if expr_tree.typ == DfPredType::BinOp(BinOpType::Eq) { + if expr_tree.typ == PredicateType::BinOp(BinOpType::Eq) { let left_child = expr_tree.child(0); let right_child = expr_tree.child(1); // 2. Check that both sides are column refs - if left_child.typ == DfPredType::ColumnRef && right_child.typ == DfPredType::ColumnRef { + if left_child.typ == PredicateType::AttrRef && right_child.typ == PredicateType::AttrRef + { // 3. Check that both sides don't belong to the same table (if we don't know, that // means they don't belong) - let left_col_ref_expr = ColumnRefPred::from_pred_node(left_child) - .expect("we already checked that the type is ColumnRef"); - let right_col_ref_expr = ColumnRefPred::from_pred_node(right_child) - .expect("we already checked that the type is ColumnRef"); - let left_col_ref = &column_refs[left_col_ref_expr.index()]; - let right_col_ref = &column_refs[right_col_ref_expr.index()]; + let left_col_ref_expr = AttrRefPred::from_pred_node(left_child) + .expect("we already checked that the type is AttrRef"); + let right_col_ref_expr = AttrRefPred::from_pred_node(right_child) + .expect("we already checked that the type is AttrRef"); + let left_col_ref = &column_refs[left_col_ref_expr.attr_index() as usize]; + let right_col_ref = &column_refs[right_col_ref_expr.attr_index() as usize]; let is_same_table = if let ( - ColumnRef::BaseTableColumnRef(BaseTableColumnRef { - table: left_table, .. + AttrRef::BaseTableAttrRef(BaseTableAttrRef { + table_id: left_table_id, + .. }), - ColumnRef::BaseTableColumnRef(BaseTableColumnRef { - table: right_table, .. + AttrRef::BaseTableAttrRef(BaseTableAttrRef { + table_id: right_table_id, + .. }), ) = (left_col_ref, right_col_ref) { - left_table == right_table + left_table_id == right_table_id } else { false }; @@ -328,54 +339,72 @@ impl< } /// Get the selectivity of one column eq predicate, e.g. colA = colB. - fn get_join_selectivity_from_on_col_ref_pair( + async fn get_join_selectivity_from_on_col_ref_pair( &self, - left: &ColumnRef, - right: &ColumnRef, - ) -> f64 { + left: &AttrRef, + right: &AttrRef, + ) -> CostModelResult { // the formula for each pair is min(1 / ndistinct1, 1 / ndistinct2) // (see https://postgrespro.com/blog/pgsql/5969618) - let ndistincts = vec![left, right].into_iter().map(|col_ref| { - match self.get_single_column_stats_from_col_ref(col_ref) { - Some(per_col_stats) => per_col_stats.ndistinct, - None => DEFAULT_NUM_DISTINCT, - } - }); + let mut ndistincts = vec![]; + for attr_ref in [left, right] { + let ndistinct = match attr_ref { + AttrRef::BaseTableAttrRef(base_attr_ref) => { + match self + .get_attribute_comb_stats(base_attr_ref.table_id, &[base_attr_ref.attr_idx]) + .await? + { + Some(per_col_stats) => per_col_stats.ndistinct, + None => DEFAULT_NUM_DISTINCT, + } + } + AttrRef::Derived => DEFAULT_NUM_DISTINCT, + }; + ndistincts.push(ndistinct); + } + // using reduce(f64::min) is the idiomatic workaround to min() because // f64 does not implement Ord due to NaN - let selectivity = ndistincts.map(|ndistinct| 1.0 / ndistinct as f64).reduce(f64::min).expect("reduce() only returns None if the iterator is empty, which is impossible since col_ref_exprs.len() == 2"); + let selectivity = ndistincts.into_iter().map(|ndistinct| 1.0 / ndistinct as f64).reduce(f64::min).expect("reduce() only returns None if the iterator is empty, which is impossible since col_ref_exprs.len() == 2"); assert!( !selectivity.is_nan(), "it should be impossible for selectivity to be NaN since n-distinct is never 0" ); - selectivity + Ok(selectivity) } /// Given a set of N columns involved in a multi-equality, find the total selectivity /// of the multi-equality. /// /// This is a generalization of get_join_selectivity_from_on_col_ref_pair(). - fn get_join_selectivity_from_most_selective_columns( + async fn get_join_selectivity_from_most_selective_columns( &self, - base_col_refs: HashSet, - ) -> f64 { - assert!(base_col_refs.len() > 1); - let num_base_col_refs = base_col_refs.len(); - base_col_refs + base_attr_refs: HashSet, + ) -> CostModelResult { + assert!(base_attr_refs.len() > 1); + let num_base_attr_refs = base_attr_refs.len(); + + let mut ndistincts = vec![]; + for base_attr_ref in base_attr_refs.iter() { + let ndistinct = match self + .get_attribute_comb_stats(base_attr_ref.table_id, &[base_attr_ref.attr_idx]) + .await? + { + Some(per_col_stats) => per_col_stats.ndistinct, + None => DEFAULT_NUM_DISTINCT, + }; + ndistincts.push(ndistinct); + } + + Ok(ndistincts .into_iter() - .map(|base_col_ref| { - match self.get_column_comb_stats(&base_col_ref.table, &[base_col_ref.col_idx]) { - Some(per_col_stats) => per_col_stats.ndistinct, - None => DEFAULT_NUM_DISTINCT, - } - }) .map(|ndistinct| 1.0 / ndistinct as f64) .sorted_by(|a, b| { a.partial_cmp(b) .expect("No floats should be NaN since n-distinct is never 0") }) - .take(num_base_col_refs - 1) - .product() + .take(num_base_attr_refs - 1) + .product()) } /// A predicate set defines a "multi-equality graph", which is an unweighted undirected graph. @@ -400,14 +429,14 @@ impl< /// quotient is the "adjustment" factor. /// /// NOTE: This function modifies `past_eq_columns` by adding `predicate` to it. - fn get_join_selectivity_adjustment_when_adding_to_multi_equality_graph( + async fn get_join_selectivity_adjustment_when_adding_to_multi_equality_graph( &self, predicate: &EqPredicate, - past_eq_columns: &mut EqBaseTableColumnSets, - ) -> f64 { + past_eq_columns: &mut SemanticCorrelation, + ) -> CostModelResult { if predicate.left == predicate.right { // self-join, TODO: is this correct? - return 1.0; + return Ok(1.0); } // To find the adjustment, we need to know the selectivity of the graph before `predicate` // is added. @@ -419,20 +448,23 @@ impl< let children_pred_sel = { if past_eq_columns.is_eq(&predicate.left, &predicate.right) { self.get_join_selectivity_from_most_selective_columns( - past_eq_columns.find_cols_for_eq_column_set(&predicate.left), + past_eq_columns.find_attrs_for_eq_attribute_set(&predicate.left), ) + .await? } else { let left_sel = if past_eq_columns.contains(&predicate.left) { self.get_join_selectivity_from_most_selective_columns( - past_eq_columns.find_cols_for_eq_column_set(&predicate.left), + past_eq_columns.find_attrs_for_eq_attribute_set(&predicate.left), ) + .await? } else { 1.0 }; let right_sel = if past_eq_columns.contains(&predicate.right) { self.get_join_selectivity_from_most_selective_columns( - past_eq_columns.find_cols_for_eq_column_set(&predicate.right), + past_eq_columns.find_attrs_for_eq_attribute_set(&predicate.right), ) + .await? } else { 1.0 }; @@ -444,12 +476,13 @@ impl< // it creates. past_eq_columns.add_predicate(predicate.clone()); let new_pred_sel = { - let cols = past_eq_columns.find_cols_for_eq_column_set(&predicate.left); + let cols = past_eq_columns.find_attrs_for_eq_attribute_set(&predicate.left); self.get_join_selectivity_from_most_selective_columns(cols) - }; + } + .await?; // Compute the adjustment factor. - new_pred_sel / children_pred_sel + Ok(new_pred_sel / children_pred_sel) } /// Get the selectivity of the on conditions. @@ -465,959 +498,38 @@ impl< /// However, we don't just throw away A = C, because we want to pick the most selective /// predicates. For details on how we do this, see /// `get_join_selectivity_from_redundant_predicates`. - fn get_join_on_selectivity( + async fn get_join_on_selectivity( &self, - on_col_ref_pairs: &[(ColumnRefPred, ColumnRefPred)], - column_refs: &BaseTableColumnRefs, + on_col_ref_pairs: &[(AttrRefPred, AttrRefPred)], + column_refs: &AttrRefs, input_correlation: Option, right_col_ref_offset: usize, - ) -> f64 { - let mut past_eq_columns = input_correlation - .map(|c| EqBaseTableColumnSets::try_from(c).unwrap()) - .unwrap_or_default(); - - // multiply the selectivities of all individual conditions together - on_col_ref_pairs - .iter() - .map(|on_col_ref_pair| { - let left_col_ref = &column_refs[on_col_ref_pair.0.index()]; - let right_col_ref = &column_refs[on_col_ref_pair.1.index() + right_col_ref_offset]; - - if let (ColumnRef::BaseTableColumnRef(left), ColumnRef::BaseTableColumnRef(right)) = - (left_col_ref, right_col_ref) - { - let predicate = EqPredicate::new(left.clone(), right.clone()); - return self - .get_join_selectivity_adjustment_when_adding_to_multi_equality_graph( - &predicate, - &mut past_eq_columns, - ); - } - - self.get_join_selectivity_from_on_col_ref_pair(left_col_ref, right_col_ref) - }) - .product() - } -} - -#[cfg(test)] -mod tests { - use std::collections::HashSet; - - use optd_core::nodes::Value; - use optd_datafusion_repr::plan_nodes::{ArcDfPredNode, BinOpType, JoinType, LogOpType}; - use optd_datafusion_repr::properties::column_ref::{ - BaseTableColumnRef, BaseTableColumnRefs, ColumnRef, EqBaseTableColumnSets, EqPredicate, - SemanticCorrelation, - }; - use optd_datafusion_repr::properties::schema::Schema; - - use crate::adv_stats::tests::*; - use crate::adv_stats::DEFAULT_EQ_SEL; - - /// A wrapper around get_join_selectivity_from_expr_tree that extracts the - /// table row counts from the cost model. - fn test_get_join_selectivity( - cost_model: &TestOptCostModel, - reverse_tables: bool, - join_typ: JoinType, - expr_tree: ArcDfPredNode, - schema: &Schema, - column_refs: &BaseTableColumnRefs, - input_correlation: Option, - ) -> f64 { - let table1_row_cnt = cost_model.per_table_stats_map[TABLE1_NAME].row_cnt as f64; - let table2_row_cnt = cost_model.per_table_stats_map[TABLE2_NAME].row_cnt as f64; - if !reverse_tables { - cost_model.get_join_selectivity_from_expr_tree( - join_typ, - expr_tree, - schema, - column_refs, - input_correlation, - table1_row_cnt, - table2_row_cnt, - ) - } else { - cost_model.get_join_selectivity_from_expr_tree( - join_typ, - expr_tree, - schema, - column_refs, - input_correlation, - table2_row_cnt, - table1_row_cnt, - ) - } - } - - #[test] - fn test_inner_const() { - let cost_model = create_one_column_cost_model(get_empty_per_col_stats()); - assert_approx_eq::assert_approx_eq!( - cost_model.get_join_selectivity_from_expr_tree( - JoinType::Inner, - cnst(Value::Bool(true)), - &Schema::new(vec![]), - &vec![], - None, - f64::NAN, - f64::NAN - ), - 1.0 - ); - assert_approx_eq::assert_approx_eq!( - cost_model.get_join_selectivity_from_expr_tree( - JoinType::Inner, - cnst(Value::Bool(false)), - &Schema::new(vec![]), - &vec![], - None, - f64::NAN, - f64::NAN - ), - 0.0 - ); - } - - #[test] - fn test_inner_oncond() { - let cost_model = create_two_table_cost_model( - TestPerColumnStats::new( - TestMostCommonValues::empty(), - 5, - 0.0, - Some(TestDistribution::empty()), - ), - TestPerColumnStats::new( - TestMostCommonValues::empty(), - 4, - 0.0, - Some(TestDistribution::empty()), - ), - ); - let expr_tree = bin_op(BinOpType::Eq, col_ref(0), col_ref(1)); - let expr_tree_rev = bin_op(BinOpType::Eq, col_ref(1), col_ref(0)); - let schema = Schema::new(vec![]); - let column_refs = vec![ - ColumnRef::base_table_column_ref(String::from(TABLE1_NAME), 0), - ColumnRef::base_table_column_ref(String::from(TABLE2_NAME), 0), - ]; - assert_approx_eq::assert_approx_eq!( - test_get_join_selectivity( - &cost_model, - false, - JoinType::Inner, - expr_tree, - &schema, - &column_refs, - None, - ), - 0.2 - ); - assert_approx_eq::assert_approx_eq!( - test_get_join_selectivity( - &cost_model, - false, - JoinType::Inner, - expr_tree_rev, - &schema, - &column_refs, - None, - ), - 0.2 - ); - } - - #[test] - fn test_inner_and_of_onconds() { - let cost_model = create_two_table_cost_model( - TestPerColumnStats::new( - TestMostCommonValues::empty(), - 5, - 0.0, - Some(TestDistribution::empty()), - ), - TestPerColumnStats::new( - TestMostCommonValues::empty(), - 4, - 0.0, - Some(TestDistribution::empty()), - ), - ); - let eq0and1 = bin_op(BinOpType::Eq, col_ref(0), col_ref(1)); - let eq1and0 = bin_op(BinOpType::Eq, col_ref(1), col_ref(0)); - let expr_tree = log_op(LogOpType::And, vec![eq0and1.clone(), eq1and0.clone()]); - let expr_tree_rev = log_op(LogOpType::And, vec![eq1and0.clone(), eq0and1.clone()]); - let schema = Schema::new(vec![]); - let column_refs = vec![ - ColumnRef::base_table_column_ref(String::from(TABLE1_NAME), 0), - ColumnRef::base_table_column_ref(String::from(TABLE2_NAME), 0), - ]; - assert_approx_eq::assert_approx_eq!( - test_get_join_selectivity( - &cost_model, - false, - JoinType::Inner, - expr_tree, - &schema, - &column_refs, - None, - ), - 0.2 - ); - assert_approx_eq::assert_approx_eq!( - test_get_join_selectivity( - &cost_model, - false, - JoinType::Inner, - expr_tree_rev, - &schema, - &column_refs, - None - ), - 0.2 - ); - } - - #[test] - fn test_inner_and_of_oncond_and_filter() { - let cost_model = create_two_table_cost_model( - TestPerColumnStats::new( - TestMostCommonValues::empty(), - 5, - 0.0, - Some(TestDistribution::empty()), - ), - TestPerColumnStats::new( - TestMostCommonValues::empty(), - 4, - 0.0, - Some(TestDistribution::empty()), - ), - ); - let eq0and1 = bin_op(BinOpType::Eq, col_ref(0), col_ref(1)); - let eq100 = bin_op(BinOpType::Eq, col_ref(1), cnst(Value::Int32(100))); - let expr_tree = log_op(LogOpType::And, vec![eq0and1.clone(), eq100.clone()]); - let expr_tree_rev = log_op(LogOpType::And, vec![eq100.clone(), eq0and1.clone()]); - let schema = Schema::new(vec![]); - let column_refs = vec![ - ColumnRef::base_table_column_ref(String::from(TABLE1_NAME), 0), - ColumnRef::base_table_column_ref(String::from(TABLE2_NAME), 0), - ]; - assert_approx_eq::assert_approx_eq!( - test_get_join_selectivity( - &cost_model, - false, - JoinType::Inner, - expr_tree, - &schema, - &column_refs, - None - ), - 0.05 - ); - assert_approx_eq::assert_approx_eq!( - test_get_join_selectivity( - &cost_model, - false, - JoinType::Inner, - expr_tree_rev, - &schema, - &column_refs, - None - ), - 0.05 - ); - } - - #[test] - fn test_inner_and_of_filters() { - let cost_model = create_two_table_cost_model( - TestPerColumnStats::new( - TestMostCommonValues::empty(), - 5, - 0.0, - Some(TestDistribution::empty()), - ), - TestPerColumnStats::new( - TestMostCommonValues::empty(), - 4, - 0.0, - Some(TestDistribution::empty()), - ), - ); - let neq12 = bin_op(BinOpType::Neq, col_ref(0), cnst(Value::Int32(12))); - let eq100 = bin_op(BinOpType::Eq, col_ref(1), cnst(Value::Int32(100))); - let expr_tree = log_op(LogOpType::And, vec![neq12.clone(), eq100.clone()]); - let expr_tree_rev = log_op(LogOpType::And, vec![eq100.clone(), neq12.clone()]); - let schema = Schema::new(vec![]); - let column_refs = vec![ - ColumnRef::base_table_column_ref(String::from(TABLE1_NAME), 0), - ColumnRef::base_table_column_ref(String::from(TABLE2_NAME), 0), - ]; - assert_approx_eq::assert_approx_eq!( - test_get_join_selectivity( - &cost_model, - false, - JoinType::Inner, - expr_tree, - &schema, - &column_refs, - None, - ), - 0.2 - ); - assert_approx_eq::assert_approx_eq!( - test_get_join_selectivity( - &cost_model, - false, - JoinType::Inner, - expr_tree_rev, - &schema, - &column_refs, - None - ), - 0.2 - ); - } - - #[test] - fn test_inner_colref_eq_colref_same_table_is_not_oncond() { - let cost_model = create_two_table_cost_model( - TestPerColumnStats::new( - TestMostCommonValues::empty(), - 5, - 0.0, - Some(TestDistribution::empty()), - ), - TestPerColumnStats::new( - TestMostCommonValues::empty(), - 4, - 0.0, - Some(TestDistribution::empty()), - ), - ); - let expr_tree = bin_op(BinOpType::Eq, col_ref(0), col_ref(0)); - let schema = Schema::new(vec![]); - let column_refs = vec![ - ColumnRef::base_table_column_ref(String::from(TABLE1_NAME), 0), - ColumnRef::base_table_column_ref(String::from(TABLE2_NAME), 0), - ]; - assert_approx_eq::assert_approx_eq!( - test_get_join_selectivity( - &cost_model, - false, - JoinType::Inner, - expr_tree, - &schema, - &column_refs, - None - ), - DEFAULT_EQ_SEL - ); - } - - // We don't test joinsel or with oncond because if there is an oncond (on condition), the - // top-level operator must be an AND - - /// I made this helper function to avoid copying all eight lines over and over - fn assert_outer_selectivities( - cost_model: &TestOptCostModel, - expr_tree: ArcDfPredNode, - expr_tree_rev: ArcDfPredNode, - schema: &Schema, - column_refs: &BaseTableColumnRefs, - expected_table1_outer_sel: f64, - expected_table2_outer_sel: f64, - ) { - // all table 1 outer combinations - assert_approx_eq::assert_approx_eq!( - test_get_join_selectivity( - cost_model, - false, - JoinType::LeftOuter, - expr_tree.clone(), - schema, - column_refs, - None - ), - expected_table1_outer_sel - ); - assert_approx_eq::assert_approx_eq!( - test_get_join_selectivity( - cost_model, - false, - JoinType::LeftOuter, - expr_tree_rev.clone(), - schema, - column_refs, - None - ), - expected_table1_outer_sel - ); - assert_approx_eq::assert_approx_eq!( - test_get_join_selectivity( - cost_model, - true, - JoinType::RightOuter, - expr_tree.clone(), - schema, - column_refs, - None - ), - expected_table1_outer_sel - ); - assert_approx_eq::assert_approx_eq!( - test_get_join_selectivity( - cost_model, - true, - JoinType::RightOuter, - expr_tree_rev.clone(), - schema, - column_refs, - None - ), - expected_table1_outer_sel - ); - // all table 2 outer combinations - assert_approx_eq::assert_approx_eq!( - test_get_join_selectivity( - cost_model, - true, - JoinType::LeftOuter, - expr_tree.clone(), - schema, - column_refs, - None - ), - expected_table2_outer_sel - ); - assert_approx_eq::assert_approx_eq!( - test_get_join_selectivity( - cost_model, - true, - JoinType::LeftOuter, - expr_tree_rev.clone(), - schema, - column_refs, - None - ), - expected_table2_outer_sel - ); - assert_approx_eq::assert_approx_eq!( - test_get_join_selectivity( - cost_model, - false, - JoinType::RightOuter, - expr_tree.clone(), - schema, - column_refs, - None - ), - expected_table2_outer_sel - ); - assert_approx_eq::assert_approx_eq!( - test_get_join_selectivity( - cost_model, - false, - JoinType::RightOuter, - expr_tree_rev.clone(), - schema, - column_refs, - None - ), - expected_table2_outer_sel - ); - } - - /// Unique oncond means an oncondition on columns which are unique in both tables - /// There's only one case if both columns are unique and have different row counts: the inner - /// will be < 1 / row count of one table and = 1 / row count of another - #[test] - fn test_outer_unique_oncond() { - let cost_model = create_two_table_cost_model_custom_row_cnts( - TestPerColumnStats::new( - TestMostCommonValues::empty(), - 5, - 0.0, - Some(TestDistribution::empty()), - ), - TestPerColumnStats::new( - TestMostCommonValues::empty(), - 4, - 0.0, - Some(TestDistribution::empty()), - ), - 5, - 4, - ); - // the left/right of the join refers to the tables, not the order of columns in the - // predicate - let expr_tree = bin_op(BinOpType::Eq, col_ref(0), col_ref(1)); - let expr_tree_rev = bin_op(BinOpType::Eq, col_ref(1), col_ref(0)); - let schema = Schema::new(vec![]); - let column_refs = vec![ - ColumnRef::base_table_column_ref(String::from(TABLE1_NAME), 0), - ColumnRef::base_table_column_ref(String::from(TABLE2_NAME), 0), - ]; - // sanity check the expected inner sel - let expected_inner_sel = 0.2; - assert_approx_eq::assert_approx_eq!( - test_get_join_selectivity( - &cost_model, - false, - JoinType::Inner, - expr_tree.clone(), - &schema, - &column_refs, - None - ), - expected_inner_sel - ); - assert_approx_eq::assert_approx_eq!( - test_get_join_selectivity( - &cost_model, - false, - JoinType::Inner, - expr_tree_rev.clone(), - &schema, - &column_refs, - None - ), - expected_inner_sel - ); - // check the outer sels - assert_outer_selectivities( - &cost_model, - expr_tree, - expr_tree_rev, - &schema, - &column_refs, - 0.25, - 0.2, - ); - } - - /// Non-unique oncond means the column is not unique in either table - /// Inner always >= row count means that the inner join result is >= 1 / the row count of both - /// tables - #[test] - fn test_outer_nonunique_oncond_inner_always_geq_rowcnt() { - let cost_model = create_two_table_cost_model_custom_row_cnts( - TestPerColumnStats::new( - TestMostCommonValues::empty(), - 5, - 0.0, - Some(TestDistribution::empty()), - ), - TestPerColumnStats::new( - TestMostCommonValues::empty(), - 4, - 0.0, - Some(TestDistribution::empty()), - ), - 10, - 8, - ); - // the left/right of the join refers to the tables, not the order of columns in the - // predicate - let expr_tree = bin_op(BinOpType::Eq, col_ref(0), col_ref(1)); - let expr_tree_rev = bin_op(BinOpType::Eq, col_ref(1), col_ref(0)); - let schema = Schema::new(vec![]); - let column_refs = vec![ - ColumnRef::base_table_column_ref(String::from(TABLE1_NAME), 0), - ColumnRef::base_table_column_ref(String::from(TABLE2_NAME), 0), - ]; - // sanity check the expected inner sel - let expected_inner_sel = 0.2; - assert_approx_eq::assert_approx_eq!( - test_get_join_selectivity( - &cost_model, - false, - JoinType::Inner, - expr_tree.clone(), - &schema, - &column_refs, - None - ), - expected_inner_sel - ); - assert_approx_eq::assert_approx_eq!( - test_get_join_selectivity( - &cost_model, - false, - JoinType::Inner, - expr_tree_rev.clone(), - &schema, - &column_refs, - None - ), - expected_inner_sel - ); - // check the outer sels - assert_outer_selectivities( - &cost_model, - expr_tree, - expr_tree_rev, - &schema, - &column_refs, - 0.2, - 0.2, - ); - } - - /// Non-unique oncond means the column is not unique in either table - /// Inner sometimes < row count means that the inner join result < 1 / the row count of exactly - /// one table. Note that without a join filter, it's impossible to be less than the row - /// count of both tables - #[test] - fn test_outer_nonunique_oncond_inner_sometimes_lt_rowcnt() { - let cost_model = create_two_table_cost_model_custom_row_cnts( - TestPerColumnStats::new( - TestMostCommonValues::empty(), - 10, - 0.0, - Some(TestDistribution::empty()), - ), - TestPerColumnStats::new( - TestMostCommonValues::empty(), - 2, - 0.0, - Some(TestDistribution::empty()), - ), - 20, - 4, - ); - // the left/right of the join refers to the tables, not the order of columns in the - // predicate - let expr_tree = bin_op(BinOpType::Eq, col_ref(0), col_ref(1)); - let expr_tree_rev = bin_op(BinOpType::Eq, col_ref(1), col_ref(0)); - let schema = Schema::new(vec![]); - let column_refs = vec![ - ColumnRef::base_table_column_ref(String::from(TABLE1_NAME), 0), - ColumnRef::base_table_column_ref(String::from(TABLE2_NAME), 0), - ]; - // sanity check the expected inner sel - let expected_inner_sel = 0.1; - assert_approx_eq::assert_approx_eq!( - test_get_join_selectivity( - &cost_model, - false, - JoinType::Inner, - expr_tree.clone(), - &schema, - &column_refs, - None - ), - expected_inner_sel - ); - assert_approx_eq::assert_approx_eq!( - test_get_join_selectivity( - &cost_model, - false, - JoinType::Inner, - expr_tree_rev.clone(), - &schema, - &column_refs, - None - ), - expected_inner_sel - ); - // check the outer sels - assert_outer_selectivities( - &cost_model, - expr_tree, - expr_tree_rev, - &schema, - &column_refs, - 0.25, - 0.1, - ); - } - - /// Unique oncond means an oncondition on columns which are unique in both tables - /// Filter means we're adding a join filter - /// There's only one case if both columns are unique and there's a filter: - /// the inner will be < 1 / row count of both tables - #[test] - fn test_outer_unique_oncond_filter() { - let cost_model = create_two_table_cost_model_custom_row_cnts( - TestPerColumnStats::new( - TestMostCommonValues::empty(), - 50, - 0.0, - Some(TestDistribution::new(vec![(Value::Int32(128), 0.4)])), - ), - TestPerColumnStats::new( - TestMostCommonValues::empty(), - 4, - 0.0, - Some(TestDistribution::empty()), - ), - 50, - 4, - ); - // the left/right of the join refers to the tables, not the order of columns in the - // predicate - let eq0and1 = bin_op(BinOpType::Eq, col_ref(0), col_ref(1)); - let eq1and0 = bin_op(BinOpType::Eq, col_ref(1), col_ref(0)); - let filter = bin_op(BinOpType::Leq, col_ref(0), cnst(Value::Int32(128))); - let expr_tree = log_op(LogOpType::And, vec![eq0and1, filter.clone()]); - // inner rev means its the inner expr (the eq op) whose children are being reversed, as - // opposed to the and op - let expr_tree_inner_rev = log_op(LogOpType::And, vec![eq1and0, filter.clone()]); - let schema = Schema::new(vec![]); - let column_refs = vec![ - ColumnRef::base_table_column_ref(String::from(TABLE1_NAME), 0), - ColumnRef::base_table_column_ref(String::from(TABLE2_NAME), 0), - ]; - // sanity check the expected inner sel - let expected_inner_sel = 0.008; - assert_approx_eq::assert_approx_eq!( - test_get_join_selectivity( - &cost_model, - false, - JoinType::Inner, - expr_tree.clone(), - &schema, - &column_refs, - None - ), - expected_inner_sel - ); - assert_approx_eq::assert_approx_eq!( - test_get_join_selectivity( - &cost_model, - false, - JoinType::Inner, - expr_tree_inner_rev.clone(), - &schema, - &column_refs, - None - ), - expected_inner_sel - ); - // check the outer sels - assert_outer_selectivities( - &cost_model, - expr_tree, - expr_tree_inner_rev, - &schema, - &column_refs, - 0.25, - 0.02, - ); - } - - /// Test all possible permutations of three-table joins. - /// A three-table join consists of at least two joins. `join1_on_cond` is the condition of the - /// first join. There can only be one condition because only two tables are involved at - /// the time of the first join. - #[test_case::test_case(&[(0, 1)])] - #[test_case::test_case(&[(0, 2)])] - #[test_case::test_case(&[(1, 2)])] - #[test_case::test_case(&[(0, 1), (0, 2)])] - #[test_case::test_case(&[(0, 1), (1, 2)])] - #[test_case::test_case(&[(0, 2), (1, 2)])] - #[test_case::test_case(&[(0, 1), (0, 2), (1, 2)])] - fn test_three_table_join_for_initial_join_on_conds(initial_join_on_conds: &[(usize, usize)]) { - assert!( - !initial_join_on_conds.is_empty(), - "initial_join_on_conds should be non-empty" - ); - assert_eq!( - initial_join_on_conds.len(), - initial_join_on_conds.iter().collect::>().len(), - "initial_join_on_conds shouldn't contain duplicates" - ); - let cost_model = create_three_table_cost_model( - TestPerColumnStats::new( - TestMostCommonValues::empty(), - 2, - 0.0, - Some(TestDistribution::empty()), - ), - TestPerColumnStats::new( - TestMostCommonValues::empty(), - 3, - 0.0, - Some(TestDistribution::empty()), - ), - TestPerColumnStats::new( - TestMostCommonValues::empty(), - 4, - 0.0, - Some(TestDistribution::empty()), - ), - ); - let col_base_refs = vec![ - BaseTableColumnRef { - table: String::from(TABLE1_NAME), - col_idx: 0, - }, - BaseTableColumnRef { - table: String::from(TABLE2_NAME), - col_idx: 0, - }, - BaseTableColumnRef { - table: String::from(TABLE3_NAME), - col_idx: 0, - }, - ]; - let col_refs: BaseTableColumnRefs = col_base_refs - .clone() - .into_iter() - .map(|col_base_ref| col_base_ref.into()) - .collect(); - - let mut eq_columns = EqBaseTableColumnSets::new(); - for initial_join_on_cond in initial_join_on_conds { - eq_columns.add_predicate(EqPredicate::new( - col_base_refs[initial_join_on_cond.0].clone(), - col_base_refs[initial_join_on_cond.1].clone(), - )); - } - let initial_selectivity = { - if initial_join_on_conds.len() == 1 { - let initial_join_on_cond = initial_join_on_conds.first().unwrap(); - if initial_join_on_cond == &(0, 1) { - 1.0 / 3.0 - } else if initial_join_on_cond == &(0, 2) || initial_join_on_cond == &(1, 2) { - 1.0 / 4.0 - } else { - panic!(); - } - } else { - 1.0 / 12.0 + ) -> CostModelResult { + let mut past_eq_columns = input_correlation.unwrap_or_default(); + + // Multiply the selectivities of all individual conditions together + let mut selectivity = 1.0; + for on_col_ref_pair in on_col_ref_pairs { + let left_col_ref = &column_refs[on_col_ref_pair.0.attr_index() as usize]; + let right_col_ref = + &column_refs[on_col_ref_pair.1.attr_index() as usize + right_col_ref_offset]; + + if let (AttrRef::BaseTableAttrRef(left), AttrRef::BaseTableAttrRef(right)) = + (left_col_ref, right_col_ref) + { + let predicate = EqPredicate::new(left.clone(), right.clone()); + return self + .get_join_selectivity_adjustment_when_adding_to_multi_equality_graph( + &predicate, + &mut past_eq_columns, + ) + .await; } - }; - let semantic_correlation = SemanticCorrelation::new(eq_columns); - let schema = Schema::new(vec![]); - let column_refs = col_refs; - let input_correlation = Some(semantic_correlation); - // Try all join conditions of the final join which would lead to all three tables being - // joined. - let eq0and1 = bin_op(BinOpType::Eq, col_ref(0), col_ref(1)); - let eq0and2 = bin_op(BinOpType::Eq, col_ref(0), col_ref(2)); - let eq1and2 = bin_op(BinOpType::Eq, col_ref(1), col_ref(2)); - let and_01_02 = log_op(LogOpType::And, vec![eq0and1.clone(), eq0and2.clone()]); - let and_01_12 = log_op(LogOpType::And, vec![eq0and1.clone(), eq1and2.clone()]); - let and_02_12 = log_op(LogOpType::And, vec![eq0and2.clone(), eq1and2.clone()]); - let and_01_02_12 = log_op( - LogOpType::And, - vec![eq0and1.clone(), eq0and2.clone(), eq1and2.clone()], - ); - let mut join2_expr_trees = vec![and_01_02, and_01_12, and_02_12, and_01_02_12]; - if initial_join_on_conds.len() == 1 { - let initial_join_on_cond = initial_join_on_conds.first().unwrap(); - if initial_join_on_cond == &(0, 1) { - join2_expr_trees.push(eq0and2); - join2_expr_trees.push(eq1and2); - } else if initial_join_on_cond == &(0, 2) { - join2_expr_trees.push(eq0and1); - join2_expr_trees.push(eq1and2); - } else if initial_join_on_cond == &(1, 2) { - join2_expr_trees.push(eq0and1); - join2_expr_trees.push(eq0and2); - } else { - panic!(); - } - } - for expr_tree in join2_expr_trees { - let overall_selectivity = initial_selectivity - * test_get_join_selectivity( - &cost_model, - false, - JoinType::Inner, - expr_tree.clone(), - &schema, - &column_refs, - input_correlation.clone(), - ); - assert_approx_eq::assert_approx_eq!(overall_selectivity, 1.0 / 12.0); + selectivity *= self + .get_join_selectivity_from_on_col_ref_pair(left_col_ref, right_col_ref) + .await?; } - } - - #[test] - fn test_join_which_connects_two_components_together() { - let cost_model = create_four_table_cost_model( - TestPerColumnStats::new( - TestMostCommonValues::empty(), - 2, - 0.0, - Some(TestDistribution::empty()), - ), - TestPerColumnStats::new( - TestMostCommonValues::empty(), - 3, - 0.0, - Some(TestDistribution::empty()), - ), - TestPerColumnStats::new( - TestMostCommonValues::empty(), - 4, - 0.0, - Some(TestDistribution::empty()), - ), - TestPerColumnStats::new( - TestMostCommonValues::empty(), - 5, - 0.0, - Some(TestDistribution::empty()), - ), - ); - let col_base_refs = vec![ - BaseTableColumnRef { - table: String::from(TABLE1_NAME), - col_idx: 0, - }, - BaseTableColumnRef { - table: String::from(TABLE2_NAME), - col_idx: 0, - }, - BaseTableColumnRef { - table: String::from(TABLE3_NAME), - col_idx: 0, - }, - BaseTableColumnRef { - table: String::from(TABLE4_NAME), - col_idx: 0, - }, - ]; - let col_refs: BaseTableColumnRefs = col_base_refs - .clone() - .into_iter() - .map(|col_base_ref| col_base_ref.into()) - .collect(); - - let mut eq_columns = EqBaseTableColumnSets::new(); - eq_columns.add_predicate(EqPredicate::new( - col_base_refs[0].clone(), - col_base_refs[1].clone(), - )); - eq_columns.add_predicate(EqPredicate::new( - col_base_refs[2].clone(), - col_base_refs[3].clone(), - )); - let initial_selectivity = 1.0 / (3.0 * 5.0); - let semantic_correlation = SemanticCorrelation::new(eq_columns); - let schema = Schema::new(vec![]); - let column_refs = col_refs; - let input_correlation = Some(semantic_correlation); - - let eq1and2 = bin_op(BinOpType::Eq, col_ref(1), col_ref(2)); - let overall_selectivity = initial_selectivity - * test_get_join_selectivity( - &cost_model, - false, - JoinType::Inner, - eq1and2.clone(), - &schema, - &column_refs, - input_correlation, - ); - assert_approx_eq::assert_approx_eq!(overall_selectivity, 1.0 / (3.0 * 4.0 * 5.0)); + Ok(selectivity) } } diff --git a/optd-cost-model/src/cost_model.rs b/optd-cost-model/src/cost_model.rs index 4064562..b1b9aac 100644 --- a/optd-cost-model/src/cost_model.rs +++ b/optd-cost-model/src/cost_model.rs @@ -128,7 +128,7 @@ pub mod tests { common::{ nodes::ReprPredicateNode, predicates::{ - attr_ref_pred::AttributeRefPred, + attr_ref_pred::AttrRefPred, bin_op_pred::{BinOpPred, BinOpType}, cast_pred::CastPred, constant_pred::ConstantPred, @@ -183,7 +183,7 @@ pub mod tests { } pub fn attr_ref(table_id: TableId, attr_base_index: u64) -> ArcPredicateNode { - AttributeRefPred::new(table_id, attr_base_index).into_pred_node() + AttrRefPred::new(table_id, attr_base_index).into_pred_node() } pub fn cnst(value: Value) -> ArcPredicateNode { From a4ff5269d0d568edf66be6b6e404606b4cd8332f Mon Sep 17 00:00:00 2001 From: Yuanxin Cao Date: Mon, 18 Nov 2024 12:03:39 -0500 Subject: [PATCH 36/51] rename col -> attr --- optd-cost-model/src/cost/join.rs | 192 +++++++++++++++---------------- 1 file changed, 96 insertions(+), 96 deletions(-) diff --git a/optd-cost-model/src/cost/join.rs b/optd-cost-model/src/cost/join.rs index e313603..e55dc40 100644 --- a/optd-cost-model/src/cost/join.rs +++ b/optd-cost-model/src/cost/join.rs @@ -33,18 +33,18 @@ impl CostModelImpl { left_row_cnt: f64, right_row_cnt: f64, output_schema: Schema, - output_column_refs: GroupAttrRefs, + output_attr_refs: GroupAttrRefs, join_cond: ArcPredicateNode, - left_column_refs: GroupAttrRefs, - right_column_refs: GroupAttrRefs, + left_attr_refs: GroupAttrRefs, + right_attr_refs: GroupAttrRefs, ) -> CostModelResult { let selectivity = { - let input_correlation = self.get_input_correlation(left_column_refs, right_column_refs); + let input_correlation = self.get_input_correlation(left_attr_refs, right_attr_refs); self.get_join_selectivity_from_expr_tree( join_typ, join_cond, &output_schema, - output_column_refs.base_table_attr_refs(), + output_attr_refs.base_table_attr_refs(), input_correlation, left_row_cnt, right_row_cnt, @@ -63,28 +63,28 @@ impl CostModelImpl { left_keys: ListPred, right_keys: ListPred, output_schema: Schema, - output_column_refs: GroupAttrRefs, - left_column_refs: GroupAttrRefs, - right_column_refs: GroupAttrRefs, + output_attr_refs: GroupAttrRefs, + left_attr_refs: GroupAttrRefs, + right_attr_refs: GroupAttrRefs, ) -> CostModelResult { let selectivity = { let schema = output_schema; - let column_refs = output_column_refs; - let column_refs = column_refs.base_table_attr_refs(); - let left_col_cnt = left_column_refs.base_table_attr_refs().len(); + let attr_refs = output_attr_refs; + let attr_refs = attr_refs.base_table_attr_refs(); + let left_attr_cnt = left_attr_refs.base_table_attr_refs().len(); // there may be more than one expression tree in a group. // see comment in PredicateType::PhysicalFilter(_) for more information - let input_correlation = self.get_input_correlation(left_column_refs, right_column_refs); + let input_correlation = self.get_input_correlation(left_attr_refs, right_attr_refs); self.get_join_selectivity_from_keys( join_typ, left_keys, right_keys, &schema, - column_refs, + attr_refs, input_correlation, left_row_cnt, right_row_cnt, - left_col_cnt, + left_attr_cnt, ) .await? }; @@ -110,16 +110,16 @@ impl CostModelImpl { left_keys: ListPred, right_keys: ListPred, schema: &Schema, - column_refs: &AttrRefs, + attr_refs: &AttrRefs, input_correlation: Option, left_row_cnt: f64, right_row_cnt: f64, - left_col_cnt: usize, + left_attr_cnt: usize, ) -> CostModelResult { assert!(left_keys.len() == right_keys.len()); // I assume that the keys are already in the right order // s.t. the ith key of left_keys corresponds with the ith key of right_keys - let on_col_ref_pairs = left_keys + let on_attr_ref_pairs = left_keys .to_vec() .into_iter() .zip(right_keys.to_vec()) @@ -132,14 +132,14 @@ impl CostModelImpl { .collect_vec(); self.get_join_selectivity_core( join_typ, - on_col_ref_pairs, + on_attr_ref_pairs, None, schema, - column_refs, + attr_refs, input_correlation, left_row_cnt, right_row_cnt, - left_col_cnt, + left_attr_cnt, ) .await } @@ -147,34 +147,34 @@ impl CostModelImpl { /// The core logic of join selectivity which assumes we've already separated the expression /// into the on conditions and the filters. /// - /// Hash join and NLJ reference right table columns differently, hence the - /// `right_col_ref_offset` parameter. + /// Hash join and NLJ reference right table attributes differently, hence the + /// `right_attr_ref_offset` parameter. /// - /// For hash join, the right table columns indices are with respect to the right table, - /// which means #0 is the first column of the right table. + /// For hash join, the right table attributes indices are with respect to the right table, + /// which means #0 is the first attribute of the right table. /// - /// For NLJ, the right table columns indices are with respect to the output of the join. - /// For example, if the left table has 3 columns, the first column of the right table + /// For NLJ, the right table attributes indices are with respect to the output of the join. + /// For example, if the left table has 3 attributes, the first attribute of the right table /// is #3 instead of #0. #[allow(clippy::too_many_arguments)] async fn get_join_selectivity_core( &self, join_typ: JoinType, - on_col_ref_pairs: Vec<(AttrRefPred, AttrRefPred)>, + on_attr_ref_pairs: Vec<(AttrRefPred, AttrRefPred)>, filter_expr_tree: Option, schema: &Schema, - column_refs: &AttrRefs, + attr_refs: &AttrRefs, input_correlation: Option, left_row_cnt: f64, right_row_cnt: f64, - right_col_ref_offset: usize, + right_attr_ref_offset: usize, ) -> CostModelResult { let join_on_selectivity = self .get_join_on_selectivity( - &on_col_ref_pairs, - column_refs, + &on_attr_ref_pairs, + attr_refs, input_correlation, - right_col_ref_offset, + right_attr_ref_offset, ) .await?; // Currently, there is no difference in how we handle a join filter and a select filter, @@ -198,8 +198,8 @@ impl CostModelImpl { JoinType::RightOuter => f64::max(inner_join_selectivity, 1.0 / left_row_cnt), JoinType::Cross => { assert!( - on_col_ref_pairs.is_empty(), - "Cross joins should not have on columns" + on_attr_ref_pairs.is_empty(), + "Cross joins should not have on attributes" ); join_filter_selectivity } @@ -218,25 +218,25 @@ impl CostModelImpl { join_typ: JoinType, expr_tree: ArcPredicateNode, schema: &Schema, - column_refs: &AttrRefs, + attr_refs: &AttrRefs, input_correlation: Option, left_row_cnt: f64, right_row_cnt: f64, ) -> CostModelResult { if expr_tree.typ == PredicateType::LogOp(LogOpType::And) { - let mut on_col_ref_pairs = vec![]; + let mut on_attr_ref_pairs = vec![]; let mut filter_expr_trees = vec![]; for child_expr_tree in &expr_tree.children { - if let Some(on_col_ref_pair) = - Self::get_on_col_ref_pair(child_expr_tree.clone(), column_refs) + if let Some(on_attr_ref_pair) = + Self::get_on_attr_ref_pair(child_expr_tree.clone(), attr_refs) { - on_col_ref_pairs.push(on_col_ref_pair) + on_attr_ref_pairs.push(on_attr_ref_pair) } else { let child_expr = child_expr_tree.clone(); filter_expr_trees.push(child_expr); } } - assert!(on_col_ref_pairs.len() + filter_expr_trees.len() == expr_tree.children.len()); + assert!(on_attr_ref_pairs.len() + filter_expr_trees.len() == expr_tree.children.len()); let filter_expr_tree = if filter_expr_trees.is_empty() { None } else { @@ -244,10 +244,10 @@ impl CostModelImpl { }; self.get_join_selectivity_core( join_typ, - on_col_ref_pairs, + on_attr_ref_pairs, filter_expr_tree, schema, - column_refs, + attr_refs, input_correlation, left_row_cnt, right_row_cnt, @@ -256,14 +256,14 @@ impl CostModelImpl { .await } else { #[allow(clippy::collapsible_else_if)] - if let Some(on_col_ref_pair) = Self::get_on_col_ref_pair(expr_tree.clone(), column_refs) + if let Some(on_attr_ref_pair) = Self::get_on_attr_ref_pair(expr_tree.clone(), attr_refs) { self.get_join_selectivity_core( join_typ, - vec![on_col_ref_pair], + vec![on_attr_ref_pair], None, schema, - column_refs, + attr_refs, input_correlation, left_row_cnt, right_row_cnt, @@ -276,7 +276,7 @@ impl CostModelImpl { vec![], Some(expr_tree), schema, - column_refs, + attr_refs, input_correlation, left_row_cnt, right_row_cnt, @@ -287,29 +287,29 @@ impl CostModelImpl { } } - /// Check if an expr_tree is a join condition, returning the join on col ref pair if it is. + /// Check if an expr_tree is a join condition, returning the join on attr ref pair if it is. /// The reason the check and the info are in the same function is because their code is almost - /// identical. It only picks out equality conditions between two column refs on different + /// identical. It only picks out equality conditions between two attribute refs on different /// tables - fn get_on_col_ref_pair( + fn get_on_attr_ref_pair( expr_tree: ArcPredicateNode, - column_refs: &AttrRefs, + attr_refs: &AttrRefs, ) -> Option<(AttrRefPred, AttrRefPred)> { // 1. Check that it's equality if expr_tree.typ == PredicateType::BinOp(BinOpType::Eq) { let left_child = expr_tree.child(0); let right_child = expr_tree.child(1); - // 2. Check that both sides are column refs + // 2. Check that both sides are attribute refs if left_child.typ == PredicateType::AttrRef && right_child.typ == PredicateType::AttrRef { // 3. Check that both sides don't belong to the same table (if we don't know, that // means they don't belong) - let left_col_ref_expr = AttrRefPred::from_pred_node(left_child) + let left_attr_ref_expr = AttrRefPred::from_pred_node(left_child) .expect("we already checked that the type is AttrRef"); - let right_col_ref_expr = AttrRefPred::from_pred_node(right_child) + let right_attr_ref_expr = AttrRefPred::from_pred_node(right_child) .expect("we already checked that the type is AttrRef"); - let left_col_ref = &column_refs[left_col_ref_expr.attr_index() as usize]; - let right_col_ref = &column_refs[right_col_ref_expr.attr_index() as usize]; + let left_attr_ref = &attr_refs[left_attr_ref_expr.attr_index() as usize]; + let right_attr_ref = &attr_refs[right_attr_ref_expr.attr_index() as usize]; let is_same_table = if let ( AttrRef::BaseTableAttrRef(BaseTableAttrRef { table_id: left_table_id, @@ -319,14 +319,14 @@ impl CostModelImpl { table_id: right_table_id, .. }), - ) = (left_col_ref, right_col_ref) + ) = (left_attr_ref, right_attr_ref) { left_table_id == right_table_id } else { false }; if !is_same_table { - Some((left_col_ref_expr, right_col_ref_expr)) + Some((left_attr_ref_expr, right_attr_ref_expr)) } else { None } @@ -338,8 +338,8 @@ impl CostModelImpl { } } - /// Get the selectivity of one column eq predicate, e.g. colA = colB. - async fn get_join_selectivity_from_on_col_ref_pair( + /// Get the selectivity of one attribute eq predicate, e.g. attrA = attrB. + async fn get_join_selectivity_from_on_attr_ref_pair( &self, left: &AttrRef, right: &AttrRef, @@ -354,7 +354,7 @@ impl CostModelImpl { .get_attribute_comb_stats(base_attr_ref.table_id, &[base_attr_ref.attr_idx]) .await? { - Some(per_col_stats) => per_col_stats.ndistinct, + Some(per_attr_stats) => per_attr_stats.ndistinct, None => DEFAULT_NUM_DISTINCT, } } @@ -365,7 +365,7 @@ impl CostModelImpl { // using reduce(f64::min) is the idiomatic workaround to min() because // f64 does not implement Ord due to NaN - let selectivity = ndistincts.into_iter().map(|ndistinct| 1.0 / ndistinct as f64).reduce(f64::min).expect("reduce() only returns None if the iterator is empty, which is impossible since col_ref_exprs.len() == 2"); + let selectivity = ndistincts.into_iter().map(|ndistinct| 1.0 / ndistinct as f64).reduce(f64::min).expect("reduce() only returns None if the iterator is empty, which is impossible since attr_ref_exprs.len() == 2"); assert!( !selectivity.is_nan(), "it should be impossible for selectivity to be NaN since n-distinct is never 0" @@ -373,11 +373,11 @@ impl CostModelImpl { Ok(selectivity) } - /// Given a set of N columns involved in a multi-equality, find the total selectivity + /// Given a set of N attributes involved in a multi-equality, find the total selectivity /// of the multi-equality. /// - /// This is a generalization of get_join_selectivity_from_on_col_ref_pair(). - async fn get_join_selectivity_from_most_selective_columns( + /// This is a generalization of get_join_selectivity_from_on_attr_ref_pair(). + async fn get_join_selectivity_from_most_selective_attrs( &self, base_attr_refs: HashSet, ) -> CostModelResult { @@ -390,7 +390,7 @@ impl CostModelImpl { .get_attribute_comb_stats(base_attr_ref.table_id, &[base_attr_ref.attr_idx]) .await? { - Some(per_col_stats) => per_col_stats.ndistinct, + Some(per_attr_stats) => per_attr_stats.ndistinct, None => DEFAULT_NUM_DISTINCT, }; ndistincts.push(ndistinct); @@ -408,15 +408,15 @@ impl CostModelImpl { } /// A predicate set defines a "multi-equality graph", which is an unweighted undirected graph. - /// The nodes are columns while edges are predicates. The old graph is defined by - /// `past_eq_columns` while the `predicate` is the new addition to this graph. This + /// The nodes are attributes while edges are predicates. The old graph is defined by + /// `past_eq_attrs` while the `predicate` is the new addition to this graph. This /// unweighted undirected graph consists of a number of connected components, where each - /// connected component represents columns that are set to be equal to each other. Single + /// connected component represents attributes that are set to be equal to each other. Single /// nodes not connected to anything are considered standalone connected components. /// /// The selectivity of each connected component of N nodes is equal to the product of /// 1/ndistinct of the N-1 nodes with the highest ndistinct values. You can see this if you - /// imagine that all columns being joined are unique columns and that they follow the + /// imagine that all attributes being joined are unique attributes and that they follow the /// inclusion principle (every element of the smaller tables is present in the larger /// tables). When these assumptions are not true, the selectivity may not be completely /// accurate. However, it is still fairly accurate. @@ -428,11 +428,11 @@ impl CostModelImpl { /// function) and then the selectivity of the connected component after this join. The /// quotient is the "adjustment" factor. /// - /// NOTE: This function modifies `past_eq_columns` by adding `predicate` to it. + /// NOTE: This function modifies `past_eq_attrs` by adding `predicate` to it. async fn get_join_selectivity_adjustment_when_adding_to_multi_equality_graph( &self, predicate: &EqPredicate, - past_eq_columns: &mut SemanticCorrelation, + past_eq_attrs: &mut SemanticCorrelation, ) -> CostModelResult { if predicate.left == predicate.right { // self-join, TODO: is this correct? @@ -443,26 +443,26 @@ impl CostModelImpl { // // There are two cases: (1) adding `predicate` does not change the # of connected // components, and (2) adding `predicate` reduces the # of connected by 1. Note that - // columns not involved in any predicates are considered a part of the graph and are + // attributes not involved in any predicates are considered a part of the graph and are // a connected component on their own. let children_pred_sel = { - if past_eq_columns.is_eq(&predicate.left, &predicate.right) { - self.get_join_selectivity_from_most_selective_columns( - past_eq_columns.find_attrs_for_eq_attribute_set(&predicate.left), + if past_eq_attrs.is_eq(&predicate.left, &predicate.right) { + self.get_join_selectivity_from_most_selective_attrs( + past_eq_attrs.find_attrs_for_eq_attribute_set(&predicate.left), ) .await? } else { - let left_sel = if past_eq_columns.contains(&predicate.left) { - self.get_join_selectivity_from_most_selective_columns( - past_eq_columns.find_attrs_for_eq_attribute_set(&predicate.left), + let left_sel = if past_eq_attrs.contains(&predicate.left) { + self.get_join_selectivity_from_most_selective_attrs( + past_eq_attrs.find_attrs_for_eq_attribute_set(&predicate.left), ) .await? } else { 1.0 }; - let right_sel = if past_eq_columns.contains(&predicate.right) { - self.get_join_selectivity_from_most_selective_columns( - past_eq_columns.find_attrs_for_eq_attribute_set(&predicate.right), + let right_sel = if past_eq_attrs.contains(&predicate.right) { + self.get_join_selectivity_from_most_selective_attrs( + past_eq_attrs.find_attrs_for_eq_attribute_set(&predicate.right), ) .await? } else { @@ -472,12 +472,12 @@ impl CostModelImpl { } }; - // Add predicate to past_eq_columns and compute the selectivity of the connected component + // Add predicate to past_eq_attrs and compute the selectivity of the connected component // it creates. - past_eq_columns.add_predicate(predicate.clone()); + past_eq_attrs.add_predicate(predicate.clone()); let new_pred_sel = { - let cols = past_eq_columns.find_attrs_for_eq_attribute_set(&predicate.left); - self.get_join_selectivity_from_most_selective_columns(cols) + let attrs = past_eq_attrs.find_attrs_for_eq_attribute_set(&predicate.left); + self.get_join_selectivity_from_most_selective_attrs(attrs) } .await?; @@ -500,34 +500,34 @@ impl CostModelImpl { /// `get_join_selectivity_from_redundant_predicates`. async fn get_join_on_selectivity( &self, - on_col_ref_pairs: &[(AttrRefPred, AttrRefPred)], - column_refs: &AttrRefs, + on_attr_ref_pairs: &[(AttrRefPred, AttrRefPred)], + attr_refs: &AttrRefs, input_correlation: Option, - right_col_ref_offset: usize, + right_attr_ref_offset: usize, ) -> CostModelResult { - let mut past_eq_columns = input_correlation.unwrap_or_default(); + let mut past_eq_attrs = input_correlation.unwrap_or_default(); // Multiply the selectivities of all individual conditions together let mut selectivity = 1.0; - for on_col_ref_pair in on_col_ref_pairs { - let left_col_ref = &column_refs[on_col_ref_pair.0.attr_index() as usize]; - let right_col_ref = - &column_refs[on_col_ref_pair.1.attr_index() as usize + right_col_ref_offset]; + for on_attr_ref_pair in on_attr_ref_pairs { + let left_attr_ref = &attr_refs[on_attr_ref_pair.0.attr_index() as usize]; + let right_attr_ref = + &attr_refs[on_attr_ref_pair.1.attr_index() as usize + right_attr_ref_offset]; if let (AttrRef::BaseTableAttrRef(left), AttrRef::BaseTableAttrRef(right)) = - (left_col_ref, right_col_ref) + (left_attr_ref, right_attr_ref) { let predicate = EqPredicate::new(left.clone(), right.clone()); return self .get_join_selectivity_adjustment_when_adding_to_multi_equality_graph( &predicate, - &mut past_eq_columns, + &mut past_eq_attrs, ) .await; } selectivity *= self - .get_join_selectivity_from_on_col_ref_pair(left_col_ref, right_col_ref) + .get_join_selectivity_from_on_attr_ref_pair(left_attr_ref, right_attr_ref) .await?; } Ok(selectivity) From 0ba4132ab84e89673e83e4aa90d78e3c6b97bdc4 Mon Sep 17 00:00:00 2001 From: Yuanxin Cao Date: Mon, 18 Nov 2024 12:30:14 -0500 Subject: [PATCH 37/51] refactor join to not pass in logical props --- .../src/common/properties/attr_ref.rs | 8 +-- optd-cost-model/src/cost/join.rs | 49 ++++++++----------- optd-cost-model/src/cost_model.rs | 4 +- 3 files changed, 26 insertions(+), 35 deletions(-) diff --git a/optd-cost-model/src/common/properties/attr_ref.rs b/optd-cost-model/src/common/properties/attr_ref.rs index fea3270..5c73961 100644 --- a/optd-cost-model/src/common/properties/attr_ref.rs +++ b/optd-cost-model/src/common/properties/attr_ref.rs @@ -163,7 +163,7 @@ impl SemanticCorrelation { /// [`GroupAttrRefs`] represents the attributes of a group in a query. #[derive(Clone, Debug)] pub struct GroupAttrRefs { - attribute_refs: AttrRefs, + attr_refs: AttrRefs, /// Correlation of the output attributes of the group. output_correlation: Option, } @@ -171,13 +171,13 @@ pub struct GroupAttrRefs { impl GroupAttrRefs { pub fn new(attribute_refs: AttrRefs, output_correlation: Option) -> Self { Self { - attribute_refs, + attr_refs: attribute_refs, output_correlation, } } - pub fn base_table_attr_refs(&self) -> &AttrRefs { - &self.attribute_refs + pub fn attr_refs(&self) -> &AttrRefs { + &self.attr_refs } pub fn output_correlation(&self) -> Option<&SemanticCorrelation> { diff --git a/optd-cost-model/src/cost/join.rs b/optd-cost-model/src/cost/join.rs index e55dc40..cb11a6a 100644 --- a/optd-cost-model/src/cost/join.rs +++ b/optd-cost-model/src/cost/join.rs @@ -11,13 +11,11 @@ use crate::{ list_pred::ListPred, log_op_pred::{LogOpPred, LogOpType}, }, - properties::{ - attr_ref::{ - self, AttrRef, AttrRefs, BaseTableAttrRef, EqPredicate, GroupAttrRefs, - SemanticCorrelation, - }, - schema::Schema, + properties::attr_ref::{ + self, AttrRef, AttrRefs, BaseTableAttrRef, EqPredicate, GroupAttrRefs, + SemanticCorrelation, }, + types::GroupId, }, cost_model::CostModelImpl, stats::DEFAULT_NUM_DISTINCT, @@ -30,21 +28,23 @@ impl CostModelImpl { pub async fn get_nlj_row_cnt( &self, join_typ: JoinType, + group_id: GroupId, left_row_cnt: f64, right_row_cnt: f64, - output_schema: Schema, - output_attr_refs: GroupAttrRefs, + left_group_id: GroupId, + right_group_id: GroupId, join_cond: ArcPredicateNode, - left_attr_refs: GroupAttrRefs, - right_attr_refs: GroupAttrRefs, ) -> CostModelResult { let selectivity = { + let output_attr_refs = self.memo.get_attribute_ref(group_id); + let left_attr_refs = self.memo.get_attribute_ref(left_group_id); + let right_attr_refs = self.memo.get_attribute_ref(right_group_id); let input_correlation = self.get_input_correlation(left_attr_refs, right_attr_refs); + self.get_join_selectivity_from_expr_tree( join_typ, join_cond, - &output_schema, - output_attr_refs.base_table_attr_refs(), + output_attr_refs.attr_refs(), input_correlation, left_row_cnt, right_row_cnt, @@ -58,20 +58,19 @@ impl CostModelImpl { pub async fn get_hash_join_row_cnt( &self, join_typ: JoinType, + group_id: GroupId, left_row_cnt: f64, right_row_cnt: f64, + left_group_id: GroupId, + right_group_id: GroupId, left_keys: ListPred, right_keys: ListPred, - output_schema: Schema, - output_attr_refs: GroupAttrRefs, - left_attr_refs: GroupAttrRefs, - right_attr_refs: GroupAttrRefs, ) -> CostModelResult { let selectivity = { - let schema = output_schema; - let attr_refs = output_attr_refs; - let attr_refs = attr_refs.base_table_attr_refs(); - let left_attr_cnt = left_attr_refs.base_table_attr_refs().len(); + let output_attr_refs = self.memo.get_attribute_ref(group_id); + let left_attr_refs = self.memo.get_attribute_ref(left_group_id); + let right_attr_refs = self.memo.get_attribute_ref(right_group_id); + let left_attr_cnt = left_attr_refs.attr_refs().len(); // there may be more than one expression tree in a group. // see comment in PredicateType::PhysicalFilter(_) for more information let input_correlation = self.get_input_correlation(left_attr_refs, right_attr_refs); @@ -79,8 +78,7 @@ impl CostModelImpl { join_typ, left_keys, right_keys, - &schema, - attr_refs, + output_attr_refs.attr_refs(), input_correlation, left_row_cnt, right_row_cnt, @@ -109,7 +107,6 @@ impl CostModelImpl { join_typ: JoinType, left_keys: ListPred, right_keys: ListPred, - schema: &Schema, attr_refs: &AttrRefs, input_correlation: Option, left_row_cnt: f64, @@ -134,7 +131,6 @@ impl CostModelImpl { join_typ, on_attr_ref_pairs, None, - schema, attr_refs, input_correlation, left_row_cnt, @@ -162,7 +158,6 @@ impl CostModelImpl { join_typ: JoinType, on_attr_ref_pairs: Vec<(AttrRefPred, AttrRefPred)>, filter_expr_tree: Option, - schema: &Schema, attr_refs: &AttrRefs, input_correlation: Option, left_row_cnt: f64, @@ -217,7 +212,6 @@ impl CostModelImpl { &self, join_typ: JoinType, expr_tree: ArcPredicateNode, - schema: &Schema, attr_refs: &AttrRefs, input_correlation: Option, left_row_cnt: f64, @@ -246,7 +240,6 @@ impl CostModelImpl { join_typ, on_attr_ref_pairs, filter_expr_tree, - schema, attr_refs, input_correlation, left_row_cnt, @@ -262,7 +255,6 @@ impl CostModelImpl { join_typ, vec![on_attr_ref_pair], None, - schema, attr_refs, input_correlation, left_row_cnt, @@ -275,7 +267,6 @@ impl CostModelImpl { join_typ, vec![], Some(expr_tree), - schema, attr_refs, input_correlation, left_row_cnt, diff --git a/optd-cost-model/src/cost_model.rs b/optd-cost-model/src/cost_model.rs index b1b9aac..4d36c38 100644 --- a/optd-cost-model/src/cost_model.rs +++ b/optd-cost-model/src/cost_model.rs @@ -22,7 +22,7 @@ use crate::{ pub struct CostModelImpl { pub storage_manager: S, pub default_catalog_source: CatalogSource, - _memo: Arc, + pub memo: Arc, } impl CostModelImpl { @@ -35,7 +35,7 @@ impl CostModelImpl { Self { storage_manager, default_catalog_source, - _memo: memo, + memo, } } } From ab15f05df7cfc9e9dfb1818e3d4c80b51893e660 Mon Sep 17 00:00:00 2001 From: Yuanxin Cao Date: Mon, 18 Nov 2024 12:45:21 -0500 Subject: [PATCH 38/51] make statistics f64 instead of u64 --- optd-cost-model/src/cost/agg.rs | 18 +++++++++--------- optd-cost-model/src/cost/filter/controller.rs | 5 +---- optd-cost-model/src/cost/limit.rs | 2 +- optd-cost-model/src/lib.rs | 4 ++-- 4 files changed, 13 insertions(+), 16 deletions(-) diff --git a/optd-cost-model/src/cost/agg.rs b/optd-cost-model/src/cost/agg.rs index f288ebb..1bbf155 100644 --- a/optd-cost-model/src/cost/agg.rs +++ b/optd-cost-model/src/cost/agg.rs @@ -17,7 +17,7 @@ impl CostModelImpl { ) -> CostModelResult { let group_by = ListPred::from_pred_node(group_by).unwrap(); if group_by.is_empty() { - Ok(EstimatedStatistic(1)) + Ok(EstimatedStatistic(1.0)) } else { // Multiply the n-distinct of all the group by columns. // TODO: improve with multi-dimensional n-distinct @@ -57,7 +57,7 @@ impl CostModelImpl { } } } - Ok(EstimatedStatistic(row_cnt)) + Ok(EstimatedStatistic(row_cnt as f64)) } } } @@ -110,21 +110,21 @@ mod tests { let group_bys = empty_list(); assert_eq!( cost_model.get_agg_row_cnt(group_bys).await.unwrap(), - EstimatedStatistic(1) + EstimatedStatistic(1.0) ); // Group by single column should return the default value since there are no stats. let group_bys = list(vec![attr_ref(table_id, 0)]); assert_eq!( cost_model.get_agg_row_cnt(group_bys).await.unwrap(), - EstimatedStatistic(DEFAULT_NUM_DISTINCT) + EstimatedStatistic(DEFAULT_NUM_DISTINCT as f64) ); // Group by two columns should return the default value squared since there are no stats. let group_bys = list(vec![attr_ref(table_id, 0), attr_ref(table_id, 1)]); assert_eq!( cost_model.get_agg_row_cnt(group_bys).await.unwrap(), - EstimatedStatistic(DEFAULT_NUM_DISTINCT * DEFAULT_NUM_DISTINCT) + EstimatedStatistic((DEFAULT_NUM_DISTINCT * DEFAULT_NUM_DISTINCT) as f64) ); } @@ -193,14 +193,14 @@ mod tests { let group_bys = empty_list(); assert_eq!( cost_model.get_agg_row_cnt(group_bys).await.unwrap(), - EstimatedStatistic(1) + EstimatedStatistic(1.0) ); // Group by single column should return the n-distinct of the column. let group_bys = list(vec![attr_ref(table_id, attr1_base_idx)]); assert_eq!( cost_model.get_agg_row_cnt(group_bys).await.unwrap(), - EstimatedStatistic(attr1_ndistinct) + EstimatedStatistic(attr1_ndistinct as f64) ); // Group by two columns should return the product of the n-distinct of the columns. @@ -210,7 +210,7 @@ mod tests { ]); assert_eq!( cost_model.get_agg_row_cnt(group_bys).await.unwrap(), - EstimatedStatistic(attr1_ndistinct * attr2_ndistinct) + EstimatedStatistic((attr1_ndistinct * attr2_ndistinct) as f64) ); // Group by multiple columns should return the product of the n-distinct of the columns. If one of the columns @@ -222,7 +222,7 @@ mod tests { ]); assert_eq!( cost_model.get_agg_row_cnt(group_bys).await.unwrap(), - EstimatedStatistic(attr1_ndistinct * attr2_ndistinct * DEFAULT_NUM_DISTINCT) + EstimatedStatistic((attr1_ndistinct * attr2_ndistinct * DEFAULT_NUM_DISTINCT) as f64) ); } } diff --git a/optd-cost-model/src/cost/filter/controller.rs b/optd-cost-model/src/cost/filter/controller.rs index 3f6ef21..fd9769f 100644 --- a/optd-cost-model/src/cost/filter/controller.rs +++ b/optd-cost-model/src/cost/filter/controller.rs @@ -18,10 +18,7 @@ impl CostModelImpl { cond: ArcPredicateNode, ) -> CostModelResult { let selectivity = { self.get_filter_selectivity(cond).await? }; - Ok( - EstimatedStatistic((child_row_cnt.0 as f64 * selectivity) as u64) - .max(EstimatedStatistic(1)), - ) + Ok(EstimatedStatistic((child_row_cnt.0 * selectivity).max(1.0))) } pub async fn get_filter_selectivity( diff --git a/optd-cost-model/src/cost/limit.rs b/optd-cost-model/src/cost/limit.rs index 38e7550..c63c0e0 100644 --- a/optd-cost-model/src/cost/limit.rs +++ b/optd-cost-model/src/cost/limit.rs @@ -22,7 +22,7 @@ impl CostModelImpl { if fetch == u64::MAX { Ok(child_row_cnt) } else { - Ok(EstimatedStatistic(child_row_cnt.0.min(fetch))) + Ok(EstimatedStatistic(child_row_cnt.0.min(fetch as f64))) } } } diff --git a/optd-cost-model/src/lib.rs b/optd-cost-model/src/lib.rs index e6002e6..13774b2 100644 --- a/optd-cost-model/src/lib.rs +++ b/optd-cost-model/src/lib.rs @@ -33,8 +33,8 @@ pub struct Cost(pub Vec); /// Estimated statistic calculated by the cost model. /// It is the estimated output row count of the targeted expression. -#[derive(Eq, Ord, PartialEq, PartialOrd, Debug)] -pub struct EstimatedStatistic(pub u64); +#[derive(PartialEq, PartialOrd, Debug)] +pub struct EstimatedStatistic(pub f64); pub type CostModelResult = Result; From b682c73541e99e2e924436d60a4a5bbcd3d1fbb3 Mon Sep 17 00:00:00 2001 From: Yuanxin Cao Date: Mon, 18 Nov 2024 13:16:09 -0500 Subject: [PATCH 39/51] split join into multiple files --- optd-cost-model/src/cost/join/hash_join.rs | 93 ++++++ optd-cost-model/src/cost/{ => join}/join.rs | 301 ++++-------------- optd-cost-model/src/cost/join/mod.rs | 3 + .../src/cost/join/nested_loop_join.rs | 123 +++++++ 4 files changed, 278 insertions(+), 242 deletions(-) create mode 100644 optd-cost-model/src/cost/join/hash_join.rs rename optd-cost-model/src/cost/{ => join}/join.rs (58%) create mode 100644 optd-cost-model/src/cost/join/mod.rs create mode 100644 optd-cost-model/src/cost/join/nested_loop_join.rs diff --git a/optd-cost-model/src/cost/join/hash_join.rs b/optd-cost-model/src/cost/join/hash_join.rs new file mode 100644 index 0000000..a0f39a7 --- /dev/null +++ b/optd-cost-model/src/cost/join/hash_join.rs @@ -0,0 +1,93 @@ +use itertools::Itertools; + +use crate::{ + common::{ + nodes::{JoinType, ReprPredicateNode}, + predicates::{attr_ref_pred::AttrRefPred, list_pred::ListPred}, + properties::attr_ref::{AttrRefs, SemanticCorrelation}, + types::GroupId, + }, + cost_model::CostModelImpl, + storage::CostModelStorageManager, + CostModelResult, EstimatedStatistic, +}; + +use super::join::get_input_correlation; + +impl CostModelImpl { + #[allow(clippy::too_many_arguments)] + pub async fn get_hash_join_row_cnt( + &self, + join_typ: JoinType, + group_id: GroupId, + left_row_cnt: f64, + right_row_cnt: f64, + left_group_id: GroupId, + right_group_id: GroupId, + left_keys: ListPred, + right_keys: ListPred, + ) -> CostModelResult { + let selectivity = { + let output_attr_refs = self.memo.get_attribute_ref(group_id); + let left_attr_refs = self.memo.get_attribute_ref(left_group_id); + let right_attr_refs = self.memo.get_attribute_ref(right_group_id); + let left_attr_cnt = left_attr_refs.attr_refs().len(); + // there may be more than one expression tree in a group. + // see comment in PredicateType::PhysicalFilter(_) for more information + let input_correlation = get_input_correlation(left_attr_refs, right_attr_refs); + self.get_hash_join_selectivity( + join_typ, + left_keys, + right_keys, + output_attr_refs.attr_refs(), + input_correlation, + left_row_cnt, + right_row_cnt, + left_attr_cnt, + ) + .await? + }; + Ok(EstimatedStatistic( + (left_row_cnt * right_row_cnt * selectivity).max(1.0), + )) + } + + #[allow(clippy::too_many_arguments)] + async fn get_hash_join_selectivity( + &self, + join_typ: JoinType, + left_keys: ListPred, + right_keys: ListPred, + attr_refs: &AttrRefs, + input_correlation: Option, + left_row_cnt: f64, + right_row_cnt: f64, + left_attr_cnt: usize, + ) -> CostModelResult { + assert!(left_keys.len() == right_keys.len()); + // I assume that the keys are already in the right order + // s.t. the ith key of left_keys corresponds with the ith key of right_keys + let on_attr_ref_pairs = left_keys + .to_vec() + .into_iter() + .zip(right_keys.to_vec()) + .map(|(left_key, right_key)| { + ( + AttrRefPred::from_pred_node(left_key).expect("keys should be AttrRefPreds"), + AttrRefPred::from_pred_node(right_key).expect("keys should be AttrRefPreds"), + ) + }) + .collect_vec(); + self.get_join_selectivity_core( + join_typ, + on_attr_ref_pairs, + None, + attr_refs, + input_correlation, + left_row_cnt, + right_row_cnt, + left_attr_cnt, + ) + .await + } +} diff --git a/optd-cost-model/src/cost/join.rs b/optd-cost-model/src/cost/join/join.rs similarity index 58% rename from optd-cost-model/src/cost/join.rs rename to optd-cost-model/src/cost/join/join.rs index cb11a6a..b2b7e49 100644 --- a/optd-cost-model/src/cost/join.rs +++ b/optd-cost-model/src/cost/join/join.rs @@ -23,123 +23,67 @@ use crate::{ CostModelResult, }; -impl CostModelImpl { - #[allow(clippy::too_many_arguments)] - pub async fn get_nlj_row_cnt( - &self, - join_typ: JoinType, - group_id: GroupId, - left_row_cnt: f64, - right_row_cnt: f64, - left_group_id: GroupId, - right_group_id: GroupId, - join_cond: ArcPredicateNode, - ) -> CostModelResult { - let selectivity = { - let output_attr_refs = self.memo.get_attribute_ref(group_id); - let left_attr_refs = self.memo.get_attribute_ref(left_group_id); - let right_attr_refs = self.memo.get_attribute_ref(right_group_id); - let input_correlation = self.get_input_correlation(left_attr_refs, right_attr_refs); - - self.get_join_selectivity_from_expr_tree( - join_typ, - join_cond, - output_attr_refs.attr_refs(), - input_correlation, - left_row_cnt, - right_row_cnt, - ) - .await? - }; - Ok((left_row_cnt * right_row_cnt * selectivity).max(1.0)) - } - - #[allow(clippy::too_many_arguments)] - pub async fn get_hash_join_row_cnt( - &self, - join_typ: JoinType, - group_id: GroupId, - left_row_cnt: f64, - right_row_cnt: f64, - left_group_id: GroupId, - right_group_id: GroupId, - left_keys: ListPred, - right_keys: ListPred, - ) -> CostModelResult { - let selectivity = { - let output_attr_refs = self.memo.get_attribute_ref(group_id); - let left_attr_refs = self.memo.get_attribute_ref(left_group_id); - let right_attr_refs = self.memo.get_attribute_ref(right_group_id); - let left_attr_cnt = left_attr_refs.attr_refs().len(); - // there may be more than one expression tree in a group. - // see comment in PredicateType::PhysicalFilter(_) for more information - let input_correlation = self.get_input_correlation(left_attr_refs, right_attr_refs); - self.get_join_selectivity_from_keys( - join_typ, - left_keys, - right_keys, - output_attr_refs.attr_refs(), - input_correlation, - left_row_cnt, - right_row_cnt, - left_attr_cnt, - ) - .await? - }; - Ok((left_row_cnt * right_row_cnt * selectivity).max(1.0)) - } - - fn get_input_correlation( - &self, - left_prop: GroupAttrRefs, - right_prop: GroupAttrRefs, - ) -> Option { - SemanticCorrelation::merge( - left_prop.output_correlation().cloned(), - right_prop.output_correlation().cloned(), - ) - } +pub(crate) fn get_input_correlation( + left_prop: GroupAttrRefs, + right_prop: GroupAttrRefs, +) -> Option { + SemanticCorrelation::merge( + left_prop.output_correlation().cloned(), + right_prop.output_correlation().cloned(), + ) +} - /// A wrapper to convert the join keys to the format expected by get_join_selectivity_core() - #[allow(clippy::too_many_arguments)] - async fn get_join_selectivity_from_keys( - &self, - join_typ: JoinType, - left_keys: ListPred, - right_keys: ListPred, - attr_refs: &AttrRefs, - input_correlation: Option, - left_row_cnt: f64, - right_row_cnt: f64, - left_attr_cnt: usize, - ) -> CostModelResult { - assert!(left_keys.len() == right_keys.len()); - // I assume that the keys are already in the right order - // s.t. the ith key of left_keys corresponds with the ith key of right_keys - let on_attr_ref_pairs = left_keys - .to_vec() - .into_iter() - .zip(right_keys.to_vec()) - .map(|(left_key, right_key)| { - ( - AttrRefPred::from_pred_node(left_key).expect("keys should be AttrRefPreds"), - AttrRefPred::from_pred_node(right_key).expect("keys should be AttrRefPreds"), - ) - }) - .collect_vec(); - self.get_join_selectivity_core( - join_typ, - on_attr_ref_pairs, - None, - attr_refs, - input_correlation, - left_row_cnt, - right_row_cnt, - left_attr_cnt, - ) - .await +/// Check if an expr_tree is a join condition, returning the join on attr ref pair if it is. +/// The reason the check and the info are in the same function is because their code is almost +/// identical. It only picks out equality conditions between two attribute refs on different +/// tables +pub(crate) fn get_on_attr_ref_pair( + expr_tree: ArcPredicateNode, + attr_refs: &AttrRefs, +) -> Option<(AttrRefPred, AttrRefPred)> { + // 1. Check that it's equality + if expr_tree.typ == PredicateType::BinOp(BinOpType::Eq) { + let left_child = expr_tree.child(0); + let right_child = expr_tree.child(1); + // 2. Check that both sides are attribute refs + if left_child.typ == PredicateType::AttrRef && right_child.typ == PredicateType::AttrRef { + // 3. Check that both sides don't belong to the same table (if we don't know, that + // means they don't belong) + let left_attr_ref_expr = AttrRefPred::from_pred_node(left_child) + .expect("we already checked that the type is AttrRef"); + let right_attr_ref_expr = AttrRefPred::from_pred_node(right_child) + .expect("we already checked that the type is AttrRef"); + let left_attr_ref = &attr_refs[left_attr_ref_expr.attr_index() as usize]; + let right_attr_ref = &attr_refs[right_attr_ref_expr.attr_index() as usize]; + let is_same_table = if let ( + AttrRef::BaseTableAttrRef(BaseTableAttrRef { + table_id: left_table_id, + .. + }), + AttrRef::BaseTableAttrRef(BaseTableAttrRef { + table_id: right_table_id, + .. + }), + ) = (left_attr_ref, right_attr_ref) + { + left_table_id == right_table_id + } else { + false + }; + if !is_same_table { + Some((left_attr_ref_expr, right_attr_ref_expr)) + } else { + None + } + } else { + None + } + } else { + None } +} +impl CostModelImpl { /// The core logic of join selectivity which assumes we've already separated the expression /// into the on conditions and the filters. /// @@ -153,7 +97,7 @@ impl CostModelImpl { /// For example, if the left table has 3 attributes, the first attribute of the right table /// is #3 instead of #0. #[allow(clippy::too_many_arguments)] - async fn get_join_selectivity_core( + pub(crate) async fn get_join_selectivity_core( &self, join_typ: JoinType, on_attr_ref_pairs: Vec<(AttrRefPred, AttrRefPred)>, @@ -202,133 +146,6 @@ impl CostModelImpl { }) } - /// The expr_tree input must be a "mixed expression tree", just like with - /// `get_filter_selectivity`. - /// - /// This is a "wrapper" to separate the equality conditions from the filter conditions before - /// calling the "main" `get_join_selectivity_core` function. - #[allow(clippy::too_many_arguments)] - async fn get_join_selectivity_from_expr_tree( - &self, - join_typ: JoinType, - expr_tree: ArcPredicateNode, - attr_refs: &AttrRefs, - input_correlation: Option, - left_row_cnt: f64, - right_row_cnt: f64, - ) -> CostModelResult { - if expr_tree.typ == PredicateType::LogOp(LogOpType::And) { - let mut on_attr_ref_pairs = vec![]; - let mut filter_expr_trees = vec![]; - for child_expr_tree in &expr_tree.children { - if let Some(on_attr_ref_pair) = - Self::get_on_attr_ref_pair(child_expr_tree.clone(), attr_refs) - { - on_attr_ref_pairs.push(on_attr_ref_pair) - } else { - let child_expr = child_expr_tree.clone(); - filter_expr_trees.push(child_expr); - } - } - assert!(on_attr_ref_pairs.len() + filter_expr_trees.len() == expr_tree.children.len()); - let filter_expr_tree = if filter_expr_trees.is_empty() { - None - } else { - Some(LogOpPred::new(LogOpType::And, filter_expr_trees).into_pred_node()) - }; - self.get_join_selectivity_core( - join_typ, - on_attr_ref_pairs, - filter_expr_tree, - attr_refs, - input_correlation, - left_row_cnt, - right_row_cnt, - 0, - ) - .await - } else { - #[allow(clippy::collapsible_else_if)] - if let Some(on_attr_ref_pair) = Self::get_on_attr_ref_pair(expr_tree.clone(), attr_refs) - { - self.get_join_selectivity_core( - join_typ, - vec![on_attr_ref_pair], - None, - attr_refs, - input_correlation, - left_row_cnt, - right_row_cnt, - 0, - ) - .await - } else { - self.get_join_selectivity_core( - join_typ, - vec![], - Some(expr_tree), - attr_refs, - input_correlation, - left_row_cnt, - right_row_cnt, - 0, - ) - .await - } - } - } - - /// Check if an expr_tree is a join condition, returning the join on attr ref pair if it is. - /// The reason the check and the info are in the same function is because their code is almost - /// identical. It only picks out equality conditions between two attribute refs on different - /// tables - fn get_on_attr_ref_pair( - expr_tree: ArcPredicateNode, - attr_refs: &AttrRefs, - ) -> Option<(AttrRefPred, AttrRefPred)> { - // 1. Check that it's equality - if expr_tree.typ == PredicateType::BinOp(BinOpType::Eq) { - let left_child = expr_tree.child(0); - let right_child = expr_tree.child(1); - // 2. Check that both sides are attribute refs - if left_child.typ == PredicateType::AttrRef && right_child.typ == PredicateType::AttrRef - { - // 3. Check that both sides don't belong to the same table (if we don't know, that - // means they don't belong) - let left_attr_ref_expr = AttrRefPred::from_pred_node(left_child) - .expect("we already checked that the type is AttrRef"); - let right_attr_ref_expr = AttrRefPred::from_pred_node(right_child) - .expect("we already checked that the type is AttrRef"); - let left_attr_ref = &attr_refs[left_attr_ref_expr.attr_index() as usize]; - let right_attr_ref = &attr_refs[right_attr_ref_expr.attr_index() as usize]; - let is_same_table = if let ( - AttrRef::BaseTableAttrRef(BaseTableAttrRef { - table_id: left_table_id, - .. - }), - AttrRef::BaseTableAttrRef(BaseTableAttrRef { - table_id: right_table_id, - .. - }), - ) = (left_attr_ref, right_attr_ref) - { - left_table_id == right_table_id - } else { - false - }; - if !is_same_table { - Some((left_attr_ref_expr, right_attr_ref_expr)) - } else { - None - } - } else { - None - } - } else { - None - } - } - /// Get the selectivity of one attribute eq predicate, e.g. attrA = attrB. async fn get_join_selectivity_from_on_attr_ref_pair( &self, diff --git a/optd-cost-model/src/cost/join/mod.rs b/optd-cost-model/src/cost/join/mod.rs new file mode 100644 index 0000000..54ba481 --- /dev/null +++ b/optd-cost-model/src/cost/join/mod.rs @@ -0,0 +1,3 @@ +pub mod hash_join; +pub mod join; +pub mod nested_loop_join; diff --git a/optd-cost-model/src/cost/join/nested_loop_join.rs b/optd-cost-model/src/cost/join/nested_loop_join.rs new file mode 100644 index 0000000..58e0ae1 --- /dev/null +++ b/optd-cost-model/src/cost/join/nested_loop_join.rs @@ -0,0 +1,123 @@ +use crate::{ + common::{ + nodes::{ArcPredicateNode, JoinType, PredicateType, ReprPredicateNode}, + predicates::log_op_pred::{LogOpPred, LogOpType}, + properties::attr_ref::{AttrRefs, SemanticCorrelation}, + types::GroupId, + }, + cost::join::join::get_on_attr_ref_pair, + cost_model::CostModelImpl, + storage::CostModelStorageManager, + CostModelResult, EstimatedStatistic, +}; + +use super::join::get_input_correlation; + +impl CostModelImpl { + #[allow(clippy::too_many_arguments)] + pub async fn get_nlj_row_cnt( + &self, + join_typ: JoinType, + group_id: GroupId, + left_row_cnt: f64, + right_row_cnt: f64, + left_group_id: GroupId, + right_group_id: GroupId, + join_cond: ArcPredicateNode, + ) -> CostModelResult { + let selectivity = { + let output_attr_refs = self.memo.get_attribute_ref(group_id); + let left_attr_refs = self.memo.get_attribute_ref(left_group_id); + let right_attr_refs = self.memo.get_attribute_ref(right_group_id); + let input_correlation = get_input_correlation(left_attr_refs, right_attr_refs); + + self.get_nlj_join_selectivity( + join_typ, + join_cond, + output_attr_refs.attr_refs(), + input_correlation, + left_row_cnt, + right_row_cnt, + ) + .await? + }; + Ok(EstimatedStatistic( + (left_row_cnt * right_row_cnt * selectivity).max(1.0), + )) + } + + /// The expr_tree input must be a "mixed expression tree", just like with + /// `get_filter_selectivity`. + /// + /// This is a "wrapper" to separate the equality conditions from the filter conditions before + /// calling the "main" `get_join_selectivity_core` function. + #[allow(clippy::too_many_arguments)] + async fn get_nlj_join_selectivity( + &self, + join_typ: JoinType, + expr_tree: ArcPredicateNode, + attr_refs: &AttrRefs, + input_correlation: Option, + left_row_cnt: f64, + right_row_cnt: f64, + ) -> CostModelResult { + if expr_tree.typ == PredicateType::LogOp(LogOpType::And) { + let mut on_attr_ref_pairs = vec![]; + let mut filter_expr_trees = vec![]; + for child_expr_tree in &expr_tree.children { + if let Some(on_attr_ref_pair) = + get_on_attr_ref_pair(child_expr_tree.clone(), attr_refs) + { + on_attr_ref_pairs.push(on_attr_ref_pair) + } else { + let child_expr = child_expr_tree.clone(); + filter_expr_trees.push(child_expr); + } + } + assert!(on_attr_ref_pairs.len() + filter_expr_trees.len() == expr_tree.children.len()); + let filter_expr_tree = if filter_expr_trees.is_empty() { + None + } else { + Some(LogOpPred::new(LogOpType::And, filter_expr_trees).into_pred_node()) + }; + self.get_join_selectivity_core( + join_typ, + on_attr_ref_pairs, + filter_expr_tree, + attr_refs, + input_correlation, + left_row_cnt, + right_row_cnt, + 0, + ) + .await + } else { + #[allow(clippy::collapsible_else_if)] + if let Some(on_attr_ref_pair) = get_on_attr_ref_pair(expr_tree.clone(), attr_refs) { + self.get_join_selectivity_core( + join_typ, + vec![on_attr_ref_pair], + None, + attr_refs, + input_correlation, + left_row_cnt, + right_row_cnt, + 0, + ) + .await + } else { + self.get_join_selectivity_core( + join_typ, + vec![], + Some(expr_tree), + attr_refs, + input_correlation, + left_row_cnt, + right_row_cnt, + 0, + ) + .await + } + } + } +} From 5d7314100bdff9407018cabf627411529f9f6d8b Mon Sep 17 00:00:00 2001 From: Yuanxin Cao Date: Mon, 18 Nov 2024 13:40:32 -0500 Subject: [PATCH 40/51] reorganize join --- optd-cost-model/src/cost/join/hash_join.rs | 43 +---- optd-cost-model/src/cost/join/join.rs | 166 ++++++++++++------ optd-cost-model/src/cost/join/mod.rs | 68 +++++++ .../src/cost/join/nested_loop_join.rs | 80 +-------- 4 files changed, 183 insertions(+), 174 deletions(-) diff --git a/optd-cost-model/src/cost/join/hash_join.rs b/optd-cost-model/src/cost/join/hash_join.rs index a0f39a7..c4049db 100644 --- a/optd-cost-model/src/cost/join/hash_join.rs +++ b/optd-cost-model/src/cost/join/hash_join.rs @@ -12,7 +12,7 @@ use crate::{ CostModelResult, EstimatedStatistic, }; -use super::join::get_input_correlation; +use super::get_input_correlation; impl CostModelImpl { #[allow(clippy::too_many_arguments)] @@ -35,7 +35,7 @@ impl CostModelImpl { // there may be more than one expression tree in a group. // see comment in PredicateType::PhysicalFilter(_) for more information let input_correlation = get_input_correlation(left_attr_refs, right_attr_refs); - self.get_hash_join_selectivity( + self.get_join_selectivity_from_keys( join_typ, left_keys, right_keys, @@ -51,43 +51,4 @@ impl CostModelImpl { (left_row_cnt * right_row_cnt * selectivity).max(1.0), )) } - - #[allow(clippy::too_many_arguments)] - async fn get_hash_join_selectivity( - &self, - join_typ: JoinType, - left_keys: ListPred, - right_keys: ListPred, - attr_refs: &AttrRefs, - input_correlation: Option, - left_row_cnt: f64, - right_row_cnt: f64, - left_attr_cnt: usize, - ) -> CostModelResult { - assert!(left_keys.len() == right_keys.len()); - // I assume that the keys are already in the right order - // s.t. the ith key of left_keys corresponds with the ith key of right_keys - let on_attr_ref_pairs = left_keys - .to_vec() - .into_iter() - .zip(right_keys.to_vec()) - .map(|(left_key, right_key)| { - ( - AttrRefPred::from_pred_node(left_key).expect("keys should be AttrRefPreds"), - AttrRefPred::from_pred_node(right_key).expect("keys should be AttrRefPreds"), - ) - }) - .collect_vec(); - self.get_join_selectivity_core( - join_typ, - on_attr_ref_pairs, - None, - attr_refs, - input_correlation, - left_row_cnt, - right_row_cnt, - left_attr_cnt, - ) - .await - } } diff --git a/optd-cost-model/src/cost/join/join.rs b/optd-cost-model/src/cost/join/join.rs index b2b7e49..28d3029 100644 --- a/optd-cost-model/src/cost/join/join.rs +++ b/optd-cost-model/src/cost/join/join.rs @@ -17,73 +17,129 @@ use crate::{ }, types::GroupId, }, + cost::join::get_on_attr_ref_pair, cost_model::CostModelImpl, stats::DEFAULT_NUM_DISTINCT, storage::CostModelStorageManager, CostModelResult, }; -pub(crate) fn get_input_correlation( - left_prop: GroupAttrRefs, - right_prop: GroupAttrRefs, -) -> Option { - SemanticCorrelation::merge( - left_prop.output_correlation().cloned(), - right_prop.output_correlation().cloned(), - ) -} - -/// Check if an expr_tree is a join condition, returning the join on attr ref pair if it is. -/// The reason the check and the info are in the same function is because their code is almost -/// identical. It only picks out equality conditions between two attribute refs on different -/// tables -pub(crate) fn get_on_attr_ref_pair( - expr_tree: ArcPredicateNode, - attr_refs: &AttrRefs, -) -> Option<(AttrRefPred, AttrRefPred)> { - // 1. Check that it's equality - if expr_tree.typ == PredicateType::BinOp(BinOpType::Eq) { - let left_child = expr_tree.child(0); - let right_child = expr_tree.child(1); - // 2. Check that both sides are attribute refs - if left_child.typ == PredicateType::AttrRef && right_child.typ == PredicateType::AttrRef { - // 3. Check that both sides don't belong to the same table (if we don't know, that - // means they don't belong) - let left_attr_ref_expr = AttrRefPred::from_pred_node(left_child) - .expect("we already checked that the type is AttrRef"); - let right_attr_ref_expr = AttrRefPred::from_pred_node(right_child) - .expect("we already checked that the type is AttrRef"); - let left_attr_ref = &attr_refs[left_attr_ref_expr.attr_index() as usize]; - let right_attr_ref = &attr_refs[right_attr_ref_expr.attr_index() as usize]; - let is_same_table = if let ( - AttrRef::BaseTableAttrRef(BaseTableAttrRef { - table_id: left_table_id, - .. - }), - AttrRef::BaseTableAttrRef(BaseTableAttrRef { - table_id: right_table_id, - .. - }), - ) = (left_attr_ref, right_attr_ref) - { - left_table_id == right_table_id +impl CostModelImpl { + /// The expr_tree input must be a "mixed expression tree", just like with + /// `get_filter_selectivity`. + /// + /// This is a "wrapper" to separate the equality conditions from the filter conditions before + /// calling the "main" `get_join_selectivity_core` function. + #[allow(clippy::too_many_arguments)] + pub(crate) async fn get_join_selectivity_from_expr_tree( + &self, + join_typ: JoinType, + expr_tree: ArcPredicateNode, + attr_refs: &AttrRefs, + input_correlation: Option, + left_row_cnt: f64, + right_row_cnt: f64, + ) -> CostModelResult { + if expr_tree.typ == PredicateType::LogOp(LogOpType::And) { + let mut on_attr_ref_pairs = vec![]; + let mut filter_expr_trees = vec![]; + for child_expr_tree in &expr_tree.children { + if let Some(on_attr_ref_pair) = + get_on_attr_ref_pair(child_expr_tree.clone(), attr_refs) + { + on_attr_ref_pairs.push(on_attr_ref_pair) + } else { + let child_expr = child_expr_tree.clone(); + filter_expr_trees.push(child_expr); + } + } + assert!(on_attr_ref_pairs.len() + filter_expr_trees.len() == expr_tree.children.len()); + let filter_expr_tree = if filter_expr_trees.is_empty() { + None } else { - false + Some(LogOpPred::new(LogOpType::And, filter_expr_trees).into_pred_node()) }; - if !is_same_table { - Some((left_attr_ref_expr, right_attr_ref_expr)) + self.get_join_selectivity_core( + join_typ, + on_attr_ref_pairs, + filter_expr_tree, + attr_refs, + input_correlation, + left_row_cnt, + right_row_cnt, + 0, + ) + .await + } else { + #[allow(clippy::collapsible_else_if)] + if let Some(on_attr_ref_pair) = get_on_attr_ref_pair(expr_tree.clone(), attr_refs) { + self.get_join_selectivity_core( + join_typ, + vec![on_attr_ref_pair], + None, + attr_refs, + input_correlation, + left_row_cnt, + right_row_cnt, + 0, + ) + .await } else { - None + self.get_join_selectivity_core( + join_typ, + vec![], + Some(expr_tree), + attr_refs, + input_correlation, + left_row_cnt, + right_row_cnt, + 0, + ) + .await } - } else { - None } - } else { - None } -} -impl CostModelImpl { + /// A wrapper to convert the join keys to the format expected by get_join_selectivity_core() + #[allow(clippy::too_many_arguments)] + pub(crate) async fn get_join_selectivity_from_keys( + &self, + join_typ: JoinType, + left_keys: ListPred, + right_keys: ListPred, + attr_refs: &AttrRefs, + input_correlation: Option, + left_row_cnt: f64, + right_row_cnt: f64, + left_attr_cnt: usize, + ) -> CostModelResult { + assert!(left_keys.len() == right_keys.len()); + // I assume that the keys are already in the right order + // s.t. the ith key of left_keys corresponds with the ith key of right_keys + let on_attr_ref_pairs = left_keys + .to_vec() + .into_iter() + .zip(right_keys.to_vec()) + .map(|(left_key, right_key)| { + ( + AttrRefPred::from_pred_node(left_key).expect("keys should be AttrRefPreds"), + AttrRefPred::from_pred_node(right_key).expect("keys should be AttrRefPreds"), + ) + }) + .collect_vec(); + self.get_join_selectivity_core( + join_typ, + on_attr_ref_pairs, + None, + attr_refs, + input_correlation, + left_row_cnt, + right_row_cnt, + left_attr_cnt, + ) + .await + } + /// The core logic of join selectivity which assumes we've already separated the expression /// into the on conditions and the filters. /// @@ -97,7 +153,7 @@ impl CostModelImpl { /// For example, if the left table has 3 attributes, the first attribute of the right table /// is #3 instead of #0. #[allow(clippy::too_many_arguments)] - pub(crate) async fn get_join_selectivity_core( + async fn get_join_selectivity_core( &self, join_typ: JoinType, on_attr_ref_pairs: Vec<(AttrRefPred, AttrRefPred)>, diff --git a/optd-cost-model/src/cost/join/mod.rs b/optd-cost-model/src/cost/join/mod.rs index 54ba481..8b29661 100644 --- a/optd-cost-model/src/cost/join/mod.rs +++ b/optd-cost-model/src/cost/join/mod.rs @@ -1,3 +1,71 @@ +use crate::common::{ + nodes::{ArcPredicateNode, PredicateType, ReprPredicateNode}, + predicates::{attr_ref_pred::AttrRefPred, bin_op_pred::BinOpType}, + properties::attr_ref::{ + AttrRef, AttrRefs, BaseTableAttrRef, GroupAttrRefs, SemanticCorrelation, + }, +}; + pub mod hash_join; pub mod join; pub mod nested_loop_join; + +pub(crate) fn get_input_correlation( + left_prop: GroupAttrRefs, + right_prop: GroupAttrRefs, +) -> Option { + SemanticCorrelation::merge( + left_prop.output_correlation().cloned(), + right_prop.output_correlation().cloned(), + ) +} + +/// Check if an expr_tree is a join condition, returning the join on attr ref pair if it is. +/// The reason the check and the info are in the same function is because their code is almost +/// identical. It only picks out equality conditions between two attribute refs on different +/// tables +pub(crate) fn get_on_attr_ref_pair( + expr_tree: ArcPredicateNode, + attr_refs: &AttrRefs, +) -> Option<(AttrRefPred, AttrRefPred)> { + // 1. Check that it's equality + if expr_tree.typ == PredicateType::BinOp(BinOpType::Eq) { + let left_child = expr_tree.child(0); + let right_child = expr_tree.child(1); + // 2. Check that both sides are attribute refs + if left_child.typ == PredicateType::AttrRef && right_child.typ == PredicateType::AttrRef { + // 3. Check that both sides don't belong to the same table (if we don't know, that + // means they don't belong) + let left_attr_ref_expr = AttrRefPred::from_pred_node(left_child) + .expect("we already checked that the type is AttrRef"); + let right_attr_ref_expr = AttrRefPred::from_pred_node(right_child) + .expect("we already checked that the type is AttrRef"); + let left_attr_ref = &attr_refs[left_attr_ref_expr.attr_index() as usize]; + let right_attr_ref = &attr_refs[right_attr_ref_expr.attr_index() as usize]; + let is_same_table = if let ( + AttrRef::BaseTableAttrRef(BaseTableAttrRef { + table_id: left_table_id, + .. + }), + AttrRef::BaseTableAttrRef(BaseTableAttrRef { + table_id: right_table_id, + .. + }), + ) = (left_attr_ref, right_attr_ref) + { + left_table_id == right_table_id + } else { + false + }; + if !is_same_table { + Some((left_attr_ref_expr, right_attr_ref_expr)) + } else { + None + } + } else { + None + } + } else { + None + } +} diff --git a/optd-cost-model/src/cost/join/nested_loop_join.rs b/optd-cost-model/src/cost/join/nested_loop_join.rs index 58e0ae1..0c9102f 100644 --- a/optd-cost-model/src/cost/join/nested_loop_join.rs +++ b/optd-cost-model/src/cost/join/nested_loop_join.rs @@ -5,13 +5,12 @@ use crate::{ properties::attr_ref::{AttrRefs, SemanticCorrelation}, types::GroupId, }, - cost::join::join::get_on_attr_ref_pair, cost_model::CostModelImpl, storage::CostModelStorageManager, CostModelResult, EstimatedStatistic, }; -use super::join::get_input_correlation; +use super::get_input_correlation; impl CostModelImpl { #[allow(clippy::too_many_arguments)] @@ -31,7 +30,7 @@ impl CostModelImpl { let right_attr_refs = self.memo.get_attribute_ref(right_group_id); let input_correlation = get_input_correlation(left_attr_refs, right_attr_refs); - self.get_nlj_join_selectivity( + self.get_join_selectivity_from_expr_tree( join_typ, join_cond, output_attr_refs.attr_refs(), @@ -45,79 +44,4 @@ impl CostModelImpl { (left_row_cnt * right_row_cnt * selectivity).max(1.0), )) } - - /// The expr_tree input must be a "mixed expression tree", just like with - /// `get_filter_selectivity`. - /// - /// This is a "wrapper" to separate the equality conditions from the filter conditions before - /// calling the "main" `get_join_selectivity_core` function. - #[allow(clippy::too_many_arguments)] - async fn get_nlj_join_selectivity( - &self, - join_typ: JoinType, - expr_tree: ArcPredicateNode, - attr_refs: &AttrRefs, - input_correlation: Option, - left_row_cnt: f64, - right_row_cnt: f64, - ) -> CostModelResult { - if expr_tree.typ == PredicateType::LogOp(LogOpType::And) { - let mut on_attr_ref_pairs = vec![]; - let mut filter_expr_trees = vec![]; - for child_expr_tree in &expr_tree.children { - if let Some(on_attr_ref_pair) = - get_on_attr_ref_pair(child_expr_tree.clone(), attr_refs) - { - on_attr_ref_pairs.push(on_attr_ref_pair) - } else { - let child_expr = child_expr_tree.clone(); - filter_expr_trees.push(child_expr); - } - } - assert!(on_attr_ref_pairs.len() + filter_expr_trees.len() == expr_tree.children.len()); - let filter_expr_tree = if filter_expr_trees.is_empty() { - None - } else { - Some(LogOpPred::new(LogOpType::And, filter_expr_trees).into_pred_node()) - }; - self.get_join_selectivity_core( - join_typ, - on_attr_ref_pairs, - filter_expr_tree, - attr_refs, - input_correlation, - left_row_cnt, - right_row_cnt, - 0, - ) - .await - } else { - #[allow(clippy::collapsible_else_if)] - if let Some(on_attr_ref_pair) = get_on_attr_ref_pair(expr_tree.clone(), attr_refs) { - self.get_join_selectivity_core( - join_typ, - vec![on_attr_ref_pair], - None, - attr_refs, - input_correlation, - left_row_cnt, - right_row_cnt, - 0, - ) - .await - } else { - self.get_join_selectivity_core( - join_typ, - vec![], - Some(expr_tree), - attr_refs, - input_correlation, - left_row_cnt, - right_row_cnt, - 0, - ) - .await - } - } - } } From 51f917d9caed09a0b6e5c089eb210e37af394b66 Mon Sep 17 00:00:00 2001 From: Yuanxin Cao Date: Mon, 18 Nov 2024 14:31:22 -0500 Subject: [PATCH 41/51] refine test infra --- optd-cost-model/src/cost/agg.rs | 80 ++++++++----------- optd-cost-model/src/cost/filter/controller.rs | 79 ++++++++---------- optd-cost-model/src/cost/filter/in_list.rs | 2 +- optd-cost-model/src/cost/filter/like.rs | 4 +- optd-cost-model/src/cost_model.rs | 45 ++++++++++- optd-cost-model/src/memo_ext.rs | 33 ++++++-- optd-cost-model/src/storage/mock.rs | 14 ++-- 7 files changed, 142 insertions(+), 115 deletions(-) diff --git a/optd-cost-model/src/cost/agg.rs b/optd-cost-model/src/cost/agg.rs index 1bbf155..fd9141c 100644 --- a/optd-cost-model/src/cost/agg.rs +++ b/optd-cost-model/src/cost/agg.rs @@ -72,7 +72,7 @@ mod tests { values::Value, }, cost_model::tests::{ - attr_ref, cnst, create_cost_model_mock_storage, empty_list, empty_per_attr_stats, list, + attr_ref, cnst, create_mock_cost_model, empty_list, empty_per_attr_stats, list, TestPerAttributeStats, }, stats::{utilities::simple_map::SimpleMap, MostCommonValues, DEFAULT_NUM_DISTINCT}, @@ -84,27 +84,20 @@ mod tests { let table_id = TableId(0); let attr_infos = HashMap::from([( table_id, - HashMap::from([ - ( - 0, - Attribute { - name: String::from("attr1"), - typ: ConstantType::Int32, - nullable: false, - }, - ), - ( - 1, - Attribute { - name: String::from("attr2"), - typ: ConstantType::Int64, - nullable: false, - }, - ), - ]), + vec![ + Attribute { + name: String::from("attr1"), + typ: ConstantType::Int32, + nullable: false, + }, + Attribute { + name: String::from("attr2"), + typ: ConstantType::Int64, + nullable: false, + }, + ], )]); - let cost_model = - create_cost_model_mock_storage(vec![table_id], vec![], vec![None], attr_infos); + let cost_model = create_mock_cost_model(vec![table_id], vec![], vec![None], attr_infos); // Group by empty list should return 1. let group_bys = empty_list(); @@ -136,32 +129,23 @@ mod tests { let attr3_base_idx = 2; let attr_infos = HashMap::from([( table_id, - HashMap::from([ - ( - attr1_base_idx, - Attribute { - name: String::from("attr1"), - typ: ConstantType::Int32, - nullable: false, - }, - ), - ( - attr2_base_idx, - Attribute { - name: String::from("attr2"), - typ: ConstantType::Int64, - nullable: false, - }, - ), - ( - attr3_base_idx, - Attribute { - name: String::from("attr3"), - typ: ConstantType::Int64, - nullable: false, - }, - ), - ]), + vec![ + Attribute { + name: String::from("attr1"), + typ: ConstantType::Int32, + nullable: false, + }, + Attribute { + name: String::from("attr2"), + typ: ConstantType::Int64, + nullable: false, + }, + Attribute { + name: String::from("attr3"), + typ: ConstantType::Int64, + nullable: false, + }, + ], )]); let attr1_ndistinct = 12; @@ -179,7 +163,7 @@ mod tests { 0.0, ); - let cost_model = create_cost_model_mock_storage( + let cost_model = create_mock_cost_model( vec![table_id], vec![HashMap::from([ (attr1_base_idx, attr1_stats), diff --git a/optd-cost-model/src/cost/filter/controller.rs b/optd-cost-model/src/cost/filter/controller.rs index fd9769f..40cf969 100644 --- a/optd-cost-model/src/cost/filter/controller.rs +++ b/optd-cost-model/src/cost/filter/controller.rs @@ -109,7 +109,7 @@ mod tests { #[tokio::test] async fn test_const() { - let cost_model = create_cost_model_mock_storage( + let cost_model = create_mock_cost_model( vec![TableId(0)], vec![HashMap::from([(0, empty_per_attr_stats())])], vec![None], @@ -143,7 +143,7 @@ mod tests { 0.0, ); let table_id = TableId(0); - let cost_model = create_cost_model_mock_storage( + let cost_model = create_mock_cost_model( vec![table_id], vec![HashMap::from([(0, per_attribute_stats)])], vec![None], @@ -177,7 +177,7 @@ mod tests { 0.0, ); let table_id = TableId(0); - let cost_model = create_cost_model_mock_storage( + let cost_model = create_mock_cost_model( vec![table_id], vec![HashMap::from([(0, per_attribute_stats)])], vec![None], @@ -212,7 +212,7 @@ mod tests { 0.0, ); let table_id = TableId(0); - let cost_model = create_cost_model_mock_storage( + let cost_model = create_mock_cost_model( vec![table_id], vec![HashMap::from([(0, per_attribute_stats)])], vec![None], @@ -246,7 +246,7 @@ mod tests { 0.0, ); let table_id = TableId(0); - let cost_model = create_cost_model_mock_storage( + let cost_model = create_mock_cost_model( vec![table_id], vec![HashMap::from([(0, per_attribute_stats)])], vec![None], @@ -289,7 +289,7 @@ mod tests { 0.0, ); let table_id = TableId(0); - let cost_model = create_cost_model_mock_storage( + let cost_model = create_mock_cost_model( vec![table_id], vec![HashMap::from([(0, per_attribute_stats)])], vec![None], @@ -332,7 +332,7 @@ mod tests { 0.0, ); let table_id = TableId(0); - let cost_model = create_cost_model_mock_storage( + let cost_model = create_mock_cost_model( vec![table_id], vec![HashMap::from([(0, per_attribute_stats)])], vec![None], @@ -370,7 +370,7 @@ mod tests { 0.0, ); let table_id = TableId(0); - let cost_model = create_cost_model_mock_storage( + let cost_model = create_mock_cost_model( vec![table_id], vec![HashMap::from([(0, per_attribute_stats)])], vec![None], @@ -414,7 +414,7 @@ mod tests { 0.0, ); let table_id = TableId(0); - let cost_model = create_cost_model_mock_storage( + let cost_model = create_mock_cost_model( vec![table_id], vec![HashMap::from([(0, per_attribute_stats)])], vec![None], @@ -458,7 +458,7 @@ mod tests { 0.0, ); let table_id = TableId(0); - let cost_model = create_cost_model_mock_storage( + let cost_model = create_mock_cost_model( vec![table_id], vec![HashMap::from([(0, per_attribute_stats)])], vec![None], @@ -498,7 +498,7 @@ mod tests { 0.0, ); let table_id = TableId(0); - let cost_model = create_cost_model_mock_storage( + let cost_model = create_mock_cost_model( vec![table_id], vec![HashMap::from([(0, per_attribute_stats)])], vec![None], @@ -536,7 +536,7 @@ mod tests { 0.0, ); let table_id = TableId(0); - let cost_model = create_cost_model_mock_storage( + let cost_model = create_mock_cost_model( vec![table_id], vec![HashMap::from([(0, per_attribute_stats)])], vec![None], @@ -576,7 +576,7 @@ mod tests { 0.0, ); let table_id = TableId(0); - let cost_model = create_cost_model_mock_storage( + let cost_model = create_mock_cost_model( vec![table_id], vec![HashMap::from([(0, per_attribute_stats)])], vec![None], @@ -623,7 +623,7 @@ mod tests { 0.0, ); let table_id = TableId(0); - let cost_model = create_cost_model_mock_storage( + let cost_model = create_mock_cost_model( vec![table_id], vec![HashMap::from([(0, per_attribute_stats)])], vec![None], @@ -669,7 +669,7 @@ mod tests { 0.0, ); let table_id = TableId(0); - let cost_model = create_cost_model_mock_storage( + let cost_model = create_mock_cost_model( vec![table_id], vec![HashMap::from([(0, per_attribute_stats)])], vec![None], @@ -702,7 +702,7 @@ mod tests { 0.0, ); let table_id = TableId(0); - let cost_model = create_cost_model_mock_storage( + let cost_model = create_mock_cost_model( vec![table_id], vec![HashMap::from([(0, per_attribute_stats)])], vec![None], @@ -747,16 +747,13 @@ mod tests { let table_id = TableId(0); let attr_infos = HashMap::from([( table_id, - HashMap::from([( - 0, - Attribute { - name: String::from("attr1"), - typ: ConstantType::Int32, - nullable: false, - }, - )]), + vec![Attribute { + name: String::from("attr1"), + typ: ConstantType::Int32, + nullable: false, + }], )]); - let cost_model = create_cost_model_mock_storage( + let cost_model = create_mock_cost_model( vec![table_id], vec![HashMap::from([(0, per_attribute_stats)])], vec![None], @@ -803,26 +800,20 @@ mod tests { let table_id = TableId(0); let attr_infos = HashMap::from([( table_id, - HashMap::from([ - ( - 0, - Attribute { - name: String::from("attr1"), - typ: ConstantType::Int32, - nullable: false, - }, - ), - ( - 1, - Attribute { - name: String::from("attr2"), - typ: ConstantType::Int64, - nullable: false, - }, - ), - ]), + vec![ + Attribute { + name: String::from("attr1"), + typ: ConstantType::Int32, + nullable: false, + }, + Attribute { + name: String::from("attr2"), + typ: ConstantType::Int64, + nullable: false, + }, + ], )]); - let cost_model = create_cost_model_mock_storage( + let cost_model = create_mock_cost_model( vec![table_id], vec![HashMap::from([(0, per_attribute_stats)])], vec![None], diff --git a/optd-cost-model/src/cost/filter/in_list.rs b/optd-cost-model/src/cost/filter/in_list.rs index 8c11bcd..4c6cf3a 100644 --- a/optd-cost-model/src/cost/filter/in_list.rs +++ b/optd-cost-model/src/cost/filter/in_list.rs @@ -91,7 +91,7 @@ mod tests { 0.0, ); let table_id = TableId(0); - let cost_model = create_cost_model_mock_storage( + let cost_model = create_mock_cost_model( vec![table_id], vec![HashMap::from([(0, per_attribute_stats)])], vec![None], diff --git a/optd-cost-model/src/cost/filter/like.rs b/optd-cost-model/src/cost/filter/like.rs index 997e389..b1b3e98 100644 --- a/optd-cost-model/src/cost/filter/like.rs +++ b/optd-cost-model/src/cost/filter/like.rs @@ -125,7 +125,7 @@ mod tests { 0.0, ); let table_id = TableId(0); - let cost_model = create_cost_model_mock_storage( + let cost_model = create_mock_cost_model( vec![table_id], vec![HashMap::from([(0, per_attribute_stats)])], vec![None], @@ -168,7 +168,7 @@ mod tests { null_frac, ); let table_id = TableId(0); - let cost_model = create_cost_model_mock_storage( + let cost_model = create_mock_cost_model( vec![table_id], vec![HashMap::from([(0, per_attribute_stats)])], vec![None], diff --git a/optd-cost-model/src/cost_model.rs b/optd-cost-model/src/cost_model.rs index 4d36c38..441fa1a 100644 --- a/optd-cost-model/src/cost_model.rs +++ b/optd-cost-model/src/cost_model.rs @@ -138,9 +138,10 @@ pub mod tests { log_op_pred::{LogOpPred, LogOpType}, un_op_pred::{UnOpPred, UnOpType}, }, + types::GroupId, values::Value, }, - memo_ext::tests::MockMemoExt, + memo_ext::tests::{MemoGroupInfo, MockMemoExtImpl}, stats::{ utilities::counter::Counter, AttributeCombValueStats, Distribution, MostCommonValues, }, @@ -153,7 +154,7 @@ pub mod tests { // TODO: add tests for non-mock storage manager pub type TestOptCostModelMock = CostModelImpl; - pub fn create_cost_model_mock_storage( + pub fn create_mock_cost_model( table_id: Vec, per_attribute_stats: Vec>, row_counts: Vec>, @@ -179,7 +180,45 @@ pub mod tests { .collect(), per_table_attr_infos, ); - CostModelImpl::new(storage_manager, CatalogSource::Mock, Arc::new(MockMemoExt)) + CostModelImpl::new( + storage_manager, + CatalogSource::Mock, + Arc::new(MockMemoExtImpl::default()), + ) + } + + pub fn create_mock_cost_model_with_memo( + table_id: Vec, + per_attribute_stats: Vec>, + row_counts: Vec>, + per_table_attr_infos: BaseTableAttrInfo, + group_info: HashMap, + ) -> TestOptCostModelMock { + let storage_manager = CostModelStorageMockManagerImpl::new( + table_id + .into_iter() + .zip(per_attribute_stats) + .zip(row_counts) + .map(|((table_id, per_attr_stats), row_count)| { + ( + table_id, + TableStats::new( + row_count.unwrap_or(100), + per_attr_stats + .into_iter() + .map(|(attr_idx, stats)| (vec![attr_idx], stats)) + .collect(), + ), + ) + }) + .collect(), + per_table_attr_infos, + ); + CostModelImpl::new( + storage_manager, + CatalogSource::Mock, + Arc::new(MockMemoExtImpl::from(group_info)), + ) } pub fn attr_ref(table_id: TableId, attr_base_index: u64) -> ArcPredicateNode { diff --git a/optd-cost-model/src/memo_ext.rs b/optd-cost-model/src/memo_ext.rs index 4ca052d..09e2c31 100644 --- a/optd-cost-model/src/memo_ext.rs +++ b/optd-cost-model/src/memo_ext.rs @@ -21,25 +21,42 @@ pub trait MemoExt: Send + Sync + 'static { // TODO: Figure out what other information is needed to compute the cost... } +#[cfg(test)] pub mod tests { + use std::collections::HashMap; + use crate::common::{ properties::{attr_ref::GroupAttrRefs, schema::Schema, Attribute}, types::GroupId, }; - pub struct MockMemoExt; + pub struct MemoGroupInfo { + pub schema: Schema, + pub attr_ref: GroupAttrRefs, + } + + #[derive(Default)] + pub struct MockMemoExtImpl { + memo: HashMap, + } - impl super::MemoExt for MockMemoExt { - fn get_schema(&self, _group_id: GroupId) -> Schema { - unimplemented!() + impl super::MemoExt for MockMemoExtImpl { + fn get_schema(&self, group_id: GroupId) -> Schema { + self.memo.get(&group_id).unwrap().schema.clone() } - fn get_attribute_ref(&self, _group_id: GroupId) -> GroupAttrRefs { - unimplemented!() + fn get_attribute_ref(&self, group_id: GroupId) -> GroupAttrRefs { + self.memo.get(&group_id).unwrap().attr_ref.clone() } - fn get_attribute_info(&self, _group_id: GroupId, _attr_ref_idx: u64) -> Attribute { - unimplemented!() + fn get_attribute_info(&self, group_id: GroupId, attr_ref_idx: u64) -> Attribute { + self.memo.get(&group_id).unwrap().schema.attributes[attr_ref_idx as usize].clone() + } + } + + impl From> for MockMemoExtImpl { + fn from(memo: HashMap) -> Self { + Self { memo } } } } diff --git a/optd-cost-model/src/storage/mock.rs b/optd-cost-model/src/storage/mock.rs index 1f369f9..c67c6a3 100644 --- a/optd-cost-model/src/storage/mock.rs +++ b/optd-cost-model/src/storage/mock.rs @@ -30,7 +30,7 @@ impl TableStats { } pub type BaseTableStats = HashMap; -pub type BaseTableAttrInfo = HashMap>; // (table_id, (attr_base_index, attr)) +pub type BaseTableAttrInfo = HashMap>; pub struct CostModelStorageMockManagerImpl { pub(crate) per_table_stats_map: BaseTableStats, @@ -55,14 +55,10 @@ impl CostModelStorageManager for CostModelStorageMockManagerImpl { table_id: TableId, attr_base_index: u64, ) -> CostModelResult> { - let table_attr_infos = self.per_table_attr_infos_map.get(&table_id); - match table_attr_infos { - None => Ok(None), - Some(table_attr_infos) => match table_attr_infos.get(&attr_base_index) { - None => Ok(None), - Some(attr) => Ok(Some(attr.clone())), - }, - } + Ok(self + .per_table_attr_infos_map + .get(&table_id) + .map(|table_attr_infos| table_attr_infos[attr_base_index as usize].clone())) } async fn get_attributes_comb_statistics( From 5197090d760b1c5020dc0a85cea37f9a177c812b Mon Sep 17 00:00:00 2001 From: Yuanxin Cao Date: Mon, 18 Nov 2024 15:30:57 -0500 Subject: [PATCH 42/51] add test infra for join --- .../src/common/properties/attr_ref.rs | 6 +- .../src/common/properties/schema.rs | 6 + optd-cost-model/src/cost/agg.rs | 74 ++-- optd-cost-model/src/cost/filter/controller.rs | 65 ++-- optd-cost-model/src/cost/filter/in_list.rs | 1 - optd-cost-model/src/cost/filter/like.rs | 2 - optd-cost-model/src/cost_model.rs | 327 ++++++++++++++++-- optd-cost-model/src/memo_ext.rs | 18 + optd-cost-model/src/storage/mock.rs | 13 +- 9 files changed, 390 insertions(+), 122 deletions(-) diff --git a/optd-cost-model/src/common/properties/attr_ref.rs b/optd-cost-model/src/common/properties/attr_ref.rs index 5c73961..d6105b6 100644 --- a/optd-cost-model/src/common/properties/attr_ref.rs +++ b/optd-cost-model/src/common/properties/attr_ref.rs @@ -23,6 +23,10 @@ pub enum AttrRef { } impl AttrRef { + pub fn new_base_table_attr_ref(table_id: TableId, attr_idx: u64) -> Self { + AttrRef::BaseTableAttrRef(BaseTableAttrRef { table_id, attr_idx }) + } + pub fn base_table_attr_ref(table_id: TableId, attr_idx: u64) -> Self { AttrRef::BaseTableAttrRef(BaseTableAttrRef { table_id, attr_idx }) } @@ -161,7 +165,7 @@ impl SemanticCorrelation { } /// [`GroupAttrRefs`] represents the attributes of a group in a query. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Default)] pub struct GroupAttrRefs { attr_refs: AttrRefs, /// Correlation of the output attributes of the group. diff --git a/optd-cost-model/src/common/properties/schema.rs b/optd-cost-model/src/common/properties/schema.rs index 4ee4fce..d25a23a 100644 --- a/optd-cost-model/src/common/properties/schema.rs +++ b/optd-cost-model/src/common/properties/schema.rs @@ -33,3 +33,9 @@ impl Schema { self.len() == 0 } } + +impl From> for Schema { + fn from(attributes: Vec) -> Self { + Self::new(attributes) + } +} diff --git a/optd-cost-model/src/cost/agg.rs b/optd-cost-model/src/cost/agg.rs index fd9141c..71c9d36 100644 --- a/optd-cost-model/src/cost/agg.rs +++ b/optd-cost-model/src/cost/agg.rs @@ -82,22 +82,22 @@ mod tests { #[tokio::test] async fn test_agg_no_stats() { let table_id = TableId(0); - let attr_infos = HashMap::from([( - table_id, - vec![ - Attribute { - name: String::from("attr1"), - typ: ConstantType::Int32, - nullable: false, - }, - Attribute { - name: String::from("attr2"), - typ: ConstantType::Int64, - nullable: false, - }, - ], - )]); - let cost_model = create_mock_cost_model(vec![table_id], vec![], vec![None], attr_infos); + // let attr_infos = HashMap::from([( + // table_id, + // vec![ + // Attribute { + // name: String::from("attr1"), + // typ: ConstantType::Int32, + // nullable: false, + // }, + // Attribute { + // name: String::from("attr2"), + // typ: ConstantType::Int64, + // nullable: false, + // }, + // ], + // )]); + let cost_model = create_mock_cost_model(vec![table_id], vec![], vec![None]); // Group by empty list should return 1. let group_bys = empty_list(); @@ -127,26 +127,26 @@ mod tests { let attr1_base_idx = 0; let attr2_base_idx = 1; let attr3_base_idx = 2; - let attr_infos = HashMap::from([( - table_id, - vec![ - Attribute { - name: String::from("attr1"), - typ: ConstantType::Int32, - nullable: false, - }, - Attribute { - name: String::from("attr2"), - typ: ConstantType::Int64, - nullable: false, - }, - Attribute { - name: String::from("attr3"), - typ: ConstantType::Int64, - nullable: false, - }, - ], - )]); + // let attr_infos = HashMap::from([( + // table_id, + // vec![ + // Attribute { + // name: String::from("attr1"), + // typ: ConstantType::Int32, + // nullable: false, + // }, + // Attribute { + // name: String::from("attr2"), + // typ: ConstantType::Int64, + // nullable: false, + // }, + // Attribute { + // name: String::from("attr3"), + // typ: ConstantType::Int64, + // nullable: false, + // }, + // ], + // )]); let attr1_ndistinct = 12; let attr2_ndistinct = 645; @@ -170,7 +170,7 @@ mod tests { (attr2_base_idx, attr2_stats), ])], vec![None], - attr_infos, + // attr_infos, ); // Group by empty list should return 1. diff --git a/optd-cost-model/src/cost/filter/controller.rs b/optd-cost-model/src/cost/filter/controller.rs index 40cf969..d8fe9b5 100644 --- a/optd-cost-model/src/cost/filter/controller.rs +++ b/optd-cost-model/src/cost/filter/controller.rs @@ -100,6 +100,7 @@ mod tests { values::Value, }, cost_model::tests::*, + memo_ext::tests::MemoGroupInfo, stats::{ utilities::{counter::Counter, simple_map::SimpleMap}, Distribution, MostCommonValues, DEFAULT_EQ_SEL, @@ -113,7 +114,6 @@ mod tests { vec![TableId(0)], vec![HashMap::from([(0, empty_per_attr_stats())])], vec![None], - HashMap::new(), ); assert_approx_eq::assert_approx_eq!( cost_model @@ -147,7 +147,6 @@ mod tests { vec![table_id], vec![HashMap::from([(0, per_attribute_stats)])], vec![None], - HashMap::new(), ); let expr_tree = bin_op(BinOpType::Eq, attr_ref(table_id, 0), cnst(Value::Int32(1))); @@ -181,7 +180,6 @@ mod tests { vec![table_id], vec![HashMap::from([(0, per_attribute_stats)])], vec![None], - HashMap::new(), ); let expr_tree = bin_op(BinOpType::Eq, attr_ref(table_id, 0), cnst(Value::Int32(2))); @@ -216,7 +214,6 @@ mod tests { vec![table_id], vec![HashMap::from([(0, per_attribute_stats)])], vec![None], - HashMap::new(), ); let expr_tree = bin_op(BinOpType::Neq, attr_ref(table_id, 0), cnst(Value::Int32(1))); @@ -250,7 +247,6 @@ mod tests { vec![table_id], vec![HashMap::from([(0, per_attribute_stats)])], vec![None], - HashMap::new(), ); let expr_tree = bin_op( @@ -293,7 +289,6 @@ mod tests { vec![table_id], vec![HashMap::from([(0, per_attribute_stats)])], vec![None], - HashMap::new(), ); let expr_tree = bin_op( @@ -336,7 +331,6 @@ mod tests { vec![table_id], vec![HashMap::from([(0, per_attribute_stats)])], vec![None], - HashMap::new(), ); let expr_tree = bin_op( @@ -374,7 +368,6 @@ mod tests { vec![table_id], vec![HashMap::from([(0, per_attribute_stats)])], vec![None], - HashMap::new(), ); let expr_tree = bin_op(BinOpType::Lt, attr_ref(table_id, 0), cnst(Value::Int32(15))); @@ -418,7 +411,6 @@ mod tests { vec![table_id], vec![HashMap::from([(0, per_attribute_stats)])], vec![None], - HashMap::new(), ); let expr_tree = bin_op(BinOpType::Lt, attr_ref(table_id, 0), cnst(Value::Int32(15))); @@ -462,7 +454,6 @@ mod tests { vec![table_id], vec![HashMap::from([(0, per_attribute_stats)])], vec![None], - HashMap::new(), ); let expr_tree = bin_op(BinOpType::Lt, attr_ref(table_id, 0), cnst(Value::Int32(15))); @@ -502,7 +493,6 @@ mod tests { vec![table_id], vec![HashMap::from([(0, per_attribute_stats)])], vec![None], - HashMap::new(), ); let expr_tree = bin_op(BinOpType::Gt, attr_ref(table_id, 0), cnst(Value::Int32(15))); @@ -540,7 +530,6 @@ mod tests { vec![table_id], vec![HashMap::from([(0, per_attribute_stats)])], vec![None], - HashMap::new(), ); let expr_tree = bin_op( @@ -580,7 +569,6 @@ mod tests { vec![table_id], vec![HashMap::from([(0, per_attribute_stats)])], vec![None], - HashMap::new(), ); let eq1 = bin_op(BinOpType::Eq, attr_ref(table_id, 0), cnst(Value::Int32(1))); @@ -627,7 +615,6 @@ mod tests { vec![table_id], vec![HashMap::from([(0, per_attribute_stats)])], vec![None], - HashMap::new(), ); let eq1 = bin_op(BinOpType::Eq, attr_ref(table_id, 0), cnst(Value::Int32(1))); @@ -673,7 +660,6 @@ mod tests { vec![table_id], vec![HashMap::from([(0, per_attribute_stats)])], vec![None], - HashMap::new(), ); let expr_tree = un_op( @@ -706,7 +692,6 @@ mod tests { vec![table_id], vec![HashMap::from([(0, per_attribute_stats)])], vec![None], - HashMap::new(), ); let expr_tree = bin_op( @@ -745,19 +730,18 @@ mod tests { 0.1, ); let table_id = TableId(0); - let attr_infos = HashMap::from([( - table_id, - vec![Attribute { - name: String::from("attr1"), - typ: ConstantType::Int32, - nullable: false, - }], - )]); + // let attr_infos = HashMap::from([( + // table_id, + // vec![Attribute { + // name: String::from("attr1"), + // typ: ConstantType::Int32, + // nullable: false, + // }], + // )]); let cost_model = create_mock_cost_model( vec![table_id], vec![HashMap::from([(0, per_attribute_stats)])], vec![None], - attr_infos, ); let expr_tree = bin_op( @@ -798,26 +782,25 @@ mod tests { 0.0, ); let table_id = TableId(0); - let attr_infos = HashMap::from([( - table_id, - vec![ - Attribute { - name: String::from("attr1"), - typ: ConstantType::Int32, - nullable: false, - }, - Attribute { - name: String::from("attr2"), - typ: ConstantType::Int64, - nullable: false, - }, - ], - )]); + // let attr_infos = HashMap::from([( + // table_id, + // vec![ + // Attribute { + // name: String::from("attr1"), + // typ: ConstantType::Int32, + // nullable: false, + // }, + // Attribute { + // name: String::from("attr2"), + // typ: ConstantType::Int64, + // nullable: false, + // }, + // ], + // )]); let cost_model = create_mock_cost_model( vec![table_id], vec![HashMap::from([(0, per_attribute_stats)])], vec![None], - attr_infos, ); let expr_tree = bin_op( diff --git a/optd-cost-model/src/cost/filter/in_list.rs b/optd-cost-model/src/cost/filter/in_list.rs index 4c6cf3a..36a3791 100644 --- a/optd-cost-model/src/cost/filter/in_list.rs +++ b/optd-cost-model/src/cost/filter/in_list.rs @@ -95,7 +95,6 @@ mod tests { vec![table_id], vec![HashMap::from([(0, per_attribute_stats)])], vec![None], - HashMap::new(), ); assert_approx_eq::assert_approx_eq!( diff --git a/optd-cost-model/src/cost/filter/like.rs b/optd-cost-model/src/cost/filter/like.rs index b1b3e98..7f62fcc 100644 --- a/optd-cost-model/src/cost/filter/like.rs +++ b/optd-cost-model/src/cost/filter/like.rs @@ -129,7 +129,6 @@ mod tests { vec![table_id], vec![HashMap::from([(0, per_attribute_stats)])], vec![None], - HashMap::new(), ); assert_approx_eq::assert_approx_eq!( @@ -172,7 +171,6 @@ mod tests { vec![table_id], vec![HashMap::from([(0, per_attribute_stats)])], vec![None], - HashMap::new(), ); assert_approx_eq::assert_approx_eq!( diff --git a/optd-cost-model/src/cost_model.rs b/optd-cost-model/src/cost_model.rs index 441fa1a..03319f1 100644 --- a/optd-cost-model/src/cost_model.rs +++ b/optd-cost-model/src/cost_model.rs @@ -131,13 +131,18 @@ pub mod tests { attr_ref_pred::AttrRefPred, bin_op_pred::{BinOpPred, BinOpType}, cast_pred::CastPred, - constant_pred::ConstantPred, + constant_pred::{ConstantPred, ConstantType}, in_list_pred::InListPred, like_pred::LikePred, list_pred::ListPred, log_op_pred::{LogOpPred, LogOpType}, un_op_pred::{UnOpPred, UnOpType}, }, + properties::{ + attr_ref::{AttrRef, GroupAttrRefs}, + schema::Schema, + Attribute, + }, types::GroupId, values::Value, }, @@ -145,11 +150,21 @@ pub mod tests { stats::{ utilities::counter::Counter, AttributeCombValueStats, Distribution, MostCommonValues, }, - storage::mock::{BaseTableAttrInfo, CostModelStorageMockManagerImpl, TableStats}, + storage::mock::{CostModelStorageMockManagerImpl, TableStats}, }; use super::*; + const TEST_TABLE1_ID: TableId = TableId(0); + const TEST_TABLE2_ID: TableId = TableId(1); + const TEST_TABLE3_ID: TableId = TableId(2); + const TEST_TABLE4_ID: TableId = TableId(3); + + const TEST_GROUP1_ID: GroupId = GroupId(0); + const TEST_GROUP2_ID: GroupId = GroupId(1); + const TEST_GROUP3_ID: GroupId = GroupId(2); + const TEST_GROUP4_ID: GroupId = GroupId(3); + pub type TestPerAttributeStats = AttributeCombValueStats; // TODO: add tests for non-mock storage manager pub type TestOptCostModelMock = CostModelImpl; @@ -158,7 +173,6 @@ pub mod tests { table_id: Vec, per_attribute_stats: Vec>, row_counts: Vec>, - per_table_attr_infos: BaseTableAttrInfo, ) -> TestOptCostModelMock { let storage_manager = CostModelStorageMockManagerImpl::new( table_id @@ -178,7 +192,6 @@ pub mod tests { ) }) .collect(), - per_table_attr_infos, ); CostModelImpl::new( storage_manager, @@ -187,40 +200,296 @@ pub mod tests { ) } - pub fn create_mock_cost_model_with_memo( - table_id: Vec, - per_attribute_stats: Vec>, - row_counts: Vec>, - per_table_attr_infos: BaseTableAttrInfo, - group_info: HashMap, + /// Create a cost model two tables, each with one attribute. Each attribute has 100 values. + pub fn create_two_table_mock_cost_model( + tbl1_per_attr_stats: TestPerAttributeStats, + tbl2_per_attr_stats: TestPerAttributeStats, + ) -> TestOptCostModelMock { + create_two_table_cost_model_custom_row_cnts( + tbl1_per_attr_stats, + tbl2_per_attr_stats, + 100, + 100, + ) + } + + /// Create a cost model with three columns, one for each table. Each column has 100 values. + pub fn create_three_table_cost_model( + tbl1_per_column_stats: TestPerAttributeStats, + tbl2_per_column_stats: TestPerAttributeStats, + tbl3_per_column_stats: TestPerAttributeStats, ) -> TestOptCostModelMock { let storage_manager = CostModelStorageMockManagerImpl::new( - table_id - .into_iter() - .zip(per_attribute_stats) - .zip(row_counts) - .map(|((table_id, per_attr_stats), row_count)| { - ( - table_id, - TableStats::new( - row_count.unwrap_or(100), - per_attr_stats - .into_iter() - .map(|(attr_idx, stats)| (vec![attr_idx], stats)) - .collect(), - ), - ) - }) - .collect(), - per_table_attr_infos, + vec![ + ( + TEST_TABLE1_ID, + TableStats::new( + 100, + vec![(vec![0], tbl1_per_column_stats)].into_iter().collect(), + ), + ), + ( + TEST_TABLE2_ID, + TableStats::new( + 100, + vec![(vec![0], tbl2_per_column_stats)].into_iter().collect(), + ), + ), + ( + TEST_TABLE3_ID, + TableStats::new( + 100, + vec![(vec![0], tbl3_per_column_stats)].into_iter().collect(), + ), + ), + ] + .into_iter() + .collect(), ); + let memo = HashMap::from([ + ( + TEST_GROUP1_ID, + MemoGroupInfo::new( + vec![Attribute { + name: "attr1".to_string(), + typ: ConstantType::Int64, + nullable: false, + }] + .into(), + GroupAttrRefs::new( + vec![AttrRef::new_base_table_attr_ref(TEST_TABLE1_ID, 0)], + None, + ), + ), + ), + ( + TEST_GROUP2_ID, + MemoGroupInfo::new( + vec![Attribute { + name: "attr2".to_string(), + typ: ConstantType::Int64, + nullable: false, + }] + .into(), + GroupAttrRefs::new( + vec![AttrRef::new_base_table_attr_ref(TEST_TABLE2_ID, 0)], + None, + ), + ), + ), + ( + TEST_GROUP3_ID, + MemoGroupInfo::new( + vec![Attribute { + name: "attr3".to_string(), + typ: ConstantType::Int64, + nullable: false, + }] + .into(), + GroupAttrRefs::new( + vec![AttrRef::new_base_table_attr_ref(TEST_TABLE3_ID, 0)], + None, + ), + ), + ), + ]); CostModelImpl::new( storage_manager, CatalogSource::Mock, - Arc::new(MockMemoExtImpl::from(group_info)), + Arc::new(MockMemoExtImpl::from(memo)), ) } + /// Create a cost model with three columns, one for each table. Each column has 100 values. + pub fn create_four_table_cost_model( + tbl1_per_column_stats: TestPerAttributeStats, + tbl2_per_column_stats: TestPerAttributeStats, + tbl3_per_column_stats: TestPerAttributeStats, + tbl4_per_column_stats: TestPerAttributeStats, + ) -> TestOptCostModelMock { + let storage_manager = CostModelStorageMockManagerImpl::new( + vec![ + ( + TEST_TABLE1_ID, + TableStats::new( + 100, + vec![(vec![0], tbl1_per_column_stats)].into_iter().collect(), + ), + ), + ( + TEST_TABLE2_ID, + TableStats::new( + 100, + vec![(vec![0], tbl2_per_column_stats)].into_iter().collect(), + ), + ), + ( + TEST_TABLE3_ID, + TableStats::new( + 100, + vec![(vec![0], tbl3_per_column_stats)].into_iter().collect(), + ), + ), + ( + TEST_TABLE4_ID, + TableStats::new( + 100, + vec![(vec![0], tbl4_per_column_stats)].into_iter().collect(), + ), + ), + ] + .into_iter() + .collect(), + ); + let memo = HashMap::from([ + ( + TEST_GROUP1_ID, + MemoGroupInfo::new( + vec![Attribute { + name: "attr1".to_string(), + typ: ConstantType::Int64, + nullable: false, + }] + .into(), + GroupAttrRefs::new( + vec![AttrRef::new_base_table_attr_ref(TEST_TABLE1_ID, 0)], + None, + ), + ), + ), + ( + TEST_GROUP2_ID, + MemoGroupInfo::new( + vec![Attribute { + name: "attr2".to_string(), + typ: ConstantType::Int64, + nullable: false, + }] + .into(), + GroupAttrRefs::new( + vec![AttrRef::new_base_table_attr_ref(TEST_TABLE2_ID, 0)], + None, + ), + ), + ), + ( + TEST_GROUP3_ID, + MemoGroupInfo::new( + vec![Attribute { + name: "attr3".to_string(), + typ: ConstantType::Int64, + nullable: false, + }] + .into(), + GroupAttrRefs::new( + vec![AttrRef::new_base_table_attr_ref(TEST_TABLE3_ID, 0)], + None, + ), + ), + ), + ( + TEST_GROUP4_ID, + MemoGroupInfo::new( + vec![Attribute { + name: "attr4".to_string(), + typ: ConstantType::Int64, + nullable: false, + }] + .into(), + GroupAttrRefs::new( + vec![AttrRef::new_base_table_attr_ref(TEST_TABLE4_ID, 0)], + None, + ), + ), + ), + ]); + CostModelImpl::new( + storage_manager, + CatalogSource::Mock, + Arc::new(MockMemoExtImpl::from(memo)), + ) + } + + /// We need custom row counts because some join algorithms rely on the row cnt + pub fn create_two_table_cost_model_custom_row_cnts( + tbl1_per_column_stats: TestPerAttributeStats, + tbl2_per_column_stats: TestPerAttributeStats, + tbl1_row_cnt: u64, + tbl2_row_cnt: u64, + ) -> TestOptCostModelMock { + let storage_manager = CostModelStorageMockManagerImpl::new( + vec![ + ( + TEST_TABLE1_ID, + TableStats::new( + tbl1_row_cnt, + vec![(vec![0], tbl1_per_column_stats)].into_iter().collect(), + ), + ), + ( + TEST_TABLE2_ID, + TableStats::new( + tbl2_row_cnt, + vec![(vec![0], tbl2_per_column_stats)].into_iter().collect(), + ), + ), + ] + .into_iter() + .collect(), + ); + let memo = HashMap::from([ + ( + TEST_GROUP1_ID, + MemoGroupInfo::new( + vec![Attribute { + name: "attr1".to_string(), + typ: ConstantType::Int64, + nullable: false, + }] + .into(), + GroupAttrRefs::new( + vec![AttrRef::new_base_table_attr_ref(TEST_TABLE1_ID, 0)], + None, + ), + ), + ), + ( + TEST_GROUP2_ID, + MemoGroupInfo::new( + vec![Attribute { + name: "attr2".to_string(), + typ: ConstantType::Int64, + nullable: false, + }] + .into(), + GroupAttrRefs::new( + vec![AttrRef::new_base_table_attr_ref(TEST_TABLE2_ID, 0)], + None, + ), + ), + ), + ]); + CostModelImpl::new( + storage_manager, + CatalogSource::Mock, + Arc::new(MockMemoExtImpl::from(memo)), + ) + } + + impl TestOptCostModelMock { + pub fn get_row_count(&self, table_id: TableId) -> u64 { + self.storage_manager + .per_table_stats_map + .get(&table_id) + .map(|stats| stats.row_cnt) + .unwrap_or(0) + } + + pub fn get_attr_refs(&self, group_id: GroupId) -> GroupAttrRefs { + self.memo.get_attribute_ref(group_id) + } + } + pub fn attr_ref(table_id: TableId, attr_base_index: u64) -> ArcPredicateNode { AttrRefPred::new(table_id, attr_base_index).into_pred_node() } diff --git a/optd-cost-model/src/memo_ext.rs b/optd-cost-model/src/memo_ext.rs index 09e2c31..559ba6b 100644 --- a/optd-cost-model/src/memo_ext.rs +++ b/optd-cost-model/src/memo_ext.rs @@ -35,11 +35,29 @@ pub mod tests { pub attr_ref: GroupAttrRefs, } + impl MemoGroupInfo { + pub fn new(schema: Schema, attr_ref: GroupAttrRefs) -> Self { + Self { schema, attr_ref } + } + } + #[derive(Default)] pub struct MockMemoExtImpl { memo: HashMap, } + impl MockMemoExtImpl { + pub fn add_group_info( + &mut self, + group_id: GroupId, + schema: Schema, + attr_ref: GroupAttrRefs, + ) { + self.memo + .insert(group_id, MemoGroupInfo::new(schema, attr_ref)); + } + } + impl super::MemoExt for MockMemoExtImpl { fn get_schema(&self, group_id: GroupId) -> Schema { self.memo.get(&group_id).unwrap().schema.clone() diff --git a/optd-cost-model/src/storage/mock.rs b/optd-cost-model/src/storage/mock.rs index c67c6a3..4894859 100644 --- a/optd-cost-model/src/storage/mock.rs +++ b/optd-cost-model/src/storage/mock.rs @@ -30,21 +30,15 @@ impl TableStats { } pub type BaseTableStats = HashMap; -pub type BaseTableAttrInfo = HashMap>; pub struct CostModelStorageMockManagerImpl { pub(crate) per_table_stats_map: BaseTableStats, - pub(crate) per_table_attr_infos_map: BaseTableAttrInfo, } impl CostModelStorageMockManagerImpl { - pub fn new( - per_table_stats_map: BaseTableStats, - per_table_attr_infos_map: BaseTableAttrInfo, - ) -> Self { + pub fn new(per_table_stats_map: BaseTableStats) -> Self { Self { per_table_stats_map, - per_table_attr_infos_map, } } } @@ -55,10 +49,7 @@ impl CostModelStorageManager for CostModelStorageMockManagerImpl { table_id: TableId, attr_base_index: u64, ) -> CostModelResult> { - Ok(self - .per_table_attr_infos_map - .get(&table_id) - .map(|table_attr_infos| table_attr_infos[attr_base_index as usize].clone())) + unimplemented!() } async fn get_attributes_comb_statistics( From 68b288518f7c5198fb6ce30a108149764c1872f8 Mon Sep 17 00:00:00 2001 From: Yuanxin Cao Date: Mon, 18 Nov 2024 15:40:38 -0500 Subject: [PATCH 43/51] refine mock interface --- optd-cost-model/src/cost/agg.rs | 40 ++++--------------------------- optd-cost-model/src/cost_model.rs | 11 ++++++++- 2 files changed, 14 insertions(+), 37 deletions(-) diff --git a/optd-cost-model/src/cost/agg.rs b/optd-cost-model/src/cost/agg.rs index 71c9d36..abe9d54 100644 --- a/optd-cost-model/src/cost/agg.rs +++ b/optd-cost-model/src/cost/agg.rs @@ -68,7 +68,9 @@ mod tests { use crate::{ common::{ - predicates::constant_pred::ConstantType, properties::Attribute, types::TableId, + predicates::constant_pred::ConstantType, + properties::Attribute, + types::{GroupId, TableId}, values::Value, }, cost_model::tests::{ @@ -82,21 +84,6 @@ mod tests { #[tokio::test] async fn test_agg_no_stats() { let table_id = TableId(0); - // let attr_infos = HashMap::from([( - // table_id, - // vec![ - // Attribute { - // name: String::from("attr1"), - // typ: ConstantType::Int32, - // nullable: false, - // }, - // Attribute { - // name: String::from("attr2"), - // typ: ConstantType::Int64, - // nullable: false, - // }, - // ], - // )]); let cost_model = create_mock_cost_model(vec![table_id], vec![], vec![None]); // Group by empty list should return 1. @@ -124,29 +111,10 @@ mod tests { #[tokio::test] async fn test_agg_with_stats() { let table_id = TableId(0); + let group_id = GroupId(0); let attr1_base_idx = 0; let attr2_base_idx = 1; let attr3_base_idx = 2; - // let attr_infos = HashMap::from([( - // table_id, - // vec![ - // Attribute { - // name: String::from("attr1"), - // typ: ConstantType::Int32, - // nullable: false, - // }, - // Attribute { - // name: String::from("attr2"), - // typ: ConstantType::Int64, - // nullable: false, - // }, - // Attribute { - // name: String::from("attr3"), - // typ: ConstantType::Int64, - // nullable: false, - // }, - // ], - // )]); let attr1_ndistinct = 12; let attr2_ndistinct = 645; diff --git a/optd-cost-model/src/cost_model.rs b/optd-cost-model/src/cost_model.rs index 03319f1..9126b6b 100644 --- a/optd-cost-model/src/cost_model.rs +++ b/optd-cost-model/src/cost_model.rs @@ -173,6 +173,15 @@ pub mod tests { table_id: Vec, per_attribute_stats: Vec>, row_counts: Vec>, + ) -> TestOptCostModelMock { + create_mock_cost_model_with_memo(table_id, per_attribute_stats, row_counts, HashMap::new()) + } + + pub fn create_mock_cost_model_with_memo( + table_id: Vec, + per_attribute_stats: Vec>, + row_counts: Vec>, + memo: HashMap, ) -> TestOptCostModelMock { let storage_manager = CostModelStorageMockManagerImpl::new( table_id @@ -196,7 +205,7 @@ pub mod tests { CostModelImpl::new( storage_manager, CatalogSource::Mock, - Arc::new(MockMemoExtImpl::default()), + Arc::new(MockMemoExtImpl::from(memo)), ) } From 36b93b930bd29b21c50d3fae04c8b99cc99daecd Mon Sep 17 00:00:00 2001 From: Yuanxin Cao Date: Mon, 18 Nov 2024 15:43:36 -0500 Subject: [PATCH 44/51] make CostModelStorageManagerImpl::get_attribute_info unimplemented --- optd-cost-model/src/storage/persistent.rs | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/optd-cost-model/src/storage/persistent.rs b/optd-cost-model/src/storage/persistent.rs index aba3ce2..ed4f07a 100644 --- a/optd-cost-model/src/storage/persistent.rs +++ b/optd-cost-model/src/storage/persistent.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use optd_persistent::{cost_model::interface::StatType, CostModelStorageLayer}; use crate::{ - common::{predicates::constant_pred::ConstantType, properties::Attribute, types::TableId}, + common::{properties::Attribute, types::TableId}, stats::{utilities::counter::Counter, AttributeCombValueStats, Distribution, MostCommonValues}, CostModelResult, }; @@ -26,26 +26,12 @@ impl CostModelStorageManagerImpl { impl CostModelStorageManager for CostModelStorageManagerImpl { - /// Gets the attribute information for a given table and attribute base index. - /// - /// TODO: if we have memory cache, - /// we should add the reference. (&Attr) - /// TODO(IMPORTANT): what if table is a derived (temporary) table? And what if - /// the attribute is a derived attribute? async fn get_attribute_info( &self, table_id: TableId, attr_base_index: u64, ) -> CostModelResult> { - Ok(self - .backend_manager - .get_attribute(table_id.into(), attr_base_index as i32) - .await? - .map(|attr| Attribute { - name: attr.name, - typ: ConstantType::from_persistent_attr_type(attr.attr_type), - nullable: attr.nullable, - })) + unimplemented!() } /// Gets the latest statistics for a given table. From 11a3a4e94ed5f05d459b040cef21f57b57751f14 Mon Sep 17 00:00:00 2001 From: Yuanxin Cao Date: Mon, 18 Nov 2024 16:20:20 -0500 Subject: [PATCH 45/51] modify MemoExt interface --- optd-cost-model/src/cost/join/hash_join.rs | 6 +-- .../src/cost/join/nested_loop_join.rs | 6 +-- optd-cost-model/src/cost_model.rs | 18 ++++----- optd-cost-model/src/memo_ext.rs | 38 +++++++++++++------ 4 files changed, 41 insertions(+), 27 deletions(-) diff --git a/optd-cost-model/src/cost/join/hash_join.rs b/optd-cost-model/src/cost/join/hash_join.rs index c4049db..b912108 100644 --- a/optd-cost-model/src/cost/join/hash_join.rs +++ b/optd-cost-model/src/cost/join/hash_join.rs @@ -28,9 +28,9 @@ impl CostModelImpl { right_keys: ListPred, ) -> CostModelResult { let selectivity = { - let output_attr_refs = self.memo.get_attribute_ref(group_id); - let left_attr_refs = self.memo.get_attribute_ref(left_group_id); - let right_attr_refs = self.memo.get_attribute_ref(right_group_id); + let output_attr_refs = self.memo.get_attribute_refs(group_id); + let left_attr_refs = self.memo.get_attribute_refs(left_group_id); + let right_attr_refs = self.memo.get_attribute_refs(right_group_id); let left_attr_cnt = left_attr_refs.attr_refs().len(); // there may be more than one expression tree in a group. // see comment in PredicateType::PhysicalFilter(_) for more information diff --git a/optd-cost-model/src/cost/join/nested_loop_join.rs b/optd-cost-model/src/cost/join/nested_loop_join.rs index 0c9102f..7f99e34 100644 --- a/optd-cost-model/src/cost/join/nested_loop_join.rs +++ b/optd-cost-model/src/cost/join/nested_loop_join.rs @@ -25,9 +25,9 @@ impl CostModelImpl { join_cond: ArcPredicateNode, ) -> CostModelResult { let selectivity = { - let output_attr_refs = self.memo.get_attribute_ref(group_id); - let left_attr_refs = self.memo.get_attribute_ref(left_group_id); - let right_attr_refs = self.memo.get_attribute_ref(right_group_id); + let output_attr_refs = self.memo.get_attribute_refs(group_id); + let left_attr_refs = self.memo.get_attribute_refs(left_group_id); + let right_attr_refs = self.memo.get_attribute_refs(right_group_id); let input_correlation = get_input_correlation(left_attr_refs, right_attr_refs); self.get_join_selectivity_from_expr_tree( diff --git a/optd-cost-model/src/cost_model.rs b/optd-cost-model/src/cost_model.rs index 9126b6b..4e12a51 100644 --- a/optd-cost-model/src/cost_model.rs +++ b/optd-cost-model/src/cost_model.rs @@ -155,15 +155,15 @@ pub mod tests { use super::*; - const TEST_TABLE1_ID: TableId = TableId(0); - const TEST_TABLE2_ID: TableId = TableId(1); - const TEST_TABLE3_ID: TableId = TableId(2); - const TEST_TABLE4_ID: TableId = TableId(3); + pub const TEST_TABLE1_ID: TableId = TableId(0); + pub const TEST_TABLE2_ID: TableId = TableId(1); + pub const TEST_TABLE3_ID: TableId = TableId(2); + pub const TEST_TABLE4_ID: TableId = TableId(3); - const TEST_GROUP1_ID: GroupId = GroupId(0); - const TEST_GROUP2_ID: GroupId = GroupId(1); - const TEST_GROUP3_ID: GroupId = GroupId(2); - const TEST_GROUP4_ID: GroupId = GroupId(3); + pub const TEST_GROUP1_ID: GroupId = GroupId(0); + pub const TEST_GROUP2_ID: GroupId = GroupId(1); + pub const TEST_GROUP3_ID: GroupId = GroupId(2); + pub const TEST_GROUP4_ID: GroupId = GroupId(3); pub type TestPerAttributeStats = AttributeCombValueStats; // TODO: add tests for non-mock storage manager @@ -495,7 +495,7 @@ pub mod tests { } pub fn get_attr_refs(&self, group_id: GroupId) -> GroupAttrRefs { - self.memo.get_attribute_ref(group_id) + self.memo.get_attribute_refs(group_id) } } diff --git a/optd-cost-model/src/memo_ext.rs b/optd-cost-model/src/memo_ext.rs index 559ba6b..c7827c5 100644 --- a/optd-cost-model/src/memo_ext.rs +++ b/optd-cost-model/src/memo_ext.rs @@ -1,5 +1,9 @@ use crate::common::{ - properties::{attr_ref::GroupAttrRefs, schema::Schema, Attribute}, + properties::{ + attr_ref::{AttrRef, GroupAttrRefs}, + schema::Schema, + Attribute, + }, types::GroupId, }; @@ -13,10 +17,12 @@ use crate::common::{ pub trait MemoExt: Send + Sync + 'static { /// Get the schema of a group in the memo. fn get_schema(&self, group_id: GroupId) -> Schema; - /// Get the attribute reference of a group in the memo. - fn get_attribute_ref(&self, group_id: GroupId) -> GroupAttrRefs; - /// Get the attribute information of a given attribute in a group in the memo. + /// Get the attribute info of a given attribute in a group in the memo. fn get_attribute_info(&self, group_id: GroupId, attr_ref_idx: u64) -> Attribute; + /// Get the attribute reference of a group in the memo. + fn get_attribute_refs(&self, group_id: GroupId) -> GroupAttrRefs; + /// Get the attribute reference of a given attribute in a group in the memo. + fn get_attribute_ref(&self, group_id: GroupId, attr_ref_idx: u64) -> AttrRef; // TODO: Figure out what other information is needed to compute the cost... } @@ -26,18 +32,22 @@ pub mod tests { use std::collections::HashMap; use crate::common::{ - properties::{attr_ref::GroupAttrRefs, schema::Schema, Attribute}, + properties::{ + attr_ref::{AttrRef, GroupAttrRefs}, + schema::Schema, + Attribute, + }, types::GroupId, }; pub struct MemoGroupInfo { pub schema: Schema, - pub attr_ref: GroupAttrRefs, + pub attr_refs: GroupAttrRefs, } impl MemoGroupInfo { - pub fn new(schema: Schema, attr_ref: GroupAttrRefs) -> Self { - Self { schema, attr_ref } + pub fn new(schema: Schema, attr_refs: GroupAttrRefs) -> Self { + Self { schema, attr_refs } } } @@ -63,13 +73,17 @@ pub mod tests { self.memo.get(&group_id).unwrap().schema.clone() } - fn get_attribute_ref(&self, group_id: GroupId) -> GroupAttrRefs { - self.memo.get(&group_id).unwrap().attr_ref.clone() - } - fn get_attribute_info(&self, group_id: GroupId, attr_ref_idx: u64) -> Attribute { self.memo.get(&group_id).unwrap().schema.attributes[attr_ref_idx as usize].clone() } + + fn get_attribute_refs(&self, group_id: GroupId) -> GroupAttrRefs { + self.memo.get(&group_id).unwrap().attr_refs.clone() + } + + fn get_attribute_ref(&self, group_id: GroupId, attr_ref_idx: u64) -> AttrRef { + self.memo.get(&group_id).unwrap().attr_refs.attr_refs()[attr_ref_idx as usize].clone() + } } impl From> for MockMemoExtImpl { From 8c4191f273207b0c66511ca07e46d1462515eb94 Mon Sep 17 00:00:00 2001 From: Yuanxin Cao Date: Mon, 18 Nov 2024 17:10:51 -0500 Subject: [PATCH 46/51] rename AttrRefPred -> AttrIndexPred and revert back to initial design --- optd-cost-model/src/common/nodes.rs | 5 +- .../src/common/predicates/attr_index_pred.rs | 42 +++++ .../src/common/predicates/attr_ref_pred.rs | 74 -------- .../src/common/predicates/id_pred.rs | 43 ----- optd-cost-model/src/common/predicates/mod.rs | 3 +- optd-cost-model/src/cost/agg.rs | 31 ++-- optd-cost-model/src/cost/filter/comp_op.rs | 22 +-- optd-cost-model/src/cost/filter/controller.rs | 160 +++++++++++++----- optd-cost-model/src/cost/filter/in_list.rs | 20 +-- optd-cost-model/src/cost/filter/like.rs | 18 +- optd-cost-model/src/cost/join/hash_join.rs | 2 +- optd-cost-model/src/cost/join/join.rs | 10 +- optd-cost-model/src/cost/join/mod.rs | 11 +- optd-cost-model/src/cost_model.rs | 19 +-- 14 files changed, 230 insertions(+), 230 deletions(-) create mode 100644 optd-cost-model/src/common/predicates/attr_index_pred.rs delete mode 100644 optd-cost-model/src/common/predicates/attr_ref_pred.rs delete mode 100644 optd-cost-model/src/common/predicates/id_pred.rs diff --git a/optd-cost-model/src/common/nodes.rs b/optd-cost-model/src/common/nodes.rs index 8ad98e4..79a47f7 100644 --- a/optd-cost-model/src/common/nodes.rs +++ b/optd-cost-model/src/common/nodes.rs @@ -56,10 +56,7 @@ impl std::fmt::Display for PhysicalNodeType { pub enum PredicateType { List, Constant(ConstantType), - AttrRef, - ExternAttributeRef, - // TODO(lanlou): Id -> Id(IdType) - Id, + AttrIndex, UnOp(UnOpType), BinOp(BinOpType), LogOp(LogOpType), diff --git a/optd-cost-model/src/common/predicates/attr_index_pred.rs b/optd-cost-model/src/common/predicates/attr_index_pred.rs new file mode 100644 index 0000000..412c7a3 --- /dev/null +++ b/optd-cost-model/src/common/predicates/attr_index_pred.rs @@ -0,0 +1,42 @@ +use crate::common::{ + nodes::{ArcPredicateNode, PredicateNode, PredicateType, ReprPredicateNode}, + values::Value, +}; + +/// [`AttributeIndexPred`] represents the position of an attribute in a schema or +/// [`GroupAttrRefs`]. +/// +/// The `data` field holds the index of the attribute in the schema or [`GroupAttrRefs`]. +#[derive(Clone, Debug)] +pub struct AttrIndexPred(pub ArcPredicateNode); + +impl AttrIndexPred { + pub fn new(attr_idx: u64) -> AttrIndexPred { + AttrIndexPred( + PredicateNode { + typ: PredicateType::AttrIndex, + children: vec![], + data: Some(Value::UInt64(attr_idx)), + } + .into(), + ) + } + + /// Gets the attribute index. + pub fn attr_index(&self) -> u64 { + self.0.data.as_ref().unwrap().as_u64() + } +} + +impl ReprPredicateNode for AttrIndexPred { + fn into_pred_node(self) -> ArcPredicateNode { + self.0 + } + + fn from_pred_node(pred_node: ArcPredicateNode) -> Option { + if pred_node.typ != PredicateType::AttrIndex { + return None; + } + Some(Self(pred_node)) + } +} diff --git a/optd-cost-model/src/common/predicates/attr_ref_pred.rs b/optd-cost-model/src/common/predicates/attr_ref_pred.rs deleted file mode 100644 index 9afe6a0..0000000 --- a/optd-cost-model/src/common/predicates/attr_ref_pred.rs +++ /dev/null @@ -1,74 +0,0 @@ -use crate::common::{ - nodes::{ArcPredicateNode, PredicateNode, PredicateType, ReprPredicateNode}, - types::TableId, -}; - -use super::id_pred::IdPred; - -/// [`AttributeRefPred`] represents a reference to a column in a relation. -/// -/// An [`AttributeRefPred`] has two children: -/// 1. The table id, represented by an [`IdPred`]. -/// 2. The index of the attribute, represented by an [`IdPred`]. -/// -/// Although it may be strange at first glance (table id and attribute base index -/// aren't children of the attribute reference), but considering the attribute reference -/// can be represented as table_id.attr_base_index, and it enables the cost model to -/// obtain the information in a simple way without refactoring `data` field. -/// -/// **TODO**: Now we assume any IdPred is as same as the ones in the ORM layer. -/// -/// Currently, [`AttributeRefPred`] only holds base table attributes, i.e. attributes -/// that already exist in the table. More complex structures may be introduced in the -/// future to represent derived attributes (e.g. t.v1 + t.v2). -/// -/// TODO: Support derived column in `AttributeRefPred`. -/// Proposal: Data field can store the column type (base or derived). -#[derive(Clone, Debug)] -pub struct AttrRefPred(pub ArcPredicateNode); - -impl AttrRefPred { - pub fn new(table_id: TableId, attribute_idx: u64) -> AttrRefPred { - AttrRefPred( - PredicateNode { - typ: PredicateType::AttrRef, - children: vec![ - IdPred::new(table_id.0).into_pred_node(), - IdPred::new(attribute_idx).into_pred_node(), - ], - data: None, - } - .into(), - ) - } - - /// Gets the table id. - pub fn table_id(&self) -> TableId { - TableId(self.0.child(0).data.as_ref().unwrap().as_u64()) - } - - /// Gets the attribute index. - /// Note: The attribute index is the **base** index, which is table specific. - pub fn attr_index(&self) -> u64 { - self.0.child(1).data.as_ref().unwrap().as_u64() - } - - /// Checks whether the attribute is a derived attribute. Currently, this will always return - /// false, since derived attribute is not yet supported. - pub fn is_derived(&self) -> bool { - false - } -} - -impl ReprPredicateNode for AttrRefPred { - fn into_pred_node(self) -> ArcPredicateNode { - self.0 - } - - fn from_pred_node(pred_node: ArcPredicateNode) -> Option { - if pred_node.typ != PredicateType::AttrRef { - return None; - } - Some(Self(pred_node)) - } -} diff --git a/optd-cost-model/src/common/predicates/id_pred.rs b/optd-cost-model/src/common/predicates/id_pred.rs deleted file mode 100644 index 13f557f..0000000 --- a/optd-cost-model/src/common/predicates/id_pred.rs +++ /dev/null @@ -1,43 +0,0 @@ -use crate::common::{ - nodes::{ArcPredicateNode, PredicateNode, PredicateType, ReprPredicateNode}, - values::Value, -}; - -/// [`IdPred`] holds an id or an index, e.g. table id. -/// -/// The data is of uint64 type, because an id or an index can always be -/// represented by uint64. -#[derive(Clone, Debug)] -pub struct IdPred(pub ArcPredicateNode); - -impl IdPred { - pub fn new(id: u64) -> IdPred { - IdPred( - PredicateNode { - typ: PredicateType::Id, - children: vec![], - data: Some(Value::UInt64(id)), - } - .into(), - ) - } - - /// Gets the id stored in the predicate. - pub fn id(&self) -> u64 { - self.0.data.clone().unwrap().as_u64() - } -} - -impl ReprPredicateNode for IdPred { - fn into_pred_node(self) -> ArcPredicateNode { - self.0 - } - - fn from_pred_node(pred_node: ArcPredicateNode) -> Option { - if let PredicateType::Id = pred_node.typ { - Some(Self(pred_node)) - } else { - None - } - } -} diff --git a/optd-cost-model/src/common/predicates/mod.rs b/optd-cost-model/src/common/predicates/mod.rs index 65d6ad0..40c64cf 100644 --- a/optd-cost-model/src/common/predicates/mod.rs +++ b/optd-cost-model/src/common/predicates/mod.rs @@ -1,10 +1,9 @@ -pub mod attr_ref_pred; +pub mod attr_index_pred; pub mod bin_op_pred; pub mod cast_pred; pub mod constant_pred; pub mod data_type_pred; pub mod func_pred; -pub mod id_pred; pub mod in_list_pred; pub mod like_pred; pub mod list_pred; diff --git a/optd-cost-model/src/cost/agg.rs b/optd-cost-model/src/cost/agg.rs index abe9d54..254a6ad 100644 --- a/optd-cost-model/src/cost/agg.rs +++ b/optd-cost-model/src/cost/agg.rs @@ -1,7 +1,7 @@ use crate::{ common::{ nodes::{ArcPredicateNode, PredicateType, ReprPredicateNode}, - predicates::{attr_ref_pred::AttrRefPred, list_pred::ListPred}, + predicates::{attr_index_pred::AttrIndexPred, list_pred::ListPred}, types::TableId, }, cost_model::CostModelImpl, @@ -25,17 +25,18 @@ impl CostModelImpl { for node in &group_by.0.children { match node.typ { - PredicateType::AttrRef => { + PredicateType::AttrIndex => { let attr_ref = - AttrRefPred::from_pred_node(node.clone()).ok_or_else(|| { + AttrIndexPred::from_pred_node(node.clone()).ok_or_else(|| { SemanticError::InvalidPredicate( "Expected AttributeRef predicate".to_string(), ) })?; - if attr_ref.is_derived() { + let is_derived = todo!(); + if is_derived { row_cnt *= DEFAULT_NUM_DISTINCT; } else { - let table_id = attr_ref.table_id(); + let table_id = todo!(); let attr_idx = attr_ref.attr_index(); // TODO: Only query ndistinct instead of all kinds of stats. let stats_option = @@ -74,7 +75,7 @@ mod tests { values::Value, }, cost_model::tests::{ - attr_ref, cnst, create_mock_cost_model, empty_list, empty_per_attr_stats, list, + attr_index, cnst, create_mock_cost_model, empty_list, empty_per_attr_stats, list, TestPerAttributeStats, }, stats::{utilities::simple_map::SimpleMap, MostCommonValues, DEFAULT_NUM_DISTINCT}, @@ -94,14 +95,14 @@ mod tests { ); // Group by single column should return the default value since there are no stats. - let group_bys = list(vec![attr_ref(table_id, 0)]); + let group_bys = list(vec![attr_index(0)]); assert_eq!( cost_model.get_agg_row_cnt(group_bys).await.unwrap(), EstimatedStatistic(DEFAULT_NUM_DISTINCT as f64) ); // Group by two columns should return the default value squared since there are no stats. - let group_bys = list(vec![attr_ref(table_id, 0), attr_ref(table_id, 1)]); + let group_bys = list(vec![attr_index(0), attr_index(1)]); assert_eq!( cost_model.get_agg_row_cnt(group_bys).await.unwrap(), EstimatedStatistic((DEFAULT_NUM_DISTINCT * DEFAULT_NUM_DISTINCT) as f64) @@ -149,17 +150,14 @@ mod tests { ); // Group by single column should return the n-distinct of the column. - let group_bys = list(vec![attr_ref(table_id, attr1_base_idx)]); + let group_bys = list(vec![attr_index(attr1_base_idx)]); // TODO: Fix this assert_eq!( cost_model.get_agg_row_cnt(group_bys).await.unwrap(), EstimatedStatistic(attr1_ndistinct as f64) ); // Group by two columns should return the product of the n-distinct of the columns. - let group_bys = list(vec![ - attr_ref(table_id, attr1_base_idx), - attr_ref(table_id, attr2_base_idx), - ]); + let group_bys = list(vec![attr_index(attr1_base_idx), attr_index(attr2_base_idx)]); // TODO: Fix this assert_eq!( cost_model.get_agg_row_cnt(group_bys).await.unwrap(), EstimatedStatistic((attr1_ndistinct * attr2_ndistinct) as f64) @@ -168,9 +166,10 @@ mod tests { // Group by multiple columns should return the product of the n-distinct of the columns. If one of the columns // does not have stats, it should use the default value instead. let group_bys = list(vec![ - attr_ref(table_id, attr1_base_idx), - attr_ref(table_id, attr2_base_idx), - attr_ref(table_id, attr3_base_idx), + // TODO: Fix this + attr_index(attr1_base_idx), + attr_index(attr2_base_idx), + attr_index(attr3_base_idx), ]); assert_eq!( cost_model.get_agg_row_cnt(group_bys).await.unwrap(), diff --git a/optd-cost-model/src/cost/filter/comp_op.rs b/optd-cost-model/src/cost/filter/comp_op.rs index 9712c82..6d062d3 100644 --- a/optd-cost-model/src/cost/filter/comp_op.rs +++ b/optd-cost-model/src/cost/filter/comp_op.rs @@ -4,7 +4,7 @@ use crate::{ common::{ nodes::{ArcPredicateNode, PredicateType, ReprPredicateNode}, predicates::{ - attr_ref_pred::AttrRefPred, bin_op_pred::BinOpType, cast_pred::CastPred, + attr_index_pred::AttrIndexPred, bin_op_pred::BinOpType, cast_pred::CastPred, constant_pred::ConstantPred, }, values::Value, @@ -41,7 +41,7 @@ impl CostModelImpl { .first() .expect("we just checked that attr_ref_exprs.len() == 1"); let attr_ref_idx = attr_ref_expr.attr_index(); - let table_id = attr_ref_expr.table_id(); + let table_id = todo!(); // TODO: Consider attribute is a derived attribute if values.len() == 1 { @@ -118,7 +118,7 @@ impl CostModelImpl { &self, left: ArcPredicateNode, right: ArcPredicateNode, - ) -> CostModelResult<(Vec, Vec, Vec, bool)> { + ) -> CostModelResult<(Vec, Vec, Vec, bool)> { let mut attr_ref_exprs = vec![]; let mut values = vec![]; let mut non_attr_ref_exprs = vec![]; @@ -166,11 +166,11 @@ impl CostModelImpl { .into_pred_node(); false } - PredicateType::AttrRef => { - let attr_ref_expr = AttrRefPred::from_pred_node(cast_expr_child) + PredicateType::AttrIndex => { + let attr_ref_expr = AttrIndexPred::from_pred_node(cast_expr_child) .expect("we already checked that the type is AttributeRef"); let attr_ref_idx = attr_ref_expr.attr_index(); - let table_id = attr_ref_expr.table_id(); + let table_id = todo!(); cast_node = attr_ref_expr.into_pred_node(); // The "invert" cast is to invert the cast so that we're casting the // non_cast_node to the attribute's original type. @@ -185,7 +185,7 @@ impl CostModelImpl { let invert_cast_data_type = &attribute_info.typ.into_data_type(); match non_cast_node.typ { - PredicateType::AttrRef => { + PredicateType::AttrIndex => { // In general, there's no way to remove the Cast here. We can't move // the Cast to the other AttributeRef // because that would lead to an infinite loop. Thus, we just leave @@ -219,10 +219,10 @@ impl CostModelImpl { // Sort nodes into attr_ref_exprs, values, and non_attr_ref_exprs match uncasted_left.as_ref().typ { - PredicateType::AttrRef => { + PredicateType::AttrIndex => { is_left_attr_ref = true; attr_ref_exprs.push( - AttrRefPred::from_pred_node(uncasted_left) + AttrIndexPred::from_pred_node(uncasted_left) .expect("we already checked that the type is AttributeRef"), ); } @@ -240,9 +240,9 @@ impl CostModelImpl { } } match uncasted_right.as_ref().typ { - PredicateType::AttrRef => { + PredicateType::AttrIndex => { attr_ref_exprs.push( - AttrRefPred::from_pred_node(uncasted_right) + AttrIndexPred::from_pred_node(uncasted_right) .expect("we already checked that the type is AttributeRef"), ); } diff --git a/optd-cost-model/src/cost/filter/controller.rs b/optd-cost-model/src/cost/filter/controller.rs index d8fe9b5..54bdc7d 100644 --- a/optd-cost-model/src/cost/filter/controller.rs +++ b/optd-cost-model/src/cost/filter/controller.rs @@ -28,7 +28,7 @@ impl CostModelImpl { Box::pin(async move { match &expr_tree.typ { PredicateType::Constant(_) => Ok(Self::get_constant_selectivity(expr_tree)), - PredicateType::AttrRef => unimplemented!("check bool type or else panic"), + PredicateType::AttrIndex => unimplemented!("check bool type or else panic"), PredicateType::UnOp(un_op_typ) => { assert!(expr_tree.children.len() == 1); let child = expr_tree.child(0); @@ -149,8 +149,16 @@ mod tests { vec![None], ); - let expr_tree = bin_op(BinOpType::Eq, attr_ref(table_id, 0), cnst(Value::Int32(1))); - let expr_tree_rev = bin_op(BinOpType::Eq, cnst(Value::Int32(1)), attr_ref(table_id, 0)); + let expr_tree = bin_op( + BinOpType::Eq, + attr_index(0), // TODO: Fix this + cnst(Value::Int32(1)), + ); + let expr_tree_rev = bin_op( + BinOpType::Eq, + cnst(Value::Int32(1)), + attr_index(0), // TODO: Fix this + ); assert_approx_eq::assert_approx_eq!( cost_model.get_filter_selectivity(expr_tree).await.unwrap(), 0.3 @@ -182,8 +190,16 @@ mod tests { vec![None], ); - let expr_tree = bin_op(BinOpType::Eq, attr_ref(table_id, 0), cnst(Value::Int32(2))); - let expr_tree_rev = bin_op(BinOpType::Eq, cnst(Value::Int32(2)), attr_ref(table_id, 0)); + let expr_tree = bin_op( + BinOpType::Eq, + attr_index(0), // TODO: Fix this + cnst(Value::Int32(2)), + ); + let expr_tree_rev = bin_op( + BinOpType::Eq, + cnst(Value::Int32(2)), + attr_index(0), // TODO: Fix this + ); assert_approx_eq::assert_approx_eq!( cost_model.get_filter_selectivity(expr_tree).await.unwrap(), 0.12 @@ -216,8 +232,16 @@ mod tests { vec![None], ); - let expr_tree = bin_op(BinOpType::Neq, attr_ref(table_id, 0), cnst(Value::Int32(1))); - let expr_tree_rev = bin_op(BinOpType::Neq, cnst(Value::Int32(1)), attr_ref(table_id, 0)); + let expr_tree = bin_op( + BinOpType::Neq, + attr_index(0), // TODO: Fix this + cnst(Value::Int32(1)), + ); + let expr_tree_rev = bin_op( + BinOpType::Neq, + cnst(Value::Int32(1)), + attr_index(0), // TODO: Fix this + ); assert_approx_eq::assert_approx_eq!( cost_model.get_filter_selectivity(expr_tree).await.unwrap(), 1.0 - 0.3 @@ -251,10 +275,14 @@ mod tests { let expr_tree = bin_op( BinOpType::Leq, - attr_ref(table_id, 0), + attr_index(0), // TODO: Fix this cnst(Value::Int32(15)), ); - let expr_tree_rev = bin_op(BinOpType::Gt, cnst(Value::Int32(15)), attr_ref(table_id, 0)); + let expr_tree_rev = bin_op( + BinOpType::Gt, + cnst(Value::Int32(15)), + attr_index(0), // TODO: Fix this + ); assert_approx_eq::assert_approx_eq!( cost_model.get_filter_selectivity(expr_tree).await.unwrap(), 0.7 @@ -293,10 +321,14 @@ mod tests { let expr_tree = bin_op( BinOpType::Leq, - attr_ref(table_id, 0), + attr_index(0), // TODO: Fix this cnst(Value::Int32(15)), ); - let expr_tree_rev = bin_op(BinOpType::Gt, cnst(Value::Int32(15)), attr_ref(table_id, 0)); + let expr_tree_rev = bin_op( + BinOpType::Gt, + cnst(Value::Int32(15)), + attr_index(0), // TODO: Fix this + ); assert_approx_eq::assert_approx_eq!( cost_model.get_filter_selectivity(expr_tree).await.unwrap(), 0.85 @@ -335,10 +367,14 @@ mod tests { let expr_tree = bin_op( BinOpType::Leq, - attr_ref(table_id, 0), + attr_index(0), // TODO: Fix this + cnst(Value::Int32(15)), + ); + let expr_tree_rev = bin_op( + BinOpType::Gt, cnst(Value::Int32(15)), + attr_index(0), // TODO: Fix this ); - let expr_tree_rev = bin_op(BinOpType::Gt, cnst(Value::Int32(15)), attr_ref(table_id, 0)); assert_approx_eq::assert_approx_eq!( cost_model.get_filter_selectivity(expr_tree).await.unwrap(), 0.93 @@ -370,11 +406,15 @@ mod tests { vec![None], ); - let expr_tree = bin_op(BinOpType::Lt, attr_ref(table_id, 0), cnst(Value::Int32(15))); + let expr_tree = bin_op( + BinOpType::Lt, + attr_index(0), // TODO: Fix this + cnst(Value::Int32(15)), + ); let expr_tree_rev = bin_op( BinOpType::Geq, cnst(Value::Int32(15)), - attr_ref(table_id, 0), + attr_index(0), // TODO: Fix this ); assert_approx_eq::assert_approx_eq!( cost_model.get_filter_selectivity(expr_tree).await.unwrap(), @@ -413,11 +453,15 @@ mod tests { vec![None], ); - let expr_tree = bin_op(BinOpType::Lt, attr_ref(table_id, 0), cnst(Value::Int32(15))); + let expr_tree = bin_op( + BinOpType::Lt, + attr_index(0), // TODO: Fix this + cnst(Value::Int32(15)), + ); let expr_tree_rev = bin_op( BinOpType::Geq, cnst(Value::Int32(15)), - attr_ref(table_id, 0), + attr_index(0), // TODO: Fix this ); assert_approx_eq::assert_approx_eq!( cost_model.get_filter_selectivity(expr_tree).await.unwrap(), @@ -456,11 +500,15 @@ mod tests { vec![None], ); - let expr_tree = bin_op(BinOpType::Lt, attr_ref(table_id, 0), cnst(Value::Int32(15))); + let expr_tree = bin_op( + BinOpType::Lt, + attr_index(0), // TODO: Fix this + cnst(Value::Int32(15)), + ); let expr_tree_rev = bin_op( BinOpType::Geq, cnst(Value::Int32(15)), - attr_ref(table_id, 0), + attr_index(0), // TODO: Fix this ); assert_approx_eq::assert_approx_eq!( cost_model.get_filter_selectivity(expr_tree).await.unwrap(), @@ -495,11 +543,15 @@ mod tests { vec![None], ); - let expr_tree = bin_op(BinOpType::Gt, attr_ref(table_id, 0), cnst(Value::Int32(15))); + let expr_tree = bin_op( + BinOpType::Gt, + attr_index(0), // TODO: Fix this + cnst(Value::Int32(15)), + ); let expr_tree_rev = bin_op( BinOpType::Leq, cnst(Value::Int32(15)), - attr_ref(table_id, 0), + attr_index(0), // TODO: Fix this ); assert_approx_eq::assert_approx_eq!( cost_model.get_filter_selectivity(expr_tree).await.unwrap(), @@ -534,10 +586,14 @@ mod tests { let expr_tree = bin_op( BinOpType::Geq, - attr_ref(table_id, 0), + attr_index(0), // TODO: Fix this + cnst(Value::Int32(15)), + ); + let expr_tree_rev = bin_op( + BinOpType::Lt, cnst(Value::Int32(15)), + attr_index(0), // TODO: Fix this ); - let expr_tree_rev = bin_op(BinOpType::Lt, cnst(Value::Int32(15)), attr_ref(table_id, 0)); assert_approx_eq::assert_approx_eq!( cost_model.get_filter_selectivity(expr_tree).await.unwrap(), @@ -571,9 +627,21 @@ mod tests { vec![None], ); - let eq1 = bin_op(BinOpType::Eq, attr_ref(table_id, 0), cnst(Value::Int32(1))); - let eq5 = bin_op(BinOpType::Eq, attr_ref(table_id, 0), cnst(Value::Int32(5))); - let eq8 = bin_op(BinOpType::Eq, attr_ref(table_id, 0), cnst(Value::Int32(8))); + let eq1 = bin_op( + BinOpType::Eq, + attr_index(0), // TODO: Fix this + cnst(Value::Int32(1)), + ); + let eq5 = bin_op( + BinOpType::Eq, + attr_index(0), // TODO: Fix this + cnst(Value::Int32(5)), + ); + let eq8 = bin_op( + BinOpType::Eq, + attr_index(0), // TODO: Fix this + cnst(Value::Int32(8)), + ); let expr_tree = log_op(LogOpType::And, vec![eq1.clone(), eq5.clone(), eq8.clone()]); let expr_tree_shift1 = log_op(LogOpType::And, vec![eq5.clone(), eq8.clone(), eq1.clone()]); let expr_tree_shift2 = log_op(LogOpType::And, vec![eq8.clone(), eq1.clone(), eq5.clone()]); @@ -617,9 +685,21 @@ mod tests { vec![None], ); - let eq1 = bin_op(BinOpType::Eq, attr_ref(table_id, 0), cnst(Value::Int32(1))); - let eq5 = bin_op(BinOpType::Eq, attr_ref(table_id, 0), cnst(Value::Int32(5))); - let eq8 = bin_op(BinOpType::Eq, attr_ref(table_id, 0), cnst(Value::Int32(8))); + let eq1 = bin_op( + BinOpType::Eq, + attr_index(0), // TODO: Fix this + cnst(Value::Int32(1)), + ); + let eq5 = bin_op( + BinOpType::Eq, + attr_index(0), // TODO: Fix this + cnst(Value::Int32(5)), + ); + let eq8 = bin_op( + BinOpType::Eq, + attr_index(0), // TODO: Fix this + cnst(Value::Int32(8)), + ); let expr_tree = log_op(LogOpType::Or, vec![eq1.clone(), eq5.clone(), eq8.clone()]); let expr_tree_shift1 = log_op(LogOpType::Or, vec![eq5.clone(), eq8.clone(), eq1.clone()]); let expr_tree_shift2 = log_op(LogOpType::Or, vec![eq8.clone(), eq1.clone(), eq5.clone()]); @@ -664,7 +744,11 @@ mod tests { let expr_tree = un_op( UnOpType::Not, - bin_op(BinOpType::Eq, attr_ref(table_id, 0), cnst(Value::Int32(1))), + bin_op( + BinOpType::Eq, + attr_index(0), // TODO: Fix this + cnst(Value::Int32(1)), + ), ); assert_approx_eq::assert_approx_eq!( @@ -696,13 +780,13 @@ mod tests { let expr_tree = bin_op( BinOpType::Eq, - attr_ref(table_id, 0), + attr_index(0), // TODO: Fix this cast(cnst(Value::Int64(1)), DataType::Int32), ); let expr_tree_rev = bin_op( BinOpType::Eq, cast(cnst(Value::Int64(1)), DataType::Int32), - attr_ref(table_id, 0), + attr_index(0), // TODO: Fix this ); assert_approx_eq::assert_approx_eq!( @@ -746,13 +830,13 @@ mod tests { let expr_tree = bin_op( BinOpType::Eq, - cast(attr_ref(table_id, 0), DataType::Int64), + cast(attr_index(0), DataType::Int64), // TODO: Fix this cnst(Value::Int64(1)), ); let expr_tree_rev = bin_op( BinOpType::Eq, cnst(Value::Int64(1)), - cast(attr_ref(table_id, 0), DataType::Int64), + cast(attr_index(0), DataType::Int64), // TODO: Fix this ); assert_approx_eq::assert_approx_eq!( @@ -805,13 +889,13 @@ mod tests { let expr_tree = bin_op( BinOpType::Eq, - cast(attr_ref(table_id, 0), DataType::Int64), - attr_ref(table_id, 1), + cast(attr_index(0), DataType::Int64), // TODO: Fix this + attr_index(1), // TODO: Fix this ); let expr_tree_rev = bin_op( BinOpType::Eq, - attr_ref(table_id, 1), - cast(attr_ref(table_id, 0), DataType::Int64), + attr_index(1), // TODO: Fix this + cast(attr_index(0), DataType::Int64), // TODO: Fix this ); assert_approx_eq::assert_approx_eq!( diff --git a/optd-cost-model/src/cost/filter/in_list.rs b/optd-cost-model/src/cost/filter/in_list.rs index 36a3791..69b9b4f 100644 --- a/optd-cost-model/src/cost/filter/in_list.rs +++ b/optd-cost-model/src/cost/filter/in_list.rs @@ -2,7 +2,7 @@ use crate::{ common::{ nodes::{PredicateType, ReprPredicateNode}, predicates::{ - attr_ref_pred::AttrRefPred, constant_pred::ConstantPred, in_list_pred::InListPred, + attr_index_pred::AttrIndexPred, constant_pred::ConstantPred, in_list_pred::InListPred, }, }, cost_model::CostModelImpl, @@ -18,7 +18,7 @@ impl CostModelImpl { let child = expr.child(); // Check child is a attribute ref. - if !matches!(child.typ, PredicateType::AttrRef) { + if !matches!(child.typ, PredicateType::AttrIndex) { return Ok(UNIMPLEMENTED_SEL); } @@ -32,9 +32,9 @@ impl CostModelImpl { } // Convert child and const expressions to concrete types. - let attr_ref_pred = AttrRefPred::from_pred_node(child).unwrap(); + let attr_ref_pred = AttrIndexPred::from_pred_node(child).unwrap(); let attr_ref_idx = attr_ref_pred.attr_index(); - let table_id = attr_ref_pred.table_id(); + let table_id = todo!(); // TODO: Fix this let list_exprs = list_exprs .into_iter() .map(|expr| { @@ -99,7 +99,7 @@ mod tests { assert_approx_eq::assert_approx_eq!( cost_model - .get_in_list_selectivity(&in_list(table_id, 0, vec![Value::Int32(1)], false)) + .get_in_list_selectivity(&in_list(0, vec![Value::Int32(1)], false)) // TODO: Fix this .await .unwrap(), 0.8 @@ -107,7 +107,7 @@ mod tests { assert_approx_eq::assert_approx_eq!( cost_model .get_in_list_selectivity(&in_list( - table_id, + // TODO: Fix this 0, vec![Value::Int32(1), Value::Int32(2)], false @@ -118,14 +118,14 @@ mod tests { ); assert_approx_eq::assert_approx_eq!( cost_model - .get_in_list_selectivity(&in_list(table_id, 0, vec![Value::Int32(3)], false)) + .get_in_list_selectivity(&in_list(0, vec![Value::Int32(3)], false)) // TODO: Fix this .await .unwrap(), 0.0 ); assert_approx_eq::assert_approx_eq!( cost_model - .get_in_list_selectivity(&in_list(table_id, 0, vec![Value::Int32(1)], true)) + .get_in_list_selectivity(&in_list(0, vec![Value::Int32(1)], true)) // TODO: Fix this .await .unwrap(), 0.2 @@ -133,7 +133,7 @@ mod tests { assert_approx_eq::assert_approx_eq!( cost_model .get_in_list_selectivity(&in_list( - table_id, + // TODO: Fix this 0, vec![Value::Int32(1), Value::Int32(2)], true @@ -144,7 +144,7 @@ mod tests { ); assert_approx_eq::assert_approx_eq!( cost_model - .get_in_list_selectivity(&in_list(table_id, 0, vec![Value::Int32(3)], true)) + .get_in_list_selectivity(&in_list(0, vec![Value::Int32(3)], true)) // TODO: Fix this .await .unwrap(), 1.0 diff --git a/optd-cost-model/src/cost/filter/like.rs b/optd-cost-model/src/cost/filter/like.rs index 7f62fcc..fb11833 100644 --- a/optd-cost-model/src/cost/filter/like.rs +++ b/optd-cost-model/src/cost/filter/like.rs @@ -4,7 +4,7 @@ use crate::{ common::{ nodes::{PredicateType, ReprPredicateNode}, predicates::{ - attr_ref_pred::AttrRefPred, constant_pred::ConstantPred, like_pred::LikePred, + attr_index_pred::AttrIndexPred, constant_pred::ConstantPred, like_pred::LikePred, }, }, cost_model::CostModelImpl, @@ -32,7 +32,7 @@ impl CostModelImpl { let child = like_expr.child(); // Check child is a attribute ref. - if !matches!(child.typ, PredicateType::AttrRef) { + if !matches!(child.typ, PredicateType::AttrIndex) { return Ok(UNIMPLEMENTED_SEL); } @@ -42,9 +42,9 @@ impl CostModelImpl { return Ok(UNIMPLEMENTED_SEL); } - let attr_ref_pred = AttrRefPred::from_pred_node(child).unwrap(); + let attr_ref_pred = AttrIndexPred::from_pred_node(child).unwrap(); let attr_ref_idx = attr_ref_pred.attr_index(); - let table_id = attr_ref_pred.table_id(); + let table_id = todo!(); // TODO: Fix this // TODO: Consider attribute is a derived attribute let pattern = ConstantPred::from_pred_node(pattern) @@ -133,21 +133,21 @@ mod tests { assert_approx_eq::assert_approx_eq!( cost_model - .get_like_selectivity(&like(table_id, 0, "%abcd%", false)) + .get_like_selectivity(&like(0, "%abcd%", false)) // TODO: Fix this .await .unwrap(), 0.1 + FULL_WILDCARD_SEL_FACTOR.powi(2) * FIXED_CHAR_SEL_FACTOR.powi(4) ); assert_approx_eq::assert_approx_eq!( cost_model - .get_like_selectivity(&like(table_id, 0, "%abc%", false)) + .get_like_selectivity(&like(0, "%abc%", false)) // TODO: Fix this .await .unwrap(), 0.1 + 0.1 + FULL_WILDCARD_SEL_FACTOR.powi(2) * FIXED_CHAR_SEL_FACTOR.powi(3) ); assert_approx_eq::assert_approx_eq!( cost_model - .get_like_selectivity(&like(table_id, 0, "%abc%", true)) + .get_like_selectivity(&like(0, "%abc%", true)) // TODO: Fix this .await .unwrap(), 1.0 - (0.1 + 0.1 + FULL_WILDCARD_SEL_FACTOR.powi(2) * FIXED_CHAR_SEL_FACTOR.powi(3)) @@ -175,14 +175,14 @@ mod tests { assert_approx_eq::assert_approx_eq!( cost_model - .get_like_selectivity(&like(table_id, 0, "%abcd%", false)) + .get_like_selectivity(&like(0, "%abcd%", false)) // TODO: Fix this .await .unwrap(), 0.1 + FULL_WILDCARD_SEL_FACTOR.powi(2) * FIXED_CHAR_SEL_FACTOR.powi(4) ); assert_approx_eq::assert_approx_eq!( cost_model - .get_like_selectivity(&like(table_id, 0, "%abcd%", true)) + .get_like_selectivity(&like(0, "%abcd%", true)) // TODO: Fix this .await .unwrap(), 1.0 - (0.1 + FULL_WILDCARD_SEL_FACTOR.powi(2) * FIXED_CHAR_SEL_FACTOR.powi(4)) diff --git a/optd-cost-model/src/cost/join/hash_join.rs b/optd-cost-model/src/cost/join/hash_join.rs index b912108..fec65b1 100644 --- a/optd-cost-model/src/cost/join/hash_join.rs +++ b/optd-cost-model/src/cost/join/hash_join.rs @@ -3,7 +3,7 @@ use itertools::Itertools; use crate::{ common::{ nodes::{JoinType, ReprPredicateNode}, - predicates::{attr_ref_pred::AttrRefPred, list_pred::ListPred}, + predicates::{attr_index_pred::AttrIndexPred, list_pred::ListPred}, properties::attr_ref::{AttrRefs, SemanticCorrelation}, types::GroupId, }, diff --git a/optd-cost-model/src/cost/join/join.rs b/optd-cost-model/src/cost/join/join.rs index 28d3029..eeb5847 100644 --- a/optd-cost-model/src/cost/join/join.rs +++ b/optd-cost-model/src/cost/join/join.rs @@ -6,7 +6,7 @@ use crate::{ common::{ nodes::{ArcPredicateNode, JoinType, PredicateType, ReprPredicateNode}, predicates::{ - attr_ref_pred::AttrRefPred, + attr_index_pred::AttrIndexPred, bin_op_pred::BinOpType, list_pred::ListPred, log_op_pred::{LogOpPred, LogOpType}, @@ -122,8 +122,8 @@ impl CostModelImpl { .zip(right_keys.to_vec()) .map(|(left_key, right_key)| { ( - AttrRefPred::from_pred_node(left_key).expect("keys should be AttrRefPreds"), - AttrRefPred::from_pred_node(right_key).expect("keys should be AttrRefPreds"), + AttrIndexPred::from_pred_node(left_key).expect("keys should be AttrRefPreds"), + AttrIndexPred::from_pred_node(right_key).expect("keys should be AttrRefPreds"), ) }) .collect_vec(); @@ -156,7 +156,7 @@ impl CostModelImpl { async fn get_join_selectivity_core( &self, join_typ: JoinType, - on_attr_ref_pairs: Vec<(AttrRefPred, AttrRefPred)>, + on_attr_ref_pairs: Vec<(AttrIndexPred, AttrIndexPred)>, filter_expr_tree: Option, attr_refs: &AttrRefs, input_correlation: Option, @@ -364,7 +364,7 @@ impl CostModelImpl { /// `get_join_selectivity_from_redundant_predicates`. async fn get_join_on_selectivity( &self, - on_attr_ref_pairs: &[(AttrRefPred, AttrRefPred)], + on_attr_ref_pairs: &[(AttrIndexPred, AttrIndexPred)], attr_refs: &AttrRefs, input_correlation: Option, right_attr_ref_offset: usize, diff --git a/optd-cost-model/src/cost/join/mod.rs b/optd-cost-model/src/cost/join/mod.rs index 8b29661..65ab5bb 100644 --- a/optd-cost-model/src/cost/join/mod.rs +++ b/optd-cost-model/src/cost/join/mod.rs @@ -1,6 +1,6 @@ use crate::common::{ nodes::{ArcPredicateNode, PredicateType, ReprPredicateNode}, - predicates::{attr_ref_pred::AttrRefPred, bin_op_pred::BinOpType}, + predicates::{attr_index_pred::AttrIndexPred, bin_op_pred::BinOpType}, properties::attr_ref::{ AttrRef, AttrRefs, BaseTableAttrRef, GroupAttrRefs, SemanticCorrelation, }, @@ -27,18 +27,19 @@ pub(crate) fn get_input_correlation( pub(crate) fn get_on_attr_ref_pair( expr_tree: ArcPredicateNode, attr_refs: &AttrRefs, -) -> Option<(AttrRefPred, AttrRefPred)> { +) -> Option<(AttrIndexPred, AttrIndexPred)> { // 1. Check that it's equality if expr_tree.typ == PredicateType::BinOp(BinOpType::Eq) { let left_child = expr_tree.child(0); let right_child = expr_tree.child(1); // 2. Check that both sides are attribute refs - if left_child.typ == PredicateType::AttrRef && right_child.typ == PredicateType::AttrRef { + if left_child.typ == PredicateType::AttrIndex && right_child.typ == PredicateType::AttrIndex + { // 3. Check that both sides don't belong to the same table (if we don't know, that // means they don't belong) - let left_attr_ref_expr = AttrRefPred::from_pred_node(left_child) + let left_attr_ref_expr = AttrIndexPred::from_pred_node(left_child) .expect("we already checked that the type is AttrRef"); - let right_attr_ref_expr = AttrRefPred::from_pred_node(right_child) + let right_attr_ref_expr = AttrIndexPred::from_pred_node(right_child) .expect("we already checked that the type is AttrRef"); let left_attr_ref = &attr_refs[left_attr_ref_expr.attr_index() as usize]; let right_attr_ref = &attr_refs[right_attr_ref_expr.attr_index() as usize]; diff --git a/optd-cost-model/src/cost_model.rs b/optd-cost-model/src/cost_model.rs index 4e12a51..552a272 100644 --- a/optd-cost-model/src/cost_model.rs +++ b/optd-cost-model/src/cost_model.rs @@ -128,7 +128,7 @@ pub mod tests { common::{ nodes::ReprPredicateNode, predicates::{ - attr_ref_pred::AttrRefPred, + attr_index_pred::AttrIndexPred, bin_op_pred::{BinOpPred, BinOpType}, cast_pred::CastPred, constant_pred::{ConstantPred, ConstantType}, @@ -499,8 +499,8 @@ pub mod tests { } } - pub fn attr_ref(table_id: TableId, attr_base_index: u64) -> ArcPredicateNode { - AttrRefPred::new(table_id, attr_base_index).into_pred_node() + pub fn attr_index(attr_index: u64) -> ArcPredicateNode { + AttrIndexPred::new(attr_index).into_pred_node() } pub fn cnst(value: Value) -> ArcPredicateNode { @@ -535,24 +535,19 @@ pub mod tests { ListPred::new(children).into_pred_node() } - pub fn in_list( - table_id: TableId, - attr_ref_idx: u64, - list: Vec, - negated: bool, - ) -> InListPred { + pub fn in_list(attr_idx: u64, list: Vec, negated: bool) -> InListPred { InListPred::new( - attr_ref(table_id, attr_ref_idx), + attr_index(attr_idx), ListPred::new(list.into_iter().map(cnst).collect_vec()), negated, ) } - pub fn like(table_id: TableId, attr_ref_idx: u64, pattern: &str, negated: bool) -> LikePred { + pub fn like(attr_idx: u64, pattern: &str, negated: bool) -> LikePred { LikePred::new( negated, false, - attr_ref(table_id, attr_ref_idx), + attr_index(attr_idx), cnst(Value::String(pattern.into())), ) } From 1569fc5976710d17f4f7852c61f25f2681b4bde4 Mon Sep 17 00:00:00 2001 From: Lan Lou Date: Mon, 18 Nov 2024 20:09:14 -0500 Subject: [PATCH 47/51] Modify the tests of filter and agg --- optd-cost-model/src/cost/agg.rs | 100 ++-- optd-cost-model/src/cost/filter/attribute.rs | 13 +- optd-cost-model/src/cost/filter/comp_op.rs | 133 ++--- optd-cost-model/src/cost/filter/controller.rs | 476 ++++++++---------- optd-cost-model/src/cost/filter/in_list.rs | 92 ++-- optd-cost-model/src/cost/filter/like.rs | 120 +++-- optd-cost-model/src/cost/filter/log_op.rs | 7 +- optd-cost-model/src/cost/join/join.rs | 6 +- optd-cost-model/src/cost/join/mod.rs | 1 + optd-cost-model/src/cost_model.rs | 92 +++- optd-cost-model/src/storage/mock.rs | 10 +- optd-cost-model/src/storage/mod.rs | 12 +- optd-cost-model/src/storage/persistent.rs | 10 +- 13 files changed, 581 insertions(+), 491 deletions(-) diff --git a/optd-cost-model/src/cost/agg.rs b/optd-cost-model/src/cost/agg.rs index 254a6ad..f5edc7a 100644 --- a/optd-cost-model/src/cost/agg.rs +++ b/optd-cost-model/src/cost/agg.rs @@ -2,7 +2,8 @@ use crate::{ common::{ nodes::{ArcPredicateNode, PredicateType, ReprPredicateNode}, predicates::{attr_index_pred::AttrIndexPred, list_pred::ListPred}, - types::TableId, + properties::attr_ref::{AttrRef, BaseTableAttrRef}, + types::GroupId, }, cost_model::CostModelImpl, stats::DEFAULT_NUM_DISTINCT, @@ -13,6 +14,7 @@ use crate::{ impl CostModelImpl { pub async fn get_agg_row_cnt( &self, + group_id: GroupId, group_by: ArcPredicateNode, ) -> CostModelResult { let group_by = ListPred::from_pred_node(group_by).unwrap(); @@ -32,12 +34,9 @@ impl CostModelImpl { "Expected AttributeRef predicate".to_string(), ) })?; - let is_derived = todo!(); - if is_derived { - row_cnt *= DEFAULT_NUM_DISTINCT; - } else { - let table_id = todo!(); - let attr_idx = attr_ref.attr_index(); + if let AttrRef::BaseTableAttrRef(BaseTableAttrRef { table_id, attr_idx }) = + self.memo.get_attribute_ref(group_id, attr_ref.attr_index()) + { // TODO: Only query ndistinct instead of all kinds of stats. let stats_option = self.get_attribute_comb_stats(table_id, &[attr_idx]).await?; @@ -50,6 +49,9 @@ impl CostModelImpl { } }; row_cnt *= ndistinct; + } else { + // TOOD: Handle derived attributes. + row_cnt *= DEFAULT_NUM_DISTINCT; } } _ => { @@ -65,7 +67,7 @@ impl CostModelImpl { #[cfg(test)] mod tests { - use std::collections::HashMap; + use std::{collections::HashMap, ops::Deref}; use crate::{ common::{ @@ -75,8 +77,9 @@ mod tests { values::Value, }, cost_model::tests::{ - attr_index, cnst, create_mock_cost_model, empty_list, empty_per_attr_stats, list, - TestPerAttributeStats, + attr_index, cnst, create_mock_cost_model, create_mock_cost_model_with_attr_types, + empty_list, empty_per_attr_stats, list, TestPerAttributeStats, TEST_ATTR1_BASE_INDEX, + TEST_ATTR2_BASE_INDEX, TEST_ATTR3_BASE_INDEX, TEST_GROUP1_ID, TEST_TABLE1_ID, }, stats::{utilities::simple_map::SimpleMap, MostCommonValues, DEFAULT_NUM_DISTINCT}, EstimatedStatistic, @@ -84,39 +87,49 @@ mod tests { #[tokio::test] async fn test_agg_no_stats() { - let table_id = TableId(0); - let cost_model = create_mock_cost_model(vec![table_id], vec![], vec![None]); + let cost_model = create_mock_cost_model_with_attr_types( + vec![TEST_TABLE1_ID], + vec![], + vec![HashMap::from([ + (TEST_ATTR1_BASE_INDEX, ConstantType::Int32), + (TEST_ATTR2_BASE_INDEX, ConstantType::Int32), + ])], + vec![None], + ); // Group by empty list should return 1. let group_bys = empty_list(); assert_eq!( - cost_model.get_agg_row_cnt(group_bys).await.unwrap(), + cost_model + .get_agg_row_cnt(TEST_GROUP1_ID, group_bys) + .await + .unwrap(), EstimatedStatistic(1.0) ); // Group by single column should return the default value since there are no stats. let group_bys = list(vec![attr_index(0)]); assert_eq!( - cost_model.get_agg_row_cnt(group_bys).await.unwrap(), + cost_model + .get_agg_row_cnt(TEST_GROUP1_ID, group_bys) + .await + .unwrap(), EstimatedStatistic(DEFAULT_NUM_DISTINCT as f64) ); // Group by two columns should return the default value squared since there are no stats. let group_bys = list(vec![attr_index(0), attr_index(1)]); assert_eq!( - cost_model.get_agg_row_cnt(group_bys).await.unwrap(), + cost_model + .get_agg_row_cnt(TEST_GROUP1_ID, group_bys) + .await + .unwrap(), EstimatedStatistic((DEFAULT_NUM_DISTINCT * DEFAULT_NUM_DISTINCT) as f64) ); } #[tokio::test] async fn test_agg_with_stats() { - let table_id = TableId(0); - let group_id = GroupId(0); - let attr1_base_idx = 0; - let attr2_base_idx = 1; - let attr3_base_idx = 2; - let attr1_ndistinct = 12; let attr2_ndistinct = 645; let attr1_stats = TestPerAttributeStats::new( @@ -132,47 +145,58 @@ mod tests { 0.0, ); - let cost_model = create_mock_cost_model( - vec![table_id], + let cost_model = create_mock_cost_model_with_attr_types( + vec![TEST_TABLE1_ID], + vec![HashMap::from([ + (TEST_ATTR1_BASE_INDEX, attr1_stats), + (TEST_ATTR2_BASE_INDEX, attr2_stats), + ])], vec![HashMap::from([ - (attr1_base_idx, attr1_stats), - (attr2_base_idx, attr2_stats), + (TEST_ATTR1_BASE_INDEX, ConstantType::Int32), + (TEST_ATTR2_BASE_INDEX, ConstantType::Int32), + (TEST_ATTR3_BASE_INDEX, ConstantType::Int32), ])], vec![None], - // attr_infos, ); // Group by empty list should return 1. let group_bys = empty_list(); assert_eq!( - cost_model.get_agg_row_cnt(group_bys).await.unwrap(), + cost_model + .get_agg_row_cnt(TEST_GROUP1_ID, group_bys) + .await + .unwrap(), EstimatedStatistic(1.0) ); // Group by single column should return the n-distinct of the column. - let group_bys = list(vec![attr_index(attr1_base_idx)]); // TODO: Fix this + let group_bys = list(vec![attr_index(0)]); assert_eq!( - cost_model.get_agg_row_cnt(group_bys).await.unwrap(), + cost_model + .get_agg_row_cnt(TEST_GROUP1_ID, group_bys) + .await + .unwrap(), EstimatedStatistic(attr1_ndistinct as f64) ); // Group by two columns should return the product of the n-distinct of the columns. - let group_bys = list(vec![attr_index(attr1_base_idx), attr_index(attr2_base_idx)]); // TODO: Fix this + let group_bys = list(vec![attr_index(0), attr_index(1)]); assert_eq!( - cost_model.get_agg_row_cnt(group_bys).await.unwrap(), + cost_model + .get_agg_row_cnt(TEST_GROUP1_ID, group_bys) + .await + .unwrap(), EstimatedStatistic((attr1_ndistinct * attr2_ndistinct) as f64) ); // Group by multiple columns should return the product of the n-distinct of the columns. If one of the columns // does not have stats, it should use the default value instead. - let group_bys = list(vec![ - // TODO: Fix this - attr_index(attr1_base_idx), - attr_index(attr2_base_idx), - attr_index(attr3_base_idx), - ]); + let group_bys = list(vec![attr_index(0), attr_index(1), attr_index(2)]); assert_eq!( - cost_model.get_agg_row_cnt(group_bys).await.unwrap(), + cost_model + .get_agg_row_cnt(TEST_GROUP1_ID, group_bys) + .await + .unwrap(), EstimatedStatistic((attr1_ndistinct * attr2_ndistinct * DEFAULT_NUM_DISTINCT) as f64) ); } diff --git a/optd-cost-model/src/cost/filter/attribute.rs b/optd-cost-model/src/cost/filter/attribute.rs index e39d7b5..7a082b7 100644 --- a/optd-cost-model/src/cost/filter/attribute.rs +++ b/optd-cost-model/src/cost/filter/attribute.rs @@ -16,6 +16,10 @@ impl CostModelImpl { /// Also, get_attribute_equality_selectivity is a subroutine when computing range /// selectivity, which is another reason for separating these into two functions /// is_eq means whether it's == or != + /// + /// Currently, we only support calculating the equality selectivity for an existed attribute, + /// not a derived attribute. + /// TODO: Support derived attributes. pub(crate) async fn get_attribute_equality_selectivity( &self, table_id: TableId, @@ -23,7 +27,6 @@ impl CostModelImpl { value: &Value, is_eq: bool, ) -> CostModelResult { - // TODO: The attribute could be a derived attribute let ret_sel = { if let Some(attribute_stats) = self .get_attribute_comb_stats(table_id, &[attr_base_index]) @@ -89,6 +92,10 @@ impl CostModelImpl { } /// Compute the frequency of values in a attribute less than the given value. + /// + /// Currently, we only support calculating the equality selectivity for an existed attribute, + /// not a derived attribute. + /// TODO: Support derived attributes. async fn get_attribute_lt_value_freq( &self, attribute_stats: &AttributeCombValueStats, @@ -116,6 +123,10 @@ impl CostModelImpl { /// Range predicates are handled entirely differently from equality predicates so this is its /// own function. If it is unable to find the statistics, it returns DEFAULT_INEQ_SEL. /// The selectivity is computed as quantile of the right bound minus quantile of the left bound. + /// + /// Currently, we only support calculating the equality selectivity for an existed attribute, + /// not a derived attribute. + /// TODO: Support derived attributes. pub(crate) async fn get_attribute_range_selectivity( &self, table_id: TableId, diff --git a/optd-cost-model/src/cost/filter/comp_op.rs b/optd-cost-model/src/cost/filter/comp_op.rs index 6d062d3..5270819 100644 --- a/optd-cost-model/src/cost/filter/comp_op.rs +++ b/optd-cost-model/src/cost/filter/comp_op.rs @@ -7,6 +7,8 @@ use crate::{ attr_index_pred::AttrIndexPred, bin_op_pred::BinOpType, cast_pred::CastPred, constant_pred::ConstantPred, }, + properties::attr_ref::{AttrRef, BaseTableAttrRef}, + types::GroupId, values::Value, }, cost_model::CostModelImpl, @@ -19,6 +21,7 @@ impl CostModelImpl { /// Comparison operators are the base case for recursion in get_filter_selectivity() pub(crate) async fn get_comp_op_selectivity( &self, + group_id: GroupId, comp_bin_op_typ: BinOpType, left: ArcPredicateNode, right: ArcPredicateNode, @@ -27,7 +30,7 @@ impl CostModelImpl { // I intentionally performed moves on left and right. This way, we don't accidentally use // them after this block - let semantic_res = self.get_semantic_nodes(left, right).await; + let semantic_res = self.get_semantic_nodes(group_id, left, right).await; if semantic_res.is_err() { return Ok(Self::get_default_comparison_op_selectivity(comp_bin_op_typ)); } @@ -41,67 +44,74 @@ impl CostModelImpl { .first() .expect("we just checked that attr_ref_exprs.len() == 1"); let attr_ref_idx = attr_ref_expr.attr_index(); - let table_id = todo!(); - // TODO: Consider attribute is a derived attribute - if values.len() == 1 { - let value = values - .first() - .expect("we just checked that values.len() == 1"); - match comp_bin_op_typ { - BinOpType::Eq => { - self.get_attribute_equality_selectivity(table_id, attr_ref_idx, value, true) - .await - } - BinOpType::Neq => { - self.get_attribute_equality_selectivity( - table_id, - attr_ref_idx, - value, - false, - ) - .await - } - BinOpType::Lt | BinOpType::Leq | BinOpType::Gt | BinOpType::Geq => { - let start = match (comp_bin_op_typ, is_left_attr_ref) { - (BinOpType::Lt, true) | (BinOpType::Geq, false) => Bound::Unbounded, - (BinOpType::Leq, true) | (BinOpType::Gt, false) => Bound::Unbounded, - (BinOpType::Gt, true) | (BinOpType::Leq, false) => Bound::Excluded(value), - (BinOpType::Geq, true) | (BinOpType::Lt, false) => Bound::Included(value), - _ => unreachable!("all comparison BinOpTypes were enumerated. this should be unreachable"), - }; - let end = match (comp_bin_op_typ, is_left_attr_ref) { - (BinOpType::Lt, true) | (BinOpType::Geq, false) => Bound::Excluded(value), - (BinOpType::Leq, true) | (BinOpType::Gt, false) => Bound::Included(value), - (BinOpType::Gt, true) | (BinOpType::Leq, false) => Bound::Unbounded, - (BinOpType::Geq, true) | (BinOpType::Lt, false) => Bound::Unbounded, - _ => unreachable!("all comparison BinOpTypes were enumerated. this should be unreachable"), - }; - self.get_attribute_range_selectivity(table_id, attr_ref_idx, start, end) + if let AttrRef::BaseTableAttrRef(BaseTableAttrRef { table_id, attr_idx }) = + self.memo.get_attribute_ref(group_id, attr_ref_idx) + { + if values.len() == 1 { + let value = values + .first() + .expect("we just checked that values.len() == 1"); + match comp_bin_op_typ { + BinOpType::Eq => { + self.get_attribute_equality_selectivity(table_id, attr_idx, value, true) + .await + } + BinOpType::Neq => { + self.get_attribute_equality_selectivity( + table_id, + attr_ref_idx, + value, + false, + ) .await + } + BinOpType::Lt | BinOpType::Leq | BinOpType::Gt | BinOpType::Geq => { + let start = match (comp_bin_op_typ, is_left_attr_ref) { + (BinOpType::Lt, true) | (BinOpType::Geq, false) => Bound::Unbounded, + (BinOpType::Leq, true) | (BinOpType::Gt, false) => Bound::Unbounded, + (BinOpType::Gt, true) | (BinOpType::Leq, false) => Bound::Excluded(value), + (BinOpType::Geq, true) | (BinOpType::Lt, false) => Bound::Included(value), + _ => unreachable!("all comparison BinOpTypes were enumerated. this should be unreachable"), + }; + let end = match (comp_bin_op_typ, is_left_attr_ref) { + (BinOpType::Lt, true) | (BinOpType::Geq, false) => Bound::Excluded(value), + (BinOpType::Leq, true) | (BinOpType::Gt, false) => Bound::Included(value), + (BinOpType::Gt, true) | (BinOpType::Leq, false) => Bound::Unbounded, + (BinOpType::Geq, true) | (BinOpType::Lt, false) => Bound::Unbounded, + _ => unreachable!("all comparison BinOpTypes were enumerated. this should be unreachable"), + }; + self.get_attribute_range_selectivity(table_id, attr_ref_idx, start, end) + .await + } + _ => unreachable!( + "all comparison BinOpTypes were enumerated. this should be unreachable" + ), } - _ => unreachable!( - "all comparison BinOpTypes were enumerated. this should be unreachable" - ), - } - } else { - let non_attr_ref_expr = non_attr_ref_exprs.first().expect( - "non_attr_ref_exprs should have a value since attr_ref_exprs.len() == 1", - ); + } else { + let non_attr_ref_expr = non_attr_ref_exprs.first().expect( + "non_attr_ref_exprs should have a value since attr_ref_exprs.len() == 1", + ); - match non_attr_ref_expr.as_ref().typ { - PredicateType::BinOp(_) => { - Ok(Self::get_default_comparison_op_selectivity(comp_bin_op_typ)) - } - PredicateType::Cast => Ok(UNIMPLEMENTED_SEL), - PredicateType::Constant(_) => { - unreachable!("we should have handled this in the values.len() == 1 branch") + match non_attr_ref_expr.as_ref().typ { + PredicateType::BinOp(_) => { + Ok(Self::get_default_comparison_op_selectivity(comp_bin_op_typ)) + } + PredicateType::Cast => Ok(UNIMPLEMENTED_SEL), + PredicateType::Constant(_) => { + unreachable!( + "we should have handled this in the values.len() == 1 branch" + ) + } + _ => unimplemented!( + "unhandled case of comparing a attribute ref node to {}", + non_attr_ref_expr.as_ref().typ + ), } - _ => unimplemented!( - "unhandled case of comparing a attribute ref node to {}", - non_attr_ref_expr.as_ref().typ - ), } + } else { + // TODO: attribute is derived + Ok(Self::get_default_comparison_op_selectivity(comp_bin_op_typ)) } } else if attr_ref_exprs.len() == 2 { Ok(Self::get_default_comparison_op_selectivity(comp_bin_op_typ)) @@ -116,6 +126,7 @@ impl CostModelImpl { #[allow(clippy::type_complexity)] async fn get_semantic_nodes( &self, + group_id: GroupId, left: ArcPredicateNode, right: ArcPredicateNode, ) -> CostModelResult<(Vec, Vec, Vec, bool)> { @@ -170,18 +181,10 @@ impl CostModelImpl { let attr_ref_expr = AttrIndexPred::from_pred_node(cast_expr_child) .expect("we already checked that the type is AttributeRef"); let attr_ref_idx = attr_ref_expr.attr_index(); - let table_id = todo!(); cast_node = attr_ref_expr.into_pred_node(); // The "invert" cast is to invert the cast so that we're casting the // non_cast_node to the attribute's original type. - // TODO: Consider attribute info is None. - // **TODO**: What if this attribute is a derived attribute? - let attribute_info = self - .storage_manager - .get_attribute_info(table_id, attr_ref_idx) - .await? - .ok_or({ SemanticError::AttributeNotFound(table_id, attr_ref_idx) })?; - + let attribute_info = self.memo.get_attribute_info(group_id, attr_ref_idx); let invert_cast_data_type = &attribute_info.typ.into_data_type(); match non_cast_node.typ { diff --git a/optd-cost-model/src/cost/filter/controller.rs b/optd-cost-model/src/cost/filter/controller.rs index 54bdc7d..05363e4 100644 --- a/optd-cost-model/src/cost/filter/controller.rs +++ b/optd-cost-model/src/cost/filter/controller.rs @@ -2,6 +2,7 @@ use crate::{ common::{ nodes::{ArcPredicateNode, PredicateType, ReprPredicateNode}, predicates::{in_list_pred::InListPred, like_pred::LikePred, un_op_pred::UnOpType}, + types::GroupId, }, cost_model::CostModelImpl, stats::UNIMPLEMENTED_SEL, @@ -15,14 +16,16 @@ impl CostModelImpl { pub async fn get_filter_row_cnt( &self, child_row_cnt: EstimatedStatistic, + group_id: GroupId, cond: ArcPredicateNode, ) -> CostModelResult { - let selectivity = { self.get_filter_selectivity(cond).await? }; + let selectivity = { self.get_filter_selectivity(group_id, cond).await? }; Ok(EstimatedStatistic((child_row_cnt.0 * selectivity).max(1.0))) } pub async fn get_filter_selectivity( &self, + group_id: GroupId, expr_tree: ArcPredicateNode, ) -> CostModelResult { Box::pin(async move { @@ -36,7 +39,7 @@ impl CostModelImpl { // not doesn't care about nulls so there's no complex logic. it just reverses // the selectivity for instance, != _will not_ include nulls // but "NOT ==" _will_ include nulls - UnOpType::Not => Ok(1.0 - self.get_filter_selectivity(child).await?), + UnOpType::Not => Ok(1.0 - self.get_filter_selectivity(group_id, child).await?), UnOpType::Neg => panic!( "the selectivity of operations that return numerical values is undefined" ), @@ -48,7 +51,7 @@ impl CostModelImpl { let right_child = expr_tree.child(1); if bin_op_typ.is_comparison() { - self.get_comp_op_selectivity(*bin_op_typ, left_child, right_child).await + self.get_comp_op_selectivity(group_id, *bin_op_typ, left_child, right_child).await } else if bin_op_typ.is_numerical() { panic!( "the selectivity of operations that return numerical values is undefined" @@ -58,7 +61,7 @@ impl CostModelImpl { } } PredicateType::LogOp(log_op_typ) => { - self.get_log_op_selectivity(*log_op_typ, &expr_tree.children).await + self.get_log_op_selectivity(group_id, *log_op_typ, &expr_tree.children).await } PredicateType::Func(_) => unimplemented!("check bool type or else panic"), PredicateType::SortOrder(_) => { @@ -68,14 +71,14 @@ impl CostModelImpl { PredicateType::Cast => unimplemented!("check bool type or else panic"), PredicateType::Like => { let like_expr = LikePred::from_pred_node(expr_tree).unwrap(); - self.get_like_selectivity(&like_expr).await + self.get_like_selectivity(group_id, &like_expr).await } PredicateType::DataType(_) => { panic!("the selectivity of a data type is not defined") } PredicateType::InList => { let in_list_expr = InListPred::from_pred_node(expr_tree).unwrap(); - self.get_in_list_selectivity(&in_list_expr).await + self.get_in_list_selectivity(group_id, &in_list_expr).await } _ => unreachable!( "all expression DfPredType were enumerated. this should be unreachable" @@ -117,14 +120,14 @@ mod tests { ); assert_approx_eq::assert_approx_eq!( cost_model - .get_filter_selectivity(cnst(Value::Bool(true))) + .get_filter_selectivity(TEST_GROUP1_ID, cnst(Value::Bool(true))) .await .unwrap(), 1.0 ); assert_approx_eq::assert_approx_eq!( cost_model - .get_filter_selectivity(cnst(Value::Bool(false))) + .get_filter_selectivity(TEST_GROUP1_ID, cnst(Value::Bool(false))) .await .unwrap(), 0.0 @@ -160,12 +163,15 @@ mod tests { attr_index(0), // TODO: Fix this ); assert_approx_eq::assert_approx_eq!( - cost_model.get_filter_selectivity(expr_tree).await.unwrap(), + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), 0.3 ); assert_approx_eq::assert_approx_eq!( cost_model - .get_filter_selectivity(expr_tree_rev) + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_rev) .await .unwrap(), 0.3 @@ -183,30 +189,27 @@ mod tests { 5, 0.0, ); - let table_id = TableId(0); let cost_model = create_mock_cost_model( - vec![table_id], - vec![HashMap::from([(0, per_attribute_stats)])], + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], vec![None], ); - let expr_tree = bin_op( - BinOpType::Eq, - attr_index(0), // TODO: Fix this - cnst(Value::Int32(2)), - ); - let expr_tree_rev = bin_op( - BinOpType::Eq, - cnst(Value::Int32(2)), - attr_index(0), // TODO: Fix this - ); + let expr_tree = bin_op(BinOpType::Eq, attr_index(0), cnst(Value::Int32(2))); + let expr_tree_rev = bin_op(BinOpType::Eq, cnst(Value::Int32(2)), attr_index(0)); assert_approx_eq::assert_approx_eq!( - cost_model.get_filter_selectivity(expr_tree).await.unwrap(), + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), 0.12 ); assert_approx_eq::assert_approx_eq!( cost_model - .get_filter_selectivity(expr_tree_rev) + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_rev) .await .unwrap(), 0.12 @@ -225,30 +228,27 @@ mod tests { 0, 0.0, ); - let table_id = TableId(0); let cost_model = create_mock_cost_model( - vec![table_id], - vec![HashMap::from([(0, per_attribute_stats)])], + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], vec![None], ); - let expr_tree = bin_op( - BinOpType::Neq, - attr_index(0), // TODO: Fix this - cnst(Value::Int32(1)), - ); - let expr_tree_rev = bin_op( - BinOpType::Neq, - cnst(Value::Int32(1)), - attr_index(0), // TODO: Fix this - ); + let expr_tree = bin_op(BinOpType::Neq, attr_index(0), cnst(Value::Int32(1))); + let expr_tree_rev = bin_op(BinOpType::Neq, cnst(Value::Int32(1)), attr_index(0)); assert_approx_eq::assert_approx_eq!( - cost_model.get_filter_selectivity(expr_tree).await.unwrap(), + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), 1.0 - 0.3 ); assert_approx_eq::assert_approx_eq!( cost_model - .get_filter_selectivity(expr_tree_rev) + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_rev) .await .unwrap(), 1.0 - 0.3 @@ -266,30 +266,27 @@ mod tests { 10, 0.0, ); - let table_id = TableId(0); let cost_model = create_mock_cost_model( - vec![table_id], - vec![HashMap::from([(0, per_attribute_stats)])], + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], vec![None], ); - let expr_tree = bin_op( - BinOpType::Leq, - attr_index(0), // TODO: Fix this - cnst(Value::Int32(15)), - ); - let expr_tree_rev = bin_op( - BinOpType::Gt, - cnst(Value::Int32(15)), - attr_index(0), // TODO: Fix this - ); + let expr_tree = bin_op(BinOpType::Leq, attr_index(0), cnst(Value::Int32(15))); + let expr_tree_rev = bin_op(BinOpType::Gt, cnst(Value::Int32(15)), attr_index(0)); assert_approx_eq::assert_approx_eq!( - cost_model.get_filter_selectivity(expr_tree).await.unwrap(), + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), 0.7 ); assert_approx_eq::assert_approx_eq!( cost_model - .get_filter_selectivity(expr_tree_rev) + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_rev) .await .unwrap(), 0.7 @@ -312,30 +309,27 @@ mod tests { 10, 0.0, ); - let table_id = TableId(0); let cost_model = create_mock_cost_model( - vec![table_id], - vec![HashMap::from([(0, per_attribute_stats)])], + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], vec![None], ); - let expr_tree = bin_op( - BinOpType::Leq, - attr_index(0), // TODO: Fix this - cnst(Value::Int32(15)), - ); - let expr_tree_rev = bin_op( - BinOpType::Gt, - cnst(Value::Int32(15)), - attr_index(0), // TODO: Fix this - ); + let expr_tree = bin_op(BinOpType::Leq, attr_index(0), cnst(Value::Int32(15))); + let expr_tree_rev = bin_op(BinOpType::Gt, cnst(Value::Int32(15)), attr_index(0)); assert_approx_eq::assert_approx_eq!( - cost_model.get_filter_selectivity(expr_tree).await.unwrap(), + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), 0.85 ); assert_approx_eq::assert_approx_eq!( cost_model - .get_filter_selectivity(expr_tree_rev) + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_rev) .await .unwrap(), 0.85 @@ -358,30 +352,27 @@ mod tests { 10, 0.0, ); - let table_id = TableId(0); let cost_model = create_mock_cost_model( - vec![table_id], - vec![HashMap::from([(0, per_attribute_stats)])], + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], vec![None], ); - let expr_tree = bin_op( - BinOpType::Leq, - attr_index(0), // TODO: Fix this - cnst(Value::Int32(15)), - ); - let expr_tree_rev = bin_op( - BinOpType::Gt, - cnst(Value::Int32(15)), - attr_index(0), // TODO: Fix this - ); + let expr_tree = bin_op(BinOpType::Leq, attr_index(0), cnst(Value::Int32(15))); + let expr_tree_rev = bin_op(BinOpType::Gt, cnst(Value::Int32(15)), attr_index(0)); assert_approx_eq::assert_approx_eq!( - cost_model.get_filter_selectivity(expr_tree).await.unwrap(), + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), 0.93 ); assert_approx_eq::assert_approx_eq!( cost_model - .get_filter_selectivity(expr_tree_rev) + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_rev) .await .unwrap(), 0.93 @@ -399,30 +390,27 @@ mod tests { 10, 0.0, ); - let table_id = TableId(0); let cost_model = create_mock_cost_model( - vec![table_id], - vec![HashMap::from([(0, per_attribute_stats)])], + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], vec![None], ); - let expr_tree = bin_op( - BinOpType::Lt, - attr_index(0), // TODO: Fix this - cnst(Value::Int32(15)), - ); - let expr_tree_rev = bin_op( - BinOpType::Geq, - cnst(Value::Int32(15)), - attr_index(0), // TODO: Fix this - ); + let expr_tree = bin_op(BinOpType::Lt, attr_index(0), cnst(Value::Int32(15))); + let expr_tree_rev = bin_op(BinOpType::Geq, cnst(Value::Int32(15)), attr_index(0)); assert_approx_eq::assert_approx_eq!( - cost_model.get_filter_selectivity(expr_tree).await.unwrap(), + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), 0.6 ); assert_approx_eq::assert_approx_eq!( cost_model - .get_filter_selectivity(expr_tree_rev) + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_rev) .await .unwrap(), 0.6 @@ -446,30 +434,27 @@ mod tests { * remaining value has freq 0.1 */ 0.0, ); - let table_id = TableId(0); let cost_model = create_mock_cost_model( - vec![table_id], - vec![HashMap::from([(0, per_attribute_stats)])], + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], vec![None], ); - let expr_tree = bin_op( - BinOpType::Lt, - attr_index(0), // TODO: Fix this - cnst(Value::Int32(15)), - ); - let expr_tree_rev = bin_op( - BinOpType::Geq, - cnst(Value::Int32(15)), - attr_index(0), // TODO: Fix this - ); + let expr_tree = bin_op(BinOpType::Lt, attr_index(0), cnst(Value::Int32(15))); + let expr_tree_rev = bin_op(BinOpType::Geq, cnst(Value::Int32(15)), attr_index(0)); assert_approx_eq::assert_approx_eq!( - cost_model.get_filter_selectivity(expr_tree).await.unwrap(), + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), 0.75 ); assert_approx_eq::assert_approx_eq!( cost_model - .get_filter_selectivity(expr_tree_rev) + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_rev) .await .unwrap(), 0.75 @@ -493,30 +478,27 @@ mod tests { * remaining value has freq 0.1 */ 0.0, ); - let table_id = TableId(0); let cost_model = create_mock_cost_model( - vec![table_id], - vec![HashMap::from([(0, per_attribute_stats)])], + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], vec![None], ); - let expr_tree = bin_op( - BinOpType::Lt, - attr_index(0), // TODO: Fix this - cnst(Value::Int32(15)), - ); - let expr_tree_rev = bin_op( - BinOpType::Geq, - cnst(Value::Int32(15)), - attr_index(0), // TODO: Fix this - ); + let expr_tree = bin_op(BinOpType::Lt, attr_index(0), cnst(Value::Int32(15))); + let expr_tree_rev = bin_op(BinOpType::Geq, cnst(Value::Int32(15)), attr_index(0)); assert_approx_eq::assert_approx_eq!( - cost_model.get_filter_selectivity(expr_tree).await.unwrap(), + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), 0.85 ); assert_approx_eq::assert_approx_eq!( cost_model - .get_filter_selectivity(expr_tree_rev) + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_rev) .await .unwrap(), 0.85 @@ -536,30 +518,27 @@ mod tests { 10, 0.0, ); - let table_id = TableId(0); let cost_model = create_mock_cost_model( - vec![table_id], - vec![HashMap::from([(0, per_attribute_stats)])], + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], vec![None], ); - let expr_tree = bin_op( - BinOpType::Gt, - attr_index(0), // TODO: Fix this - cnst(Value::Int32(15)), - ); - let expr_tree_rev = bin_op( - BinOpType::Leq, - cnst(Value::Int32(15)), - attr_index(0), // TODO: Fix this - ); + let expr_tree = bin_op(BinOpType::Gt, attr_index(0), cnst(Value::Int32(15))); + let expr_tree_rev = bin_op(BinOpType::Leq, cnst(Value::Int32(15)), attr_index(0)); assert_approx_eq::assert_approx_eq!( - cost_model.get_filter_selectivity(expr_tree).await.unwrap(), + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), 1.0 - 0.7 ); assert_approx_eq::assert_approx_eq!( cost_model - .get_filter_selectivity(expr_tree_rev) + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_rev) .await .unwrap(), 1.0 - 0.7 @@ -577,31 +556,28 @@ mod tests { 10, 0.0, ); - let table_id = TableId(0); let cost_model = create_mock_cost_model( - vec![table_id], - vec![HashMap::from([(0, per_attribute_stats)])], + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], vec![None], ); - let expr_tree = bin_op( - BinOpType::Geq, - attr_index(0), // TODO: Fix this - cnst(Value::Int32(15)), - ); - let expr_tree_rev = bin_op( - BinOpType::Lt, - cnst(Value::Int32(15)), - attr_index(0), // TODO: Fix this - ); + let expr_tree = bin_op(BinOpType::Geq, attr_index(0), cnst(Value::Int32(15))); + let expr_tree_rev = bin_op(BinOpType::Lt, cnst(Value::Int32(15)), attr_index(0)); assert_approx_eq::assert_approx_eq!( - cost_model.get_filter_selectivity(expr_tree).await.unwrap(), + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), 1.0 - 0.6 ); assert_approx_eq::assert_approx_eq!( cost_model - .get_filter_selectivity(expr_tree_rev) + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_rev) .await .unwrap(), 1.0 - 0.6 @@ -620,46 +596,39 @@ mod tests { 0, 0.0, ); - let table_id = TableId(0); let cost_model = create_mock_cost_model( - vec![table_id], - vec![HashMap::from([(0, per_attribute_stats)])], + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], vec![None], ); - let eq1 = bin_op( - BinOpType::Eq, - attr_index(0), // TODO: Fix this - cnst(Value::Int32(1)), - ); - let eq5 = bin_op( - BinOpType::Eq, - attr_index(0), // TODO: Fix this - cnst(Value::Int32(5)), - ); - let eq8 = bin_op( - BinOpType::Eq, - attr_index(0), // TODO: Fix this - cnst(Value::Int32(8)), - ); + let eq1 = bin_op(BinOpType::Eq, attr_index(0), cnst(Value::Int32(1))); + let eq5 = bin_op(BinOpType::Eq, attr_index(0), cnst(Value::Int32(5))); + let eq8 = bin_op(BinOpType::Eq, attr_index(0), cnst(Value::Int32(8))); let expr_tree = log_op(LogOpType::And, vec![eq1.clone(), eq5.clone(), eq8.clone()]); let expr_tree_shift1 = log_op(LogOpType::And, vec![eq5.clone(), eq8.clone(), eq1.clone()]); let expr_tree_shift2 = log_op(LogOpType::And, vec![eq8.clone(), eq1.clone(), eq5.clone()]); assert_approx_eq::assert_approx_eq!( - cost_model.get_filter_selectivity(expr_tree).await.unwrap(), + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), 0.03 ); assert_approx_eq::assert_approx_eq!( cost_model - .get_filter_selectivity(expr_tree_shift1) + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_shift1) .await .unwrap(), 0.03 ); assert_approx_eq::assert_approx_eq!( cost_model - .get_filter_selectivity(expr_tree_shift2) + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_shift2) .await .unwrap(), 0.03 @@ -678,46 +647,39 @@ mod tests { 0, 0.0, ); - let table_id = TableId(0); let cost_model = create_mock_cost_model( - vec![table_id], - vec![HashMap::from([(0, per_attribute_stats)])], + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], vec![None], ); - let eq1 = bin_op( - BinOpType::Eq, - attr_index(0), // TODO: Fix this - cnst(Value::Int32(1)), - ); - let eq5 = bin_op( - BinOpType::Eq, - attr_index(0), // TODO: Fix this - cnst(Value::Int32(5)), - ); - let eq8 = bin_op( - BinOpType::Eq, - attr_index(0), // TODO: Fix this - cnst(Value::Int32(8)), - ); + let eq1 = bin_op(BinOpType::Eq, attr_index(0), cnst(Value::Int32(1))); + let eq5 = bin_op(BinOpType::Eq, attr_index(0), cnst(Value::Int32(5))); + let eq8 = bin_op(BinOpType::Eq, attr_index(0), cnst(Value::Int32(8))); let expr_tree = log_op(LogOpType::Or, vec![eq1.clone(), eq5.clone(), eq8.clone()]); let expr_tree_shift1 = log_op(LogOpType::Or, vec![eq5.clone(), eq8.clone(), eq1.clone()]); let expr_tree_shift2 = log_op(LogOpType::Or, vec![eq8.clone(), eq1.clone(), eq5.clone()]); assert_approx_eq::assert_approx_eq!( - cost_model.get_filter_selectivity(expr_tree).await.unwrap(), + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), 0.72 ); assert_approx_eq::assert_approx_eq!( cost_model - .get_filter_selectivity(expr_tree_shift1) + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_shift1) .await .unwrap(), 0.72 ); assert_approx_eq::assert_approx_eq!( cost_model - .get_filter_selectivity(expr_tree_shift2) + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_shift2) .await .unwrap(), 0.72 @@ -735,24 +697,25 @@ mod tests { 0, 0.0, ); - let table_id = TableId(0); let cost_model = create_mock_cost_model( - vec![table_id], - vec![HashMap::from([(0, per_attribute_stats)])], + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], vec![None], ); let expr_tree = un_op( UnOpType::Not, - bin_op( - BinOpType::Eq, - attr_index(0), // TODO: Fix this - cnst(Value::Int32(1)), - ), + bin_op(BinOpType::Eq, attr_index(0), cnst(Value::Int32(1))), ); assert_approx_eq::assert_approx_eq!( - cost_model.get_filter_selectivity(expr_tree).await.unwrap(), + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), 0.7 ); } @@ -771,31 +734,36 @@ mod tests { 0, 0.0, ); - let table_id = TableId(0); let cost_model = create_mock_cost_model( - vec![table_id], - vec![HashMap::from([(0, per_attribute_stats)])], + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], vec![None], ); let expr_tree = bin_op( BinOpType::Eq, - attr_index(0), // TODO: Fix this + attr_index(0), cast(cnst(Value::Int64(1)), DataType::Int32), ); let expr_tree_rev = bin_op( BinOpType::Eq, cast(cnst(Value::Int64(1)), DataType::Int32), - attr_index(0), // TODO: Fix this + attr_index(0), ); assert_approx_eq::assert_approx_eq!( - cost_model.get_filter_selectivity(expr_tree).await.unwrap(), + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), 0.3 ); assert_approx_eq::assert_approx_eq!( cost_model - .get_filter_selectivity(expr_tree_rev) + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_rev) .await .unwrap(), 0.3 @@ -813,18 +781,16 @@ mod tests { 0, 0.1, ); - let table_id = TableId(0); - // let attr_infos = HashMap::from([( - // table_id, - // vec![Attribute { - // name: String::from("attr1"), - // typ: ConstantType::Int32, - // nullable: false, - // }], - // )]); - let cost_model = create_mock_cost_model( - vec![table_id], - vec![HashMap::from([(0, per_attribute_stats)])], + let cost_model = create_mock_cost_model_with_attr_types( + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + ConstantType::Int32, + )])], vec![None], ); @@ -840,12 +806,15 @@ mod tests { ); assert_approx_eq::assert_approx_eq!( - cost_model.get_filter_selectivity(expr_tree).await.unwrap(), + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), 0.3 ); assert_approx_eq::assert_approx_eq!( cost_model - .get_filter_selectivity(expr_tree_rev) + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_rev) .await .unwrap(), 0.3 @@ -866,45 +835,40 @@ mod tests { 0.0, ); let table_id = TableId(0); - // let attr_infos = HashMap::from([( - // table_id, - // vec![ - // Attribute { - // name: String::from("attr1"), - // typ: ConstantType::Int32, - // nullable: false, - // }, - // Attribute { - // name: String::from("attr2"), - // typ: ConstantType::Int64, - // nullable: false, - // }, - // ], - // )]); - let cost_model = create_mock_cost_model( - vec![table_id], - vec![HashMap::from([(0, per_attribute_stats)])], + let cost_model = create_mock_cost_model_with_attr_types( + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], + vec![HashMap::from([ + (TEST_ATTR1_BASE_INDEX, ConstantType::Int32), + (TEST_ATTR2_BASE_INDEX, ConstantType::Int64), + ])], vec![None], ); let expr_tree = bin_op( BinOpType::Eq, - cast(attr_index(0), DataType::Int64), // TODO: Fix this - attr_index(1), // TODO: Fix this + cast(attr_index(0), DataType::Int64), + attr_index(1), ); let expr_tree_rev = bin_op( BinOpType::Eq, - attr_index(1), // TODO: Fix this - cast(attr_index(0), DataType::Int64), // TODO: Fix this + attr_index(1), + cast(attr_index(0), DataType::Int64), ); assert_approx_eq::assert_approx_eq!( - cost_model.get_filter_selectivity(expr_tree).await.unwrap(), + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), DEFAULT_EQ_SEL ); assert_approx_eq::assert_approx_eq!( cost_model - .get_filter_selectivity(expr_tree_rev) + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_rev) .await .unwrap(), DEFAULT_EQ_SEL diff --git a/optd-cost-model/src/cost/filter/in_list.rs b/optd-cost-model/src/cost/filter/in_list.rs index 69b9b4f..f056fb1 100644 --- a/optd-cost-model/src/cost/filter/in_list.rs +++ b/optd-cost-model/src/cost/filter/in_list.rs @@ -4,6 +4,8 @@ use crate::{ predicates::{ attr_index_pred::AttrIndexPred, constant_pred::ConstantPred, in_list_pred::InListPred, }, + properties::attr_ref::{AttrRef, BaseTableAttrRef}, + types::GroupId, }, cost_model::CostModelImpl, stats::UNIMPLEMENTED_SEL, @@ -14,7 +16,11 @@ use crate::{ impl CostModelImpl { /// Only support attrA in (val1, val2, val3) where attrA is a attribute ref and /// val1, val2, val3 are constants. - pub(crate) async fn get_in_list_selectivity(&self, expr: &InListPred) -> CostModelResult { + pub(crate) async fn get_in_list_selectivity( + &self, + group_id: GroupId, + expr: &InListPred, + ) -> CostModelResult { let child = expr.child(); // Check child is a attribute ref. @@ -34,7 +40,7 @@ impl CostModelImpl { // Convert child and const expressions to concrete types. let attr_ref_pred = AttrIndexPred::from_pred_node(child).unwrap(); let attr_ref_idx = attr_ref_pred.attr_index(); - let table_id = todo!(); // TODO: Fix this + let list_exprs = list_exprs .into_iter() .map(|expr| { @@ -44,24 +50,30 @@ impl CostModelImpl { .collect::>(); let negated = expr.negated(); - // TODO: Consider attribute is a derived attribute - let mut in_sel = 0.0; - for expr in &list_exprs { - let selectivity = self - .get_attribute_equality_selectivity( - table_id, - attr_ref_idx, - &expr.value(), - /* is_equality */ true, - ) - .await?; - in_sel += selectivity; - } - in_sel = in_sel.min(1.0); - if negated { - Ok(1.0 - in_sel) + if let AttrRef::BaseTableAttrRef(BaseTableAttrRef { table_id, attr_idx }) = + self.memo.get_attribute_ref(group_id, attr_ref_idx) + { + let mut in_sel = 0.0; + for expr in &list_exprs { + let selectivity = self + .get_attribute_equality_selectivity( + table_id, + attr_idx, + &expr.value(), + /* is_equality */ true, + ) + .await?; + in_sel += selectivity; + } + in_sel = in_sel.min(1.0); + if negated { + Ok(1.0 - in_sel) + } else { + Ok(in_sel) + } } else { - Ok(in_sel) + // TODO: Child is a derived attribute. + Ok(UNIMPLEMENTED_SEL) } } } @@ -71,8 +83,12 @@ mod tests { use std::collections::HashMap; use crate::{ - common::{types::TableId, values::Value}, + common::{ + types::{GroupId, TableId}, + values::Value, + }, cost_model::tests::*, + memo_ext::tests::MemoGroupInfo, stats::{ utilities::{counter::Counter, simple_map::SimpleMap}, MostCommonValues, @@ -90,61 +106,59 @@ mod tests { 2, 0.0, ); - let table_id = TableId(0); let cost_model = create_mock_cost_model( - vec![table_id], - vec![HashMap::from([(0, per_attribute_stats)])], + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], vec![None], ); assert_approx_eq::assert_approx_eq!( cost_model - .get_in_list_selectivity(&in_list(0, vec![Value::Int32(1)], false)) // TODO: Fix this + .get_in_list_selectivity(TEST_GROUP1_ID, &in_list(0, vec![Value::Int32(1)], false)) .await .unwrap(), 0.8 ); assert_approx_eq::assert_approx_eq!( cost_model - .get_in_list_selectivity(&in_list( - // TODO: Fix this - 0, - vec![Value::Int32(1), Value::Int32(2)], - false - )) + .get_in_list_selectivity( + TEST_GROUP1_ID, + &in_list(0, vec![Value::Int32(1), Value::Int32(2)], false) + ) .await .unwrap(), 1.0 ); assert_approx_eq::assert_approx_eq!( cost_model - .get_in_list_selectivity(&in_list(0, vec![Value::Int32(3)], false)) // TODO: Fix this + .get_in_list_selectivity(TEST_GROUP1_ID, &in_list(0, vec![Value::Int32(3)], false)) .await .unwrap(), 0.0 ); assert_approx_eq::assert_approx_eq!( cost_model - .get_in_list_selectivity(&in_list(0, vec![Value::Int32(1)], true)) // TODO: Fix this + .get_in_list_selectivity(TEST_GROUP1_ID, &in_list(0, vec![Value::Int32(1)], true)) .await .unwrap(), 0.2 ); assert_approx_eq::assert_approx_eq!( cost_model - .get_in_list_selectivity(&in_list( - // TODO: Fix this - 0, - vec![Value::Int32(1), Value::Int32(2)], - true - )) + .get_in_list_selectivity( + TEST_GROUP1_ID, + &in_list(0, vec![Value::Int32(1), Value::Int32(2)], true) + ) .await .unwrap(), 0.0 ); assert_approx_eq::assert_approx_eq!( cost_model - .get_in_list_selectivity(&in_list(0, vec![Value::Int32(3)], true)) // TODO: Fix this + .get_in_list_selectivity(TEST_GROUP1_ID, &in_list(0, vec![Value::Int32(3)], true)) // TODO: Fix this .await .unwrap(), 1.0 diff --git a/optd-cost-model/src/cost/filter/like.rs b/optd-cost-model/src/cost/filter/like.rs index fb11833..32800e4 100644 --- a/optd-cost-model/src/cost/filter/like.rs +++ b/optd-cost-model/src/cost/filter/like.rs @@ -6,6 +6,8 @@ use crate::{ predicates::{ attr_index_pred::AttrIndexPred, constant_pred::ConstantPred, like_pred::LikePred, }, + properties::attr_ref::{AttrRef, BaseTableAttrRef}, + types::GroupId, }, cost_model::CostModelImpl, stats::{ @@ -28,7 +30,11 @@ impl CostModelImpl { /// is composed of MCV frequency and non-MCV selectivity. MCV frequency is computed by /// adding up frequencies of MCVs that match the pattern. Non-MCV selectivity is computed /// in the same way that Postgres computes selectivity for the wildcard part of the pattern. - pub(crate) async fn get_like_selectivity(&self, like_expr: &LikePred) -> CostModelResult { + pub(crate) async fn get_like_selectivity( + &self, + group_id: GroupId, + like_expr: &LikePred, + ) -> CostModelResult { let child = like_expr.child(); // Check child is a attribute ref. @@ -44,46 +50,47 @@ impl CostModelImpl { let attr_ref_pred = AttrIndexPred::from_pred_node(child).unwrap(); let attr_ref_idx = attr_ref_pred.attr_index(); - let table_id = todo!(); // TODO: Fix this - // TODO: Consider attribute is a derived attribute - let pattern = ConstantPred::from_pred_node(pattern) - .expect("we already checked pattern is a constant") - .value() - .as_str(); - - // Compute the selectivity exculuding MCVs. - // See Postgres `like_selectivity`. - let non_mcv_sel = pattern - .chars() - .fold(1.0, |acc, c| { - if c == '%' { - acc * FULL_WILDCARD_SEL_FACTOR - } else { - acc * FIXED_CHAR_SEL_FACTOR - } - }) - .min(1.0); - - // Compute the selectivity in MCVs. - // TODO: Handle the case where `attribute_stats` is None. - if let Some(attribute_stats) = self - .get_attribute_comb_stats(table_id, &[attr_ref_idx]) - .await? + if let AttrRef::BaseTableAttrRef(BaseTableAttrRef { table_id, attr_idx }) = + self.memo.get_attribute_ref(group_id, attr_ref_idx) { - let (mcv_freq, null_frac) = { - let pred = Box::new(move |val: &AttributeCombValue| { - let string = - StringArray::from(vec![val[0].as_ref().unwrap().as_str().as_ref()]); - let pattern = StringArray::from(vec![pattern.as_ref()]); - like(&string, &pattern).unwrap().value(0) - }); - ( - attribute_stats.mcvs.freq_over_pred(pred), - attribute_stats.null_frac, - ) - }; - + let pattern = ConstantPred::from_pred_node(pattern) + .expect("we already checked pattern is a constant") + .value() + .as_str(); + + // Compute the selectivity exculuding MCVs. + // See Postgres `like_selectivity`. + let non_mcv_sel = pattern + .chars() + .fold(1.0, |acc, c| { + if c == '%' { + acc * FULL_WILDCARD_SEL_FACTOR + } else { + acc * FIXED_CHAR_SEL_FACTOR + } + }) + .min(1.0); + + // Compute the selectivity in MCVs. + // TODO: Handle the case where `attribute_stats` is None. + let (mut mcv_freq, mut null_frac) = (0.0, 0.0); + if let Some(attribute_stats) = + self.get_attribute_comb_stats(table_id, &[attr_idx]).await? + { + (mcv_freq, null_frac) = { + let pred = Box::new(move |val: &AttributeCombValue| { + let string = + StringArray::from(vec![val[0].as_ref().unwrap().as_str().as_ref()]); + let pattern = StringArray::from(vec![pattern.as_ref()]); + like(&string, &pattern).unwrap().value(0) + }); + ( + attribute_stats.mcvs.freq_over_pred(pred), + attribute_stats.null_frac, + ) + }; + } let result = non_mcv_sel + mcv_freq; Ok(if like_expr.negated() { @@ -95,6 +102,7 @@ impl CostModelImpl { // `patternsel_common`. .clamp(0.0001, 0.9999)) } else { + // TOOD: derived attribute Ok(UNIMPLEMENTED_SEL) } } @@ -105,7 +113,10 @@ mod tests { use std::collections::HashMap; use crate::{ - common::{types::TableId, values::Value}, + common::{ + types::{GroupId, TableId}, + values::Value, + }, cost_model::tests::*, stats::{ utilities::{counter::Counter, simple_map::SimpleMap}, @@ -124,30 +135,35 @@ mod tests { 2, 0.0, ); - let table_id = TableId(0); let cost_model = create_mock_cost_model( - vec![table_id], - vec![HashMap::from([(0, per_attribute_stats)])], + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], vec![None], ); assert_approx_eq::assert_approx_eq!( cost_model - .get_like_selectivity(&like(0, "%abcd%", false)) // TODO: Fix this + .get_like_selectivity( + TEST_GROUP1_ID, + &like(TEST_ATTR1_BASE_INDEX, "%abcd%", false) + ) // TODO: Fix this .await .unwrap(), 0.1 + FULL_WILDCARD_SEL_FACTOR.powi(2) * FIXED_CHAR_SEL_FACTOR.powi(4) ); assert_approx_eq::assert_approx_eq!( cost_model - .get_like_selectivity(&like(0, "%abc%", false)) // TODO: Fix this + .get_like_selectivity(TEST_GROUP1_ID, &like(TEST_ATTR1_BASE_INDEX, "%abc%", false)) // TODO: Fix this .await .unwrap(), 0.1 + 0.1 + FULL_WILDCARD_SEL_FACTOR.powi(2) * FIXED_CHAR_SEL_FACTOR.powi(3) ); assert_approx_eq::assert_approx_eq!( cost_model - .get_like_selectivity(&like(0, "%abc%", true)) // TODO: Fix this + .get_like_selectivity(TEST_GROUP1_ID, &like(TEST_ATTR1_BASE_INDEX, "%abc%", true)) // TODO: Fix this .await .unwrap(), 1.0 - (0.1 + 0.1 + FULL_WILDCARD_SEL_FACTOR.powi(2) * FIXED_CHAR_SEL_FACTOR.powi(3)) @@ -166,23 +182,25 @@ mod tests { 2, null_frac, ); - let table_id = TableId(0); let cost_model = create_mock_cost_model( - vec![table_id], - vec![HashMap::from([(0, per_attribute_stats)])], + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], vec![None], ); assert_approx_eq::assert_approx_eq!( cost_model - .get_like_selectivity(&like(0, "%abcd%", false)) // TODO: Fix this + .get_like_selectivity(TEST_GROUP1_ID, &like(0, "%abcd%", false)) // TODO: Fix this .await .unwrap(), 0.1 + FULL_WILDCARD_SEL_FACTOR.powi(2) * FIXED_CHAR_SEL_FACTOR.powi(4) ); assert_approx_eq::assert_approx_eq!( cost_model - .get_like_selectivity(&like(0, "%abcd%", true)) // TODO: Fix this + .get_like_selectivity(TEST_GROUP1_ID, &like(0, "%abcd%", true)) // TODO: Fix this .await .unwrap(), 1.0 - (0.1 + FULL_WILDCARD_SEL_FACTOR.powi(2) * FIXED_CHAR_SEL_FACTOR.powi(4)) diff --git a/optd-cost-model/src/cost/filter/log_op.rs b/optd-cost-model/src/cost/filter/log_op.rs index 66bab10..61862a2 100644 --- a/optd-cost-model/src/cost/filter/log_op.rs +++ b/optd-cost-model/src/cost/filter/log_op.rs @@ -1,5 +1,5 @@ use crate::{ - common::{nodes::ArcPredicateNode, predicates::log_op_pred::LogOpType}, + common::{nodes::ArcPredicateNode, predicates::log_op_pred::LogOpType, types::GroupId}, cost_model::CostModelImpl, storage::CostModelStorageManager, CostModelResult, @@ -8,6 +8,7 @@ use crate::{ impl CostModelImpl { pub(crate) async fn get_log_op_selectivity( &self, + group_id: GroupId, log_op_typ: LogOpType, children: &[ArcPredicateNode], ) -> CostModelResult { @@ -15,7 +16,7 @@ impl CostModelImpl { LogOpType::And => { let mut and_sel = 1.0; for child in children { - let selectivity = self.get_filter_selectivity(child.clone()).await?; + let selectivity = self.get_filter_selectivity(group_id, child.clone()).await?; and_sel *= selectivity; } Ok(and_sel) @@ -23,7 +24,7 @@ impl CostModelImpl { LogOpType::Or => { let mut or_sel_neg = 1.0; for child in children { - let selectivity = self.get_filter_selectivity(child.clone()).await?; + let selectivity = self.get_filter_selectivity(group_id, child.clone()).await?; or_sel_neg *= (1.0 - selectivity); } Ok(1.0 - or_sel_neg) diff --git a/optd-cost-model/src/cost/join/join.rs b/optd-cost-model/src/cost/join/join.rs index eeb5847..ec98ace 100644 --- a/optd-cost-model/src/cost/join/join.rs +++ b/optd-cost-model/src/cost/join/join.rs @@ -180,8 +180,10 @@ impl CostModelImpl { // get_filter_selectivity() function, but this may change in the future. let join_filter_selectivity = match filter_expr_tree { Some(filter_expr_tree) => { - // FIXME: Pass in group id or schema & attr_refs - self.get_filter_selectivity(filter_expr_tree).await? + // FIXME(group_id): Pass in group id or schema & attr_refs + let group_id = GroupId(0); + self.get_filter_selectivity(group_id, filter_expr_tree) + .await? } None => 1.0, }; diff --git a/optd-cost-model/src/cost/join/mod.rs b/optd-cost-model/src/cost/join/mod.rs index 65ab5bb..b752aa8 100644 --- a/optd-cost-model/src/cost/join/mod.rs +++ b/optd-cost-model/src/cost/join/mod.rs @@ -7,6 +7,7 @@ use crate::common::{ }; pub mod hash_join; +// FIXME: module has the same name as its containing module pub mod join; pub mod nested_loop_join; diff --git a/optd-cost-model/src/cost_model.rs b/optd-cost-model/src/cost_model.rs index 552a272..11fe2ec 100644 --- a/optd-cost-model/src/cost_model.rs +++ b/optd-cost-model/src/cost_model.rs @@ -165,23 +165,76 @@ pub mod tests { pub const TEST_GROUP3_ID: GroupId = GroupId(2); pub const TEST_GROUP4_ID: GroupId = GroupId(3); + // This is base index rather than ref index. + pub const TEST_ATTR1_BASE_INDEX: u64 = 0; + pub const TEST_ATTR2_BASE_INDEX: u64 = 1; + pub const TEST_ATTR3_BASE_INDEX: u64 = 2; + pub type TestPerAttributeStats = AttributeCombValueStats; // TODO: add tests for non-mock storage manager pub type TestOptCostModelMock = CostModelImpl; + // Use this method, we only create one group `TEST_GROUP1_ID` in the memo. + // We put the first attribute in the first table as the ref index 0 in the group. + // And put the second attribute in the first table as the ref index 1 in the group. + // etc. + // The orders of attributes and tables are defined by the order of their ids (smaller first). pub fn create_mock_cost_model( table_id: Vec, + // u64 should be base attribute index. + per_attribute_stats: Vec>, + row_counts: Vec>, + ) -> TestOptCostModelMock { + let attr_ids: Vec<(TableId, u64, Option)> = per_attribute_stats + .iter() + .enumerate() + .map(|(idx, m)| (table_id[idx], m)) + .flat_map(|(table_id, m)| { + m.iter() + .map(|(attr_idx, _)| (table_id, *attr_idx, None)) + .collect_vec() + }) + .sorted_by_key(|(table_id, attr_idx, _)| (*table_id, *attr_idx)) + .collect(); + create_mock_cost_model_with_memo( + table_id.clone(), + per_attribute_stats, + row_counts, + create_one_group_all_base_attributes_mock_memo(attr_ids), + ) + } + + pub fn create_mock_cost_model_with_attr_types( + table_id: Vec, + // u64 should be base attribute index. per_attribute_stats: Vec>, + attributes: Vec>, row_counts: Vec>, ) -> TestOptCostModelMock { - create_mock_cost_model_with_memo(table_id, per_attribute_stats, row_counts, HashMap::new()) + let attr_ids: Vec<(TableId, u64, Option)> = attributes + .iter() + .enumerate() + .map(|(idx, m)| (table_id[idx], m)) + .flat_map(|(table_id, m)| { + m.iter() + .map(|(attr_idx, typ)| (table_id, *attr_idx, Some(*typ))) + .collect_vec() + }) + .sorted_by_key(|(table_id, attr_idx, _)| (*table_id, *attr_idx)) + .collect(); + create_mock_cost_model_with_memo( + table_id.clone(), + per_attribute_stats, + row_counts, + create_one_group_all_base_attributes_mock_memo(attr_ids), + ) } pub fn create_mock_cost_model_with_memo( table_id: Vec, per_attribute_stats: Vec>, row_counts: Vec>, - memo: HashMap, + memo: MockMemoExtImpl, ) -> TestOptCostModelMock { let storage_manager = CostModelStorageMockManagerImpl::new( table_id @@ -202,11 +255,36 @@ pub mod tests { }) .collect(), ); - CostModelImpl::new( - storage_manager, - CatalogSource::Mock, - Arc::new(MockMemoExtImpl::from(memo)), - ) + CostModelImpl::new(storage_manager, CatalogSource::Mock, Arc::new(memo)) + } + + // attributes: Vec<(TableId, AttrBaseIndex)> + pub fn create_one_group_all_base_attributes_mock_memo( + attr_ids: Vec<(TableId, u64, Option)>, + ) -> MockMemoExtImpl { + let group_info = MemoGroupInfo::new( + Schema::new( + attr_ids + .clone() + .into_iter() + .map(|(_, _, typ)| Attribute { + name: "attr".to_string(), + typ: typ.unwrap_or(ConstantType::Int64), + nullable: false, + }) + .collect(), + ), + GroupAttrRefs::new( + attr_ids + .into_iter() + .map(|(table_id, attr_base_index, _)| { + AttrRef::new_base_table_attr_ref(table_id, attr_base_index) + }) + .collect(), + None, + ), + ); + MockMemoExtImpl::from(HashMap::from([(TEST_GROUP1_ID, group_info)])) } /// Create a cost model two tables, each with one attribute. Each attribute has 100 values. diff --git a/optd-cost-model/src/storage/mock.rs b/optd-cost-model/src/storage/mock.rs index 4894859..d878bcb 100644 --- a/optd-cost-model/src/storage/mock.rs +++ b/optd-cost-model/src/storage/mock.rs @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize}; use crate::{common::types::TableId, stats::AttributeCombValueStats, CostModelResult}; -use super::{Attribute, CostModelStorageManager}; +use super::CostModelStorageManager; pub type AttrIndices = Vec; @@ -44,14 +44,6 @@ impl CostModelStorageMockManagerImpl { } impl CostModelStorageManager for CostModelStorageMockManagerImpl { - async fn get_attribute_info( - &self, - table_id: TableId, - attr_base_index: u64, - ) -> CostModelResult> { - unimplemented!() - } - async fn get_attributes_comb_statistics( &self, table_id: TableId, diff --git a/optd-cost-model/src/storage/mod.rs b/optd-cost-model/src/storage/mod.rs index 1dc86fc..d3d26cd 100644 --- a/optd-cost-model/src/storage/mod.rs +++ b/optd-cost-model/src/storage/mod.rs @@ -1,20 +1,10 @@ -use crate::{ - common::{properties::Attribute, types::TableId}, - stats::AttributeCombValueStats, - CostModelResult, -}; +use crate::{common::types::TableId, stats::AttributeCombValueStats, CostModelResult}; pub mod mock; pub mod persistent; #[trait_variant::make(Send)] pub trait CostModelStorageManager { - async fn get_attribute_info( - &self, - table_id: TableId, - attr_base_index: u64, - ) -> CostModelResult>; - async fn get_attributes_comb_statistics( &self, table_id: TableId, diff --git a/optd-cost-model/src/storage/persistent.rs b/optd-cost-model/src/storage/persistent.rs index ed4f07a..dede7f3 100644 --- a/optd-cost-model/src/storage/persistent.rs +++ b/optd-cost-model/src/storage/persistent.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use optd_persistent::{cost_model::interface::StatType, CostModelStorageLayer}; use crate::{ - common::{properties::Attribute, types::TableId}, + common::types::TableId, stats::{utilities::counter::Counter, AttributeCombValueStats, Distribution, MostCommonValues}, CostModelResult, }; @@ -26,14 +26,6 @@ impl CostModelStorageManagerImpl { impl CostModelStorageManager for CostModelStorageManagerImpl { - async fn get_attribute_info( - &self, - table_id: TableId, - attr_base_index: u64, - ) -> CostModelResult> { - unimplemented!() - } - /// Gets the latest statistics for a given table. /// /// TODO: Currently, in `AttributeCombValueStats`, only `Distribution` is optional. From 489ff480b78968c3a0a4126b115ed56a832759ba Mon Sep 17 00:00:00 2001 From: Yuanxin Cao Date: Mon, 18 Nov 2024 20:22:59 -0500 Subject: [PATCH 48/51] add join test --- Cargo.lock | 34 + optd-cost-model/Cargo.toml | 1 + optd-cost-model/src/cost/join/core.rs | 1201 +++++++++++++++++++++++++ optd-cost-model/src/cost/join/join.rs | 401 --------- optd-cost-model/src/cost/join/mod.rs | 3 +- optd-cost-model/src/cost_model.rs | 43 +- optd-cost-model/src/stats/mod.rs | 8 + 7 files changed, 1280 insertions(+), 411 deletions(-) create mode 100644 optd-cost-model/src/cost/join/core.rs delete mode 100644 optd-cost-model/src/cost/join/join.rs diff --git a/Cargo.lock b/Cargo.lock index f6c0033..b54d901 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2257,6 +2257,7 @@ dependencies = [ "serde", "serde_json", "serde_with", + "test-case", "tokio", "trait-variant", ] @@ -3644,6 +3645,39 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "test-case" +version = "3.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb2550dd13afcd286853192af8601920d959b14c401fcece38071d53bf0768a8" +dependencies = [ + "test-case-macros", +] + +[[package]] +name = "test-case-core" +version = "3.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adcb7fd841cd518e279be3d5a3eb0636409487998a4aff22f3de87b81e88384f" +dependencies = [ + "cfg-if", + "proc-macro2", + "quote", + "syn 2.0.87", +] + +[[package]] +name = "test-case-macros" +version = "3.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c89e72a01ed4c579669add59014b9a524d609c0c88c6a585ce37485879f6ffb" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", + "test-case-core", +] + [[package]] name = "thiserror" version = "1.0.69" diff --git a/optd-cost-model/Cargo.toml b/optd-cost-model/Cargo.toml index 4ede352..e8b22aa 100644 --- a/optd-cost-model/Cargo.toml +++ b/optd-cost-model/Cargo.toml @@ -22,3 +22,4 @@ tokio = { version = "1.0.1", features = ["macros", "rt-multi-thread"] } [dev-dependencies] crossbeam = "0.8" rand = "0.8" +test-case = "3.3" diff --git a/optd-cost-model/src/cost/join/core.rs b/optd-cost-model/src/cost/join/core.rs new file mode 100644 index 0000000..6693b0c --- /dev/null +++ b/optd-cost-model/src/cost/join/core.rs @@ -0,0 +1,1201 @@ +use std::collections::HashSet; + +use itertools::Itertools; + +use crate::{ + common::{ + nodes::{ArcPredicateNode, JoinType, PredicateType, ReprPredicateNode}, + predicates::{ + attr_index_pred::AttrIndexPred, + bin_op_pred::BinOpType, + list_pred::ListPred, + log_op_pred::{LogOpPred, LogOpType}, + }, + properties::attr_ref::{ + self, AttrRef, AttrRefs, BaseTableAttrRef, EqPredicate, GroupAttrRefs, + SemanticCorrelation, + }, + types::GroupId, + }, + cost::join::get_on_attr_ref_pair, + cost_model::CostModelImpl, + stats::DEFAULT_NUM_DISTINCT, + storage::CostModelStorageManager, + CostModelResult, +}; + +impl CostModelImpl { + /// The expr_tree input must be a "mixed expression tree", just like with + /// `get_filter_selectivity`. + /// + /// This is a "wrapper" to separate the equality conditions from the filter conditions before + /// calling the "main" `get_join_selectivity_core` function. + #[allow(clippy::too_many_arguments)] + pub(crate) async fn get_join_selectivity_from_expr_tree( + &self, + join_typ: JoinType, + expr_tree: ArcPredicateNode, + attr_refs: &AttrRefs, + input_correlation: Option, + left_row_cnt: f64, + right_row_cnt: f64, + ) -> CostModelResult { + if expr_tree.typ == PredicateType::LogOp(LogOpType::And) { + let mut on_attr_ref_pairs = vec![]; + let mut filter_expr_trees = vec![]; + for child_expr_tree in &expr_tree.children { + if let Some(on_attr_ref_pair) = + get_on_attr_ref_pair(child_expr_tree.clone(), attr_refs) + { + on_attr_ref_pairs.push(on_attr_ref_pair) + } else { + let child_expr = child_expr_tree.clone(); + filter_expr_trees.push(child_expr); + } + } + assert!(on_attr_ref_pairs.len() + filter_expr_trees.len() == expr_tree.children.len()); + let filter_expr_tree = if filter_expr_trees.is_empty() { + None + } else { + Some(LogOpPred::new(LogOpType::And, filter_expr_trees).into_pred_node()) + }; + self.get_join_selectivity_core( + join_typ, + on_attr_ref_pairs, + filter_expr_tree, + attr_refs, + input_correlation, + left_row_cnt, + right_row_cnt, + 0, + ) + .await + } else { + #[allow(clippy::collapsible_else_if)] + if let Some(on_attr_ref_pair) = get_on_attr_ref_pair(expr_tree.clone(), attr_refs) { + self.get_join_selectivity_core( + join_typ, + vec![on_attr_ref_pair], + None, + attr_refs, + input_correlation, + left_row_cnt, + right_row_cnt, + 0, + ) + .await + } else { + self.get_join_selectivity_core( + join_typ, + vec![], + Some(expr_tree), + attr_refs, + input_correlation, + left_row_cnt, + right_row_cnt, + 0, + ) + .await + } + } + } + + /// A wrapper to convert the join keys to the format expected by get_join_selectivity_core() + #[allow(clippy::too_many_arguments)] + pub(crate) async fn get_join_selectivity_from_keys( + &self, + join_typ: JoinType, + left_keys: ListPred, + right_keys: ListPred, + attr_refs: &AttrRefs, + input_correlation: Option, + left_row_cnt: f64, + right_row_cnt: f64, + left_attr_cnt: usize, + ) -> CostModelResult { + assert!(left_keys.len() == right_keys.len()); + // I assume that the keys are already in the right order + // s.t. the ith key of left_keys corresponds with the ith key of right_keys + let on_attr_ref_pairs = left_keys + .to_vec() + .into_iter() + .zip(right_keys.to_vec()) + .map(|(left_key, right_key)| { + ( + AttrIndexPred::from_pred_node(left_key).expect("keys should be AttrRefPreds"), + AttrIndexPred::from_pred_node(right_key).expect("keys should be AttrRefPreds"), + ) + }) + .collect_vec(); + self.get_join_selectivity_core( + join_typ, + on_attr_ref_pairs, + None, + attr_refs, + input_correlation, + left_row_cnt, + right_row_cnt, + left_attr_cnt, + ) + .await + } + + /// The core logic of join selectivity which assumes we've already separated the expression + /// into the on conditions and the filters. + /// + /// Hash join and NLJ reference right table attributes differently, hence the + /// `right_attr_ref_offset` parameter. + /// + /// For hash join, the right table attributes indices are with respect to the right table, + /// which means #0 is the first attribute of the right table. + /// + /// For NLJ, the right table attributes indices are with respect to the output of the join. + /// For example, if the left table has 3 attributes, the first attribute of the right table + /// is #3 instead of #0. + #[allow(clippy::too_many_arguments)] + async fn get_join_selectivity_core( + &self, + join_typ: JoinType, + on_attr_ref_pairs: Vec<(AttrIndexPred, AttrIndexPred)>, + filter_expr_tree: Option, + attr_refs: &AttrRefs, + input_correlation: Option, + left_row_cnt: f64, + right_row_cnt: f64, + right_attr_ref_offset: usize, + ) -> CostModelResult { + let join_on_selectivity = self + .get_join_on_selectivity( + &on_attr_ref_pairs, + attr_refs, + input_correlation, + right_attr_ref_offset, + ) + .await?; + // Currently, there is no difference in how we handle a join filter and a select filter, + // so we use the same function. + // + // One difference (that we *don't* care about right now) is that join filters can contain + // expressions from multiple different tables. Currently, this doesn't affect the + // get_filter_selectivity() function, but this may change in the future. + let join_filter_selectivity = match filter_expr_tree { + Some(filter_expr_tree) => { + // FIXME(group_id): Pass in group id or schema & attr_refs + let group_id = GroupId(0); + self.get_filter_selectivity(group_id, filter_expr_tree) + .await? + } + None => 1.0, + }; + let inner_join_selectivity = join_on_selectivity * join_filter_selectivity; + + Ok(match join_typ { + JoinType::Inner => inner_join_selectivity, + JoinType::LeftOuter => f64::max(inner_join_selectivity, 1.0 / right_row_cnt), + JoinType::RightOuter => f64::max(inner_join_selectivity, 1.0 / left_row_cnt), + JoinType::Cross => { + assert!( + on_attr_ref_pairs.is_empty(), + "Cross joins should not have on attributes" + ); + join_filter_selectivity + } + _ => unimplemented!("join_typ={} is not implemented", join_typ), + }) + } + + /// Get the selectivity of one attribute eq predicate, e.g. attrA = attrB. + async fn get_join_selectivity_from_on_attr_ref_pair( + &self, + left: &AttrRef, + right: &AttrRef, + ) -> CostModelResult { + // the formula for each pair is min(1 / ndistinct1, 1 / ndistinct2) + // (see https://postgrespro.com/blog/pgsql/5969618) + let mut ndistincts = vec![]; + for attr_ref in [left, right] { + let ndistinct = match attr_ref { + AttrRef::BaseTableAttrRef(base_attr_ref) => { + match self + .get_attribute_comb_stats(base_attr_ref.table_id, &[base_attr_ref.attr_idx]) + .await? + { + Some(per_attr_stats) => per_attr_stats.ndistinct, + None => DEFAULT_NUM_DISTINCT, + } + } + AttrRef::Derived => DEFAULT_NUM_DISTINCT, + }; + ndistincts.push(ndistinct); + } + + // using reduce(f64::min) is the idiomatic workaround to min() because + // f64 does not implement Ord due to NaN + let selectivity = ndistincts.into_iter().map(|ndistinct| 1.0 / ndistinct as f64).reduce(f64::min).expect("reduce() only returns None if the iterator is empty, which is impossible since attr_ref_exprs.len() == 2"); + assert!( + !selectivity.is_nan(), + "it should be impossible for selectivity to be NaN since n-distinct is never 0" + ); + Ok(selectivity) + } + + /// Given a set of N attributes involved in a multi-equality, find the total selectivity + /// of the multi-equality. + /// + /// This is a generalization of get_join_selectivity_from_on_attr_ref_pair(). + async fn get_join_selectivity_from_most_selective_attrs( + &self, + base_attr_refs: HashSet, + ) -> CostModelResult { + assert!(base_attr_refs.len() > 1); + let num_base_attr_refs = base_attr_refs.len(); + + let mut ndistincts = vec![]; + for base_attr_ref in base_attr_refs.iter() { + let ndistinct = match self + .get_attribute_comb_stats(base_attr_ref.table_id, &[base_attr_ref.attr_idx]) + .await? + { + Some(per_attr_stats) => per_attr_stats.ndistinct, + None => DEFAULT_NUM_DISTINCT, + }; + ndistincts.push(ndistinct); + } + + Ok(ndistincts + .into_iter() + .map(|ndistinct| 1.0 / ndistinct as f64) + .sorted_by(|a, b| { + a.partial_cmp(b) + .expect("No floats should be NaN since n-distinct is never 0") + }) + .take(num_base_attr_refs - 1) + .product()) + } + + /// A predicate set defines a "multi-equality graph", which is an unweighted undirected graph. + /// The nodes are attributes while edges are predicates. The old graph is defined by + /// `past_eq_attrs` while the `predicate` is the new addition to this graph. This + /// unweighted undirected graph consists of a number of connected components, where each + /// connected component represents attributes that are set to be equal to each other. Single + /// nodes not connected to anything are considered standalone connected components. + /// + /// The selectivity of each connected component of N nodes is equal to the product of + /// 1/ndistinct of the N-1 nodes with the highest ndistinct values. You can see this if you + /// imagine that all attributes being joined are unique attributes and that they follow the + /// inclusion principle (every element of the smaller tables is present in the larger + /// tables). When these assumptions are not true, the selectivity may not be completely + /// accurate. However, it is still fairly accurate. + /// + /// However, we cannot simply add `predicate` to the multi-equality graph and compute the + /// selectivity of the entire connected component, because this would be "double counting" a + /// lot of nodes. The join(s) before this join would already have a selectivity value. Thus, + /// we compute the selectivity of the join(s) before this join (the first block of the + /// function) and then the selectivity of the connected component after this join. The + /// quotient is the "adjustment" factor. + /// + /// NOTE: This function modifies `past_eq_attrs` by adding `predicate` to it. + async fn get_join_selectivity_adjustment_when_adding_to_multi_equality_graph( + &self, + predicate: &EqPredicate, + past_eq_attrs: &mut SemanticCorrelation, + ) -> CostModelResult { + if predicate.left == predicate.right { + // self-join, TODO: is this correct? + return Ok(1.0); + } + // To find the adjustment, we need to know the selectivity of the graph before `predicate` + // is added. + // + // There are two cases: (1) adding `predicate` does not change the # of connected + // components, and (2) adding `predicate` reduces the # of connected by 1. Note that + // attributes not involved in any predicates are considered a part of the graph and are + // a connected component on their own. + let children_pred_sel = { + if past_eq_attrs.is_eq(&predicate.left, &predicate.right) { + self.get_join_selectivity_from_most_selective_attrs( + past_eq_attrs.find_attrs_for_eq_attribute_set(&predicate.left), + ) + .await? + } else { + let left_sel = if past_eq_attrs.contains(&predicate.left) { + self.get_join_selectivity_from_most_selective_attrs( + past_eq_attrs.find_attrs_for_eq_attribute_set(&predicate.left), + ) + .await? + } else { + 1.0 + }; + let right_sel = if past_eq_attrs.contains(&predicate.right) { + self.get_join_selectivity_from_most_selective_attrs( + past_eq_attrs.find_attrs_for_eq_attribute_set(&predicate.right), + ) + .await? + } else { + 1.0 + }; + left_sel * right_sel + } + }; + + // Add predicate to past_eq_attrs and compute the selectivity of the connected component + // it creates. + past_eq_attrs.add_predicate(predicate.clone()); + let new_pred_sel = { + let attrs = past_eq_attrs.find_attrs_for_eq_attribute_set(&predicate.left); + self.get_join_selectivity_from_most_selective_attrs(attrs) + } + .await?; + + // Compute the adjustment factor. + Ok(new_pred_sel / children_pred_sel) + } + + /// Get the selectivity of the on conditions. + /// + /// Note that the selectivity of the on conditions does not depend on join type. + /// Join type is accounted for separately in get_join_selectivity_core(). + /// + /// We also check if each predicate is correlated with any of the previous predicates. + /// + /// More specifically, we are checking if the predicate can be expressed with other existing + /// predicates. E.g. if we have a predicate like A = B and B = C is equivalent to A = C. + // + /// However, we don't just throw away A = C, because we want to pick the most selective + /// predicates. For details on how we do this, see + /// `get_join_selectivity_from_redundant_predicates`. + async fn get_join_on_selectivity( + &self, + on_attr_ref_pairs: &[(AttrIndexPred, AttrIndexPred)], + attr_refs: &AttrRefs, + input_correlation: Option, + right_attr_ref_offset: usize, + ) -> CostModelResult { + let mut past_eq_attrs = input_correlation.unwrap_or_default(); + + // Multiply the selectivities of all individual conditions together + let mut selectivity = 1.0; + for on_attr_ref_pair in on_attr_ref_pairs { + let left_attr_ref = &attr_refs[on_attr_ref_pair.0.attr_index() as usize]; + let right_attr_ref = + &attr_refs[on_attr_ref_pair.1.attr_index() as usize + right_attr_ref_offset]; + + if let (AttrRef::BaseTableAttrRef(left), AttrRef::BaseTableAttrRef(right)) = + (left_attr_ref, right_attr_ref) + { + let predicate = EqPredicate::new(left.clone(), right.clone()); + return self + .get_join_selectivity_adjustment_when_adding_to_multi_equality_graph( + &predicate, + &mut past_eq_attrs, + ) + .await; + } + + selectivity *= self + .get_join_selectivity_from_on_attr_ref_pair(left_attr_ref, right_attr_ref) + .await?; + } + Ok(selectivity) + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use crate::{ + common::{predicates::attr_index_pred, types::TableId, values::Value}, + cost_model::tests::{ + attr_index, bin_op, cnst, create_four_table_mock_cost_model, create_mock_cost_model, + create_three_table_mock_cost_model, create_two_table_mock_cost_model, + create_two_table_mock_cost_model_custom_row_cnts, empty_per_attr_stats, log_op, + per_attr_stats_with_dist_and_ndistinct, per_attr_stats_with_ndistinct, + TestOptCostModelMock, TestPerAttributeStats, TEST_TABLE1_ID, TEST_TABLE2_ID, + TEST_TABLE3_ID, TEST_TABLE4_ID, + }, + stats::DEFAULT_EQ_SEL, + }; + + use super::*; + + /// A wrapper around get_join_selectivity_from_expr_tree that extracts the + /// table row counts from the cost model. + async fn test_get_join_selectivity( + cost_model: &TestOptCostModelMock, + reverse_tables: bool, + join_typ: JoinType, + expr_tree: ArcPredicateNode, + attr_refs: &AttrRefs, + input_correlation: Option, + ) -> f64 { + let table1_row_cnt = cost_model.get_row_count(TEST_TABLE1_ID) as f64; + let table2_row_cnt = cost_model.get_row_count(TEST_TABLE2_ID) as f64; + + if !reverse_tables { + cost_model + .get_join_selectivity_from_expr_tree( + join_typ, + expr_tree, + attr_refs, + input_correlation, + table1_row_cnt, + table2_row_cnt, + ) + .await + .unwrap() + } else { + cost_model + .get_join_selectivity_from_expr_tree( + join_typ, + expr_tree, + attr_refs, + input_correlation, + table2_row_cnt, + table1_row_cnt, + ) + .await + .unwrap() + } + } + + #[tokio::test] + async fn test_inner_const() { + let cost_model = create_mock_cost_model( + vec![TEST_TABLE1_ID], + vec![HashMap::from([(0, empty_per_attr_stats())])], + vec![None], + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_join_selectivity_from_expr_tree( + JoinType::Inner, + cnst(Value::Bool(true)), + &vec![], + None, + f64::NAN, + f64::NAN + ) + .await + .unwrap(), + 1.0 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_join_selectivity_from_expr_tree( + JoinType::Inner, + cnst(Value::Bool(false)), + &vec![], + None, + f64::NAN, + f64::NAN + ) + .await + .unwrap(), + 0.0 + ); + } + + #[tokio::test] + async fn test_inner_oncond() { + let cost_model = create_two_table_mock_cost_model( + per_attr_stats_with_ndistinct(5), + per_attr_stats_with_ndistinct(4), + ); + + let attr_refs = vec![ + AttrRef::base_table_attr_ref(TEST_TABLE1_ID, 0), + AttrRef::base_table_attr_ref(TEST_TABLE2_ID, 0), + ]; + let expr_tree = bin_op(BinOpType::Eq, attr_index(0), attr_index(1)); + let expr_tree_rev = bin_op(BinOpType::Eq, attr_index(1), attr_index(0)); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree, + &attr_refs, + None, + ) + .await, + 0.2 + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree_rev, + &attr_refs, + None, + ) + .await, + 0.2 + ); + } + + #[tokio::test] + async fn test_inner_and_of_onconds() { + let cost_model = create_two_table_mock_cost_model( + per_attr_stats_with_ndistinct(5), + per_attr_stats_with_ndistinct(4), + ); + + let attr_refs = vec![ + AttrRef::base_table_attr_ref(TEST_TABLE1_ID, 0), + AttrRef::base_table_attr_ref(TEST_TABLE2_ID, 0), + ]; + let eq0and1 = bin_op(BinOpType::Eq, attr_index(0), attr_index(1)); + let eq1and0 = bin_op(BinOpType::Eq, attr_index(1), attr_index(0)); + let expr_tree = log_op(LogOpType::And, vec![eq0and1.clone(), eq1and0.clone()]); + let expr_tree_rev = log_op(LogOpType::And, vec![eq1and0.clone(), eq0and1.clone()]); + + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree, + &attr_refs, + None, + ) + .await, + 0.2 + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree_rev, + &attr_refs, + None + ) + .await, + 0.2 + ); + } + + #[tokio::test] + #[ignore = "index out of bounds: the len is 1 but the index is 1"] + async fn test_inner_and_of_oncond_and_filter() { + let cost_model = create_two_table_mock_cost_model( + per_attr_stats_with_ndistinct(5), + per_attr_stats_with_ndistinct(4), + ); + + let attr_refs = vec![ + AttrRef::base_table_attr_ref(TEST_TABLE1_ID, 0), + AttrRef::base_table_attr_ref(TEST_TABLE2_ID, 0), + ]; + let eq0and1 = bin_op(BinOpType::Eq, attr_index(0), attr_index(1)); + let eq100 = bin_op(BinOpType::Eq, attr_index(1), cnst(Value::Int32(100))); + let expr_tree = log_op(LogOpType::And, vec![eq0and1.clone(), eq100.clone()]); + let expr_tree_rev = log_op(LogOpType::And, vec![eq100.clone(), eq0and1.clone()]); + + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree, + &attr_refs, + None + ) + .await, + 0.05 + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree_rev, + &attr_refs, + None + ) + .await, + 0.05 + ); + } + + #[tokio::test] + #[ignore = "filter todo"] + async fn test_inner_and_of_filters() { + let cost_model = create_two_table_mock_cost_model( + per_attr_stats_with_ndistinct(5), + per_attr_stats_with_ndistinct(4), + ); + + let attr_refs = vec![ + AttrRef::base_table_attr_ref(TEST_TABLE1_ID, 0), + AttrRef::base_table_attr_ref(TEST_TABLE2_ID, 0), + ]; + let neq12 = bin_op(BinOpType::Neq, attr_index(0), cnst(Value::Int32(12))); + let eq100 = bin_op(BinOpType::Eq, attr_index(1), cnst(Value::Int32(100))); + let expr_tree = log_op(LogOpType::And, vec![neq12.clone(), eq100.clone()]); + let expr_tree_rev = log_op(LogOpType::And, vec![eq100.clone(), neq12.clone()]); + + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree, + &attr_refs, + None, + ) + .await, + 0.2 + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree_rev, + &attr_refs, + None + ) + .await, + 0.2 + ); + } + + #[tokio::test] + async fn test_inner_colref_eq_colref_same_table_is_not_oncond() { + let cost_model = create_two_table_mock_cost_model( + per_attr_stats_with_ndistinct(5), + per_attr_stats_with_ndistinct(4), + ); + + let attr_refs = vec![ + AttrRef::base_table_attr_ref(TEST_TABLE1_ID, 0), + AttrRef::base_table_attr_ref(TEST_TABLE2_ID, 0), + ]; + let expr_tree = bin_op(BinOpType::Eq, attr_index(0), attr_index(0)); + + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree, + &attr_refs, + None + ) + .await, + DEFAULT_EQ_SEL + ); + } + + // We don't test joinsel or with oncond because if there is an oncond (on condition), the + // top-level operator must be an AND + + /// I made this helper function to avoid copying all eight lines over and over + async fn assert_outer_selectivities( + cost_model: &TestOptCostModelMock, + expr_tree: ArcPredicateNode, + expr_tree_rev: ArcPredicateNode, + attr_refs: &AttrRefs, + expected_table1_outer_sel: f64, + expected_table2_outer_sel: f64, + ) { + // all table 1 outer combinations + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + cost_model, + false, + JoinType::LeftOuter, + expr_tree.clone(), + attr_refs, + None + ) + .await, + expected_table1_outer_sel + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + cost_model, + false, + JoinType::LeftOuter, + expr_tree_rev.clone(), + attr_refs, + None + ) + .await, + expected_table1_outer_sel + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + cost_model, + true, + JoinType::RightOuter, + expr_tree.clone(), + attr_refs, + None + ) + .await, + expected_table1_outer_sel + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + cost_model, + true, + JoinType::RightOuter, + expr_tree_rev.clone(), + attr_refs, + None + ) + .await, + expected_table1_outer_sel + ); + // all table 2 outer combinations + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + cost_model, + true, + JoinType::LeftOuter, + expr_tree.clone(), + attr_refs, + None + ) + .await, + expected_table2_outer_sel + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + cost_model, + true, + JoinType::LeftOuter, + expr_tree_rev.clone(), + attr_refs, + None + ) + .await, + expected_table2_outer_sel + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + cost_model, + false, + JoinType::RightOuter, + expr_tree.clone(), + attr_refs, + None + ) + .await, + expected_table2_outer_sel + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + cost_model, + false, + JoinType::RightOuter, + expr_tree_rev.clone(), + attr_refs, + None + ) + .await, + expected_table2_outer_sel + ); + } + + /// Unique oncond means an oncondition on columns which are unique in both tables + /// There's only one case if both columns are unique and have different row counts: the inner + /// will be < 1 / row count of one table and = 1 / row count of another + #[tokio::test] + async fn test_outer_unique_oncond() { + let cost_model = create_two_table_mock_cost_model_custom_row_cnts( + per_attr_stats_with_ndistinct(5), + per_attr_stats_with_ndistinct(4), + 5, + 4, + ); + + let attr_refs = vec![ + AttrRef::base_table_attr_ref(TEST_TABLE1_ID, 0), + AttrRef::base_table_attr_ref(TEST_TABLE2_ID, 0), + ]; + // the left/right of the join refers to the tables, not the order of columns in the + // predicate + let expr_tree = bin_op(BinOpType::Eq, attr_index(0), attr_index(1)); + let expr_tree_rev = bin_op(BinOpType::Eq, attr_index(1), attr_index(0)); + + // sanity check the expected inner sel + let expected_inner_sel = 0.2; + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree.clone(), + &attr_refs, + None + ) + .await, + expected_inner_sel + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree_rev.clone(), + &attr_refs, + None + ) + .await, + expected_inner_sel + ); + // check the outer sels + assert_outer_selectivities(&cost_model, expr_tree, expr_tree_rev, &attr_refs, 0.25, 0.2); + } + + /// Non-unique oncond means the column is not unique in either table + /// Inner always >= row count means that the inner join result is >= 1 / the row count of both + /// tables + #[tokio::test] + async fn test_outer_nonunique_oncond_inner_always_geq_rowcnt() { + let cost_model = create_two_table_mock_cost_model_custom_row_cnts( + per_attr_stats_with_ndistinct(5), + per_attr_stats_with_ndistinct(4), + 10, + 8, + ); + + let attr_refs = vec![ + AttrRef::base_table_attr_ref(TEST_TABLE1_ID, 0), + AttrRef::base_table_attr_ref(TEST_TABLE2_ID, 0), + ]; + // the left/right of the join refers to the tables, not the order of columns in the + // predicate + let expr_tree = bin_op(BinOpType::Eq, attr_index(0), attr_index(1)); + let expr_tree_rev = bin_op(BinOpType::Eq, attr_index(1), attr_index(0)); + + // sanity check the expected inner sel + let expected_inner_sel = 0.2; + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree.clone(), + &attr_refs, + None + ) + .await, + expected_inner_sel + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree_rev.clone(), + &attr_refs, + None + ) + .await, + expected_inner_sel + ); + // check the outer sels + assert_outer_selectivities(&cost_model, expr_tree, expr_tree_rev, &attr_refs, 0.2, 0.2) + .await; + } + + /// Non-unique oncond means the column is not unique in either table + /// Inner sometimes < row count means that the inner join result < 1 / the row count of exactly + /// one table. Note that without a join filter, it's impossible to be less than the row + /// count of both tables + #[tokio::test] + async fn test_outer_nonunique_oncond_inner_sometimes_lt_rowcnt() { + let cost_model = create_two_table_mock_cost_model_custom_row_cnts( + per_attr_stats_with_ndistinct(10), + per_attr_stats_with_ndistinct(2), + 20, + 4, + ); + + let attr_refs = vec![ + AttrRef::base_table_attr_ref(TEST_TABLE1_ID, 0), + AttrRef::base_table_attr_ref(TEST_TABLE2_ID, 0), + ]; + // the left/right of the join refers to the tables, not the order of columns in the + // predicate + let expr_tree = bin_op(BinOpType::Eq, attr_index(0), attr_index(1)); + let expr_tree_rev = bin_op(BinOpType::Eq, attr_index(1), attr_index(0)); + + // sanity check the expected inner sel + let expected_inner_sel = 0.1; + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree.clone(), + &attr_refs, + None + ) + .await, + expected_inner_sel + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree_rev.clone(), + &attr_refs, + None + ) + .await, + expected_inner_sel + ); + // check the outer sels + assert_outer_selectivities(&cost_model, expr_tree, expr_tree_rev, &attr_refs, 0.25, 0.1) + .await; + } + + /// Unique oncond means an oncondition on columns which are unique in both tables + /// Filter means we're adding a join filter + /// There's only one case if both columns are unique and there's a filter: + /// the inner will be < 1 / row count of both tables + #[tokio::test] + async fn test_outer_unique_oncond_filter() { + let cost_model = create_two_table_mock_cost_model_custom_row_cnts( + per_attr_stats_with_dist_and_ndistinct(vec![(Value::Int32(128), 0.4)], 50), + per_attr_stats_with_ndistinct(4), + 50, + 4, + ); + + let attr_refs = vec![ + AttrRef::base_table_attr_ref(TEST_TABLE1_ID, 0), + AttrRef::base_table_attr_ref(TEST_TABLE2_ID, 0), + ]; + // the left/right of the join refers to the tables, not the order of columns in the + // predicate + let eq0and1 = bin_op(BinOpType::Eq, attr_index(0), attr_index(1)); + let eq1and0 = bin_op(BinOpType::Eq, attr_index(1), attr_index(0)); + let filter = bin_op(BinOpType::Leq, attr_index(0), cnst(Value::Int32(128))); + let expr_tree = log_op(LogOpType::And, vec![eq0and1, filter.clone()]); + // inner rev means its the inner expr (the eq op) whose children are being reversed, as + // opposed to the and op + let expr_tree_inner_rev = log_op(LogOpType::And, vec![eq1and0, filter.clone()]); + + // sanity check the expected inner sel + let expected_inner_sel = 0.008; + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree.clone(), + &attr_refs, + None + ) + .await, + expected_inner_sel + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree_inner_rev.clone(), + &attr_refs, + None + ) + .await, + expected_inner_sel + ); + // check the outer sels + assert_outer_selectivities( + &cost_model, + expr_tree, + expr_tree_inner_rev, + &attr_refs, + 0.25, + 0.02, + ) + .await; + } + + /// Test all possible permutations of three-table joins. + /// A three-table join consists of at least two joins. `join1_on_cond` is the condition of the + /// first join. There can only be one condition because only two tables are involved at + /// the time of the first join. + #[tokio::test] + #[ignore = "fail"] + #[test_case::test_case(&[(0, 1)])] + #[test_case::test_case(&[(0, 2)])] + #[test_case::test_case(&[(1, 2)])] + #[test_case::test_case(&[(0, 1), (0, 2)])] + #[test_case::test_case(&[(0, 1), (1, 2)])] + #[test_case::test_case(&[(0, 2), (1, 2)])] + #[test_case::test_case(&[(0, 1), (0, 2), (1, 2)])] + async fn test_three_table_join_for_initial_join_on_conds( + initial_join_on_conds: &[(usize, usize)], + ) { + assert!( + !initial_join_on_conds.is_empty(), + "initial_join_on_conds should be non-empty" + ); + assert_eq!( + initial_join_on_conds.len(), + initial_join_on_conds.iter().collect::>().len(), + "initial_join_on_conds shouldn't contain duplicates" + ); + let cost_model = create_three_table_mock_cost_model( + per_attr_stats_with_ndistinct(2), + per_attr_stats_with_ndistinct(3), + per_attr_stats_with_ndistinct(4), + ); + + let base_attr_refs = vec![ + BaseTableAttrRef { + table_id: TEST_TABLE1_ID, + attr_idx: 0, + }, + BaseTableAttrRef { + table_id: TEST_TABLE2_ID, + attr_idx: 0, + }, + BaseTableAttrRef { + table_id: TEST_TABLE3_ID, + attr_idx: 0, + }, + ]; + let attr_refs = base_attr_refs + .clone() + .into_iter() + .map(AttrRef::BaseTableAttrRef) + .collect(); + + let mut eq_columns = SemanticCorrelation::new(); + for initial_join_on_cond in initial_join_on_conds { + eq_columns.add_predicate(EqPredicate::new( + base_attr_refs[initial_join_on_cond.0].clone(), + base_attr_refs[initial_join_on_cond.1].clone(), + )); + } + let initial_selectivity = { + if initial_join_on_conds.len() == 1 { + let initial_join_on_cond = initial_join_on_conds.first().unwrap(); + if initial_join_on_cond == &(0, 1) { + 1.0 / 3.0 + } else if initial_join_on_cond == &(0, 2) || initial_join_on_cond == &(1, 2) { + 1.0 / 4.0 + } else { + panic!(); + } + } else { + 1.0 / 12.0 + } + }; + + let input_correlation = Some(eq_columns); + + // Try all join conditions of the final join which would lead to all three tables being + // joined. + let eq0and1 = bin_op(BinOpType::Eq, attr_index(0), attr_index(1)); + let eq0and2 = bin_op(BinOpType::Eq, attr_index(0), attr_index(2)); + let eq1and2 = bin_op(BinOpType::Eq, attr_index(1), attr_index(2)); + let and_01_02 = log_op(LogOpType::And, vec![eq0and1.clone(), eq0and2.clone()]); + let and_01_12 = log_op(LogOpType::And, vec![eq0and1.clone(), eq1and2.clone()]); + let and_02_12 = log_op(LogOpType::And, vec![eq0and2.clone(), eq1and2.clone()]); + let and_01_02_12 = log_op( + LogOpType::And, + vec![eq0and1.clone(), eq0and2.clone(), eq1and2.clone()], + ); + let mut join2_expr_trees = vec![and_01_02, and_01_12, and_02_12, and_01_02_12]; + if initial_join_on_conds.len() == 1 { + let initial_join_on_cond = initial_join_on_conds.first().unwrap(); + if initial_join_on_cond == &(0, 1) { + join2_expr_trees.push(eq0and2); + join2_expr_trees.push(eq1and2); + } else if initial_join_on_cond == &(0, 2) { + join2_expr_trees.push(eq0and1); + join2_expr_trees.push(eq1and2); + } else if initial_join_on_cond == &(1, 2) { + join2_expr_trees.push(eq0and1); + join2_expr_trees.push(eq0and2); + } else { + panic!(); + } + } + for expr_tree in join2_expr_trees { + let overall_selectivity = initial_selectivity + * test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree.clone(), + &attr_refs, + input_correlation.clone(), + ) + .await; + assert_approx_eq::assert_approx_eq!(overall_selectivity, 1.0 / 12.0); + } + } + + #[tokio::test] + async fn test_join_which_connects_two_components_together() { + let cost_model = create_four_table_mock_cost_model( + per_attr_stats_with_ndistinct(2), + per_attr_stats_with_ndistinct(3), + per_attr_stats_with_ndistinct(4), + per_attr_stats_with_ndistinct(5), + ); + let base_attr_refs = vec![ + BaseTableAttrRef { + table_id: TEST_TABLE1_ID, + attr_idx: 0, + }, + BaseTableAttrRef { + table_id: TEST_TABLE2_ID, + attr_idx: 0, + }, + BaseTableAttrRef { + table_id: TEST_TABLE3_ID, + attr_idx: 0, + }, + BaseTableAttrRef { + table_id: TEST_TABLE4_ID, + attr_idx: 0, + }, + ]; + let attr_refs = base_attr_refs + .clone() + .into_iter() + .map(AttrRef::BaseTableAttrRef) + .collect(); + + let mut eq_columns = SemanticCorrelation::new(); + eq_columns.add_predicate(EqPredicate::new( + base_attr_refs[0].clone(), + base_attr_refs[1].clone(), + )); + eq_columns.add_predicate(EqPredicate::new( + base_attr_refs[2].clone(), + base_attr_refs[3].clone(), + )); + let initial_selectivity = 1.0 / (3.0 * 5.0); + let input_correlation = Some(eq_columns); + + let eq1and2 = bin_op(BinOpType::Eq, attr_index(1), attr_index(2)); + let overall_selectivity = initial_selectivity + * test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + eq1and2.clone(), + &attr_refs, + input_correlation, + ) + .await; + assert_approx_eq::assert_approx_eq!(overall_selectivity, 1.0 / (3.0 * 4.0 * 5.0)); + } +} diff --git a/optd-cost-model/src/cost/join/join.rs b/optd-cost-model/src/cost/join/join.rs deleted file mode 100644 index ec98ace..0000000 --- a/optd-cost-model/src/cost/join/join.rs +++ /dev/null @@ -1,401 +0,0 @@ -use std::collections::HashSet; - -use itertools::Itertools; - -use crate::{ - common::{ - nodes::{ArcPredicateNode, JoinType, PredicateType, ReprPredicateNode}, - predicates::{ - attr_index_pred::AttrIndexPred, - bin_op_pred::BinOpType, - list_pred::ListPred, - log_op_pred::{LogOpPred, LogOpType}, - }, - properties::attr_ref::{ - self, AttrRef, AttrRefs, BaseTableAttrRef, EqPredicate, GroupAttrRefs, - SemanticCorrelation, - }, - types::GroupId, - }, - cost::join::get_on_attr_ref_pair, - cost_model::CostModelImpl, - stats::DEFAULT_NUM_DISTINCT, - storage::CostModelStorageManager, - CostModelResult, -}; - -impl CostModelImpl { - /// The expr_tree input must be a "mixed expression tree", just like with - /// `get_filter_selectivity`. - /// - /// This is a "wrapper" to separate the equality conditions from the filter conditions before - /// calling the "main" `get_join_selectivity_core` function. - #[allow(clippy::too_many_arguments)] - pub(crate) async fn get_join_selectivity_from_expr_tree( - &self, - join_typ: JoinType, - expr_tree: ArcPredicateNode, - attr_refs: &AttrRefs, - input_correlation: Option, - left_row_cnt: f64, - right_row_cnt: f64, - ) -> CostModelResult { - if expr_tree.typ == PredicateType::LogOp(LogOpType::And) { - let mut on_attr_ref_pairs = vec![]; - let mut filter_expr_trees = vec![]; - for child_expr_tree in &expr_tree.children { - if let Some(on_attr_ref_pair) = - get_on_attr_ref_pair(child_expr_tree.clone(), attr_refs) - { - on_attr_ref_pairs.push(on_attr_ref_pair) - } else { - let child_expr = child_expr_tree.clone(); - filter_expr_trees.push(child_expr); - } - } - assert!(on_attr_ref_pairs.len() + filter_expr_trees.len() == expr_tree.children.len()); - let filter_expr_tree = if filter_expr_trees.is_empty() { - None - } else { - Some(LogOpPred::new(LogOpType::And, filter_expr_trees).into_pred_node()) - }; - self.get_join_selectivity_core( - join_typ, - on_attr_ref_pairs, - filter_expr_tree, - attr_refs, - input_correlation, - left_row_cnt, - right_row_cnt, - 0, - ) - .await - } else { - #[allow(clippy::collapsible_else_if)] - if let Some(on_attr_ref_pair) = get_on_attr_ref_pair(expr_tree.clone(), attr_refs) { - self.get_join_selectivity_core( - join_typ, - vec![on_attr_ref_pair], - None, - attr_refs, - input_correlation, - left_row_cnt, - right_row_cnt, - 0, - ) - .await - } else { - self.get_join_selectivity_core( - join_typ, - vec![], - Some(expr_tree), - attr_refs, - input_correlation, - left_row_cnt, - right_row_cnt, - 0, - ) - .await - } - } - } - - /// A wrapper to convert the join keys to the format expected by get_join_selectivity_core() - #[allow(clippy::too_many_arguments)] - pub(crate) async fn get_join_selectivity_from_keys( - &self, - join_typ: JoinType, - left_keys: ListPred, - right_keys: ListPred, - attr_refs: &AttrRefs, - input_correlation: Option, - left_row_cnt: f64, - right_row_cnt: f64, - left_attr_cnt: usize, - ) -> CostModelResult { - assert!(left_keys.len() == right_keys.len()); - // I assume that the keys are already in the right order - // s.t. the ith key of left_keys corresponds with the ith key of right_keys - let on_attr_ref_pairs = left_keys - .to_vec() - .into_iter() - .zip(right_keys.to_vec()) - .map(|(left_key, right_key)| { - ( - AttrIndexPred::from_pred_node(left_key).expect("keys should be AttrRefPreds"), - AttrIndexPred::from_pred_node(right_key).expect("keys should be AttrRefPreds"), - ) - }) - .collect_vec(); - self.get_join_selectivity_core( - join_typ, - on_attr_ref_pairs, - None, - attr_refs, - input_correlation, - left_row_cnt, - right_row_cnt, - left_attr_cnt, - ) - .await - } - - /// The core logic of join selectivity which assumes we've already separated the expression - /// into the on conditions and the filters. - /// - /// Hash join and NLJ reference right table attributes differently, hence the - /// `right_attr_ref_offset` parameter. - /// - /// For hash join, the right table attributes indices are with respect to the right table, - /// which means #0 is the first attribute of the right table. - /// - /// For NLJ, the right table attributes indices are with respect to the output of the join. - /// For example, if the left table has 3 attributes, the first attribute of the right table - /// is #3 instead of #0. - #[allow(clippy::too_many_arguments)] - async fn get_join_selectivity_core( - &self, - join_typ: JoinType, - on_attr_ref_pairs: Vec<(AttrIndexPred, AttrIndexPred)>, - filter_expr_tree: Option, - attr_refs: &AttrRefs, - input_correlation: Option, - left_row_cnt: f64, - right_row_cnt: f64, - right_attr_ref_offset: usize, - ) -> CostModelResult { - let join_on_selectivity = self - .get_join_on_selectivity( - &on_attr_ref_pairs, - attr_refs, - input_correlation, - right_attr_ref_offset, - ) - .await?; - // Currently, there is no difference in how we handle a join filter and a select filter, - // so we use the same function. - // - // One difference (that we *don't* care about right now) is that join filters can contain - // expressions from multiple different tables. Currently, this doesn't affect the - // get_filter_selectivity() function, but this may change in the future. - let join_filter_selectivity = match filter_expr_tree { - Some(filter_expr_tree) => { - // FIXME(group_id): Pass in group id or schema & attr_refs - let group_id = GroupId(0); - self.get_filter_selectivity(group_id, filter_expr_tree) - .await? - } - None => 1.0, - }; - let inner_join_selectivity = join_on_selectivity * join_filter_selectivity; - - Ok(match join_typ { - JoinType::Inner => inner_join_selectivity, - JoinType::LeftOuter => f64::max(inner_join_selectivity, 1.0 / right_row_cnt), - JoinType::RightOuter => f64::max(inner_join_selectivity, 1.0 / left_row_cnt), - JoinType::Cross => { - assert!( - on_attr_ref_pairs.is_empty(), - "Cross joins should not have on attributes" - ); - join_filter_selectivity - } - _ => unimplemented!("join_typ={} is not implemented", join_typ), - }) - } - - /// Get the selectivity of one attribute eq predicate, e.g. attrA = attrB. - async fn get_join_selectivity_from_on_attr_ref_pair( - &self, - left: &AttrRef, - right: &AttrRef, - ) -> CostModelResult { - // the formula for each pair is min(1 / ndistinct1, 1 / ndistinct2) - // (see https://postgrespro.com/blog/pgsql/5969618) - let mut ndistincts = vec![]; - for attr_ref in [left, right] { - let ndistinct = match attr_ref { - AttrRef::BaseTableAttrRef(base_attr_ref) => { - match self - .get_attribute_comb_stats(base_attr_ref.table_id, &[base_attr_ref.attr_idx]) - .await? - { - Some(per_attr_stats) => per_attr_stats.ndistinct, - None => DEFAULT_NUM_DISTINCT, - } - } - AttrRef::Derived => DEFAULT_NUM_DISTINCT, - }; - ndistincts.push(ndistinct); - } - - // using reduce(f64::min) is the idiomatic workaround to min() because - // f64 does not implement Ord due to NaN - let selectivity = ndistincts.into_iter().map(|ndistinct| 1.0 / ndistinct as f64).reduce(f64::min).expect("reduce() only returns None if the iterator is empty, which is impossible since attr_ref_exprs.len() == 2"); - assert!( - !selectivity.is_nan(), - "it should be impossible for selectivity to be NaN since n-distinct is never 0" - ); - Ok(selectivity) - } - - /// Given a set of N attributes involved in a multi-equality, find the total selectivity - /// of the multi-equality. - /// - /// This is a generalization of get_join_selectivity_from_on_attr_ref_pair(). - async fn get_join_selectivity_from_most_selective_attrs( - &self, - base_attr_refs: HashSet, - ) -> CostModelResult { - assert!(base_attr_refs.len() > 1); - let num_base_attr_refs = base_attr_refs.len(); - - let mut ndistincts = vec![]; - for base_attr_ref in base_attr_refs.iter() { - let ndistinct = match self - .get_attribute_comb_stats(base_attr_ref.table_id, &[base_attr_ref.attr_idx]) - .await? - { - Some(per_attr_stats) => per_attr_stats.ndistinct, - None => DEFAULT_NUM_DISTINCT, - }; - ndistincts.push(ndistinct); - } - - Ok(ndistincts - .into_iter() - .map(|ndistinct| 1.0 / ndistinct as f64) - .sorted_by(|a, b| { - a.partial_cmp(b) - .expect("No floats should be NaN since n-distinct is never 0") - }) - .take(num_base_attr_refs - 1) - .product()) - } - - /// A predicate set defines a "multi-equality graph", which is an unweighted undirected graph. - /// The nodes are attributes while edges are predicates. The old graph is defined by - /// `past_eq_attrs` while the `predicate` is the new addition to this graph. This - /// unweighted undirected graph consists of a number of connected components, where each - /// connected component represents attributes that are set to be equal to each other. Single - /// nodes not connected to anything are considered standalone connected components. - /// - /// The selectivity of each connected component of N nodes is equal to the product of - /// 1/ndistinct of the N-1 nodes with the highest ndistinct values. You can see this if you - /// imagine that all attributes being joined are unique attributes and that they follow the - /// inclusion principle (every element of the smaller tables is present in the larger - /// tables). When these assumptions are not true, the selectivity may not be completely - /// accurate. However, it is still fairly accurate. - /// - /// However, we cannot simply add `predicate` to the multi-equality graph and compute the - /// selectivity of the entire connected component, because this would be "double counting" a - /// lot of nodes. The join(s) before this join would already have a selectivity value. Thus, - /// we compute the selectivity of the join(s) before this join (the first block of the - /// function) and then the selectivity of the connected component after this join. The - /// quotient is the "adjustment" factor. - /// - /// NOTE: This function modifies `past_eq_attrs` by adding `predicate` to it. - async fn get_join_selectivity_adjustment_when_adding_to_multi_equality_graph( - &self, - predicate: &EqPredicate, - past_eq_attrs: &mut SemanticCorrelation, - ) -> CostModelResult { - if predicate.left == predicate.right { - // self-join, TODO: is this correct? - return Ok(1.0); - } - // To find the adjustment, we need to know the selectivity of the graph before `predicate` - // is added. - // - // There are two cases: (1) adding `predicate` does not change the # of connected - // components, and (2) adding `predicate` reduces the # of connected by 1. Note that - // attributes not involved in any predicates are considered a part of the graph and are - // a connected component on their own. - let children_pred_sel = { - if past_eq_attrs.is_eq(&predicate.left, &predicate.right) { - self.get_join_selectivity_from_most_selective_attrs( - past_eq_attrs.find_attrs_for_eq_attribute_set(&predicate.left), - ) - .await? - } else { - let left_sel = if past_eq_attrs.contains(&predicate.left) { - self.get_join_selectivity_from_most_selective_attrs( - past_eq_attrs.find_attrs_for_eq_attribute_set(&predicate.left), - ) - .await? - } else { - 1.0 - }; - let right_sel = if past_eq_attrs.contains(&predicate.right) { - self.get_join_selectivity_from_most_selective_attrs( - past_eq_attrs.find_attrs_for_eq_attribute_set(&predicate.right), - ) - .await? - } else { - 1.0 - }; - left_sel * right_sel - } - }; - - // Add predicate to past_eq_attrs and compute the selectivity of the connected component - // it creates. - past_eq_attrs.add_predicate(predicate.clone()); - let new_pred_sel = { - let attrs = past_eq_attrs.find_attrs_for_eq_attribute_set(&predicate.left); - self.get_join_selectivity_from_most_selective_attrs(attrs) - } - .await?; - - // Compute the adjustment factor. - Ok(new_pred_sel / children_pred_sel) - } - - /// Get the selectivity of the on conditions. - /// - /// Note that the selectivity of the on conditions does not depend on join type. - /// Join type is accounted for separately in get_join_selectivity_core(). - /// - /// We also check if each predicate is correlated with any of the previous predicates. - /// - /// More specifically, we are checking if the predicate can be expressed with other existing - /// predicates. E.g. if we have a predicate like A = B and B = C is equivalent to A = C. - // - /// However, we don't just throw away A = C, because we want to pick the most selective - /// predicates. For details on how we do this, see - /// `get_join_selectivity_from_redundant_predicates`. - async fn get_join_on_selectivity( - &self, - on_attr_ref_pairs: &[(AttrIndexPred, AttrIndexPred)], - attr_refs: &AttrRefs, - input_correlation: Option, - right_attr_ref_offset: usize, - ) -> CostModelResult { - let mut past_eq_attrs = input_correlation.unwrap_or_default(); - - // Multiply the selectivities of all individual conditions together - let mut selectivity = 1.0; - for on_attr_ref_pair in on_attr_ref_pairs { - let left_attr_ref = &attr_refs[on_attr_ref_pair.0.attr_index() as usize]; - let right_attr_ref = - &attr_refs[on_attr_ref_pair.1.attr_index() as usize + right_attr_ref_offset]; - - if let (AttrRef::BaseTableAttrRef(left), AttrRef::BaseTableAttrRef(right)) = - (left_attr_ref, right_attr_ref) - { - let predicate = EqPredicate::new(left.clone(), right.clone()); - return self - .get_join_selectivity_adjustment_when_adding_to_multi_equality_graph( - &predicate, - &mut past_eq_attrs, - ) - .await; - } - - selectivity *= self - .get_join_selectivity_from_on_attr_ref_pair(left_attr_ref, right_attr_ref) - .await?; - } - Ok(selectivity) - } -} diff --git a/optd-cost-model/src/cost/join/mod.rs b/optd-cost-model/src/cost/join/mod.rs index b752aa8..71b991b 100644 --- a/optd-cost-model/src/cost/join/mod.rs +++ b/optd-cost-model/src/cost/join/mod.rs @@ -6,9 +6,8 @@ use crate::common::{ }, }; +pub mod core; pub mod hash_join; -// FIXME: module has the same name as its containing module -pub mod join; pub mod nested_loop_join; pub(crate) fn get_input_correlation( diff --git a/optd-cost-model/src/cost_model.rs b/optd-cost-model/src/cost_model.rs index 11fe2ec..4e7ef7f 100644 --- a/optd-cost-model/src/cost_model.rs +++ b/optd-cost-model/src/cost_model.rs @@ -148,7 +148,8 @@ pub mod tests { }, memo_ext::tests::{MemoGroupInfo, MockMemoExtImpl}, stats::{ - utilities::counter::Counter, AttributeCombValueStats, Distribution, MostCommonValues, + utilities::{counter::Counter, simple_map::SimpleMap}, + AttributeCombValueStats, Distribution, MostCommonValues, }, storage::mock::{CostModelStorageMockManagerImpl, TableStats}, }; @@ -292,7 +293,7 @@ pub mod tests { tbl1_per_attr_stats: TestPerAttributeStats, tbl2_per_attr_stats: TestPerAttributeStats, ) -> TestOptCostModelMock { - create_two_table_cost_model_custom_row_cnts( + create_two_table_mock_cost_model_custom_row_cnts( tbl1_per_attr_stats, tbl2_per_attr_stats, 100, @@ -300,8 +301,8 @@ pub mod tests { ) } - /// Create a cost model with three columns, one for each table. Each column has 100 values. - pub fn create_three_table_cost_model( + /// Create a cost model three tables, each with one attribute. Each attribute has 100 values. + pub fn create_three_table_mock_cost_model( tbl1_per_column_stats: TestPerAttributeStats, tbl2_per_column_stats: TestPerAttributeStats, tbl3_per_column_stats: TestPerAttributeStats, @@ -387,8 +388,8 @@ pub mod tests { ) } - /// Create a cost model with three columns, one for each table. Each column has 100 values. - pub fn create_four_table_cost_model( + /// Create a cost model four tables, each with one attribute. Each attribute has 100 values. + pub fn create_four_table_mock_cost_model( tbl1_per_column_stats: TestPerAttributeStats, tbl2_per_column_stats: TestPerAttributeStats, tbl3_per_column_stats: TestPerAttributeStats, @@ -498,7 +499,7 @@ pub mod tests { } /// We need custom row counts because some join algorithms rely on the row cnt - pub fn create_two_table_cost_model_custom_row_cnts( + pub fn create_two_table_mock_cost_model_custom_row_cnts( tbl1_per_column_stats: TestPerAttributeStats, tbl2_per_column_stats: TestPerAttributeStats, tbl1_row_cnt: u64, @@ -631,6 +632,32 @@ pub mod tests { } pub(crate) fn empty_per_attr_stats() -> TestPerAttributeStats { - TestPerAttributeStats::new(MostCommonValues::Counter(Counter::default()), None, 0, 0.0) + TestPerAttributeStats::new( + MostCommonValues::empty(), + Some(Distribution::empty()), + 0, + 0.0, + ) + } + + pub(crate) fn per_attr_stats_with_ndistinct(ndistinct: u64) -> TestPerAttributeStats { + TestPerAttributeStats::new( + MostCommonValues::empty(), + Some(Distribution::empty()), + ndistinct, + 0.0, + ) + } + + pub(crate) fn per_attr_stats_with_dist_and_ndistinct( + dist: Vec<(Value, f64)>, + ndistinct: u64, + ) -> TestPerAttributeStats { + TestPerAttributeStats::new( + MostCommonValues::empty(), + Some(Distribution::SimpleDistribution(SimpleMap::new(dist))), + ndistinct, + 0.0, + ) } } diff --git a/optd-cost-model/src/stats/mod.rs b/optd-cost-model/src/stats/mod.rs index 0fcc4c2..7ec2510 100644 --- a/optd-cost-model/src/stats/mod.rs +++ b/optd-cost-model/src/stats/mod.rs @@ -83,6 +83,10 @@ impl MostCommonValues { MostCommonValues::SimpleFrequency(simple_map) => simple_map.m.len(), } } + + pub fn empty() -> Self { + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![])) + } } // TODO: remove the clone, see the comment in the [`AttributeCombValueStats`] @@ -110,6 +114,10 @@ impl Distribution { } } } + + pub fn empty() -> Self { + Distribution::SimpleDistribution(SimpleMap::new(vec![])) + } } // TODO: Remove the clone. Now I have to add this because From be71afb6fb03e99962d16469c0b7979fee9cacbe Mon Sep 17 00:00:00 2001 From: Yuanxin Cao Date: Mon, 18 Nov 2024 20:57:13 -0500 Subject: [PATCH 49/51] pass group id to join and fix filter-related tests --- optd-cost-model/src/common/properties/mod.rs | 18 ++++ optd-cost-model/src/cost/join/core.rs | 89 +++++++++++++++++-- optd-cost-model/src/cost/join/hash_join.rs | 1 + .../src/cost/join/nested_loop_join.rs | 1 + optd-cost-model/src/cost_model.rs | 78 +++++----------- 5 files changed, 124 insertions(+), 63 deletions(-) diff --git a/optd-cost-model/src/common/properties/mod.rs b/optd-cost-model/src/common/properties/mod.rs index c9acbd1..a90d634 100644 --- a/optd-cost-model/src/common/properties/mod.rs +++ b/optd-cost-model/src/common/properties/mod.rs @@ -21,3 +21,21 @@ impl std::fmt::Display for Attribute { } } } + +impl Attribute { + pub fn new(name: String, typ: ConstantType, nullable: bool) -> Self { + Self { + name, + typ, + nullable, + } + } + + pub fn new_non_null_int64(name: String) -> Self { + Self { + name, + typ: ConstantType::Int64, + nullable: false, + } + } +} diff --git a/optd-cost-model/src/cost/join/core.rs b/optd-cost-model/src/cost/join/core.rs index 6693b0c..cbc684e 100644 --- a/optd-cost-model/src/cost/join/core.rs +++ b/optd-cost-model/src/cost/join/core.rs @@ -34,6 +34,7 @@ impl CostModelImpl { pub(crate) async fn get_join_selectivity_from_expr_tree( &self, join_typ: JoinType, + group_id: GroupId, expr_tree: ArcPredicateNode, attr_refs: &AttrRefs, input_correlation: Option, @@ -61,6 +62,7 @@ impl CostModelImpl { }; self.get_join_selectivity_core( join_typ, + group_id, on_attr_ref_pairs, filter_expr_tree, attr_refs, @@ -75,6 +77,7 @@ impl CostModelImpl { if let Some(on_attr_ref_pair) = get_on_attr_ref_pair(expr_tree.clone(), attr_refs) { self.get_join_selectivity_core( join_typ, + group_id, vec![on_attr_ref_pair], None, attr_refs, @@ -87,6 +90,7 @@ impl CostModelImpl { } else { self.get_join_selectivity_core( join_typ, + group_id, vec![], Some(expr_tree), attr_refs, @@ -105,6 +109,7 @@ impl CostModelImpl { pub(crate) async fn get_join_selectivity_from_keys( &self, join_typ: JoinType, + group_id: GroupId, left_keys: ListPred, right_keys: ListPred, attr_refs: &AttrRefs, @@ -129,6 +134,7 @@ impl CostModelImpl { .collect_vec(); self.get_join_selectivity_core( join_typ, + group_id, on_attr_ref_pairs, None, attr_refs, @@ -156,6 +162,7 @@ impl CostModelImpl { async fn get_join_selectivity_core( &self, join_typ: JoinType, + group_id: GroupId, on_attr_ref_pairs: Vec<(AttrIndexPred, AttrIndexPred)>, filter_expr_tree: Option, attr_refs: &AttrRefs, @@ -180,8 +187,6 @@ impl CostModelImpl { // get_filter_selectivity() function, but this may change in the future. let join_filter_selectivity = match filter_expr_tree { Some(filter_expr_tree) => { - // FIXME(group_id): Pass in group id or schema & attr_refs - let group_id = GroupId(0); self.get_filter_selectivity(group_id, filter_expr_tree) .await? } @@ -405,20 +410,28 @@ mod tests { use std::collections::HashMap; use crate::{ - common::{predicates::attr_index_pred, types::TableId, values::Value}, + common::{ + predicates::{attr_index_pred, constant_pred::ConstantType}, + properties::Attribute, + types::TableId, + values::Value, + }, cost_model::tests::{ attr_index, bin_op, cnst, create_four_table_mock_cost_model, create_mock_cost_model, create_three_table_mock_cost_model, create_two_table_mock_cost_model, create_two_table_mock_cost_model_custom_row_cnts, empty_per_attr_stats, log_op, per_attr_stats_with_dist_and_ndistinct, per_attr_stats_with_ndistinct, - TestOptCostModelMock, TestPerAttributeStats, TEST_TABLE1_ID, TEST_TABLE2_ID, - TEST_TABLE3_ID, TEST_TABLE4_ID, + TestOptCostModelMock, TestPerAttributeStats, TEST_ATTR1_NAME, TEST_ATTR2_NAME, + TEST_TABLE1_ID, TEST_TABLE2_ID, TEST_TABLE3_ID, TEST_TABLE4_ID, }, + memo_ext::tests::MemoGroupInfo, stats::DEFAULT_EQ_SEL, }; use super::*; + const JOIN_GROUP_ID: GroupId = GroupId(10); + /// A wrapper around get_join_selectivity_from_expr_tree that extracts the /// table row counts from the cost model. async fn test_get_join_selectivity( @@ -436,6 +449,7 @@ mod tests { cost_model .get_join_selectivity_from_expr_tree( join_typ, + JOIN_GROUP_ID, expr_tree, attr_refs, input_correlation, @@ -448,6 +462,7 @@ mod tests { cost_model .get_join_selectivity_from_expr_tree( join_typ, + JOIN_GROUP_ID, expr_tree, attr_refs, input_correlation, @@ -470,6 +485,7 @@ mod tests { cost_model .get_join_selectivity_from_expr_tree( JoinType::Inner, + JOIN_GROUP_ID, cnst(Value::Bool(true)), &vec![], None, @@ -484,6 +500,7 @@ mod tests { cost_model .get_join_selectivity_from_expr_tree( JoinType::Inner, + JOIN_GROUP_ID, cnst(Value::Bool(false)), &vec![], None, @@ -501,6 +518,7 @@ mod tests { let cost_model = create_two_table_mock_cost_model( per_attr_stats_with_ndistinct(5), per_attr_stats_with_ndistinct(4), + None, ); let attr_refs = vec![ @@ -540,6 +558,7 @@ mod tests { let cost_model = create_two_table_mock_cost_model( per_attr_stats_with_ndistinct(5), per_attr_stats_with_ndistinct(4), + None, ); let attr_refs = vec![ @@ -578,11 +597,28 @@ mod tests { } #[tokio::test] - #[ignore = "index out of bounds: the len is 1 but the index is 1"] async fn test_inner_and_of_oncond_and_filter() { + let join_memo = HashMap::from([( + JOIN_GROUP_ID, + MemoGroupInfo::new( + vec![ + Attribute::new_non_null_int64(TEST_ATTR1_NAME.to_string()), + Attribute::new_non_null_int64(TEST_ATTR2_NAME.to_string()), + ] + .into(), + GroupAttrRefs::new( + vec![ + AttrRef::new_base_table_attr_ref(TEST_TABLE1_ID, 0), + AttrRef::new_base_table_attr_ref(TEST_TABLE2_ID, 0), + ], + None, + ), + ), + )]); let cost_model = create_two_table_mock_cost_model( per_attr_stats_with_ndistinct(5), per_attr_stats_with_ndistinct(4), + Some(join_memo), ); let attr_refs = vec![ @@ -621,11 +657,28 @@ mod tests { } #[tokio::test] - #[ignore = "filter todo"] async fn test_inner_and_of_filters() { + let join_memo = HashMap::from([( + JOIN_GROUP_ID, + MemoGroupInfo::new( + vec![ + Attribute::new_non_null_int64(TEST_ATTR1_NAME.to_string()), + Attribute::new_non_null_int64(TEST_ATTR2_NAME.to_string()), + ] + .into(), + GroupAttrRefs::new( + vec![ + AttrRef::new_base_table_attr_ref(TEST_TABLE1_ID, 0), + AttrRef::new_base_table_attr_ref(TEST_TABLE2_ID, 0), + ], + None, + ), + ), + )]); let cost_model = create_two_table_mock_cost_model( per_attr_stats_with_ndistinct(5), per_attr_stats_with_ndistinct(4), + Some(join_memo), ); let attr_refs = vec![ @@ -668,6 +721,7 @@ mod tests { let cost_model = create_two_table_mock_cost_model( per_attr_stats_with_ndistinct(5), per_attr_stats_with_ndistinct(4), + None, ); let attr_refs = vec![ @@ -812,6 +866,7 @@ mod tests { per_attr_stats_with_ndistinct(4), 5, 4, + None, ); let attr_refs = vec![ @@ -863,6 +918,7 @@ mod tests { per_attr_stats_with_ndistinct(4), 10, 8, + None, ); let attr_refs = vec![ @@ -916,6 +972,7 @@ mod tests { per_attr_stats_with_ndistinct(2), 20, 4, + None, ); let attr_refs = vec![ @@ -964,11 +1021,29 @@ mod tests { /// the inner will be < 1 / row count of both tables #[tokio::test] async fn test_outer_unique_oncond_filter() { + let join_memo = HashMap::from([( + JOIN_GROUP_ID, + MemoGroupInfo::new( + vec![ + Attribute::new_non_null_int64(TEST_ATTR1_NAME.to_string()), + Attribute::new_non_null_int64(TEST_ATTR2_NAME.to_string()), + ] + .into(), + GroupAttrRefs::new( + vec![ + AttrRef::new_base_table_attr_ref(TEST_TABLE1_ID, 0), + AttrRef::new_base_table_attr_ref(TEST_TABLE2_ID, 0), + ], + None, + ), + ), + )]); let cost_model = create_two_table_mock_cost_model_custom_row_cnts( per_attr_stats_with_dist_and_ndistinct(vec![(Value::Int32(128), 0.4)], 50), per_attr_stats_with_ndistinct(4), 50, 4, + Some(join_memo), ); let attr_refs = vec![ diff --git a/optd-cost-model/src/cost/join/hash_join.rs b/optd-cost-model/src/cost/join/hash_join.rs index fec65b1..47c9ebd 100644 --- a/optd-cost-model/src/cost/join/hash_join.rs +++ b/optd-cost-model/src/cost/join/hash_join.rs @@ -37,6 +37,7 @@ impl CostModelImpl { let input_correlation = get_input_correlation(left_attr_refs, right_attr_refs); self.get_join_selectivity_from_keys( join_typ, + group_id, left_keys, right_keys, output_attr_refs.attr_refs(), diff --git a/optd-cost-model/src/cost/join/nested_loop_join.rs b/optd-cost-model/src/cost/join/nested_loop_join.rs index 7f99e34..ebb70c9 100644 --- a/optd-cost-model/src/cost/join/nested_loop_join.rs +++ b/optd-cost-model/src/cost/join/nested_loop_join.rs @@ -32,6 +32,7 @@ impl CostModelImpl { self.get_join_selectivity_from_expr_tree( join_typ, + group_id, join_cond, output_attr_refs.attr_refs(), input_correlation, diff --git a/optd-cost-model/src/cost_model.rs b/optd-cost-model/src/cost_model.rs index 4e7ef7f..4583484 100644 --- a/optd-cost-model/src/cost_model.rs +++ b/optd-cost-model/src/cost_model.rs @@ -117,7 +117,7 @@ impl CostModelImpl { /// optd-datafusion-bridge and optd-datafusion-repr #[cfg(test)] pub mod tests { - use std::collections::HashMap; + use std::{collections::HashMap, hash::Hash}; use arrow_schema::DataType; use itertools::Itertools; @@ -171,6 +171,11 @@ pub mod tests { pub const TEST_ATTR2_BASE_INDEX: u64 = 1; pub const TEST_ATTR3_BASE_INDEX: u64 = 2; + pub const TEST_ATTR1_NAME: &str = "attr1"; + pub const TEST_ATTR2_NAME: &str = "attr2"; + pub const TEST_ATTR3_NAME: &str = "attr3"; + pub const TEST_ATTR4_NAME: &str = "attr4"; + pub type TestPerAttributeStats = AttributeCombValueStats; // TODO: add tests for non-mock storage manager pub type TestOptCostModelMock = CostModelImpl; @@ -292,12 +297,14 @@ pub mod tests { pub fn create_two_table_mock_cost_model( tbl1_per_attr_stats: TestPerAttributeStats, tbl2_per_attr_stats: TestPerAttributeStats, + additional_memo: Option>, ) -> TestOptCostModelMock { create_two_table_mock_cost_model_custom_row_cnts( tbl1_per_attr_stats, tbl2_per_attr_stats, 100, 100, + additional_memo, ) } @@ -338,12 +345,7 @@ pub mod tests { ( TEST_GROUP1_ID, MemoGroupInfo::new( - vec![Attribute { - name: "attr1".to_string(), - typ: ConstantType::Int64, - nullable: false, - }] - .into(), + vec![Attribute::new_non_null_int64(TEST_ATTR1_NAME.to_string())].into(), GroupAttrRefs::new( vec![AttrRef::new_base_table_attr_ref(TEST_TABLE1_ID, 0)], None, @@ -353,12 +355,7 @@ pub mod tests { ( TEST_GROUP2_ID, MemoGroupInfo::new( - vec![Attribute { - name: "attr2".to_string(), - typ: ConstantType::Int64, - nullable: false, - }] - .into(), + vec![Attribute::new_non_null_int64(TEST_ATTR2_NAME.to_string())].into(), GroupAttrRefs::new( vec![AttrRef::new_base_table_attr_ref(TEST_TABLE2_ID, 0)], None, @@ -368,12 +365,7 @@ pub mod tests { ( TEST_GROUP3_ID, MemoGroupInfo::new( - vec![Attribute { - name: "attr3".to_string(), - typ: ConstantType::Int64, - nullable: false, - }] - .into(), + vec![Attribute::new_non_null_int64(TEST_ATTR3_NAME.to_string())].into(), GroupAttrRefs::new( vec![AttrRef::new_base_table_attr_ref(TEST_TABLE3_ID, 0)], None, @@ -433,12 +425,7 @@ pub mod tests { ( TEST_GROUP1_ID, MemoGroupInfo::new( - vec![Attribute { - name: "attr1".to_string(), - typ: ConstantType::Int64, - nullable: false, - }] - .into(), + vec![Attribute::new_non_null_int64(TEST_ATTR1_NAME.to_string())].into(), GroupAttrRefs::new( vec![AttrRef::new_base_table_attr_ref(TEST_TABLE1_ID, 0)], None, @@ -448,12 +435,7 @@ pub mod tests { ( TEST_GROUP2_ID, MemoGroupInfo::new( - vec![Attribute { - name: "attr2".to_string(), - typ: ConstantType::Int64, - nullable: false, - }] - .into(), + vec![Attribute::new_non_null_int64(TEST_ATTR2_NAME.to_string())].into(), GroupAttrRefs::new( vec![AttrRef::new_base_table_attr_ref(TEST_TABLE2_ID, 0)], None, @@ -463,12 +445,7 @@ pub mod tests { ( TEST_GROUP3_ID, MemoGroupInfo::new( - vec![Attribute { - name: "attr3".to_string(), - typ: ConstantType::Int64, - nullable: false, - }] - .into(), + vec![Attribute::new_non_null_int64(TEST_ATTR3_NAME.to_string())].into(), GroupAttrRefs::new( vec![AttrRef::new_base_table_attr_ref(TEST_TABLE3_ID, 0)], None, @@ -478,12 +455,7 @@ pub mod tests { ( TEST_GROUP4_ID, MemoGroupInfo::new( - vec![Attribute { - name: "attr4".to_string(), - typ: ConstantType::Int64, - nullable: false, - }] - .into(), + vec![Attribute::new_non_null_int64(TEST_ATTR4_NAME.to_string())].into(), GroupAttrRefs::new( vec![AttrRef::new_base_table_attr_ref(TEST_TABLE4_ID, 0)], None, @@ -504,6 +476,7 @@ pub mod tests { tbl2_per_column_stats: TestPerAttributeStats, tbl1_row_cnt: u64, tbl2_row_cnt: u64, + additional_memo: Option>, ) -> TestOptCostModelMock { let storage_manager = CostModelStorageMockManagerImpl::new( vec![ @@ -525,16 +498,11 @@ pub mod tests { .into_iter() .collect(), ); - let memo = HashMap::from([ + let mut memo = HashMap::from([ ( TEST_GROUP1_ID, MemoGroupInfo::new( - vec![Attribute { - name: "attr1".to_string(), - typ: ConstantType::Int64, - nullable: false, - }] - .into(), + vec![Attribute::new_non_null_int64(TEST_ATTR1_NAME.to_string())].into(), GroupAttrRefs::new( vec![AttrRef::new_base_table_attr_ref(TEST_TABLE1_ID, 0)], None, @@ -544,12 +512,7 @@ pub mod tests { ( TEST_GROUP2_ID, MemoGroupInfo::new( - vec![Attribute { - name: "attr2".to_string(), - typ: ConstantType::Int64, - nullable: false, - }] - .into(), + vec![Attribute::new_non_null_int64(TEST_ATTR2_NAME.to_string())].into(), GroupAttrRefs::new( vec![AttrRef::new_base_table_attr_ref(TEST_TABLE2_ID, 0)], None, @@ -557,6 +520,9 @@ pub mod tests { ), ), ]); + if let Some(additional_memo) = additional_memo { + memo.extend(additional_memo); + } CostModelImpl::new( storage_manager, CatalogSource::Mock, From 624d040e32c866807588cccc4416aedeb7c3e6d4 Mon Sep 17 00:00:00 2001 From: Yuanxin Cao Date: Mon, 18 Nov 2024 21:23:43 -0500 Subject: [PATCH 50/51] fix all join tests --- optd-cost-model/src/cost/join/core.rs | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/optd-cost-model/src/cost/join/core.rs b/optd-cost-model/src/cost/join/core.rs index cbc684e..c68c1db 100644 --- a/optd-cost-model/src/cost/join/core.rs +++ b/optd-cost-model/src/cost/join/core.rs @@ -385,22 +385,22 @@ impl CostModelImpl { let right_attr_ref = &attr_refs[on_attr_ref_pair.1.attr_index() as usize + right_attr_ref_offset]; - if let (AttrRef::BaseTableAttrRef(left), AttrRef::BaseTableAttrRef(right)) = - (left_attr_ref, right_attr_ref) - { - let predicate = EqPredicate::new(left.clone(), right.clone()); - return self - .get_join_selectivity_adjustment_when_adding_to_multi_equality_graph( + selectivity *= + if let (AttrRef::BaseTableAttrRef(left), AttrRef::BaseTableAttrRef(right)) = + (left_attr_ref, right_attr_ref) + { + let predicate = EqPredicate::new(left.clone(), right.clone()); + self.get_join_selectivity_adjustment_when_adding_to_multi_equality_graph( &predicate, &mut past_eq_attrs, ) - .await; - } - - selectivity *= self - .get_join_selectivity_from_on_attr_ref_pair(left_attr_ref, right_attr_ref) - .await?; + .await? + } else { + self.get_join_selectivity_from_on_attr_ref_pair(left_attr_ref, right_attr_ref) + .await? + }; } + Ok(selectivity) } } @@ -1103,7 +1103,6 @@ mod tests { /// first join. There can only be one condition because only two tables are involved at /// the time of the first join. #[tokio::test] - #[ignore = "fail"] #[test_case::test_case(&[(0, 1)])] #[test_case::test_case(&[(0, 2)])] #[test_case::test_case(&[(1, 2)])] From f8a0e706fa5ac1715095a99cc4c2647f9180bc00 Mon Sep 17 00:00:00 2001 From: Lan Lou Date: Tue, 19 Nov 2024 15:03:30 -0500 Subject: [PATCH 51/51] Change filter controller name --- optd-cost-model/src/cost/filter/{controller.rs => core.rs} | 0 optd-cost-model/src/cost/filter/mod.rs | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename optd-cost-model/src/cost/filter/{controller.rs => core.rs} (100%) diff --git a/optd-cost-model/src/cost/filter/controller.rs b/optd-cost-model/src/cost/filter/core.rs similarity index 100% rename from optd-cost-model/src/cost/filter/controller.rs rename to optd-cost-model/src/cost/filter/core.rs diff --git a/optd-cost-model/src/cost/filter/mod.rs b/optd-cost-model/src/cost/filter/mod.rs index bf1d5ab..00ea653 100644 --- a/optd-cost-model/src/cost/filter/mod.rs +++ b/optd-cost-model/src/cost/filter/mod.rs @@ -1,7 +1,7 @@ pub mod attribute; pub mod comp_op; pub mod constant; -pub mod controller; +pub mod core; pub mod in_list; pub mod like; pub mod log_op;