Skip to content

Commit

Permalink
Merge pull request #998 from ithinkicancode/fix/thread-safe-cyclic-re…
Browse files Browse the repository at this point in the history
…sponses

Made thenRespondCyclic* functions thread-safe
  • Loading branch information
adamw authored Jun 1, 2021
2 parents a0cef55 + 913e4ce commit 11968c9
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ class SttpBackendStubZioTests extends AnyFlatSpec with Matchers with ScalaFuture
val backend: SttpBackendStub[Task, Any] = SttpBackendStub(new RIOMonadAsyncError[Any])
.whenRequestMatches(_ => true)
.thenRespondCyclic("a", "b", "c")

// when
val r = basicRequest.get(uri"http://example.org/a/b/c").send(backend)

Expand All @@ -28,6 +27,24 @@ class SttpBackendStubZioTests extends AnyFlatSpec with Matchers with ScalaFuture
runtime.unsafeRun(r).body shouldBe Right("a")
}

it should "cycle through responses when called concurrently" in {
// given
val backend: SttpBackendStub[Task, Any] = SttpBackendStub(new RIOMonadAsyncError[Any])
.whenRequestMatches(_ => true)
.thenRespondCyclic("a", "b", "c")

// when
val r = basicRequest.get(uri"http://example.org/a/b/c").send(backend)

// then
val effect = ZIO
.collectAllPar(Seq.fill(100)(r))
.map(_.map(_.body))

runtime.unsafeRun(effect) should contain theSameElementsAs ((1 to 33).flatMap(_ => Seq("a", "b", "c")) ++ Seq("a"))
.map(Right(_))
}

it should "allow effectful stubbing" in {
import stubbing._
val r1 = send(basicRequest.get(uri"http://example.org/a")).map(_.body)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package sttp.client3.testing

import java.util.concurrent.atomic.AtomicInteger
import java.util.function.IntUnaryOperator
import scala.util.{Failure, Success, Try}

final class AtomicCyclicIterator[+T] private(val elements: Seq[T]) {
private val vector = elements.toVector
private val lastIndex = elements.length - 1
private val currentIndex = new AtomicInteger(0)

private val toNextIndex = new IntUnaryOperator {
final override def applyAsInt(i: Int): Int =
if (i == lastIndex) 0 else i + 1
}

def next(): T = {
val index = currentIndex.getAndUpdate(toNextIndex)
vector(index)
}
}

object AtomicCyclicIterator {

def tryFrom[T](elements: Seq[T]): Try[AtomicCyclicIterator[T]] = {
if (elements.nonEmpty)
Success(new AtomicCyclicIterator(elements))
else
Failure(new IllegalArgumentException("Argument must be a non-empty collection."))
}

def unsafeFrom[T](elements: Seq[T]): AtomicCyclicIterator[T] = tryFrom(elements).get

def apply[T](head: T, tail: Seq[T]): AtomicCyclicIterator[T] = unsafeFrom(head +: tail)

def of[T](head: T, tail: T*): AtomicCyclicIterator[T] = apply(head, tail)
}
Original file line number Diff line number Diff line change
Expand Up @@ -108,18 +108,15 @@ class SttpBackendStub[F[_], +P](
new SttpBackendStub[F, P](monad, matchers.orElse(m), fallback)
}

/** Not thread-safe!
*/
def thenRespondCyclic[T](bodies: T*): SttpBackendStub[F, P] = {
thenRespondCyclicResponses(bodies.map(body => Response[T](body, StatusCode.Ok, "OK")): _*)
}

/** Not thread-safe!
*/
def thenRespondCyclicResponses[T](responses: Response[T]*): SttpBackendStub[F, P] = {
val iterator = Iterator.continually(responses).flatten
thenRespond(iterator.next)
val iterator = AtomicCyclicIterator.unsafeFrom(responses)
thenRespond(iterator.next())
}

def thenRespondF(resp: => F[Response[_]]): SttpBackendStub[F, P] = {
val m: PartialFunction[Request[_, _], F[Response[_]]] = {
case r if p(r) => resp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,11 @@ class SttpBackendStubTests extends AnyFlatSpec with Matchers with ScalaFutures {
Response("error", StatusCode.InternalServerError, "Something went wrong")
)

basicRequest.get(uri"http://example.org").send(backend).is200 should be(true)
basicRequest.get(uri"http://example.org").send(backend).isServerError should be(true)
basicRequest.get(uri"http://example.org").send(backend).is200 should be(true)
def testResult = basicRequest.get(uri"http://example.org").send(backend)

testResult.is200 should be(true)
testResult.isServerError should be(true)
testResult.is200 should be(true)
}

it should "always return a string when requested to do so" in {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,15 @@ trait SttpClientStubbingBase[R, P] {
private def whenRequest(
f: SttpBackendStub[RIO[R, *], P]#WhenRequest => SttpBackendStub[RIO[R, *], P]
): URIO[SttpClientStubbing, Unit] =
URIO.accessM(_.get.update(stub => f(stub.whenRequestMatches(p))))
URIO.serviceWith(_.update(stub => f(stub.whenRequestMatches(p))))
}

val layer: ZLayer[Any, Nothing, Has[Service] with Has[SttpBackend[RIO[R, *], P]]] = {
val monad = new RIOMonadAsyncError[R]
implicit val _serviceTag: Tag[Service] = serviceTag
implicit val _backendTag: Tag[SttpBackend[RIO[R, *], P]] = sttpBackendTag
ZLayer.fromEffectMany(for {

val composed = for {
stub <- Ref.make(SttpBackendStub[RIO[R, *], P](monad))
stubber = new StubWrapper(stub)
proxy = new SttpBackend[RIO[R, *], P] {
Expand All @@ -83,6 +84,8 @@ trait SttpClientStubbingBase[R, P] {

override def responseMonad: MonadError[RIO[R, *]] = monad
}
} yield Has.allOf[Service, SttpBackend[RIO[R, *], P]](stubber, proxy))
} yield Has.allOf[Service, SttpBackend[RIO[R, *], P]](stubber, proxy)

composed.toLayerMany
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import sttp.client3._
import sttp.client3.impl.zio._
import sttp.client3.testing.SttpBackendStub
import sttp.model.Method
import zio.Task
import zio.{Task, ZIO}
import zio.stream.ZStream

class SttpBackendStubZioTests extends AnyFlatSpec with Matchers with ScalaFutures with ZioTestBase {
Expand All @@ -17,7 +17,6 @@ class SttpBackendStubZioTests extends AnyFlatSpec with Matchers with ScalaFuture
val backend: SttpBackendStub[Task, Any] = SttpBackendStub(new RIOMonadAsyncError[Any])
.whenRequestMatches(_ => true)
.thenRespondCyclic("a", "b", "c")

// when
val r = basicRequest.get(uri"http://example.org/a/b/c").send(backend)

Expand All @@ -28,6 +27,24 @@ class SttpBackendStubZioTests extends AnyFlatSpec with Matchers with ScalaFuture
runtime.unsafeRun(r).body shouldBe Right("a")
}

it should "cycle through responses when called concurrently" in {
// given
val backend: SttpBackendStub[Task, Any] = SttpBackendStub(new RIOMonadAsyncError[Any])
.whenRequestMatches(_ => true)
.thenRespondCyclic("a", "b", "c")

// when
val r = basicRequest.get(uri"http://example.org/a/b/c").send(backend)

// then
val effect = ZIO
.collectAllPar(Seq.fill(100)(r))
.map(_.map(_.body))

runtime.unsafeRun(effect) should contain theSameElementsAs ((1 to 33).flatMap(_ => Seq("a", "b", "c")) ++ Seq("a"))
.map(Right(_))
}

it should "allow effectful stubbing" in {
import stubbing._
val r1 = send(basicRequest.get(uri"http://example.org/a")).map(_.body)
Expand Down

0 comments on commit 11968c9

Please sign in to comment.