diff --git a/online/src/main/scala/ai/chronon/online/FetcherCache.scala b/online/src/main/scala/ai/chronon/online/FetcherCache.scala new file mode 100644 index 000000000..b67fd57d9 --- /dev/null +++ b/online/src/main/scala/ai/chronon/online/FetcherCache.scala @@ -0,0 +1,214 @@ +package ai.chronon.online + +import ai.chronon.aggregator.windowing.FinalBatchIr +import ai.chronon.api.Extensions.MetadataOps +import ai.chronon.api.GroupBy +import ai.chronon.online.FetcherBase.GroupByRequestMeta +import ai.chronon.online.Fetcher.Request +import ai.chronon.online.FetcherCache.{ + BatchIrCache, + BatchResponses, + CachedBatchResponse, + CachedFinalIrBatchResponse, + CachedMapBatchResponse, + KvStoreBatchResponse +} +import ai.chronon.online.KVStore.{GetRequest, TimedValue} +import com.github.benmanes.caffeine.cache.{Cache => CaffeineCache} + +import scala.util.{Success, Try} +import java.util.concurrent.ConcurrentHashMap +import scala.collection.JavaConverters.mapAsScalaConcurrentMapConverter + +/* + * FetcherCache is an extension to FetcherBase that provides caching functionality. It caches KV store + * requests to decrease feature serving latency. + * */ +trait FetcherCache { + val batchIrCacheName = "batch_cache" + val maybeBatchIrCache: Option[BatchIrCache] = + Option(System.getProperty("ai.chronon.fetcher.batch_ir_cache_size")) + .map(size => new BatchIrCache(batchIrCacheName, size.toInt)) + .orElse(None) + + def isCacheSizeConfigured: Boolean = maybeBatchIrCache.isDefined + + // Memoize which GroupBys have caching enabled + private[online] val isCachingEnabledForGroupBy: collection.concurrent.Map[String, Boolean] = + new ConcurrentHashMap[String, Boolean]().asScala + + def isCachingEnabled(groupBy: GroupBy): Boolean = { + if (!isCacheSizeConfigured || groupBy.getMetaData == null || groupBy.getMetaData.getName == null) return false + + val groupByName = groupBy.getMetaData.getName + isCachingEnabledForGroupBy.getOrElse( + groupByName, { + groupBy.getMetaData.customJsonLookUp("enable_caching") match { + case b: Boolean => + println(s"Caching is ${if (b) "enabled" else "disabled"} for $groupByName") + isCachingEnabledForGroupBy.putIfAbsent(groupByName, b) + b + case null => + println(s"Caching is disabled for $groupByName, enable_caching is not set.") + isCachingEnabledForGroupBy.putIfAbsent(groupByName, false) + false + case _ => false + } + } + ) + } + + protected val caffeineMetricsContext: Metrics.Context = Metrics.Context(Metrics.Environment.JoinFetching) + + /** + * Obtain the Map[String, AnyRef] response from a batch response. + * + * If batch IR caching is enabled, this method will try to fetch the IR from the cache. If it's not in the cache, + * it will decode it from the batch bytes and store it. + * + * @param batchResponses the batch responses + * @param batchBytes the batch bytes corresponding to the batchResponses. Can be `null`. + * @param servingInfo the GroupByServingInfoParsed that contains the info to decode the bytes + * @param decodingFunction the function to decode bytes into Map[String, AnyRef] + * @param keys the keys used to fetch this particular batch response, for caching purposes + */ + private[online] def getMapResponseFromBatchResponse(batchResponses: BatchResponses, + batchBytes: Array[Byte], + decodingFunction: Array[Byte] => Map[String, AnyRef], + servingInfo: GroupByServingInfoParsed, + keys: Map[String, Any]): Map[String, AnyRef] = { + if (!isCachingEnabled(servingInfo.groupBy)) return decodingFunction(batchBytes) + + batchResponses match { + case _: KvStoreBatchResponse => + val batchRequestCacheKey = + BatchIrCache.Key(servingInfo.groupByOps.batchDataset, keys, servingInfo.batchEndTsMillis) + val decodedBytes = decodingFunction(batchBytes) + if (decodedBytes != null) + maybeBatchIrCache.get.cache.put(batchRequestCacheKey, CachedMapBatchResponse(decodedBytes)) + decodedBytes + case cachedResponse: CachedBatchResponse => + cachedResponse match { + case CachedFinalIrBatchResponse(_: FinalBatchIr) => decodingFunction(batchBytes) + case CachedMapBatchResponse(mapResponse: Map[String, AnyRef]) => mapResponse + } + } + } + + /** + * Obtain the FinalBatchIr from a batch response. + * + * If batch IR caching is enabled, this method will try to fetch the IR from the cache. If it's not in the cache, + * it will decode it from the batch bytes and store it. + * + * @param batchResponses the batch responses + * @param batchBytes the batch bytes corresponding to the batchResponses. Can be `null`. + * @param servingInfo the GroupByServingInfoParsed that contains the info to decode the bytes + * @param decodingFunction the function to decode bytes into FinalBatchIr + * @param keys the keys used to fetch this particular batch response, for caching purposes + */ + private[online] def getBatchIrFromBatchResponse( + batchResponses: BatchResponses, + batchBytes: Array[Byte], + servingInfo: GroupByServingInfoParsed, + decodingFunction: (Array[Byte], GroupByServingInfoParsed) => FinalBatchIr, + keys: Map[String, Any]): FinalBatchIr = { + if (!isCachingEnabled(servingInfo.groupBy)) return decodingFunction(batchBytes, servingInfo) + + batchResponses match { + case _: KvStoreBatchResponse => + val batchRequestCacheKey = + BatchIrCache.Key(servingInfo.groupByOps.batchDataset, keys, servingInfo.batchEndTsMillis) + val decodedBytes = decodingFunction(batchBytes, servingInfo) + if (decodedBytes != null) + maybeBatchIrCache.get.cache.put(batchRequestCacheKey, CachedFinalIrBatchResponse(decodedBytes)) + decodedBytes + case cachedResponse: CachedBatchResponse => + cachedResponse match { + case CachedFinalIrBatchResponse(finalBatchIr: FinalBatchIr) => finalBatchIr + case CachedMapBatchResponse(_: Map[String, AnyRef]) => decodingFunction(batchBytes, servingInfo) + } + } + } + + /** + * Given a list of GetRequests, return a map of GetRequests to cached FinalBatchIrs. + */ + def getCachedRequests( + groupByRequestToKvRequest: Seq[(Request, Try[GroupByRequestMeta])]): Map[GetRequest, CachedBatchResponse] = { + if (!isCacheSizeConfigured) return Map.empty + + groupByRequestToKvRequest + .map { + case (request, Success(GroupByRequestMeta(servingInfo, batchRequest, _, _, _))) => + if (!isCachingEnabled(servingInfo.groupBy)) { Map.empty } + else { + val batchRequestCacheKey = + BatchIrCache.Key(batchRequest.dataset, request.keys, servingInfo.batchEndTsMillis) + + // Metrics so we can get per-groupby cache metrics + val metricsContext = + request.context.getOrElse(Metrics.Context(Metrics.Environment.JoinFetching, servingInfo.groupBy)) + + maybeBatchIrCache.get.cache.getIfPresent(batchRequestCacheKey) match { + case null => + metricsContext.increment(s"${batchIrCacheName}_gb_misses") + val emptyMap: Map[GetRequest, CachedBatchResponse] = Map.empty + emptyMap + case cachedIr: CachedBatchResponse => + metricsContext.increment(s"${batchIrCacheName}_gb_hits") + Map(batchRequest -> cachedIr) + } + } + case _ => + val emptyMap: Map[GetRequest, CachedBatchResponse] = Map.empty + emptyMap + } + .foldLeft(Map.empty[GetRequest, CachedBatchResponse])(_ ++ _) + } +} + +object FetcherCache { + private[online] class BatchIrCache(val cacheName: String, val maximumSize: Int = 10000) { + import BatchIrCache._ + + val cache: CaffeineCache[Key, Value] = + LRUCache[Key, Value](cacheName = cacheName, maximumSize = maximumSize) + } + + private[online] object BatchIrCache { + // We use the dataset, keys, and batchEndTsMillis to identify a batch request. + // There's one edge case to be aware of: if a batch job is re-run in the same day, the batchEndTsMillis will + // be the same but the underlying data may have have changed. If that new batch data is needed immediately, the + // Fetcher service should be restarted. + case class Key(dataset: String, keys: Map[String, Any], batchEndTsMillis: Long) + + // FinalBatchIr is for GroupBys using temporally accurate aggregation. + // Map[String, Any] is for GroupBys using snapshot accurate aggregation or no aggregation. + type Value = BatchResponses + } + + // BatchResponses encapsulates either a batch response from kv store or a cached batch response. + sealed abstract class BatchResponses { + def getBatchBytes(batchEndTsMillis: Long): Array[Byte] + } + object BatchResponses { + def apply(kvStoreResponse: Try[Seq[TimedValue]]): KvStoreBatchResponse = KvStoreBatchResponse(kvStoreResponse) + def apply(cachedResponse: FinalBatchIr): CachedFinalIrBatchResponse = CachedFinalIrBatchResponse(cachedResponse) + def apply(cachedResponse: Map[String, AnyRef]): CachedMapBatchResponse = CachedMapBatchResponse(cachedResponse) + } + case class KvStoreBatchResponse(response: Try[Seq[TimedValue]]) extends BatchResponses { + def getBatchBytes(batchEndTsMillis: Long): Array[Byte] = + response + .map(_.maxBy(_.millis)) + .filter(_.millis >= batchEndTsMillis) + .map(_.bytes) + .getOrElse(null) + } + sealed abstract class CachedBatchResponse extends BatchResponses { + // This is the case where we don't have bytes because the decoded IR was cached so we didn't hit the KV store again. + def getBatchBytes(batchEndTsMillis: Long): Null = null + } + case class CachedFinalIrBatchResponse(response: FinalBatchIr) extends CachedBatchResponse + case class CachedMapBatchResponse(response: Map[String, AnyRef]) extends CachedBatchResponse +} diff --git a/online/src/test/scala/ai/chronon/online/FetcherCacheTest.scala b/online/src/test/scala/ai/chronon/online/FetcherCacheTest.scala new file mode 100644 index 000000000..bb476b2f0 --- /dev/null +++ b/online/src/test/scala/ai/chronon/online/FetcherCacheTest.scala @@ -0,0 +1,381 @@ +package ai.chronon.online + +import ai.chronon.aggregator.windowing.FinalBatchIr +import ai.chronon.api.Extensions.GroupByOps +import ai.chronon.api.{Builders, GroupBy} +import ai.chronon.online.FetcherBase._ +import ai.chronon.online.{AvroCodec, FetcherCache, GroupByServingInfoParsed, KVStore, Metrics} +import ai.chronon.online.Fetcher.Request +import ai.chronon.online.FetcherCache.{BatchIrCache, BatchResponses, CachedMapBatchResponse} +import ai.chronon.online.KVStore.TimedValue +import ai.chronon.online.Metrics.Context +import org.junit.Assert.{assertArrayEquals, assertEquals, assertFalse, assertNull, assertTrue, fail} +import org.junit.Test +import org.mockito.Mockito._ +import org.mockito.ArgumentMatchers.any +import org.mockito.Mockito +import org.mockito.stubbing.Stubber +import org.scalatestplus.mockito.MockitoSugar + +import scala.collection.JavaConverters._ +import scala.util.{Failure, Success, Try} + +trait MockitoHelper extends MockitoSugar { + // Overriding doReturn to fix a known Java/Scala interoperability issue when using Mockito. ("doReturn: ambiguous + // reference to overloaded definition"). An alternative would be to use the 'mockito-scala' library. + def doReturn(toBeReturned: Any): Stubber = { + Mockito.doReturn(toBeReturned, Nil: _*) + } +} + +class FetcherCacheTest extends MockitoHelper { + class TestableFetcherCache(cache: Option[BatchIrCache]) extends FetcherCache { + override val maybeBatchIrCache: Option[BatchIrCache] = cache + } + val batchIrCacheMaximumSize = 50 + + @Test + def test_BatchIrCache_CorrectlyCachesBatchIrs(): Unit = { + val cacheName = "test" + val batchIrCache = new BatchIrCache(cacheName, batchIrCacheMaximumSize) + val dataset = "TEST_GROUPBY_BATCH" + val batchEndTsMillis = 1000L + + def createBatchir(i: Int) = + BatchResponses(FinalBatchIr(collapsed = Array(i), tailHops = Array(Array(Array(i)), Array(Array(i))))) + def createCacheKey(i: Int) = BatchIrCache.Key(dataset, Map("key" -> i), batchEndTsMillis) + + // Create a bunch of test batchIrs and store them in cache + val batchIrs: Map[BatchIrCache.Key, BatchIrCache.Value] = + (0 until batchIrCacheMaximumSize).map(i => createCacheKey(i) -> createBatchir(i)).toMap + batchIrCache.cache.putAll(batchIrs.asJava) + + // Check that the cache contains all the batchIrs we created + batchIrs.foreach(entry => { + val cachedBatchIr = batchIrCache.cache.getIfPresent(entry._1) + assertEquals(cachedBatchIr, entry._2) + }) + } + + @Test + def test_BatchIrCache_CorrectlyCachesMapResponse(): Unit = { + val cacheName = "test" + val batchIrCache = new BatchIrCache(cacheName, batchIrCacheMaximumSize) + val dataset = "TEST_GROUPBY_BATCH" + val batchEndTsMillis = 1000L + + def createMapResponse(i: Int) = + BatchResponses(Map("group_by_key" -> i.asInstanceOf[AnyRef])) + def createCacheKey(i: Int) = BatchIrCache.Key(dataset, Map("key" -> i), batchEndTsMillis) + + // Create a bunch of test mapResponses and store them in cache + val mapResponses: Map[BatchIrCache.Key, BatchIrCache.Value] = + (0 until batchIrCacheMaximumSize).map(i => createCacheKey(i) -> createMapResponse(i)).toMap + batchIrCache.cache.putAll(mapResponses.asJava) + + // Check that the cache contains all the mapResponses we created + mapResponses.foreach(entry => { + val cachedBatchIr = batchIrCache.cache.getIfPresent(entry._1) + assertEquals(cachedBatchIr, entry._2) + }) + } + + // Test that the cache keys are compared by equality, not by reference. In practice, this means that if two keys + // have the same (dataset, keys, batchEndTsMillis), they will only be stored once in the cache. + @Test + def test_BatchIrCache_KeysAreComparedByEquality(): Unit = { + val cacheName = "test" + val batchIrCache = new BatchIrCache(cacheName, batchIrCacheMaximumSize) + + val dataset = "TEST_GROUPBY_BATCH" + val batchEndTsMillis = 1000L + + def createCacheValue(i: Int) = + BatchResponses(FinalBatchIr(collapsed = Array(i), tailHops = Array(Array(Array(i)), Array(Array(i))))) + def createCacheKey(i: Int) = BatchIrCache.Key(dataset, Map("key" -> i), batchEndTsMillis) + + assert(batchIrCache.cache.estimatedSize() == 0) + batchIrCache.cache.put(createCacheKey(1), createCacheValue(1)) + assert(batchIrCache.cache.estimatedSize() == 1) + // Create a second key object with the same values as the first key, make sure it's not stored separately + batchIrCache.cache.put(createCacheKey(1), createCacheValue(1)) + assert(batchIrCache.cache.estimatedSize() == 1) + } + + @Test + def test_getCachedRequests_ReturnsCorrectCachedDataWhenCacheIsEnabled(): Unit = { + val cacheName = "test" + val testCache = Some(new BatchIrCache(cacheName, batchIrCacheMaximumSize)) + val fetcherCache = new TestableFetcherCache(testCache) { + override def isCachingEnabled(groupBy: GroupBy) = true + } + + // Prepare groupByRequestToKvRequest + val batchEndTsMillis = 0L + val keys = Map("key" -> "value") + val eventTs = 1000L + val dataset = "TEST_GROUPBY_BATCH" + val mockGroupByServingInfoParsed = mock[GroupByServingInfoParsed] + val mockContext = mock[Metrics.Context] + val request = Request("req_name", keys, Some(eventTs), Some(mock[Context])) + val getRequest = KVStore.GetRequest("key".getBytes, dataset, Some(eventTs)) + val requestMeta = + GroupByRequestMeta(mockGroupByServingInfoParsed, getRequest, Some(getRequest), Some(eventTs), mockContext) + val groupByRequestToKvRequest: Seq[(Request, Try[GroupByRequestMeta])] = Seq((request, Success(requestMeta))) + + // getCachedRequests should return an empty list when the cache is empty + val cachedRequestBeforePopulating = fetcherCache.getCachedRequests(groupByRequestToKvRequest) + assert(cachedRequestBeforePopulating.isEmpty) + + // Add a GetRequest and a FinalBatchIr + val key = BatchIrCache.Key(getRequest.dataset, keys, batchEndTsMillis) + val finalBatchIr = BatchResponses(FinalBatchIr(Array(1), Array(Array(Array(1)), Array(Array(1))))) + testCache.get.cache.put(key, finalBatchIr) + + // getCachedRequests should return the GetRequest and FinalBatchIr we cached + val cachedRequestsAfterAddingItem = fetcherCache.getCachedRequests(groupByRequestToKvRequest) + assert(cachedRequestsAfterAddingItem.head._1 == getRequest) + assert(cachedRequestsAfterAddingItem.head._2 == finalBatchIr) + } + + @Test + def test_getCachedRequests_DoesNotCacheWhenCacheIsDisabledForGroupBy(): Unit = { + val testCache = new BatchIrCache("test", batchIrCacheMaximumSize) + val spiedTestCache = spy(testCache) + val fetcherCache = new TestableFetcherCache(Some(testCache)) { + // Cache is enabled globally, but disabled for a specific groupBy + override def isCachingEnabled(groupBy: GroupBy) = false + } + + // Prepare groupByRequestToKvRequest + val keys = Map("key" -> "value") + val eventTs = 1000L + val dataset = "TEST_GROUPBY_BATCH" + val mockGroupByServingInfoParsed = mock[GroupByServingInfoParsed] + val mockContext = mock[Metrics.Context] + val request = Request("req_name", keys, Some(eventTs)) + val getRequest = KVStore.GetRequest("key".getBytes, dataset, Some(eventTs)) + val requestMeta = + GroupByRequestMeta(mockGroupByServingInfoParsed, getRequest, Some(getRequest), Some(eventTs), mockContext) + val groupByRequestToKvRequest: Seq[(Request, Try[GroupByRequestMeta])] = Seq((request, Success(requestMeta))) + + val cachedRequests = fetcherCache.getCachedRequests(groupByRequestToKvRequest) + assert(cachedRequests.isEmpty) + // Cache was never called + verify(spiedTestCache, never()).cache + } + + @Test + def test_getBatchBytes_ReturnsLatestTimedValueBytesIfGreaterThanBatchEnd(): Unit = { + val kvStoreResponse = Success( + Seq(TimedValue(Array(1.toByte), 1000L), TimedValue(Array(2.toByte), 2000L)) + ) + val batchResponses = BatchResponses(kvStoreResponse) + val batchBytes = batchResponses.getBatchBytes(1500L) + assertArrayEquals(Array(2.toByte), batchBytes) + } + + @Test + def test_getBatchBytes_ReturnsNullIfLatestTimedValueTimestampIsLessThanBatchEnd(): Unit = { + val kvStoreResponse = Success( + Seq(TimedValue(Array(1.toByte), 1000L), TimedValue(Array(2.toByte), 1500L)) + ) + val batchResponses = BatchResponses(kvStoreResponse) + val batchBytes = batchResponses.getBatchBytes(2000L) + assertNull(batchBytes) + } + + @Test + def test_getBatchBytes_ReturnsNullWhenCachedBatchResponse(): Unit = { + val finalBatchIr = mock[FinalBatchIr] + val batchResponses = BatchResponses(finalBatchIr) + val batchBytes = batchResponses.getBatchBytes(1000L) + assertNull(batchBytes) + } + + @Test + def test_getBatchBytes_ReturnsNullWhenKvStoreBatchResponseFails(): Unit = { + val kvStoreResponse = Failure(new RuntimeException("KV Store error")) + val batchResponses = BatchResponses(kvStoreResponse) + val batchBytes = batchResponses.getBatchBytes(1000L) + assertNull(batchBytes) + } + + @Test + def test_getBatchIrFromBatchResponse_ReturnsCorrectIRsWithCacheEnabled(): Unit = { + // Use a real cache + val batchIrCache = new BatchIrCache("test_cache", batchIrCacheMaximumSize) + + // Create all necessary mocks + val servingInfo = mock[GroupByServingInfoParsed] + val groupByOps = mock[GroupByOps] + val toBatchIr = mock[(Array[Byte], GroupByServingInfoParsed) => FinalBatchIr] + when(servingInfo.groupByOps).thenReturn(groupByOps) + when(groupByOps.batchDataset).thenReturn("test_dataset") + when(servingInfo.groupByOps.batchDataset).thenReturn("test_dataset") + when(servingInfo.batchEndTsMillis).thenReturn(1000L) + + // Dummy data + val batchBytes = Array[Byte](1, 1) + val keys = Map("key" -> "value") + val cacheKey = BatchIrCache.Key(servingInfo.groupByOps.batchDataset, keys, servingInfo.batchEndTsMillis) + + val fetcherCache = new TestableFetcherCache(Some(batchIrCache)) + val spiedFetcherCache = Mockito.spy(fetcherCache) + doReturn(true).when(spiedFetcherCache).isCachingEnabled(any()) + + // 1. Cached BatchResponse returns the same IRs passed in + val finalBatchIr1 = mock[FinalBatchIr] + val cachedBatchResponse = BatchResponses(finalBatchIr1) + val cachedIr = + spiedFetcherCache.getBatchIrFromBatchResponse(cachedBatchResponse, batchBytes, servingInfo, toBatchIr, keys) + assertEquals(finalBatchIr1, cachedIr) + verify(toBatchIr, never())(any(classOf[Array[Byte]]), any()) // no decoding needed + + // 2. Un-cached BatchResponse has IRs added to cache + val finalBatchIr2 = mock[FinalBatchIr] + val kvStoreBatchResponses = BatchResponses(Success(Seq(TimedValue(batchBytes, 1000L)))) + when(toBatchIr(any(), any())).thenReturn(finalBatchIr2) + val uncachedIr = + spiedFetcherCache.getBatchIrFromBatchResponse(kvStoreBatchResponses, batchBytes, servingInfo, toBatchIr, keys) + assertEquals(finalBatchIr2, uncachedIr) + assertEquals(batchIrCache.cache.getIfPresent(cacheKey), BatchResponses(finalBatchIr2)) // key was added + verify(toBatchIr, times(1))(any(), any()) // decoding did happen + } + + @Test + def test_getBatchIrFromBatchResponse_DecodesBatchBytesIfCacheDisabled(): Unit = { + // Set up mocks and dummy data + val servingInfo = mock[GroupByServingInfoParsed] + val batchBytes = Array[Byte](1, 2, 3) + val keys = Map("key" -> "value") + val finalBatchIr = mock[FinalBatchIr] + val toBatchIr = mock[(Array[Byte], GroupByServingInfoParsed) => FinalBatchIr] + val kvStoreBatchResponses = BatchResponses(Success(Seq(TimedValue(batchBytes, 1000L)))) + + val spiedFetcherCache = Mockito.spy(new TestableFetcherCache(None)) + when(toBatchIr(any(), any())).thenReturn(finalBatchIr) + + // When getBatchIrFromBatchResponse is called, it decodes the bytes and doesn't hit the cache + val ir = + spiedFetcherCache.getBatchIrFromBatchResponse(kvStoreBatchResponses, batchBytes, servingInfo, toBatchIr, keys) + verify(toBatchIr, times(1))(batchBytes, servingInfo) // decoding did happen + assertEquals(finalBatchIr, ir) + } + + @Test + def test_getBatchIrFromBatchResponse_ReturnsCorrectMapResponseWithCacheEnabled(): Unit = { + // Use a real cache + val batchIrCache = new BatchIrCache("test_cache", batchIrCacheMaximumSize) + // Set up mocks and dummy data + val servingInfo = mock[GroupByServingInfoParsed] + val groupByOps = mock[GroupByOps] + val outputCodec = mock[AvroCodec] + when(servingInfo.groupByOps).thenReturn(groupByOps) + when(groupByOps.batchDataset).thenReturn("test_dataset") + when(servingInfo.groupByOps.batchDataset).thenReturn("test_dataset") + when(servingInfo.batchEndTsMillis).thenReturn(1000L) + val batchBytes = Array[Byte](1, 2, 3) + val keys = Map("key" -> "value") + val cacheKey = BatchIrCache.Key(servingInfo.groupByOps.batchDataset, keys, servingInfo.batchEndTsMillis) + + val spiedFetcherCache = Mockito.spy(new TestableFetcherCache(Some(batchIrCache))) + doReturn(true).when(spiedFetcherCache).isCachingEnabled(any()) + + // 1. Cached BatchResponse returns the same Map responses passed in + val mapResponse1 = mock[Map[String, AnyRef]] + val cachedBatchResponse = BatchResponses(mapResponse1) + val decodingFunction1 = (bytes: Array[Byte]) => { + fail("Decoding function should not be called when batch response is cached") + mapResponse1 + } + val cachedMapResponse = spiedFetcherCache.getMapResponseFromBatchResponse(cachedBatchResponse, + batchBytes, + decodingFunction1, + servingInfo, + keys) + assertEquals(mapResponse1, cachedMapResponse) + + // 2. Un-cached BatchResponse has Map responses added to cache + val mapResponse2 = mock[Map[String, AnyRef]] + val kvStoreBatchResponses = BatchResponses(Success(Seq(TimedValue(batchBytes, 1000L)))) + def decodingFunction2 = (bytes: Array[Byte]) => mapResponse2 + val decodedMapResponse = spiedFetcherCache.getMapResponseFromBatchResponse(kvStoreBatchResponses, + batchBytes, + decodingFunction2, + servingInfo, + keys) + assertEquals(mapResponse2, decodedMapResponse) + assertEquals(batchIrCache.cache.getIfPresent(cacheKey), CachedMapBatchResponse(mapResponse2)) // key was added + } + + @Test + def test_getMapResponseFromBatchResponse_DecodesBatchBytesIfCacheDisabled(): Unit = { + // Set up mocks and dummy data + val servingInfo = mock[GroupByServingInfoParsed] + val batchBytes = Array[Byte](1, 2, 3) + val keys = Map("key" -> "value") + val mapResponse = mock[Map[String, AnyRef]] + val outputCodec = mock[AvroCodec] + val kvStoreBatchResponses = BatchResponses(Success(Seq(TimedValue(batchBytes, 1000L)))) + when(servingInfo.outputCodec).thenReturn(outputCodec) + when(outputCodec.decodeMap(any())).thenReturn(mapResponse) + + val spiedFetcherCache = Mockito.spy(new TestableFetcherCache(None)) + + // When getMapResponseFromBatchResponse is called, it decodes the bytes and doesn't hit the cache + val decodedMapResponse = spiedFetcherCache.getMapResponseFromBatchResponse(kvStoreBatchResponses, + batchBytes, + servingInfo.outputCodec.decodeMap, + servingInfo, + keys) + verify(servingInfo.outputCodec.decodeMap(any()), times(1)) // decoding did happen + assertEquals(mapResponse, decodedMapResponse) + } + + @Test + def test_isCachingEnabled_CorrectlyDetermineIfCacheIsEnabled(): Unit = { + val baseFetcher = new TestableFetcherCache(Some(new BatchIrCache("test", batchIrCacheMaximumSize))) + def buildGroupByWithCustomJson(name: String, customJson: String = null): GroupBy = + Builders.GroupBy( + metaData = Builders.MetaData(name = name, customJson = customJson) + ) + + assertFalse(baseFetcher.isCachingEnabled(buildGroupByWithCustomJson("test_groupby_1"))) + assertFalse(baseFetcher.isCachingEnabled(buildGroupByWithCustomJson("test_groupby_2", "{}"))) + assertTrue( + baseFetcher + .isCachingEnabled(buildGroupByWithCustomJson("test_groupby_3", "{\"enable_caching\": true}")) + ) + assertFalse( + baseFetcher + .isCachingEnabled(buildGroupByWithCustomJson("test_groupby_4", "{\"enable_caching\": false}")) + ) + assertFalse( + baseFetcher + .isCachingEnabled( + buildGroupByWithCustomJson("test_groupby_5", "{\"enable_caching\": \"string instead of bool\"}") + ) + ) + } + + @Test + def test_isCachingEnabled_Memoizes(): Unit = { + val baseFetcher = new TestableFetcherCache(Some(new BatchIrCache("test", batchIrCacheMaximumSize))) + def buildGroupByWithCustomJson(name: String, customJson: String = null): GroupBy = + Builders.GroupBy( + metaData = Builders.MetaData(name = name, customJson = customJson) + ) + + // the map is clean at the start + assertFalse(baseFetcher.isCachingEnabledForGroupBy.contains("test_memo_gb_1")) + assertFalse(baseFetcher.isCachingEnabledForGroupBy.contains("test_memo_gb_2")) + + // memoization works for both true and false + baseFetcher.isCachingEnabled(buildGroupByWithCustomJson("test_memo_gb_1")) + assertFalse(baseFetcher.isCachingEnabledForGroupBy("test_memo_gb_1")) + + baseFetcher.isCachingEnabled(buildGroupByWithCustomJson("test_memo_gb_2", "{\"enable_caching\": true}")) + assertTrue(baseFetcher.isCachingEnabledForGroupBy("test_memo_gb_2")) + } +}