Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CHIP-1] Cache batch IRs in the Fetcher #682

Merged
merged 23 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
3933137
Add LRU Cache
caiocamatta-stripe Feb 12, 2024
ab1d2b3
Double gauge
caiocamatta-stripe Feb 12, 2024
355f101
Revert accidental build sbt change
caiocamatta-stripe Feb 12, 2024
b52d988
Move GroupByRequestMeta to object
caiocamatta-stripe Feb 12, 2024
a28ce28
Add FetcherCache and tests
caiocamatta-stripe Feb 12, 2024
d7801bb
Update FetcherBase to use cache
caiocamatta-stripe Feb 21, 2024
9f711d5
Scala 2.13 support?
caiocamatta-stripe Feb 22, 2024
df3242c
Add getServingInfo unit tests
caiocamatta-stripe Feb 26, 2024
89d68bb
Refactor getServingInfo tests
caiocamatta-stripe Feb 28, 2024
fd0df49
Fix stub
caiocamatta-stripe Feb 28, 2024
77c711a
Fix Mockito "ambiguous reference to overloaded definition" error
caiocamatta-stripe Apr 2, 2024
0789bea
Fix "Both batch and streaming data are null" check
caiocamatta-stripe Apr 2, 2024
acbe30a
Address PR review: add comments and use logger
caiocamatta-stripe Apr 3, 2024
7f3683e
Merge main
caiocamatta-stripe May 9, 2024
80bd7cf
Fewer comments for constructGroupByResponse
caiocamatta-stripe May 9, 2024
2873a67
Use the FlagStore to determine if caching is enabled
caiocamatta-stripe May 9, 2024
d024746
Apply suggestions from code review
caiocamatta-stripe Jun 21, 2024
abd0b66
Merge branch 'main' into caiocamatta--fetcher-batch-ir-caching-oss
caiocamatta-stripe Jun 24, 2024
b9b54a0
Address review, add comments, rename tests
caiocamatta-stripe Jun 24, 2024
db950fc
Change test names
caiocamatta-stripe Jun 24, 2024
36c4ba4
CamelCase FetcherBaseTest
caiocamatta-stripe Jun 24, 2024
581d427
Update online/src/main/scala/ai/chronon/online/FetcherBase.scala
caiocamatta-stripe Jun 26, 2024
0d8d6f4
fmt
caiocamatta-stripe Jun 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,8 @@ lazy val online = project
// statsd 3.0 has local aggregation - TODO: upgrade
"com.datadoghq" % "java-dogstatsd-client" % "2.7",
"org.rogach" %% "scallop" % "4.0.1",
"net.jodah" % "typetools" % "0.4.1"
"net.jodah" % "typetools" % "0.4.1",
"com.github.ben-manes.caffeine" % "caffeine" % "2.9.3"
),
libraryDependencies ++= fromMatrix(scalaVersion.value, "spark-all", "scala-parallel-collections", "netty-buffer")
)
Expand All @@ -308,7 +309,8 @@ lazy val online_unshaded = (project in file("online"))
// statsd 3.0 has local aggregation - TODO: upgrade
"com.datadoghq" % "java-dogstatsd-client" % "2.7",
"org.rogach" %% "scallop" % "4.0.1",
"net.jodah" % "typetools" % "0.4.1"
"net.jodah" % "typetools" % "0.4.1",
"com.github.ben-manes.caffeine" % "caffeine" % "2.9.3"
),
libraryDependencies ++= fromMatrix(scalaVersion.value,
"jackson",
Expand Down
259 changes: 180 additions & 79 deletions online/src/main/scala/ai/chronon/online/FetcherBase.scala

Large diffs are not rendered by default.

215 changes: 215 additions & 0 deletions online/src/main/scala/ai/chronon/online/FetcherCache.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
package ai.chronon.online

import ai.chronon.aggregator.windowing.FinalBatchIr
import ai.chronon.api.Extensions.MetadataOps
import ai.chronon.api.GroupBy
import ai.chronon.online.FetcherBase.GroupByRequestMeta
import ai.chronon.online.Fetcher.Request
import ai.chronon.online.FetcherCache.{
BatchIrCache,
BatchResponses,
CachedBatchResponse,
CachedFinalIrBatchResponse,
CachedMapBatchResponse,
KvStoreBatchResponse
}
import ai.chronon.online.KVStore.{GetRequest, TimedValue}
import com.github.benmanes.caffeine.cache.{Cache => CaffeineCache}

import scala.util.{Success, Try}
import java.util.concurrent.ConcurrentHashMap
import scala.collection.JavaConverters.mapAsScalaConcurrentMapConverter
import scala.collection.Seq

/*
* FetcherCache is an extension to FetcherBase that provides caching functionality. It caches KV store
* requests to decrease feature serving latency.
* */
trait FetcherCache {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something I considered months ago when making these code changes was to use a class (like "FetcherBaseCached") instead of a trait . This class would inherit from FetcherBase and overrides its methods to add caching. That would arguably be slightly cleaner because it'd give us two completely separate version of the Fetcher. Users would be able to use the normal/old Fetcher with no caching if they wanted.

This would require a significant amount of additional refactoring and re-testing, and I don't think it's worth it. Ideally, once this IR cache is merged in, it becomes a core part of the fetcher that users can enable / disable for their GroupBys as necessary. We've already tested the status quo (no caching), so IMO the risks that arise from additional refactors would outweigh the benefits of having a separate class.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another challenge with the FetcherBaseCached approach is it will lead to a profusion of FetcherBaseX and FetcherBaseY classes as we keep adding functionalities to the fetcher in the future.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Plus one to Piyush's point here.

val batchIrCacheName = "batch_cache"
val maybeBatchIrCache: Option[BatchIrCache] =
Option(System.getProperty("ai.chronon.fetcher.batch_ir_cache_size"))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the unit for the size? Mb or elements size?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great point. It's in elements. I'll update to make that clearer.

.map(size => new BatchIrCache(batchIrCacheName, size.toInt))
.orElse(None)

def isCacheSizeConfigured: Boolean = maybeBatchIrCache.isDefined

// Memoize which GroupBys have caching enabled
private[online] val isCachingEnabledForGroupBy: collection.concurrent.Map[String, Boolean] =
new ConcurrentHashMap[String, Boolean]().asScala

def isCachingEnabled(groupBy: GroupBy): Boolean = {
if (!isCacheSizeConfigured || groupBy.getMetaData == null || groupBy.getMetaData.getName == null) return false

val groupByName = groupBy.getMetaData.getName
isCachingEnabledForGroupBy.getOrElse(
groupByName, {
groupBy.getMetaData.customJsonLookUp("enable_caching") match {
case b: Boolean =>
println(s"Caching is ${if (b) "enabled" else "disabled"} for $groupByName")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

switch to log (here and others?)

isCachingEnabledForGroupBy.putIfAbsent(groupByName, b)
b
case null =>
println(s"Caching is disabled for $groupByName, enable_caching is not set.")
isCachingEnabledForGroupBy.putIfAbsent(groupByName, false)
false
case _ => false
}
}
)
}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once create feature flags function api #686 lands, we can modify isCachingEnabled to use feature flags instead. That's much better than a parameter in the GroupBy definition as it allows you to make immediate changes instead of having to modify your GroupBy to enable/disable caching.


protected val caffeineMetricsContext: Metrics.Context = Metrics.Context(Metrics.Environment.JoinFetching)

/**
* Obtain the Map[String, AnyRef] response from a batch response.
*
* If batch IR caching is enabled, this method will try to fetch the IR from the cache. If it's not in the cache,
* it will decode it from the batch bytes and store it.
*
* @param batchResponses the batch responses
* @param batchBytes the batch bytes corresponding to the batchResponses. Can be `null`.
* @param servingInfo the GroupByServingInfoParsed that contains the info to decode the bytes
* @param decodingFunction the function to decode bytes into Map[String, AnyRef]
* @param keys the keys used to fetch this particular batch response, for caching purposes
*/
private[online] def getMapResponseFromBatchResponse(batchResponses: BatchResponses,
batchBytes: Array[Byte],
decodingFunction: Array[Byte] => Map[String, AnyRef],
servingInfo: GroupByServingInfoParsed,
keys: Map[String, Any]): Map[String, AnyRef] = {
if (!isCachingEnabled(servingInfo.groupBy)) return decodingFunction(batchBytes)

batchResponses match {
case _: KvStoreBatchResponse =>
val batchRequestCacheKey =
BatchIrCache.Key(servingInfo.groupByOps.batchDataset, keys, servingInfo.batchEndTsMillis)
val decodedBytes = decodingFunction(batchBytes)
if (decodedBytes != null)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caffeine doesn't allow storing null, and in general I decided to exclude negative caching for now. Most likely, the keys that end up cached are not new keys and do have batch data (e.g. big merchants, or power users).

maybeBatchIrCache.get.cache.put(batchRequestCacheKey, CachedMapBatchResponse(decodedBytes))
decodedBytes
case cachedResponse: CachedBatchResponse =>
cachedResponse match {
case CachedFinalIrBatchResponse(_: FinalBatchIr) => decodingFunction(batchBytes)
case CachedMapBatchResponse(mapResponse: Map[String, AnyRef]) => mapResponse
}
}
}

/**
* Obtain the FinalBatchIr from a batch response.
*
* If batch IR caching is enabled, this method will try to fetch the IR from the cache. If it's not in the cache,
* it will decode it from the batch bytes and store it.
*
* @param batchResponses the batch responses
* @param batchBytes the batch bytes corresponding to the batchResponses. Can be `null`.
* @param servingInfo the GroupByServingInfoParsed that contains the info to decode the bytes
* @param decodingFunction the function to decode bytes into FinalBatchIr
* @param keys the keys used to fetch this particular batch response, for caching purposes
*/
private[online] def getBatchIrFromBatchResponse(
batchResponses: BatchResponses,
batchBytes: Array[Byte],
servingInfo: GroupByServingInfoParsed,
decodingFunction: (Array[Byte], GroupByServingInfoParsed) => FinalBatchIr,
keys: Map[String, Any]): FinalBatchIr = {
if (!isCachingEnabled(servingInfo.groupBy)) return decodingFunction(batchBytes, servingInfo)

batchResponses match {
case _: KvStoreBatchResponse =>
val batchRequestCacheKey =
BatchIrCache.Key(servingInfo.groupByOps.batchDataset, keys, servingInfo.batchEndTsMillis)
val decodedBytes = decodingFunction(batchBytes, servingInfo)
if (decodedBytes != null)
maybeBatchIrCache.get.cache.put(batchRequestCacheKey, CachedFinalIrBatchResponse(decodedBytes))
decodedBytes
case cachedResponse: CachedBatchResponse =>
cachedResponse match {
case CachedFinalIrBatchResponse(finalBatchIr: FinalBatchIr) => finalBatchIr
case CachedMapBatchResponse(_: Map[String, AnyRef]) => decodingFunction(batchBytes, servingInfo)
}
}
}

/**
* Given a list of GetRequests, return a map of GetRequests to cached FinalBatchIrs.
*/
def getCachedRequests(
groupByRequestToKvRequest: Seq[(Request, Try[GroupByRequestMeta])]): Map[GetRequest, CachedBatchResponse] = {
if (!isCacheSizeConfigured) return Map.empty

groupByRequestToKvRequest
.map {
case (request, Success(GroupByRequestMeta(servingInfo, batchRequest, _, _, _))) =>
if (!isCachingEnabled(servingInfo.groupBy)) { Map.empty }
else {
val batchRequestCacheKey =
BatchIrCache.Key(batchRequest.dataset, request.keys, servingInfo.batchEndTsMillis)

// Metrics so we can get per-groupby cache metrics
val metricsContext =
request.context.getOrElse(Metrics.Context(Metrics.Environment.JoinFetching, servingInfo.groupBy))

maybeBatchIrCache.get.cache.getIfPresent(batchRequestCacheKey) match {
case null =>
metricsContext.increment(s"${batchIrCacheName}_gb_misses")
val emptyMap: Map[GetRequest, CachedBatchResponse] = Map.empty
emptyMap
case cachedIr: CachedBatchResponse =>
metricsContext.increment(s"${batchIrCacheName}_gb_hits")
Map(batchRequest -> cachedIr)
}
}
case _ =>
val emptyMap: Map[GetRequest, CachedBatchResponse] = Map.empty
emptyMap
}
.foldLeft(Map.empty[GetRequest, CachedBatchResponse])(_ ++ _)
}
}

object FetcherCache {
private[online] class BatchIrCache(val cacheName: String, val maximumSize: Int = 10000) {
import BatchIrCache._

val cache: CaffeineCache[Key, Value] =
LRUCache[Key, Value](cacheName = cacheName, maximumSize = maximumSize)
}

private[online] object BatchIrCache {
// We use the dataset, keys, and batchEndTsMillis to identify a batch request.
// There's one edge case to be aware of: if a batch job is re-run in the same day, the batchEndTsMillis will
// be the same but the underlying data may have have changed. If that new batch data is needed immediately, the
// Fetcher service should be restarted.
case class Key(dataset: String, keys: Map[String, Any], batchEndTsMillis: Long)

// FinalBatchIr is for GroupBys using temporally accurate aggregation.
// Map[String, Any] is for GroupBys using snapshot accurate aggregation or no aggregation.
type Value = BatchResponses
}

// BatchResponses encapsulates either a batch response from kv store or a cached batch response.
sealed abstract class BatchResponses {
def getBatchBytes(batchEndTsMillis: Long): Array[Byte]
}
object BatchResponses {
def apply(kvStoreResponse: Try[Seq[TimedValue]]): KvStoreBatchResponse = KvStoreBatchResponse(kvStoreResponse)
def apply(cachedResponse: FinalBatchIr): CachedFinalIrBatchResponse = CachedFinalIrBatchResponse(cachedResponse)
def apply(cachedResponse: Map[String, AnyRef]): CachedMapBatchResponse = CachedMapBatchResponse(cachedResponse)
}
case class KvStoreBatchResponse(response: Try[Seq[TimedValue]]) extends BatchResponses {
piyushn-stripe marked this conversation as resolved.
Show resolved Hide resolved
def getBatchBytes(batchEndTsMillis: Long): Array[Byte] =
response
.map(_.maxBy(_.millis))
.filter(_.millis >= batchEndTsMillis)
.map(_.bytes)
.getOrElse(null)
}
sealed abstract class CachedBatchResponse extends BatchResponses {
// This is the case where we don't have bytes because the decoded IR was cached so we didn't hit the KV store again.
def getBatchBytes(batchEndTsMillis: Long): Null = null
}
case class CachedFinalIrBatchResponse(response: FinalBatchIr) extends CachedBatchResponse
case class CachedMapBatchResponse(response: Map[String, AnyRef]) extends CachedBatchResponse
}
57 changes: 57 additions & 0 deletions online/src/main/scala/ai/chronon/online/LRUCache.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package ai.chronon.online

import com.github.benmanes.caffeine.cache.{Caffeine, Cache => CaffeineCache}

/**
* Utility to create a cache with LRU semantics.
*
* The original purpose of having an LRU cache in Chronon is to cache KVStore calls and decoded IRs
* in the Fetcher. This helps decrease to feature serving latency.
*/
object LRUCache {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caffeine is technically not LRU but it's similar. I think naming this LRUCache makes it easy to understand what it does (and to contrast it with the existing TTLCache).


/**
* Build a bounded, thread-safe Caffeine cache that stores KEY-VALUE pairs.
*
* @param cacheName Name of the cache
* @param maximumSize Maximum number of entries in the cache
* @tparam KEY The type of the key used to access the cache
* @tparam VALUE The type of the value stored in the cache
* @return Caffeine cache
*/
def apply[KEY <: Object, VALUE <: Object](cacheName: String, maximumSize: Int = 10000): CaffeineCache[KEY, VALUE] = {
buildCaffeineCache[KEY, VALUE](cacheName, maximumSize)
}

private def buildCaffeineCache[KEY <: Object, VALUE <: Object](
cacheName: String,
maximumSize: Int = 10000): CaffeineCache[KEY, VALUE] = {
println(s"Chronon Cache build started. cacheName=$cacheName")
caiocamatta-stripe marked this conversation as resolved.
Show resolved Hide resolved
val cache: CaffeineCache[KEY, VALUE] = Caffeine
.newBuilder()
.maximumSize(maximumSize)
.recordStats()
.build[KEY, VALUE]()
println(s"Chronon Cache build finished. cacheName=$cacheName")
caiocamatta-stripe marked this conversation as resolved.
Show resolved Hide resolved
cache
}

/**
* Report metrics for a Caffeine cache. The "cache" tag is added to all metrics.
*
* @param metricsContext Metrics.Context for recording metrics
* @param cache Caffeine cache to get metrics from
* @param cacheName Cache name for tagging
*/
def collectCaffeineCacheMetrics(metricsContext: Metrics.Context,
cache: CaffeineCache[_, _],
cacheName: String): Unit = {
val stats = cache.stats()
metricsContext.gauge(s"$cacheName.hits", stats.hitCount())
metricsContext.gauge(s"$cacheName.misses", stats.missCount())
metricsContext.gauge(s"$cacheName.evictions", stats.evictionCount())
metricsContext.gauge(s"$cacheName.loads", stats.loadCount())
metricsContext.gauge(s"$cacheName.hit_rate", stats.hitRate())
metricsContext.gauge(s"$cacheName.average_load_penalty", stats.averageLoadPenalty())
}
}
2 changes: 2 additions & 0 deletions online/src/main/scala/ai/chronon/online/Metrics.scala
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,8 @@ object Metrics {

def gauge(metric: String, value: Long): Unit = stats.gauge(prefix(metric), value, tags)

def gauge(metric: String, value: Double): Unit = stats.gauge(prefix(metric), value, tags)

def toTags: Array[String] = {
val joinNames: Array[String] = Option(join).map(_.split(",")).getOrElse(Array.empty[String]).map(_.sanitize)
assert(
Expand Down
54 changes: 51 additions & 3 deletions online/src/test/scala/ai/chronon/online/FetcherBaseTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@

package ai.chronon.online

import ai.chronon.aggregator.windowing.FinalBatchIr
import ai.chronon.api.Extensions.GroupByOps
import ai.chronon.api.MetaData
import ai.chronon.online.Fetcher.{ColumnSpec, Request, Response}
import ai.chronon.online.FetcherCache.BatchResponses
import ai.chronon.online.KVStore.TimedValue
import org.junit.{Before, Test}
import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito._
Expand All @@ -29,8 +34,9 @@ import org.scalatestplus.mockito.MockitoSugar
import scala.concurrent.duration.DurationInt
import scala.concurrent.{Await, ExecutionContext, Future}
import scala.util.{Failure, Success}
import scala.util.Try

class FetcherBaseTest extends MockitoSugar with Matchers {
class FetcherBaseTest extends MockitoSugar with Matchers with MockitoHelper {
val GroupBy = "relevance.short_term_user_features"
val Column = "pdp_view_count_14d"
val GuestKey = "guest"
Expand Down Expand Up @@ -118,7 +124,7 @@ class FetcherBaseTest extends MockitoSugar with Matchers {
// Fetch a single query
val keyMap = Map(GuestKey -> GuestId)
val query = ColumnSpec(GroupBy, Column, None, Some(keyMap))

doAnswer(new Answer[Future[Seq[Fetcher.Response]]] {
def answer(invocation: InvocationOnMock): Future[Seq[Response]] = {
Future.successful(Seq())
Expand All @@ -130,7 +136,7 @@ class FetcherBaseTest extends MockitoSugar with Matchers {
queryResults.contains(query) shouldBe true
queryResults.get(query).map(_.values) match {
case Some(Failure(ex: IllegalStateException)) => succeed
case _ => fail()
case _ => fail()
}

// GroupBy request sent to KV store for the query
Expand All @@ -141,4 +147,46 @@ class FetcherBaseTest extends MockitoSugar with Matchers {
actualRequest.get.name shouldBe query.groupByName + "." + query.columnName
actualRequest.get.keys shouldBe query.keyMapping.get
}

// updateServingInfo() is called when the batch response is from the KV store.
@Test
def test_getServingInfo_ShouldCallUpdateServingInfoIfBatchResponseIsFromKvStore(): Unit = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as the other one, could you make it camel case only to be consistent with the naming convention? thank you!

val oldServingInfo = mock[GroupByServingInfoParsed]
val updatedServingInfo = mock[GroupByServingInfoParsed]
doReturn(updatedServingInfo).when(fetcherBase).updateServingInfo(any(), any())

val batchTimedValuesSuccess = Success(Seq(TimedValue(Array(1.toByte), 2000L)))
val kvStoreBatchResponses = BatchResponses(batchTimedValuesSuccess)

val result = fetcherBase.getServingInfo(oldServingInfo, kvStoreBatchResponses)

// updateServingInfo is called
result shouldEqual updatedServingInfo
verify(fetcherBase).updateServingInfo(any(), any())
}

// If a batch response is cached, the serving info should be refreshed. This is needed to prevent
// the serving info from becoming stale if all the requests are cached.
@Test
def test_getServingInfo_ShouldRefreshServingInfoIfBatchResponseIsCached(): Unit = {
val ttlCache = mock[TTLCache[String, Try[GroupByServingInfoParsed]]]
doReturn(ttlCache).when(fetcherBase).getGroupByServingInfo

val oldServingInfo = mock[GroupByServingInfoParsed]
doReturn(Success(oldServingInfo)).when(ttlCache).refresh(any[String])

val metaDataMock = mock[MetaData]
val groupByOpsMock = mock[GroupByOps]
metaDataMock.name = "test"
groupByOpsMock.metaData = metaDataMock
doReturn(groupByOpsMock).when(oldServingInfo).groupByOps

val cachedBatchResponses = BatchResponses(mock[FinalBatchIr])
val result = fetcherBase.getServingInfo(oldServingInfo, cachedBatchResponses)

// FetcherBase.updateServingInfo is not called, but getGroupByServingInfo.refresh() is.
result shouldEqual oldServingInfo
verify(ttlCache).refresh(any())
verify(fetcherBase, never()).updateServingInfo(any(), any())
}
}
Loading