Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
ezvz committed Feb 21, 2024
1 parent ac5095b commit f6c7b0e
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 1 deletion.
3 changes: 3 additions & 0 deletions spark/src/main/scala/ai/chronon/spark/Driver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,9 @@ object Driver {
opt[String](required = false,
descr =
"Start date to compute join backfill, this start date will override start partition in conf.")
val limit: ScallopOption[Int] =
opt[Int](required = false,
descr = "Limits the number of rows that your join will produce. Results in faster ")
lazy val joinConf: api.Join = parseConf[api.Join](confPath())
override def subcommandName() = s"join_${joinConf.metaData.name}"
}
Expand Down
54 changes: 53 additions & 1 deletion spark/src/main/scala/ai/chronon/spark/Join.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,14 @@ import ai.chronon.spark.JoinUtils._
import org.apache.spark.sql
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._

import java.util.concurrent.{Callable, ExecutorCompletionService, ExecutorService, Executors}

import scala.collection.Seq
import scala.collection.mutable
import scala.collection.parallel.ExecutionContextTaskSupport
import scala.concurrent.duration.{Duration, DurationInt}
import scala.concurrent.{Await, ExecutionContext, ExecutionContextExecutorService, Future}
import scala.jdk.CollectionConverters.{asJavaIterableConverter, asScalaBufferConverter, mapAsScalaMapConverter}
import scala.util.ScalaJavaConversions.{IterableOps, ListOps, MapOps}
import scala.util.{Failure, Success}

Expand Down Expand Up @@ -185,6 +186,54 @@ class Join(joinConf: api.Join,
coveringSetsPerJoinPart
}

def injectKeyFilter(leftDf: DataFrame, joinPart: api.JoinPart): Unit = {
// Modifies the joinPart to inject the key filter into the

val groupByKeyNames = joinPart.groupBy.getKeyColumns.asScala

// In case the joinPart uses a keymapping
val leftSideKeyNames: Map[String, String] = if (joinPart.keyMapping != null) {
joinPart.keyMapping.asScala.toMap
} else {
groupByKeyNames.map { k =>
(k, k)
}.toMap
}

joinPart.groupBy.sources.asScala.foreach { source =>
val selectMap = Option(source.rootQuery.getQuerySelects).getOrElse(Map.empty[String, String])
val groupByKeyExpressions = groupByKeyNames.map { key =>
key -> selectMap.getOrElse(key, key)
}.toMap


val joinSelects: Map[String, String] = Option(joinConf.left.rootQuery.getQuerySelects).getOrElse(Map.empty[String, String])

groupByKeyExpressions.map{ case (keyName, groupByKeyExpression) =>
val leftSideKeyName = leftSideKeyNames.get(keyName).get
val leftSelectExpression = joinSelects.getOrElse(leftSideKeyName, keyName)
val values = leftDf.select(leftSelectExpression).collect().map(row => row(0))

// Check for null keys, warn if found, err if all null
val (notNullValues, nullValues) = values.partition(_ != null)
if (notNullValues.isEmpty) {
throw new RuntimeException(s"No not-null keys found for key: $keyName. Check source table or where clauses.")
} else if (!nullValues.isEmpty) {
logger.warn(s"Found ${nullValues.length} null keys for key: $keyName.")
}

// String manipulate to form valid SQL
val valueSet = notNullValues.map {
case s: String => s"'$s'" // Add single quotes for string values
case other => other.toString // Keep other types (like Int) as they are
}.toSet

// Form the final WHERE clause for injection
s"$groupByKeyExpression in (${valueSet.mkString(sep = ",")})"
}.foreach(source.rootQuery.getWheres.add(_))
}
}

override def computeRange(leftDf: DataFrame, leftRange: PartitionRange, bootstrapInfo: BootstrapInfo): DataFrame = {
val leftTaggedDf = if (leftDf.schema.names.contains(Constants.TimeColumn)) {
leftDf.withTimeBasedColumn(Constants.TimePartitionColumn)
Expand Down Expand Up @@ -251,6 +300,9 @@ class Join(joinConf: api.Join,
leftRange.isSingleDay,
s"Macro ${Constants.ChrononRunDs} is only supported for single day join, current range is ${leftRange}")
}

// If left DF is small, hardcode the key filter into the joinPart's GroupBy's where clause.
if (unfilledLeftDf.isDefined && unfilledLeftDf.get.df.)
val df =
computeRightTable(unfilledLeftDf, joinPart, leftRange, joinLevelBloomMapOpt).map(df => joinPart -> df)
Thread.currentThread().setName(s"done-$threadName")
Expand Down

0 comments on commit f6c7b0e

Please sign in to comment.