Skip to content

Commit

Permalink
Adding support to list/set/tuple fields + renaming the class of the p…
Browse files Browse the repository at this point in the history
…rovider
  • Loading branch information
leandrodamascena committed Dec 18, 2023
1 parent ca7897f commit fec33a6
Show file tree
Hide file tree
Showing 17 changed files with 161 additions and 67 deletions.
102 changes: 75 additions & 27 deletions aws_lambda_powertools/utilities/_data_masking/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
import logging
from typing import Any, Callable, Iterable, Optional, Union

from aws_lambda_powertools.utilities._data_masking.exceptions import DataMaskingUnsupportedTypeError
from aws_lambda_powertools.utilities._data_masking.exceptions import (
DataMaskingFieldNotFound,
DataMaskingUnsupportedTypeError,
)
from aws_lambda_powertools.utilities._data_masking.provider import BaseProvider

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -137,9 +140,56 @@ def _apply_action_to_fields(
```
"""

data_parsed = {}
data_parsed: dict = self._normalize_data_to_parse(fields, data)

if fields is None:
for nested_field in fields:
logger.debug(f"Processing nested field: {nested_field}")

nested_parsed_field = nested_field

# Ensure the nested field is represented as a string
if not isinstance(nested_parsed_field, str):
nested_parsed_field = json.dumps(nested_parsed_field)

# Split the nested field into keys using dot, square brackets as separators
# keys = re.split(r"\.|\[|\]", nested_field) # noqa ERA001 - REVIEW THIS

keys = nested_parsed_field.replace("][", ".").replace("[", ".").replace("]", "").split(".")
keys = [key for key in keys if key] # Remove empty strings from the split

# Traverse the dictionary hierarchy by iterating through the list of nested keys
current_dict = data_parsed

for key in keys[:-1]:
# If enter here, the customer is passing potential list, set or tuple
# Example "payload[0]"

logger.debug(f"Processing {key} in field {nested_field}")

# It supports dict, list, set and tuple
try:
if isinstance(current_dict, dict) and key in current_dict:
# If enter heres, it captures the name of the key
# Example "payload"
current_dict = current_dict[key]
elif (
isinstance(current_dict, (set, tuple, list)) and key.isdigit() and int(key) < len(current_dict)
):
# If enter heres, it captures the index of the key
# Example "[0]"
current_dict = current_dict[int(key)]
except KeyError:
# Handle the case when the key doesn't exist
raise DataMaskingFieldNotFound(f"Key {key} not found in {current_dict}")

last_key = keys[-1]

current_dict = self._apply_action_to_specific_type(current_dict, action, last_key, **provider_options)

return data_parsed

def _normalize_data_to_parse(self, fields: list, data: str | dict) -> dict:
if not fields:
raise ValueError("No fields specified.")

if isinstance(data, str):
Expand All @@ -154,29 +204,27 @@ def _apply_action_to_fields(
f"Unsupported data type. Expected a traversable type (dict or str), but got {type(data)}.",
)

for nested_field in fields:
# Prevent overriding loop variable
current_nested_field = nested_field

# Ensure the nested field is represented as a string
if not isinstance(current_nested_field, str):
current_nested_field = json.dumps(current_nested_field)

# Split the nested field string into a list of nested keys
# ['a.b.c'] -> ['a', 'b', 'c']
nested_keys = current_nested_field.split(".")

# Initialize the current dictionary to the root dictionary
current_dict = data_parsed

# Traverse the dictionary hierarchy by iterating through the list of nested keys
for key in nested_keys[:-1]:
current_dict = current_dict[key]

# Retrieve the final value of the nested field
target_value = current_dict[nested_keys[-1]]
return data_parsed

# Apply the specified 'action' to the target value
current_dict[nested_keys[-1]] = action(target_value, **provider_options)
def _apply_action_to_specific_type(self, current_dict: dict, action: Callable, last_key, **provider_options):
logger.debug("Processing the last fields to apply the action")
# Apply the action to the last key (either a specific index or dictionary key)
if isinstance(current_dict, dict) and last_key in current_dict:
current_dict[last_key] = action(current_dict[last_key], **provider_options)
elif isinstance(current_dict, list) and last_key.isdigit() and int(last_key) < len(current_dict):
current_dict[int(last_key)] = action(current_dict[int(last_key)], **provider_options)
elif isinstance(current_dict, tuple) and last_key.isdigit() and int(last_key) < len(current_dict):
index = int(last_key)
current_dict = (
current_dict[:index] + (action(current_dict[index], **provider_options),) + current_dict[index + 1 :]
)
elif isinstance(current_dict, set):
# Convert the set to a list, apply the action, and convert back to a set
elements_list = list(current_dict)
elements_list[int(last_key)] = action(elements_list[int(last_key)], **provider_options)
current_dict = set(elements_list)
else:
# Handle the case when the last key doesn't exist
raise DataMaskingFieldNotFound(f"Key {last_key} not found in {current_dict}")

return data_parsed
return current_dict
6 changes: 6 additions & 0 deletions aws_lambda_powertools/utilities/_data_masking/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,9 @@ class DataMaskingContextMismatchError(Exception):
"""
Decrypting with the incorrect encryption context.
"""


class DataMaskingFieldNotFound(Exception):
"""
Field not found.
"""
10 changes: 5 additions & 5 deletions aws_lambda_powertools/utilities/_data_masking/provider/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def decrypt(self, data) -> Any:
def mask(self, data) -> Union[str, Iterable]:
# Implementation logic for data masking
pass
def lambda_handler(event, context):
provider = MyCustomProvider(["secret-key"])
data_masker = DataMasking(provider=provider)
Expand Down Expand Up @@ -83,13 +83,13 @@ def decrypt(self, data) -> Any:

def mask(self, data) -> Union[str, Iterable]:
"""
This method irreversibly masks data.
This method irreversibly masks data.
If the data to be masked is of type `str`, `dict`, or `bytes`,
this method will return a masked string, i.e. "*****".
If the data to be masked is of an iterable type like `list`, `tuple`,
or `set`, this method will return a new object of the same type as the
If the data to be masked is of an iterable type like `list`, `tuple`,
or `set`, this method will return a new object of the same type as the
input data but with each element replaced by the string "*****".
"""
if isinstance(data, (str, dict, bytes)):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from aws_lambda_powertools.utilities._data_masking.provider.kms.aws_encryption_sdk import AwsEncryptionSdkProvider
from aws_lambda_powertools.utilities._data_masking.provider.kms.aws_encryption_sdk import AWSEncryptionSDKProvider

__all__ = [
"AwsEncryptionSdkProvider",
"AWSEncryptionSDKProvider",
]
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@
logger = logging.getLogger(__name__)


class AwsEncryptionSdkProvider(BaseProvider):
class AWSEncryptionSDKProvider(BaseProvider):
"""
The AwsEncryptionSdkProvider is used as a provider for the DataMasking class.
The AWSEncryptionSDKProvider is used as a provider for the DataMasking class.
This provider allows you to perform data masking using the AWS Encryption SDK
for encryption and decryption. It integrates with the DataMasking class to
Expand All @@ -44,12 +44,12 @@ class AwsEncryptionSdkProvider(BaseProvider):
```
from aws_lambda_powertools.utilities.data_masking import DataMasking
from aws_lambda_powertools.utilities.data_masking.providers.kms.aws_encryption_sdk import (
AwsEncryptionSdkProvider,
AWSEncryptionSDKProvider,
)
def lambda_handler(event, context):
provider = AwsEncryptionSdkProvider(["arn:aws:kms:us-east-1:0123456789012:key/key-id"])
provider = AWSEncryptionSDKProvider(["arn:aws:kms:us-east-1:0123456789012:key/key-id"])
data_masker = DataMasking(provider=provider)
data = {
Expand Down Expand Up @@ -130,7 +130,7 @@ def __init__(

def encrypt(self, data: bytes | str | Dict | float, **provider_options) -> str:
"""
Encrypt data using the AwsEncryptionSdkProvider.
Encrypt data using the AWSEncryptionSDKProvider.
Parameters
-------
Expand Down
4 changes: 2 additions & 2 deletions examples/data_masking/src/data_masking_function_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from aws_lambda_powertools import Logger, Tracer
from aws_lambda_powertools.utilities._data_masking import DataMasking
from aws_lambda_powertools.utilities._data_masking.provider.kms.aws_encryption_sdk import AwsEncryptionSdkProvider
from aws_lambda_powertools.utilities._data_masking.provider.kms.aws_encryption_sdk import AWSEncryptionSDKProvider
from aws_lambda_powertools.utilities.typing import LambdaContext

KMS_KEY_ARN = os.getenv("KMS_KEY_ARN", "")
Expand All @@ -18,7 +18,7 @@ def lambda_handler(event: dict, context: LambdaContext) -> dict:

data = event["body"]

data_masker = DataMasking(provider=AwsEncryptionSdkProvider(keys=[KMS_KEY_ARN]))
data_masker = DataMasking(provider=AWSEncryptionSDKProvider(keys=[KMS_KEY_ARN]))
encrypted = data_masker.encrypt(data, fields=["address.street", "job_history.company.company_name"])
decrypted = data_masker.decrypt(encrypted, fields=["address.street", "job_history.company.company_name"])
return {"Decrypted_json": decrypted}
4 changes: 2 additions & 2 deletions examples/data_masking/src/getting_started_decrypt_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

from aws_lambda_powertools import Logger
from aws_lambda_powertools.utilities._data_masking import DataMasking
from aws_lambda_powertools.utilities._data_masking.provider.kms.aws_encryption_sdk import AwsEncryptionSdkProvider
from aws_lambda_powertools.utilities._data_masking.provider.kms.aws_encryption_sdk import AWSEncryptionSDKProvider
from aws_lambda_powertools.utilities.typing import LambdaContext

KMS_KEY_ARN = os.getenv("KMS_KEY_ARN", "")

encryption_provider = AwsEncryptionSdkProvider(keys=[KMS_KEY_ARN])
encryption_provider = AWSEncryptionSDKProvider(keys=[KMS_KEY_ARN])
data_masker = DataMasking(provider=encryption_provider)

logger = Logger()
Expand Down
4 changes: 2 additions & 2 deletions examples/data_masking/src/getting_started_decrypt_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

from aws_lambda_powertools import Logger
from aws_lambda_powertools.utilities._data_masking import DataMasking
from aws_lambda_powertools.utilities._data_masking.provider.kms.aws_encryption_sdk import AwsEncryptionSdkProvider
from aws_lambda_powertools.utilities._data_masking.provider.kms.aws_encryption_sdk import AWSEncryptionSDKProvider
from aws_lambda_powertools.utilities.typing import LambdaContext

KMS_KEY_ARN = os.getenv("KMS_KEY_ARN", "") # (1)!

encryption_provider = AwsEncryptionSdkProvider(keys=[KMS_KEY_ARN]) # (2)!
encryption_provider = AWSEncryptionSDKProvider(keys=[KMS_KEY_ARN]) # (2)!
data_masker = DataMasking(provider=encryption_provider)

logger = Logger()
Expand Down
4 changes: 2 additions & 2 deletions examples/data_masking/src/getting_started_encrypt_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@

from aws_lambda_powertools import Logger
from aws_lambda_powertools.utilities._data_masking import DataMasking
from aws_lambda_powertools.utilities._data_masking.provider.kms.aws_encryption_sdk import AwsEncryptionSdkProvider
from aws_lambda_powertools.utilities._data_masking.provider.kms.aws_encryption_sdk import AWSEncryptionSDKProvider
from aws_lambda_powertools.utilities.typing import LambdaContext

KMS_KEY_ARN = os.getenv("KMS_KEY_ARN", "")

encryption_provider = AwsEncryptionSdkProvider(keys=[KMS_KEY_ARN])
encryption_provider = AWSEncryptionSDKProvider(keys=[KMS_KEY_ARN])
data_masker = DataMasking(provider=encryption_provider)

logger = Logger()
Expand Down
4 changes: 2 additions & 2 deletions examples/data_masking/src/getting_started_encrypt_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

from aws_lambda_powertools import Logger
from aws_lambda_powertools.utilities._data_masking import DataMasking
from aws_lambda_powertools.utilities._data_masking.provider.kms.aws_encryption_sdk import AwsEncryptionSdkProvider
from aws_lambda_powertools.utilities._data_masking.provider.kms.aws_encryption_sdk import AWSEncryptionSDKProvider
from aws_lambda_powertools.utilities.typing import LambdaContext

KMS_KEY_ARN = os.getenv("KMS_KEY_ARN", "")

encryption_provider = AwsEncryptionSdkProvider(keys=[KMS_KEY_ARN]) # (1)!
encryption_provider = AWSEncryptionSDKProvider(keys=[KMS_KEY_ARN]) # (1)!
data_masker = DataMasking(provider=encryption_provider)

logger = Logger()
Expand Down
4 changes: 2 additions & 2 deletions examples/data_masking/tests/src/single_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from aws_lambda_powertools.utilities._data_masking.base import DataMasking
from aws_lambda_powertools.utilities._data_masking.provider import BaseProvider
from aws_lambda_powertools.utilities._data_masking.provider.kms.aws_encryption_sdk import AwsEncryptionSdkProvider
from aws_lambda_powertools.utilities._data_masking.provider.kms.aws_encryption_sdk import AWSEncryptionSDKProvider


class FakeEncryptionKeyProvider(BaseProvider):
Expand All @@ -31,7 +31,7 @@ def handler(event, context):
data = "mock_value"

fake_key_provider = FakeEncryptionKeyProvider()
provider = AwsEncryptionSdkProvider(
provider = AWSEncryptionSDKProvider(
keys=["dummy"],
key_provider=fake_key_provider,
)
Expand Down
4 changes: 2 additions & 2 deletions tests/e2e/data_masking/handlers/basic_handler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from aws_lambda_powertools import Logger
from aws_lambda_powertools.utilities._data_masking import DataMasking
from aws_lambda_powertools.utilities._data_masking.provider.kms.aws_encryption_sdk import AwsEncryptionSdkProvider
from aws_lambda_powertools.utilities._data_masking.provider.kms.aws_encryption_sdk import AWSEncryptionSDKProvider

logger = Logger()

Expand All @@ -14,7 +14,7 @@ def lambda_handler(event, context):

# Encrypting data for test_encryption_in_handler test
kms_key = event.get("kms_key", "")
data_masker = DataMasking(provider=AwsEncryptionSdkProvider(keys=[kms_key]))
data_masker = DataMasking(provider=AWSEncryptionSDKProvider(keys=[kms_key]))
value = [1, 2, "string", 4.5]
encrypted_data = data_masker.encrypt(value)
response = {}
Expand Down
12 changes: 6 additions & 6 deletions tests/e2e/data_masking/test_e2e_data_masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from aws_encryption_sdk.exceptions import DecryptKeyError

from aws_lambda_powertools.utilities._data_masking import DataMasking
from aws_lambda_powertools.utilities._data_masking.exceptions import DataMaskingContextMismatchError
from aws_lambda_powertools.utilities._data_masking.provider.kms.aws_encryption_sdk import (
AwsEncryptionSdkProvider,
ContextMismatchError,
AWSEncryptionSDKProvider,
)
from tests.e2e.utils import data_fetcher

Expand Down Expand Up @@ -36,7 +36,7 @@ def kms_key2_arn(infrastructure: dict) -> str:

@pytest.fixture
def data_masker(kms_key1_arn) -> DataMasking:
return DataMasking(provider=AwsEncryptionSdkProvider(keys=[kms_key1_arn]))
return DataMasking(provider=AWSEncryptionSDKProvider(keys=[kms_key1_arn]))


@pytest.mark.xdist_group(name="data_masking")
Expand Down Expand Up @@ -79,7 +79,7 @@ def test_encryption_context_mismatch(data_masker):
encrypted_data = data_masker.encrypt(value, encryption_context={"this": "is_secure"})

# THEN decrypting with a different encryption_context should raise a ContextMismatchError
with pytest.raises(ContextMismatchError):
with pytest.raises(DataMaskingContextMismatchError):
data_masker.decrypt(encrypted_data, encryption_context={"not": "same_context"})


Expand All @@ -93,7 +93,7 @@ def test_encryption_no_context_fail(data_masker):
encrypted_data = data_masker.encrypt(value)

# THEN decrypting with an encryption_context should raise a ContextMismatchError
with pytest.raises(ContextMismatchError):
with pytest.raises(DataMaskingContextMismatchError):
data_masker.decrypt(encrypted_data, encryption_context={"this": "is_secure"})


Expand All @@ -106,7 +106,7 @@ def test_encryption_decryption_key_mismatch(data_masker, kms_key2_arn):
encrypted_data = data_masker.encrypt(value)

# THEN when decrypting with a different key it should fail
data_masker_key2 = DataMasking(provider=AwsEncryptionSdkProvider(keys=[kms_key2_arn]))
data_masker_key2 = DataMasking(provider=AWSEncryptionSDKProvider(keys=[kms_key2_arn]))

with pytest.raises(DecryptKeyError):
data_masker_key2.decrypt(encrypted_data)
Expand Down
Loading

0 comments on commit fec33a6

Please sign in to comment.