Skip to content

Commit

Permalink
Merge pull request #13 from nrccua/ARCH-532-add-decorator-to-manage-d…
Browse files Browse the repository at this point in the history
…b-connections

Adding DB decorator for DAG usage primarily
  • Loading branch information
nrccua-timr authored Feb 8, 2021
2 parents 23ea7d2 + e464b03 commit 2ecf218
Show file tree
Hide file tree
Showing 12 changed files with 231 additions and 27 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ repos:
- id: requirements-txt-fixer
- id: trailing-whitespace
- repo: https://github.com/PyCQA/isort
rev: 5.6.4
rev: 5.7.0
hooks:
- id: isort
- repo: https://github.com/myint/docformatter
rev: v1.3.1
rev: v1.4
hooks:
- id: docformatter
6 changes: 6 additions & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@ History
=======


v0.10.0 (2021-02-08)
-----------------------

* Add decorator to manage DB connections and using SQL transactions.


v0.9.8 (2021-02-01)
-----------------------

Expand Down
87 changes: 84 additions & 3 deletions aioradio/file_ingestion.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""Generic functions related to working with files or the file system."""

# pylint: disable=broad-except
# pylint: disable=invalid-name
# pylint: disable=too-many-arguments
# pylint: disable=too-many-boolean-expressions

import asyncio
import functools
import json
import os
import re
import time
Expand All @@ -22,6 +24,10 @@
from smb.smb_structs import OperationFailure
from smb.SMBConnection import SMBConnection

from aioradio.aws.secrets import get_secret
from aioradio.psycopg2 import establish_psycopg2_connection
from aioradio.pyodbc import establish_pyodbc_connection

DIRECTORY = Path(__file__).parent.absolute()


Expand All @@ -31,7 +37,6 @@ def async_wrapper(func: coroutine) -> Any:
Args:
func (coroutine): async coroutine
Returns:
Any: any
"""
Expand All @@ -44,12 +49,88 @@ def wrapper(*args, **kwargs) -> Any:
Any: any
"""

loop = asyncio.get_event_loop()
return loop.run_until_complete(func(*args, **kwargs))
return asyncio.get_event_loop().run_until_complete(func(*args, **kwargs))

return wrapper


def async_db_wrapper(db_info: List[Dict[str, Any]]) -> Any:
"""Decorator to run functions using async that handles database connection
creation and closure. Pulls database creds from AWS secret manager.
Args:
db_info (List[Dict[str, str]], optional): Database info {'name', 'db', 'secret', 'region'}. Defaults to [].
Returns:
Any: any
"""

def parent_wrapper(func: coroutine) -> Any:
"""Decorator parent wrapper.
Args:
func (coroutine): async coroutine
Returns:
Any: any
"""

@functools.wraps(func)
def child_wrapper(*args, **kwargs) -> Any:
"""Decorator child wrapper. All DB established/closed connections
and commits or rollbacks take place in the decorator and should
never happen within the inner function.
Returns:
Any: any
"""

async_run = asyncio.get_event_loop().run_until_complete
conns = {}
rollback = {}

# create connections
for item in db_info:

if item['db'] in ['pyodbc', 'psycopg2']:
creds = {**json.loads(async_run(get_secret(item['secret'], item['region']))), **{'database': item.get('database', '')}}
if item['db'] == 'pyodbc':
conns[item['name']] = async_run(establish_pyodbc_connection(**creds, autocommit=False))
elif item['db'] == 'psycopg2':
conns[item['name']] = async_run(establish_psycopg2_connection(**creds))
rollback[item['name']] = item['rollback']
print(f"ESTABLISHED CONNECTION for {item['name']}")

result = None
error = None
try:
# run main function
result = async_run(func(*args, **kwargs, conns=conns)) if conns else async_run(func(*args, **kwargs))
except Exception as err:
error = err

# close connections
for name, conn in conns.items():

if rollback[name]:
conn.rollback()
elif error is None:
conn.commit()

conn.close()
print(f"CLOSED CONNECTION for {name}")

# if we caught an exception raise it again
if error is not None:
raise error

return result

return child_wrapper

return parent_wrapper


def async_wrapper_using_new_loop(func: coroutine) -> Any:
"""Decorator to run functions using async. Found this handy to use with DAG
tasks.
Expand Down
36 changes: 36 additions & 0 deletions aioradio/psycopg2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Pyodbc functions for connecting and send queries."""

# pylint: disable=c-extension-no-member
# pylint: disable=too-many-arguments

import psycopg2


async def establish_psycopg2_connection(
host: str,
user: str,
password: str,
database: str,
port: int=5432,
is_audit: bool=False
):
"""Acquire the psycopg2 connection object.
Args:
host (str): Host
user (str): User
password (str): Password
database (str): Database
port (int, optional): Port. Defaults to 5432.
is_audit (bool, optional): Audit queries. Defaults to False.
Returns:
pyscopg2 Connection object
"""

conn = psycopg2.connect(host=host, port=port, user=user, password=password, dbname=database)

if is_audit:
conn.autocommit=True

return conn
28 changes: 19 additions & 9 deletions aioradio/pyodbc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Pyodbc functions for connecting and send queries."""

# pylint: disable=c-extension-no-member
# pylint: disable=too-many-arguments

import os
from typing import Any, List, Union
Expand Down Expand Up @@ -34,15 +35,25 @@ async def get_unixodbc_driver_path(paths: List[str]) -> Union[str, None]:
return driver_path


async def establish_pyodbc_connection(host: str, user: str, pwd: str, driver: str = None) -> pyodbc.Connection:
async def establish_pyodbc_connection(
host: str,
user: str,
pwd: str,
port: int=1433,
database: str='',
driver: str='',
autocommit: bool=False
) -> pyodbc.Connection:
"""Acquire and return pyodbc.Connection object else raise
FileNotFoundError.
Args:
host (str): hostname
user (str): username
pwd (str): password
driver (str, optional): unixodbc driver. Defaults to None.
post (int, optional): port. Defaults to 1433.
database (str, optional): database. Defaults to ''.
driver (str, optional): unixodbc driver. Defaults to ''.
Raises:
FileNotFoundError: unable to locate unixodbc driver
Expand All @@ -51,16 +62,15 @@ async def establish_pyodbc_connection(host: str, user: str, pwd: str, driver: st
pyodbc.Connection: database connection object
"""

if driver is None:
verified_driver = await get_unixodbc_driver_path(UNIXODBC_DRIVER_PATHS)
else:
verified_driver = await get_unixodbc_driver_path([driver])

verified_driver = await get_unixodbc_driver_path([driver]) if driver else await get_unixodbc_driver_path(UNIXODBC_DRIVER_PATHS)
if verified_driver is None:
raise FileNotFoundError('Unable to locate unixodbc driver file: libtdsodbc.so')

return pyodbc.connect(
f'DRIVER={verified_driver};SERVER={host};PORT=1433;UID={user};PWD={pwd};TDS_Version=8.0')
conn_string = f'DRIVER={verified_driver};SERVER={host};PORT={port};UID={user};PWD={pwd};TDS_Version=8.0'
if database:
conn_string += f';DATABASE={database}'

return pyodbc.connect(conn_string, autocommit=autocommit)


async def pyodbc_query_fetchone(conn: pyodbc.Connection, query: str) -> Union[List[Any], None]:
Expand Down
11 changes: 6 additions & 5 deletions aioradio/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,21 @@ aioboto3==8.2.0
aiobotocore==1.1.2
aiojobs==0.3.0
aioredis==1.3.1
ddtrace==0.45.0
ddtrace==0.46.0
fakeredis==1.4.5
flask==1.1.2
httpx==0.16.1
mandrill==1.0.59
moto==1.3.16
orjson==3.4.6
pre-commit==2.9.3
orjson==3.4.8
pre-commit==2.10.1
psycopg2-binary==2.8.6
pylint==2.6.0
pyodbc==4.0.30
pysmb==1.2.6
pytest==6.2.1
pytest==6.2.2
pytest-asyncio==0.14.0
pytest-cov==2.10.1
pytest-cov==2.11.1
python-json-logger==2.0.1
twine==3.3.0
wheel==0.36.2
37 changes: 36 additions & 1 deletion aioradio/tests/file_ingestion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@

import pytest

from aioradio.file_ingestion import (delete_ftp_file, establish_ftp_connection,
from aioradio.file_ingestion import (async_db_wrapper, delete_ftp_file,
establish_ftp_connection,
get_current_datetime_from_timestamp,
list_ftp_objects,
send_emails_via_mandrill,
Expand Down Expand Up @@ -180,3 +181,37 @@ async def test_delete_ftp_file(github_action):

result = await delete_ftp_file(conn=conn, service_name='EnrollmentFunnel', ftp_path='pytest/is/great/test_file_ingestion.zip')
assert result is True


def test_async_wrapper(user):
"""Test async_wrapper with database connections."""

if user != 'tim.reichard':
pytest.skip('Skip test_async_wrapper_factory since user is not Tim Reichard')

db_info=[
{
'name': 'test1',
'db': 'pyodbc',
'secret': 'production/airflowCluster/sqloltp',
'region': 'us-east-1',
'rollback': True
},
{
'name': 'test2',
'db': 'psycopg2',
'secret': 'datalab/dev/classplanner_db',
'region': 'us-east-1',
'database': 'student',
'is_audit': False,
'rollback': True
}
]

@async_db_wrapper(db_info=db_info)
async def func(**kwargs):
conns = kwargs['conns']
for name, conn in conns.items():
print(f"Connection name: {name}\tConnection object: {conn}")

func()
24 changes: 24 additions & 0 deletions aioradio/tests/psycopg2_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""pytest psycopg2 script."""

import json

import pytest

from aioradio.aws.secrets import get_secret
from aioradio.psycopg2 import establish_psycopg2_connection

pytestmark = pytest.mark.asyncio


async def test_establish_psycopg2_connection(github_action, user):
"""Test establish_psycopg2_connection."""

if github_action:
pytest.skip('Skip test_establish_psycopg2_connection when running via Github Action')
elif user != 'tim.reichard':
pytest.skip('Skip test_establish_psycopg2_connection since user is not Tim Reichard')

creds = json.loads(await get_secret('datalab/dev/classplanner_db', 'us-east-1'))
conn = await establish_psycopg2_connection(**creds, database='student')
assert conn.closed == 0
conn.close()
2 changes: 1 addition & 1 deletion aioradio/tests/pyodbc_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""pytest file_ingestion script."""
"""pytest pyodbc script."""

import os

Expand Down
4 changes: 1 addition & 3 deletions aioradio/tests/sqs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

# pylint: disable=c-extension-no-member

from asyncio import sleep
from uuid import uuid4

import orjson
Expand All @@ -21,7 +20,7 @@
async def test_add_regions():
"""Add us-east-2 region."""

await add_regions(['us-east-2'])
await add_regions([REGION])


async def test_sqs_creating_queue(sqs_queue_url):
Expand Down Expand Up @@ -82,5 +81,4 @@ async def test_sqs_purge_messages():
# accept err: "Only one PurgeQueue operation on pytest is allowed every 60 seconds."
assert 'PurgeQueue' in err
else:
await sleep(3)
assert not await get_messages(queue=QUEUE, region=REGION, wait_time=1)
Loading

0 comments on commit 2ecf218

Please sign in to comment.