From 255ece08bd4b95d02f123e96718274bcb984580a Mon Sep 17 00:00:00 2001 From: Varant Zanoyan Date: Tue, 20 Feb 2024 08:16:39 -0800 Subject: [PATCH] WIP --- api/py/ai/chronon/repo/run.py | 19 ++-- .../scala/ai/chronon/api/QueryUtils.scala | 21 ++-- .../main/scala/ai/chronon/spark/Sample.scala | 100 +++++++++++++----- .../ai/chronon/spark/SampleDataLoader.scala | 6 +- .../scala/ai/chronon/spark/TableUtils.scala | 2 + 5 files changed, 104 insertions(+), 44 deletions(-) diff --git a/api/py/ai/chronon/repo/run.py b/api/py/ai/chronon/repo/run.py index c97e4e921..9eb3ba057 100755 --- a/api/py/ai/chronon/repo/run.py +++ b/api/py/ai/chronon/repo/run.py @@ -485,15 +485,22 @@ def run(self, runtime_args=""): if self.mode == "sample": # After executing the local data sampling, sample mode runs also run a local execution of the job itself - print("Sampling complete. Running {} in local mode".format(self.conf)) + conf_path = os.path.join(args.repo, args.conf) + with open(conf_path, "r") as conf_file: + conf_json = json.load(conf_file) + name = conf_json.get("metaData").get("name") + print("Sampling complete. Running {} in local mode".format(name)) self.mode = "sampled-backfill" - # Make sure you set `--master "${SPARK_JOB_MODE:-yarn}"` in your spark_submit script for this to work as intended - # os.environ["SPARK_JOB_MODE"] = "local[*]" - #if not self.local_warehouse_location: - # raise RuntimeError("You must provide the `local-warehouse-dir` argument to use sample mode") - #os.environ["SPARK_local_warehouse_location"] = self.local_warehouse_location self.run() + full_output_directory = "{}/data/{}".format(self.local_warehouse_location, name.replace(".", "_")) + print("""\n\n + Sampled run complete. Note that your data exists external to your production warehouse. + To query this data from pyspark: + from pyspark.sql import SparkSession + df = spark.read.parquet("{}") + + """.format(full_output_directory)) def set_defaults(parser, pre_parse_args=None): diff --git a/api/src/main/scala/ai/chronon/api/QueryUtils.scala b/api/src/main/scala/ai/chronon/api/QueryUtils.scala index dc872c35f..68a184486 100644 --- a/api/src/main/scala/ai/chronon/api/QueryUtils.scala +++ b/api/src/main/scala/ai/chronon/api/QueryUtils.scala @@ -18,6 +18,18 @@ package ai.chronon.api // utilized by both streaming and batch object QueryUtils { + + def getWhereClause(wheres: Seq[String], includeWhere: Boolean = true): String = { + val whereString = if (includeWhere) "where" else "" + Option(wheres) + .filter(_.nonEmpty) + .map { ws => + s""" + |$whereString + | ${ws.map(w => s"(${w})").mkString(" AND ")}""".stripMargin + } + .getOrElse("") + } // when the value in fillIfAbsent for a key is null, we expect the column with the same name as the key // to be present in the table that the generated query runs on. def build(selects: Map[String, String], @@ -41,14 +53,7 @@ object QueryUtils { case (None, _) => Seq("*") } - val whereClause = Option(wheres) - .filter(_.nonEmpty) - .map { ws => - s""" - |WHERE - | ${ws.map(w => s"(${w})").mkString(" AND ")}""".stripMargin - } - .getOrElse("") + val whereClause = getWhereClause(wheres) s"""SELECT | ${finalSelects.mkString(",\n ")} diff --git a/spark/src/main/scala/ai/chronon/spark/Sample.scala b/spark/src/main/scala/ai/chronon/spark/Sample.scala index 2b05aef6a..c763131b2 100644 --- a/spark/src/main/scala/ai/chronon/spark/Sample.scala +++ b/spark/src/main/scala/ai/chronon/spark/Sample.scala @@ -7,8 +7,9 @@ import scala.io.Source import scala.jdk.CollectionConverters.{asScalaBufferConverter, mapAsScalaMapConverter} import ai.chronon.api -import ai.chronon.api.Extensions.{GroupByOps, SourceOps} -import ai.chronon.spark.Driver.{parseConf} +import ai.chronon.api.Extensions.{GroupByOps, QueryOps, SourceOps} +import ai.chronon.api.QueryUtils +import ai.chronon.spark.Driver.parseConf import ai.chronon.spark.SampleHelper.{getPriorRunManifestMetadata, writeManifestMetadata} import com.google.gson.{Gson, GsonBuilder} import com.google.gson.reflect.TypeToken @@ -32,14 +33,23 @@ class Sample(conf: Any, } def sampleSource(table: String, dateRange: PartitionRange, includeLimit: Boolean, - keySetOpt: Option[Seq[Map[String, Array[Any]]]] = None): DataFrame = { + keySetOpt: Option[Seq[Map[String, Array[Any]]]] = None, baseWhereClause: String = ""): DataFrame = { val outputFile = s"$outputDir/$table" val additionalFilterString = keySetOpt.map { keySetList => val keyFilterWheres = keySetList.map { keySet => val filterList = keySet.map { case (keyName: String, values: Array[Any]) => - val valueSet = values.map { + + val (notNullValues, nullValues) = values.partition(_ != null) + + if (notNullValues.isEmpty) { + throw new RuntimeException(s"No not-null keys found for table: $table key: $keyName. Check source table or where clauses.") + } else if (!nullValues.isEmpty) { + logger.warn(s"Found ${nullValues.length} null keys for table: $table key: $keyName.") + } + + val valueSet = notNullValues.map { case s: String => s"'$s'" // Add single quotes for string values case other => other.toString // Keep other types (like Int) as they are }.toSet @@ -54,12 +64,19 @@ class Sample(conf: Any, val limitString = if (includeLimit) s"LIMIT $numRows" else "" + val whereClauseInjection = if (baseWhereClause.isEmpty) { + "" + } else { + s"AND $baseWhereClause" + } + val sql = s""" |SELECT * FROM $table |WHERE ds >= "${dateRange.start}" AND |ds <= "${dateRange.end}" |$additionalFilterString + |$whereClauseInjection |$limitString |""".stripMargin @@ -77,11 +94,16 @@ class Sample(conf: Any, df } - def createKeyFilters(keys: Seq[String], df: DataFrame): Map[String, Array[Any]] = { - keys.map{ key => - val values = df.select(key).collect().map(row => row(0)) - key -> values - }.toMap + def createKeyFilters(keys: Map[String, String], df: DataFrame, joinOpt: Option[api.Join] = None): Map[String, Array[Any]] = { + val joinSelects: Map[String, String] = joinOpt.map{ join => + Option(join.left.rootQuery.getQuerySelects).getOrElse(Map.empty[String, String]) + }.getOrElse(Map.empty[String, String]) + + keys.map{ case (keyName, keyExpression) => + val selectExpression = joinSelects.getOrElse(keyName, keyName) + val values = df.select(selectExpression).collect().map(row => row(0)) + keyExpression -> values + } } def getTableSemanticHash(source: api.Source, groupBy: api.GroupBy): Int = { @@ -134,7 +156,9 @@ class Sample(conf: Any, val distinctKeysDf = tableUtils.sparkSession.sql(distinctKeysSql) - val keyFilters: Map[String, Array[Any]] = createKeyFilters(keys, distinctKeysDf) + val keyExpressions = createKeyExpressionMap(keys, source.rootQuery) + + val keyFilters: Map[String, Array[Any]] = createKeyFilters(keyExpressions, distinctKeysDf) // We don't need to do anything with the output df in this case sampleSource(source.rootTable, range, false, Option(Seq(keyFilters))) @@ -144,6 +168,12 @@ class Sample(conf: Any, } } + def createKeyExpressionMap(keys: Seq[String], query: api.Query): Map[String, String] = { + val selectMap = Option(query.getQuerySelects).getOrElse(Map.empty[String, String]) + keys.map { key=> + key -> selectMap.getOrElse(key, key) + }.toMap + } def dsToInt(ds: String): Int = { ds.replace("-", "").toInt @@ -178,16 +208,32 @@ class Sample(conf: Any, def sampleJoin(join: api.Join): Unit = { + // TODO: Fix logic for when left ts != ds + // Create a map of table -> List[joinPart] // So that we can generate semantic hashing per table, and construct one query per source table - val tablesMap: Map[String, List[api.JoinPart]] = join.joinParts.asScala.flatMap { joinPart => + val tablesMap: Map[String, List[(api.JoinPart, Map[String, String])]] = join.joinParts.asScala.flatMap { joinPart => + // Get the key cols + val keyNames: Seq[String] = if (joinPart.keyMapping != null) { + joinPart.keyMapping.asScala.keys.toSeq + } else { + joinPart.groupBy.getKeyColumns.asScala + } + joinPart.groupBy.sources.asScala.map { source => - (source.rootTable, joinPart) + val selectMap = Option(source.rootQuery.getQuerySelects).getOrElse(Map.empty[String, String]) + val keyExpressions = keyNames.map { key=> + key -> selectMap.getOrElse(key, key) + }.toMap + (source.rootTable, joinPart, keyExpressions) } + }.map { + case (rootTable, joinPart, keyMap) => (rootTable, (joinPart, keyMap)) }.groupBy(_._1).mapValues(_.map(_._2).toList) - val tableHashes: Map[String, Int] = tablesMap.map{ case(table, joinParts) => - (table, getTableSemanticHash(joinParts, join)) + + val tableHashes: Map[String, Int] = tablesMap.map{ case(table, joinPartsAndKeys) => + (table, getTableSemanticHash(joinPartsAndKeys.map(_._1), join)) } ++ Map(join.getLeft.rootTable -> Option(join.left.query.wheres.asScala).getOrElse("").hashCode()) val tablesToSample: Seq[String] = if (forceResample) { @@ -204,25 +250,25 @@ class Sample(conf: Any, val queryRange = PartitionRange(startDate, endDate)(tableUtils) // First sample the left side val leftRoot = join.getLeft.rootTable - val sampledLeftDf = sampleSource(leftRoot, queryRange, true) + // val wheres = join.getLeft.getJoinSource.getQuery.getWheres.asScala TODO: same as above + val wheres = join.getLeft.getEvents.getQuery.getWheres.asScala - val filteredTablesMap = tablesMap.filter { case (key, _) => - tablesToSample.contains(key) } + // QueryUtils.build(null, leftRoot, wheres) + val whereClause = QueryUtils.getWhereClause(wheres, false) - filteredTablesMap.map { case (table, joinParts) => - val startDateAndKeyFilters = joinParts.map { joinPart => - val groupBy = joinPart.groupBy + val sampledLeftDf = sampleSource(leftRoot, queryRange, true, keySetOpt = None, baseWhereClause = whereClause) - // Get the key cols - val keys: Seq[String] = if (joinPart.keyMapping != null) { - joinPart.keyMapping.asScala.keys.toSeq - } else { - groupBy.getKeyColumns.asScala - } + val filteredTablesMap = tablesMap.filter { case (tableName, _) => + tablesToSample.contains(tableName) } + + filteredTablesMap.map { case (table, joinPartsAndKeys) => + val startDateAndKeyFilters = joinPartsAndKeys.map { case (joinPart, keyMap) => + val groupBy = joinPart.groupBy // Construct the specific key filter for each GroupBy using the sampledLeftDf - val keyFilters: Map[String, Array[Any]] = createKeyFilters(keys, sampledLeftDf) + val keyFilters: Map[String, Array[Any]] = createKeyFilters(keyMap, sampledLeftDf, Option(join)) + // TODO, this queryrange should be based off of timestamps val start = QueryRangeHelper.earliestDate(join.left.dataModel, groupBy, tableUtils, queryRange) (start, keyFilters) } diff --git a/spark/src/main/scala/ai/chronon/spark/SampleDataLoader.scala b/spark/src/main/scala/ai/chronon/spark/SampleDataLoader.scala index 42b61f9ac..9aa6f07bd 100644 --- a/spark/src/main/scala/ai/chronon/spark/SampleDataLoader.scala +++ b/spark/src/main/scala/ai/chronon/spark/SampleDataLoader.scala @@ -52,7 +52,8 @@ object SampleDataLoader { getPriorRunManifestMetadata(sampleDirectory).foreach { case(tableName, hash) => logger.info(s"Checking $tableName") - val resampleTable = if (tableUtils.tableExists(tableName)) { + val database = tableName.split("\\.").head + val resampleTable = if (tableUtils.databaseExists(database) && tableUtils.tableExists(tableName)) { logger.info(s"Found existing table $tableName.") val tableDesc = sparkSession.sql(s"DESCRIBE TABLE EXTENDED $tableName") tableDesc.show(10) @@ -61,8 +62,7 @@ object SampleDataLoader { val semanticHashValue = SEMANTIC_HASH_REGEX.findFirstMatchIn(tableMetadata) match { case Some(matchFound) => matchFound.group(1) case None => - logger.error(s"Failed to parse semantic hash from $tableName. Table metadata: $tableMetadata") - "" + throw new RuntimeException(s"Failed to parse semantic hash from $tableName. Table metadata: $tableMetadata") } // Reload table when the data hash changed diff --git a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala index c5f56fb7e..942e92d77 100644 --- a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala @@ -104,6 +104,8 @@ case class TableUtils(sparkSession: SparkSession) { def tableExists(tableName: String): Boolean = sparkSession.catalog.tableExists(tableName) + def databaseExists(databaseName: String): Boolean = sparkSession.catalog.databaseExists(databaseName) + def loadEntireTable(tableName: String): DataFrame = sparkSession.table(tableName) def isPartitioned(tableName: String): Boolean = {