Skip to content

Commit

Permalink
fixes #85
Browse files Browse the repository at this point in the history
  • Loading branch information
AKST committed Feb 10, 2025
1 parent 0e6d765 commit 6f035f9
Show file tree
Hide file tree
Showing 7 changed files with 136 additions and 58 deletions.
1 change: 1 addition & 0 deletions lib/pipeline/nsw_vg/property_description/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
TelemetryListener as ProcDescTelemetryListener,
)
from .type import ParentMessage as ProcUpdateMessage
from .work_partitioner import WorkPartitioner as PropDescWorkPartitioner
71 changes: 22 additions & 49 deletions lib/pipeline/nsw_vg/property_description/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,13 @@
from lib.service.database import DatabaseService
from lib.pipeline.nsw_lrs.property_description.parse import parse_property_description_data
from lib.pipeline.nsw_lrs.property_description import PropertyDescription
from .type import ParentMessage

@dataclass
class QuantileRange:
start: Optional[int]
end: Optional[int]
from .type import ParentMessage, PartitionSlice
from .work_partitioner import WorkPartitioner

@dataclass
class WorkerProcessConfig:
worker_no: int
quantiles: List[QuantileRange]
quantiles: List[PartitionSlice]

@dataclass
class _WorkerClient:
Expand All @@ -30,7 +26,7 @@ class _WorkerClient:
async def join(self: Self) -> None:
await asyncio.to_thread(self.proc.join)

SpawnWorkerFn = Callable[[int, WorkerProcessConfig], Process]
SpawnWorkerFn = Callable[[WorkerProcessConfig], Process]

class PropDescIngestionWorkerPool:
_logger = getLogger(f'{__name__}.Pool')
Expand All @@ -45,9 +41,9 @@ def __init__(self: Self,
self._semaphore = semaphore
self._spawn_worker_fn = spawn_worker_fn

def spawn(self: Self, worker_no: int, quantiles: List[QuantileRange]) -> None:
def spawn(self: Self, worker_no: int, quantiles: List[PartitionSlice]) -> None:
worker_conf = WorkerProcessConfig(worker_no=worker_no, quantiles=quantiles)
process = self._spawn_worker_fn(worker_no, worker_conf)
process = self._spawn_worker_fn(worker_conf)
self._pool[worker_no] = _WorkerClient(process)
process.start()

Expand All @@ -63,49 +59,26 @@ class PropDescIngestionSupervisor:
_db: DatabaseService
_worker_pool: PropDescIngestionWorkerPool

def __init__(self: Self, db: DatabaseService, pool: PropDescIngestionWorkerPool) -> None:
def __init__(self: Self,
db: DatabaseService,
pool: PropDescIngestionWorkerPool,
partitioner: WorkPartitioner) -> None:
self._db = db
self._worker_pool = pool
self._partitioner = partitioner

async def ingest(self: Self, workers: int, sub_workers: int) -> None:
no_of_quantiles = workers * sub_workers
quantiles = await self._find_table_quantiles(workers, sub_workers)
slices = await self._partitioner.find_partitions()

for q_id, quantile in quantiles.items():
for q_id, partition_slice in slices.items():
self._logger.debug(f"spawning {q_id}")
self._worker_pool.spawn(q_id, quantile)
self._worker_pool.spawn(q_id, partition_slice)

self._logger.debug(f"Awaiting workers")
await self._worker_pool.join_all()
self._logger.debug(f"Done")

async def _find_table_quantiles(self: Self, workers: int, sub_workers: int) -> Dict[int, List[QuantileRange]]:

async with self._db.async_connect() as c, c.cursor() as cursor:
no_of_quantiles = workers * sub_workers
self._logger.info(f"Finding quantiles (count {no_of_quantiles})")
await cursor.execute(f"""
SELECT segment, MIN(property_id), MAX(property_id)
FROM (SELECT property_id, NTILE({no_of_quantiles}) OVER (ORDER BY property_id) AS segment
FROM nsw_lrs.legal_description
WHERE legal_description_kind = '> 2004-08-17'
AND strata_lot_number IS NULL) t
GROUP BY segment
ORDER BY segment
""")
items = await cursor.fetchall()
q_start = [None, *(row[1] for row in items[1:])]
q_end = [*(row[2] for row in items[:-1]), None]

qs = [
QuantileRange(start=q_start[i], end=q_end[i])
for i, row in enumerate(items)
]

return {
i: qs[i * sub_workers:(i + 1) * sub_workers]
for i in range(workers)
}

class PropDescIngestionWorker:
_logger = getLogger(f'{__name__}.Worker')
Expand All @@ -120,13 +93,13 @@ def __init__(self: Self,
self._semaphore = semaphore
self._db = db

async def ingest(self: Self, quantiles: List[QuantileRange]) -> None:
async def ingest(self: Self, partitions: List[PartitionSlice]) -> None:
self._logger.debug("Starting sub workers")
tasks = [asyncio.create_task(self.worker(i, q)) for i, q in enumerate(quantiles)]
tasks = [asyncio.create_task(self.worker(i, q)) for i, q in enumerate(partitions)]
await asyncio.gather(*tasks)
self._logger.debug("Finished ingesting")

async def worker(self: Self, worker_id: int, quantile: QuantileRange) -> None:
async def worker(self: Self, worker_id: int, partition: PartitionSlice) -> None:
limit = 100
temp_table_name = f"q_{uuid.uuid4().hex[:8]}"

Expand All @@ -140,7 +113,7 @@ def on_ingest_queued(amount: int):

async with self._db.async_connect() as conn, conn.cursor() as cursor:
self._logger.debug(f'creating temp table {temp_table_name}')
await self.create_temp_table(quantile, temp_table_name, cursor)
await self.create_temp_table(partition, temp_table_name, cursor)

await cursor.execute(f"SELECT count(*) FROM pg_temp.{temp_table_name}")
count = (await cursor.fetchone())[0]
Expand Down Expand Up @@ -254,7 +227,7 @@ async def ingest_page(self: Self,


async def create_temp_table(self: Self,
q: QuantileRange,
p: PartitionSlice,
temp_table_name: str,
cursor) -> None:
loop = asyncio.get_running_loop()
Expand All @@ -268,12 +241,12 @@ async def create_temp_table(self: Self,
legal_description_id,
property_id,
effective_date
FROM nsw_lrs.legal_description
FROM {p.src_table_name}
LEFT JOIN meta.source_byte_position USING (source_id)
LEFT JOIN meta.file_source USING (file_source_id)
WHERE legal_description_kind = '> 2004-08-17'
{f"AND property_id >= {q.start}" if q.start else ''}
{f"AND property_id < {q.end}" if q.end else ''}
{f"AND property_id >= {p.start}" if p.start else ''}
{f"AND property_id < {p.end}" if p.end else ''}
AND strata_lot_number IS NULL;
""")
self._semaphore.release()
Expand Down
6 changes: 3 additions & 3 deletions lib/pipeline/nsw_vg/property_description/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Self, Optional
import queue

from lib.service.clock import ClockService
from lib.service.clock import AbstractClockService
from lib.utility.format import fmt_time_elapsed

from .type import ParentMessage
Expand Down Expand Up @@ -43,14 +43,14 @@ class Telemetry:

def __init__(self: Self,
start_time: float,
clock: ClockService,
clock: AbstractClockService,
processes: dict[int, Process]) -> None:
self.start_time = start_time
self._clock = clock
self._processes = processes

@staticmethod
def create(clock: ClockService, workers: int, subworkers: int) -> 'Telemetry':
def create(clock: AbstractClockService, workers: int, subworkers: int) -> 'Telemetry':
procs = {
j: Process({
i: WorkerState(0, 0)
Expand Down
13 changes: 13 additions & 0 deletions lib/pipeline/nsw_vg/property_description/type.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
from abc import ABC
from dataclasses import dataclass
from typing import Optional

@dataclass
class PartitionSlice:
src_table_name: str
start: Optional[int]
end: Optional[int]
count: int

@dataclass
class QuantileRange:
start: Optional[int]
end: Optional[int]

class ParentMessage:
@dataclass
Expand Down
89 changes: 89 additions & 0 deletions lib/pipeline/nsw_vg/property_description/work_partitioner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import asyncio
import heapq
import logging
from typing import Self, Optional

from lib.service.database import *
from .type import PartitionSlice

class WorkPartitioner:
"""
The goal of this class is break up the work into efficicent
portions of work. This is done by ensuring inputs do not
span partitions
"""
_logger = logging.getLogger(__name__)
_db_partitions: Optional[int]

def __init__(self: Self,
db: DatabaseService,
workers: int,
subdivide: int):
self._db = db
self._workers = workers
self._subdivide = subdivide

async def find_partitions(self: Self) -> dict[int, list[PartitionSlice]]:
async def _partition_count(c: DbCursorLike, name: str) -> tuple[str, int]:
await c.execute(f"""
SELECT count(*)::bigint AS row_count
FROM {name}
WHERE legal_description_kind = '> 2004-08-17'
AND strata_lot_number IS NULL
""")
return name, (await c.fetchone())[0]

async def _partition_ntiles(c: DbCursorLike, name: str, segs: int) -> list[PartitionSlice]:
await c.execute(f"""
SELECT segment, MIN(property_id), MAX(property_id), COUNT(*)
FROM (SELECT property_id,
NTILE({segs}) OVER (ORDER BY property_id) AS segment
FROM {name}
WHERE legal_description_kind = '> 2004-08-17'
AND strata_lot_number IS NULL) t
GROUP BY segment
ORDER BY segment
""")
return [
PartitionSlice(name, mn, mx, count)
for segment, mn, mx, count in await c.fetchall()
]

def distribute(items: list[PartitionSlice], n: int) -> list[list[PartitionSlice]]:
bins: list[tuple[float, int, list[PartitionSlice]]] = [(0, i, []) for i in range(n)]
heapq.heapify(bins)

for item in sorted(items, key=lambda x: x.count, reverse=True):
current_sum, bin_index, assigned_list = heapq.heappop(bins)
assigned_list.append(item)
heapq.heappush(bins, (current_sum + item.count, bin_index, assigned_list))

return [b[2] for b in sorted(bins, key=lambda b: b[1])]

async with self._db.async_connect() as conn, conn.cursor() as c:
await c.execute("""
SELECT inhrelid::regclass AS partition_name FROM pg_inherits
WHERE inhparent = 'nsw_lrs.legal_description'::regclass;
""")
name_w_count = await asyncio.gather(*[
_partition_count(c, row[0]) for row in await c.fetchall()
])
total_count = sum([count for _, count in name_w_count])
total_segments = self._subdivide * self._workers
segments_per_partition = [
(name, count, max(1, round(count / total_count * total_segments)))
for name, count in name_w_count
if count > 0
]
slices: list[PartitionSlice] = [s
for name, count, segments in segments_per_partition
for s in await _partition_ntiles(c, name, segments)
]

return { i: ps for i, ps in enumerate(distribute(slices, self._workers)) }






3 changes: 2 additions & 1 deletion lib/tasks/nsw_vg/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ async def ingest_nswvg(

if config.property_descriptions:
try:
await ingest_property_description(db, config.property_descriptions)
await ingest_property_description(db, clock, config.property_descriptions)
except Exception as e:
_logger.exception(e)
_logger.error('failed to ingest all property descriptions')
Expand Down Expand Up @@ -176,6 +176,7 @@ async def ingest_nswvg(
worker_debug=args.nswlrs_propdesc_child_debug,
workers=args.nswlrs_propdesc_workers,
sub_workers=args.nswlrs_propdesc_subworkers,
truncate_earlier=False,
)

config = NswVgTaskConfig.Ingestion(
Expand Down
11 changes: 6 additions & 5 deletions lib/tasks/nsw_vg/ingest_property_descriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
PropDescIngestionWorkerPool,
ProcDescTelemetry,
ProcDescTelemetryListener,
PropDescWorkPartitioner,
ProcUpdateMessage,
WorkerProcessConfig,
)
Expand All @@ -37,6 +38,7 @@ async def ingest_property_description(
semaphore = MpSemaphore(1)
telemetry = ProcDescTelemetry.create(clock, config.workers, config.sub_workers)
queue, telemetry_listener = ProcDescTelemetryListener.create(telemetry)
partitioner = PropDescWorkPartitioner(db, config.workers, config.sub_workers)

try:
if config.truncate_earlier:
Expand All @@ -48,13 +50,13 @@ async def ingest_property_description(
await conn.execute(query)

spawn_worker_with_worker_config = \
lambda process_id, w_config: Process(target=spawn_worker, args=(
process_id, queue, w_config, semaphore,
lambda w_config: Process(target=spawn_worker, args=(
queue, w_config, semaphore,
config.worker_debug, config.db_config))

telemetry_listener.listen()
pool = PropDescIngestionWorkerPool(semaphore, spawn_worker_with_worker_config)
parent = PropDescIngestionSupervisor(db, pool)
parent = PropDescIngestionSupervisor(db, pool, partitioner)
await parent.ingest(config.workers, config.sub_workers)
telemetry_listener.stop()
except Exception as e:
Expand All @@ -63,7 +65,6 @@ async def ingest_property_description(
raise e

def spawn_worker(
process_id: int,
queue: MpQueue,
config: WorkerProcessConfig,
semaphore: SemaphoreT,
Expand All @@ -74,7 +75,7 @@ async def worker_runtime(config: WorkerProcessConfig, semaphore: SemaphoreT, db_
config_vendor_logging({'sqlglot', 'psycopg.pool'})
config_logging(config.worker_no, worker_debug)
db = DatabaseServiceImpl.create(db_config, len(config.quantiles))
worker = PropDescIngestionWorker(process_id, queue, semaphore, db)
worker = PropDescIngestionWorker(config.worker_no, queue, semaphore, db)
await worker.ingest(config.quantiles)
asyncio.run(worker_runtime(config, semaphore, db_config))

Expand Down

0 comments on commit 6f035f9

Please sign in to comment.