diff --git a/ensemble/src/connection.rs b/ensemble/src/connection.rs index 237af04..25bf510 100644 --- a/ensemble/src/connection.rs +++ b/ensemble/src/connection.rs @@ -80,3 +80,32 @@ pub async fn get() -> Result { Some(rb) => Ok(rb.get_pool()?.get().await?), } } + +#[cfg(any(feature = "mysql", feature = "postgres"))] +pub enum Database { + MySQL, + PostgreSQL, +} + +#[cfg(any(feature = "mysql", feature = "postgres"))] +impl Database { + pub fn is_mysql(&self) -> bool { + matches!(self, Database::MySQL) + } + + pub fn is_postgres(&self) -> bool { + matches!(self, Database::PostgreSQL) + } +} + +#[cfg(any(feature = "mysql", feature = "postgres"))] +pub const fn which_db() -> Database { + #[cfg(all(feature = "mysql", feature = "postgres"))] + panic!("Both the `mysql` and `postgres` features are enabled. Please enable only one of them."); + + if cfg!(feature = "mysql") { + Database::MySQL + } else { + Database::PostgreSQL + } +} diff --git a/ensemble/src/migrations/mod.rs b/ensemble/src/migrations/mod.rs index 5f601b5..e7af011 100644 --- a/ensemble/src/migrations/mod.rs +++ b/ensemble/src/migrations/mod.rs @@ -4,9 +4,11 @@ use std::fmt::Debug; use crate::connection::ConnectError; pub use migrator::Migrator; +#[cfg(any(feature = "mysql", feature = "postgres"))] pub use schema::Schema; mod migrator; +#[cfg(any(feature = "mysql", feature = "postgres"))] mod schema; #[derive(Debug, thiserror::Error)] diff --git a/ensemble/src/migrations/schema/column.rs b/ensemble/src/migrations/schema/column.rs index 8886579..452b50b 100644 --- a/ensemble/src/migrations/schema/column.rs +++ b/ensemble/src/migrations/schema/column.rs @@ -4,6 +4,7 @@ use rbs::Value; use std::{fmt::Display, sync::mpsc}; use super::Schemable; +use crate::connection; #[derive(Debug, Clone, PartialEq, Eq)] pub enum Type { @@ -114,7 +115,16 @@ impl Column { } pub(crate) fn to_sql(&self) -> String { - let mut sql = format!("{} {}", self.name, self.r#type); + let db_type = if connection::which_db().is_postgres() + && self.r#type == Type::BigInteger + && self.auto_increment + { + "bigserial".to_string() + } else { + self.r#type.to_string() + }; + + let mut sql = format!("{} {db_type}", self.name); #[cfg(feature = "mysql")] if self.unsigned { @@ -162,9 +172,6 @@ impl Column { if self.auto_increment { #[cfg(feature = "mysql")] sql.push_str(" AUTO_INCREMENT"); - - #[cfg(feature = "postgres")] - sql.push_str(" SERIAL"); } if let Some(index) = &self.index { diff --git a/ensemble/src/migrations/schema/command.rs b/ensemble/src/migrations/schema/command.rs index f19df99..9c6dbd6 100644 --- a/ensemble/src/migrations/schema/command.rs +++ b/ensemble/src/migrations/schema/command.rs @@ -2,17 +2,14 @@ use std::{fmt::Display, sync::mpsc}; use ensemble_derive::Column; +use crate::connection::{self, Database}; + use super::Schemable; #[derive(Debug)] pub struct Command { - sql: String, -} - -impl Command { - pub fn to_sql(&self) -> String { - self.sql.clone() - } + pub(crate) inline_sql: String, + pub(crate) post_sql: Option, } #[derive(Debug, Clone, Column)] @@ -20,6 +17,8 @@ impl Command { pub struct ForeignIndex { #[builder(init)] column: String, + #[builder(init)] + origin_table: String, name: Option, #[builder(rename = "references")] foreign_column: Option, @@ -35,7 +34,7 @@ pub struct ForeignIndex { } impl ForeignIndex { - fn to_sql(&self) -> String { + fn to_sql(&self) -> (String, Option) { let foreign_column = &self .foreign_column .as_ref() @@ -46,9 +45,15 @@ impl ForeignIndex { ToString::to_string, ); - let mut sql = format!( - "KEY {index_name} ({}), CONSTRAINT {index_name} FOREIGN KEY ({}) REFERENCES {}({foreign_column})", self.column, self.column, self.table, - ); + let mut sql = match connection::which_db() { + Database::MySQL => format!( + "KEY {index_name} ({}), CONSTRAINT {index_name} FOREIGN KEY ({}) REFERENCES {}({foreign_column})", self.column, self.column, self.table, + ), + Database::PostgreSQL => format!( + "FOREIGN KEY ({}) REFERENCES {}({foreign_column})", + self.column, self.table, + ) + }; if let Some(on_delete) = &self.on_delete { sql.push_str(&format!(" ON DELETE {on_delete}")); @@ -58,7 +63,16 @@ impl ForeignIndex { sql.push_str(&format!(" ON UPDATE {on_update}")); } - sql + match connection::which_db() { + Database::MySQL => (sql, None), + Database::PostgreSQL => ( + sql, + Some(format!( + "CREATE INDEX {index_name} ON {}({});", + self.origin_table, self.column + )), + ), + } } } @@ -67,8 +81,13 @@ impl ForeignIndex { impl Drop for ForeignIndex { fn drop(&mut self) { if let Some(tx) = self.tx.take() { - tx.send(Schemable::Command(Command { sql: self.to_sql() })) - .unwrap(); + let (inline_sql, post_sql) = self.to_sql(); + + tx.send(Schemable::Command(Command { + inline_sql, + post_sql, + })) + .unwrap(); drop(tx); } } diff --git a/ensemble/src/migrations/schema/mod.rs b/ensemble/src/migrations/schema/mod.rs index 2af356c..a54303d 100644 --- a/ensemble/src/migrations/schema/mod.rs +++ b/ensemble/src/migrations/schema/mod.rs @@ -7,7 +7,6 @@ use self::{ column::{Column, Type}, command::{Command, ForeignIndex}, }; - use super::{migrator::MIGRATE_CONN, Error}; use crate::{connection, Model}; @@ -15,6 +14,7 @@ mod column; mod command; pub struct Schema {} + pub enum Schemable { Column(Column), Command(Command), @@ -31,23 +31,27 @@ impl Schema { where F: FnOnce(&mut Table) + Send, { - let (columns, commands) = Self::get_schema(callback)?; + let (columns, commands) = Self::get_schema(table_name.to_string(), callback)?; let mut conn_lock = MIGRATE_CONN.try_lock().map_err(|_| Error::Lock)?; let mut conn = conn_lock.take().ok_or(Error::Lock)?; let sql = format!( - "CREATE TABLE {} ({}) {}", + "CREATE TABLE {} ({}) {}; {}", table_name, columns .iter() .map(Column::to_sql) - .chain(commands.iter().map(Command::to_sql)) + .chain(commands.iter().map(|cmd| cmd.inline_sql.clone())) .join(", "), - if cfg!(feature = "mysql") { + if connection::which_db().is_mysql() { "ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci" } else { "" - } + }, + commands + .iter() + .filter_map(|cmd| cmd.post_sql.clone()) + .join("\n") ); tracing::debug!(sql = sql.as_str(), "Running CREATE TABLE SQL query"); @@ -80,12 +84,15 @@ impl Schema { Ok(()) } - fn get_schema(callback: F) -> Result<(Vec, Vec), Error> + fn get_schema(table_name: String, callback: F) -> Result<(Vec, Vec), Error> where F: FnOnce(&mut Table), { let (tx, rx) = mpsc::channel(); - let mut table = Table { sender: Some(tx) }; + let mut table = Table { + name: table_name, + sender: Some(tx), + }; let ret = std::thread::spawn(move || { let mut schema = vec![]; @@ -114,6 +121,7 @@ impl Schema { #[derive(Debug)] pub struct Table { + name: String, sender: Option>, } @@ -174,7 +182,7 @@ impl Table { /// Specify a foreign key for the table. pub fn foreign(&mut self, column: &str) -> ForeignIndex { - ForeignIndex::new(column.to_string(), self.sender.clone()) + ForeignIndex::new(column.to_string(), self.name.clone(), self.sender.clone()) } /// Create a new enum column on the table. @@ -202,7 +210,7 @@ impl Table { Column::new(column.clone(), Type::String(255), self.sender.clone()); } - let index = ForeignIndex::new(column, self.sender.clone()); + let index = ForeignIndex::new(column, self.name.clone(), self.sender.clone()); index.on(M::TABLE_NAME).references(M::PRIMARY_KEY) } @@ -216,7 +224,7 @@ impl Table { column.unsigned(true); }; - let index = ForeignIndex::new(name.to_string(), self.sender.clone()); + let index = ForeignIndex::new(name.to_string(), self.name.clone(), self.sender.clone()); // if the column name is of the form `resource_id`, we extract and set the table name and foreign column name if let Some((resource, column)) = name.split_once('_') { diff --git a/examples/migrations/Cargo.lock b/examples/migrations/Cargo.lock index 57571f8..f114620 100644 --- a/examples/migrations/Cargo.lock +++ b/examples/migrations/Cargo.lock @@ -450,6 +450,7 @@ version = "0.0.0" dependencies = [ "ensemble", "tokio", + "tracing-subscriber", ] [[package]] @@ -825,6 +826,15 @@ version = "0.4.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" +[[package]] +name = "matchers" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" +dependencies = [ + "regex-automata 0.1.10", +] + [[package]] name = "md-5" version = "0.10.5" @@ -887,6 +897,16 @@ dependencies = [ "tempfile", ] +[[package]] +name = "nu-ansi-term" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" +dependencies = [ + "overload", + "winapi", +] + [[package]] name = "num-bigint" version = "0.4.3" @@ -1021,6 +1041,12 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" +[[package]] +name = "overload" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" + [[package]] name = "parking_lot" version = "0.12.1" @@ -1449,8 +1475,17 @@ checksum = "81bc1d4caf89fac26a70747fe603c130093b53c773888797a6329091246d651a" dependencies = [ "aho-corasick", "memchr", - "regex-automata", - "regex-syntax", + "regex-automata 0.3.6", + "regex-syntax 0.7.4", +] + +[[package]] +name = "regex-automata" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" +dependencies = [ + "regex-syntax 0.6.29", ] [[package]] @@ -1461,9 +1496,15 @@ checksum = "fed1ceff11a1dddaee50c9dc8e4938bd106e9d89ae372f192311e7da498e3b69" dependencies = [ "aho-corasick", "memchr", - "regex-syntax", + "regex-syntax 0.7.4", ] +[[package]] +name = "regex-syntax" +version = "0.6.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" + [[package]] name = "regex-syntax" version = "0.7.4" @@ -1682,6 +1723,15 @@ dependencies = [ "tokio", ] +[[package]] +name = "sharded-slab" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "900fba806f70c630b0a382d0d825e17a0f19fcd059a2ade1ff237bcddf446b31" +dependencies = [ + "lazy_static", +] + [[package]] name = "signal-hook-registry" version = "1.4.1" @@ -1828,6 +1878,16 @@ dependencies = [ "syn 2.0.29", ] +[[package]] +name = "thread_local" +version = "1.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fdd6f064ccff2d6567adcb3873ca630700f00b5ad3f060c25b5dcfd9a4ce152" +dependencies = [ + "cfg-if", + "once_cell", +] + [[package]] name = "time" version = "0.3.25" @@ -1947,6 +2007,36 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0955b8137a1df6f1a2e9a37d8a6656291ff0297c1a97c24e0d8425fe2312f79a" dependencies = [ "once_cell", + "valuable", +] + +[[package]] +name = "tracing-log" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78ddad33d2d10b1ed7eb9d1f518a5674713876e97e5bb9b7345a7984fbb4f922" +dependencies = [ + "lazy_static", + "log", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30a651bc37f915e81f087d86e62a18eec5f79550c7faff886f7090b4ea757c77" +dependencies = [ + "matchers", + "nu-ansi-term", + "once_cell", + "regex", + "sharded-slab", + "smallvec", + "thread_local", + "tracing", + "tracing-core", + "tracing-log", ] [[package]] @@ -2002,6 +2092,12 @@ dependencies = [ "getrandom", ] +[[package]] +name = "valuable" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" + [[package]] name = "vcpkg" version = "0.2.15"