Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add types and code generation for event stream operation signatures #318

Merged
merged 6 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 71 additions & 3 deletions codegen/smithy-python-codegen-test/model/main.smithy
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,14 @@ service Weather {
operations: [
GetCurrentTime
TestUnionListOperation
StreamAtmosphericConditions
]
}

resource City {
identifiers: { cityId: CityId }
identifiers: {
cityId: CityId
}
read: GetCity
list: ListCities
resources: [
Expand Down Expand Up @@ -56,12 +59,16 @@ union UnionListMember {
}

resource Forecast {
identifiers: { cityId: CityId }
identifiers: {
cityId: CityId
}
read: GetForecast
}

resource CityImage {
identifiers: { cityId: CityId }
identifiers: {
cityId: CityId
}
read: GetCityImage
}

Expand Down Expand Up @@ -622,6 +629,67 @@ union Precipitation {
baz: example.weather.nested.more#Baz
}

@http(method: "POST", uri: "/cities/{cityId}/atmosphere")
operation StreamAtmosphericConditions {
input := {
@required
@httpLabel
cityId: CityId

@required
@httpPayload
stream: AtmosphericConditions
}

output := {
@required
@httpHeader("x-initial-sample-rate")
initialSampleRate: Double

@required
@httpPayload
stream: CollectionDirectives
}
}

@streaming
union AtmosphericConditions {
humidity: HumiditySample
pressure: PressureSample
temperature: TemperatureSample
}

@mixin
structure Sample {
@required
collectionTime: Timestamp
}

structure HumiditySample with [Sample] {
@required
humidity: Double
}

structure PressureSample with [Sample] {
@required
pressure: Double
}

structure TemperatureSample with [Sample] {
@required
temperature: Double
}

@streaming
union CollectionDirectives {
sampleRate: SampleRate
}

structure SampleRate {
@required
samplesPerMinute: Double
}

structure OtherStructure {}

enum StringYesNo {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import java.util.LinkedHashSet;
import java.util.Set;
import software.amazon.smithy.codegen.core.SymbolReference;
import software.amazon.smithy.model.knowledge.EventStreamIndex;
import software.amazon.smithy.model.knowledge.EventStreamInfo;
import software.amazon.smithy.model.knowledge.ServiceIndex;
import software.amazon.smithy.model.knowledge.TopDownIndex;
import software.amazon.smithy.model.shapes.OperationShape;
Expand Down Expand Up @@ -104,8 +106,14 @@ def __init__(self, config: $1T | None = None, plugins: list[$2T] | None = None):
""", configSymbol, pluginSymbol, writer.consumer(w -> writeDefaultPlugins(w, defaultPlugins)));

var topDownIndex = TopDownIndex.of(context.model());
var eventStreamIndex = EventStreamIndex.of(context.model());
for (OperationShape operation : topDownIndex.getContainedOperations(service)) {
generateOperation(writer, operation);
if (eventStreamIndex.getInputInfo(operation).isPresent()
|| eventStreamIndex.getOutputInfo(operation).isPresent()) {
generateEventStreamOperation(writer, operation);
} else {
generateOperation(writer, operation);
}
}
});

Expand Down Expand Up @@ -695,4 +703,44 @@ private void generateOperation(PythonWriter writer, OperationShape operation) {
}
});
}

private void generateEventStreamOperation(PythonWriter writer, OperationShape operation) {
writer.pushState();
writer.addDependency(SmithyPythonDependency.SMITHY_EVENT_STREAM);
writer.addImports("smithy_event_stream.aio.interfaces", Set.of(
"EventStream", "InputEventStream", "OutputEventStream"));
var operationSymbol = context.symbolProvider().toSymbol(operation);
var pluginSymbol = CodegenUtils.getPluginSymbol(context.settings());

var input = context.model().expectShape(operation.getInputShape());
var inputSymbol = context.symbolProvider().toSymbol(input);

var eventStreamIndex = EventStreamIndex.of(context.model());
var inputStreamSymbol = eventStreamIndex.getInputInfo(operation)
.map(EventStreamInfo::getEventStreamTarget)
.map(target -> context.symbolProvider().toSymbol(target))
.orElse(null);
writer.putContext("inputStream", inputStreamSymbol);

var output = context.model().expectShape(operation.getOutputShape());
var outputSymbol = context.symbolProvider().toSymbol(output);
var outputStreamSymbol = eventStreamIndex.getOutputInfo(operation)
.map(EventStreamInfo::getEventStreamTarget)
.map(target -> context.symbolProvider().toSymbol(target))
.orElse(null);
writer.putContext("outputStream", outputStreamSymbol);

writer.write("""
async def $L(self, input: $T, plugins: list[$T] | None = None) -> EventStream[
${?inputStream}InputEventStream[${inputStream:T}]${/inputStream}\
${^inputStream}None${/inputStream},
${?outputStream}OutputEventStream[${outputStream:T}]${/outputStream}\
${^outputStream}None${/outputStream},
$T
]:
raise NotImplementedError()
""", operationSymbol.getName(), inputSymbol, pluginSymbol, outputSymbol);

writer.popState();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ public final class SmithyPythonDependency {
false
);

public static final PythonDependency SMITHY_EVENT_STREAM = new PythonDependency(
"smithy_event_stream",
"==0.0.1",
Type.DEPENDENCY,
false
);

/**
* testing framework used in generated functional tests.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,20 @@ public void run() {

writer.write("""
@dataclass
class $L:
${C|}
class $1L:
${2C|}

value: $T
value: $3T

def serialize(self, serializer: ShapeSerializer):
serializer.write_struct($T, self)
serializer.write_struct($4T, self)

def serialize_members(self, serializer: ShapeSerializer):
${C|}
${5C|}

@classmethod
def deserialize(cls, deserializer: ShapeDeserializer) -> Self:
return cls(value=${6C|})

""",
memberSymbol.getName(),
Expand All @@ -90,7 +94,11 @@ def serialize_members(self, serializer: ShapeSerializer):
targetSymbol,
schemaSymbol,
writer.consumer(w -> target.accept(
new MemberSerializerGenerator(context, w, member, "serializer"))));
new MemberSerializerGenerator(context, w, member, "serializer"))),
writer.consumer(w -> target.accept(
new MemberDeserializerGenerator(context, w, member, "deserializer")))

);
}

// Note that the unknown variant doesn't implement __eq__. This is because
Expand Down Expand Up @@ -118,11 +126,15 @@ raise SmithyException("Unknown union variants may not be serialized.")
def serialize_members(self, serializer: ShapeSerializer):
raise SmithyException("Unknown union variants may not be serialized.")

@classmethod
def deserialize(cls, deserializer: ShapeDeserializer) -> Self:
raise NotImplementedError()

""", unknownSymbol.getName());
memberNames.add(unknownSymbol.getName());

shape.getTrait(DocumentationTrait.class).ifPresent(trait -> writer.writeComment(trait.getValue()));
writer.write("type $L = $L\n", parentName, String.join(" | ", memberNames));
shape.getTrait(DocumentationTrait.class).ifPresent(trait -> writer.writeDocs(trait.getValue()));

generateDeserializer();
writer.popState();
Expand Down Expand Up @@ -173,13 +185,10 @@ raise SmithyException("Unions must have exactly one value, but found more than o
private void deserializeMembers() {
int index = 0;
for (MemberShape member : shape.members()) {
var target = model.expectShape(member.getTarget());
writer.write("""
case $L:
self._set_result($T(${C|}))
""", index++, symbolProvider.toSymbol(member), writer.consumer(w ->
target.accept(new MemberDeserializerGenerator(context, writer, member, "de"))
));
self._set_result($T.deserialize(de))
""", index++, symbolProvider.toSymbol(member));
}
}
}
Loading
Loading