Skip to content

Commit

Permalink
Handle empty and incomplete event bytes
Browse files Browse the repository at this point in the history
This adds in explicit handling for both emtpy and incomplete event
bytes. If nothing is able to be read from the source, event decoders
will return None. If there are bytes there, but they're truncated,
then an explicit error is thrown that wraps what would otherwise
be a `struct.error`. This is only applied for truncations that would
not already be caught by checksum validation.
  • Loading branch information
JordonPhillips committed Nov 26, 2024
1 parent 2848899 commit d700b24
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ def __init__(

async def receive(self) -> E | None:
event = await Event.decode_async(self._source)
if event is None:
return None

deserializer = EventDeserializer(
event=event,
payload_codec=self._payload_codec,
Expand Down
49 changes: 39 additions & 10 deletions python-packages/aws-event-stream/aws_event_stream/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .exceptions import (
ChecksumMismatch,
DuplicateHeader,
InvalidEventBytes,
InvalidHeadersLength,
InvalidHeaderValue,
InvalidHeaderValueLength,
Expand Down Expand Up @@ -286,27 +287,41 @@ class Event:
"""

@classmethod
def decode(cls, source: BytesReader) -> Self:
def decode(cls, source: BytesReader) -> Self | None:
"""Decode an event from a byte stream.
:param source: An object to read event bytes from. It must have a `read` method
that accepts a number of bytes to read.
:returns: An Event representing the next event on the source.
:returns: An Event representing the next event on the source, or None if no
data can be read from the source.
"""

prelude_bytes = source.read(8)
if not prelude_bytes:
# If nothing can be read from the source, return None. If bytes are missing
# later, that indicates a problem with the source and therefore will result
# in an exception.
return None

prelude_crc_bytes = source.read(4)
prelude_crc: int = _DecodeUtils.unpack_uint32(prelude_crc_bytes)[0]
try:
prelude_crc: int = _DecodeUtils.unpack_uint32(prelude_crc_bytes)[0]
total_length, headers_length = unpack("!II", prelude_bytes)
except struct.error as e:
raise InvalidEventBytes() from e

total_length, headers_length = unpack("!II", prelude_bytes)
_validate_checksum(prelude_bytes, prelude_crc)
prelude = EventPrelude(
total_length=total_length, headers_length=headers_length, crc=prelude_crc
)

message_bytes = source.read(total_length - _MESSAGE_METADATA_SIZE)
crc: int = _DecodeUtils.unpack_uint32(source.read(4))[0]
try:
crc: int = _DecodeUtils.unpack_uint32(source.read(4))[0]
except struct.error as e:
raise InvalidEventBytes() from e

_validate_checksum(prelude_crc_bytes + message_bytes, crc, prelude_crc)

message = EventMessage(
Expand All @@ -316,27 +331,41 @@ def decode(cls, source: BytesReader) -> Self:
return cls(prelude, message, crc)

@classmethod
async def decode_async(cls, source: AsyncByteStream) -> Self:
async def decode_async(cls, source: AsyncByteStream) -> Self | None:
"""Decode an event from an async byte stream.
:param source: An object to read event bytes from. It must have a `read` method
that accepts a number of bytes to read.
:returns: An Event representing the next event on the source.
:returns: An Event representing the next event on the source, or None if no
data can be read from the source.
"""

prelude_bytes = await source.read(8)
if not prelude_bytes:
# If nothing can be read from the source, return None. If bytes are missing
# later, that indicates a problem with the source and therefore will result
# in an exception.
return None

prelude_crc_bytes = await source.read(4)
prelude_crc: int = _DecodeUtils.unpack_uint32(prelude_crc_bytes)[0]
try:
prelude_crc: int = _DecodeUtils.unpack_uint32(prelude_crc_bytes)[0]
total_length, headers_length = unpack("!II", prelude_bytes)
except struct.error as e:
raise InvalidEventBytes() from e

total_length, headers_length = unpack("!II", prelude_bytes)
_validate_checksum(prelude_bytes, prelude_crc)
prelude = EventPrelude(
total_length=total_length, headers_length=headers_length, crc=prelude_crc
)

message_bytes = await source.read(total_length - _MESSAGE_METADATA_SIZE)
crc: int = _DecodeUtils.unpack_uint32(await source.read(4))[0]
try:
crc: int = _DecodeUtils.unpack_uint32(await source.read(4))[0]
except struct.error as e:
raise InvalidEventBytes() from e

_validate_checksum(prelude_crc_bytes + message_bytes, crc, prelude_crc)

message = EventMessage(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@ def __init__(self, size: str, value: int):
super().__init__(message)


class InvalidEventBytes(EventError):
def __init__(self) -> None:
message = "Invalid event bytes."
super().__init__(message)


class MissingInitialResponse(EventError):
def __init__(self) -> None:
super().__init__("Expected an initial response, but none was found.")
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
@pytest.mark.parametrize("expected,given", EVENT_STREAM_SERDE_CASES)
def test_event_deserializer(expected: DeserializeableShape, given: EventMessage):
source = Event.decode(BytesIO(given.encode()))
assert source is not None
deserializer = EventDeserializer(event=source, payload_codec=JSONCodec())
result = EventStreamDeserializer().deserialize(deserializer)
assert result == expected
Expand All @@ -30,6 +31,7 @@ def test_event_deserializer(expected: DeserializeableShape, given: EventMessage)
def test_deserialize_initial_request():
expected, given = INITIAL_REQUEST_CASE
source = Event.decode(BytesIO(given.encode()))
assert source is not None
deserializer = EventDeserializer(event=source, payload_codec=JSONCodec())
result = EventStreamOperationInputOutput.deserialize(deserializer)
assert result == expected
Expand All @@ -38,6 +40,7 @@ def test_deserialize_initial_request():
def test_deserialize_initial_response():
expected, given = INITIAL_RESPONSE_CASE
source = Event.decode(BytesIO(given.encode()))
assert source is not None
deserializer = EventDeserializer(event=source, payload_codec=JSONCodec())
result = EventStreamOperationInputOutput.deserialize(deserializer)
assert result == expected
Expand All @@ -52,6 +55,7 @@ def test_deserialize_unmodeled_error():
}
)
source = Event.decode(BytesIO(message.encode()))
assert source is not None
deserializer = EventDeserializer(event=source, payload_codec=JSONCodec())

with pytest.raises(UnmodeledEventError, match="InternalError"):
Expand Down
74 changes: 37 additions & 37 deletions python-packages/aws-event-stream/tests/unit/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from aws_event_stream.exceptions import (
ChecksumMismatch,
DuplicateHeader,
InvalidEventBytes,
InvalidHeadersLength,
InvalidHeaderValueLength,
InvalidIntegerValue,
Expand Down Expand Up @@ -381,6 +382,8 @@
),
)

EMPTY_SOURCE = (b"", None)

# Tuples of encoded messages and their expected decoded output
POSITIVE_CASES = [
EMPTY_MESSAGE, # standard
Expand All @@ -398,6 +401,7 @@
PAYLOAD_ONE_STR_HEADER, # standard
ALL_HEADERS_TYPES, # standard
ERROR_EVENT_MESSAGE,
EMPTY_SOURCE,
]

CORRUPTED_HEADERS_LENGTH = (
Expand Down Expand Up @@ -489,66 +493,62 @@
InvalidPayloadLength,
)

TRUNCATED_PRELUDE = (b"\x00", InvalidEventBytes)

MISSING_PRELUDE_CRC_BYTES = (b"\x00\x00\x00\x16", InvalidEventBytes)

MISSING_MESSAGE_CRC_BYTES = (
(
b"\x00\x00\x00\x10" # total length
b"\x00\x00\x00\x00" # headers length
b"\x05\xc2\x48\xeb" # prelude crc
),
InvalidEventBytes,
)

# Tuples of encoded messages and their expected exception
NEGATIVE_CASES = [
CORRUPTED_LENGTH, # standard
CORRUPTED_PAYLOAD, # standard
CORRUPTED_HEADERS, # standard
CORRUPTED_HEADERS_LENGTH, # standard
DUPLICATE_HEADER,
INVALID_HEADERS_LENGTH,
INVALID_HEADER_VALUE_LENGTH,
INVALID_PAYLOAD_LENGTH,
]
NEGATIVE_CASES = {
"corrupted-length": CORRUPTED_LENGTH, # standard
"corrupted-payload": CORRUPTED_PAYLOAD, # standard
"corrupted-headers": CORRUPTED_HEADERS, # standard
"corrupted-headers-length": CORRUPTED_HEADERS_LENGTH, # standard
"duplicate-header": DUPLICATE_HEADER,
"invalid-headers-length": INVALID_HEADERS_LENGTH,
"invalid-header-value-length": INVALID_HEADER_VALUE_LENGTH,
"invalid-payload-length": INVALID_PAYLOAD_LENGTH,
"truncated-prelude": TRUNCATED_PRELUDE,
"missing-prelude-crc-bytes": MISSING_PRELUDE_CRC_BYTES,
"missing-message-crc-bytes": MISSING_MESSAGE_CRC_BYTES,
}


@pytest.mark.parametrize("encoded,expected", POSITIVE_CASES)
def test_decode(encoded: bytes, expected: Event):
def test_decode(encoded: bytes, expected: Event | None):
assert Event.decode(BytesIO(encoded)) == expected


@pytest.mark.parametrize("encoded,expected", POSITIVE_CASES)
async def test_decode_async(encoded: bytes, expected: Event):
async def test_decode_async(encoded: bytes, expected: Event | None):
assert await Event.decode_async(AsyncBytesReader(encoded)) == expected


@pytest.mark.parametrize("expected,event", POSITIVE_CASES)
@pytest.mark.parametrize(
"expected,event", [c for c in POSITIVE_CASES if c[1] is not None]
)
def test_encode(expected: bytes, event: Event):
assert event.message.encode() == expected


@pytest.mark.parametrize(
"encoded,expected",
NEGATIVE_CASES,
ids=[
"corrupted-length",
"corrupted-payload",
"corrupted-headers",
"corrupted-headers-length",
"duplicate-header",
"invalid-headers-length",
"invalid-header-value-length",
"invalid-payload-length",
],
"encoded,expected", NEGATIVE_CASES.values(), ids=NEGATIVE_CASES.keys()
)
def test_negative_cases(encoded: bytes, expected: type[Exception]):
with pytest.raises(expected):
Event.decode(BytesIO(encoded))


@pytest.mark.parametrize(
"encoded,expected",
NEGATIVE_CASES,
ids=[
"corrupted-length",
"corrupted-payload",
"corrupted-headers",
"corrupted-headers-length",
"duplicate-header",
"invalid-headers-length",
"invalid-header-value-length",
"invalid-payload-length",
],
"encoded,expected", NEGATIVE_CASES.values(), ids=NEGATIVE_CASES.keys()
)
async def test_negative_cases_async(encoded: bytes, expected: type[Exception]):
with pytest.raises(expected):
Expand Down

0 comments on commit d700b24

Please sign in to comment.