From 09c355803de4c2c3de19fedf3bf6d2f0f1c43610 Mon Sep 17 00:00:00 2001 From: "tim.reichard" Date: Mon, 8 Feb 2021 08:33:41 -0600 Subject: [PATCH 1/3] Adding DB decorator for DAG usage primarily --- .pre-commit-config.yaml | 4 +- HISTORY.rst | 6 ++ aioradio/file_ingestion.py | 86 ++++++++++++++++++++++++++- aioradio/psycopg2.py | 36 +++++++++++ aioradio/pyodbc.py | 28 ++++++--- aioradio/requirements.txt | 11 ++-- aioradio/tests/file_ingestion_test.py | 37 +++++++++++- aioradio/tests/psycopg2_test.py | 24 ++++++++ aioradio/tests/pyodbc_test.py | 2 +- aioradio/tests/sqs_test.py | 4 +- conftest.py | 12 ++++ setup.py | 7 ++- 12 files changed, 230 insertions(+), 27 deletions(-) create mode 100644 aioradio/psycopg2.py create mode 100644 aioradio/tests/psycopg2_test.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8f59b26..f97f27d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -15,10 +15,10 @@ repos: - id: requirements-txt-fixer - id: trailing-whitespace - repo: https://github.com/PyCQA/isort - rev: 5.6.4 + rev: 5.7.0 hooks: - id: isort - repo: https://github.com/myint/docformatter - rev: v1.3.1 + rev: v1.4 hooks: - id: docformatter diff --git a/HISTORY.rst b/HISTORY.rst index 59d8b40..79f081c 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -3,6 +3,12 @@ History ======= +v0.10.0 (2021-02-08) +----------------------- + +* Add decorator to manage DB connections and using SQL transactions. + + v0.9.8 (2021-02-01) ----------------------- diff --git a/aioradio/file_ingestion.py b/aioradio/file_ingestion.py index 035220a..4590ce5 100644 --- a/aioradio/file_ingestion.py +++ b/aioradio/file_ingestion.py @@ -1,11 +1,13 @@ """Generic functions related to working with files or the file system.""" +# pylint: disable=broad-except # pylint: disable=invalid-name # pylint: disable=too-many-arguments # pylint: disable=too-many-boolean-expressions import asyncio import functools +import json import os import re import time @@ -22,6 +24,10 @@ from smb.smb_structs import OperationFailure from smb.SMBConnection import SMBConnection +from aioradio.aws.secrets import get_secret +from aioradio.psycopg2 import establish_psycopg2_connection +from aioradio.pyodbc import establish_pyodbc_connection + DIRECTORY = Path(__file__).parent.absolute() @@ -31,7 +37,6 @@ def async_wrapper(func: coroutine) -> Any: Args: func (coroutine): async coroutine - Returns: Any: any """ @@ -44,12 +49,87 @@ def wrapper(*args, **kwargs) -> Any: Any: any """ - loop = asyncio.get_event_loop() - return loop.run_until_complete(func(*args, **kwargs)) + return asyncio.get_event_loop().run_until_complete(func(*args, **kwargs)) return wrapper +def async_db_wrapper(db_info: List[Dict[str, Any]]) -> Any: + """Decorator to run functions using async that handles database connection + creation and closure. Pulls database creds from AWS secret manager. + + Args: + db_info (List[Dict[str, str]], optional): Database info {'name', 'db', 'secret', 'region'}. Defaults to []. + + Returns: + Any: any + """ + + def parent_wrapper(func: coroutine) -> Any: + """Decorator parent wrapper. + + Args: + func (coroutine): async coroutine + + Returns: + Any: any + """ + + @functools.wraps(func) + def child_wrapper(*args, **kwargs) -> Any: + """Decorator child wrapper. All DB established/closed connections + and commits or rollbacks take place in the decorator and should + never happen within the inner function. + + Returns: + Any: any + """ + + async_run = asyncio.get_event_loop().run_until_complete + conns = {} + rollback = {} + + # create connections + for item in db_info: + + if item['db'] in ['pyodbc', 'psycopg2']: + creds = {**json.loads(async_run(get_secret(item['secret'], item['region']))), **{'database': item.get('database', '')}} + if item['db'] == 'pyodbc': + conns[item['name']] = async_run(establish_pyodbc_connection(**creds, autocommit=False)) + elif item['db'] == 'psycopg2': + conns[item['name']] = async_run(establish_psycopg2_connection(**creds)) + rollback[item['name']] = item['rollback'] + print(f"ESTABLISHED CONNECTION for {item['name']}") + + result = None + error = None + try: + # run main function + result = async_run(func(*args, **kwargs, conns=conns)) if conns else async_run(func(*args, **kwargs)) + except Exception as err: + error = err + + # close connections + for name, conn in conns.items(): + + if rollback[name]: + conn.rollback() + + conn.commit() + conn.close() + print(f"CLOSED CONNECTION for {name}") + + # if we caught an exception raise it again + if error is not None: + raise error + + return result + + return child_wrapper + + return parent_wrapper + + def async_wrapper_using_new_loop(func: coroutine) -> Any: """Decorator to run functions using async. Found this handy to use with DAG tasks. diff --git a/aioradio/psycopg2.py b/aioradio/psycopg2.py new file mode 100644 index 0000000..ecb9e3f --- /dev/null +++ b/aioradio/psycopg2.py @@ -0,0 +1,36 @@ +"""Pyodbc functions for connecting and send queries.""" + +# pylint: disable=c-extension-no-member +# pylint: disable=too-many-arguments + +import psycopg2 + + +async def establish_psycopg2_connection( + host: str, + user: str, + password: str, + database: str, + port: int=5432, + is_audit: bool=False +): + """Acquire the psycopg2 connection object. + + Args: + host (str): Host + user (str): User + password (str): Password + database (str): Database + port (int, optional): Port. Defaults to 5432. + is_audit (bool, optional): Audit queries. Defaults to False. + + Returns: + pyscopg2 Connection object + """ + + conn = psycopg2.connect(host=host, port=port, user=user, password=password, dbname=database) + + if is_audit: + conn.autocommit=True + + return conn diff --git a/aioradio/pyodbc.py b/aioradio/pyodbc.py index 3f4f1be..aca17f7 100644 --- a/aioradio/pyodbc.py +++ b/aioradio/pyodbc.py @@ -1,6 +1,7 @@ """Pyodbc functions for connecting and send queries.""" # pylint: disable=c-extension-no-member +# pylint: disable=too-many-arguments import os from typing import Any, List, Union @@ -34,7 +35,15 @@ async def get_unixodbc_driver_path(paths: List[str]) -> Union[str, None]: return driver_path -async def establish_pyodbc_connection(host: str, user: str, pwd: str, driver: str = None) -> pyodbc.Connection: +async def establish_pyodbc_connection( + host: str, + user: str, + pwd: str, + port: int=1433, + database: str='', + driver: str='', + autocommit: bool=False +) -> pyodbc.Connection: """Acquire and return pyodbc.Connection object else raise FileNotFoundError. @@ -42,7 +51,9 @@ async def establish_pyodbc_connection(host: str, user: str, pwd: str, driver: st host (str): hostname user (str): username pwd (str): password - driver (str, optional): unixodbc driver. Defaults to None. + post (int, optional): port. Defaults to 1433. + database (str, optional): database. Defaults to ''. + driver (str, optional): unixodbc driver. Defaults to ''. Raises: FileNotFoundError: unable to locate unixodbc driver @@ -51,16 +62,15 @@ async def establish_pyodbc_connection(host: str, user: str, pwd: str, driver: st pyodbc.Connection: database connection object """ - if driver is None: - verified_driver = await get_unixodbc_driver_path(UNIXODBC_DRIVER_PATHS) - else: - verified_driver = await get_unixodbc_driver_path([driver]) - + verified_driver = await get_unixodbc_driver_path([driver]) if driver else await get_unixodbc_driver_path(UNIXODBC_DRIVER_PATHS) if verified_driver is None: raise FileNotFoundError('Unable to locate unixodbc driver file: libtdsodbc.so') - return pyodbc.connect( - f'DRIVER={verified_driver};SERVER={host};PORT=1433;UID={user};PWD={pwd};TDS_Version=8.0') + conn_string = f'DRIVER={verified_driver};SERVER={host};PORT={port};UID={user};PWD={pwd};TDS_Version=8.0' + if database: + conn_string += f';DATABASE={database}' + + return pyodbc.connect(conn_string, autocommit=autocommit) async def pyodbc_query_fetchone(conn: pyodbc.Connection, query: str) -> Union[List[Any], None]: diff --git a/aioradio/requirements.txt b/aioradio/requirements.txt index 4aedbdd..ef5219c 100644 --- a/aioradio/requirements.txt +++ b/aioradio/requirements.txt @@ -2,20 +2,21 @@ aioboto3==8.2.0 aiobotocore==1.1.2 aiojobs==0.3.0 aioredis==1.3.1 -ddtrace==0.45.0 +ddtrace==0.46.0 fakeredis==1.4.5 flask==1.1.2 httpx==0.16.1 mandrill==1.0.59 moto==1.3.16 -orjson==3.4.6 -pre-commit==2.9.3 +orjson==3.4.8 +pre-commit==2.10.1 +psycopg2-binary==2.8.6 pylint==2.6.0 pyodbc==4.0.30 pysmb==1.2.6 -pytest==6.2.1 +pytest==6.2.2 pytest-asyncio==0.14.0 -pytest-cov==2.10.1 +pytest-cov==2.11.1 python-json-logger==2.0.1 twine==3.3.0 wheel==0.36.2 diff --git a/aioradio/tests/file_ingestion_test.py b/aioradio/tests/file_ingestion_test.py index a747006..ea4bdf9 100644 --- a/aioradio/tests/file_ingestion_test.py +++ b/aioradio/tests/file_ingestion_test.py @@ -11,7 +11,8 @@ import pytest -from aioradio.file_ingestion import (delete_ftp_file, establish_ftp_connection, +from aioradio.file_ingestion import (async_db_wrapper, delete_ftp_file, + establish_ftp_connection, get_current_datetime_from_timestamp, list_ftp_objects, send_emails_via_mandrill, @@ -180,3 +181,37 @@ async def test_delete_ftp_file(github_action): result = await delete_ftp_file(conn=conn, service_name='EnrollmentFunnel', ftp_path='pytest/is/great/test_file_ingestion.zip') assert result is True + + +def test_async_wrapper(user): + """Test async_wrapper with database connections.""" + + if user != 'tim.reichard': + pytest.skip('Skip test_async_wrapper_factory since user is not Tim Reichard') + + db_info=[ + { + 'name': 'test1', + 'db': 'pyodbc', + 'secret': 'production/airflowCluster/sqloltp', + 'region': 'us-east-1', + 'rollback': True + }, + { + 'name': 'test2', + 'db': 'psycopg2', + 'secret': 'datalab/dev/classplanner_db', + 'region': 'us-east-1', + 'database': 'student', + 'is_audit': False, + 'rollback': True + } + ] + + @async_db_wrapper(db_info=db_info) + async def func(**kwargs): + conns = kwargs['conns'] + for name, conn in conns.items(): + print(f"Connection name: {name}\tConnection object: {conn}") + + func() diff --git a/aioradio/tests/psycopg2_test.py b/aioradio/tests/psycopg2_test.py new file mode 100644 index 0000000..496fbbc --- /dev/null +++ b/aioradio/tests/psycopg2_test.py @@ -0,0 +1,24 @@ +"""pytest psycopg2 script.""" + +import json + +import pytest + +from aioradio.aws.secrets import get_secret +from aioradio.psycopg2 import establish_psycopg2_connection + +pytestmark = pytest.mark.asyncio + + +async def test_establish_psycopg2_connection(github_action, user): + """Test establish_psycopg2_connection.""" + + if github_action: + pytest.skip('Skip test_establish_psycopg2_connection when running via Github Action') + elif user != 'tim.reichard': + pytest.skip('Skip test_establish_psycopg2_connection since user is not Tim Reichard') + + creds = json.loads(await get_secret('datalab/dev/classplanner_db', 'us-east-1')) + conn = await establish_psycopg2_connection(**creds, database='student') + assert conn.closed == 0 + conn.close() diff --git a/aioradio/tests/pyodbc_test.py b/aioradio/tests/pyodbc_test.py index e9b22d1..e051dad 100644 --- a/aioradio/tests/pyodbc_test.py +++ b/aioradio/tests/pyodbc_test.py @@ -1,4 +1,4 @@ -"""pytest file_ingestion script.""" +"""pytest pyodbc script.""" import os diff --git a/aioradio/tests/sqs_test.py b/aioradio/tests/sqs_test.py index 1ed5c5c..f2f7b69 100644 --- a/aioradio/tests/sqs_test.py +++ b/aioradio/tests/sqs_test.py @@ -2,7 +2,6 @@ # pylint: disable=c-extension-no-member -from asyncio import sleep from uuid import uuid4 import orjson @@ -21,7 +20,7 @@ async def test_add_regions(): """Add us-east-2 region.""" - await add_regions(['us-east-2']) + await add_regions([REGION]) async def test_sqs_creating_queue(sqs_queue_url): @@ -82,5 +81,4 @@ async def test_sqs_purge_messages(): # accept err: "Only one PurgeQueue operation on pytest is allowed every 60 seconds." assert 'PurgeQueue' in err else: - await sleep(3) assert not await get_messages(queue=QUEUE, region=REGION, wait_time=1) diff --git a/conftest.py b/conftest.py index 5070acc..cf687af 100644 --- a/conftest.py +++ b/conftest.py @@ -1,6 +1,7 @@ """pytest configuration.""" import asyncio +import os from itertools import chain import aioboto3 @@ -25,6 +26,17 @@ def event_loop(): loop.close() +@pytest.fixture(scope='session') +def user(): + """Get the current USER environment variable value. + + Some tests need to be skipped if the user doesn't have access to an + AWS service. + """ + + return os.getenv('USER') + + @pytest.fixture(scope='module') def payload(): """Test payload to reuse.""" diff --git a/setup.py b/setup.py index 77b7c74..9268b30 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ long_description = fileobj.read() setup(name='aioradio', - version='0.9.8', + version='0.10.0', description='Generic asynchronous i/o python utilities for AWS services (SQS, S3, DynamoDB, Secrets Manager), Redis, MSSQL (pyodbc), JIRA and more', long_description=long_description, long_description_content_type="text/markdown", @@ -28,10 +28,11 @@ 'ddtrace', 'fakeredis', 'httpx', - 'python-json-logger', 'mandrill', - 'pysmb', 'orjson', + 'psycopg2-binary' + 'pysmb', + 'python-json-logger', 'xlrd' ], include_package_data=True, From 1c65b833626ac2d5c63ed73ac494a50af77341e6 Mon Sep 17 00:00:00 2001 From: "tim.reichard" Date: Mon, 8 Feb 2021 08:57:35 -0600 Subject: [PATCH 2/3] Refine logic around rollback and commit --- aioradio/file_ingestion.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/aioradio/file_ingestion.py b/aioradio/file_ingestion.py index 4590ce5..f82255b 100644 --- a/aioradio/file_ingestion.py +++ b/aioradio/file_ingestion.py @@ -114,8 +114,9 @@ def child_wrapper(*args, **kwargs) -> Any: if rollback[name]: conn.rollback() + else: + conn.commit() - conn.commit() conn.close() print(f"CLOSED CONNECTION for {name}") From e464b0355c97c52cdf9944dfb76512bfeb770c1b Mon Sep 17 00:00:00 2001 From: "tim.reichard" Date: Mon, 8 Feb 2021 09:15:12 -0600 Subject: [PATCH 3/3] Update logic not to commit if exception is thrown --- aioradio/file_ingestion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aioradio/file_ingestion.py b/aioradio/file_ingestion.py index f82255b..cc2a7e7 100644 --- a/aioradio/file_ingestion.py +++ b/aioradio/file_ingestion.py @@ -114,7 +114,7 @@ def child_wrapper(*args, **kwargs) -> Any: if rollback[name]: conn.rollback() - else: + elif error is None: conn.commit() conn.close()