From 795ca1ff2ad5efffec9e4ea869b41ebf4ac1ea9c Mon Sep 17 00:00:00 2001 From: Lan Lou Date: Tue, 19 Nov 2024 16:22:26 -0500 Subject: [PATCH] Add filter implementation --- optd-cost-model/src/cost/filter.rs | 0 optd-cost-model/src/cost/filter/attribute.rs | 183 ++++ optd-cost-model/src/cost/filter/comp_op.rs | 280 ++++++ optd-cost-model/src/cost/filter/constant.rs | 38 + optd-cost-model/src/cost/filter/core.rs | 877 +++++++++++++++++++ optd-cost-model/src/cost/filter/in_list.rs | 167 ++++ optd-cost-model/src/cost/filter/like.rs | 210 +++++ optd-cost-model/src/cost/filter/log_op.rs | 34 + optd-cost-model/src/cost/filter/mod.rs | 7 + 9 files changed, 1796 insertions(+) delete mode 100644 optd-cost-model/src/cost/filter.rs create mode 100644 optd-cost-model/src/cost/filter/attribute.rs create mode 100644 optd-cost-model/src/cost/filter/comp_op.rs create mode 100644 optd-cost-model/src/cost/filter/constant.rs create mode 100644 optd-cost-model/src/cost/filter/core.rs create mode 100644 optd-cost-model/src/cost/filter/in_list.rs create mode 100644 optd-cost-model/src/cost/filter/like.rs create mode 100644 optd-cost-model/src/cost/filter/log_op.rs create mode 100644 optd-cost-model/src/cost/filter/mod.rs diff --git a/optd-cost-model/src/cost/filter.rs b/optd-cost-model/src/cost/filter.rs deleted file mode 100644 index e69de29..0000000 diff --git a/optd-cost-model/src/cost/filter/attribute.rs b/optd-cost-model/src/cost/filter/attribute.rs new file mode 100644 index 0000000..7a082b7 --- /dev/null +++ b/optd-cost-model/src/cost/filter/attribute.rs @@ -0,0 +1,183 @@ +use std::ops::Bound; + +use crate::{ + common::{types::TableId, values::Value}, + cost_model::CostModelImpl, + stats::{AttributeCombValue, AttributeCombValueStats, DEFAULT_EQ_SEL, DEFAULT_INEQ_SEL}, + storage::CostModelStorageManager, + CostModelResult, +}; + +impl CostModelImpl { + /// Get the selectivity of an expression of the form "attribute equals value" (or "value equals + /// attribute") Will handle the case of statistics missing + /// Equality predicates are handled entirely differently from range predicates so this is its + /// own function + /// Also, get_attribute_equality_selectivity is a subroutine when computing range + /// selectivity, which is another reason for separating these into two functions + /// is_eq means whether it's == or != + /// + /// Currently, we only support calculating the equality selectivity for an existed attribute, + /// not a derived attribute. + /// TODO: Support derived attributes. + pub(crate) async fn get_attribute_equality_selectivity( + &self, + table_id: TableId, + attr_base_index: u64, + value: &Value, + is_eq: bool, + ) -> CostModelResult { + let ret_sel = { + if let Some(attribute_stats) = self + .get_attribute_comb_stats(table_id, &[attr_base_index]) + .await? + { + let eq_freq = + if let Some(freq) = attribute_stats.mcvs.freq(&vec![Some(value.clone())]) { + freq + } else { + let non_mcv_freq = 1.0 - attribute_stats.mcvs.total_freq(); + // always safe because usize is at least as large as i32 + let ndistinct_as_usize = attribute_stats.ndistinct as usize; + let non_mcv_cnt = ndistinct_as_usize - attribute_stats.mcvs.cnt(); + if non_mcv_cnt == 0 { + return Ok(0.0); + } + // note that nulls are not included in ndistinct so we don't need to do non_mcv_cnt + // - 1 if null_frac > 0 + (non_mcv_freq - attribute_stats.null_frac) / (non_mcv_cnt as f64) + }; + if is_eq { + eq_freq + } else { + 1.0 - eq_freq - attribute_stats.null_frac + } + } else { + #[allow(clippy::collapsible_else_if)] + if is_eq { + DEFAULT_EQ_SEL + } else { + 1.0 - DEFAULT_EQ_SEL + } + } + }; + + assert!( + (0.0..=1.0).contains(&ret_sel), + "ret_sel ({}) should be in [0, 1]", + ret_sel + ); + Ok(ret_sel) + } + + /// Compute the frequency of values in a attribute less than or equal to the given value. + fn get_attribute_leq_value_freq( + per_attribute_stats: &AttributeCombValueStats, + value: &Value, + ) -> f64 { + // because distr does not include the values in MCVs, we need to compute the CDFs there as + // well because nulls return false in any comparison, they are never included when + // computing range selectivity + let distr_leq_freq = per_attribute_stats.distr.as_ref().unwrap().cdf(value); + let value = value.clone(); + let pred = Box::new(move |val: &AttributeCombValue| *val[0].as_ref().unwrap() <= value); + let mcvs_leq_freq = per_attribute_stats.mcvs.freq_over_pred(pred); + let ret_freq = distr_leq_freq + mcvs_leq_freq; + assert!( + (0.0..=1.0).contains(&ret_freq), + "ret_freq ({}) should be in [0, 1]", + ret_freq + ); + ret_freq + } + + /// Compute the frequency of values in a attribute less than the given value. + /// + /// Currently, we only support calculating the equality selectivity for an existed attribute, + /// not a derived attribute. + /// TODO: Support derived attributes. + async fn get_attribute_lt_value_freq( + &self, + attribute_stats: &AttributeCombValueStats, + table_id: TableId, + attr_base_index: u64, + value: &Value, + ) -> CostModelResult { + // depending on whether value is in mcvs or not, we use different logic to turn total_lt_cdf + // into total_leq_cdf this logic just so happens to be the exact same logic as + // get_attribute_equality_selectivity implements + let ret_freq = Self::get_attribute_leq_value_freq(attribute_stats, value) + - self + .get_attribute_equality_selectivity(table_id, attr_base_index, value, true) + .await?; + assert!( + (0.0..=1.0).contains(&ret_freq), + "ret_freq ({}) should be in [0, 1]", + ret_freq + ); + Ok(ret_freq) + } + + /// Get the selectivity of an expression of the form "attribute =/> value" (or "value + /// =/> attribute"). Computes selectivity based off of statistics. + /// Range predicates are handled entirely differently from equality predicates so this is its + /// own function. If it is unable to find the statistics, it returns DEFAULT_INEQ_SEL. + /// The selectivity is computed as quantile of the right bound minus quantile of the left bound. + /// + /// Currently, we only support calculating the equality selectivity for an existed attribute, + /// not a derived attribute. + /// TODO: Support derived attributes. + pub(crate) async fn get_attribute_range_selectivity( + &self, + table_id: TableId, + attr_base_index: u64, + start: Bound<&Value>, + end: Bound<&Value>, + ) -> CostModelResult { + // TODO: Consider attribute is a derived attribute + if let Some(attribute_stats) = self + .get_attribute_comb_stats(table_id, &[attr_base_index]) + .await? + { + let left_quantile = match start { + Bound::Unbounded => 0.0, + Bound::Included(value) => { + self.get_attribute_lt_value_freq( + &attribute_stats, + table_id, + attr_base_index, + value, + ) + .await? + } + Bound::Excluded(value) => { + Self::get_attribute_leq_value_freq(&attribute_stats, value) + } + }; + let right_quantile = match end { + Bound::Unbounded => 1.0, + Bound::Included(value) => { + Self::get_attribute_leq_value_freq(&attribute_stats, value) + } + Bound::Excluded(value) => { + self.get_attribute_lt_value_freq( + &attribute_stats, + table_id, + attr_base_index, + value, + ) + .await? + } + }; + assert!( + left_quantile <= right_quantile, + "left_quantile ({}) should be <= right_quantile ({})", + left_quantile, + right_quantile + ); + Ok(right_quantile - left_quantile) + } else { + Ok(DEFAULT_INEQ_SEL) + } + } +} diff --git a/optd-cost-model/src/cost/filter/comp_op.rs b/optd-cost-model/src/cost/filter/comp_op.rs new file mode 100644 index 0000000..5270819 --- /dev/null +++ b/optd-cost-model/src/cost/filter/comp_op.rs @@ -0,0 +1,280 @@ +use std::ops::Bound; + +use crate::{ + common::{ + nodes::{ArcPredicateNode, PredicateType, ReprPredicateNode}, + predicates::{ + attr_index_pred::AttrIndexPred, bin_op_pred::BinOpType, cast_pred::CastPred, + constant_pred::ConstantPred, + }, + properties::attr_ref::{AttrRef, BaseTableAttrRef}, + types::GroupId, + values::Value, + }, + cost_model::CostModelImpl, + stats::{DEFAULT_EQ_SEL, DEFAULT_INEQ_SEL, UNIMPLEMENTED_SEL}, + storage::CostModelStorageManager, + CostModelResult, SemanticError, +}; + +impl CostModelImpl { + /// Comparison operators are the base case for recursion in get_filter_selectivity() + pub(crate) async fn get_comp_op_selectivity( + &self, + group_id: GroupId, + comp_bin_op_typ: BinOpType, + left: ArcPredicateNode, + right: ArcPredicateNode, + ) -> CostModelResult { + assert!(comp_bin_op_typ.is_comparison()); + + // I intentionally performed moves on left and right. This way, we don't accidentally use + // them after this block + let semantic_res = self.get_semantic_nodes(group_id, left, right).await; + if semantic_res.is_err() { + return Ok(Self::get_default_comparison_op_selectivity(comp_bin_op_typ)); + } + let (attr_ref_exprs, values, non_attr_ref_exprs, is_left_attr_ref) = semantic_res.unwrap(); + + // Handle the different cases of semantic nodes. + if attr_ref_exprs.is_empty() { + Ok(UNIMPLEMENTED_SEL) + } else if attr_ref_exprs.len() == 1 { + let attr_ref_expr = attr_ref_exprs + .first() + .expect("we just checked that attr_ref_exprs.len() == 1"); + let attr_ref_idx = attr_ref_expr.attr_index(); + + if let AttrRef::BaseTableAttrRef(BaseTableAttrRef { table_id, attr_idx }) = + self.memo.get_attribute_ref(group_id, attr_ref_idx) + { + if values.len() == 1 { + let value = values + .first() + .expect("we just checked that values.len() == 1"); + match comp_bin_op_typ { + BinOpType::Eq => { + self.get_attribute_equality_selectivity(table_id, attr_idx, value, true) + .await + } + BinOpType::Neq => { + self.get_attribute_equality_selectivity( + table_id, + attr_ref_idx, + value, + false, + ) + .await + } + BinOpType::Lt | BinOpType::Leq | BinOpType::Gt | BinOpType::Geq => { + let start = match (comp_bin_op_typ, is_left_attr_ref) { + (BinOpType::Lt, true) | (BinOpType::Geq, false) => Bound::Unbounded, + (BinOpType::Leq, true) | (BinOpType::Gt, false) => Bound::Unbounded, + (BinOpType::Gt, true) | (BinOpType::Leq, false) => Bound::Excluded(value), + (BinOpType::Geq, true) | (BinOpType::Lt, false) => Bound::Included(value), + _ => unreachable!("all comparison BinOpTypes were enumerated. this should be unreachable"), + }; + let end = match (comp_bin_op_typ, is_left_attr_ref) { + (BinOpType::Lt, true) | (BinOpType::Geq, false) => Bound::Excluded(value), + (BinOpType::Leq, true) | (BinOpType::Gt, false) => Bound::Included(value), + (BinOpType::Gt, true) | (BinOpType::Leq, false) => Bound::Unbounded, + (BinOpType::Geq, true) | (BinOpType::Lt, false) => Bound::Unbounded, + _ => unreachable!("all comparison BinOpTypes were enumerated. this should be unreachable"), + }; + self.get_attribute_range_selectivity(table_id, attr_ref_idx, start, end) + .await + } + _ => unreachable!( + "all comparison BinOpTypes were enumerated. this should be unreachable" + ), + } + } else { + let non_attr_ref_expr = non_attr_ref_exprs.first().expect( + "non_attr_ref_exprs should have a value since attr_ref_exprs.len() == 1", + ); + + match non_attr_ref_expr.as_ref().typ { + PredicateType::BinOp(_) => { + Ok(Self::get_default_comparison_op_selectivity(comp_bin_op_typ)) + } + PredicateType::Cast => Ok(UNIMPLEMENTED_SEL), + PredicateType::Constant(_) => { + unreachable!( + "we should have handled this in the values.len() == 1 branch" + ) + } + _ => unimplemented!( + "unhandled case of comparing a attribute ref node to {}", + non_attr_ref_expr.as_ref().typ + ), + } + } + } else { + // TODO: attribute is derived + Ok(Self::get_default_comparison_op_selectivity(comp_bin_op_typ)) + } + } else if attr_ref_exprs.len() == 2 { + Ok(Self::get_default_comparison_op_selectivity(comp_bin_op_typ)) + } else { + unreachable!("we could have at most pushed left and right into attr_ref_exprs") + } + } + + /// Convert the left and right child nodes of some operation to what they semantically are. + /// This is convenient to avoid repeating the same logic just with "left" and "right" swapped. + /// The last return value is true when the input node (left) is a AttributeRefPred. + #[allow(clippy::type_complexity)] + async fn get_semantic_nodes( + &self, + group_id: GroupId, + left: ArcPredicateNode, + right: ArcPredicateNode, + ) -> CostModelResult<(Vec, Vec, Vec, bool)> { + let mut attr_ref_exprs = vec![]; + let mut values = vec![]; + let mut non_attr_ref_exprs = vec![]; + let is_left_attr_ref; + + // Recursively unwrap casts as much as we can. + let mut uncasted_left = left; + let mut uncasted_right = right; + loop { + // println!("loop {}, uncasted_left={:?}, uncasted_right={:?}", Local::now(), + // uncasted_left, uncasted_right); + if uncasted_left.as_ref().typ == PredicateType::Cast + && uncasted_right.as_ref().typ == PredicateType::Cast + { + let left_cast_expr = CastPred::from_pred_node(uncasted_left) + .expect("we already checked that the type is Cast"); + let right_cast_expr = CastPred::from_pred_node(uncasted_right) + .expect("we already checked that the type is Cast"); + assert!(left_cast_expr.cast_to() == right_cast_expr.cast_to()); + uncasted_left = left_cast_expr.child().into_pred_node(); + uncasted_right = right_cast_expr.child().into_pred_node(); + } else if uncasted_left.as_ref().typ == PredicateType::Cast + || uncasted_right.as_ref().typ == PredicateType::Cast + { + let is_left_cast = uncasted_left.as_ref().typ == PredicateType::Cast; + let (mut cast_node, mut non_cast_node) = if is_left_cast { + (uncasted_left, uncasted_right) + } else { + (uncasted_right, uncasted_left) + }; + + let cast_expr = CastPred::from_pred_node(cast_node) + .expect("we already checked that the type is Cast"); + let cast_expr_child = cast_expr.child().into_pred_node(); + let cast_expr_cast_to = cast_expr.cast_to(); + + let should_break = match cast_expr_child.typ { + PredicateType::Constant(_) => { + cast_node = ConstantPred::new( + ConstantPred::from_pred_node(cast_expr_child) + .expect("we already checked that the type is Constant") + .value() + .convert_to_type(cast_expr_cast_to), + ) + .into_pred_node(); + false + } + PredicateType::AttrIndex => { + let attr_ref_expr = AttrIndexPred::from_pred_node(cast_expr_child) + .expect("we already checked that the type is AttributeRef"); + let attr_ref_idx = attr_ref_expr.attr_index(); + cast_node = attr_ref_expr.into_pred_node(); + // The "invert" cast is to invert the cast so that we're casting the + // non_cast_node to the attribute's original type. + let attribute_info = self.memo.get_attribute_info(group_id, attr_ref_idx); + let invert_cast_data_type = &attribute_info.typ.into_data_type(); + + match non_cast_node.typ { + PredicateType::AttrIndex => { + // In general, there's no way to remove the Cast here. We can't move + // the Cast to the other AttributeRef + // because that would lead to an infinite loop. Thus, we just leave + // the cast where it is and break. + true + } + _ => { + non_cast_node = + CastPred::new(non_cast_node, invert_cast_data_type.clone()) + .into_pred_node(); + false + } + } + } + _ => todo!(), + }; + + (uncasted_left, uncasted_right) = if is_left_cast { + (cast_node, non_cast_node) + } else { + (non_cast_node, cast_node) + }; + + if should_break { + break; + } + } else { + break; + } + } + + // Sort nodes into attr_ref_exprs, values, and non_attr_ref_exprs + match uncasted_left.as_ref().typ { + PredicateType::AttrIndex => { + is_left_attr_ref = true; + attr_ref_exprs.push( + AttrIndexPred::from_pred_node(uncasted_left) + .expect("we already checked that the type is AttributeRef"), + ); + } + PredicateType::Constant(_) => { + is_left_attr_ref = false; + values.push( + ConstantPred::from_pred_node(uncasted_left) + .expect("we already checked that the type is Constant") + .value(), + ) + } + _ => { + is_left_attr_ref = false; + non_attr_ref_exprs.push(uncasted_left); + } + } + match uncasted_right.as_ref().typ { + PredicateType::AttrIndex => { + attr_ref_exprs.push( + AttrIndexPred::from_pred_node(uncasted_right) + .expect("we already checked that the type is AttributeRef"), + ); + } + PredicateType::Constant(_) => values.push( + ConstantPred::from_pred_node(uncasted_right) + .expect("we already checked that the type is Constant") + .value(), + ), + _ => { + non_attr_ref_exprs.push(uncasted_right); + } + } + + assert!(attr_ref_exprs.len() + values.len() + non_attr_ref_exprs.len() == 2); + Ok((attr_ref_exprs, values, non_attr_ref_exprs, is_left_attr_ref)) + } + + /// The default selectivity of a comparison expression + /// Used when one side of the comparison is a attribute while the other side is something too + /// complex/impossible to evaluate (subquery, UDF, another attribute, we have no stats, etc.) + fn get_default_comparison_op_selectivity(comp_bin_op_typ: BinOpType) -> f64 { + assert!(comp_bin_op_typ.is_comparison()); + match comp_bin_op_typ { + BinOpType::Eq => DEFAULT_EQ_SEL, + BinOpType::Neq => 1.0 - DEFAULT_EQ_SEL, + BinOpType::Lt | BinOpType::Leq | BinOpType::Gt | BinOpType::Geq => DEFAULT_INEQ_SEL, + _ => unreachable!( + "all comparison BinOpTypes were enumerated. this should be unreachable" + ), + } + } +} diff --git a/optd-cost-model/src/cost/filter/constant.rs b/optd-cost-model/src/cost/filter/constant.rs new file mode 100644 index 0000000..e131bde --- /dev/null +++ b/optd-cost-model/src/cost/filter/constant.rs @@ -0,0 +1,38 @@ +use crate::{ + common::{ + nodes::{ArcPredicateNode, PredicateType}, + predicates::constant_pred::ConstantType, + values::Value, + }, + cost_model::CostModelImpl, + storage::CostModelStorageManager, +}; + +impl CostModelImpl { + pub(crate) fn get_constant_selectivity(const_node: ArcPredicateNode) -> f64 { + if let PredicateType::Constant(const_typ) = const_node.typ { + if matches!(const_typ, ConstantType::Bool) { + let value = const_node + .as_ref() + .data + .as_ref() + .expect("constants should have data"); + if let Value::Bool(bool_value) = value { + if *bool_value { + 1.0 + } else { + 0.0 + } + } else { + unreachable!( + "if the typ is ConstantType::Bool, the value should be a Value::Bool" + ) + } + } else { + panic!("selectivity is not defined on constants which are not bools") + } + } else { + panic!("get_constant_selectivity must be called on a constant") + } + } +} diff --git a/optd-cost-model/src/cost/filter/core.rs b/optd-cost-model/src/cost/filter/core.rs new file mode 100644 index 0000000..05363e4 --- /dev/null +++ b/optd-cost-model/src/cost/filter/core.rs @@ -0,0 +1,877 @@ +use crate::{ + common::{ + nodes::{ArcPredicateNode, PredicateType, ReprPredicateNode}, + predicates::{in_list_pred::InListPred, like_pred::LikePred, un_op_pred::UnOpType}, + types::GroupId, + }, + cost_model::CostModelImpl, + stats::UNIMPLEMENTED_SEL, + storage::CostModelStorageManager, + CostModelResult, EstimatedStatistic, +}; + +impl CostModelImpl { + // TODO: is it a good design to pass table_id here? I think it needs to be refactored. + // Consider to remove table_id. + pub async fn get_filter_row_cnt( + &self, + child_row_cnt: EstimatedStatistic, + group_id: GroupId, + cond: ArcPredicateNode, + ) -> CostModelResult { + let selectivity = { self.get_filter_selectivity(group_id, cond).await? }; + Ok(EstimatedStatistic((child_row_cnt.0 * selectivity).max(1.0))) + } + + pub async fn get_filter_selectivity( + &self, + group_id: GroupId, + expr_tree: ArcPredicateNode, + ) -> CostModelResult { + Box::pin(async move { + match &expr_tree.typ { + PredicateType::Constant(_) => Ok(Self::get_constant_selectivity(expr_tree)), + PredicateType::AttrIndex => unimplemented!("check bool type or else panic"), + PredicateType::UnOp(un_op_typ) => { + assert!(expr_tree.children.len() == 1); + let child = expr_tree.child(0); + match un_op_typ { + // not doesn't care about nulls so there's no complex logic. it just reverses + // the selectivity for instance, != _will not_ include nulls + // but "NOT ==" _will_ include nulls + UnOpType::Not => Ok(1.0 - self.get_filter_selectivity(group_id, child).await?), + UnOpType::Neg => panic!( + "the selectivity of operations that return numerical values is undefined" + ), + } + } + PredicateType::BinOp(bin_op_typ) => { + assert!(expr_tree.children.len() == 2); + let left_child = expr_tree.child(0); + let right_child = expr_tree.child(1); + + if bin_op_typ.is_comparison() { + self.get_comp_op_selectivity(group_id, *bin_op_typ, left_child, right_child).await + } else if bin_op_typ.is_numerical() { + panic!( + "the selectivity of operations that return numerical values is undefined" + ) + } else { + unreachable!("all BinOpTypes should be true for at least one is_*() function") + } + } + PredicateType::LogOp(log_op_typ) => { + self.get_log_op_selectivity(group_id, *log_op_typ, &expr_tree.children).await + } + PredicateType::Func(_) => unimplemented!("check bool type or else panic"), + PredicateType::SortOrder(_) => { + panic!("the selectivity of sort order expressions is undefined") + } + PredicateType::Between => Ok(UNIMPLEMENTED_SEL), + PredicateType::Cast => unimplemented!("check bool type or else panic"), + PredicateType::Like => { + let like_expr = LikePred::from_pred_node(expr_tree).unwrap(); + self.get_like_selectivity(group_id, &like_expr).await + } + PredicateType::DataType(_) => { + panic!("the selectivity of a data type is not defined") + } + PredicateType::InList => { + let in_list_expr = InListPred::from_pred_node(expr_tree).unwrap(); + self.get_in_list_selectivity(group_id, &in_list_expr).await + } + _ => unreachable!( + "all expression DfPredType were enumerated. this should be unreachable" + ), + } + }).await + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use crate::{ + common::{ + predicates::{ + bin_op_pred::BinOpType, constant_pred::ConstantType, log_op_pred::LogOpType, + un_op_pred::UnOpType, + }, + properties::Attribute, + types::TableId, + values::Value, + }, + cost_model::tests::*, + memo_ext::tests::MemoGroupInfo, + stats::{ + utilities::{counter::Counter, simple_map::SimpleMap}, + Distribution, MostCommonValues, DEFAULT_EQ_SEL, + }, + }; + use arrow_schema::DataType; + + #[tokio::test] + async fn test_const() { + let cost_model = create_mock_cost_model( + vec![TableId(0)], + vec![HashMap::from([(0, empty_per_attr_stats())])], + vec![None], + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, cnst(Value::Bool(true))) + .await + .unwrap(), + 1.0 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, cnst(Value::Bool(false))) + .await + .unwrap(), + 0.0 + ); + } + + #[tokio::test] + async fn test_attr_ref_eq_constint_in_mcv() { + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![( + vec![Some(Value::Int32(1))], + 0.3, + )])), + None, + 0, + 0.0, + ); + let table_id = TableId(0); + let cost_model = create_mock_cost_model( + vec![table_id], + vec![HashMap::from([(0, per_attribute_stats)])], + vec![None], + ); + + let expr_tree = bin_op( + BinOpType::Eq, + attr_index(0), // TODO: Fix this + cnst(Value::Int32(1)), + ); + let expr_tree_rev = bin_op( + BinOpType::Eq, + cnst(Value::Int32(1)), + attr_index(0), // TODO: Fix this + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), + 0.3 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_rev) + .await + .unwrap(), + 0.3 + ); + } + + #[tokio::test] + async fn test_attr_ref_eq_constint_not_in_mcv() { + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![ + (vec![Some(Value::Int32(1))], 0.2), + (vec![Some(Value::Int32(3))], 0.44), + ])), + None, + 5, + 0.0, + ); + let cost_model = create_mock_cost_model( + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], + vec![None], + ); + + let expr_tree = bin_op(BinOpType::Eq, attr_index(0), cnst(Value::Int32(2))); + let expr_tree_rev = bin_op(BinOpType::Eq, cnst(Value::Int32(2)), attr_index(0)); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), + 0.12 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_rev) + .await + .unwrap(), + 0.12 + ); + } + + /// I only have one test for NEQ since I'll assume that it uses the same underlying logic as EQ + #[tokio::test] + async fn test_attr_ref_neq_constint_in_mcv() { + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![( + vec![Some(Value::Int32(1))], + 0.3, + )])), + None, + 0, + 0.0, + ); + let cost_model = create_mock_cost_model( + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], + vec![None], + ); + + let expr_tree = bin_op(BinOpType::Neq, attr_index(0), cnst(Value::Int32(1))); + let expr_tree_rev = bin_op(BinOpType::Neq, cnst(Value::Int32(1)), attr_index(0)); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), + 1.0 - 0.3 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_rev) + .await + .unwrap(), + 1.0 - 0.3 + ); + } + + #[tokio::test] + async fn test_attr_ref_leq_constint_no_mcvs_in_range() { + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::SimpleFrequency(SimpleMap::default()), + Some(Distribution::SimpleDistribution(SimpleMap::new(vec![( + Value::Int32(15), + 0.7, + )]))), + 10, + 0.0, + ); + let cost_model = create_mock_cost_model( + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], + vec![None], + ); + + let expr_tree = bin_op(BinOpType::Leq, attr_index(0), cnst(Value::Int32(15))); + let expr_tree_rev = bin_op(BinOpType::Gt, cnst(Value::Int32(15)), attr_index(0)); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), + 0.7 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_rev) + .await + .unwrap(), + 0.7 + ); + } + + #[tokio::test] + async fn test_attr_ref_leq_constint_with_mcvs_in_range_not_at_border() { + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![ + (vec![Some(Value::Int32(6))], 0.05), + (vec![Some(Value::Int32(10))], 0.1), + (vec![Some(Value::Int32(17))], 0.08), + (vec![Some(Value::Int32(25))], 0.07), + ])), + Some(Distribution::SimpleDistribution(SimpleMap::new(vec![( + Value::Int32(15), + 0.7, + )]))), + 10, + 0.0, + ); + let cost_model = create_mock_cost_model( + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], + vec![None], + ); + + let expr_tree = bin_op(BinOpType::Leq, attr_index(0), cnst(Value::Int32(15))); + let expr_tree_rev = bin_op(BinOpType::Gt, cnst(Value::Int32(15)), attr_index(0)); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), + 0.85 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_rev) + .await + .unwrap(), + 0.85 + ); + } + + #[tokio::test] + async fn test_attr_ref_leq_constint_with_mcv_at_border() { + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![ + (vec![Some(Value::Int32(6))], 0.05), + (vec![Some(Value::Int32(10))], 0.1), + (vec![Some(Value::Int32(15))], 0.08), + (vec![Some(Value::Int32(25))], 0.07), + ])), + Some(Distribution::SimpleDistribution(SimpleMap::new(vec![( + Value::Int32(15), + 0.7, + )]))), + 10, + 0.0, + ); + let cost_model = create_mock_cost_model( + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], + vec![None], + ); + + let expr_tree = bin_op(BinOpType::Leq, attr_index(0), cnst(Value::Int32(15))); + let expr_tree_rev = bin_op(BinOpType::Gt, cnst(Value::Int32(15)), attr_index(0)); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), + 0.93 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_rev) + .await + .unwrap(), + 0.93 + ); + } + + #[tokio::test] + async fn test_attr_ref_lt_constint_no_mcvs_in_range() { + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::SimpleFrequency(SimpleMap::default()), + Some(Distribution::SimpleDistribution(SimpleMap::new(vec![( + Value::Int32(15), + 0.7, + )]))), + 10, + 0.0, + ); + let cost_model = create_mock_cost_model( + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], + vec![None], + ); + + let expr_tree = bin_op(BinOpType::Lt, attr_index(0), cnst(Value::Int32(15))); + let expr_tree_rev = bin_op(BinOpType::Geq, cnst(Value::Int32(15)), attr_index(0)); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), + 0.6 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_rev) + .await + .unwrap(), + 0.6 + ); + } + + #[tokio::test] + async fn test_attr_ef_lt_constint_with_mcvs_in_range_not_at_border() { + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![ + (vec![Some(Value::Int32(6))], 0.05), + (vec![Some(Value::Int32(10))], 0.1), + (vec![Some(Value::Int32(17))], 0.08), + (vec![Some(Value::Int32(25))], 0.07), + ])), + Some(Distribution::SimpleDistribution(SimpleMap::new(vec![( + Value::Int32(15), + 0.7, + )]))), + 11, /* there are 4 MCVs which together add up to 0.3. With 11 total ndistinct, each + * remaining value has freq 0.1 */ + 0.0, + ); + let cost_model = create_mock_cost_model( + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], + vec![None], + ); + + let expr_tree = bin_op(BinOpType::Lt, attr_index(0), cnst(Value::Int32(15))); + let expr_tree_rev = bin_op(BinOpType::Geq, cnst(Value::Int32(15)), attr_index(0)); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), + 0.75 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_rev) + .await + .unwrap(), + 0.75 + ); + } + + #[tokio::test] + async fn test_attr_ref_lt_constint_with_mcv_at_border() { + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![ + (vec![Some(Value::Int32(6))], 0.05), + (vec![Some(Value::Int32(10))], 0.1), + (vec![Some(Value::Int32(15))], 0.08), + (vec![Some(Value::Int32(25))], 0.07), + ])), + Some(Distribution::SimpleDistribution(SimpleMap::new(vec![( + Value::Int32(15), + 0.7, + )]))), + 11, /* there are 4 MCVs which together add up to 0.3. With 11 total ndistinct, each + * remaining value has freq 0.1 */ + 0.0, + ); + let cost_model = create_mock_cost_model( + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], + vec![None], + ); + + let expr_tree = bin_op(BinOpType::Lt, attr_index(0), cnst(Value::Int32(15))); + let expr_tree_rev = bin_op(BinOpType::Geq, cnst(Value::Int32(15)), attr_index(0)); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), + 0.85 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_rev) + .await + .unwrap(), + 0.85 + ); + } + + /// I have fewer tests for GT since I'll assume that it uses the same underlying logic as LEQ + /// The only interesting thing to test is that if there are nulls, those aren't included in GT + #[tokio::test] + async fn test_attr_ref_gt_constint() { + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::SimpleFrequency(SimpleMap::default()), + Some(Distribution::SimpleDistribution(SimpleMap::new(vec![( + Value::Int32(15), + 0.7, + )]))), + 10, + 0.0, + ); + let cost_model = create_mock_cost_model( + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], + vec![None], + ); + + let expr_tree = bin_op(BinOpType::Gt, attr_index(0), cnst(Value::Int32(15))); + let expr_tree_rev = bin_op(BinOpType::Leq, cnst(Value::Int32(15)), attr_index(0)); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), + 1.0 - 0.7 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_rev) + .await + .unwrap(), + 1.0 - 0.7 + ); + } + + #[tokio::test] + async fn test_attr_ref_geq_constint() { + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::SimpleFrequency(SimpleMap::default()), + Some(Distribution::SimpleDistribution(SimpleMap::new(vec![( + Value::Int32(15), + 0.7, + )]))), + 10, + 0.0, + ); + let cost_model = create_mock_cost_model( + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], + vec![None], + ); + + let expr_tree = bin_op(BinOpType::Geq, attr_index(0), cnst(Value::Int32(15))); + let expr_tree_rev = bin_op(BinOpType::Lt, cnst(Value::Int32(15)), attr_index(0)); + + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), + 1.0 - 0.6 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_rev) + .await + .unwrap(), + 1.0 - 0.6 + ); + } + + #[tokio::test] + async fn test_and() { + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![ + (vec![Some(Value::Int32(1))], 0.3), + (vec![Some(Value::Int32(5))], 0.5), + (vec![Some(Value::Int32(8))], 0.2), + ])), + None, + 0, + 0.0, + ); + let cost_model = create_mock_cost_model( + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], + vec![None], + ); + + let eq1 = bin_op(BinOpType::Eq, attr_index(0), cnst(Value::Int32(1))); + let eq5 = bin_op(BinOpType::Eq, attr_index(0), cnst(Value::Int32(5))); + let eq8 = bin_op(BinOpType::Eq, attr_index(0), cnst(Value::Int32(8))); + let expr_tree = log_op(LogOpType::And, vec![eq1.clone(), eq5.clone(), eq8.clone()]); + let expr_tree_shift1 = log_op(LogOpType::And, vec![eq5.clone(), eq8.clone(), eq1.clone()]); + let expr_tree_shift2 = log_op(LogOpType::And, vec![eq8.clone(), eq1.clone(), eq5.clone()]); + + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), + 0.03 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_shift1) + .await + .unwrap(), + 0.03 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_shift2) + .await + .unwrap(), + 0.03 + ); + } + + #[tokio::test] + async fn test_or() { + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![ + (vec![Some(Value::Int32(1))], 0.3), + (vec![Some(Value::Int32(5))], 0.5), + (vec![Some(Value::Int32(8))], 0.2), + ])), + None, + 0, + 0.0, + ); + let cost_model = create_mock_cost_model( + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], + vec![None], + ); + + let eq1 = bin_op(BinOpType::Eq, attr_index(0), cnst(Value::Int32(1))); + let eq5 = bin_op(BinOpType::Eq, attr_index(0), cnst(Value::Int32(5))); + let eq8 = bin_op(BinOpType::Eq, attr_index(0), cnst(Value::Int32(8))); + let expr_tree = log_op(LogOpType::Or, vec![eq1.clone(), eq5.clone(), eq8.clone()]); + let expr_tree_shift1 = log_op(LogOpType::Or, vec![eq5.clone(), eq8.clone(), eq1.clone()]); + let expr_tree_shift2 = log_op(LogOpType::Or, vec![eq8.clone(), eq1.clone(), eq5.clone()]); + + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), + 0.72 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_shift1) + .await + .unwrap(), + 0.72 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_shift2) + .await + .unwrap(), + 0.72 + ); + } + + #[tokio::test] + async fn test_not() { + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![( + vec![Some(Value::Int32(1))], + 0.3, + )])), + None, + 0, + 0.0, + ); + let cost_model = create_mock_cost_model( + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], + vec![None], + ); + + let expr_tree = un_op( + UnOpType::Not, + bin_op(BinOpType::Eq, attr_index(0), cnst(Value::Int32(1))), + ); + + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), + 0.7 + ); + } + + // I didn't test any non-unique cases with filter. The non-unique tests without filter should + // cover that + + #[tokio::test] + async fn test_attr_ref_eq_cast_value() { + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![( + vec![Some(Value::Int32(1))], + 0.3, + )])), + None, + 0, + 0.0, + ); + let cost_model = create_mock_cost_model( + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], + vec![None], + ); + + let expr_tree = bin_op( + BinOpType::Eq, + attr_index(0), + cast(cnst(Value::Int64(1)), DataType::Int32), + ); + let expr_tree_rev = bin_op( + BinOpType::Eq, + cast(cnst(Value::Int64(1)), DataType::Int32), + attr_index(0), + ); + + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), + 0.3 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_rev) + .await + .unwrap(), + 0.3 + ); + } + + #[tokio::test] + async fn test_cast_attr_ref_eq_value() { + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![( + vec![Some(Value::Int32(1))], + 0.3, + )])), + None, + 0, + 0.1, + ); + let cost_model = create_mock_cost_model_with_attr_types( + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + ConstantType::Int32, + )])], + vec![None], + ); + + let expr_tree = bin_op( + BinOpType::Eq, + cast(attr_index(0), DataType::Int64), // TODO: Fix this + cnst(Value::Int64(1)), + ); + let expr_tree_rev = bin_op( + BinOpType::Eq, + cnst(Value::Int64(1)), + cast(attr_index(0), DataType::Int64), // TODO: Fix this + ); + + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), + 0.3 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_rev) + .await + .unwrap(), + 0.3 + ); + } + + /// In this case, we should leave the Cast as is. + /// + /// Note that the test only checks the selectivity and thus doesn't explicitly test that the + /// Cast is indeed left as is. However, if get_filter_selectivity() doesn't crash, that's a + /// pretty good signal that the Cast was left as is. + #[tokio::test] + async fn test_cast_attr_ref_eq_attr_ref() { + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::SimpleFrequency(SimpleMap::default()), + None, + 0, + 0.0, + ); + let table_id = TableId(0); + let cost_model = create_mock_cost_model_with_attr_types( + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], + vec![HashMap::from([ + (TEST_ATTR1_BASE_INDEX, ConstantType::Int32), + (TEST_ATTR2_BASE_INDEX, ConstantType::Int64), + ])], + vec![None], + ); + + let expr_tree = bin_op( + BinOpType::Eq, + cast(attr_index(0), DataType::Int64), + attr_index(1), + ); + let expr_tree_rev = bin_op( + BinOpType::Eq, + attr_index(1), + cast(attr_index(0), DataType::Int64), + ); + + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), + DEFAULT_EQ_SEL + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_rev) + .await + .unwrap(), + DEFAULT_EQ_SEL + ); + } +} diff --git a/optd-cost-model/src/cost/filter/in_list.rs b/optd-cost-model/src/cost/filter/in_list.rs new file mode 100644 index 0000000..f056fb1 --- /dev/null +++ b/optd-cost-model/src/cost/filter/in_list.rs @@ -0,0 +1,167 @@ +use crate::{ + common::{ + nodes::{PredicateType, ReprPredicateNode}, + predicates::{ + attr_index_pred::AttrIndexPred, constant_pred::ConstantPred, in_list_pred::InListPred, + }, + properties::attr_ref::{AttrRef, BaseTableAttrRef}, + types::GroupId, + }, + cost_model::CostModelImpl, + stats::UNIMPLEMENTED_SEL, + storage::CostModelStorageManager, + CostModelResult, +}; + +impl CostModelImpl { + /// Only support attrA in (val1, val2, val3) where attrA is a attribute ref and + /// val1, val2, val3 are constants. + pub(crate) async fn get_in_list_selectivity( + &self, + group_id: GroupId, + expr: &InListPred, + ) -> CostModelResult { + let child = expr.child(); + + // Check child is a attribute ref. + if !matches!(child.typ, PredicateType::AttrIndex) { + return Ok(UNIMPLEMENTED_SEL); + } + + // Check all expressions in the list are constants. + let list_exprs = expr.list().to_vec(); + if list_exprs + .iter() + .any(|expr| !matches!(expr.typ, PredicateType::Constant(_))) + { + return Ok(UNIMPLEMENTED_SEL); + } + + // Convert child and const expressions to concrete types. + let attr_ref_pred = AttrIndexPred::from_pred_node(child).unwrap(); + let attr_ref_idx = attr_ref_pred.attr_index(); + + let list_exprs = list_exprs + .into_iter() + .map(|expr| { + ConstantPred::from_pred_node(expr) + .expect("we already checked all list elements are constants") + }) + .collect::>(); + let negated = expr.negated(); + + if let AttrRef::BaseTableAttrRef(BaseTableAttrRef { table_id, attr_idx }) = + self.memo.get_attribute_ref(group_id, attr_ref_idx) + { + let mut in_sel = 0.0; + for expr in &list_exprs { + let selectivity = self + .get_attribute_equality_selectivity( + table_id, + attr_idx, + &expr.value(), + /* is_equality */ true, + ) + .await?; + in_sel += selectivity; + } + in_sel = in_sel.min(1.0); + if negated { + Ok(1.0 - in_sel) + } else { + Ok(in_sel) + } + } else { + // TODO: Child is a derived attribute. + Ok(UNIMPLEMENTED_SEL) + } + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use crate::{ + common::{ + types::{GroupId, TableId}, + values::Value, + }, + cost_model::tests::*, + memo_ext::tests::MemoGroupInfo, + stats::{ + utilities::{counter::Counter, simple_map::SimpleMap}, + MostCommonValues, + }, + }; + + #[tokio::test] + async fn test_in_list() { + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![ + (vec![Some(Value::Int32(1))], 0.8), + (vec![Some(Value::Int32(2))], 0.2), + ])), + None, + 2, + 0.0, + ); + let cost_model = create_mock_cost_model( + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], + vec![None], + ); + + assert_approx_eq::assert_approx_eq!( + cost_model + .get_in_list_selectivity(TEST_GROUP1_ID, &in_list(0, vec![Value::Int32(1)], false)) + .await + .unwrap(), + 0.8 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_in_list_selectivity( + TEST_GROUP1_ID, + &in_list(0, vec![Value::Int32(1), Value::Int32(2)], false) + ) + .await + .unwrap(), + 1.0 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_in_list_selectivity(TEST_GROUP1_ID, &in_list(0, vec![Value::Int32(3)], false)) + .await + .unwrap(), + 0.0 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_in_list_selectivity(TEST_GROUP1_ID, &in_list(0, vec![Value::Int32(1)], true)) + .await + .unwrap(), + 0.2 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_in_list_selectivity( + TEST_GROUP1_ID, + &in_list(0, vec![Value::Int32(1), Value::Int32(2)], true) + ) + .await + .unwrap(), + 0.0 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_in_list_selectivity(TEST_GROUP1_ID, &in_list(0, vec![Value::Int32(3)], true)) // TODO: Fix this + .await + .unwrap(), + 1.0 + ); + } +} diff --git a/optd-cost-model/src/cost/filter/like.rs b/optd-cost-model/src/cost/filter/like.rs new file mode 100644 index 0000000..32800e4 --- /dev/null +++ b/optd-cost-model/src/cost/filter/like.rs @@ -0,0 +1,210 @@ +use datafusion::arrow::{array::StringArray, compute::like}; + +use crate::{ + common::{ + nodes::{PredicateType, ReprPredicateNode}, + predicates::{ + attr_index_pred::AttrIndexPred, constant_pred::ConstantPred, like_pred::LikePred, + }, + properties::attr_ref::{AttrRef, BaseTableAttrRef}, + types::GroupId, + }, + cost_model::CostModelImpl, + stats::{ + AttributeCombValue, FIXED_CHAR_SEL_FACTOR, FULL_WILDCARD_SEL_FACTOR, UNIMPLEMENTED_SEL, + }, + storage::CostModelStorageManager, + CostModelResult, +}; + +impl CostModelImpl { + /// Compute the selectivity of a (NOT) LIKE expression. + /// + /// The logic is somewhat similar to Postgres but different. Postgres first estimates the + /// histogram part of the population and then add up data for any MCV values. If the + /// histogram is large enough, it just uses the number of matches in the histogram, + /// otherwise it estimates the fixed prefix and remainder of pattern separately and + /// combine them. + /// + /// Our approach is simpler and less selective. Firstly, we don't use histogram. The selectivity + /// is composed of MCV frequency and non-MCV selectivity. MCV frequency is computed by + /// adding up frequencies of MCVs that match the pattern. Non-MCV selectivity is computed + /// in the same way that Postgres computes selectivity for the wildcard part of the pattern. + pub(crate) async fn get_like_selectivity( + &self, + group_id: GroupId, + like_expr: &LikePred, + ) -> CostModelResult { + let child = like_expr.child(); + + // Check child is a attribute ref. + if !matches!(child.typ, PredicateType::AttrIndex) { + return Ok(UNIMPLEMENTED_SEL); + } + + // Check pattern is a constant. + let pattern = like_expr.pattern(); + if !matches!(pattern.typ, PredicateType::Constant(_)) { + return Ok(UNIMPLEMENTED_SEL); + } + + let attr_ref_pred = AttrIndexPred::from_pred_node(child).unwrap(); + let attr_ref_idx = attr_ref_pred.attr_index(); + + if let AttrRef::BaseTableAttrRef(BaseTableAttrRef { table_id, attr_idx }) = + self.memo.get_attribute_ref(group_id, attr_ref_idx) + { + let pattern = ConstantPred::from_pred_node(pattern) + .expect("we already checked pattern is a constant") + .value() + .as_str(); + + // Compute the selectivity exculuding MCVs. + // See Postgres `like_selectivity`. + let non_mcv_sel = pattern + .chars() + .fold(1.0, |acc, c| { + if c == '%' { + acc * FULL_WILDCARD_SEL_FACTOR + } else { + acc * FIXED_CHAR_SEL_FACTOR + } + }) + .min(1.0); + + // Compute the selectivity in MCVs. + // TODO: Handle the case where `attribute_stats` is None. + let (mut mcv_freq, mut null_frac) = (0.0, 0.0); + if let Some(attribute_stats) = + self.get_attribute_comb_stats(table_id, &[attr_idx]).await? + { + (mcv_freq, null_frac) = { + let pred = Box::new(move |val: &AttributeCombValue| { + let string = + StringArray::from(vec![val[0].as_ref().unwrap().as_str().as_ref()]); + let pattern = StringArray::from(vec![pattern.as_ref()]); + like(&string, &pattern).unwrap().value(0) + }); + ( + attribute_stats.mcvs.freq_over_pred(pred), + attribute_stats.null_frac, + ) + }; + } + let result = non_mcv_sel + mcv_freq; + + Ok(if like_expr.negated() { + 1.0 - result - null_frac + } else { + result + } + // Postgres clamps the result after histogram and before MCV. See Postgres + // `patternsel_common`. + .clamp(0.0001, 0.9999)) + } else { + // TOOD: derived attribute + Ok(UNIMPLEMENTED_SEL) + } + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use crate::{ + common::{ + types::{GroupId, TableId}, + values::Value, + }, + cost_model::tests::*, + stats::{ + utilities::{counter::Counter, simple_map::SimpleMap}, + MostCommonValues, FIXED_CHAR_SEL_FACTOR, FULL_WILDCARD_SEL_FACTOR, + }, + }; + + #[tokio::test] + async fn test_like_no_nulls() { + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![ + (vec![Some(Value::String("abcd".into()))], 0.1), + (vec![Some(Value::String("abc".into()))], 0.1), + ])), + None, + 2, + 0.0, + ); + let cost_model = create_mock_cost_model( + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], + vec![None], + ); + + assert_approx_eq::assert_approx_eq!( + cost_model + .get_like_selectivity( + TEST_GROUP1_ID, + &like(TEST_ATTR1_BASE_INDEX, "%abcd%", false) + ) // TODO: Fix this + .await + .unwrap(), + 0.1 + FULL_WILDCARD_SEL_FACTOR.powi(2) * FIXED_CHAR_SEL_FACTOR.powi(4) + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_like_selectivity(TEST_GROUP1_ID, &like(TEST_ATTR1_BASE_INDEX, "%abc%", false)) // TODO: Fix this + .await + .unwrap(), + 0.1 + 0.1 + FULL_WILDCARD_SEL_FACTOR.powi(2) * FIXED_CHAR_SEL_FACTOR.powi(3) + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_like_selectivity(TEST_GROUP1_ID, &like(TEST_ATTR1_BASE_INDEX, "%abc%", true)) // TODO: Fix this + .await + .unwrap(), + 1.0 - (0.1 + 0.1 + FULL_WILDCARD_SEL_FACTOR.powi(2) * FIXED_CHAR_SEL_FACTOR.powi(3)) + ); + } + + #[tokio::test] + async fn test_like_with_nulls() { + let null_frac = 0.5; + let mut mcvs_counts = HashMap::new(); + mcvs_counts.insert(vec![Some(Value::String("abcd".into()))], 1); + let mcvs_total_count = 10; + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + None, + 2, + null_frac, + ); + let cost_model = create_mock_cost_model( + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], + vec![None], + ); + + assert_approx_eq::assert_approx_eq!( + cost_model + .get_like_selectivity(TEST_GROUP1_ID, &like(0, "%abcd%", false)) // TODO: Fix this + .await + .unwrap(), + 0.1 + FULL_WILDCARD_SEL_FACTOR.powi(2) * FIXED_CHAR_SEL_FACTOR.powi(4) + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_like_selectivity(TEST_GROUP1_ID, &like(0, "%abcd%", true)) // TODO: Fix this + .await + .unwrap(), + 1.0 - (0.1 + FULL_WILDCARD_SEL_FACTOR.powi(2) * FIXED_CHAR_SEL_FACTOR.powi(4)) + - null_frac + ); + } +} diff --git a/optd-cost-model/src/cost/filter/log_op.rs b/optd-cost-model/src/cost/filter/log_op.rs new file mode 100644 index 0000000..61862a2 --- /dev/null +++ b/optd-cost-model/src/cost/filter/log_op.rs @@ -0,0 +1,34 @@ +use crate::{ + common::{nodes::ArcPredicateNode, predicates::log_op_pred::LogOpType, types::GroupId}, + cost_model::CostModelImpl, + storage::CostModelStorageManager, + CostModelResult, +}; + +impl CostModelImpl { + pub(crate) async fn get_log_op_selectivity( + &self, + group_id: GroupId, + log_op_typ: LogOpType, + children: &[ArcPredicateNode], + ) -> CostModelResult { + match log_op_typ { + LogOpType::And => { + let mut and_sel = 1.0; + for child in children { + let selectivity = self.get_filter_selectivity(group_id, child.clone()).await?; + and_sel *= selectivity; + } + Ok(and_sel) + } + LogOpType::Or => { + let mut or_sel_neg = 1.0; + for child in children { + let selectivity = self.get_filter_selectivity(group_id, child.clone()).await?; + or_sel_neg *= (1.0 - selectivity); + } + Ok(1.0 - or_sel_neg) + } + } + } +} diff --git a/optd-cost-model/src/cost/filter/mod.rs b/optd-cost-model/src/cost/filter/mod.rs new file mode 100644 index 0000000..00ea653 --- /dev/null +++ b/optd-cost-model/src/cost/filter/mod.rs @@ -0,0 +1,7 @@ +pub mod attribute; +pub mod comp_op; +pub mod constant; +pub mod core; +pub mod in_list; +pub mod like; +pub mod log_op;