diff --git a/api/py/ai/chronon/repo/run.py b/api/py/ai/chronon/repo/run.py index 696430f5d..c97e4e921 100755 --- a/api/py/ai/chronon/repo/run.py +++ b/api/py/ai/chronon/repo/run.py @@ -29,7 +29,8 @@ ONLINE_ARGS = "--online-jar={online_jar} --online-class={online_class} " OFFLINE_ARGS = "--conf-path={conf_path} --end-date={ds} " -SAMPLING_ARGS = "--output-dir={output_dir} " +SAMPLING_ARGS = "--output-dir={output_dir} --num-rows={num_rows}" +SAMPLED_BACKFILL_ARGS = "--output-dir={output_dir} --local-warehouse-location={local_warehouse_location} --start-date={start_date}" ONLINE_WRITE_ARGS = "--conf-path={conf_path} " + ONLINE_ARGS ONLINE_OFFLINE_WRITE_ARGS = OFFLINE_ARGS + ONLINE_ARGS ONLINE_MODES = [ @@ -52,7 +53,8 @@ "log-flattener", "metadata-export", "label-join", - "sample" + "sample", + "sampled-backfill" ] MODES_USING_EMBEDDED = ["metadata-upload", "fetch", "local-streaming"] @@ -77,6 +79,7 @@ "label-join": OFFLINE_ARGS, "streaming-client": ONLINE_WRITE_ARGS, "sample": OFFLINE_ARGS + SAMPLING_ARGS, + "sampled-backfill": OFFLINE_ARGS + SAMPLED_BACKFILL_ARGS, "info": "", } @@ -104,7 +107,8 @@ "log-flattener": "log-flattener", "metadata-export": "metadata-export", "label-join": "label-join", - "sample": "sample" + "sample": "sample", + "sampled-backfill": "sampled-join" }, "staging_queries": { "backfill": "staging-query-backfill", @@ -350,6 +354,9 @@ def __init__(self, args, jar_path): self.mode = args.mode self.online_jar = args.online_jar self.output_dir = args.output_dir + self.num_rows = args.num_rows + self.start_date = args.start_date + self.local_warehouse_location = args.local_warehouse_location valid_jar = args.online_jar and os.path.exists(args.online_jar) # fetch online jar if necessary if (self.mode in ONLINE_MODES) and (not args.sub_help) and not valid_jar: @@ -395,15 +402,18 @@ def __init__(self, args, jar_path): self.spark_submit = args.spark_submit_path self.list_apps_cmd = args.list_apps - def run(self): + def run(self, runtime_args=""): base_args = MODE_ARGS[self.mode].format( conf_path=self.conf, ds=self.ds, online_jar=self.online_jar, online_class=self.online_class, - output_dir=self.output_dir + output_dir=self.output_dir, + num_rows=self.num_rows, + local_warehouse_location=self.local_warehouse_location, + start_date=self.start_date ) - final_args = base_args + " " + str(self.args) + final_args = base_args + " " + str(self.args) + " " + runtime_args if self.mode == "info": command = "python3 {script} --conf {conf} --ds {ds} --repo {repo}".format( script=self.render_info, conf=self.conf, ds=self.ds, repo=self.repo @@ -476,17 +486,23 @@ def run(self): if self.mode == "sample": # After executing the local data sampling, sample mode runs also run a local execution of the job itself print("Sampling complete. Running {} in local mode".format(self.conf)) - self.mode = "backfill" + self.mode = "sampled-backfill" # Make sure you set `--master "${SPARK_JOB_MODE:-yarn}"` in your spark_submit script for this to work as intended - os.environ["SPARK_JOB_MODE"] = "local[*]" + # os.environ["SPARK_JOB_MODE"] = "local[*]" + #if not self.local_warehouse_location: + # raise RuntimeError("You must provide the `local-warehouse-dir` argument to use sample mode") + #os.environ["SPARK_local_warehouse_location"] = self.local_warehouse_location self.run() -def set_defaults(parser): +def set_defaults(parser, pre_parse_args=None): """Set default values based on environment""" chronon_repo_path = os.environ.get("CHRONON_REPO_PATH", ".") today = datetime.today().strftime("%Y-%m-%d") + start_date_default = None + if pre_parse_args: + start_date_default = pre_parse_args.ds if pre_parse_args.mode == "sample" else None parser.set_defaults( mode="backfill", ds=today, @@ -506,7 +522,10 @@ def set_defaults(parser): chronon_jar=os.environ.get("CHRONON_DRIVER_JAR"), list_apps="python3 " + os.path.join(chronon_repo_path, "scripts/yarn_list.py"), render_info=os.path.join(chronon_repo_path, RENDER_INFO_DEFAULT_SCRIPT), - output_dir=os.environ.get("CHRONON_LOCAL_DATA_DIR") + output_dir=os.environ.get("CHRONON_LOCAL_DATA_DIR"), + num_rows=100, + local_warehouse_location=os.environ.get("CHRONON_LOCAL_WAREHOUSE_LOCATION", os.getcwd() + "/chronon_local"), + start_date=start_date_default ) if __name__ == "__main__": @@ -517,7 +536,7 @@ def set_defaults(parser): help="Conf param - required for every mode except fetch", ) parser.add_argument("--mode", choices=MODE_ARGS.keys()) - parser.add_argument("--ds", help="the end partition to backfill the data") + parser.add_argument("--ds", help="the end partition to backfill the data. Acts as start_date default as well in sampling case.") parser.add_argument( "--app-name", help="app name. Default to {}".format(APP_NAME_TEMPLATE) ) @@ -585,11 +604,21 @@ def set_defaults(parser): help="Path to local directory to store sampled data for in-memory runs. " + "Only applicable when mode is set to sample", ) + parser.add_argument( + "--num-rows", + help="Number of output rows desired for sample run. " + + "Only applicable when mode is set to sample", + ) + parser.add_argument( + "--local-warehouse-location", + help="Directory to use as the local warehouse for local runs." + + "Only applicable when mode is set to sample", + ) set_defaults(parser) pre_parse_args, _ = parser.parse_known_args() # We do a pre-parse to extract conf, mode, etc and set environment variables and re parse default values. set_runtime_env(pre_parse_args) - set_defaults(parser) + set_defaults(parser, pre_parse_args) args, unknown_args = parser.parse_known_args() jar_type = "embedded" if args.mode in MODES_USING_EMBEDDED else "uber" extra_args = (" " + args.online_args) if args.mode in ONLINE_MODES else "" diff --git a/spark/src/main/scala/ai/chronon/spark/Driver.scala b/spark/src/main/scala/ai/chronon/spark/Driver.scala index 6666fa35d..a6f13d2f9 100644 --- a/spark/src/main/scala/ai/chronon/spark/Driver.scala +++ b/spark/src/main/scala/ai/chronon/spark/Driver.scala @@ -405,6 +405,59 @@ object Driver { } } + object SampledJoin { + class Args extends Subcommand("sampled-join") with OfflineSubcommand { + val startDate: ScallopOption[String] = + opt[String](required = false, + descr = "The earliest date for which you want to run your join, defaults to endDate if not set", + default = None) + val outputDir: ScallopOption[String] = + opt[String]( + required = true, + descr = "The output directory for your sample data which will be used by the join." + ) + val forceResample: ScallopOption[Boolean] = + opt[Boolean]( + required = false, + descr = + "Option to force resampling even if semantics haven't changed", + default = Some(false) + ) + lazy val joinConf: api.Join = parseConf[api.Join](confPath()) + override def subcommandName() = "sampledJoin" + } + + def run(args: Args): Unit = { + // Todo - start date and force override + + val session = SparkSessionBuilder.build2(args.subcommandName(), + local = true, + localWarehouseLocation = args.localWarehouseLocation.toOption) + + val tableUtils = TableUtils(session) + SampleDataLoader.loadAllData(args.outputDir(), tableUtils) + + println(s"\n\n\n \n ==== ======= \n\n\n Creating join with end date of ${args.endDate()} and start date of ${args.startDate}") + + val join = new Join( + args.joinConf, + args.endDate(), + args.buildTableUtils(), + false + ) + println ("\n\n\n ------------------------------------------- \n\n") + session.sql("use data") + session.sql("show tables").show(10) + val df = join.computeJoin(overrideStartPartition = args.startDate.toOption) + + df.show(numRows = 3, truncate = 0, vertical = true) + logger.info( + s"\nShowing three rows of output above.\nQuery table `${args.joinConf.metaData.outputTable}` for more.\n") + + } + + } + object MetadataExport { class Args extends Subcommand("metadata-export") with OfflineSubcommand { val inputRootPath: ScallopOption[String] = @@ -866,6 +919,8 @@ object Driver { addSubcommand(AnalyzerArgs) object SampleArgs extends Sample.Args addSubcommand(SampleArgs) + object SampledJoinArgs extends SampledJoin.Args + addSubcommand(SampledJoinArgs) object DailyStatsArgs extends DailyStats.Args addSubcommand(DailyStatsArgs) object LogStatsArgs extends LogStats.Args @@ -891,9 +946,11 @@ object Driver { def main(baseArgs: Array[String]): Unit = { val args = new Args(baseArgs) + println("HERE!") var shouldExit = true args.subcommand match { case Some(x) => + println(s"HERE!!! ${x}") x match { case args.JoinBackFillArgs => JoinBackfill.run(args.JoinBackFillArgs) case args.GroupByBackfillArgs => GroupByBackfill.run(args.GroupByBackfillArgs) @@ -910,6 +967,7 @@ object Driver { case args.CompareJoinQueryArgs => CompareJoinQuery.run(args.CompareJoinQueryArgs) case args.AnalyzerArgs => Analyzer.run(args.AnalyzerArgs) case args.SampleArgs => Sample.run(args.SampleArgs) + case args.SampledJoinArgs => SampledJoin.run(args.SampledJoinArgs) case args.DailyStatsArgs => DailyStats.run(args.DailyStatsArgs) case args.LogStatsArgs => LogStats.run(args.LogStatsArgs) case args.MetadataExportArgs => MetadataExport.run(args.MetadataExportArgs) diff --git a/spark/src/main/scala/ai/chronon/spark/Sample.scala b/spark/src/main/scala/ai/chronon/spark/Sample.scala index c8393bd72..2b05aef6a 100644 --- a/spark/src/main/scala/ai/chronon/spark/Sample.scala +++ b/spark/src/main/scala/ai/chronon/spark/Sample.scala @@ -8,11 +8,12 @@ import scala.jdk.CollectionConverters.{asScalaBufferConverter, mapAsScalaMapConv import ai.chronon.api import ai.chronon.api.Extensions.{GroupByOps, SourceOps} -import ai.chronon.spark.Driver.{logger, parseConf} +import ai.chronon.spark.Driver.{parseConf} import ai.chronon.spark.SampleHelper.{getPriorRunManifestMetadata, writeManifestMetadata} import com.google.gson.{Gson, GsonBuilder} import com.google.gson.reflect.TypeToken import org.apache.spark.sql.DataFrame +import org.slf4j.LoggerFactory class Sample(conf: Any, tableUtils: TableUtils, @@ -22,6 +23,7 @@ class Sample(conf: Any, forceResample: Boolean = false, numRows: Int = 100) { + @transient lazy val logger = LoggerFactory.getLogger(getClass) val MANIFEST_FILE_NAME = "manifest.json" val MANIFEST_SERDE_TYPE_TOKEN = new TypeToken[java.util.Map[String, Integer]](){}.getType @@ -153,7 +155,7 @@ class Sample(conf: Any, val joinPartsSemantics = joinParts.map{ joinPart => // For each joinPart, the only relevant sampling metadata for it's sources are keyMapping and keyColumn s"${Option(joinPart.prefix).getOrElse("")}${joinPart.groupBy.metaData.getName}" -> - (joinPart.keyMapping.asScala, joinPart.groupBy.keyColumns.asScala).hashCode() + (joinPart.keyMapping.asScala, joinPart.groupBy.keyColumns.asScala, joinPart.groupBy.maxWindow).hashCode() }.toMap // The left side hash only depends on the source where clauses @@ -186,7 +188,7 @@ class Sample(conf: Any, val tableHashes: Map[String, Int] = tablesMap.map{ case(table, joinParts) => (table, getTableSemanticHash(joinParts, join)) - } + } ++ Map(join.getLeft.rootTable -> Option(join.left.query.wheres.asScala).getOrElse("").hashCode()) val tablesToSample: Seq[String] = if (forceResample) { tableHashes.keys.toSeq diff --git a/spark/src/main/scala/ai/chronon/spark/SampleDataLoader.scala b/spark/src/main/scala/ai/chronon/spark/SampleDataLoader.scala index fab42a96e..42b61f9ac 100644 --- a/spark/src/main/scala/ai/chronon/spark/SampleDataLoader.scala +++ b/spark/src/main/scala/ai/chronon/spark/SampleDataLoader.scala @@ -1,10 +1,12 @@ package ai.chronon.spark -import java.nio.file.{Files, Paths} - import ai.chronon.spark.SampleHelper.getPriorRunManifestMetadata import org.apache.spark.sql.SparkSession import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ +import org.slf4j.LoggerFactory + + /* This class works in conjunction with Sample.scala @@ -15,24 +17,66 @@ as possible loading tables in between sampled data runs. */ object SampleDataLoader { - def loadTable(directory: String, session: SparkSession): Unit = { - val df: DataFrame = session.read.parquet(directory) - val folderName = Paths.get(directory).getFileName.toString - val splits = folderName.split("\\.") - assert(splits.nonEmpty && splits.size == 2, - "local data directory must be named `namespace.table`. This should be auto-generated by Sample.scala. Check" + - "path or job logic.") - session.sql(s"DROP TABLE IF EXISTS $folderName") - df.write.partitionBy("ds").saveAsTable(folderName) + @transient lazy val logger = LoggerFactory.getLogger(getClass) + private val SEMANTIC_HASH_REGEX = "semantic_hash=(-?\\d+)".r + + private def loadTable(directory: String, tableName: String, session: SparkSession, semanticHash: Int): Unit = { + val df: DataFrame = session.read.parquet(s"$directory/$tableName") + session.sql(s"DROP TABLE IF EXISTS $tableName") + session.sql(s"CREATE DATABASE IF NOT EXISTS ${tableName.split("\\.").head}") + df.write.partitionBy("ds").saveAsTable(tableName) + session.sql(s"ALTER TABLE $tableName SET TBLPROPERTIES ('semantic_hash' = '$semanticHash')") + } + + private def checkLocalWarehouse(tableUtils: TableUtils): Unit = { + + print(s"\n\n\n =========================== \n\n ${tableUtils.dataWarehouseDir.get} \n ${tableUtils.jobMode} \n ${tableUtils.hiveMetastore} \n\n\n") + // TODO: Is this the best way to run this safety check? Alternatively, could take the intended + // path in as an argument and check that it matches the spark setting. + if (tableUtils.dataWarehouseDir.get.startsWith("hdfs") || tableUtils.dataWarehouseDir.get.startsWith("s3") || !tableUtils.jobMode.startsWith("local") ) { + throw new RuntimeException( + """ + |SampleDataLoader is only meant to be run with a local spark directory. This is a safety precaution + |As this module drops and overwrites tables that likely share a name with production tables. + |Please be sure to set a local data warehouse path when calling the Chronon sample job. This + |Can be done with the `spark.sql.warehouse.dir` config option.""".stripMargin + ) + } } - def loadAllData(sampleDirectory: String, session: SparkSession): Unit = { - val basePath = Paths.get(sampleDirectory) - val directoriesStream = Files.newDirectoryStream(basePath) + def loadAllData(sampleDirectory: String, tableUtils: TableUtils): Unit = { + // Important: Always check that we're in local mode before doing anything + checkLocalWarehouse(tableUtils) - val sampleManifest: Map[String, Int] = getPriorRunManifestMetadata(sampleDirectory) + val sparkSession = tableUtils.sparkSession + getPriorRunManifestMetadata(sampleDirectory).foreach { case(tableName, hash) => + logger.info(s"Checking $tableName") + val resampleTable = if (tableUtils.tableExists(tableName)) { + logger.info(s"Found existing table $tableName.") + val tableDesc = sparkSession.sql(s"DESCRIBE TABLE EXTENDED $tableName") + tableDesc.show(10) + val tableMetadata = tableDesc.filter(col("col_name") === "Table Properties").select("data_type").collect().head.getAs[String](0) - } + val semanticHashValue = SEMANTIC_HASH_REGEX.findFirstMatchIn(tableMetadata) match { + case Some(matchFound) => matchFound.group(1) + case None => + logger.error(s"Failed to parse semantic hash from $tableName. Table metadata: $tableMetadata") + "" + } + // Reload table when the data hash changed + semanticHashValue != hash.toString + } else { + true + } + + if (resampleTable) { + logger.info(s"Loading data for table $tableName") + loadTable(sampleDirectory, tableName, sparkSession, hash) + } else { + logger.info(s"$tableName already up to date based on semantic hash.") + } + } + } } diff --git a/spark/src/main/scala/ai/chronon/spark/SparkSessionBuilder.scala b/spark/src/main/scala/ai/chronon/spark/SparkSessionBuilder.scala index 64ce6e654..34dfa1833 100644 --- a/spark/src/main/scala/ai/chronon/spark/SparkSessionBuilder.scala +++ b/spark/src/main/scala/ai/chronon/spark/SparkSessionBuilder.scala @@ -90,6 +90,40 @@ object SparkSessionBuilder { spark } + def build2(name: String, + local: Boolean = false, + localWarehouseLocation: Option[String] = None, + additionalConfig: Option[Map[String, String]] = None): SparkSession = { + if (local) { + //required to run spark locally with hive support enabled - for sbt test + System.setSecurityManager(null) + } + val userName = Properties.userName + val warehouseDir = localWarehouseLocation.map(expandUser).getOrElse(DefaultWarehouseDir.getAbsolutePath) + val metastoreDb = s"jdbc:derby:;databaseName=$warehouseDir/metastore_db;create=true" + val baseBuilder = SparkSession + .builder() + .appName(name) + .config("spark.sql.session.timeZone", "UTC") + //otherwise overwrite will delete ALL partitions, not just the ones it touches + .config("spark.sql.sources.partitionOverwriteMode", "dynamic") + .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + .config("spark.kryo.registrator", "ai.chronon.spark.ChrononKryoRegistrator") + .config("spark.kryoserializer.buffer.max", "2000m") + .config("spark.kryo.referenceTracking", "false") + .config("spark.sql.legacy.timeParserPolicy", "LEGACY") + .master("local[*]") + .config("spark.hadoop.javax.jdo.option.ConnectionURL", metastoreDb) + .config("spark.kryo.registrationRequired", s"${localWarehouseLocation.isEmpty}") + .config("spark.sql.warehouse.dir", s"$warehouseDir") + + val spark = baseBuilder.getOrCreate() + // disable log spam + spark.sparkContext.setLogLevel("ERROR") + Logger.getLogger("parquet.hadoop").setLevel(java.util.logging.Level.SEVERE) + spark + } + def buildStreaming(local: Boolean): SparkSession = { val userName = Properties.userName val baseBuilder = SparkSession diff --git a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala index 15007dee3..c5f56fb7e 100644 --- a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala @@ -72,6 +72,9 @@ case class TableUtils(sparkSession: SparkSession) { val joinPartParallelism: Int = sparkSession.conf.get("spark.chronon.join.part.parallelism", "1").toInt val aggregationParallelism: Int = sparkSession.conf.get("spark.chronon.group_by.parallelism", "1000").toInt val maxWait: Int = sparkSession.conf.get("spark.chronon.wait.hours", "48").toInt + val dataWarehouseDir: Option[String] = sparkSession.sparkContext.getConf.getOption("spark.sql.warehouse.dir") + val jobMode: String = sparkSession.sparkContext.master + val hiveMetastore: Option[String] = sparkSession.sparkContext.getConf.getOption("hive.metastore.uris") sparkSession.sparkContext.setLogLevel("ERROR") // converts String-s like "a=b/c=d" to Map("a" -> "b", "c" -> "d")