Skip to content

Commit

Permalink
Merge pull request #4 from SFTtech/milo/fix-db-migration-tools
Browse files Browse the repository at this point in the history
fix db migration tools
  • Loading branch information
mikonse authored Dec 27, 2024
2 parents 21ef8d0 + b0c63c9 commit 39c887c
Show file tree
Hide file tree
Showing 21 changed files with 526 additions and 896 deletions.
54 changes: 50 additions & 4 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,33 @@ jobs:
- name: Lint
run: npx nx run-many --target=lint

test:
runs-on: ubuntu-latest
test-py:
runs-on: ${{ matrix.os }}
strategy:
matrix:
python-version: [ 3.11 ]
os: [ ubuntu-latest ]
postgres-version: [ 15-bookworm ]
services:
postgres:
image: postgres:${{ matrix.postgres-version }}
env:
POSTGRES_PASSWORD: "password"
POSTGRES_USER: "sftkit"
POSTGRES_DB: "sftkit_test"
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
ports:
- 5432:5432
env:
SFTKIT_TEST_DB_USER: "sftkit"
SFTKIT_TEST_DB_HOST: "localhost"
SFTKIT_TEST_DB_PORT: "5432"
SFTKIT_TEST_DB_DBNAME: "sftkit_test"
SFTKIT_TEST_DB_PASSWORD: "password"
steps:
- uses: actions/checkout@v4

Expand All @@ -49,13 +74,34 @@ jobs:
- name: Set up Python with PDM
uses: pdm-project/setup-pdm@v3
with:
python-version: "3.11"
python-version: ${{ matrix.python-version }}

- name: Install Python dependencies
run: pdm sync -d

- name: Test
run: npx nx run-many --target=test
run: npx nx run-many --target=test --projects=tag:lang:python

test-js:
runs-on: ubuntu-latest
strategy:
matrix:
node-version: [ 20 ]
steps:
- uses: actions/checkout@v4

- name: Set up Nodejs
uses: actions/setup-node@v4
with:
node-version: ${{ matrix.node-version}}
cache: "npm"
cache-dependency-path: package-lock.json

- name: Install node dependencies
run: npm ci

- name: Test
run: npx nx run-many --target=test --projects=tag:lang:javascript

build:
runs-on: ubuntu-latest
Expand Down
977 changes: 179 additions & 798 deletions pdm.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions sftkit/project.json
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
{
"name": "sftkit",
"$schema": "../../node_modules/nx/schemas/project-schema.json",
"$schema": "../node_modules/nx/schemas/project-schema.json",
"sourceRoot": "sftkit/sftkit",
"projectType": "library",
"tags": [],
"tags": ["lang:python"],
"targets": {
"typecheck": {
"executor": "nx:run-commands",
Expand Down
12 changes: 6 additions & 6 deletions sftkit/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

[build-system]
requires = ["pdm-backend"]
build-backend = "pdm.backend"
Expand All @@ -17,11 +16,11 @@ classifiers = [
]
requires-python = ">=3.11"
dependencies = [
"fastapi>=0.111.0",
"typer>=0.12.3",
"uvicorn[standard]>=0.22.0",
"asyncpg>=0.29.0",
"pydantic[email]==2.7.4",
"fastapi>=0.115.6",
"typer>=0.15.1",
"uvicorn>=0.34.0",
"asyncpg>=0.30.0",
"pydantic[email]==2.10.4",
]

[project.urls]
Expand All @@ -35,6 +34,7 @@ source = ["sftkit"]

[tool.pytest.ini_options]
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "session"
minversion = "6.0"
testpaths = ["tests"]

Expand Down
88 changes: 34 additions & 54 deletions sftkit/sftkit/database/_migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@

import asyncpg

from sftkit.database import Connection
from sftkit.database.introspection import list_constraints, list_functions, list_triggers, list_views

logger = logging.getLogger(__name__)

MIGRATION_VERSION_RE = re.compile(r"^-- migration: (?P<version>\w+)$")
MIGRATION_REQURES_RE = re.compile(r"^-- requires: (?P<version>\w+)$")
MIGRATION_TABLE = "schema_revision"


async def _run_postgres_code(conn: asyncpg.Connection, code: str, file_name: Path):
async def _run_postgres_code(conn: Connection, code: str, file_name: Path):
if all(line.startswith("--") for line in code.splitlines()):
return
try:
Expand All @@ -32,33 +35,23 @@ async def _run_postgres_code(conn: asyncpg.Connection, code: str, file_name: Pat
raise ValueError(f"Syntax or Access error when executing SQL code ({file_name!s}): {message!r}") from exc


async def _drop_all_views(conn: asyncpg.Connection, schema: str):
async def _drop_all_views(conn: Connection, schema: str):
# TODO: we might have to find out the dependency order of the views if drop cascade does not work
result = await conn.fetch(
"select table_name from information_schema.views where table_schema = $1 and table_name !~ '^pg_';",
schema,
)
views = [row["table_name"] for row in result]
views = await list_views(conn, schema)
if len(views) == 0:
return

# we use drop if exists here as the cascade dropping might lead the view to being already dropped
# due to being a dependency of another view
drop_statements = "\n".join([f"drop view if exists {view} cascade;" for view in views])
drop_statements = "\n".join([f"drop view if exists {view.table_name} cascade;" for view in views])
await conn.execute(drop_statements)


async def _drop_all_triggers(conn: asyncpg.Connection, schema: str):
result = await conn.fetch(
"select distinct on (trigger_name, event_object_table) trigger_name, event_object_table "
"from information_schema.triggers where trigger_schema = $1",
schema,
)
async def _drop_all_triggers(conn: Connection, schema: str):
triggers = await list_triggers(conn, schema)
statements = []
for row in result:
trigger_name = row["trigger_name"]
table = row["event_object_table"]
statements.append(f"drop trigger {trigger_name} on {table};")
for trigger in triggers:
statements.append(f'drop trigger "{trigger.trigger_name}" on "{trigger.event_object_table}";')

if len(statements) == 0:
return
Expand All @@ -67,27 +60,20 @@ async def _drop_all_triggers(conn: asyncpg.Connection, schema: str):
await conn.execute(drop_statements)


async def _drop_all_functions(conn: asyncpg.Connection, schema: str):
result = await conn.fetch(
"select proname, pg_get_function_identity_arguments(oid) as signature, prokind from pg_proc "
"where pronamespace = $1::regnamespace;",
schema,
)
async def _drop_all_functions(conn: Connection, schema: str):
funcs = await list_functions(conn, schema)
drop_statements = []
for row in result:
kind = row["prokind"].decode("utf-8")
name = row["proname"]
signature = row["signature"]
if kind in ("f", "w"):
for func in funcs:
if func.prokind in ("f", "w"):
drop_type = "function"
elif kind == "a":
elif func.prokind == "a":
drop_type = "aggregate"
elif kind == "p":
elif func.prokind == "p":
drop_type = "procedure"
else:
raise RuntimeError(f'Unknown postgres function type "{kind}"')
raise RuntimeError(f'Unknown postgres function type "{func.prokind}"')

drop_statements.append(f"drop {drop_type} {name}({signature}) cascade;")
drop_statements.append(f'drop {drop_type} "{func.proname}"({func.signature}) cascade;')

if len(drop_statements) == 0:
return
Expand All @@ -96,37 +82,31 @@ async def _drop_all_functions(conn: asyncpg.Connection, schema: str):
await conn.execute(drop_code)


async def _drop_all_constraints(conn: asyncpg.Connection, schema: str):
async def _drop_all_constraints(conn: Connection, schema: str):
"""drop all constraints in the given schema which are not unique, primary or foreign key constraints"""
result = await conn.fetch(
"select con.conname as constraint_name, rel.relname as table_name, con.contype as constraint_type "
"from pg_catalog.pg_constraint con "
" join pg_catalog.pg_namespace nsp on nsp.oid = con.connamespace "
" left join pg_catalog.pg_class rel on rel.oid = con.conrelid "
"where nsp.nspname = $1 and con.conname !~ '^pg_' "
" and con.contype != 'p' and con.contype != 'f' and con.contype != 'u';",
schema,
)
constraints = []
for row in result:
constraint_name = row["constraint_name"]
constraint_type = row["constraint_type"].decode("utf-8")
table_name = row["table_name"]
constraints = await list_constraints(conn, schema)
drop_statements = []
for constraint in constraints:
constraint_name = constraint.conname
constraint_type = constraint.contype
table_name = constraint.relname
if constraint_type in ("p", "f", "u"):
continue
if constraint_type == "c":
constraints.append(f"alter table {table_name} drop constraint {constraint_name};")
drop_statements.append(f'alter table "{table_name}" drop constraint "{constraint_name}";')
elif constraint_type == "t":
constraints.append(f"drop constraint trigger {constraint_name};")
drop_statements.append(f"drop constraint trigger {constraint_name};")
else:
raise RuntimeError(f'Unknown constraint type "{constraint_type}" for constraint "{constraint_name}"')

if len(constraints) == 0:
if len(drop_statements) == 0:
return

drop_statements = "\n".join(constraints)
await conn.execute(drop_statements)
drop_cmd = "\n".join(drop_statements)
await conn.execute(drop_cmd)


async def _drop_db_code(conn: asyncpg.Connection, schema: str):
async def _drop_db_code(conn: Connection, schema: str):
await _drop_all_triggers(conn, schema=schema)
await _drop_all_functions(conn, schema=schema)
await _drop_all_views(conn, schema=schema)
Expand Down
89 changes: 89 additions & 0 deletions sftkit/sftkit/database/introspection/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from pydantic import BaseModel

from sftkit.database import Connection


class PgFunctionDef(BaseModel):
proname: str
pronamespace: int # oid
proowner: int # oid
prolang: int # oid
procost: float
prorows: int
provariadic: int # oid
prosupport: str
prokind: str
prosecdef: bool
proleakproof: bool
proisstrict: bool
proretset: bool
provolatile: str
proparallel: str
pronargs: int
pronargdefaults: int
prorettype: int # oid
proargtypes: list[int] # oid
proallargtypes: list[int] | None # oid
proargmodes: list[str] | None
proargnames: list[str] | None
# proargdefaults: pg_node_tree | None
protrftypes: list[str] | None
prosrc: str
probin: str | None
# prosqlbody: pg_node_tree | None
proconfig: list[str] | None
proacl: list[str] | None
signature: str


async def list_functions(conn: Connection, schema: str) -> list[PgFunctionDef]:
return await conn.fetch_many(
PgFunctionDef,
"select pg_proc.*, pg_get_function_identity_arguments(oid) as signature from pg_proc "
"where pronamespace = $1::regnamespace and pg_proc.proname !~ '^pg_';",
schema,
)


class PgViewDef(BaseModel):
table_name: str


async def list_views(conn: Connection, schema: str) -> list[PgViewDef]:
return await conn.fetch_many(
PgViewDef,
"select table_name from information_schema.views where table_schema = $1 and table_name !~ '^pg_';",
schema,
)


class PgTriggerDef(BaseModel):
trigger_name: str
event_object_table: str


async def list_triggers(conn: Connection, schema: str) -> list[PgTriggerDef]:
return await conn.fetch_many(
PgTriggerDef,
"select distinct on (trigger_name, event_object_table) trigger_name, event_object_table "
"from information_schema.triggers where trigger_schema = $1",
schema,
)


class PgConstraintDef(BaseModel):
conname: str
relname: str
contype: str


async def list_constraints(conn: Connection, schema: str) -> list[PgConstraintDef]:
return await conn.fetch_many(
PgConstraintDef,
"select con.conname, rel.relname, con.contype "
"from pg_catalog.pg_constraint con "
" join pg_catalog.pg_namespace nsp on nsp.oid = con.connamespace "
" left join pg_catalog.pg_class rel on rel.oid = con.conrelid "
"where nsp.nspname = $1 and con.conname !~ '^pg_';",
schema,
)
1 change: 1 addition & 0 deletions sftkit/tests/assets/minimal_db/code/constraints.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
alter table "user" add constraint username_allowlist check (name != 'exclusion');
13 changes: 13 additions & 0 deletions sftkit/tests/assets/minimal_db/code/functions.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
create or replace function test_func(
arg1 bigint,
arg2 text
) returns boolean as
$$
<<locals>> declare
tmp_var double precision;
begin
tmp_var = arg1 > 0 and arg2 != 'bla';
return tmp_var;
end;
$$ language plpgsql
set search_path = "$user", public;
14 changes: 14 additions & 0 deletions sftkit/tests/assets/minimal_db/code/triggers.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
create or replace function user_trigger() returns trigger as
$$
begin
return NEW;
end
$$ language plpgsql
stable
set search_path = "$user", public;

create trigger create_user_trigger
before insert
on "user"
for each row
execute function user_trigger();
6 changes: 6 additions & 0 deletions sftkit/tests/assets/minimal_db/code/views.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
create view user_with_post_count as
select u.*, author_counts.count
from "user" as u
join (
select p.author_id, count(*) as count from post as p group by p.author_id
) as author_counts on u.id = author_counts.author_id;
Loading

0 comments on commit 39c887c

Please sign in to comment.