Skip to content

Commit

Permalink
Use Modify instead of Get / Update for Ref Updates (#134)
Browse files Browse the repository at this point in the history
* Use Modify instead of Get / Update for Ref Updates

* Parallel PutRecords
  • Loading branch information
etspaceman authored Jul 6, 2021
1 parent 169dd64 commit 10d17a1
Show file tree
Hide file tree
Showing 21 changed files with 244 additions and 218 deletions.
6 changes: 3 additions & 3 deletions src/fun/scala/kinesis/mock/PutRecordTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class PutRecordTests extends munit.CatsEffectSuite with AwsFunctionalTests {
for {
recordRequests <- IO(
putRecordRequestArb.arbitrary
.take(5)
.take(20)
.toVector
.map(_.copy(streamName = resources.streamName))
.map(x =>
Expand All @@ -34,7 +34,7 @@ class PutRecordTests extends munit.CatsEffectSuite with AwsFunctionalTests {
.build()
)
)
_ <- recordRequests.traverse(x =>
_ <- recordRequests.parTraverse(x =>
resources.kinesisClient.putRecord(x).toIO
)
shards <- resources.kinesisClient
Expand Down Expand Up @@ -68,7 +68,7 @@ class PutRecordTests extends munit.CatsEffectSuite with AwsFunctionalTests {
)
res = gets.flatMap(_.records().asScala.toVector)
} yield assert(
res.length == 5 && res.forall(rec =>
res.length == 20 && res.forall(rec =>
recordRequests.exists(req =>
req.data.asByteArray.sameElements(rec.data.asByteArray)
&& req.partitionKey == rec.partitionKey
Expand Down
49 changes: 26 additions & 23 deletions src/fun/scala/kinesis/mock/PutRecordsTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,29 @@ class PutRecordsTests extends munit.CatsEffectSuite with AwsFunctionalTests {

fixture.test("It should put records") { resources =>
for {
req <- IO(
PutRecordsRequest
.builder()
.records(
putRecordsRequestEntryArb.arbitrary
.take(5)
.toVector
.map(x =>
PutRecordsRequestEntry
.builder()
.data(SdkBytes.fromByteArray(x.data))
.partitionKey(x.partitionKey)
.maybeTransform(x.explicitHashKey)(_.explicitHashKey(_))
.build()
)
.asJava
)
.streamName(resources.streamName.streamName)
.build()
reqs <- IO(
List.fill(10)(
PutRecordsRequest
.builder()
.records(
putRecordsRequestEntryArb.arbitrary
.take(5)
.toVector
.map(x =>
PutRecordsRequestEntry
.builder()
.data(SdkBytes.fromByteArray(x.data))
.partitionKey(x.partitionKey)
.maybeTransform(x.explicitHashKey)(_.explicitHashKey(_))
.build()
)
.asJava
)
.streamName(resources.streamName.streamName)
.build()
)
)
_ <- resources.kinesisClient.putRecords(req).toIO
_ <- reqs.parTraverse(req => resources.kinesisClient.putRecords(req).toIO)
shards <- resources.kinesisClient
.listShards(
ListShardsRequest
Expand Down Expand Up @@ -67,14 +69,15 @@ class PutRecordsTests extends munit.CatsEffectSuite with AwsFunctionalTests {
.toIO
)
res = gets.flatMap(_.records().asScala.toVector)
records = reqs.flatMap(_.records.asScala.toVector)
} yield assert(
res.length == 5 && res.forall(rec =>
req.records.asScala.toVector.exists(req =>
res.length == 50 && res.forall(rec =>
records.exists(req =>
req.data.asByteArray.sameElements(rec.data.asByteArray)
&& req.partitionKey == rec.partitionKey
)
),
s"$res\n${req.records()}"
s"$res\n$records"
)
}
}
10 changes: 5 additions & 5 deletions src/main/scala/kinesis/mock/api/AddTagsToStreamRequest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import cats.syntax.all._
import io.circe

import kinesis.mock.models._
import kinesis.mock.syntax.either._
import kinesis.mock.validations.CommonValidations

// https://docs.aws.amazon.com/kinesis/latest/APIReference/API_AddTagsToStream.html
Expand All @@ -19,7 +20,7 @@ final case class AddTagsToStreamRequest(
) {
def addTagsToStream(
streamsRef: Ref[IO, Streams]
): IO[Response[Unit]] = streamsRef.get.flatMap { streams =>
): IO[Response[Unit]] = streamsRef.modify { streams =>
CommonValidations
.validateStreamName(streamName)
.flatMap(_ =>
Expand Down Expand Up @@ -62,11 +63,10 @@ final case class AddTagsToStreamRequest(
).mapN((_, _, _, _, _) => stream)
)
)
.traverse(stream =>
streamsRef.update(x =>
x.updateStream(stream.copy(tags = stream.tags |+| tags))
)
.map(stream =>
(streams.updateStream(stream.copy(tags = stream.tags |+| tags)), ())
)
.sequenceWithDefault(streams)
}
}

Expand Down
21 changes: 10 additions & 11 deletions src/main/scala/kinesis/mock/api/CreateStreamRequest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@ package kinesis.mock
package api

import cats.Eq
import cats.effect.IO
import cats.effect.concurrent.Ref
import cats.effect.{Concurrent, IO}
import cats.syntax.all._
import io.circe

import kinesis.mock.models._
import kinesis.mock.syntax.either._
import kinesis.mock.validations.CommonValidations

final case class CreateStreamRequest(shardCount: Int, streamName: StreamName) {
Expand All @@ -16,8 +17,8 @@ final case class CreateStreamRequest(shardCount: Int, streamName: StreamName) {
shardLimit: Int,
awsRegion: AwsRegion,
awsAccountId: AwsAccountId
)(implicit C: Concurrent[IO]): IO[Response[Unit]] =
streamsRef.get.flatMap { streams =>
): IO[Response[Unit]] =
streamsRef.modify { streams =>
(
CommonValidations.validateStreamName(streamName),
if (streams.streams.contains(streamName))
Expand All @@ -36,16 +37,14 @@ final case class CreateStreamRequest(shardCount: Int, streamName: StreamName) {
).asLeft
else Right(()),
CommonValidations.validateShardLimit(shardCount, streams, shardLimit)
).traverseN { (_, _, _, _, _) =>
).mapN { (_, _, _, _, _) =>
val newStream =
StreamData.create(shardCount, streamName, awsRegion, awsAccountId)
for {
res <- streamsRef
.update(x =>
x.copy(streams = x.streams + (streamName -> newStream))
)
} yield res
}
(
streams.copy(streams = streams.streams + (streamName -> newStream)),
()
)
}.sequenceWithDefault(streams)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import cats.syntax.all._
import io.circe

import kinesis.mock.models._
import kinesis.mock.syntax.either._
import kinesis.mock.validations.CommonValidations

// https://docs.aws.amazon.com/kinesis/latest/APIReference/API_DecreaseStreamRetention.html
Expand All @@ -19,7 +20,7 @@ final case class DecreaseStreamRetentionPeriodRequest(
) {
def decreaseStreamRetention(
streamsRef: Ref[IO, Streams]
): IO[Response[Unit]] = streamsRef.get.flatMap { streams =>
): IO[Response[Unit]] = streamsRef.modify { streams =>
CommonValidations
.validateStreamName(streamName)
.flatMap(_ =>
Expand All @@ -38,13 +39,15 @@ final case class DecreaseStreamRetentionPeriodRequest(
).mapN((_, _, _) => stream)
)
)
.traverse(stream =>
streamsRef.update(x =>
x.updateStream(
.map(stream =>
(
streams.updateStream(
stream.copy(retentionPeriod = retentionPeriodHours.hours)
)
),
()
)
)
.sequenceWithDefault(streams)
}
}

Expand Down
21 changes: 10 additions & 11 deletions src/main/scala/kinesis/mock/api/DeleteStreamRequest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import cats.syntax.all._
import io.circe

import kinesis.mock.models._
import kinesis.mock.syntax.either._
import kinesis.mock.validations.CommonValidations

final case class DeleteStreamRequest(
Expand All @@ -17,7 +18,7 @@ final case class DeleteStreamRequest(
def deleteStream(
streamsRef: Ref[IO, Streams]
): IO[Response[Unit]] =
streamsRef.get.flatMap { streams =>
streamsRef.modify { streams =>
CommonValidations
.validateStreamName(streamName)
.flatMap(_ =>
Expand All @@ -37,7 +38,7 @@ final case class DeleteStreamRequest(
).mapN((_, _) => stream)
)
)
.traverse { stream =>
.map { stream =>
val deletingStream = Map(
streamName -> stream.copy(
shards = Map.empty,
Expand All @@ -47,16 +48,14 @@ final case class DeleteStreamRequest(
consumers = Map.empty
)
)

for {
res <- streamsRef
.update(streams =>
streams.copy(
streams = streams.streams ++ deletingStream
)
)
} yield res
(
streams.copy(
streams = streams.streams ++ deletingStream
),
()
)
}
.sequenceWithDefault(streams)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import cats.syntax.all._
import io.circe

import kinesis.mock.models._
import kinesis.mock.syntax.either._
import kinesis.mock.validations.CommonValidations

// https://docs.aws.amazon.com/kinesis/latest/APIReference/API_DeregisterStreamConsumer.html
Expand All @@ -17,26 +18,26 @@ final case class DeregisterStreamConsumerRequest(
streamArn: Option[String]
) {
private def deregister(
streamsRef: Ref[IO, Streams],
streams: Streams,
consumer: Consumer,
stream: StreamData
): IO[Consumer] = {
val newConsumer = consumer.copy(consumerStatus = ConsumerStatus.DELETING)
): (Streams, Consumer) = {
val newConsumer =
consumer.copy(consumerStatus = ConsumerStatus.DELETING)

streamsRef
.update(x =>
x.updateStream(
stream.copy(consumers =
stream.consumers + (consumer.consumerName -> newConsumer)
)
(
streams.updateStream(
stream.copy(consumers =
stream.consumers + (consumer.consumerName -> newConsumer)
)
)
.as(newConsumer)
),
newConsumer
)
}

def deregisterStreamConsumer(
streamsRef: Ref[IO, Streams]
): IO[Response[Consumer]] = streamsRef.get.flatMap { streams =>
): IO[Response[Consumer]] = streamsRef.modify { streams =>
(consumerArn, consumerName, streamArn) match {
case (Some(cArn), _, _) =>
CommonValidations
Expand All @@ -50,9 +51,10 @@ final case class DeregisterStreamConsumerRequest(
s"Consumer $consumerName is not in an ACTIVE state"
).asLeft
}
.traverse { case (consumer, stream) =>
deregister(streamsRef, consumer, stream)
.map { case (consumer, stream) =>
deregister(streams, consumer, stream)
}
.sequenceWithDefault(streams)
case (None, Some(cName), Some(sArn)) =>
CommonValidations
.findStreamByArn(sArn, streams)
Expand All @@ -68,11 +70,13 @@ final case class DeregisterStreamConsumerRequest(

}
}
.traverse { case (consumer, stream) =>
deregister(streamsRef, consumer, stream)
.map { case (consumer, stream) =>
deregister(streams, consumer, stream)
}
.sequenceWithDefault(streams)
case _ =>
IO(
(
streams,
InvalidArgumentException(
"ConsumerArn or both ConsumerName and StreamARN are required for this request."
).asLeft
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ package api
import cats.Eq
import cats.effect.IO
import cats.effect.concurrent.Ref
import cats.syntax.all._
import io.circe

import kinesis.mock.models._
import kinesis.mock.syntax.either._
import kinesis.mock.validations.CommonValidations

// https://docs.aws.amazon.com/kinesis/latest/APIReference/API_DisableEnhancedMonitoring.html
Expand All @@ -18,33 +18,31 @@ final case class DisableEnhancedMonitoringRequest(
def disableEnhancedMonitoring(
streamsRef: Ref[IO, Streams]
): IO[Response[DisableEnhancedMonitoringResponse]] =
streamsRef.get.flatMap { streams =>
streamsRef.modify { streams =>
CommonValidations
.validateStreamName(streamName)
.flatMap(_ => CommonValidations.findStream(streamName, streams))
.traverse { stream =>
.map { stream =>
val current =
stream.enhancedMonitoring.flatMap(_.shardLevelMetrics)
val desired =
if (shardLevelMetrics.contains(ShardLevelMetric.ALL))
Vector.empty
else current.diff(shardLevelMetrics)

streamsRef
.update(x =>
x.updateStream(
stream
.copy(enhancedMonitoring = Vector(ShardLevelMetrics(desired)))
)
)
.as(
DisableEnhancedMonitoringResponse(
current,
desired,
streamName
)
(
streams.updateStream(
stream
.copy(enhancedMonitoring = Vector(ShardLevelMetrics(desired)))
),
DisableEnhancedMonitoringResponse(
current,
desired,
streamName
)
)
}
.sequenceWithDefault(streams)
}
}

Expand Down
Loading

0 comments on commit 10d17a1

Please sign in to comment.