From 9689c046292debf33b5461deb13dd4b4056f1de3 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Thu, 28 Nov 2024 14:43:32 -0500 Subject: [PATCH] add expression representation and refactor memo This commit adds the `src/expression` module which contains a very simple representation of Cascades expressions. The `Memo` trait interface and implemenation has also changed, where it now correctly detects exact match duplicates. TODO: Add the duplicate detection to the other methods that need them. TODO: Add more tests. TODO: Figure out how to test in CI. --- optd-mvp/src/expression/logical_expression.rs | 121 ++++++++++++++++++ optd-mvp/src/expression/mod.rs | 62 +++++++++ .../src/expression/physical_expression.rs | 121 ++++++++++++++++++ optd-mvp/src/lib.rs | 2 + optd-mvp/src/memo/interface.rs | 9 ++ .../implementation.rs} | 38 ++---- optd-mvp/src/memo/persistent/mod.rs | 62 +++++++++ optd-mvp/src/memo/persistent/tests.rs | 29 +++++ 8 files changed, 417 insertions(+), 27 deletions(-) create mode 100644 optd-mvp/src/expression/logical_expression.rs create mode 100644 optd-mvp/src/expression/mod.rs create mode 100644 optd-mvp/src/expression/physical_expression.rs rename optd-mvp/src/memo/{persistent.rs => persistent/implementation.rs} (89%) create mode 100644 optd-mvp/src/memo/persistent/mod.rs create mode 100644 optd-mvp/src/memo/persistent/tests.rs diff --git a/optd-mvp/src/expression/logical_expression.rs b/optd-mvp/src/expression/logical_expression.rs new file mode 100644 index 0000000..354b3c7 --- /dev/null +++ b/optd-mvp/src/expression/logical_expression.rs @@ -0,0 +1,121 @@ +//! Definition of logical expressions / relations in the Cascades query optimization framework. +//! +//! FIXME: All fields are placeholders, and group IDs are just represented as i32 for now. +//! +//! TODO figure out if each relation should be in a different submodule. + +use crate::entities::*; +use serde::{Deserialize, Serialize}; +use std::hash::{DefaultHasher, Hash, Hasher}; + +#[derive(Clone, Debug)] +pub enum LogicalExpression { + Scan(Scan), + Filter(Filter), + Join(Join), +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct Scan { + table_schema: String, +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct Filter { + child: i32, + expression: String, +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct Join { + left: i32, + right: i32, + expression: String, +} + +/// TODO Use a macro instead. +impl From for LogicalExpression { + fn from(value: logical_expression::Model) -> Self { + match value.kind { + 0 => Self::Scan( + serde_json::from_value(value.data) + .expect("unable to deserialize data into a logical `Scan`"), + ), + 1 => Self::Filter( + serde_json::from_value(value.data) + .expect("Unable to deserialize data into a logical `Filter`"), + ), + 2 => Self::Join( + serde_json::from_value(value.data) + .expect("Unable to deserialize data into a logical `Join`"), + ), + _ => panic!(), + } + } +} + +/// TODO Use a macro instead. +impl From for logical_expression::Model { + fn from(value: LogicalExpression) -> logical_expression::Model { + fn create_logical_expression( + kind: i16, + data: serde_json::Value, + ) -> logical_expression::Model { + let mut hasher = DefaultHasher::new(); + kind.hash(&mut hasher); + data.hash(&mut hasher); + let fingerprint = hasher.finish() as i64; + + logical_expression::Model { + id: -1, + group_id: -1, + fingerprint, + kind, + data, + } + } + + match value { + LogicalExpression::Scan(scan) => create_logical_expression( + 0, + serde_json::to_value(scan).expect("unable to serialize logical `Scan`"), + ), + LogicalExpression::Filter(filter) => create_logical_expression( + 1, + serde_json::to_value(filter).expect("unable to serialize logical `Filter`"), + ), + LogicalExpression::Join(join) => create_logical_expression( + 2, + serde_json::to_value(join).expect("unable to serialize logical `Join`"), + ), + } + } +} + +#[cfg(test)] +pub use build::*; + +#[cfg(test)] +mod build { + use super::*; + use crate::expression::Expression; + + pub fn scan(table_schema: String) -> Expression { + Expression::Logical(LogicalExpression::Scan(Scan { table_schema })) + } + + pub fn filter(child_group: i32, expression: String) -> Expression { + Expression::Logical(LogicalExpression::Filter(Filter { + child: child_group, + expression, + })) + } + + pub fn join(left_group: i32, right_group: i32, expression: String) -> Expression { + Expression::Logical(LogicalExpression::Join(Join { + left: left_group, + right: right_group, + expression, + })) + } +} diff --git a/optd-mvp/src/expression/mod.rs b/optd-mvp/src/expression/mod.rs new file mode 100644 index 0000000..459e13b --- /dev/null +++ b/optd-mvp/src/expression/mod.rs @@ -0,0 +1,62 @@ +//! In-memory representation of Cascades logical and physical expression / operators / relations. +//! +//! TODO more docs. + +mod logical_expression; +pub use logical_expression::*; + +mod physical_expression; +pub use physical_expression::*; + +/// The representation of a Cascades expression. +/// +/// TODO more docs. +#[derive(Clone, Debug)] +pub enum Expression { + Logical(LogicalExpression), + Physical(PhysicalExpression), +} + +/// Converts the database / JSON representation of a logical expression into an in-memory one. +impl From for Expression { + fn from(value: crate::entities::logical_expression::Model) -> Self { + Self::Logical(value.into()) + } +} + +/// Converts the in-memory representation of a logical expression into the database / JSON version. +/// +/// # Panics +/// +/// This will panic if the [`Expression`] is [`Expression::Physical`]. +impl From for crate::entities::logical_expression::Model { + fn from(value: Expression) -> Self { + let Expression::Logical(expr) = value else { + panic!("Attempted to convert an in-memory physical expression into a logical database / JSON expression"); + }; + + expr.into() + } +} + +/// Converts the database / JSON representation of a physical expression into an in-memory one. +impl From for Expression { + fn from(value: crate::entities::physical_expression::Model) -> Self { + Self::Physical(value.into()) + } +} + +/// Converts the in-memory representation of a physical expression into the database / JSON version. +/// +/// # Panics +/// +/// This will panic if the [`Expression`] is [`Expression::Physical`]. +impl From for crate::entities::physical_expression::Model { + fn from(value: Expression) -> Self { + let Expression::Physical(expr) = value else { + panic!("Attempted to convert an in-memory logical expression into a physical database / JSON expression"); + }; + + expr.into() + } +} diff --git a/optd-mvp/src/expression/physical_expression.rs b/optd-mvp/src/expression/physical_expression.rs new file mode 100644 index 0000000..fced9d1 --- /dev/null +++ b/optd-mvp/src/expression/physical_expression.rs @@ -0,0 +1,121 @@ +//! Definition of physical expressions / operators in the Cascades query optimization framework. +//! +//! FIXME: All fields are placeholders, and group IDs are just represented as i32 for now. +//! +//! TODO figure out if each operator should be in a different submodule. + +use crate::entities::*; +use serde::{Deserialize, Serialize}; +use std::hash::{DefaultHasher, Hash, Hasher}; + +#[derive(Clone, Debug)] +pub enum PhysicalExpression { + TableScan(TableScan), + Filter(PhysicalFilter), + HashJoin(HashJoin), +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct TableScan { + table_schema: String, +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct PhysicalFilter { + child: i32, + expression: String, +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct HashJoin { + left: i32, + right: i32, + expression: String, +} + +/// TODO Use a macro instead. +impl From for PhysicalExpression { + fn from(value: physical_expression::Model) -> Self { + match value.kind { + 0 => Self::TableScan( + serde_json::from_value(value.data) + .expect("unable to deserialize data into a physical `TableScan`"), + ), + 1 => Self::Filter( + serde_json::from_value(value.data) + .expect("Unable to deserialize data into a physical `Filter`"), + ), + 2 => Self::HashJoin( + serde_json::from_value(value.data) + .expect("Unable to deserialize data into a physical `HashJoin`"), + ), + _ => panic!(), + } + } +} + +/// TODO Use a macro instead. +impl From for physical_expression::Model { + fn from(value: PhysicalExpression) -> physical_expression::Model { + fn create_physical_expression( + kind: i16, + data: serde_json::Value, + ) -> physical_expression::Model { + let mut hasher = DefaultHasher::new(); + kind.hash(&mut hasher); + data.hash(&mut hasher); + let fingerprint = hasher.finish() as i64; + + physical_expression::Model { + id: -1, + group_id: -1, + fingerprint, + kind, + data, + } + } + + match value { + PhysicalExpression::TableScan(scan) => create_physical_expression( + 0, + serde_json::to_value(scan).expect("unable to serialize physical `TableScan`"), + ), + PhysicalExpression::Filter(filter) => create_physical_expression( + 1, + serde_json::to_value(filter).expect("unable to serialize physical `Filter`"), + ), + PhysicalExpression::HashJoin(join) => create_physical_expression( + 2, + serde_json::to_value(join).expect("unable to serialize physical `HashJoin`"), + ), + } + } +} + +#[cfg(test)] +pub use build::*; + +#[cfg(test)] +mod build { + use super::*; + use crate::expression::Expression; + + pub fn table_scan(table_schema: String) -> Expression { + Expression::Physical(PhysicalExpression::TableScan(TableScan { table_schema })) + } + + pub fn filter(child_group: i32, expression: String) -> Expression { + Expression::Physical(PhysicalExpression::Filter(PhysicalFilter { + child: child_group, + expression, + })) + } + + pub fn hash_join(left_group: i32, right_group: i32, expression: String) -> Expression { + Expression::Physical(PhysicalExpression::HashJoin(HashJoin { + left: left_group, + right: right_group, + expression, + })) + } +} diff --git a/optd-mvp/src/lib.rs b/optd-mvp/src/lib.rs index c5185cd..506eee4 100644 --- a/optd-mvp/src/lib.rs +++ b/optd-mvp/src/lib.rs @@ -10,6 +10,8 @@ mod entities; mod memo; use memo::MemoError; +mod expression; + /// The filename of the SQLite database for migration. pub const DATABASE_FILENAME: &str = "sqlite.db"; /// The URL of the SQLite database for migration. diff --git a/optd-mvp/src/memo/interface.rs b/optd-mvp/src/memo/interface.rs index a88740e..54abe21 100644 --- a/optd-mvp/src/memo/interface.rs +++ b/optd-mvp/src/memo/interface.rs @@ -1,3 +1,6 @@ +//! This module defines the [`Memo`] trait, which defines shared behavior of all memo table that can +//! be used for query optimization in the Cascades framework. + use crate::OptimizerResult; use thiserror::Error; @@ -96,6 +99,9 @@ pub trait Memo { /// [`MemoError::InvalidExpression`] error. /// /// If the group does not exist, returns a [`MemoError::UnknownGroup`] error. + /// + /// FIXME: This needs to have a mechanism of reporting that a duplicate expression was found in + /// another group. async fn add_logical_expression_to_group( &self, group_id: Self::GroupId, @@ -114,6 +120,9 @@ pub trait Memo { /// [`MemoError::InvalidExpression`] error. /// /// If the group does not exist, returns a [`MemoError::UnknownGroup`] error. + /// + /// FIXME: This needs to have a mechanism of reporting that a duplicate expression was found in + /// another group. async fn add_physical_expression_to_group( &self, group_id: Self::GroupId, diff --git a/optd-mvp/src/memo/persistent.rs b/optd-mvp/src/memo/persistent/implementation.rs similarity index 89% rename from optd-mvp/src/memo/persistent.rs rename to optd-mvp/src/memo/persistent/implementation.rs index 445ee6c..02bd8dd 100644 --- a/optd-mvp/src/memo/persistent.rs +++ b/optd-mvp/src/memo/persistent/implementation.rs @@ -1,28 +1,10 @@ +//! This module contains the implementation of the [`Memo`] trait for [`PersistentMemo`]. + +use super::*; use crate::{ - entities::{prelude::*, *}, memo::{Memo, MemoError}, - OptimizerResult, DATABASE_URL, + OptimizerResult, }; -use sea_orm::*; - -/// A persistent memo table, backed by a database on disk. -/// -/// TODO more docs. -pub struct PersistentMemo { - /// This `PersistentMemo` is reliant on the SeaORM [`DatabaseConnection`] that stores all of the - /// objects needed for query optimization. - db: DatabaseConnection, -} - -impl PersistentMemo { - /// TODO remove dead code and write docs. - #[allow(dead_code)] - pub async fn new() -> Self { - Self { - db: Database::connect(DATABASE_URL).await.unwrap(), - } - } -} impl Memo for PersistentMemo { type Group = cascades_group::Model; @@ -183,19 +165,21 @@ impl Memo for PersistentMemo { logical_expression: Self::LogicalExpression, children: &[Self::GroupId], ) -> OptimizerResult<(Self::GroupId, Self::LogicalExpressionId)> { - // Lookup all expressions that have the same fingerprint. There may be false positives, but - // we will check for those later. + // Lookup all expressions that have the same fingerprint and kind. There may be false + // positives, but we will check for those next. let fingerprint = logical_expression.fingerprint; + let kind = logical_expression.kind; let potential_matches = LogicalExpression::find() .filter(logical_expression::Column::Fingerprint.eq(fingerprint)) + .filter(logical_expression::Column::Kind.eq(kind)) .all(&self.db) .await?; - // Of the expressions that have the same fingerprint, check if there already exists an - // expression that is exactly identical to the input expression. + // Of the expressions that have the same fingerprint and kind, check if there already exists + // an expression that is exactly identical to the input expression. let mut matches: Vec<_> = potential_matches .into_iter() - .filter(|expr| expr == &logical_expression) + .filter(|expr| expr.data == logical_expression.data) .collect(); assert!( matches.len() <= 1, diff --git a/optd-mvp/src/memo/persistent/mod.rs b/optd-mvp/src/memo/persistent/mod.rs new file mode 100644 index 0000000..8f3f0a7 --- /dev/null +++ b/optd-mvp/src/memo/persistent/mod.rs @@ -0,0 +1,62 @@ +//! This module contains the definition and implementation of the [`PersistentMemo`] type, which +//! implements the `Memo` trait and supports memo table operations necessary for query optimization. + +use crate::{ + entities::{prelude::*, *}, + DATABASE_URL, +}; +use sea_orm::*; + +#[cfg(test)] +mod tests; + +/// A persistent memo table, backed by a database on disk. +/// +/// TODO more docs. +pub struct PersistentMemo { + /// This `PersistentMemo` is reliant on the SeaORM [`DatabaseConnection`] that stores all of the + /// objects needed for query optimization. + db: DatabaseConnection, +} + +impl PersistentMemo { + /// Creates a new `PersistentMemo` struct by connecting to a database defined at + /// [`DATABASE_URL`]. + /// + /// TODO remove dead code and write docs. + #[allow(dead_code)] + pub async fn new() -> Self { + Self { + db: Database::connect(DATABASE_URL).await.unwrap(), + } + } + + /// Since there is no asynchronous drop yet in Rust, we must do this manually. + /// + /// TODO remove dead code and write docs. + #[allow(dead_code)] + pub async fn cleanup(&self) { + cascades_group::Entity::delete_many() + .exec(&self.db) + .await + .unwrap(); + logical_expression::Entity::delete_many() + .exec(&self.db) + .await + .unwrap(); + logical_children::Entity::delete_many() + .exec(&self.db) + .await + .unwrap(); + physical_expression::Entity::delete_many() + .exec(&self.db) + .await + .unwrap(); + physical_children::Entity::delete_many() + .exec(&self.db) + .await + .unwrap(); + } +} + +mod implementation; diff --git a/optd-mvp/src/memo/persistent/tests.rs b/optd-mvp/src/memo/persistent/tests.rs new file mode 100644 index 0000000..cf79507 --- /dev/null +++ b/optd-mvp/src/memo/persistent/tests.rs @@ -0,0 +1,29 @@ +use super::*; +use crate::{expression::*, memo::Memo}; + +/// Tests is exact expression matches are detected and handled by the memo table. +#[ignore] +#[tokio::test] +async fn test_simple_duplicates() { + let memo = PersistentMemo::new().await; + + let scan = scan("(a int, b int)".to_string()); + let scan1 = scan.clone(); + let scan2 = scan.clone(); + + let res0 = memo.add_logical_expression(scan.into(), &[]).await.unwrap(); + let res1 = memo + .add_logical_expression(scan1.into(), &[]) + .await + .unwrap(); + let res2 = memo + .add_logical_expression(scan2.into(), &[]) + .await + .unwrap(); + + assert_eq!(res0, res1); + assert_eq!(res0, res2); + assert_eq!(res1, res2); + + memo.cleanup().await; +}