Skip to content

Commit

Permalink
add memo trait
Browse files Browse the repository at this point in the history
This commit adds a `Memo` trait and a first draft of an implementation
of the `Memo` trait via the backed ORM-mapped database.
  • Loading branch information
connortsui20 committed Nov 17, 2024
1 parent b96ee5a commit 1fcef01
Show file tree
Hide file tree
Showing 18 changed files with 806 additions and 31 deletions.
1 change: 1 addition & 0 deletions optd-cost-model/src/cost/agg.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

40 changes: 24 additions & 16 deletions optd-persistent/src/cost_model/orm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,10 +238,13 @@ impl CostModelStorageLayer for BackendManager {
match res {
Ok(insert_res) => insert_res.last_insert_id,
Err(_) => {
return Err(BackendError::BackendError(format!(
"failed to insert statistic {:?} into statistic table",
stat
)))
return Err(BackendError::CostModel(
format!(
"failed to insert statistic {:?} into statistic table",
stat
)
.into(),
))
}
}
}
Expand Down Expand Up @@ -450,10 +453,13 @@ impl CostModelStorageLayer for BackendManager {
.collect::<Vec<_>>();

if attr_ids.len() != attr_base_indices.len() {
return Err(BackendError::BackendError(format!(
"Not all attributes found for table_id {} and base indices {:?}",
table_id, attr_base_indices
)));
return Err(BackendError::CostModel(
format!(
"Not all attributes found for table_id {} and base indices {:?}",
table_id, attr_base_indices
)
.into(),
));
}

self.get_stats_for_attr(attr_ids, stat_type, epoch_id).await
Expand Down Expand Up @@ -505,10 +511,13 @@ impl CostModelStorageLayer for BackendManager {
.one(&self.db)
.await?;
if expr_exists.is_none() {
return Err(BackendError::BackendError(format!(
"physical expression id {} not found when storing cost",
physical_expression_id
)));
return Err(BackendError::CostModel(
format!(
"physical expression id {} not found when storing cost",
physical_expression_id
)
.into(),
));
}

// Check if epoch_id exists in Event table
Expand All @@ -518,10 +527,9 @@ impl CostModelStorageLayer for BackendManager {
.await
.unwrap();
if epoch_exists.is_none() {
return Err(BackendError::BackendError(format!(
"epoch id {} not found when storing cost",
epoch_id
)));
return Err(BackendError::CostModel(
format!("epoch id {} not found when storing cost", epoch_id).into(),
));
}

let new_cost = plan_cost::ActiveModel {
Expand Down
43 changes: 40 additions & 3 deletions optd-persistent/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ mod migrator;
pub mod cost_model;
pub use cost_model::interface::CostModelStorageLayer;

mod memo;
pub use memo::interface::Memo;

/// The filename of the SQLite database for migration.
pub const DATABASE_FILENAME: &str = "sqlite.db";
/// The URL of the SQLite database for migration.
Expand All @@ -39,17 +42,48 @@ fn get_sqlite_url(file: &str) -> String {
format!("sqlite:{}?mode=rwc", file)
}

pub type StorageResult<T> = Result<T, BackendError>;
#[derive(Debug)]
pub enum CostModelError {
// TODO: Add more error types
UnknownStatisticType,
VersionedStatisticNotFound,
CustomError(String),
}

/// TODO convert this to `thiserror`
#[derive(Debug)]
/// The different kinds of errors that might occur while running operations on a memo table.
pub enum MemoError {
UnknownGroup,
UnknownLogicalExpression,
UnknownPhysicalExpression,
InvalidExpression,
}

/// TODO convert this to `thiserror`
#[derive(Debug)]
pub enum BackendError {
Memo(MemoError),
DatabaseError(DbErr),
CostModel(CostModelError),
BackendError(String),
}

impl From<String> for BackendError {
impl From<String> for CostModelError {
fn from(value: String) -> Self {
BackendError::BackendError(value)
CostModelError::CustomError(value)
}
}

impl From<CostModelError> for BackendError {
fn from(value: CostModelError) -> Self {
BackendError::CostModel(value)
}
}

impl From<MemoError> for BackendError {
fn from(value: MemoError) -> Self {
BackendError::Memo(value)
}
}

Expand All @@ -59,6 +93,9 @@ impl From<DbErr> for BackendError {
}
}

/// A type alias for a result with [`BackendError`] as the error type.
pub type StorageResult<T> = Result<T, BackendError>;

pub struct BackendManager {
db: DatabaseConnection,
}
Expand Down
11 changes: 11 additions & 0 deletions optd-persistent/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,17 @@ use optd_persistent::DATABASE_URL;

#[tokio::main]
async fn main() {
basic_demo().await;
memo_demo().await;
}

async fn memo_demo() {
let _db = Database::connect(DATABASE_URL).await.unwrap();

todo!()
}

async fn basic_demo() {
let db = Database::connect(DATABASE_URL).await.unwrap();

// Create a new `CascadesGroup`.
Expand Down
73 changes: 73 additions & 0 deletions optd-persistent/src/memo/expression.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
use crate::entities::*;
use std::hash::{DefaultHasher, Hash, Hasher};

/// All of the different types of fixed logical operators.
///
/// Note that there could be more operators that the memo table must support that are not enumerated
/// in this enum, as there can be up to `2^16` different types of operators.
#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
#[non_exhaustive]
#[repr(i16)]
pub enum LogicalOperator {
Scan,
Join,
}

/// All of the different types of fixed physical operators.
///
/// Note that there could be more operators that the memo table must support that are not enumerated
/// in this enum, as there can be up to `2^16` different types of operators.
#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
#[non_exhaustive]
#[repr(i16)]
pub enum PhysicalOperator {
TableScan,
IndexScan,
NestedLoopJoin,
HashJoin,
}

/// A method to generate a fingerprint used to efficiently check if two
/// expressions are equivalent.
///
/// TODO actually make efficient.
fn fingerprint(variant_tag: i16, data: &serde_json::Value) -> i64 {
let mut hasher = DefaultHasher::new();

variant_tag.hash(&mut hasher);
data.hash(&mut hasher);

hasher.finish() as i64
}

impl logical_expression::Model {
/// Creates a new logical expression with an unset `id` and `group_id`.
pub fn new(variant_tag: LogicalOperator, data: serde_json::Value) -> Self {
let tag = variant_tag as i16;
let fingerprint = fingerprint(tag, &data);

Self {
id: 0,
group_id: 0,
fingerprint,
variant_tag: tag,
data,
}
}
}

impl physical_expression::Model {
/// Creates a new physical expression with an unset `id` and `group_id`.
pub fn new(variant_tag: PhysicalOperator, data: serde_json::Value) -> Self {
let tag = variant_tag as i16;
let fingerprint = fingerprint(tag, &data);

Self {
id: 0,
group_id: 0,
fingerprint,
variant_tag: tag,
data,
}
}
}
141 changes: 141 additions & 0 deletions optd-persistent/src/memo/interface.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
use crate::StorageResult;

/// A trait representing an implementation of a memoization table.
///
/// Note that we use [`trait_variant`] here in order to add bounds on every method.
/// See this [blog post](
/// https://blog.rust-lang.org/2023/12/21/async-fn-rpit-in-traits.html#async-fn-in-public-traits)
/// for more information.
///
/// TODO Figure out for each when to get the ID of a record or the entire record itself.
#[trait_variant::make(Send)]
pub trait Memo {
/// A type representing a group in the Cascades framework.
type Group;
/// A type representing a unique identifier for a group.
type GroupId;
/// A type representing a logical expression.
type LogicalExpression;
/// A type representing a unique identifier for a logical expression.
type LogicalExpressionId;
/// A type representing a physical expression.
type PhysicalExpression;
/// A type representing a unique identifier for a physical expression.
type PhysicalExpressionId;

/// Retrieves a [`Self::Group`] given a [`Self::GroupId`].
///
/// If the group does not exist, returns a [`MemoError::UnknownGroup`] error.
async fn get_group(&self, group_id: Self::GroupId) -> StorageResult<Self::Group>;

/// Retrieves all group IDs that are stored in the memo table.
async fn get_all_groups(&self) -> StorageResult<Vec<Self::Group>>;

/// Retrieves a [`Self::LogicalExpression`] given a [`Self::LogicalExpressionId`].
///
/// If the logical expression does not exist, returns a [`MemoError::UnknownLogicalExpression`]
/// error.
async fn get_logical_expression(
&self,
logical_expression_id: Self::LogicalExpressionId,
) -> StorageResult<Self::LogicalExpression>;

/// Retrieves a [`Self::PhysicalExpression`] given a [`Self::PhysicalExpressionId`].
///
/// If the physical expression does not exist, returns a
/// [`MemoError::UnknownPhysicalExpression`] error.
async fn get_physical_expression(
&self,
physical_expression_id: Self::PhysicalExpressionId,
) -> StorageResult<Self::PhysicalExpression>;

/// Retrieves the parent group ID of a logical expression given its expression ID.
///
/// If the logical expression does not exist, returns a [`MemoError::UnknownLogicalExpression`]
/// error.
async fn get_group_from_logical_expression(
&self,
logical_expression_id: Self::LogicalExpressionId,
) -> StorageResult<Self::GroupId>;

/// Retrieves the parent group ID of a logical expression given its expression ID.
///
/// If the physical expression does not exist, returns a
/// [`MemoError::UnknownPhysicalExpression`] error.
async fn get_group_from_physical_expression(
&self,
physical_expression_id: Self::PhysicalExpressionId,
) -> StorageResult<Self::GroupId>;

/// Retrieves all of the logical expression "children" of a group.
///
/// If the group does not exist, returns a [`MemoError::UnknownGroup`] error.
async fn get_group_logical_expressions(
&self,
group_id: Self::GroupId,
) -> StorageResult<Vec<Self::LogicalExpression>>;

/// Retrieves all of the physical expression "children" of a group.
///
/// If the group does not exist, returns a [`MemoError::UnknownGroup`] error.
async fn get_group_physical_expressions(
&self,
group_id: Self::GroupId,
) -> StorageResult<Vec<Self::PhysicalExpression>>;

/// Retrieves the best physical query plan (winner) for a given group.
///
/// If the group does not exist, returns a [`MemoError::UnknownGroup`] error.
async fn get_winner(
&self,
group_id: Self::GroupId,
) -> StorageResult<Option<Self::PhysicalExpressionId>>;

/// Updates / replaces a group's best physical plan (winner). Optionally returns the previous
/// winner's physical expression ID.
///
/// If the group does not exist, returns a [`MemoError::UnknownGroup`] error.
async fn update_group_winner(
&self,
group_id: Self::GroupId,
physical_expression_id: Self::PhysicalExpressionId,
) -> StorageResult<Option<Self::PhysicalExpressionId>>;

/// Adds a logical expression to an existing group via its [`Self::GroupId`].
///
/// If the group does not exist, returns a [`MemoError::UnknownGroup`] error.
async fn add_logical_expression_to_group(
&self,
group_id: Self::GroupId,
logical_expression: Self::LogicalExpression,
children: Vec<Self::LogicalExpressionId>,
) -> StorageResult<()>;

/// Adds a physical expression to an existing group via its [`Self::GroupId`].
///
/// If the group does not exist, returns a [`MemoError::UnknownGroup`] error.
async fn add_physical_expression_to_group(
&self,
group_id: Self::GroupId,
physical_expression: Self::PhysicalExpression,
children: Vec<Self::LogicalExpressionId>,
) -> StorageResult<()>;

/// Adds a new logical expression into the memo table, creating a new group if the expression
/// does not already exist.
///
/// The [`Self::LogicalExpression`] type should have some sort of mechanism for checking if
/// the expression has been seen before, and if it has already been created, then the parent
/// group ID should also be retrievable.
///
/// If the expression already exists, then this function will return the [`Self::GroupId`] of
/// the parent group and the corresponding (already existing) [`Self::LogicalExpressionId`].
///
/// If the expression does not exist, this function will create a new group and a new
/// expression, returning brand new IDs for both.
async fn add_logical_expression(
&self,
expression: Self::LogicalExpression,
children: Vec<Self::LogicalExpressionId>,
) -> StorageResult<(Self::GroupId, Self::LogicalExpressionId)>;
}
4 changes: 4 additions & 0 deletions optd-persistent/src/memo/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
mod expression;

pub mod interface;
pub mod orm;
Loading

0 comments on commit 1fcef01

Please sign in to comment.