Skip to content

Commit

Permalink
Create batch_get method
Browse files Browse the repository at this point in the history
Create a batch_get method to get multiple items from the DynamoDB table
  • Loading branch information
RomaricKanyamibwa committed Sep 21, 2023
1 parent 6fe7fc8 commit 779b677
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 2 deletions.
63 changes: 62 additions & 1 deletion src/e3/aws/dynamodb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import TYPE_CHECKING
import logging
import re
import time
from botocore.exceptions import ClientError

if TYPE_CHECKING:
Expand Down Expand Up @@ -92,7 +93,7 @@ def get_item(
:return: retrieved item
"""
table = self.client.Table(table_name)
logger.info(f"Retrievieng item {item} from {table_name}...")
logger.info(f"Retrieving item {item} from {table_name}...")
try:
response = table.get_item(
Key={key: item[key] for key in keys if key in item.keys()}
Expand All @@ -104,6 +105,66 @@ def get_item(
else:
return response.get("Item", {})

def batch_get_items(
self, items: list[dict[str, Any]], table_name: str, keys: list[str]
) -> list[dict[str, Any]]:
"""Retrieve multiple items from a table.
When Amazon DynamoDB cannot process all items in a batch, a set of unprocessed
keys is returned. This function uses an exponential backoff algorithm to retry
getting the unprocessed keys until all are retrieved or the specified
number of tries is reached.
:param items: items we want to retrieve
:param table_name: table containing the items
:param keys: the primary keys of the items
:return: retrieved item
"""
logger.info(f"Retrieving items {items} from {table_name}...")
res = []

tries = 0
max_tries = 5
sleepy_time = 1 # Start with 1 second of sleep, then exponentially increase.
batch_keys = {
table_name: {
"Keys": [
{key: item[key] for key in keys if key in item.keys()}
for item in items
],
"ConsistentRead": True,
}
}
print(batch_keys)
while tries < max_tries:
try:
response = self.client.batch_get_item(
RequestItems=batch_keys,
)
res.extend(response.get("Responses", {table_name: []})[table_name])
logger.debug(f"Get_item response: {response}")
unprocessed = response["UnprocessedKeys"]
if len(unprocessed) > 0:
batch_keys = unprocessed
unprocessed_count = sum(

Check warning on line 149 in src/e3/aws/dynamodb/__init__.py

View check run for this annotation

Codecov / codecov/patch

src/e3/aws/dynamodb/__init__.py#L148-L149

Added lines #L148 - L149 were not covered by tests
[len(batch_key["Keys"]) for batch_key in batch_keys.values()] # type: ignore
)
logger.info(

Check warning on line 152 in src/e3/aws/dynamodb/__init__.py

View check run for this annotation

Codecov / codecov/patch

src/e3/aws/dynamodb/__init__.py#L152

Added line #L152 was not covered by tests
"%s unprocessed keys returned. Sleep, then retry.",
unprocessed_count,
)
tries += 1

Check warning on line 156 in src/e3/aws/dynamodb/__init__.py

View check run for this annotation

Codecov / codecov/patch

src/e3/aws/dynamodb/__init__.py#L156

Added line #L156 was not covered by tests
if tries < max_tries:
logger.info("Sleeping for %s seconds.", sleepy_time)
time.sleep(sleepy_time)
sleepy_time = min(sleepy_time * 2, 32)
else:
break
except ClientError as e:
logger.error(e)
return []
return res

def update_item(
self,
item: dict[str, Any],
Expand Down
5 changes: 5 additions & 0 deletions tests/coverage/base.rc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ exclude_lines =
# + <os>-only and <os>: no cover
# + py2-only or py3-only
if TYPE_CHECKING:
# testing this option is hard since it would require a table with
# more that 16MB of data
if len(unprocessed) > 0:
if tries < max_tries:



[html]
Expand Down
34 changes: 33 additions & 1 deletion tests/tests_e3_aws/dynamodb/main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from e3.aws.dynamodb import DynamoDB

if TYPE_CHECKING:
from typing import Any
from typing import Any, Generator
from collections.abc import Iterable
from pytest import LogCaptureFixture

TABLE_NAME = "customer"
PRIMARY_KEYS = ["name"]
Expand Down Expand Up @@ -83,6 +84,37 @@ def test_get_item_missing(client: DynamoDB) -> None:
)


def test_batch_get_item(client: DynamoDB) -> None:
"""Test getting an item that doesn't exist."""
assert client.batch_get_items(
items=[{"name": "Doe"}, {"name": "Dupont"}],
table_name=TABLE_NAME,
keys=PRIMARY_KEYS,
) == [{"age": 23, "name": "Doe"}]


def test_batch_get_item_error(
client: DynamoDB, caplog: Generator[LogCaptureFixture, Any, Any]
) -> None:
"""Test getting an item that doesn't exist."""
items = client.batch_get_items(
items=[{"name": "Dupont"}], table_name="Fake_Table", keys=PRIMARY_KEYS
)

messages = []

# capture logs and ensure that are what we expect
messages.extend([x.message for x in caplog.get_records("call")]) # type: ignore

assert len(messages) == 2
assert (
"An error occurred (ResourceNotFoundException) when "
"calling the BatchGetItem operation: Requested resource not found" in messages
)

assert items == []


def test_update_item(client: DynamoDB) -> None:
"""Test updating an item."""
customers = [dict(customer) for customer in CUSTOMERS]
Expand Down

0 comments on commit 779b677

Please sign in to comment.