Skip to content

Commit

Permalink
Improve client init time by switching to regex-lite (#3269)
Browse files Browse the repository at this point in the history
Each client initialization was taking between 1 and 2 milliseconds,
regardless if the client had been constructed before or not. For
example, if a customer wants five clients with different credentials
providers, that could be 10 milliseconds of time spent in
`Client::from_conf`. Approximately 98% of this time was spent compiling
regular expressions for the endpoint partition resolver.

This change switches everything over to the regex-lite crate, which has
faster regex compile times, and shouldn't have much of an impact on
performance for our specific use-cases (small strings, only evaluated at
client initialization).

The use of regex was entirely removed in aws-sigv4 since it was overkill
for what it was being used for.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
  • Loading branch information
jdisanti authored Nov 30, 2023
1 parent 6420816 commit 5b93fd2
Show file tree
Hide file tree
Showing 15 changed files with 116 additions and 106 deletions.
12 changes: 11 additions & 1 deletion CHANGELOG.next.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,14 @@
# message = "Fix typos in module documentation for generated crates"
# references = ["smithy-rs#920"]
# meta = { "breaking" = false, "tada" = false, "bug" = false, "target" = "client | server | all"}
# author = "rcoh"
# author = "rcoh"

[[aws-sdk-rust]]
message = """Client creation now takes microseconds instead of milliseconds.
Previously, it would take 2-3 milliseconds for each client instantiation due to time spent compiling regexes.
For applications that used several clients, this would increase start-up time in cases where it really matters,
such as for AWS Lambda cold starts. This time was improved by both changing regex implementation and caching the
result of the compilation."""
references = ["aws-sdk-rust#975", "smithy-rs#3269"]
meta = { "breaking" = false, "tada" = true, "bug" = false }
author = "jdisanti"
1 change: 0 additions & 1 deletion aws/rust-runtime/aws-sigv4/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ num-bigint = { version = "0.4", optional = true }
once_cell = "1.8"
p256 = { version = "0.11", features = ["ecdsa"], optional = true }
percent-encoding = { version = "2.1", optional = true }
regex = "1.5"
ring = { version = "0.17.5", optional = true }
sha2 = "0.10"
crypto-bigint = { version = "0.5.4", optional = true }
Expand Down
75 changes: 38 additions & 37 deletions aws/rust-runtime/aws-sigv4/src/http_request/canonical_request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -425,39 +425,37 @@ impl<'a> fmt::Display for CanonicalRequest<'a> {
}
}

/// A regex for matching on 2 or more spaces that acts on bytes.
static MULTIPLE_SPACES: once_cell::sync::Lazy<regex::bytes::Regex> =
once_cell::sync::Lazy::new(|| regex::bytes::Regex::new(r" {2,}").unwrap());

/// Removes excess spaces before and after a given byte string, and converts multiple sequential
/// spaces to a single space e.g. " Some example text " -> "Some example text".
///
/// This function ONLY affects spaces and not other kinds of whitespace.
fn trim_all(text: &[u8]) -> Cow<'_, [u8]> {
// The normal trim function will trim non-breaking spaces and other various whitespace chars.
// S3 ONLY trims spaces so we use trim_matches to trim spaces only
let text = trim_spaces_from_byte_string(text);
MULTIPLE_SPACES.replace_all(text, " ".as_bytes())
}

/// Removes excess spaces before and after a given byte string by returning a subset of those bytes.
/// Will return an empty slice if a string is composed entirely of whitespace.
fn trim_spaces_from_byte_string(bytes: &[u8]) -> &[u8] {
let starting_index = bytes.iter().position(|b| *b != b' ').unwrap_or(0);
let ending_offset = bytes.iter().rev().position(|b| *b != b' ').unwrap_or(0);
let ending_index = bytes.len() - ending_offset;
&bytes[starting_index..ending_index]
fn trim_all(text: &str) -> Cow<'_, str> {
let text = text.trim_matches(' ');
let requires_filter = text
.chars()
.zip(text.chars().skip(1))
.any(|(a, b)| a == ' ' && b == ' ');
if !requires_filter {
Cow::Borrowed(text)
} else {
// The normal trim function will trim non-breaking spaces and other various whitespace chars.
// S3 ONLY trims spaces so we use trim_matches to trim spaces only
Cow::Owned(
text.chars()
// Filter out consecutive spaces
.zip(text.chars().skip(1).chain(std::iter::once('!')))
.filter(|(a, b)| *a != ' ' || *b != ' ')
.map(|(a, _)| a)
.collect(),
)
}
}

/// Works just like [trim_all] but acts on HeaderValues instead of bytes.
/// Will ensure that the underlying bytes are valid UTF-8.
fn normalize_header_value(header_value: &str) -> Result<HeaderValue, CanonicalRequestError> {
let trimmed_value = trim_all(header_value.as_bytes());
HeaderValue::from_str(
std::str::from_utf8(&trimmed_value)
.map_err(CanonicalRequestError::invalid_utf8_in_header_value)?,
)
.map_err(CanonicalRequestError::from)
let trimmed_value = trim_all(header_value);
HeaderValue::from_str(&trimmed_value).map_err(CanonicalRequestError::from)
}

#[derive(Debug, PartialEq, Default)]
Expand Down Expand Up @@ -631,6 +629,7 @@ mod tests {
use http::{HeaderValue, Uri};
use pretty_assertions::assert_eq;
use proptest::{prelude::*, proptest};
use std::borrow::Cow;
use std::time::Duration;

fn signing_params(identity: &Identity, settings: SigningSettings) -> SigningParams<'_> {
Expand Down Expand Up @@ -982,32 +981,34 @@ mod tests {

#[test]
fn test_trim_all_handles_spaces_correctly() {
// Can't compare a byte array to a Cow so we convert both to slices before comparing
let expected = &b"Some example text"[..];
let actual = &trim_all(b" Some example text ")[..];

assert_eq!(expected, actual);
assert_eq!(Cow::Borrowed("don't touch me"), trim_all("don't touch me"));
assert_eq!("trim left", trim_all(" trim left"));
assert_eq!("trim right", trim_all("trim right "));
assert_eq!("trim both", trim_all(" trim both "));
assert_eq!("", trim_all(" "));
assert_eq!("", trim_all(" "));
assert_eq!("a b", trim_all(" a b "));
assert_eq!("Some example text", trim_all(" Some example text "));
}

#[test]
fn test_trim_all_ignores_other_forms_of_whitespace() {
// Can't compare a byte array to a Cow so we convert both to slices before comparing
let expected = &b"\t\xA0Some\xA0 example \xA0text\xA0\n"[..];
// \xA0 is a non-breaking space character
let actual = &trim_all(b"\t\xA0Some\xA0 example \xA0text\xA0\n")[..];

assert_eq!(expected, actual);
assert_eq!(
"\t\u{A0}Some\u{A0} example \u{A0}text\u{A0}\n",
trim_all("\t\u{A0}Some\u{A0} example \u{A0}text\u{A0}\n")
);
}

#[test]
fn trim_spaces_works_on_single_characters() {
assert_eq!(trim_all(b"2").as_ref(), b"2");
assert_eq!(trim_all("2").as_ref(), "2");
}

proptest! {
#[test]
fn test_trim_all_doesnt_elongate_strings(s in ".*") {
assert!(trim_all(s.as_bytes()).len() <= s.len())
assert!(trim_all(&s).len() <= s.len())
}

#[test]
Expand All @@ -1018,7 +1019,7 @@ mod tests {

#[test]
fn test_trim_all_does_nothing_when_there_are_no_spaces(s in "[^ ]*") {
assert_eq!(trim_all(s.as_bytes()).as_ref(), s.as_bytes());
assert_eq!(trim_all(&s).as_ref(), s);
}
}
}
10 changes: 0 additions & 10 deletions aws/rust-runtime/aws-sigv4/src/http_request/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ use http::header::{InvalidHeaderName, InvalidHeaderValue};
use http::uri::InvalidUri;
use std::error::Error;
use std::fmt;
use std::str::Utf8Error;

#[derive(Debug)]
enum SigningErrorKind {
Expand Down Expand Up @@ -63,7 +62,6 @@ impl From<CanonicalRequestError> for SigningError {
enum CanonicalRequestErrorKind {
InvalidHeaderName { source: InvalidHeaderName },
InvalidHeaderValue { source: InvalidHeaderValue },
InvalidUtf8InHeaderValue { source: Utf8Error },
InvalidUri { source: InvalidUri },
UnsupportedIdentityType,
}
Expand All @@ -79,7 +77,6 @@ impl fmt::Display for CanonicalRequestError {
match self.kind {
InvalidHeaderName { .. } => write!(f, "invalid header name"),
InvalidHeaderValue { .. } => write!(f, "invalid header value"),
InvalidUtf8InHeaderValue { .. } => write!(f, "invalid UTF-8 in header value"),
InvalidUri { .. } => write!(f, "the uri was invalid"),
UnsupportedIdentityType => {
write!(f, "only AWS credentials are supported for signing")
Expand All @@ -94,20 +91,13 @@ impl Error for CanonicalRequestError {
match &self.kind {
InvalidHeaderName { source } => Some(source),
InvalidHeaderValue { source } => Some(source),
InvalidUtf8InHeaderValue { source } => Some(source),
InvalidUri { source } => Some(source),
UnsupportedIdentityType => None,
}
}
}

impl CanonicalRequestError {
pub(crate) fn invalid_utf8_in_header_value(source: Utf8Error) -> Self {
Self {
kind: CanonicalRequestErrorKind::InvalidUtf8InHeaderValue { source },
}
}

pub(crate) fn unsupported_identity_type() -> Self {
Self {
kind: CanonicalRequestErrorKind::UnsupportedIdentityType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ import software.amazon.smithy.rulesengine.traits.ContextParamTrait
import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.generators.EndpointStdLib
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.generators.FunctionRegistry
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.InlineDependency
import software.amazon.smithy.rust.codegen.core.rustlang.RustDependency
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.RustType
import software.amazon.smithy.rust.codegen.core.rustlang.toType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.makeOptional
Expand All @@ -43,16 +45,33 @@ fun Identifier.rustName(): String {
}

/**
* Endpoints standard library file
* Endpoints standard library
*/
internal fun endpointsLib(name: String, vararg additionalDependency: RustDependency) = InlineDependency.forRustFile(
RustModule.pubCrate(
name,
parent = EndpointStdLib,
),
"/inlineable/src/endpoint_lib/$name.rs",
*additionalDependency,
)
object EndpointsLib {
val DiagnosticCollector = endpointsLib("diagnostic").toType().resolve("DiagnosticCollector")
fun PartitionResolver(runtimeConfig: RuntimeConfig) =
endpointsLib("partition", CargoDependency.smithyJson(runtimeConfig), CargoDependency.RegexLite).toType()
.resolve("PartitionResolver")

val substring = endpointsLib("substring").toType().resolve("substring")
val isValidHostLabel = endpointsLib("host").toType().resolve("is_valid_host_label")
val parseUrl = endpointsLib("parse_url", CargoDependency.Http, CargoDependency.Url).toType().resolve("parse_url")
val uriEncode = endpointsLib("uri_encode", CargoDependency.PercentEncoding).toType().resolve("uri_encode")

val awsParseArn = endpointsLib("arn").toType().resolve("parse_arn")
val awsIsVirtualHostableS3Bucket =
endpointsLib("s3", endpointsLib("host"), CargoDependency.OnceCell, CargoDependency.RegexLite).toType()
.resolve("is_virtual_hostable_s3_bucket")

private fun endpointsLib(name: String, vararg additionalDependency: RustDependency) = InlineDependency.forRustFile(
RustModule.pubCrate(
name,
parent = EndpointStdLib,
),
"/inlineable/src/endpoint_lib/$name.rs",
*additionalDependency,
)
}

class Types(runtimeConfig: RuntimeConfig) {
private val smithyTypesEndpointModule = RuntimeType.smithyTypes(runtimeConfig).resolve("endpoint")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.ClientRustModule
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.Context
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointTypesGenerator
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointsLib
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.Types
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.endpointsLib
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.memberName
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.rulesgen.ExpressionGenerator
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.rulesgen.Ownership
Expand All @@ -34,7 +34,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.escape
import software.amazon.smithy.rust.codegen.core.rustlang.join
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.toType
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
Expand Down Expand Up @@ -138,7 +137,7 @@ internal class EndpointResolverGenerator(
"ResolveEndpointError" to types.resolveEndpointError,
"EndpointError" to types.resolveEndpointError,
"ServiceSpecificEndpointResolver" to codegenContext.serviceSpecificEndpointResolver(),
"DiagnosticCollector" to endpointsLib("diagnostic").toType().resolve("DiagnosticCollector"),
"DiagnosticCollector" to EndpointsLib.DiagnosticCollector,
)

private val allowLintsForResolver = listOf(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
package software.amazon.smithy.rust.codegen.client.smithy.endpoint.rulesgen

import software.amazon.smithy.model.node.Node
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.endpointsLib
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointsLib
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.generators.CustomRuntimeFunction
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.generators.EndpointStdLib
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.toType
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
Expand All @@ -22,16 +22,10 @@ import software.amazon.smithy.rust.codegen.core.util.dq
* Standard library functions available to all generated crates (e.g. not `aws.` specific / prefixed)
*/
internal val SmithyEndpointsStdLib: List<CustomRuntimeFunction> = listOf(
SimpleRuntimeFunction("substring", endpointsLib("substring").toType().resolve("substring")),
SimpleRuntimeFunction("isValidHostLabel", endpointsLib("host").toType().resolve("is_valid_host_label")),
SimpleRuntimeFunction(
"parseURL",
endpointsLib("parse_url", CargoDependency.Http, CargoDependency.Url).toType().resolve("parse_url"),
),
SimpleRuntimeFunction(
"uriEncode",
endpointsLib("uri_encode", CargoDependency.PercentEncoding).toType().resolve("uri_encode"),
),
SimpleRuntimeFunction("substring", EndpointsLib.substring),
SimpleRuntimeFunction("isValidHostLabel", EndpointsLib.isValidHostLabel),
SimpleRuntimeFunction("parseURL", EndpointsLib.parseUrl),
SimpleRuntimeFunction("uriEncode", EndpointsLib.uriEncode),
)

/**
Expand All @@ -40,20 +34,9 @@ internal val SmithyEndpointsStdLib: List<CustomRuntimeFunction> = listOf(
* This is defined in client-codegen to support running tests—it is not used when generating smithy-native services.
*/
fun awsStandardLib(runtimeConfig: RuntimeConfig, partitionsDotJson: Node) = listOf(
SimpleRuntimeFunction("aws.parseArn", endpointsLib("arn").toType().resolve("parse_arn")),
SimpleRuntimeFunction(
"aws.isVirtualHostableS3Bucket",
endpointsLib(
"s3",
endpointsLib("host"),
CargoDependency.OnceCell,
CargoDependency.Regex,
).toType().resolve("is_virtual_hostable_s3_bucket"),
),
AwsPartitionResolver(
runtimeConfig,
partitionsDotJson,
),
SimpleRuntimeFunction("aws.parseArn", EndpointsLib.awsParseArn),
SimpleRuntimeFunction("aws.isVirtualHostableS3Bucket", EndpointsLib.awsIsVirtualHostableS3Bucket),
AwsPartitionResolver(runtimeConfig, partitionsDotJson),
)

/**
Expand All @@ -65,19 +48,26 @@ class AwsPartitionResolver(runtimeConfig: RuntimeConfig, private val partitionsD
CustomRuntimeFunction() {
override val id: String = "aws.partition"
private val codegenScope = arrayOf(
"PartitionResolver" to endpointsLib(
"partition",
CargoDependency.smithyJson(runtimeConfig),
CargoDependency.Regex,
).toType()
.resolve("PartitionResolver"),
"PartitionResolver" to EndpointsLib.PartitionResolver(runtimeConfig),
"Lazy" to CargoDependency.OnceCell.toType().resolve("sync::Lazy"),
)

override fun structFieldInit() = writable {
val json = Node.printJson(partitionsDotJson).dq()
rustTemplate(
"""partition_resolver: #{PartitionResolver}::new_from_json(b$json).expect("valid JSON")""",
"""partition_resolver: #{DEFAULT_PARTITION_RESOLVER}.clone()""",
*codegenScope,
"DEFAULT_PARTITION_RESOLVER" to RuntimeType.forInlineFun("DEFAULT_PARTITION_RESOLVER", EndpointStdLib) {
rustTemplate(
"""
// Loading the partition JSON is expensive since it involves many regex compilations,
// so cache the result so that it only need to be paid for the first constructed client.
pub(crate) static DEFAULT_PARTITION_RESOLVER: #{Lazy}<#{PartitionResolver}> =
#{Lazy}::new(|| #{PartitionResolver}::new_from_json(b$json).expect("valid JSON"));
""",
*codegenScope,
)
},
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ data class CargoDependency(
val Md5: CargoDependency = CargoDependency("md-5", CratesIo("0.10.0"), rustName = "md5")
val PercentEncoding: CargoDependency = CargoDependency("percent-encoding", CratesIo("2.0.0"))
val Regex: CargoDependency = CargoDependency("regex", CratesIo("1.5.5"))
val RegexLite: CargoDependency = CargoDependency("regex-lite", CratesIo("0.1.5"))
val Ring: CargoDependency = CargoDependency("ring", CratesIo("0.17.5"))
val TokioStream: CargoDependency = CargoDependency("tokio-stream", CratesIo("0.1.7"))
val Tower: CargoDependency = CargoDependency("tower", CratesIo("0.4"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ data class RuntimeType(val path: String, val dependency: RustDependency? = null)
val PercentEncoding = CargoDependency.PercentEncoding.toType()
val PrettyAssertions = CargoDependency.PrettyAssertions.toType()
val Regex = CargoDependency.Regex.toType()
val RegexLite = CargoDependency.RegexLite.toType()
val Tokio = CargoDependency.Tokio.toType()
val TokioStream = CargoDependency.TokioStream.toType()
val Tower = CargoDependency.Tower.toType()
Expand Down
2 changes: 1 addition & 1 deletion rust-runtime/aws-smithy-protocol-test/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ repository = "https://github.com/smithy-lang/smithy-rs"
assert-json-diff = "1.1"
http = "0.2.1"
pretty_assertions = "1.3"
regex = "1.5"
regex-lite = "0.1.5"
roxmltree = "0.14.1"
serde_json = "1"
thiserror = "1.0.40"
Expand Down
Loading

0 comments on commit 5b93fd2

Please sign in to comment.