From b554596f3a7e7dfaea9a7a37c0a71c5786b3d154 Mon Sep 17 00:00:00 2001 From: Daniil Filippov Date: Tue, 7 Nov 2023 13:30:35 +0300 Subject: [PATCH] Override blocking inside CatsMonadError Fix #1991 --- .../client4/impl/cats/CatsMonadError.scala | 2 ++ .../impl/cats/CatsMonadErrorTest.scala | 30 +++++++++++++++++++ 2 files changed, 32 insertions(+) create mode 100644 effects/cats/src/test/scalajvm/sttp/client4/impl/cats/CatsMonadErrorTest.scala diff --git a/effects/cats/src/main/scala/sttp/client4/impl/cats/CatsMonadError.scala b/effects/cats/src/main/scala/sttp/client4/impl/cats/CatsMonadError.scala index 186cee67df..57269d0101 100644 --- a/effects/cats/src/main/scala/sttp/client4/impl/cats/CatsMonadError.scala +++ b/effects/cats/src/main/scala/sttp/client4/impl/cats/CatsMonadError.scala @@ -21,4 +21,6 @@ class CatsMonadError[F[_]](implicit F: Sync[F]) extends MonadError[F] { override def flatten[T](ffa: F[F[T]]): F[T] = F.flatten(ffa) override def ensure[T](f: F[T], e: => F[Unit]): F[T] = F.guaranteeCase(f)(_ => e) + + override def blocking[T](t: => T): F[T] = F.blocking(t) } diff --git a/effects/cats/src/test/scalajvm/sttp/client4/impl/cats/CatsMonadErrorTest.scala b/effects/cats/src/test/scalajvm/sttp/client4/impl/cats/CatsMonadErrorTest.scala new file mode 100644 index 0000000000..cde3257bfc --- /dev/null +++ b/effects/cats/src/test/scalajvm/sttp/client4/impl/cats/CatsMonadErrorTest.scala @@ -0,0 +1,30 @@ +package sttp.client4.impl.cats + +import cats.effect.IO +import cats.effect.unsafe.{IORuntime, IORuntimeConfig} +import org.scalatest.freespec.AsyncFreeSpec + +class CatsMonadErrorTest extends AsyncFreeSpec { + "blocking" - { + "should shift to blocking execution context" in { + implicit val ioRuntime: IORuntime = createIORuntime(computePoolSize = 1) + val monad = new CatsMonadError[IO] + + val program = monad + .blocking(Thread.sleep(100)) + .background + .use(getOutcome => IO.race(getOutcome, IO.unit)) + + program + .flatMap(either => IO.delay(assert(either.isRight))) + .unsafeToFuture() + } + } + + private def createIORuntime(computePoolSize: Int): IORuntime = { + val (compute, _) = IORuntime.createWorkStealingComputeThreadPool(computePoolSize) + val (blocking, _) = IORuntime.createDefaultBlockingExecutionContext() + + IORuntime(compute, blocking, compute, () => (), IORuntimeConfig()) + } +}