Skip to content

Commit

Permalink
Chaining fix (#567)
Browse files Browse the repository at this point in the history
* Set namespace in joinSource

* Join backfill customization

* Add flag to join metadata & start partition override

* Add unit test

* format

* cleanup

* comments

* format

* minor

---------

Co-authored-by: Sophie Wang <[email protected]>
  • Loading branch information
nikhilsimha and Sophie Wang authored Sep 27, 2023
1 parent 39febaa commit af47ab4
Show file tree
Hide file tree
Showing 18 changed files with 139 additions and 16 deletions.
2 changes: 2 additions & 0 deletions api/py/ai/chronon/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,8 @@ def _normalize_source(source):
elif isinstance(source, ttypes.EntitySource):
return ttypes.Source(entities=source)
elif isinstance(source, ttypes.JoinSource):
if not source.join.metadata.isSetOutputNamespace():
source.join.metadata.setOutputNamespace(output_namespace)
return ttypes.Source(joinSource=source)
elif isinstance(source, ttypes.Source):
return source
Expand Down
8 changes: 7 additions & 1 deletion api/py/ai/chronon/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ def Join(left: api.Source,
consistency_sample_percent: float = 5.0,
online_external_parts: List[api.ExternalPart] = None,
offline_schedule: str = '@daily',
historical_backfill: bool = None,
row_ids: List[str] = None,
bootstrap_parts: List[api.BootstrapPart] = None,
bootstrap_from_log: bool = False,
Expand Down Expand Up @@ -468,6 +469,10 @@ def Join(left: api.Source,
:param tags:
Additional metadata about the Join that you wish to track. Does not effect computation.
:type tags: Dict[str, str]
:param historical_backfill:
Flag to indicate whether join backfill should backfill previous holes.
Setting to false will only backfill latest single partition
:type historical_backfill: bool
:return:
A join object that can be used to backfill or serve data. For ML use-cases this should map 1:1 to model.
"""
Expand Down Expand Up @@ -568,7 +573,8 @@ def Join(left: api.Source,
modeToEnvMap=env,
samplePercent=sample_percent,
offlineSchedule=offline_schedule,
consistencySamplePercent=consistency_sample_percent
consistencySamplePercent=consistency_sample_percent,
historicalBackfill=historical_backfill
)

return api.Join(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,5 +59,5 @@
"sample_config_json": """{"sample_key": "sample_value"}""",
"description": "sample description"
},
output_namespace="sample_namespace",
output_namespace="test_namespace",
)
3 changes: 2 additions & 1 deletion api/py/test/sample/joins/sample_team/sample_chaining_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
),
],
online=True,
check_consistency=True
check_consistency=True,
historical_backfill=False,
)

chaining_group_by_v1 = GroupBy(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@
"team": "sample_team",
"samplePercent": 100.0,
"offlineSchedule": "@daily",
"consistencySamplePercent": 5.0
"consistencySamplePercent": 5.0,
"historicalBackfill": 0
},
"left": {
"events": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
"team": "sample_team",
"samplePercent": 100.0,
"offlineSchedule": "@daily",
"consistencySamplePercent": 5.0
"consistencySamplePercent": 5.0,
"historicalBackfill": 0
},
"left": {
"events": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@
"team": "sample_team",
"samplePercent": 100.0,
"offlineSchedule": "@daily",
"consistencySamplePercent": 5.0
"consistencySamplePercent": 5.0,
"historicalBackfill": 0
},
"left": {
"events": {
Expand Down
4 changes: 3 additions & 1 deletion api/src/main/scala/ai/chronon/api/Builders.scala
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,8 @@ object Builders {
team: String = null,
samplePercent: Double = 100,
consistencySamplePercent: Double = 5,
tableProperties: Map[String, String] = Map.empty
tableProperties: Map[String, String] = Map.empty,
historicalBackill: Boolean = true
): MetaData = {
val result = new MetaData()
result.setName(name)
Expand All @@ -259,6 +260,7 @@ object Builders {
result.setCustomJson(customJson)
result.setOutputNamespace(namespace)
result.setTeam(Option(team).getOrElse("chronon"))
result.setHistoricalBackfill(historicalBackill)
if (dependencies != null)
result.setDependencies(dependencies.toSeq.toJava)
if (samplePercent > 0)
Expand Down
4 changes: 4 additions & 0 deletions api/src/main/scala/ai/chronon/api/Extensions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,10 @@ object Extensions {
.toArray
}

def historicalBackfill: Boolean = {
Option(join.metaData.historicalBackfill).getOrElse(true)
}

def computedFeatureCols: Seq[String] = joinPartOps.flatMap(_.valueColumns)

def partOutputTable(jp: JoinPart): String =
Expand Down
3 changes: 3 additions & 0 deletions api/thrift/api.thrift
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,9 @@ struct MetaData {
12: optional string offlineSchedule
// percentage of online serving requests used to compute consistency metrics
13: optional double consistencySamplePercent
// Flag to indicate whether join backfill should backfill previous holes.
// Setting to false will only backfill latest single partition
14: optional bool historicalBackfill
}

// Equivalent to a FeatureSet in chronon terms
Expand Down
3 changes: 2 additions & 1 deletion spark/src/main/scala/ai/chronon/spark/Analyzer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,8 @@ class Analyzer(tableUtils: TableUtils,
// Pair of (table name, group_by name, expected_start) which indicate that the table no not have data available for the required group_by
val dataAvailabilityErrors: ListBuffer[(String, String, String)] = ListBuffer.empty[(String, String, String)]

val rangeToFill = JoinUtils.getRangesToFill(joinConf.left, tableUtils, endDate)
val rangeToFill =
JoinUtils.getRangesToFill(joinConf.left, tableUtils, endDate, historicalBackfill = joinConf.historicalBackfill)
println(s"[Analyzer] Join range to fill $rangeToFill")
val unfilledRanges = tableUtils
.unfilledRanges(joinConf.metaData.outputTable, rangeToFill, Some(Seq(joinConf.left.table)))
Expand Down
6 changes: 5 additions & 1 deletion spark/src/main/scala/ai/chronon/spark/Driver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,10 @@ object Driver {
opt[Boolean](required = false,
default = Some(false),
descr = "Skip the first unfilled partition range if some future partitions have been populated.")
val startPartitionOverride: ScallopOption[String] =
opt[String](required = false,
descr =
"Start date to compute join backfill, this start date will override start partition in conf.")
lazy val joinConf: api.Join = parseConf[api.Join](confPath())
override def subcommandName() = s"join_${joinConf.metaData.name}"
}
Expand All @@ -221,7 +225,7 @@ object Driver {
args.buildTableUtils(),
!args.runFirstHole()
)
val df = join.computeJoin(args.stepDays.toOption)
val df = join.computeJoin(args.stepDays.toOption, args.startPartitionOverride.toOption)

if (args.shouldExport()) {
args.exportTableToLocal(args.joinConf.metaData.outputTable, tableUtils)
Expand Down
11 changes: 8 additions & 3 deletions spark/src/main/scala/ai/chronon/spark/GroupBy.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import ai.chronon.aggregator.base.TimeTuple
import ai.chronon.aggregator.row.RowAggregator
import ai.chronon.aggregator.windowing._
import ai.chronon.api
import ai.chronon.api.Constants
import ai.chronon.api.{Accuracy, Constants, DataModel}
import ai.chronon.api.DataModel.{Entities, Events}
import ai.chronon.api.Extensions._
import ai.chronon.online.{RowWrapper, SparkConversions}
Expand Down Expand Up @@ -388,8 +388,13 @@ object GroupBy {
println("Join source detected. Materializing the join.")
val joinSource = source.getJoinSource
val joinConf = joinSource.join
// materialize the table
val join = new Join(joinConf, queryRange.end, tableUtils, mutationScan = false, showDf = showDf)
// materialize the table with the right end date. QueryRange.end could be shifted for temporal events
val beforeDs = tableUtils.partitionSpec.before(queryRange.end)
val isPreShifted =
groupByConf.dataModel == DataModel.Events && groupByConf.inferredAccuracy == Accuracy.TEMPORAL
val endDate = if (isPreShifted) beforeDs else queryRange.end

val join = new Join(joinConf, endDate, tableUtils, mutationScan = false, showDf = showDf)
if (computeDependency) {
val df = join.computeJoin()
if (showDf) {
Expand Down
2 changes: 2 additions & 0 deletions spark/src/main/scala/ai/chronon/spark/GroupByUpload.scala
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ object GroupByUpload {
showDf = showDf)
lazy val groupByUpload = new GroupByUpload(endDs, groupBy)
// for temporal accuracy - we don't need to scan mutations for upload
// when endDs = xxxx-01-02 the timestamp from airflow is more than (xxxx-01-03 00:00:00)
// we wait for event partitions of (xxxx-01-02) which contain data until (xxxx-01-02 23:59:59.999)
lazy val shiftedGroupBy =
GroupBy.from(groupByConf,
PartitionRange(endDs, endDs).shift(1),
Expand Down
11 changes: 9 additions & 2 deletions spark/src/main/scala/ai/chronon/spark/JoinBase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ abstract class JoinBase(joinConf: api.Join,

def computeRange(leftDf: DataFrame, leftRange: PartitionRange, bootstrapInfo: BootstrapInfo): DataFrame

def computeJoin(stepDays: Option[Int] = None): DataFrame = {
def computeJoin(stepDays: Option[Int] = None, overrideStartPartition: Option[String] = None): DataFrame = {

assert(Option(joinConf.metaData.team).nonEmpty,
s"join.metaData.team needs to be set for join ${joinConf.metaData.name}")
Expand Down Expand Up @@ -303,7 +303,14 @@ abstract class JoinBase(joinConf: api.Join,
tableUtils.archiveOrDropTableIfExists(_, Some(archivedAtTs)))

// detect holes and chunks to fill
val rangeToFill = JoinUtils.getRangesToFill(joinConf.left, tableUtils, endPartition)
// OverrideStartPartition is used to replace the start partition of the join config. This is useful when
// 1 - User would like to test run with different start partition
// 2 - User has entity table which is accumulative and only want to run backfill for the latest partition
val rangeToFill = JoinUtils.getRangesToFill(joinConf.left,
tableUtils,
endPartition,
overrideStartPartition,
joinConf.historicalBackfill)
println(s"Join range to fill $rangeToFill")
val unfilledRanges = tableUtils
.unfilledRanges(outputTable, rangeToFill, Some(Seq(joinConf.left.table)), skipFirstHole = skipFirstHole)
Expand Down
13 changes: 11 additions & 2 deletions spark/src/main/scala/ai/chronon/spark/JoinUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,18 @@ object JoinUtils {
*/
def getRangesToFill(leftSource: ai.chronon.api.Source,
tableUtils: TableUtils,
endPartition: String): PartitionRange = {
val leftStart = Option(leftSource.query.startPartition)
endPartition: String,
overrideStartPartition: Option[String] = None,
historicalBackfill: Boolean = true): PartitionRange = {
val overrideStart = if (historicalBackfill) {
overrideStartPartition
} else {
println(s"Historical backfill is set to false. Backfill single partition only: $endPartition")
Some(endPartition)
}
lazy val defaultLeftStart = Option(leftSource.query.startPartition)
.getOrElse(tableUtils.firstAvailablePartition(leftSource.table, leftSource.subPartitionFilters).get)
val leftStart = overrideStart.getOrElse(defaultLeftStart)
val leftEnd = Option(leftSource.query.endPartition).getOrElse(endPartition)
PartitionRange(leftStart, leftEnd)(tableUtils)
}
Expand Down
56 changes: 56 additions & 0 deletions spark/src/test/scala/ai/chronon/spark/test/JoinTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,62 @@ class JoinTest {
assertTrue(stdOutMsg.contains(s"There is no data to compute based on end partition of $end."))
}

@Test
def testEntitiesEntitiesNoHistoricalBackfill(): Unit = {
// Only backfill latest partition if historical_backfill is turned off
val weightSchema = List(
Column("user", api.StringType, 1000),
Column("country", api.StringType, 100),
Column("weight", api.DoubleType, 500)
)
val weightTable = s"$namespace.weights_no_historical_backfill"
DataFrameGen.entities(spark, weightSchema, 1000, partitions = 400).save(weightTable)

val weightSource = Builders.Source.entities(
query = Builders.Query(selects = Builders.Selects("weight"),
startPartition = yearAgo,
endPartition = today),
snapshotTable = weightTable
)

val weightGroupBy = Builders.GroupBy(
sources = Seq(weightSource),
keyColumns = Seq("country"),
aggregations = Seq(Builders.Aggregation(operation = Operation.AVERAGE, inputColumn = "weight")),
metaData = Builders.MetaData(name = "test.country_weights_no_backfill", namespace = namespace)
)

// left side
val countrySchema = List(Column("country", api.StringType, 100))
val countryTable = s"$namespace.countries_no_historical_backfill"
DataFrameGen.entities(spark, countrySchema, 1000, partitions = 30).save(countryTable)

val start = tableUtils.partitionSpec.minus(today, new Window(30, TimeUnit.DAYS))
val end = tableUtils.partitionSpec.minus(today, new Window(5, TimeUnit.DAYS))
val joinConf = Builders.Join(
left = Builders.Source.entities(Builders.Query(startPartition = start), snapshotTable = countryTable),
joinParts = Seq(Builders.JoinPart(groupBy = weightGroupBy)),
metaData = Builders.MetaData(name = "test.country_no_historical_backfill",
namespace = namespace,
team = "chronon",
historicalBackill = false)
)

val runner = new Join(joinConf, end, tableUtils)
val computed = runner.computeJoin(Some(7))
println("showing join result")
computed.show()

val leftSideCount = spark.sql(s"SELECT country, ds from $countryTable where ds == '$end'").count()
println(s"Left side expected count: $leftSideCount")
println(s"Actual count: ${computed.count()}")
assertEquals(leftSideCount, computed.count())
// There should be only one partition in computed df which equals to end partition
val allPartitions = computed.select("ds").rdd.map(row => row(0)).collect().toSet
assert(allPartitions.size == 1)
assertEquals(allPartitions.toList(0), end)
}

@Test
def testEventsEventsSnapshot(): Unit = {
val viewsSchema = List(
Expand Down
18 changes: 18 additions & 0 deletions spark/src/test/scala/ai/chronon/spark/test/JoinUtilsTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,24 @@ class JoinUtilsTest {
assertEquals(range, PartitionRange(startPartition, endPartition)(tableUtils))
}

@Test
def testGetRangesToFillWithOverride(): Unit = {
spark.sql("CREATE DATABASE IF NOT EXISTS joinUtil")
// left table
val itemQueries = List(Column("item", api.StringType, 100))
val itemQueriesTable = "joinUtil.queries_table"
DataFrameGen
.events(spark, itemQueries, 1000, partitions = 50)
.save(itemQueriesTable)

val startPartition = "2023-04-15"
val startPartitionOverride = "2023-08-01"
val endPartition = "2023-08-08"
val leftSource = Builders.Source.events(Builders.Query(startPartition = startPartition), table = itemQueriesTable)
val range = JoinUtils.getRangesToFill(leftSource, tableUtils, endPartition, Some(startPartitionOverride))
assertEquals(range, PartitionRange(startPartitionOverride, endPartition)(tableUtils))
}

import ai.chronon.api.{LongType, StringType, StructField, StructType}

def createSampleTable(tableName: String = "testSampleTable"): DataFrame = {
Expand Down

0 comments on commit af47ab4

Please sign in to comment.