From bddcb85a0beb2dccb0345280c91403e5047f4d32 Mon Sep 17 00:00:00 2001 From: Kyri Petrou Date: Fri, 28 Jun 2024 21:30:17 +1000 Subject: [PATCH 1/3] Requests should be cached on a per-datasource level --- .../src/main/scala/zio/query/Cache.scala | 67 ++++++++++++++----- .../src/main/scala/zio/query/DataSource.scala | 4 +- .../src/main/scala/zio/query/Request.scala | 2 +- .../src/main/scala/zio/query/ZQuery.scala | 7 +- .../zio/query/internal/BlockedRequests.scala | 7 +- .../src/test/scala/zio/query/ZQuerySpec.scala | 43 ++++++++++-- 6 files changed, 99 insertions(+), 31 deletions(-) diff --git a/zio-query/shared/src/main/scala/zio/query/Cache.scala b/zio-query/shared/src/main/scala/zio/query/Cache.scala index b1dc0271..79020014 100644 --- a/zio-query/shared/src/main/scala/zio/query/Cache.scala +++ b/zio-query/shared/src/main/scala/zio/query/Cache.scala @@ -27,14 +27,20 @@ import java.util.concurrent.ConcurrentHashMap * executed. This is used internally by the library to provide deduplication and * caching of requests. */ -trait Cache { +abstract class Cache { /** * Looks up a request in the cache, failing with the unit value if the request * is not in the cache or succeeding with a `Promise` if the request is in the * cache that will contain the result of the request when it is executed. */ - def get[E, A](request: Request[E, A])(implicit trace: Trace): IO[Unit, Promise[E, A]] + def get[E, A, B]( + datasource: DataSource[?, A], + request: A + )(implicit + trace: Trace, + ev: A <:< Request[E, B] + ): IO[Unit, Promise[E, B]] /** * Looks up a request in the cache. If the request is not in the cache returns @@ -42,7 +48,10 @@ trait Cache { * the request is in the cache returns a `Right` with a `Promise` that will * contain the result of the request when it is executed. */ - def lookup[E, A, B](request: A)(implicit + def lookup[E, A, B]( + dataSource: DataSource[?, A], + request: A + )(implicit ev: A <:< Request[E, B], trace: Trace ): UIO[Either[Promise[E, B], Promise[E, B]]] @@ -51,16 +60,28 @@ trait Cache { * Inserts a request and a `Promise` that will contain the result of the * request when it is executed into the cache. */ - def put[E, A](request: Request[E, A], result: Promise[E, A])(implicit trace: Trace): UIO[Unit] + def put[E, A, B]( + dataSource: DataSource[?, A], + request: A, + result: Promise[E, B] + )(implicit + ev: A <:< Request[E, B], + trace: Trace + ): UIO[Unit] /** * Removes a request from the cache. */ - def remove[E, A](request: Request[E, A])(implicit trace: Trace): UIO[Unit] + def remove[E, A, B]( + dataSource: DataSource[?, A], + request: A + )(implicit + ev: A <:< Request[E, B], + trace: Trace + ): UIO[Unit] } object Cache { - /** * Constructs an empty cache. */ @@ -74,33 +95,45 @@ object Cache { def empty(expectedNumOfElements: Int)(implicit trace: Trace): UIO[Cache] = ZIO.succeed(Cache.unsafeMake(expectedNumOfElements)) - private[query] final class Default(private val map: ConcurrentHashMap[Request[_, _], Promise[_, _]]) extends Cache { + private type Key = (DataSource[_, _], Request[_, _]) + + private[query] final class Default(private val map: ConcurrentHashMap[Key, Promise[_, _]]) extends Cache { - def get[E, A](request: Request[E, A])(implicit trace: Trace): IO[Unit, Promise[E, A]] = + def get[E, A, B](ds: DataSource[?, A], request: A)(implicit + trace: Trace, + ev: A <:< Request[E, B] + ): IO[Unit, Promise[E, B]] = ZIO.suspendSucceed { - val out = map.get(request).asInstanceOf[Promise[E, A]] + val out = map.get((ds, request)).asInstanceOf[Promise[E, B]] if (out eq null) Exit.fail(()) else Exit.succeed(out) } - def lookup[E, A, B](request: A)(implicit + def lookup[E, A, B](ds: DataSource[?, A], request: A)(implicit ev: A <:< Request[E, B], trace: Trace ): UIO[Either[Promise[E, B], Promise[E, B]]] = - ZIO.succeed(lookupUnsafe(request)(Unsafe.unsafe)) + ZIO.succeed(lookupUnsafe(ds, request)(ev, Unsafe.unsafe)) - def lookupUnsafe[E, A, B](request: Request[_, _])(implicit + def lookupUnsafe[E, A, B](ds: DataSource[?, A], request: A)(implicit + ev: A <:< Request[E, B], unsafe: Unsafe ): Either[Promise[E, B], Promise[E, B]] = { val newPromise = Promise.unsafe.make[E, B](FiberId.None) - val existing = map.putIfAbsent(request, newPromise).asInstanceOf[Promise[E, B]] + val existing = map.putIfAbsent((ds, request), newPromise).asInstanceOf[Promise[E, B]] if (existing eq null) Left(newPromise) else Right(existing) } - def put[E, A](request: Request[E, A], result: Promise[E, A])(implicit trace: Trace): UIO[Unit] = - ZIO.succeed(map.put(request, result)) + def put[E, A, B](ds: DataSource[?, A], request: A, result: Promise[E, B])(implicit + ev: A <:< Request[E, B], + trace: Trace + ): UIO[Unit] = + ZIO.succeed(map.put((ds, request), result)) - def remove[E, A](request: Request[E, A])(implicit trace: Trace): UIO[Unit] = - ZIO.succeed(map.remove(request)) + def remove[E, A, B](ds: DataSource[?, A], request: A)(implicit + ev: A <:< Request[E, B], + trace: Trace + ): UIO[Unit] = + ZIO.succeed(map.remove((ds, request))) } // TODO: Initialize the map with a sensible default value. Default is 16, which seems way too small for a cache diff --git a/zio-query/shared/src/main/scala/zio/query/DataSource.scala b/zio-query/shared/src/main/scala/zio/query/DataSource.scala index 8ae9b27e..1d4ba31f 100644 --- a/zio-query/shared/src/main/scala/zio/query/DataSource.scala +++ b/zio-query/shared/src/main/scala/zio/query/DataSource.scala @@ -40,7 +40,7 @@ import zio.{Chunk, Exit, Trace, ZEnvironment, ZIO} * for all requests received. Failure to do so will cause a query to die with a * `QueryFailure` when run. */ -trait DataSource[-R, -A] { self => +trait DataSource[-R, -A] extends Serializable { self => /** * Syntax for adding aspects. @@ -65,7 +65,7 @@ trait DataSource[-R, -A] { self => */ def batchN(n: Int): DataSource[R, A] = new DataSource[R, A] { - val identifier = s"${self}.batchN($n)" + val identifier = s"$self.batchN($n)" def runAll(requests: Chunk[Chunk[A]])(implicit trace: Trace): ZIO[R, Nothing, CompletedRequestMap] = if (n < 1) ZIO.die(new IllegalArgumentException("batchN: n must be at least 1")) diff --git a/zio-query/shared/src/main/scala/zio/query/Request.scala b/zio-query/shared/src/main/scala/zio/query/Request.scala index 81f9d256..1c1a8814 100644 --- a/zio-query/shared/src/main/scala/zio/query/Request.scala +++ b/zio-query/shared/src/main/scala/zio/query/Request.scala @@ -30,4 +30,4 @@ import zio.stacktracer.TracingImplicits.disableAutoTrace * * }}} */ -trait Request[E, A] +trait Request[E, A] extends Product with Serializable diff --git a/zio-query/shared/src/main/scala/zio/query/ZQuery.scala b/zio-query/shared/src/main/scala/zio/query/ZQuery.scala index cbe68140..98918101 100644 --- a/zio-query/shared/src/main/scala/zio/query/ZQuery.scala +++ b/zio-query/shared/src/main/scala/zio/query/ZQuery.scala @@ -1532,6 +1532,7 @@ object ZQuery { dataSource: DataSource[R, A], request: A )(implicit ev: A <:< Request[E, B], trace: Trace): CachedResult[R, E, B] = { + implicit val unsafe: Unsafe = Unsafe.unsafe def foldPromise(either: Either[Promise[E, B], Promise[E, B]]): CachedResult[R, E, B] = either match { @@ -1543,7 +1544,7 @@ object ZQuery { ) ) case Right(promise) => - promise.unsafe.poll(Unsafe.unsafe) match { + promise.unsafe.poll match { case None => CachedResult.Pure(Result.blocked(BlockedRequests.empty, Continue(promise))) case Some(io: Exit[E, B]) => CachedResult.Pure(Result.fromExit(io)) case Some(io) => CachedResult.Effectful(io.exit.map(Result.fromExit)) @@ -1551,8 +1552,8 @@ object ZQuery { } cache match { - case cache: Cache.Default => foldPromise(cache.lookupUnsafe(request)(Unsafe.unsafe)) - case cache => CachedResult.Effectful(cache.lookup(request).flatMap(foldPromise(_).toZIO)) + case cache: Cache.Default => foldPromise(cache.lookupUnsafe(dataSource, request)) + case cache => CachedResult.Effectful(cache.lookup(dataSource, request).flatMap(foldPromise(_).toZIO)) } } diff --git a/zio-query/shared/src/main/scala/zio/query/internal/BlockedRequests.scala b/zio-query/shared/src/main/scala/zio/query/internal/BlockedRequests.scala index 6646825f..fe749e18 100644 --- a/zio-query/shared/src/main/scala/zio/query/internal/BlockedRequests.scala +++ b/zio-query/shared/src/main/scala/zio/query/internal/BlockedRequests.scala @@ -129,7 +129,7 @@ private[query] sealed trait BlockedRequests[-R] { self => completedRequestsM.remove(req) orElse completedRequests.lookup(req) } // cache responses that were not requested but were completed by the DataSource - if (completedRequestsM.nonEmpty) cacheLeftovers(cache, completedRequestsM) else ZIO.unit + if (completedRequestsM.nonEmpty) cacheLeftovers(cache, dataSource, completedRequestsM) else ZIO.unit } else { // No need to remove entries here since we don't need to know which ones we need to put in the cache ZIO.succeed(completePromises(dataSource, sequential)(completedRequestsM.get)) @@ -295,6 +295,7 @@ private[query] object BlockedRequests { private def cacheLeftovers( cache: Cache, + dataSource: DataSource[?, Any], map: mutable.HashMap[Request[_, _], Exit[Any, Any]] )(implicit trace: Trace): UIO[Unit] = cache match { @@ -302,7 +303,7 @@ private[query] object BlockedRequests { ZIO.succeedUnsafe { implicit unsafe => map.foreach { case (request: Request[Any, Any], exit) => cache - .lookupUnsafe(request) + .lookupUnsafe(dataSource, request) .merge .unsafe .done(exit) @@ -310,7 +311,7 @@ private[query] object BlockedRequests { } case cache => ZIO.foreachDiscard(map) { case (request: Request[Any, Any], exit) => - cache.lookup(request).flatMap(_.merge.done(exit)) + cache.lookup(dataSource, request).flatMap(_.merge.done(exit)) } } } diff --git a/zio-query/shared/src/test/scala/zio/query/ZQuerySpec.scala b/zio-query/shared/src/test/scala/zio/query/ZQuerySpec.scala index 391d018d..cd0495cd 100644 --- a/zio-query/shared/src/test/scala/zio/query/ZQuerySpec.scala +++ b/zio-query/shared/src/test/scala/zio/query/ZQuerySpec.scala @@ -16,6 +16,21 @@ object ZQuerySpec extends ZIOBaseSpec { log <- TestConsole.output } yield assert(log)(hasSize(equalTo(2))) }, + test("N + 1 on composed queries") { + for { + ref <- Ref.make(0) + ds = countingDs(ref) + query1 = identityDs.query(IdReq(1)).flatMap(i => ds.query(CountingReq(i))) + query2 = identityDs.query(IdReq(2)).flatMap(i => ds.query(CountingReq(i))) + res <- (query1 <~> query2).run + i <- ref.get + } yield assertTrue(i == 1, res == ("1", "2")) + }, + test("requests are cached per datasource") { + for { + res <- identityDs.query(IdReq(1)).flatMap(i => plus1Ds.query(IdReq(i))).run + } yield assertTrue(res == 2) + }, test("mapError does not prevent batching") { implicit val canFail = zio.CanFail val a = getUserNameById(1).zip(getUserNameById(2)).mapError(identity) @@ -204,8 +219,9 @@ object ZQuerySpec extends ZIOBaseSpec { } yield res requestResult <- query.runCache(cache) oneToTen = (1 to 10).toList - cachedResults <- ZIO.foreach(oneToTen)(i => cache.get(Req.Get(i)).flatMap(_.await)) - cacheCheck = cachedResults == oneToTen.map(_.toString) + cachedResults <- + ZIO.foreach(oneToTen)(i => cache.get(dsCompletingMoreRequests, Req.Get(i)).flatMap(_.await)) + cacheCheck = cachedResults == oneToTen.map(_.toString) } yield assertTrue(requestResult == "1", cacheCheck) }, test("caching disabled") { @@ -216,7 +232,7 @@ object ZQuerySpec extends ZIOBaseSpec { } yield res requestResult <- query.runCache(cache) oneToTen = (1 to 10).toList - cachedResults <- ZIO.foreach(oneToTen)(i => cache.get(Req.Get(i)).isFailure) + cachedResults <- ZIO.foreach(oneToTen)(i => cache.get(dsCompletingMoreRequests, Req.Get(i)).isFailure) cacheCheck = cachedResults.forall(identity) } yield assertTrue(requestResult == "1", cacheCheck) } @@ -226,7 +242,7 @@ object ZQuerySpec extends ZIOBaseSpec { cache <- zio.query.Cache.empty query = for { _ <- getUserNameById(1) - _ <- ZQuery.fromZIO(cache.remove(GetNameById(1))) + _ <- ZQuery.fromZIO(cache.remove(UserRequestDataSource, GetNameById(1))) _ <- getUserNameById(1) } yield () _ <- query.runCache(cache) @@ -270,7 +286,7 @@ object ZQuerySpec extends ZIOBaseSpec { assert(log)(hasAt(0)(containsString("GetNameById(1)"))) && assert(log)(hasAt(0)(containsString("GetNameById(2)"))) && assert(log)(hasAt(1)(containsString("GetNameById(1)"))) - } @@ nonFlaky, + } @@ nonFlaky(50), suite("race")( test("race with never") { val query = ZQuery.never.race(ZQuery.succeed(())) @@ -395,6 +411,23 @@ object ZQuerySpec extends ZIOBaseSpec { } } + case class IdReq(id: Int) extends Request[Nothing, Int] + val identityDs: DataSource[Any, IdReq] = + DataSource.fromFunctionBatchedZIO[Any, Nothing, IdReq, Int]("IdDs") { requests => + ZIO.succeed(requests.map(_.id)) + } + + val plus1Ds: DataSource[Any, IdReq] = + DataSource.fromFunctionBatchedZIO[Any, Nothing, IdReq, Int]("Plus1Ds") { requests => + ZIO.succeed(requests.map(_.id + 1)) + } + + case class CountingReq(id: Int) extends Request[Nothing, String] + def countingDs(ref: Ref[Int]): DataSource[Any, CountingReq] = + DataSource.fromFunctionBatchedZIO[Any, Nothing, CountingReq, String]("CountingDs") { requests => + ref.update(_ + 1).as(requests.map(_.id.toString)) + } + val getAllUserIds: ZQuery[Any, Nothing, List[Int]] = ZQuery.fromRequest(GetAllIds)(UserRequestDataSource) From 02827c9763bf766e2a8d407a5fc5a9a064cec655 Mon Sep 17 00:00:00 2001 From: Kyri Petrou Date: Mon, 1 Jul 2024 20:06:02 +1000 Subject: [PATCH 2/3] fmt and fix Scala 2.12 --- .../shared/src/main/scala/zio/query/Cache.scala | 1 + .../src/test/scala/zio/query/ZQuerySpec.scala | 15 ++++++++------- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/zio-query/shared/src/main/scala/zio/query/Cache.scala b/zio-query/shared/src/main/scala/zio/query/Cache.scala index 79020014..0ee74ec8 100644 --- a/zio-query/shared/src/main/scala/zio/query/Cache.scala +++ b/zio-query/shared/src/main/scala/zio/query/Cache.scala @@ -82,6 +82,7 @@ abstract class Cache { } object Cache { + /** * Constructs an empty cache. */ diff --git a/zio-query/shared/src/test/scala/zio/query/ZQuerySpec.scala b/zio-query/shared/src/test/scala/zio/query/ZQuerySpec.scala index cd0495cd..21bfb2ee 100644 --- a/zio-query/shared/src/test/scala/zio/query/ZQuerySpec.scala +++ b/zio-query/shared/src/test/scala/zio/query/ZQuerySpec.scala @@ -18,13 +18,14 @@ object ZQuerySpec extends ZIOBaseSpec { }, test("N + 1 on composed queries") { for { - ref <- Ref.make(0) - ds = countingDs(ref) - query1 = identityDs.query(IdReq(1)).flatMap(i => ds.query(CountingReq(i))) - query2 = identityDs.query(IdReq(2)).flatMap(i => ds.query(CountingReq(i))) - res <- (query1 <~> query2).run - i <- ref.get - } yield assertTrue(i == 1, res == ("1", "2")) + ref <- Ref.make(0) + ds = countingDs(ref) + query1 = identityDs.query(IdReq(1)).flatMap(i => ds.query(CountingReq(i))) + query2 = identityDs.query(IdReq(2)).flatMap(i => ds.query(CountingReq(i))) + res <- (query1 <~> query2).run + i <- ref.get + expected = ("1", "2") + } yield assertTrue(i == 1, res == expected) }, test("requests are cached per datasource") { for { From b2a2149b6513b389af794c40a2f7b905028ee4c1 Mon Sep 17 00:00:00 2001 From: Kyri Petrou Date: Mon, 1 Jul 2024 20:09:44 +1000 Subject: [PATCH 3/3] Disable Mima --- build.sbt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build.sbt b/build.sbt index 1c364149..f2fd39e9 100644 --- a/build.sbt +++ b/build.sbt @@ -106,7 +106,7 @@ lazy val docs = project Global / onChangedBuildSource := ReloadOnSourceChanges -lazy val enforceMimaCompatibility = true // Enable / disable failing CI on binary incompatibilities +lazy val enforceMimaCompatibility = false // Enable / disable failing CI on binary incompatibilities lazy val enableMimaSettingsJVM = Def.settings(