Skip to content

Commit

Permalink
Fix indexes on Postgres
Browse files Browse the repository at this point in the history
  • Loading branch information
m1guelpf committed Sep 2, 2023
1 parent 6aad029 commit fee71c7
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 32 deletions.
29 changes: 29 additions & 0 deletions ensemble/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,32 @@ pub async fn get() -> Result<Connection, ConnectError> {
Some(rb) => Ok(rb.get_pool()?.get().await?),
}
}

#[cfg(any(feature = "mysql", feature = "postgres"))]
pub enum Database {
MySQL,
PostgreSQL,
}

#[cfg(any(feature = "mysql", feature = "postgres"))]
impl Database {
pub fn is_mysql(&self) -> bool {
matches!(self, Database::MySQL)
}

pub fn is_postgres(&self) -> bool {
matches!(self, Database::PostgreSQL)
}
}

#[cfg(any(feature = "mysql", feature = "postgres"))]
pub const fn which_db() -> Database {
#[cfg(all(feature = "mysql", feature = "postgres"))]
panic!("Both the `mysql` and `postgres` features are enabled. Please enable only one of them.");

if cfg!(feature = "mysql") {
Database::MySQL
} else {
Database::PostgreSQL
}
}
2 changes: 2 additions & 0 deletions ensemble/src/migrations/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ use std::fmt::Debug;
use crate::connection::ConnectError;

pub use migrator::Migrator;
#[cfg(any(feature = "mysql", feature = "postgres"))]
pub use schema::Schema;

mod migrator;
#[cfg(any(feature = "mysql", feature = "postgres"))]
mod schema;

#[derive(Debug, thiserror::Error)]
Expand Down
15 changes: 11 additions & 4 deletions ensemble/src/migrations/schema/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use rbs::Value;
use std::{fmt::Display, sync::mpsc};

use super::Schemable;
use crate::connection;

#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Type {
Expand Down Expand Up @@ -114,7 +115,16 @@ impl Column {
}

pub(crate) fn to_sql(&self) -> String {
let mut sql = format!("{} {}", self.name, self.r#type);
let db_type = if connection::which_db().is_postgres()
&& self.r#type == Type::BigInteger
&& self.auto_increment
{
"bigserial".to_string()
} else {
self.r#type.to_string()
};

let mut sql = format!("{} {db_type}", self.name);

#[cfg(feature = "mysql")]
if self.unsigned {
Expand Down Expand Up @@ -162,9 +172,6 @@ impl Column {
if self.auto_increment {
#[cfg(feature = "mysql")]
sql.push_str(" AUTO_INCREMENT");

#[cfg(feature = "postgres")]
sql.push_str(" SERIAL");
}

if let Some(index) = &self.index {
Expand Down
47 changes: 33 additions & 14 deletions ensemble/src/migrations/schema/command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,23 @@ use std::{fmt::Display, sync::mpsc};

use ensemble_derive::Column;

use crate::connection::{self, Database};

use super::Schemable;

#[derive(Debug)]
pub struct Command {
sql: String,
}

impl Command {
pub fn to_sql(&self) -> String {
self.sql.clone()
}
pub(crate) inline_sql: String,
pub(crate) post_sql: Option<String>,
}

#[derive(Debug, Clone, Column)]
#[allow(dead_code)]
pub struct ForeignIndex {
#[builder(init)]
column: String,
#[builder(init)]
origin_table: String,
name: Option<String>,
#[builder(rename = "references")]
foreign_column: Option<String>,
Expand All @@ -35,7 +34,7 @@ pub struct ForeignIndex {
}

impl ForeignIndex {
fn to_sql(&self) -> String {
fn to_sql(&self) -> (String, Option<String>) {
let foreign_column = &self
.foreign_column
.as_ref()
Expand All @@ -46,9 +45,15 @@ impl ForeignIndex {
ToString::to_string,
);

let mut sql = format!(
"KEY {index_name} ({}), CONSTRAINT {index_name} FOREIGN KEY ({}) REFERENCES {}({foreign_column})", self.column, self.column, self.table,
);
let mut sql = match connection::which_db() {
Database::MySQL => format!(
"KEY {index_name} ({}), CONSTRAINT {index_name} FOREIGN KEY ({}) REFERENCES {}({foreign_column})", self.column, self.column, self.table,
),
Database::PostgreSQL => format!(
"FOREIGN KEY ({}) REFERENCES {}({foreign_column})",
self.column, self.table,
)
};

if let Some(on_delete) = &self.on_delete {
sql.push_str(&format!(" ON DELETE {on_delete}"));
Expand All @@ -58,7 +63,16 @@ impl ForeignIndex {
sql.push_str(&format!(" ON UPDATE {on_update}"));
}

sql
match connection::which_db() {
Database::MySQL => (sql, None),
Database::PostgreSQL => (
sql,
Some(format!(
"CREATE INDEX {index_name} ON {}({});",
self.origin_table, self.column
)),
),
}
}
}

Expand All @@ -67,8 +81,13 @@ impl ForeignIndex {
impl Drop for ForeignIndex {
fn drop(&mut self) {
if let Some(tx) = self.tx.take() {
tx.send(Schemable::Command(Command { sql: self.to_sql() }))
.unwrap();
let (inline_sql, post_sql) = self.to_sql();

tx.send(Schemable::Command(Command {
inline_sql,
post_sql,
}))
.unwrap();
drop(tx);
}
}
Expand Down
30 changes: 19 additions & 11 deletions ensemble/src/migrations/schema/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ use self::{
column::{Column, Type},
command::{Command, ForeignIndex},
};

use super::{migrator::MIGRATE_CONN, Error};
use crate::{connection, Model};

mod column;
mod command;

pub struct Schema {}

pub enum Schemable {
Column(Column),
Command(Command),
Expand All @@ -31,23 +31,27 @@ impl Schema {
where
F: FnOnce(&mut Table) + Send,
{
let (columns, commands) = Self::get_schema(callback)?;
let (columns, commands) = Self::get_schema(table_name.to_string(), callback)?;
let mut conn_lock = MIGRATE_CONN.try_lock().map_err(|_| Error::Lock)?;
let mut conn = conn_lock.take().ok_or(Error::Lock)?;

let sql = format!(
"CREATE TABLE {} ({}) {}",
"CREATE TABLE {} ({}) {}; {}",
table_name,
columns
.iter()
.map(Column::to_sql)
.chain(commands.iter().map(Command::to_sql))
.chain(commands.iter().map(|cmd| cmd.inline_sql.clone()))
.join(", "),
if cfg!(feature = "mysql") {
if connection::which_db().is_mysql() {
"ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci"
} else {
""
}
},
commands
.iter()
.filter_map(|cmd| cmd.post_sql.clone())
.join("\n")
);

tracing::debug!(sql = sql.as_str(), "Running CREATE TABLE SQL query");
Expand Down Expand Up @@ -80,12 +84,15 @@ impl Schema {
Ok(())
}

fn get_schema<F>(callback: F) -> Result<(Vec<Column>, Vec<Command>), Error>
fn get_schema<F>(table_name: String, callback: F) -> Result<(Vec<Column>, Vec<Command>), Error>
where
F: FnOnce(&mut Table),
{
let (tx, rx) = mpsc::channel();
let mut table = Table { sender: Some(tx) };
let mut table = Table {
name: table_name,
sender: Some(tx),
};

let ret = std::thread::spawn(move || {
let mut schema = vec![];
Expand Down Expand Up @@ -114,6 +121,7 @@ impl Schema {

#[derive(Debug)]
pub struct Table {
name: String,
sender: Option<mpsc::Sender<Schemable>>,
}

Expand Down Expand Up @@ -174,7 +182,7 @@ impl Table {

/// Specify a foreign key for the table.
pub fn foreign(&mut self, column: &str) -> ForeignIndex {
ForeignIndex::new(column.to_string(), self.sender.clone())
ForeignIndex::new(column.to_string(), self.name.clone(), self.sender.clone())
}

/// Create a new enum column on the table.
Expand Down Expand Up @@ -202,7 +210,7 @@ impl Table {
Column::new(column.clone(), Type::String(255), self.sender.clone());
}

let index = ForeignIndex::new(column, self.sender.clone());
let index = ForeignIndex::new(column, self.name.clone(), self.sender.clone());
index.on(M::TABLE_NAME).references(M::PRIMARY_KEY)
}

Expand All @@ -216,7 +224,7 @@ impl Table {
column.unsigned(true);
};

let index = ForeignIndex::new(name.to_string(), self.sender.clone());
let index = ForeignIndex::new(name.to_string(), self.name.clone(), self.sender.clone());

// if the column name is of the form `resource_id`, we extract and set the table name and foreign column name
if let Some((resource, column)) = name.split_once('_') {
Expand Down
Loading

0 comments on commit fee71c7

Please sign in to comment.