diff --git a/spark/src/main/scala/ai/chronon/spark/Analyzer.scala b/spark/src/main/scala/ai/chronon/spark/Analyzer.scala index d8965df447..ab98928c68 100644 --- a/spark/src/main/scala/ai/chronon/spark/Analyzer.scala +++ b/spark/src/main/scala/ai/chronon/spark/Analyzer.scala @@ -250,6 +250,7 @@ class Analyzer(tableUtils: TableUtils, } def analyzeJoin(joinConf: api.Join, + joinPartOnly: Option[List[String]] = None, enableHitter: Boolean = false, validationAssert: Boolean = false): (Map[String, DataType], ListBuffer[AggregationMetadata]) = { val name = "joins/" + joinConf.metaData.name @@ -276,7 +277,9 @@ class Analyzer(tableUtils: TableUtils, .unfilledRanges(joinConf.metaData.outputTable, rangeToFill, Some(Seq(joinConf.left.table))) .getOrElse(Seq.empty) - joinConf.joinParts.toScala.foreach { part => + joinConf.joinParts.toScala + .filter(part => joinPartOnly.isDefined && joinPartOnly.get.contains(part.groupBy.metaData.name)) + .foreach { part => val (aggMetadata, gbKeySchema) = analyzeGroupBy(part.groupBy, part.fullPrefix, includeOutputTableName = true, enableHitter = enableHitter) aggregationsMetadata ++= aggMetadata.map { aggMeta => diff --git a/spark/src/main/scala/ai/chronon/spark/Driver.scala b/spark/src/main/scala/ai/chronon/spark/Driver.scala index 7df7949616..02fb166a99 100644 --- a/spark/src/main/scala/ai/chronon/spark/Driver.scala +++ b/spark/src/main/scala/ai/chronon/spark/Driver.scala @@ -232,6 +232,10 @@ object Driver { opt[String](required = false, descr = "Start date to compute join backfill, this start date will override start partition in conf.") + // joinPartOnly should be a list instead of a string + val selectedJoinParts: ScallopOption[List[String]] = + opt[List[String]](required = false, + descr = "Only backfill the specified JoinPart.") lazy val joinConf: api.Join = parseConf[api.Join](confPath()) override def subcommandName() = s"join_${joinConf.metaData.name}" } @@ -244,7 +248,8 @@ object Driver { args.buildTableUtils(), !args.runFirstHole() ) - val df = join.computeJoin(args.stepDays.toOption, args.startPartitionOverride.toOption) + // workflow change + val df = join.computeJoin(args.stepDays.toOption, args.startPartitionOverride.toOption, args.selectedJoinParts.toOption) if (args.shouldExport()) { args.exportTableToLocal(args.joinConf.metaData.outputTable, tableUtils) diff --git a/spark/src/main/scala/ai/chronon/spark/JoinBase.scala b/spark/src/main/scala/ai/chronon/spark/JoinBase.scala index 12664ab03c..675fb30ef4 100644 --- a/spark/src/main/scala/ai/chronon/spark/JoinBase.scala +++ b/spark/src/main/scala/ai/chronon/spark/JoinBase.scala @@ -157,7 +157,7 @@ abstract class JoinBase(joinConf: api.Join, // Cache join part data into intermediate table if (filledDf.isDefined) { logger.info(s"Writing to join part table: $partTable for partition range $unfilledRange") - filledDf.get.save(partTable, tableProps, stats = prunedLeft.map(_.stats)) + filledDf.get.save(partTable, tablePr ops, stats = prunedLeft.map(_.stats)) } }) val elapsedMins = (System.currentTimeMillis() - start) / 60000 @@ -288,12 +288,14 @@ abstract class JoinBase(joinConf: api.Join, def computeRange(leftDf: DataFrame, leftRange: PartitionRange, bootstrapInfo: BootstrapInfo): DataFrame - def computeJoin(stepDays: Option[Int] = None, overrideStartPartition: Option[String] = None): DataFrame = { + def computeJoin(stepDays: Option[Int] = None, overrideStartPartition: Option[String] = None, selectedJoinParts: Option[List[String]]): DataFrame = { assert(Option(joinConf.metaData.team).nonEmpty, s"join.metaData.team needs to be set for join ${joinConf.metaData.name}") - joinConf.joinParts.asScala.foreach { jp => + joinConf.joinParts.asScala + .filter(jp => selectedJoinParts.isDefined && selectedJoinParts.get.contains(jp.groupBy.metaData.name)) + .foreach { jp => assert(Option(jp.groupBy.metaData.team).nonEmpty, s"groupBy.metaData.team needs to be set for joinPart ${jp.groupBy.metaData.name}") } @@ -302,7 +304,7 @@ abstract class JoinBase(joinConf: api.Join, val today = tableUtils.partitionSpec.at(System.currentTimeMillis()) val analyzer = new Analyzer(tableUtils, joinConf, today, today, silenceMode = true) try { - analyzer.analyzeJoin(joinConf, validationAssert = true) + analyzer.analyzeJoin(joinConf, selectedJoinParts, validationAssert = true) metrics.gauge(Metrics.Name.validationSuccess, 1) logger.info("Join conf validation succeeded. No error found.") } catch {