diff --git a/optd-persistent/src/orm_manager.rs b/optd-persistent/src/orm_manager.rs index bc07274..35c8885 100644 --- a/optd-persistent/src/orm_manager.rs +++ b/optd-persistent/src/orm_manager.rs @@ -1,11 +1,11 @@ #![allow(dead_code, unused_imports, unused_variables)] -use crate::entities::{prelude::*, *}; -use crate::orm_manager::{Cost, Event}; -use crate::storage_layer::{self, EpochId, StorageLayer}; +use crate::entities::physical_expression; +use crate::storage_layer::{self, EpochId, StorageLayer, StorageResult}; use crate::DATABASE_URL; +use sea_orm::DatabaseConnection; use sea_orm::*; -use sqlx::types::chrono::Utc; +use sea_orm_migration::prelude::*; pub struct ORMManager { db_conn: DatabaseConnection, @@ -31,126 +31,85 @@ impl StorageLayer for ORMManager { &mut self, source: String, data: String, - ) -> Result { - let new_event = event::ActiveModel { - source_variant: sea_orm::ActiveValue::Set(source), - create_timestamp: sea_orm::ActiveValue::Set(Utc::now()), - data: sea_orm::ActiveValue::Set(sea_orm::JsonValue::String(data)), - ..Default::default() - }; - let res = Event::insert(new_event).exec(&self.db_conn).await; - match res { - Ok(insert_res) => { - self.latest_epoch_id = insert_res.last_insert_id; - Ok(self.latest_epoch_id) - } - Err(_) => Err(()), - } + ) -> StorageResult { + todo!() } async fn update_stats_from_catalog( &self, c: storage_layer::CatalogSource, epoch_id: storage_layer::EpochId, - ) -> Result<(), ()> { + ) -> StorageResult<()> { todo!() } - async fn update_stats(&self, stats: i32, epoch_id: storage_layer::EpochId) -> Result<(), ()> { + async fn update_stats( + &self, + stats: i32, + epoch_id: storage_layer::EpochId, + ) -> StorageResult<()> { todo!() } - async fn get_stats_analysis( + async fn store_cost( &self, - table_id: i32, - attr_id: Option, + expr_id: storage_layer::ExprId, + cost: i32, epoch_id: storage_layer::EpochId, - ) -> Option { + ) -> StorageResult<()> { todo!() } - async fn get_stats(&self, table_id: i32, attr_id: Option) -> Option { + async fn get_stats_for_table( + &self, + table_id: i32, + stat_type: i32, + epoch_id: Option, + ) -> StorageResult> { todo!() } - async fn store_cost( + async fn get_stats_for_attr( &self, - expr_id: storage_layer::ExprId, - cost: i32, - epoch_id: storage_layer::EpochId, - ) -> Result<(), DbErr> { - // TODO: update PhysicalExpression and Event tables - // Check if expr_id exists in PhysicalExpression table - let expr_exists = PhysicalExpression::find_by_id(expr_id) - .one(&self.db_conn) - .await?; - if expr_exists.is_none() { - return Err(DbErr::RecordNotFound( - "ExprId not found in PhysicalExpression table".to_string(), - )); - } - - // Check if epoch_id exists in Event table - let epoch_exists = Event::find() - .filter(event::Column::EpochId.eq(epoch_id)) - .one(&self.db_conn) - .await - .unwrap(); + attr_id: i32, + stat_type: i32, + epoch_id: Option, + ) -> StorageResult> { + todo!() + } - let new_cost = cost::ActiveModel { - expr_id: ActiveValue::Set(expr_id), - epoch_id: ActiveValue::Set(epoch_id), - cost: ActiveValue::Set(cost), - valid: ActiveValue::Set(true), - ..Default::default() - }; - let res = Cost::insert(new_cost).exec(&self.db_conn).await; - match res { - Ok(_) => Ok(()), - Err(e) => Err(DbErr::Custom(e.to_string())), - } + async fn get_stats_for_attrs( + &self, + attr_ids: Vec, + stat_type: i32, + epoch_id: Option, + ) -> StorageResult> { + todo!() } async fn get_cost_analysis( &self, expr_id: storage_layer::ExprId, epoch_id: storage_layer::EpochId, - ) -> Option { - let cost = Cost::find() - .filter(cost::Column::ExprId.eq(expr_id)) - .filter(cost::Column::EpochId.eq(epoch_id)) - .one(&self.db_conn) - .await - .unwrap(); - assert!(cost.is_some(), "Cost not found in Cost table"); - assert!(cost.clone().unwrap().valid, "Cost is not valid"); - cost.map(|c| c.cost) + ) -> StorageResult> { + todo!() } - /// Get the latest cost for an expression - async fn get_cost(&self, expr_id: storage_layer::ExprId) -> Option { - let cost = Cost::find() - .filter(cost::Column::ExprId.eq(expr_id)) - .order_by_desc(cost::Column::EpochId) - .one(&self.db_conn) - .await - .unwrap(); - assert!(cost.is_some(), "Cost not found in Cost table"); - assert!(cost.clone().unwrap().valid, "Cost is not valid"); - cost.map(|c| c.cost) + async fn get_cost(&self, expr_id: storage_layer::ExprId) -> StorageResult> { + todo!() } async fn get_group_winner_from_group_id( &self, group_id: i32, - ) -> Option { + ) -> StorageResult> { todo!() } async fn add_new_expr( &mut self, expr: storage_layer::Expression, - ) -> (storage_layer::GroupId, storage_layer::ExprId) { + ) -> StorageResult<(storage_layer::GroupId, storage_layer::ExprId)> { todo!() } @@ -158,26 +117,32 @@ impl StorageLayer for ORMManager { &mut self, expr: storage_layer::Expression, group_id: storage_layer::GroupId, - ) -> Option { + ) -> StorageResult> { todo!() } - async fn get_group_id(&self, expr_id: storage_layer::ExprId) -> storage_layer::GroupId { + async fn get_group_id( + &self, + expr_id: storage_layer::ExprId, + ) -> StorageResult { todo!() } - async fn get_expr_memoed(&self, expr_id: storage_layer::ExprId) -> storage_layer::Expression { + async fn get_expr_memoed( + &self, + expr_id: storage_layer::ExprId, + ) -> StorageResult { todo!() } - async fn get_all_group_ids(&self) -> Vec { + async fn get_all_group_ids(&self) -> StorageResult> { todo!() } async fn get_group( &self, group_id: storage_layer::GroupId, - ) -> crate::entities::cascades_group::ActiveModel { + ) -> StorageResult { todo!() } @@ -185,125 +150,35 @@ impl StorageLayer for ORMManager { &mut self, group_id: storage_layer::GroupId, latest_winner: Option, - ) { + ) -> StorageResult<()> { todo!() } async fn get_all_exprs_in_group( &self, group_id: storage_layer::GroupId, - ) -> Vec { + ) -> StorageResult> { todo!() } async fn get_group_info( &self, group_id: storage_layer::GroupId, - ) -> &Option { + ) -> StorageResult<&Option> { todo!() } async fn get_predicate_binding( &self, group_id: storage_layer::GroupId, - ) -> Option { + ) -> StorageResult> { todo!() } async fn try_get_predicate_binding( &self, group_id: storage_layer::GroupId, - ) -> Option { + ) -> StorageResult> { todo!() } } - -#[cfg(test)] -mod tests { - use crate::migrate; - use sea_orm::{ConnectionTrait, Database, EntityTrait, ModelTrait}; - use serde_json::de; - - use crate::entities::event::Entity as Event; - use crate::storage_layer::StorageLayer; - use crate::TEST_DATABASE_URL; - - async fn run_migration() { - let _ = std::fs::remove_file(TEST_DATABASE_URL); - - let db = Database::connect(TEST_DATABASE_URL) - .await - .expect("Unable to connect to the database"); - - migrate(&db) - .await - .expect("Something went wrong during migration"); - - db.execute(sea_orm::Statement::from_string( - sea_orm::DatabaseBackend::Sqlite, - "PRAGMA foreign_keys = ON;".to_owned(), - )) - .await - .expect("Unable to enable foreign keys"); - } - - #[tokio::test] - async fn test_create_new_epoch() { - run_migration().await; - let mut orm_manager = super::ORMManager::new(Some(TEST_DATABASE_URL)).await; - let res = orm_manager - .create_new_epoch("source".to_string(), "data".to_string()) - .await; - println!("{:?}", res); - assert!(res.is_ok()); - assert_eq!( - super::Event::find() - .all(&orm_manager.db_conn) - .await - .unwrap() - .len(), - 1 - ); - println!( - "{:?}", - super::Event::find() - .all(&orm_manager.db_conn) - .await - .unwrap()[0] - ); - assert_eq!( - super::Event::find() - .all(&orm_manager.db_conn) - .await - .unwrap()[0] - .epoch_id, - res.unwrap() - ); - } - - #[tokio::test] - #[ignore] // Need to update all tables - async fn test_store_cost() { - run_migration().await; - let mut orm_manager = super::ORMManager::new(Some(TEST_DATABASE_URL)).await; - let epoch_id = orm_manager - .create_new_epoch("source".to_string(), "data".to_string()) - .await - .unwrap(); - let expr_id = 1; - let cost = 42; - let res = orm_manager.store_cost(expr_id, cost, epoch_id).await; - match res { - Ok(_) => assert!(true), - Err(e) => { - println!("Error: {:?}", e); - assert!(false); - } - } - let costs = super::Cost::find().all(&orm_manager.db_conn).await.unwrap(); - assert_eq!(costs.len(), 1); - assert_eq!(costs[0].epoch_id, epoch_id); - assert_eq!(costs[0].expr_id, expr_id); - assert_eq!(costs[0].cost, cost); - } -} diff --git a/optd-persistent/src/storage_layer.rs b/optd-persistent/src/storage_layer.rs index 1d35eda..93590f6 100644 --- a/optd-persistent/src/storage_layer.rs +++ b/optd-persistent/src/storage_layer.rs @@ -13,6 +13,8 @@ pub type GroupId = i32; pub type ExprId = i32; pub type EpochId = i32; +pub type StorageResult = Result; + pub enum CatalogSource { Iceberg(), } @@ -37,63 +39,102 @@ pub struct WinnerInfo {} pub trait StorageLayer { // TODO: Change EpochId to event::Model::epoch_id - async fn create_new_epoch(&mut self, source: String, data: String) -> Result; + async fn create_new_epoch(&mut self, source: String, data: String) -> StorageResult; + async fn update_stats_from_catalog( &self, c: CatalogSource, epoch_id: EpochId, - ) -> Result<(), ()>; + ) -> StorageResult<()>; + // i32 in `stats:i32` is a placeholder for the stats type - async fn update_stats(&self, stats: i32, epoch_id: EpochId) -> Result<(), ()>; - async fn store_cost(&self, expr_id: ExprId, cost: i32, epoch_id: EpochId) -> Result<(), DbErr>; - // table_id, attr_id OR expr_id and return a vector? - async fn get_stats_analysis( + async fn update_stats(&self, stats: i32, epoch_id: EpochId) -> StorageResult<()>; + + async fn store_cost(&self, expr_id: ExprId, cost: i32, epoch_id: EpochId) -> StorageResult<()>; + + /// Get the statistics for a given table. + /// + /// If `epoch_id` is None, it will return the latest statistics. + async fn get_stats_for_table( &self, table_id: i32, - attr_id: Option, + stat_type: i32, + epoch_id: Option, + ) -> StorageResult>; + + /// Get the statistics for a given attribute. + /// + /// If `epoch_id` is None, it will return the latest statistics. + async fn get_stats_for_attr( + &self, + attr_id: i32, + stat_type: i32, + epoch_id: Option, + ) -> StorageResult>; + + /// Get the joint statistics for a list of attributes. + /// + /// If `epoch_id` is None, it will return the latest statistics. + async fn get_stats_for_attrs( + &self, + attr_ids: Vec, + stat_type: i32, + epoch_id: Option, + ) -> StorageResult>; + + async fn get_cost_analysis( + &self, + expr_id: ExprId, epoch_id: EpochId, - ) -> Option; - async fn get_stats(&self, table_id: i32, attr_id: Option) -> Option; - async fn get_cost_analysis(&self, expr_id: ExprId, epoch_id: EpochId) -> Option; - async fn get_cost(&self, expr_id: ExprId) -> Option; + ) -> StorageResult>; + + async fn get_cost(&self, expr_id: ExprId) -> StorageResult>; async fn get_group_winner_from_group_id( &self, group_id: i32, - ) -> Option; + ) -> StorageResult>; /// Add an expression to the memo table. If the expression already exists, it will return the existing group id and /// expr id. Otherwise, a new group and expr will be created. - async fn add_new_expr(&mut self, expr: Expression) -> (GroupId, ExprId); + async fn add_new_expr(&mut self, expr: Expression) -> StorageResult<(GroupId, ExprId)>; /// Add a new expression to an existing group. If the expression is a group, it will merge the two groups. Otherwise, /// it will add the expression to the group. Returns the expr id if the expression is not a group. - async fn add_expr_to_group(&mut self, expr: Expression, group_id: GroupId) -> Option; + async fn add_expr_to_group( + &mut self, + expr: Expression, + group_id: GroupId, + ) -> StorageResult>; /// Get the group id of an expression. /// The group id is volatile, depending on whether the groups are merged. - async fn get_group_id(&self, expr_id: ExprId) -> GroupId; + async fn get_group_id(&self, expr_id: ExprId) -> StorageResult; /// Get the memoized representation of a node. - async fn get_expr_memoed(&self, expr_id: ExprId) -> Expression; + async fn get_expr_memoed(&self, expr_id: ExprId) -> StorageResult; /// Get all groups IDs in the memo table. - async fn get_all_group_ids(&self) -> Vec; + async fn get_all_group_ids(&self) -> StorageResult>; /// Get a group by ID - async fn get_group(&self, group_id: GroupId) -> cascades_group::ActiveModel; + async fn get_group(&self, group_id: GroupId) -> StorageResult; /// Update the group winner. - async fn update_group_winner(&mut self, group_id: GroupId, latest_winner: Option); + async fn update_group_winner( + &mut self, + group_id: GroupId, + latest_winner: Option, + ) -> StorageResult<()>; // The below functions can be overwritten by the memo table implementation if there // are more efficient way to retrieve the information. /// Get all expressions in the group. - async fn get_all_exprs_in_group(&self, group_id: GroupId) -> Vec; + async fn get_all_exprs_in_group(&self, group_id: GroupId) -> StorageResult>; /// Get winner info for a group id - async fn get_group_info(&self, group_id: GroupId) -> &Option; + async fn get_group_info(&self, group_id: GroupId) -> StorageResult<&Option>; // TODO: /// Get the best group binding based on the cost @@ -125,8 +166,11 @@ pub trait StorageLayer { // }; /// Get all bindings of a predicate group. Will panic if the group contains more than one bindings. - async fn get_predicate_binding(&self, group_id: GroupId) -> Option; + async fn get_predicate_binding(&self, group_id: GroupId) -> StorageResult>; /// Get all bindings of a predicate group. Returns None if the group contains zero or more than one bindings. - async fn try_get_predicate_binding(&self, group_id: GroupId) -> Option; + async fn try_get_predicate_binding( + &self, + group_id: GroupId, + ) -> StorageResult>; }