Skip to content

Commit

Permalink
Merge pull request #26 from GSA-TTS/app-context-refactor
Browse files Browse the repository at this point in the history
App context refactor
  • Loading branch information
akuny authored Feb 6, 2024
2 parents 4888761 + 0c9ccd0 commit 2c4e21f
Show file tree
Hide file tree
Showing 10 changed files with 166 additions and 125 deletions.
23 changes: 23 additions & 0 deletions nad_ch/application/interfaces.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional, Protocol
from nad_ch.application.dtos import DownloadResult
from nad_ch.domain.repositories import DataProviderRepository, DataSubmissionRepository


class Logger(Protocol):
Expand Down Expand Up @@ -30,3 +31,25 @@ def cleanup_temp_dir(self, temp_dir: str) -> bool:
class TaskQueue(Protocol):
def run_load_and_validate(self, path: str):
...


class ApplicationContext:
@property
def providers(self) -> DataProviderRepository:
return self._providers

@property
def submissions(self) -> DataSubmissionRepository:
return self._submissions

@property
def logger(self) -> Logger:
return self._logger

@property
def storage(self) -> Storage:
return self._storage

@property
def task_queue(self) -> TaskQueue:
return self._task_queue
2 changes: 1 addition & 1 deletion nad_ch/application/use_cases.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from typing import List
from nad_ch.application.dtos import DownloadResult
from nad_ch.application_context import ApplicationContext
from nad_ch.application.interfaces import ApplicationContext
from nad_ch.domain.entities import DataProvider, DataSubmission


Expand Down
111 changes: 0 additions & 111 deletions nad_ch/application_context.py

This file was deleted.

2 changes: 1 addition & 1 deletion nad_ch/cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from nad_ch.controllers.cli import cli
from nad_ch.application_context import create_app_context
from nad_ch.config import create_app_context


def main():
Expand Down
4 changes: 3 additions & 1 deletion nad_ch/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from .base import APP_ENV


if APP_ENV == "dev_local" or APP_ENV == "test":
if APP_ENV == "dev_local":
from .development_local import *
elif APP_ENV == "dev_remote":
from .development_remote import *
elif APP_ENV == "test":
from .test import *
52 changes: 45 additions & 7 deletions nad_ch/config/development_local.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import logging
import os
from .base import *
from nad_ch.application.interfaces import ApplicationContext
from nad_ch.infrastructure.database import (
create_session_factory,
SqlAlchemyDataProviderRepository,
SqlAlchemyDataSubmissionRepository,
)
from nad_ch.infrastructure.logger import BasicLogger
from nad_ch.infrastructure.storage import MinioStorage


# Local development config
APP_ENV = os.getenv("APP_ENV")
STORAGE_PATH = os.getenv("STORAGE_PATH")


postgres_user = os.getenv("POSTGRES_USER")
postgres_password = os.getenv("POSTGRES_PASSWORD")
postgres_host = os.getenv("POSTGRES_HOST")
Expand All @@ -16,13 +21,46 @@
f"postgresql+psycopg2://{postgres_user}:{postgres_password}"
f"@{postgres_host}:{postgres_port}/{postgres_db}"
)


QUEUE_BROKER_URL = os.getenv("QUEUE_BROKER_URL")
QUEUE_BACKEND_URL = os.getenv("QUEUE_BACKEND_URL")

S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
S3_ENDPOINT = os.getenv("S3_ENDPOINT")
S3_ACCESS_KEY = os.getenv("S3_ACCESS_KEY")
S3_SECRET_ACCESS_KEY = os.getenv("S3_SECRET_ACCESS_KEY")
S3_REGION = os.getenv("S3_REGION")


class DevLocalApplicationContext(ApplicationContext):
def __init__(self):
self._session = create_session_factory(DATABASE_URL)
self._providers = self.create_provider_repository()
self._submissions = self.create_submission_repository()
self._logger = self.create_logger()
self._storage = self.create_storage()
self._task_queue = self.create_task_queue()

def create_provider_repository(self):
return SqlAlchemyDataProviderRepository(self._session)

def create_submission_repository(self):
return SqlAlchemyDataSubmissionRepository(self._session)

def create_logger(self):
return BasicLogger(__name__, logging.DEBUG)

def create_storage(self):
return MinioStorage(
S3_ENDPOINT,
S3_ACCESS_KEY,
S3_SECRET_ACCESS_KEY,
S3_BUCKET_NAME,
)

def create_task_queue(self):
from nad_ch.infrastructure.task_queue import celery_app, CeleryTaskQueue

return CeleryTaskQueue(celery_app)


def create_app_context():
return DevLocalApplicationContext()
45 changes: 44 additions & 1 deletion nad_ch/config/development_remote.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
import json
import os
from .base import *
from nad_ch.application.interfaces import ApplicationContext
from nad_ch.infrastructure.database import (
create_session_factory,
SqlAlchemyDataProviderRepository,
SqlAlchemyDataSubmissionRepository,
)
from nad_ch.infrastructure.logger import BasicLogger
from nad_ch.infrastructure.storage import S3Storage


def get_credentials(service_name, default={}):
service = vcap_services.get(service_name, [default])
return service[0].get("credentials", default) if service else default


# Remote development config
vcap_services = json.loads(os.getenv("VCAP_SERVICES", "{}"))


Expand All @@ -28,3 +35,39 @@ def get_credentials(service_name, default={}):
"secret_access_key", os.getenv("S3_SECRET_ACCESS_KEY")
)
S3_REGION = s3_credentials.get("region", os.getenv("S3_REGION"))


class DevRemoteApplicationContext(ApplicationContext):
def __init__(self):
self._session = create_session_factory(DATABASE_URL)
self._providers = self.create_provider_repository()
self._submissions = self.create_submission_repository()
self._logger = self.create_logger()
self._storage = self.create_storage()
self._task_queue = self.create_task_queue()

def create_provider_repository(self):
return SqlAlchemyDataProviderRepository(self._session)

def create_submission_repository(self):
return SqlAlchemyDataSubmissionRepository(self._session)

def create_logger(self):
return BasicLogger(__name__)

def create_storage(self):
return S3Storage(
S3_ACCESS_KEY,
S3_SECRET_ACCESS_KEY,
S3_REGION,
S3_BUCKET_NAME,
)

def create_task_queue(self):
from nad_ch.infrastructure.task_queue import celery_app, CeleryTaskQueue

return CeleryTaskQueue(celery_app)


def create_app_context():
return DevRemoteApplicationContext()
45 changes: 45 additions & 0 deletions nad_ch/config/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import logging
import os
from nad_ch.application.interfaces import ApplicationContext
from nad_ch.infrastructure.logger import BasicLogger
from tests.fakes_and_mocks import (
FakeDataProviderRepository,
FakeDataSubmissionRepository,
FakeStorage,
)


DATABASE_URL = os.getenv("DATABASE_URL")
QUEUE_BROKER_URL = os.getenv("QUEUE_BROKER_URL")
QUEUE_BACKEND_URL = os.getenv("QUEUE_BACKEND_URL")


class TestApplicationContext(ApplicationContext):
def __init__(self):
self._session = None
self._providers = self.create_provider_repository()
self._submissions = self.create_submission_repository()
self._logger = self.create_logger()
self._storage = self.create_storage()
self._task_queue = self.create_task_queue()

def create_provider_repository(self):
return FakeDataProviderRepository()

def create_submission_repository(self):
return FakeDataSubmissionRepository()

def create_logger(self):
return BasicLogger(__name__, logging.DEBUG)

def create_storage(self):
return FakeStorage()

def create_task_queue(self):
from nad_ch.infrastructure.task_queue import celery_app, CeleryTaskQueue

return CeleryTaskQueue(celery_app)


def create_app_context():
return TestApplicationContext()
1 change: 1 addition & 0 deletions nad_ch/infrastructure/task_queue.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from celery import Celery
import geopandas as gpd
from nad_ch.application.interfaces import TaskQueue
Expand Down
6 changes: 3 additions & 3 deletions tests/application/test_use_cases.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
import re
from nad_ch.application_context import create_app_context
from nad_ch.config import create_app_context
from nad_ch.domain.entities import DataProvider, DataSubmission
from nad_ch.application.use_cases import (
add_data_provider,
Expand All @@ -27,13 +27,13 @@ def test_add_data_provider(app_context):


def test_add_data_provider_logs_error_if_no_provider_name_given(mocker):
mock_context = mocker.patch("nad_ch.application_context.create_app_context")
mock_context = mocker.patch("nad_ch.config.create_app_context")
add_data_provider(mock_context, "")
mock_context.logger.error.assert_called_once_with("Provider name required")


def test_add_data_provider_logs_error_if_provider_name_not_unique(mocker):
mock_context = mocker.patch("nad_ch.application_context.create_app_context")
mock_context = mocker.patch("nad_ch.config.create_app_context")
mock_context.providers.get_by_name.return_value("State X")
add_data_provider(mock_context, "State X")

Expand Down

0 comments on commit 2c4e21f

Please sign in to comment.