diff --git a/HISTORY.rst b/HISTORY.rst index 4514f73..aaa9449 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -3,6 +3,11 @@ History ======= +v0.17.28 (2023-01-26) + +* Add support to filter secrets in async_db_wrapper using the argument secret_json_key. + + v0.17.27 (2023-01-20) * Raise errors instead of printing in file_ingestion function zipfile_to_tsv. diff --git a/aioradio/file_ingestion.py b/aioradio/file_ingestion.py index 7ec46c8..d0cc294 100644 --- a/aioradio/file_ingestion.py +++ b/aioradio/file_ingestion.py @@ -1306,7 +1306,12 @@ async def child_wrapper(*args, **kwargs) -> Any: secret = await get_secret(item['secret'], item['region'], item['aws_creds']) else: secret = await get_secret(item['secret'], item['region']) - creds = {**json.loads(secret), **{'database': item.get('database', '')}} + + secret = json.loads(secret) + if 'secret_json_key' in item: + secret = secret[item['secret_json_key']] + + creds = {**secret, **{'database': item.get('database', '')}} if item['db'] == 'pyodbc': # Add import here because it requires extra dependencies many systems # don't have out of the box so only import when explicitly being used diff --git a/aioradio/tests/aws_secrets_test.py b/aioradio/tests/aws_secrets_test.py index 1ac5476..9b3a3a3 100644 --- a/aioradio/tests/aws_secrets_test.py +++ b/aioradio/tests/aws_secrets_test.py @@ -10,8 +10,6 @@ from aioradio.aws.secrets import get_secret -pytestmark = pytest.mark.asyncio - @mock_secretsmanager def test_secrets_get_secret(): @@ -27,6 +25,7 @@ def test_secrets_get_secret(): @pytest.mark.xfail +@pytest.mark.asyncio async def test_secrets_get_secret_with_bad_key(): """Test exception raised when using a bad key retrieving from Secrets Manager.""" diff --git a/aioradio/tests/file_ingestion_test.py b/aioradio/tests/file_ingestion_test.py index efd632a..f9c5222 100644 --- a/aioradio/tests/file_ingestion_test.py +++ b/aioradio/tests/file_ingestion_test.py @@ -200,31 +200,25 @@ async def func(): assert result == 'Hello World' - -def test_async_db_wrapper(user): +@pytest.mark.asyncio +async def test_async_db_wrapper(user): """Test async_db_wrapper with database connections.""" if user != 'tim.reichard': pytest.skip('Skip test_async_db_wrapper since user is not Tim Reichard') - db_info=[ - { - 'name': 'test1', - 'db': 'pyodbc', - 'secret': 'production/airflowCluster/sqloltp', - 'region': 'us-east-1', - 'rollback': True - }, - { - 'name': 'test2', - 'db': 'psycopg2', - 'secret': 'datalab/dev/classplanner_db', - 'region': 'us-east-1', - 'database': 'student', - 'is_audit': False, - 'rollback': True - } - ] + db_info=[{ + 'db': 'pyodbc', + 'name': 'test1', + 'database': 'DataStage', + 'secret': 'efi/sandbox/all', + 'secret_json_key': 'mssql', + 'region': 'us-east-1', + 'rollback': True, + 'trusted_connection': 'no', + 'application_intent': 'ReadOnly', + 'tds_version': '7.4' + }] @async_db_wrapper(db_info=db_info) async def func(**kwargs): @@ -232,4 +226,4 @@ async def func(**kwargs): for name, conn in conns.items(): print(f"Connection name: {name}\tConnection object: {conn}") - func() + await func() diff --git a/setup.py b/setup.py index 42b9835..fb91328 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ long_description = fileobj.read() setup(name='aioradio', - version='0.17.27', + version='0.17.28', description='Generic asynchronous i/o python utilities for AWS services (SQS, S3, DynamoDB, Secrets Manager), Redis, MSSQL (pyodbc), JIRA and more', long_description=long_description, long_description_content_type="text/markdown",