From 3d50b74e90b5874143f5f5f66d46c9dd0e14a27c Mon Sep 17 00:00:00 2001 From: Daniil Filippov Date: Tue, 7 Nov 2023 13:30:35 +0300 Subject: [PATCH] Override blocking inside CatsMonadError Fix #1991 --- .../client3/impl/cats/CatsMonadError.scala | 2 ++ .../impl/cats/CatsMonadErrorTest.scala | 32 +++++++++++++++++++ 2 files changed, 34 insertions(+) create mode 100644 effects/cats/src/test/scala/sttp/client3/impl/cats/CatsMonadErrorTest.scala diff --git a/effects/cats/src/main/scala/sttp/client3/impl/cats/CatsMonadError.scala b/effects/cats/src/main/scala/sttp/client3/impl/cats/CatsMonadError.scala index 22ee7f3f56..2e56ddcfb9 100644 --- a/effects/cats/src/main/scala/sttp/client3/impl/cats/CatsMonadError.scala +++ b/effects/cats/src/main/scala/sttp/client3/impl/cats/CatsMonadError.scala @@ -21,4 +21,6 @@ class CatsMonadError[F[_]](implicit F: Sync[F]) extends MonadError[F] { override def flatten[T](ffa: F[F[T]]): F[T] = F.flatten(ffa) override def ensure[T](f: F[T], e: => F[Unit]): F[T] = F.guaranteeCase(f)(_ => e) + + override def blocking[T](t: => T): F[T] = F.blocking(t) } diff --git a/effects/cats/src/test/scala/sttp/client3/impl/cats/CatsMonadErrorTest.scala b/effects/cats/src/test/scala/sttp/client3/impl/cats/CatsMonadErrorTest.scala new file mode 100644 index 0000000000..82f47847f9 --- /dev/null +++ b/effects/cats/src/test/scala/sttp/client3/impl/cats/CatsMonadErrorTest.scala @@ -0,0 +1,32 @@ +package sttp.client3.impl.cats + +import cats.effect.IO +import cats.effect.unsafe.{IORuntime, IORuntimeConfig} +import org.scalatest.freespec.AsyncFreeSpec + +class CatsMonadErrorTest extends AsyncFreeSpec { + import scala.concurrent.duration._ + + "blocking" - { + "should shift to blocking execution context " in { + implicit val ioRuntime: IORuntime = createIORuntime(computePoolSize = 1) + val monad = new CatsMonadError[IO] + + val program = monad + .blocking(Thread.sleep(100.millis.toMillis)) + .background + .use(getOutcome => IO.race(getOutcome, IO.unit)) + + program + .flatMap(either => IO.delay(assert(either.isRight))) + .unsafeToFuture() + } + } + + private def createIORuntime(computePoolSize: Int): IORuntime = { + val (compute, _) = IORuntime.createWorkStealingComputeThreadPool(computePoolSize) + val (blocking, _) = IORuntime.createDefaultBlockingExecutionContext() + + IORuntime(compute, blocking, compute, () => (), IORuntimeConfig()) + } +}