Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Requests should be cached on a per-datasource level #491

Open
wants to merge 3 commits into
base: series/2.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ lazy val docs = project

Global / onChangedBuildSource := ReloadOnSourceChanges

lazy val enforceMimaCompatibility = true // Enable / disable failing CI on binary incompatibilities
lazy val enforceMimaCompatibility = false // Enable / disable failing CI on binary incompatibilities

lazy val enableMimaSettingsJVM =
Def.settings(
Expand Down
66 changes: 50 additions & 16 deletions zio-query/shared/src/main/scala/zio/query/Cache.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,31 @@ import java.util.concurrent.ConcurrentHashMap
* executed. This is used internally by the library to provide deduplication and
* caching of requests.
*/
trait Cache {
abstract class Cache {

/**
* Looks up a request in the cache, failing with the unit value if the request
* is not in the cache or succeeding with a `Promise` if the request is in the
* cache that will contain the result of the request when it is executed.
*/
def get[E, A](request: Request[E, A])(implicit trace: Trace): IO[Unit, Promise[E, A]]
def get[E, A, B](
datasource: DataSource[?, A],
request: A
)(implicit
trace: Trace,
ev: A <:< Request[E, B]
): IO[Unit, Promise[E, B]]

/**
* Looks up a request in the cache. If the request is not in the cache returns
* a `Left` with a `Promise` that can be completed to complete the request. If
* the request is in the cache returns a `Right` with a `Promise` that will
* contain the result of the request when it is executed.
*/
def lookup[E, A, B](request: A)(implicit
def lookup[E, A, B](
dataSource: DataSource[?, A],
request: A
)(implicit
ev: A <:< Request[E, B],
trace: Trace
): UIO[Either[Promise[E, B], Promise[E, B]]]
Expand All @@ -51,12 +60,25 @@ trait Cache {
* Inserts a request and a `Promise` that will contain the result of the
* request when it is executed into the cache.
*/
def put[E, A](request: Request[E, A], result: Promise[E, A])(implicit trace: Trace): UIO[Unit]
def put[E, A, B](
dataSource: DataSource[?, A],
request: A,
result: Promise[E, B]
)(implicit
ev: A <:< Request[E, B],
trace: Trace
): UIO[Unit]

/**
* Removes a request from the cache.
*/
def remove[E, A](request: Request[E, A])(implicit trace: Trace): UIO[Unit]
def remove[E, A, B](
dataSource: DataSource[?, A],
request: A
)(implicit
ev: A <:< Request[E, B],
trace: Trace
): UIO[Unit]
}

object Cache {
Expand All @@ -74,33 +96,45 @@ object Cache {
def empty(expectedNumOfElements: Int)(implicit trace: Trace): UIO[Cache] =
ZIO.succeed(Cache.unsafeMake(expectedNumOfElements))

private[query] final class Default(private val map: ConcurrentHashMap[Request[_, _], Promise[_, _]]) extends Cache {
private type Key = (DataSource[_, _], Request[_, _])

private[query] final class Default(private val map: ConcurrentHashMap[Key, Promise[_, _]]) extends Cache {

def get[E, A](request: Request[E, A])(implicit trace: Trace): IO[Unit, Promise[E, A]] =
def get[E, A, B](ds: DataSource[?, A], request: A)(implicit
trace: Trace,
ev: A <:< Request[E, B]
): IO[Unit, Promise[E, B]] =
ZIO.suspendSucceed {
val out = map.get(request).asInstanceOf[Promise[E, A]]
val out = map.get((ds, request)).asInstanceOf[Promise[E, B]]
if (out eq null) Exit.fail(()) else Exit.succeed(out)
}

def lookup[E, A, B](request: A)(implicit
def lookup[E, A, B](ds: DataSource[?, A], request: A)(implicit
ev: A <:< Request[E, B],
trace: Trace
): UIO[Either[Promise[E, B], Promise[E, B]]] =
ZIO.succeed(lookupUnsafe(request)(Unsafe.unsafe))
ZIO.succeed(lookupUnsafe(ds, request)(ev, Unsafe.unsafe))

def lookupUnsafe[E, A, B](request: Request[_, _])(implicit
def lookupUnsafe[E, A, B](ds: DataSource[?, A], request: A)(implicit
ev: A <:< Request[E, B],
unsafe: Unsafe
): Either[Promise[E, B], Promise[E, B]] = {
val newPromise = Promise.unsafe.make[E, B](FiberId.None)
val existing = map.putIfAbsent(request, newPromise).asInstanceOf[Promise[E, B]]
val existing = map.putIfAbsent((ds, request), newPromise).asInstanceOf[Promise[E, B]]
if (existing eq null) Left(newPromise) else Right(existing)
}

def put[E, A](request: Request[E, A], result: Promise[E, A])(implicit trace: Trace): UIO[Unit] =
ZIO.succeed(map.put(request, result))
def put[E, A, B](ds: DataSource[?, A], request: A, result: Promise[E, B])(implicit
ev: A <:< Request[E, B],
trace: Trace
): UIO[Unit] =
ZIO.succeed(map.put((ds, request), result))

def remove[E, A](request: Request[E, A])(implicit trace: Trace): UIO[Unit] =
ZIO.succeed(map.remove(request))
def remove[E, A, B](ds: DataSource[?, A], request: A)(implicit
ev: A <:< Request[E, B],
trace: Trace
): UIO[Unit] =
ZIO.succeed(map.remove((ds, request)))
}

// TODO: Initialize the map with a sensible default value. Default is 16, which seems way too small for a cache
Expand Down
4 changes: 2 additions & 2 deletions zio-query/shared/src/main/scala/zio/query/DataSource.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import zio.{Chunk, Exit, Trace, ZEnvironment, ZIO}
* for all requests received. Failure to do so will cause a query to die with a
* `QueryFailure` when run.
*/
trait DataSource[-R, -A] { self =>
trait DataSource[-R, -A] extends Serializable { self =>

/**
* Syntax for adding aspects.
Expand All @@ -65,7 +65,7 @@ trait DataSource[-R, -A] { self =>
*/
def batchN(n: Int): DataSource[R, A] =
new DataSource[R, A] {
val identifier = s"${self}.batchN($n)"
val identifier = s"$self.batchN($n)"
def runAll(requests: Chunk[Chunk[A]])(implicit trace: Trace): ZIO[R, Nothing, CompletedRequestMap] =
if (n < 1)
ZIO.die(new IllegalArgumentException("batchN: n must be at least 1"))
Expand Down
2 changes: 1 addition & 1 deletion zio-query/shared/src/main/scala/zio/query/Request.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ import zio.stacktracer.TracingImplicits.disableAutoTrace
*
* }}}
*/
trait Request[E, A]
trait Request[E, A] extends Product with Serializable
7 changes: 4 additions & 3 deletions zio-query/shared/src/main/scala/zio/query/ZQuery.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1532,6 +1532,7 @@ object ZQuery {
dataSource: DataSource[R, A],
request: A
)(implicit ev: A <:< Request[E, B], trace: Trace): CachedResult[R, E, B] = {
implicit val unsafe: Unsafe = Unsafe.unsafe

def foldPromise(either: Either[Promise[E, B], Promise[E, B]]): CachedResult[R, E, B] =
either match {
Expand All @@ -1543,16 +1544,16 @@ object ZQuery {
)
)
case Right(promise) =>
promise.unsafe.poll(Unsafe.unsafe) match {
promise.unsafe.poll match {
case None => CachedResult.Pure(Result.blocked(BlockedRequests.empty, Continue(promise)))
case Some(io: Exit[E, B]) => CachedResult.Pure(Result.fromExit(io))
case Some(io) => CachedResult.Effectful(io.exit.map(Result.fromExit))
}
}

cache match {
case cache: Cache.Default => foldPromise(cache.lookupUnsafe(request)(Unsafe.unsafe))
case cache => CachedResult.Effectful(cache.lookup(request).flatMap(foldPromise(_).toZIO))
case cache: Cache.Default => foldPromise(cache.lookupUnsafe(dataSource, request))
case cache => CachedResult.Effectful(cache.lookup(dataSource, request).flatMap(foldPromise(_).toZIO))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ private[query] sealed trait BlockedRequests[-R] { self =>
completedRequestsM.remove(req) orElse completedRequests.lookup(req)
}
// cache responses that were not requested but were completed by the DataSource
if (completedRequestsM.nonEmpty) cacheLeftovers(cache, completedRequestsM) else ZIO.unit
if (completedRequestsM.nonEmpty) cacheLeftovers(cache, dataSource, completedRequestsM) else ZIO.unit
} else {
// No need to remove entries here since we don't need to know which ones we need to put in the cache
ZIO.succeed(completePromises(dataSource, sequential)(completedRequestsM.get))
Expand Down Expand Up @@ -295,22 +295,23 @@ private[query] object BlockedRequests {

private def cacheLeftovers(
cache: Cache,
dataSource: DataSource[?, Any],
map: mutable.HashMap[Request[_, _], Exit[Any, Any]]
)(implicit trace: Trace): UIO[Unit] =
cache match {
case cache: Cache.Default =>
ZIO.succeedUnsafe { implicit unsafe =>
map.foreach { case (request: Request[Any, Any], exit) =>
cache
.lookupUnsafe(request)
.lookupUnsafe(dataSource, request)
.merge
.unsafe
.done(exit)
}
}
case cache =>
ZIO.foreachDiscard(map) { case (request: Request[Any, Any], exit) =>
cache.lookup(request).flatMap(_.merge.done(exit))
cache.lookup(dataSource, request).flatMap(_.merge.done(exit))
}
}
}
44 changes: 39 additions & 5 deletions zio-query/shared/src/test/scala/zio/query/ZQuerySpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,22 @@ object ZQuerySpec extends ZIOBaseSpec {
log <- TestConsole.output
} yield assert(log)(hasSize(equalTo(2)))
},
test("N + 1 on composed queries") {
for {
ref <- Ref.make(0)
ds = countingDs(ref)
query1 = identityDs.query(IdReq(1)).flatMap(i => ds.query(CountingReq(i)))
query2 = identityDs.query(IdReq(2)).flatMap(i => ds.query(CountingReq(i)))
res <- (query1 <~> query2).run
i <- ref.get
expected = ("1", "2")
} yield assertTrue(i == 1, res == expected)
},
test("requests are cached per datasource") {
for {
res <- identityDs.query(IdReq(1)).flatMap(i => plus1Ds.query(IdReq(i))).run
} yield assertTrue(res == 2)
},
test("mapError does not prevent batching") {
implicit val canFail = zio.CanFail
val a = getUserNameById(1).zip(getUserNameById(2)).mapError(identity)
Expand Down Expand Up @@ -204,8 +220,9 @@ object ZQuerySpec extends ZIOBaseSpec {
} yield res
requestResult <- query.runCache(cache)
oneToTen = (1 to 10).toList
cachedResults <- ZIO.foreach(oneToTen)(i => cache.get(Req.Get(i)).flatMap(_.await))
cacheCheck = cachedResults == oneToTen.map(_.toString)
cachedResults <-
ZIO.foreach(oneToTen)(i => cache.get(dsCompletingMoreRequests, Req.Get(i)).flatMap(_.await))
cacheCheck = cachedResults == oneToTen.map(_.toString)
} yield assertTrue(requestResult == "1", cacheCheck)
},
test("caching disabled") {
Expand All @@ -216,7 +233,7 @@ object ZQuerySpec extends ZIOBaseSpec {
} yield res
requestResult <- query.runCache(cache)
oneToTen = (1 to 10).toList
cachedResults <- ZIO.foreach(oneToTen)(i => cache.get(Req.Get(i)).isFailure)
cachedResults <- ZIO.foreach(oneToTen)(i => cache.get(dsCompletingMoreRequests, Req.Get(i)).isFailure)
cacheCheck = cachedResults.forall(identity)
} yield assertTrue(requestResult == "1", cacheCheck)
}
Expand All @@ -226,7 +243,7 @@ object ZQuerySpec extends ZIOBaseSpec {
cache <- zio.query.Cache.empty
query = for {
_ <- getUserNameById(1)
_ <- ZQuery.fromZIO(cache.remove(GetNameById(1)))
_ <- ZQuery.fromZIO(cache.remove(UserRequestDataSource, GetNameById(1)))
_ <- getUserNameById(1)
} yield ()
_ <- query.runCache(cache)
Expand Down Expand Up @@ -270,7 +287,7 @@ object ZQuerySpec extends ZIOBaseSpec {
assert(log)(hasAt(0)(containsString("GetNameById(1)"))) &&
assert(log)(hasAt(0)(containsString("GetNameById(2)"))) &&
assert(log)(hasAt(1)(containsString("GetNameById(1)")))
} @@ nonFlaky,
} @@ nonFlaky(50),
suite("race")(
test("race with never") {
val query = ZQuery.never.race(ZQuery.succeed(()))
Expand Down Expand Up @@ -395,6 +412,23 @@ object ZQuerySpec extends ZIOBaseSpec {
}
}

case class IdReq(id: Int) extends Request[Nothing, Int]
val identityDs: DataSource[Any, IdReq] =
DataSource.fromFunctionBatchedZIO[Any, Nothing, IdReq, Int]("IdDs") { requests =>
ZIO.succeed(requests.map(_.id))
}

val plus1Ds: DataSource[Any, IdReq] =
DataSource.fromFunctionBatchedZIO[Any, Nothing, IdReq, Int]("Plus1Ds") { requests =>
ZIO.succeed(requests.map(_.id + 1))
}

case class CountingReq(id: Int) extends Request[Nothing, String]
def countingDs(ref: Ref[Int]): DataSource[Any, CountingReq] =
DataSource.fromFunctionBatchedZIO[Any, Nothing, CountingReq, String]("CountingDs") { requests =>
ref.update(_ + 1).as(requests.map(_.id.toString))
}

val getAllUserIds: ZQuery[Any, Nothing, List[Int]] =
ZQuery.fromRequest(GetAllIds)(UserRequestDataSource)

Expand Down
Loading