diff --git a/optd-persistent/src/lib.rs b/optd-persistent/src/lib.rs index 2638940..340a07a 100644 --- a/optd-persistent/src/lib.rs +++ b/optd-persistent/src/lib.rs @@ -13,6 +13,9 @@ mod migrator; pub mod cost_model; pub use cost_model::interface::CostModelStorageLayer; +mod memo; +pub use memo::interface::Memo; + /// The filename of the SQLite database for migration. pub const DATABASE_FILENAME: &str = "sqlite.db"; /// The URL of the SQLite database for migration. @@ -39,8 +42,6 @@ fn get_sqlite_url(file: &str) -> String { format!("sqlite:{}?mode=rwc", file) } -pub type StorageResult = Result; - #[derive(Debug)] pub enum CostModelError { // TODO: Add more error types @@ -48,9 +49,22 @@ pub enum CostModelError { VersionedStatisticNotFound, } +/// TODO convert this to `thiserror` +#[derive(Debug)] +/// The different kinds of errors that might occur while running operations on a memo table. +pub enum MemoError { + UnknownGroup, + UnknownLogicalExpression, + UnknownPhysicalExpression, + InvalidExpression, + Database(DbErr), +} + +/// TODO convert this to `thiserror` #[derive(Debug)] pub enum BackendError { CostModel(CostModelError), + Memo(MemoError), Database(DbErr), // TODO: Add other variants as needed for different error types } @@ -61,12 +75,27 @@ impl From for BackendError { } } +impl From for BackendError { + fn from(value: MemoError) -> Self { + BackendError::Memo(value) + } +} + impl From for BackendError { fn from(value: DbErr) -> Self { BackendError::Database(value) } } +impl From for MemoError { + fn from(value: DbErr) -> Self { + MemoError::Database(value) + } +} + +/// A type alias for a result with [`BackendError`] as the error type. +pub type StorageResult = Result; + pub struct BackendManager { db: DatabaseConnection, } diff --git a/optd-persistent/src/main.rs b/optd-persistent/src/main.rs index e05fb59..165189e 100644 --- a/optd-persistent/src/main.rs +++ b/optd-persistent/src/main.rs @@ -17,6 +17,17 @@ use optd_persistent::DATABASE_URL; #[tokio::main] async fn main() { + basic_demo().await; + memo_demo().await; +} + +async fn memo_demo() { + let _db = Database::connect(DATABASE_URL).await.unwrap(); + + todo!() +} + +async fn basic_demo() { let db = Database::connect(DATABASE_URL).await.unwrap(); // Create a new `CascadesGroup`. diff --git a/optd-persistent/src/memo/expression.rs b/optd-persistent/src/memo/expression.rs new file mode 100644 index 0000000..ff1590c --- /dev/null +++ b/optd-persistent/src/memo/expression.rs @@ -0,0 +1,73 @@ +use crate::entities::*; +use std::hash::{DefaultHasher, Hash, Hasher}; + +/// All of the different types of fixed logical operators. +/// +/// Note that there could be more operators that the memo table must support that are not enumerated +/// in this enum, as there can be up to `2^16` different types of operators. +#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)] +#[non_exhaustive] +#[repr(i16)] +pub enum LogicalOperator { + Scan, + Join, +} + +/// All of the different types of fixed physical operators. +/// +/// Note that there could be more operators that the memo table must support that are not enumerated +/// in this enum, as there can be up to `2^16` different types of operators. +#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)] +#[non_exhaustive] +#[repr(i16)] +pub enum PhysicalOperator { + TableScan, + IndexScan, + NestedLoopJoin, + HashJoin, +} + +/// A method to generate a fingerprint used to efficiently check if two +/// expressions are equivalent. +/// +/// TODO actually make efficient. +fn fingerprint(variant_tag: i16, data: &serde_json::Value) -> i64 { + let mut hasher = DefaultHasher::new(); + + variant_tag.hash(&mut hasher); + data.hash(&mut hasher); + + hasher.finish() as i64 +} + +impl logical_expression::Model { + /// Creates a new logical expression with an unset `id` and `group_id`. + pub fn new(variant_tag: LogicalOperator, data: serde_json::Value) -> Self { + let tag = variant_tag as i16; + let fingerprint = fingerprint(tag, &data); + + Self { + id: 0, + group_id: 0, + fingerprint, + variant_tag: tag, + data, + } + } +} + +impl physical_expression::Model { + /// Creates a new physical expression with an unset `id` and `group_id`. + pub fn new(variant_tag: PhysicalOperator, data: serde_json::Value) -> Self { + let tag = variant_tag as i16; + let fingerprint = fingerprint(tag, &data); + + Self { + id: 0, + group_id: 0, + fingerprint, + variant_tag: tag, + data, + } + } +} diff --git a/optd-persistent/src/memo/interface.rs b/optd-persistent/src/memo/interface.rs new file mode 100644 index 0000000..b9451f9 --- /dev/null +++ b/optd-persistent/src/memo/interface.rs @@ -0,0 +1,138 @@ +use crate::StorageResult; + +/// A trait representing an implementation of a memoization table. +/// +/// Note that we use [`trait_variant`] here in order to add bounds on every method. +/// See this [blog post]( +/// https://blog.rust-lang.org/2023/12/21/async-fn-rpit-in-traits.html#async-fn-in-public-traits) +/// for more information. +/// +/// TODO Figure out for each when to get the ID of a record or the entire record itself. +#[trait_variant::make(Send)] +pub trait Memo { + /// A type representing a group in the Cascades framework. + type Group; + /// A type representing a unique identifier for a group. + type GroupId; + /// A type representing a logical expression. + type LogicalExpression; + /// A type representing a unique identifier for a logical expression. + type LogicalExpressionId; + /// A type representing a physical expression. + type PhysicalExpression; + /// A type representing a unique identifier for a physical expression. + type PhysicalExpressionId; + + /// Retrieves a [`Self::Group`] given a [`Self::GroupId`]. + /// + /// If the group does not exist, returns a [`MemoError::UnknownGroup`] error. + async fn get_group(&self, group_id: Self::GroupId) -> StorageResult; + + /// Retrieves all group IDs that are stored in the memo table. + async fn get_all_groups(&self) -> StorageResult>; + + /// Retrieves a [`Self::LogicalExpression`] given a [`Self::LogicalExpressionId`]. + /// + /// If the logical expression does not exist, returns a [`MemoError::UnknownLogicalExpression`] + /// error. + async fn get_logical_expression( + &self, + logical_expression_id: Self::LogicalExpressionId, + ) -> StorageResult; + + /// Retrieves a [`Self::PhysicalExpression`] given a [`Self::PhysicalExpressionId`]. + /// + /// If the physical expression does not exist, returns a + /// [`MemoError::UnknownPhysicalExpression`] error. + async fn get_physical_expression( + &self, + physical_expression_id: Self::PhysicalExpressionId, + ) -> StorageResult; + + /// Retrieves the parent group ID of a logical expression given its expression ID. + /// + /// If the logical expression does not exist, returns a [`MemoError::UnknownLogicalExpression`] + /// error. + async fn get_group_from_logical_expression( + &self, + logical_expression_id: Self::LogicalExpressionId, + ) -> StorageResult; + + /// Retrieves the parent group ID of a logical expression given its expression ID. + /// + /// If the physical expression does not exist, returns a + /// [`MemoError::UnknownPhysicalExpression`] error. + async fn get_group_from_physical_expression( + &self, + physical_expression_id: Self::PhysicalExpressionId, + ) -> StorageResult; + + /// Retrieves all of the logical expression "children" of a group. + /// + /// If the group does not exist, returns a [`MemoError::UnknownGroup`] error. + async fn get_group_logical_expressions( + &self, + group_id: Self::GroupId, + ) -> StorageResult>; + + /// Retrieves all of the physical expression "children" of a group. + /// + /// If the group does not exist, returns a [`MemoError::UnknownGroup`] error. + async fn get_group_physical_expressions( + &self, + group_id: Self::GroupId, + ) -> StorageResult>; + + /// Retrieves the best physical query plan (winner) for a given group. + /// + /// If the group does not exist, returns a [`MemoError::UnknownGroup`] error. + async fn get_winner( + &self, + group_id: Self::GroupId, + ) -> StorageResult>; + + /// Updates / replaces a group's best physical plan (winner). Optionally returns the previous + /// winner's physical expression ID. + /// + /// If the group does not exist, returns a [`MemoError::UnknownGroup`] error. + async fn update_group_winner( + &self, + group_id: Self::GroupId, + physical_expression_id: Self::PhysicalExpressionId, + ) -> StorageResult>; + + /// Adds a logical expression to an existing group via its [`Self::GroupId`]. + /// + /// If the group does not exist, returns a [`MemoError::UnknownGroup`] error. + async fn add_logical_expression_to_group( + &self, + group_id: Self::GroupId, + logical_expression: Self::LogicalExpression, + ) -> StorageResult<()>; + + /// Adds a physical expression to an existing group via its [`Self::GroupId`]. + /// + /// If the group does not exist, returns a [`MemoError::UnknownGroup`] error. + async fn add_physical_expression_to_group( + &self, + group_id: Self::GroupId, + physical_expression: Self::PhysicalExpression, + ) -> StorageResult<()>; + + /// Adds a new logical expression into the memo table, creating a new group if the expression + /// does not already exist. + /// + /// The [`Self::LogicalExpression`] type should have some sort of mechanism for checking if + /// the expression has been seen before, and if it has already been created, then the parent + /// group ID should also be retrievable. + /// + /// If the expression already exists, then this function will return the [`Self::GroupId`] of + /// the parent group and the corresponding (already existing) [`Self::LogicalExpressionId`]. + /// + /// If the expression does not exist, this function will create a new group and a new + /// expression, returning brand new IDs for both. + async fn add_logical_expression( + &self, + expression: Self::LogicalExpression, + ) -> StorageResult<(Self::GroupId, Self::LogicalExpressionId)>; +} diff --git a/optd-persistent/src/memo/mod.rs b/optd-persistent/src/memo/mod.rs new file mode 100644 index 0000000..c455782 --- /dev/null +++ b/optd-persistent/src/memo/mod.rs @@ -0,0 +1,4 @@ +mod expression; + +pub mod interface; +pub mod orm; diff --git a/optd-persistent/src/memo/orm.rs b/optd-persistent/src/memo/orm.rs new file mode 100644 index 0000000..e7cc78b --- /dev/null +++ b/optd-persistent/src/memo/orm.rs @@ -0,0 +1,210 @@ +use crate::{ + entities::{prelude::*, *}, + BackendManager, {Memo, MemoError, StorageResult}, +}; +use sea_orm::*; + +impl Memo for BackendManager { + type Group = cascades_group::Model; + type GroupId = i32; + type LogicalExpression = logical_expression::Model; + type LogicalExpressionId = i32; + type PhysicalExpression = physical_expression::Model; + type PhysicalExpressionId = i32; + + async fn get_group(&self, group_id: Self::GroupId) -> StorageResult { + Ok(CascadesGroup::find_by_id(group_id) + .one(&self.db) + .await? + .ok_or(MemoError::UnknownGroup)?) + } + + async fn get_all_groups(&self) -> StorageResult> { + Ok(CascadesGroup::find().all(&self.db).await?) + } + + async fn get_logical_expression( + &self, + logical_expression_id: Self::LogicalExpressionId, + ) -> StorageResult { + Ok(LogicalExpression::find_by_id(logical_expression_id) + .one(&self.db) + .await? + .ok_or(MemoError::UnknownLogicalExpression)?) + } + + async fn get_physical_expression( + &self, + physical_expression_id: Self::PhysicalExpressionId, + ) -> StorageResult { + Ok(PhysicalExpression::find_by_id(physical_expression_id) + .one(&self.db) + .await? + .ok_or(MemoError::UnknownPhysicalExpression)?) + } + + async fn get_group_from_logical_expression( + &self, + logical_expression_id: Self::LogicalExpressionId, + ) -> StorageResult { + // Find the logical expression and then look up the field. + Ok(self + .get_logical_expression(logical_expression_id) + .await? + .group_id) + } + + async fn get_group_from_physical_expression( + &self, + physical_expression_id: Self::PhysicalExpressionId, + ) -> StorageResult { + Ok(self + .get_physical_expression(physical_expression_id) + .await? + .group_id) + } + + async fn get_group_logical_expressions( + &self, + group_id: Self::GroupId, + ) -> StorageResult> { + // First retrieve the group record, and then find all related logical expressions. + Ok(self + .get_group(group_id) + .await? + .find_related(LogicalExpression) + .all(&self.db) + .await?) + } + + async fn get_group_physical_expressions( + &self, + group_id: Self::GroupId, + ) -> StorageResult> { + // First retrieve the group record, and then find all related physical expressions. + Ok(self + .get_group(group_id) + .await? + .find_related(PhysicalExpression) + .all(&self.db) + .await?) + } + + async fn get_winner( + &self, + group_id: Self::GroupId, + ) -> StorageResult> { + Ok(self.get_group(group_id).await?.latest_winner) + } + + async fn update_group_winner( + &self, + group_id: Self::GroupId, + physical_expression_id: Self::PhysicalExpressionId, + ) -> StorageResult> { + // First retrieve the group record, and then use an `ActiveModel` to update it. + let mut group = self.get_group(group_id).await?.into_active_model(); + let old_id = group.latest_winner; + + group.latest_winner = Set(Some(physical_expression_id)); + group.update(&self.db).await?; + + // The old value must be set (`None` still means it has been set). + let old = old_id.unwrap(); + Ok(old) + } + + async fn add_logical_expression_to_group( + &self, + group_id: Self::GroupId, + logical_expression: Self::LogicalExpression, + ) -> StorageResult<()> { + if logical_expression.group_id != group_id { + Err(MemoError::InvalidExpression)? + } + + // Check if the group actually exists. + let _ = self.get_group(group_id).await?; + + // Insert the expression. + let _ = logical_expression + .into_active_model() + .insert(&self.db) + .await?; + + todo!("add the children of the logical expression into the children table") + } + + async fn add_physical_expression_to_group( + &self, + group_id: Self::GroupId, + physical_expression: Self::PhysicalExpression, + ) -> StorageResult<()> { + if physical_expression.group_id != group_id { + Err(MemoError::InvalidExpression)? + } + + // Check if the group actually exists. + let _ = self.get_group(group_id).await?; + + // Insert the expression. + let _ = physical_expression + .into_active_model() + .insert(&self.db) + .await?; + + todo!("add the children of the logical expression into the children table") + } + + /// Note that in this function, we ignore the group ID that the logical expression contains. + async fn add_logical_expression( + &self, + expression: Self::LogicalExpression, + ) -> StorageResult<(Self::GroupId, Self::LogicalExpressionId)> { + // Lookup all expressions that have the same fingerprint. There may be false positives, but + // we will check for those later. + let fingerprint = expression.fingerprint; + let potential_matches = LogicalExpression::find() + .filter(logical_expression::Column::Fingerprint.eq(fingerprint)) + .all(&self.db) + .await?; + + // Of the expressions that have the same fingerprint, check if there already exists an + // expression that is exactly identical to the input expression. + let mut matches: Vec<_> = potential_matches + .into_iter() + .filter(|expr| expr == &expression) + .collect(); + assert!( + matches.len() <= 1, + "there cannot be more than 1 exact logical expression match" + ); + + // The expression already exists, so return its data. + if !matches.is_empty() { + let existing_expression = matches + .pop() + .expect("we just checked that an element exists"); + + return Ok((existing_expression.group_id, existing_expression.id)); + } + + // The expression does not exist yet, so we need to create a new group and new expression. + let group = cascades_group::ActiveModel { + latest_winner: Set(None), + in_progress: Set(false), + is_optimized: Set(false), + ..Default::default() + }; + + // Insert a new group. + let res = cascades_group::Entity::insert(group).exec(&self.db).await?; + + // Insert the input expression with the correct `group_id`. + let mut new_expr = expression.into_active_model(); + new_expr.group_id = Set(res.last_insert_id); + let new_expr = new_expr.insert(&self.db).await?; + + Ok((new_expr.group_id, new_expr.id)) + } +}