diff --git a/rust-runtime/aws-smithy-mocks-experimental/src/lib.rs b/rust-runtime/aws-smithy-mocks-experimental/src/lib.rs index 0a70a52786..cf69c201c8 100644 --- a/rust-runtime/aws-smithy-mocks-experimental/src/lib.rs +++ b/rust-runtime/aws-smithy-mocks-experimental/src/lib.rs @@ -108,6 +108,16 @@ where MockOutput::ModeledResponse(Arc::new(move || Ok(Output::erase(output())))), ) } + + /// If a rule matches, then return a specific error + pub fn then_error(self, output: impl Fn() -> E + Send + Sync + 'static) -> Rule { + Rule::new( + self.input_filter, + MockOutput::ModeledResponse(Arc::new(move || { + Err(OrchestratorError::operation(Error::erase(output()))) + })), + ) + } } #[derive(Clone)] diff --git a/rust-runtime/aws-smithy-mocks-experimental/tests/get-object-mocks.rs b/rust-runtime/aws-smithy-mocks-experimental/tests/get-object-mocks.rs index c8c234dcc3..0de8ad0076 100644 --- a/rust-runtime/aws-smithy-mocks-experimental/tests/get-object-mocks.rs +++ b/rust-runtime/aws-smithy-mocks-experimental/tests/get-object-mocks.rs @@ -5,9 +5,12 @@ use aws_sdk_s3::config::Region; use aws_sdk_s3::operation::get_object::{GetObjectError, GetObjectOutput}; +use aws_sdk_s3::operation::list_buckets::ListBucketsError; use aws_sdk_s3::{Client, Config}; use aws_smithy_types::body::SdkBody; use aws_smithy_types::byte_stream::ByteStream; +use aws_smithy_types::error::metadata::ProvideErrorMetadata; +use aws_smithy_types::error::ErrorMetadata; use aws_smithy_mocks_experimental::{mock, MockResponseInterceptor}; @@ -44,9 +47,14 @@ async fn create_mock_s3_get_object() { .build() }); + let modeled_error = mock!(Client::list_buckets).then_error(|| { + ListBucketsError::generic(ErrorMetadata::builder().code("InvalidAccessKey").build()) + }); + let get_object_mocks = MockResponseInterceptor::new() .with_rule(&s3_404) - .with_rule(&s3_real_object); + .with_rule(&s3_real_object) + .with_rule(&modeled_error); let s3 = aws_sdk_s3::Client::from_conf( Config::builder() @@ -83,4 +91,7 @@ async fn create_mock_s3_get_object() { .to_vec(); assert_eq!(data, b"test-test-test"); assert_eq!(s3_real_object.num_calls(), 1); + + let err = s3.list_buckets().send().await.expect_err("bad access key"); + assert_eq!(err.code(), Some("InvalidAccessKey")); }