Skip to content

Commit

Permalink
Add database and celery task tests
Browse files Browse the repository at this point in the history
  • Loading branch information
= committed Mar 12, 2024
1 parent 763fe4f commit 8a83209
Show file tree
Hide file tree
Showing 30 changed files with 788 additions and 236 deletions.
17 changes: 17 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from tests.fixtures import *
from tests.test_data.baselines import *
from nad_ch.config import QUEUE_BACKEND_URL, QUEUE_BROKER_URL


pytest_plugins = ("celery.contrib.pytest", )

@pytest.fixture(scope="session")
def celery_config():
return {
"broker_url": QUEUE_BROKER_URL,
"result_backend": QUEUE_BACKEND_URL,
"broker_connection_retry": True,
"broker_connection_retry_delay": 5,
"broker_connection_retry_max": 3,
"broker_connection_retry_on_startup": True
}
9 changes: 7 additions & 2 deletions nad_ch/application/interfaces.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Optional, Protocol
from typing import Optional, Protocol, Dict
from nad_ch.application.dtos import DownloadResult
from nad_ch.domain.repositories import (
DataProducerRepository,
DataSubmissionRepository,
UserRepository,
ColumnMapRepository,
)


Expand Down Expand Up @@ -38,7 +39,7 @@ def run_load_and_validate(
submissions: DataSubmissionRepository,
submission_id: int,
path: str,
config_name: str,
column_map: Dict[str, str],
):
...

Expand Down Expand Up @@ -78,6 +79,10 @@ def submissions(self) -> DataSubmissionRepository:
def users(self) -> UserRepository:
return self._users

@property
def column_maps(self) -> ColumnMapRepository:
return self._column_maps

@property
def logger(self) -> Logger:
return self._logger
Expand Down
12 changes: 10 additions & 2 deletions nad_ch/application/use_cases/data_submissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ def list_data_submissions_by_producer(
return get_view_model(submissions)


def validate_data_submission(ctx: ApplicationContext, filename: str, config_name: str):
def validate_data_submission(
ctx: ApplicationContext, filename: str, column_map_name: str
):
submission = ctx.submissions.get_by_filename(filename)
if not submission:
ctx.logger.error("Data submission with that filename does not exist")
Expand All @@ -81,8 +83,14 @@ def validate_data_submission(ctx: ApplicationContext, filename: str, config_name
ctx.logger.error("Data extration error")
return

# Using version 1 for column maps for now, may add feature for user to select
# version later
column_map = ctx.column_maps.get_by_name_and_version(column_map_name, 1)
report = ctx.task_queue.run_load_and_validate(
ctx.submissions, submission.id, download_result.extracted_dir, config_name
ctx.submissions,
submission.id,
download_result.extracted_dir,
column_map.mapping,
)

ctx.logger.info(f"Total number of features: {report.overview.feature_count}")
Expand Down
6 changes: 3 additions & 3 deletions nad_ch/application/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ def get_features_flagged(features: Dict[str, DataSubmissionReportFeature]) -> in


def initialize_overview_details(
gdf: GeoDataFrame, column_maps: Dict[str, str]
gdf: GeoDataFrame, column_map: Dict[str, str]
) -> Tuple[DataSubmissionReportOverview, Dict[str, DataSubmissionReportFeature]]:
report_overview = DataSubmissionReportOverview(feature_count=get_feature_count(gdf))
report_features = {
nad_name: DataSubmissionReportFeature(
provided_feature_name=provided_name, nad_feature_name=nad_name
)
for provided_name, nad_name in column_maps.items()
for provided_name, nad_name in column_map.items()
}
return report_overview, report_features

Expand Down Expand Up @@ -60,8 +60,8 @@ def update_overview_details(


def finalize_overview_details(
features: Dict[str, DataSubmissionReportFeature],
overview: DataSubmissionReportOverview,
features: Dict[str, DataSubmissionReportFeature],
) -> DataSubmissionReportOverview:
overview.features_flagged += get_features_flagged(features)
# TODO: Add logic for etl_update_required & data_update_required
Expand Down
5 changes: 5 additions & 0 deletions nad_ch/config/development_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
SqlAlchemyDataProducerRepository,
SqlAlchemyDataSubmissionRepository,
SqlAlchemyUserRepository,
SqlAlchemyColumnMapRepository,
)
from nad_ch.infrastructure.auth import AuthenticationImplementation
from nad_ch.infrastructure.logger import BasicLogger
Expand Down Expand Up @@ -46,6 +47,7 @@ def __init__(self):
self._producers = self.create_producer_repository()
self._submissions = self.create_submission_repository()
self._users = self.create_user_repository()
self._column_maps = self.create_column_map_repository()
self._logger = self.create_logger()
self._storage = self.create_storage()
self._task_queue = self.create_task_queue()
Expand All @@ -60,6 +62,9 @@ def create_submission_repository(self):
def create_user_repository(self):
return SqlAlchemyUserRepository(self._session_factory)

def create_column_map_repository(self):
return SqlAlchemyColumnMapRepository(self._session_factory)

def create_logger(self):
return BasicLogger(__name__, logging.DEBUG)

Expand Down
5 changes: 5 additions & 0 deletions nad_ch/config/development_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
SqlAlchemyDataProducerRepository,
SqlAlchemyDataSubmissionRepository,
SqlAlchemyUserRepository,
SqlAlchemyColumnMapRepository,
)
from nad_ch.infrastructure.auth import AuthenticationImplementation
from nad_ch.infrastructure.logger import BasicLogger
Expand Down Expand Up @@ -45,6 +46,7 @@ def __init__(self):
self._producers = self.create_producer_repository()
self._submissions = self.create_submission_repository()
self._users = self.create_user_repository()
self._column_maps = self.create_column_map_repository()
self._logger = self.create_logger()
self._storage = self.create_storage()
self._task_queue = self.create_task_queue()
Expand All @@ -59,6 +61,9 @@ def create_submission_repository(self):
def create_user_repository(self):
return SqlAlchemyUserRepository(self._session_factory)

def create_column_map_repository(self):
return SqlAlchemyColumnMapRepository(self._session_factory)

def create_logger(self):
return BasicLogger(__name__)

Expand Down
5 changes: 5 additions & 0 deletions nad_ch/config/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
FakeDataProducerRepository,
FakeDataSubmissionRepository,
FakeUserRepository,
FakeColumnMapRepository,
FakeStorage,
)

Expand All @@ -23,6 +24,7 @@ def __init__(self):
self._producers = self.create_producer_repository()
self._submissions = self.create_submission_repository()
self._users = self.create_user_repository()
self._column_maps = self.create_column_map_repository()
self._logger = self.create_logger()
self._storage = self.create_storage()
self._task_queue = self.create_task_queue()
Expand All @@ -37,6 +39,9 @@ def create_submission_repository(self):
def create_user_repository(self):
return FakeUserRepository()

def create_column_map_repository(self):
return FakeColumnMapRepository()

def create_logger(self):
return BasicLogger(__name__, logging.DEBUG)

Expand Down
3 changes: 3 additions & 0 deletions nad_ch/domain/repositories.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ def get_by_email(self, email: str) -> Optional[User]:
def get_by_id(self, id: int) -> Optional[User]:
...

def get_all(self) -> Iterable[User]:
...


class ColumnMapRepository(Protocol):
def add(self, column_map: ColumnMap) -> ColumnMap:
Expand Down
48 changes: 25 additions & 23 deletions nad_ch/infrastructure/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def from_entity(submission: DataSubmission, producer_id: int, column_map_id: int

def to_entity(self):
producer = self.data_producer.to_entity()
column_map = self.column_map.to_entity(producer)
column_map = self.column_map.to_entity()
entity = DataSubmission(
id=self.id,
filename=self.filename,
Expand Down Expand Up @@ -190,13 +190,14 @@ def from_entity(column_map: ColumnMap, producer_id: int):
)
return model

def to_entity(self, producer: DataProducer):
def to_entity(self):
producer_entity = self.data_producer.to_entity()
entity = ColumnMap(
id=self.id,
name=self.name,
version_id=self.version_id,
mapping=self.mapping,
producer=producer,
producer=producer_entity,
)

if self.created_at is not None:
Expand Down Expand Up @@ -290,32 +291,32 @@ def get_by_producer(self, producer: DataProducer) -> List[DataSubmission]:

def get_by_filename(self, filename: str) -> Optional[DataSubmission]:
with session_scope(self.session_factory) as session:
result = (
session.query(DataSubmissionModel, DataProducerModel)
.join(
DataProducerModel,
DataProducerModel.id == DataSubmissionModel.data_producer_id,
)
submission_model = (
session.query(DataSubmissionModel)
.filter(DataSubmissionModel.filename == filename)
.first()
)

if result:
submission_model, producer_model = result
return submission_model.to_entity(producer_model.to_entity())
if submission_model:
return submission_model.to_entity()
else:
return None

def update_report(self, id: int, report) -> None:
with session_scope(self.session_factory) as session:
model_instance = (
submission_model = (
session.query(DataSubmissionModel)
.filter(DataSubmissionModel.id == id)
.first()
)

if model_instance:
model_instance.report = report
if submission_model:
submission_model.report = report
session.commit()
session.refresh(submission_model)
return submission_model.to_entity()
else:
return None


class SqlAlchemyUserRepository(UserRepository):
Expand Down Expand Up @@ -350,6 +351,12 @@ def get_by_id(self, id: int) -> Optional[User]:
else:
return None

def get_all(self) -> List[User]:
with session_scope(self.session_factory) as session:
user_models = session.query(UserModel).all()
user_entities = [user.to_entity() for user in user_models]
return user_entities


class SqlAlchemyColumnMapRepository(ColumnMapRepository):
def __init__(self, session_factory):
Expand All @@ -366,8 +373,7 @@ def add(self, column_map: ColumnMap) -> ColumnMap:
session.add(column_map_model)
session.commit()
session.refresh(column_map_model)
producer_model_entity = producer_model.to_entity()
return column_map_model.to_entity(producer_model_entity)
return column_map_model.to_entity()

def get_all(self) -> List[ColumnMap]:
with session_scope(self.session_factory) as session:
Expand All @@ -387,10 +393,7 @@ def get_by_data_submission(
.first()
)
if submission_model:
producer_entity = submission_model.producer.to_entity()
column_map_entity = submission_model.column_map.to_entity(
producer_entity
)
column_map_entity = submission_model.column_map.to_entity()
return column_map_entity
else:
return None
Expand All @@ -406,8 +409,7 @@ def get_by_name_and_version(
)
.first()
)
producer_entity = column_map_model.data_producer.to_entity()
if column_map_model:
return column_map_model.to_entity(producer_entity)
return column_map_model.to_entity()
else:
return None
19 changes: 13 additions & 6 deletions nad_ch/infrastructure/task_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,17 @@
)
from nad_ch.config import QUEUE_BROKER_URL, QUEUE_BACKEND_URL
from nad_ch.domain.repositories import DataSubmissionRepository
from typing import Dict


celery_app = Celery(
"redis-task-queue", broker=QUEUE_BROKER_URL, backend=QUEUE_BACKEND_URL
"redis-task-queue",
broker=QUEUE_BROKER_URL,
backend=QUEUE_BACKEND_URL,
broker_connection_retry=True, # Enable broker connection retry
broker_connection_retry_delay=5, # Optional: retry delay in seconds
broker_connection_retry_max=3, # Optional: maximum number of retries
broker_connection_retry_on_startup=True, # Enable retry on startup
)


Expand All @@ -31,13 +38,13 @@


@celery_app.task
def load_and_validate(gdb_file_path: str, config_name: str) -> dict:
data_reader = DataReader(config_name)
def load_and_validate(gdb_file_path: str, column_map: Dict[str, str]) -> dict:
data_reader = DataReader(column_map)
first_batch = True
for gdf in data_reader.read_file_in_batches(path=gdb_file_path):
if first_batch:
overview, feature_details = initialize_overview_details(
data_reader.valid_renames
gdf, data_reader.valid_renames
)
feature_details = update_feature_details(gdf, feature_details)
overview = update_overview_details(gdf, overview)
Expand All @@ -56,9 +63,9 @@ def run_load_and_validate(
submissions: DataSubmissionRepository,
submission_id: int,
path: str,
config_name: str,
column_map: Dict[str, str],
):
task_result = load_and_validate.apply_async(args=[path, config_name])
task_result = load_and_validate.apply_async(args=[path, column_map])
report_dict = task_result.get()
submissions.update_report(submission_id, report_dict)
return report_from_dict(report_dict)
2 changes: 1 addition & 1 deletion scripts/seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def main():
)
ctx.users.add(new_user)

new_column_map = ColumnMap(name="New Jersey Mapping v1", producer=saved_producer)
# new_column_map = ColumnMap(name="New Jersey Mapping v1", producer=saved_producer)
# TODO save column map once ApplicationContext can provide a repository
# saved_column_map = ctx.column_maps.add(new_column_map)

Expand Down
Loading

0 comments on commit 8a83209

Please sign in to comment.