From 10d17a148635a4eb98951a3c6ed86804e754325e Mon Sep 17 00:00:00 2001 From: Eric Meisel Date: Tue, 6 Jul 2021 16:48:50 -0400 Subject: [PATCH] Use Modify instead of Get / Update for Ref Updates (#134) * Use Modify instead of Get / Update for Ref Updates * Parallel PutRecords --- .../scala/kinesis/mock/PutRecordTests.scala | 6 +-- .../scala/kinesis/mock/PutRecordsTests.scala | 49 ++++++++++--------- .../mock/api/AddTagsToStreamRequest.scala | 10 ++-- .../mock/api/CreateStreamRequest.scala | 21 ++++---- ...DecreaseStreamRetentionPeriodRequest.scala | 13 +++-- .../mock/api/DeleteStreamRequest.scala | 21 ++++---- .../api/DeregisterStreamConsumerRequest.scala | 38 +++++++------- .../DisableEnhancedMonitoringRequest.scala | 30 ++++++------ .../api/EnableEnhancedMonitoringRequest.scala | 31 ++++++------ ...IncreaseStreamRetentionPeriodRequest.scala | 11 +++-- .../kinesis/mock/api/MergeShardsRequest.scala | 17 ++++--- .../kinesis/mock/api/PutRecordRequest.scala | 43 ++++++++-------- .../kinesis/mock/api/PutRecordsRequest.scala | 35 +++++++------ .../api/RegisterStreamConsumerRequest.scala | 27 +++++----- .../api/RemoveTagsFromStreamRequest.scala | 11 +++-- .../kinesis/mock/api/SplitShardRequest.scala | 17 ++++--- .../api/StartStreamEncryptionRequest.scala | 13 +++-- .../api/StopStreamEncryptionRequest.scala | 13 +++-- .../mock/api/UpdateShardCountRequest.scala | 17 ++++--- src/main/scala/kinesis/mock/cache/Cache.scala | 20 ++------ .../scala/kinesis/mock/syntax/either.scala | 19 +++++++ 21 files changed, 244 insertions(+), 218 deletions(-) create mode 100644 src/main/scala/kinesis/mock/syntax/either.scala diff --git a/src/fun/scala/kinesis/mock/PutRecordTests.scala b/src/fun/scala/kinesis/mock/PutRecordTests.scala index eab13b75..284e4612 100644 --- a/src/fun/scala/kinesis/mock/PutRecordTests.scala +++ b/src/fun/scala/kinesis/mock/PutRecordTests.scala @@ -18,7 +18,7 @@ class PutRecordTests extends munit.CatsEffectSuite with AwsFunctionalTests { for { recordRequests <- IO( putRecordRequestArb.arbitrary - .take(5) + .take(20) .toVector .map(_.copy(streamName = resources.streamName)) .map(x => @@ -34,7 +34,7 @@ class PutRecordTests extends munit.CatsEffectSuite with AwsFunctionalTests { .build() ) ) - _ <- recordRequests.traverse(x => + _ <- recordRequests.parTraverse(x => resources.kinesisClient.putRecord(x).toIO ) shards <- resources.kinesisClient @@ -68,7 +68,7 @@ class PutRecordTests extends munit.CatsEffectSuite with AwsFunctionalTests { ) res = gets.flatMap(_.records().asScala.toVector) } yield assert( - res.length == 5 && res.forall(rec => + res.length == 20 && res.forall(rec => recordRequests.exists(req => req.data.asByteArray.sameElements(rec.data.asByteArray) && req.partitionKey == rec.partitionKey diff --git a/src/fun/scala/kinesis/mock/PutRecordsTests.scala b/src/fun/scala/kinesis/mock/PutRecordsTests.scala index 225d96d5..cbf488ec 100644 --- a/src/fun/scala/kinesis/mock/PutRecordsTests.scala +++ b/src/fun/scala/kinesis/mock/PutRecordsTests.scala @@ -16,27 +16,29 @@ class PutRecordsTests extends munit.CatsEffectSuite with AwsFunctionalTests { fixture.test("It should put records") { resources => for { - req <- IO( - PutRecordsRequest - .builder() - .records( - putRecordsRequestEntryArb.arbitrary - .take(5) - .toVector - .map(x => - PutRecordsRequestEntry - .builder() - .data(SdkBytes.fromByteArray(x.data)) - .partitionKey(x.partitionKey) - .maybeTransform(x.explicitHashKey)(_.explicitHashKey(_)) - .build() - ) - .asJava - ) - .streamName(resources.streamName.streamName) - .build() + reqs <- IO( + List.fill(10)( + PutRecordsRequest + .builder() + .records( + putRecordsRequestEntryArb.arbitrary + .take(5) + .toVector + .map(x => + PutRecordsRequestEntry + .builder() + .data(SdkBytes.fromByteArray(x.data)) + .partitionKey(x.partitionKey) + .maybeTransform(x.explicitHashKey)(_.explicitHashKey(_)) + .build() + ) + .asJava + ) + .streamName(resources.streamName.streamName) + .build() + ) ) - _ <- resources.kinesisClient.putRecords(req).toIO + _ <- reqs.parTraverse(req => resources.kinesisClient.putRecords(req).toIO) shards <- resources.kinesisClient .listShards( ListShardsRequest @@ -67,14 +69,15 @@ class PutRecordsTests extends munit.CatsEffectSuite with AwsFunctionalTests { .toIO ) res = gets.flatMap(_.records().asScala.toVector) + records = reqs.flatMap(_.records.asScala.toVector) } yield assert( - res.length == 5 && res.forall(rec => - req.records.asScala.toVector.exists(req => + res.length == 50 && res.forall(rec => + records.exists(req => req.data.asByteArray.sameElements(rec.data.asByteArray) && req.partitionKey == rec.partitionKey ) ), - s"$res\n${req.records()}" + s"$res\n$records" ) } } diff --git a/src/main/scala/kinesis/mock/api/AddTagsToStreamRequest.scala b/src/main/scala/kinesis/mock/api/AddTagsToStreamRequest.scala index 474ae734..5992ad19 100644 --- a/src/main/scala/kinesis/mock/api/AddTagsToStreamRequest.scala +++ b/src/main/scala/kinesis/mock/api/AddTagsToStreamRequest.scala @@ -8,6 +8,7 @@ import cats.syntax.all._ import io.circe import kinesis.mock.models._ +import kinesis.mock.syntax.either._ import kinesis.mock.validations.CommonValidations // https://docs.aws.amazon.com/kinesis/latest/APIReference/API_AddTagsToStream.html @@ -19,7 +20,7 @@ final case class AddTagsToStreamRequest( ) { def addTagsToStream( streamsRef: Ref[IO, Streams] - ): IO[Response[Unit]] = streamsRef.get.flatMap { streams => + ): IO[Response[Unit]] = streamsRef.modify { streams => CommonValidations .validateStreamName(streamName) .flatMap(_ => @@ -62,11 +63,10 @@ final case class AddTagsToStreamRequest( ).mapN((_, _, _, _, _) => stream) ) ) - .traverse(stream => - streamsRef.update(x => - x.updateStream(stream.copy(tags = stream.tags |+| tags)) - ) + .map(stream => + (streams.updateStream(stream.copy(tags = stream.tags |+| tags)), ()) ) + .sequenceWithDefault(streams) } } diff --git a/src/main/scala/kinesis/mock/api/CreateStreamRequest.scala b/src/main/scala/kinesis/mock/api/CreateStreamRequest.scala index 848a8931..d2eff21d 100644 --- a/src/main/scala/kinesis/mock/api/CreateStreamRequest.scala +++ b/src/main/scala/kinesis/mock/api/CreateStreamRequest.scala @@ -2,12 +2,13 @@ package kinesis.mock package api import cats.Eq +import cats.effect.IO import cats.effect.concurrent.Ref -import cats.effect.{Concurrent, IO} import cats.syntax.all._ import io.circe import kinesis.mock.models._ +import kinesis.mock.syntax.either._ import kinesis.mock.validations.CommonValidations final case class CreateStreamRequest(shardCount: Int, streamName: StreamName) { @@ -16,8 +17,8 @@ final case class CreateStreamRequest(shardCount: Int, streamName: StreamName) { shardLimit: Int, awsRegion: AwsRegion, awsAccountId: AwsAccountId - )(implicit C: Concurrent[IO]): IO[Response[Unit]] = - streamsRef.get.flatMap { streams => + ): IO[Response[Unit]] = + streamsRef.modify { streams => ( CommonValidations.validateStreamName(streamName), if (streams.streams.contains(streamName)) @@ -36,16 +37,14 @@ final case class CreateStreamRequest(shardCount: Int, streamName: StreamName) { ).asLeft else Right(()), CommonValidations.validateShardLimit(shardCount, streams, shardLimit) - ).traverseN { (_, _, _, _, _) => + ).mapN { (_, _, _, _, _) => val newStream = StreamData.create(shardCount, streamName, awsRegion, awsAccountId) - for { - res <- streamsRef - .update(x => - x.copy(streams = x.streams + (streamName -> newStream)) - ) - } yield res - } + ( + streams.copy(streams = streams.streams + (streamName -> newStream)), + () + ) + }.sequenceWithDefault(streams) } } diff --git a/src/main/scala/kinesis/mock/api/DecreaseStreamRetentionPeriodRequest.scala b/src/main/scala/kinesis/mock/api/DecreaseStreamRetentionPeriodRequest.scala index f6437cad..d309b93f 100644 --- a/src/main/scala/kinesis/mock/api/DecreaseStreamRetentionPeriodRequest.scala +++ b/src/main/scala/kinesis/mock/api/DecreaseStreamRetentionPeriodRequest.scala @@ -10,6 +10,7 @@ import cats.syntax.all._ import io.circe import kinesis.mock.models._ +import kinesis.mock.syntax.either._ import kinesis.mock.validations.CommonValidations // https://docs.aws.amazon.com/kinesis/latest/APIReference/API_DecreaseStreamRetention.html @@ -19,7 +20,7 @@ final case class DecreaseStreamRetentionPeriodRequest( ) { def decreaseStreamRetention( streamsRef: Ref[IO, Streams] - ): IO[Response[Unit]] = streamsRef.get.flatMap { streams => + ): IO[Response[Unit]] = streamsRef.modify { streams => CommonValidations .validateStreamName(streamName) .flatMap(_ => @@ -38,13 +39,15 @@ final case class DecreaseStreamRetentionPeriodRequest( ).mapN((_, _, _) => stream) ) ) - .traverse(stream => - streamsRef.update(x => - x.updateStream( + .map(stream => + ( + streams.updateStream( stream.copy(retentionPeriod = retentionPeriodHours.hours) - ) + ), + () ) ) + .sequenceWithDefault(streams) } } diff --git a/src/main/scala/kinesis/mock/api/DeleteStreamRequest.scala b/src/main/scala/kinesis/mock/api/DeleteStreamRequest.scala index 0ce5e818..87ff17c0 100644 --- a/src/main/scala/kinesis/mock/api/DeleteStreamRequest.scala +++ b/src/main/scala/kinesis/mock/api/DeleteStreamRequest.scala @@ -8,6 +8,7 @@ import cats.syntax.all._ import io.circe import kinesis.mock.models._ +import kinesis.mock.syntax.either._ import kinesis.mock.validations.CommonValidations final case class DeleteStreamRequest( @@ -17,7 +18,7 @@ final case class DeleteStreamRequest( def deleteStream( streamsRef: Ref[IO, Streams] ): IO[Response[Unit]] = - streamsRef.get.flatMap { streams => + streamsRef.modify { streams => CommonValidations .validateStreamName(streamName) .flatMap(_ => @@ -37,7 +38,7 @@ final case class DeleteStreamRequest( ).mapN((_, _) => stream) ) ) - .traverse { stream => + .map { stream => val deletingStream = Map( streamName -> stream.copy( shards = Map.empty, @@ -47,16 +48,14 @@ final case class DeleteStreamRequest( consumers = Map.empty ) ) - - for { - res <- streamsRef - .update(streams => - streams.copy( - streams = streams.streams ++ deletingStream - ) - ) - } yield res + ( + streams.copy( + streams = streams.streams ++ deletingStream + ), + () + ) } + .sequenceWithDefault(streams) } } diff --git a/src/main/scala/kinesis/mock/api/DeregisterStreamConsumerRequest.scala b/src/main/scala/kinesis/mock/api/DeregisterStreamConsumerRequest.scala index 83773462..da78c111 100644 --- a/src/main/scala/kinesis/mock/api/DeregisterStreamConsumerRequest.scala +++ b/src/main/scala/kinesis/mock/api/DeregisterStreamConsumerRequest.scala @@ -8,6 +8,7 @@ import cats.syntax.all._ import io.circe import kinesis.mock.models._ +import kinesis.mock.syntax.either._ import kinesis.mock.validations.CommonValidations // https://docs.aws.amazon.com/kinesis/latest/APIReference/API_DeregisterStreamConsumer.html @@ -17,26 +18,26 @@ final case class DeregisterStreamConsumerRequest( streamArn: Option[String] ) { private def deregister( - streamsRef: Ref[IO, Streams], + streams: Streams, consumer: Consumer, stream: StreamData - ): IO[Consumer] = { - val newConsumer = consumer.copy(consumerStatus = ConsumerStatus.DELETING) + ): (Streams, Consumer) = { + val newConsumer = + consumer.copy(consumerStatus = ConsumerStatus.DELETING) - streamsRef - .update(x => - x.updateStream( - stream.copy(consumers = - stream.consumers + (consumer.consumerName -> newConsumer) - ) + ( + streams.updateStream( + stream.copy(consumers = + stream.consumers + (consumer.consumerName -> newConsumer) ) - ) - .as(newConsumer) + ), + newConsumer + ) } def deregisterStreamConsumer( streamsRef: Ref[IO, Streams] - ): IO[Response[Consumer]] = streamsRef.get.flatMap { streams => + ): IO[Response[Consumer]] = streamsRef.modify { streams => (consumerArn, consumerName, streamArn) match { case (Some(cArn), _, _) => CommonValidations @@ -50,9 +51,10 @@ final case class DeregisterStreamConsumerRequest( s"Consumer $consumerName is not in an ACTIVE state" ).asLeft } - .traverse { case (consumer, stream) => - deregister(streamsRef, consumer, stream) + .map { case (consumer, stream) => + deregister(streams, consumer, stream) } + .sequenceWithDefault(streams) case (None, Some(cName), Some(sArn)) => CommonValidations .findStreamByArn(sArn, streams) @@ -68,11 +70,13 @@ final case class DeregisterStreamConsumerRequest( } } - .traverse { case (consumer, stream) => - deregister(streamsRef, consumer, stream) + .map { case (consumer, stream) => + deregister(streams, consumer, stream) } + .sequenceWithDefault(streams) case _ => - IO( + ( + streams, InvalidArgumentException( "ConsumerArn or both ConsumerName and StreamARN are required for this request." ).asLeft diff --git a/src/main/scala/kinesis/mock/api/DisableEnhancedMonitoringRequest.scala b/src/main/scala/kinesis/mock/api/DisableEnhancedMonitoringRequest.scala index 055d825d..51e5b829 100644 --- a/src/main/scala/kinesis/mock/api/DisableEnhancedMonitoringRequest.scala +++ b/src/main/scala/kinesis/mock/api/DisableEnhancedMonitoringRequest.scala @@ -4,10 +4,10 @@ package api import cats.Eq import cats.effect.IO import cats.effect.concurrent.Ref -import cats.syntax.all._ import io.circe import kinesis.mock.models._ +import kinesis.mock.syntax.either._ import kinesis.mock.validations.CommonValidations // https://docs.aws.amazon.com/kinesis/latest/APIReference/API_DisableEnhancedMonitoring.html @@ -18,11 +18,11 @@ final case class DisableEnhancedMonitoringRequest( def disableEnhancedMonitoring( streamsRef: Ref[IO, Streams] ): IO[Response[DisableEnhancedMonitoringResponse]] = - streamsRef.get.flatMap { streams => + streamsRef.modify { streams => CommonValidations .validateStreamName(streamName) .flatMap(_ => CommonValidations.findStream(streamName, streams)) - .traverse { stream => + .map { stream => val current = stream.enhancedMonitoring.flatMap(_.shardLevelMetrics) val desired = @@ -30,21 +30,19 @@ final case class DisableEnhancedMonitoringRequest( Vector.empty else current.diff(shardLevelMetrics) - streamsRef - .update(x => - x.updateStream( - stream - .copy(enhancedMonitoring = Vector(ShardLevelMetrics(desired))) - ) - ) - .as( - DisableEnhancedMonitoringResponse( - current, - desired, - streamName - ) + ( + streams.updateStream( + stream + .copy(enhancedMonitoring = Vector(ShardLevelMetrics(desired))) + ), + DisableEnhancedMonitoringResponse( + current, + desired, + streamName ) + ) } + .sequenceWithDefault(streams) } } diff --git a/src/main/scala/kinesis/mock/api/EnableEnhancedMonitoringRequest.scala b/src/main/scala/kinesis/mock/api/EnableEnhancedMonitoringRequest.scala index ed6389f4..5ce1d2c1 100644 --- a/src/main/scala/kinesis/mock/api/EnableEnhancedMonitoringRequest.scala +++ b/src/main/scala/kinesis/mock/api/EnableEnhancedMonitoringRequest.scala @@ -4,10 +4,10 @@ package api import cats.Eq import cats.effect.IO import cats.effect.concurrent.Ref -import cats.syntax.all._ import io.circe import kinesis.mock.models._ +import kinesis.mock.syntax.either._ import kinesis.mock.validations.CommonValidations // https://docs.aws.amazon.com/kinesis/latest/APIReference/API_EnableEnhancedMonitoring.html @@ -18,11 +18,11 @@ final case class EnableEnhancedMonitoringRequest( def enableEnhancedMonitoring( streamsRef: Ref[IO, Streams] ): IO[Response[EnableEnhancedMonitoringResponse]] = - streamsRef.get.flatMap { streams => + streamsRef.modify { streams => CommonValidations .validateStreamName(streamName) .flatMap(_ => CommonValidations.findStream(streamName, streams)) - .traverse { stream => + .map { stream => val current = stream.enhancedMonitoring.flatMap(_.shardLevelMetrics) val desired = if (shardLevelMetrics.contains(ShardLevelMetric.ALL)) @@ -30,22 +30,19 @@ final case class EnableEnhancedMonitoringRequest( .filterNot(_ == ShardLevelMetric.ALL) .toVector else (current ++ shardLevelMetrics).distinct - - streamsRef - .update(x => - x.updateStream( - stream - .copy(enhancedMonitoring = Vector(ShardLevelMetrics(desired))) - ) - ) - .as( - EnableEnhancedMonitoringResponse( - current, - desired, - streamName - ) + ( + streams.updateStream( + stream + .copy(enhancedMonitoring = Vector(ShardLevelMetrics(desired))) + ), + EnableEnhancedMonitoringResponse( + current, + desired, + streamName ) + ) } + .sequenceWithDefault(streams) } } diff --git a/src/main/scala/kinesis/mock/api/IncreaseStreamRetentionPeriodRequest.scala b/src/main/scala/kinesis/mock/api/IncreaseStreamRetentionPeriodRequest.scala index b13dbd17..801e3bae 100644 --- a/src/main/scala/kinesis/mock/api/IncreaseStreamRetentionPeriodRequest.scala +++ b/src/main/scala/kinesis/mock/api/IncreaseStreamRetentionPeriodRequest.scala @@ -10,6 +10,7 @@ import cats.syntax.all._ import io.circe import kinesis.mock.models._ +import kinesis.mock.syntax.either._ import kinesis.mock.validations.CommonValidations // https://docs.aws.amazon.com/kinesis/latest/APIReference/API_IncreaseStreamRetention.html @@ -19,7 +20,7 @@ final case class IncreaseStreamRetentionPeriodRequest( ) { def increaseStreamRetention( streamsRef: Ref[IO, Streams] - ): IO[Response[Unit]] = streamsRef.get.flatMap { streams => + ): IO[Response[Unit]] = streamsRef.modify { streams => CommonValidations .validateStreamName(streamName) .flatMap(_ => @@ -38,13 +39,15 @@ final case class IncreaseStreamRetentionPeriodRequest( ).mapN((_, _, _) => stream) ) ) - .traverse(stream => - streamsRef.update(streams => + .map(stream => + ( streams.updateStream( stream.copy(retentionPeriod = retentionPeriodHours.hours) - ) + ), + () ) ) + .sequenceWithDefault(streams) } } diff --git a/src/main/scala/kinesis/mock/api/MergeShardsRequest.scala b/src/main/scala/kinesis/mock/api/MergeShardsRequest.scala index 84ef1240..1901009e 100644 --- a/src/main/scala/kinesis/mock/api/MergeShardsRequest.scala +++ b/src/main/scala/kinesis/mock/api/MergeShardsRequest.scala @@ -4,12 +4,13 @@ package api import java.time.Instant import cats.Eq +import cats.effect.IO import cats.effect.concurrent.Ref -import cats.effect.{Concurrent, IO} import cats.syntax.all._ import io.circe import kinesis.mock.models._ +import kinesis.mock.syntax.either._ import kinesis.mock.validations.CommonValidations // https://docs.aws.amazon.com/kinesis/latest/APIReference/API_MergeShards.html @@ -20,8 +21,8 @@ final case class MergeShardsRequest( ) { def mergeShards( streamsRef: Ref[IO, Streams] - )(implicit C: Concurrent[IO]): IO[Response[Unit]] = - streamsRef.get.flatMap(streams => + ): IO[Response[Unit]] = + streamsRef.modify(streams => CommonValidations .validateStreamName(streamName) .flatMap(_ => @@ -69,7 +70,7 @@ final case class MergeShardsRequest( } } ) - .traverse { + .map { case ( stream, (adjacentShard, adjacentData), @@ -116,8 +117,8 @@ final case class MergeShardsRequest( .copy(endingSequenceNumber = Some(SequenceNumber.shardEnd)) ) -> shardData ) - streamsRef.update(x => - x.updateStream( + ( + streams.updateStream( stream.copy( shards = stream.shards.filterNot { case (s, _) => s.shardId == adjacentShard.shardId || s.shardId == shard.shardId @@ -125,9 +126,11 @@ final case class MergeShardsRequest( ++ (oldShards :+ newShard), streamStatus = StreamStatus.UPDATING ) - ) + ), + () ) } + .sequenceWithDefault(streams) ) } diff --git a/src/main/scala/kinesis/mock/api/PutRecordRequest.scala b/src/main/scala/kinesis/mock/api/PutRecordRequest.scala index a0eb4e8b..33be6117 100644 --- a/src/main/scala/kinesis/mock/api/PutRecordRequest.scala +++ b/src/main/scala/kinesis/mock/api/PutRecordRequest.scala @@ -11,6 +11,7 @@ import io.circe import kinesis.mock.instances.circe._ import kinesis.mock.models._ +import kinesis.mock.syntax.either._ import kinesis.mock.validations.CommonValidations final case class PutRecordRequest( @@ -22,7 +23,7 @@ final case class PutRecordRequest( ) { def putRecord( streamsRef: Ref[IO, Streams] - ): IO[Response[PutRecordResponse]] = streamsRef.get.flatMap { streams => + ): IO[Response[PutRecordResponse]] = streamsRef.modify { streams => val now = Instant.now() CommonValidations .validateStreamName(streamName) @@ -64,7 +65,7 @@ final case class PutRecordRequest( } } ) - .traverse { case (stream, shard, records) => + .map { case (stream, shard, records) => val seqNo = SequenceNumber.create( shard.createdAtTimestamp, shard.shardId.index, @@ -72,30 +73,28 @@ final case class PutRecordRequest( Some(records.length), Some(now) ) - streamsRef - .update(x => - x.updateStream { - stream.copy( - shards = stream.shards ++ Map( - shard -> (records :+ KinesisRecord( - now, - data, - stream.encryptionType, - partitionKey, - seqNo - )) - ) + ( + streams.updateStream { + stream.copy( + shards = stream.shards ++ Map( + shard -> (records :+ KinesisRecord( + now, + data, + stream.encryptionType, + partitionKey, + seqNo + )) ) - } - ) - .as( - PutRecordResponse( - stream.encryptionType, - seqNo, - shard.shardId.shardId ) + }, + PutRecordResponse( + stream.encryptionType, + seqNo, + shard.shardId.shardId ) + ) } + .sequenceWithDefault(streams) } } diff --git a/src/main/scala/kinesis/mock/api/PutRecordsRequest.scala b/src/main/scala/kinesis/mock/api/PutRecordsRequest.scala index 297a723a..b19ca5e6 100644 --- a/src/main/scala/kinesis/mock/api/PutRecordsRequest.scala +++ b/src/main/scala/kinesis/mock/api/PutRecordsRequest.scala @@ -10,6 +10,7 @@ import cats.syntax.all._ import io.circe import kinesis.mock.models._ +import kinesis.mock.syntax.either._ import kinesis.mock.validations.CommonValidations final case class PutRecordsRequest( @@ -19,7 +20,7 @@ final case class PutRecordsRequest( def putRecords( streamsRef: Ref[IO, Streams] ): IO[Response[PutRecordsResponse]] = - streamsRef.get.flatMap { streams => + streamsRef.modify[Response[PutRecordsResponse]] { streams => val now = Instant.now() CommonValidations .validateStreamName(streamName) @@ -52,7 +53,7 @@ final case class PutRecordsRequest( ).mapN((_, recs) => (stream, recs)) } ) - .traverse { case (stream, recs) => + .map { case (stream, recs) => val grouped = recs .groupBy { case (shard, records, _) => (shard, records) @@ -97,26 +98,24 @@ final case class PutRecordsRequest( ) } - streamsRef - .update(x => - x.updateStream( - stream.copy( - shards = stream.shards ++ newShards.map { - case (shard, (records, _)) => shard -> records - } - ) - ) - ) - .as( - PutRecordsResponse( - stream.encryptionType, - 0, - newShards.flatMap { case (_, (_, resultEntries)) => - resultEntries + ( + streams.updateStream( + stream.copy( + shards = stream.shards ++ newShards.map { + case (shard, (records, _)) => shard -> records } ) + ), + PutRecordsResponse( + stream.encryptionType, + 0, + newShards.flatMap { case (_, (_, resultEntries)) => + resultEntries + } ) + ) } + .sequenceWithDefault(streams) } } diff --git a/src/main/scala/kinesis/mock/api/RegisterStreamConsumerRequest.scala b/src/main/scala/kinesis/mock/api/RegisterStreamConsumerRequest.scala index a1113a58..068db73f 100644 --- a/src/main/scala/kinesis/mock/api/RegisterStreamConsumerRequest.scala +++ b/src/main/scala/kinesis/mock/api/RegisterStreamConsumerRequest.scala @@ -8,6 +8,7 @@ import cats.syntax.all._ import io.circe import kinesis.mock.models._ +import kinesis.mock.syntax.either._ import kinesis.mock.validations.CommonValidations // https://docs.aws.amazon.com/kinesis/latest/APIReference/API_RegisterStreamConsumer.html @@ -18,7 +19,7 @@ final case class RegisterStreamConsumerRequest( def registerStreamConsumer( streamsRef: Ref[IO, Streams] ): IO[Response[RegisterStreamConsumerResponse]] = - streamsRef.get.flatMap { streams => + streamsRef.modify { streams => CommonValidations .validateStreamArn(streamArn) .flatMap(_ => @@ -48,24 +49,20 @@ final case class RegisterStreamConsumerRequest( ).mapN { (_, _, _, _) => (stream, streamArn, consumerName) } ) ) - .traverse { case (stream, streamArn, consumerName) => + .map { case (stream, streamArn, consumerName) => val consumer = Consumer.create(streamArn, consumerName) - streamsRef - .update(x => - x.updateStream( - stream - .copy(consumers = - stream.consumers + (consumerName -> consumer) - ) - ) - ) - .as( - RegisterStreamConsumerResponse( - ConsumerSummary.fromConsumer(consumer) - ) + ( + streams.updateStream( + stream + .copy(consumers = stream.consumers + (consumerName -> consumer)) + ), + RegisterStreamConsumerResponse( + ConsumerSummary.fromConsumer(consumer) ) + ) } + .sequenceWithDefault(streams) } } diff --git a/src/main/scala/kinesis/mock/api/RemoveTagsFromStreamRequest.scala b/src/main/scala/kinesis/mock/api/RemoveTagsFromStreamRequest.scala index 73232f68..1da1c35d 100644 --- a/src/main/scala/kinesis/mock/api/RemoveTagsFromStreamRequest.scala +++ b/src/main/scala/kinesis/mock/api/RemoveTagsFromStreamRequest.scala @@ -8,6 +8,7 @@ import cats.syntax.all._ import io.circe import kinesis.mock.models._ +import kinesis.mock.syntax.either._ import kinesis.mock.validations.CommonValidations final case class RemoveTagsFromStreamRequest( @@ -19,7 +20,7 @@ final case class RemoveTagsFromStreamRequest( // https://docs.aws.amazon.com/directoryservice/latest/devguide/API_Tag.html def removeTagsFromStream( streamsRef: Ref[IO, Streams] - ): IO[Response[Unit]] = streamsRef.get.flatMap(streams => + ): IO[Response[Unit]] = streamsRef.modify(streams => CommonValidations .validateStreamName(streamName) .flatMap(_ => @@ -38,11 +39,13 @@ final case class RemoveTagsFromStreamRequest( ).mapN((_, _) => stream) ) ) - .traverse(stream => - streamsRef.update(x => - x.updateStream(stream.copy(tags = stream.tags -- tagKeys)) + .map(stream => + ( + streams.updateStream(stream.copy(tags = stream.tags -- tagKeys)), + () ) ) + .sequenceWithDefault(streams) ) } diff --git a/src/main/scala/kinesis/mock/api/SplitShardRequest.scala b/src/main/scala/kinesis/mock/api/SplitShardRequest.scala index 61e62d27..d235fb7c 100644 --- a/src/main/scala/kinesis/mock/api/SplitShardRequest.scala +++ b/src/main/scala/kinesis/mock/api/SplitShardRequest.scala @@ -4,12 +4,13 @@ package api import java.time.Instant import cats.Eq +import cats.effect.IO import cats.effect.concurrent.Ref -import cats.effect.{Concurrent, IO} import cats.syntax.all._ import io.circe import kinesis.mock.models._ +import kinesis.mock.syntax.either._ import kinesis.mock.validations.CommonValidations // https://docs.aws.amazon.com/kinesis/latest/APIReference/API_SplitShard.html @@ -21,8 +22,8 @@ final case class SplitShardRequest( def splitShard( streamsRef: Ref[IO, Streams], shardLimit: Int - )(implicit C: Concurrent[IO]): IO[Response[Unit]] = - streamsRef.get.flatMap { streams => + ): IO[Response[Unit]] = + streamsRef.modify { streams => CommonValidations .validateStreamName(streamName) .flatMap(_ => @@ -63,7 +64,7 @@ final case class SplitShardRequest( } } ) - .traverse { case (shard, shardData, stream) => + .map { case (shard, shardData, stream) => val now = Instant.now() val newStartingHashKeyNumber = BigInt(newStartingHashKey) val newShardIndex1 = stream.shards.keys.map(_.shardId.index).max + 1 @@ -109,17 +110,19 @@ final case class SplitShardRequest( ) ) -> shardData - streamsRef.update(x => - x.updateStream( + ( + streams.updateStream( stream.copy( shards = stream.shards.filterNot { case (shard, _) => shard.shardId == oldShard._1.shardId } ++ (newShards :+ oldShard), streamStatus = StreamStatus.UPDATING ) - ) + ), + () ) } + .sequenceWithDefault(streams) } } diff --git a/src/main/scala/kinesis/mock/api/StartStreamEncryptionRequest.scala b/src/main/scala/kinesis/mock/api/StartStreamEncryptionRequest.scala index c2b41304..25106c2c 100644 --- a/src/main/scala/kinesis/mock/api/StartStreamEncryptionRequest.scala +++ b/src/main/scala/kinesis/mock/api/StartStreamEncryptionRequest.scala @@ -8,6 +8,7 @@ import cats.syntax.all._ import io.circe import kinesis.mock.models._ +import kinesis.mock.syntax.either._ import kinesis.mock.validations.CommonValidations final case class StartStreamEncryptionRequest( @@ -17,7 +18,7 @@ final case class StartStreamEncryptionRequest( ) { def startStreamEncryption( streamsRef: Ref[IO, Streams] - ): IO[Response[Unit]] = streamsRef.get.flatMap { streams => + ): IO[Response[Unit]] = streamsRef.modify { streams => CommonValidations .validateStreamName(streamName) .flatMap(_ => @@ -31,17 +32,19 @@ final case class StartStreamEncryptionRequest( ).mapN((_, _, _) => stream) ) ) - .traverse(stream => - streamsRef.update(x => - x.updateStream( + .map(stream => + ( + streams.updateStream( stream.copy( encryptionType = encryptionType, streamStatus = StreamStatus.UPDATING, keyId = Some(keyId) ) - ) + ), + () ) ) + .sequenceWithDefault(streams) } } diff --git a/src/main/scala/kinesis/mock/api/StopStreamEncryptionRequest.scala b/src/main/scala/kinesis/mock/api/StopStreamEncryptionRequest.scala index d9bbcb37..8916baeb 100644 --- a/src/main/scala/kinesis/mock/api/StopStreamEncryptionRequest.scala +++ b/src/main/scala/kinesis/mock/api/StopStreamEncryptionRequest.scala @@ -8,6 +8,7 @@ import cats.syntax.all._ import io.circe import kinesis.mock.models._ +import kinesis.mock.syntax.either._ import kinesis.mock.validations.CommonValidations final case class StopStreamEncryptionRequest( @@ -17,7 +18,7 @@ final case class StopStreamEncryptionRequest( ) { def stopStreamEncryption( streamsRef: Ref[IO, Streams] - ): IO[Response[Unit]] = streamsRef.get.flatMap(streams => + ): IO[Response[Unit]] = streamsRef.modify(streams => CommonValidations .validateStreamName(streamName) .flatMap(_ => @@ -31,17 +32,19 @@ final case class StopStreamEncryptionRequest( ).mapN((_, _, _) => stream) ) ) - .traverse(stream => - streamsRef.update(x => - x.updateStream( + .map(stream => + ( + streams.updateStream( stream.copy( encryptionType = EncryptionType.NONE, streamStatus = StreamStatus.UPDATING, keyId = None ) - ) + ), + () ) ) + .sequenceWithDefault(streams) ) } diff --git a/src/main/scala/kinesis/mock/api/UpdateShardCountRequest.scala b/src/main/scala/kinesis/mock/api/UpdateShardCountRequest.scala index 0e3b519f..aeb14cf8 100644 --- a/src/main/scala/kinesis/mock/api/UpdateShardCountRequest.scala +++ b/src/main/scala/kinesis/mock/api/UpdateShardCountRequest.scala @@ -6,12 +6,13 @@ import scala.concurrent.duration._ import java.time.Instant import cats.Eq +import cats.effect.IO import cats.effect.concurrent.Ref -import cats.effect.{Concurrent, IO} import cats.syntax.all._ import io.circe import kinesis.mock.models._ +import kinesis.mock.syntax.either._ import kinesis.mock.validations.CommonValidations // https://docs.aws.amazon.com/kinesis/latest/APIReference/API_UpdateShardCount.html @@ -23,8 +24,8 @@ final case class UpdateShardCountRequest( def updateShardCount( streamsRef: Ref[IO, Streams], shardLimit: Int - )(implicit C: Concurrent[IO]): IO[Response[Unit]] = - streamsRef.get.flatMap { streams => + ): IO[Response[Unit]] = + streamsRef.modify { streams => val now = Instant.now() CommonValidations .validateStreamName(streamName) @@ -69,7 +70,7 @@ final case class UpdateShardCountRequest( ).mapN((_, _, _) => stream) } ) - .traverse { stream => + .map { stream => val shards = stream.shards.toVector val oldShards = shards.map { case (shard, data) => ( @@ -86,15 +87,17 @@ final case class UpdateShardCountRequest( now, oldShards.map(_._1.shardId.index).max + 1 ) - streamsRef.update(x => - x.updateStream( + ( + streams.updateStream( stream.copy( shards = newShards ++ oldShards, streamStatus = StreamStatus.UPDATING ) - ) + ), + () ) } + .sequenceWithDefault(streams) } } diff --git a/src/main/scala/kinesis/mock/cache/Cache.scala b/src/main/scala/kinesis/mock/cache/Cache.scala index 0d9ea6af..03800d24 100644 --- a/src/main/scala/kinesis/mock/cache/Cache.scala +++ b/src/main/scala/kinesis/mock/cache/Cache.scala @@ -103,10 +103,7 @@ class Cache private ( req: CreateStreamRequest, context: LoggingContext, isCbor: Boolean - )(implicit - T: Timer[IO], - CS: ContextShift[IO] - ): IO[Response[Unit]] = { + )(implicit T: Timer[IO]): IO[Response[Unit]] = { val ctx = context + ("streamName" -> req.streamName.streamName) logger.debug(ctx.context)("Processing CreateStream request") *> logger.trace(ctx.addEncoded("request", req, isCbor).context)( @@ -1041,10 +1038,7 @@ class Cache private ( req: MergeShardsRequest, context: LoggingContext, isCbor: Boolean - )(implicit - T: Timer[IO], - CS: ContextShift[IO] - ): IO[Response[Unit]] = { + )(implicit T: Timer[IO]): IO[Response[Unit]] = { val ctx = context + ("streamName" -> req.streamName.streamName) logger.debug(ctx.context)( "Processing MergeShards request" @@ -1096,10 +1090,7 @@ class Cache private ( req: SplitShardRequest, context: LoggingContext, isCbor: Boolean - )(implicit - T: Timer[IO], - CS: ContextShift[IO] - ): IO[Response[Unit]] = { + )(implicit T: Timer[IO]): IO[Response[Unit]] = { val ctx = context + ("streamName" -> req.streamName.streamName) logger.debug(ctx.context)( "Processing SplitShard request" @@ -1151,10 +1142,7 @@ class Cache private ( req: UpdateShardCountRequest, context: LoggingContext, isCbor: Boolean - )(implicit - T: Timer[IO], - CS: ContextShift[IO] - ): IO[Response[Unit]] = { + )(implicit T: Timer[IO]): IO[Response[Unit]] = { val ctx = context + ("streamName" -> req.streamName.streamName) logger.debug(ctx.context)( "Processing UpdateShardCount request" diff --git a/src/main/scala/kinesis/mock/syntax/either.scala b/src/main/scala/kinesis/mock/syntax/either.scala new file mode 100644 index 00000000..11cc809a --- /dev/null +++ b/src/main/scala/kinesis/mock/syntax/either.scala @@ -0,0 +1,19 @@ +package kinesis.mock.syntax + +object either extends KinesisMockEitherSyntax + +trait KinesisMockEitherSyntax { + implicit def toKinesisMockEitherTupleOps[L, T1, T2]( + e: Either[L, (T1, T2)] + ): KinesisMockEitherSyntax.KinesisMockEitherTupleOps[L, T1, T2] = + new KinesisMockEitherSyntax.KinesisMockEitherTupleOps(e) +} + +object KinesisMockEitherSyntax { + final class KinesisMockEitherTupleOps[L, T1, T2]( + private val e: Either[L, (T1, T2)] + ) extends AnyVal { + def sequenceWithDefault(default: T1): (T1, Either[L, T2]) = + e.fold(e => (default, Left(e)), { case (t1, t2) => (t1, Right(t2)) }) + } +}