From 1331dc5443e6c3f5494eed2a44606062d3888446 Mon Sep 17 00:00:00 2001 From: Zelda Hessler Date: Thu, 21 Sep 2023 12:20:51 -0500 Subject: [PATCH] add support for nullable struct members when generating AWS SDKs (#2916) ## Motivation and Context smithy-rs#1767 aws-sdk-rust#536 ## Description This PR adds support for nullability i.e. much less unwraps will be required when using the AWS SDK. For generic clients, this new behavior can be enabled in codegen by setting `nullabilityCheckMode: "Client"` in their codegen config: ``` "plugins": { "rust-client-codegen": { "codegen": { "includeFluentClient": false, "nullabilityCheckMode": "CLIENT_CAREFUL" }, } ``` ## Testing Ran existing tests ## Checklist - [x] I have updated `CHANGELOG.next.toml` if I made changes to the smithy-rs codegen or runtime crates - [x] I have updated `CHANGELOG.next.toml` if I made changes to the AWS SDK, generated SDK code, or SDK runtime crates ---- _By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice._ --------- Co-authored-by: John DiSanti Co-authored-by: Russell Cohen --- CHANGELOG.next.toml | 15 ++ aws/rust-runtime/aws-config/src/sts/util.rs | 17 +- aws/sdk-adhoc-test/build.gradle.kts | 52 ++--- .../models/required-value-test.smithy | 28 +++ .../smithy/rustsdk/BaseRequestIdDecorator.kt | 22 +- .../s3/S3ExtendedRequestIdDecorator.kt | 6 + .../timestream/TimestreamDecorator.kt | 4 +- .../ec2/EC2MakePrimitivesOptionalTest.kt | 15 +- aws/sdk/build.gradle.kts | 4 +- aws/sdk/integration-tests/Cargo.toml | 2 +- .../dynamodb/tests/movies.rs | 15 +- .../qldbsession/tests/integration.rs | 14 +- codegen-client-test/build.gradle.kts | 2 + .../error-correction-nullability-test.smithy | 128 ++++++++++++ .../client/smithy/ClientCodegenVisitor.kt | 4 +- .../client/smithy/ClientRustSettings.kt | 5 + .../IdempotencyTokenGenerator.kt | 62 ++++-- .../generators/ClientBuilderInstantiator.kt | 4 +- .../smithy/generators/ErrorCorrection.kt | 12 +- .../protocol/ProtocolParserGenerator.kt | 34 ++-- .../protocol/ProtocolTestGenerator.kt | 2 +- .../smithy/protocols/ClientProtocolLoader.kt | 4 +- .../smithy/endpoint/EndpointsDecoratorTest.kt | 1 + .../smithy/generators/ErrorCorrectionTest.kt | 18 +- .../rust/codegen/core/rustlang/RustType.kt | 1 + .../codegen/core/smithy/CoreRustSettings.kt | 4 +- .../rust/codegen/core/smithy/SymbolVisitor.kt | 2 +- .../smithy/generators/BuilderGenerator.kt | 2 + .../core/smithy/generators/Instantiator.kt | 2 +- .../codegen/core/smithy/protocols/AwsJson.kt | 2 - .../parse/XmlBindingTraitParserGenerator.kt | 3 +- .../serialize/QuerySerializerGenerator.kt | 22 +- .../testutil/DefaultBuilderInstantiator.kt | 14 +- .../core/testutil/EventStreamTestModels.kt | 2 +- .../rust/codegen/core/testutil/TestHelpers.kt | 14 +- .../core/rustlang/RustReservedWordsTest.kt | 15 +- .../smithy/generators/BuilderGeneratorTest.kt | 3 +- .../generators/StructureGeneratorTest.kt | 1 - .../parse/JsonParserGeneratorTest.kt | 4 +- .../XmlBindingTraitParserGeneratorTest.kt | 2 +- .../AwsQuerySerializerGeneratorTest.kt | 162 ++++++++++++++- .../Ec2QuerySerializerGeneratorTest.kt | 169 +++++++++++++++- .../serialize/JsonSerializerGeneratorTest.kt | 190 +++++++++++++++++- .../XmlBindingTraitSerializerGeneratorTest.kt | 180 ++++++++++++++++- .../generators/protocol/ServerProtocol.kt | 2 +- examples/Cargo.toml | 3 +- examples/pokemon-service/tests/simple.rs | 12 +- .../tests/simple_integration_test.rs | 8 +- rust-runtime/aws-smithy-types/src/blob.rs | 2 +- 49 files changed, 1110 insertions(+), 181 deletions(-) create mode 100644 aws/sdk-adhoc-test/models/required-value-test.smithy create mode 100644 codegen-client-test/model/error-correction-nullability-test.smithy diff --git a/CHANGELOG.next.toml b/CHANGELOG.next.toml index eb146026fc..5468952f15 100644 --- a/CHANGELOG.next.toml +++ b/CHANGELOG.next.toml @@ -34,6 +34,21 @@ references = ["smithy-rs#2911"] meta = { "breaking" = true, "tada" = false, "bug" = false } author = "Velfi" +[[aws-sdk-rust]] +message = "Struct members modeled as required are no longer wrapped in `Option`s [when possible](https://smithy.io/2.0/spec/aggregate-types.html#structure-member-optionality). For upgrade guidance and more info, see [here](https://github.com/awslabs/smithy-rs/discussions/2929)." +references = ["smithy-rs#2916", "aws-sdk-rust#536"] +meta = { "breaking" = true, "tada" = true, "bug" = false } +author = "Velfi" + +[[smithy-rs]] +message = """ +Support for Smithy IDLv2 nullability is now enabled by default. You can maintain the old behavior by setting `nullabilityCheckMode: "CLIENT_ZERO_VALUE_V1" in your codegen config. +For upgrade guidance and more info, see [here](https://github.com/awslabs/smithy-rs/discussions/2929). +""" +references = ["smithy-rs#2916", "smithy-rs#1767"] +meta = { "breaking" = false, "tada" = true, "bug" = false, "target" = "client"} +author = "Velfi" + [[aws-sdk-rust]] message = """ All versions of SigningParams have been updated to contain an [`Identity`](https://docs.rs/aws-smithy-runtime-api/latest/aws_smithy_runtime_api/client/identity/struct.Identity.html) diff --git a/aws/rust-runtime/aws-config/src/sts/util.rs b/aws/rust-runtime/aws-config/src/sts/util.rs index bc6151985d..e215204f84 100644 --- a/aws/rust-runtime/aws-config/src/sts/util.rs +++ b/aws/rust-runtime/aws-config/src/sts/util.rs @@ -17,24 +17,15 @@ pub(crate) fn into_credentials( ) -> provider::Result { let sts_credentials = sts_credentials .ok_or_else(|| CredentialsError::unhandled("STS credentials must be defined"))?; - let expiration = SystemTime::try_from( - sts_credentials - .expiration - .ok_or_else(|| CredentialsError::unhandled("missing expiration"))?, - ) - .map_err(|_| { + let expiration = SystemTime::try_from(sts_credentials.expiration).map_err(|_| { CredentialsError::unhandled( "credential expiration time cannot be represented by a SystemTime", ) })?; Ok(AwsCredentials::new( - sts_credentials - .access_key_id - .ok_or_else(|| CredentialsError::unhandled("access key id missing from result"))?, - sts_credentials - .secret_access_key - .ok_or_else(|| CredentialsError::unhandled("secret access token missing"))?, - sts_credentials.session_token, + sts_credentials.access_key_id, + sts_credentials.secret_access_key, + Some(sts_credentials.session_token), Some(expiration), provider_name, )) diff --git a/aws/sdk-adhoc-test/build.gradle.kts b/aws/sdk-adhoc-test/build.gradle.kts index 907c8f9b87..eb6c8cb76a 100644 --- a/aws/sdk-adhoc-test/build.gradle.kts +++ b/aws/sdk-adhoc-test/build.gradle.kts @@ -35,40 +35,46 @@ dependencies { implementation("software.amazon.smithy:smithy-aws-protocol-tests:$smithyVersion") implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion") implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion") + implementation("software.amazon.smithy:smithy-model:$smithyVersion") } -val allCodegenTests = listOf( - CodegenTest( - "com.amazonaws.apigateway#BackplaneControlService", - "apigateway", - imports = listOf("models/apigateway-rules.smithy"), +fun getNullabilityCheckMode(): String = properties.get("nullability.check.mode") ?: "CLIENT_CAREFUL" + +fun baseTest(service: String, module: String, imports: List = listOf()): CodegenTest { + return CodegenTest( + service = service, + module = module, + imports = imports, + extraCodegenConfig = """ + "includeFluentClient": false, + "nullabilityCheckMode": "${getNullabilityCheckMode()}" + """, extraConfig = """ - , - "codegen": { - "includeFluentClient": false - }, - "customizationConfig": { + , "customizationConfig": { "awsSdk": { - "generateReadme": false + "generateReadme": false, + "requireEndpointResolver": false } } """, + ) +} + +val allCodegenTests = listOf( + baseTest( + "com.amazonaws.apigateway#BackplaneControlService", + "apigateway", + imports = listOf("models/apigateway-rules.smithy"), ), - CodegenTest( + baseTest( "com.amazonaws.testservice#TestService", "endpoint-test-service", imports = listOf("models/single-static-endpoint.smithy"), - extraConfig = """ - , - "codegen": { - "includeFluentClient": false - }, - "customizationConfig": { - "awsSdk": { - "generateReadme": false - } - } - """, + ), + baseTest( + "com.amazonaws.testservice#RequiredValues", + "required-values", + imports = listOf("models/required-value-test.smithy"), ), ) diff --git a/aws/sdk-adhoc-test/models/required-value-test.smithy b/aws/sdk-adhoc-test/models/required-value-test.smithy new file mode 100644 index 0000000000..efb90d9250 --- /dev/null +++ b/aws/sdk-adhoc-test/models/required-value-test.smithy @@ -0,0 +1,28 @@ +$version: "1.0" + +namespace com.amazonaws.testservice + +use aws.api#service +use aws.protocols#restJson1 + +@restJson1 +@title("Test Service") +@service(sdkId: "Test") +@aws.auth#sigv4(name: "test-service") +service RequiredValues { + operations: [TestOperation] +} + +@http(method: "GET", uri: "/") +operation TestOperation { + errors: [Error] +} + +@error("client") +structure Error { + @required + requestId: String + + @required + message: String +} diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/BaseRequestIdDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/BaseRequestIdDecorator.kt index 9b8adeddab..5025ffedab 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/BaseRequestIdDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/BaseRequestIdDecorator.kt @@ -5,7 +5,9 @@ package software.amazon.smithy.rustsdk +import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext import software.amazon.smithy.rust.codegen.client.smithy.ClientRustModule import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator @@ -19,6 +21,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope import software.amazon.smithy.rust.codegen.core.smithy.RustCrate import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderCustomization import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderSection @@ -26,6 +29,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureCusto import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureSection import software.amazon.smithy.rust.codegen.core.smithy.generators.error.ErrorImplCustomization import software.amazon.smithy.rust.codegen.core.smithy.generators.error.ErrorImplSection +import software.amazon.smithy.rust.codegen.core.smithy.isOptional import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticOutputTrait import software.amazon.smithy.rust.codegen.core.util.hasTrait @@ -72,6 +76,10 @@ abstract class BaseRequestIdDecorator : ClientCodegenDecorator { } } + open fun asMemberShape(container: StructureShape): MemberShape? { + return container.members().firstOrNull { member -> member.memberName.lowercase() == "requestid" } + } + private inner class RequestIdOperationCustomization(private val codegenContext: ClientCodegenContext) : OperationCustomization() { override fun section(section: OperationSection): Writable = writable { @@ -82,12 +90,14 @@ abstract class BaseRequestIdDecorator : ClientCodegenDecorator { "apply_to_error" to applyToError(codegenContext), ) } + is OperationSection.MutateOutput -> { rust( "output._set_$fieldName(#T::$accessorFunctionName(${section.responseHeadersName}).map(str::to_string));", accessorTrait(codegenContext), ) } + is OperationSection.BeforeParseResponse -> { rustTemplate( "#{tracing}::debug!($fieldName = ?#{trait}::$accessorFunctionName(${section.responseName}));", @@ -95,6 +105,7 @@ abstract class BaseRequestIdDecorator : ClientCodegenDecorator { "trait" to accessorTrait(codegenContext), ) } + else -> {} } } @@ -123,8 +134,17 @@ abstract class BaseRequestIdDecorator : ClientCodegenDecorator { rustBlock("fn $accessorFunctionName(&self) -> Option<&str>") { rustBlock("match self") { section.allErrors.forEach { error -> + val optional = asMemberShape(error)?.let { member -> + codegenContext.symbolProvider.toSymbol(member).isOptional() + } ?: true + val wrapped = writable { + when (optional) { + false -> rustTemplate("#{Some}(e.$accessorFunctionName())", *preludeScope) + true -> rustTemplate("e.$accessorFunctionName()") + } + } val sym = codegenContext.symbolProvider.toSymbol(error) - rust("Self::${sym.name}(e) => e.$accessorFunctionName(),") + rust("Self::${sym.name}(e) => #T,", wrapped) } rust("Self::Unhandled(e) => e.$accessorFunctionName(),") } diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/s3/S3ExtendedRequestIdDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/s3/S3ExtendedRequestIdDecorator.kt index 6b117b60da..3cd223ae44 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/s3/S3ExtendedRequestIdDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/s3/S3ExtendedRequestIdDecorator.kt @@ -5,6 +5,8 @@ package software.amazon.smithy.rustsdk.customize.s3 +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rustsdk.BaseRequestIdDecorator @@ -17,6 +19,10 @@ class S3ExtendedRequestIdDecorator : BaseRequestIdDecorator() { override val fieldName: String = "extended_request_id" override val accessorFunctionName: String = "extended_request_id" + override fun asMemberShape(container: StructureShape): MemberShape? { + return null + } + private val requestIdModule: RuntimeType = RuntimeType.forInlineDependency(InlineAwsDependency.forRustFile("s3_request_id")) diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/timestream/TimestreamDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/timestream/TimestreamDecorator.kt index 503d8ae4f3..74c19ea048 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/timestream/TimestreamDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/timestream/TimestreamDecorator.kt @@ -60,12 +60,12 @@ class TimestreamDecorator : ClientCodegenDecorator { client.describe_endpoints().send().await.map_err(|e| { #{ResolveEndpointError}::from_source("failed to call describe_endpoints", e) })?; - let endpoint = describe_endpoints.endpoints().unwrap().get(0).unwrap(); + let endpoint = describe_endpoints.endpoints().get(0).unwrap(); let expiry = client.config().time_source().expect("checked when ep discovery was enabled").now() + #{Duration}::from_secs(endpoint.cache_period_in_minutes() as u64 * 60); Ok(( #{Endpoint}::builder() - .url(format!("https://{}", endpoint.address().unwrap())) + .url(format!("https://{}", endpoint.address())) .build(), expiry, )) diff --git a/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/customize/ec2/EC2MakePrimitivesOptionalTest.kt b/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/customize/ec2/EC2MakePrimitivesOptionalTest.kt index 51eba7fa86..ae919497f0 100644 --- a/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/customize/ec2/EC2MakePrimitivesOptionalTest.kt +++ b/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/customize/ec2/EC2MakePrimitivesOptionalTest.kt @@ -6,15 +6,22 @@ package software.amazon.smithy.rustsdk.customize.ec2 import io.kotest.matchers.shouldBe -import org.junit.jupiter.api.Test +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.CsvSource import software.amazon.smithy.model.knowledge.NullableIndex import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.util.lookup internal class EC2MakePrimitivesOptionalTest { - @Test - fun `primitive shapes are boxed`() { + @ParameterizedTest + @CsvSource( + "CLIENT", + "CLIENT_CAREFUL", + "CLIENT_ZERO_VALUE_V1", + "CLIENT_ZERO_VALUE_V1_NO_INPUT", + ) + fun `primitive shapes are boxed`(nullabilityCheckMode: NullableIndex.CheckMode) { val baseModel = """ namespace test structure Primitives { @@ -36,7 +43,7 @@ internal class EC2MakePrimitivesOptionalTest { val nullableIndex = NullableIndex(model) val struct = model.lookup("test#Primitives") struct.members().forEach { - nullableIndex.isMemberNullable(it, NullableIndex.CheckMode.CLIENT_ZERO_VALUE_V1) shouldBe true + nullableIndex.isMemberNullable(it, nullabilityCheckMode) shouldBe true } } } diff --git a/aws/sdk/build.gradle.kts b/aws/sdk/build.gradle.kts index 5ac1ef092c..9a2037642d 100644 --- a/aws/sdk/build.gradle.kts +++ b/aws/sdk/build.gradle.kts @@ -60,6 +60,7 @@ val crateVersioner by lazy { aws.sdk.CrateVersioner.defaultFor(rootProject, prop fun getRustMSRV(): String = properties.get("rust.msrv") ?: throw Exception("Rust MSRV missing") fun getPreviousReleaseVersionManifestPath(): String? = properties.get("aws.sdk.previous.release.versions.manifest") +fun getNullabilityCheckMode(): String = properties.get("nullability.check.mode") ?: "CLIENT_CAREFUL" fun loadServiceMembership(): Membership { val membershipOverride = properties.get("aws.services")?.let { parseMembership(it) } @@ -103,7 +104,8 @@ fun generateSmithyBuild(services: AwsServices): String { "renameErrors": false, "debugMode": $debugMode, "eventStreamAllowList": [$eventStreamAllowListMembers], - "enableUserConfigurableRuntimePlugins": false + "enableUserConfigurableRuntimePlugins": false, + "nullabilityCheckMode": "${getNullabilityCheckMode()}" }, "service": "${service.service}", "module": "$moduleName", diff --git a/aws/sdk/integration-tests/Cargo.toml b/aws/sdk/integration-tests/Cargo.toml index 284bc1bcb1..f18a443839 100644 --- a/aws/sdk/integration-tests/Cargo.toml +++ b/aws/sdk/integration-tests/Cargo.toml @@ -14,7 +14,7 @@ members = [ "s3", "s3control", "sts", - "transcribestreaming", "timestreamquery", + "transcribestreaming", "webassembly", ] diff --git a/aws/sdk/integration-tests/dynamodb/tests/movies.rs b/aws/sdk/integration-tests/dynamodb/tests/movies.rs index a3eaa244ad..7b045c6f5b 100644 --- a/aws/sdk/integration-tests/dynamodb/tests/movies.rs +++ b/aws/sdk/integration-tests/dynamodb/tests/movies.rs @@ -28,31 +28,36 @@ async fn create_table(client: &Client, table_name: &str) { KeySchemaElement::builder() .attribute_name("year") .key_type(KeyType::Hash) - .build(), + .build() + .unwrap(), ) .key_schema( KeySchemaElement::builder() .attribute_name("title") .key_type(KeyType::Range) - .build(), + .build() + .unwrap(), ) .attribute_definitions( AttributeDefinition::builder() .attribute_name("year") .attribute_type(ScalarAttributeType::N) - .build(), + .build() + .unwrap(), ) .attribute_definitions( AttributeDefinition::builder() .attribute_name("title") .attribute_type(ScalarAttributeType::S) - .build(), + .build() + .unwrap(), ) .provisioned_throughput( ProvisionedThroughput::builder() .read_capacity_units(10) .write_capacity_units(10) - .build(), + .build() + .unwrap(), ) .send() .await diff --git a/aws/sdk/integration-tests/qldbsession/tests/integration.rs b/aws/sdk/integration-tests/qldbsession/tests/integration.rs index b73dea2fc7..816f3cd8fb 100644 --- a/aws/sdk/integration-tests/qldbsession/tests/integration.rs +++ b/aws/sdk/integration-tests/qldbsession/tests/integration.rs @@ -3,19 +3,14 @@ * SPDX-License-Identifier: Apache-2.0 */ -use aws_sdk_qldbsession as qldbsession; +use aws_sdk_qldbsession::config::{Config, Credentials, Region}; +use aws_sdk_qldbsession::types::StartSessionRequest; +use aws_sdk_qldbsession::Client; use aws_smithy_client::test_connection::TestConnection; use aws_smithy_http::body::SdkBody; use http::Uri; -use qldbsession::config::{Config, Credentials, Region}; -use qldbsession::types::StartSessionRequest; -use qldbsession::Client; use std::time::{Duration, UNIX_EPOCH}; -// TODO(DVR): having the full HTTP requests right in the code is a bit gross, consider something -// like https://github.com/davidbarsky/sigv4/blob/master/aws-sigv4/src/lib.rs#L283-L315 to store -// the requests/responses externally - #[tokio::test] async fn signv4_use_correct_service_name() { let conn = TestConnection::new(vec![( @@ -46,7 +41,8 @@ async fn signv4_use_correct_service_name() { .start_session( StartSessionRequest::builder() .ledger_name("not-real-ledger") - .build(), + .build() + .unwrap(), ) .customize() .await diff --git a/codegen-client-test/build.gradle.kts b/codegen-client-test/build.gradle.kts index d7173e31d7..0d39e4b754 100644 --- a/codegen-client-test/build.gradle.kts +++ b/codegen-client-test/build.gradle.kts @@ -106,6 +106,8 @@ val allCodegenTests = listOf( "pokemon-service-awsjson-client", dependsOn = listOf("pokemon-awsjson.smithy", "pokemon-common.smithy"), ), + ClientTest("aws.protocoltests.json#RequiredValueJson", "required-values-json"), + ClientTest("aws.protocoltests.json#RequiredValueXml", "required-values-xml"), ).map(ClientTest::toCodegenTest) project.registerGenerateSmithyBuildTask(rootProject, pluginName, allCodegenTests) diff --git a/codegen-client-test/model/error-correction-nullability-test.smithy b/codegen-client-test/model/error-correction-nullability-test.smithy new file mode 100644 index 0000000000..8a125d3004 --- /dev/null +++ b/codegen-client-test/model/error-correction-nullability-test.smithy @@ -0,0 +1,128 @@ +$version: "2.0" + + +namespace aws.protocoltests.json + +use aws.protocols#awsJson1_0 +use aws.protocols#restXml +use smithy.test#httpResponseTests + +@awsJson1_0 +service RequiredValueJson { + operations: [SayHello], + version: "1" +} + + +@restXml +service RequiredValueXml { + operations: [SayHelloXml], + version: "1" +} + +@error("client") +structure Error { + @required + requestId: String + + @required + message: String +} + +@http(method: "POST", uri: "/") +operation SayHello { output: TestOutputDocument, errors: [Error] } + +@http(method: "POST", uri: "/") +operation SayHelloXml { output: TestOutput, errors: [Error] } + +structure TestOutputDocument with [TestStruct] { innerField: Nested, @required document: Document } +structure TestOutput with [TestStruct] { innerField: Nested } + +@mixin +structure TestStruct { + @required + foo: String, + @required + byteValue: Byte, + @required + listValue: StringList, + @required + mapValue: ListMap, + @required + doubleListValue: DoubleList + @required + nested: Nested + @required + blob: Blob + @required + enum: Enum + @required + union: U + notRequired: String +} + +enum Enum { + A, + B, + C +} +union U { + A: Integer, + B: String, + C: Unit +} + +structure Nested { + @required + a: String +} + +list StringList { + member: String +} + +list DoubleList { + member: StringList +} + +map ListMap { + key: String, + value: StringList +} + +apply SayHello @httpResponseTests([{ + id: "error_recovery_json", + protocol: awsJson1_0, + params: { + union: { A: 5 }, + enum: "A", + foo: "", + byteValue: 0, + blob: "", + listValue: [], + mapValue: {}, + doubleListValue: [] + document: null + nested: { a: "" } + }, + code: 200, + body: "{\"union\": { \"A\": 5 }, \"enum\": \"A\" }" + }]) + +apply SayHelloXml @httpResponseTests([{ + id: "error_recovery_xml", + protocol: restXml, + params: { + union: { A: 5 }, + enum: "A", + foo: "", + byteValue: 0, + blob: "", + listValue: [], + mapValue: {}, + doubleListValue: [] + nested: { a: "" } + }, + code: 200, + body: "5A" + }]) diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenVisitor.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenVisitor.kt index 987a1df85c..efeef74e4b 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenVisitor.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenVisitor.kt @@ -7,7 +7,6 @@ package software.amazon.smithy.rust.codegen.client.smithy import software.amazon.smithy.build.PluginContext import software.amazon.smithy.model.Model -import software.amazon.smithy.model.knowledge.NullableIndex import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.model.shapes.Shape @@ -77,10 +76,11 @@ class ClientCodegenVisitor( val rustSymbolProviderConfig = RustSymbolProviderConfig( runtimeConfig = settings.runtimeConfig, renameExceptions = settings.codegenConfig.renameExceptions, - nullabilityCheckMode = NullableIndex.CheckMode.CLIENT_ZERO_VALUE_V1, + nullabilityCheckMode = settings.codegenConfig.nullabilityCheckMode, moduleProvider = ClientModuleProvider, nameBuilderFor = { symbol -> "${symbol.name}Builder" }, ) + val baseModel = baselineTransform(context.model) val untransformedService = settings.getService(baseModel) val (protocol, generator) = ClientProtocolLoader( diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientRustSettings.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientRustSettings.kt index dc6fd4f028..acc0a59184 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientRustSettings.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientRustSettings.kt @@ -6,6 +6,7 @@ package software.amazon.smithy.rust.codegen.client.smithy import software.amazon.smithy.model.Model +import software.amazon.smithy.model.knowledge.NullableIndex import software.amazon.smithy.model.node.ObjectNode import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.rust.codegen.core.smithy.CODEGEN_SETTINGS @@ -81,6 +82,7 @@ data class ClientRustSettings( data class ClientCodegenConfig( override val formatTimeoutSeconds: Int = defaultFormatTimeoutSeconds, override val debugMode: Boolean = defaultDebugMode, + val nullabilityCheckMode: NullableIndex.CheckMode = NullableIndex.CheckMode.CLIENT, val renameExceptions: Boolean = defaultRenameExceptions, val includeFluentClient: Boolean = defaultIncludeFluentClient, val addMessageToErrors: Boolean = defaultAddMessageToErrors, @@ -99,6 +101,7 @@ data class ClientCodegenConfig( private val defaultEventStreamAllowList: Set = emptySet() private const val defaultIncludeEndpointUrlConfig = true private const val defaultEnableUserConfigurableRuntimePlugins = true + private const val defaultNullabilityCheckMode = "CLIENT" fun fromCodegenConfigAndNode(coreCodegenConfig: CoreCodegenConfig, node: Optional) = if (node.isPresent) { @@ -115,11 +118,13 @@ data class ClientCodegenConfig( addMessageToErrors = node.get().getBooleanMemberOrDefault("addMessageToErrors", defaultAddMessageToErrors), includeEndpointUrlConfig = node.get().getBooleanMemberOrDefault("includeEndpointUrlConfig", defaultIncludeEndpointUrlConfig), enableUserConfigurableRuntimePlugins = node.get().getBooleanMemberOrDefault("enableUserConfigurableRuntimePlugins", defaultEnableUserConfigurableRuntimePlugins), + nullabilityCheckMode = NullableIndex.CheckMode.valueOf(node.get().getStringMemberOrDefault("nullabilityCheckMode", defaultNullabilityCheckMode)), ) } else { ClientCodegenConfig( formatTimeoutSeconds = coreCodegenConfig.formatTimeoutSeconds, debugMode = coreCodegenConfig.debugMode, + nullabilityCheckMode = NullableIndex.CheckMode.valueOf(defaultNullabilityCheckMode), ) } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/IdempotencyTokenGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/IdempotencyTokenGenerator.kt index eead91bbe0..dd1bcfef08 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/IdempotencyTokenGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/IdempotencyTokenGenerator.kt @@ -19,6 +19,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.toType import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope +import software.amazon.smithy.rust.codegen.core.smithy.isOptional import software.amazon.smithy.rust.codegen.core.util.findMemberWithTrait import software.amazon.smithy.rust.codegen.core.util.inputShape @@ -37,28 +38,51 @@ class IdempotencyTokenGenerator( return emptySection } val memberName = symbolProvider.toMemberName(idempotencyTokenMember) + val codegenScope = arrayOf( + *preludeScope, + "Input" to symbolProvider.toSymbol(inputShape), + "IdempotencyTokenRuntimePlugin" to + InlineDependency.forRustFile( + RustModule.pubCrate("client_idempotency_token", parent = ClientRustModule.root), + "/inlineable/src/client_idempotency_token.rs", + CargoDependency.smithyRuntimeApi(runtimeConfig), + CargoDependency.smithyTypes(runtimeConfig), + ).toType().resolve("IdempotencyTokenRuntimePlugin"), + ) + return when (section) { is OperationSection.AdditionalRuntimePlugins -> writable { section.addOperationRuntimePlugin(this) { - rustTemplate( - """ - #{IdempotencyTokenRuntimePlugin}::new(|token_provider, input| { - let input: &mut #{Input} = input.downcast_mut().expect("correct type"); - if input.$memberName.is_none() { - input.$memberName = #{Some}(token_provider.make_idempotency_token()); - } - }) - """, - *preludeScope, - "Input" to symbolProvider.toSymbol(inputShape), - "IdempotencyTokenRuntimePlugin" to - InlineDependency.forRustFile( - RustModule.pubCrate("client_idempotency_token", parent = ClientRustModule.root), - "/inlineable/src/client_idempotency_token.rs", - CargoDependency.smithyRuntimeApi(runtimeConfig), - CargoDependency.smithyTypes(runtimeConfig), - ).toType().resolve("IdempotencyTokenRuntimePlugin"), - ) + if (symbolProvider.toSymbol(idempotencyTokenMember).isOptional()) { + // An idempotency token is optional. If the user didn't specify a token + // then we'll generate one and set it. + rustTemplate( + """ + #{IdempotencyTokenRuntimePlugin}::new(|token_provider, input| { + let input: &mut #{Input} = input.downcast_mut().expect("correct type"); + if input.$memberName.is_none() { + input.$memberName = #{Some}(token_provider.make_idempotency_token()); + } + }) + """, + *codegenScope, + ) + } else { + // An idempotency token is required, but it'll be set to an empty string if + // the user didn't specify one. If that's the case, then we'll generate one + // and set it. + rustTemplate( + """ + #{IdempotencyTokenRuntimePlugin}::new(|token_provider, input| { + let input: &mut #{Input} = input.downcast_mut().expect("correct type"); + if input.$memberName.is_empty() { + input.$memberName = token_provider.make_idempotency_token(); + } + }) + """, + *codegenScope, + ) + } } } else -> emptySection diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientBuilderInstantiator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientBuilderInstantiator.kt index 34d315a3f3..453568c8c5 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientBuilderInstantiator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientBuilderInstantiator.kt @@ -34,11 +34,11 @@ class ClientBuilderInstantiator(private val clientCodegenContext: ClientCodegenC } if (BuilderGenerator.hasFallibleBuilder(shape, clientCodegenContext.symbolProvider)) { rustTemplate( - "#{builder}.build()#{mapErr}?", + "#{builder}.build()#{mapErr}", "builder" to builderW, "mapErr" to ( mapErr?.map { - rust(".map_err(#T)", it) + rust(".map_err(#T)?", it) } ?: writable { } ), ) diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ErrorCorrection.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ErrorCorrection.kt index 6f212014de..b05c014383 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ErrorCorrection.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ErrorCorrection.kt @@ -58,8 +58,15 @@ private fun ClientCodegenContext.errorCorrectedDefault(member: MemberShape): Wri val instantiator = PrimitiveInstantiator(runtimeConfig, symbolProvider) return writable { when { - target is EnumShape || target.hasTrait() -> rustTemplate(""""no value was set".parse::<#{Shape}>().ok()""", "Shape" to targetSymbol) - target is BooleanShape || target is NumberShape || target is StringShape || target is DocumentShape || target is ListShape || target is MapShape -> rust("Some(Default::default())") + target is EnumShape || target.hasTrait() -> rustTemplate( + """"no value was set".parse::<#{Shape}>().ok()""", + "Shape" to targetSymbol, + ) + + target is BooleanShape || target is NumberShape || target is StringShape || target is DocumentShape || target is ListShape || target is MapShape -> rust( + "Some(Default::default())", + ) + target is StructureShape -> rustTemplate( "{ let builder = #{Builder}::default(); #{instantiate} }", "Builder" to symbolProvider.symbolForBuilder(target), @@ -73,6 +80,7 @@ private fun ClientCodegenContext.errorCorrectedDefault(member: MemberShape): Wri it.plus { rustTemplate(".map(#{Box}::new)", *preludeScope) } }, ) + target is TimestampShape -> instantiator.instantiate(target, Node.from(0)).some()(this) target is BlobShape -> instantiator.instantiate(target, Node.from("")).some()(this) target is UnionShape -> rust("Some(#T::Unknown)", targetSymbol) diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolParserGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolParserGenerator.kt index 608cd5869a..17632df46f 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolParserGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolParserGenerator.kt @@ -26,7 +26,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.customize.writeCustomizations -import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingDescriptor import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation @@ -131,7 +130,7 @@ class ProtocolParserGenerator( withBlock("Err(match error_code {", "})") { val errors = operationShape.operationErrors(model) errors.forEach { error -> - val errorShape = model.expectShape(error.id, software.amazon.smithy.model.shapes.StructureShape::class.java) + val errorShape = model.expectShape(error.id, StructureShape::class.java) val variantName = symbolProvider.toSymbol(model.expectShape(error.id)).name val errorCode = httpBindingResolver.errorCode(errorShape).dq() withBlock( @@ -139,7 +138,7 @@ class ProtocolParserGenerator( "}),", errorSymbol, ) { - software.amazon.smithy.rust.codegen.core.rustlang.Attribute.AllowUnusedMut.render(this) + Attribute.AllowUnusedMut.render(this) assignment("mut tmp") { rustBlock("") { renderShapeParser( @@ -159,14 +158,18 @@ class ProtocolParserGenerator( ) } } - if (errorShape.errorMessageMember() != null) { - rust( - """ - if tmp.message.is_none() { - tmp.message = _error_message; - } - """, - ) + val errorMessageMember = errorShape.errorMessageMember() + // If the message member is optional and wasn't set, we set a generic error message. + if (errorMessageMember != null) { + if (errorMessageMember.isOptional) { + rust( + """ + if tmp.message.is_none() { + tmp.message = _error_message; + } + """, + ) + } } rust("tmp") } @@ -257,18 +260,15 @@ class ProtocolParserGenerator( } } - val err = if (BuilderGenerator.hasFallibleBuilder(outputShape, symbolProvider)) { - ".map_err(${format(errorSymbol)}::unhandled)?" - } else { - "" + val mapErr = writable { + rust("#T::unhandled", errorSymbol) } writeCustomizations( customizations, OperationSection.MutateOutput(customizations, operationShape, "_response_headers"), ) - - rust("output.build()$err") + codegenContext.builderInstantiator().finalizeBuilder("output", outputShape, mapErr)(this) } /** diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGenerator.kt index e7ccd10507..f2a79906de 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGenerator.kt @@ -435,7 +435,7 @@ class DefaultProtocolTestGenerator( // When we generate a body instead of a stub, drop the trailing `;` and enable the assertion assertOk(rustWriter) { rustWriter.write( - "#T(&body, ${ + "#T(body, ${ rustWriter.escape(body).dq() }, #T::from(${(mediaType ?: "unknown").dq()}))", RT.protocolTest(rc, "validate_body"), diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/ClientProtocolLoader.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/ClientProtocolLoader.kt index d7cdd7594c..cb9c7db7e5 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/ClientProtocolLoader.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/ClientProtocolLoader.kt @@ -63,9 +63,9 @@ private class ClientAwsJsonFactory(private val version: AwsJsonVersion) : ProtocolGeneratorFactory { override fun protocol(codegenContext: ClientCodegenContext): Protocol = if (compatibleWithAwsQuery(codegenContext.serviceShape, version)) { - AwsQueryCompatible(codegenContext, AwsJson(codegenContext, version, codegenContext.builderInstantiator())) + AwsQueryCompatible(codegenContext, AwsJson(codegenContext, version)) } else { - AwsJson(codegenContext, version, codegenContext.builderInstantiator()) + AwsJson(codegenContext, version) } override fun buildProtocolGenerator(codegenContext: ClientCodegenContext): OperationGenerator = diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/EndpointsDecoratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/EndpointsDecoratorTest.kt index d6878f1b96..390e9d2880 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/EndpointsDecoratorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/EndpointsDecoratorTest.kt @@ -109,6 +109,7 @@ class EndpointsDecoratorTest { input: TestOperationInput } + @input structure TestOperationInput { @contextParam(name: "Bucket") @required diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ErrorCorrectionTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ErrorCorrectionTest.kt index 5849010312..3cdc523d9c 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ErrorCorrectionTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ErrorCorrectionTest.kt @@ -97,24 +97,24 @@ class ErrorCorrectionTest { rustTemplate( """ let builder = #{correct_errors}(#{Shape}::builder().foo("abcd")); - let shape = builder.build(); + let shape = builder.build().unwrap(); // don't override a field already set - assert_eq!(shape.foo(), Some("abcd")); + assert_eq!(shape.foo(), "abcd"); // set nested fields - assert_eq!(shape.nested().unwrap().a(), Some("")); + assert_eq!(shape.nested().a(), ""); // don't default non-required fields assert_eq!(shape.not_required(), None); // set defaults for everything else - assert_eq!(shape.blob().unwrap().as_ref(), &[]); + assert_eq!(shape.blob().as_ref(), &[]); - assert_eq!(shape.list_value(), Some(&[][..])); - assert!(shape.map_value().unwrap().is_empty()); - assert_eq!(shape.double_list_value(), Some(&[][..])); + assert!(shape.list_value().is_empty()); + assert!(shape.map_value().is_empty()); + assert!(shape.double_list_value().is_empty()); // enums and unions become unknown variants - assert!(matches!(shape.r##enum(), Some(crate::types::Enum::Unknown(_)))); - assert!(shape.union().unwrap().is_unknown()); + assert!(matches!(shape.r##enum(), crate::types::Enum::Unknown(_))); + assert!(shape.union().is_unknown()); """, *codegenCtx, ) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustType.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustType.kt index 36e0dbdb6a..1ada2d199c 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustType.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustType.kt @@ -494,6 +494,7 @@ class Attribute(val inner: Writable, val isDeriveHelper: Boolean = false) { } companion object { + val AllowNeedlessQuestionMark = Attribute(allow("clippy::needless_question_mark")) val AllowClippyBoxedLocal = Attribute(allow("clippy::boxed_local")) val AllowClippyLetAndReturn = Attribute(allow("clippy::let_and_return")) val AllowClippyNeedlessBorrow = Attribute(allow("clippy::needless_borrow")) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CoreRustSettings.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CoreRustSettings.kt index 72db4046d7..b477ab5607 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CoreRustSettings.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CoreRustSettings.kt @@ -49,8 +49,8 @@ open class CoreCodegenConfig( fun fromNode(node: Optional): CoreCodegenConfig = if (node.isPresent) { CoreCodegenConfig( - node.get().getNumberMemberOrDefault("formatTimeoutSeconds", defaultFormatTimeoutSeconds).toInt(), - node.get().getBooleanMemberOrDefault("debugMode", defaultDebugMode), + formatTimeoutSeconds = node.get().getNumberMemberOrDefault("formatTimeoutSeconds", defaultFormatTimeoutSeconds).toInt(), + debugMode = node.get().getBooleanMemberOrDefault("debugMode", defaultDebugMode), ) } else { CoreCodegenConfig( diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolVisitor.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolVisitor.kt index 31659a3744..ec83941c8f 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolVisitor.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolVisitor.kt @@ -192,7 +192,7 @@ open class SymbolVisitor( val rustType = RustType.Opaque(shape.contextName(serviceShape).toPascalCase()) symbolBuilder(shape, rustType).locatedIn(moduleForShape(shape)).build() } else { - simpleShape(shape) + symbolBuilder(shape, RustType.String).build() } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGenerator.kt index 9a952ebe40..26a607f497 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGenerator.kt @@ -87,6 +87,8 @@ fun MemberShape.enforceRequired( return field } val shape = this + val isOptional = codegenContext.symbolProvider.toSymbol(shape).isOptional() + val field = field.letIf(!isOptional) { field.map { rust("Some(#T)", it) } } val error = OperationBuildError(codegenContext.runtimeConfig).missingField( codegenContext.symbolProvider.toMemberName(shape), "A required field was not set", ) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt index 799aa1a000..a01ad09545 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt @@ -141,7 +141,7 @@ open class Instantiator( private fun renderMember(writer: RustWriter, memberShape: MemberShape, data: Node, ctx: Ctx) { val targetShape = model.expectShape(memberShape.target) val symbol = symbolProvider.toSymbol(memberShape) - if (data is NullNode) { + if (data is NullNode && !targetShape.isDocumentShape) { check(symbol.isOptional()) { "A null node was provided for $memberShape but the symbol was not optional. This is invalid input data." } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt index 0b3178095c..0a53422552 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt @@ -18,7 +18,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderInstantiator import software.amazon.smithy.rust.codegen.core.smithy.generators.serializationError import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.JsonParserGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator @@ -123,7 +122,6 @@ class AwsJsonSerializerGenerator( open class AwsJson( val codegenContext: CodegenContext, val awsJsonVersion: AwsJsonVersion, - val builderInstantiator: BuilderInstantiator, ) : Protocol { private val runtimeConfig = codegenContext.runtimeConfig private val errorScope = arrayOf( diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGenerator.kt index b7205562d1..03caf527a8 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGenerator.kt @@ -476,6 +476,7 @@ class XmlBindingTraitParserGenerator( private fun RustWriter.parseStructure(shape: StructureShape, ctx: Ctx) { val symbol = symbolProvider.toSymbol(shape) val nestedParser = protocolFunctions.deserializeFn(shape) { fnName -> + Attribute.AllowNeedlessQuestionMark.render(this) rustBlockTemplate( "pub fn $fnName(decoder: &mut #{ScopedDecoder}) -> Result<#{Shape}, #{XmlDecodeError}>", *codegenScope, "Shape" to symbol, @@ -493,7 +494,7 @@ class XmlBindingTraitParserGenerator( shape, mapErr = { rustTemplate( - """.map_err(|_|#{XmlDecodeError}::custom("missing field"))?""", + """|_|#{XmlDecodeError}::custom("missing field")""", *codegenScope, ) }, diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/QuerySerializerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/QuerySerializerGenerator.kt index 1a5b37bc74..c6adbd3bec 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/QuerySerializerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/QuerySerializerGenerator.kt @@ -5,6 +5,7 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize +import software.amazon.smithy.model.knowledge.NullableIndex import software.amazon.smithy.model.shapes.BlobShape import software.amazon.smithy.model.shapes.BooleanShape import software.amazon.smithy.model.shapes.CollectionShape @@ -44,13 +45,14 @@ import software.amazon.smithy.rust.codegen.core.util.inputShape import software.amazon.smithy.rust.codegen.core.util.isTargetUnit import software.amazon.smithy.rust.codegen.core.util.orNull -abstract class QuerySerializerGenerator(codegenContext: CodegenContext) : StructuredDataSerializerGenerator { +abstract class QuerySerializerGenerator(private val codegenContext: CodegenContext) : StructuredDataSerializerGenerator { protected data class Context( /** Expression that yields a QueryValueWriter */ val writerExpression: String, /** Expression representing the value to write to the QueryValueWriter */ val valueExpression: ValueExpression, val shape: T, + val isOptional: Boolean = false, ) protected data class MemberContext( @@ -88,6 +90,7 @@ abstract class QuerySerializerGenerator(codegenContext: CodegenContext) : Struct protected val model = codegenContext.model protected val symbolProvider = codegenContext.symbolProvider protected val runtimeConfig = codegenContext.runtimeConfig + private val nullableIndex = NullableIndex(model) private val target = codegenContext.target private val serviceShape = codegenContext.serviceShape private val serializerError = runtimeConfig.serializationError() @@ -118,7 +121,7 @@ abstract class QuerySerializerGenerator(codegenContext: CodegenContext) : Struct } override fun unsetStructure(structure: StructureShape): RuntimeType { - TODO("AwsQuery doesn't support payload serialization") + TODO("$protocolName doesn't support payload serialization") } override fun unsetUnion(union: UnionShape): RuntimeType { @@ -179,7 +182,8 @@ abstract class QuerySerializerGenerator(codegenContext: CodegenContext) : Struct rust("Ok(())") } } - rust("#T(${context.writerExpression}, ${context.valueExpression.name})?;", structureSerializer) + + rust("#T(${context.writerExpression}, ${context.valueExpression.asRef()})?;", structureSerializer) } private fun RustWriter.serializeStructureInner(context: Context) { @@ -216,9 +220,11 @@ abstract class QuerySerializerGenerator(codegenContext: CodegenContext) : Struct val writer = context.writerExpression val value = context.valueExpression when (target) { - is StringShape -> when (target.hasTrait()) { - true -> rust("$writer.string(${value.name}.as_str());") - false -> rust("$writer.string(${value.name});") + is StringShape -> { + when (target.hasTrait()) { + true -> rust("$writer.string(${value.name}.as_str());") + false -> rust("$writer.string(${value.asRef()});") + } } is BooleanShape -> rust("$writer.boolean(${value.asValue()});") is NumberShape -> { @@ -234,13 +240,13 @@ abstract class QuerySerializerGenerator(codegenContext: CodegenContext) : Struct ) } is BlobShape -> rust( - "$writer.string(&#T(${value.name}));", + "$writer.string(&#T(${value.asRef()}));", RuntimeType.base64Encode(runtimeConfig), ) is TimestampShape -> { val timestampFormat = determineTimestampFormat(context.shape) val timestampFormatType = RuntimeType.serializeTimestampFormat(runtimeConfig, timestampFormat) - rust("$writer.date_time(${value.name}, #T)?;", timestampFormatType) + rust("$writer.date_time(${value.asRef()}, #T)?;", timestampFormatType) } is CollectionShape -> serializeCollection(context, Context(writer, context.valueExpression, target)) is MapShape -> serializeMap(context, Context(writer, context.valueExpression, target)) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/DefaultBuilderInstantiator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/DefaultBuilderInstantiator.kt index 6147f84e2f..96af195c76 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/DefaultBuilderInstantiator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/DefaultBuilderInstantiator.kt @@ -10,18 +10,28 @@ import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderInstantiator /** * A Default instantiator that uses `builder.build()` in all cases. This exists to support tests in codegen-core * and to serve as the base behavior for client and server instantiators. */ -class DefaultBuilderInstantiator : BuilderInstantiator { +class DefaultBuilderInstantiator(private val checkFallibleBuilder: Boolean, private val symbolProvider: RustSymbolProvider) : BuilderInstantiator { override fun setField(builder: String, value: Writable, field: MemberShape): Writable { return setFieldWithSetter(builder, value, field) } override fun finalizeBuilder(builder: String, shape: StructureShape, mapErr: Writable?): Writable { - return writable { rust("builder.build()") } + return writable { + rust("builder.build()") + if (checkFallibleBuilder && BuilderGenerator.hasFallibleBuilder(shape, symbolProvider)) { + if (mapErr != null) { + rust(".map_err(#T)", mapErr) + } + rust("?") + } + } } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt index d6b43b97cb..e944a552a0 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt @@ -145,7 +145,7 @@ object EventStreamTestModels { validTestUnion = """{"Foo":"hello"}""", validSomeError = """{"Message":"some error"}""", validUnmodeledError = """{"Message":"unmodeled error"}""", - ) { AwsJson(it, AwsJsonVersion.Json11, builderInstantiator = DefaultBuilderInstantiator()) }, + ) { AwsJson(it, AwsJsonVersion.Json11) }, // // restXml diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/TestHelpers.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/TestHelpers.kt index ae0398ae92..87ebb679cc 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/TestHelpers.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/TestHelpers.kt @@ -105,10 +105,10 @@ private object CodegenCoreTestModules { } } -val TestRustSymbolProviderConfig = RustSymbolProviderConfig( +fun testRustSymbolProviderConfig(nullabilityCheckMode: NullableIndex.CheckMode) = RustSymbolProviderConfig( runtimeConfig = TestRuntimeConfig, renameExceptions = true, - nullabilityCheckMode = NullableIndex.CheckMode.CLIENT_ZERO_VALUE_V1, + nullabilityCheckMode = nullabilityCheckMode, moduleProvider = CodegenCoreTestModules.TestModuleProvider, ) @@ -147,12 +147,12 @@ fun String.asSmithyModel(sourceLocation: String? = null, smithyVersion: String = internal fun testSymbolProvider( model: Model, rustReservedWordConfig: RustReservedWordConfig? = null, - config: RustSymbolProviderConfig = TestRustSymbolProviderConfig, + nullabilityCheckMode: NullableIndex.CheckMode = NullableIndex.CheckMode.CLIENT, ): RustSymbolProvider = SymbolVisitor( testRustSettings(), model, ServiceShape.builder().version("test").id("test#Service").build(), - config, + testRustSymbolProviderConfig(nullabilityCheckMode), ).let { BaseSymbolMetadataProvider(it, additionalAttributes = listOf(Attribute.NonExhaustive)) } .let { RustReservedWordSymbolProvider( @@ -167,9 +167,10 @@ internal fun testCodegenContext( serviceShape: ServiceShape? = null, settings: CoreRustSettings = testRustSettings(), codegenTarget: CodegenTarget = CodegenTarget.CLIENT, + nullabilityCheckMode: NullableIndex.CheckMode = NullableIndex.CheckMode.CLIENT, ): CodegenContext = object : CodegenContext( model, - testSymbolProvider(model), + testSymbolProvider(model, nullabilityCheckMode = nullabilityCheckMode), TestModuleDocProvider, serviceShape ?: model.serviceShapes.firstOrNull() @@ -179,14 +180,13 @@ internal fun testCodegenContext( codegenTarget, ) { override fun builderInstantiator(): BuilderInstantiator { - return DefaultBuilderInstantiator() + return DefaultBuilderInstantiator(codegenTarget == CodegenTarget.CLIENT, symbolProvider) } } /** * In tests, we frequently need to generate a struct, a builder, and an impl block to access said builder. */ - fun StructureShape.renderWithModelBuilder( model: Model, symbolProvider: RustSymbolProvider, diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustReservedWordsTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustReservedWordsTest.kt index 25e47e0963..ec97799af4 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustReservedWordsTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustReservedWordsTest.kt @@ -8,20 +8,21 @@ package software.amazon.smithy.rust.codegen.core.rustlang import io.kotest.matchers.shouldBe import org.junit.jupiter.api.Test import software.amazon.smithy.model.Model +import software.amazon.smithy.model.knowledge.NullableIndex import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.rust.codegen.core.smithy.MaybeRenamed import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitor import software.amazon.smithy.rust.codegen.core.smithy.WrappingSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.renamedFrom -import software.amazon.smithy.rust.codegen.core.testutil.TestRustSymbolProviderConfig import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.testRustSettings +import software.amazon.smithy.rust.codegen.core.testutil.testRustSymbolProviderConfig import software.amazon.smithy.rust.codegen.core.util.lookup internal class RustReservedWordSymbolProviderTest { - private class TestSymbolProvider(model: Model) : - WrappingSymbolProvider(SymbolVisitor(testRustSettings(), model, null, TestRustSymbolProviderConfig)) + private class TestSymbolProvider(model: Model, nullabilityCheckMode: NullableIndex.CheckMode) : + WrappingSymbolProvider(SymbolVisitor(testRustSettings(), model, null, testRustSymbolProviderConfig(nullabilityCheckMode))) private val emptyConfig = RustReservedWordConfig(emptyMap(), emptyMap(), emptyMap()) @Test @@ -30,13 +31,13 @@ internal class RustReservedWordSymbolProviderTest { namespace test structure Self {} """.asSmithyModel() - val provider = RustReservedWordSymbolProvider(TestSymbolProvider(model), emptyConfig) + val provider = RustReservedWordSymbolProvider(TestSymbolProvider(model, NullableIndex.CheckMode.CLIENT), emptyConfig) val symbol = provider.toSymbol(model.lookup("test#Self")) symbol.name shouldBe "SelfValue" } private fun mappingTest(config: RustReservedWordConfig, model: Model, id: String, test: (String) -> Unit) { - val provider = RustReservedWordSymbolProvider(TestSymbolProvider(model), config) + val provider = RustReservedWordSymbolProvider(TestSymbolProvider(model, NullableIndex.CheckMode.CLIENT), config) val symbol = provider.toMemberName(model.lookup("test#Container\$$id")) test(symbol) } @@ -132,7 +133,7 @@ internal class RustReservedWordSymbolProviderTest { async: String } """.asSmithyModel() - val provider = RustReservedWordSymbolProvider(TestSymbolProvider(model), emptyConfig) + val provider = RustReservedWordSymbolProvider(TestSymbolProvider(model, NullableIndex.CheckMode.CLIENT), emptyConfig) provider.toMemberName( MemberShape.builder().id("namespace#container\$async").target("namespace#Integer").build(), ) shouldBe "r##async" @@ -149,7 +150,7 @@ internal class RustReservedWordSymbolProviderTest { @enum([{ name: "dontcare", value: "dontcare" }]) string Container """.asSmithyModel() val provider = RustReservedWordSymbolProvider( - TestSymbolProvider(model), + TestSymbolProvider(model, NullableIndex.CheckMode.CLIENT), reservedWordConfig = emptyConfig.copy( enumMemberMap = mapOf( "Unknown" to "UnknownValue", diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGeneratorTest.kt index 3e1886a195..06888429e6 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGeneratorTest.kt @@ -17,7 +17,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.Default import software.amazon.smithy.rust.codegen.core.smithy.WrappingSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.setDefault -import software.amazon.smithy.rust.codegen.core.testutil.TestRustSymbolProviderConfig import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest @@ -184,7 +183,7 @@ internal class BuilderGeneratorTest { val provider = testSymbolProvider( model, rustReservedWordConfig = StructureGeneratorTest.rustReservedWordConfig, - config = TestRustSymbolProviderConfig.copy(nullabilityCheckMode = NullableIndex.CheckMode.CLIENT_CAREFUL), + nullabilityCheckMode = NullableIndex.CheckMode.CLIENT_CAREFUL, ) val project = TestWorkspace.testProject(provider) val shape: StructureShape = model.lookup("com.test#MyStruct") diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGeneratorTest.kt index 77932ddfac..af7ff639b5 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGeneratorTest.kt @@ -84,7 +84,6 @@ class StructureGeneratorTest { val credentials = model.lookup("com.test#Credentials") val secretStructure = model.lookup("com.test#SecretStructure") val structWithInnerSecretStructure = model.lookup("com.test#StructWithInnerSecretStructure") - val error = model.lookup("com.test#MyError") val rustReservedWordConfig: RustReservedWordConfig = RustReservedWordConfig( structureMemberMap = StructureGenerator.structureMemberNameMap, diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGeneratorTest.kt index 79435b5e9b..108207473f 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGeneratorTest.kt @@ -149,7 +149,7 @@ class JsonParserGeneratorTest { let top = output.top.expect("top"); assert_eq!(Some(45), top.extra); assert_eq!(Some("something".to_string()), top.field); - assert_eq!(Some(Choice::Int(5)), top.choice); + assert_eq!(Choice::Int(5), top.choice); """, ) unitTest( @@ -166,7 +166,7 @@ class JsonParserGeneratorTest { // unknown variant let input = br#"{ "top": { "choice": { "somenewvariant": "data" } } }"#; let output = ${format(operationGenerator)}(input, test_output::OpOutput::builder()).unwrap().build(); - assert!(output.top.unwrap().choice.unwrap().is_unknown()); + assert!(output.top.unwrap().choice.is_unknown()); """, ) diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGeneratorTest.kt index 47f310e83d..0d78af182e 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGeneratorTest.kt @@ -81,7 +81,7 @@ internal class XmlBindingTraitParserGeneratorTest { extra: Long, @xmlName("prefix:local") - renamedWithPrefix: String + renamedWithPrefix: String, } @http(uri: "/top", method: "POST") diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/AwsQuerySerializerGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/AwsQuerySerializerGeneratorTest.kt index ad44554e36..2203b05dc1 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/AwsQuerySerializerGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/AwsQuerySerializerGeneratorTest.kt @@ -7,10 +7,12 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.CsvSource +import software.amazon.smithy.model.knowledge.NullableIndex import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget +import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator.Companion.hasFallibleBuilder import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.TestEnumType import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator @@ -95,7 +97,7 @@ class AwsQuerySerializerGeneratorTest { val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel)) val codegenContext = testCodegenContext(model, codegenTarget = codegenTarget) val symbolProvider = codegenContext.symbolProvider - val parserGenerator = AwsQuerySerializerGenerator(testCodegenContext(model, codegenTarget = codegenTarget)) + val parserGenerator = AwsQuerySerializerGenerator(codegenContext) val operationGenerator = parserGenerator.operationInputSerializer(model.lookup("test#Op")) val project = TestWorkspace.testProject(symbolProvider) @@ -152,4 +154,162 @@ class AwsQuerySerializerGeneratorTest { } project.compileAndTest() } + + private val baseModelWithRequiredTypes = """ + namespace test + use aws.protocols#restJson1 + + union Choice { + blob: Blob, + boolean: Boolean, + date: Timestamp, + enum: FooEnum, + int: Integer, + @xmlFlattened + list: SomeList, + long: Long, + map: MyMap, + number: Double, + s: String, + top: Top, + unit: Unit, + } + + @enum([{name: "FOO", value: "FOO"}]) + string FooEnum + + map MyMap { + key: String, + value: Choice, + } + + list SomeList { + member: Choice + } + + structure Top { + @required + choice: Choice, + @required + field: String, + @required + extra: Long, + @xmlName("rec") + recursive: TopList + } + + list TopList { + @xmlName("item") + member: Top + } + + structure OpInput { + @required + @xmlName("some_bool") + boolean: Boolean, + list: SomeList, + map: MyMap, + @required + top: Top, + @required + blob: Blob + } + + @http(uri: "/top", method: "POST") + operation Op { + input: OpInput, + } + """.asSmithyModel() + + @ParameterizedTest + @CsvSource( + "true, CLIENT", + "true, CLIENT_CAREFUL", + "true, CLIENT_ZERO_VALUE_V1", + "true, CLIENT_ZERO_VALUE_V1_NO_INPUT", + "false, SERVER", + ) + fun `generates valid serializers for required types`(generateUnknownVariant: Boolean, nullabilityCheckMode: NullableIndex.CheckMode) { + val codegenTarget = when (generateUnknownVariant) { + true -> CodegenTarget.CLIENT + false -> CodegenTarget.SERVER + } + val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModelWithRequiredTypes)) + val codegenContext = testCodegenContext(model, codegenTarget = codegenTarget, nullabilityCheckMode = nullabilityCheckMode) + val symbolProvider = codegenContext.symbolProvider + val parserGenerator = AwsQuerySerializerGenerator(codegenContext) + val operationGenerator = parserGenerator.operationInputSerializer(model.lookup("test#Op")) + + val project = TestWorkspace.testProject(symbolProvider) + + // Depending on the nullability check mode, the builder can be fallible or not. When it's fallible, we need to + // add unwrap calls. + val builderIsFallible = hasFallibleBuilder(model.lookup("test#Top"), symbolProvider) + val maybeUnwrap = if (builderIsFallible) { ".unwrap()" } else { "" } + project.lib { + unitTest( + "query_serializer", + """ + use test_model::{Choice, Top}; + + let input = crate::test_input::OpInput::builder() + .top( + Top::builder() + .field("Hello") + .choice(Choice::Boolean(true)) + .extra(45) + .recursive( + Top::builder() + .field("World!") + .choice(Choice::Boolean(true)) + .extra(55) + .build() + $maybeUnwrap + ) + .build() + $maybeUnwrap + ) + .boolean(true) + .blob(aws_smithy_types::Blob::new(&b"test"[..])) + .build() + .unwrap(); + let serialized = ${format(operationGenerator!!)}(&input).unwrap(); + let output = std::str::from_utf8(serialized.bytes().unwrap()).unwrap(); + assert_eq!( + output, + "\ + Action=Op\ + &Version=test\ + &some_bool=true\ + &top.choice.choice=true\ + &top.field=Hello\ + &top.extra=45\ + &top.rec.item.1.choice.choice=true\ + &top.rec.item.1.field=World%21\ + &top.rec.item.1.extra=55\ + &blob=dGVzdA%3D%3D" + ); + """, + ) + } + model.lookup("test#Top").also { top -> + top.renderWithModelBuilder(model, symbolProvider, project) + project.moduleFor(top) { + UnionGenerator( + model, + symbolProvider, + this, + model.lookup("test#Choice"), + renderUnknownVariant = generateUnknownVariant, + ).render() + val enum = model.lookup("test#FooEnum") + EnumGenerator(model, symbolProvider, enum, TestEnumType).render(this) + } + } + + model.lookup("test#Op").inputShape(model).also { input -> + input.renderWithModelBuilder(model, symbolProvider, project) + } + project.compileAndTest() + } } diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/Ec2QuerySerializerGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/Ec2QuerySerializerGeneratorTest.kt index 2436aff706..4b5f490e13 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/Ec2QuerySerializerGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/Ec2QuerySerializerGeneratorTest.kt @@ -5,10 +5,13 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize -import org.junit.jupiter.api.Test +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.CsvSource +import software.amazon.smithy.model.knowledge.NullableIndex import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.TestEnumType import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator @@ -19,6 +22,7 @@ import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest import software.amazon.smithy.rust.codegen.core.testutil.renderWithModelBuilder import software.amazon.smithy.rust.codegen.core.testutil.testCodegenContext +import software.amazon.smithy.rust.codegen.core.testutil.testRustSettings import software.amazon.smithy.rust.codegen.core.testutil.testSymbolProvider import software.amazon.smithy.rust.codegen.core.testutil.unitTest import software.amazon.smithy.rust.codegen.core.util.inputShape @@ -83,10 +87,18 @@ class Ec2QuerySerializerGeneratorTest { } """.asSmithyModel() - @Test - fun `generates valid serializers`() { + @ParameterizedTest + @CsvSource( + "CLIENT", + "CLIENT_CAREFUL", + "CLIENT_ZERO_VALUE_V1", + "CLIENT_ZERO_VALUE_V1_NO_INPUT", + "SERVER", + ) + fun `generates valid serializers`(nullabilityCheckMode: NullableIndex.CheckMode) { val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel)) - val codegenContext = testCodegenContext(model) + val settings = testRustSettings() + val codegenContext = testCodegenContext(model, settings = settings, nullabilityCheckMode = nullabilityCheckMode) val symbolProvider = codegenContext.symbolProvider val parserGenerator = Ec2QuerySerializerGenerator(codegenContext) val operationGenerator = parserGenerator.operationInputSerializer(model.lookup("test#Op")) @@ -139,4 +151,153 @@ class Ec2QuerySerializerGeneratorTest { } project.compileAndTest() } + + private val baseModelWithRequiredTypes = """ + namespace test + + union Choice { + blob: Blob, + boolean: Boolean, + date: Timestamp, + enum: FooEnum, + int: Integer, + @xmlFlattened + list: SomeList, + long: Long, + map: MyMap, + number: Double, + s: String, + top: Top, + unit: Unit, + } + + @enum([{name: "FOO", value: "FOO"}]) + string FooEnum + + map MyMap { + key: String, + value: Choice, + } + + list SomeList { + member: Choice + } + + structure Top { + @required + choice: Choice, + @required + field: String, + @required + extra: Long, + @xmlName("rec") + recursive: TopList + } + + list TopList { + @xmlName("item") + member: Top + } + + structure OpInput { + @required + @xmlName("some_bool") + boolean: Boolean, + list: SomeList, + map: MyMap, + @required + top: Top, + @required + blob: Blob + } + + @http(uri: "/top", method: "POST") + operation Op { + input: OpInput, + } + """.asSmithyModel() + + @ParameterizedTest + @CsvSource( + "CLIENT", + "CLIENT_CAREFUL", + "CLIENT_ZERO_VALUE_V1", + "CLIENT_ZERO_VALUE_V1_NO_INPUT", + "SERVER", + ) + fun `generates valid serializers for required types`(nullabilityCheckMode: NullableIndex.CheckMode) { + val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModelWithRequiredTypes)) + val settings = testRustSettings() + val codegenContext = testCodegenContext(model, settings = settings, nullabilityCheckMode = nullabilityCheckMode) + val symbolProvider = codegenContext.symbolProvider + val parserGenerator = Ec2QuerySerializerGenerator(codegenContext) + val operationGenerator = parserGenerator.operationInputSerializer(model.lookup("test#Op")) + + val project = TestWorkspace.testProject(testSymbolProvider(model)) + + // Depending on the nullability check mode, the builder can be fallible or not. When it's fallible, we need to + // add unwrap calls. + val builderIsFallible = + BuilderGenerator.hasFallibleBuilder(model.lookup("test#Top"), symbolProvider) + val maybeUnwrap = if (builderIsFallible) { ".unwrap()" } else { "" } + project.lib { + unitTest( + "ec2query_serializer", + """ + use test_model::{Choice, Top}; + + let input = crate::test_input::OpInput::builder() + .top( + Top::builder() + .field("Hello") + .choice(Choice::Boolean(true)) + .extra(45) + .recursive( + Top::builder() + .field("World!") + .choice(Choice::Boolean(true)) + .extra(55) + .build() + $maybeUnwrap + ) + .build() + $maybeUnwrap + ) + .boolean(true) + .blob(aws_smithy_types::Blob::new(&b"test"[..])) + .build() + .unwrap(); + let serialized = ${format(operationGenerator!!)}(&input).unwrap(); + let output = std::str::from_utf8(serialized.bytes().unwrap()).unwrap(); + assert_eq!( + output, + "\ + Action=Op\ + &Version=test\ + &Some_bool=true\ + &Top.Choice.Choice=true\ + &Top.Field=Hello\ + &Top.Extra=45\ + &Top.Rec.1.Choice.Choice=true\ + &Top.Rec.1.Field=World%21\ + &Top.Rec.1.Extra=55\ + &Blob=dGVzdA%3D%3D" + ); + """, + ) + } + model.lookup("test#Top").also { top -> + top.renderWithModelBuilder(model, symbolProvider, project) + project.moduleFor(top) { + UnionGenerator(model, symbolProvider, this, model.lookup("test#Choice")).render() + val enum = model.lookup("test#FooEnum") + EnumGenerator(model, symbolProvider, enum, TestEnumType).render(this) + } + } + + model.lookup("test#Op").inputShape(model).also { input -> + input.renderWithModelBuilder(model, symbolProvider, project) + } + project.compileAndTest() + } } diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGeneratorTest.kt index 23c27f331b..61140ecd42 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGeneratorTest.kt @@ -5,10 +5,13 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize -import org.junit.jupiter.api.Test +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.CsvSource +import software.amazon.smithy.model.knowledge.NullableIndex import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.TestEnumType import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator @@ -85,6 +88,7 @@ class JsonSerializerGeneratorTest { member: Top } + @input structure OpInput { @httpHeader("x-test") someHeader: String, @@ -98,10 +102,17 @@ class JsonSerializerGeneratorTest { } """.asSmithyModel() - @Test - fun `generates valid serializers`() { + @ParameterizedTest + @CsvSource( + "CLIENT", + "CLIENT_CAREFUL", + "CLIENT_ZERO_VALUE_V1", + "CLIENT_ZERO_VALUE_V1_NO_INPUT", + "SERVER", + ) + fun `generates valid serializers`(nullabilityCheckMode: NullableIndex.CheckMode) { val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel)) - val codegenContext = testCodegenContext(model) + val codegenContext = testCodegenContext(model, nullabilityCheckMode = nullabilityCheckMode) val symbolProvider = codegenContext.symbolProvider val parserSerializer = JsonSerializerGenerator( codegenContext, @@ -111,7 +122,7 @@ class JsonSerializerGeneratorTest { val operationGenerator = parserSerializer.operationInputSerializer(model.lookup("test#Op")) val documentGenerator = parserSerializer.documentSerializer() - val project = TestWorkspace.testProject(testSymbolProvider(model)) + val project = TestWorkspace.testProject(symbolProvider) project.lib { unitTest( "json_serializers", @@ -137,7 +148,174 @@ class JsonSerializerGeneratorTest { .choice(Choice::Unknown) .build() ).build().unwrap(); - let serialized = ${format(operationGenerator)}(&input).expect_err("cannot serialize unknown variant"); + ${format(operationGenerator)}(&input).expect_err("cannot serialize unknown variant"); + """, + ) + } + model.lookup("test#Top").also { top -> + top.renderWithModelBuilder(model, symbolProvider, project) + project.moduleFor(top) { + UnionGenerator(model, symbolProvider, this, model.lookup("test#Choice")).render() + val enum = model.lookup("test#FooEnum") + EnumGenerator(model, symbolProvider, enum, TestEnumType).render(this) + } + } + + model.lookup("test#Op").inputShape(model).also { input -> + input.renderWithModelBuilder(model, symbolProvider, project) + } + project.compileAndTest() + } + + private val baseModelWithRequiredTypes = """ + namespace test + use aws.protocols#restJson1 + + union Choice { + blob: Blob, + boolean: Boolean, + date: Timestamp, + document: Document, + enum: FooEnum, + int: Integer, + list: SomeList, + listSparse: SomeSparseList, + long: Long, + map: MyMap, + mapSparse: MySparseMap, + number: Double, + s: String, + top: Top, + unit: Unit, + } + + @enum([{name: "FOO", value: "FOO"}]) + string FooEnum + + map MyMap { + key: String, + value: Choice, + } + + @sparse + map MySparseMap { + key: String, + value: Choice, + } + + list SomeList { + member: Choice + } + + @sparse + list SomeSparseList { + member: Choice + } + + structure Top { + @required + choice: Choice, + @required + field: String, + @required + extra: Long, + @jsonName("rec") + recursive: TopList + } + + list TopList { + member: Top + } + + @input + structure OpInput { + @httpHeader("x-test") + someHeader: String, + + @required + boolean: Boolean, + list: SomeList, + map: MyMap, + + @required + top: Top + } + + @http(uri: "/top", method: "POST") + operation Op { + input: OpInput, + } + """.asSmithyModel() + + @ParameterizedTest + @CsvSource( + "CLIENT", + "CLIENT_CAREFUL", + "CLIENT_ZERO_VALUE_V1", + "CLIENT_ZERO_VALUE_V1_NO_INPUT", + "SERVER", + ) + fun `generates valid serializers for required types`(nullabilityCheckMode: NullableIndex.CheckMode) { + val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModelWithRequiredTypes)) + val codegenContext = testCodegenContext(model, nullabilityCheckMode = nullabilityCheckMode) + val symbolProvider = codegenContext.symbolProvider + val parserSerializer = JsonSerializerGenerator( + codegenContext, + HttpTraitHttpBindingResolver(model, ProtocolContentTypes.consistent("application/json")), + ::restJsonFieldName, + ) + val operationGenerator = parserSerializer.operationInputSerializer(model.lookup("test#Op")) + val documentGenerator = parserSerializer.documentSerializer() + + val project = TestWorkspace.testProject(testSymbolProvider(model)) + + // Depending on the nullability check mode, the builder can be fallible or not. When it's fallible, we need to + // add unwrap calls. + val builderIsFallible = + BuilderGenerator.hasFallibleBuilder(model.lookup("test#Top"), symbolProvider) + val maybeUnwrap = if (builderIsFallible) { ".unwrap()" } else { "" } + project.lib { + unitTest( + "json_serializers", + """ + use test_model::{Choice, Top}; + + // Generate the document serializer even though it's not tested directly + // ${format(documentGenerator)} + + let input = crate::test_input::OpInput::builder() + .top( + Top::builder() + .field("Hello") + .choice(Choice::Boolean(true)) + .extra(45) + .recursive( + Top::builder() + .field("World!") + .choice(Choice::Boolean(true)) + .extra(55) + .build() + $maybeUnwrap + ) + .build() + $maybeUnwrap + ) + .boolean(true) + .build() + .unwrap(); + let serialized = ${format(operationGenerator!!)}(&input).unwrap(); + let output = std::str::from_utf8(serialized.bytes().unwrap()).unwrap(); + assert_eq!(output, r#"{"boolean":true,"top":{"choice":{"boolean":true},"field":"Hello","extra":45,"rec":[{"choice":{"boolean":true},"field":"World!","extra":55}]}}"#); + + let input = crate::test_input::OpInput::builder().top( + Top::builder() + .field("Hello") + .choice(Choice::Unknown) + .extra(45) + .build() + $maybeUnwrap + ).boolean(false).build().unwrap(); + ${format(operationGenerator)}(&input).expect_err("cannot serialize unknown variant"); """, ) } diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/XmlBindingTraitSerializerGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/XmlBindingTraitSerializerGeneratorTest.kt index a695d2a401..4410c33910 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/XmlBindingTraitSerializerGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/XmlBindingTraitSerializerGeneratorTest.kt @@ -5,13 +5,18 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize -import org.junit.jupiter.api.Test +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.CsvSource +import software.amazon.smithy.model.knowledge.NullableIndex +import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.TestEnumType import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator +import software.amazon.smithy.rust.codegen.core.smithy.isOptional import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpTraitHttpBindingResolver import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolContentTypes import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer @@ -103,10 +108,17 @@ internal class XmlBindingTraitSerializerGeneratorTest { } """.asSmithyModel() - @Test - fun `generates valid serializers`() { + @ParameterizedTest + @CsvSource( + "CLIENT", + "CLIENT_CAREFUL", + "CLIENT_ZERO_VALUE_V1", + "CLIENT_ZERO_VALUE_V1_NO_INPUT", + "SERVER", + ) + fun `generates valid serializers`(nullabilityCheckMode: NullableIndex.CheckMode) { val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel)) - val codegenContext = testCodegenContext(model) + val codegenContext = testCodegenContext(model, nullabilityCheckMode = nullabilityCheckMode) val symbolProvider = codegenContext.symbolProvider val parserGenerator = XmlBindingTraitSerializerGenerator( codegenContext, @@ -120,14 +132,14 @@ internal class XmlBindingTraitSerializerGeneratorTest { "serialize_xml", """ use test_model::Top; - let inp = crate::test_input::OpInput::builder().payload( + let input = crate::test_input::OpInput::builder().payload( Top::builder() .field("hello!") .extra(45) .recursive(Top::builder().extra(55).build()) .build() ).build().unwrap(); - let serialized = ${format(operationSerializer)}(&inp.payload.unwrap()).unwrap(); + let serialized = ${format(operationSerializer)}(&input.payload.unwrap()).unwrap(); let output = std::str::from_utf8(&serialized).unwrap(); assert_eq!(output, "hello!"); """, @@ -158,4 +170,160 @@ internal class XmlBindingTraitSerializerGeneratorTest { } project.compileAndTest() } + + private val baseModelWithRequiredTypes = """ + namespace test + use aws.protocols#restXml + union Choice { + boolean: Boolean, + @xmlFlattened + @xmlName("Hi") + flatMap: MyMap, + deepMap: MyMap, + @xmlFlattened + flatList: SomeList, + deepList: SomeList, + s: String, + enum: FooEnum, + date: Timestamp, + number: Double, + top: Top, + blob: Blob, + unit: Unit, + } + + @enum([{name: "FOO", value: "FOO"}]) + string FooEnum + + map MyMap { + @xmlName("Name") + key: String, + @xmlName("Setting") + value: Choice, + } + + list SomeList { + member: Choice + } + + + structure Top { + @required + choice: Choice, + @required + field: String, + @required + @xmlAttribute + extra: Long, + @xmlName("prefix:local") + renamedWithPrefix: String, + @xmlName("rec") + @xmlFlattened + recursive: TopList + } + + list TopList { + member: Top + } + + structure OpInput { + @required + @httpPayload + payload: Top + } + + @http(uri: "/top", method: "POST") + operation Op { + input: OpInput, + } + """.asSmithyModel() + + @ParameterizedTest + @CsvSource( + "CLIENT", + "CLIENT_CAREFUL", + "CLIENT_ZERO_VALUE_V1", + "CLIENT_ZERO_VALUE_V1_NO_INPUT", + "SERVER", + ) + fun `generates valid serializers for required types`(nullabilityCheckMode: NullableIndex.CheckMode) { + val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModelWithRequiredTypes)) + val codegenContext = testCodegenContext(model, nullabilityCheckMode = nullabilityCheckMode) + val symbolProvider = codegenContext.symbolProvider + val parserGenerator = XmlBindingTraitSerializerGenerator( + codegenContext, + HttpTraitHttpBindingResolver(model, ProtocolContentTypes.consistent("application/xml")), + ) + val operationSerializer = parserGenerator.payloadSerializer(model.lookup("test#OpInput\$payload")) + + val project = TestWorkspace.testProject(symbolProvider) + + // Depending on the nullability check mode, the builder can be fallible or not. When it's fallible, we need to + // add unwrap calls. + val builderIsFallible = + BuilderGenerator.hasFallibleBuilder(model.lookup("test#Top"), symbolProvider) + val maybeUnwrap = if (builderIsFallible) { ".unwrap()" } else { "" } + val payloadIsOptional = model.lookup("test#OpInput\$payload").let { + symbolProvider.toSymbol(it).isOptional() + } + val maybeUnwrapPayload = if (payloadIsOptional) { ".unwrap()" } else { "" } + project.lib { + unitTest( + "serialize_xml", + """ + use test_model::{Choice, Top}; + + let input = crate::test_input::OpInput::builder() + .payload( + Top::builder() + .field("Hello") + .choice(Choice::Boolean(true)) + .extra(45) + .recursive( + Top::builder() + .field("World!") + .choice(Choice::Boolean(true)) + .extra(55) + .build() + $maybeUnwrap + ) + .build() + $maybeUnwrap + ) + .build() + .unwrap(); + let serialized = ${format(operationSerializer)}(&input.payload$maybeUnwrapPayload).unwrap(); + let output = std::str::from_utf8(&serialized).unwrap(); + assert_eq!(output, "trueHellotrueWorld!"); + """, + ) + unitTest( + "unknown_variants", + """ + use test_model::{Choice, Top}; + let input = crate::test_input::OpInput::builder().payload( + Top::builder() + .field("Hello") + .choice(Choice::Unknown) + .extra(45) + .build() + $maybeUnwrap + ).build().unwrap(); + ${format(operationSerializer)}(&input.payload$maybeUnwrapPayload).expect_err("cannot serialize unknown variant"); + """, + ) + } + model.lookup("test#Top").also { top -> + top.renderWithModelBuilder(model, symbolProvider, project) + project.moduleFor(top) { + UnionGenerator(model, symbolProvider, this, model.lookup("test#Choice")).render() + val enum = model.lookup("test#FooEnum") + EnumGenerator(model, symbolProvider, enum, TestEnumType).render(this) + } + } + model.lookup("test#Op").inputShape(model).also { input -> + input.renderWithModelBuilder(model, symbolProvider, project) + } + project.compileAndTest() + } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt index 622123cd4a..f86c547da1 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt @@ -122,7 +122,7 @@ class ServerAwsJsonProtocol( private val serverCodegenContext: ServerCodegenContext, awsJsonVersion: AwsJsonVersion, private val additionalParserCustomizations: List = listOf(), -) : AwsJson(serverCodegenContext, awsJsonVersion, serverCodegenContext.builderInstantiator()), ServerProtocol { +) : AwsJson(serverCodegenContext, awsJsonVersion), ServerProtocol { private val runtimeConfig = codegenContext.runtimeConfig override val protocolModulePath: String diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 33c374bbcb..453eee98e4 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -6,7 +6,8 @@ members = [ "pokemon-service-tls", "pokemon-service-lambda", "pokemon-service-server-sdk", - "pokemon-service-client" + "pokemon-service-client", + ] [profile.release] diff --git a/examples/pokemon-service/tests/simple.rs b/examples/pokemon-service/tests/simple.rs index 3a055da587..dc9f88d0b5 100644 --- a/examples/pokemon-service/tests/simple.rs +++ b/examples/pokemon-service/tests/simple.rs @@ -19,7 +19,7 @@ async fn simple_integration_test() { let client = common::client(); let service_statistics_out = client.get_server_statistics().send().await.unwrap(); - assert_eq!(0, service_statistics_out.calls_count.unwrap()); + assert_eq!(0, service_statistics_out.calls_count); let pokemon_species_output = client .get_pokemon_species() @@ -27,10 +27,10 @@ async fn simple_integration_test() { .send() .await .unwrap(); - assert_eq!("pikachu", pokemon_species_output.name().unwrap()); + assert_eq!("pikachu", pokemon_species_output.name()); let service_statistics_out = client.get_server_statistics().send().await.unwrap(); - assert_eq!(1, service_statistics_out.calls_count.unwrap()); + assert_eq!(1, service_statistics_out.calls_count); let storage_err = client .get_storage() @@ -56,12 +56,12 @@ async fn simple_integration_test() { .await .unwrap(); assert_eq!( - Some(vec![ + vec![ "bulbasaur".to_string(), "charmander".to_string(), "squirtle".to_string(), "pikachu".to_string() - ]), + ], storage_out.collection ); @@ -80,7 +80,7 @@ async fn simple_integration_test() { ); let service_statistics_out = client.get_server_statistics().send().await.unwrap(); - assert_eq!(2, service_statistics_out.calls_count.unwrap()); + assert_eq!(2, service_statistics_out.calls_count); let hyper_client = hyper::Client::new(); let health_check_url = format!("{}/ping", common::base_url()); diff --git a/examples/python/pokemon-service-test/tests/simple_integration_test.rs b/examples/python/pokemon-service-test/tests/simple_integration_test.rs index 39f979e9a3..0c928fa0d7 100644 --- a/examples/python/pokemon-service-test/tests/simple_integration_test.rs +++ b/examples/python/pokemon-service-test/tests/simple_integration_test.rs @@ -36,7 +36,7 @@ async fn simple_integration_test_http2() { async fn simple_integration_test_with_client(client: PokemonClient) { let service_statistics_out = client.get_server_statistics().send().await.unwrap(); - assert_eq!(0, service_statistics_out.calls_count.unwrap()); + assert_eq!(0, service_statistics_out.calls_count); let pokemon_species_output = client .get_pokemon_species() @@ -44,10 +44,10 @@ async fn simple_integration_test_with_client(client: PokemonClient) { .send() .await .unwrap(); - assert_eq!("pikachu", pokemon_species_output.name().unwrap()); + assert_eq!("pikachu", pokemon_species_output.name()); let service_statistics_out = client.get_server_statistics().send().await.unwrap(); - assert_eq!(1, service_statistics_out.calls_count.unwrap()); + assert_eq!(1, service_statistics_out.calls_count); let pokemon_species_error = client .get_pokemon_species() @@ -64,7 +64,7 @@ async fn simple_integration_test_with_client(client: PokemonClient) { ); let service_statistics_out = client.get_server_statistics().send().await.unwrap(); - assert_eq!(2, service_statistics_out.calls_count.unwrap()); + assert_eq!(2, service_statistics_out.calls_count); let _health_check = client.check_health().send().await.unwrap(); } diff --git a/rust-runtime/aws-smithy-types/src/blob.rs b/rust-runtime/aws-smithy-types/src/blob.rs index 5365b91249..ab7ae30d3c 100644 --- a/rust-runtime/aws-smithy-types/src/blob.rs +++ b/rust-runtime/aws-smithy-types/src/blob.rs @@ -43,7 +43,7 @@ mod serde_serialize { S: serde::Serializer, { if serializer.is_human_readable() { - serializer.serialize_str(&crate::base64::encode(&self.inner)) + serializer.serialize_str(&base64::encode(&self.inner)) } else { serializer.serialize_bytes(&self.inner) }