Skip to content

Commit

Permalink
make ingest_gis configable, update db mock, refactor sql tool
Browse files Browse the repository at this point in the history
Ingest GIS Configurable
- This is just so I can disable GIS for instance 4

Update DB Mock & DbService Interface
- Looks like I forgot to add copy & abort to the db instance type
- I actually moved ConnectionLike::info out of the interface and
  added a method to DatabaseService to do that logic without
  exposing psycopgs API to the user of the service

Refactor SQL Tool
- I moved reading files logic to a new class called FileDiscovery
  - This stuff was moved out SchemaReader
- Added a comment for the stuff unsafe stuff, I feel like maybe I
  should lift this out or something? idk.
- I deleted a test, but the truth is I just moved it to a new name
  but I am not ready to commit the contents
  • Loading branch information
AKST committed Feb 8, 2025
1 parent 9acad19 commit fd395a9
Show file tree
Hide file tree
Showing 11 changed files with 131 additions and 108 deletions.
1 change: 1 addition & 0 deletions lib/defaults/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class InstanceCfg:
docker_volume: str
docker_container: ContainerConfig
docker_image: ImageConfig
enable_gis: bool
enable_gnaf: bool
gnaf_states: Set[GnafState]
nswvg_lv_discovery_mode: NswVgLvCsvDiscoveryMode.T
Expand Down
4 changes: 4 additions & 0 deletions lib/defaults/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def _create_mounted_dirs(
nswvg_psi_min_pub_year=None,
gnaf_states=ALL_STATES,
enable_gnaf=True,
enable_gis=True,
database=DatabaseConfig(
dbname=_db_name_1,
user='postgres',
Expand Down Expand Up @@ -106,6 +107,7 @@ def _create_mounted_dirs(
nswvg_psi_min_pub_year=2024,
gnaf_states={'NSW'},
enable_gnaf=True,
enable_gis=True,
database=DatabaseConfig(
dbname=_db_name_2,
user='postgres',
Expand Down Expand Up @@ -143,6 +145,7 @@ def _create_mounted_dirs(
nswvg_psi_min_pub_year=2024,
gnaf_states={'NSW'},
enable_gnaf=True,
enable_gis=True,
database=DatabaseConfig(
dbname=_db_name_3,
user='postgres',
Expand Down Expand Up @@ -179,6 +182,7 @@ def _create_mounted_dirs(
nswvg_psi_min_pub_year=2024,
gnaf_states={'NSW'},
enable_gnaf=True,
enable_gis=False,
database=DatabaseConfig(
dbname=_db_name_4,
user='postgres',
Expand Down
20 changes: 15 additions & 5 deletions lib/service/database/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import re
from typing import Any, Self, Protocol, Sequence

from .type import DatabaseService, CursorLike, ConnectionLike
from .type import DatabaseService, CursorLike, ConnectionLike, CopyLike

def clean_sql(sql: str) -> str:
return re.sub(r'(\s|\n)+', ' ', sql).strip()
Expand Down Expand Up @@ -42,15 +42,16 @@ async def fetchall(self: Self) -> list[list[Any]]:
self.state.fetchall_i += 1
return result

async def abort(self: Self):
return

def copy(self: Self, statement: str, params: list[str] | None = None) -> CopyLike:
raise NotImplementedError('idk I didn\'t get around to it')

@dataclass
class MockConnection(ConnectionLike):
state: MockDbState

@property
def info(self: Self) -> Any:
raise Exception()

async def __aexit__(self: Self, *args, **kwargs):
return

Expand All @@ -71,6 +72,12 @@ async def execute(self: Self, sql: str, args: Sequence[Any] = []) -> CursorLike:
self.state.execute_args.append((sql, args))
return MockCursor(self.state)

async def abort(self: Self):
return

def copy(self: Self, statement: str, params: list[str] | None = None) -> CopyLike:
raise NotImplementedError('idk I didn\'t get around to it')

@dataclass
class MockDatabaseService(DatabaseService):
state: MockDbState = field(default_factory=lambda: MockDbState())
Expand All @@ -86,3 +93,6 @@ async def wait_till_running(self):

def async_connect(self) -> ConnectionLike:
return MockConnection(self.state)

def is_idle(self: Self, conn: ConnectionLike) -> bool:
return True
5 changes: 5 additions & 0 deletions lib/service/database/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,8 @@ async def wait_till_running(self: Self, interval=5, timeout=60):
raise e
await asyncio.sleep(interval)

def is_idle(self: Self, conn: ConnectionLike) -> bool:
if isinstance(conn, psycopg.AsyncConnection):
return conn.info.transaction_status == psycopg.pq.TransactionStatus.IDLE
raise TypeError('unknown connection kind')

7 changes: 3 additions & 4 deletions lib/service/database/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,6 @@ def copy(self: Self, statement: str, params: list[str] | None = None) -> CopyLik
...

class ConnectionLike(Protocol):
@property
def info(self: Self) -> Any:
...

async def __aexit__(self: Self, *args, **kwargs):
...

Expand Down Expand Up @@ -91,3 +87,6 @@ async def wait_till_running(self: Self, interval: int = 5, timeout: int = 60):
def async_connect(self: Self) -> ConnectionLike:
...

def is_idle(self: Self, conn: ConnectionLike) -> bool:
...

43 changes: 23 additions & 20 deletions lib/tasks/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class IngestConfig:
docker_volume: str
docker_image_config: ImageConfig
docker_container_config: ContainerConfig
enable_gis: bool
enable_gnaf: bool
enable_clean_staging_data: bool

Expand Down Expand Up @@ -173,27 +174,28 @@ async def ingest_all(config: IngestConfig):
io_service,
)

await ingest_gis(
io_service,
db_service,
uuid,
clock,
GisTaskConfig.Ingestion(
deduplication=GisTaskConfig.Deduplication(
run_from=None,
run_till=None,
truncate=False,
),
staging=GisTaskConfig.StageApiData(
db_workers=config.db_connections,
db_mode='write',
gis_params=[],
exp_backoff_attempts=8,
disable_cache=False,
projections=GisTaskConfig.projection_kinds,
),
if config.enable_gis:
await ingest_gis(
io_service,
db_service,
uuid,
clock,
GisTaskConfig.Ingestion(
deduplication=GisTaskConfig.Deduplication(
run_from=None,
run_till=None,
truncate=False,
),
staging=GisTaskConfig.StageApiData(
db_workers=config.db_connections,
db_mode='write',
gis_params=[],
exp_backoff_attempts=8,
disable_cache=False,
projections=GisTaskConfig.projection_kinds,
),
)
)
)

await run_count_for_schemas(db_service_config, ns_dependency_order)

Expand Down Expand Up @@ -238,6 +240,7 @@ async def ingest_all(config: IngestConfig):
docker_container_config=instance_cfg.docker_container,
gnaf_states=instance_cfg.gnaf_states,
enable_gnaf=instance_cfg.enable_gnaf,
enable_gis=instance_cfg.enable_gis,
enable_clean_staging_data=instance_cfg.clean_staging_data,
)

Expand Down
10 changes: 8 additions & 2 deletions lib/tooling/schema/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,17 @@ async def create(self: Self, command: Command, t: Transform.Create) -> None:

async with self._db.async_connect() as conn, conn.cursor() as cursor:
for file in file_list:
# TODO look into this
"""
WHY DOES THIS EXIST?
In practise this isn't actually used! It's to handle instances
of concurrently creating indexes, which really shouldn't appear
in our schemas.
"""
if file.is_known_to_be_transaction_unsafe:
await conn.commit()
await conn.set_autocommit(True)
elif conn.info.transaction_status == psycopg.pq.TransactionStatus.IDLE:
elif self._db.is_idle(conn):
await conn.set_autocommit(False)

if t.run_raw_schema:
Expand Down
12 changes: 11 additions & 1 deletion lib/tooling/schema/create.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
from os.path import basename
import re
from lib.service.database import DatabaseService
from lib.service.io import IoService

from .controller import SchemaController
from .file_discovery import FileDiscovery
from .reader import SchemaReader

_ROOT_DIR = './sql'

def create_file_regex(root_dir: str) -> re.Pattern:
path_root = re.escape(root_dir)
path_ns = r'(?P<ns>[_a-zA-Z][_a-zA-Z0-9]*)'
path_file = r'(?P<step>\d{3})_APPLY(_(?P<name>[_a-zA-Z][_a-zA-Z0-9]*))?.sql'
return re.compile(rf'^{path_root}/{path_ns}/schema/{path_file}$')

def create(io: IoService, db: DatabaseService) -> SchemaController:
reader = SchemaReader.create(io, basename(_ROOT_DIR))
pattern = create_file_regex(basename(_ROOT_DIR))
file_discovery = FileDiscovery(io, pattern, basename(_ROOT_DIR))
reader = SchemaReader(file_discovery, io)
controller = SchemaController(io, db, reader)
return controller

Expand Down
52 changes: 52 additions & 0 deletions lib/tooling/schema/file_discovery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from dataclasses import dataclass
from logging import getLogger
import re
from typing import cast, Self, Optional

from lib.service.io import IoService

from .config import schema_ns
from .type import SchemaNamespace

def create_file_regex(root_dir: str) -> re.Pattern:
path_root = re.escape(root_dir)
path_ns = r'(?P<ns>[_a-zA-Z][_a-zA-Z0-9]*)'
path_file = r'(?P<step>\d{3})_APPLY(_(?P<name>[_a-zA-Z][_a-zA-Z0-9]*))?.sql'
return re.compile(rf'^{path_root}/{path_ns}/schema/{path_file}$')

@dataclass
class FileDiscoveryMatch:
ns: SchemaNamespace
step: int
name: str


class FileDiscovery:
logger = getLogger(__name__)

def __init__(self: Self, io: IoService, file_regex: re.Pattern, root_dir: str):
self._io = io
self.file_regex = file_regex
self.root_dir = root_dir

async def ns_matches(self: Self, ns: SchemaNamespace) -> list[tuple[str, FileDiscoveryMatch]]:
return [(f, self.match_file(f)) for f in await self.ns_sql_files(ns)]

async def ns_sql_files(self: Self, ns: SchemaNamespace) -> list[str]:
glob_s = '*_APPLY*.sql'
root_d = f'{self.root_dir}/{ns}/schema'
return [f async for f in self._io.grep_dir(root_d, glob_s)]

def match_file(self: Self, file: str) -> FileDiscoveryMatch:
match self.file_regex.match(file):
case None:
raise ValueError(f'invalid file {file}')
case match:
ns_str = match.group('ns')
step = int(match.group('step'))
name = match.group('name')
if ns_str not in schema_ns:
raise TypeError(f'unknown namespace {ns_str}')
ns: SchemaNamespace = cast(SchemaNamespace, ns_str)
return FileDiscoveryMatch(ns, step, name)

54 changes: 9 additions & 45 deletions lib/tooling/schema/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
Optional,
Self,
Type,
Tuple,
)

from lib.service.io import IoService
from lib.utility.concurrent import fmap
from lib.utility.iteration import partition

from .config import schema_ns
from .file_discovery import FileDiscovery, FileDiscoveryMatch
from .type import (
AlterTableAction,
Stmt,
Expand All @@ -28,41 +28,22 @@
Ref,
)

@dataclass
class _FileMeta:
ns: SchemaNamespace
step: int
name: Optional[str]

class SchemaReader:
logger = getLogger(f'{__name__}.SchemaReader')
file_regex: re.Pattern
root_dir: str
_io: IoService

def __init__(self: Self,
root_dir: str,
file_regex: re.Pattern[str],
file_discovery: FileDiscovery,
io: IoService) -> None:
self.root_dir = root_dir
self.file_regex = file_regex
self.file_discovery = file_discovery
self._io = io

@staticmethod
def create(io: IoService, root_dir: str) -> 'SchemaReader':
path_root = re.escape(root_dir)
path_ns = r'(?P<ns>[_a-zA-Z][_a-zA-Z0-9]*)'
path_file = r'(?P<step>\d{3})_APPLY(_(?P<name>[_a-zA-Z][_a-zA-Z0-9]*))?.sql'
pattern = re.compile(rf'^{path_root}/{path_ns}/schema/{path_file}$')
return SchemaReader(root_dir, pattern, io)

async def files(
self: Self,
name: SchemaNamespace,
maybe_range: Optional[range] = None,
load_syn=False,
) -> list[SqlFileMetaData]:
metas = [(f, self.__f_meta_data(f)) for f in await self.__ns_sql(name)]
metas = await self.file_discovery.ns_matches(name)

return sorted([
await self.__f_sql_meta_data(f, meta, load_syn)
Expand All @@ -80,35 +61,18 @@ async def all_files(
namespace: sorted([
await self.__f_sql_meta_data(f, meta, load_syn)
for f, meta in [
(f, self.__f_meta_data(f))
for f in await self.__ns_sql(namespace)
(f, self.file_discovery.match_file(f))
for f in await self.file_discovery.ns_sql_files(namespace)
]
for f in await self.__ns_sql(namespace)
for f in await self.file_discovery.ns_sql_files(namespace)
], key=lambda it: it.step)
for namespace in (names or schema_ns)
}

async def __ns_sql(self: Self, ns: SchemaNamespace) -> list[str]:
glob_s = '*_APPLY*.sql'
root_d = f'{self.root_dir}/{ns}/schema'
return [f async for f in self._io.grep_dir(root_d, glob_s)]

def __f_meta_data(self: Self, f: str) -> _FileMeta:
match self.file_regex.match(f):
case None: raise ValueError(f'invalid file {f}')
case match:
ns_str = match.group('ns')
step = int(match.group('step'))
name = match.group('name')
if ns_str not in schema_ns:
raise TypeError(f'unknown namespace {ns_str}')
ns: SchemaNamespace = cast(SchemaNamespace, ns_str)
return _FileMeta(ns, step, name)

async def __f_sql_meta_data(self: Self, f: str, meta: _FileMeta, load_syn: bool) -> SqlFileMetaData:
async def __f_sql_meta_data(self: Self, f: str, meta: FileDiscoveryMatch, load_syn: bool) -> SqlFileMetaData:
try:
contents = await fmap(sql_as_operations, self._io.f_read(f)) if load_syn else None
return SqlFileMetaData(f, self.root_dir, meta.ns, meta.step, meta.name, contents)
return SqlFileMetaData(f, self.file_discovery.root_dir, meta.ns, meta.step, meta.name, contents)
except Exception as e:
self.logger.error(f'failed on {f}')
raise e
Expand Down
Loading

0 comments on commit fd395a9

Please sign in to comment.