diff --git a/src/e3/aws/dynamodb/__init__.py b/src/e3/aws/dynamodb/__init__.py index 931d562..def481c 100644 --- a/src/e3/aws/dynamodb/__init__.py +++ b/src/e3/aws/dynamodb/__init__.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING import logging import re +import time from botocore.exceptions import ClientError if TYPE_CHECKING: @@ -92,7 +93,7 @@ def get_item( :return: retrieved item """ table = self.client.Table(table_name) - logger.info(f"Retrievieng item {item} from {table_name}...") + logger.info(f"Retrieving item {item} from {table_name}...") try: response = table.get_item( Key={key: item[key] for key in keys if key in item.keys()} @@ -104,6 +105,66 @@ def get_item( else: return response.get("Item", {}) + def batch_get_items( + self, items: list[dict[str, Any]], table_name: str, keys: list[str] + ) -> list[dict[str, Any]]: + """Retrieve multiple items from a table. + + When Amazon DynamoDB cannot process all items in a batch, a set of unprocessed + keys is returned. This function uses an exponential backoff algorithm to retry + getting the unprocessed keys until all are retrieved or the specified + number of tries is reached. + + :param items: items we want to retrieve + :param table_name: table containing the items + :param keys: the primary keys of the items + :return: retrieved item + """ + logger.info(f"Retrieving items {items} from {table_name}...") + res = [] + + tries = 0 + max_tries = 5 + sleepy_time = 1 # Start with 1 second of sleep, then exponentially increase. + batch_keys = { + table_name: { + "Keys": [ + {key: item[key] for key in keys if key in item.keys()} + for item in items + ], + "ConsistentRead": True, + } + } + print(batch_keys) + while tries < max_tries: + try: + response = self.client.batch_get_item( + RequestItems=batch_keys, + ) + res.extend(response.get("Responses", {table_name: []})[table_name]) + logger.debug(f"Get_item response: {response}") + unprocessed = response["UnprocessedKeys"] + if len(unprocessed) > 0: + batch_keys = unprocessed + unprocessed_count = sum( + [len(batch_key["Keys"]) for batch_key in batch_keys.values()] # type: ignore + ) + logger.info( + "%s unprocessed keys returned. Sleep, then retry.", + unprocessed_count, + ) + tries += 1 + if tries < max_tries: + logger.info("Sleeping for %s seconds.", sleepy_time) + time.sleep(sleepy_time) + sleepy_time = min(sleepy_time * 2, 32) + else: + break + except ClientError as e: + logger.error(e) + return [] + return res + def update_item( self, item: dict[str, Any], diff --git a/tests/coverage/base.rc b/tests/coverage/base.rc index 0bc0d2d..4f84be9 100644 --- a/tests/coverage/base.rc +++ b/tests/coverage/base.rc @@ -12,6 +12,11 @@ exclude_lines = # + -only and : no cover # + py2-only or py3-only if TYPE_CHECKING: + # testing this option is hard since it would require a table with + # more that 16MB of data + if len(unprocessed) > 0: + if tries < max_tries: + [html] diff --git a/tests/tests_e3_aws/dynamodb/main_test.py b/tests/tests_e3_aws/dynamodb/main_test.py index a4b3128..db58c3e 100644 --- a/tests/tests_e3_aws/dynamodb/main_test.py +++ b/tests/tests_e3_aws/dynamodb/main_test.py @@ -7,8 +7,9 @@ from e3.aws.dynamodb import DynamoDB if TYPE_CHECKING: - from typing import Any + from typing import Any, Generator from collections.abc import Iterable + from pytest import LogCaptureFixture TABLE_NAME = "customer" PRIMARY_KEYS = ["name"] @@ -83,6 +84,37 @@ def test_get_item_missing(client: DynamoDB) -> None: ) +def test_batch_get_item(client: DynamoDB) -> None: + """Test getting an item that doesn't exist.""" + assert client.batch_get_items( + items=[{"name": "Doe"}, {"name": "Dupont"}], + table_name=TABLE_NAME, + keys=PRIMARY_KEYS, + ) == [{"age": 23, "name": "Doe"}] + + +def test_batch_get_item_error( + client: DynamoDB, caplog: Generator[LogCaptureFixture, Any, Any] +) -> None: + """Test getting an item that doesn't exist.""" + items = client.batch_get_items( + items=[{"name": "Dupont"}], table_name="Fake_Table", keys=PRIMARY_KEYS + ) + + messages = [] + + # capture logs and ensure that are what we expect + messages.extend([x.message for x in caplog.get_records("call")]) # type: ignore + + assert len(messages) == 2 + assert ( + "An error occurred (ResourceNotFoundException) when " + "calling the BatchGetItem operation: Requested resource not found" in messages + ) + + assert items == [] + + def test_update_item(client: DynamoDB) -> None: """Test updating an item.""" customers = [dict(customer) for customer in CUSTOMERS]