Skip to content

Commit

Permalink
Remove Shard Semaphores (#113)
Browse files Browse the repository at this point in the history
  • Loading branch information
etspaceman authored Jun 19, 2021
1 parent f0bd7ad commit 7d1e9dd
Show file tree
Hide file tree
Showing 37 changed files with 266 additions and 573 deletions.
10 changes: 3 additions & 7 deletions src/main/scala/kinesis/mock/api/CreateStreamRequest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package kinesis.mock
package api

import cats.Eq
import cats.effect.concurrent.{Ref, Semaphore}
import cats.effect.concurrent.Ref
import cats.effect.{Concurrent, IO}
import cats.syntax.all._
import io.circe
Expand All @@ -13,7 +13,6 @@ import kinesis.mock.validations.CommonValidations
final case class CreateStreamRequest(shardCount: Int, streamName: StreamName) {
def createStream(
streamsRef: Ref[IO, Streams],
shardSemaphoresRef: Ref[IO, Map[ShardSemaphoresKey, Semaphore[IO]]],
shardLimit: Int,
awsRegion: AwsRegion,
awsAccountId: AwsAccountId
Expand All @@ -38,16 +37,13 @@ final case class CreateStreamRequest(shardCount: Int, streamName: StreamName) {
else Right(()),
CommonValidations.validateShardLimit(shardCount, streams, shardLimit)
).traverseN { (_, _, _, _, _) =>
val (newStream, shardSemaphoreKeys) =
val newStream =
StreamData.create(shardCount, streamName, awsRegion, awsAccountId)
for {
_ <- streamsRef
res <- streamsRef
.update(x =>
x.copy(streams = x.streams ++ List(streamName -> newStream))
)
shardSemaphores <- shardSemaphoreKeys
.traverse(key => Semaphore[IO](1).map(s => key -> s))
res <- shardSemaphoresRef.update(x => x ++ shardSemaphores)
} yield res
}
}
Expand Down
12 changes: 3 additions & 9 deletions src/main/scala/kinesis/mock/api/DeleteStreamRequest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import scala.collection.SortedMap

import cats.Eq
import cats.effect.IO
import cats.effect.concurrent.{Ref, Semaphore}
import cats.effect.concurrent.Ref
import cats.syntax.all._
import io.circe

Expand All @@ -17,8 +17,7 @@ final case class DeleteStreamRequest(
enforceConsumerDeletion: Option[Boolean]
) {
def deleteStream(
streamsRef: Ref[IO, Streams],
shardSemaphoresRef: Ref[IO, Map[ShardSemaphoresKey, Semaphore[IO]]]
streamsRef: Ref[IO, Streams]
): IO[Response[Unit]] =
streamsRef.get.flatMap { streams =>
CommonValidations
Expand Down Expand Up @@ -51,18 +50,13 @@ final case class DeleteStreamRequest(
)
)

val shardSemaphoreKeys = stream.shards.keys.toList
.map(shard => ShardSemaphoresKey(stream.streamName, shard))

for {
_ <- streamsRef
res <- streamsRef
.update(streams =>
streams.copy(
streams = streams.streams ++ deletingStream
)
)
res <- shardSemaphoresRef
.update(shardSemaphores => shardSemaphores -- shardSemaphoreKeys)
} yield res
}
}
Expand Down
44 changes: 11 additions & 33 deletions src/main/scala/kinesis/mock/api/MergeShardsRequest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ package api
import java.time.Instant

import cats.Eq
import cats.effect.concurrent.{Ref, Semaphore}
import cats.effect.concurrent.Ref
import cats.effect.{Concurrent, IO}
import cats.syntax.all._
import io.circe
Expand All @@ -19,8 +19,7 @@ final case class MergeShardsRequest(
streamName: StreamName
) {
def mergeShards(
streamsRef: Ref[IO, Streams],
shardSemaphoresRef: Ref[IO, Map[ShardSemaphoresKey, Semaphore[IO]]]
streamsRef: Ref[IO, Streams]
)(implicit C: Concurrent[IO]): IO[Response[Unit]] =
streamsRef.get.flatMap(streams =>
CommonValidations
Expand Down Expand Up @@ -117,36 +116,15 @@ final case class MergeShardsRequest(
.copy(endingSequenceNumber = Some(SequenceNumber.shardEnd))
) -> shardData
)
shardSemaphoresRef.get.flatMap(shardSemaphores =>
shardSemaphores(
ShardSemaphoresKey(streamName, adjacentShard)
).withPermit(
shardSemaphores(ShardSemaphoresKey(streamName, shard))
.withPermit(
for {
_ <- streamsRef.update(x =>
x.updateStream(
stream.copy(
shards = stream.shards.filterNot { case (s, _) =>
s.shardId == adjacentShard.shardId || s.shardId == shard.shardId
}
++ (oldShards :+ newShard),
streamStatus = StreamStatus.UPDATING
)
)
)
newSemaphore <- Semaphore[IO](1)
newShardsSemaphoreKey = ShardSemaphoresKey(
streamName,
newShard._1
)
res <- shardSemaphoresRef.update(shardsSemaphore =>
shardsSemaphore ++ List(
newShardsSemaphoreKey -> newSemaphore
)
)
} yield res
)
streamsRef.update(x =>
x.updateStream(
stream.copy(
shards = stream.shards.filterNot { case (s, _) =>
s.shardId == adjacentShard.shardId || s.shardId == shard.shardId
}
++ (oldShards :+ newShard),
streamStatus = StreamStatus.UPDATING
)
)
)
}
Expand Down
50 changes: 22 additions & 28 deletions src/main/scala/kinesis/mock/api/PutRecordRequest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import java.time.Instant

import cats.Eq
import cats.effect.IO
import cats.effect.concurrent.{Ref, Semaphore}
import cats.effect.concurrent.Ref
import cats.syntax.all._
import io.circe

Expand All @@ -23,8 +23,7 @@ final case class PutRecordRequest(
streamName: StreamName
) {
def putRecord(
streamsRef: Ref[IO, Streams],
shardSemaphoresRef: Ref[IO, Map[ShardSemaphoresKey, Semaphore[IO]]]
streamsRef: Ref[IO, Streams]
): IO[Response[PutRecordResponse]] = streamsRef.get.flatMap { streams =>
val now = Instant.now()
CommonValidations
Expand Down Expand Up @@ -75,34 +74,29 @@ final case class PutRecordRequest(
Some(records.length),
Some(now)
)
// Use a semaphore to ensure synchronous operations on the shard
shardSemaphoresRef.get.flatMap(shardSemaphores =>
shardSemaphores(ShardSemaphoresKey(streamName, shard)).withPermit(
streamsRef
.update(x =>
x.updateStream {
stream.copy(
shards = stream.shards ++ SortedMap(
shard -> (records :+ KinesisRecord(
now,
data,
stream.encryptionType,
partitionKey,
seqNo
))
)
)
}
)
.as(
PutRecordResponse(
stream.encryptionType,
seqNo,
shard.shardId.shardId
streamsRef
.update(x =>
x.updateStream {
stream.copy(
shards = stream.shards ++ SortedMap(
shard -> (records :+ KinesisRecord(
now,
data,
stream.encryptionType,
partitionKey,
seqNo
))
)
)
}
)
.as(
PutRecordResponse(
stream.encryptionType,
seqNo,
shard.shardId.shardId
)
)
)
}
}
}
Expand Down
74 changes: 29 additions & 45 deletions src/main/scala/kinesis/mock/api/PutRecordsRequest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ package api

import java.time.Instant

import cats.Eq
import cats.effect.IO
import cats.effect.concurrent.{Ref, Semaphore}
import cats.effect.concurrent.Ref
import cats.syntax.all._
import cats.{Eq, Parallel}
import io.circe

import kinesis.mock.models._
Expand All @@ -17,10 +17,7 @@ final case class PutRecordsRequest(
streamName: StreamName
) {
def putRecords(
streamsRef: Ref[IO, Streams],
shardSemaphoresRef: Ref[IO, Map[ShardSemaphoresKey, Semaphore[IO]]]
)(implicit
P: Parallel[IO]
streamsRef: Ref[IO, Streams]
): IO[Response[PutRecordsResponse]] =
streamsRef.get.flatMap { streams =>
val now = Instant.now()
Expand Down Expand Up @@ -89,49 +86,36 @@ final case class PutRecordsRequest(
}
.toList

shardSemaphoresRef.get.flatMap { shardSemaphores =>
val keys = grouped.map { case ((shard, _), _) =>
ShardSemaphoresKey(streamName, shard)
}

val semaphores = shardSemaphores.toList.filter { case (key, _) =>
keys.contains(key)
}
val newShards = grouped.map {
case ((shard, currentRecords), recordsToAdd) =>
(
shard,
(
currentRecords ++ recordsToAdd.map(_._1),
recordsToAdd.map(_._2)
)
)
}

for {
_ <- semaphores.parTraverse { case (_, semaphore) =>
semaphore.acquire
}
newShards = grouped.map {
case ((shard, currentRecords), recordsToAdd) =>
(
shard,
(
currentRecords ++ recordsToAdd.map(_._1),
recordsToAdd.map(_._2)
)
)
}
_ <- streamsRef.update(x =>
x.updateStream(
stream.copy(
shards = stream.shards ++ newShards.map {
case (shard, (records, _)) => shard -> records
}
)
streamsRef
.update(x =>
x.updateStream(
stream.copy(
shards = stream.shards ++ newShards.map {
case (shard, (records, _)) => shard -> records
}
)
)
_ <- semaphores.parTraverse { case (_, semaphore) =>
semaphore.release
}
} yield PutRecordsResponse(
stream.encryptionType,
0,
newShards.flatMap { case (_, (_, resultEntries)) =>
resultEntries
}
)
}
.as(
PutRecordsResponse(
stream.encryptionType,
0,
newShards.flatMap { case (_, (_, resultEntries)) =>
resultEntries
}
)
)
}
}
}
Expand Down
34 changes: 10 additions & 24 deletions src/main/scala/kinesis/mock/api/SplitShardRequest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ package api
import java.time.Instant

import cats.Eq
import cats.effect.concurrent.{Ref, Semaphore}
import cats.effect.concurrent.Ref
import cats.effect.{Concurrent, IO}
import cats.syntax.all._
import io.circe
Expand All @@ -20,7 +20,6 @@ final case class SplitShardRequest(
) {
def splitShard(
streamsRef: Ref[IO, Streams],
shardSemaphoresRef: Ref[IO, Map[ShardSemaphoresKey, Semaphore[IO]]],
shardLimit: Int
)(implicit C: Concurrent[IO]): IO[Response[Unit]] =
streamsRef.get.flatMap { streams =>
Expand Down Expand Up @@ -110,29 +109,16 @@ final case class SplitShardRequest(
)
) -> shardData

shardSemaphoresRef.get.flatMap { shardSemaphores =>
shardSemaphores(ShardSemaphoresKey(streamName, shard)).withPermit(
streamsRef.update(x =>
x.updateStream(
stream.copy(
shards = stream.shards.filterNot { case (shard, _) =>
shard.shardId == oldShard._1.shardId
} ++ (newShards :+ oldShard),
streamStatus = StreamStatus.UPDATING
)
)
)
) *> Semaphore[IO](1)
.flatMap(x => Semaphore[IO](1).map(y => List(x, y)))
.flatMap(semaphores =>
shardSemaphoresRef.update(shardsSemaphore =>
shardsSemaphore ++ List(
ShardSemaphoresKey(streamName, newShard1._1),
ShardSemaphoresKey(streamName, newShard2._1)
).zip(semaphores)
)
streamsRef.update(x =>
x.updateStream(
stream.copy(
shards = stream.shards.filterNot { case (shard, _) =>
shard.shardId == oldShard._1.shardId
} ++ (newShards :+ oldShard),
streamStatus = StreamStatus.UPDATING
)
}
)
)
}
}
}
Expand Down
Loading

0 comments on commit 7d1e9dd

Please sign in to comment.