Skip to content

Commit

Permalink
fix const names in codegen (#3639)
Browse files Browse the repository at this point in the history
  • Loading branch information
aws-sdk-rust-ci authored May 15, 2024
2 parents 1117dc7 + ce189b0 commit 1ccf22d
Show file tree
Hide file tree
Showing 30 changed files with 163 additions and 164 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ class RegionProviderConfig(codegenContext: ClientCodegenContext) : ConfigCustomi
}

is ServiceConfig.BuilderFromConfigBag -> {
rustTemplate("${section.builder}.set_region(${section.config_bag}.load::<#{Region}>().cloned());", *codegenScope)
rustTemplate("${section.builder}.set_region(${section.configBag}.load::<#{Region}>().cloned());", *codegenScope)
}

else -> emptySection
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ class UserAgentDecorator : ClientCodegenDecorator {

is ServiceConfig.BuilderFromConfigBag ->
writable {
rustTemplate("${section.builder}.set_app_name(${section.config_bag}.load::<#{AppName}>().cloned());", *codegenScope)
rustTemplate("${section.builder}.set_app_name(${section.configBag}.load::<#{AppName}>().cloned());", *codegenScope)
}

is ServiceConfig.BuilderBuild ->
Expand Down
8 changes: 4 additions & 4 deletions buildSrc/src/main/kotlin/CodegenTestCommon.kt
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ fun Project.registerGenerateSmithyBuildTask(

// If this is a rebuild, cache all the hashes of the generated Rust files. These are later used by the
// `modifyMtime` task.
project.extra[previousBuildHashesKey] =
project.extra[PREVIOUS_BUILD_HASHES_KEY] =
project.buildDir.walk()
.filter { it.isFile }
.map {
Expand Down Expand Up @@ -217,7 +217,7 @@ fun Project.registerGenerateCargoConfigTomlTask(outputDir: File) {
}
}

const val previousBuildHashesKey = "previousBuildHashes"
const val PREVIOUS_BUILD_HASHES_KEY = "previousBuildHashes"

fun Project.registerModifyMtimeTask() {
// Cargo uses `mtime` (among other factors) to determine whether a compilation unit needs a rebuild. While developing,
Expand All @@ -232,11 +232,11 @@ fun Project.registerModifyMtimeTask() {
dependsOn("generateSmithyBuild")

doFirst {
if (!project.extra.has(previousBuildHashesKey)) {
if (!project.extra.has(PREVIOUS_BUILD_HASHES_KEY)) {
println("No hashes from a previous build exist because `generateSmithyBuild` is up to date, skipping `mtime` fixups")
} else {
@Suppress("UNCHECKED_CAST")
val previousBuildHashes: Map<String, Long> = project.extra[previousBuildHashesKey] as Map<String, Long>
val previousBuildHashes: Map<String, Long> = project.extra[PREVIOUS_BUILD_HASHES_KEY] as Map<String, Long>

project.buildDir.walk()
.filter { it.isFile }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ val ClientReservedWords =
mapOf(
// Unions contain an `Unknown` variant. This exists to support parsing data returned from the server
// that represent union variants that have been added since this SDK was generated.
UnionGenerator.UnknownVariantName to "${UnionGenerator.UnknownVariantName}Value",
"${UnionGenerator.UnknownVariantName}Value" to "${UnionGenerator.UnknownVariantName}Value_",
UnionGenerator.UNKNOWN_VARIANT_NAME to "${UnionGenerator.UNKNOWN_VARIANT_NAME}Value",
"${UnionGenerator.UNKNOWN_VARIANT_NAME}Value" to "${UnionGenerator.UNKNOWN_VARIANT_NAME}Value_",
),
enumMemberMap =
mapOf(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,59 +86,59 @@ data class ClientRustSettings(
* [addMessageToErrors]: Adds a `message` field automatically to all error shapes
*/
data class ClientCodegenConfig(
override val formatTimeoutSeconds: Int = defaultFormatTimeoutSeconds,
override val debugMode: Boolean = defaultDebugMode,
override val flattenCollectionAccessors: Boolean = defaultFlattenAccessors,
override val formatTimeoutSeconds: Int = DEFAULT_FORMAT_TIMEOUT_SECONDS,
override val debugMode: Boolean = DEFAULT_DEBUG_MODE,
override val flattenCollectionAccessors: Boolean = DEFAULT_FLATTEN_ACCESSORS,
val nullabilityCheckMode: NullableIndex.CheckMode = NullableIndex.CheckMode.CLIENT,
val renameExceptions: Boolean = defaultRenameExceptions,
val includeFluentClient: Boolean = defaultIncludeFluentClient,
val addMessageToErrors: Boolean = defaultAddMessageToErrors,
val renameExceptions: Boolean = DEFAULT_RENAME_EXCEPTIONS,
val includeFluentClient: Boolean = DEFAULT_INCLUDE_FLUENT_CLIENT,
val addMessageToErrors: Boolean = DEFAULT_ADD_MESSAGE_TO_ERRORS,
// TODO(EventStream): [CLEANUP] Remove this property when turning on Event Stream for all services
val eventStreamAllowList: Set<String> = defaultEventStreamAllowList,
val eventStreamAllowList: Set<String> = DEFAULT_EVENT_STREAM_ALLOW_LIST,
/** If true, adds `endpoint_url`/`set_endpoint_url` methods to the service config */
val includeEndpointUrlConfig: Boolean = defaultIncludeEndpointUrlConfig,
val enableUserConfigurableRuntimePlugins: Boolean = defaultEnableUserConfigurableRuntimePlugins,
val includeEndpointUrlConfig: Boolean = DEFAULT_INCLUDE_ENDPOINT_URL_CONFIG,
val enableUserConfigurableRuntimePlugins: Boolean = DEFAULT_ENABLE_USER_CONFIGURABLE_RUNTIME_PLUGINS,
) : CoreCodegenConfig(
formatTimeoutSeconds, debugMode, defaultFlattenAccessors,
formatTimeoutSeconds, debugMode, DEFAULT_FLATTEN_ACCESSORS,
) {
companion object {
private const val defaultRenameExceptions = true
private const val defaultIncludeFluentClient = true
private const val defaultAddMessageToErrors = true
private val defaultEventStreamAllowList: Set<String> = emptySet()
private const val defaultIncludeEndpointUrlConfig = true
private const val defaultEnableUserConfigurableRuntimePlugins = true
private const val defaultNullabilityCheckMode = "CLIENT"
private const val DEFAULT_RENAME_EXCEPTIONS = true
private const val DEFAULT_INCLUDE_FLUENT_CLIENT = true
private const val DEFAULT_ADD_MESSAGE_TO_ERRORS = true
private val DEFAULT_EVENT_STREAM_ALLOW_LIST: Set<String> = emptySet()
private const val DEFAULT_INCLUDE_ENDPOINT_URL_CONFIG = true
private const val DEFAULT_ENABLE_USER_CONFIGURABLE_RUNTIME_PLUGINS = true
private const val DEFAULT_NULLABILITY_CHECK_MODE = "CLIENT"

// Note: only clients default to true, servers default to false
private const val defaultFlattenAccessors = true
private const val DEFAULT_FLATTEN_ACCESSORS = true

fun fromCodegenConfigAndNode(
coreCodegenConfig: CoreCodegenConfig,
node: Optional<ObjectNode>,
) = if (node.isPresent) {
ClientCodegenConfig(
formatTimeoutSeconds = coreCodegenConfig.formatTimeoutSeconds,
flattenCollectionAccessors = node.get().getBooleanMemberOrDefault("flattenCollectionAccessors", defaultFlattenAccessors),
flattenCollectionAccessors = node.get().getBooleanMemberOrDefault("flattenCollectionAccessors", DEFAULT_FLATTEN_ACCESSORS),
debugMode = coreCodegenConfig.debugMode,
eventStreamAllowList =
node.get().getArrayMember("eventStreamAllowList").map { array ->
array.toList().mapNotNull { node ->
node.asStringNode().orNull()?.value
}
}.orNull()?.toSet() ?: defaultEventStreamAllowList,
renameExceptions = node.get().getBooleanMemberOrDefault("renameErrors", defaultRenameExceptions),
includeFluentClient = node.get().getBooleanMemberOrDefault("includeFluentClient", defaultIncludeFluentClient),
addMessageToErrors = node.get().getBooleanMemberOrDefault("addMessageToErrors", defaultAddMessageToErrors),
includeEndpointUrlConfig = node.get().getBooleanMemberOrDefault("includeEndpointUrlConfig", defaultIncludeEndpointUrlConfig),
enableUserConfigurableRuntimePlugins = node.get().getBooleanMemberOrDefault("enableUserConfigurableRuntimePlugins", defaultEnableUserConfigurableRuntimePlugins),
nullabilityCheckMode = NullableIndex.CheckMode.valueOf(node.get().getStringMemberOrDefault("nullabilityCheckMode", defaultNullabilityCheckMode)),
}.orNull()?.toSet() ?: DEFAULT_EVENT_STREAM_ALLOW_LIST,
renameExceptions = node.get().getBooleanMemberOrDefault("renameErrors", DEFAULT_RENAME_EXCEPTIONS),
includeFluentClient = node.get().getBooleanMemberOrDefault("includeFluentClient", DEFAULT_INCLUDE_FLUENT_CLIENT),
addMessageToErrors = node.get().getBooleanMemberOrDefault("addMessageToErrors", DEFAULT_ADD_MESSAGE_TO_ERRORS),
includeEndpointUrlConfig = node.get().getBooleanMemberOrDefault("includeEndpointUrlConfig", DEFAULT_INCLUDE_ENDPOINT_URL_CONFIG),
enableUserConfigurableRuntimePlugins = node.get().getBooleanMemberOrDefault("enableUserConfigurableRuntimePlugins", DEFAULT_ENABLE_USER_CONFIGURABLE_RUNTIME_PLUGINS),
nullabilityCheckMode = NullableIndex.CheckMode.valueOf(node.get().getStringMemberOrDefault("nullabilityCheckMode", DEFAULT_NULLABILITY_CHECK_MODE)),
)
} else {
ClientCodegenConfig(
formatTimeoutSeconds = coreCodegenConfig.formatTimeoutSeconds,
debugMode = coreCodegenConfig.debugMode,
nullabilityCheckMode = NullableIndex.CheckMode.valueOf(defaultNullabilityCheckMode),
nullabilityCheckMode = NullableIndex.CheckMode.valueOf(DEFAULT_NULLABILITY_CHECK_MODE),
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,15 +292,15 @@ class ResiliencyConfigCustomization(codegenContext: ClientCodegenContext) : Conf

is ServiceConfig.BuilderFromConfigBag -> {
rustTemplate(
"${section.builder}.set_retry_config(${section.config_bag}.load::<#{RetryConfig}>().cloned());",
"${section.builder}.set_retry_config(${section.configBag}.load::<#{RetryConfig}>().cloned());",
*codegenScope,
)
rustTemplate(
"${section.builder}.set_timeout_config(${section.config_bag}.load::<#{TimeoutConfig}>().cloned());",
"${section.builder}.set_timeout_config(${section.configBag}.load::<#{TimeoutConfig}>().cloned());",
*codegenScope,
)
rustTemplate(
"${section.builder}.set_retry_partition(${section.config_bag}.load::<#{RetryPartition}>().cloned());",
"${section.builder}.set_retry_partition(${section.configBag}.load::<#{RetryPartition}>().cloned());",
*codegenScope,
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ internal class EndpointResolverGenerator(
private val context = Context(registry, runtimeConfig)

companion object {
const val DiagnosticCollector = "_diagnostic_collector"
private const val ParamsName = "_params"
const val DIAGNOSTIC_COLLECTOR = "_diagnostic_collector"
private const val PARAMS_NAME = "_params"
}

/**
Expand Down Expand Up @@ -224,7 +224,7 @@ internal class EndpointResolverGenerator(
Attribute(allow(allowLintsForResolver)).render(this)
rustTemplate(
"""
pub(super) fn resolve_endpoint($ParamsName: &#{Params}, $DiagnosticCollector: &mut #{DiagnosticCollector}, #{additional_args}) -> #{endpoint}::Result {
pub(super) fn resolve_endpoint($PARAMS_NAME: &#{Params}, $DIAGNOSTIC_COLLECTOR: &mut #{DiagnosticCollector}, #{additional_args}) -> #{endpoint}::Result {
#{body:W}
}
Expand All @@ -241,7 +241,7 @@ internal class EndpointResolverGenerator(
writable {
endpointRuleSet.parameters.toList().forEach {
Attribute.AllowUnusedVariables.render(this)
rust("let ${it.memberName()} = &$ParamsName.${it.memberName()};")
rust("let ${it.memberName()} = &$PARAMS_NAME.${it.memberName()};")
}
generateRulesList(endpointRuleSet.rules)(this)
}
Expand All @@ -256,7 +256,7 @@ internal class EndpointResolverGenerator(
// it's hard to figure out if these are always needed or not
Attribute.AllowUnreachableCode.render(this)
rustTemplate(
"""return Err(#{EndpointError}::message(format!("No rules matched these parameters. This is a bug. {:?}", $ParamsName)));""",
"""return Err(#{EndpointError}::message(format!("No rules matched these parameters. This is a bug. {:?}", $PARAMS_NAME)));""",
*codegenScope,
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class ExpressionGenerator(
val expressionGenerator = ExpressionGenerator(Ownership.Borrowed, context)
val argWritables = args.map { expressionGenerator.generate(it) }
rustTemplate(
"#{fn}(#{args}, ${EndpointResolverGenerator.DiagnosticCollector})",
"#{fn}(#{args}, ${EndpointResolverGenerator.DIAGNOSTIC_COLLECTOR})",
"fn" to fnDefinition.usage(),
"args" to argWritables.join(","),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ data class InfallibleEnumType(
) : EnumType() {
companion object {
/** Name of the generated unknown enum member name for enums with named members. */
const val UnknownVariant = "Unknown"
const val UNKNOWN_VARIANT = "Unknown"

/** Name of the opaque struct that is inner data for the generated [UnknownVariant]. */
const val UnknownVariantValue = "UnknownVariantValue"
/** Name of the opaque struct that is inner data for the generated [UNKNOWN_VARIANT]. */
const val UNKNOWN_VARIANT_VALUE = "UnknownVariantValue"
}

override fun implFromForStr(context: EnumGeneratorContext): Writable =
Expand All @@ -56,7 +56,7 @@ data class InfallibleEnumType(
rust("${member.value.dq()} => ${context.enumName}::${member.derivedName()},")
}
rust(
"other => ${context.enumName}::$UnknownVariant(#T(other.to_owned()))",
"other => ${context.enumName}::$UNKNOWN_VARIANT(#T(other.to_owned()))",
unknownVariantValue(context),
)
},
Expand Down Expand Up @@ -131,25 +131,25 @@ data class InfallibleEnumType(

override fun additionalDocs(context: EnumGeneratorContext): Writable =
writable {
renderForwardCompatibilityNote(context.enumName, context.sortedMembers, UnknownVariant, UnknownVariantValue)
renderForwardCompatibilityNote(context.enumName, context.sortedMembers, UNKNOWN_VARIANT, UNKNOWN_VARIANT_VALUE)
}

override fun additionalEnumMembers(context: EnumGeneratorContext): Writable =
writable {
docs("`$UnknownVariant` contains new variants that have been added since this code was generated.")
docs("`$UNKNOWN_VARIANT` contains new variants that have been added since this code was generated.")
rust(
"""##[deprecated(note = "Don't directly match on `$UnknownVariant`. See the docs on this enum for the correct way to handle unknown variants.")]""",
"""##[deprecated(note = "Don't directly match on `$UNKNOWN_VARIANT`. See the docs on this enum for the correct way to handle unknown variants.")]""",
)
rust("$UnknownVariant(#T)", unknownVariantValue(context))
rust("$UNKNOWN_VARIANT(#T)", unknownVariantValue(context))
}

override fun additionalAsStrMatchArms(context: EnumGeneratorContext): Writable =
writable {
rust("${context.enumName}::$UnknownVariant(value) => value.as_str()")
rust("${context.enumName}::$UNKNOWN_VARIANT(value) => value.as_str()")
}

private fun unknownVariantValue(context: EnumGeneratorContext): RuntimeType {
return RuntimeType.forInlineFun(UnknownVariantValue, unknownVariantModule) {
return RuntimeType.forInlineFun(UNKNOWN_VARIANT_VALUE, unknownVariantModule) {
docs(
"""
Opaque struct used as inner data for the `Unknown` variant defined in enums in
Expand All @@ -159,16 +159,16 @@ data class InfallibleEnumType(
""".trimIndent(),
)
context.enumMeta.render(this)
rustTemplate("struct $UnknownVariantValue(pub(crate) #{String});", *preludeScope)
rustBlock("impl $UnknownVariantValue") {
rustTemplate("struct $UNKNOWN_VARIANT_VALUE(pub(crate) #{String});", *preludeScope)
rustBlock("impl $UNKNOWN_VARIANT_VALUE") {
// The generated as_str is not pub as we need to prevent users from calling it on this opaque struct.
rustBlock("pub(crate) fn as_str(&self) -> &str") {
rust("&self.0")
}
}
rustTemplate(
"""
impl #{Display} for $UnknownVariantValue {
impl #{Display} for $UNKNOWN_VARIANT_VALUE {
fn fmt(&self, f: &mut #{Fmt}::Formatter) -> #{Fmt}::Result {
write!(f, "{}", self.0)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class ServiceGenerator(
)
serviceConfigGenerator.render(this)

// Enable users to opt in to the test-utils in the runtime crate
// Enable users to opt in to the `test-util` feature in the runtime crate
rustCrate.mergeFeature(TestUtilFeature.copy(deps = listOf("aws-smithy-runtime/test-util")))

ServiceRuntimePluginGenerator(codegenContext)
Expand Down
Loading

0 comments on commit 1ccf22d

Please sign in to comment.