Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements to TileCodec to compress and skip redundant windows #539

Merged
merged 7 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 57 additions & 10 deletions online/src/main/scala/ai/chronon/online/TileCodec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,44 @@ package ai.chronon.online
import ai.chronon.aggregator.row.RowAggregator
import ai.chronon.api.{BooleanType, DataType, GroupBy, StructType}
import org.apache.avro.generic.GenericData
import ai.chronon.api.Extensions.{AggregationOps, MetadataOps}
import ai.chronon.api.Extensions.{AggregationOps, MetadataOps, WindowUtils}

import scala.collection.JavaConverters._
import scala.util.ScalaJavaConversions.ListOps
import scala.util.ScalaVersionSpecificCollectionsConverter

object TileCodec {
def buildRowAggregator(groupBy: GroupBy, inputSchema: Seq[(String, DataType)]): RowAggregator = {
// a set of Chronon groupBy aggregations needs to be flatted out to get the
// full cross-product of all the feature column aggregations to be computed
// feature column aggregations to be computed. We don't include windows in this
// to keep the aggregation work & payload size small as the multiple windows for a given
// counter are identical value wise within a tile (e.g. sum_1d and sum_7d are the same in a tile)
val unpackedAggs = groupBy.aggregations.asScala.flatMap(_.unWindowed)
new RowAggregator(inputSchema, unpackedAggs)
}

def buildWindowedRowAggregator(groupBy: GroupBy, inputSchema: Seq[(String, DataType)]): RowAggregator = {
// a set of Chronon groupBy aggregations needs to be flatted out to get the
// feature column aggregations to be computed. This version includes windows in the feature
// columns to get the full cross product (buckets * windows) as this is useful in unit tests to compare
// the final results
val unpackedAggs = groupBy.aggregations.asScala.flatMap(_.unpack)
new RowAggregator(inputSchema, unpackedAggs)
}
}

/**
* TileCodec is a helper class that allows for the creation of pre-aggregated tiles of feature values.
* These pre-aggregated tiles can be used in the serving layer to compute the final feature values along
* with batch pre-aggregates produced by GroupByUploads.
* The pre-aggregated tiles are serialized as Avro and indicate whether the tile is complete or not (partial aggregates)
*/
class TileCodec(rowAggregator: RowAggregator, groupBy: GroupBy) {
* TileCodec is a helper class that allows for the creation of pre-aggregated tiles of feature values.
* These pre-aggregated tiles can be used in the serving layer to compute the final feature values along
* with batch pre-aggregates produced by GroupByUploads.
* The pre-aggregated tiles are serialized as Avro and indicate whether the tile is complete or not (partial aggregates)
*/
class TileCodec(groupBy: GroupBy, inputSchema: Seq[(String, DataType)]) {

import TileCodec._
val rowAggregator: RowAggregator = buildRowAggregator(groupBy, inputSchema)
val windowedRowAggregator: RowAggregator = buildWindowedRowAggregator(groupBy, inputSchema)

val windowedIrSchema: StructType = StructType.from("WindowedIr", rowAggregator.irSchema)
val fields: Array[(String, DataType)] = Array(
"collapsedIr" -> windowedIrSchema,
Expand All @@ -32,7 +50,6 @@ class TileCodec(rowAggregator: RowAggregator, groupBy: GroupBy) {
val tileChrononSchema: StructType =
StructType.from(s"${groupBy.metaData.cleanName}_TILE_IR", fields)
val tileAvroSchema: String = AvroConversions.fromChrononSchema(tileChrononSchema).toString()
val tileAvroCodec: AvroCodec = AvroCodec.of(tileAvroSchema)
private val irToBytesFn = AvroConversions.encodeBytes(tileChrononSchema, null)

def makeTileIr(ir: Array[Any], isComplete: Boolean): Array[Byte] = {
Expand All @@ -42,6 +59,7 @@ class TileCodec(rowAggregator: RowAggregator, groupBy: GroupBy) {
}

def decodeTileIr(tileIr: Array[Byte]): (Array[Any], Boolean) = {
val tileAvroCodec: AvroCodec = AvroCodec.of(tileAvroSchema)
val decodedTileIr = tileAvroCodec.decode(tileIr)
val collapsedIr = decodedTileIr
.get("collapsedIr")
Expand All @@ -51,7 +69,36 @@ class TileCodec(rowAggregator: RowAggregator, groupBy: GroupBy) {
.toChrononRow(collapsedIr, windowedIrSchema)
.asInstanceOf[Array[Any]]
val denormalizedIr = rowAggregator.denormalize(ir)
val expandedWindowedIr = expandWindowedTileIr(denormalizedIr)
val isComplete = decodedTileIr.get("isComplete").asInstanceOf[Boolean]
(denormalizedIr, isComplete)
(expandedWindowedIr, isComplete)
}

// method that takes a tile IR in the unwindowed form and expands it to the windowed form
// as an example: [myfield_sum, myfield_average] -> [myfield_sum_1d, myfield_sum_7d, myfield_average_1d, myfield_average_7d]
def expandWindowedTileIr(baseIr: Array[Any]): Array[Any] = {
val flattenedIr = windowedRowAggregator.init
var irPos = 0
var bucketPos = 0
groupBy.aggregations.asScala.foreach {
aggr =>
val buckets = Option(aggr.buckets)
.map(_.toScala)
.getOrElse(Seq(null))
val windows = Option(aggr.windows)
.map(_.toScala)
.getOrElse(Seq(WindowUtils.Unbounded))
// for each aggregation we have 1/more buckets and 1/more windows
// we need to iterate over the baseIr and clone a given counter's values n times where
// n is the number of windows for that counter
for(_ <- buckets) {
for(_ <- windows) {
flattenedIr(irPos) = rowAggregator.columnAggregators(bucketPos).clone(baseIr(bucketPos))
irPos += 1
}
bucketPos += 1
}
}
flattenedIr
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it might be more efficient to directly implement and not pay for the intermediate collection.
updateFlatWithTile(historicalFlatIr, tileIr) vs. update(historicalFlatIr, flatten(tile))

}
}
106 changes: 73 additions & 33 deletions online/src/test/scala/ai/chronon/online/TileCodecTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,35 +8,38 @@ import scala.collection.JavaConverters._
class TileCodecTest {
private val histogram = Map[String, Int]("A" -> 3, "B" -> 2).asJava

private val aggregationsAndExpected: Array[(Aggregation, Any)] = Array(
Builders.Aggregation(Operation.AVERAGE, "views", Seq(new Window(1, TimeUnit.DAYS))) -> 16.0,
Builders.Aggregation(Operation.AVERAGE, "rating", Seq(new Window(1, TimeUnit.DAYS))) -> 4.0,
private val aggregationsAndExpected: Array[(Aggregation, Seq[Any])] = Array(
Builders.Aggregation(Operation.AVERAGE, "views", Seq(new Window(1, TimeUnit.DAYS))) -> Seq(16.0),
Builders.Aggregation(Operation.AVERAGE, "rating", Seq(new Window(1, TimeUnit.DAYS))) -> Seq(4.0),

Builders.Aggregation(Operation.SUM, "rating", Seq(new Window(1, TimeUnit.DAYS))) -> 12.0f,
Builders.Aggregation(Operation.SUM, "rating", Seq(new Window(7, TimeUnit.DAYS))) -> 12.0f,
Builders.Aggregation(Operation.SUM, "rating", Seq(new Window(1, TimeUnit.DAYS), new Window(7, TimeUnit.DAYS))) -> Seq(12.0f, 12.0f),

Builders.Aggregation(Operation.UNIQUE_COUNT, "title", Seq(new Window(1, TimeUnit.DAYS))) -> 3L,
Builders.Aggregation(Operation.UNIQUE_COUNT, "title", Seq(new Window(7, TimeUnit.DAYS))) -> 3L,
Builders.Aggregation(Operation.UNIQUE_COUNT, "title", Seq(new Window(1, TimeUnit.DAYS), new Window(7, TimeUnit.DAYS))) -> Seq(3L, 3L),

Builders.Aggregation(Operation.LAST, "title", Seq(new Window(1, TimeUnit.DAYS))) -> "C",
Builders.Aggregation(Operation.LAST, "title", Seq(new Window(7, TimeUnit.DAYS))) -> "C",
Builders.Aggregation(Operation.LAST, "title", Seq(new Window(1, TimeUnit.DAYS), new Window(7, TimeUnit.DAYS))) -> Seq("C", "C"),

Builders.Aggregation(Operation.LAST_K, "title", Seq(new Window(1, TimeUnit.DAYS)), argMap = Map("k" -> "2")) -> List("C", "B").asJava,
Builders.Aggregation(Operation.LAST_K, "title", Seq(new Window(7, TimeUnit.DAYS)), argMap = Map("k" -> "2")) -> List("C", "B").asJava,
Builders.Aggregation(Operation.LAST_K, "title", Seq(new Window(1, TimeUnit.DAYS), new Window(7, TimeUnit.DAYS)), argMap = Map("k" -> "2")) -> Seq(List("C", "B").asJava, List("C", "B").asJava),

Builders.Aggregation(Operation.TOP_K, "title", Seq(new Window(1, TimeUnit.DAYS)), argMap = Map("k" -> "1")) -> List("C").asJava,
Builders.Aggregation(Operation.TOP_K, "title", Seq(new Window(7, TimeUnit.DAYS)), argMap = Map("k" -> "1")) -> List("C").asJava,
Builders.Aggregation(Operation.TOP_K, "title", Seq(new Window(1, TimeUnit.DAYS), new Window(7, TimeUnit.DAYS)), argMap = Map("k" -> "1")) -> Seq(List("C").asJava, List("C").asJava),

Builders.Aggregation(Operation.MIN, "title", Seq(new Window(1, TimeUnit.DAYS))) -> "A",
Builders.Aggregation(Operation.MIN, "title", Seq(new Window(7, TimeUnit.DAYS))) -> "A",
Builders.Aggregation(Operation.MIN, "title", Seq(new Window(1, TimeUnit.DAYS), new Window(7, TimeUnit.DAYS))) -> Seq("A", "A"),

Builders.Aggregation(Operation.APPROX_UNIQUE_COUNT, "title", Seq(new Window(1, TimeUnit.DAYS))) -> 3L,
Builders.Aggregation(Operation.APPROX_UNIQUE_COUNT, "title", Seq(new Window(7, TimeUnit.DAYS))) -> 3L,
Builders.Aggregation(Operation.APPROX_UNIQUE_COUNT, "title", Seq(new Window(1, TimeUnit.DAYS), new Window(7, TimeUnit.DAYS))) -> Seq(3L, 3L),

Builders.Aggregation(Operation.HISTOGRAM, "hist_input", Seq(new Window(1, TimeUnit.DAYS)), argMap = Map("k" -> "2")) -> histogram,
Builders.Aggregation(Operation.HISTOGRAM, "hist_input", Seq(new Window(7, TimeUnit.DAYS)), argMap = Map("k" -> "2")) -> histogram
Builders.Aggregation(Operation.HISTOGRAM, "hist_input", Seq(new Window(1, TimeUnit.DAYS), new Window(7, TimeUnit.DAYS)), argMap = Map("k" -> "2")) -> Seq(histogram, histogram)
)

private val bucketedAggregations: Array[Aggregation] = Array(
Builders.Aggregation(
operation = Operation.AVERAGE,
inputColumn = "views",
buckets = Seq("title"),
windows = Seq(new Window(1, TimeUnit.DAYS), new Window(7, TimeUnit.DAYS))
)
)
private val expectedBucketResult = Map("A" -> 4.0, "B" -> 40.0, "C" -> 4.0).asJava
private val expectedBucketedResults = Seq(expectedBucketResult, expectedBucketResult)

private val schema = List(
"created" -> LongType,
"views" -> IntType,
Expand All @@ -45,45 +48,82 @@ class TileCodecTest {
"hist_input" -> ListType(StringType)
)

def createRow(ts: Long, views: Int, rating: Float, title: String, histInput: Seq[String]): Row = {
val values: Array[(String, Any)] = Array(
"created" -> ts,
"views" -> views,
"rating" -> rating,
"title" -> title,
"hist_input" -> histInput
)
new ArrayRow(values.map(_._2), ts)
}

@Test
def testTileCodecIrSerRoundTrip(): Unit = {
val groupByMetadata = Builders.MetaData(name = "my_group_by")
val (aggregations, expectedVals) = aggregationsAndExpected.unzip
val expectedFlattenedVals = expectedVals.flatten
val groupBy = Builders.GroupBy(metaData = groupByMetadata, aggregations = aggregations)
val rowAggregator = TileCodec.buildRowAggregator(groupBy, schema)
val rowIR = rowAggregator.init
val tileCodec = new TileCodec(rowAggregator, groupBy)
val tileCodec = new TileCodec(groupBy, schema)
val rowIR = tileCodec.rowAggregator.init

val originalIsComplete = true
val rows = Seq(
createRow(1519862399984L, 4, 4.0f, "A", Seq("D", "A", "B", "A")),
createRow(1519862399984L, 40, 5.0f, "B", Seq()),
createRow(1519862399988L, 4, 3.0f, "C", Seq("A", "B", "C"))
)
rows.foreach(row => rowAggregator.update(rowIR, row))
rows.foreach(row => tileCodec.rowAggregator.update(rowIR, row))
val bytes = tileCodec.makeTileIr(rowIR, originalIsComplete)
assert(bytes.length > 0)

val (deserPayload, isComplete) = tileCodec.decodeTileIr(bytes)
assert(isComplete == originalIsComplete)

// lets finalize the payload intermediate results and verify things
val finalResults = rowAggregator.finalize(deserPayload)
expectedVals.zip(finalResults).zip(rowAggregator.outputSchema.map(_._1)).foreach {
val finalResults = tileCodec.windowedRowAggregator.finalize(deserPayload)
assertEquals(expectedFlattenedVals.length, finalResults.length)

// we use a windowed row aggregator for the final results as we want the final flattened results
val windowedRowAggregator = TileCodec.buildWindowedRowAggregator(groupBy, schema)
expectedFlattenedVals.zip(finalResults).zip(windowedRowAggregator.outputSchema.map(_._1)).foreach {
case ((expected, actual), name) =>
println(s"Checking: $name")
assertEquals(expected, actual)
}
}

def createRow(ts: Long, views: Int, rating: Float, title: String, histInput: Seq[String]): Row = {
val values: Array[(String, Any)] = Array(
"created" -> ts,
"views" -> views,
"rating" -> rating,
"title" -> title,
"hist_input" -> histInput
@Test
def testTileCodecIrSerRoundTrip_WithBuckets(): Unit = {
val groupByMetadata = Builders.MetaData(name = "my_group_by")
val groupBy = Builders.GroupBy(metaData = groupByMetadata, aggregations = bucketedAggregations)
val tileCodec = new TileCodec(groupBy, schema)
val rowIR = tileCodec.rowAggregator.init

val originalIsComplete = true
val rows = Seq(
createRow(1519862399984L, 4, 4.0f, "A", Seq("D", "A", "B", "A")),
createRow(1519862399984L, 40, 5.0f, "B", Seq()),
createRow(1519862399988L, 4, 3.0f, "C", Seq("A", "B", "C"))
)
new ArrayRow(values.map(_._2), ts)
rows.foreach(row => tileCodec.rowAggregator.update(rowIR, row))
val bytes = tileCodec.makeTileIr(rowIR, originalIsComplete)
assert(bytes.length > 0)

val (deserPayload, isComplete) = tileCodec.decodeTileIr(bytes)
assert(isComplete == originalIsComplete)

// lets finalize the payload intermediate results and verify things
val finalResults = tileCodec.windowedRowAggregator.finalize(deserPayload)
assertEquals(expectedBucketedResults.size, finalResults.length)

// we use a windowed row aggregator for the final results as we want the final flattened results
val windowedRowAggregator = TileCodec.buildWindowedRowAggregator(groupBy, schema)
expectedBucketedResults.zip(finalResults).zip(windowedRowAggregator.outputSchema.map(_._1)).foreach {
case ((expected, actual), name) =>
println(s"Checking: $name")
assertEquals(expected, actual)
}
}
}