Skip to content

Commit

Permalink
Refactor UnconstrainedUnionGenerator to use `ValidationExceptionCon…
Browse files Browse the repository at this point in the history
…versionGenerator` (#3733)

`UnconstrainedUnionGenerator` should use
`ValidationExceptionConversionGenerator` to generate the
`as_validation_exception` method.

---------

Co-authored-by: Fahad Zubair <[email protected]>
  • Loading branch information
drganjoo and Fahad Zubair authored Jul 2, 2024
1 parent 9af72f5 commit dc1ffb8
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ class PythonServerCodegenVisitor(
rustCrate.createInlineModuleCreator(),
this@modelsModuleWriter,
shape,
validationExceptionConversionGenerator,
).render()
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,7 @@ open class ServerCodegenVisitor(
rustCrate.createInlineModuleCreator(),
this@modelsModuleWriter,
shape,
validationExceptionConversionGenerator,
).render()
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.withBlock
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
Expand All @@ -31,6 +32,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.Length
import software.amazon.smithy.rust.codegen.server.smithy.generators.Pattern
import software.amazon.smithy.rust.codegen.server.smithy.generators.Range
import software.amazon.smithy.rust.codegen.server.smithy.generators.StringTraitInfo
import software.amazon.smithy.rust.codegen.server.smithy.generators.UnionConstraintTraitInfo
import software.amazon.smithy.rust.codegen.server.smithy.generators.ValidationExceptionConversionGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.isKeyConstrained
import software.amazon.smithy.rust.codegen.server.smithy.generators.isValueConstrained
Expand Down Expand Up @@ -320,4 +322,19 @@ class ValidationExceptionWithReasonConversionGenerator(private val codegenContex
"AsValidationExceptionFields" to validationExceptionFields.join("\n"),
)
}

override fun unionShapeConstraintViolationImplBlock(
unionConstraintTraitInfo: Collection<UnionConstraintTraitInfo>,
) = writable {
rustBlockTemplate(
"pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField",
"String" to RuntimeType.String,
) {
withBlock("match self {", "}") {
for (constraintViolation in unionConstraintTraitInfo) {
rust("""Self::${constraintViolation.name()}(inner) => inner.as_validation_exception_field(path + "/${constraintViolation.forMember.memberName}"),""")
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.withBlock
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
Expand All @@ -30,6 +31,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.ConstraintVi
import software.amazon.smithy.rust.codegen.server.smithy.generators.Range
import software.amazon.smithy.rust.codegen.server.smithy.generators.StringTraitInfo
import software.amazon.smithy.rust.codegen.server.smithy.generators.TraitInfo
import software.amazon.smithy.rust.codegen.server.smithy.generators.UnionConstraintTraitInfo
import software.amazon.smithy.rust.codegen.server.smithy.generators.ValidationExceptionConversionGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.isKeyConstrained
import software.amazon.smithy.rust.codegen.server.smithy.generators.isValueConstrained
Expand Down Expand Up @@ -244,4 +246,19 @@ class SmithyValidationExceptionConversionGenerator(private val codegenContext: S
"AsValidationExceptionFields" to validationExceptionFields.join(""),
)
}

override fun unionShapeConstraintViolationImplBlock(
unionConstraintTraitInfo: Collection<UnionConstraintTraitInfo>,
) = writable {
rustBlockTemplate(
"pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField",
"String" to RuntimeType.String,
) {
withBlock("match self {", "}") {
for (constraintViolation in unionConstraintTraitInfo) {
rust("""Self::${constraintViolation.name()}(inner) => inner.as_validation_exception_field(path + "/${constraintViolation.forMember.memberName}"),""")
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Visibility
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.withBlock
import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate
Expand Down Expand Up @@ -53,6 +52,7 @@ class UnconstrainedUnionGenerator(
private val inlineModuleCreator: InlineModuleCreator,
private val modelsModuleWriter: RustWriter,
val shape: UnionShape,
private val validationExceptionConversionGenerator: ValidationExceptionConversionGenerator,
) {
private val model = codegenContext.model
private val symbolProvider = codegenContext.symbolProvider
Expand Down Expand Up @@ -172,18 +172,15 @@ class UnconstrainedUnionGenerator(
)

if (shape.isReachableFromOperationInput()) {
rustBlock("impl $constraintViolationName") {
rustBlockTemplate(
"pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField",
"String" to RuntimeType.String,
) {
withBlock("match self {", "}") {
for (constraintViolation in constraintViolations()) {
rust("""Self::${constraintViolation.name()}(inner) => inner.as_validation_exception_field(path + "/${constraintViolation.forMember.memberName}"),""")
}
}
rustTemplate(
"""
impl $constraintViolationName {
#{UnionShapeConstraintViolationImplBlock:W}
}
}
""",
"UnionShapeConstraintViolationImplBlock" to
validationExceptionConversionGenerator.unionShapeConstraintViolationImplBlock(constraintViolations()),
)
}
}
}
Expand All @@ -199,30 +196,26 @@ class UnconstrainedUnionGenerator(
}
}

data class ConstraintViolation(val forMember: MemberShape) {
fun name() = forMember.memberName.toPascalCase()
}

private fun constraintViolations() =
sortedMembers
.filter { it.targetCanReachConstrainedShape(model, symbolProvider) }
.map { ConstraintViolation(it) }
.map { UnionConstraintTraitInfo(it) }

private fun renderConstraintViolation(
writer: RustWriter,
constraintViolation: ConstraintViolation,
unionConstraintTraitInfo: UnionConstraintTraitInfo,
) {
val targetShape = model.expectShape(constraintViolation.forMember.target)
val targetShape = model.expectShape(unionConstraintTraitInfo.forMember.target)

val constraintViolationSymbol =
constraintViolationSymbolProvider.toSymbol(targetShape)
// Box this constraint violation symbol if necessary.
.letIf(constraintViolation.forMember.hasTrait<ConstraintViolationRustBoxTrait>()) {
.letIf(unionConstraintTraitInfo.forMember.hasTrait<ConstraintViolationRustBoxTrait>()) {
it.makeRustBoxed()
}

writer.rust(
"${constraintViolation.name()}(#T),",
"${unionConstraintTraitInfo.name()}(#T),",
constraintViolationSymbol,
)
}
Expand Down Expand Up @@ -291,7 +284,7 @@ class UnconstrainedUnionGenerator(
{
let constrained: #{ConstrainedSymbol} = $unconstrainedVar
.try_into()$boxIt$boxErr
.map_err(Self::Error::${ConstraintViolation(member).name()})?;
.map_err(Self::Error::${UnionConstraintTraitInfo(member).name()})?;
constrained.into()
}
""",
Expand All @@ -304,9 +297,13 @@ class UnconstrainedUnionGenerator(
.try_into()
$boxIt
$boxErr
.map_err(Self::Error::${ConstraintViolation(member).name()})?
.map_err(Self::Error::${UnionConstraintTraitInfo(member).name()})?
""",
)
}
}
}

data class UnionConstraintTraitInfo(val forMember: MemberShape) {
fun name() = forMember.memberName.toPascalCase()
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,8 @@ interface ValidationExceptionConversionGenerator {
collectionConstraintsInfo: Collection<CollectionTraitInfo>,
isMemberConstrained: Boolean,
): Writable

fun unionShapeConstraintViolationImplBlock(
unionConstraintTraitInfo: Collection<UnionConstraintTraitInfo>,
): Writable
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import software.amazon.smithy.rust.codegen.core.testutil.unitTest
import software.amazon.smithy.rust.codegen.core.util.lookup
import software.amazon.smithy.rust.codegen.server.smithy.ServerRustModule
import software.amazon.smithy.rust.codegen.server.smithy.createInlineModuleCreator
import software.amazon.smithy.rust.codegen.server.smithy.customizations.SmithyValidationExceptionConversionGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerRestJsonProtocol
import software.amazon.smithy.rust.codegen.server.smithy.renderInlineMemoryModules
import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverRenderWithModelBuilder
Expand Down Expand Up @@ -63,7 +64,7 @@ class UnconstrainedUnionGeneratorTest {
TestUtility.generateIsError().invoke(this)

project.withModule(ServerRustModule.Model) modelsModuleWriter@{
UnconstrainedUnionGenerator(codegenContext, project.createInlineModuleCreator(), this@modelsModuleWriter, unionShape).render()
UnconstrainedUnionGenerator(codegenContext, project.createInlineModuleCreator(), this@modelsModuleWriter, unionShape, SmithyValidationExceptionConversionGenerator(codegenContext)).render()

this@unconstrainedModuleWriter.unitTest(
name = "unconstrained_union_fail_to_constrain",
Expand Down

0 comments on commit dc1ffb8

Please sign in to comment.