From 46772c9a55043b06d3e7d5074637bf65a306eb01 Mon Sep 17 00:00:00 2001 From: JordonPhillips Date: Tue, 27 Aug 2024 15:22:19 +0200 Subject: [PATCH] Import AWS protocol framing and event stream support This brings over the protocol framing and event stream support that is currently being used in the transcribe streaming sdk, with some changes from botocore. Minor refactoring has been done to format, add type hints, and to upgrade to some new language features. Some issues reported by the typing have been addressed, mostly accounting for cases where a value could be `None` but wasn't handled as such. More refactoring will be done in future PRs to make all this implement a generic interface. Effort will be made to keep this module free of dependencies, so that it can at least be used by the transcribe sdk, if not botocore. --- Makefile | 3 + python-packages/aws-event-stream/BUILD | 28 + python-packages/aws-event-stream/MANIFEST.in | 1 + python-packages/aws-event-stream/NOTICE | 1 + python-packages/aws-event-stream/README.md | 0 .../aws-event-stream/aws_event_stream/BUILD | 13 + .../aws_event_stream/__init__.py | 2 + .../aws-event-stream/aws_event_stream/auth.py | 106 +++ .../aws_event_stream/eventstream.py | 663 ++++++++++++++++++ .../aws_event_stream/exceptions.py | 95 +++ .../aws_event_stream/py.typed | 0 .../aws_event_stream/structures.py | 92 +++ .../aws-event-stream/pyproject.toml | 46 ++ .../aws-event-stream/requirements.txt | 3 + python-packages/aws-event-stream/tests/BUILD | 20 + .../aws-event-stream/tests/__init__.py | 2 + .../aws-event-stream/tests/py.typed | 0 .../aws-event-stream/tests/unit/__init__.py | 2 + .../tests/unit/test_eventstream.py | 660 +++++++++++++++++ 19 files changed, 1737 insertions(+) create mode 100644 python-packages/aws-event-stream/BUILD create mode 100644 python-packages/aws-event-stream/MANIFEST.in create mode 100644 python-packages/aws-event-stream/NOTICE create mode 100644 python-packages/aws-event-stream/README.md create mode 100644 python-packages/aws-event-stream/aws_event_stream/BUILD create mode 100644 python-packages/aws-event-stream/aws_event_stream/__init__.py create mode 100644 python-packages/aws-event-stream/aws_event_stream/auth.py create mode 100644 python-packages/aws-event-stream/aws_event_stream/eventstream.py create mode 100644 python-packages/aws-event-stream/aws_event_stream/exceptions.py create mode 100644 python-packages/aws-event-stream/aws_event_stream/py.typed create mode 100644 python-packages/aws-event-stream/aws_event_stream/structures.py create mode 100644 python-packages/aws-event-stream/pyproject.toml create mode 100644 python-packages/aws-event-stream/requirements.txt create mode 100644 python-packages/aws-event-stream/tests/BUILD create mode 100644 python-packages/aws-event-stream/tests/__init__.py create mode 100644 python-packages/aws-event-stream/tests/py.typed create mode 100644 python-packages/aws-event-stream/tests/unit/__init__.py create mode 100644 python-packages/aws-event-stream/tests/unit/test_eventstream.py diff --git a/Makefile b/Makefile index 41a0b7c3..629a7254 100644 --- a/Makefile +++ b/Makefile @@ -55,6 +55,7 @@ lint-py: pants ./pants fix lint python-packages/smithy-aws-core:: ./pants fix lint python-packages/smithy-json:: ./pants fix lint python-packages/smithy-event-stream:: + ./pants fix lint python-packages/aws-event-stream:: ## Runs checkers for the python packages. @@ -64,6 +65,7 @@ check-py: pants ./pants check python-packages/smithy-aws-core:: ./pants check python-packages/smithy-json:: ./pants check python-packages/smithy-event-stream:: + ./pants check python-packages/aws-event-stream:: ## Runs tests for the python packages. @@ -73,6 +75,7 @@ test-py: pants ./pants test python-packages/smithy-aws-core:: ./pants test python-packages/smithy-json:: ./pants test python-packages/smithy-event-stream:: + ./pants test python-packages/aws-event-stream:: ## Runs formatters/fixers/linters/checkers/tests for the python packages. diff --git a/python-packages/aws-event-stream/BUILD b/python-packages/aws-event-stream/BUILD new file mode 100644 index 00000000..a0ad1e14 --- /dev/null +++ b/python-packages/aws-event-stream/BUILD @@ -0,0 +1,28 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +resource(name="pyproject", source="pyproject.toml") +resource(name="readme", source="README.md") +resource(name="notice", source="NOTICE") + +python_distribution( + name="dist", + dependencies=[ + ":pyproject", + ":readme", + ":notice", + "python-packages/aws-event-stream/aws_event_stream:source", + ], + provides=python_artifact( + name="aws_event_stream", + version="0.0.1", + ), +) + +# We shouldn't need this, but pants will assume that smithy_core is an external +# dependency since it's in pyproject.toml and there's no way to exclude it, so +# for now we need to duplicate things. +python_requirements( + name="requirements", + source="requirements.txt", +) diff --git a/python-packages/aws-event-stream/MANIFEST.in b/python-packages/aws-event-stream/MANIFEST.in new file mode 100644 index 00000000..bcd0f28a --- /dev/null +++ b/python-packages/aws-event-stream/MANIFEST.in @@ -0,0 +1 @@ +include aws_event_stream/py.typed diff --git a/python-packages/aws-event-stream/NOTICE b/python-packages/aws-event-stream/NOTICE new file mode 100644 index 00000000..616fc588 --- /dev/null +++ b/python-packages/aws-event-stream/NOTICE @@ -0,0 +1 @@ +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. diff --git a/python-packages/aws-event-stream/README.md b/python-packages/aws-event-stream/README.md new file mode 100644 index 00000000..e69de29b diff --git a/python-packages/aws-event-stream/aws_event_stream/BUILD b/python-packages/aws-event-stream/aws_event_stream/BUILD new file mode 100644 index 00000000..c5a54b40 --- /dev/null +++ b/python-packages/aws-event-stream/aws_event_stream/BUILD @@ -0,0 +1,13 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +resource(name="pytyped", source="py.typed") + +python_sources( + name="source", + dependencies=[ + ":pytyped", + "python-packages/aws-event-stream:requirements", + ], + sources=["**/*.py"], +) diff --git a/python-packages/aws-event-stream/aws_event_stream/__init__.py b/python-packages/aws-event-stream/aws_event_stream/__init__.py new file mode 100644 index 00000000..04f8b7b7 --- /dev/null +++ b/python-packages/aws-event-stream/aws_event_stream/__init__.py @@ -0,0 +1,2 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 diff --git a/python-packages/aws-event-stream/aws_event_stream/auth.py b/python-packages/aws-event-stream/aws_event_stream/auth.py new file mode 100644 index 00000000..01c770cc --- /dev/null +++ b/python-packages/aws-event-stream/aws_event_stream/auth.py @@ -0,0 +1,106 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import datetime +import hmac +from binascii import hexlify +from collections.abc import Callable +from dataclasses import dataclass +from hashlib import sha256 +from typing import Optional + +from .eventstream import ( + HEADER_SERIALIZATION_VALUE, + HEADERS_SERIALIZATION_DICT, + EventStreamMessageSerializer, +) + + +def _utc_now() -> datetime.datetime: + return datetime.datetime.now(datetime.UTC) + + +@dataclass +class Credentials: + access_key_id: str + secret_access_key: str + session_token: str | None + + +class EventSigner: + _ISO8601_TIMESTAMP_FMT = "%Y%m%dT%H%M%SZ" + _NOW_TYPE = Optional[Callable[[], datetime.datetime]] + + def __init__( + self, + signing_name: str, + region: str, + utc_now: _NOW_TYPE = None, + ): + self.signing_name = signing_name + self.region = region + self.serializer = EventStreamMessageSerializer() + if utc_now is None: + utc_now = _utc_now + self._utc_now = utc_now + + def sign( + self, payload: bytes, prior_signature: bytes, credentials: Credentials + ) -> HEADERS_SERIALIZATION_DICT: + now = self._utc_now() + + # pyright gets confused for some reason if we use + # HEADERS_SERIALIZATION_DICT here. It gets convinced that the dict + # can only have datetime values. + headers: dict[str, HEADER_SERIALIZATION_VALUE] = { + ":date": now, + } + + timestamp = now.strftime(self._ISO8601_TIMESTAMP_FMT) + string_to_sign = self._string_to_sign( + timestamp, headers, payload, prior_signature + ) + event_signature = self._sign_event(timestamp, string_to_sign, credentials) + headers[":chunk-signature"] = event_signature + return headers + + def _keypath(self, timestamp: str) -> str: + parts = [ + timestamp[:8], # Only using the YYYYMMDD + self.region, + self.signing_name, + "aws4_request", + ] + return "/".join(parts) + + def _string_to_sign( + self, + timestamp: str, + headers: HEADERS_SERIALIZATION_DICT, + payload: bytes, + prior_signature: bytes, + ) -> str: + encoded_headers = self.serializer.encode_headers(headers) + parts = [ + "AWS4-HMAC-SHA256-PAYLOAD", + timestamp, + self._keypath(timestamp), + hexlify(prior_signature).decode("utf-8"), + sha256(encoded_headers).hexdigest(), + sha256(payload).hexdigest(), + ] + return "\n".join(parts) + + def _hmac(self, key: bytes, msg: bytes) -> bytes: + return hmac.new(key, msg, sha256).digest() + + def _sign_event( + self, timestamp: str, string_to_sign: str, credentials: Credentials + ) -> bytes: + key = credentials.secret_access_key.encode("utf-8") + today = timestamp[:8].encode("utf-8") # Only using the YYYYMMDD + k_date = self._hmac(b"AWS4" + key, today) + k_region = self._hmac(k_date, self.region.encode("utf-8")) + k_service = self._hmac(k_region, self.signing_name.encode("utf-8")) + k_signing = self._hmac(k_service, b"aws4_request") + return self._hmac(k_signing, string_to_sign.encode("utf-8")) diff --git a/python-packages/aws-event-stream/aws_event_stream/eventstream.py b/python-packages/aws-event-stream/aws_event_stream/eventstream.py new file mode 100644 index 00000000..8db367ca --- /dev/null +++ b/python-packages/aws-event-stream/aws_event_stream/eventstream.py @@ -0,0 +1,663 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Binary Event Stream support for the application/vnd.amazon.eventstream format.""" + +import datetime +import uuid +from binascii import crc32 +from collections.abc import AsyncIterator, Callable, Mapping +from dataclasses import dataclass +from struct import pack, unpack +from typing import Any, Protocol + +from .exceptions import ( + ChecksumMismatch, + DuplicateHeader, + HeaderBytesExceedMaxLength, + HeaderValueBytesExceedMaxLength, + InvalidHeadersLength, + InvalidHeaderValue, + InvalidPayloadLength, + ParserError, + PayloadBytesExceedMaxLength, +) +from .structures import BufferableByteStream + +# byte length of the prelude (total_length + header_length + prelude_crc) +_PRELUDE_LENGTH = 12 +_MAX_HEADERS_LENGTH = 128 * 1024 # 128 Kb +_MAX_HEADER_VALUE_BYTE_LENGTH = 32 * 1024 - 1 +_MAX_PAYLOAD_LENGTH = 16 * 1024**2 # 16 Mb + +HEADER_VALUE = bool | bytes | int | str + + +@dataclass +class HeaderValue[T]: + """A data class for explicit header serialization. + + This is used to represent types that Python doesn't natively have distinctions for, + notably fixed-size integers. + """ + + value: T + + +class Int8HeaderValue(HeaderValue[int]): + """Value that should be explicitly serialized as an int8.""" + + +class Int16HeaderValue(HeaderValue[int]): + """Value that should be explicitly serialized as an int16.""" + + +class Int32HeaderValue(HeaderValue[int]): + """Value that should be explicitly serialized as an int32.""" + + +class Int64HeaderValue(HeaderValue[int]): + """Value that should be explicitly serialized as an int64.""" + + +type NumericHeaderValue = Int8HeaderValue | Int16HeaderValue | Int32HeaderValue | Int64HeaderValue + + +# Possible types for serializing headers differs from possible types returned when decoding +HEADER_SERIALIZATION_VALUE = ( + bool | bytes | int | str | uuid.UUID | datetime.datetime | NumericHeaderValue +) +HEADERS_SERIALIZATION_DICT = Mapping[str, HEADER_SERIALIZATION_VALUE] + + +class EventStreamMessageSerializer: + DEFAULT_INT_TYPE: type[NumericHeaderValue] = Int32HeaderValue + + def serialize(self, headers: HEADERS_SERIALIZATION_DICT, payload: bytes) -> bytes: + # TODO: Investigate preformance of this once we can make requests + if len(payload) > _MAX_PAYLOAD_LENGTH: + raise PayloadBytesExceedMaxLength(len(payload)) + + # The encoded headers are variable length and this length + # is required to generate the prelude, generate the headers first + encoded_headers = self.encode_headers(headers) + if len(encoded_headers) > _MAX_HEADERS_LENGTH: + raise HeaderBytesExceedMaxLength(len(encoded_headers)) + prelude_bytes = self._encode_prelude(encoded_headers, payload) + + # Calculate the prelude_crc and it's byte representation + prelude_crc = self._calculate_checksum(prelude_bytes) + prelude_crc_bytes = pack("!I", prelude_crc) + messages_bytes = prelude_crc_bytes + encoded_headers + payload + + # Calculate the checksum continuing from the prelude crc + final_crc = self._calculate_checksum(messages_bytes, crc=prelude_crc) + final_crc_bytes = pack("!I", final_crc) + return prelude_bytes + messages_bytes + final_crc_bytes + + def encode_headers(self, headers: HEADERS_SERIALIZATION_DICT) -> bytes: + encoded = b"" + for key, val in headers.items(): + encoded += self._encode_header_key(key) + encoded += self._encode_header_val(val) + return encoded + + def _encode_header_key(self, key: str) -> bytes: + enc = key.encode("utf-8") + return pack("B", len(enc)) + enc + + def _encode_header_val(self, val: HEADER_SERIALIZATION_VALUE) -> bytes: + # Handle booleans first to avoid being viewed as ints + if val is True: + return b"\x00" + elif val is False: + return b"\x01" + + if isinstance(val, int): + val = self.DEFAULT_INT_TYPE(val) + + match val: + case Int8HeaderValue(): + return b"\x02" + pack("!b", val.value) + case Int16HeaderValue(): + return b"\x03" + pack("!h", val.value) + case Int32HeaderValue(): + return b"\x04" + pack("!i", val.value) + case Int64HeaderValue(): + return b"\x05" + pack("!q", val.value) + case bytes(): + # Byte arrays are prefaced with a 16bit length, but are restricted + # to a max length of 2**15 - 1, enforce this explicitly + if len(val) > _MAX_HEADER_VALUE_BYTE_LENGTH: + raise HeaderValueBytesExceedMaxLength(len(val)) + return b"\x06" + pack("!H", len(val)) + val + case str(): + utf8_string = val.encode("utf-8") + # Strings are prefaced with a 16bit length, but are restricted + # to a max length of 2**15 - 1, enforce this explicitly + if len(utf8_string) > _MAX_HEADER_VALUE_BYTE_LENGTH: + raise HeaderValueBytesExceedMaxLength(len(utf8_string)) + return b"\x07" + pack("!H", len(utf8_string)) + utf8_string + case datetime.datetime(): + ms_timestamp = int(val.timestamp() * 1000) + return b"\x08" + pack("!q", ms_timestamp) + case uuid.UUID(): + return b"\x09" + val.bytes + + raise InvalidHeaderValue(val) + + def _encode_prelude(self, encoded_headers: bytes, payload: bytes) -> bytes: + header_length = len(encoded_headers) + payload_length = len(payload) + total_length = header_length + payload_length + 16 + return pack("!II", total_length, header_length) + + def _calculate_checksum(self, data: bytes, crc: int = 0) -> int: + return crc32(data, crc) & 0xFFFFFFFF + + +class BaseEvent: + """Base class for typed events sent over event stream with service. + + :param payload: bytes payload to be sent with event + :param event_payload: boolean stating if event has a payload + """ + + def __init__(self, payload: bytes, event_payload: bool | None = None): + self.payload = payload + self.event_payload = event_payload + self.event = True + + +class BaseStream: + """Base class for EventStream established between client and Transcribe Service. + + These streams will always be established automatically by the client. + """ + + def __init__( + self, + input_stream: Any = None, + event_serializer: Any = None, + eventstream_serializer: Any = None, + event_signer: Any = None, + initial_signature: Any = None, + credential_resolver: Any = None, + ): + if input_stream is None: + input_stream = BufferableByteStream() + self._input_stream: BufferableByteStream = input_stream + # TODO: Cant type due to circular import + self._event_serializer = event_serializer + if eventstream_serializer is None: + eventstream_serializer = EventStreamMessageSerializer() + self._eventstream_serializer = eventstream_serializer + self._event_signer = event_signer + self._prior_signature: Any = initial_signature + self._credential_resolver = credential_resolver + + async def send_event(self, event: BaseEvent): + headers, payload = self._event_serializer.serialize(event) + event_bytes = self._eventstream_serializer.serialize(headers, payload) + signed_bytes = await self._sign_event(event_bytes) + self._input_stream.write(signed_bytes) + + async def end_stream(self): + signed_bytes = await self._sign_event(b"") + self._input_stream.write(signed_bytes) + self._input_stream.end_stream() + + async def _sign_event(self, event_bytes: bytes): + creds = await self._credential_resolver.get_credentials() + signed_headers = self._event_signer.sign( + event_bytes, self._prior_signature, creds + ) + self._prior_signature = signed_headers.get(":chunk-signature") + return self._eventstream_serializer.serialize(signed_headers, event_bytes) + + +class DecodeUtils: + """Unpacking utility functions used in the decoder. + + All methods on this class take raw bytes and return a tuple containing the value + parsed from the bytes and the number of bytes consumed to parse that value. + """ + + UINT8_BYTE_FORMAT = "!B" + UINT16_BYTE_FORMAT = "!H" + UINT32_BYTE_FORMAT = "!I" + INT8_BYTE_FORMAT = "!b" + INT16_BYTE_FORMAT = "!h" + INT32_BYTE_FORMAT = "!i" + INT64_BYTE_FORMAT = "!q" + PRELUDE_BYTE_FORMAT = "!III" + + # uint byte size to unpack format + UINT_BYTE_FORMAT = { + 1: UINT8_BYTE_FORMAT, + 2: UINT16_BYTE_FORMAT, + 4: UINT32_BYTE_FORMAT, + } + + @staticmethod + def unpack_true(data: bytes) -> tuple[bool, int]: + """This method consumes none of the provided bytes and returns True. + + :param data: The bytes to parse from. This is ignored in this method. + :returns: The tuple (True, 0) + """ + return True, 0 + + @staticmethod + def unpack_false(data: bytes) -> tuple[bool, int]: + """This method consumes none of the provided bytes and returns False. + + :param data: The bytes to parse from. This is ignored in this method. + :returns: The tuple (False, 0) + """ + return False, 0 + + @staticmethod + def unpack_uint8(data: bytes) -> tuple[int, int]: + """Parse an unsigned 8-bit integer from the bytes. + + :param data: The bytes to parse from. + :returns: A tuple containing the (parsed integer value, bytes consumed) + """ + value = unpack(DecodeUtils.UINT8_BYTE_FORMAT, data[:1])[0] + return value, 1 + + @staticmethod + def unpack_uint32(data: bytes) -> tuple[int, int]: + """Parse an unsigned 32-bit integer from the bytes. + + :param data: The bytes to parse from. + :returns: A tuple containing the (parsed integer value, bytes consumed) + """ + value = unpack(DecodeUtils.UINT32_BYTE_FORMAT, data[:4])[0] + return value, 4 + + @staticmethod + def unpack_int8(data: bytes): + """Parse a signed 8-bit integer from the bytes. + + :param data: The bytes to parse from. + :returns: A tuple containing the (parsed integer value, bytes consumed) + """ + value = unpack(DecodeUtils.INT8_BYTE_FORMAT, data[:1])[0] + return value, 1 + + @staticmethod + def unpack_int16(data: bytes) -> tuple[int, int]: + """Parse a signed 16-bit integer from the bytes. + + :param data: The bytes to parse from. + :returns: A tuple containing the (parsed integer value, bytes consumed) + """ + value = unpack(DecodeUtils.INT16_BYTE_FORMAT, data[:2])[0] + return value, 2 + + @staticmethod + def unpack_int32(data: bytes) -> tuple[int, int]: + """Parse a signed 32-bit integer from the bytes. + + :param data: The bytes to parse from. + :returns: A tuple containing the (parsed integer value, bytes consumed) + """ + value = unpack(DecodeUtils.INT32_BYTE_FORMAT, data[:4])[0] + return value, 4 + + @staticmethod + def unpack_int64(data: bytes) -> tuple[int, int]: + """Parse a signed 64-bit integer from the bytes. + + :param data: The bytes to parse from. + :returns: A tuple containing the (parsed integer value, bytes consumed) + """ + value = unpack(DecodeUtils.INT64_BYTE_FORMAT, data[:8])[0] + return value, 8 + + @staticmethod + def unpack_byte_array(data: bytes, length_byte_size: int = 2) -> tuple[bytes, int]: + """Parse a variable length byte array from the bytes. + + The bytes are expected to be in the following format: + [ length ][0 ... length bytes] + where length is an unsigned integer represented in the smallest number + of bytes to hold the maximum length of the array. + + :param data: The bytes to parse from. + :param length_byte_size: The byte size of the preceding integer that + represents the length of the array. Supported values are 1, 2, and 4. + :returns: A tuple containing the (parsed bytes, bytes consumed) + """ + uint_byte_format = DecodeUtils.UINT_BYTE_FORMAT[length_byte_size] + length = unpack(uint_byte_format, data[:length_byte_size])[0] + bytes_end = length + length_byte_size + array_bytes = data[length_byte_size:bytes_end] + return array_bytes, bytes_end + + @staticmethod + def unpack_utf8_string(data: bytes, length_byte_size: int = 2) -> tuple[str, int]: + """Parse a variable length utf-8 string from the bytes. + + The bytes are expected to be in the following format: + [ length ][0 ... length bytes] + where length is an unsigned integer represented in the smallest number + of bytes to hold the maximum length of the array and the following + bytes are a valid utf-8 string. + + :param data: The bytes to parse from. + :param length_byte_size: The byte size of the preceding integer that + represents the length of the array. Supported values are 1, 2, and 4. + :returns: A tuple containing the (parsed string, bytes consumed) + """ + array_bytes, consumed = DecodeUtils.unpack_byte_array(data, length_byte_size) + return array_bytes.decode("utf-8"), consumed + + @staticmethod + def unpack_uuid(data: bytes) -> tuple[bytes, int]: + """Parse a 16-byte uuid from the bytes. + + :param data: The bytes to parse from. + :returns: A tuple containing the (uuid bytes, bytes consumed). + """ + return data[:16], 16 + + @staticmethod + def unpack_prelude(data: bytes) -> tuple[tuple[Any, ...], int]: + """Parse the prelude for an event stream message from the bytes. + + The prelude for an event stream message has the following format: + [total_length][header_length][prelude_crc] + where each field is an unsigned 32-bit integer. + + :param data: The bytes to parse from. + :returns: A tuple of ((total_length, headers_length, prelude_crc), + consumed) + """ + return (unpack(DecodeUtils.PRELUDE_BYTE_FORMAT, data), _PRELUDE_LENGTH) + + +def _validate_checksum(data: bytes, checksum: int, crc: int = 0) -> None: + # To generate the same numeric value across all Python versions and + # platforms use crc32(data) & 0xffffffff. + computed_checksum = crc32(data, crc) & 0xFFFFFFFF + if checksum != computed_checksum: + raise ChecksumMismatch(checksum, computed_checksum) + + +class MessagePrelude: + """Represents the prelude of an event stream message.""" + + def __init__(self, total_length: int, headers_length: int, crc: int): + self.total_length = total_length + self.headers_length = headers_length + self.crc = crc + + @property + def payload_length(self) -> int: + """Calculates the total payload length. + + The extra minus 4 bytes is for the message CRC. + """ + return self.total_length - self.headers_length - _PRELUDE_LENGTH - 4 + + @property + def payload_end(self) -> int: + """Calculates the byte offset for the end of the message payload. + + The extra minus 4 bytes is for the message CRC. + """ + return self.total_length - 4 + + @property + def headers_end(self) -> int: + """Calculates the byte offset for the end of the message headers.""" + return _PRELUDE_LENGTH + self.headers_length + + +class EventStreamMessage: + """Represents an event stream message.""" + + def __init__( + self, + prelude: MessagePrelude, + headers: dict[str, HEADER_VALUE], + payload: bytes, + crc: int, + ): + self.prelude = prelude + self.headers = headers + self.payload = payload + self.crc = crc + + def to_response_dict(self, status_code: int = 200) -> dict[str, Any]: + message_type = self.headers.get(":message-type") + if message_type == "error" or message_type == "exception": + status_code = 400 + return { + "status_code": status_code, + "headers": self.headers, + "body": self.payload, + } + + +class EventStreamHeaderParser: + """Parses the event headers from an event stream message. + + Expects all of the header data upfront and creates a dictionary of headers to + return. This object can be reused multiple times to parse the headers from multiple + event stream messages. + """ + + # Maps header type to appropriate unpacking function + # These unpacking functions return the value and the amount unpacked + _HEADER_TYPE_MAP: dict[int, Callable[[bytes], tuple[HEADER_VALUE, int]]] = { + # boolean_true + 0: DecodeUtils.unpack_true, + # boolean_false + 1: DecodeUtils.unpack_false, + # byte + 2: DecodeUtils.unpack_int8, + # short + 3: DecodeUtils.unpack_int16, + # integer + 4: DecodeUtils.unpack_int32, + # long + 5: DecodeUtils.unpack_int64, + # byte_array + 6: DecodeUtils.unpack_byte_array, + # string + 7: DecodeUtils.unpack_utf8_string, + # timestamp + 8: DecodeUtils.unpack_int64, + # uuid + 9: DecodeUtils.unpack_uuid, + } + + def __init__(self): + self._data: Any = None + + def parse(self, data: bytes) -> dict[str, HEADER_VALUE]: + """Parses the event stream headers from an event stream message. + + :param data: The bytes that correspond to the headers section of an event stream + message. + :returns: A dictionary of header key, value pairs. + """ + self._data = data + return self._parse_headers() + + def _parse_headers(self) -> dict[str, HEADER_VALUE]: + headers: dict[str, HEADER_VALUE] = {} + while self._data: + name, value = self._parse_header() + if name in headers: + raise DuplicateHeader(name) + headers[name] = value + return headers + + def _parse_header(self) -> tuple[str, HEADER_VALUE]: + name = self._parse_name() + value = self._parse_value() + return name, value + + def _parse_name(self) -> str: + name, consumed = DecodeUtils.unpack_utf8_string(self._data, 1) + self._advance_data(consumed) + return name + + def _parse_type(self) -> int: + type, consumed = DecodeUtils.unpack_uint8(self._data) + self._advance_data(consumed) + return type + + def _parse_value(self) -> HEADER_VALUE: + header_type = self._parse_type() + value_unpacker = self._HEADER_TYPE_MAP[header_type] + value, consumed = value_unpacker(self._data) + self._advance_data(consumed) + return value + + def _advance_data(self, consumed: int): + self._data = self._data[consumed:] + + +class EventStreamBuffer: + """Streaming based event stream buffer. + + A buffer class that wraps bytes from an event stream providing parsed messages as + they become available via an iterable interface. + """ + + def __init__(self): + self._data: bytes = b"" + self._prelude: MessagePrelude | None = None + self._header_parser = EventStreamHeaderParser() + + def add_data(self, data: bytes): + """Add data to the buffer. + + :param data: The bytes to add to the buffer to be used when parsing. + """ + self._data += data + + def _validate_prelude(self, prelude: MessagePrelude): + if prelude.headers_length > _MAX_HEADERS_LENGTH: + raise InvalidHeadersLength(prelude.headers_length) + + if prelude.payload_length > _MAX_PAYLOAD_LENGTH: + raise InvalidPayloadLength(prelude.payload_length) + + def _parse_prelude(self) -> MessagePrelude: + prelude_bytes = self._data[:_PRELUDE_LENGTH] + raw_prelude, _ = DecodeUtils.unpack_prelude(prelude_bytes) + prelude = MessagePrelude(*raw_prelude) + self._validate_prelude(prelude) + + # The minus 4 removes the prelude crc from the bytes to be checked + _validate_checksum(prelude_bytes[: _PRELUDE_LENGTH - 4], prelude.crc) + return prelude + + def _parse_headers(self) -> dict[str, HEADER_VALUE]: + if not self._prelude: + raise ParserError("Attempted to parse headers with missing prelude.") + header_bytes = self._data[_PRELUDE_LENGTH : self._prelude.headers_end] + return self._header_parser.parse(header_bytes) + + def _parse_payload(self) -> bytes: + if not self._prelude: + raise ParserError("Attempted to parse payload with missing prelude.") + prelude = self._prelude + payload_bytes = self._data[prelude.headers_end : prelude.payload_end] + return payload_bytes + + def _parse_message_crc(self) -> int: + if not self._prelude: + raise ParserError("Attempted to parse crc with missing prelude.") + prelude = self._prelude + crc_bytes = self._data[prelude.payload_end : prelude.total_length] + message_crc, _ = DecodeUtils.unpack_uint32(crc_bytes) + return message_crc + + def _parse_message_bytes(self) -> bytes: + if not self._prelude: + raise ParserError("Attempted to parse message with missing prelude.") + # The minus 4 includes the prelude crc to the bytes to be checked + message_bytes = self._data[_PRELUDE_LENGTH - 4 : self._prelude.payload_end] + return message_bytes + + def _validate_message_crc(self) -> int: + if not self._prelude: + raise ParserError("Attempted to parse message with missing prelude.") + message_crc = self._parse_message_crc() + message_bytes = self._parse_message_bytes() + _validate_checksum(message_bytes, message_crc, crc=self._prelude.crc) + return message_crc + + def _parse_message(self) -> EventStreamMessage: + if not self._prelude: + raise ParserError("Attempted to parse message with missing prelude.") + crc = self._validate_message_crc() + headers = self._parse_headers() + payload = self._parse_payload() + message = EventStreamMessage(self._prelude, headers, payload, crc) + self._prepare_for_next_message() + return message + + def _prepare_for_next_message(self): + if not self._prelude: + raise ParserError("Attempted to parse message with missing prelude.") + # Advance the data and reset the current prelude + self._data = self._data[self._prelude.total_length :] + self._prelude = None + + def next(self) -> EventStreamMessage: + """Provides the next available message parsed from the stream.""" + if len(self._data) < _PRELUDE_LENGTH: + raise StopIteration() + + if self._prelude is None: + self._prelude = self._parse_prelude() + + if len(self._data) < self._prelude.total_length: + raise StopIteration() + + return self._parse_message() + + def __next__(self): + return self.next() + + def __iter__(self): + return self + + +class EventParser(Protocol): + def parse(self, event: EventStreamMessage) -> Any: ... + + +class EventStream: + """Wrapper class for an event stream body. + + This wraps the underlying streaming body, parsing it for individual events and + yielding them as they come available through the async iterator interface. + """ + + def __init__(self, raw_stream: Any, parser: EventParser): + self._raw_stream = raw_stream + self._parser = parser + self._event_generator: AsyncIterator[EventStreamMessage] = ( + self._create_raw_event_generator() + ) + + async def __aiter__(self): + async for event in self._event_generator: + parsed_event = self._parser.parse(event) + yield parsed_event + + async def _create_raw_event_generator(self) -> AsyncIterator[EventStreamMessage]: + event_stream_buffer = EventStreamBuffer() + async for chunk in self._raw_stream.chunks(): + event_stream_buffer.add_data(chunk) + for event in event_stream_buffer: + yield event diff --git a/python-packages/aws-event-stream/aws_event_stream/exceptions.py b/python-packages/aws-event-stream/aws_event_stream/exceptions.py new file mode 100644 index 00000000..3d721c25 --- /dev/null +++ b/python-packages/aws-event-stream/aws_event_stream/exceptions.py @@ -0,0 +1,95 @@ +"""Binary Event Stream support for the application/vnd.amazon.eventstream format.""" + +from typing import Any + +_MAX_HEADERS_LENGTH = 128 * 1024 # 128 Kb +_MAX_PAYLOAD_LENGTH = 16 * 1024**2 # 16 Mb + + +class ParserError(Exception): + """Base binary flow encoding parsing exception.""" + + +class DuplicateHeader(ParserError): + """Duplicate header found in the event.""" + + def __init__(self, header: str): + message = f'Duplicate header present: "{header}"' + super().__init__(message) + + +class InvalidHeadersLength(ParserError): + """Headers length is longer than the maximum.""" + + def __init__(self, length: int): + message = ( + f"Header length of {length} exceeded the maximum of {_MAX_HEADERS_LENGTH}" + ) + super().__init__(message) + + +class InvalidPayloadLength(ParserError): + """Payload length is longer than the maximum.""" + + def __init__(self, length: int): + message = ( + f"Payload length of {length} exceeded the maximum of {_MAX_PAYLOAD_LENGTH}" + ) + super().__init__(message) + + +class ChecksumMismatch(ParserError): + """Calculated checksum did not match the expected checksum.""" + + def __init__(self, expected: int, calculated: int): + message = f"Checksum mismatch: expected 0x{expected:08x}, calculated 0x{calculated:08x}" + super().__init__(message) + + +class NoInitialResponseError(ParserError): + """An event of type initial-response was not received. + + This exception is raised when the event stream produced no events or the first event + in the stream was not of the initial-response type. + """ + + def __init__(self): + message = "First event was not of the initial-response type" + super().__init__(message) + + +class SerializationError(Exception): + """Base binary flow encoding serialization exception.""" + + +class InvalidHeaderValue(SerializationError): + def __init__(self, value: Any): + message = f"Invalid header value type: {type(value)}" + super().__init__(message) + self.value = value + + +class HeaderBytesExceedMaxLength(SerializationError): + def __init__(self, length: int): + message = ( + f"Headers exceeded max serialization " + f"length of 128 KiB at {length} bytes" + ) + super().__init__(message) + + +class HeaderValueBytesExceedMaxLength(SerializationError): + def __init__(self, length: int): + message = ( + f"Header bytes value exceeds max serialization " + f"length of (32 KiB - 1) at {length} bytes" + ) + super().__init__(message) + + +class PayloadBytesExceedMaxLength(SerializationError): + def __init__(self, length: int): + message = ( + f"Payload exceeded max serialization " f"length of 16 MiB at {length} bytes" + ) + super().__init__(message) diff --git a/python-packages/aws-event-stream/aws_event_stream/py.typed b/python-packages/aws-event-stream/aws_event_stream/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/python-packages/aws-event-stream/aws_event_stream/structures.py b/python-packages/aws-event-stream/aws_event_stream/structures.py new file mode 100644 index 00000000..f55789ea --- /dev/null +++ b/python-packages/aws-event-stream/aws_event_stream/structures.py @@ -0,0 +1,92 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + + +from io import BufferedIOBase +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + # We need to import this from _typeshed as this is not publicly exposed and + # would otherwise require us to redefine this type to subclass + # BufferedIOBase + from _typeshed import ReadableBuffer, WriteableBuffer + + +class BufferableByteStream(BufferedIOBase): + """BufferableByteStream will always be in non-blocking mode.""" + + def __init__(self): + self._byte_chunks: list[bytes] = [] + self.__done: bool = False + self.__closed: bool = False + + def read(self, size: int | None = -1) -> bytes | None: # type: ignore + if len(self._byte_chunks) < 1 and not self.__done: + raise BlockingIOError("read") + elif (self.__done and not self._byte_chunks) or self.closed: + return b"" + + temp_bytes = self._byte_chunks.pop(0) + remaining_bytes = b"" + if size == -1 or size is None: + return temp_bytes + elif size > 0: + remaining_bytes = temp_bytes[size:] + temp_bytes = temp_bytes[:size] + else: + remaining_bytes = temp_bytes + temp_bytes = b"" + + if len(remaining_bytes) > 0: + self._byte_chunks.insert(0, remaining_bytes) + return temp_bytes + + def read1(self, size: int = -1) -> bytes | None: # type: ignore + return self.read(size) + + def readinto(self, b: "WriteableBuffer", read1: bool = False): + if not isinstance(b, memoryview): + b = memoryview(b) + b = b.cast("B") + + if read1: + data = self.read1(len(b)) + else: + data = self.read(len(b)) + + if data is None: + raise BlockingIOError("readinto") + + n = len(data) + + b[:n] = data + + return n + + def write(self, b: "ReadableBuffer") -> int: + if not isinstance(b, bytes): + type_ = type(b) + raise ValueError( + f"Unexpected value written to BufferableByteStream. " + f"Only bytes are support but {type_} was provided." + ) + + if self.closed or self.__done: + raise OSError("Stream is completed and doesn't support further writes.") + + if b: + self._byte_chunks.append(b) + + return len(b) + + @property + def closed(self) -> bool: + return self.__closed + + def close(self): + self._buffered_bytes_chunks = None + self.__done = True + self.__closed = True + + def end_stream(self): + self.__done = True diff --git a/python-packages/aws-event-stream/pyproject.toml b/python-packages/aws-event-stream/pyproject.toml new file mode 100644 index 00000000..9e31443d --- /dev/null +++ b/python-packages/aws-event-stream/pyproject.toml @@ -0,0 +1,46 @@ +[build-system] +requires = ["setuptools", "setuptools-scm", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "aws_event_stream" +version = "0.0.1" +description = "AWS event stream library for Smithy defined services in Python." +readme = "README.md" +authors = [{name = "Amazon Web Services"}] +keywords = ["aws", "python", "sdk", "amazon", "smithy", "codegen", "http"] +requires-python = ">=3.12" +license = {text = "Apache License 2.0"} +classifiers = [ + "Development Status :: 2 - Pre-Alpha", + "Intended Audience :: Developers", + "Intended Audience :: System Administrators", + "Natural Language :: English", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: Implementation :: CPython", + "Topic :: Software Development :: Libraries" +] + +[project.urls] +source = "https://github.com/awslabs/smithy-python/tree/develop/python-packages/aws-event-stream" +changelog = "https://github.com/awslabs/smithy-python/blob/develop/CHANGES.md" + +[tool.setuptools] +license-files = ["NOTICE"] +include-package-data = true + +[tool.setuptools.packages.find] +exclude=["tests*", "codegen", "designs"] + +[tool.isort] +profile = "black" +honor_noqa = true +src_paths = ["aws_event_stream", "tests"] + +[tool.pytest.ini_options] +asyncio_mode = "auto" diff --git a/python-packages/aws-event-stream/requirements.txt b/python-packages/aws-event-stream/requirements.txt new file mode 100644 index 00000000..baedc420 --- /dev/null +++ b/python-packages/aws-event-stream/requirements.txt @@ -0,0 +1,3 @@ +# We shouldn't need this, but pants will assume that smithy_core is an external +# dependency since it's in pyproject.toml and there's no way to exclude it, so +# for now we need to duplicate things. diff --git a/python-packages/aws-event-stream/tests/BUILD b/python-packages/aws-event-stream/tests/BUILD new file mode 100644 index 00000000..b80471f7 --- /dev/null +++ b/python-packages/aws-event-stream/tests/BUILD @@ -0,0 +1,20 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +python_test_utils( + name="test_utils", + sources=[ + "**/conftest.py", # pytest's conftest.py file + ], +) + +resource(name="pytyped", source="py.typed") + +python_tests( + name="tests", + dependencies=[":test_utils", ":pytyped"], + sources=[ + "**/test_*.py", + "**/tests.py", + ], +) diff --git a/python-packages/aws-event-stream/tests/__init__.py b/python-packages/aws-event-stream/tests/__init__.py new file mode 100644 index 00000000..04f8b7b7 --- /dev/null +++ b/python-packages/aws-event-stream/tests/__init__.py @@ -0,0 +1,2 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 diff --git a/python-packages/aws-event-stream/tests/py.typed b/python-packages/aws-event-stream/tests/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/python-packages/aws-event-stream/tests/unit/__init__.py b/python-packages/aws-event-stream/tests/unit/__init__.py new file mode 100644 index 00000000..04f8b7b7 --- /dev/null +++ b/python-packages/aws-event-stream/tests/unit/__init__.py @@ -0,0 +1,2 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 diff --git a/python-packages/aws-event-stream/tests/unit/test_eventstream.py b/python-packages/aws-event-stream/tests/unit/test_eventstream.py new file mode 100644 index 00000000..447bd1bd --- /dev/null +++ b/python-packages/aws-event-stream/tests/unit/test_eventstream.py @@ -0,0 +1,660 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Unit tests for the binary event stream decoder.""" + +import datetime +import uuid +from typing import Any +from unittest.mock import Mock + +import pytest + +from aws_event_stream.auth import Credentials, EventSigner +from aws_event_stream.eventstream import ( + HEADERS_SERIALIZATION_DICT, + ChecksumMismatch, + DecodeUtils, + DuplicateHeader, + EventParser, + EventStream, + EventStreamBuffer, + EventStreamHeaderParser, + EventStreamMessage, + EventStreamMessageSerializer, + HeaderBytesExceedMaxLength, + HeaderValueBytesExceedMaxLength, + Int8HeaderValue, + Int16HeaderValue, + Int32HeaderValue, + Int64HeaderValue, + InvalidHeadersLength, + InvalidHeaderValue, + InvalidPayloadLength, + MessagePrelude, + PayloadBytesExceedMaxLength, +) + +EMPTY_MESSAGE = ( + b"\x00\x00\x00\x10\x00\x00\x00\x00\x05\xc2H\xeb}\x98\xc8\xff", + EventStreamMessage( + prelude=MessagePrelude( + total_length=0x10, + headers_length=0, + crc=0x05C248EB, + ), + headers={}, + payload=b"", + crc=0x7D98C8FF, + ), +) + +INT8_HEADER = ( + (b"\x00\x00\x00\x17\x00\x00\x00\x07)\x86\x01X\x04" b"byte\x02\xff\xc2\xf8i\xdc"), + EventStreamMessage( + prelude=MessagePrelude( + total_length=0x17, + headers_length=0x7, + crc=0x29860158, + ), + headers={"byte": -1}, + payload=b"", + crc=0xC2F869DC, + ), +) + +INT16_HEADER = ( + (b"\x00\x00\x00\x19\x00\x00\x00\tq\x0e\x92>\x05" b"short\x03\xff\xff\xb2|\xb6\xcc"), + EventStreamMessage( + prelude=MessagePrelude( + total_length=0x19, + headers_length=0x9, + crc=0x710E923E, + ), + headers={"short": -1}, + payload=b"", + crc=0xB27CB6CC, + ), +) + +INT32_HEADER = ( + ( + b"\x00\x00\x00\x1d\x00\x00\x00\r\x83\xe3\xf0\xe7\x07" + b"integer\x04\xff\xff\xff\xff\x8b\x8e\x12\xeb" + ), + EventStreamMessage( + prelude=MessagePrelude( + total_length=0x1D, + headers_length=0xD, + crc=0x83E3F0E7, + ), + headers={"integer": -1}, + payload=b"", + crc=0x8B8E12EB, + ), +) + +INT64_HEADER = ( + ( + b"\x00\x00\x00\x1e\x00\x00\x00\x0e]J\xdb\x8d\x04" + b"long\x05\xff\xff\xff\xff\xff\xff\xff\xffK\xc22\xda" + ), + EventStreamMessage( + prelude=MessagePrelude( + total_length=0x1E, + headers_length=0xE, + crc=0x5D4ADB8D, + ), + headers={"long": -1}, + payload=b"", + crc=0x4BC232DA, + ), +) + +PAYLOAD_NO_HEADERS = ( + b"\x00\x00\x00\x1d\x00\x00\x00\x00\xfdR\x8cZ{'foo':'bar'}\xc3e96", + EventStreamMessage( + prelude=MessagePrelude( + total_length=0x1D, + headers_length=0, + crc=0xFD528C5A, + ), + headers={}, + payload=b"{'foo':'bar'}", + crc=0xC3653936, + ), +) + +PAYLOAD_ONE_STR_HEADER = ( + ( + b"\x00\x00\x00=\x00\x00\x00 \x07\xfd\x83\x96\x0ccontent-type\x07\x00\x10" + b"application/json{'foo':'bar'}\x8d\x9c\x08\xb1" + ), + EventStreamMessage( + prelude=MessagePrelude( + total_length=0x3D, + headers_length=0x20, + crc=0x07FD8396, + ), + headers={"content-type": "application/json"}, + payload=b"{'foo':'bar'}", + crc=0x8D9C08B1, + ), +) + +ALL_HEADERS_TYPES = ( + ( + b"\x00\x00\x00\x62\x00\x00\x00\x52\x03\xb5\xcb\x9c" + b"\x010\x00\x011\x01\x012\x02\x02\x013\x03\x00\x03" + b"\x014\x04\x00\x00\x00\x04\x015\x05\x00\x00\x00\x00\x00\x00\x00\x05" + b"\x016\x06\x00\x05bytes\x017\x07\x00\x04utf8" + b"\x018\x08\x00\x00\x00\x00\x00\x00\x00\x08\x019\x090123456789abcdef" + b"\x63\x35\x36\x71" + ), + EventStreamMessage( + prelude=MessagePrelude( + total_length=0x62, + headers_length=0x52, + crc=0x03B5CB9C, + ), + headers={ + "0": True, + "1": False, + "2": 0x02, + "3": 0x03, + "4": 0x04, + "5": 0x05, + "6": b"bytes", + "7": "utf8", + "8": 0x08, + "9": b"0123456789abcdef", + }, + payload=b"", + crc=0x63353671, + ), +) + +ERROR_EVENT_MESSAGE = ( + ( + b"\x00\x00\x00\x52\x00\x00\x00\x42\xbf\x23\x63\x7e" + b"\x0d:message-type\x07\x00\x05error" + b"\x0b:error-code\x07\x00\x04code" + b"\x0e:error-message\x07\x00\x07message" + b"\x6b\x6c\xea\x3d" + ), + EventStreamMessage( + prelude=MessagePrelude( + total_length=0x52, + headers_length=0x42, + crc=0xBF23637E, + ), + headers={ + ":message-type": "error", + ":error-code": "code", + ":error-message": "message", + }, + payload=b"", + crc=0x6B6CEA3D, + ), +) + +# Tuples of encoded messages and their expected decoded output +POSITIVE_CASES = [ + EMPTY_MESSAGE, + INT8_HEADER, + INT16_HEADER, + INT32_HEADER, + INT64_HEADER, + PAYLOAD_NO_HEADERS, + PAYLOAD_ONE_STR_HEADER, + ALL_HEADERS_TYPES, + ERROR_EVENT_MESSAGE, +] + +CORRUPTED_HEADER_LENGTH = ( + ( + b"\x00\x00\x00=\xFF\x00\x01\x02\x07\xfd\x83\x96\x0ccontent-type\x07\x00" + b"\x10application/json{'foo':'bar'}\x8d\x9c\x08\xb1" + ), + InvalidHeadersLength, +) + +CORRUPTED_HEADERS = ( + ( + b"\x00\x00\x00=\x00\x00\x00 \x07\xfd\x83\x96\x0ccontent+type\x07\x00\x10" + b"application/json{'foo':'bar'}\x8d\x9c\x08\xb1" + ), + ChecksumMismatch, +) + +CORRUPTED_LENGTH = ( + b"\x01\x00\x00\x1d\x00\x00\x00\x00\xfdR\x8cZ{'foo':'bar'}\xc3e96", + InvalidPayloadLength, +) + +CORRUPTED_PAYLOAD = ( + b"\x00\x00\x00\x1d\x00\x00\x00\x00\xfdR\x8cZ{'foo':'bar'\x8d\xc3e96", + ChecksumMismatch, +) + +DUPLICATE_HEADER = ( + ( + b"\x00\x00\x00\x24\x00\x00\x00\x14\x4b\xb9\x82\xd0" + b"\x04test\x04asdf\x04test\x04asdf\xf3\xf4\x75\x63" + ), + DuplicateHeader, +) + +# Tuples of encoded messages and their expected exception +NEGATIVE_CASES = [ + CORRUPTED_LENGTH, + CORRUPTED_PAYLOAD, + CORRUPTED_HEADERS, + CORRUPTED_HEADER_LENGTH, + DUPLICATE_HEADER, +] + + +class IdentityParser(EventParser): + def parse(self, event: EventStreamMessage) -> Any: + return event + + +def assert_message_equal(message_a: EventStreamMessage, message_b: EventStreamMessage): + """Asserts all fields for two messages are equal.""" + assert message_a.prelude.total_length == message_b.prelude.total_length + assert message_a.prelude.headers_length == message_b.prelude.headers_length + assert message_a.prelude.crc == message_b.prelude.crc + assert message_a.headers == message_b.headers + assert message_a.payload == message_b.payload + assert message_a.crc == message_b.crc + + +def test_partial_message(): + """Ensure that we can receive partial payloads.""" + data = EMPTY_MESSAGE[0] + event_buffer = EventStreamBuffer() + # This mid point is an arbitrary break in the middle of the headers + mid_point = 15 + event_buffer.add_data(data[:mid_point]) + messages = list(event_buffer) + assert messages == [] + + event_buffer.add_data(data[mid_point : len(data)]) + for message in event_buffer: + assert_message_equal(message, EMPTY_MESSAGE[1]) + + +def check_message_decodes(encoded: bytes, decoded: EventStreamMessage): + """Ensure the message decodes to what we expect.""" + event_buffer = EventStreamBuffer() + event_buffer.add_data(encoded) + messages = list(event_buffer) + assert len(messages) == 1 + assert_message_equal(messages[0], decoded) + + +@pytest.mark.parametrize("encoded,decoded", POSITIVE_CASES) +def test_positive_cases(encoded: bytes, decoded: EventStreamMessage): + """Test that all positive cases decode how we expect.""" + check_message_decodes(encoded, decoded) + + +def test_all_positive_cases(): + """Test all positive cases can be decoded on the same buffer.""" + event_buffer = EventStreamBuffer() + # add all positive test cases to the same buffer + for encoded, _ in POSITIVE_CASES: + event_buffer.add_data(encoded) + # collect all of the expected messages + expected_messages = [decoded for (_, decoded) in POSITIVE_CASES] + # collect all of the decoded messages + decoded_messages = list(event_buffer) + # assert all messages match what we expect + for expected, decoded in zip(expected_messages, decoded_messages): + assert_message_equal(expected, decoded) + + +@pytest.mark.parametrize("encoded,exception", NEGATIVE_CASES) +def test_negative_cases(encoded: bytes, exception: type[Exception]): + """Test that all negative cases raise the expected exception.""" + with pytest.raises(exception): + check_message_decodes(encoded, None) # type: ignore + + +def test_header_parser(): + """Test that the header parser supports all header types.""" + headers_data = ( + b"\x010\x00\x011\x01\x012\x02\x02\x013\x03\x00\x03" + b"\x014\x04\x00\x00\x00\x04\x015\x05\x00\x00\x00\x00\x00\x00\x00\x05" + b"\x016\x06\x00\x05bytes\x017\x07\x00\x04utf8" + b"\x018\x08\x00\x00\x00\x00\x00\x00\x00\x08\x019\x090123456789abcdef" + ) + + expected_headers = { + "0": True, + "1": False, + "2": 0x02, + "3": 0x03, + "4": 0x04, + "5": 0x05, + "6": b"bytes", + "7": "utf8", + "8": 0x08, + "9": b"0123456789abcdef", + } + + parser = EventStreamHeaderParser() + headers = parser.parse(headers_data) + assert headers == expected_headers + + +def test_message_prelude_properties(): + """Test that calculated properties from the payload are correct.""" + # Total length: 40, Headers Length: 15, random crc + prelude = MessagePrelude(40, 15, 0x00000000) + assert prelude.payload_length == 9 + assert prelude.headers_end == 27 + assert prelude.payload_end == 36 + + +def test_message_to_response_dict(): + response_dict = PAYLOAD_ONE_STR_HEADER[1].to_response_dict() + assert response_dict["status_code"] == 200 + expected_headers = {"content-type": "application/json"} + assert response_dict["headers"] == expected_headers + assert response_dict["body"] == b"{'foo':'bar'}" + + +def test_message_to_response_dict_error(): + response_dict = ERROR_EVENT_MESSAGE[1].to_response_dict() + assert response_dict["status_code"] == 400 + headers = { + ":message-type": "error", + ":error-code": "code", + ":error-message": "message", + } + assert response_dict["headers"] == headers + assert response_dict["body"] == b"" + + +def test_unpack_uint8(): + (value, bytes_consumed) = DecodeUtils.unpack_uint8(b"\xDE") + assert bytes_consumed == 1 + assert value == 0xDE + + +def test_unpack_uint32(): + (value, bytes_consumed) = DecodeUtils.unpack_uint32(b"\xDE\xAD\xBE\xEF") + assert bytes_consumed == 4 + assert value == 0xDEADBEEF + + +def test_unpack_int8(): + (value, bytes_consumed) = DecodeUtils.unpack_int8(b"\xFE") + assert bytes_consumed == 1 + assert value == -2 + + +def test_unpack_int16(): + (value, bytes_consumed) = DecodeUtils.unpack_int16(b"\xFF\xFE") + assert bytes_consumed == 2 + assert value == -2 + + +def test_unpack_int32(): + (value, bytes_consumed) = DecodeUtils.unpack_int32(b"\xFF\xFF\xFF\xFE") + assert bytes_consumed == 4 + assert value == -2 + + +def test_unpack_int64(): + test_bytes = b"\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFE" + (value, bytes_consumed) = DecodeUtils.unpack_int64(test_bytes) + assert bytes_consumed == 8 + assert value == -2 + + +def test_unpack_array_short(): + test_bytes = b"\x00\x10application/json" + (value, bytes_consumed) = DecodeUtils.unpack_byte_array(test_bytes) + assert bytes_consumed == 18 + assert value == b"application/json" + + +def test_unpack_byte_array_int(): + (value, array_bytes_consumed) = DecodeUtils.unpack_byte_array( + b"\x00\x00\x00\x10application/json", length_byte_size=4 + ) + assert array_bytes_consumed == 20 + assert value == b"application/json" + + +def test_unpack_utf8_string(): + length = b"\x00\x09" + utf8_string = b"\xe6\x97\xa5\xe6\x9c\xac\xe8\xaa\x9e" + encoded = length + utf8_string + (value, bytes_consumed) = DecodeUtils.unpack_utf8_string(encoded) + assert bytes_consumed == 11 + assert value == utf8_string.decode("utf-8") + + +def test_unpack_prelude(): + data = b"\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x03" + prelude = DecodeUtils.unpack_prelude(data) + assert prelude == ((1, 2, 3), 12) + + +def create_mock_raw_stream(*data: bytes): + raw_stream = Mock() + + async def chunks(): + for chunk in data: + yield chunk + yield b"" + + raw_stream.chunks = chunks + return raw_stream + + +@pytest.mark.asyncio +async def test_event_stream_wrapper_iteration(): + raw_stream = create_mock_raw_stream( + b"\x00\x00\x00+\x00\x00\x00\x0e4\x8b\xec{\x08event-id\x04\x00", + b"\x00\xa0\x0c{'foo':'bar'}\xd3\x89\x02\x85", + ) + + parser = IdentityParser() + event_stream = EventStream(raw_stream, parser) + events: list[EventStreamMessage] = [] + async for event in event_stream: + events.append(event) + assert len(events) == 1 + event = events[0] + + assert event.headers == {"event-id": 0x0000A00C} + assert event.payload == b"{'foo':'bar'}" + + +SERIALIZATION_CASES: list[tuple[bytes, dict[str, Any], bytes]] = [ + # Empty headers and empty payload + (b"\x00\x00\x00\x10\x00\x00\x00\x00\x05\xc2H\xeb}\x98\xc8\xff", {}, b""), + # Empty headers with payload + ( + b"\x00\x00\x00\x1c\x00\x00\x00\x00\xc02\xa5\xeatest payload\x076E\xf9", + {}, + b"test payload", + ), + # Header true value, type 0 + ( + b"\x00\x00\x00\x16\x00\x00\x00\x06c\xe1\x18~\x04true\x00\xf1\xe7\xbc\xd7", + {"true": True}, + b"", + ), + # Header false value, type 1 + ( + b"\x00\x00\x00\x17\x00\x00\x00\x07)\x86\x01X\x05false\x01R1~\xf4", + {"false": False}, + b"", + ), + # Header byte, type 2 + ( + b"\x00\x00\x00\x17\x00\x00\x00\x07)\x86\x01X\x04byte\x02\xff\xc2\xf8i\xdc", + {"byte": Int8HeaderValue(-1)}, + b"", + ), + # Header short, type 3 + ( + b"\x00\x00\x00\x19\x00\x00\x00\tq\x0e\x92>\x05short\x03\xff\xff\xb2|\xb6\xcc", + {"short": Int16HeaderValue(-1)}, + b"", + ), + # Header integer, type 4 + ( + b"\x00\x00\x00\x1d\x00\x00\x00\r\x83\xe3\xf0\xe7\x07integer\x04\xff\xff\xff\xff\x8b\x8e\x12\xeb", + {"integer": Int32HeaderValue(-1)}, + b"", + ), + # Header integer, by default integers will be serialized as 32bits + ( + b"\x00\x00\x00\x1d\x00\x00\x00\r\x83\xe3\xf0\xe7\x07integer\x04\xff\xff\xff\xff\x8b\x8e\x12\xeb", + {"integer": -1}, + b"", + ), + # Header long, type 5 + ( + b"\x00\x00\x00\x1e\x00\x00\x00\x0e]J\xdb\x8d\x04long\x05\xff\xff\xff\xff\xff\xff\xff\xffK\xc22\xda", + {"long": Int64HeaderValue(-1)}, + b"", + ), + # Header bytes, type 6 + ( + b"\x00\x00\x00\x1d\x00\x00\x00\r\x83\xe3\xf0\xe7\x05bytes\x06\x00\x04\xde\xad\xbe\xef\x9a\xabK ", + {"bytes": b"\xde\xad\xbe\xef"}, + b"", + ), + # Header string, type 7 + ( + b"\x00\x00\x00 \x00\x00\x00\x10\xb9T\xe0\t\x06string\x07\x00\x06foobarL\xc53(", + {"string": "foobar"}, + b"", + ), + # Header timestamp, type 8 + ( + b"\x00\x00\x00#\x00\x00\x00\x13g\xfd\xcbc\ttimestamp\x08\x00\x00\x01r\xee\xbc'\xa6\xd4D^\x11", + { + "timestamp": datetime.datetime( + 2020, + 6, + 26, + hour=3, + minute=46, + second=47, + microsecond=846000, + tzinfo=datetime.UTC, + ) + }, + b"", + ), + # Header UUID, type 9 + ( + b"\x00\x00\x00&\x00\x00\x00\x16\xdfw\xb0\x9c\x04uuid\t\xde\xad\xbe\xef\xde\xad\xbe\xef\xde\xad\xbe\xef\xde\xad\xbe\xef\xb1g\xd4{", + {"uuid": uuid.UUID("deadbeef-dead-beef-dead-beefdeadbeef")}, + b"", + ), +] + + +class TestEventStreamMessageSerializer: + @pytest.fixture + def serializer(self): + return EventStreamMessageSerializer() + + @pytest.mark.parametrize("expected, headers, payload", SERIALIZATION_CASES) + def test_serialized_message( + self, + serializer: EventStreamMessageSerializer, + expected: bytes, + headers: HEADERS_SERIALIZATION_DICT, + payload: bytes, + ): + serialized = serializer.serialize(headers, payload) + assert expected == serialized + + def test_encode_headers(self, serializer: EventStreamMessageSerializer): + headers = {"foo": "bar"} + encoded_headers = serializer.encode_headers(headers) + assert b"\x03foo\x07\x00\x03bar" == encoded_headers + + def test_invalid_header_value(self, serializer: EventStreamMessageSerializer): + # Str header value len are stored in a uint16 but cannot be larger + # than 2 ** 15 - 1 + headers = { + "foo": 2.0, + } + with pytest.raises(InvalidHeaderValue): + serializer.serialize(headers, b"") # type: ignore + + def test_header_str_too_long(self, serializer: EventStreamMessageSerializer): + # Str header value len are stored in a uint16 but cannot be larger + # than 2 ** 15 - 1 + headers = { + "foo": "a" * (2**16 - 1), + } + with pytest.raises(HeaderValueBytesExceedMaxLength): + serializer.serialize(headers, b"") + + def test_header_bytes_too_long(self, serializer: EventStreamMessageSerializer): + # Bytes header value len are stored in a uint16 but cannot be larger + # than 2 ** 15 - 1 + headers = { + "foo": b"a" * (2**16 - 1), + } + with pytest.raises(HeaderValueBytesExceedMaxLength): + serializer.serialize(headers, b"") + + def test_headers_too_long(self, serializer: EventStreamMessageSerializer): + # These headers are rougly 150k bytes, more than 128 KiB max + long_header_value = b"a" * 30000 + headers = { + "a": long_header_value, + "b": long_header_value, + "c": long_header_value, + "d": long_header_value, + "e": long_header_value, + } + with pytest.raises(HeaderBytesExceedMaxLength): + serializer.serialize(headers, b"") + + def test_payload_too_long(self, serializer: EventStreamMessageSerializer): + # 18 MiB payaload, larger than the max of 16 MiB + payload = b"abcdefghijklmnopqr" * (1024**2) + with pytest.raises(PayloadBytesExceedMaxLength): + serializer.serialize({}, payload) + + +class TestEventSigner: + @pytest.fixture + def credentials(self): + return Credentials("foo", "bar", None) + + @pytest.fixture + def event_signer(self): + return EventSigner( + "signing-name", + "region-name", + utc_now=self.utc_now, + ) + + def utc_now(self): + return datetime.datetime(2020, 7, 23, 22, 39, 55, 29943, tzinfo=datetime.UTC) + + def test_basic_event_signature( + self, event_signer: EventSigner, credentials: Credentials + ): + signed_headers = event_signer.sign(b"message", b"prior", credentials) + assert signed_headers[":date"] == self.utc_now() + expected_signature = ( + b"\x0e\xf5n\xbf\x8cW\x0b>\xf3\xdc\x9fA\x99^\xd17\xcd" + b"\x86\x9c\xdb\xa0Y\x18\x88+\x9b\x10p{n$e" + ) + assert signed_headers[":chunk-signature"] == expected_signature