Skip to content

Commit

Permalink
introduce StorageResult and make methods return StorageResult
Browse files Browse the repository at this point in the history
  • Loading branch information
xx01cyx authored and lanlou1554 committed Nov 6, 2024
1 parent 0aef9ef commit 436ffd4
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 206 deletions.
241 changes: 58 additions & 183 deletions optd-persistent/src/orm_manager.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#![allow(dead_code, unused_imports, unused_variables)]

use crate::entities::{prelude::*, *};
use crate::orm_manager::{Cost, Event};
use crate::storage_layer::{self, EpochId, StorageLayer};
use crate::entities::physical_expression;
use crate::storage_layer::{self, EpochId, StorageLayer, StorageResult};
use crate::DATABASE_URL;
use sea_orm::DatabaseConnection;
use sea_orm::*;
use sqlx::types::chrono::Utc;
use sea_orm_migration::prelude::*;

pub struct ORMManager {
db_conn: DatabaseConnection,
Expand All @@ -31,279 +31,154 @@ impl StorageLayer for ORMManager {
&mut self,
source: String,
data: String,
) -> Result<storage_layer::EpochId, ()> {
let new_event = event::ActiveModel {
source_variant: sea_orm::ActiveValue::Set(source),
create_timestamp: sea_orm::ActiveValue::Set(Utc::now()),
data: sea_orm::ActiveValue::Set(sea_orm::JsonValue::String(data)),
..Default::default()
};
let res = Event::insert(new_event).exec(&self.db_conn).await;
match res {
Ok(insert_res) => {
self.latest_epoch_id = insert_res.last_insert_id;
Ok(self.latest_epoch_id)
}
Err(_) => Err(()),
}
) -> StorageResult<storage_layer::EpochId> {
todo!()
}

async fn update_stats_from_catalog(
&self,
c: storage_layer::CatalogSource,
epoch_id: storage_layer::EpochId,
) -> Result<(), ()> {
) -> StorageResult<()> {
todo!()
}

async fn update_stats(&self, stats: i32, epoch_id: storage_layer::EpochId) -> Result<(), ()> {
async fn update_stats(
&self,
stats: i32,
epoch_id: storage_layer::EpochId,
) -> StorageResult<()> {
todo!()
}

async fn get_stats_analysis(
async fn store_cost(
&self,
table_id: i32,
attr_id: Option<i32>,
expr_id: storage_layer::ExprId,
cost: i32,
epoch_id: storage_layer::EpochId,
) -> Option<i32> {
) -> StorageResult<()> {
todo!()
}

async fn get_stats(&self, table_id: i32, attr_id: Option<i32>) -> Option<i32> {
async fn get_stats_for_table(
&self,
table_id: i32,
stat_type: i32,
epoch_id: Option<EpochId>,
) -> StorageResult<Option<f32>> {
todo!()
}

async fn store_cost(
async fn get_stats_for_attr(
&self,
expr_id: storage_layer::ExprId,
cost: i32,
epoch_id: storage_layer::EpochId,
) -> Result<(), DbErr> {
// TODO: update PhysicalExpression and Event tables
// Check if expr_id exists in PhysicalExpression table
let expr_exists = PhysicalExpression::find_by_id(expr_id)
.one(&self.db_conn)
.await?;
if expr_exists.is_none() {
return Err(DbErr::RecordNotFound(
"ExprId not found in PhysicalExpression table".to_string(),
));
}

// Check if epoch_id exists in Event table
let epoch_exists = Event::find()
.filter(event::Column::EpochId.eq(epoch_id))
.one(&self.db_conn)
.await
.unwrap();
attr_id: i32,
stat_type: i32,
epoch_id: Option<EpochId>,
) -> StorageResult<Option<f32>> {
todo!()
}

let new_cost = cost::ActiveModel {
expr_id: ActiveValue::Set(expr_id),
epoch_id: ActiveValue::Set(epoch_id),
cost: ActiveValue::Set(cost),
valid: ActiveValue::Set(true),
..Default::default()
};
let res = Cost::insert(new_cost).exec(&self.db_conn).await;
match res {
Ok(_) => Ok(()),
Err(e) => Err(DbErr::Custom(e.to_string())),
}
async fn get_stats_for_attrs(
&self,
attr_ids: Vec<i32>,
stat_type: i32,
epoch_id: Option<EpochId>,
) -> StorageResult<Option<f32>> {
todo!()
}

async fn get_cost_analysis(
&self,
expr_id: storage_layer::ExprId,
epoch_id: storage_layer::EpochId,
) -> Option<i32> {
let cost = Cost::find()
.filter(cost::Column::ExprId.eq(expr_id))
.filter(cost::Column::EpochId.eq(epoch_id))
.one(&self.db_conn)
.await
.unwrap();
assert!(cost.is_some(), "Cost not found in Cost table");
assert!(cost.clone().unwrap().valid, "Cost is not valid");
cost.map(|c| c.cost)
) -> StorageResult<Option<i32>> {
todo!()
}

/// Get the latest cost for an expression
async fn get_cost(&self, expr_id: storage_layer::ExprId) -> Option<i32> {
let cost = Cost::find()
.filter(cost::Column::ExprId.eq(expr_id))
.order_by_desc(cost::Column::EpochId)
.one(&self.db_conn)
.await
.unwrap();
assert!(cost.is_some(), "Cost not found in Cost table");
assert!(cost.clone().unwrap().valid, "Cost is not valid");
cost.map(|c| c.cost)
async fn get_cost(&self, expr_id: storage_layer::ExprId) -> StorageResult<Option<i32>> {
todo!()
}

async fn get_group_winner_from_group_id(
&self,
group_id: i32,
) -> Option<physical_expression::ActiveModel> {
) -> StorageResult<Option<physical_expression::ActiveModel>> {
todo!()
}

async fn add_new_expr(
&mut self,
expr: storage_layer::Expression,
) -> (storage_layer::GroupId, storage_layer::ExprId) {
) -> StorageResult<(storage_layer::GroupId, storage_layer::ExprId)> {
todo!()
}

async fn add_expr_to_group(
&mut self,
expr: storage_layer::Expression,
group_id: storage_layer::GroupId,
) -> Option<storage_layer::ExprId> {
) -> StorageResult<Option<storage_layer::ExprId>> {
todo!()
}

async fn get_group_id(&self, expr_id: storage_layer::ExprId) -> storage_layer::GroupId {
async fn get_group_id(
&self,
expr_id: storage_layer::ExprId,
) -> StorageResult<storage_layer::GroupId> {
todo!()
}

async fn get_expr_memoed(&self, expr_id: storage_layer::ExprId) -> storage_layer::Expression {
async fn get_expr_memoed(
&self,
expr_id: storage_layer::ExprId,
) -> StorageResult<storage_layer::Expression> {
todo!()
}

async fn get_all_group_ids(&self) -> Vec<storage_layer::GroupId> {
async fn get_all_group_ids(&self) -> StorageResult<Vec<storage_layer::GroupId>> {
todo!()
}

async fn get_group(
&self,
group_id: storage_layer::GroupId,
) -> crate::entities::cascades_group::ActiveModel {
) -> StorageResult<crate::entities::cascades_group::ActiveModel> {
todo!()
}

async fn update_group_winner(
&mut self,
group_id: storage_layer::GroupId,
latest_winner: Option<storage_layer::ExprId>,
) {
) -> StorageResult<()> {
todo!()
}

async fn get_all_exprs_in_group(
&self,
group_id: storage_layer::GroupId,
) -> Vec<storage_layer::ExprId> {
) -> StorageResult<Vec<storage_layer::ExprId>> {
todo!()
}

async fn get_group_info(
&self,
group_id: storage_layer::GroupId,
) -> &Option<storage_layer::ExprId> {
) -> StorageResult<&Option<storage_layer::ExprId>> {
todo!()
}

async fn get_predicate_binding(
&self,
group_id: storage_layer::GroupId,
) -> Option<storage_layer::Expression> {
) -> StorageResult<Option<storage_layer::Expression>> {
todo!()
}

async fn try_get_predicate_binding(
&self,
group_id: storage_layer::GroupId,
) -> Option<storage_layer::Expression> {
) -> StorageResult<Option<storage_layer::Expression>> {
todo!()
}
}

#[cfg(test)]
mod tests {
use crate::migrate;
use sea_orm::{ConnectionTrait, Database, EntityTrait, ModelTrait};
use serde_json::de;

use crate::entities::event::Entity as Event;
use crate::storage_layer::StorageLayer;
use crate::TEST_DATABASE_URL;

async fn run_migration() {
let _ = std::fs::remove_file(TEST_DATABASE_URL);

let db = Database::connect(TEST_DATABASE_URL)
.await
.expect("Unable to connect to the database");

migrate(&db)
.await
.expect("Something went wrong during migration");

db.execute(sea_orm::Statement::from_string(
sea_orm::DatabaseBackend::Sqlite,
"PRAGMA foreign_keys = ON;".to_owned(),
))
.await
.expect("Unable to enable foreign keys");
}

#[tokio::test]
async fn test_create_new_epoch() {
run_migration().await;
let mut orm_manager = super::ORMManager::new(Some(TEST_DATABASE_URL)).await;
let res = orm_manager
.create_new_epoch("source".to_string(), "data".to_string())
.await;
println!("{:?}", res);
assert!(res.is_ok());
assert_eq!(
super::Event::find()
.all(&orm_manager.db_conn)
.await
.unwrap()
.len(),
1
);
println!(
"{:?}",
super::Event::find()
.all(&orm_manager.db_conn)
.await
.unwrap()[0]
);
assert_eq!(
super::Event::find()
.all(&orm_manager.db_conn)
.await
.unwrap()[0]
.epoch_id,
res.unwrap()
);
}

#[tokio::test]
#[ignore] // Need to update all tables
async fn test_store_cost() {
run_migration().await;
let mut orm_manager = super::ORMManager::new(Some(TEST_DATABASE_URL)).await;
let epoch_id = orm_manager
.create_new_epoch("source".to_string(), "data".to_string())
.await
.unwrap();
let expr_id = 1;
let cost = 42;
let res = orm_manager.store_cost(expr_id, cost, epoch_id).await;
match res {
Ok(_) => assert!(true),
Err(e) => {
println!("Error: {:?}", e);
assert!(false);
}
}
let costs = super::Cost::find().all(&orm_manager.db_conn).await.unwrap();
assert_eq!(costs.len(), 1);
assert_eq!(costs[0].epoch_id, epoch_id);
assert_eq!(costs[0].expr_id, expr_id);
assert_eq!(costs[0].cost, cost);
}
}
Loading

0 comments on commit 436ffd4

Please sign in to comment.