From 5a08b592a68763937bb9a0cdba3aaf938f1125f3 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Wed, 6 Nov 2024 18:37:50 -0500 Subject: [PATCH] add memo trait This commit adds a `Memo` trait and a first draft of an implementation of the `Memo` trait via the backed ORM-mapped database. pushing entities --- optd-cost-model/src/cost/agg.rs | 1 + optd-persistent/src/cost_model/orm.rs | 40 ++-- optd-persistent/src/entities/predicate.rs | 60 +++++ .../src/entities/predicate_children.rs | 34 +++ .../predicate_logical_expression_junction.rs | 46 ++++ .../predicate_physical_expression_junction.rs | 46 ++++ optd-persistent/src/lib.rs | 43 +++- optd-persistent/src/main.rs | 11 + optd-persistent/src/memo/expression.rs | 73 ++++++ optd-persistent/src/memo/interface.rs | 141 ++++++++++++ optd-persistent/src/memo/mod.rs | 4 + optd-persistent/src/memo/orm.rs | 213 ++++++++++++++++++ .../memo/m20241029_000001_cascades_group.rs | 9 + .../m20241029_000001_logical_expression.rs | 2 - .../m20241029_000001_physical_expression.rs | 2 - .../memo/m20241029_000001_predicate.rs | 46 ++++ .../m20241029_000001_predicate_children.rs | 61 +++++ ...1_predicate_logical_expression_junction.rs | 72 ++++++ ..._predicate_physical_expression_junction.rs | 74 ++++++ optd-persistent/src/migrator/memo/mod.rs | 8 + optd-persistent/src/migrator/mod.rs | 4 + schema/all_tables.dbml | 33 ++- 22 files changed, 992 insertions(+), 31 deletions(-) create mode 100644 optd-persistent/src/entities/predicate.rs create mode 100644 optd-persistent/src/entities/predicate_children.rs create mode 100644 optd-persistent/src/entities/predicate_logical_expression_junction.rs create mode 100644 optd-persistent/src/entities/predicate_physical_expression_junction.rs create mode 100644 optd-persistent/src/memo/expression.rs create mode 100644 optd-persistent/src/memo/interface.rs create mode 100644 optd-persistent/src/memo/mod.rs create mode 100644 optd-persistent/src/memo/orm.rs create mode 100644 optd-persistent/src/migrator/memo/m20241029_000001_predicate.rs create mode 100644 optd-persistent/src/migrator/memo/m20241029_000001_predicate_children.rs create mode 100644 optd-persistent/src/migrator/memo/m20241029_000001_predicate_logical_expression_junction.rs create mode 100644 optd-persistent/src/migrator/memo/m20241029_000001_predicate_physical_expression_junction.rs diff --git a/optd-cost-model/src/cost/agg.rs b/optd-cost-model/src/cost/agg.rs index e69de29..8b13789 100644 --- a/optd-cost-model/src/cost/agg.rs +++ b/optd-cost-model/src/cost/agg.rs @@ -0,0 +1 @@ + diff --git a/optd-persistent/src/cost_model/orm.rs b/optd-persistent/src/cost_model/orm.rs index d172c14..5b56476 100644 --- a/optd-persistent/src/cost_model/orm.rs +++ b/optd-persistent/src/cost_model/orm.rs @@ -238,10 +238,13 @@ impl CostModelStorageLayer for BackendManager { match res { Ok(insert_res) => insert_res.last_insert_id, Err(_) => { - return Err(BackendError::BackendError(format!( - "failed to insert statistic {:?} into statistic table", - stat - ))) + return Err(BackendError::CostModel( + format!( + "failed to insert statistic {:?} into statistic table", + stat + ) + .into(), + )) } } } @@ -450,10 +453,13 @@ impl CostModelStorageLayer for BackendManager { .collect::>(); if attr_ids.len() != attr_base_indices.len() { - return Err(BackendError::BackendError(format!( - "Not all attributes found for table_id {} and base indices {:?}", - table_id, attr_base_indices - ))); + return Err(BackendError::CostModel( + format!( + "Not all attributes found for table_id {} and base indices {:?}", + table_id, attr_base_indices + ) + .into(), + )); } self.get_stats_for_attr(attr_ids, stat_type, epoch_id).await @@ -505,10 +511,13 @@ impl CostModelStorageLayer for BackendManager { .one(&self.db) .await?; if expr_exists.is_none() { - return Err(BackendError::BackendError(format!( - "physical expression id {} not found when storing cost", - physical_expression_id - ))); + return Err(BackendError::CostModel( + format!( + "physical expression id {} not found when storing cost", + physical_expression_id + ) + .into(), + )); } // Check if epoch_id exists in Event table @@ -518,10 +527,9 @@ impl CostModelStorageLayer for BackendManager { .await .unwrap(); if epoch_exists.is_none() { - return Err(BackendError::BackendError(format!( - "epoch id {} not found when storing cost", - epoch_id - ))); + return Err(BackendError::CostModel( + format!("epoch id {} not found when storing cost", epoch_id).into(), + )); } let new_cost = plan_cost::ActiveModel { diff --git a/optd-persistent/src/entities/predicate.rs b/optd-persistent/src/entities/predicate.rs new file mode 100644 index 0000000..e142a21 --- /dev/null +++ b/optd-persistent/src/entities/predicate.rs @@ -0,0 +1,60 @@ +//! `SeaORM` Entity, @generated by sea-orm-codegen 1.1.1 + +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq)] +#[sea_orm(table_name = "predicate")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub data: Json, + pub variant: i32, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm(has_many = "super::predicate_logical_expression_junction::Entity")] + PredicateLogicalExpressionJunction, + #[sea_orm(has_many = "super::predicate_physical_expression_junction::Entity")] + PredicatePhysicalExpressionJunction, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::PredicateLogicalExpressionJunction.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::PredicatePhysicalExpressionJunction.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + super::predicate_logical_expression_junction::Relation::LogicalExpression.def() + } + fn via() -> Option { + Some( + super::predicate_logical_expression_junction::Relation::Predicate + .def() + .rev(), + ) + } +} + +impl Related for Entity { + fn to() -> RelationDef { + super::predicate_physical_expression_junction::Relation::PhysicalExpression.def() + } + fn via() -> Option { + Some( + super::predicate_physical_expression_junction::Relation::Predicate + .def() + .rev(), + ) + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/optd-persistent/src/entities/predicate_children.rs b/optd-persistent/src/entities/predicate_children.rs new file mode 100644 index 0000000..93ef3eb --- /dev/null +++ b/optd-persistent/src/entities/predicate_children.rs @@ -0,0 +1,34 @@ +//! `SeaORM` Entity, @generated by sea-orm-codegen 1.1.1 + +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq)] +#[sea_orm(table_name = "predicate_children")] +pub struct Model { + #[sea_orm(primary_key, auto_increment = false)] + pub parent_id: i32, + #[sea_orm(primary_key, auto_increment = false)] + pub child_id: i32, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::predicate::Entity", + from = "Column::ChildId", + to = "super::predicate::Column::Id", + on_update = "Cascade", + on_delete = "Cascade" + )] + Predicate2, + #[sea_orm( + belongs_to = "super::predicate::Entity", + from = "Column::ParentId", + to = "super::predicate::Column::Id", + on_update = "Cascade", + on_delete = "Cascade" + )] + Predicate1, +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/optd-persistent/src/entities/predicate_logical_expression_junction.rs b/optd-persistent/src/entities/predicate_logical_expression_junction.rs new file mode 100644 index 0000000..520da03 --- /dev/null +++ b/optd-persistent/src/entities/predicate_logical_expression_junction.rs @@ -0,0 +1,46 @@ +//! `SeaORM` Entity, @generated by sea-orm-codegen 1.1.1 + +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq)] +#[sea_orm(table_name = "predicate_logical_expression_junction")] +pub struct Model { + #[sea_orm(primary_key, auto_increment = false)] + pub logical_expr_id: i32, + #[sea_orm(primary_key, auto_increment = false)] + pub predicate_id: i32, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::logical_expression::Entity", + from = "Column::LogicalExprId", + to = "super::logical_expression::Column::Id", + on_update = "Cascade", + on_delete = "Cascade" + )] + LogicalExpression, + #[sea_orm( + belongs_to = "super::predicate::Entity", + from = "Column::PredicateId", + to = "super::predicate::Column::Id", + on_update = "Cascade", + on_delete = "Cascade" + )] + Predicate, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::LogicalExpression.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Predicate.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/optd-persistent/src/entities/predicate_physical_expression_junction.rs b/optd-persistent/src/entities/predicate_physical_expression_junction.rs new file mode 100644 index 0000000..263abc8 --- /dev/null +++ b/optd-persistent/src/entities/predicate_physical_expression_junction.rs @@ -0,0 +1,46 @@ +//! `SeaORM` Entity, @generated by sea-orm-codegen 1.1.1 + +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq)] +#[sea_orm(table_name = "predicate_physical_expression_junction")] +pub struct Model { + #[sea_orm(primary_key, auto_increment = false)] + pub physical_expr_id: i32, + #[sea_orm(primary_key, auto_increment = false)] + pub predicate_id: i32, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::physical_expression::Entity", + from = "Column::PhysicalExprId", + to = "super::physical_expression::Column::Id", + on_update = "Cascade", + on_delete = "Cascade" + )] + PhysicalExpression, + #[sea_orm( + belongs_to = "super::predicate::Entity", + from = "Column::PredicateId", + to = "super::predicate::Column::Id", + on_update = "Cascade", + on_delete = "Cascade" + )] + Predicate, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::PhysicalExpression.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Predicate.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/optd-persistent/src/lib.rs b/optd-persistent/src/lib.rs index 9bac1b6..dd95da2 100644 --- a/optd-persistent/src/lib.rs +++ b/optd-persistent/src/lib.rs @@ -13,6 +13,9 @@ mod migrator; pub mod cost_model; pub use cost_model::interface::CostModelStorageLayer; +mod memo; +pub use memo::interface::Memo; + /// The filename of the SQLite database for migration. pub const DATABASE_FILENAME: &str = "sqlite.db"; /// The URL of the SQLite database for migration. @@ -39,17 +42,48 @@ fn get_sqlite_url(file: &str) -> String { format!("sqlite:{}?mode=rwc", file) } -pub type StorageResult = Result; +#[derive(Debug)] +pub enum CostModelError { + // TODO: Add more error types + UnknownStatisticType, + VersionedStatisticNotFound, + CustomError(String), +} +/// TODO convert this to `thiserror` +#[derive(Debug)] +/// The different kinds of errors that might occur while running operations on a memo table. +pub enum MemoError { + UnknownGroup, + UnknownLogicalExpression, + UnknownPhysicalExpression, + InvalidExpression, +} + +/// TODO convert this to `thiserror` #[derive(Debug)] pub enum BackendError { + Memo(MemoError), DatabaseError(DbErr), + CostModel(CostModelError), BackendError(String), } -impl From for BackendError { +impl From for CostModelError { fn from(value: String) -> Self { - BackendError::BackendError(value) + CostModelError::CustomError(value) + } +} + +impl From for BackendError { + fn from(value: CostModelError) -> Self { + BackendError::CostModel(value) + } +} + +impl From for BackendError { + fn from(value: MemoError) -> Self { + BackendError::Memo(value) } } @@ -59,6 +93,9 @@ impl From for BackendError { } } +/// A type alias for a result with [`BackendError`] as the error type. +pub type StorageResult = Result; + pub struct BackendManager { db: DatabaseConnection, } diff --git a/optd-persistent/src/main.rs b/optd-persistent/src/main.rs index e05fb59..165189e 100644 --- a/optd-persistent/src/main.rs +++ b/optd-persistent/src/main.rs @@ -17,6 +17,17 @@ use optd_persistent::DATABASE_URL; #[tokio::main] async fn main() { + basic_demo().await; + memo_demo().await; +} + +async fn memo_demo() { + let _db = Database::connect(DATABASE_URL).await.unwrap(); + + todo!() +} + +async fn basic_demo() { let db = Database::connect(DATABASE_URL).await.unwrap(); // Create a new `CascadesGroup`. diff --git a/optd-persistent/src/memo/expression.rs b/optd-persistent/src/memo/expression.rs new file mode 100644 index 0000000..ff1590c --- /dev/null +++ b/optd-persistent/src/memo/expression.rs @@ -0,0 +1,73 @@ +use crate::entities::*; +use std::hash::{DefaultHasher, Hash, Hasher}; + +/// All of the different types of fixed logical operators. +/// +/// Note that there could be more operators that the memo table must support that are not enumerated +/// in this enum, as there can be up to `2^16` different types of operators. +#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)] +#[non_exhaustive] +#[repr(i16)] +pub enum LogicalOperator { + Scan, + Join, +} + +/// All of the different types of fixed physical operators. +/// +/// Note that there could be more operators that the memo table must support that are not enumerated +/// in this enum, as there can be up to `2^16` different types of operators. +#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)] +#[non_exhaustive] +#[repr(i16)] +pub enum PhysicalOperator { + TableScan, + IndexScan, + NestedLoopJoin, + HashJoin, +} + +/// A method to generate a fingerprint used to efficiently check if two +/// expressions are equivalent. +/// +/// TODO actually make efficient. +fn fingerprint(variant_tag: i16, data: &serde_json::Value) -> i64 { + let mut hasher = DefaultHasher::new(); + + variant_tag.hash(&mut hasher); + data.hash(&mut hasher); + + hasher.finish() as i64 +} + +impl logical_expression::Model { + /// Creates a new logical expression with an unset `id` and `group_id`. + pub fn new(variant_tag: LogicalOperator, data: serde_json::Value) -> Self { + let tag = variant_tag as i16; + let fingerprint = fingerprint(tag, &data); + + Self { + id: 0, + group_id: 0, + fingerprint, + variant_tag: tag, + data, + } + } +} + +impl physical_expression::Model { + /// Creates a new physical expression with an unset `id` and `group_id`. + pub fn new(variant_tag: PhysicalOperator, data: serde_json::Value) -> Self { + let tag = variant_tag as i16; + let fingerprint = fingerprint(tag, &data); + + Self { + id: 0, + group_id: 0, + fingerprint, + variant_tag: tag, + data, + } + } +} diff --git a/optd-persistent/src/memo/interface.rs b/optd-persistent/src/memo/interface.rs new file mode 100644 index 0000000..2d2240e --- /dev/null +++ b/optd-persistent/src/memo/interface.rs @@ -0,0 +1,141 @@ +use crate::StorageResult; + +/// A trait representing an implementation of a memoization table. +/// +/// Note that we use [`trait_variant`] here in order to add bounds on every method. +/// See this [blog post]( +/// https://blog.rust-lang.org/2023/12/21/async-fn-rpit-in-traits.html#async-fn-in-public-traits) +/// for more information. +/// +/// TODO Figure out for each when to get the ID of a record or the entire record itself. +#[trait_variant::make(Send)] +pub trait Memo { + /// A type representing a group in the Cascades framework. + type Group; + /// A type representing a unique identifier for a group. + type GroupId; + /// A type representing a logical expression. + type LogicalExpression; + /// A type representing a unique identifier for a logical expression. + type LogicalExpressionId; + /// A type representing a physical expression. + type PhysicalExpression; + /// A type representing a unique identifier for a physical expression. + type PhysicalExpressionId; + + /// Retrieves a [`Self::Group`] given a [`Self::GroupId`]. + /// + /// If the group does not exist, returns a [`MemoError::UnknownGroup`] error. + async fn get_group(&self, group_id: Self::GroupId) -> StorageResult; + + /// Retrieves all group IDs that are stored in the memo table. + async fn get_all_groups(&self) -> StorageResult>; + + /// Retrieves a [`Self::LogicalExpression`] given a [`Self::LogicalExpressionId`]. + /// + /// If the logical expression does not exist, returns a [`MemoError::UnknownLogicalExpression`] + /// error. + async fn get_logical_expression( + &self, + logical_expression_id: Self::LogicalExpressionId, + ) -> StorageResult; + + /// Retrieves a [`Self::PhysicalExpression`] given a [`Self::PhysicalExpressionId`]. + /// + /// If the physical expression does not exist, returns a + /// [`MemoError::UnknownPhysicalExpression`] error. + async fn get_physical_expression( + &self, + physical_expression_id: Self::PhysicalExpressionId, + ) -> StorageResult; + + /// Retrieves the parent group ID of a logical expression given its expression ID. + /// + /// If the logical expression does not exist, returns a [`MemoError::UnknownLogicalExpression`] + /// error. + async fn get_group_from_logical_expression( + &self, + logical_expression_id: Self::LogicalExpressionId, + ) -> StorageResult; + + /// Retrieves the parent group ID of a logical expression given its expression ID. + /// + /// If the physical expression does not exist, returns a + /// [`MemoError::UnknownPhysicalExpression`] error. + async fn get_group_from_physical_expression( + &self, + physical_expression_id: Self::PhysicalExpressionId, + ) -> StorageResult; + + /// Retrieves all of the logical expression "children" of a group. + /// + /// If the group does not exist, returns a [`MemoError::UnknownGroup`] error. + async fn get_group_logical_expressions( + &self, + group_id: Self::GroupId, + ) -> StorageResult>; + + /// Retrieves all of the physical expression "children" of a group. + /// + /// If the group does not exist, returns a [`MemoError::UnknownGroup`] error. + async fn get_group_physical_expressions( + &self, + group_id: Self::GroupId, + ) -> StorageResult>; + + /// Retrieves the best physical query plan (winner) for a given group. + /// + /// If the group does not exist, returns a [`MemoError::UnknownGroup`] error. + async fn get_winner( + &self, + group_id: Self::GroupId, + ) -> StorageResult>; + + /// Updates / replaces a group's best physical plan (winner). Optionally returns the previous + /// winner's physical expression ID. + /// + /// If the group does not exist, returns a [`MemoError::UnknownGroup`] error. + async fn update_group_winner( + &self, + group_id: Self::GroupId, + physical_expression_id: Self::PhysicalExpressionId, + ) -> StorageResult>; + + /// Adds a logical expression to an existing group via its [`Self::GroupId`]. + /// + /// If the group does not exist, returns a [`MemoError::UnknownGroup`] error. + async fn add_logical_expression_to_group( + &self, + group_id: Self::GroupId, + logical_expression: Self::LogicalExpression, + children: Vec, + ) -> StorageResult<()>; + + /// Adds a physical expression to an existing group via its [`Self::GroupId`]. + /// + /// If the group does not exist, returns a [`MemoError::UnknownGroup`] error. + async fn add_physical_expression_to_group( + &self, + group_id: Self::GroupId, + physical_expression: Self::PhysicalExpression, + children: Vec, + ) -> StorageResult<()>; + + /// Adds a new logical expression into the memo table, creating a new group if the expression + /// does not already exist. + /// + /// The [`Self::LogicalExpression`] type should have some sort of mechanism for checking if + /// the expression has been seen before, and if it has already been created, then the parent + /// group ID should also be retrievable. + /// + /// If the expression already exists, then this function will return the [`Self::GroupId`] of + /// the parent group and the corresponding (already existing) [`Self::LogicalExpressionId`]. + /// + /// If the expression does not exist, this function will create a new group and a new + /// expression, returning brand new IDs for both. + async fn add_logical_expression( + &self, + expression: Self::LogicalExpression, + children: Vec, + ) -> StorageResult<(Self::GroupId, Self::LogicalExpressionId)>; +} diff --git a/optd-persistent/src/memo/mod.rs b/optd-persistent/src/memo/mod.rs new file mode 100644 index 0000000..c455782 --- /dev/null +++ b/optd-persistent/src/memo/mod.rs @@ -0,0 +1,4 @@ +mod expression; + +pub mod interface; +pub mod orm; diff --git a/optd-persistent/src/memo/orm.rs b/optd-persistent/src/memo/orm.rs new file mode 100644 index 0000000..6d4ec31 --- /dev/null +++ b/optd-persistent/src/memo/orm.rs @@ -0,0 +1,213 @@ +use crate::{ + entities::{prelude::*, *}, + BackendManager, {Memo, MemoError, StorageResult}, +}; +use sea_orm::*; + +impl Memo for BackendManager { + type Group = cascades_group::Model; + type GroupId = i32; + type LogicalExpression = logical_expression::Model; + type LogicalExpressionId = i32; + type PhysicalExpression = physical_expression::Model; + type PhysicalExpressionId = i32; + + async fn get_group(&self, group_id: Self::GroupId) -> StorageResult { + Ok(CascadesGroup::find_by_id(group_id) + .one(&self.db) + .await? + .ok_or(MemoError::UnknownGroup)?) + } + + async fn get_all_groups(&self) -> StorageResult> { + Ok(CascadesGroup::find().all(&self.db).await?) + } + + async fn get_logical_expression( + &self, + logical_expression_id: Self::LogicalExpressionId, + ) -> StorageResult { + Ok(LogicalExpression::find_by_id(logical_expression_id) + .one(&self.db) + .await? + .ok_or(MemoError::UnknownLogicalExpression)?) + } + + async fn get_physical_expression( + &self, + physical_expression_id: Self::PhysicalExpressionId, + ) -> StorageResult { + Ok(PhysicalExpression::find_by_id(physical_expression_id) + .one(&self.db) + .await? + .ok_or(MemoError::UnknownPhysicalExpression)?) + } + + async fn get_group_from_logical_expression( + &self, + logical_expression_id: Self::LogicalExpressionId, + ) -> StorageResult { + // Find the logical expression and then look up the field. + Ok(self + .get_logical_expression(logical_expression_id) + .await? + .group_id) + } + + async fn get_group_from_physical_expression( + &self, + physical_expression_id: Self::PhysicalExpressionId, + ) -> StorageResult { + Ok(self + .get_physical_expression(physical_expression_id) + .await? + .group_id) + } + + async fn get_group_logical_expressions( + &self, + group_id: Self::GroupId, + ) -> StorageResult> { + // First retrieve the group record, and then find all related logical expressions. + Ok(self + .get_group(group_id) + .await? + .find_related(LogicalExpression) + .all(&self.db) + .await?) + } + + async fn get_group_physical_expressions( + &self, + group_id: Self::GroupId, + ) -> StorageResult> { + // First retrieve the group record, and then find all related physical expressions. + Ok(self + .get_group(group_id) + .await? + .find_related(PhysicalExpression) + .all(&self.db) + .await?) + } + + async fn get_winner( + &self, + group_id: Self::GroupId, + ) -> StorageResult> { + Ok(self.get_group(group_id).await?.latest_winner) + } + + async fn update_group_winner( + &self, + group_id: Self::GroupId, + physical_expression_id: Self::PhysicalExpressionId, + ) -> StorageResult> { + // First retrieve the group record, and then use an `ActiveModel` to update it. + let mut group = self.get_group(group_id).await?.into_active_model(); + let old_id = group.latest_winner; + + group.latest_winner = Set(Some(physical_expression_id)); + group.update(&self.db).await?; + + // The old value must be set (`None` still means it has been set). + let old = old_id.unwrap(); + Ok(old) + } + + async fn add_logical_expression_to_group( + &self, + group_id: Self::GroupId, + logical_expression: Self::LogicalExpression, + _children: Vec, + ) -> StorageResult<()> { + if logical_expression.group_id != group_id { + Err(MemoError::InvalidExpression)? + } + + // Check if the group actually exists. + let _ = self.get_group(group_id).await?; + + // Insert the expression. + let _ = logical_expression + .into_active_model() + .insert(&self.db) + .await?; + + todo!("add the children of the logical expression into the children table") + } + + async fn add_physical_expression_to_group( + &self, + group_id: Self::GroupId, + physical_expression: Self::PhysicalExpression, + _children: Vec, + ) -> StorageResult<()> { + if physical_expression.group_id != group_id { + Err(MemoError::InvalidExpression)? + } + + // Check if the group actually exists. + let _ = self.get_group(group_id).await?; + + // Insert the expression. + let _ = physical_expression + .into_active_model() + .insert(&self.db) + .await?; + + todo!("add the children of the logical expression into the children table") + } + + /// Note that in this function, we ignore the group ID that the logical expression contains. + async fn add_logical_expression( + &self, + expression: Self::LogicalExpression, + _children: Vec, + ) -> StorageResult<(Self::GroupId, Self::LogicalExpressionId)> { + // Lookup all expressions that have the same fingerprint. There may be false positives, but + // we will check for those later. + let fingerprint = expression.fingerprint; + let potential_matches = LogicalExpression::find() + .filter(logical_expression::Column::Fingerprint.eq(fingerprint)) + .all(&self.db) + .await?; + + // Of the expressions that have the same fingerprint, check if there already exists an + // expression that is exactly identical to the input expression. + let mut matches: Vec<_> = potential_matches + .into_iter() + .filter(|expr| expr == &expression) + .collect(); + assert!( + matches.len() <= 1, + "there cannot be more than 1 exact logical expression match" + ); + + // The expression already exists, so return its data. + if !matches.is_empty() { + let existing_expression = matches + .pop() + .expect("we just checked that an element exists"); + + return Ok((existing_expression.group_id, existing_expression.id)); + } + + // The expression does not exist yet, so we need to create a new group and new expression. + let group = cascades_group::ActiveModel { + latest_winner: Set(None), + in_progress: Set(false), + is_optimized: Set(false), + ..Default::default() + }; + + // Insert a new group. + let res = cascades_group::Entity::insert(group).exec(&self.db).await?; + + // Insert the input expression with the correct `group_id`. + let mut new_expr = expression.into_active_model(); + new_expr.group_id = Set(res.last_insert_id); + let new_expr = new_expr.insert(&self.db).await?; + + Ok((new_expr.group_id, new_expr.id)) + } +} diff --git a/optd-persistent/src/migrator/memo/m20241029_000001_cascades_group.rs b/optd-persistent/src/migrator/memo/m20241029_000001_cascades_group.rs index fc7b198..1700772 100644 --- a/optd-persistent/src/migrator/memo/m20241029_000001_cascades_group.rs +++ b/optd-persistent/src/migrator/memo/m20241029_000001_cascades_group.rs @@ -75,6 +75,7 @@ pub enum CascadesGroup { LatestWinner, InProgress, IsOptimized, + ParentId, } #[derive(DeriveMigrationName)] @@ -99,6 +100,14 @@ impl MigrationTrait for Migration { ) .col(boolean(CascadesGroup::InProgress)) .col(boolean(CascadesGroup::IsOptimized)) + .col(integer_null(CascadesGroup::ParentId)) + .foreign_key( + ForeignKey::create() + .from(CascadesGroup::Table, CascadesGroup::ParentId) + .to(CascadesGroup::Table, CascadesGroup::Id) + .on_delete(ForeignKeyAction::SetNull) + .on_update(ForeignKeyAction::Cascade), + ) .to_owned(), ) .await diff --git a/optd-persistent/src/migrator/memo/m20241029_000001_logical_expression.rs b/optd-persistent/src/migrator/memo/m20241029_000001_logical_expression.rs index a8369b0..71e28d3 100644 --- a/optd-persistent/src/migrator/memo/m20241029_000001_logical_expression.rs +++ b/optd-persistent/src/migrator/memo/m20241029_000001_logical_expression.rs @@ -45,7 +45,6 @@ pub enum LogicalExpression { GroupId, Fingerprint, VariantTag, - Data, } #[derive(DeriveMigrationName)] @@ -70,7 +69,6 @@ impl MigrationTrait for Migration { ) .col(big_unsigned(LogicalExpression::Fingerprint)) .col(small_integer(LogicalExpression::VariantTag)) - .col(json(LogicalExpression::Data)) .to_owned(), ) .await diff --git a/optd-persistent/src/migrator/memo/m20241029_000001_physical_expression.rs b/optd-persistent/src/migrator/memo/m20241029_000001_physical_expression.rs index dc9a3ab..8f7cb96 100644 --- a/optd-persistent/src/migrator/memo/m20241029_000001_physical_expression.rs +++ b/optd-persistent/src/migrator/memo/m20241029_000001_physical_expression.rs @@ -46,7 +46,6 @@ pub enum PhysicalExpression { GroupId, Fingerprint, VariantTag, - Data, } #[derive(DeriveMigrationName)] @@ -71,7 +70,6 @@ impl MigrationTrait for Migration { ) .col(big_unsigned(PhysicalExpression::Fingerprint)) .col(small_integer(PhysicalExpression::VariantTag)) - .col(json(PhysicalExpression::Data)) .to_owned(), ) .await diff --git a/optd-persistent/src/migrator/memo/m20241029_000001_predicate.rs b/optd-persistent/src/migrator/memo/m20241029_000001_predicate.rs new file mode 100644 index 0000000..b4a58b9 --- /dev/null +++ b/optd-persistent/src/migrator/memo/m20241029_000001_predicate.rs @@ -0,0 +1,46 @@ +/* +Table predicate { + id integer [pk] + data json + variant integer +} +*/ + +use sea_orm_migration::{ + prelude::*, + schema::{integer, json, pk_auto}, +}; + +#[derive(Iden)] +pub enum Predicate { + Table, + Id, + Data, + Variant, +} + +#[derive(DeriveMigrationName)] +pub struct Migration; + +#[async_trait::async_trait] +impl MigrationTrait for Migration { + async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> { + manager + .create_table( + Table::create() + .table(Predicate::Table) + .if_not_exists() + .col(pk_auto(Predicate::Id)) + .col(json(Predicate::Data)) + .col(integer(Predicate::Variant)) + .to_owned(), + ) + .await + } + + async fn down(&self, manager: &SchemaManager) -> Result<(), DbErr> { + manager + .drop_table(Table::drop().table(Predicate::Table).to_owned()) + .await + } +} diff --git a/optd-persistent/src/migrator/memo/m20241029_000001_predicate_children.rs b/optd-persistent/src/migrator/memo/m20241029_000001_predicate_children.rs new file mode 100644 index 0000000..f3c0d10 --- /dev/null +++ b/optd-persistent/src/migrator/memo/m20241029_000001_predicate_children.rs @@ -0,0 +1,61 @@ +/* +Table predicate_children { + parent_id integer [ref: > predicate.id] + child_id integer [ref: > predicate.id] +} + */ + +use sea_orm_migration::{prelude::*, schema::integer}; + +use super::m20241029_000001_predicate::Predicate; + +#[derive(Iden)] +pub enum PredicateChildren { + Table, + ParentId, + ChildId, +} + +#[derive(DeriveMigrationName)] +pub struct Migration; + +#[async_trait::async_trait] +impl MigrationTrait for Migration { + async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> { + manager + .create_table( + Table::create() + .table(PredicateChildren::Table) + .if_not_exists() + .col(integer(PredicateChildren::ParentId)) + .foreign_key( + ForeignKey::create() + .from(PredicateChildren::Table, PredicateChildren::ParentId) + .to(Predicate::Table, Predicate::Id) + .on_delete(ForeignKeyAction::Cascade) + .on_update(ForeignKeyAction::Cascade), + ) + .col(integer(PredicateChildren::ChildId)) + .foreign_key( + ForeignKey::create() + .from(PredicateChildren::Table, PredicateChildren::ChildId) + .to(Predicate::Table, Predicate::Id) + .on_delete(ForeignKeyAction::Cascade) + .on_update(ForeignKeyAction::Cascade), + ) + .primary_key( + Index::create() + .col(PredicateChildren::ParentId) + .col(PredicateChildren::ChildId), + ) + .to_owned(), + ) + .await + } + + async fn down(&self, manager: &SchemaManager) -> Result<(), DbErr> { + manager + .drop_table(Table::drop().table(PredicateChildren::Table).to_owned()) + .await + } +} diff --git a/optd-persistent/src/migrator/memo/m20241029_000001_predicate_logical_expression_junction.rs b/optd-persistent/src/migrator/memo/m20241029_000001_predicate_logical_expression_junction.rs new file mode 100644 index 0000000..ab901d6 --- /dev/null +++ b/optd-persistent/src/migrator/memo/m20241029_000001_predicate_logical_expression_junction.rs @@ -0,0 +1,72 @@ +/* +Table predicate_logical_expression_junction { + logical_expr_id integer [ref: > logical_expression.id] + predicate_id integer [ref: > predicate.id] +} + */ + +use sea_orm_migration::{prelude::*, schema::integer}; + +use super::{ + m20241029_000001_logical_expression::LogicalExpression, m20241029_000001_predicate::Predicate, +}; + +#[derive(Iden)] +pub enum PredicateLogicalExpressionJunction { + Table, + LogicalExprId, + PredicateId, +} + +#[derive(DeriveMigrationName)] +pub struct Migration; + +#[async_trait::async_trait] +impl MigrationTrait for Migration { + async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> { + manager + .create_table( + Table::create() + .table(PredicateLogicalExpressionJunction::Table) + .col(integer(PredicateLogicalExpressionJunction::LogicalExprId)) + .foreign_key( + ForeignKey::create() + .from( + PredicateLogicalExpressionJunction::Table, + PredicateLogicalExpressionJunction::LogicalExprId, + ) + .to(LogicalExpression::Table, LogicalExpression::Id) + .on_delete(ForeignKeyAction::Cascade) + .on_update(ForeignKeyAction::Cascade), + ) + .col(integer(PredicateLogicalExpressionJunction::PredicateId)) + .foreign_key( + ForeignKey::create() + .from( + PredicateLogicalExpressionJunction::Table, + PredicateLogicalExpressionJunction::PredicateId, + ) + .to(Predicate::Table, Predicate::Id) + .on_delete(ForeignKeyAction::Cascade) + .on_update(ForeignKeyAction::Cascade), + ) + .primary_key( + Index::create() + .col(PredicateLogicalExpressionJunction::LogicalExprId) + .col(PredicateLogicalExpressionJunction::PredicateId), + ) + .to_owned(), + ) + .await + } + + async fn down(&self, manager: &SchemaManager) -> Result<(), DbErr> { + manager + .drop_table( + Table::drop() + .table(PredicateLogicalExpressionJunction::Table) + .to_owned(), + ) + .await + } +} diff --git a/optd-persistent/src/migrator/memo/m20241029_000001_predicate_physical_expression_junction.rs b/optd-persistent/src/migrator/memo/m20241029_000001_predicate_physical_expression_junction.rs new file mode 100644 index 0000000..8d82831 --- /dev/null +++ b/optd-persistent/src/migrator/memo/m20241029_000001_predicate_physical_expression_junction.rs @@ -0,0 +1,74 @@ +/* +Table predicate_physical_expression_junction { + physical_expr_id integer [ref: > physical_expression.id] + predicate_id integer [ref: > predicate.id] +} + */ + +use sea_orm_migration::{prelude::*, schema::integer}; + +use super::{ + m20241029_000001_physical_expression::PhysicalExpression, m20241029_000001_predicate::Predicate, +}; + +#[derive(Iden)] +pub enum PredicatePhysicalExpressionJunction { + Table, + PhysicalExprId, + PredicateId, +} + +#[derive(DeriveMigrationName)] +pub struct Migration; + +#[async_trait::async_trait] +impl MigrationTrait for Migration { + async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> { + manager + .create_table( + Table::create() + .table(PredicatePhysicalExpressionJunction::Table) + .col(integer(PredicatePhysicalExpressionJunction::PhysicalExprId)) + .foreign_key( + ForeignKey::create() + .name("predicate_physical_expression_junction_physical_expr_id_fkey") + .from( + PredicatePhysicalExpressionJunction::Table, + PredicatePhysicalExpressionJunction::PhysicalExprId, + ) + .to(PhysicalExpression::Table, PhysicalExpression::Id) + .on_delete(ForeignKeyAction::Cascade) + .on_update(ForeignKeyAction::Cascade), + ) + .col(integer(PredicatePhysicalExpressionJunction::PredicateId)) + .foreign_key( + ForeignKey::create() + .name("predicate_physical_expression_junction_predicate_id_fkey") + .from( + PredicatePhysicalExpressionJunction::Table, + PredicatePhysicalExpressionJunction::PredicateId, + ) + .to(Predicate::Table, Predicate::Id) + .on_delete(ForeignKeyAction::Cascade) + .on_update(ForeignKeyAction::Cascade), + ) + .primary_key( + Index::create() + .col(PredicatePhysicalExpressionJunction::PhysicalExprId) + .col(PredicatePhysicalExpressionJunction::PredicateId), + ) + .to_owned(), + ) + .await + } + + async fn down(&self, manager: &SchemaManager) -> Result<(), DbErr> { + manager + .drop_table( + Table::drop() + .table(PredicatePhysicalExpressionJunction::Table) + .to_owned(), + ) + .await + } +} diff --git a/optd-persistent/src/migrator/memo/mod.rs b/optd-persistent/src/migrator/memo/mod.rs index f15b437..b4caabf 100644 --- a/optd-persistent/src/migrator/memo/mod.rs +++ b/optd-persistent/src/migrator/memo/mod.rs @@ -9,6 +9,10 @@ pub(crate) mod m20241029_000001_logical_property; pub(crate) mod m20241029_000001_physical_children; pub(crate) mod m20241029_000001_physical_expression; pub(crate) mod m20241029_000001_physical_property; +pub(crate) mod m20241029_000001_predicate; +pub(crate) mod m20241029_000001_predicate_children; +pub(crate) mod m20241029_000001_predicate_logical_expression_junction; +pub(crate) mod m20241029_000001_predicate_physical_expression_junction; pub(crate) use m20241029_000001_cascades_group as cascades_group; pub(crate) use m20241029_000001_group_winner as group_winner; @@ -18,3 +22,7 @@ pub(crate) use m20241029_000001_logical_property as logical_property; pub(crate) use m20241029_000001_physical_children as physical_children; pub(crate) use m20241029_000001_physical_expression as physical_expression; pub(crate) use m20241029_000001_physical_property as physical_property; +pub(crate) use m20241029_000001_predicate as predicate; +pub(crate) use m20241029_000001_predicate_children as predicate_children; +pub(crate) use m20241029_000001_predicate_logical_expression_junction as predicate_logical_expression_junction; +pub(crate) use m20241029_000001_predicate_physical_expression_junction as predicate_physical_expression_junction; diff --git a/optd-persistent/src/migrator/mod.rs b/optd-persistent/src/migrator/mod.rs index 2571143..468f7d0 100644 --- a/optd-persistent/src/migrator/mod.rs +++ b/optd-persistent/src/migrator/mod.rs @@ -33,6 +33,10 @@ impl MigratorTrait for Migrator { Box::new(memo::physical_expression::Migration), Box::new(memo::physical_children::Migration), Box::new(memo::physical_property::Migration), + Box::new(memo::predicate::Migration), + Box::new(memo::predicate_children::Migration), + Box::new(memo::predicate_logical_expression_junction::Migration), + Box::new(memo::predicate_physical_expression_junction::Migration), ] } } diff --git a/schema/all_tables.dbml b/schema/all_tables.dbml index 091115d..305075a 100644 --- a/schema/all_tables.dbml +++ b/schema/all_tables.dbml @@ -137,7 +137,7 @@ Table logical_expression { group_id integer [ref: > cascades_group.id] fingerprint integer variant_tag integer - data json + predicate integer [ref: > predicate.id] } Table cascades_group { @@ -145,6 +145,7 @@ Table cascades_group { latest_winner integer [ref: > physical_expression.id, null] in_progress boolean is_optimized boolean + parent integer [ref: > cascades_group.id] } // Physical expressions and properties @@ -153,7 +154,7 @@ Table physical_expression { group_id integer [ref: > cascades_group.id] fingerprint integer variant_tag integer - data json + predicate integer [ref: > predicate.id] } Table physical_property { @@ -187,13 +188,29 @@ Table group_winner { id integer [pk] group_id integer [ref: > cascades_group.id] physical_expression_id integer [ref: > physical_expression.id] - cost_id integer [ref: > plan_cost.id] + cost integer epoch_id integer [ref: > event.epoch_id] } +Table predicate { + id integer [pk] + data json + variant integer +} + +Table predicate_children { + parent_id integer [ref: > predicate.id] + child_id integer [ref: > predicate.id] +} + +Table predicate_logical_expression_junction { + logical_expr_id integer [ref: > logical_expression.id] + predicate_id integer [ref: > predicate.id] +} + +Table predicate_physical_expression_junction { + physical_expr_id integer [ref: > physical_expression.id] + predicate_id integer [ref: > predicate.id] +} + -// Notes: -// - All columns are NOT NULL unless specified otherwise -// - fingerprint represents a hash of the actual data in the logical and physical expression tuple -// - Each new event inserted into events table has its own epoch_id -// - `cascades_group.latest_winner` is an optimization to avoid querying entire group_winner table