Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
ezvz committed Feb 20, 2024
1 parent 834a370 commit 255ece0
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 44 deletions.
19 changes: 13 additions & 6 deletions api/py/ai/chronon/repo/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,15 +485,22 @@ def run(self, runtime_args=""):

if self.mode == "sample":
# After executing the local data sampling, sample mode runs also run a local execution of the job itself
print("Sampling complete. Running {} in local mode".format(self.conf))
conf_path = os.path.join(args.repo, args.conf)
with open(conf_path, "r") as conf_file:
conf_json = json.load(conf_file)
name = conf_json.get("metaData").get("name")
print("Sampling complete. Running {} in local mode".format(name))
self.mode = "sampled-backfill"
# Make sure you set `--master "${SPARK_JOB_MODE:-yarn}"` in your spark_submit script for this to work as intended
# os.environ["SPARK_JOB_MODE"] = "local[*]"
#if not self.local_warehouse_location:
# raise RuntimeError("You must provide the `local-warehouse-dir` argument to use sample mode")
#os.environ["SPARK_local_warehouse_location"] = self.local_warehouse_location
self.run()
full_output_directory = "{}/data/{}".format(self.local_warehouse_location, name.replace(".", "_"))
print("""\n\n
Sampled run complete. Note that your data exists external to your production warehouse.
To query this data from pyspark:
from pyspark.sql import SparkSession
df = spark.read.parquet("{}")
""".format(full_output_directory))


def set_defaults(parser, pre_parse_args=None):
Expand Down
21 changes: 13 additions & 8 deletions api/src/main/scala/ai/chronon/api/QueryUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,18 @@ package ai.chronon.api

// utilized by both streaming and batch
object QueryUtils {

def getWhereClause(wheres: Seq[String], includeWhere: Boolean = true): String = {
val whereString = if (includeWhere) "where" else ""
Option(wheres)
.filter(_.nonEmpty)
.map { ws =>
s"""
|$whereString
| ${ws.map(w => s"(${w})").mkString(" AND ")}""".stripMargin
}
.getOrElse("")
}
// when the value in fillIfAbsent for a key is null, we expect the column with the same name as the key
// to be present in the table that the generated query runs on.
def build(selects: Map[String, String],
Expand All @@ -41,14 +53,7 @@ object QueryUtils {
case (None, _) => Seq("*")
}

val whereClause = Option(wheres)
.filter(_.nonEmpty)
.map { ws =>
s"""
|WHERE
| ${ws.map(w => s"(${w})").mkString(" AND ")}""".stripMargin
}
.getOrElse("")
val whereClause = getWhereClause(wheres)

s"""SELECT
| ${finalSelects.mkString(",\n ")}
Expand Down
100 changes: 73 additions & 27 deletions spark/src/main/scala/ai/chronon/spark/Sample.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ import scala.io.Source
import scala.jdk.CollectionConverters.{asScalaBufferConverter, mapAsScalaMapConverter}

import ai.chronon.api
import ai.chronon.api.Extensions.{GroupByOps, SourceOps}
import ai.chronon.spark.Driver.{parseConf}
import ai.chronon.api.Extensions.{GroupByOps, QueryOps, SourceOps}
import ai.chronon.api.QueryUtils
import ai.chronon.spark.Driver.parseConf
import ai.chronon.spark.SampleHelper.{getPriorRunManifestMetadata, writeManifestMetadata}
import com.google.gson.{Gson, GsonBuilder}
import com.google.gson.reflect.TypeToken
Expand All @@ -32,14 +33,23 @@ class Sample(conf: Any,
}

def sampleSource(table: String, dateRange: PartitionRange, includeLimit: Boolean,
keySetOpt: Option[Seq[Map[String, Array[Any]]]] = None): DataFrame = {
keySetOpt: Option[Seq[Map[String, Array[Any]]]] = None, baseWhereClause: String = ""): DataFrame = {

val outputFile = s"$outputDir/$table"
val additionalFilterString = keySetOpt.map { keySetList =>

val keyFilterWheres = keySetList.map { keySet =>
val filterList = keySet.map { case (keyName: String, values: Array[Any]) =>
val valueSet = values.map {

val (notNullValues, nullValues) = values.partition(_ != null)

if (notNullValues.isEmpty) {
throw new RuntimeException(s"No not-null keys found for table: $table key: $keyName. Check source table or where clauses.")
} else if (!nullValues.isEmpty) {
logger.warn(s"Found ${nullValues.length} null keys for table: $table key: $keyName.")
}

val valueSet = notNullValues.map {
case s: String => s"'$s'" // Add single quotes for string values
case other => other.toString // Keep other types (like Int) as they are
}.toSet
Expand All @@ -54,12 +64,19 @@ class Sample(conf: Any,

val limitString = if (includeLimit) s"LIMIT $numRows" else ""

val whereClauseInjection = if (baseWhereClause.isEmpty) {
""
} else {
s"AND $baseWhereClause"
}

val sql =
s"""
|SELECT * FROM $table
|WHERE ds >= "${dateRange.start}" AND
|ds <= "${dateRange.end}"
|$additionalFilterString
|$whereClauseInjection
|$limitString
|""".stripMargin

Expand All @@ -77,11 +94,16 @@ class Sample(conf: Any,
df
}

def createKeyFilters(keys: Seq[String], df: DataFrame): Map[String, Array[Any]] = {
keys.map{ key =>
val values = df.select(key).collect().map(row => row(0))
key -> values
}.toMap
def createKeyFilters(keys: Map[String, String], df: DataFrame, joinOpt: Option[api.Join] = None): Map[String, Array[Any]] = {
val joinSelects: Map[String, String] = joinOpt.map{ join =>
Option(join.left.rootQuery.getQuerySelects).getOrElse(Map.empty[String, String])
}.getOrElse(Map.empty[String, String])

keys.map{ case (keyName, keyExpression) =>
val selectExpression = joinSelects.getOrElse(keyName, keyName)
val values = df.select(selectExpression).collect().map(row => row(0))
keyExpression -> values
}
}

def getTableSemanticHash(source: api.Source, groupBy: api.GroupBy): Int = {
Expand Down Expand Up @@ -134,7 +156,9 @@ class Sample(conf: Any,

val distinctKeysDf = tableUtils.sparkSession.sql(distinctKeysSql)

val keyFilters: Map[String, Array[Any]] = createKeyFilters(keys, distinctKeysDf)
val keyExpressions = createKeyExpressionMap(keys, source.rootQuery)

val keyFilters: Map[String, Array[Any]] = createKeyFilters(keyExpressions, distinctKeysDf)

// We don't need to do anything with the output df in this case
sampleSource(source.rootTable, range, false, Option(Seq(keyFilters)))
Expand All @@ -144,6 +168,12 @@ class Sample(conf: Any,
}
}

def createKeyExpressionMap(keys: Seq[String], query: api.Query): Map[String, String] = {
val selectMap = Option(query.getQuerySelects).getOrElse(Map.empty[String, String])
keys.map { key=>
key -> selectMap.getOrElse(key, key)
}.toMap
}

def dsToInt(ds: String): Int = {
ds.replace("-", "").toInt
Expand Down Expand Up @@ -178,16 +208,32 @@ class Sample(conf: Any,

def sampleJoin(join: api.Join): Unit = {

// TODO: Fix logic for when left ts != ds

// Create a map of table -> List[joinPart]
// So that we can generate semantic hashing per table, and construct one query per source table
val tablesMap: Map[String, List[api.JoinPart]] = join.joinParts.asScala.flatMap { joinPart =>
val tablesMap: Map[String, List[(api.JoinPart, Map[String, String])]] = join.joinParts.asScala.flatMap { joinPart =>
// Get the key cols
val keyNames: Seq[String] = if (joinPart.keyMapping != null) {
joinPart.keyMapping.asScala.keys.toSeq
} else {
joinPart.groupBy.getKeyColumns.asScala
}

joinPart.groupBy.sources.asScala.map { source =>
(source.rootTable, joinPart)
val selectMap = Option(source.rootQuery.getQuerySelects).getOrElse(Map.empty[String, String])
val keyExpressions = keyNames.map { key=>
key -> selectMap.getOrElse(key, key)
}.toMap
(source.rootTable, joinPart, keyExpressions)
}
}.map {
case (rootTable, joinPart, keyMap) => (rootTable, (joinPart, keyMap))
}.groupBy(_._1).mapValues(_.map(_._2).toList)

val tableHashes: Map[String, Int] = tablesMap.map{ case(table, joinParts) =>
(table, getTableSemanticHash(joinParts, join))

val tableHashes: Map[String, Int] = tablesMap.map{ case(table, joinPartsAndKeys) =>
(table, getTableSemanticHash(joinPartsAndKeys.map(_._1), join))
} ++ Map(join.getLeft.rootTable -> Option(join.left.query.wheres.asScala).getOrElse("").hashCode())

val tablesToSample: Seq[String] = if (forceResample) {
Expand All @@ -204,25 +250,25 @@ class Sample(conf: Any,
val queryRange = PartitionRange(startDate, endDate)(tableUtils)
// First sample the left side
val leftRoot = join.getLeft.rootTable
val sampledLeftDf = sampleSource(leftRoot, queryRange, true)
// val wheres = join.getLeft.getJoinSource.getQuery.getWheres.asScala TODO: same as above
val wheres = join.getLeft.getEvents.getQuery.getWheres.asScala

val filteredTablesMap = tablesMap.filter { case (key, _) =>
tablesToSample.contains(key) }
// QueryUtils.build(null, leftRoot, wheres)
val whereClause = QueryUtils.getWhereClause(wheres, false)

filteredTablesMap.map { case (table, joinParts) =>
val startDateAndKeyFilters = joinParts.map { joinPart =>
val groupBy = joinPart.groupBy
val sampledLeftDf = sampleSource(leftRoot, queryRange, true, keySetOpt = None, baseWhereClause = whereClause)

// Get the key cols
val keys: Seq[String] = if (joinPart.keyMapping != null) {
joinPart.keyMapping.asScala.keys.toSeq
} else {
groupBy.getKeyColumns.asScala
}
val filteredTablesMap = tablesMap.filter { case (tableName, _) =>
tablesToSample.contains(tableName) }

filteredTablesMap.map { case (table, joinPartsAndKeys) =>
val startDateAndKeyFilters = joinPartsAndKeys.map { case (joinPart, keyMap) =>
val groupBy = joinPart.groupBy

// Construct the specific key filter for each GroupBy using the sampledLeftDf
val keyFilters: Map[String, Array[Any]] = createKeyFilters(keys, sampledLeftDf)
val keyFilters: Map[String, Array[Any]] = createKeyFilters(keyMap, sampledLeftDf, Option(join))

// TODO, this queryrange should be based off of timestamps
val start = QueryRangeHelper.earliestDate(join.left.dataModel, groupBy, tableUtils, queryRange)
(start, keyFilters)
}
Expand Down
6 changes: 3 additions & 3 deletions spark/src/main/scala/ai/chronon/spark/SampleDataLoader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ object SampleDataLoader {

getPriorRunManifestMetadata(sampleDirectory).foreach { case(tableName, hash) =>
logger.info(s"Checking $tableName")
val resampleTable = if (tableUtils.tableExists(tableName)) {
val database = tableName.split("\\.").head
val resampleTable = if (tableUtils.databaseExists(database) && tableUtils.tableExists(tableName)) {
logger.info(s"Found existing table $tableName.")
val tableDesc = sparkSession.sql(s"DESCRIBE TABLE EXTENDED $tableName")
tableDesc.show(10)
Expand All @@ -61,8 +62,7 @@ object SampleDataLoader {
val semanticHashValue = SEMANTIC_HASH_REGEX.findFirstMatchIn(tableMetadata) match {
case Some(matchFound) => matchFound.group(1)
case None =>
logger.error(s"Failed to parse semantic hash from $tableName. Table metadata: $tableMetadata")
""
throw new RuntimeException(s"Failed to parse semantic hash from $tableName. Table metadata: $tableMetadata")
}

// Reload table when the data hash changed
Expand Down
2 changes: 2 additions & 0 deletions spark/src/main/scala/ai/chronon/spark/TableUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ case class TableUtils(sparkSession: SparkSession) {

def tableExists(tableName: String): Boolean = sparkSession.catalog.tableExists(tableName)

def databaseExists(databaseName: String): Boolean = sparkSession.catalog.databaseExists(databaseName)

def loadEntireTable(tableName: String): DataFrame = sparkSession.table(tableName)

def isPartitioned(tableName: String): Boolean = {
Expand Down

0 comments on commit 255ece0

Please sign in to comment.