Skip to content

Commit

Permalink
feat: only backfill selected join parts
Browse files Browse the repository at this point in the history
  • Loading branch information
Donghan Zhang committed Feb 16, 2024
1 parent ac5095b commit 78dd02a
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 6 deletions.
5 changes: 4 additions & 1 deletion spark/src/main/scala/ai/chronon/spark/Analyzer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ class Analyzer(tableUtils: TableUtils,
}

def analyzeJoin(joinConf: api.Join,
joinPartOnly: Option[List[String]] = None,
enableHitter: Boolean = false,
validationAssert: Boolean = false): (Map[String, DataType], ListBuffer[AggregationMetadata]) = {
val name = "joins/" + joinConf.metaData.name
Expand All @@ -276,7 +277,9 @@ class Analyzer(tableUtils: TableUtils,
.unfilledRanges(joinConf.metaData.outputTable, rangeToFill, Some(Seq(joinConf.left.table)))
.getOrElse(Seq.empty)

joinConf.joinParts.toScala.foreach { part =>
joinConf.joinParts.toScala
.filter(part => joinPartOnly.isDefined && joinPartOnly.get.contains(part.groupBy.metaData.name))
.foreach { part =>
val (aggMetadata, gbKeySchema) =
analyzeGroupBy(part.groupBy, part.fullPrefix, includeOutputTableName = true, enableHitter = enableHitter)
aggregationsMetadata ++= aggMetadata.map { aggMeta =>
Expand Down
7 changes: 6 additions & 1 deletion spark/src/main/scala/ai/chronon/spark/Driver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,10 @@ object Driver {
opt[String](required = false,
descr =
"Start date to compute join backfill, this start date will override start partition in conf.")
// joinPartOnly should be a list instead of a string
val selectedJoinParts: ScallopOption[List[String]] =
opt[List[String]](required = false,
descr = "Only backfill the specified JoinPart.")
lazy val joinConf: api.Join = parseConf[api.Join](confPath())
override def subcommandName() = s"join_${joinConf.metaData.name}"
}
Expand All @@ -244,7 +248,8 @@ object Driver {
args.buildTableUtils(),
!args.runFirstHole()
)
val df = join.computeJoin(args.stepDays.toOption, args.startPartitionOverride.toOption)
// workflow change
val df = join.computeJoin(args.stepDays.toOption, args.startPartitionOverride.toOption, args.selectedJoinParts.toOption)

if (args.shouldExport()) {
args.exportTableToLocal(args.joinConf.metaData.outputTable, tableUtils)
Expand Down
10 changes: 6 additions & 4 deletions spark/src/main/scala/ai/chronon/spark/JoinBase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ abstract class JoinBase(joinConf: api.Join,
// Cache join part data into intermediate table
if (filledDf.isDefined) {
logger.info(s"Writing to join part table: $partTable for partition range $unfilledRange")
filledDf.get.save(partTable, tableProps, stats = prunedLeft.map(_.stats))
filledDf.get.save(partTable, tablePr ops, stats = prunedLeft.map(_.stats))
}
})
val elapsedMins = (System.currentTimeMillis() - start) / 60000
Expand Down Expand Up @@ -288,12 +288,14 @@ abstract class JoinBase(joinConf: api.Join,

def computeRange(leftDf: DataFrame, leftRange: PartitionRange, bootstrapInfo: BootstrapInfo): DataFrame

def computeJoin(stepDays: Option[Int] = None, overrideStartPartition: Option[String] = None): DataFrame = {
def computeJoin(stepDays: Option[Int] = None, overrideStartPartition: Option[String] = None, selectedJoinParts: Option[List[String]]): DataFrame = {

assert(Option(joinConf.metaData.team).nonEmpty,
s"join.metaData.team needs to be set for join ${joinConf.metaData.name}")

joinConf.joinParts.asScala.foreach { jp =>
joinConf.joinParts.asScala
.filter(jp => selectedJoinParts.isDefined && selectedJoinParts.get.contains(jp.groupBy.metaData.name))
.foreach { jp =>
assert(Option(jp.groupBy.metaData.team).nonEmpty,
s"groupBy.metaData.team needs to be set for joinPart ${jp.groupBy.metaData.name}")
}
Expand All @@ -302,7 +304,7 @@ abstract class JoinBase(joinConf: api.Join,
val today = tableUtils.partitionSpec.at(System.currentTimeMillis())
val analyzer = new Analyzer(tableUtils, joinConf, today, today, silenceMode = true)
try {
analyzer.analyzeJoin(joinConf, validationAssert = true)
analyzer.analyzeJoin(joinConf, selectedJoinParts, validationAssert = true)
metrics.gauge(Metrics.Name.validationSuccess, 1)
logger.info("Join conf validation succeeded. No error found.")
} catch {
Expand Down

0 comments on commit 78dd02a

Please sign in to comment.