Skip to content

Commit

Permalink
Add tiled implementation of the Flink app (#627)
Browse files Browse the repository at this point in the history
* Add custom triggers

* Move triggers

* Add KeySelector

* Comments

* Rename tiling package to window

* WIP runTiledGroupByJob

* Comment-out AsyncKVStoreWriterTest.scala ? question mark ?

* Add ChrononFlinkRowAggregators

* Refactor AvroCodec slightly

* Add TiledAvroCodecFn

* Add LateEventCounter

* Finish runTiledGroupByJob

* Add ChrononFlinkRowAggregationFunctionTest

* Add missing @test decorator

* Add KeySelector tests

* Add e2e tiled test

* Scalafmt

* Comments

* Uncomment AsyncKVStoreWriterTest

* Remove slot sharing so that test finally halts

* Tweak strings in key selector test

* Rename files, change comments

* keyToBytes in process function should convert to array first

* Refactor tiled Flink test, use watermark strategy

* Improve e2e test so that we check actual tile IRs

* rm debug=true

* Use log4j

* Remove comment

* Minor clean up, change comments

* scalafmt

* Add missing getSmallestWindowResolutionInMillis

* Add missing tiledCodec

* Enable debug logs it tests

* Info instead of debug

* Fix lack of isolation in test sink

* Make BaseAvroCodecFn abstract

* Update FlinkJob comments

* Comment

* Move getSmallestWindowResolutionInMillis to GroupByOps

* Use new GroupByOps method, fix mistake

* Revert "Move getSmallestWindowResolutionInMillis to GroupByOps"

* Use toScala, use multiline strings

* Use logger.debug

* Scalafmt

* Fix compile error
  • Loading branch information
caiocamatta-stripe authored Feb 28, 2024
1 parent c93ba28 commit dcb3750
Show file tree
Hide file tree
Showing 14 changed files with 1,146 additions and 105 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
package ai.chronon.aggregator.windowing

import ai.chronon.api.Extensions.{WindowOps, WindowUtils}
import ai.chronon.api.{TimeUnit, Window}
import ai.chronon.api.{GroupBy, TimeUnit, Window}

import scala.util.ScalaJavaConversions.ListOps
import scala.util.ScalaVersionSpecificCollectionsConverter.convertJavaListToScala

trait Resolution extends Serializable {
// For a given window what is the resolution of the tail
Expand Down Expand Up @@ -57,3 +60,19 @@ object DailyResolution extends Resolution {

val hopSizes: Array[Long] = Array(WindowUtils.Day.millis)
}

object ResolutionUtils {

/**
* Find the smallest tail window resolution in a GroupBy. Returns None if the GroupBy does not define any windows.
* The window resolutions are: 5 min for a GroupBy a window < 12 hrs, 1 hr for < 12 days, 1 day for > 12 days.
* */
def getSmallestWindowResolutionInMillis(groupBy: GroupBy): Option[Long] =
Option(
groupBy.aggregations.toScala.toArray
.flatMap(aggregation =>
if (aggregation.windows != null) aggregation.windows.toScala
else None)
.map(FiveMinuteResolution.calculateTailHop)
).filter(_.nonEmpty).map(_.min)
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ class AsyncKVStoreWriter(onlineImpl: Api, featureGroupName: String)
// The context used for the future callbacks
implicit lazy val executor: ExecutionContext = AsyncKVStoreWriter.ExecutionContextInstance

// One may want to use different KV stores depending on whether tiling is on.
// The untiled version of Chronon works on "append" store semantics, and the tiled version works on "overwrite".
protected def getKVStore: KVStore = {
onlineImpl.genKvStore
}
Expand Down
102 changes: 85 additions & 17 deletions flink/src/main/scala/ai/chronon/flink/AvroCodecFn.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package ai.chronon.flink
import org.slf4j.LoggerFactory
import ai.chronon.api.Extensions.GroupByOps
import ai.chronon.api.{Constants, DataModel, Query, StructType => ChrononStructType}
import ai.chronon.flink.window.TimestampedTile
import ai.chronon.online.{AvroConversions, GroupByServingInfoParsed}
import ai.chronon.online.KVStore.PutRequest
import org.apache.flink.api.common.functions.RichFlatMapFunction
Expand All @@ -13,28 +14,32 @@ import org.apache.flink.util.Collector
import scala.jdk.CollectionConverters._

/**
* A Flink function that is responsible for converting the Spark expr eval output and converting that to a form
* that can be written out to the KV store (PutRequest object)
* @param groupByServingInfoParsed The GroupBy we are working with
* @tparam T The input data type
* Base class for the Avro conversion Flink operator.
*
* Subclasses should override the RichFlatMapFunction methods (flatMap) and groupByServingInfoParsed.
*
* @tparam IN The input data type which contains the data to be avro-converted to bytes.
* @tparam OUT The output data type (generally a PutRequest).
*/
case class AvroCodecFn[T](groupByServingInfoParsed: GroupByServingInfoParsed)
extends RichFlatMapFunction[Map[String, Any], PutRequest] {
@transient lazy val logger = LoggerFactory.getLogger(getClass)
sealed abstract class BaseAvroCodecFn[IN, OUT] extends RichFlatMapFunction[IN, OUT] {
def groupByServingInfoParsed: GroupByServingInfoParsed

@transient lazy val logger = LoggerFactory.getLogger(getClass)
@transient protected var avroConversionErrorCounter: Counter = _
@transient protected var eventProcessingErrorCounter: Counter =
_ // Shared metric for errors across the entire Flink app.

protected val query: Query = groupByServingInfoParsed.groupBy.streamingSource.get.getEvents.query
protected val streamingDataset: String = groupByServingInfoParsed.groupBy.streamingDataset
protected lazy val query: Query = groupByServingInfoParsed.groupBy.streamingSource.get.getEvents.query
protected lazy val streamingDataset: String = groupByServingInfoParsed.groupBy.streamingDataset

// TODO: update to use constant names that are company specific
protected val timeColumnAlias: String = Constants.TimeColumn
protected val timeColumn: String = Option(query.timeColumn).getOrElse(timeColumnAlias)
protected lazy val timeColumnAlias: String = Constants.TimeColumn
protected lazy val timeColumn: String = Option(query.timeColumn).getOrElse(timeColumnAlias)

protected val (keyToBytes, valueToBytes): (Any => Array[Byte], Any => Array[Byte]) =
protected lazy val (keyToBytes, valueToBytes): (Any => Array[Byte], Any => Array[Byte]) =
getKVSerializers(groupByServingInfoParsed)
protected val (keyColumns, valueColumns): (Array[String], Array[String]) = getKVColumns
protected val extraneousRecord: Any => Array[Any] = {
protected lazy val (keyColumns, valueColumns): (Array[String], Array[String]) = getKVColumns
protected lazy val extraneousRecord: Any => Array[Any] = {
case x: Map[_, _] if x.keys.forall(_.isInstanceOf[String]) =>
x.flatMap { case (key, value) => Array(key, value) }.toArray
}
Expand Down Expand Up @@ -70,6 +75,16 @@ case class AvroCodecFn[T](groupByServingInfoParsed: GroupByServingInfoParsed)
val valueColumns = groupByServingInfoParsed.groupBy.aggregationInputs ++ additionalColumns
(keyColumns, valueColumns)
}
}

/**
* A Flink function that is responsible for converting the Spark expr eval output and converting that to a form
* that can be written out to the KV store (PutRequest object)
* @param groupByServingInfoParsed The GroupBy we are working with
* @tparam T The input data type
*/
case class AvroCodecFn[T](groupByServingInfoParsed: GroupByServingInfoParsed)
extends BaseAvroCodecFn[Map[String, Any], PutRequest] {

override def open(configuration: Configuration): Unit = {
super.open(configuration)
Expand All @@ -87,16 +102,69 @@ case class AvroCodecFn[T](groupByServingInfoParsed: GroupByServingInfoParsed)
} catch {
case e: Exception =>
// To improve availability, we don't rethrow the exception. We just drop the event
// and track the errors in a metric. If there are too many errors we'll get alerted/paged.
// and track the errors in a metric. Alerts should be set up on this metric.
logger.error(s"Error converting to Avro bytes - $e")
eventProcessingErrorCounter.inc()
avroConversionErrorCounter.inc()
}

def avroConvertMapToPutRequest(in: Map[String, Any]): PutRequest = {
val tsMills = in(timeColumnAlias).asInstanceOf[Long]
val keyBytes = keyToBytes(keyColumns.map(in.get(_).get))
val valueBytes = valueToBytes(valueColumns.map(in.get(_).get))
val keyBytes = keyToBytes(keyColumns.map(in(_)))
val valueBytes = valueToBytes(valueColumns.map(in(_)))
PutRequest(keyBytes, valueBytes, streamingDataset, Some(tsMills))
}
}

/**
* A Flink function that is responsible for converting an array of pre-aggregates (aka a tile) to a form
* that can be written out to the KV store (PutRequest object).
*
* @param groupByServingInfoParsed The GroupBy we are working with
* @tparam T The input data type
*/
case class TiledAvroCodecFn[T](groupByServingInfoParsed: GroupByServingInfoParsed)
extends BaseAvroCodecFn[TimestampedTile, PutRequest] {
override def open(configuration: Configuration): Unit = {
super.open(configuration)
val metricsGroup = getRuntimeContext.getMetricGroup
.addGroup("chronon")
.addGroup("feature_group", groupByServingInfoParsed.groupBy.getMetaData.getName)
avroConversionErrorCounter = metricsGroup.counter("avro_conversion_errors")
eventProcessingErrorCounter = metricsGroup.counter("event_processing_error")
}
override def close(): Unit = super.close()

override def flatMap(value: TimestampedTile, out: Collector[PutRequest]): Unit =
try {
out.collect(avroConvertTileToPutRequest(value))
} catch {
case e: Exception =>
// To improve availability, we don't rethrow the exception. We just drop the event
// and track the errors in a metric. Alerts should be set up on this metric.
logger.error(s"Error converting to Avro bytes - ", e)
eventProcessingErrorCounter.inc()
avroConversionErrorCounter.inc()
}

def avroConvertTileToPutRequest(in: TimestampedTile): PutRequest = {
val tsMills = in.latestTsMillis

// 'keys' is a map of (key name in schema -> key value), e.g. Map("card_number" -> "4242-4242-4242-4242")
// We convert to AnyRef because Chronon expects an AnyRef (for scala <> java interoperability reasons).
val keys: Map[String, AnyRef] = keyColumns.zip(in.keys.map(_.asInstanceOf[AnyRef])).toMap
val keyBytes = keyToBytes(in.keys.toArray)
val valueBytes = in.tileBytes

logger.debug(
s"""
|Avro converting tile to PutRequest - tile=${in}
|groupBy=${groupByServingInfoParsed.groupBy.getMetaData.getName} tsMills=$tsMills keys=$keys
|keyBytes=${java.util.Base64.getEncoder.encodeToString(keyBytes)}
|valueBytes=${java.util.Base64.getEncoder.encodeToString(valueBytes)}
|streamingDataset=$streamingDataset""".stripMargin
)

PutRequest(keyBytes, valueBytes, streamingDataset, Some(tsMills))
}
}
149 changes: 136 additions & 13 deletions flink/src/main/scala/ai/chronon/flink/FlinkJob.scala
Original file line number Diff line number Diff line change
@@ -1,25 +1,31 @@
package ai.chronon.flink

import ai.chronon.aggregator.windowing.ResolutionUtils
import ai.chronon.api.{DataType}
import ai.chronon.api.Extensions.{GroupByOps, SourceOps}
import ai.chronon.online.GroupByServingInfoParsed
import ai.chronon.flink.window.{
AlwaysFireOnElementTrigger,
FlinkRowAggProcessFunction,
FlinkRowAggregationFunction,
KeySelector,
TimestampedTile
}
import ai.chronon.online.{GroupByServingInfoParsed, SparkConversions}
import ai.chronon.online.KVStore.PutRequest
import org.apache.flink.streaming.api.scala.{DataStream, StreamExecutionEnvironment}
import org.apache.flink.streaming.api.scala.{DataStream, OutputTag, StreamExecutionEnvironment}
import org.apache.spark.sql.Encoder
import org.apache.flink.api.scala._
import org.apache.flink.streaming.api.functions.async.RichAsyncFunction
import org.apache.flink.streaming.api.windowing.assigners.{TumblingEventTimeWindows, WindowAssigner}
import org.apache.flink.streaming.api.windowing.time.Time
import org.apache.flink.streaming.api.windowing.windows.TimeWindow
import org.slf4j.LoggerFactory

/**
* Flink job that processes a single streaming GroupBy and writes out the results
* (raw events in untiled, pre-aggregates in case of tiled) to the KV store.
* At a high level, the operators are structured as follows:
* Kafka source -> Spark expression eval -> Avro conversion -> KV store writer
* Kafka source - Reads objects of type T (specific case class, Thrift / Proto) from a Kafka topic
* Spark expression eval - Evaluates the Spark SQL expression in the GroupBy and projects and filters the input data
* Avro conversion - Converts the Spark expr eval output to a form that can be written out to the KV store (PutRequest object)
* KV store writer - Writes the PutRequest objects to the KV store using the AsyncDataStream API
* Flink job that processes a single streaming GroupBy and writes out the results to the KV store.
*
* In the untiled version there are no-shuffles and thus this ends up being a single node in the Flink DAG
* (with the above 4 operators and parallelism as injected by the user)
* There are two versions of the job, tiled and untiled. The untiled version writes out raw events while the tiled
* version writes out pre-aggregates. See the `runGroupByJob` and `runTiledGroupByJob` methods for more details.
*
* @param eventSrc - Provider of a Flink Datastream[T] for the given topic and feature group
* @param sinkFn - Async Flink writer function to help us write to the KV store
Expand All @@ -33,10 +39,13 @@ class FlinkJob[T](eventSrc: FlinkSource[T],
groupByServingInfoParsed: GroupByServingInfoParsed,
encoder: Encoder[T],
parallelism: Int) {
private[this] val logger = LoggerFactory.getLogger(getClass)

val featureGroupName: String = groupByServingInfoParsed.groupBy.getMetaData.getName
logger.info(f"Creating Flink job. featureGroupName=${featureGroupName}")

protected val exprEval: SparkExpressionEvalFn[T] =
new SparkExpressionEvalFn[T](encoder, groupByServingInfoParsed.groupBy)
val featureGroupName: String = groupByServingInfoParsed.groupBy.getMetaData.getName

if (groupByServingInfoParsed.groupBy.streamingSource.isEmpty) {
throw new IllegalArgumentException(
Expand All @@ -47,7 +56,25 @@ class FlinkJob[T](eventSrc: FlinkSource[T],
// The source of our Flink application is a Kafka topic
val kafkaTopic: String = groupByServingInfoParsed.groupBy.streamingSource.get.topic

/**
* The "untiled" version of the Flink app.
*
* At a high level, the operators are structured as follows:
* Kafka source -> Spark expression eval -> Avro conversion -> KV store writer
* Kafka source - Reads objects of type T (specific case class, Thrift / Proto) from a Kafka topic
* Spark expression eval - Evaluates the Spark SQL expression in the GroupBy and projects and filters the input data
* Avro conversion - Converts the Spark expr eval output to a form that can be written out to the KV store
* (PutRequest object)
* KV store writer - Writes the PutRequest objects to the KV store using the AsyncDataStream API
*
* In this untiled version, there are no shuffles and thus this ends up being a single node in the Flink DAG
* (with the above 4 operators and parallelism as injected by the user).
*/
def runGroupByJob(env: StreamExecutionEnvironment): DataStream[WriteResponse] = {
logger.info(
f"Running Flink job for featureGroupName=${featureGroupName}, kafkaTopic=${kafkaTopic}. " +
f"Tiling is disabled.")

val sourceStream: DataStream[T] =
eventSrc
.getDataStream(kafkaTopic, featureGroupName)(env, parallelism)
Expand All @@ -70,4 +97,100 @@ class FlinkJob[T](eventSrc: FlinkSource[T],
featureGroupName
)
}

/**
* The "tiled" version of the Flink app.
*
* The operators are structured as follows:
* 1. Kafka source - Reads objects of type T (specific case class, Thrift / Proto) from a Kafka topic
* 2. Spark expression eval - Evaluates the Spark SQL expression in the GroupBy and projects and filters the input
* data
* 3. Window/tiling - This window aggregates incoming events, keeps track of the IRs, and sends them forward so
* they are written out to the KV store
* 4. Avro conversion - Finishes converting the output of the window (the IRs) to a form that can be written out
* to the KV store (PutRequest object)
* 5. KV store writer - Writes the PutRequest objects to the KV store using the AsyncDataStream API
*
* The window causes a split in the Flink DAG, so there are two nodes, (1+2) and (3+4+5).
*/
def runTiledGroupByJob(env: StreamExecutionEnvironment): DataStream[WriteResponse] = {
logger.info(
f"Running Flink job for featureGroupName=${featureGroupName}, kafkaTopic=${kafkaTopic}. " +
f"Tiling is enabled.")

val tilingWindowSizeInMillis: Option[Long] =
ResolutionUtils.getSmallestWindowResolutionInMillis(groupByServingInfoParsed.groupBy)

val sourceStream: DataStream[T] =
eventSrc
.getDataStream(kafkaTopic, featureGroupName)(env, parallelism)

val sparkExprEvalDS: DataStream[Map[String, Any]] = sourceStream
.flatMap(exprEval)
.uid(s"spark-expr-eval-flatmap-$featureGroupName")
.name(s"Spark expression eval for $featureGroupName")
.setParallelism(sourceStream.parallelism) // Use same parallelism as previous operator

val inputSchema: Seq[(String, DataType)] =
exprEval.getOutputSchema.fields
.map(field => (field.name, SparkConversions.toChrononType(field.name, field.dataType)))
.toSeq

val window = TumblingEventTimeWindows
.of(Time.milliseconds(tilingWindowSizeInMillis.get))
.asInstanceOf[WindowAssigner[Map[String, Any], TimeWindow]]

// An alternative to AlwaysFireOnElementTrigger can be used: BufferedProcessingTimeTrigger.
// The latter will buffer writes so they happen at most every X milliseconds per GroupBy & key.
val trigger = new AlwaysFireOnElementTrigger()

// We use Flink "Side Outputs" to track any late events that aren't computed.
val tilingLateEventsTag = OutputTag[Map[String, Any]]("tiling-late-events")

// The tiling operator works the following way:
// 1. Input: Spark expression eval (previous operator)
// 2. Key by the entity key(s) defined in the groupby
// 3. Window by a tumbling window
// 4. Use our custom trigger that will "FIRE" on every element
// 5. the AggregationFunction merges each incoming element with the current IRs which are kept in state
// - Each time a "FIRE" is triggered (i.e. on every event), getResult() is called and the current IRs are emitted
// 6. A process window function does additional processing each time the AggregationFunction emits results
// - The only purpose of this window function is to mark tiles as closed so we can do client-side caching in SFS
// 7. Output: TimestampedTile, containing the current IRs (Avro encoded) and the timestamp of the current element
val tilingDS: DataStream[TimestampedTile] =
sparkExprEvalDS
.keyBy(KeySelector.getKeySelectionFunction(groupByServingInfoParsed.groupBy))
.window(window)
.trigger(trigger)
.sideOutputLateData(tilingLateEventsTag)
.aggregate(
// See Flink's "ProcessWindowFunction with Incremental Aggregation"
preAggregator = new FlinkRowAggregationFunction(groupByServingInfoParsed.groupBy, inputSchema),
windowFunction = new FlinkRowAggProcessFunction(groupByServingInfoParsed.groupBy, inputSchema)
)
.uid(s"tiling-01-$featureGroupName")
.name(s"Tiling for $featureGroupName")
.setParallelism(sourceStream.parallelism)

// Track late events
val sideOutputStream: DataStream[Map[String, Any]] =
tilingDS
.getSideOutput(tilingLateEventsTag)
.flatMap(new LateEventCounter(featureGroupName))
.uid(s"tiling-side-output-01-$featureGroupName")
.name(s"Tiling Side Output Late Data for $featureGroupName")
.setParallelism(sourceStream.parallelism)

val putRecordDS: DataStream[PutRequest] = tilingDS
.flatMap(new TiledAvroCodecFn[T](groupByServingInfoParsed))
.uid(s"avro-conversion-01-$featureGroupName")
.name(s"Avro conversion for $featureGroupName")
.setParallelism(sourceStream.parallelism)

AsyncKVStoreWriter.withUnorderedWaits(
putRecordDS,
sinkFn,
featureGroupName
)
}
}
2 changes: 2 additions & 0 deletions flink/src/main/scala/ai/chronon/flink/FlinkSource.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ abstract class FlinkSource[T] extends Serializable {

/**
* Return a Flink DataStream for the given topic and feature group.
*
* When implementing a source, you should also make a conscious decision about your allowed lateness strategy.
*/
def getDataStream(topic: String, groupName: String)(
env: StreamExecutionEnvironment,
Expand Down
Loading

0 comments on commit dcb3750

Please sign in to comment.