diff --git a/api/src/main/scala/ai/chronon/api/Constants.scala b/api/src/main/scala/ai/chronon/api/Constants.scala index e72152bb5..3b4fedf4f 100644 --- a/api/src/main/scala/ai/chronon/api/Constants.scala +++ b/api/src/main/scala/ai/chronon/api/Constants.scala @@ -63,4 +63,5 @@ object Constants { val LabelViewPropertyFeatureTable: String = "feature_table" val LabelViewPropertyKeyLabelTable: String = "label_table" val ChrononRunDs: String = "CHRONON_RUN_DS" + val SmallJoinCutoff: Int = 5000 } diff --git a/spark/src/main/scala/ai/chronon/spark/Join.scala b/spark/src/main/scala/ai/chronon/spark/Join.scala index c7984ac2e..6f8759d5b 100644 --- a/spark/src/main/scala/ai/chronon/spark/Join.scala +++ b/spark/src/main/scala/ai/chronon/spark/Join.scala @@ -16,6 +16,8 @@ package ai.chronon.spark +import java.util + import org.slf4j.LoggerFactory import ai.chronon.api import ai.chronon.api.Extensions._ @@ -37,6 +39,8 @@ import scala.jdk.CollectionConverters.{asJavaIterableConverter, asScalaBufferCon import scala.util.ScalaJavaConversions.{IterableOps, ListOps, MapOps} import scala.util.{Failure, Success} +import ai.chronon.api.Constants.SmallJoinCutoff + /* * hashes: a list containing bootstrap hashes that represent the list of bootstrap parts that a record has matched * during the bootstrap join @@ -186,6 +190,16 @@ class Join(joinConf: api.Join, coveringSetsPerJoinPart } + def getAllLeftSideKeyNames(): Seq[String] = { + joinConf.getJoinParts.asScala.flatMap { joinPart => + if (joinPart.keyMapping != null) { + joinPart.keyMapping.asScala.keys.toSeq + } else { + joinPart.groupBy.getKeyColumns.asScala + } + } + } + def injectKeyFilter(leftDf: DataFrame, joinPart: api.JoinPart): Unit = { // Modifies the joinPart to inject the key filter into the @@ -193,7 +207,7 @@ class Join(joinConf: api.Join, // In case the joinPart uses a keymapping val leftSideKeyNames: Map[String, String] = if (joinPart.keyMapping != null) { - joinPart.keyMapping.asScala.toMap + joinPart.rightToLeft } else { groupByKeyNames.map { k => (k, k) @@ -210,9 +224,12 @@ class Join(joinConf: api.Join, val joinSelects: Map[String, String] = Option(joinConf.left.rootQuery.getQuerySelects).getOrElse(Map.empty[String, String]) groupByKeyExpressions.map{ case (keyName, groupByKeyExpression) => + println("---------------------------------------") + println(s"Left side keynames ${leftSideKeyNames.mkString(",")}") + println(s"keyName: $keyName, expression: $groupByKeyExpressions") + println("---------------------------------------") val leftSideKeyName = leftSideKeyNames.get(keyName).get - val leftSelectExpression = joinSelects.getOrElse(leftSideKeyName, keyName) - val values = leftDf.select(leftSelectExpression).collect().map(row => row(0)) + val values = leftDf.select(leftSideKeyName).collect().map(row => row(0)) // Check for null keys, warn if found, err if all null val (notNullValues, nullValues) = values.partition(_ != null) @@ -230,11 +247,15 @@ class Join(joinConf: api.Join, // Form the final WHERE clause for injection s"$groupByKeyExpression in (${valueSet.mkString(sep = ",")})" - }.foreach(source.rootQuery.getWheres.add(_)) + }.foreach { whereClause => + val currentWheres = Option(source.rootQuery.getWheres).getOrElse(new util.ArrayList[String]()) + currentWheres.add(whereClause) + source.rootQuery.setWheres(currentWheres) + } } } - override def computeRange(leftDf: DataFrame, leftRange: PartitionRange, bootstrapInfo: BootstrapInfo): DataFrame = { + override def computeRange(leftDf: DataFrame, leftRange: PartitionRange, bootstrapInfo: BootstrapInfo, runSmallMode: Boolean = false): DataFrame = { val leftTaggedDf = if (leftDf.schema.names.contains(Constants.TimeColumn)) { leftDf.withTimeBasedColumn(Constants.TimePartitionColumn) } else { @@ -251,7 +272,7 @@ class Join(joinConf: api.Join, val bootstrapCoveringSets = findBootstrapSetCoverings(bootstrapDf, bootstrapInfo, leftRange) // compute a single bloomfilter at join level if there is no bootstrap operation - val joinLevelBloomMapOpt = if (bootstrapDf.columns.contains(Constants.MatchedHashes)) { + lazy val joinLevelBloomMapOpt = if (bootstrapDf.columns.contains(Constants.MatchedHashes)) { // do not compute if any bootstrap is involved None } else { @@ -266,8 +287,16 @@ class Join(joinConf: api.Join, } } + val parallellism = if (runSmallMode) { + // Max out parallelism + joinConf.getJoinParts.asScala.length + } else { + tableUtils.joinPartParallelism + } + implicit val executionContext: ExecutionContextExecutorService = - ExecutionContext.fromExecutorService(Executors.newFixedThreadPool(tableUtils.joinPartParallelism)) + ExecutionContext.fromExecutorService(Executors.newFixedThreadPool(parallellism)) + val joinedDfTry = tableUtils .wrapWithCache("Computing left parts for bootstrap table", bootstrapDf) { @@ -301,10 +330,14 @@ class Join(joinConf: api.Join, s"Macro ${Constants.ChrononRunDs} is only supported for single day join, current range is ${leftRange}") } - // If left DF is small, hardcode the key filter into the joinPart's GroupBy's where clause. - if (unfilledLeftDf.isDefined && unfilledLeftDf.get.df.) - val df = - computeRightTable(unfilledLeftDf, joinPart, leftRange, joinLevelBloomMapOpt).map(df => joinPart -> df) + val (bloomFilterOpt, skipFilter) = if (runSmallMode) { + // If left DF is small, hardcode the key filter into the joinPart's GroupBy's where clause. + injectKeyFilter(leftDf, joinPart) + (None, true) + } else { + (joinLevelBloomMapOpt, false) + } + val df = computeRightTable(unfilledLeftDf, joinPart, leftRange, bloomFilterOpt, skipFilter).map(df => joinPart -> df) Thread.currentThread().setName(s"done-$threadName") df } diff --git a/spark/src/main/scala/ai/chronon/spark/JoinBase.scala b/spark/src/main/scala/ai/chronon/spark/JoinBase.scala index 12664ab03..7c59d14f4 100644 --- a/spark/src/main/scala/ai/chronon/spark/JoinBase.scala +++ b/spark/src/main/scala/ai/chronon/spark/JoinBase.scala @@ -28,11 +28,13 @@ import com.google.gson.Gson import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.util.sketch.BloomFilter - import java.time.Instant + import scala.collection.JavaConverters._ import scala.collection.Seq +import ai.chronon.api.Constants.SmallJoinCutoff + abstract class JoinBase(joinConf: api.Join, endPartition: String, tableUtils: TableUtils, @@ -117,13 +119,14 @@ abstract class JoinBase(joinConf: api.Join, def computeRightTable(leftDf: Option[DfWithStats], joinPart: JoinPart, leftRange: PartitionRange, - joinLevelBloomMapOpt: Option[Map[String, BloomFilter]]): Option[DataFrame] = { + joinLevelBloomMapOpt: Option[Map[String, BloomFilter]], + skipBloom: Boolean = false): Option[DataFrame] = { val partTable = joinConf.partOutputTable(joinPart) val partMetrics = Metrics.Context(metrics, joinPart) if (joinPart.groupBy.aggregations == null) { // for non-aggregation cases, we directly read from the source table and there is no intermediate join part table - computeJoinPart(leftDf, joinPart, joinLevelBloomMapOpt) + computeJoinPart(leftDf, joinPart, joinLevelBloomMapOpt, skipBloom) } else { // in Events <> batch GB case, the partition dates are offset by 1 val shiftDays = @@ -145,15 +148,19 @@ abstract class JoinBase(joinConf: api.Join, skipFirstHole = false ) .getOrElse(Seq()) - val partitionCount = unfilledRanges.map(_.partitions.length).sum + + // todo: undo this, just for debugging + val unfilledRangeCombined = Seq(PartitionRange(unfilledRanges.minBy(_.start).start, unfilledRanges.maxBy(_.end).end)(tableUtils)) + + val partitionCount = unfilledRangeCombined.map(_.partitions.length).sum if (partitionCount > 0) { val start = System.currentTimeMillis() - unfilledRanges + unfilledRangeCombined .foreach(unfilledRange => { val leftUnfilledRange = unfilledRange.shift(-shiftDays) val prunedLeft = leftDf.flatMap(_.prunePartitions(leftUnfilledRange)) val filledDf = - computeJoinPart(prunedLeft, joinPart, joinLevelBloomMapOpt) + computeJoinPart(prunedLeft, joinPart, joinLevelBloomMapOpt, skipBloom) // Cache join part data into intermediate table if (filledDf.isDefined) { logger.info(s"Writing to join part table: $partTable for partition range $unfilledRange") @@ -182,7 +189,8 @@ abstract class JoinBase(joinConf: api.Join, def computeJoinPart(leftDfWithStats: Option[DfWithStats], joinPart: JoinPart, - joinLevelBloomMapOpt: Option[Map[String, BloomFilter]]): Option[DataFrame] = { + joinLevelBloomMapOpt: Option[Map[String, BloomFilter]], + skipBloom: Boolean = false): Option[DataFrame] = { if (leftDfWithStats.isEmpty) { // happens when all rows are already filled by bootstrap tables @@ -196,14 +204,17 @@ abstract class JoinBase(joinConf: api.Join, logger.info( s"\nBackfill is required for ${joinPart.groupBy.metaData.name} for $rowCount rows on range $unfilledRange") - val rightBloomMap = + val rightBloomMap = if (skipBloom) { + None + } else { JoinUtils.genBloomFilterIfNeeded(leftDf, - joinPart, - joinConf, - rowCount, - unfilledRange, - tableUtils, - joinLevelBloomMapOpt) + joinPart, + joinConf, + rowCount, + unfilledRange, + tableUtils, + joinLevelBloomMapOpt) + } val rightSkewFilter = joinConf.partSkewFilter(joinPart) def genGroupBy(partitionRange: PartitionRange) = GroupBy.from(joinPart.groupBy, @@ -286,7 +297,7 @@ abstract class JoinBase(joinConf: api.Join, Some(rightDfWithDerivations) } - def computeRange(leftDf: DataFrame, leftRange: PartitionRange, bootstrapInfo: BootstrapInfo): DataFrame + def computeRange(leftDf: DataFrame, leftRange: PartitionRange, bootstrapInfo: BootstrapInfo, runSmallMode: Boolean = false): DataFrame def computeJoin(stepDays: Option[Int] = None, overrideStartPartition: Option[String] = None): DataFrame = { @@ -302,8 +313,8 @@ abstract class JoinBase(joinConf: api.Join, val today = tableUtils.partitionSpec.at(System.currentTimeMillis()) val analyzer = new Analyzer(tableUtils, joinConf, today, today, silenceMode = true) try { - analyzer.analyzeJoin(joinConf, validationAssert = true) - metrics.gauge(Metrics.Name.validationSuccess, 1) + //analyzer.analyzeJoin(joinConf, validationAssert = true) + //metrics.gauge(Metrics.Name.validationSuccess, 1) logger.info("Join conf validation succeeded. No error found.") } catch { case ex: AssertionError => @@ -323,7 +334,7 @@ abstract class JoinBase(joinConf: api.Join, // detect holes and chunks to fill // OverrideStartPartition is used to replace the start partition of the join config. This is useful when // 1 - User would like to test run with different start partition - // 2 - User has entity table which is accumulative and only want to run backfill for the latest partition + // 2 - User has entity table which is cumulative and only want to run backfill for the latest partition val rangeToFill = JoinUtils.getRangesToFill(joinConf.left, tableUtils, endPartition, @@ -349,16 +360,36 @@ abstract class JoinBase(joinConf: api.Join, // build bootstrap info once for the entire job val bootstrapInfo = BootstrapInfo.from(joinConf, rangeToFill, tableUtils, leftSchema, mutationScan = mutationScan) - logger.info(s"Join ranges to compute: ${stepRanges.map { _.toString }.pretty}") - stepRanges.zipWithIndex.foreach { + val wholeRange = PartitionRange(unfilledRanges.minBy(_.start).start, unfilledRanges.maxBy(_.end).end)(tableUtils) + + val runSmallMode = { + val thresholdCount = leftDf(joinConf, wholeRange, tableUtils, limit = Some(SmallJoinCutoff + 1)).get.count() + val result = thresholdCount <= SmallJoinCutoff + if (result) { + logger.info(s"Counted $thresholdCount rows, running join in small mode.") + tableUtils.shouldRepartition = false + } else { + logger.info(s"Counted greater than $SmallJoinCutoff rows, proceeding with normal computation.") + } + result + } + + val effectiveRanges = if (runSmallMode) { + Seq(wholeRange) + } else { + stepRanges + } + + logger.info(s"Join ranges to compute: ${effectiveRanges.map { _.toString }.pretty}") + effectiveRanges.zipWithIndex.foreach { case (range, index) => val startMillis = System.currentTimeMillis() - val progress = s"| [${index + 1}/${stepRanges.size}]" + val progress = s"| [${index + 1}/${effectiveRanges.size}]" logger.info(s"Computing join for range: ${range.toString} $progress") leftDf(joinConf, range, tableUtils).map { leftDfInRange => if (showDf) leftDfInRange.prettyPrint() // set autoExpand = true to ensure backward compatibility due to column ordering changes - computeRange(leftDfInRange, range, bootstrapInfo).save(outputTable, tableProps, autoExpand = true) + computeRange(leftDfInRange, range, bootstrapInfo, runSmallMode).save(outputTable, tableProps, autoExpand = true) val elapsedMins = (System.currentTimeMillis() - startMillis) / (60 * 1000) metrics.gauge(Metrics.Name.LatencyMinutes, elapsedMins) metrics.gauge(Metrics.Name.PartitionCount, range.partitions.length) diff --git a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala index 15007dee3..c086d0307 100644 --- a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala @@ -16,6 +16,8 @@ package ai.chronon.spark +import java.io.{PrintWriter, StringWriter} + import org.slf4j.LoggerFactory import ai.chronon.aggregator.windowing.TsUtils import ai.chronon.api.{Constants, PartitionSpec} @@ -29,10 +31,10 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession} import org.apache.spark.storage.StorageLevel - import java.time.format.DateTimeFormatter import java.time.{Instant, ZoneId} import java.util.concurrent.{ExecutorService, Executors} + import scala.collection.{Seq, mutable} import scala.concurrent.{ExecutionContext, ExecutionContextExecutor} import scala.util.{Failure, Success, Try} @@ -46,6 +48,8 @@ case class TableUtils(sparkSession: SparkSession) { .withZone(ZoneId.systemDefault()) val partitionColumn: String = sparkSession.conf.get("spark.chronon.partition.column", "ds") + var shouldRepartition: Boolean = false + //sparkSession.conf.get("spark.chronon.repartition", "true").toBoolean private val partitionFormat: String = sparkSession.conf.get("spark.chronon.partition.format", "yyyy-MM-dd") val partitionSpec: PartitionSpec = PartitionSpec(partitionFormat, WindowUtils.Day.millis) @@ -76,6 +80,9 @@ case class TableUtils(sparkSession: SparkSession) { sparkSession.sparkContext.setLogLevel("ERROR") // converts String-s like "a=b/c=d" to Map("a" -> "b", "c" -> "d") + def setRepartition(setTo: Boolean): Unit = { + this.shouldRepartition = setTo + } def preAggRepartition(df: DataFrame): DataFrame = if (df.rdd.getNumPartitions < aggregationParallelism) { df.repartition(aggregationParallelism) @@ -322,8 +329,15 @@ case class TableUtils(sparkSession: SparkSession) { def sql(query: String): DataFrame = { val partitionCount = sparkSession.sparkContext.getConf.getInt("spark.default.parallelism", 1000) + val sw = new StringWriter() + val pw = new PrintWriter(sw) + new Throwable().printStackTrace(pw) + val stackTraceString = sw.toString + val stackTraceStringPretty = stackTraceString.split("\n").filter(_.contains("chronon")).map(_.replace("at ai.chronon.spark.", "")).mkString("\n") + logger.info( - s"\n----[Running query coalesced into at most $partitionCount partitions]----\n$query\n----[End of Query]----\n") + s"\n----[Running query coalesced into at most $partitionCount partitions]----\n$query\n----[End of Query]----\n\n Query call path (not an error stack trace): \n$stackTraceStringPretty \n\n --------") + val df = sparkSession.sql(query).coalesce(partitionCount) df } @@ -383,7 +397,15 @@ case class TableUtils(sparkSession: SparkSession) { saveMode: SaveMode, stats: Option[DfStats]): Unit = { wrapWithCache(s"repartition & write to $tableName", df) { - repartitionAndWriteInternal(df, tableName, saveMode, stats) + if (shouldRepartition) { + logger.info(s"Repartitioning before writing...") + repartitionAndWriteInternal(df, tableName, saveMode, stats) + } else { + logger.info(s"Skipping repartition...") + df.write.mode(saveMode).insertInto(tableName) + logger.info(s"Finished writing to $tableName") + } + }.get } @@ -392,6 +414,7 @@ case class TableUtils(sparkSession: SparkSession) { saveMode: SaveMode, stats: Option[DfStats]): Unit = { // get row count and table partition count statistics + val (rowCount: Long, tablePartitionCount: Int) = if (df.schema.fieldNames.contains(partitionColumn)) { if (stats.isDefined && stats.get.partitionRange.wellDefined) {