Skip to content

Commit

Permalink
fix(cost-model): add more cost types and raw estimated statistic into…
Browse files Browse the repository at this point in the history
… ORM (#30)

* fix(cost-model): add more cost types and raw estimated statistic into ORM

* Add comments
  • Loading branch information
lanlou1554 authored Nov 13, 2024
1 parent 1dd21ac commit db8829d
Show file tree
Hide file tree
Showing 9 changed files with 107 additions and 142 deletions.
3 changes: 2 additions & 1 deletion optd-persistent/src/bin/init.rs
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,8 @@ async fn init_all_tables() -> Result<(), sea_orm::error::DbErr> {
id: Set(1),
physical_expression_id: Set(1),
epoch_id: Set(1),
cost: Set(10),
cost: Set(json!({"compute_cost":10, "io_cost":10})),
estimated_statistic: Set(10),
is_valid: Set(true),
};
plan_cost::Entity::insert(plan_cost)
Expand Down
15 changes: 12 additions & 3 deletions optd-persistent/src/cost_model/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,15 @@ pub struct Stat {
pub name: String,
}

/// TODO: documentation
#[derive(Clone, Debug, PartialEq)]
pub struct Cost {
pub compute_cost: i32,
pub io_cost: i32,
// Raw estimated output row count of targeted expression.
pub estimated_statistic: i32,
}

/// TODO: documentation
#[trait_variant::make(Send)]
pub trait CostModelStorageLayer {
Expand All @@ -91,7 +100,7 @@ pub trait CostModelStorageLayer {
async fn store_cost(
&self,
expr_id: Self::ExprId,
cost: i32,
cost: Cost,
epoch_id: Self::EpochId,
) -> StorageResult<()>;

Expand Down Expand Up @@ -126,7 +135,7 @@ pub trait CostModelStorageLayer {
&self,
expr_id: Self::ExprId,
epoch_id: Self::EpochId,
) -> StorageResult<Option<i32>>;
) -> StorageResult<Option<Cost>>;

async fn get_cost(&self, expr_id: Self::ExprId) -> StorageResult<Option<i32>>;
async fn get_cost(&self, expr_id: Self::ExprId) -> StorageResult<Option<Cost>>;
}
101 changes: 82 additions & 19 deletions optd-persistent/src/cost_model/orm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

use std::ptr::null;

use crate::cost_model::interface::Cost;
use crate::entities::{prelude::*, *};
use crate::{BackendError, BackendManager, CostModelError, CostModelStorageLayer, StorageResult};
use sea_orm::prelude::{Expr, Json};
Expand All @@ -11,6 +12,7 @@ use sea_orm::{
ActiveModelTrait, ColumnTrait, DbBackend, DbErr, DeleteResult, EntityOrSelect, ModelTrait,
QueryFilter, QueryOrder, QuerySelect, QueryTrait, RuntimeErr, TransactionTrait,
};
use serde_json::json;

use super::catalog::mock_catalog::{self, MockCatalog};
use super::interface::{CatalogSource, EpochOption, Stat};
Expand Down Expand Up @@ -443,33 +445,41 @@ impl CostModelStorageLayer for BackendManager {
&self,
expr_id: Self::ExprId,
epoch_id: Self::EpochId,
) -> StorageResult<Option<i32>> {
) -> StorageResult<Option<Cost>> {
let cost = PlanCost::find()
.filter(plan_cost::Column::PhysicalExpressionId.eq(expr_id))
.filter(plan_cost::Column::EpochId.eq(epoch_id))
.one(&self.db)
.await?;
assert!(cost.is_some(), "Cost not found in Cost table");
assert!(cost.clone().unwrap().is_valid, "Cost is not valid");
Ok(cost.map(|c| c.cost))
Ok(cost.map(|c| Cost {
compute_cost: c.cost.get("compute_cost").unwrap().as_i64().unwrap() as i32,
io_cost: c.cost.get("io_cost").unwrap().as_i64().unwrap() as i32,
estimated_statistic: c.estimated_statistic,
}))
}

async fn get_cost(&self, expr_id: Self::ExprId) -> StorageResult<Option<i32>> {
async fn get_cost(&self, expr_id: Self::ExprId) -> StorageResult<Option<Cost>> {
let cost = PlanCost::find()
.filter(plan_cost::Column::PhysicalExpressionId.eq(expr_id))
.order_by_desc(plan_cost::Column::EpochId)
.one(&self.db)
.await?;
assert!(cost.is_some(), "Cost not found in Cost table");
assert!(cost.clone().unwrap().is_valid, "Cost is not valid");
Ok(cost.map(|c| c.cost))
Ok(cost.map(|c| Cost {
compute_cost: c.cost.get("compute_cost").unwrap().as_i64().unwrap() as i32,
io_cost: c.cost.get("io_cost").unwrap().as_i64().unwrap() as i32,
estimated_statistic: c.estimated_statistic,
}))
}

/// TODO: documentation
async fn store_cost(
&self,
physical_expression_id: Self::ExprId,
cost: i32,
cost: Cost,
epoch_id: Self::EpochId,
) -> StorageResult<()> {
let expr_exists = PhysicalExpression::find_by_id(physical_expression_id)
Expand All @@ -496,7 +506,10 @@ impl CostModelStorageLayer for BackendManager {
let new_cost = plan_cost::ActiveModel {
physical_expression_id: sea_orm::ActiveValue::Set(physical_expression_id),
epoch_id: sea_orm::ActiveValue::Set(epoch_id),
cost: sea_orm::ActiveValue::Set(cost),
cost: sea_orm::ActiveValue::Set(
json!({"compute_cost": cost.compute_cost, "io_cost": cost.io_cost}),
),
estimated_statistic: sea_orm::ActiveValue::Set(cost.estimated_statistic),
is_valid: sea_orm::ActiveValue::Set(true),
..Default::default()
};
Expand All @@ -507,7 +520,7 @@ impl CostModelStorageLayer for BackendManager {

#[cfg(test)]
mod tests {
use crate::cost_model::interface::{EpochOption, StatType};
use crate::cost_model::interface::{Cost, EpochOption, StatType};
use crate::{cost_model::interface::Stat, migrate, CostModelStorageLayer};
use crate::{get_sqlite_url, TEST_DATABASE_FILE};
use sea_orm::sqlx::database;
Expand Down Expand Up @@ -681,7 +694,17 @@ mod tests {
.await
.unwrap();
backend_manager
.store_cost(expr_id, 42, versioned_stat_res[0].epoch_id)
.store_cost(
expr_id,
{
Cost {
compute_cost: 42,
io_cost: 42,
estimated_statistic: 42,
}
},
versioned_stat_res[0].epoch_id,
)
.await
.unwrap();
let cost_res = PlanCost::find()
Expand Down Expand Up @@ -744,7 +767,7 @@ mod tests {
.await
.unwrap();
assert_eq!(cost_res.len(), 1);
assert_eq!(cost_res[0].cost, 42);
assert_eq!(cost_res[0].cost, json!({"compute_cost": 42, "io_cost": 42}));
assert_eq!(cost_res[0].epoch_id, epoch_id1);
assert!(!cost_res[0].is_valid);

Expand Down Expand Up @@ -875,9 +898,13 @@ mod tests {
.await
.unwrap();
let physical_expression_id = 1;
let cost = 42;
let cost = Cost {
compute_cost: 42,
io_cost: 42,
estimated_statistic: 42,
};
backend_manager
.store_cost(physical_expression_id, cost, epoch_id)
.store_cost(physical_expression_id, cost.clone(), epoch_id)
.await
.unwrap();
let costs = super::PlanCost::find()
Expand All @@ -887,7 +914,14 @@ mod tests {
assert_eq!(costs.len(), 2); // The first row one is the initialized data
assert_eq!(costs[1].epoch_id, epoch_id);
assert_eq!(costs[1].physical_expression_id, physical_expression_id);
assert_eq!(costs[1].cost, cost);
assert_eq!(
costs[1].cost,
json!({"compute_cost": cost.compute_cost, "io_cost": cost.io_cost})
);
assert_eq!(
costs[1].estimated_statistic as i32,
cost.estimated_statistic
);

remove_db_file(DATABASE_FILE);
}
Expand All @@ -903,9 +937,13 @@ mod tests {
.await
.unwrap();
let physical_expression_id = 1;
let cost = 42;
let cost = Cost {
compute_cost: 42,
io_cost: 42,
estimated_statistic: 42,
};
let _ = backend_manager
.store_cost(physical_expression_id, cost, epoch_id)
.store_cost(physical_expression_id, cost.clone(), epoch_id)
.await;
let costs = super::PlanCost::find()
.all(&backend_manager.db)
Expand All @@ -914,7 +952,14 @@ mod tests {
assert_eq!(costs.len(), 2); // The first row one is the initialized data
assert_eq!(costs[1].epoch_id, epoch_id);
assert_eq!(costs[1].physical_expression_id, physical_expression_id);
assert_eq!(costs[1].cost, cost);
assert_eq!(
costs[1].cost,
json!({"compute_cost": cost.compute_cost, "io_cost": cost.io_cost})
);
assert_eq!(
costs[1].estimated_statistic as i32,
cost.estimated_statistic
);

let res = backend_manager
.get_cost(physical_expression_id)
Expand All @@ -936,9 +981,13 @@ mod tests {
.await
.unwrap();
let physical_expression_id = 1;
let cost = 42;
let cost = Cost {
compute_cost: 1420,
io_cost: 42,
estimated_statistic: 42,
};
let _ = backend_manager
.store_cost(physical_expression_id, cost, epoch_id)
.store_cost(physical_expression_id, cost.clone(), epoch_id)
.await;
let costs = super::PlanCost::find()
.all(&backend_manager.db)
Expand All @@ -947,7 +996,14 @@ mod tests {
assert_eq!(costs.len(), 2); // The first row one is the initialized data
assert_eq!(costs[1].epoch_id, epoch_id);
assert_eq!(costs[1].physical_expression_id, physical_expression_id);
assert_eq!(costs[1].cost, cost);
assert_eq!(
costs[1].cost,
json!({"compute_cost": cost.compute_cost, "io_cost": cost.io_cost})
);
assert_eq!(
costs[1].estimated_statistic as i32,
cost.estimated_statistic
);
println!("{:?}", costs);

// Retrieve physical_expression_id 1 and epoch_id 1
Expand All @@ -957,7 +1013,14 @@ mod tests {
.unwrap();

// The cost in the dummy data is 10
assert_eq!(res.unwrap(), 10);
assert_eq!(
res.unwrap(),
Cost {
compute_cost: 10,
io_cost: 10,
estimated_statistic: 10,
}
);

remove_db_file(DATABASE_FILE);
}
Expand Down
Binary file modified optd-persistent/src/db/init.db
Binary file not shown.
68 changes: 0 additions & 68 deletions optd-persistent/src/entities/constraint.rs

This file was deleted.

48 changes: 0 additions & 48 deletions optd-persistent/src/entities/index.rs

This file was deleted.

3 changes: 2 additions & 1 deletion optd-persistent/src/entities/plan_cost.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ pub struct Model {
pub id: i32,
pub physical_expression_id: i32,
pub epoch_id: i32,
pub cost: i32,
pub cost: Json,
pub estimated_statistic: i32,
pub is_valid: bool,
}

Expand Down
Loading

0 comments on commit db8829d

Please sign in to comment.