diff --git a/online/src/main/scala/ai/chronon/online/TileCodec.scala b/online/src/main/scala/ai/chronon/online/TileCodec.scala index aa70bc687..dea0f416f 100644 --- a/online/src/main/scala/ai/chronon/online/TileCodec.scala +++ b/online/src/main/scala/ai/chronon/online/TileCodec.scala @@ -3,26 +3,44 @@ package ai.chronon.online import ai.chronon.aggregator.row.RowAggregator import ai.chronon.api.{BooleanType, DataType, GroupBy, StructType} import org.apache.avro.generic.GenericData -import ai.chronon.api.Extensions.{AggregationOps, MetadataOps} +import ai.chronon.api.Extensions.{AggregationOps, MetadataOps, WindowUtils} import scala.collection.JavaConverters._ +import scala.util.ScalaJavaConversions.ListOps +import scala.util.ScalaVersionSpecificCollectionsConverter object TileCodec { def buildRowAggregator(groupBy: GroupBy, inputSchema: Seq[(String, DataType)]): RowAggregator = { // a set of Chronon groupBy aggregations needs to be flatted out to get the - // full cross-product of all the feature column aggregations to be computed + // feature column aggregations to be computed. We don't include windows in this + // to keep the aggregation work & payload size small as the multiple windows for a given + // counter are identical value wise within a tile (e.g. sum_1d and sum_7d are the same in a tile) + val unpackedAggs = groupBy.aggregations.asScala.flatMap(_.unWindowed) + new RowAggregator(inputSchema, unpackedAggs) + } + + def buildWindowedRowAggregator(groupBy: GroupBy, inputSchema: Seq[(String, DataType)]): RowAggregator = { + // a set of Chronon groupBy aggregations needs to be flatted out to get the + // feature column aggregations to be computed. This version includes windows in the feature + // columns to get the full cross product (buckets * windows) as this is useful in unit tests to compare + // the final results val unpackedAggs = groupBy.aggregations.asScala.flatMap(_.unpack) new RowAggregator(inputSchema, unpackedAggs) } } /** - * TileCodec is a helper class that allows for the creation of pre-aggregated tiles of feature values. - * These pre-aggregated tiles can be used in the serving layer to compute the final feature values along - * with batch pre-aggregates produced by GroupByUploads. - * The pre-aggregated tiles are serialized as Avro and indicate whether the tile is complete or not (partial aggregates) - */ -class TileCodec(rowAggregator: RowAggregator, groupBy: GroupBy) { + * TileCodec is a helper class that allows for the creation of pre-aggregated tiles of feature values. + * These pre-aggregated tiles can be used in the serving layer to compute the final feature values along + * with batch pre-aggregates produced by GroupByUploads. + * The pre-aggregated tiles are serialized as Avro and indicate whether the tile is complete or not (partial aggregates) + */ +class TileCodec(groupBy: GroupBy, inputSchema: Seq[(String, DataType)]) { + + import TileCodec._ + val rowAggregator: RowAggregator = buildRowAggregator(groupBy, inputSchema) + val windowedRowAggregator: RowAggregator = buildWindowedRowAggregator(groupBy, inputSchema) + val windowedIrSchema: StructType = StructType.from("WindowedIr", rowAggregator.irSchema) val fields: Array[(String, DataType)] = Array( "collapsedIr" -> windowedIrSchema, @@ -32,7 +50,6 @@ class TileCodec(rowAggregator: RowAggregator, groupBy: GroupBy) { val tileChrononSchema: StructType = StructType.from(s"${groupBy.metaData.cleanName}_TILE_IR", fields) val tileAvroSchema: String = AvroConversions.fromChrononSchema(tileChrononSchema).toString() - val tileAvroCodec: AvroCodec = AvroCodec.of(tileAvroSchema) private val irToBytesFn = AvroConversions.encodeBytes(tileChrononSchema, null) def makeTileIr(ir: Array[Any], isComplete: Boolean): Array[Byte] = { @@ -42,6 +59,7 @@ class TileCodec(rowAggregator: RowAggregator, groupBy: GroupBy) { } def decodeTileIr(tileIr: Array[Byte]): (Array[Any], Boolean) = { + val tileAvroCodec: AvroCodec = AvroCodec.of(tileAvroSchema) val decodedTileIr = tileAvroCodec.decode(tileIr) val collapsedIr = decodedTileIr .get("collapsedIr") @@ -51,7 +69,36 @@ class TileCodec(rowAggregator: RowAggregator, groupBy: GroupBy) { .toChrononRow(collapsedIr, windowedIrSchema) .asInstanceOf[Array[Any]] val denormalizedIr = rowAggregator.denormalize(ir) + val expandedWindowedIr = expandWindowedTileIr(denormalizedIr) val isComplete = decodedTileIr.get("isComplete").asInstanceOf[Boolean] - (denormalizedIr, isComplete) + (expandedWindowedIr, isComplete) + } + + // method that takes a tile IR in the unwindowed form and expands it to the windowed form + // as an example: [myfield_sum, myfield_average] -> [myfield_sum_1d, myfield_sum_7d, myfield_average_1d, myfield_average_7d] + def expandWindowedTileIr(baseIr: Array[Any]): Array[Any] = { + val flattenedIr = windowedRowAggregator.init + var irPos = 0 + var bucketPos = 0 + groupBy.aggregations.asScala.foreach { + aggr => + val buckets = Option(aggr.buckets) + .map(_.toScala) + .getOrElse(Seq(null)) + val windows = Option(aggr.windows) + .map(_.toScala) + .getOrElse(Seq(WindowUtils.Unbounded)) + // for each aggregation we have 1/more buckets and 1/more windows + // we need to iterate over the baseIr and clone a given counter's values n times where + // n is the number of windows for that counter + for(_ <- buckets) { + for(_ <- windows) { + flattenedIr(irPos) = rowAggregator.columnAggregators(bucketPos).clone(baseIr(bucketPos)) + irPos += 1 + } + bucketPos += 1 + } + } + flattenedIr } } diff --git a/online/src/test/scala/ai/chronon/online/TileCodecTest.scala b/online/src/test/scala/ai/chronon/online/TileCodecTest.scala index c666211fd..bb80a5f82 100644 --- a/online/src/test/scala/ai/chronon/online/TileCodecTest.scala +++ b/online/src/test/scala/ai/chronon/online/TileCodecTest.scala @@ -8,35 +8,38 @@ import scala.collection.JavaConverters._ class TileCodecTest { private val histogram = Map[String, Int]("A" -> 3, "B" -> 2).asJava - private val aggregationsAndExpected: Array[(Aggregation, Any)] = Array( - Builders.Aggregation(Operation.AVERAGE, "views", Seq(new Window(1, TimeUnit.DAYS))) -> 16.0, - Builders.Aggregation(Operation.AVERAGE, "rating", Seq(new Window(1, TimeUnit.DAYS))) -> 4.0, + private val aggregationsAndExpected: Array[(Aggregation, Seq[Any])] = Array( + Builders.Aggregation(Operation.AVERAGE, "views", Seq(new Window(1, TimeUnit.DAYS))) -> Seq(16.0), + Builders.Aggregation(Operation.AVERAGE, "rating", Seq(new Window(1, TimeUnit.DAYS))) -> Seq(4.0), - Builders.Aggregation(Operation.SUM, "rating", Seq(new Window(1, TimeUnit.DAYS))) -> 12.0f, - Builders.Aggregation(Operation.SUM, "rating", Seq(new Window(7, TimeUnit.DAYS))) -> 12.0f, + Builders.Aggregation(Operation.SUM, "rating", Seq(new Window(1, TimeUnit.DAYS), new Window(7, TimeUnit.DAYS))) -> Seq(12.0f, 12.0f), - Builders.Aggregation(Operation.UNIQUE_COUNT, "title", Seq(new Window(1, TimeUnit.DAYS))) -> 3L, - Builders.Aggregation(Operation.UNIQUE_COUNT, "title", Seq(new Window(7, TimeUnit.DAYS))) -> 3L, + Builders.Aggregation(Operation.UNIQUE_COUNT, "title", Seq(new Window(1, TimeUnit.DAYS), new Window(7, TimeUnit.DAYS))) -> Seq(3L, 3L), - Builders.Aggregation(Operation.LAST, "title", Seq(new Window(1, TimeUnit.DAYS))) -> "C", - Builders.Aggregation(Operation.LAST, "title", Seq(new Window(7, TimeUnit.DAYS))) -> "C", + Builders.Aggregation(Operation.LAST, "title", Seq(new Window(1, TimeUnit.DAYS), new Window(7, TimeUnit.DAYS))) -> Seq("C", "C"), - Builders.Aggregation(Operation.LAST_K, "title", Seq(new Window(1, TimeUnit.DAYS)), argMap = Map("k" -> "2")) -> List("C", "B").asJava, - Builders.Aggregation(Operation.LAST_K, "title", Seq(new Window(7, TimeUnit.DAYS)), argMap = Map("k" -> "2")) -> List("C", "B").asJava, + Builders.Aggregation(Operation.LAST_K, "title", Seq(new Window(1, TimeUnit.DAYS), new Window(7, TimeUnit.DAYS)), argMap = Map("k" -> "2")) -> Seq(List("C", "B").asJava, List("C", "B").asJava), - Builders.Aggregation(Operation.TOP_K, "title", Seq(new Window(1, TimeUnit.DAYS)), argMap = Map("k" -> "1")) -> List("C").asJava, - Builders.Aggregation(Operation.TOP_K, "title", Seq(new Window(7, TimeUnit.DAYS)), argMap = Map("k" -> "1")) -> List("C").asJava, + Builders.Aggregation(Operation.TOP_K, "title", Seq(new Window(1, TimeUnit.DAYS), new Window(7, TimeUnit.DAYS)), argMap = Map("k" -> "1")) -> Seq(List("C").asJava, List("C").asJava), - Builders.Aggregation(Operation.MIN, "title", Seq(new Window(1, TimeUnit.DAYS))) -> "A", - Builders.Aggregation(Operation.MIN, "title", Seq(new Window(7, TimeUnit.DAYS))) -> "A", + Builders.Aggregation(Operation.MIN, "title", Seq(new Window(1, TimeUnit.DAYS), new Window(7, TimeUnit.DAYS))) -> Seq("A", "A"), - Builders.Aggregation(Operation.APPROX_UNIQUE_COUNT, "title", Seq(new Window(1, TimeUnit.DAYS))) -> 3L, - Builders.Aggregation(Operation.APPROX_UNIQUE_COUNT, "title", Seq(new Window(7, TimeUnit.DAYS))) -> 3L, + Builders.Aggregation(Operation.APPROX_UNIQUE_COUNT, "title", Seq(new Window(1, TimeUnit.DAYS), new Window(7, TimeUnit.DAYS))) -> Seq(3L, 3L), - Builders.Aggregation(Operation.HISTOGRAM, "hist_input", Seq(new Window(1, TimeUnit.DAYS)), argMap = Map("k" -> "2")) -> histogram, - Builders.Aggregation(Operation.HISTOGRAM, "hist_input", Seq(new Window(7, TimeUnit.DAYS)), argMap = Map("k" -> "2")) -> histogram + Builders.Aggregation(Operation.HISTOGRAM, "hist_input", Seq(new Window(1, TimeUnit.DAYS), new Window(7, TimeUnit.DAYS)), argMap = Map("k" -> "2")) -> Seq(histogram, histogram) ) + private val bucketedAggregations: Array[Aggregation] = Array( + Builders.Aggregation( + operation = Operation.AVERAGE, + inputColumn = "views", + buckets = Seq("title"), + windows = Seq(new Window(1, TimeUnit.DAYS), new Window(7, TimeUnit.DAYS)) + ) + ) + private val expectedBucketResult = Map("A" -> 4.0, "B" -> 40.0, "C" -> 4.0).asJava + private val expectedBucketedResults = Seq(expectedBucketResult, expectedBucketResult) + private val schema = List( "created" -> LongType, "views" -> IntType, @@ -45,14 +48,25 @@ class TileCodecTest { "hist_input" -> ListType(StringType) ) + def createRow(ts: Long, views: Int, rating: Float, title: String, histInput: Seq[String]): Row = { + val values: Array[(String, Any)] = Array( + "created" -> ts, + "views" -> views, + "rating" -> rating, + "title" -> title, + "hist_input" -> histInput + ) + new ArrayRow(values.map(_._2), ts) + } + @Test def testTileCodecIrSerRoundTrip(): Unit = { val groupByMetadata = Builders.MetaData(name = "my_group_by") val (aggregations, expectedVals) = aggregationsAndExpected.unzip + val expectedFlattenedVals = expectedVals.flatten val groupBy = Builders.GroupBy(metaData = groupByMetadata, aggregations = aggregations) - val rowAggregator = TileCodec.buildRowAggregator(groupBy, schema) - val rowIR = rowAggregator.init - val tileCodec = new TileCodec(rowAggregator, groupBy) + val tileCodec = new TileCodec(groupBy, schema) + val rowIR = tileCodec.rowAggregator.init val originalIsComplete = true val rows = Seq( @@ -60,7 +74,7 @@ class TileCodecTest { createRow(1519862399984L, 40, 5.0f, "B", Seq()), createRow(1519862399988L, 4, 3.0f, "C", Seq("A", "B", "C")) ) - rows.foreach(row => rowAggregator.update(rowIR, row)) + rows.foreach(row => tileCodec.rowAggregator.update(rowIR, row)) val bytes = tileCodec.makeTileIr(rowIR, originalIsComplete) assert(bytes.length > 0) @@ -68,22 +82,48 @@ class TileCodecTest { assert(isComplete == originalIsComplete) // lets finalize the payload intermediate results and verify things - val finalResults = rowAggregator.finalize(deserPayload) - expectedVals.zip(finalResults).zip(rowAggregator.outputSchema.map(_._1)).foreach { + val finalResults = tileCodec.windowedRowAggregator.finalize(deserPayload) + assertEquals(expectedFlattenedVals.length, finalResults.length) + + // we use a windowed row aggregator for the final results as we want the final flattened results + val windowedRowAggregator = TileCodec.buildWindowedRowAggregator(groupBy, schema) + expectedFlattenedVals.zip(finalResults).zip(windowedRowAggregator.outputSchema.map(_._1)).foreach { case ((expected, actual), name) => println(s"Checking: $name") assertEquals(expected, actual) } } - def createRow(ts: Long, views: Int, rating: Float, title: String, histInput: Seq[String]): Row = { - val values: Array[(String, Any)] = Array( - "created" -> ts, - "views" -> views, - "rating" -> rating, - "title" -> title, - "hist_input" -> histInput + @Test + def testTileCodecIrSerRoundTrip_WithBuckets(): Unit = { + val groupByMetadata = Builders.MetaData(name = "my_group_by") + val groupBy = Builders.GroupBy(metaData = groupByMetadata, aggregations = bucketedAggregations) + val tileCodec = new TileCodec(groupBy, schema) + val rowIR = tileCodec.rowAggregator.init + + val originalIsComplete = true + val rows = Seq( + createRow(1519862399984L, 4, 4.0f, "A", Seq("D", "A", "B", "A")), + createRow(1519862399984L, 40, 5.0f, "B", Seq()), + createRow(1519862399988L, 4, 3.0f, "C", Seq("A", "B", "C")) ) - new ArrayRow(values.map(_._2), ts) + rows.foreach(row => tileCodec.rowAggregator.update(rowIR, row)) + val bytes = tileCodec.makeTileIr(rowIR, originalIsComplete) + assert(bytes.length > 0) + + val (deserPayload, isComplete) = tileCodec.decodeTileIr(bytes) + assert(isComplete == originalIsComplete) + + // lets finalize the payload intermediate results and verify things + val finalResults = tileCodec.windowedRowAggregator.finalize(deserPayload) + assertEquals(expectedBucketedResults.size, finalResults.length) + + // we use a windowed row aggregator for the final results as we want the final flattened results + val windowedRowAggregator = TileCodec.buildWindowedRowAggregator(groupBy, schema) + expectedBucketedResults.zip(finalResults).zip(windowedRowAggregator.outputSchema.map(_._1)).foreach { + case ((expected, actual), name) => + println(s"Checking: $name") + assertEquals(expected, actual) + } } }