Skip to content

Commit

Permalink
feat(cost-model): introduce MemoExt trait for property info (#38)
Browse files Browse the repository at this point in the history
To compute the cost for an expression, we need information about the schema & attribute ref (including attribute correlations). In the current optd, this is done by calling the optimizer core's methods. We were against this approach in previous discussions because we thought this makes the core and the cost model coupled too much -- we thereby eliminated the optimizer parameter and intended to get all these information from the storage/ORM.

However, there is a performance drawback of getting everything from ORM: the core should have all the information (schema & attribute ref) we need in memory -- it would be more efficient for them to be passed in by the core than querying the underlying external database. This also aligns more with the way the cascades optimizer works: building the memo table in a bottom-up approach and remembering everything.

Therefore, to avoid getting everything from ORM and still use one general interface for all types of node, we would need the core to implement a trait provided by the cost model, and the cost model will call the corresponding methods to get the information, i.e. MemoExt in this PR. This allows the core to remain ignorant of what the cost model needs for computing the cost.
  • Loading branch information
xx01cyx authored Nov 18, 2024
1 parent b5aed2b commit 9ba03e6
Show file tree
Hide file tree
Showing 8 changed files with 455 additions and 5 deletions.
1 change: 1 addition & 0 deletions optd-cost-model/src/common/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub mod nodes;
pub mod predicates;
pub mod properties;
pub mod types;
pub mod values;
245 changes: 245 additions & 0 deletions optd-cost-model/src/common/properties/attr_ref.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
use std::collections::HashSet;

use crate::{common::types::TableId, utils::DisjointSets};

pub type AttrRefs = Vec<AttrRef>;

/// [`BaseTableAttrRef`] represents a reference to an attribute in a base table,
/// i.e. a table existing in the catalog.
#[derive(Clone, Debug, Default, Eq, Hash, PartialEq)]
pub struct BaseTableAttrRef {
pub table_id: TableId,
pub attr_idx: u64,
}

/// [`AttrRef`] represents a reference to an attribute in a query.
#[derive(Clone, Debug)]
pub enum AttrRef {
/// Reference to a base table attribute.
BaseTableAttrRef(BaseTableAttrRef),
/// Reference to a derived attribute (e.g. t.v1 + t.v2).
/// TODO: Better representation of derived attributes.
Derived,
}

impl AttrRef {
pub fn base_table_attr_ref(table_id: TableId, attr_idx: u64) -> Self {
AttrRef::BaseTableAttrRef(BaseTableAttrRef { table_id, attr_idx })
}
}

impl From<BaseTableAttrRef> for AttrRef {
fn from(attr: BaseTableAttrRef) -> Self {
AttrRef::BaseTableAttrRef(attr)
}
}

/// [`EqPredicate`] represents an equality predicate between two attributes.
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct EqPredicate {
pub left: BaseTableAttrRef,
pub right: BaseTableAttrRef,
}

impl EqPredicate {
pub fn new(left: BaseTableAttrRef, right: BaseTableAttrRef) -> Self {
Self { left, right }
}
}

/// [`SemanticCorrelation`] represents the semantic correlation between attributes in a
/// query. "Semantic" means that the attributes are correlated based on the
/// semantics of the query, not the statistics.
///
/// [`SemanticCorrelation`] contains equal attributes denoted by disjoint sets of base
/// table attributes, e.g. {{ t1.c1 = t2.c1 = t3.c1 }, { t1.c2 = t2.c2 }}.
#[derive(Clone, Debug, Default)]
pub struct SemanticCorrelation {
/// A disjoint set of base table attributes with equal values in the same row.
disjoint_eq_attr_sets: DisjointSets<BaseTableAttrRef>,
/// The predicates that define the equalities.
eq_predicates: HashSet<EqPredicate>,
}

impl SemanticCorrelation {
pub fn new() -> Self {
Self {
disjoint_eq_attr_sets: DisjointSets::new(),
eq_predicates: HashSet::new(),
}
}

pub fn add_predicate(&mut self, predicate: EqPredicate) {
let left = &predicate.left;
let right = &predicate.right;

// Add the indices to the set if they do not exist.
if !self.disjoint_eq_attr_sets.contains(left) {
self.disjoint_eq_attr_sets
.make_set(left.clone())
.expect("just checked left attribute index does not exist");
}
if !self.disjoint_eq_attr_sets.contains(right) {
self.disjoint_eq_attr_sets
.make_set(right.clone())
.expect("just checked right attribute index does not exist");
}
// Union the attributes.
self.disjoint_eq_attr_sets
.union(left, right)
.expect("both attribute indices should exist");

// Keep track of the predicate.
self.eq_predicates.insert(predicate);
}

/// Determine if two attributes are in the same set.
pub fn is_eq(&mut self, left: &BaseTableAttrRef, right: &BaseTableAttrRef) -> bool {
self.disjoint_eq_attr_sets
.same_set(left, right)
.unwrap_or(false)
}

pub fn contains(&self, base_attr_ref: &BaseTableAttrRef) -> bool {
self.disjoint_eq_attr_sets.contains(base_attr_ref)
}

/// Get the number of attributes that are equal to `attr`, including `attr` itself.
pub fn num_eq_attributes(&mut self, attr: &BaseTableAttrRef) -> usize {
self.disjoint_eq_attr_sets.set_size(attr).unwrap()
}

/// Find the set of predicates that define the equality of the set of attributes `attr` belongs to.
pub fn find_predicates_for_eq_attr_set(&mut self, attr: &BaseTableAttrRef) -> Vec<EqPredicate> {
let mut predicates = Vec::new();
for predicate in &self.eq_predicates {
let left = &predicate.left;
let right = &predicate.right;
if (left != attr && self.disjoint_eq_attr_sets.same_set(attr, left).unwrap())
|| (right != attr && self.disjoint_eq_attr_sets.same_set(attr, right).unwrap())
{
predicates.push(predicate.clone());
}
}
predicates
}

/// Find the set of attributes that define the equality of the set of attributes `attr` belongs to.
pub fn find_attrs_for_eq_attribute_set(
&mut self,
attr: &BaseTableAttrRef,
) -> HashSet<BaseTableAttrRef> {
let predicates = self.find_predicates_for_eq_attr_set(attr);
predicates
.into_iter()
.flat_map(|predicate| vec![predicate.left, predicate.right])
.collect()
}

/// Union two `EqBaseTableattributesets` to produce a new disjoint sets.
pub fn union(x: Self, y: Self) -> Self {
let mut eq_attr_sets = Self::new();
for predicate in x
.eq_predicates
.into_iter()
.chain(y.eq_predicates.into_iter())
{
eq_attr_sets.add_predicate(predicate);
}
eq_attr_sets
}

pub fn merge(x: Option<Self>, y: Option<Self>) -> Option<Self> {
let eq_attr_sets = match (x, y) {
(Some(x), Some(y)) => Self::union(x, y),
(Some(x), None) => x.clone(),
(None, Some(y)) => y.clone(),
_ => return None,
};
Some(eq_attr_sets)
}
}

/// [`GroupAttrRefs`] represents the attributes of a group in a query.
#[derive(Clone, Debug)]
pub struct GroupAttrRefs {
attribute_refs: AttrRefs,
/// Correlation of the output attributes of the group.
output_correlation: Option<SemanticCorrelation>,
}

impl GroupAttrRefs {
pub fn new(attribute_refs: AttrRefs, output_correlation: Option<SemanticCorrelation>) -> Self {
Self {
attribute_refs,
output_correlation,
}
}

pub fn base_table_attribute_refs(&self) -> &AttrRefs {
&self.attribute_refs
}

pub fn output_correlation(&self) -> Option<&SemanticCorrelation> {
self.output_correlation.as_ref()
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_eq_base_table_attribute_sets() {
let attr1 = BaseTableAttrRef {
table_id: TableId(1),
attr_idx: 1,
};
let attr2 = BaseTableAttrRef {
table_id: TableId(2),
attr_idx: 2,
};
let attr3 = BaseTableAttrRef {
table_id: TableId(3),
attr_idx: 3,
};
let attr4 = BaseTableAttrRef {
table_id: TableId(4),
attr_idx: 4,
};
let pred1 = EqPredicate::new(attr1.clone(), attr2.clone());
let pred2 = EqPredicate::new(attr3.clone(), attr4.clone());
let pred3 = EqPredicate::new(attr1.clone(), attr3.clone());

let mut eq_attr_sets = SemanticCorrelation::new();

// (1, 2)
eq_attr_sets.add_predicate(pred1.clone());
assert!(eq_attr_sets.is_eq(&attr1, &attr2));

// (1, 2), (3, 4)
eq_attr_sets.add_predicate(pred2.clone());
assert!(eq_attr_sets.is_eq(&attr3, &attr4));
assert!(!eq_attr_sets.is_eq(&attr2, &attr3));

let predicates = eq_attr_sets.find_predicates_for_eq_attr_set(&attr1);
assert_eq!(predicates.len(), 1);
assert!(predicates.contains(&pred1));

let predicates = eq_attr_sets.find_predicates_for_eq_attr_set(&attr3);
assert_eq!(predicates.len(), 1);
assert!(predicates.contains(&pred2));

// (1, 2, 3, 4)
eq_attr_sets.add_predicate(pred3.clone());
assert!(eq_attr_sets.is_eq(&attr1, &attr3));
assert!(eq_attr_sets.is_eq(&attr2, &attr4));
assert!(eq_attr_sets.is_eq(&attr1, &attr4));

let predicates = eq_attr_sets.find_predicates_for_eq_attr_set(&attr1);
assert_eq!(predicates.len(), 3);
assert!(predicates.contains(&pred1));
assert!(predicates.contains(&pred2));
assert!(predicates.contains(&pred3));
}
}
23 changes: 23 additions & 0 deletions optd-cost-model/src/common/properties/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
use serde::{Deserialize, Serialize};

use super::predicates::constant_pred::ConstantType;

pub mod attr_ref;
pub mod schema;

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Attribute {
pub name: String,
pub typ: ConstantType,
pub nullable: bool,
}

impl std::fmt::Display for Attribute {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.nullable {
write!(f, "{}:{:?}", self.name, self.typ)
} else {
write!(f, "{}:{:?}(non-null)", self.name, self.typ)
}
}
}
35 changes: 35 additions & 0 deletions optd-cost-model/src/common/properties/schema.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
use itertools::Itertools;

use serde::{Deserialize, Serialize};

use super::Attribute;

/// [`Schema`] represents the schema of a group in the memo. It contains a list of attributes.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Schema {
pub attributes: Vec<Attribute>,
}

impl std::fmt::Display for Schema {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"[{}]",
self.attributes.iter().map(|x| x.to_string()).join(", ")
)
}
}

impl Schema {
pub fn new(attributes: Vec<Attribute>) -> Self {
Self { attributes }
}

pub fn len(&self) -> usize {
self.attributes.len()
}

pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
10 changes: 7 additions & 3 deletions optd-cost-model/src/cost_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use crate::{
nodes::{ArcPredicateNode, PhysicalNodeType},
types::{AttrId, EpochId, ExprId, TableId},
},
memo_ext::MemoExt,
storage::CostModelStorageManager,
ComputeCostContext, Cost, CostModel, CostModelResult, EstimatedStatistic, StatValue,
};
Expand All @@ -20,28 +21,31 @@ use crate::{
pub struct CostModelImpl<S: CostModelStorageLayer> {
storage_manager: CostModelStorageManager<S>,
default_catalog_source: CatalogSource,
_memo: Arc<dyn MemoExt>,
}

impl<S: CostModelStorageLayer> CostModelImpl<S> {
/// TODO: documentation
pub fn new(
storage_manager: CostModelStorageManager<S>,
default_catalog_source: CatalogSource,
memo: Arc<dyn MemoExt>,
) -> Self {
Self {
storage_manager,
default_catalog_source,
_memo: memo,
}
}
}

impl<S: CostModelStorageLayer + std::marker::Sync + 'static> CostModel for CostModelImpl<S> {
impl<S: CostModelStorageLayer + Sync + 'static> CostModel for CostModelImpl<S> {
fn compute_operation_cost(
&self,
node: &PhysicalNodeType,
predicates: &[ArcPredicateNode],
children_stats: &[Option<&EstimatedStatistic>],
context: Option<ComputeCostContext>,
context: ComputeCostContext,
) -> CostModelResult<Cost> {
todo!()
}
Expand All @@ -51,7 +55,7 @@ impl<S: CostModelStorageLayer + std::marker::Sync + 'static> CostModel for CostM
node: PhysicalNodeType,
predicates: &[ArcPredicateNode],
children_statistics: &[Option<&EstimatedStatistic>],
context: Option<ComputeCostContext>,
context: ComputeCostContext,
) -> CostModelResult<EstimatedStatistic> {
todo!()
}
Expand Down
6 changes: 4 additions & 2 deletions optd-cost-model/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ use optd_persistent::{
pub mod common;
pub mod cost;
pub mod cost_model;
pub mod memo_ext;
pub mod stats;
pub mod storage;
pub mod utils;

pub enum StatValue {
Int(i64),
Expand Down Expand Up @@ -63,7 +65,7 @@ pub trait CostModel: 'static + Send + Sync {
node: &PhysicalNodeType,
predicates: &[ArcPredicateNode],
children_stats: &[Option<&EstimatedStatistic>],
context: Option<ComputeCostContext>,
context: ComputeCostContext,
) -> CostModelResult<Cost>;

/// TODO: documentation
Expand All @@ -76,7 +78,7 @@ pub trait CostModel: 'static + Send + Sync {
node: PhysicalNodeType,
predicates: &[ArcPredicateNode],
children_statistics: &[Option<&EstimatedStatistic>],
context: Option<ComputeCostContext>,
context: ComputeCostContext,
) -> CostModelResult<EstimatedStatistic>;

/// TODO: documentation
Expand Down
Loading

0 comments on commit 9ba03e6

Please sign in to comment.