From 7cec4ee569e46ba2e7c7bf640e053438f6b3d8e0 Mon Sep 17 00:00:00 2001 From: JordonPhillips Date: Mon, 21 Oct 2024 17:52:53 +0200 Subject: [PATCH] Create high-level AWS event streams This updates the high-level event stream interfaces and creates AWS implementations of them. The `EventStream` protocol was split into three protocols: `DuplexEventStream`, `InputEventStream`, and `OutputEventStream`. These three classes encompass the three different configurations that clients can expect, and each are typed with their particular use-case in mind. This lets the type declarations be more concise and accurate. Before, it could be extremely ambiguous from a typing perspective what you were getting. The old `InputEventStream` and `OutputEventStream` classes were renamed to `AsyncEventPublisher` and `AsyncEventReceiver`, respectively. This is a more accurate description of what they do, particularly as they can be used for a service implementation as well. In the AWS implementation, some changes needed to be made. Notably the `Event` class had to get a `decode_async` method to be able to read from an async stream. Then the calling of that method had to be pulled out of the deserializer so that both sync and async clients can use it. Test cases were updated to also test the async method. Tests for the event stream classes will come in the form of protocol tests later on. --- .../_private/deserializers.py | 49 +++- .../aws_event_stream/_private/serializers.py | 40 ++- .../aws_event_stream/aio/__init__.py | 247 ++++++++++++++++++ .../aws_event_stream/events.py | 31 +++ .../aws_event_stream/exceptions.py | 5 + .../aws-event-stream/pyproject.toml | 4 + .../tests/unit/_private/test_deserializers.py | 18 +- .../tests/unit/test_events.py | 25 ++ .../smithy_core/aio/interfaces/__init__.py | 14 + .../smithy_event_stream/aio/__init__.py | 2 + .../smithy_event_stream/aio/interfaces.py | 188 ++++++++++--- 11 files changed, 569 insertions(+), 54 deletions(-) create mode 100644 python-packages/aws-event-stream/aws_event_stream/aio/__init__.py create mode 100644 python-packages/smithy-event-stream/smithy_event_stream/aio/__init__.py diff --git a/python-packages/aws-event-stream/aws_event_stream/_private/deserializers.py b/python-packages/aws-event-stream/aws_event_stream/_private/deserializers.py index 1f4ca8a0..f04b7e50 100644 --- a/python-packages/aws-event-stream/aws_event_stream/_private/deserializers.py +++ b/python-packages/aws-event-stream/aws_event_stream/_private/deserializers.py @@ -3,11 +3,16 @@ import datetime from collections.abc import Callable +from smithy_core.aio.interfaces import AsyncByteStream, AsyncCloseable from smithy_core.codecs import Codec -from smithy_core.deserializers import ShapeDeserializer, SpecificShapeDeserializer -from smithy_core.interfaces import BytesReader +from smithy_core.deserializers import ( + DeserializeableShape, + ShapeDeserializer, + SpecificShapeDeserializer, +) from smithy_core.schemas import Schema from smithy_core.utils import expect_type +from smithy_event_stream.aio.interfaces import AsyncEventReceiver from ..events import HEADERS_DICT, Event from ..exceptions import EventError, UnmodeledEventError @@ -17,11 +22,38 @@ INITIAL_MESSAGE_TYPES = (INITIAL_REQUEST_EVENT_TYPE, INITIAL_RESPONSE_EVENT_TYPE) -class EventDeserializer(SpecificShapeDeserializer): +class AWSAsyncEventReceiver[E: DeserializeableShape](AsyncEventReceiver[E]): def __init__( - self, source: BytesReader, payload_codec: Codec, is_client_mode: bool = True + self, + payload_codec: Codec, + source: AsyncByteStream, + deserializer: Callable[[ShapeDeserializer], E], + is_client_mode: bool = True, ) -> None: + self._payload_codec = payload_codec self._source = source + self._is_client_mode = is_client_mode + self._deserializer = deserializer + + async def receive(self) -> E | None: + event = await Event.decode_async(self._source) + deserializer = EventDeserializer( + event=event, + payload_codec=self._payload_codec, + is_client_mode=self._is_client_mode, + ) + return self._deserializer(deserializer) + + async def close(self) -> None: + if isinstance(self._source, AsyncCloseable): + await self._source.close() + + +class EventDeserializer(SpecificShapeDeserializer): + def __init__( + self, event: Event, payload_codec: Codec, is_client_mode: bool = True + ) -> None: + self._event = event self._payload_codec = payload_codec self._is_client_mode = is_client_mode @@ -30,13 +62,12 @@ def read_struct( schema: Schema, consumer: Callable[[Schema, ShapeDeserializer], None], ) -> None: - event = Event.decode(self._source) - headers = event.message.headers + headers = self._event.message.headers payload_deserializer = None - if event.message.payload: + if self._event.message.payload: payload_deserializer = self._payload_codec.create_deserializer( - event.message.payload + self._event.message.payload ) message_deserializer = EventMessageDeserializer(headers, payload_deserializer) @@ -61,7 +92,7 @@ def read_struct( expect_type(str, headers[":error-message"]), ) case _: - raise EventError(f"Unknown event structure: {event}") + raise EventError(f"Unknown event structure: {self._event}") class EventMessageDeserializer(SpecificShapeDeserializer): diff --git a/python-packages/aws-event-stream/aws_event_stream/_private/serializers.py b/python-packages/aws-event-stream/aws_event_stream/_private/serializers.py index 276488c4..3baf1c0e 100644 --- a/python-packages/aws-event-stream/aws_event_stream/_private/serializers.py +++ b/python-packages/aws-event-stream/aws_event_stream/_private/serializers.py @@ -1,20 +1,24 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 import datetime -from collections.abc import Iterator +from collections.abc import Callable, Iterator from contextlib import contextmanager from io import BytesIO from typing import Never +from smithy_core.aio.interfaces import AsyncCloseable, AsyncWriter from smithy_core.codecs import Codec +from smithy_core.exceptions import ExpectationNotMetException from smithy_core.schemas import Schema from smithy_core.serializers import ( InterceptingSerializer, + SerializeableShape, ShapeSerializer, SpecificShapeSerializer, ) from smithy_core.shapes import ShapeType from smithy_core.utils import expect_type +from smithy_event_stream.aio.interfaces import AsyncEventPublisher from ..events import EventHeaderEncoder, EventMessage from ..exceptions import InvalidHeaderValue @@ -30,6 +34,40 @@ _DEFAULT_BLOB_CONTENT_TYPE = "application/octet-stream" +type Signer = Callable[[EventMessage], EventMessage] +"""A function that takes an event message and signs it, and returns it signed.""" + + +class AWSAsyncEventPublisher[E: SerializeableShape](AsyncEventPublisher[E]): + def __init__( + self, + payload_codec: Codec, + async_writer: AsyncWriter, + signer: Signer | None = None, + is_client_mode: bool = True, + ): + self._writer = async_writer + self._signer = signer + self._serializer = EventSerializer( + payload_codec=payload_codec, is_client_mode=is_client_mode + ) + + async def send(self, event: E) -> None: + event.serialize(self._serializer) + result = self._serializer.get_result() + if result is None: + raise ExpectationNotMetException( + "Expected an event message to be serialized, but was None." + ) + if self._signer is not None: + result = self._signer(result) + await self._writer.write(result.encode()) + + async def close(self) -> None: + if isinstance(self._writer, AsyncCloseable): + await self._writer.close() + + class EventSerializer(SpecificShapeSerializer): def __init__( self, diff --git a/python-packages/aws-event-stream/aws_event_stream/aio/__init__.py b/python-packages/aws-event-stream/aws_event_stream/aio/__init__.py new file mode 100644 index 00000000..b019e004 --- /dev/null +++ b/python-packages/aws-event-stream/aws_event_stream/aio/__init__.py @@ -0,0 +1,247 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import asyncio +from collections.abc import Callable +from typing import Self + +from smithy_core.aio.interfaces import AsyncByteStream, AsyncWriter +from smithy_core.codecs import Codec +from smithy_core.deserializers import DeserializeableShape, ShapeDeserializer +from smithy_core.serializers import SerializeableShape +from smithy_event_stream.aio.interfaces import ( + AsyncEventReceiver, + DuplexEventStream, + InputEventStream, + OutputEventStream, +) + +from .._private.deserializers import AWSAsyncEventReceiver as _AWSEventReceiver +from .._private.serializers import AWSAsyncEventPublisher as _AWSEventPublisher +from .._private.serializers import Signer +from ..exceptions import MissingInitialResponse + + +class AWSDuplexEventStream[ + I: SerializeableShape, O: DeserializeableShape, R: DeserializeableShape +](DuplexEventStream[I, O, R]): + """A duplex event stream using the application/vnd.amazon.eventstream format.""" + + def __init__( + self, + payload_codec: Codec, + async_writer: AsyncWriter, + deserializer: Callable[[ShapeDeserializer], O], + async_reader: AsyncByteStream | None = None, + initial_response: R | None = None, + deserializeable_response: type[R] | None = None, + signer: Signer | None = None, + is_client_mode: bool = True, + ) -> None: + """Construct an AWSDuplexEventStream. + + :param payload_codec: The codec to encode the event payload with. + :param async_writer: The writer to write event bytes to. + :param deserializer: A callable to deserialize events with. This should be the + union's deserialize method. + :param async_reader: The reader to read event bytes from, if available. If not + immediately available, output will be blocked on it becoming available. + :param initial_response: The deserialized operation response, if available. If + not immediately available, output will be blocked on it becoming available. + :param deserializeable_response: The deserializeable response class. Setting + this indicates that the initial response is sent over the event stream. The + deserialize method of this class will be used to deserialize it upon + calling ``await_output``. + :param signer: An optional callable to sign events with prior to them being + encoded. + :param is_client_mode: Whether the stream is being constructed for a client or + server implementation. + """ + self.input_stream = _AWSEventPublisher( + payload_codec=payload_codec, + async_writer=async_writer, + signer=signer, + is_client_mode=is_client_mode, + ) + + self._deserializer = deserializer + self._payload_codec = payload_codec + self._is_client_mode = is_client_mode + + # Create a future to allow awaiting the reader + loop = asyncio.get_event_loop() + self._reader_future: asyncio.Future[AsyncByteStream] = loop.create_future() + if async_reader is not None: + self._reader_future.set_result(async_reader) + + # Create a future to allow awaiting the initial response + self._response = initial_response + self._deserializerable_response = deserializeable_response + self._response_future: asyncio.Future[R] = loop.create_future() + + @property + def response(self) -> R | None: + return self._response + + @response.setter + def response(self, value: R) -> None: + self._response_future.set_result(value) + self._response = value + + def set_reader(self, value: AsyncByteStream) -> None: + """Sets the object to read events from. + + :param value: An async readable object to read event bytes from. + """ + self._reader_future.set_result(value) + + async def await_output(self) -> tuple[R, AsyncEventReceiver[O]]: + async_reader = await self._reader_future + if self.output_stream is None: + self.output_stream = _AWSEventReceiver[O]( + payload_codec=self._payload_codec, + source=async_reader, + deserializer=self._deserializer, + is_client_mode=self._is_client_mode, + ) + + if self.response is None: + if self._deserializerable_response is None: + initial_response = await self._response_future + else: + initial_response_stream = _AWSEventReceiver( + payload_codec=self._payload_codec, + source=async_reader, + deserializer=self._deserializerable_response.deserialize, + is_client_mode=self._is_client_mode, + ) + initial_response = await initial_response_stream.receive() + if initial_response is None: + raise MissingInitialResponse() + self.response = initial_response + else: + initial_response = self.response + + return initial_response, self.output_stream + + +class AWSInputEventStream[I: SerializeableShape, R](InputEventStream[I, R]): + """An input event stream using the application/vnd.amazon.eventstream format.""" + + def __init__( + self, + payload_codec: Codec, + async_writer: AsyncWriter, + initial_response: R | None = None, + signer: Signer | None = None, + is_client_mode: bool = True, + ) -> None: + """Construct an AWSInputEventStream. + + :param payload_codec: The codec to encode the event payload with. + :param async_writer: The writer to write event bytes to. + :param initial_response: The deserialized operation response, if available. + :param signer: An optional callable to sign events with prior to them being + encoded. + :param is_client_mode: Whether the stream is being constructed for a client or + server implementation. + """ + self._response = initial_response + + # Create a future to allow awaiting the initial response. + loop = asyncio.get_event_loop() + self._response_future: asyncio.Future[R] = loop.create_future() + if initial_response is not None: + self._response_future.set_result(initial_response) + + self.input_stream = _AWSEventPublisher( + payload_codec=payload_codec, + async_writer=async_writer, + signer=signer, + is_client_mode=is_client_mode, + ) + + @property + def response(self) -> R | None: + return self._response + + @response.setter + def response(self, value: R) -> None: + self._response_future.set_result(value) + self._response = value + + async def await_output(self) -> R: + return await self._response_future + + +class AWSOutputEventStream[O: DeserializeableShape, R: DeserializeableShape]( + OutputEventStream[O, R] +): + """An output event stream using the application/vnd.amazon.eventstream format.""" + + def __init__( + self, + payload_codec: Codec, + initial_response: R, + async_reader: AsyncByteStream, + deserializer: Callable[[ShapeDeserializer], O], + is_client_mode: bool = True, + ) -> None: + """Construct an AWSOutputEventStream. + + :param payload_codec: The codec to decode event payloads with. + :param initial_response: The deserialized operation response. If this is not + available immediately, use ``AWSOutputEventStream.create``. + :param async_reader: An async reader to read event bytes from. + :param deserializer: A callable to deserialize events with. This should be the + union's deserialize method. + :param is_client_mode: Whether the stream is being constructed for a client or + server implementation. + """ + self.response = initial_response + self.output_stream = _AWSEventReceiver[O]( + payload_codec=payload_codec, + source=async_reader, + deserializer=deserializer, + is_client_mode=is_client_mode, + ) + + @classmethod + async def create( + cls, + payload_codec: Codec, + deserializeable_response: type[R], + async_reader: AsyncByteStream, + deserializer: Callable[[ShapeDeserializer], O], + is_client_mode: bool = True, + ) -> Self: + """Construct an AWSOutputEventStream and decode the initial response. + + :param payload_codec: The codec to decode event payloads with. + :param deserializeable_response: The deserializeable response class. The + deserialize method of this class will be used to deserialize the + initial response from the stream.. + :param initial_response: The deserialized operation response. If this is not + available immediately, use ``AWSOutputEventStream.create``. + :param async_reader: An async reader to read event bytes from. + :param deserializer: A callable to deserialize events with. This should be the + union's deserialize method. + :param is_client_mode: Whether the stream is being constructed for a client or + server implementation. + """ + initial_response_stream = _AWSEventReceiver( + payload_codec=payload_codec, + source=async_reader, + deserializer=deserializeable_response.deserialize, + is_client_mode=is_client_mode, + ) + initial_response = await initial_response_stream.receive() + if initial_response is None: + raise MissingInitialResponse() + + return cls( + payload_codec=payload_codec, + initial_response=initial_response, + async_reader=async_reader, + deserializer=deserializer, + is_client_mode=is_client_mode, + ) diff --git a/python-packages/aws-event-stream/aws_event_stream/events.py b/python-packages/aws-event-stream/aws_event_stream/events.py index 50da440b..ec77473f 100644 --- a/python-packages/aws-event-stream/aws_event_stream/events.py +++ b/python-packages/aws-event-stream/aws_event_stream/events.py @@ -17,6 +17,7 @@ from struct import pack, unpack from typing import Literal, Self +from smithy_core.aio.interfaces import AsyncByteStream from smithy_core.interfaces import BytesReader from smithy_core.types import TimestampFormat @@ -314,6 +315,36 @@ def decode(cls, source: BytesReader) -> Self: ) return cls(prelude, message, crc) + @classmethod + async def decode_async(cls, source: AsyncByteStream) -> Self: + """Decode an event from an async byte stream. + + :param source: An object to read event bytes from. It must have a `read` method + that accepts a number of bytes to read. + + :returns: An Event representing the next event on the source. + """ + + prelude_bytes = await source.read(8) + prelude_crc_bytes = await source.read(4) + prelude_crc: int = _DecodeUtils.unpack_uint32(prelude_crc_bytes)[0] + + total_length, headers_length = unpack("!II", prelude_bytes) + _validate_checksum(prelude_bytes, prelude_crc) + prelude = EventPrelude( + total_length=total_length, headers_length=headers_length, crc=prelude_crc + ) + + message_bytes = await source.read(total_length - _MESSAGE_METADATA_SIZE) + crc: int = _DecodeUtils.unpack_uint32(await source.read(4))[0] + _validate_checksum(prelude_crc_bytes + message_bytes, crc, prelude_crc) + + message = EventMessage( + headers_bytes=message_bytes[: prelude.headers_length], + payload=message_bytes[prelude.headers_length :], + ) + return cls(prelude, message, crc) + class _EventEncoder: """A utility class that encodes message bytes into binary events.""" diff --git a/python-packages/aws-event-stream/aws_event_stream/exceptions.py b/python-packages/aws-event-stream/aws_event_stream/exceptions.py index 8c5dc083..ec705c61 100644 --- a/python-packages/aws-event-stream/aws_event_stream/exceptions.py +++ b/python-packages/aws-event-stream/aws_event_stream/exceptions.py @@ -91,3 +91,8 @@ def __init__(self, size: str, value: int): f"be 32-bit." ) super().__init__(message) + + +class MissingInitialResponse(EventError): + def __init__(self) -> None: + super().__init__("Expected an initial response, but none was found.") diff --git a/python-packages/aws-event-stream/pyproject.toml b/python-packages/aws-event-stream/pyproject.toml index 9e31443d..37a65161 100644 --- a/python-packages/aws-event-stream/pyproject.toml +++ b/python-packages/aws-event-stream/pyproject.toml @@ -25,6 +25,10 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Topic :: Software Development :: Libraries" ] +dependencies=[ + "smithy_core==0.0.1", + "smithy_event_stream==0.0.1", +] [project.urls] source = "https://github.com/awslabs/smithy-python/tree/develop/python-packages/aws-event-stream" diff --git a/python-packages/aws-event-stream/tests/unit/_private/test_deserializers.py b/python-packages/aws-event-stream/tests/unit/_private/test_deserializers.py index 9dfa84df..30153390 100644 --- a/python-packages/aws-event-stream/tests/unit/_private/test_deserializers.py +++ b/python-packages/aws-event-stream/tests/unit/_private/test_deserializers.py @@ -7,7 +7,7 @@ from smithy_json import JSONCodec from aws_event_stream._private.deserializers import EventDeserializer -from aws_event_stream.events import EventMessage +from aws_event_stream.events import Event, EventMessage from aws_event_stream.exceptions import UnmodeledEventError from . import ( @@ -21,24 +21,24 @@ @pytest.mark.parametrize("expected,given", EVENT_STREAM_SERDE_CASES) def test_event_deserializer(expected: DeserializeableShape, given: EventMessage): - source = BytesIO(given.encode()) - deserializer = EventDeserializer(source=source, payload_codec=JSONCodec()) + source = Event.decode(BytesIO(given.encode())) + deserializer = EventDeserializer(event=source, payload_codec=JSONCodec()) result = EventStreamDeserializer().deserialize(deserializer) assert result == expected def test_deserialize_initial_request(): expected, given = INITIAL_REQUEST_CASE - source = BytesIO(given.encode()) - deserializer = EventDeserializer(source=source, payload_codec=JSONCodec()) + source = Event.decode(BytesIO(given.encode())) + deserializer = EventDeserializer(event=source, payload_codec=JSONCodec()) result = EventStreamOperationInputOutput.deserialize(deserializer) assert result == expected def test_deserialize_initial_response(): expected, given = INITIAL_RESPONSE_CASE - source = BytesIO(given.encode()) - deserializer = EventDeserializer(source=source, payload_codec=JSONCodec()) + source = Event.decode(BytesIO(given.encode())) + deserializer = EventDeserializer(event=source, payload_codec=JSONCodec()) result = EventStreamOperationInputOutput.deserialize(deserializer) assert result == expected @@ -51,8 +51,8 @@ def test_deserialize_unmodeled_error(): ":error-message": "An internal server error occurred.", } ) - source = BytesIO(message.encode()) - deserializer = EventDeserializer(source=source, payload_codec=JSONCodec()) + source = Event.decode(BytesIO(message.encode())) + deserializer = EventDeserializer(event=source, payload_codec=JSONCodec()) with pytest.raises(UnmodeledEventError, match="InternalError"): EventStreamOperationInputOutput.deserialize(deserializer) diff --git a/python-packages/aws-event-stream/tests/unit/test_events.py b/python-packages/aws-event-stream/tests/unit/test_events.py index 0bbdfec9..d6c534bb 100644 --- a/python-packages/aws-event-stream/tests/unit/test_events.py +++ b/python-packages/aws-event-stream/tests/unit/test_events.py @@ -6,6 +6,7 @@ from io import BytesIO import pytest +from smithy_core.aio.types import AsyncBytesReader from aws_event_stream.events import ( MAX_HEADER_VALUE_BYTE_LENGTH, @@ -506,6 +507,11 @@ def test_decode(encoded: bytes, expected: Event): assert Event.decode(BytesIO(encoded)) == expected +@pytest.mark.parametrize("encoded,expected", POSITIVE_CASES) +async def test_decode_async(encoded: bytes, expected: Event): + assert await Event.decode_async(AsyncBytesReader(encoded)) == expected + + @pytest.mark.parametrize("expected,event", POSITIVE_CASES) def test_encode(expected: bytes, event: Event): assert event.message.encode() == expected @@ -530,6 +536,25 @@ def test_negative_cases(encoded: bytes, expected: type[Exception]): Event.decode(BytesIO(encoded)) +@pytest.mark.parametrize( + "encoded,expected", + NEGATIVE_CASES, + ids=[ + "corrupted-length", + "corrupted-payload", + "corrupted-headers", + "corrupted-headers-length", + "duplicate-header", + "invalid-headers-length", + "invalid-header-value-length", + "invalid-payload-length", + ], +) +async def test_negative_cases_async(encoded: bytes, expected: type[Exception]): + with pytest.raises(expected): + await Event.decode_async(AsyncBytesReader(encoded)) + + def test_event_prelude_rejects_long_headers(): headers_length = MAX_HEADERS_LENGTH + 1 total_length = headers_length + 16 diff --git a/python-packages/smithy-core/smithy_core/aio/interfaces/__init__.py b/python-packages/smithy-core/smithy_core/aio/interfaces/__init__.py index bdf71d54..79fdb60c 100644 --- a/python-packages/smithy-core/smithy_core/aio/interfaces/__init__.py +++ b/python-packages/smithy-core/smithy_core/aio/interfaces/__init__.py @@ -14,6 +14,20 @@ class AsyncByteStream(Protocol): async def read(self, size: int = -1) -> bytes: ... +@runtime_checkable +class AsyncWriter(Protocol): + """An object with an async write method.""" + + async def write(self, data: bytes) -> None: ... + + +@runtime_checkable +class AsyncCloseable(Protocol): + """An object that can asynchronously close.""" + + async def close(self): ... + + # A union of all acceptable streaming blob types. Deserialized payloads will # always return a ByteStream, or AsyncByteStream if async is enabled. type StreamingBlob = SyncStreamingBlob | AsyncByteStream | AsyncIterable[bytes] diff --git a/python-packages/smithy-event-stream/smithy_event_stream/aio/__init__.py b/python-packages/smithy-event-stream/smithy_event_stream/aio/__init__.py new file mode 100644 index 00000000..04f8b7b7 --- /dev/null +++ b/python-packages/smithy-event-stream/smithy_event_stream/aio/__init__.py @@ -0,0 +1,2 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 diff --git a/python-packages/smithy-event-stream/smithy_event_stream/aio/interfaces.py b/python-packages/smithy-event-stream/smithy_event_stream/aio/interfaces.py index 5708f143..a38468d0 100644 --- a/python-packages/smithy-event-stream/smithy_event_stream/aio/interfaces.py +++ b/python-packages/smithy-event-stream/smithy_event_stream/aio/interfaces.py @@ -6,7 +6,7 @@ from smithy_core.serializers import SerializeableShape -class InputEventStream[E: SerializeableShape](Protocol): +class AsyncEventPublisher[E: SerializeableShape](Protocol): """Asynchronously sends events to a service. This may be used as a context manager to ensure the stream is closed before exiting. @@ -30,7 +30,7 @@ async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any): await self.close() -class OutputEventStream[E: DeserializeableShape](Protocol): +class AsyncEventReceiver[E: DeserializeableShape](Protocol): """Asynchronously receives events from a service. Events may be received via the ``receive`` method or by using this class as @@ -69,10 +69,8 @@ async def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any): await self.close() -class EventStream[I: InputEventStream[Any] | None, O: OutputEventStream[Any] | None, R]( - Protocol -): - """A unidirectional or bidirectional event stream. +class DuplexEventStream[I: SerializeableShape, O: DeserializeableShape, R](Protocol): + """An event stream that both sends and receives messages. To ensure that streams are closed upon exiting, this class may be used as an async context manager. @@ -104,30 +102,46 @@ async def handle_output(stream: EventStream) -> None: return """ - input_stream: I - """An event stream that sends events to the service. + input_stream: AsyncEventPublisher[I] + """An event stream that sends events to the service.""" - This value will be None if the operation has no input stream. - """ + # Exposing response and output_stream via @property allows implementations that + # don't have it immediately available to do things like put a future in + # await_output or otherwise reasonably implement that method while still allowing + # them to inherit directly from the protocol class. + _output_stream: AsyncEventReceiver[O] | None = None + _response: R | None = None - output_stream: O | None = None - """An event stream that receives events from the service. + @property + def output_stream(self) -> AsyncEventReceiver[O] | None: + """An event stream that receives events from the service. - This value may be None until ``await_output`` has been called. + This value may be None until ``await_output`` has been called. - This value will also be None if the operation has no output stream. - """ + This value will also be None if the operation has no output stream. + """ + return self._output_stream - response: R | None = None - """The initial response from the service. + @output_stream.setter + def output_stream(self, value: AsyncEventReceiver[O]) -> None: + self._output_stream = value - This value may be None until ``await_output`` has been called. + @property + def response(self) -> R | None: + """The initial response from the service. - This may include context necessary to interpret output events or prepare - input events. It will always be available before any events. - """ + This value may be None until ``await_output`` has been called. + + This may include context necessary to interpret output events or prepare + input events. It will always be available before any events. + """ + return self._response - async def await_output(self) -> tuple[R, O]: + @response.setter + def response(self, value: R) -> None: + self._response = value + + async def await_output(self) -> tuple[R, AsyncEventReceiver[O]]: """Await the operation's output. The EventStream will be returned as soon as the input stream is ready to @@ -146,17 +160,6 @@ async def await_output(self) -> tuple[R, O]: :returns: A tuple containing the initial response and output stream. If the operation has no output stream, the second value will be None. """ - if self.response is not None: - self.response, self.output_stream = await self._await_output() - - return self._response, self._output_stream # type: ignore - - async def _await_output(self) -> tuple[R, O]: - """Await the operation's output without caching. - - This method is meant to be used with the default implementation of await_output. - It should return the output directly without caching. - """ ... async def close(self) -> None: @@ -167,8 +170,123 @@ async def close(self) -> None: if self.output_stream is None: _, self.output_stream = await self.await_output() - if self.output_stream is not None: - await self.output_stream.close() + await self.input_stream.close() + await self.output_stream.close() + + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any): + await self.close() + + +class InputEventStream[I: SerializeableShape, R](Protocol): + """An event stream that streams messages to the service. + + To ensure that streams are closed upon exiting, this class may be used as an async + context manager. + + .. code-block:: python + + async def main(): + client = ChatClient() + input = PublishMessagesInput(chat_room="aws-python-sdk", username="hunter7") + + async with client.publish_messages(input=input) as stream: + stream.input_stream.send(MessageStreamMessage("High severity ticket alert!")) + await stream.await_output() + """ + + input_stream: AsyncEventPublisher[I] + """An event stream that sends events to the service.""" + + # Exposing response via @property allows implementations that don't have it + # immediately available to do things like put a future in await_output or + # otherwise reasonably implement that method while still allowing them to + # inherit directly from the protocol class. + _response: R | None = None + + @property + def response(self) -> R | None: + """The initial response from the service. + + This value may be None until ``await_output`` has been called. + + This may include context necessary to interpret output events or prepare + input events. It will always be available before any events. + """ + return self._response + + @response.setter + def response(self, value: R) -> None: + self._response = value + + async def await_output(self) -> R: + """Await the operation's output. + + The InputEventStream will be returned as soon as the input stream is ready to + receive events, which may be before the initial response has been received. + + Awaiting this method will wait until the initial response was received. The + operation response will be returned by this operation and also cached in + ``response``. + + The default implementation of this method performs the caching behavior, + delegating to the abstract ``_await_output`` method to actually retrieve the + operation response. + + :returns: The operation's response. + """ + ... + + async def close(self) -> None: + """Closes the event stream.""" + await self.input_stream.close() + + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any): + await self.close() + + +class OutputEventStream[O: DeserializeableShape, R](Protocol): + """An event stream that streams messages from the service. + + To ensure that streams are closed upon exiting, this class may be used as an async + context manager. + + .. code-block:: python + + async def main(): + client = ChatClient() + input = ReceiveMessagesInput(chat_room="aws-python-sdk") + + async with client.receive_messages(input=input) as stream: + async for event in stream.output_stream: + match event: + case MessageStreamMessage(): + print(event.value) + case _: + return + """ + + output_stream: AsyncEventReceiver[O] + """An event stream that receives events from the service. + + This value will also be None if the operation has no output stream. + """ + + response: R + """The initial response from the service. + + This may include context necessary to interpret output events or prepare input + events. It will always be available before any events. + """ + + async def close(self) -> None: + """Closes the event stream.""" + await self.output_stream.close() async def __aenter__(self) -> Self: return self