-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add tiled implementation of the Flink app #627
Changes from 35 commits
a202daf
3b894a6
393d58b
28aa0ab
27295e7
1224965
4b73485
edac39f
083b870
a39496d
d4816bc
4f295c8
321eaaa
e2e85df
369c62e
75ac625
d0c35b1
68eb5bc
e6ebf3a
40d9a08
213cb7b
09c305b
d52e0fc
b249063
e66b17e
920cc4e
f10468c
164fc62
85431e1
c38c402
4a93eb3
bd44ceb
36b9bf2
639e6ea
fb06f9b
30f2cd7
829a914
2128cfe
e114e18
b93e292
54337a3
b2c6716
c72ee4e
c366012
aab3496
2f53c7e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -75,6 +75,8 @@ class AsyncKVStoreWriter(onlineImpl: Api, featureGroupName: String) | |
// The context used for the future callbacks | ||
implicit lazy val executor: ExecutionContext = AsyncKVStoreWriter.ExecutionContextInstance | ||
|
||
// One may want to use different KV stores depending on whether tiling is on. | ||
// The untiled version of Chronon works on "append" store semantics, and the tiled version works on "overwrite". | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ❤️ |
||
protected def getKVStore: KVStore = { | ||
onlineImpl.genKvStore | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ package ai.chronon.flink | |
import org.slf4j.LoggerFactory | ||
import ai.chronon.api.Extensions.GroupByOps | ||
import ai.chronon.api.{Constants, DataModel, Query, StructType => ChrononStructType} | ||
import ai.chronon.flink.window.TimestampedTile | ||
import ai.chronon.online.{AvroConversions, GroupByServingInfoParsed} | ||
import ai.chronon.online.KVStore.PutRequest | ||
import org.apache.flink.api.common.functions.RichFlatMapFunction | ||
|
@@ -12,17 +13,15 @@ import org.apache.flink.util.Collector | |
|
||
import scala.jdk.CollectionConverters._ | ||
|
||
/** | ||
* A Flink function that is responsible for converting the Spark expr eval output and converting that to a form | ||
* that can be written out to the KV store (PutRequest object) | ||
* @param groupByServingInfoParsed The GroupBy we are working with | ||
* @tparam T The input data type | ||
*/ | ||
case class AvroCodecFn[T](groupByServingInfoParsed: GroupByServingInfoParsed) | ||
extends RichFlatMapFunction[Map[String, Any], PutRequest] { | ||
@transient lazy val logger = LoggerFactory.getLogger(getClass) | ||
// This utility contains common code for AvroCodecFn and TiledAvroCodecFn | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. lets also highlight what the subclasses need to override / specialize |
||
sealed trait AvroCodecFnUtility { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how about a name like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. to make things explicit, we could make this an abstract class that extends from RichFlatMapFunction (so then subclasses know their contract is to provide groupByServingInfoParsed and fill out the open / flatMap methods) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. agreed - |
||
// Should be overriden | ||
val groupByServingInfoParsed: GroupByServingInfoParsed | ||
|
||
@transient lazy val logger = LoggerFactory.getLogger(getClass) | ||
@transient protected var avroConversionErrorCounter: Counter = _ | ||
@transient protected var eventProcessingErrorCounter: Counter = | ||
_ // Shared metric for errors across the entire Flink app. | ||
|
||
protected val query: Query = groupByServingInfoParsed.groupBy.streamingSource.get.getEvents.query | ||
protected val streamingDataset: String = groupByServingInfoParsed.groupBy.streamingDataset | ||
|
@@ -70,6 +69,17 @@ case class AvroCodecFn[T](groupByServingInfoParsed: GroupByServingInfoParsed) | |
val valueColumns = groupByServingInfoParsed.groupBy.aggregationInputs ++ additionalColumns | ||
(keyColumns, valueColumns) | ||
} | ||
} | ||
|
||
/** | ||
* A Flink function that is responsible for converting the Spark expr eval output and converting that to a form | ||
* that can be written out to the KV store (PutRequest object) | ||
* @param groupByServingInfoParsed The GroupBy we are working with | ||
* @tparam T The input data type | ||
*/ | ||
case class AvroCodecFn[T](groupByServingInfoParsed: GroupByServingInfoParsed) | ||
extends RichFlatMapFunction[Map[String, Any], PutRequest] | ||
with AvroCodecFnUtility { | ||
|
||
override def open(configuration: Configuration): Unit = { | ||
super.open(configuration) | ||
|
@@ -87,16 +97,70 @@ case class AvroCodecFn[T](groupByServingInfoParsed: GroupByServingInfoParsed) | |
} catch { | ||
case e: Exception => | ||
// To improve availability, we don't rethrow the exception. We just drop the event | ||
// and track the errors in a metric. If there are too many errors we'll get alerted/paged. | ||
// and track the errors in a metric. Alerts should be set up on this metric. | ||
logger.error(s"Error converting to Avro bytes - $e") | ||
eventProcessingErrorCounter.inc() | ||
avroConversionErrorCounter.inc() | ||
} | ||
|
||
def avroConvertMapToPutRequest(in: Map[String, Any]): PutRequest = { | ||
val tsMills = in(timeColumnAlias).asInstanceOf[Long] | ||
val keyBytes = keyToBytes(keyColumns.map(in.get(_).get)) | ||
val valueBytes = valueToBytes(valueColumns.map(in.get(_).get)) | ||
val keyBytes = keyToBytes(keyColumns.map(in(_))) | ||
val valueBytes = valueToBytes(valueColumns.map(in(_))) | ||
PutRequest(keyBytes, valueBytes, streamingDataset, Some(tsMills)) | ||
} | ||
} | ||
|
||
/** | ||
* A Flink function that is responsible for converting an array of pre-aggregates (aka a tile) to a form | ||
* that can be written out to the KV store (PutRequest object). | ||
* @param groupByServingInfoParsed The GroupBy we are working with | ||
* @tparam T The input data type | ||
*/ | ||
case class TiledAvroCodecFn[T](groupByServingInfoParsed: GroupByServingInfoParsed, debug: Boolean = false) | ||
extends RichFlatMapFunction[TimestampedTile, PutRequest] | ||
with AvroCodecFnUtility { | ||
override def open(configuration: Configuration): Unit = { | ||
super.open(configuration) | ||
val metricsGroup = getRuntimeContext.getMetricGroup | ||
.addGroup("chronon") | ||
.addGroup("feature_group", groupByServingInfoParsed.groupBy.getMetaData.getName) | ||
avroConversionErrorCounter = metricsGroup.counter("avro_conversion_errors") | ||
eventProcessingErrorCounter = metricsGroup.counter("event_processing_error") | ||
} | ||
override def close(): Unit = super.close() | ||
|
||
override def flatMap(value: TimestampedTile, out: Collector[PutRequest]): Unit = | ||
try { | ||
out.collect(avroConvertTileToPutRequest(value)) | ||
} catch { | ||
case e: Exception => | ||
// To improve availability, we don't rethrow the exception. We just drop the event | ||
// and track the errors in a metric. Alerts should be set up on this metric. | ||
logger.error(s"Error converting to Avro bytes - ", e) | ||
eventProcessingErrorCounter.inc() | ||
avroConversionErrorCounter.inc() | ||
} | ||
|
||
def avroConvertTileToPutRequest(in: TimestampedTile): PutRequest = { | ||
val tsMills = in.latestTsMillis | ||
|
||
// 'keys' is a map of (key name in schema -> key value), e.g. Map("card_number" -> "4242-4242-4242-4242") | ||
// We convert to AnyRef because Chronon expects an AnyRef (for scala <> java interoperability reasons). | ||
val keys: Map[String, AnyRef] = keyColumns.zip(in.keys.map(_.asInstanceOf[AnyRef])).toMap | ||
val keyBytes = keyToBytes(in.keys.toArray) | ||
val valueBytes = in.tileBytes | ||
|
||
if (debug) { | ||
logger.info( | ||
f"Avro converting tile to PutRequest - tile=${in} " + | ||
f"groupBy=${groupByServingInfoParsed.groupBy.getMetaData.getName} tsMills=$tsMills keys=$keys " + | ||
f"keyBytes=${java.util.Base64.getEncoder.encodeToString(keyBytes)} " + | ||
f"valueBytes=${java.util.Base64.getEncoder.encodeToString(valueBytes)} " + | ||
f"streamingDataset=$streamingDataset" | ||
caiocamatta-stripe marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
} | ||
|
||
PutRequest(keyBytes, valueBytes, streamingDataset, Some(tsMills)) | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,37 @@ | ||
package ai.chronon.flink | ||
|
||
import ai.chronon.aggregator.windowing.ResolutionUtils | ||
import ai.chronon.api.{DataType} | ||
import ai.chronon.api.Extensions.{GroupByOps, SourceOps} | ||
import ai.chronon.online.GroupByServingInfoParsed | ||
import ai.chronon.flink.window.{ | ||
AlwaysFireOnElementTrigger, | ||
FlinkRowAggProcessFunction, | ||
FlinkRowAggregationFunction, | ||
KeySelector, | ||
TimestampedTile | ||
} | ||
import ai.chronon.online.{GroupByServingInfoParsed, SparkConversions} | ||
import ai.chronon.online.KVStore.PutRequest | ||
import org.apache.flink.streaming.api.scala.{DataStream, StreamExecutionEnvironment} | ||
import org.apache.flink.streaming.api.scala.{DataStream, OutputTag, StreamExecutionEnvironment} | ||
import org.apache.spark.sql.Encoder | ||
import org.apache.flink.api.scala._ | ||
import org.apache.flink.streaming.api.functions.async.RichAsyncFunction | ||
import org.apache.flink.streaming.api.windowing.assigners.{TumblingEventTimeWindows, WindowAssigner} | ||
import org.apache.flink.streaming.api.windowing.time.Time | ||
import org.apache.flink.streaming.api.windowing.windows.TimeWindow | ||
import org.slf4j.LoggerFactory | ||
|
||
/** | ||
* Flink job that processes a single streaming GroupBy and writes out the results | ||
* (raw events in untiled, pre-aggregates in case of tiled) to the KV store. | ||
* | ||
* At a high level, the operators are structured as follows: | ||
* Kafka source -> Spark expression eval -> Avro conversion -> KV store writer | ||
* Kafka source - Reads objects of type T (specific case class, Thrift / Proto) from a Kafka topic | ||
* Spark expression eval - Evaluates the Spark SQL expression in the GroupBy and projects and filters the input data | ||
* Avro conversion - Converts the Spark expr eval output to a form that can be written out to the KV store (PutRequest object) | ||
* KV store writer - Writes the PutRequest objects to the KV store using the AsyncDataStream API | ||
* Kafka source -> Spark expression eval -> Avro conversion -> KV store writer | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this changes now with the new tiling operators right? Should we add those too or maybe we move the operator breakdown to the |
||
* Kafka source - Reads objects of type T (specific case class, Thrift / Proto) from a Kafka topic | ||
* Spark expression eval - Evaluates the Spark SQL expression in the GroupBy and projects and filters the input data | ||
* Avro conversion - Converts the Spark expr eval output to a form that can be written out to the KV store | ||
* (PutRequest object) | ||
* KV store writer - Writes the PutRequest objects to the KV store using the AsyncDataStream API | ||
* | ||
* In the untiled version there are no-shuffles and thus this ends up being a single node in the Flink DAG | ||
* (with the above 4 operators and parallelism as injected by the user) | ||
|
@@ -26,17 +41,22 @@ import org.apache.flink.streaming.api.functions.async.RichAsyncFunction | |
* @param groupByServingInfoParsed - The GroupBy we are working with | ||
* @param encoder - Spark Encoder for the input data type | ||
* @param parallelism - Parallelism to use for the Flink job | ||
* @param debug whether to enable debug logs | ||
* @tparam T - The input data type | ||
*/ | ||
class FlinkJob[T](eventSrc: FlinkSource[T], | ||
sinkFn: RichAsyncFunction[PutRequest, WriteResponse], | ||
groupByServingInfoParsed: GroupByServingInfoParsed, | ||
encoder: Encoder[T], | ||
parallelism: Int) { | ||
parallelism: Int, | ||
debug: Boolean = false) { | ||
private[this] val logger = LoggerFactory.getLogger(getClass) | ||
|
||
val featureGroupName: String = groupByServingInfoParsed.groupBy.getMetaData.getName | ||
logger.info(f"Creating Flink job. featureGroupName=${featureGroupName}") | ||
|
||
protected val exprEval: SparkExpressionEvalFn[T] = | ||
new SparkExpressionEvalFn[T](encoder, groupByServingInfoParsed.groupBy) | ||
val featureGroupName: String = groupByServingInfoParsed.groupBy.getMetaData.getName | ||
|
||
if (groupByServingInfoParsed.groupBy.streamingSource.isEmpty) { | ||
throw new IllegalArgumentException( | ||
|
@@ -48,6 +68,10 @@ class FlinkJob[T](eventSrc: FlinkSource[T], | |
val kafkaTopic: String = groupByServingInfoParsed.groupBy.streamingSource.get.topic | ||
|
||
def runGroupByJob(env: StreamExecutionEnvironment): DataStream[WriteResponse] = { | ||
logger.info( | ||
f"Running Flink job for featureGroupName=${featureGroupName}, kafkaTopic=${kafkaTopic}. " + | ||
f"Tiling is disabled.") | ||
|
||
val sourceStream: DataStream[T] = | ||
eventSrc | ||
.getDataStream(kafkaTopic, featureGroupName)(env, parallelism) | ||
|
@@ -70,4 +94,100 @@ class FlinkJob[T](eventSrc: FlinkSource[T], | |
featureGroupName | ||
) | ||
} | ||
|
||
/** | ||
* The "tiled" version of the Flink app. | ||
* | ||
* The operators are structured as follows: | ||
* 1. Kafka source - Reads objects of type T (specific case class, Thrift / Proto) from a Kafka topic | ||
* 2. Spark expression eval - Evaluates the Spark SQL expression in the GroupBy and projects and filters the input | ||
* data | ||
* 3. Window/tiling - This window aggregates incoming events, keeps track of the IRs, and sends them forward so | ||
* they are written out to the KV store | ||
* 4. Avro conversion - Finishes converting the output of the window (the IRs) to a form that can be written out | ||
* to the KV store (PutRequest object) | ||
* 5. KV store writer - Writes the PutRequest objects to the KV store using the AsyncDataStream API | ||
* | ||
* The window causes a split in the Flink DAG, so there are two nodes, (1+2) and (3+4+5). | ||
*/ | ||
def runTiledGroupByJob(env: StreamExecutionEnvironment): DataStream[WriteResponse] = { | ||
logger.info( | ||
f"Running Flink job for featureGroupName=${featureGroupName}, kafkaTopic=${kafkaTopic}. " + | ||
f"Tiling is enabled.") | ||
|
||
val tilingWindowSizeInMillis: Option[Long] = | ||
ResolutionUtils.getSmallestWindowResolutionInMillis(groupByServingInfoParsed.groupBy) | ||
|
||
val sourceStream: DataStream[T] = | ||
eventSrc | ||
.getDataStream(kafkaTopic, featureGroupName)(env, parallelism) | ||
|
||
val sparkExprEvalDS: DataStream[Map[String, Any]] = sourceStream | ||
.flatMap(exprEval) | ||
.uid(s"spark-expr-eval-flatmap-$featureGroupName") | ||
.name(s"Spark expression eval for $featureGroupName") | ||
.setParallelism(sourceStream.parallelism) // Use same parallelism as previous operator | ||
|
||
val inputSchema: Seq[(String, DataType)] = | ||
exprEval.getOutputSchema.fields | ||
.map(field => (field.name, SparkConversions.toChrononType(field.name, field.dataType))) | ||
.toSeq | ||
|
||
val window = TumblingEventTimeWindows | ||
.of(Time.milliseconds(tilingWindowSizeInMillis.get)) | ||
.asInstanceOf[WindowAssigner[Map[String, Any], TimeWindow]] | ||
|
||
// An alternative to AlwaysFireOnElementTrigger can be used: BufferedProcessingTimeTrigger. | ||
// The latter will buffer writes so they happen at most every X milliseconds per GroupBy & key. | ||
val trigger = new AlwaysFireOnElementTrigger() | ||
|
||
// We use Flink "Side Outputs" to track any late events that aren't computed. | ||
val tilingLateEventsTag = OutputTag[Map[String, Any]]("tiling-late-events") | ||
|
||
// The tiling operator works the following way: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice comments! |
||
// 1. Input: Spark expression eval (previous operator) | ||
// 2. Key by the entity key(s) defined in the groupby | ||
// 3. Window by a tumbling window | ||
// 4. Use our custom trigger that will "FIRE" on every element | ||
// 5. the AggregationFunction merges each incoming element with the current IRs which are kept in state | ||
// - Each time a "FIRE" is triggered (i.e. on every event), getResult() is called and the current IRs are emitted | ||
// 6. A process window function does additional processing each time the AggregationFunction emits results | ||
// - The only purpose of this window function is to mark tiles as closed so we can do client-side caching in SFS | ||
// 7. Output: TimestampedTile, containing the current IRs (Avro encoded) and the timestamp of the current element | ||
val tilingDS: DataStream[TimestampedTile] = | ||
sparkExprEvalDS | ||
.keyBy(KeySelector.getKeySelectionFunction(groupByServingInfoParsed.groupBy)) | ||
.window(window) | ||
.trigger(trigger) | ||
.sideOutputLateData(tilingLateEventsTag) | ||
.aggregate( | ||
// See Flink's "ProcessWindowFunction with Incremental Aggregation" | ||
preAggregator = new FlinkRowAggregationFunction(groupByServingInfoParsed.groupBy, inputSchema, debug), | ||
windowFunction = new FlinkRowAggProcessFunction(groupByServingInfoParsed.groupBy, inputSchema, debug) | ||
) | ||
.uid(s"tiling-01-$featureGroupName") | ||
.name(s"Tiling for $featureGroupName") | ||
.setParallelism(sourceStream.parallelism) | ||
|
||
// Track late events | ||
val sideOutputStream: DataStream[Map[String, Any]] = | ||
tilingDS | ||
.getSideOutput(tilingLateEventsTag) | ||
.flatMap(new LateEventCounter(featureGroupName)) | ||
.uid(s"tiling-side-output-01-$featureGroupName") | ||
.name(s"Tiling Side Output Late Data for $featureGroupName") | ||
.setParallelism(sourceStream.parallelism) | ||
|
||
val putRecordDS: DataStream[PutRequest] = tilingDS | ||
.flatMap(new TiledAvroCodecFn[T](groupByServingInfoParsed, debug)) | ||
.uid(s"avro-conversion-01-$featureGroupName") | ||
.name(s"Avro conversion for $featureGroupName") | ||
.setParallelism(sourceStream.parallelism) | ||
|
||
AsyncKVStoreWriter.withUnorderedWaits( | ||
putRecordDS, | ||
sinkFn, | ||
featureGroupName | ||
) | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
package ai.chronon.flink | ||
|
||
import org.apache.flink.api.common.functions.RichFlatMapFunction | ||
import org.apache.flink.configuration.Configuration | ||
import org.apache.flink.metrics.Counter | ||
import org.apache.flink.util.Collector | ||
|
||
/** | ||
* Function to count late events. | ||
* | ||
* This function should consume the Side Output of the main tiling window. | ||
* */ | ||
class LateEventCounter(featureGroupName: String) extends RichFlatMapFunction[Map[String, Any], Map[String, Any]] { | ||
@transient private var lateEventCounter: Counter = _ | ||
|
||
override def open(parameters: Configuration): Unit = { | ||
val metricsGroup = getRuntimeContext.getMetricGroup | ||
.addGroup("chronon") | ||
.addGroup("feature_group", featureGroupName) | ||
lateEventCounter = metricsGroup.counter("tiling.late_events") | ||
} | ||
|
||
override def flatMap(in: Map[String, Any], out: Collector[Map[String, Any]]): Unit = { | ||
lateEventCounter.inc() | ||
out.collect(in); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could live in
Extensions
- we already have helper functions likeMaxWindow
etc on GroupBy as implicits: https://github.com/airbnb/chronon/blob/master/api/src/main/scala/ai/chronon/api/Extensions.scala#L416There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Extensions
does seem like a good place for this, but it would requireapi
to depend onaggregator
so it can useFiveMinuteResolution.calculateTailHop
. Or duplicatingcalculateTailHop
inExtensions
. What do you think?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aggregator should already be able to access api/Extensions.scala
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@caiocamatta-stripe if IntelliJ doesn't auto-complete or detect the functions in Aggregator you might need to install a scala extension, because the implicits need a little help to work in IDE I think. But it should still compile from sbt.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I may be doing something silly..
object aggregator is not a member of package ai.chronon
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feel free to move resolution & 5minResolution into api if you want to achieve this. But not necessary in this PR IMO