diff --git a/aggregator/src/main/scala/ai/chronon/aggregator/windowing/Resolution.scala b/aggregator/src/main/scala/ai/chronon/aggregator/windowing/Resolution.scala index 681df08c9..5db1fd5a9 100644 --- a/aggregator/src/main/scala/ai/chronon/aggregator/windowing/Resolution.scala +++ b/aggregator/src/main/scala/ai/chronon/aggregator/windowing/Resolution.scala @@ -17,7 +17,10 @@ package ai.chronon.aggregator.windowing import ai.chronon.api.Extensions.{WindowOps, WindowUtils} -import ai.chronon.api.{TimeUnit, Window} +import ai.chronon.api.{GroupBy, TimeUnit, Window} + +import scala.util.ScalaJavaConversions.ListOps +import scala.util.ScalaVersionSpecificCollectionsConverter.convertJavaListToScala trait Resolution extends Serializable { // For a given window what is the resolution of the tail @@ -57,3 +60,19 @@ object DailyResolution extends Resolution { val hopSizes: Array[Long] = Array(WindowUtils.Day.millis) } + +object ResolutionUtils { + + /** + * Find the smallest tail window resolution in a GroupBy. Returns None if the GroupBy does not define any windows. + * The window resolutions are: 5 min for a GroupBy a window < 12 hrs, 1 hr for < 12 days, 1 day for > 12 days. + * */ + def getSmallestWindowResolutionInMillis(groupBy: GroupBy): Option[Long] = + Option( + groupBy.aggregations.toScala.toArray + .flatMap(aggregation => + if (aggregation.windows != null) aggregation.windows.toScala + else None) + .map(FiveMinuteResolution.calculateTailHop) + ).filter(_.nonEmpty).map(_.min) +} diff --git a/flink/src/main/scala/ai/chronon/flink/AsyncKVStoreWriter.scala b/flink/src/main/scala/ai/chronon/flink/AsyncKVStoreWriter.scala index 0912a2e94..6d3fecae3 100644 --- a/flink/src/main/scala/ai/chronon/flink/AsyncKVStoreWriter.scala +++ b/flink/src/main/scala/ai/chronon/flink/AsyncKVStoreWriter.scala @@ -75,6 +75,8 @@ class AsyncKVStoreWriter(onlineImpl: Api, featureGroupName: String) // The context used for the future callbacks implicit lazy val executor: ExecutionContext = AsyncKVStoreWriter.ExecutionContextInstance + // One may want to use different KV stores depending on whether tiling is on. + // The untiled version of Chronon works on "append" store semantics, and the tiled version works on "overwrite". protected def getKVStore: KVStore = { onlineImpl.genKvStore } diff --git a/flink/src/main/scala/ai/chronon/flink/AvroCodecFn.scala b/flink/src/main/scala/ai/chronon/flink/AvroCodecFn.scala index a88ea0aa7..16eb8dbb2 100644 --- a/flink/src/main/scala/ai/chronon/flink/AvroCodecFn.scala +++ b/flink/src/main/scala/ai/chronon/flink/AvroCodecFn.scala @@ -3,6 +3,7 @@ package ai.chronon.flink import org.slf4j.LoggerFactory import ai.chronon.api.Extensions.GroupByOps import ai.chronon.api.{Constants, DataModel, Query, StructType => ChrononStructType} +import ai.chronon.flink.window.TimestampedTile import ai.chronon.online.{AvroConversions, GroupByServingInfoParsed} import ai.chronon.online.KVStore.PutRequest import org.apache.flink.api.common.functions.RichFlatMapFunction @@ -13,28 +14,32 @@ import org.apache.flink.util.Collector import scala.jdk.CollectionConverters._ /** - * A Flink function that is responsible for converting the Spark expr eval output and converting that to a form - * that can be written out to the KV store (PutRequest object) - * @param groupByServingInfoParsed The GroupBy we are working with - * @tparam T The input data type + * Base class for the Avro conversion Flink operator. + * + * Subclasses should override the RichFlatMapFunction methods (flatMap) and groupByServingInfoParsed. + * + * @tparam IN The input data type which contains the data to be avro-converted to bytes. + * @tparam OUT The output data type (generally a PutRequest). */ -case class AvroCodecFn[T](groupByServingInfoParsed: GroupByServingInfoParsed) - extends RichFlatMapFunction[Map[String, Any], PutRequest] { - @transient lazy val logger = LoggerFactory.getLogger(getClass) +sealed abstract class BaseAvroCodecFn[IN, OUT] extends RichFlatMapFunction[IN, OUT] { + def groupByServingInfoParsed: GroupByServingInfoParsed + @transient lazy val logger = LoggerFactory.getLogger(getClass) @transient protected var avroConversionErrorCounter: Counter = _ + @transient protected var eventProcessingErrorCounter: Counter = + _ // Shared metric for errors across the entire Flink app. - protected val query: Query = groupByServingInfoParsed.groupBy.streamingSource.get.getEvents.query - protected val streamingDataset: String = groupByServingInfoParsed.groupBy.streamingDataset + protected lazy val query: Query = groupByServingInfoParsed.groupBy.streamingSource.get.getEvents.query + protected lazy val streamingDataset: String = groupByServingInfoParsed.groupBy.streamingDataset // TODO: update to use constant names that are company specific - protected val timeColumnAlias: String = Constants.TimeColumn - protected val timeColumn: String = Option(query.timeColumn).getOrElse(timeColumnAlias) + protected lazy val timeColumnAlias: String = Constants.TimeColumn + protected lazy val timeColumn: String = Option(query.timeColumn).getOrElse(timeColumnAlias) - protected val (keyToBytes, valueToBytes): (Any => Array[Byte], Any => Array[Byte]) = + protected lazy val (keyToBytes, valueToBytes): (Any => Array[Byte], Any => Array[Byte]) = getKVSerializers(groupByServingInfoParsed) - protected val (keyColumns, valueColumns): (Array[String], Array[String]) = getKVColumns - protected val extraneousRecord: Any => Array[Any] = { + protected lazy val (keyColumns, valueColumns): (Array[String], Array[String]) = getKVColumns + protected lazy val extraneousRecord: Any => Array[Any] = { case x: Map[_, _] if x.keys.forall(_.isInstanceOf[String]) => x.flatMap { case (key, value) => Array(key, value) }.toArray } @@ -70,6 +75,16 @@ case class AvroCodecFn[T](groupByServingInfoParsed: GroupByServingInfoParsed) val valueColumns = groupByServingInfoParsed.groupBy.aggregationInputs ++ additionalColumns (keyColumns, valueColumns) } +} + +/** + * A Flink function that is responsible for converting the Spark expr eval output and converting that to a form + * that can be written out to the KV store (PutRequest object) + * @param groupByServingInfoParsed The GroupBy we are working with + * @tparam T The input data type + */ +case class AvroCodecFn[T](groupByServingInfoParsed: GroupByServingInfoParsed) + extends BaseAvroCodecFn[Map[String, Any], PutRequest] { override def open(configuration: Configuration): Unit = { super.open(configuration) @@ -87,16 +102,69 @@ case class AvroCodecFn[T](groupByServingInfoParsed: GroupByServingInfoParsed) } catch { case e: Exception => // To improve availability, we don't rethrow the exception. We just drop the event - // and track the errors in a metric. If there are too many errors we'll get alerted/paged. + // and track the errors in a metric. Alerts should be set up on this metric. logger.error(s"Error converting to Avro bytes - $e") + eventProcessingErrorCounter.inc() avroConversionErrorCounter.inc() } def avroConvertMapToPutRequest(in: Map[String, Any]): PutRequest = { val tsMills = in(timeColumnAlias).asInstanceOf[Long] - val keyBytes = keyToBytes(keyColumns.map(in.get(_).get)) - val valueBytes = valueToBytes(valueColumns.map(in.get(_).get)) + val keyBytes = keyToBytes(keyColumns.map(in(_))) + val valueBytes = valueToBytes(valueColumns.map(in(_))) PutRequest(keyBytes, valueBytes, streamingDataset, Some(tsMills)) } +} +/** + * A Flink function that is responsible for converting an array of pre-aggregates (aka a tile) to a form + * that can be written out to the KV store (PutRequest object). + * + * @param groupByServingInfoParsed The GroupBy we are working with + * @tparam T The input data type + */ +case class TiledAvroCodecFn[T](groupByServingInfoParsed: GroupByServingInfoParsed) + extends BaseAvroCodecFn[TimestampedTile, PutRequest] { + override def open(configuration: Configuration): Unit = { + super.open(configuration) + val metricsGroup = getRuntimeContext.getMetricGroup + .addGroup("chronon") + .addGroup("feature_group", groupByServingInfoParsed.groupBy.getMetaData.getName) + avroConversionErrorCounter = metricsGroup.counter("avro_conversion_errors") + eventProcessingErrorCounter = metricsGroup.counter("event_processing_error") + } + override def close(): Unit = super.close() + + override def flatMap(value: TimestampedTile, out: Collector[PutRequest]): Unit = + try { + out.collect(avroConvertTileToPutRequest(value)) + } catch { + case e: Exception => + // To improve availability, we don't rethrow the exception. We just drop the event + // and track the errors in a metric. Alerts should be set up on this metric. + logger.error(s"Error converting to Avro bytes - ", e) + eventProcessingErrorCounter.inc() + avroConversionErrorCounter.inc() + } + + def avroConvertTileToPutRequest(in: TimestampedTile): PutRequest = { + val tsMills = in.latestTsMillis + + // 'keys' is a map of (key name in schema -> key value), e.g. Map("card_number" -> "4242-4242-4242-4242") + // We convert to AnyRef because Chronon expects an AnyRef (for scala <> java interoperability reasons). + val keys: Map[String, AnyRef] = keyColumns.zip(in.keys.map(_.asInstanceOf[AnyRef])).toMap + val keyBytes = keyToBytes(in.keys.toArray) + val valueBytes = in.tileBytes + + logger.debug( + s""" + |Avro converting tile to PutRequest - tile=${in} + |groupBy=${groupByServingInfoParsed.groupBy.getMetaData.getName} tsMills=$tsMills keys=$keys + |keyBytes=${java.util.Base64.getEncoder.encodeToString(keyBytes)} + |valueBytes=${java.util.Base64.getEncoder.encodeToString(valueBytes)} + |streamingDataset=$streamingDataset""".stripMargin + ) + + PutRequest(keyBytes, valueBytes, streamingDataset, Some(tsMills)) + } } diff --git a/flink/src/main/scala/ai/chronon/flink/FlinkJob.scala b/flink/src/main/scala/ai/chronon/flink/FlinkJob.scala index 1a275e950..25b7f0039 100644 --- a/flink/src/main/scala/ai/chronon/flink/FlinkJob.scala +++ b/flink/src/main/scala/ai/chronon/flink/FlinkJob.scala @@ -1,25 +1,31 @@ package ai.chronon.flink +import ai.chronon.aggregator.windowing.ResolutionUtils +import ai.chronon.api.{DataType} import ai.chronon.api.Extensions.{GroupByOps, SourceOps} -import ai.chronon.online.GroupByServingInfoParsed +import ai.chronon.flink.window.{ + AlwaysFireOnElementTrigger, + FlinkRowAggProcessFunction, + FlinkRowAggregationFunction, + KeySelector, + TimestampedTile +} +import ai.chronon.online.{GroupByServingInfoParsed, SparkConversions} import ai.chronon.online.KVStore.PutRequest -import org.apache.flink.streaming.api.scala.{DataStream, StreamExecutionEnvironment} +import org.apache.flink.streaming.api.scala.{DataStream, OutputTag, StreamExecutionEnvironment} import org.apache.spark.sql.Encoder import org.apache.flink.api.scala._ import org.apache.flink.streaming.api.functions.async.RichAsyncFunction +import org.apache.flink.streaming.api.windowing.assigners.{TumblingEventTimeWindows, WindowAssigner} +import org.apache.flink.streaming.api.windowing.time.Time +import org.apache.flink.streaming.api.windowing.windows.TimeWindow +import org.slf4j.LoggerFactory /** - * Flink job that processes a single streaming GroupBy and writes out the results - * (raw events in untiled, pre-aggregates in case of tiled) to the KV store. - * At a high level, the operators are structured as follows: - * Kafka source -> Spark expression eval -> Avro conversion -> KV store writer - * Kafka source - Reads objects of type T (specific case class, Thrift / Proto) from a Kafka topic - * Spark expression eval - Evaluates the Spark SQL expression in the GroupBy and projects and filters the input data - * Avro conversion - Converts the Spark expr eval output to a form that can be written out to the KV store (PutRequest object) - * KV store writer - Writes the PutRequest objects to the KV store using the AsyncDataStream API + * Flink job that processes a single streaming GroupBy and writes out the results to the KV store. * - * In the untiled version there are no-shuffles and thus this ends up being a single node in the Flink DAG - * (with the above 4 operators and parallelism as injected by the user) + * There are two versions of the job, tiled and untiled. The untiled version writes out raw events while the tiled + * version writes out pre-aggregates. See the `runGroupByJob` and `runTiledGroupByJob` methods for more details. * * @param eventSrc - Provider of a Flink Datastream[T] for the given topic and feature group * @param sinkFn - Async Flink writer function to help us write to the KV store @@ -33,10 +39,13 @@ class FlinkJob[T](eventSrc: FlinkSource[T], groupByServingInfoParsed: GroupByServingInfoParsed, encoder: Encoder[T], parallelism: Int) { + private[this] val logger = LoggerFactory.getLogger(getClass) + + val featureGroupName: String = groupByServingInfoParsed.groupBy.getMetaData.getName + logger.info(f"Creating Flink job. featureGroupName=${featureGroupName}") protected val exprEval: SparkExpressionEvalFn[T] = new SparkExpressionEvalFn[T](encoder, groupByServingInfoParsed.groupBy) - val featureGroupName: String = groupByServingInfoParsed.groupBy.getMetaData.getName if (groupByServingInfoParsed.groupBy.streamingSource.isEmpty) { throw new IllegalArgumentException( @@ -47,7 +56,25 @@ class FlinkJob[T](eventSrc: FlinkSource[T], // The source of our Flink application is a Kafka topic val kafkaTopic: String = groupByServingInfoParsed.groupBy.streamingSource.get.topic + /** + * The "untiled" version of the Flink app. + * + * At a high level, the operators are structured as follows: + * Kafka source -> Spark expression eval -> Avro conversion -> KV store writer + * Kafka source - Reads objects of type T (specific case class, Thrift / Proto) from a Kafka topic + * Spark expression eval - Evaluates the Spark SQL expression in the GroupBy and projects and filters the input data + * Avro conversion - Converts the Spark expr eval output to a form that can be written out to the KV store + * (PutRequest object) + * KV store writer - Writes the PutRequest objects to the KV store using the AsyncDataStream API + * + * In this untiled version, there are no shuffles and thus this ends up being a single node in the Flink DAG + * (with the above 4 operators and parallelism as injected by the user). + */ def runGroupByJob(env: StreamExecutionEnvironment): DataStream[WriteResponse] = { + logger.info( + f"Running Flink job for featureGroupName=${featureGroupName}, kafkaTopic=${kafkaTopic}. " + + f"Tiling is disabled.") + val sourceStream: DataStream[T] = eventSrc .getDataStream(kafkaTopic, featureGroupName)(env, parallelism) @@ -70,4 +97,100 @@ class FlinkJob[T](eventSrc: FlinkSource[T], featureGroupName ) } + + /** + * The "tiled" version of the Flink app. + * + * The operators are structured as follows: + * 1. Kafka source - Reads objects of type T (specific case class, Thrift / Proto) from a Kafka topic + * 2. Spark expression eval - Evaluates the Spark SQL expression in the GroupBy and projects and filters the input + * data + * 3. Window/tiling - This window aggregates incoming events, keeps track of the IRs, and sends them forward so + * they are written out to the KV store + * 4. Avro conversion - Finishes converting the output of the window (the IRs) to a form that can be written out + * to the KV store (PutRequest object) + * 5. KV store writer - Writes the PutRequest objects to the KV store using the AsyncDataStream API + * + * The window causes a split in the Flink DAG, so there are two nodes, (1+2) and (3+4+5). + */ + def runTiledGroupByJob(env: StreamExecutionEnvironment): DataStream[WriteResponse] = { + logger.info( + f"Running Flink job for featureGroupName=${featureGroupName}, kafkaTopic=${kafkaTopic}. " + + f"Tiling is enabled.") + + val tilingWindowSizeInMillis: Option[Long] = + ResolutionUtils.getSmallestWindowResolutionInMillis(groupByServingInfoParsed.groupBy) + + val sourceStream: DataStream[T] = + eventSrc + .getDataStream(kafkaTopic, featureGroupName)(env, parallelism) + + val sparkExprEvalDS: DataStream[Map[String, Any]] = sourceStream + .flatMap(exprEval) + .uid(s"spark-expr-eval-flatmap-$featureGroupName") + .name(s"Spark expression eval for $featureGroupName") + .setParallelism(sourceStream.parallelism) // Use same parallelism as previous operator + + val inputSchema: Seq[(String, DataType)] = + exprEval.getOutputSchema.fields + .map(field => (field.name, SparkConversions.toChrononType(field.name, field.dataType))) + .toSeq + + val window = TumblingEventTimeWindows + .of(Time.milliseconds(tilingWindowSizeInMillis.get)) + .asInstanceOf[WindowAssigner[Map[String, Any], TimeWindow]] + + // An alternative to AlwaysFireOnElementTrigger can be used: BufferedProcessingTimeTrigger. + // The latter will buffer writes so they happen at most every X milliseconds per GroupBy & key. + val trigger = new AlwaysFireOnElementTrigger() + + // We use Flink "Side Outputs" to track any late events that aren't computed. + val tilingLateEventsTag = OutputTag[Map[String, Any]]("tiling-late-events") + + // The tiling operator works the following way: + // 1. Input: Spark expression eval (previous operator) + // 2. Key by the entity key(s) defined in the groupby + // 3. Window by a tumbling window + // 4. Use our custom trigger that will "FIRE" on every element + // 5. the AggregationFunction merges each incoming element with the current IRs which are kept in state + // - Each time a "FIRE" is triggered (i.e. on every event), getResult() is called and the current IRs are emitted + // 6. A process window function does additional processing each time the AggregationFunction emits results + // - The only purpose of this window function is to mark tiles as closed so we can do client-side caching in SFS + // 7. Output: TimestampedTile, containing the current IRs (Avro encoded) and the timestamp of the current element + val tilingDS: DataStream[TimestampedTile] = + sparkExprEvalDS + .keyBy(KeySelector.getKeySelectionFunction(groupByServingInfoParsed.groupBy)) + .window(window) + .trigger(trigger) + .sideOutputLateData(tilingLateEventsTag) + .aggregate( + // See Flink's "ProcessWindowFunction with Incremental Aggregation" + preAggregator = new FlinkRowAggregationFunction(groupByServingInfoParsed.groupBy, inputSchema), + windowFunction = new FlinkRowAggProcessFunction(groupByServingInfoParsed.groupBy, inputSchema) + ) + .uid(s"tiling-01-$featureGroupName") + .name(s"Tiling for $featureGroupName") + .setParallelism(sourceStream.parallelism) + + // Track late events + val sideOutputStream: DataStream[Map[String, Any]] = + tilingDS + .getSideOutput(tilingLateEventsTag) + .flatMap(new LateEventCounter(featureGroupName)) + .uid(s"tiling-side-output-01-$featureGroupName") + .name(s"Tiling Side Output Late Data for $featureGroupName") + .setParallelism(sourceStream.parallelism) + + val putRecordDS: DataStream[PutRequest] = tilingDS + .flatMap(new TiledAvroCodecFn[T](groupByServingInfoParsed)) + .uid(s"avro-conversion-01-$featureGroupName") + .name(s"Avro conversion for $featureGroupName") + .setParallelism(sourceStream.parallelism) + + AsyncKVStoreWriter.withUnorderedWaits( + putRecordDS, + sinkFn, + featureGroupName + ) + } } diff --git a/flink/src/main/scala/ai/chronon/flink/FlinkSource.scala b/flink/src/main/scala/ai/chronon/flink/FlinkSource.scala index ceeb0d9c6..336525556 100644 --- a/flink/src/main/scala/ai/chronon/flink/FlinkSource.scala +++ b/flink/src/main/scala/ai/chronon/flink/FlinkSource.scala @@ -6,6 +6,8 @@ abstract class FlinkSource[T] extends Serializable { /** * Return a Flink DataStream for the given topic and feature group. + * + * When implementing a source, you should also make a conscious decision about your allowed lateness strategy. */ def getDataStream(topic: String, groupName: String)( env: StreamExecutionEnvironment, diff --git a/flink/src/main/scala/ai/chronon/flink/RichMetricsOperators.scala b/flink/src/main/scala/ai/chronon/flink/RichMetricsOperators.scala new file mode 100644 index 000000000..086ecc865 --- /dev/null +++ b/flink/src/main/scala/ai/chronon/flink/RichMetricsOperators.scala @@ -0,0 +1,27 @@ +package ai.chronon.flink + +import org.apache.flink.api.common.functions.RichFlatMapFunction +import org.apache.flink.configuration.Configuration +import org.apache.flink.metrics.Counter +import org.apache.flink.util.Collector + +/** + * Function to count late events. + * + * This function should consume the Side Output of the main tiling window. + * */ +class LateEventCounter(featureGroupName: String) extends RichFlatMapFunction[Map[String, Any], Map[String, Any]] { + @transient private var lateEventCounter: Counter = _ + + override def open(parameters: Configuration): Unit = { + val metricsGroup = getRuntimeContext.getMetricGroup + .addGroup("chronon") + .addGroup("feature_group", featureGroupName) + lateEventCounter = metricsGroup.counter("tiling.late_events") + } + + override def flatMap(in: Map[String, Any], out: Collector[Map[String, Any]]): Unit = { + lateEventCounter.inc() + out.collect(in); + } +} diff --git a/flink/src/main/scala/ai/chronon/flink/SparkExpressionEvalFn.scala b/flink/src/main/scala/ai/chronon/flink/SparkExpressionEvalFn.scala index 78e44540b..793517cbc 100644 --- a/flink/src/main/scala/ai/chronon/flink/SparkExpressionEvalFn.scala +++ b/flink/src/main/scala/ai/chronon/flink/SparkExpressionEvalFn.scala @@ -101,7 +101,7 @@ class SparkExpressionEvalFn[T](encoder: Encoder[T], groupBy: GroupBy) extends Ri } catch { case e: Exception => // To improve availability, we don't rethrow the exception. We just drop the event - // and track the errors in a metric. If there are too many errors we'll get alerted/paged. + // and track the errors in a metric. Alerts should be set up on this metric. logger.error(s"Error evaluating Spark expression - $e") exprEvalErrorCounter.inc() } diff --git a/flink/src/main/scala/ai/chronon/flink/window/FlinkRowAggregators.scala b/flink/src/main/scala/ai/chronon/flink/window/FlinkRowAggregators.scala new file mode 100644 index 000000000..090772362 --- /dev/null +++ b/flink/src/main/scala/ai/chronon/flink/window/FlinkRowAggregators.scala @@ -0,0 +1,206 @@ +package ai.chronon.flink.window + +import ai.chronon.aggregator.row.RowAggregator +import ai.chronon.api.Extensions.GroupByOps +import ai.chronon.api.{Constants, DataType, GroupBy, Row} +import ai.chronon.online.{ArrayRow, TileCodec} +import org.apache.flink.api.common.functions.AggregateFunction +import org.apache.flink.configuration.Configuration +import org.apache.flink.metrics.Counter +import org.apache.flink.streaming.api.scala.function.ProcessWindowFunction +import org.apache.flink.streaming.api.windowing.windows.TimeWindow +import org.apache.flink.util.Collector +import org.slf4j.LoggerFactory + +import scala.util.{Failure, Success, Try} + +/** + * TimestampedIR combines the current Intermediate Result with the timestamp of the event being processed. + * We need to keep track of the timestamp of the event processed so we can calculate processing lag down the line. + * + * Example: for a GroupBy with 2 windows, we'd have TimestampedTile( [IR for window 1, IR for window 2], timestamp ). + * + * @param ir the array of partial aggregates + * @param latestTsMillis timestamp of the current event being processed + */ +case class TimestampedIR( + ir: Array[Any], + latestTsMillis: Option[Long] +) + +/** + * Wrapper Flink aggregator around Chronon's RowAggregator. Relies on Flink to pass in + * the correct set of events for the tile. As the aggregates produced by this function + * are used on the serving side along with other pre-aggregates, we don't 'finalize' the + * Chronon RowAggregator and instead return the intermediate representation. + * + * (This cannot be a RichAggregateFunction because Flink does not support Rich functions in windows.) + */ +class FlinkRowAggregationFunction( + groupBy: GroupBy, + inputSchema: Seq[(String, DataType)] +) extends AggregateFunction[Map[String, Any], TimestampedIR, TimestampedIR] { + @transient private[flink] var rowAggregator: RowAggregator = _ + @transient lazy val logger = LoggerFactory.getLogger(getClass) + + private val valueColumns: Array[String] = inputSchema.map(_._1).toArray // column order matters + private val timeColumnAlias: String = Constants.TimeColumn + + /* + * Initialize the transient rowAggregator. + * Running this method is an idempotent operation: + * 1. The initialized RowAggregator is always the same given a `groupBy` and `inputSchema`. + * 2. The RowAggregator itself doens't hold state; Flink keeps track of the state of the IRs. + */ + private def initializeRowAggregator(): Unit = + rowAggregator = TileCodec.buildRowAggregator(groupBy, inputSchema) + + override def createAccumulator(): TimestampedIR = { + initializeRowAggregator() + TimestampedIR(rowAggregator.init, None) + } + + override def add( + element: Map[String, Any], + accumulatorIr: TimestampedIR + ): TimestampedIR = { + // Most times, the time column is a Long, but it could be a Double. + val tsMills = Try(element(timeColumnAlias).asInstanceOf[Long]) + .getOrElse(element(timeColumnAlias).asInstanceOf[Double].toLong) + val row = toChrononRow(element, tsMills) + + // Given that the rowAggregator is transient, it may be null when a job is restored from a checkpoint + if (rowAggregator == null) { + logger.debug( + f"The Flink RowAggregator was null for groupBy=${groupBy.getMetaData.getName} tsMills=$tsMills" + ) + initializeRowAggregator() + } + + logger.debug( + f"Flink pre-aggregates BEFORE adding new element: accumulatorIr=[${accumulatorIr.ir + .mkString(", ")}] groupBy=${groupBy.getMetaData.getName} tsMills=$tsMills element=$element" + ) + + val partialAggregates = Try { + rowAggregator.update(accumulatorIr.ir, row) + } + + partialAggregates match { + case Success(v) => { + logger.debug( + f"Flink pre-aggregates AFTER adding new element [${v.mkString(", ")}] " + + f"groupBy=${groupBy.getMetaData.getName} tsMills=$tsMills element=$element" + ) + TimestampedIR(v, Some(tsMills)) + } + case Failure(e) => + logger.error( + s"Flink error calculating partial row aggregate. " + + s"groupBy=${groupBy.getMetaData.getName} tsMills=$tsMills element=$element", + e + ) + throw e + } + } + + // Note we return intermediate results here as the results of this + // aggregator are used on the serving side along with other pre-aggregates + override def getResult(accumulatorIr: TimestampedIR): TimestampedIR = + accumulatorIr + + override def merge(aIr: TimestampedIR, bIr: TimestampedIR): TimestampedIR = + TimestampedIR( + rowAggregator.merge(aIr.ir, bIr.ir), + aIr.latestTsMillis + .flatMap(aL => bIr.latestTsMillis.map(bL => Math.max(aL, bL))) + .orElse(aIr.latestTsMillis.orElse(bIr.latestTsMillis)) + ) + + def toChrononRow(value: Map[String, Any], tsMills: Long): Row = { + // The row values need to be in the same order as the input schema columns + // The reason they are out of order in the first place is because the CatalystUtil does not return values in the + // same order as the schema + val values: Array[Any] = valueColumns.map(value(_)) + new ArrayRow(values, tsMills) + } +} + +/** + * TimestampedTile combines the entity keys, the encoded Intermediate Result, and the timestamp of the event being processed. + * + * We need the timestamp of the event processed so we can calculate processing lag down the line. + * + * @param keys the GroupBy entity keys + * @param tileBytes encoded tile IR + * @param latestTsMillis timestamp of the current event being processed + */ +case class TimestampedTile( + keys: List[Any], + tileBytes: Array[Byte], + latestTsMillis: Long +) + +// This process function is only meant to be used downstream of the ChrononFlinkAggregationFunction +class FlinkRowAggProcessFunction( + groupBy: GroupBy, + inputSchema: Seq[(String, DataType)] +) extends ProcessWindowFunction[TimestampedIR, TimestampedTile, List[Any], TimeWindow] { + + @transient private[flink] var tileCodec: TileCodec = _ + @transient lazy val logger = LoggerFactory.getLogger(getClass) + + @transient private var rowProcessingErrorCounter: Counter = _ + @transient private var eventProcessingErrorCounter: Counter = + _ // Shared metric for errors across the entire Flink app. + + override def open(parameters: Configuration): Unit = { + super.open(parameters) + tileCodec = new TileCodec(groupBy, inputSchema) + + val metricsGroup = getRuntimeContext.getMetricGroup + .addGroup("chronon") + .addGroup("feature_group", groupBy.getMetaData.getName) + rowProcessingErrorCounter = metricsGroup.counter("tiling_process_function_error") + eventProcessingErrorCounter = metricsGroup.counter("event_processing_error") + } + + /** + * Process events emitted from the aggregate function. + * Output format: (keys, encoded tile IR, timestamp of the event being processed) + * */ + override def process( + keys: List[Any], + context: Context, + elements: Iterable[TimestampedIR], + out: Collector[TimestampedTile] + ): Unit = { + val windowEnd = context.window.getEnd + val irEntry = elements.head + val isComplete = context.currentWatermark >= windowEnd + + val tileBytes = Try { + tileCodec.makeTileIr(irEntry.ir, isComplete) + } + + tileBytes match { + case Success(v) => { + logger.debug( + s""" + |Flink aggregator processed element irEntry=$irEntry + |tileBytes=${java.util.Base64.getEncoder.encodeToString(v)} + |windowEnd=$windowEnd groupBy=${groupBy.getMetaData.getName} + |keys=$keys isComplete=$isComplete tileAvroSchema=${tileCodec.tileAvroSchema}""" + ) + // The timestamp should never be None here. + out.collect(TimestampedTile(keys, v, irEntry.latestTsMillis.get)) + } + case Failure(e) => + // To improve availability, we don't rethrow the exception. We just drop the event + // and track the errors in a metric. Alerts should be set up on this metric. + logger.error(s"Flink process error making tile IR", e) + eventProcessingErrorCounter.inc() + rowProcessingErrorCounter.inc() + } + } +} diff --git a/flink/src/main/scala/ai/chronon/flink/window/KeySelector.scala b/flink/src/main/scala/ai/chronon/flink/window/KeySelector.scala new file mode 100644 index 000000000..900d8bebd --- /dev/null +++ b/flink/src/main/scala/ai/chronon/flink/window/KeySelector.scala @@ -0,0 +1,33 @@ +package ai.chronon.flink.window + +import ai.chronon.api.GroupBy + +import scala.jdk.CollectionConverters._ +import org.slf4j.LoggerFactory + +/** + * A KeySelector is what Flink uses to determine how to partition a DataStream. In a distributed environment, the + * KeySelector guarantees that events with the same key always end up in the same machine. + * If invoked multiple times on the same object, the returned key must be the same. + */ +object KeySelector { + private[this] lazy val logger = LoggerFactory.getLogger(getClass) + + /** + * Given a GroupBy, create a function to key the output of a SparkExprEval operator by the entities defined in the + * GroupBy. The function returns a List of size equal to the number of keys in the GroupBy. + * + * For example, if a GroupBy is defined as "GroupBy(..., keys=["color", "size"], ...), the function will key the + * Flink SparkExprEval DataStream by color and size, so all events with the same (color, size) are sent to the same + * operator. + */ + def getKeySelectionFunction(groupBy: GroupBy): Map[String, Any] => List[Any] = { + // List uses MurmurHash.seqHash for its .hashCode(), which gives us hashing based on content. + // (instead of based on the instance, which is the case for Array). + val groupByKeys: List[String] = groupBy.keyColumns.asScala.toList + logger.info( + f"Creating key selection function for Flink app. groupByKeys=$groupByKeys" + ) + (sparkEvalOutput: Map[String, Any]) => groupByKeys.collect(sparkEvalOutput) + } +} diff --git a/flink/src/main/scala/ai/chronon/flink/window/Trigger.scala b/flink/src/main/scala/ai/chronon/flink/window/Trigger.scala new file mode 100644 index 000000000..f72dddbe6 --- /dev/null +++ b/flink/src/main/scala/ai/chronon/flink/window/Trigger.scala @@ -0,0 +1,180 @@ +package ai.chronon.flink.window + +import org.apache.flink.api.common.state.{ValueState, ValueStateDescriptor} +import org.apache.flink.streaming.api.windowing.triggers.{Trigger, TriggerResult} +import org.apache.flink.streaming.api.windowing.windows.TimeWindow + +/** + * Custom Flink Trigger that fires on every event received. + * */ +class AlwaysFireOnElementTrigger extends Trigger[Map[String, Any], TimeWindow] { + override def onElement( + element: Map[String, Any], + timestamp: Long, + window: TimeWindow, + ctx: Trigger.TriggerContext + ): TriggerResult = + TriggerResult.FIRE + + override def onProcessingTime( + time: Long, + window: TimeWindow, + ctx: Trigger.TriggerContext + ): TriggerResult = + TriggerResult.CONTINUE + + override def onEventTime( + time: Long, + window: TimeWindow, + ctx: Trigger.TriggerContext + ): TriggerResult = + // We don't need to PURGE here since we don't have explicit state. + // Flink's "Window Lifecycle" doc: "The window is completely removed when the time (event or processing time) + // passes its end timestamp plus the user-specified allowed lateness" + TriggerResult.CONTINUE + + // This Trigger doesn't hold state, so we don't need to do anything when the window is purged. + override def clear( + window: TimeWindow, + ctx: Trigger.TriggerContext + ): Unit = {} + + override def canMerge: Boolean = true + + override def onMerge( + window: TimeWindow, + mergeContext: Trigger.OnMergeContext + ): Unit = {} +} + +/** + * BufferedProcessingTimeTrigger is a custom Trigger that fires at most every 'bufferSizeMillis' within a window. + * It is intended for incremental window aggregations using event-time semantics. + * + * Purpose: This trigger exists as an optimization to reduce the number of writes to our online store and better handle + * contention that arises from having hot keys. + * + * Details: + * - The buffer timers are NOT aligned with the UNIX Epoch, they can fire at any timestamp. e.g., if the first + * event arrives at 14ms, and the buffer size is 100ms, the timer will fire at 114ms. + * - Buffer timers are only scheduled when events come in. If there's a gap in events, this trigger won't fire. + * + * Edge cases handled: + * - If the (event-time) window closes before the last (processing-time) buffer fires, this trigger will fire + * the remaining buffered elements before closing. + * + * Example: + * Window size = 300,000 ms (5 minutes) + * BufferSizeMillis = 100 ms. + * Assume we are using this trigger on a GroupBy that counts the number unique IDs see. + * For simplicity, assume event time and processing time are synchronized (although in practice this is never true) + * + * Event 1: ts = 14 ms, ID = A. + * preAggregate (a Set that keeps track of all unique IDs seen) = [A] + * this causes a timer to be set for timestamp = 114 ms. + * Event 2: ts = 38 ms, ID = B. + * preAggregate = [A, B] + * Event 3: ts = 77 ms, ID = B. + * preAggregate = [A, B] + * Timer set for 114ms fires. + * we emit the preAggregate [A, B]. + * Event 4: ts = 400ms, ID = C. + * preAggregate = [A,B,C] (we don't purge the previous events when the time fires!) + * this causes a timer to be set for timestamp = 500 ms + * Timer set for 500ms fires. + * we emit the preAggregate [A, B, C]. + * */ +class BufferedProcessingTimeTrigger(bufferSizeMillis: Long) extends Trigger[Map[String, Any], TimeWindow] { + // Each pane has its own state. A Flink pane is an actual instance of a defined window for a given key. + private val nextTimerTimestampStateDescriptor = + new ValueStateDescriptor[java.lang.Long]("nextTimerTimestampState", classOf[java.lang.Long]) + + /** + * When an element arrives, set up a processing time trigger to fire after `bufferSizeMillis`. + * If a timer is already set, we don't want to create a new one. + * + * Late events are treated the same way as regular events; they will still get buffered. + */ + override def onElement( + element: Map[String, Any], + timestamp: Long, + window: TimeWindow, + ctx: Trigger.TriggerContext + ): TriggerResult = { + val nextTimerTimestampState: ValueState[java.lang.Long] = ctx.getPartitionedState( + nextTimerTimestampStateDescriptor + ) + + // Set timer if one doesn't already exist + if (nextTimerTimestampState.value() == null) { + val nextFireTimestampMillis = ctx.getCurrentProcessingTime + bufferSizeMillis + ctx.registerProcessingTimeTimer(nextFireTimestampMillis) + nextTimerTimestampState.update(nextFireTimestampMillis) + } + + TriggerResult.CONTINUE + } + + /** + * When the processing-time timer set up in `onElement` fires, we emit the results without purging the window. + * i.e., we keep the current pre-aggregates/IRs in the window so we can continue aggregating. + * + * Note: We don't need to PURGE the window anywhere. Flink will do that automatically when a window expires. + * Flink Docs: "[...] Flink keeps the state of windows until their allowed lateness expires. Once this happens, Flink + * removes the window and deletes its state [...]". + * + * Note: In case the app crashes after a processing-time timer is set, but before it fires, it will fire immediately + * after recovery. + */ + override def onProcessingTime( + timestamp: Long, + window: TimeWindow, + ctx: Trigger.TriggerContext + ): TriggerResult = { + val nextTimerTimestampState = ctx.getPartitionedState(nextTimerTimestampStateDescriptor) + nextTimerTimestampState.update(null) + TriggerResult.FIRE + } + + /** + * Fire any elements left in the buffer if the window ends before the last processing-time timer is fired. + * This can happen because we are using event-time semantics for the window, and processing-time for the buffer timer. + * + * Flink automatically sets up an event timer for the end of the window (+ allowed lateness) as soon as it + * sees the first element in it. See 'registerCleanupTimer' in Flink's 'WindowOperator.java'. + */ + override def onEventTime( + timestamp: Long, + window: TimeWindow, + ctx: Trigger.TriggerContext + ): TriggerResult = { + val nextTimerTimestampState: ValueState[java.lang.Long] = ctx.getPartitionedState( + nextTimerTimestampStateDescriptor + ) + if (nextTimerTimestampState.value() != null) { + TriggerResult.FIRE + } else { + TriggerResult.CONTINUE + } + } + + /** + * When a window is being purged (e.g., because it has expired), we delete timers and state. + * + * This function is called immediately after our 'onEventTime' which fires at the end of the window. + * See 'onEventTime' in Flink's 'WindowOperator.java'. + */ + override def clear(window: TimeWindow, ctx: Trigger.TriggerContext): Unit = { + // Remove the lingering processing-time timer if it exist. + val nextTimerTimestampState: ValueState[java.lang.Long] = ctx.getPartitionedState( + nextTimerTimestampStateDescriptor + ) + val nextTimerTimestampStateValue = nextTimerTimestampState.value() + if (nextTimerTimestampStateValue != null) { + ctx.deleteProcessingTimeTimer(nextTimerTimestampStateValue) + } + + // Delete state + nextTimerTimestampState.clear() + } +} diff --git a/flink/src/test/scala/ai/chronon/flink/test/FlinkJobIntegrationTest.scala b/flink/src/test/scala/ai/chronon/flink/test/FlinkJobIntegrationTest.scala index 3c04014cb..83f4bd55d 100644 --- a/flink/src/test/scala/ai/chronon/flink/test/FlinkJobIntegrationTest.scala +++ b/flink/src/test/scala/ai/chronon/flink/test/FlinkJobIntegrationTest.scala @@ -1,44 +1,20 @@ package ai.chronon.flink.test -import ai.chronon.api.Extensions.{WindowOps, WindowUtils} -import ai.chronon.api.{GroupBy, GroupByServingInfo, PartitionSpec} -import ai.chronon.flink.{FlinkJob, FlinkSource, SparkExpressionEvalFn, WriteResponse} -import ai.chronon.online.Extensions.StructTypeOps +import ai.chronon.flink.window.{TimestampedIR, TimestampedTile} +import ai.chronon.flink.{FlinkJob, SparkExpressionEvalFn} import ai.chronon.online.{Api, GroupByServingInfoParsed} +import ai.chronon.online.KVStore.PutRequest import org.apache.flink.runtime.testutils.MiniClusterResourceConfiguration -import org.apache.flink.streaming.api.scala.{DataStream, StreamExecutionEnvironment} -import org.apache.flink.api.scala._ -import org.apache.flink.streaming.api.functions.sink.SinkFunction +import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment import org.apache.flink.test.util.MiniClusterWithClientResource import org.apache.spark.sql.Encoders -import org.apache.spark.sql.types.StructType import org.junit.Assert.assertEquals import org.junit.{After, Before, Test} import org.mockito.Mockito.withSettings import org.scalatestplus.mockito.MockitoSugar.mock -import java.util -import java.util.Collections import scala.jdk.CollectionConverters.asScalaBufferConverter -class E2EEventSource(mockEvents: Seq[E2ETestEvent]) extends FlinkSource[E2ETestEvent] { - override def getDataStream(topic: String, groupName: String)(env: StreamExecutionEnvironment, - parallelism: Int): DataStream[E2ETestEvent] = { - env.fromCollection(mockEvents) - } -} - -class CollectSink extends SinkFunction[WriteResponse] { - override def invoke(value: WriteResponse, context: SinkFunction.Context): Unit = { - CollectSink.values.add(value) - } -} - -object CollectSink { - // must be static - val values: util.List[WriteResponse] = Collections.synchronizedList(new util.ArrayList()) -} - class FlinkJobIntegrationTest { val flinkCluster = new MiniClusterWithClientResource( @@ -47,6 +23,30 @@ class FlinkJobIntegrationTest { .setNumberTaskManagers(1) .build) + // Decode a PutRequest into a TimestampedTile + def avroConvertPutRequestToTimestampedTile[T]( + in: PutRequest, + groupByServingInfoParsed: GroupByServingInfoParsed + ): TimestampedTile = { + // Decode the key bytes into a GenericRecord + val tileBytes = in.valueBytes + val record = groupByServingInfoParsed.keyCodec.decode(in.keyBytes) + + // Get all keys we expect to be in the GenericRecord + val decodedKeys: List[String] = + groupByServingInfoParsed.groupBy.keyColumns.asScala.map(record.get(_).toString).toList + + val tsMills = in.tsMillis.get + TimestampedTile(decodedKeys, tileBytes, tsMills) + } + + // Decode a TimestampedTile into a TimestampedIR + def avroConvertTimestampedTileToTimestampedIR(timestampedTile: TimestampedTile, + groupByServingInfoParsed: GroupByServingInfoParsed): TimestampedIR = { + val tileIR = groupByServingInfoParsed.tiledCodec.decodeTileIr(timestampedTile.tileBytes) + TimestampedIR(tileIR._1, Some(timestampedTile.latestTsMillis)) + } + @Before def setup(): Unit = { flinkCluster.before() @@ -56,45 +56,7 @@ class FlinkJobIntegrationTest { @After def teardown(): Unit = { flinkCluster.after() - } - - private def makeTestGroupByServingInfoParsed(groupBy: GroupBy, - inputSchema: StructType, - outputSchema: StructType): GroupByServingInfoParsed = { - val groupByServingInfo = new GroupByServingInfo() - groupByServingInfo.setGroupBy(groupBy) - - // Set input avro schema for groupByServingInfo - groupByServingInfo.setInputAvroSchema( - inputSchema.toAvroSchema("Input").toString(true) - ) - - // Set key avro schema for groupByServingInfo - groupByServingInfo.setKeyAvroSchema( - StructType( - groupBy.keyColumns.asScala.map { keyCol => - val keyColStructType = outputSchema.fields.find(field => field.name == keyCol) - keyColStructType match { - case Some(col) => col - case None => - throw new IllegalArgumentException(s"Missing key col from output schema: $keyCol") - } - } - ).toAvroSchema("Key") - .toString(true) - ) - - // Set value avro schema for groupByServingInfo - val aggInputColNames = groupBy.aggregations.asScala.map(_.inputColumn).toList - groupByServingInfo.setSelectedAvroSchema( - StructType(outputSchema.fields.filter(field => aggInputColNames.contains(field.name))) - .toAvroSchema("Value") - .toString(true) - ) - new GroupByServingInfoParsed( - groupByServingInfo, - PartitionSpec(format = "yyyy-MM-dd", spanMillis = WindowUtils.Day.millis) - ) + CollectSink.values.clear() } @Test @@ -113,9 +75,10 @@ class FlinkJobIntegrationTest { val outputSchema = new SparkExpressionEvalFn(encoder, groupBy).getOutputSchema - val groupByServingInfoParsed = makeTestGroupByServingInfoParsed(groupBy, encoder.schema, outputSchema) + val groupByServingInfoParsed = + FlinkTestUtils.makeTestGroupByServingInfoParsed(groupBy, encoder.schema, outputSchema) val mockApi = mock[Api](withSettings().serializable()) - val writerFn = new MockAsyncKVStoreWriter(Seq(true), mockApi, "testFG") + val writerFn = new MockAsyncKVStoreWriter(Seq(true), mockApi, "testFlinkJobEndToEndFG") val job = new FlinkJob[E2ETestEvent](source, writerFn, groupByServingInfoParsed, encoder, 2) job.runGroupByJob(env).addSink(new CollectSink) @@ -132,4 +95,67 @@ class FlinkJobIntegrationTest { // check that all the writes were successful assertEquals(writeEventCreatedDS.map(_.status), Seq(true, true, true)) } + + @Test + def testTiledFlinkJobEndToEnd(): Unit = { + implicit val env: StreamExecutionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment + + // Create some test events with multiple different ids so we can check if tiling/pre-aggregation works correctly + // for each of them. + val id1Elements = Seq(E2ETestEvent(id = "id1", int_val = 1, double_val = 1.5, created = 1L), + E2ETestEvent(id = "id1", int_val = 1, double_val = 2.5, created = 2L)) + val id2Elements = Seq(E2ETestEvent(id = "id2", int_val = 1, double_val = 10.0, created = 3L)) + val elements: Seq[E2ETestEvent] = id1Elements ++ id2Elements + val source = new WatermarkedE2EEventSource(elements) + + // Make a GroupBy that SUMs the double_val of the elements. + val groupBy = FlinkTestUtils.makeGroupBy(Seq("id")) + + // Prepare the Flink Job + val encoder = Encoders.product[E2ETestEvent] + val outputSchema = new SparkExpressionEvalFn(encoder, groupBy).getOutputSchema + val groupByServingInfoParsed = + FlinkTestUtils.makeTestGroupByServingInfoParsed(groupBy, encoder.schema, outputSchema) + val mockApi = mock[Api](withSettings().serializable()) + val writerFn = new MockAsyncKVStoreWriter(Seq(true), mockApi, "testTiledFlinkJobEndToEndFG") + val job = new FlinkJob[E2ETestEvent](source, writerFn, groupByServingInfoParsed, encoder, 2) + job.runTiledGroupByJob(env).addSink(new CollectSink) + + env.execute("TiledFlinkJobIntegrationTest") + + // capture the datastream of the 'created' timestamps of all the written out events + val writeEventCreatedDS = CollectSink.values.asScala + + // BASIC ASSERTIONS + // All elements were processed + assert(writeEventCreatedDS.size == elements.size) + // check that the timestamps of the written out events match the input events + // we use a Set as we can have elements out of order given we have multiple tasks + assertEquals(writeEventCreatedDS.map(_.putRequest.tsMillis).map(_.get).toSet, elements.map(_.created).toSet) + // check that all the writes were successful + assertEquals(writeEventCreatedDS.map(_.status), Seq(true, true, true)) + + // Assert that the pre-aggregates/tiles are correct + // Get a list of the final IRs for each key. + val finalIRsPerKey: Map[List[Any], List[Any]] = writeEventCreatedDS + .map(writeEvent => { + // First, we work back from the PutRequest decode it to TimestampedTile and then TimestampedIR + val timestampedTile = + avroConvertPutRequestToTimestampedTile(writeEvent.putRequest, groupByServingInfoParsed) + val timestampedIR = avroConvertTimestampedTileToTimestampedIR(timestampedTile, groupByServingInfoParsed) + + // We're interested in the the keys, Intermediate Result, and the timestamp for each processed event + (timestampedTile.keys, timestampedIR.ir.toList, writeEvent.putRequest.tsMillis.get) + }) + .groupBy(_._1) // Group by the keys + .map((keys) => (keys._1, keys._2.maxBy(_._3)._2)) // pick just the events with largest timestamp + + // Looking back at our test events, we expect the following Intermediate Results to be generated: + val expectedFinalIRsPerKey = Map( + List("id1") -> List(4.0), // Add up the double_val of the two 'id1' events + List("id2") -> List(10.0) + ) + + assertEquals(expectedFinalIRsPerKey, finalIRsPerKey) + } } diff --git a/flink/src/test/scala/ai/chronon/flink/test/FlinkTestUtils.scala b/flink/src/test/scala/ai/chronon/flink/test/FlinkTestUtils.scala index 790b06f28..ff6a9ae2f 100644 --- a/flink/src/test/scala/ai/chronon/flink/test/FlinkTestUtils.scala +++ b/flink/src/test/scala/ai/chronon/flink/test/FlinkTestUtils.scala @@ -1,19 +1,50 @@ package ai.chronon.flink.test import ai.chronon.api.{Accuracy, Builders, GroupBy, Operation, TimeUnit, Window} -import ai.chronon.flink.AsyncKVStoreWriter +import ai.chronon.flink.{AsyncKVStoreWriter, FlinkSource, WriteResponse} import ai.chronon.online.{Api, KVStore} -import org.apache.flink.api.java.ExecutionEnvironment -import org.apache.flink.configuration.Configuration -import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment +import ai.chronon.api.Extensions.{WindowOps, WindowUtils} +import ai.chronon.api.{GroupByServingInfo, PartitionSpec} +import ai.chronon.online.Extensions.StructTypeOps +import ai.chronon.online.GroupByServingInfoParsed +import org.apache.flink.api.common.eventtime.{SerializableTimestampAssigner, WatermarkStrategy} +import org.apache.flink.api.scala.createTypeInformation +import org.apache.flink.streaming.api.functions.sink.SinkFunction +import org.apache.flink.streaming.api.scala.{DataStream, StreamExecutionEnvironment} +import org.apache.spark.sql.types.StructType import org.mockito.ArgumentMatchers import org.mockito.Mockito.{when, withSettings} import org.scalatestplus.mockito.MockitoSugar.mock +import java.time.Duration +import java.util +import java.util.Collections import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, Future} +import scala.jdk.CollectionConverters.asScalaBufferConverter case class E2ETestEvent(id: String, int_val: Int, double_val: Double, created: Long) +class E2EEventSource(mockEvents: Seq[E2ETestEvent]) extends FlinkSource[E2ETestEvent] { + override def getDataStream(topic: String, groupName: String)(env: StreamExecutionEnvironment, + parallelism: Int): DataStream[E2ETestEvent] = { + env.fromCollection(mockEvents) + } +} + +class WatermarkedE2EEventSource(mockEvents: Seq[E2ETestEvent]) extends FlinkSource[E2ETestEvent] { + def watermarkStrategy: WatermarkStrategy[E2ETestEvent] = + WatermarkStrategy + .forBoundedOutOfOrderness[E2ETestEvent](Duration.ofSeconds(5)) + .withTimestampAssigner(new SerializableTimestampAssigner[E2ETestEvent] { + override def extractTimestamp(event: E2ETestEvent, previousElementTimestamp: Long): Long = + event.created + }) + override def getDataStream(topic: String, groupName: String)(env: StreamExecutionEnvironment, + parallelism: Int): DataStream[E2ETestEvent] = { + env.fromCollection(mockEvents).assignTimestampsAndWatermarks(watermarkStrategy) + } +} + class MockAsyncKVStoreWriter(mockResults: Seq[Boolean], onlineImpl: Api, featureGroup: String) extends AsyncKVStoreWriter(onlineImpl, featureGroup) { override def getKVStore: KVStore = { @@ -25,8 +56,56 @@ class MockAsyncKVStoreWriter(mockResults: Seq[Boolean], onlineImpl: Api, feature } } +class CollectSink extends SinkFunction[WriteResponse] { + override def invoke(value: WriteResponse, context: SinkFunction.Context): Unit = { + CollectSink.values.add(value) + } +} + +object CollectSink { + // must be static + val values: util.List[WriteResponse] = Collections.synchronizedList(new util.ArrayList()) +} + object FlinkTestUtils { + def makeTestGroupByServingInfoParsed(groupBy: GroupBy, + inputSchema: StructType, + outputSchema: StructType): GroupByServingInfoParsed = { + val groupByServingInfo = new GroupByServingInfo() + groupByServingInfo.setGroupBy(groupBy) + // Set input avro schema for groupByServingInfo + groupByServingInfo.setInputAvroSchema( + inputSchema.toAvroSchema("Input").toString(true) + ) + + // Set key avro schema for groupByServingInfo + groupByServingInfo.setKeyAvroSchema( + StructType( + groupBy.keyColumns.asScala.map { keyCol => + val keyColStructType = outputSchema.fields.find(field => field.name == keyCol) + keyColStructType match { + case Some(col) => col + case None => + throw new IllegalArgumentException(s"Missing key col from output schema: $keyCol") + } + } + ).toAvroSchema("Key") + .toString(true) + ) + + // Set value avro schema for groupByServingInfo + val aggInputColNames = groupBy.aggregations.asScala.map(_.inputColumn).toList + groupByServingInfo.setSelectedAvroSchema( + StructType(outputSchema.fields.filter(field => aggInputColNames.contains(field.name))) + .toAvroSchema("Value") + .toString(true) + ) + new GroupByServingInfoParsed( + groupByServingInfo, + PartitionSpec(format = "yyyy-MM-dd", spanMillis = WindowUtils.Day.millis) + ) + } def makeGroupBy(keyColumns: Seq[String], filters: Seq[String] = Seq.empty): GroupBy = Builders.GroupBy( sources = Seq( diff --git a/flink/src/test/scala/ai/chronon/flink/test/window/FlinkRowAggregationFunctionTest.scala b/flink/src/test/scala/ai/chronon/flink/test/window/FlinkRowAggregationFunctionTest.scala new file mode 100644 index 000000000..e702a2fa1 --- /dev/null +++ b/flink/src/test/scala/ai/chronon/flink/test/window/FlinkRowAggregationFunctionTest.scala @@ -0,0 +1,218 @@ +package ai.chronon.flink.test.window + +import ai.chronon.api._ +import ai.chronon.flink.window.FlinkRowAggregationFunction +import ai.chronon.online.TileCodec +import org.junit.Assert.fail +import org.junit.Test + +import scala.util.{Failure, Try} + +class FlinkRowAggregationFunctionTest { + private val aggregations: Seq[Aggregation] = Seq( + Builders.Aggregation( + Operation.AVERAGE, + "views", + Seq( + new Window(1, TimeUnit.DAYS), + new Window(1, TimeUnit.HOURS), + new Window(30, TimeUnit.DAYS) + ) + ), + Builders.Aggregation( + Operation.AVERAGE, + "rating", + Seq( + new Window(1, TimeUnit.DAYS), + new Window(1, TimeUnit.HOURS) + ) + ), + Builders.Aggregation( + Operation.MAX, + "title", + Seq( + new Window(1, TimeUnit.DAYS) + ) + ), + Builders.Aggregation( + Operation.LAST, + "title", + Seq( + new Window(1, TimeUnit.DAYS) + ) + ) + ) + + private val schema = List( + Constants.TimeColumn -> LongType, + "views" -> IntType, + "rating" -> FloatType, + "title" -> StringType + ) + + @Test + def testFlinkAggregatorProducesCorrectResults(): Unit = { + val groupByMetadata = Builders.MetaData(name = "my_group_by") + val groupBy = Builders.GroupBy(metaData = groupByMetadata, aggregations = aggregations) + val aggregateFunc = new FlinkRowAggregationFunction(groupBy, schema) + + var acc = aggregateFunc.createAccumulator() + val rows = Seq( + createRow(1519862399984L, 4, 4.0f, "A"), + createRow(1519862399984L, 40, 5.0f, "B"), + createRow(1519862399988L, 3, 3.0f, "C"), + createRow(1519862399988L, 5, 4.0f, "D"), + createRow(1519862399994L, 4, 4.0f, "A"), + createRow(1519862399999L, 10, 4.0f, "A") + ) + rows.foreach(row => acc = aggregateFunc.add(row, acc)) + val result = aggregateFunc.getResult(acc) + + // we sanity check the final result of the accumulator + // to do so, we must first expand / decompress the windowed tile IR into a full tile + // then we can finalize the tile and get the final result + val tileCodec = new TileCodec(groupBy, schema) + val expandedIr = tileCodec.expandWindowedTileIr(result.ir) + val finalResult = tileCodec.windowedRowAggregator.finalize(expandedIr) + + // expect 7 columns as we have 3 view avg time windows, 2 rating avg and 1 max title, 1 last title + assert(finalResult.length == 7) + val expectedAvgViews = 11.0f + val expectedAvgRating = 4.0f + val expectedMax = "D" + val expectedLast = "A" + val expectedResult = Array( + expectedAvgViews, + expectedAvgViews, + expectedAvgViews, + expectedAvgRating, + expectedAvgRating, + expectedMax, + expectedLast + ) + assert(finalResult sameElements expectedResult) + } + + @Test + def testFlinkAggregatorResultsCanBeMergedWithOtherPreAggregates(): Unit = { + val groupByMetadata = Builders.MetaData(name = "my_group_by") + val groupBy = Builders.GroupBy(metaData = groupByMetadata, aggregations = aggregations) + val aggregateFunc = new FlinkRowAggregationFunction(groupBy, schema) + + // create partial aggregate 1 + var acc1 = aggregateFunc.createAccumulator() + val rows1 = Seq( + createRow(1519862399984L, 4, 4.0f, "A"), + createRow(1519862399984L, 40, 5.0f, "B") + ) + rows1.foreach(row => acc1 = aggregateFunc.add(row, acc1)) + val partialResult1 = aggregateFunc.getResult(acc1) + + // create partial aggregate 2 + var acc2 = aggregateFunc.createAccumulator() + val rows2 = Seq( + createRow(1519862399988L, 3, 3.0f, "C"), + createRow(1519862399988L, 5, 4.0f, "D") + ) + rows2.foreach(row => acc2 = aggregateFunc.add(row, acc2)) + val partialResult2 = aggregateFunc.getResult(acc2) + + // create partial aggregate 3 + var acc3 = aggregateFunc.createAccumulator() + val rows3 = Seq( + createRow(1519862399994L, 4, 4.0f, "A"), + createRow(1519862399999L, 10, 4.0f, "A") + ) + rows3.foreach(row => acc3 = aggregateFunc.add(row, acc3)) + val partialResult3 = aggregateFunc.getResult(acc3) + + // lets merge the partial results together and check + val mergedPartialAggregates = aggregateFunc.rowAggregator + .merge( + aggregateFunc.rowAggregator.merge(partialResult1.ir, partialResult2.ir), + partialResult3.ir + ) + + // we sanity check the final result of the accumulator + // to do so, we must first expand / decompress the windowed tile IR into a full tile + // then we can finalize the tile and get the final result + val tileCodec = new TileCodec(groupBy, schema) + val expandedIr = tileCodec.expandWindowedTileIr(mergedPartialAggregates) + val finalResult = tileCodec.windowedRowAggregator.finalize(expandedIr) + + // expect 7 columns as we have 3 view avg time windows, 2 rating avg and 1 max title, 1 last title + assert(finalResult.length == 7) + val expectedAvgViews = 11.0f + val expectedAvgRating = 4.0f + val expectedMax = "D" + val expectedLast = "A" + val expectedResult = Array( + expectedAvgViews, + expectedAvgViews, + expectedAvgViews, + expectedAvgRating, + expectedAvgRating, + expectedMax, + expectedLast + ) + assert(finalResult sameElements expectedResult) + } + + @Test + def testFlinkAggregatorProducesCorrectResultsIfInputIsInIncorrectOrder(): Unit = { + val groupByMetadata = Builders.MetaData(name = "my_group_by") + val groupBy = Builders.GroupBy(metaData = groupByMetadata, aggregations = aggregations) + val aggregateFunc = new FlinkRowAggregationFunction(groupBy, schema) + + var acc = aggregateFunc.createAccumulator() + + // Create a map where the entries are not in the same order as `schema`. + val outOfOrderRow = Map[String, Any]( + "rating" -> 4.0f, + Constants.TimeColumn -> 1519862399999L, + "title" -> "A", + "views" -> 10 + ) + + // If the aggregator fails to fix the order, we'll get a ClassCastException + Try { + acc = aggregateFunc.add(outOfOrderRow, acc) + } match { + case Failure(e) => { + fail( + s"An exception was thrown by the aggregator when it should not have been. " + + s"The aggregator should fix the order without failing. $e") + } + case _ => + } + + val result = aggregateFunc.getResult(acc) + + // we sanity check the final result of the accumulator + // to do so, we must first expand / decompress the windowed tile IR into a full tile + // then we can finalize the tile and get the final result + val tileCodec = new TileCodec(groupBy, schema) + val expandedIr = tileCodec.expandWindowedTileIr(result.ir) + val finalResult = tileCodec.windowedRowAggregator.finalize(expandedIr) + assert(finalResult.length == 7) + + val expectedResult = Array( + outOfOrderRow("views"), + outOfOrderRow("views"), + outOfOrderRow("views"), + outOfOrderRow("rating"), + outOfOrderRow("rating"), + outOfOrderRow("title"), + outOfOrderRow("title") + ) + assert(finalResult sameElements expectedResult) + } + + def createRow(ts: Long, views: Int, rating: Float, title: String): Map[String, Any] = + Map( + Constants.TimeColumn -> ts, + "views" -> views, + "rating" -> rating, + "title" -> title + ) +} diff --git a/flink/src/test/scala/ai/chronon/flink/test/window/KeySelectorTest.scala b/flink/src/test/scala/ai/chronon/flink/test/window/KeySelectorTest.scala new file mode 100644 index 000000000..b81c39aab --- /dev/null +++ b/flink/src/test/scala/ai/chronon/flink/test/window/KeySelectorTest.scala @@ -0,0 +1,58 @@ +package ai.chronon.flink.test.window + +import ai.chronon.api.Builders +import ai.chronon.flink.window.KeySelector +import org.junit.Test + +class KeySelectorTest { + @Test + def TestChrononFlinkJobCorrectlyKeysByAGroupbysEntityKeys(): Unit = { + // We expect something like this to come out of the SparkExprEval operator + val sampleSparkExprEvalOutput: Map[String, Any] = + Map("number" -> 4242, "ip" -> "192.168.0.1", "user" -> "abc") + + val groupByWithOneEntityKey = Builders.GroupBy(keyColumns = Seq("number")) + val keyFunctionOne = KeySelector.getKeySelectionFunction(groupByWithOneEntityKey) + assert( + keyFunctionOne(sampleSparkExprEvalOutput) == List(4242) + ) + + val groupByWithTwoEntityKey = Builders.GroupBy(keyColumns = Seq("number", "user")) + val keyFunctionTwo = KeySelector.getKeySelectionFunction(groupByWithTwoEntityKey) + assert( + keyFunctionTwo(sampleSparkExprEvalOutput) == List(4242, "abc") + ) + } + + @Test + def testKeySelectorFunctionReturnsSameHashesForListsWithTheSameContent(): Unit = { + // This is more of a sanity check. It's not comprehensive. + // SINGLE ENTITY KEY + val map1: Map[String, Any] = + Map("number" -> 4242, "ip" -> "192.168.0.1", "user" -> "abc") + val map2: Map[String, Any] = + Map("number" -> 4242, "ip" -> "10.0.0.1", "user" -> "notabc") + val groupBySingleKey = Builders.GroupBy(keyColumns = Seq("number")) + val keyFunctionOne = KeySelector.getKeySelectionFunction(groupBySingleKey) + assert( + keyFunctionOne(map1).hashCode() == keyFunctionOne(map2).hashCode() + ) + + // TWO ENTITY KEYS + val map3: Map[String, Any] = + Map("number" -> 4242, "ip" -> "192.168.0.1", "user" -> "abc") + val map4: Map[String, Any] = + Map("ip" -> "192.168.0.1", "number" -> 4242, "user" -> "notabc") + val groupByTwoKeys = Builders.GroupBy(keyColumns = Seq("number", "ip")) + val keyFunctionTwo = KeySelector.getKeySelectionFunction(groupByTwoKeys) + assert( + keyFunctionTwo(map3).hashCode() == keyFunctionTwo(map4).hashCode() + ) + + val map5: Map[String, Any] = + Map("ip" -> "192.168.0.1", "number" -> null) + val map6: Map[String, Any] = + Map("ip" -> "192.168.0.1", "number" -> null) + assert(keyFunctionTwo(map5).hashCode() == keyFunctionTwo(map6).hashCode()) + } +}