diff --git a/spark/src/main/scala/ai/chronon/spark/Driver.scala b/spark/src/main/scala/ai/chronon/spark/Driver.scala index 7df794961..835a619fa 100644 --- a/spark/src/main/scala/ai/chronon/spark/Driver.scala +++ b/spark/src/main/scala/ai/chronon/spark/Driver.scala @@ -232,6 +232,9 @@ object Driver { opt[String](required = false, descr = "Start date to compute join backfill, this start date will override start partition in conf.") + val limit: ScallopOption[Int] = + opt[Int](required = false, + descr = "Limits the number of rows that your join will produce. Results in faster ") lazy val joinConf: api.Join = parseConf[api.Join](confPath()) override def subcommandName() = s"join_${joinConf.metaData.name}" } diff --git a/spark/src/main/scala/ai/chronon/spark/Join.scala b/spark/src/main/scala/ai/chronon/spark/Join.scala index c06d54390..c7984ac2e 100644 --- a/spark/src/main/scala/ai/chronon/spark/Join.scala +++ b/spark/src/main/scala/ai/chronon/spark/Join.scala @@ -26,13 +26,14 @@ import ai.chronon.spark.JoinUtils._ import org.apache.spark.sql import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ - import java.util.concurrent.{Callable, ExecutorCompletionService, ExecutorService, Executors} + import scala.collection.Seq import scala.collection.mutable import scala.collection.parallel.ExecutionContextTaskSupport import scala.concurrent.duration.{Duration, DurationInt} import scala.concurrent.{Await, ExecutionContext, ExecutionContextExecutorService, Future} +import scala.jdk.CollectionConverters.{asJavaIterableConverter, asScalaBufferConverter, mapAsScalaMapConverter} import scala.util.ScalaJavaConversions.{IterableOps, ListOps, MapOps} import scala.util.{Failure, Success} @@ -185,6 +186,54 @@ class Join(joinConf: api.Join, coveringSetsPerJoinPart } + def injectKeyFilter(leftDf: DataFrame, joinPart: api.JoinPart): Unit = { + // Modifies the joinPart to inject the key filter into the + + val groupByKeyNames = joinPart.groupBy.getKeyColumns.asScala + + // In case the joinPart uses a keymapping + val leftSideKeyNames: Map[String, String] = if (joinPart.keyMapping != null) { + joinPart.keyMapping.asScala.toMap + } else { + groupByKeyNames.map { k => + (k, k) + }.toMap + } + + joinPart.groupBy.sources.asScala.foreach { source => + val selectMap = Option(source.rootQuery.getQuerySelects).getOrElse(Map.empty[String, String]) + val groupByKeyExpressions = groupByKeyNames.map { key => + key -> selectMap.getOrElse(key, key) + }.toMap + + + val joinSelects: Map[String, String] = Option(joinConf.left.rootQuery.getQuerySelects).getOrElse(Map.empty[String, String]) + + groupByKeyExpressions.map{ case (keyName, groupByKeyExpression) => + val leftSideKeyName = leftSideKeyNames.get(keyName).get + val leftSelectExpression = joinSelects.getOrElse(leftSideKeyName, keyName) + val values = leftDf.select(leftSelectExpression).collect().map(row => row(0)) + + // Check for null keys, warn if found, err if all null + val (notNullValues, nullValues) = values.partition(_ != null) + if (notNullValues.isEmpty) { + throw new RuntimeException(s"No not-null keys found for key: $keyName. Check source table or where clauses.") + } else if (!nullValues.isEmpty) { + logger.warn(s"Found ${nullValues.length} null keys for key: $keyName.") + } + + // String manipulate to form valid SQL + val valueSet = notNullValues.map { + case s: String => s"'$s'" // Add single quotes for string values + case other => other.toString // Keep other types (like Int) as they are + }.toSet + + // Form the final WHERE clause for injection + s"$groupByKeyExpression in (${valueSet.mkString(sep = ",")})" + }.foreach(source.rootQuery.getWheres.add(_)) + } + } + override def computeRange(leftDf: DataFrame, leftRange: PartitionRange, bootstrapInfo: BootstrapInfo): DataFrame = { val leftTaggedDf = if (leftDf.schema.names.contains(Constants.TimeColumn)) { leftDf.withTimeBasedColumn(Constants.TimePartitionColumn) @@ -251,6 +300,9 @@ class Join(joinConf: api.Join, leftRange.isSingleDay, s"Macro ${Constants.ChrononRunDs} is only supported for single day join, current range is ${leftRange}") } + + // If left DF is small, hardcode the key filter into the joinPart's GroupBy's where clause. + if (unfilledLeftDf.isDefined && unfilledLeftDf.get.df.) val df = computeRightTable(unfilledLeftDf, joinPart, leftRange, joinLevelBloomMapOpt).map(df => joinPart -> df) Thread.currentThread().setName(s"done-$threadName")