diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index de75ff5..8f59b26 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,10 +2,11 @@ default_language_version: python: python3.8 repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v3.1.0 + rev: v3.4.0 hooks: - id: check-added-large-files - id: check-ast + - id: check-docstring-first - id: check-json - id: check-yaml - id: debug-statements @@ -13,3 +14,11 @@ repos: - id: name-tests-test - id: requirements-txt-fixer - id: trailing-whitespace +- repo: https://github.com/PyCQA/isort + rev: 5.6.4 + hooks: + - id: isort +- repo: https://github.com/myint/docformatter + rev: v1.3.1 + hooks: + - id: docformatter diff --git a/HISTORY.rst b/HISTORY.rst index 9ccf0a9..3ef8a60 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -3,6 +3,13 @@ History ======= +v0.9.6 (2020-12-22) +----------------------- + +* Apply pydoc to repository. +* Add isort and docformatter to pre-commit. + + v0.9.5 (2020-12-14) ----------------------- diff --git a/aioradio/aws/dynamodb.py b/aioradio/aws/dynamodb.py index 4623b89..a102197 100644 --- a/aioradio/aws/dynamodb.py +++ b/aioradio/aws/dynamodb.py @@ -1,10 +1,8 @@ -'''Generic async AWS functions for DynamoDB.''' +"""Generic async AWS functions for DynamoDB.""" # pylint: disable=too-many-arguments -from typing import Any -from typing import Dict -from typing import List +from typing import Any, Dict, List from aioradio.aws.utils import AwsServiceManager @@ -19,7 +17,18 @@ async def create_dynamo_table( attribute_definitions: List[Dict[str, str]], key_schema: List[Dict[str, str]], provisioned_throughput: Dict[str, int]) -> str: - '''Create dynamo table. ''' + """Create dynamo table. + + Args: + table_name (str): dynamo table name + region (str): AWS region + attribute_definitions (List[Dict[str, str]]): an attribute for describing the key schema for the table + key_schema (List[Dict[str, str]]): attributes that make up the primary key of a table, or the key attributes of an index + provisioned_throughput (Dict[str, int]): Throughput (ReadCapacityUnits & WriteCapacityUnits) for the dynamo table + + Returns: + str: error message if any + """ error = '' @@ -39,7 +48,14 @@ async def create_dynamo_table( @AWS_SERVICE.active async def get_list_of_dynamo_tables(region: str) -> List[str]: - '''Get list of Dynamo tables in a particular region.''' + """Get list of Dynamo tables in a particular region. + + Args: + region (str): AWS region + + Returns: + List[str]: list of dynamo tables + """ tables = [] result = await DYNAMO[region]['client']['obj'].list_tables() @@ -50,7 +66,16 @@ async def get_list_of_dynamo_tables(region: str) -> List[str]: @AWS_SERVICE.active async def scan_dynamo(table_name: str, region: str, key: Any=None) -> List[Any]: - '''Scan dynamo table using a filter_expression if supplied.''' + """Scan dynamo table using a filter_expression if supplied. + + Args: + table_name (str): dynamo table name + region (str): AWS region + key (Any, optional): filter expression to reduce items scaned. Defaults to None. + + Returns: + List[Any]: list of scanned items + """ result = [] scan_kwargs = {'FilterExpression': key} if key is not None else {} @@ -68,7 +93,16 @@ async def scan_dynamo(table_name: str, region: str, key: Any=None) -> List[Any]: @AWS_SERVICE.active async def put_item_in_dynamo(table_name: str, region: str, item: Dict[str, Any]) -> Dict[str, Any]: - '''Put item in dynamo table.''' + """Put item in dynamo table. + + Args: + table_name (str): dynamo table name + region (str): AWS region + item (Dict[str, Any]): items to add/modifiy in dynamo table + + Returns: + Dict[str, Any]: response of operation + """ result = {} table = await DYNAMO[region]['resource']['obj'].Table(table_name) @@ -77,8 +111,17 @@ async def put_item_in_dynamo(table_name: str, region: str, item: Dict[str, Any]) @AWS_SERVICE.active -async def query_dynamo(table_name: str, region: str, key: Any): - '''Query dynamo for with specific key_condition_expression.''' +async def query_dynamo(table_name: str, region: str, key: Any) -> List[Any]: + """Query dynamo for with specific key_condition_expression. + + Args: + table_name (str): dynamo table name + region (str): AWS region + key (Any): KeyConditionExpression parameter to provide a specific value for the partition key + + Returns: + List[Any]: [description] + """ result = [] query_kwargs = {'KeyConditionExpression': key} @@ -104,7 +147,21 @@ async def update_item_in_dynamo( expression_attribute_values: str='', condition_expression: str='', return_values: str='UPDATED_NEW') -> Dict[str, Any]: - '''Update an item in Dynamo without overwriting the entire item.''' + """Update an item in Dynamo without overwriting the entire item. + + Args: + table_name (str): dynamo table name + region (str): AWS region + key (Dict[str, Any]): partition key and sort key if applicable + update_expression (str): attributes to be updated, the action to be performed on them, and new value(s) for them + expression_attribute_names (str): one or more substitution tokens for attribute names in an expression + expression_attribute_values (str, optional): one or more values that can be substituted in an expression. Defaults to ''. + condition_expression (str, optional): condition that must be satisfied in order for a conditional update to succeed. Defaults to ''. + return_values (str, optional): items to return in response. Defaults to 'UPDATED_NEW'. + + Returns: + Dict[str, Any]: [description] + """ result = {} update_kwargs = { @@ -127,7 +184,16 @@ async def update_item_in_dynamo( @AWS_SERVICE.active async def batch_write_to_dynamo(table_name: str, region: str, items: List[Dict[str, Any]]) -> bool: - '''Write batch of items to dynamo table.''' + """Write batch of items to dynamo table. + + Args: + table_name (str): dynamo table name + region (str): AWS region + items (List[Dict[str, Any]]): items to write to dynamo table + + Returns: + bool: success status of writing items + """ batch_writer_successful = False @@ -146,7 +212,16 @@ async def batch_get_items_from_dynamo( table_name: str, region: str, items: List[Dict[str, Any]]) -> Dict[str, Any]: - '''Get batch of items from dynamo.''' + """Get batch of items from dynamo. + + Args: + table_name (str): dynamo table name + region (str): AWS region + items (List[Dict[str, Any]]): list of items to fetch from dynamo table + + Returns: + Dict[str, Any]: response of operation + """ response = await DYNAMO[region]['resource']['obj'].batch_get_item(RequestItems={table_name: {'Keys': items}}) diff --git a/aioradio/aws/moto_server.py b/aioradio/aws/moto_server.py index 6270939..4753a89 100644 --- a/aioradio/aws/moto_server.py +++ b/aioradio/aws/moto_server.py @@ -1,4 +1,4 @@ -'''moto server for pytesting aws services.''' +"""moto server for pytesting aws services.""" # pylint: disable=too-many-instance-attributes # pylint: disable=unused-variable @@ -6,29 +6,35 @@ import asyncio import functools import logging +import os import socket import threading import time -import os # Third Party import aiohttp import moto.server import werkzeug.serving - HOST = '127.0.0.1' _PYCHARM_HOSTED = os.environ.get('PYCHARM_HOSTED') == '1' _CONNECT_TIMEOUT = 90 if _PYCHARM_HOSTED else 10 -def get_free_tcp_port(release_socket: bool = False): - '''Get an available TCP port.''' +def get_free_tcp_port(release_socket: bool = False) -> tuple: + """Get an available TCP port. + + Args: + release_socket (bool, optional): release socket. Defaults to False. + + Returns: + tuple: socket and port + """ sckt = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sckt.bind((HOST, 0)) - addr, port = sckt.getsockname() + _, port = sckt.getsockname() if release_socket: sckt.close() return port @@ -37,8 +43,11 @@ def get_free_tcp_port(release_socket: bool = False): class MotoService: - """ Will Create MotoService. Service is ref-counted so there will only be one per process. - Real Service will be returned by `__aenter__`.""" + """Will Create MotoService. + + Service is ref-counted so there will only be one per process. Real + Service will be returned by `__aenter__`. + """ _services = dict() # {name: instance} _main_app: moto.server.DomainDispatcherApplication = None @@ -59,8 +68,12 @@ def __init__(self, service_name: str, port: int = None): self._server = None @property - def endpoint_url(self): - '''Get the server endpoint url.''' + def endpoint_url(self) -> str: + """Get the server endpoint url. + + Returns: + str: url + """ return f'http://{self._ip_address}:{self._port}' diff --git a/aioradio/aws/s3.py b/aioradio/aws/s3.py index f574533..f51ccb1 100644 --- a/aioradio/aws/s3.py +++ b/aioradio/aws/s3.py @@ -1,9 +1,7 @@ -'''Generic async AWS functions for S3.''' +"""Generic async AWS functions for S3.""" import logging -from typing import Any -from typing import Dict -from typing import List +from typing import Any, Dict, List from aioradio.aws.utils import AwsServiceManager @@ -14,14 +12,30 @@ @AWS_SERVICE.active async def create_bucket(bucket: str) -> Dict[str, str]: - '''Create an s3 bucket.''' + """Create an s3 bucket. + + Args: + bucket (str): s3 bucket + + Returns: + Dict[str, str]: response of operation + """ return await S3['client']['obj'].create_bucket(Bucket=bucket) @AWS_SERVICE.active async def upload_file(bucket: str, filepath: str, s3_key: str) -> Dict[str, Any]: - '''Upload file to s3.''' + """Upload file to s3. + + Args: + bucket (str): s3 bucket + filepath (str): local filepath to upload + s3_key (str): destination s3 key for uploaded file + + Returns: + Dict[str, Any]: response of operation + """ response = {} with open(filepath, 'rb') as fileobj: @@ -31,8 +45,14 @@ async def upload_file(bucket: str, filepath: str, s3_key: str) -> Dict[str, Any] @AWS_SERVICE.active -async def download_file(bucket: str, filepath: str, s3_key: str) -> None: - '''Download file to s3.''' +async def download_file(bucket: str, filepath: str, s3_key: str): + """Download file to s3. + + Args: + bucket (str): s3 bucket + filepath (str): local filepath for downloaded file + s3_key (str): s3 key to download + """ with open(filepath, 'wb') as fileobj: data = await get_object(bucket=bucket, s3_key=s3_key) @@ -41,7 +61,16 @@ async def download_file(bucket: str, filepath: str, s3_key: str) -> None: @AWS_SERVICE.active async def list_s3_objects(bucket: str, s3_prefix: str, with_attributes: bool=False) -> List[str]: - '''List objects in s3 path.''' + """List objects in s3 path. + + Args: + bucket (str): s3 bucket + s3_prefix (str): s3 prefix + with_attributes (bool, optional): return all file attributes in addition to s3 keys. Defaults to False. + + Returns: + List[str]: [description] + """ arr = [] paginator = S3['client']['obj'].get_paginator('list_objects') @@ -57,7 +86,15 @@ async def list_s3_objects(bucket: str, s3_prefix: str, with_attributes: bool=Fal @AWS_SERVICE.active async def get_s3_file_attributes(bucket: str, s3_key: str) -> Dict[str, Any]: - '''Get s3 objects metadata attributes.''' + """Get s3 objects metadata attributes. + + Args: + bucket (str): s3 bucket + s3_key (str): s3 key + + Returns: + Dict[str, Any]: response of operation + """ s3_object = await S3['client']['obj'].get_object(Bucket=bucket, Key=s3_key) del s3_object['Body'] @@ -67,7 +104,15 @@ async def get_s3_file_attributes(bucket: str, s3_key: str) -> Dict[str, Any]: @AWS_SERVICE.active async def get_object(bucket: str, s3_key: str) -> bytes: - '''Directly download contents of s3 object.''' + """Directly download contents of s3 object. + + Args: + bucket (str): s3 bucket + s3_key (str): s3 key + + Returns: + bytes: streaming of s3 key as data bytes + """ data = None s3_object = await S3['client']['obj'].get_object(Bucket=bucket, Key=s3_key) @@ -79,7 +124,15 @@ async def get_object(bucket: str, s3_key: str) -> bytes: @AWS_SERVICE.active async def delete_s3_object(bucket: str, s3_prefix: str) -> Dict[str, Any]: - '''Delete object(s) from s3.''' + """Delete object(s) from s3. + + Args: + bucket (str): s3 bucket + s3_prefix (str): s3 prefix + + Returns: + Dict[str, Any]: response of operation + """ response = await S3['client']['obj'].delete_object(Bucket=bucket, Key=s3_prefix) diff --git a/aioradio/aws/secrets.py b/aioradio/aws/secrets.py index cafcc18..3a573d2 100644 --- a/aioradio/aws/secrets.py +++ b/aioradio/aws/secrets.py @@ -1,4 +1,4 @@ -'''Generic async AWS functions for Secrets Manager.''' +"""Generic async AWS functions for Secrets Manager.""" import base64 @@ -10,7 +10,15 @@ @AWS_SERVICE.active async def get_secret(secret_name: str, region: str) -> str: - '''Get secret from AWS Secrets Manager.''' + """Get secret from AWS Secrets Manager. + + Args: + secret_name (str): secret name + region (str): AWS region + + Returns: + str: secret value + """ secret = '' response = await SECRETS[region]['client']['obj'].get_secret_value(SecretId=secret_name) diff --git a/aioradio/aws/sqs.py b/aioradio/aws/sqs.py index bd54546..ba93929 100644 --- a/aioradio/aws/sqs.py +++ b/aioradio/aws/sqs.py @@ -1,11 +1,10 @@ -'''Generic async AWS functions for SQS.''' +"""Generic async AWS functions for SQS.""" # pylint: disable=dangerous-default-value # pylint: disable=too-many-arguments import logging -from typing import Dict -from typing import List +from typing import Any, Dict, List from botocore.exceptions import ClientError @@ -17,8 +16,17 @@ @AWS_SERVICE.active -async def create_queue(queue: str, region: str, attributes: Dict[str, str]) -> Dict[str, str]: - '''Create SQS queue in region defined.''' +async def create_queue(queue: str, region: str, attributes: Dict[str, str]) -> Dict[str, Any]: + """Create SQS queue in region defined. + + Args: + queue (str): sqs queue + region (str): AWS region + attributes (Dict[str, str]): sqs queue attributes + + Returns: + Dict[str, str]: response of operation + """ return await SQS[region]['client']['obj'].create_queue(QueueName=queue, Attributes=attributes) @@ -31,42 +39,18 @@ async def get_messages( max_messages: int=10, visibility_timeout: int=30, attribute_names: List[str]=[]) -> List[dict]: - """ - Get up to 10 messages from an SQS queue. Returns a list of dicts where each dict contains - the message information, Here is an example of a message produce from an s3 -> sqs event: - { - 'MessageId': '0050daf1-313b-4a5c-a4e7-8e5596085fa8', - 'ReceiptHandle': '', - 'MD5OfBody': 'ec3212bbe0cf0239ba54eefd206338ef', - 'Body': '{ - "Records": [{ - "eventVersion": "2.1", - "eventSource": "aws:s3", - "awsRegion": "us-east-2", - "eventTime": "2020-07-07T19:02:45.192Z", - "eventName": "ObjectCreated:CompleteMultipartUpload", - "userIdentity": {"principalId":"AWS:AIDATIJBHOZJSHFN3H2KY"}, - "requestParameters": {"sourceIPAddress":"52.249.199.100"}, - "responseElements": {"x-amz-request-id":"625BA0F414478E41", - "x-amz-id-2": ""}, - "s3": { - "s3SchemaVersion": "1.0", - "configurationId": "tf-s3-queue-20191002201742888700000006", - "bucket": { - "name": "nrccua-datalab-efi-input-sandbox.us-east-2", - "ownerIdentity": {"principalId":"A1MQ0EIGU3DVVT"}, - "arn": "arn:aws:s3:::nrccua-datalab-efi-input-sandbox.us-east-2" - }, - "object": { - "key": "XXXXXX/hello_world.txt", - "size":23, "eTag":"bccd05d8b202eba5e812bcf501c1682a-1", - "versionId": "ZLMCPDFMu6W865WWyfiVsaWZU8pJRUb3", - "sequencer":"005F04C6DA2DF81A97" - } - } - }] - }' - } + """Get up to 10 messages from an SQS queue. + + Args: + queue (str): sqs queue + region (str): AWS region + wait_time (int, optional): time to wait polling for messages. Defaults to 20. + max_messages (int, optional): max messages polled. Defaults to 10. + visibility_timeout (int, optional): timeout for when message will return to queue if not deleted. Defaults to 30. + attribute_names (List[str], optional): list of attributes for which to retrieve information. Defaults to []. + + Returns: + List[dict]: list of dicts where each dict contains the message information """ messages = [] @@ -89,31 +73,16 @@ async def send_messages( queue: str, region: str, entries: List[Dict[str, str]]) -> Dict[str, list]: - ''' - Send up to 10 messages to an SQS queue. Each dict in entries must have - the keys: Id and MessageBody populated. The returned data is a dict with two keys, either - Successful or Failed, for example: - { - 'Successful': [ - { - 'Id': 'string', - 'MessageId': 'string', - 'MD5OfMessageBody': 'string', - 'MD5OfMessageAttributes': 'string', - 'MD5OfMessageSystemAttributes': 'string', - 'SequenceNumber': 'string' - }, - ], - 'Failed': [ - { - 'Id': 'string', - 'SenderFault': True|False, - 'Code': 'string', - 'Message': 'string' - }, - ] - } - ''' + """Send up to 10 messages to an SQS queue. + + Args: + queue (str): sqs queue + region (str): AWS region + entries (List[Dict[str, str]]): List of dicts containing the keys: Id and MessageBody + + Returns: + Dict[str, list]: dict with two keys, either Successful or Failed + """ resp = await SQS[region]['client']['obj'].get_queue_url(QueueName=queue) queue_url = resp['QueueUrl'] @@ -127,26 +96,16 @@ async def delete_messages( queue: str, region: str, entries: List[Dict[str, str]]) -> Dict[str, list]: - ''' - Delete up to 10 messages from an SQS queue. Each dict in entries must have - the keys: Id and ReceiptHandle populated. The returned data is a dict with two keys, either - Successful or Failed, for example: - { - 'Successful': [ - { - 'Id': 'string' - }, - ], - 'Failed': [ - { - 'Id': 'string', - 'SenderFault': True|False, - 'Code': 'string', - 'Message': 'string' - }, - ] - } - ''' + """Delete up to 10 messages from an SQS queue. + + Args: + queue (str): sqs queue + region (str): AWS region + entries (List[Dict[str, str]]): List of dicts containing the keys: Id and ReceiptHandle + + Returns: + Dict[str, list]: dict with two keys, either Successful or Failed + """ resp = await SQS[region]['client']['obj'].get_queue_url(QueueName=queue) queue_url = resp['QueueUrl'] @@ -157,7 +116,15 @@ async def delete_messages( @AWS_SERVICE.active async def purge_messages(queue: str, region: str) -> str: - '''Purge messages from queue in region defined.''' + """Purge messages from queue in region defined. + + Args: + queue (str): sqs queue + region (str): AWS region + + Returns: + str: error message if any + """ error = '' try: diff --git a/aioradio/aws/utils.py b/aioradio/aws/utils.py index 4f585f1..51b60cd 100644 --- a/aioradio/aws/utils.py +++ b/aioradio/aws/utils.py @@ -1,5 +1,5 @@ -'''Generic async utility functions.''' +"""Generic async utility functions.""" # pylint: disable=broad-except # pylint: disable=logging-fstring-interpolation @@ -9,10 +9,10 @@ import logging from asyncio import sleep from copy import deepcopy -from dataclasses import dataclass -from dataclasses import field +from dataclasses import dataclass, field from time import time -from typing import List +from types import coroutine +from typing import Any, Dict, List import aioboto3 import aiobotocore @@ -23,7 +23,7 @@ @dataclass class AwsServiceManager: - '''AWS Service Manager''' + """AWS Service Manager.""" service: str module: str = 'aiobotocore' @@ -32,8 +32,6 @@ class AwsServiceManager: scheduler: aiojobs._scheduler = None def __post_init__(self): - '''Post constructor.''' - services = ['s3', 'sqs', 'secretsmanager', 'dynamodb'] if self.service not in services: raise ValueError(f'service parameter must be one of the following: {services}') @@ -60,7 +58,7 @@ def __del__(self): self.scheduler.close() async def create_scheduler(self): - '''Schedule jobs.''' + """Schedule jobs.""" self.scheduler = await aiojobs.create_scheduler() if self.regions: @@ -73,10 +71,15 @@ async def create_scheduler(self): if self.module == 'aioboto3': await self.scheduler.spawn(self.aio_server(item='resource')) - async def aio_server(self, item, region=None): - '''Begin long running server establishing modules service_dict object.''' + async def aio_server(self, item: str, region: str=''): + """Begin long running server establishing modules service_dict object. + + Args: + item (str): either 'client' or 'resource' depending on the aws service and python package + region (str, optional): AWS region. Defaults to ''. + """ - service_dict = self.service_dict if region is None else self.service_dict[region] + service_dict = self.service_dict if region == '' else self.service_dict[region] await self.establish_client_resource(service_dict[item], item, region=region) while True: @@ -90,11 +93,19 @@ async def aio_server(self, item, region=None): await self.establish_client_resource(service_dict[item], item=item, region=region, reestablish=True) - async def establish_client_resource(self, service_dict, item, region=None, reestablish=False): - '''Establish the AioSession client or resource, then re-establish every self.sleep_interval seconds.''' + async def establish_client_resource(self, service_dict: Dict[str, Any], item: str, region: str='', reestablish: bool=False): + """Establish the AioSession client or resource, then re-establish every + self.sleep_interval seconds. + + Args: + service_dict (Dict[str, Any]): dict containing info about the service requested + item (str): either 'client' or 'resource' depending on the aws service and python package + region (str, optional): AWS region. Defaults to ''. + reestablish (bool, optional): should async context manager be reinstantiated. Defaults to False. + """ kwargs = {'service_name': self.service, 'verify': False} - if region is None: + if region == '': phrase = f're-establish {self.service}' if reestablish else f'establish {self.service}' else: kwargs['region_name'] = region @@ -115,10 +126,18 @@ async def establish_client_resource(self, service_dict, item, region=None, reest service_dict['busy'] = False LOG.info(f'Successfully {phrase} service object!') - def get_region(self, args, kwargs): - '''Attempt to detect the region from the kwargs or args.''' + def get_region(self, args: List[Any], kwargs: Dict[str, Any]) -> str: + """Attempt to detect the region from the kwargs or args. + + Args: + args (List[Any]): list of arguments + kwargs (Dict[str, Any]): Dict of keyword arguments + + Returns: + str: AWS region + """ - region = None + region = '' if 'region' in kwargs: region = kwargs['region'] @@ -130,12 +149,27 @@ def get_region(self, args, kwargs): return region - def active(self, func): - '''Decorator to keep track of currently running functions, allowing the AioSession client - to only be re-establish when the count is zero to avoid functions using a stale client.''' + def active(self, func: coroutine) -> Any: + """Decorator to keep track of currently running functions, allowing the + AioSession client to only be re-establish when the count is zero to + avoid functions using a stale client. + + Args: + func (coroutine): async coroutine + + Returns: + Any: any + """ + + async def wrapper(*args, **kwargs) -> Any: + """Decorator wrapper. + + Raises: + error: some general error during function execution - async def wrapper(*args, **kwargs): - '''Decorator wrapper.''' + Returns: + Any: any + """ obj = self.service_dict[self.get_region(args, kwargs)] if self.regions else self.service_dict @@ -143,6 +177,7 @@ async def wrapper(*args, **kwargs): while obj['client']['busy'] or ('resource' in obj and obj['resource']['busy']): await sleep(0.01) + result = None error = None obj['active'] += 1 try: diff --git a/aioradio/file_ingestion.py b/aioradio/file_ingestion.py index 1a88026..46e9ab0 100644 --- a/aioradio/file_ingestion.py +++ b/aioradio/file_ingestion.py @@ -1,4 +1,4 @@ -'''Generic functions related to working with files or the file system.''' +"""Generic functions related to working with files or the file system.""" # pylint: disable=invalid-name # pylint: disable=too-many-arguments @@ -10,29 +10,37 @@ import time import zipfile from asyncio import sleep -from concurrent.futures import as_completed -from concurrent.futures import ThreadPoolExecutor -from datetime import datetime -from datetime import timezone -from datetime import tzinfo +from concurrent.futures import ThreadPoolExecutor, as_completed +from datetime import datetime, timezone, tzinfo from pathlib import Path -from typing import Any -from typing import Dict -from typing import List +from types import coroutine +from typing import Any, Dict, List import mandrill from smb.base import SharedFile -from smb.SMBConnection import SMBConnection from smb.smb_structs import OperationFailure +from smb.SMBConnection import SMBConnection DIRECTORY = Path(__file__).parent.absolute() -def async_wrapper(func): - '''Decorator to run functions using async. Found this handy to use with DAG tasks.''' +def async_wrapper(func: coroutine) -> Any: + """Decorator to run functions using async. Found this handy to use with DAG + tasks. + + Args: + func (coroutine): async coroutine + + Returns: + Any: any + """ - def wrapper(*args, **kwargs): - '''Decorator wrapper.''' + def wrapper(*args, **kwargs) -> Any: + """Decorator wrapper. + + Returns: + Any: any + """ loop = asyncio.get_event_loop() return loop.run_until_complete(func(*args, **kwargs)) @@ -40,11 +48,23 @@ def wrapper(*args, **kwargs): return wrapper -def async_wrapper_using_new_loop(func): - '''Decorator to run functions using async. Found this handy to use with DAG tasks.''' +def async_wrapper_using_new_loop(func: coroutine) -> Any: + """Decorator to run functions using async. Found this handy to use with DAG + tasks. + + Args: + func (coroutine): async coroutine + + Returns: + Any: any + """ - def wrapper(*args, **kwargs): - '''Decorator wrapper.''' + def wrapper(*args, **kwargs) -> Any: + """Decorator wrapper. + + Returns: + Any: any + """ return asyncio.run(func(*args, **kwargs)) @@ -52,12 +72,23 @@ def wrapper(*args, **kwargs): async def async_process_manager( - function: asyncio.coroutine, + function: coroutine, list_of_kwargs: List[Dict[str, Any]], chunk_size: int, use_threads=True) -> List[Any]: - '''Process manager to run fixed number of functions, usually the same function expressed as - coroutines in an array. Use case is sending many http requests or iterating files.''' + """Process manager to run fixed number of functions, usually the same + function expressed as coroutines in an array. Use case is sending many + http requests or iterating files. + + Args: + function (coroutine): async coroutine + list_of_kwargs (List[Dict[str, Any]]): list of kwargs to pass into function + chunk_size (int): number of functions to run concurrently + use_threads (bool, optional): should threads be used. Defaults to True + + Returns: + List[Any]: List of function results + """ results = [] if use_threads: @@ -74,7 +105,15 @@ async def async_process_manager( async def unzip_file(filepath: str, directory: str) -> List[str]: - '''Unzip supplied filepath in the supplied directory returning list of filenames.''' + """Unzip supplied filepath in the supplied directory. + + Args: + filepath (str): filepath to unzip + directory (str): directory to write unzipped files + + Returns: + List[str]: List of filenames + """ zipped = zipfile.ZipFile(filepath) @@ -92,10 +131,19 @@ async def unzip_file_get_filepaths( directory: str, include_extensions: List[str] = None, exclude_extensions: List[str] = None) -> List[str]: - '''Get all the filepaths after unzipping supplied filepath in the supplied directory. - If the zipfile contains zipfiles, those files will also be unzipped. If include_extensions - is supplied then add those file types to the result. If exclude_extensions is supplied - then skip adding those filepaths to the result.''' + """Get all the filepaths after unzipping supplied filepath in the supplied + directory. If the zipfile contains zipfiles, those files will also be + unzipped. + + Args: + filepath (str): [description] + directory (str): [description] + include_extensions (List[str], optional): list of file types to add to result, if None add all. Defaults to None. + exclude_extensions (List[str], optional): list of file types to exclude from result. Defaults to None. + + Returns: + List[str]: [description] + """ paths = [] zipfile_filepaths = [filepath] @@ -123,10 +171,16 @@ async def unzip_file_get_filepaths( return paths -async def get_current_datetime_from_timestamp( - dt_format: str = '%Y-%m-%d %H_%M_%S.%f', - time_zone: tzinfo = timezone.utc) -> str: - '''Get the datetime from the timestamp in the format and timezone desired.''' +async def get_current_datetime_from_timestamp(dt_format: str = '%Y-%m-%d %H_%M_%S.%f', time_zone: tzinfo = timezone.utc) -> str: + """Get the datetime from the timestamp in the format and timezone desired. + + Args: + dt_format (str, optional): date format desired. Defaults to '%Y-%m-%d %H_%M_%S.%f'. + time_zone (tzinfo, optional): timezone desired. Defaults to timezone.utc. + + Returns: + str: current datetime + """ return datetime.fromtimestamp(time.time(), time_zone).strftime(dt_format) @@ -139,7 +193,19 @@ async def send_emails_via_mandrill( template_name: str, template_content: List[Dict[str, Any]] = None ) -> Any: - '''Send emails via Mailchimp mandrill API.''' + """Send emails via Mailchimp mandrill API. + + Args: + mandrill_api_key (str): mandrill API key + emails (List[str]): receipt emails + subject (str): email subject + global_merge_vars (List[Dict[str, Any]]): List of dicts used to dynamically populated email template with data + template_name (str): mandrill template name + template_content (List[Dict[str, Any]], optional): mandrill template content. Defaults to None. + + Returns: + Any: any + """ message = { 'to': [{'email': email} for email in emails], @@ -164,7 +230,21 @@ async def establish_ftp_connection( port: int = 139, use_ntlm_v2: bool = True, is_direct_tcp: bool = False) -> SMBConnection: - '''Establish FTP connection''' + """Establish FTP connection. + + Args: + user (str): ftp username + pwd (str): ftp password + name (str): connection name + server (str): ftp server + dns (str): DNS + port (int, optional): port. Defaults to 139. + use_ntlm_v2 (bool, optional): use NTLMv1 (False) or NTLMv2(True) authentication algorithm. Defaults to True. + is_direct_tcp (bool, optional): if NetBIOS over TCP (False) or SMB over TCP (True) is used for communication. Defaults to False. + + Returns: + SMBConnection: SMB connection object + """ conn = SMBConnection( username=user, @@ -185,7 +265,19 @@ async def list_ftp_objects( exclude_directories: bool = False, exclude_files: bool = False, regex_pattern: str = None) -> List[SharedFile]: - '''List all files and directories in an FTP directory.''' + """List all files and directories in an FTP directory. + + Args: + conn (SMBConnection): SMB connection object + service_name (str): FTP service name + ftp_path (str): FTP directory path + exclude_directories (bool, optional): directories to exclude. Defaults to False. + exclude_files (bool, optional): files to exclude. Defaults to False. + regex_pattern (str, optional): regex pattern to use to filter search. Defaults to None. + + Returns: + List[SharedFile]: List of files with their attribute info + """ results = [] for item in conn.listPath(service_name, ftp_path): @@ -200,7 +292,16 @@ async def list_ftp_objects( async def delete_ftp_file(conn: SMBConnection, service_name: str, ftp_path: str) -> bool: - '''Remove a file from FTP and verify deletion.''' + """Remove a file from FTP and verify deletion. + + Args: + conn (SMBConnection): SMB connection object + service_name (str): FTP service name + ftp_path (str): FTP directory path + + Returns: + bool: deletion status + """ status = False conn.deleteFiles(service_name, ftp_path) @@ -217,7 +318,17 @@ async def write_file_to_ftp( service_name: str, ftp_path: str, local_filepath) -> SharedFile: - '''Write file to FTP creating missing FTP directories if necessary.''' + """Write file to FTP creating missing FTP directories if necessary. + + Args: + conn (SMBConnection): SMB connection object + service_name (str): FTP service name + ftp_path (str): FTP directory path + local_filepath ([type]): local filepath + + Returns: + SharedFile: ftp file attribute info + """ # steps to create missing directories path = '' @@ -237,7 +348,16 @@ async def write_file_to_ftp( return await get_ftp_file_attributes(conn, service_name, ftp_path) -async def get_ftp_file_attributes(conn: SMBConnection, service_name: str, ftp_path: str): - '''GET FTP file attributes.''' +async def get_ftp_file_attributes(conn: SMBConnection, service_name: str, ftp_path: str) -> SharedFile: + """GET FTP file attributes. + + Args: + conn (SMBConnection): SMB connection object + service_name (str): FTP service name + ftp_path (str): FTP directory path + + Returns: + SharedFile: ftp file attribute info + """ return conn.getAttributes(service_name=service_name, path=ftp_path) diff --git a/aioradio/jira.py b/aioradio/jira.py index 42c49fa..f71afb6 100644 --- a/aioradio/jira.py +++ b/aioradio/jira.py @@ -1,13 +1,22 @@ -'''Generic functions related to Jira.''' +"""Generic functions related to Jira.""" -from typing import Any -from typing import Dict +from typing import Any, Dict import httpx -async def post_jira_issue(url: str, jira_user: str, jira_token: str, payload: Dict[str, Any]): - '''Post payload to create jira issue.''' +async def post_jira_issue(url: str, jira_user: str, jira_token: str, payload: Dict[str, Any]) -> Dict[str, Any]: + """Post payload to create jira issue. + + Args: + url (str): url + jira_user (str): jira username + jira_token (str): jira token + payload (Dict[str, Any]): jira payload describing ticket info + + Returns: + Dict[str, Any]: response of operation + """ headers = {'Content-Type': 'application/json'} auth = (jira_user, jira_token) @@ -15,9 +24,18 @@ async def post_jira_issue(url: str, jira_user: str, jira_token: str, payload: Di return await client.post(url=url, json=payload, auth=auth, headers=headers) -async def get_jira_issue(url: str, jira_user: str, jira_token: str): - '''Get Jira issue using jira_link built with the expected jira_id, - an example: https://nrccua.atlassian.net/rest/api/2/issue/.''' +async def get_jira_issue(url: str, jira_user: str, jira_token: str) -> Dict[str, Any]: + """Get Jira issue using jira_link built with the expected jira_id, an + example: https://nrccua.atlassian.net/rest/api/2/issue/. + + Args: + url (str): url + jira_user (str): jira username + jira_token (str): jira token + + Returns: + Dict[str, Any]: response of operation + """ headers = {'Content-Type': 'application/json'} auth = (jira_user, jira_token) @@ -25,8 +43,21 @@ async def get_jira_issue(url: str, jira_user: str, jira_token: str): return await client.get(url=url, auth=auth, headers=headers) -async def add_comment_to_jira(url: str, jira_user: str, jira_token: str, comment: str): - '''Add Jira comment to an existing issue.''' +async def add_comment_to_jira(url: str, jira_user: str, jira_token: str, comment: str) -> Dict[str, Any]: + """Add Jira comment to an existing issue. + + Args: + url (str): url + jira_user (str): jira username + jira_token (str): jira token + comment (str): comment to add to jira ticket + + Raises: + ValueError: problem with url + + Returns: + Dict[str, Any]: response of operation + """ if not url.endswith('comment'): msg = 'Check url value! Good example is https://nrccua.atlassian.net/rest/api/2/issue//comment' diff --git a/aioradio/logger.py b/aioradio/logger.py index 46080a8..653548d 100644 --- a/aioradio/logger.py +++ b/aioradio/logger.py @@ -1,21 +1,27 @@ -'''Generic logger logging to console or using json_log_formatter when logging in docker -for cleaner datadog logging.''' +"""Generic logger logging to console or using json_log_formatter when logging +in docker for cleaner datadog logging.""" # pylint: disable=too-few-public-methods import logging import sys from datetime import datetime -from typing import List +from typing import Any, Dict, List from pythonjsonlogger import jsonlogger class CustomJsonFormatter(jsonlogger.JsonFormatter): - '''Custom Json Formatter''' + """Custom Json Formatter.""" - def add_fields(self, log_record, record, message_dict): - '''normalize default set of fields.''' + def add_fields(self, log_record: Dict[str, Any], record: logging.LogRecord, message_dict: Dict[str, Any]): + """normalize default set of fields. + + Args: + log_record (Dict[str, Any]): dict object containing log record info + record (logging.LogRecord): contains all the info pertinent to the event being logged + message_dict (Dict[str, Any]): message dict + """ ddtags = self.get_ddtags(record, reserved=self._skip_fields) if ddtags: @@ -27,14 +33,22 @@ def add_fields(self, log_record, record, message_dict): log_record["level"] = log_record["level"].upper() if log_record.get("level") else record.levelname @staticmethod - def get_ddtags(record, reserved): - '''Add datadog tags in the format datadog expects.''' + def get_ddtags(record: logging.LogRecord, reserved: Dict[str, Any]) -> str: + """Add datadog tags in the format datadog expects. + + Args: + record (logging.LogRecord): contains all the info pertinent to the event being logged + reserved (dict[str, Any]): reserved logging keys + + Returns: + str: concatenated string of k-v pairs + """ tags = {k: v for k, v in record.__dict__.items() if k not in reserved} return ','.join([f"{k}:{v}" for k, v in tags.items()]) class DatadogLogger(): - '''Custom class for JSON Formatter to include level and name.''' + """Custom class for JSON Formatter to include level and name.""" def __init__( self, @@ -52,7 +66,7 @@ def __init__( self.add_handlers() def add_handlers(self): - '''Create log handlers.''' + """Create log handlers.""" for name in self.datadog_loggers: logger = logging.getLogger(name) diff --git a/aioradio/pyodbc.py b/aioradio/pyodbc.py index 8ea55df..3f4f1be 100644 --- a/aioradio/pyodbc.py +++ b/aioradio/pyodbc.py @@ -1,11 +1,9 @@ -'''Pyodbc functions for connecting and send queries.''' +"""Pyodbc functions for connecting and send queries.""" # pylint: disable=c-extension-no-member import os -from typing import Any -from typing import List -from typing import Union +from typing import Any, List, Union import pyodbc @@ -17,8 +15,15 @@ ] -async def get_unixodbc_driver_path(paths) -> Union[str, None]: - '''Check the file system for the unixodbc driver.''' +async def get_unixodbc_driver_path(paths: List[str]) -> Union[str, None]: + """Check the file system for the unixodbc driver. + + Args: + paths (List[str]): List of filepaths + + Returns: + Union[str, None]: driver path + """ driver_path = None for path in paths: @@ -29,12 +34,22 @@ async def get_unixodbc_driver_path(paths) -> Union[str, None]: return driver_path -async def establish_pyodbc_connection( - host: str, - user: str, - pwd: str, - driver: str = None) -> pyodbc.Connection: - '''Acquire and return pyodbc.Connection object else raise FileNotFoundError.''' +async def establish_pyodbc_connection(host: str, user: str, pwd: str, driver: str = None) -> pyodbc.Connection: + """Acquire and return pyodbc.Connection object else raise + FileNotFoundError. + + Args: + host (str): hostname + user (str): username + pwd (str): password + driver (str, optional): unixodbc driver. Defaults to None. + + Raises: + FileNotFoundError: unable to locate unixodbc driver + + Returns: + pyodbc.Connection: database connection object + """ if driver is None: verified_driver = await get_unixodbc_driver_path(UNIXODBC_DRIVER_PATHS) @@ -48,8 +63,17 @@ async def establish_pyodbc_connection( f'DRIVER={verified_driver};SERVER={host};PORT=1433;UID={user};PWD={pwd};TDS_Version=8.0') -async def pyodbc_query_fetchone(conn: pyodbc.Connection, query: str) -> List[Any]: - '''Execute pyodbc query and fetchone, see https://github.com/mkleehammer/pyodbc/wiki/Cursor''' +async def pyodbc_query_fetchone(conn: pyodbc.Connection, query: str) -> Union[List[Any], None]: + """Execute pyodbc query and fetchone, see + https://github.com/mkleehammer/pyodbc/wiki/Cursor. + + Args: + conn (pyodbc.Connection): database connection object + query (str): sql query + + Returns: + Union[List[Any], None]: list of one result + """ cursor = conn.cursor() result = cursor.execute(query).fetchone() @@ -57,8 +81,17 @@ async def pyodbc_query_fetchone(conn: pyodbc.Connection, query: str) -> List[Any return result -async def pyodbc_query_fetchall(conn: pyodbc.Connection, query: str) -> List[Any]: - '''Execute pyodbc query and fetchone, see https://github.com/mkleehammer/pyodbc/wiki/Cursor''' +async def pyodbc_query_fetchall(conn: pyodbc.Connection, query: str) -> Union[List[Any], None]: + """Execute pyodbc query and fetchone, see + https://github.com/mkleehammer/pyodbc/wiki/Cursor. + + Args: + conn (pyodbc.Connection): database connection object + query (str): sql query + + Returns: + Union[List[Any], None]: list of one to many results + """ cursor = conn.cursor() result = cursor.execute(query).fetchall() diff --git a/aioradio/redis.py b/aioradio/redis.py index f4dd051..3604f56 100644 --- a/aioradio/redis.py +++ b/aioradio/redis.py @@ -1,19 +1,16 @@ -'''aioradio redis cache script.''' +"""aioradio redis cache script.""" # pylint: disable=c-extension-no-member # pylint: disable=too-many-instance-attributes import asyncio import hashlib -from dataclasses import dataclass -from dataclasses import field -from typing import Any -from typing import Dict -from typing import List +from dataclasses import dataclass, field +from typing import Any, Dict, List import aioredis -from fakeredis.aioredis import create_redis_pool as fake_redis_pool import orjson +from fakeredis.aioredis import create_redis_pool as fake_redis_pool HASH_ALGO_MAP = { 'SHA1': hashlib.sha1, @@ -30,7 +27,7 @@ @dataclass class Redis: - '''class dealing with aioredis functions.''' + """class dealing with aioredis functions.""" config: Dict[str, Any] = field(default_factory=dict) pool: aioredis.Redis = field(init=False, repr=False) @@ -54,9 +51,7 @@ class Redis: # used exclusively for pytest fakeredis: bool = False - def __post_init__(self) -> None: - '''Post constructor''' - + def __post_init__(self): if self.fakeredis: self.pool = asyncio.get_event_loop().run_until_complete(fake_redis_pool()) else: @@ -69,13 +64,19 @@ def __post_init__(self) -> None: self.pool = loop.run_until_complete( aioredis.create_redis_pool(primary_endpoint, minsize=self.pool_minsize, maxsize=self.pool_maxsize)) - def __del__(self) -> None: - '''Teardown function''' - + def __del__(self): self.pool.close() - async def get_one_item(self, cache_key: str, use_json: bool=None) -> str: - '''Check if an item is cached in redis.''' + async def get_one_item(self, cache_key: str, use_json: bool=None) -> Any: + """Check if an item is cached in redis. + + Args: + cache_key (str): redis cache key + use_json (bool, optional): convert json value to object. Defaults to None. + + Returns: + Any: any + """ if use_json is None: use_json = self.use_json @@ -87,8 +88,16 @@ async def get_one_item(self, cache_key: str, use_json: bool=None) -> str: return value - async def get_many_items(self, items: List[str], use_json: bool=None) -> List[str]: - '''Check if many items are cached in redis.''' + async def get_many_items(self, items: List[str], use_json: bool=None) -> List[Any]: + """Check if many items are cached in redis. + + Args: + items (List[str]): list of redis cache keys + use_json (bool, optional): convert json values to objects. Defaults to None. + + Returns: + List[Any]: list of objects + """ if use_json is None: use_json = self.use_json @@ -100,8 +109,15 @@ async def get_many_items(self, items: List[str], use_json: bool=None) -> List[st return values - async def set_one_item(self, cache_key: str, cache_value: str, expire: int=None, use_json: bool=None) -> None: - '''Set one key-value pair in redis.''' + async def set_one_item(self, cache_key: str, cache_value: str, expire: int=None, use_json: bool=None): + """Set one key-value pair in redis. + + Args: + cache_key (str): redis cache key + cache_value (str): redis cache value + expire (int, optional): cache expiration. Defaults to None. + use_json (bool, optional): set object to json before writing to cache. Defaults to None. + """ if expire is None: expire = self.expire @@ -114,15 +130,31 @@ async def set_one_item(self, cache_key: str, cache_value: str, expire: int=None, await self.pool.set(cache_key, cache_value, expire=expire) - async def delete_one_item(self, cache_key: str) -> None: - '''Delete key from redis.''' + async def delete_one_item(self, cache_key: str) -> int: + """Delete key from redis. + + Args: + cache_key (str): redis cache key + + Returns: + int: 1 if key is found and deleted else 0 + """ return await self.pool.delete(cache_key) async def build_cache_key(self, payload: Dict[str, Any], separator='|', use_hashkey: bool=None) -> str: - '''If you'd like to build a cache key from a dictionary object this is the function for you. - This funciton will concatenate and normalize key-values from an unnested dict, taking - care of sorting the keys and each of their values (if a list).''' + """build a cache key from a dictionary object. Concatenate and + normalize key-values from an unnested dict, taking care of sorting the + keys and each of their values (if a list). + + Args: + payload (Dict[str, Any]): dict object to use to build cache key + separator (str, optional): character to use as a separator in the cache key. Defaults to '|'. + use_hashkey (bool, optional): use a hashkey for the cache key. Defaults to None. + + Returns: + str: [description] + """ if use_hashkey is None: use_hashkey = self.use_hashkey diff --git a/aioradio/requirements.txt b/aioradio/requirements.txt index f80840e..ecf9af8 100644 --- a/aioradio/requirements.txt +++ b/aioradio/requirements.txt @@ -2,7 +2,7 @@ aioboto3==8.2.0 aiobotocore==1.1.2 aiojobs==0.3.0 aioredis==1.3.1 -ddtrace==0.44.0 +ddtrace==0.45.0 fakeredis==1.4.5 flask==1.1.2 httpx==0.16.1 diff --git a/aioradio/tests/aws_secrets_test.py b/aioradio/tests/aws_secrets_test.py index f5f3417..116eb68 100644 --- a/aioradio/tests/aws_secrets_test.py +++ b/aioradio/tests/aws_secrets_test.py @@ -1,4 +1,4 @@ -'''pytest secrets''' +"""pytest secrets.""" # pylint: disable=unused-argument @@ -10,7 +10,7 @@ async def test_secrets_get_secret(create_secret): - '''Test getting secret from Secrets Manager.''' + """Test getting secret from Secrets Manager.""" secret = await get_secret(secret_name='test-secret', region='us-east-2') assert secret == 'abc123' @@ -18,6 +18,7 @@ async def test_secrets_get_secret(create_secret): @pytest.mark.xfail async def test_secrets_get_secret_with_bad_key(): - '''Test exception raised when using a bad key retrieving from Secrets Manager.''' + """Test exception raised when using a bad key retrieving from Secrets + Manager.""" await get_secret(secret_name='Pytest-Bad-Key', region='us-east-2') diff --git a/aioradio/tests/dynamodb_test.py b/aioradio/tests/dynamodb_test.py index 427b4c7..c578118 100644 --- a/aioradio/tests/dynamodb_test.py +++ b/aioradio/tests/dynamodb_test.py @@ -1,20 +1,17 @@ -'''pytest dynamodb''' +"""pytest dynamodb.""" -from random import randint from decimal import Decimal +from random import randint from uuid import uuid4 import pytest -from boto3.dynamodb.conditions import Attr -from boto3.dynamodb.conditions import Key +from boto3.dynamodb.conditions import Attr, Key -from aioradio.aws.dynamodb import batch_get_items_from_dynamo -from aioradio.aws.dynamodb import batch_write_to_dynamo -from aioradio.aws.dynamodb import get_list_of_dynamo_tables -from aioradio.aws.dynamodb import put_item_in_dynamo -from aioradio.aws.dynamodb import query_dynamo -from aioradio.aws.dynamodb import scan_dynamo -from aioradio.aws.dynamodb import update_item_in_dynamo +from aioradio.aws.dynamodb import (batch_get_items_from_dynamo, + batch_write_to_dynamo, + get_list_of_dynamo_tables, + put_item_in_dynamo, query_dynamo, + scan_dynamo, update_item_in_dynamo) # **************************************** # DO NOT CHANGE THE DB_TABLE OR REGION @@ -26,21 +23,22 @@ async def test_dynamodb_create_table(create_table): - '''Test creating a DynamoDB table.''' + """Test creating a DynamoDB table.""" result = await create_table(table_name=DB_TABLE) assert result == DB_TABLE async def test_dynamodb_get_list_of_tables(): - '''Test getting list of DynamoDB tables and our created table is in the list.''' + """Test getting list of DynamoDB tables and our created table is in the + list.""" assert DB_TABLE in await get_list_of_dynamo_tables(region=REGION) @pytest.mark.parametrize('fice', ['XXXXXX', '012345', '999999']) async def test_dynamo_put_item(fice): - '''Test writing an item to DynamoDB table.''' + """Test writing an item to DynamoDB table.""" item = { 'fice': fice, @@ -59,7 +57,7 @@ async def test_dynamo_put_item(fice): async def test_dynamo_update_item(): - '''Test updating a nested value within an item.''' + """Test updating a nested value within an item.""" result = await update_item_in_dynamo( table_name=DB_TABLE, @@ -83,7 +81,7 @@ async def test_dynamo_update_item(): async def test_dynamo_write_batch(): - '''Test writing a batch of items to DynamoDB.''' + """Test writing a batch of items to DynamoDB.""" items = [] for fice in ['000000', '911911', '102779']: @@ -103,7 +101,7 @@ async def test_dynamo_write_batch(): async def test_batch_get_items_from_dynamo(): - '''Test getting batch of items from dynamo.''' + """Test getting batch of items from dynamo.""" items = [{'fice': fice} for fice in ['000000', '911911', '102779']] results = await batch_get_items_from_dynamo(table_name=DB_TABLE, region=REGION, items=items) @@ -113,7 +111,7 @@ async def test_batch_get_items_from_dynamo(): async def test_dynamo_query(): - '''Test querying data from dynamoDB.''' + """Test querying data from dynamoDB.""" key_condition_expression = Key('fice').eq('XXXXXX') result = await query_dynamo(table_name=DB_TABLE, region=REGION, key=key_condition_expression) @@ -121,7 +119,7 @@ async def test_dynamo_query(): async def test_dynamo_scan_table(): - '''Test scanning DynamoDB table.''' + """Test scanning DynamoDB table.""" filter_expression = Attr('data.unique_id.records').between(-10, -1) result = await scan_dynamo(table_name=DB_TABLE, region=REGION, key=filter_expression) diff --git a/aioradio/tests/file_ingestion_test.py b/aioradio/tests/file_ingestion_test.py index 59ec3e2..a747006 100644 --- a/aioradio/tests/file_ingestion_test.py +++ b/aioradio/tests/file_ingestion_test.py @@ -1,4 +1,4 @@ -'''pytest file_ingestion script''' +"""pytest file_ingestion script.""" # pylint: disable=broad-except # pylint: disable=c-extension-no-member @@ -7,18 +7,16 @@ import logging import os import time -from datetime import timedelta -from datetime import timezone +from datetime import timedelta, timezone import pytest -from aioradio.file_ingestion import delete_ftp_file -from aioradio.file_ingestion import establish_ftp_connection -from aioradio.file_ingestion import get_current_datetime_from_timestamp -from aioradio.file_ingestion import list_ftp_objects -from aioradio.file_ingestion import send_emails_via_mandrill -from aioradio.file_ingestion import unzip_file_get_filepaths -from aioradio.file_ingestion import write_file_to_ftp +from aioradio.file_ingestion import (delete_ftp_file, establish_ftp_connection, + get_current_datetime_from_timestamp, + list_ftp_objects, + send_emails_via_mandrill, + unzip_file_get_filepaths, + write_file_to_ftp) LOG = logging.getLogger(__name__) pytestmark = pytest.mark.asyncio @@ -33,7 +31,7 @@ async def test_unzip_file_get_filepaths(request, tmpdir_factory): - '''Test unzip_file_get_filepaths.''' + """Test unzip_file_get_filepaths.""" filepath = os.path.join(request.fspath.dirname, 'test_data', 'test_file_ingestion.zip') temp_directory = str(tmpdir_factory.mktemp("data")) @@ -52,7 +50,7 @@ async def test_unzip_file_get_filepaths(request, tmpdir_factory): async def test_get_current_datetime_from_timestamp(): - '''Test get_current_datetime_from_timestamp.''' + """Test get_current_datetime_from_timestamp.""" datetime_utc = await get_current_datetime_from_timestamp() assert len(datetime_utc) == 26 @@ -67,7 +65,7 @@ async def test_get_current_datetime_from_timestamp(): async def test_send_emails_via_mandrill(): - '''Test send_emails_via_mandrill.''' + """Test send_emails_via_mandrill.""" pytest.skip("Skip sending emails via mandrill.") @@ -89,7 +87,7 @@ async def test_send_emails_via_mandrill(): async def test_write_file_to_ftp(github_action, request): - '''Test write_file_to_ftp.''' + """Test write_file_to_ftp.""" if github_action: pytest.skip('Skip test_write_file_to_ftp when running via Github Action') @@ -119,7 +117,7 @@ async def test_write_file_to_ftp(github_action, request): async def test_list_ftp_objects(github_action): - '''Test test_list_ftp_objects.''' + """Test test_list_ftp_objects.""" if github_action: pytest.skip('Skip test_list_ftp_objects when running via Github Action') @@ -142,7 +140,7 @@ async def test_list_ftp_objects(github_action): async def test_list_ftp_objects_with_regex(github_action): - '''Test test_list_ftp_objects with regex.''' + """Test test_list_ftp_objects with regex.""" if github_action: pytest.skip('Skip test_list_ftp_objects_with_regex when running via Github Action') @@ -167,7 +165,7 @@ async def test_list_ftp_objects_with_regex(github_action): async def test_delete_ftp_file(github_action): - '''Test delete_ftp_file.''' + """Test delete_ftp_file.""" if github_action: pytest.skip('Skip test_list_ftp_objects_with_regex when running via Github Action') diff --git a/aioradio/tests/jira_test.py b/aioradio/tests/jira_test.py index 0073125..82ea030 100644 --- a/aioradio/tests/jira_test.py +++ b/aioradio/tests/jira_test.py @@ -1,6 +1,4 @@ -''' -pytest Jira -''' +"""pytest Jira.""" # pylint: disable=c-extension-no-member @@ -8,9 +6,7 @@ import pytest -from aioradio.jira import add_comment_to_jira -from aioradio.jira import get_jira_issue -from aioradio.jira import post_jira_issue +from aioradio.jira import add_comment_to_jira, get_jira_issue, post_jira_issue pytestmark = pytest.mark.asyncio @@ -20,7 +16,7 @@ async def test_post_jira_issue(): - '''Test posting Jira issue.''' + """Test posting Jira issue.""" pytest.skip("Skip Jira ticket creation as we don't want to create many pointless tickets.") @@ -44,7 +40,7 @@ async def test_post_jira_issue(): async def test_get_jira_issue(): - '''Test getting Jira issue.''' + """Test getting Jira issue.""" pytest.skip("Skip get jira issue.") @@ -54,7 +50,7 @@ async def test_get_jira_issue(): async def test_adding_jira_comment_to_issue(): - '''Test adding a comment to a Jira issue.''' + """Test adding a comment to a Jira issue.""" pytest.skip("Skip adding comments to jira ticket.") diff --git a/aioradio/tests/logger_test.py b/aioradio/tests/logger_test.py index 8c0dd57..c99db98 100644 --- a/aioradio/tests/logger_test.py +++ b/aioradio/tests/logger_test.py @@ -1,4 +1,4 @@ -'''pytest logger''' +"""pytest logger.""" import logging @@ -10,7 +10,7 @@ async def test_datadog_logger(): - '''Check if the logger has json formatted messages.''' + """Check if the logger has json formatted messages.""" logger = DatadogLogger( main_logger='pytest2', diff --git a/aioradio/tests/pyodbc_test.py b/aioradio/tests/pyodbc_test.py index 609efed..e9b22d1 100644 --- a/aioradio/tests/pyodbc_test.py +++ b/aioradio/tests/pyodbc_test.py @@ -1,11 +1,11 @@ -'''pytest file_ingestion script''' +"""pytest file_ingestion script.""" import os + import pytest -from aioradio.pyodbc import establish_pyodbc_connection -from aioradio.pyodbc import pyodbc_query_fetchone -from aioradio.pyodbc import pyodbc_query_fetchall +from aioradio.pyodbc import (establish_pyodbc_connection, + pyodbc_query_fetchall, pyodbc_query_fetchone) pytestmark = pytest.mark.asyncio @@ -14,7 +14,7 @@ @pytest.mark.xfail async def test_bad_unixodbc_driver(github_action): - '''Test using a bad unixodbc_driver that the proper exception is raised.''' + """Test using a bad unixodbc_driver that the proper exception is raised.""" if github_action: pytest.skip('Skip test_bad_unixodbc_driver when running via Github Action') @@ -24,8 +24,11 @@ async def test_bad_unixodbc_driver(github_action): async def test_pyodbc_query_fetchone_and_fetchall(github_action): - '''Test pyodbc_query_fetchone. Make sure you have unixodbc and freetds installed; - see here: http://www.christophers.tips/pages/pyodbc_mac.html.''' + """Test pyodbc_query_fetchone. + + Make sure you have unixodbc and freetds installed; + see here: http://www.christophers.tips/pages/pyodbc_mac.html. + """ if github_action: pytest.skip('Skip test_pyodbc_query_fetchone_and_fetchall when running via Github Action') diff --git a/aioradio/tests/redis_test.py b/aioradio/tests/redis_test.py index 3a700ae..c5a23aa 100644 --- a/aioradio/tests/redis_test.py +++ b/aioradio/tests/redis_test.py @@ -1,4 +1,4 @@ -'''pytest redis cache''' +"""pytest redis cache.""" # pylint: disable=c-extension-no-member @@ -10,14 +10,14 @@ async def test_build_cache_key(payload, cache): - '''Test health check.''' + """Test health check.""" key = await cache.build_cache_key(payload) assert key == 'opinion=[redis,rocks]|tool=pytest|version=python3' async def test_set_one_item(payload, cache): - '''Test set_one_item.''' + """Test set_one_item.""" key = await cache.build_cache_key(payload) await cache.set_one_item(cache_key=key, cache_value={'name': ['tim', 'walter', 'bryan'], 'app': 'aioradio'}) @@ -30,7 +30,7 @@ async def test_set_one_item(payload, cache): assert result == 1 async def test_set_one_item_with_hashed_key(payload, cache): - '''Test set_one_item.''' + """Test set_one_item.""" key = await cache.build_cache_key(payload, use_hashkey=True) assert key == 'bdeb95a5154f7151eecaeadbcea52ed43d80d7338192322a53ef88a50ec7e94a' @@ -46,7 +46,7 @@ async def test_set_one_item_with_hashed_key(payload, cache): async def test_get_many_items(cache): - '''Test get_many_items.''' + """Test get_many_items.""" await cache.set_one_item(cache_key='pytest-1', cache_value='one') await cache.set_one_item(cache_key='pytest-2', cache_value='two') diff --git a/aioradio/tests/s3_test.py b/aioradio/tests/s3_test.py index a2f1ff7..e5a4969 100644 --- a/aioradio/tests/s3_test.py +++ b/aioradio/tests/s3_test.py @@ -1,4 +1,4 @@ -'''pytest s3''' +"""pytest s3.""" # pylint: disable=logging-fstring-interpolation @@ -6,13 +6,9 @@ import pytest -from aioradio.aws.s3 import delete_s3_object -from aioradio.aws.s3 import download_file -from aioradio.aws.s3 import get_object -from aioradio.aws.s3 import get_s3_file_attributes -from aioradio.aws.s3 import list_s3_objects -from aioradio.aws.s3 import upload_file - +from aioradio.aws.s3 import (delete_s3_object, download_file, get_object, + get_s3_file_attributes, list_s3_objects, + upload_file) LOG = logging.getLogger(__name__) @@ -24,14 +20,17 @@ async def test_s3_creating_bucket(create_bucket): - '''Create the mock S3 bucket.''' + """Create the mock S3 bucket.""" result = await create_bucket(region_name='us-east-1', bucket_name=S3_BUCKET) assert result == S3_BUCKET async def test_s3_upload_file(tmpdir_factory): - '''Test uploading file to s3. In addition will test deleting a file and listing objects.''' + """Test uploading file to s3. + + In addition will test deleting a file and listing objects. + """ filename = 'hello_world.txt' path = str(tmpdir_factory.mktemp('upload').join(filename)) @@ -53,7 +52,7 @@ async def test_s3_upload_file(tmpdir_factory): async def test_s3_download_file(tmpdir_factory): - '''Test uploading file to s3.''' + """Test uploading file to s3.""" filename = 'hello_world.txt' path = str(tmpdir_factory.mktemp('download').join(filename)) @@ -66,14 +65,14 @@ async def test_s3_download_file(tmpdir_factory): async def test_get_object(): - '''Test get_object from s3.''' + """Test get_object from s3.""" result = await get_object(bucket=S3_BUCKET, s3_key=f'{S3_PREFIX}/hello_world.txt') assert result is not None async def test_get_file_attributes(): - '''Test retrieving s3 object attributes.''' + """Test retrieving s3 object attributes.""" result = await get_s3_file_attributes(bucket=S3_BUCKET, s3_key=f'{S3_PREFIX}/hello_world.txt') assert result['ContentLength'] == 22 diff --git a/aioradio/tests/sqs_test.py b/aioradio/tests/sqs_test.py index 9a4b062..8f2c3ab 100644 --- a/aioradio/tests/sqs_test.py +++ b/aioradio/tests/sqs_test.py @@ -1,4 +1,4 @@ -'''pytest sqs''' +"""pytest sqs.""" # pylint: disable=c-extension-no-member @@ -8,10 +8,8 @@ import orjson import pytest -from aioradio.aws.sqs import delete_messages -from aioradio.aws.sqs import get_messages -from aioradio.aws.sqs import purge_messages -from aioradio.aws.sqs import send_messages +from aioradio.aws.sqs import (delete_messages, get_messages, purge_messages, + send_messages) QUEUE = 'pytest' REGION = 'us-east-2' @@ -21,7 +19,7 @@ async def test_sqs_creating_queue(sqs_queue_url): - '''Create mock SQS queue.''' + """Create mock SQS queue.""" queue_url = await sqs_queue_url(region_name=REGION, queue_name=QUEUE) assert queue_url @@ -29,13 +27,13 @@ async def test_sqs_creating_queue(sqs_queue_url): @pytest.mark.xfail async def test_sqs_non_existing_queue(): - '''Test purging all messages from SQS queue that does not exist.''' + """Test purging all messages from SQS queue that does not exist.""" await purge_messages(queue='this-queue-does-not-exist', region=REGION) async def test_sqs_send_messages(): - '''Test sending a batch of messages to an SQS queue.''' + """Test sending a batch of messages to an SQS queue.""" entries = [ {'Id': str(uuid4()), 'MessageBody': orjson.dumps({'data': 'Hello Austin!'}).decode()}, @@ -48,7 +46,7 @@ async def test_sqs_send_messages(): async def test_sqs_get_messages(): - '''Test receiving a batch of messages from an SQS queue.''' + """Test receiving a batch of messages from an SQS queue.""" msgs = await get_messages(queue=QUEUE, region=REGION) assert len(msgs) > 0 @@ -61,7 +59,7 @@ async def test_sqs_get_messages(): async def test_sqs_delete_messages(): - '''Test successful deletion of a batch of SQS queue messages.''' + """Test successful deletion of a batch of SQS queue messages.""" entries = [{'Id': str(uuid4()), 'ReceiptHandle': i} for i in RECEIPT_HANDLES] result = await delete_messages(queue=QUEUE, region=REGION, entries=entries) @@ -69,7 +67,7 @@ async def test_sqs_delete_messages(): async def test_sqs_purge_messages(): - '''Test purging all messages from SQS queue.''' + """Test purging all messages from SQS queue.""" # Iterate twice to exercise the err on issuing PurgeQueue within 60 seconds of previous call for _ in range(2): diff --git a/conftest.py b/conftest.py index d24611a..5070acc 100644 --- a/conftest.py +++ b/conftest.py @@ -1,4 +1,4 @@ -'''pytest configuration.''' +"""pytest configuration.""" import asyncio from itertools import chain @@ -9,16 +9,16 @@ from aiobotocore.config import AioConfig from aioradio.aws.dynamodb import DYNAMO +from aioradio.aws.moto_server import MotoService from aioradio.aws.s3 import S3 -from aioradio.aws.sqs import SQS from aioradio.aws.secrets import SECRETS +from aioradio.aws.sqs import SQS from aioradio.redis import Redis -from aioradio.aws.moto_server import MotoService @pytest.fixture(scope='session') def event_loop(): - '''Redefine event_loop with scope set to session instead of function.''' + """Redefine event_loop with scope set to session instead of function.""" loop = asyncio.get_event_loop() yield loop @@ -27,7 +27,7 @@ def event_loop(): @pytest.fixture(scope='module') def payload(): - '''Test payload to reuse.''' + """Test payload to reuse.""" return { 'tool': 'pytest', @@ -39,7 +39,7 @@ def payload(): @pytest.fixture(scope='module') def cache(github_action): - '''Redefine event_loop with scope set to session instead of function.''' + """Redefine event_loop with scope set to session instead of function.""" if github_action: pytest.skip('Skip test_set_one_item when running via Github Action') @@ -49,7 +49,8 @@ def cache(github_action): def pytest_addoption(parser): - '''Command line argument --cleanse=false can be used to turn off address cleansing.''' + """Command line argument --cleanse=false can be used to turn off address + cleansing.""" parser.addoption( '--github', action='store', default='false', help='pytest running from github action') @@ -57,7 +58,7 @@ def pytest_addoption(parser): @pytest.fixture(scope='session') def github_action(pytestconfig): - '''Return True/False depending on the --cleanse command line argument.''' + """Return True/False depending on the --cleanse command line argument.""" return pytestconfig.getoption("github").lower() == "true" @@ -118,10 +119,8 @@ async def s3_server(): async def create_bucket(s3_client): _bucket_name = None - async def _f(region_name, bucket_name=None): + async def _f(region_name, bucket_name): nonlocal _bucket_name - if bucket_name is None: - bucket_name = random_bucketname() _bucket_name = bucket_name bucket_kwargs = {'Bucket': bucket_name} if region_name != 'us-east-1': @@ -184,10 +183,8 @@ async def sqs_client(session, region, sqs_config, sqs_server): async def sqs_queue_url(sqs_client): _queue_name = None - async def _f(region_name, queue_name=None): + async def _f(region_name, queue_name): nonlocal _queue_name - if queue_name is None: - queue_name = random_name() _queue_name = queue_name response = await sqs_client.create_queue(QueueName=queue_name) queue_url = response['QueueUrl'] @@ -284,10 +281,8 @@ async def _is_table_ready(table_name): response = await dynamodb_client.describe_table(TableName=table_name) return response['Table']['TableStatus'] == 'ACTIVE' - async def _f(table_name=None): + async def _f(table_name): nonlocal _table_name - if table_name is None: - table_name = random_tablename() _table_name = table_name table_kwargs = { 'TableName': table_name, diff --git a/setup.py b/setup.py index 58e5d44..6bc1089 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,5 @@ -'''Python utility for NRCCUA common generic functions to reuse across projects.''' +"""Python utility for NRCCUA common generic functions to reuse across +projects.""" from setuptools import setup @@ -6,7 +7,7 @@ long_description = fileobj.read() setup(name='aioradio', - version='0.9.5', + version='0.9.6', description='Generic asynchronous i/o python utilities for AWS services (SQS, S3, DynamoDB, Secrets Manager), Redis, MSSQL (pyodbc), JIRA and more', long_description=long_description, long_description_content_type="text/markdown",