From df77d5ff33f143abaea3f8bf5a9cdeec05353ece Mon Sep 17 00:00:00 2001 From: Fahad Zubair Date: Thu, 14 Nov 2024 11:18:23 +0000 Subject: [PATCH] Enforce constraints for unnamed enums (#3884) ### Enforces Constraints for Unnamed Enums This PR addresses the issue where, on the server side, unnamed enums were incorrectly treated as infallible during deserialization, allowing any string value to be converted without validation. The solution introduces a `ConstraintViolation` and `TryFrom` implementation for unnamed enums, ensuring that deserialized values conform to the enum variants defined in the Smithy model. The following is an example of an unnamed enum: ```smithy @enum([ { value: "MONDAY" }, { value: "TUESDAY" } ]) string UnnamedDayOfWeek ``` On the server side the following type is generated for the Smithy shape: ```rust pub struct UnnamedDayOfWeek(String); impl ::std::convert::TryFrom<::std::string::String> for UnnamedDayOfWeek { type Error = crate::model::unnamed_day_of_week::ConstraintViolation; fn try_from( s: ::std::string::String, ) -> ::std::result::Result>::Error> { match s.as_str() { "MONDAY" | "TUESDAY" => Ok(Self(s)), _ => Err(crate::model::unnamed_day_of_week::ConstraintViolation(s)), } } } ``` This change prevents invalid values from being deserialized into unnamed enums and raises appropriate constraint violations when necessary. There is one difference between the Rust code generated for `TryFrom` for named enums versus unnamed enums. The implementation for unnamed enums passes the ownership of the `String` parameter to the generated structure, and the implementation for `TryFrom<&str>` delegates to `TryFrom`. ```rust impl ::std::convert::TryFrom<::std::string::String> for UnnamedDayOfWeek { type Error = crate::model::unnamed_day_of_week::ConstraintViolation; fn try_from( s: ::std::string::String, ) -> ::std::result::Result>::Error> { match s.as_str() { "MONDAY" | "TUESDAY" => Ok(Self(s)), _ => Err(crate::model::unnamed_day_of_week::ConstraintViolation(s)), } } } impl ::std::convert::TryFrom<&str> for UnnamedDayOfWeek { type Error = crate::model::unnamed_day_of_week::ConstraintViolation; fn try_from( s: &str, ) -> ::std::result::Result>::Error> { s.to_owned().try_into() } } ``` On the client side, the behaviour is unchanged, and the client does not validate for backward compatibility reasons. An [existing test](https://github.com/smithy-lang/smithy-rs/pull/3884/files#diff-021ec60146cfe231105d21a7389f2dffcd546595964fbb3f0684ebf068325e48R82) has been modified to ensure this. ```rust #[test] fn generate_unnamed_enums() { let result = "t2.nano" .parse::() .expect("static value validated to member"); assert_eq!(result, UnnamedEnum("t2.nano".to_owned())); let result = "not-a-valid-variant" .parse::() .expect("static value validated to member"); assert_eq!(result, UnnamedEnum("not-a-valid-variant".to_owned())); } ``` Fixes issue #3880 --------- Co-authored-by: Fahad Zubair --- .changelog/4329788.md | 18 +++ .../smithy/generators/ClientEnumGenerator.kt | 31 ++++++ .../generators/ClientInstantiatorTest.kt | 8 +- .../core/smithy/generators/EnumGenerator.kt | 36 ++---- .../smithy/generators/EnumGeneratorTest.kt | 10 ++ .../core/smithy/generators/TestEnumType.kt | 32 ++++++ .../smithy/generators/ServerEnumGenerator.kt | 104 ++++++++++++------ .../codegen/server/smithy/ConstraintsTest.kt | 58 +++++++++- 8 files changed, 235 insertions(+), 62 deletions(-) create mode 100644 .changelog/4329788.md diff --git a/.changelog/4329788.md b/.changelog/4329788.md new file mode 100644 index 0000000000..c45e4c8a7e --- /dev/null +++ b/.changelog/4329788.md @@ -0,0 +1,18 @@ +--- +applies_to: ["server"] +authors: ["drganjoo"] +references: ["smithy-rs#3880"] +breaking: true +new_feature: false +bug_fix: true +--- +Unnamed enums now validate assigned values and will raise a `ConstraintViolation` if an unknown variant is set. + +The following is an example of an unnamed enum: +```smithy +@enum([ + { value: "MONDAY" }, + { value: "TUESDAY" } +]) +string UnnamedDayOfWeek +``` diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientEnumGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientEnumGenerator.kt index 77a5731b62..f45c19d74b 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientEnumGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientEnumGenerator.kt @@ -79,6 +79,37 @@ data class InfallibleEnumType( ) } + override fun implFromForStrForUnnamedEnum(context: EnumGeneratorContext): Writable = + writable { + rustTemplate( + """ + impl #{From} for ${context.enumName} where T: #{AsRef} { + fn from(s: T) -> Self { + ${context.enumName}(s.as_ref().to_owned()) + } + } + """, + *preludeScope, + ) + } + + override fun implFromStrForUnnamedEnum(context: EnumGeneratorContext): Writable = + writable { + // Add an infallible FromStr implementation for uniformity + rustTemplate( + """ + impl ::std::str::FromStr for ${context.enumName} { + type Err = ::std::convert::Infallible; + + fn from_str(s: &str) -> #{Result}::Err> { + #{Ok}(${context.enumName}::from(s)) + } + } + """, + *preludeScope, + ) + } + override fun additionalEnumImpls(context: EnumGeneratorContext): Writable = writable { // `try_parse` isn't needed for unnamed enums diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientInstantiatorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientInstantiatorTest.kt index 00fa942961..60ee794914 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientInstantiatorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientInstantiatorTest.kt @@ -69,6 +69,8 @@ internal class ClientInstantiatorTest { val shape = model.lookup("com.test#UnnamedEnum") val sut = ClientInstantiator(codegenContext) val data = Node.parse("t2.nano".dq()) + // The client SDK should accept unknown variants as valid. + val notValidVariant = Node.parse("not-a-valid-variant".dq()) val project = TestWorkspace.testProject(symbolProvider) project.moduleFor(shape) { @@ -77,7 +79,11 @@ internal class ClientInstantiatorTest { withBlock("let result = ", ";") { sut.render(this, shape, data) } - rust("""assert_eq!(result, UnnamedEnum("t2.nano".to_owned()));""") + rust("""assert_eq!(result, UnnamedEnum("$data".to_owned()));""") + withBlock("let result = ", ";") { + sut.render(this, shape, notValidVariant) + } + rust("""assert_eq!(result, UnnamedEnum("$notValidVariant".to_owned()));""") } } project.compileAndTest() diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGenerator.kt index 1f05ad7aab..d1eac1488e 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGenerator.kt @@ -59,6 +59,12 @@ abstract class EnumType { /** Returns a writable that implements `FromStr` for the enum */ abstract fun implFromStr(context: EnumGeneratorContext): Writable + /** Returns a writable that implements `From<&str>` and/or `TryFrom<&str>` for the unnamed enum */ + abstract fun implFromForStrForUnnamedEnum(context: EnumGeneratorContext): Writable + + /** Returns a writable that implements `FromStr` for the unnamed enum */ + abstract fun implFromStrForUnnamedEnum(context: EnumGeneratorContext): Writable + /** Optionally adds additional documentation to the `enum` docs */ open fun additionalDocs(context: EnumGeneratorContext): Writable = writable {} @@ -237,32 +243,10 @@ open class EnumGenerator( rust("&self.0") }, ) - - // Add an infallible FromStr implementation for uniformity - rustTemplate( - """ - impl ::std::str::FromStr for ${context.enumName} { - type Err = ::std::convert::Infallible; - - fn from_str(s: &str) -> #{Result}::Err> { - #{Ok}(${context.enumName}::from(s)) - } - } - """, - *preludeScope, - ) - - rustTemplate( - """ - impl #{From} for ${context.enumName} where T: #{AsRef} { - fn from(s: T) -> Self { - ${context.enumName}(s.as_ref().to_owned()) - } - } - - """, - *preludeScope, - ) + // impl From for Blah { ... } + enumType.implFromForStrForUnnamedEnum(context)(this) + // impl FromStr for Blah { ... } + enumType.implFromStrForUnnamedEnum(context)(this) } private fun RustWriter.renderEnum() { diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGeneratorTest.kt index 0e2b10788f..0528c2e364 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGeneratorTest.kt @@ -494,6 +494,16 @@ class EnumGeneratorTest { // intentional no-op } + override fun implFromForStrForUnnamedEnum(context: EnumGeneratorContext): Writable = + writable { + // intentional no-op + } + + override fun implFromStrForUnnamedEnum(context: EnumGeneratorContext): Writable = + writable { + // intentional no-op + } + override fun additionalEnumMembers(context: EnumGeneratorContext): Writable = writable { rust("// additional enum members") diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/TestEnumType.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/TestEnumType.kt index 5699c9e5ca..019d027ece 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/TestEnumType.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/TestEnumType.kt @@ -10,6 +10,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope import software.amazon.smithy.rust.codegen.core.util.dq object TestEnumType : EnumType() { @@ -49,4 +50,35 @@ object TestEnumType : EnumType() { """, ) } + + override fun implFromForStrForUnnamedEnum(context: EnumGeneratorContext): Writable = + writable { + rustTemplate( + """ + impl #{From} for ${context.enumName} where T: #{AsRef} { + fn from(s: T) -> Self { + ${context.enumName}(s.as_ref().to_owned()) + } + } + """, + *preludeScope, + ) + } + + override fun implFromStrForUnnamedEnum(context: EnumGeneratorContext): Writable = + writable { + // Add an infallible FromStr implementation for uniformity + rustTemplate( + """ + impl ::std::str::FromStr for ${context.enumName} { + type Err = ::std::convert::Infallible; + + fn from_str(s: &str) -> #{Result}::Err> { + #{Ok}(${context.enumName}::from(s)) + } + } + """, + *preludeScope, + ) + } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGenerator.kt index 5bc2218ad1..a3cc269692 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGenerator.kt @@ -5,10 +5,9 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.rust -import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock -import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType @@ -39,16 +38,14 @@ open class ConstrainedEnum( } private val constraintViolationSymbol = constraintViolationSymbolProvider.toSymbol(shape) private val constraintViolationName = constraintViolationSymbol.name - private val codegenScope = - arrayOf( - "String" to RuntimeType.String, - ) - override fun implFromForStr(context: EnumGeneratorContext): Writable = - writable { - withInlineModule(constraintViolationSymbol.module(), codegenContext.moduleDocProvider) { - rustTemplate( - """ + private fun generateConstraintViolation( + context: EnumGeneratorContext, + generateTryFromStrAndString: RustWriter.(EnumGeneratorContext) -> Unit, + ) = writable { + withInlineModule(constraintViolationSymbol.module(), codegenContext.moduleDocProvider) { + rustTemplate( + """ ##[derive(Debug, PartialEq)] pub struct $constraintViolationName(pub(crate) #{String}); @@ -60,47 +57,86 @@ open class ConstrainedEnum( impl #{Error} for $constraintViolationName {} """, - *codegenScope, - "Error" to RuntimeType.StdError, - "Display" to RuntimeType.Display, - ) + *preludeScope, + "Error" to RuntimeType.StdError, + "Display" to RuntimeType.Display, + ) - if (shape.isReachableFromOperationInput()) { - rustTemplate( - """ + if (shape.isReachableFromOperationInput()) { + rustTemplate( + """ impl $constraintViolationName { #{EnumShapeConstraintViolationImplBlock:W} } """, - "EnumShapeConstraintViolationImplBlock" to - validationExceptionConversionGenerator.enumShapeConstraintViolationImplBlock( - context.enumTrait, - ), - ) - } + "EnumShapeConstraintViolationImplBlock" to + validationExceptionConversionGenerator.enumShapeConstraintViolationImplBlock( + context.enumTrait, + ), + ) } - rustBlock("impl #T<&str> for ${context.enumName}", RuntimeType.TryFrom) { - rust("type Error = #T;", constraintViolationSymbol) - rustBlockTemplate("fn try_from(s: &str) -> #{Result}>::Error>", *preludeScope) { - rustBlock("match s") { - context.sortedMembers.forEach { member -> - rust("${member.value.dq()} => Ok(${context.enumName}::${member.derivedName()}),") + } + + generateTryFromStrAndString(context) + } + + override fun implFromForStr(context: EnumGeneratorContext): Writable = + generateConstraintViolation(context) { + rustTemplate( + """ + impl #{TryFrom}<&str> for ${context.enumName} { + type Error = #{ConstraintViolation}; + fn try_from(s: &str) -> #{Result}>::Error> { + match s { + #{MatchArms} + _ => Err(#{ConstraintViolation}(s.to_owned())) } - rust("_ => Err(#T(s.to_owned()))", constraintViolationSymbol) } } - } + impl #{TryFrom}<#{String}> for ${context.enumName} { + type Error = #{ConstraintViolation}; + fn try_from(s: #{String}) -> #{Result}>::Error> { + s.as_str().try_into() + } + } + """, + *preludeScope, + "ConstraintViolation" to constraintViolationSymbol, + "MatchArms" to + writable { + context.sortedMembers.forEach { member -> + rust("${member.value.dq()} => Ok(${context.enumName}::${member.derivedName()}),") + } + }, + ) + } + + override fun implFromForStrForUnnamedEnum(context: EnumGeneratorContext): Writable = + generateConstraintViolation(context) { rustTemplate( """ + impl #{TryFrom}<&str> for ${context.enumName} { + type Error = #{ConstraintViolation}; + fn try_from(s: &str) -> #{Result}>::Error> { + s.to_owned().try_into() + } + } impl #{TryFrom}<#{String}> for ${context.enumName} { type Error = #{ConstraintViolation}; fn try_from(s: #{String}) -> #{Result}>::Error> { - s.as_str().try_into() + match s.as_str() { + #{Values} => Ok(Self(s)), + _ => Err(#{ConstraintViolation}(s)) + } } } """, *preludeScope, "ConstraintViolation" to constraintViolationSymbol, + "Values" to + writable { + rust(context.sortedMembers.joinToString(" | ") { it.value.dq() }) + }, ) } @@ -118,6 +154,8 @@ open class ConstrainedEnum( "ConstraintViolation" to constraintViolationSymbol, ) } + + override fun implFromStrForUnnamedEnum(context: EnumGeneratorContext) = implFromStr(context) } class ServerEnumGenerator( diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt index b0b4a01d28..36e30230c6 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt @@ -25,9 +25,12 @@ import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.traits.AbstractTrait import software.amazon.smithy.model.transform.ModelTransformer import software.amazon.smithy.protocol.traits.Rpcv2CborTrait +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.smithy.RustCrate import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams import software.amazon.smithy.rust.codegen.core.testutil.ServerAdditionalSettings import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.testutil.unitTest import software.amazon.smithy.rust.codegen.core.util.lookup import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestSymbolProvider @@ -223,11 +226,11 @@ class ConstraintsTest { primitiveBoolean.isDirectlyConstrained(symbolProvider) shouldBe false } - // TODO(#3895): Move tests that use `generateAndCompileServer` into `constraints.smithy` once issue is resolved private fun generateAndCompileServer( model: Model, pubConstraints: Boolean = true, dir: File? = null, + test: (ServerCodegenContext, RustCrate) -> Unit = { _, _ -> }, ) { if (dir?.exists() == true) { dir.deleteRecursively() @@ -244,7 +247,57 @@ class ConstraintsTest { .toObjectNode(), overrideTestDir = dir, ), - ) { _, _ -> + test = test, + ) + } + + @Test + fun `unnamed and named enums should validate and have an associated ConstraintViolation error type`() { + val model = + """ + namespace test + use aws.protocols#restJson1 + use smithy.framework#ValidationException + + @restJson1 + service SampleService { + operations: [SampleOp] + } + + @http(uri: "/dailySummary", method: "POST") + operation SampleOp { + input := { + unnamedDay: UnnamedDayOfWeek + namedDay: DayOfWeek + } + errors: [ValidationException] + } + @enum([ + { value: "MONDAY" }, + { value: "TUESDAY" } + ]) + string UnnamedDayOfWeek + @enum([ + { value: "MONDAY", name: "MONDAY" }, + { value: "TUESDAY", name: "TUESDAY" } + ]) + string DayOfWeek + """.asSmithyModel(smithyVersion = "2") + + generateAndCompileServer(model) { _, crate -> + crate.unitTest("value_should_be_validated") { + rustTemplate( + """ + let x: Result = + "Friday".try_into(); + assert!(x.is_err()); + + let x: Result = + "Friday".try_into(); + assert!(x.is_err()); + """, + ) + } } } @@ -295,6 +348,7 @@ class ConstraintsTest { generateAndCompileServer(model) } + // TODO(#3895): Move tests that use `generateAndCompileServer` into `constraints.smithy` once issue is resolved. @Test fun `constrained list with an indirectly constrained map should compile`() { val model =