Skip to content

Commit

Permalink
Create high-level AWS event streams
Browse files Browse the repository at this point in the history
This updates the high-level event stream interfaces and creates AWS
implementations of them.

The `EventStream` protocol was split into three protocols:
`DuplexEventStream`, `InputEventStream`, and `OutputEventStream`.
These three classes encompass the three different configurations that
clients can expect, and each are typed with their particular use-case
in mind. This lets the type declarations be more concise and
accurate. Before, it could be extremely ambiguous from a typing
perspective what you were getting.

The old `InputEventStream` and `OutputEventStream` classes were
renamed to `AsyncEventPublisher` and `AsyncEventReceiver`,
respectively. This is a more accurate description of what they do,
particularly as they can be used for a service implementation as
well.

In the AWS implementation, some changes needed to be made. Notably
the `Event` class had to get a `decode_async` method to be able to
read from an async stream. Then the calling of that method had to
be pulled out of the deserializer so that both sync and async
clients can use it. Test cases were updated to also test the async
method.

Tests for the event stream classes will come in the form of protocol
tests later on.
  • Loading branch information
JordonPhillips committed Oct 24, 2024
1 parent fda945a commit 7cec4ee
Show file tree
Hide file tree
Showing 11 changed files with 569 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@
import datetime
from collections.abc import Callable

from smithy_core.aio.interfaces import AsyncByteStream, AsyncCloseable
from smithy_core.codecs import Codec
from smithy_core.deserializers import ShapeDeserializer, SpecificShapeDeserializer
from smithy_core.interfaces import BytesReader
from smithy_core.deserializers import (
DeserializeableShape,
ShapeDeserializer,
SpecificShapeDeserializer,
)
from smithy_core.schemas import Schema
from smithy_core.utils import expect_type
from smithy_event_stream.aio.interfaces import AsyncEventReceiver

from ..events import HEADERS_DICT, Event
from ..exceptions import EventError, UnmodeledEventError
Expand All @@ -17,11 +22,38 @@
INITIAL_MESSAGE_TYPES = (INITIAL_REQUEST_EVENT_TYPE, INITIAL_RESPONSE_EVENT_TYPE)


class EventDeserializer(SpecificShapeDeserializer):
class AWSAsyncEventReceiver[E: DeserializeableShape](AsyncEventReceiver[E]):
def __init__(
self, source: BytesReader, payload_codec: Codec, is_client_mode: bool = True
self,
payload_codec: Codec,
source: AsyncByteStream,
deserializer: Callable[[ShapeDeserializer], E],
is_client_mode: bool = True,
) -> None:
self._payload_codec = payload_codec
self._source = source
self._is_client_mode = is_client_mode
self._deserializer = deserializer

async def receive(self) -> E | None:
event = await Event.decode_async(self._source)
deserializer = EventDeserializer(
event=event,
payload_codec=self._payload_codec,
is_client_mode=self._is_client_mode,
)
return self._deserializer(deserializer)

async def close(self) -> None:
if isinstance(self._source, AsyncCloseable):
await self._source.close()


class EventDeserializer(SpecificShapeDeserializer):
def __init__(
self, event: Event, payload_codec: Codec, is_client_mode: bool = True
) -> None:
self._event = event
self._payload_codec = payload_codec
self._is_client_mode = is_client_mode

Expand All @@ -30,13 +62,12 @@ def read_struct(
schema: Schema,
consumer: Callable[[Schema, ShapeDeserializer], None],
) -> None:
event = Event.decode(self._source)
headers = event.message.headers
headers = self._event.message.headers

payload_deserializer = None
if event.message.payload:
if self._event.message.payload:
payload_deserializer = self._payload_codec.create_deserializer(
event.message.payload
self._event.message.payload
)

message_deserializer = EventMessageDeserializer(headers, payload_deserializer)
Expand All @@ -61,7 +92,7 @@ def read_struct(
expect_type(str, headers[":error-message"]),
)
case _:
raise EventError(f"Unknown event structure: {event}")
raise EventError(f"Unknown event structure: {self._event}")


class EventMessageDeserializer(SpecificShapeDeserializer):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import datetime
from collections.abc import Iterator
from collections.abc import Callable, Iterator
from contextlib import contextmanager
from io import BytesIO
from typing import Never

from smithy_core.aio.interfaces import AsyncCloseable, AsyncWriter
from smithy_core.codecs import Codec
from smithy_core.exceptions import ExpectationNotMetException
from smithy_core.schemas import Schema
from smithy_core.serializers import (
InterceptingSerializer,
SerializeableShape,
ShapeSerializer,
SpecificShapeSerializer,
)
from smithy_core.shapes import ShapeType
from smithy_core.utils import expect_type
from smithy_event_stream.aio.interfaces import AsyncEventPublisher

from ..events import EventHeaderEncoder, EventMessage
from ..exceptions import InvalidHeaderValue
Expand All @@ -30,6 +34,40 @@
_DEFAULT_BLOB_CONTENT_TYPE = "application/octet-stream"


type Signer = Callable[[EventMessage], EventMessage]
"""A function that takes an event message and signs it, and returns it signed."""


class AWSAsyncEventPublisher[E: SerializeableShape](AsyncEventPublisher[E]):
def __init__(
self,
payload_codec: Codec,
async_writer: AsyncWriter,
signer: Signer | None = None,
is_client_mode: bool = True,
):
self._writer = async_writer
self._signer = signer
self._serializer = EventSerializer(
payload_codec=payload_codec, is_client_mode=is_client_mode
)

async def send(self, event: E) -> None:
event.serialize(self._serializer)
result = self._serializer.get_result()
if result is None:
raise ExpectationNotMetException(
"Expected an event message to be serialized, but was None."
)
if self._signer is not None:
result = self._signer(result)
await self._writer.write(result.encode())

async def close(self) -> None:
if isinstance(self._writer, AsyncCloseable):
await self._writer.close()


class EventSerializer(SpecificShapeSerializer):
def __init__(
self,
Expand Down
Loading

0 comments on commit 7cec4ee

Please sign in to comment.