diff --git a/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java b/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java index 93068768..a56dea3c 100644 --- a/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java +++ b/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java @@ -26,6 +26,7 @@ import software.amazon.smithy.model.knowledge.TopDownIndex; import software.amazon.smithy.model.shapes.OperationShape; import software.amazon.smithy.model.shapes.ServiceShape; +import software.amazon.smithy.model.shapes.Shape; import software.amazon.smithy.model.traits.DocumentationTrait; import software.amazon.smithy.model.traits.StringTrait; import software.amazon.smithy.python.codegen.integration.PythonIntegration; @@ -123,6 +124,16 @@ def __init__(self, config: $1T | None = None, plugins: list[$2T] | None = None): } private void generateOperationExecutor(PythonWriter writer) { + writer.pushState(); + + var hasStreaming = hasEventStream(); + writer.putContext("hasEventStream", hasStreaming); + if (hasStreaming) { + writer.addImports("smithy_core.deserializers", Set.of( + "ShapeDeserializer", "DeserializeableShape")); + writer.addStdlibImport("typing", "Any"); + } + var transportRequest = context.applicationProtocol().requestType(); var transportResponse = context.applicationProtocol().responseType(); var errorSymbol = CodegenUtils.getServiceError(context.settings()); @@ -191,10 +202,18 @@ async def _execute_operation( deserialize: Callable[[$3T, $5T], Awaitable[Output]], config: $5T, operation_name: str, + ${?hasEventStream} + has_input_stream: bool = False, + event_deserializer: Callable[[ShapeDeserializer], Any] | None = None, + event_response_deserializer: DeserializeableShape | None = None, + ${/hasEventStream} ) -> Output: try: return await self._handle_execution( - input, plugins, serialize, deserialize, config, operation_name + input, plugins, serialize, deserialize, config, operation_name, + ${?hasEventStream} + has_input_stream, event_deserializer, event_response_deserializer, + ${/hasEventStream} ) except Exception as e: # Make sure every exception that we throw is an instance of $4T so @@ -211,6 +230,11 @@ async def _handle_execution( deserialize: Callable[[$3T, $5T], Awaitable[Output]], config: $5T, operation_name: str, + ${?hasEventStream} + has_input_stream: bool = False, + event_deserializer: Callable[[ShapeDeserializer], Any] | None = None, + event_response_deserializer: DeserializeableShape | None = None, + ${/hasEventStream} ) -> Output: logger.debug(f"Making request for operation {operation_name} with parameters: {input}") context: InterceptorContext[Input, None, None, None] = InterceptorContext( @@ -326,7 +350,16 @@ await sleep(retry_token.retry_delay) execution_context = cast( InterceptorContext[Input, Output, $2T | None, $3T | None], context ) + ${^hasEventStream} return await self._finalize_execution(interceptors, execution_context) + ${/hasEventStream} + ${?hasEventStream} + operation_output = await self._finalize_execution(interceptors, execution_context) + if has_input_stream or event_deserializer is not None: + ${6C|} + else: + return operation_output + ${/hasEventStream} async def _handle_attempt( self, @@ -342,7 +375,8 @@ async def _handle_attempt( for interceptor in interceptors: interceptor.read_before_attempt(context) - """, pluginSymbol, transportRequest, transportResponse, errorSymbol, configSymbol); + """, pluginSymbol, transportRequest, transportResponse, errorSymbol, configSymbol, + writer.consumer(w -> context.protocolGenerator().wrapEventStream(context, w))); boolean supportsAuth = !ServiceIndex.of(context.model()).getAuthSchemes(service).isEmpty(); writer.pushState(new ResolveIdentitySection()); @@ -604,6 +638,18 @@ async def _finalize_execution( return context.response """, transportRequest, transportResponse); writer.dedent(); + writer.popState(); + } + + private boolean hasEventStream() { + var streamIndex = EventStreamIndex.of(context.model()); + var topDownIndex = TopDownIndex.of(context.model()); + for (OperationShape operation : topDownIndex.getContainedOperations(context.settings().service())) { + if (streamIndex.getInputInfo(operation).isPresent() || streamIndex.getOutputInfo(operation).isPresent()) { + return true; + } + } + return false; } private void initializeHttpAuthParameters(PythonWriter writer) { @@ -649,40 +695,7 @@ private void generateOperation(PythonWriter writer, OperationShape operation) { writer.openBlock("async def $L(self, input: $T, plugins: list[$T] | None = None) -> $T:", "", operationSymbol.getName(), inputSymbol, pluginSymbol, outputSymbol, () -> { - writer.writeDocs(() -> { - var docs = operation.getTrait(DocumentationTrait.class) - .map(StringTrait::getValue) - .orElse(String.format("Invokes the %s operation.", operation.getId().getName())); - - var inputDocs = input.getTrait(DocumentationTrait.class) - .map(StringTrait::getValue) - .orElse("The operation's input."); - - writer.write(""" - $L - - :param input: $L - - :param plugins: A list of callables that modify the configuration dynamically. - Changes made by these plugins only apply for the duration of the operation - execution and will not affect any other operation invocations.""", docs, inputDocs); - }); - - var defaultPlugins = new LinkedHashSet(); - for (PythonIntegration integration : context.integrations()) { - for (RuntimeClientPlugin runtimeClientPlugin : integration.getClientPlugins()) { - if (runtimeClientPlugin.matchesOperation(context.model(), service, operation)) { - runtimeClientPlugin.getPythonPlugin().ifPresent(defaultPlugins::add); - } - } - } - writer.write(""" - operation_plugins: list[Plugin] = [ - $C - ] - if plugins: - operation_plugins.extend(plugins) - """, writer.consumer(w -> writeDefaultPlugins(w, defaultPlugins))); + writeSharedOperationInit(writer, operation, input); if (context.protocolGenerator() == null) { writer.write("raise NotImplementedError()"); @@ -704,16 +717,55 @@ private void generateOperation(PythonWriter writer, OperationShape operation) { }); } + private void writeSharedOperationInit(PythonWriter writer, OperationShape operation, Shape input) { + writer.writeDocs(() -> { + var docs = operation.getTrait(DocumentationTrait.class) + .map(StringTrait::getValue) + .orElse(String.format("Invokes the %s operation.", operation.getId().getName())); + + var inputDocs = input.getTrait(DocumentationTrait.class) + .map(StringTrait::getValue) + .orElse("The operation's input."); + + writer.write(""" + $L + + :param input: $L + + :param plugins: A list of callables that modify the configuration dynamically. + Changes made by these plugins only apply for the duration of the operation + execution and will not affect any other operation invocations.""", docs, inputDocs); + }); + + var defaultPlugins = new LinkedHashSet(); + for (PythonIntegration integration : context.integrations()) { + for (RuntimeClientPlugin runtimeClientPlugin : integration.getClientPlugins()) { + if (runtimeClientPlugin.matchesOperation(context.model(), service, operation)) { + runtimeClientPlugin.getPythonPlugin().ifPresent(defaultPlugins::add); + } + } + } + writer.write(""" + operation_plugins: list[Plugin] = [ + $C + ] + if plugins: + operation_plugins.extend(plugins) + """, writer.consumer(w -> writeDefaultPlugins(w, defaultPlugins))); + + } + private void generateEventStreamOperation(PythonWriter writer, OperationShape operation) { writer.pushState(); writer.addDependency(SmithyPythonDependency.SMITHY_EVENT_STREAM); - writer.addImports("smithy_event_stream.aio.interfaces", Set.of( - "EventStream", "InputEventStream", "OutputEventStream")); var operationSymbol = context.symbolProvider().toSymbol(operation); + writer.putContext("operationName", operationSymbol.getName()); var pluginSymbol = CodegenUtils.getPluginSymbol(context.settings()); + writer.putContext("plugin", pluginSymbol); var input = context.model().expectShape(operation.getInputShape()); var inputSymbol = context.symbolProvider().toSymbol(input); + writer.putContext("input", inputSymbol); var eventStreamIndex = EventStreamIndex.of(context.model()); var inputStreamSymbol = eventStreamIndex.getInputInfo(operation) @@ -724,22 +776,107 @@ private void generateEventStreamOperation(PythonWriter writer, OperationShape op var output = context.model().expectShape(operation.getOutputShape()); var outputSymbol = context.symbolProvider().toSymbol(output); + writer.putContext("output", outputSymbol); + var outputStreamSymbol = eventStreamIndex.getOutputInfo(operation) .map(EventStreamInfo::getEventStreamTarget) .map(target -> context.symbolProvider().toSymbol(target)) .orElse(null); writer.putContext("outputStream", outputStreamSymbol); - writer.write(""" - async def $L(self, input: $T, plugins: list[$T] | None = None) -> EventStream[ - ${?inputStream}InputEventStream[${inputStream:T}]${/inputStream}\ - ${^inputStream}None${/inputStream}, - ${?outputStream}OutputEventStream[${outputStream:T}]${/outputStream}\ - ${^outputStream}None${/outputStream}, - $T - ]: - raise NotImplementedError() - """, operationSymbol.getName(), inputSymbol, pluginSymbol, outputSymbol); + writer.putContext("hasProtocol", context.protocolGenerator() != null); + if (context.protocolGenerator() != null) { + var serSymbol = context.protocolGenerator().getSerializationFunction(context, operation); + writer.putContext("serSymbol", serSymbol); + var deserSymbol = context.protocolGenerator().getDeserializationFunction(context, operation); + writer.putContext("deserSymbol", deserSymbol); + } else { + writer.putContext("serSymbol", null); + writer.putContext("deserSymbol", null); + } + + if (inputStreamSymbol != null) { + if (outputStreamSymbol != null) { + writer.addImport("smithy_event_stream.aio.interfaces", "DuplexEventStream"); + writer.write(""" + async def ${operationName:L}( + self, + input: ${input:T}, + plugins: list[${plugin:T}] | None = None + ) -> DuplexEventStream[${inputStream:T}, ${outputStream:T}, ${output:T}]: + ${C|} + ${^hasProtocol} + raise NotImplementedError() + ${/hasProtocol} + ${?hasProtocol} + return await self._execute_operation( + input=input, + plugins=operation_plugins, + serialize=${serSymbol:T}, + deserialize=${deserSymbol:T}, + config=self._config, + operation_name=${operationName:S}, + has_input_stream=True, + event_deserializer=$T().deserialize, + event_response_deserializer=${output:T}, + ) # type: ignore + ${/hasProtocol} + """, + writer.consumer(w -> writeSharedOperationInit(w, operation, input)), + outputStreamSymbol.expectProperty(SymbolProperties.DESERIALIZER)); + } else { + writer.addImport("smithy_event_stream.aio.interfaces", "InputEventStream"); + writer.write(""" + async def ${operationName:L}( + self, + input: ${input:T}, + plugins: list[${plugin:T}] | None = None + ) -> InputEventStream[${inputStream:T}, ${output:T}]: + ${C|} + ${^hasProtocol} + raise NotImplementedError() + ${/hasProtocol} + ${?hasProtocol} + return await self._execute_operation( + input=input, + plugins=operation_plugins, + serialize=${serSymbol:T}, + deserialize=${deserSymbol:T}, + config=self._config, + operation_name=${operationName:S}, + has_input_stream=True, + ) # type: ignore + ${/hasProtocol} + """, writer.consumer(w -> writeSharedOperationInit(w, operation, input))); + } + } else { + writer.addImport("smithy_event_stream.aio.interfaces", "OutputEventStream"); + writer.write(""" + async def ${operationName:L}( + self, + input: ${input:T}, + plugins: list[${plugin:T}] | None = None + ) -> OutputEventStream[${outputStream:T}, ${output:T}]: + ${C|} + ${^hasProtocol} + raise NotImplementedError() + ${/hasProtocol} + ${?hasProtocol} + return await self._execute_operation( + input=input, + plugins=operation_plugins, + serialize=${serSymbol:T}, + deserialize=${deserSymbol:T}, + config=self._config, + operation_name=${operationName:S}, + event_deserializer=$T().deserialize, + event_response_deserializer=${output:T}, + ) # type: ignore + ${/hasProtocol} + """, + writer.consumer(w -> writeSharedOperationInit(w, operation, input)), + outputStreamSymbol.expectProperty(SymbolProperties.DESERIALIZER)); + } writer.popState(); } diff --git a/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/SmithyPythonDependency.java b/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/SmithyPythonDependency.java index d1af2a4d..954c207d 100644 --- a/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/SmithyPythonDependency.java +++ b/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/SmithyPythonDependency.java @@ -64,6 +64,9 @@ public final class SmithyPythonDependency { false ); + /** + * Core interfaces for event streams. + */ public static final PythonDependency SMITHY_EVENT_STREAM = new PythonDependency( "smithy_event_stream", "==0.0.1", @@ -71,6 +74,16 @@ public final class SmithyPythonDependency { false ); + /** + * EventStream implementations for application/vnd.amazon.eventstream. + */ + public static final PythonDependency AWS_EVENT_STREAM = new PythonDependency( + "aws_event_stream", + "==0.0.1", + Type.DEPENDENCY, + false + ); + /** * testing framework used in generated functional tests. */ diff --git a/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/integration/ProtocolGenerator.java b/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/integration/ProtocolGenerator.java index 5349ee6c..78257ec2 100644 --- a/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/integration/ProtocolGenerator.java +++ b/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/integration/ProtocolGenerator.java @@ -22,6 +22,7 @@ import software.amazon.smithy.model.shapes.ToShapeId; import software.amazon.smithy.python.codegen.ApplicationProtocol; import software.amazon.smithy.python.codegen.GenerationContext; +import software.amazon.smithy.python.codegen.PythonWriter; import software.amazon.smithy.utils.CaseUtils; import software.amazon.smithy.utils.SmithyUnstableApi; @@ -167,4 +168,25 @@ default void generateSharedDeserializerComponents(GenerationContext context) { */ default void generateProtocolTests(GenerationContext context) { } + + /** + * Generates the code to wrap an operation output into an event stream. + * + *

Important context variables are: + *

+ * + * @param context Generation context. + * @param writer The writer to write to. + */ + default void wrapEventStream(GenerationContext context, PythonWriter writer) { + } } diff --git a/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/integration/RestJsonProtocolGenerator.java b/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/integration/RestJsonProtocolGenerator.java index 7e566c6a..e40e16ce 100644 --- a/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/integration/RestJsonProtocolGenerator.java +++ b/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/integration/RestJsonProtocolGenerator.java @@ -167,8 +167,9 @@ protected void serializePayloadBody( // or a blob, meaning it's some potentially big collection of bytes. // See also: https://smithy.io/2.0/spec/streaming.html#smithy-api-streaming-trait if (payloadBinding.getMember().getMemberTrait(context.model(), StreamingTrait.class).isPresent()) { - // TODO: support event streams if (target.isUnionShape()) { + writer.addImport("smithy_core.aio.types", "AsyncBytesProvider"); + writer.write("body = AsyncBytesProvider()"); return; } @@ -306,7 +307,6 @@ protected void deserializePayloadBody( Shape operationOrError, HttpBinding payloadBinding ) { - writer.addDependency(SmithyPythonDependency.SMITHY_JSON); writer.addDependency(SmithyPythonDependency.SMITHY_CORE); writer.addImport("smithy_json", "JSONCodec"); @@ -379,4 +379,42 @@ protected void resolveErrorCodeAndMessage(GenerationContext context, } writer.write(")"); } + + @Override + public void wrapEventStream(GenerationContext context, PythonWriter writer) { + writer.addDependency(SmithyPythonDependency.SMITHY_JSON); + writer.addDependency(SmithyPythonDependency.AWS_EVENT_STREAM); + writer.addDependency(SmithyPythonDependency.SMITHY_CORE); + writer.addImports("aws_event_stream.aio", Set.of( + "AWSDuplexEventStream", "AWSInputEventStream", "AWSOutputEventStream")); + writer.addImport("smithy_json", "JSONCodec"); + writer.addImport("smithy_core.types", "TimestampFormat"); + writer.addStdlibImport("typing", "Any"); + + writer.write(""" + codec = JSONCodec(default_timestamp_format=TimestampFormat.EPOCH_SECONDS) + if has_input_stream: + if event_deserializer is not None: + return AWSDuplexEventStream[Any, Any, Any]( + payload_codec=codec, + initial_response=operation_output, + async_writer=execution_context.transport_request.body, # type: ignore + async_reader=execution_context.transport_response.body, # type: ignore + deserializer=event_deserializer, # type: ignore + ) + else: + return AWSInputEventStream[Any, Any]( + payload_codec=codec, + initial_response=operation_output, + async_writer=execution_context.transport_request.body, # type: ignore + ) + else: + return AWSOutputEventStream[Any, Any]( + payload_codec=codec, + initial_response=operation_output, + async_reader=execution_context.transport_response.body, # type: ignore + deserializer=event_deserializer, # type: ignore + ) + """); + } }