diff --git a/inngest/experimental/encryption_middleware.py b/inngest/experimental/encryption_middleware.py index b6b2574..590b304 100644 --- a/inngest/experimental/encryption_middleware.py +++ b/inngest/experimental/encryption_middleware.py @@ -54,6 +54,7 @@ def __init__( secret_key: typing.Union[bytes, str], *, decrypt_only: bool = False, + encrypt_invoke_data: bool = False, event_encryption_field: str = _default_event_encryption_field, fallback_decryption_keys: typing.Optional[ list[typing.Union[bytes, str]] @@ -66,6 +67,7 @@ def __init__( raw_request: Framework/platform specific request object. secret_key: Secret key used for encryption and decryption. decrypt_only: Only decrypt data (do not encrypt). + encrypt_invoke_data: Encrypt the data sent to invoked functions. Deprecated: Will be removed in a future release, where invoke data will always be encrypted (equivalent to encrypt_invoke_data=True). event_encryption_field: Automatically encrypt and decrypt this field in event data. fallback_decryption_keys: Fallback secret keys used for decryption. """ @@ -78,6 +80,7 @@ def __init__( ) self._decrypt_only = decrypt_only + self._encrypt_invoke_data = encrypt_invoke_data self._event_encryption_field = event_encryption_field self._fallback_decryption_boxes = [ @@ -94,6 +97,7 @@ def factory( secret_key: typing.Union[bytes, str], *, decrypt_only: bool = False, + encrypt_invoke_data: bool = False, event_encryption_field: str = _default_event_encryption_field, fallback_decryption_keys: typing.Optional[ list[typing.Union[bytes, str]] @@ -107,6 +111,7 @@ def factory( ---- secret_key: Fernet secret key used for encryption and decryption. decrypt_only: Only decrypt data (do not encrypt). + encrypt_invoke_data: Encrypt the data sent to invoked functions. Deprecated: Will be removed in a future release, where invoke data will always be encrypted (equivalent to encrypt_invoke_data=True). event_encryption_field: Automatically encrypt and decrypt this field in event data. fallback_decryption_keys: Fallback secret keys used for decryption. """ @@ -120,6 +125,7 @@ def _factory( raw_request, secret_key, decrypt_only=decrypt_only, + encrypt_invoke_data=encrypt_invoke_data, event_encryption_field=event_encryption_field, fallback_decryption_keys=fallback_decryption_keys, ) @@ -264,7 +270,8 @@ def transform_output(self, result: inngest.TransformOutputResult) -> None: # Encrypt invoke data if present. if ( - result.step is not None + self._encrypt_invoke_data + and result.step is not None and result.step.op is server_lib.Opcode.INVOKE and result.step.opts is not None ): diff --git a/tests/test_experimental/test_encryption_middleware/cases/invoke.py b/tests/test_experimental/test_encryption_middleware/cases/invoke.py index 60a6e1c..849370b 100644 --- a/tests/test_experimental/test_encryption_middleware/cases/invoke.py +++ b/tests/test_experimental/test_encryption_middleware/cases/invoke.py @@ -41,10 +41,14 @@ def create( event_name = base.create_event_name(framework, test_name) fn_id = base.create_fn_id(test_name) state = _State() + mw = EncryptionMiddleware.factory( + _secret_key, + encrypt_invoke_data=True, + ) @client.create_function( fn_id=f"{fn_id}/child", - middleware=[EncryptionMiddleware.factory(_secret_key)], + middleware=[mw], retries=0, trigger=inngest.TriggerEvent(event="never"), ) @@ -61,7 +65,7 @@ def child_fn_sync( @client.create_function( fn_id=fn_id, - middleware=[EncryptionMiddleware.factory(_secret_key)], + middleware=[mw], retries=0, trigger=inngest.TriggerEvent(event=event_name), ) @@ -81,7 +85,7 @@ def fn_sync( @client.create_function( fn_id=f"{fn_id}/child", - middleware=[EncryptionMiddleware.factory(_secret_key)], + middleware=[mw], retries=0, trigger=inngest.TriggerEvent(event="never"), ) @@ -98,7 +102,7 @@ async def child_fn_async( @client.create_function( fn_id=fn_id, - middleware=[EncryptionMiddleware.factory(_secret_key)], + middleware=[mw], retries=0, trigger=inngest.TriggerEvent(event=event_name), )