Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add database create and drop commands #371

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
58 changes: 58 additions & 0 deletions advanced_alchemy/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,4 +399,62 @@ async def _dump_tables() -> None:

return run(_dump_tables)

@database_group.command(name="create", help="Create a new database.")
@bind_key_option
@no_prompt_option
@click.option(
"--encoding",
"encoding",
help="Set the encoding for the created database",
type=str,
required=False,
default="utf8",
)
def create_database(bind_key: str | None, no_prompt: bool, encoding: str) -> None: # pyright: ignore[reportUnusedFunction]
from anyio import run
from rich.prompt import Confirm

from advanced_alchemy.utils.databases import create_database as _create_database

ctx = click.get_current_context()
sqlalchemy_config = get_config_by_bind_key(ctx, bind_key)

db_name = sqlalchemy_config.get_engine().url.database

console.rule("[yellow]Starting database creation process[/]", align="left")
input_confirmed = (
True if no_prompt else Confirm.ask(f"[bold]Are you sure you want to create a new database `{db_name}`?[/]")
)
if input_confirmed:

async def _create_database_wrapper() -> None:
await _create_database(sqlalchemy_config, encoding)

run(_create_database_wrapper)

@database_group.command(name="drop", help="Drop the current database.")
@bind_key_option
@no_prompt_option
def drop_database(bind_key: str | None, no_prompt: bool) -> None: # pyright: ignore[reportUnusedFunction]
from anyio import run
from rich.prompt import Confirm

from advanced_alchemy.utils.databases import drop_database as _drop_database

ctx = click.get_current_context()
sqlalchemy_config = get_config_by_bind_key(ctx, bind_key)

db_name = sqlalchemy_config.get_engine().url.database

console.rule("[yellow]Starting database deletion process[/]", align="left")
input_confirmed = (
True if no_prompt else Confirm.ask(f"[bold]Are you sure you want to drop database `{db_name}`?[/]")
)
if input_confirmed:

async def _drop_database_wrapper() -> None:
await _drop_database(sqlalchemy_config)

run(_drop_database_wrapper)

return database_group
167 changes: 167 additions & 0 deletions advanced_alchemy/utils/databases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING, Protocol

from sqlalchemy import URL, Engine, create_engine, text
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine

from advanced_alchemy.config.sync import SQLAlchemySyncConfig

if TYPE_CHECKING:
from advanced_alchemy.config.asyncio import SQLAlchemyAsyncConfig


class Adapter(Protocol):
supported_drivers: set[str] = set()
dialect: str
config: SQLAlchemyAsyncConfig | SQLAlchemySyncConfig
encoding: str | None = None

engine: Engine | None = None
async_engine: AsyncEngine | None = None
original_database_name: str | None = None

def __init__(self, config: SQLAlchemyAsyncConfig | SQLAlchemySyncConfig, encoding: str = "utf8") -> None:
self.config = config
self.encoding = encoding

if isinstance(config, SQLAlchemySyncConfig):
self.setup(config)
else:
self.setup_async(config)

def setup_async(self, config: SQLAlchemyAsyncConfig) -> None: ...
def setup(self, config: SQLAlchemySyncConfig) -> None: ...

async def create_async(self) -> None: ...
def create(self) -> None: ...

async def drop_async(self) -> None: ...
def drop(self) -> None: ...


class SQLiteAdapter(Adapter):
supported_drivers: set[str] = {"pysqlite", "aiosqlite"}
dialect: str = "sqlite"

def setup(self, config: SQLAlchemySyncConfig) -> None:
self.engine = config.get_engine()
self.original_database_name = self.engine.url.database

def setup_async(self, config: SQLAlchemyAsyncConfig) -> None:
self.async_engine = config.get_engine()
self.original_database_name = self.async_engine.url.database

def create(self) -> None:
if self.engine is not None and self.original_database_name and self.original_database_name != ":memory:":
with self.engine.begin() as conn:
conn.execute(text("CREATE TABLE DB(id int)"))
conn.execute(text("DROP TABLE DB"))

async def create_async(self) -> None:
if self.async_engine is not None:
async with self.async_engine.begin() as conn:
await conn.execute(text("CREATE TABLE DB(id int)"))
await conn.execute(text("DROP TABLE DB"))

def drop(self) -> None:
return self._drop()

async def drop_async(self) -> None:
return self._drop()

def _drop(self) -> None:
if self.original_database_name and self.original_database_name != ":memory:":
Path(self.original_database_name).unlink()


class PostgresAdapter(Adapter):
supported_drivers: set[str] = {"asyncpg", "pg8000", "psycopg", "psycopg2", "psycopg2cffi"}
dialect: str = "postgresql"

def _set_url(self, config: SQLAlchemyAsyncConfig | SQLAlchemySyncConfig) -> URL:
original_url = self.config.get_engine().url
self.original_database_name = original_url.database
return original_url._replace(database="postgres")

def setup(self, config: SQLAlchemySyncConfig) -> None:
updated_url = self._set_url(config)
self.engine = create_engine(updated_url, isolation_level="AUTOCOMMIT")

def setup_async(self, config: SQLAlchemyAsyncConfig) -> None:
updated_url = self._set_url(config)
self.async_engine = create_async_engine(updated_url, isolation_level="AUTOCOMMIT")

def create(self) -> None:
if self.engine:
with self.engine.begin() as conn:
sql = f"CREATE DATABASE {self.original_database_name} ENCODING '{self.encoding}'"
conn.execute(text(sql))

async def create_async(self) -> None:
if self.async_engine:
async with self.async_engine.begin() as conn:
sql = f"CREATE DATABASE {self.original_database_name} ENCODING '{self.encoding}'"
await conn.execute(text(sql))

def drop(self) -> None:
if self.engine:
with self.engine.begin() as conn:
# Disconnect all users from the database we are dropping.
version = conn.dialect.server_version_info
sql = self._disconnect_users_sql(version, self.original_database_name)
conn.execute(text(sql))
conn.execute(text(f"DROP DATABASE {self.original_database_name}"))

async def drop_async(self) -> None:
if self.async_engine:
async with self.async_engine.begin() as conn:
# Disconnect all users from the database we are dropping.
version = conn.dialect.server_version_info
sql = self._disconnect_users_sql(version, self.original_database_name)
await conn.execute(text(sql))
await conn.execute(text(f"DROP DATABASE {self.original_database_name}"))

def _disconnect_users_sql(self, version: tuple[int, int] | None, database: str | None) -> str:
pid_column = ("pid" if version >= (9, 2) else "procpid") if version else "procpid"
return f"""
SELECT pg_terminate_backend(pg_stat_activity.{pid_column})
FROM pg_stat_activity
WHERE pg_stat_activity.datname = '{database}'
AND {pid_column} <> pg_backend_pid();""" # noqa: S608


ADAPTERS = {"sqlite": SQLiteAdapter, "postgresql": PostgresAdapter}


def get_adapter(config: SQLAlchemyAsyncConfig | SQLAlchemySyncConfig, encoding: str = "utf8") -> Adapter:
dialect_name = config.get_engine().url.get_dialect().name
adapter_class = ADAPTERS.get(dialect_name)

if not adapter_class:
msg = f"No adapter available for {dialect_name}"
raise ValueError(msg)

driver = config.get_engine().url.get_dialect().driver
if driver not in adapter_class.supported_drivers:
msg = f"{dialect_name} adapter does not support the {driver} driver"
raise ValueError(msg)

return adapter_class(config, encoding=encoding)


async def create_database(config: SQLAlchemySyncConfig | SQLAlchemyAsyncConfig, encoding: str = "utf8") -> None:
adapter = get_adapter(config, encoding)
if isinstance(config.get_engine(), AsyncEngine):
await adapter.create_async()
else:
adapter.create()


async def drop_database(config: SQLAlchemySyncConfig | SQLAlchemyAsyncConfig) -> None:
adapter = get_adapter(config)
if isinstance(config.get_engine(), AsyncEngine):
await adapter.drop_async()
else:
adapter.drop()
159 changes: 159 additions & 0 deletions tests/integration/test_database_commands.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
from __future__ import annotations

from pathlib import Path
from typing import AsyncGenerator, Generator

import pytest
from sqlalchemy import URL, Engine, NullPool, create_engine, select, text
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine
from sqlalchemy.orm import sessionmaker

from advanced_alchemy import base
from advanced_alchemy.extensions.litestar import SQLAlchemyAsyncConfig, SQLAlchemySyncConfig
from advanced_alchemy.utils.databases import create_database, drop_database

pytestmark = [
pytest.mark.integration,
]


@pytest.fixture()
def sqlite_engine_cd(tmp_path: Path) -> Generator[Engine, None, None]:
"""SQLite engine for end-to-end testing.

Returns:
Async SQLAlchemy engine instance.
"""
engine = create_engine(f"sqlite:///{tmp_path}/test-cd.db", poolclass=NullPool)
try:
yield engine
finally:
engine.dispose()


@pytest.fixture()
async def aiosqlite_engine_cd(tmp_path: Path) -> AsyncGenerator[AsyncEngine, None]:
"""SQLite engine for end-to-end testing.

Returns:
Async SQLAlchemy engine instance.
"""
engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_path}/test-cd-async.db", poolclass=NullPool)
try:
yield engine
finally:
await engine.dispose()


@pytest.fixture()
async def asyncpg_engine_cd(docker_ip: str, postgres_service: None) -> AsyncGenerator[AsyncEngine, None]:
"""Postgresql instance for end-to-end testing."""
yield create_async_engine(
URL(
drivername="postgresql+asyncpg",
username="postgres",
password="super-secret",
host=docker_ip,
port=5423,
database="testing_create_delete",
query={}, # type:ignore[arg-type]
),
poolclass=NullPool,
)


@pytest.fixture()
def psycopg_engine_cd(docker_ip: str, postgres_service: None) -> Generator[Engine, None, None]:
"""Postgresql instance for end-to-end testing."""
yield create_engine(
URL(
drivername="postgresql+psycopg",
username="postgres",
password="super-secret",
host=docker_ip,
port=5423,
database="postgres",
query={}, # type:ignore[arg-type]
),
poolclass=NullPool,
)


async def test_create_and_drop_sqlite_sync(sqlite_engine_cd: Engine, tmp_path: Path) -> None:
orm_registry = base.create_registry()
cfg = SQLAlchemySyncConfig(
engine_instance=sqlite_engine_cd,
session_maker=sessionmaker(bind=sqlite_engine_cd, expire_on_commit=False),
metadata=orm_registry.metadata,
)
file_path = f"{tmp_path}/test-cd.db"
assert not Path(f"{tmp_path}/test-cd.db").exists()
try:
await create_database(cfg)
assert Path(file_path).exists()
with cfg.get_session() as sess:
result = sess.execute(select(text("1")))
assert result.scalar_one() == 1
await drop_database(cfg)
assert not Path(file_path).exists()
finally:
# always clean up
if Path(file_path).exists():
Path(file_path).unlink()


async def test_create_and_drop_sqlite_async(aiosqlite_engine_cd: AsyncEngine, tmp_path: Path) -> None:
orm_registry = base.create_registry()
cfg = SQLAlchemyAsyncConfig(
engine_instance=aiosqlite_engine_cd,
session_maker=async_sessionmaker(bind=aiosqlite_engine_cd, expire_on_commit=False),
metadata=orm_registry.metadata,
)
file_path = f"{tmp_path}/test-cd-async.db"
assert not Path(file_path).exists()
try:
await create_database(cfg)
assert Path(file_path).exists()
async with cfg.get_session() as sess:
result = await sess.execute(select(text("1")))
assert result.scalar_one() == 1
await drop_database(cfg)
assert not Path(file_path).exists()
finally:
# always clean up
if Path(file_path).exists():
Path(file_path).unlink()


async def test_create_and_drop_postgres_async(asyncpg_engine_cd: AsyncEngine, asyncpg_engine: AsyncEngine) -> None:
orm_registry = base.create_registry()
cfg = SQLAlchemyAsyncConfig(
engine_instance=asyncpg_engine_cd,
session_maker=async_sessionmaker(bind=asyncpg_engine_cd, expire_on_commit=False),
metadata=orm_registry.metadata,
)

dbname = asyncpg_engine_cd.url.database
exists_sql = f"""
select exists(
SELECT datname FROM pg_catalog.pg_database WHERE lower(datname) = lower('{dbname}')
);
"""

# ensure database does not exist
async with asyncpg_engine.begin() as conn:
result = await conn.execute(text(exists_sql))
assert not result.scalar_one()

await create_database(cfg)
async with asyncpg_engine.begin() as conn:
result = await conn.execute(text(exists_sql))
assert result.scalar_one()

await drop_database(cfg)

async with asyncpg_engine.begin() as conn:
result = await conn.execute(text(exists_sql))
assert not result.scalar_one()

await asyncpg_engine.dispose()
Loading
Loading