From 90243c51d67f1464bd803373b9f1adf14791dc98 Mon Sep 17 00:00:00 2001 From: Donghan Zhang Date: Fri, 23 Feb 2024 16:37:34 -0800 Subject: [PATCH] unit test --- .../main/scala/ai/chronon/spark/Driver.scala | 4 +- .../main/scala/ai/chronon/spark/GroupBy.scala | 7 +- .../scala/ai/chronon/spark/JoinBase.scala | 17 +-- .../chronon/spark/stats/ConsistencyJob.scala | 2 +- .../ai/chronon/spark/test/AnalyzerTest.scala | 2 +- .../ai/chronon/spark/test/AvroTest.scala | 2 +- .../spark/test/ChainingFetcherTest.scala | 2 +- .../chronon/spark/test/FetchStatsTest.scala | 2 +- .../ai/chronon/spark/test/FetcherTest.scala | 2 +- .../ai/chronon/spark/test/JoinTest.scala | 143 +++++++++++++++--- .../spark/test/MigrationCompareTest.scala | 2 +- .../ai/chronon/spark/test/MutationsTest.scala | 2 +- .../spark/test/bootstrap/DerivationTest.scala | 12 +- .../test/bootstrap/LogBootstrapTest.scala | 4 +- .../test/bootstrap/TableBootstrapTest.scala | 6 +- 15 files changed, 157 insertions(+), 52 deletions(-) diff --git a/spark/src/main/scala/ai/chronon/spark/Driver.scala b/spark/src/main/scala/ai/chronon/spark/Driver.scala index 6904c6b897..b038d10eab 100644 --- a/spark/src/main/scala/ai/chronon/spark/Driver.scala +++ b/spark/src/main/scala/ai/chronon/spark/Driver.scala @@ -247,10 +247,10 @@ object Driver { !args.runFirstHole(), selectedJoinParts = args.selectedJoinParts.toOption ) - val df = join.computeJoin(args.stepDays.toOption, args.startPartitionOverride.toOption) + val df = join.computeJoin(args.stepDays.toOption, args.startPartitionOverride.toOption).get if (args.selectedJoinParts.isDefined) { - logger.info("Selected join parts are populated successfully. No final join is required. Exiting.") + logger.info("Selected join parts are populated successfully. Exiting.") return } if (args.shouldExport()) { diff --git a/spark/src/main/scala/ai/chronon/spark/GroupBy.scala b/spark/src/main/scala/ai/chronon/spark/GroupBy.scala index a8224f1cfc..3308331520 100644 --- a/spark/src/main/scala/ai/chronon/spark/GroupBy.scala +++ b/spark/src/main/scala/ai/chronon/spark/GroupBy.scala @@ -16,22 +16,21 @@ package ai.chronon.spark -import org.slf4j.LoggerFactory import ai.chronon.aggregator.base.TimeTuple import ai.chronon.aggregator.row.RowAggregator import ai.chronon.aggregator.windowing._ import ai.chronon.api -import ai.chronon.api.{Accuracy, Constants, DataModel, ParametricMacro} import ai.chronon.api.DataModel.{Entities, Events} import ai.chronon.api.Extensions._ +import ai.chronon.api.{Accuracy, Constants, DataModel, ParametricMacro} import ai.chronon.online.{RowWrapper, SparkConversions} import ai.chronon.spark.Extensions._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.functions._ -import org.apache.spark.sql import org.apache.spark.sql.types._ import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.apache.spark.util.sketch.BloomFilter +import org.slf4j.LoggerFactory import java.util import scala.collection.{Seq, mutable} @@ -420,7 +419,7 @@ object GroupBy { val join = new Join(joinConf, endDate, tableUtils, mutationScan = false, showDf = showDf) if (computeDependency) { - val df = join.computeJoin() + val df = join.computeJoin().get if (showDf) { logger.info( s"printing output data from groupby::join_source: ${groupByConf.metaData.name}::${joinConf.metaData.name}") diff --git a/spark/src/main/scala/ai/chronon/spark/JoinBase.scala b/spark/src/main/scala/ai/chronon/spark/JoinBase.scala index b7caefcd75..f59536087d 100644 --- a/spark/src/main/scala/ai/chronon/spark/JoinBase.scala +++ b/spark/src/main/scala/ai/chronon/spark/JoinBase.scala @@ -16,7 +16,6 @@ package ai.chronon.spark -import org.slf4j.LoggerFactory import ai.chronon.api import ai.chronon.api.DataModel.{Entities, Events} import ai.chronon.api.Extensions._ @@ -28,6 +27,7 @@ import com.google.gson.Gson import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.util.sketch.BloomFilter +import org.slf4j.LoggerFactory import java.time.Instant import scala.collection.JavaConverters._ @@ -289,7 +289,7 @@ abstract class JoinBase(joinConf: api.Join, def computeRange(leftDf: DataFrame, leftRange: PartitionRange, bootstrapInfo: BootstrapInfo): Option[DataFrame] - def computeJoin(stepDays: Option[Int] = None, overrideStartPartition: Option[String] = None): DataFrame = { + def computeJoin(stepDays: Option[Int] = None, overrideStartPartition: Option[String] = None): Option[DataFrame] = { assert(Option(joinConf.metaData.team).nonEmpty, s"join.metaData.team needs to be set for join ${joinConf.metaData.name}") @@ -338,7 +338,7 @@ abstract class JoinBase(joinConf: api.Join, def finalResult: DataFrame = tableUtils.sql(rangeToFill.genScanQuery(null, outputTable)) if (unfilledRanges.isEmpty) { logger.info(s"\nThere is no data to compute based on end partition of ${rangeToFill.end}.\n\n Exiting..") - return finalResult + return Some(finalResult) } stepDays.foreach(metrics.gauge("step_days", _)) @@ -361,8 +361,9 @@ abstract class JoinBase(joinConf: api.Join, // set autoExpand = true to ensure backward compatibility due to column ordering changes val finalDf = computeRange(leftDfInRange, range, bootstrapInfo) if (selectedJoinParts.isDefined) { - assert(finalDf.isEmpty, "finalDf should be empty") + assert(finalDf.isEmpty, "The arg `selectedJoinParts` is defined, so no final join is required. `finalDf` should be empty") logger.info(s"Skipping writing to the output table for range: ${range.toString} $progress") + return None } else { finalDf.get.save(outputTable, tableProps, autoExpand = true) val elapsedMins = (System.currentTimeMillis() - startMillis) / (60 * 1000) @@ -373,11 +374,7 @@ abstract class JoinBase(joinConf: api.Join, } } } - if (selectedJoinParts.isDefined) { - logger.info(s"Completed join parts: ${selectedJoinParts.get.mkString(", ")}") - } else { - logger.info(s"Wrote to table $outputTable, into partitions: $unfilledRanges") - } - finalResult + logger.info(s"Wrote to table $outputTable, into partitions: $unfilledRanges") + Some(finalResult) } } diff --git a/spark/src/main/scala/ai/chronon/spark/stats/ConsistencyJob.scala b/spark/src/main/scala/ai/chronon/spark/stats/ConsistencyJob.scala index 11d24f55f1..c3f1b1de74 100644 --- a/spark/src/main/scala/ai/chronon/spark/stats/ConsistencyJob.scala +++ b/spark/src/main/scala/ai/chronon/spark/stats/ConsistencyJob.scala @@ -79,7 +79,7 @@ class ConsistencyJob(session: SparkSession, joinConf: Join, endDate: String) ext if (unfilledRanges.isEmpty) return val join = new chronon.spark.Join(buildComparisonJoin(), unfilledRanges.last.end, TableUtils(session)) logger.info("Starting compute Join for comparison table") - val compareDf = join.computeJoin(Some(30)) + val compareDf = join.computeJoin(Some(30)).get logger.info("======= side-by-side comparison schema =======") logger.info(compareDf.schema.pretty) } diff --git a/spark/src/test/scala/ai/chronon/spark/test/AnalyzerTest.scala b/spark/src/test/scala/ai/chronon/spark/test/AnalyzerTest.scala index cbe97537d9..ffbdad66e3 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/AnalyzerTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/AnalyzerTest.scala @@ -68,7 +68,7 @@ class AnalyzerTest { val analyzer = new Analyzer(tableUtils, joinConf, oneMonthAgo, today, enableHitter = true) val analyzerSchema = analyzer.analyzeJoin(joinConf)._1.map { case (k, v) => s"${k} => ${v}" }.toList.sorted val join = new Join(joinConf = joinConf, endPartition = oneMonthAgo, tableUtils) - val computed = join.computeJoin() + val computed = join.computeJoin().get val expectedSchema = computed.schema.fields.map(field => s"${field.name} => ${field.dataType}").sorted logger.info("=== expected schema =====") logger.info(expectedSchema.mkString("\n")) diff --git a/spark/src/test/scala/ai/chronon/spark/test/AvroTest.scala b/spark/src/test/scala/ai/chronon/spark/test/AvroTest.scala index e9036f77f4..71e7fbddba 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/AvroTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/AvroTest.scala @@ -99,7 +99,7 @@ class AvroTest { metaData = Builders.MetaData(name = "unit_test.test_decimal", namespace = namespace, team = "chronon") ) val runner = new Join(joinConf, tableUtils.partitionSpec.minus(today, new Window(40, TimeUnit.DAYS)), tableUtils) - val df = runner.computeJoin() + val df = runner.computeJoin().get df.printSchema() } diff --git a/spark/src/test/scala/ai/chronon/spark/test/ChainingFetcherTest.scala b/spark/src/test/scala/ai/chronon/spark/test/ChainingFetcherTest.scala index 4ea7fe2152..8d8d599c6b 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/ChainingFetcherTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/ChainingFetcherTest.scala @@ -215,7 +215,7 @@ class ChainingFetcherTest extends TestCase { val inMemoryKvStore = kvStoreFunc() val mockApi = new MockApi(kvStoreFunc, namespace) - val joinedDf = new ai.chronon.spark.Join(joinConf, endDs, tableUtils).computeJoin() + val joinedDf = new ai.chronon.spark.Join(joinConf, endDs, tableUtils).computeJoin().get val joinTable = s"$namespace.join_test_expected_${joinConf.metaData.cleanName}" joinedDf.save(joinTable) logger.info("=== Expected join table computed: === " + joinTable) diff --git a/spark/src/test/scala/ai/chronon/spark/test/FetchStatsTest.scala b/spark/src/test/scala/ai/chronon/spark/test/FetchStatsTest.scala index e0602d06d9..4f6f253c03 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/FetchStatsTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/FetchStatsTest.scala @@ -122,7 +122,7 @@ class FetchStatsTest extends TestCase { // Compute daily join. val joinJob = new Join(joinConf, today, tableUtils) - joinJob.computeJoin() + joinJob.computeJoin().get // Load some data. implicit val executionContext: ExecutionContext = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(1)) val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("FetchStatsTest") diff --git a/spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala b/spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala index 624d3ec9fe..5c06294d85 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala @@ -529,7 +529,7 @@ class FetcherTest extends TestCase { val inMemoryKvStore = kvStoreFunc() val mockApi = new MockApi(kvStoreFunc, namespace) - val joinedDf = new ai.chronon.spark.Join(joinConf, endDs, tableUtils).computeJoin() + val joinedDf = new ai.chronon.spark.Join(joinConf, endDs, tableUtils).computeJoin().get val joinTable = s"$namespace.join_test_expected_${joinConf.metaData.cleanName}" joinedDf.save(joinTable) val endDsExpected = tableUtils.sql(s"SELECT * FROM $joinTable WHERE ds='$endDs'") diff --git a/spark/src/test/scala/ai/chronon/spark/test/JoinTest.scala b/spark/src/test/scala/ai/chronon/spark/test/JoinTest.scala index 5ed03f9d25..73428c07ac 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/JoinTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/JoinTest.scala @@ -139,7 +139,7 @@ class JoinTest { ) val runner1 = new Join(joinConf, tableUtils.partitionSpec.minus(today, new Window(40, TimeUnit.DAYS)), tableUtils) - runner1.computeJoin() + runner1.computeJoin().get val dropStart = tableUtils.partitionSpec.minus(today, new Window(55, TimeUnit.DAYS)) val dropEnd = tableUtils.partitionSpec.minus(today, new Window(45, TimeUnit.DAYS)) tableUtils.dropPartitionRange( @@ -160,7 +160,7 @@ class JoinTest { resetUDFs() val runner2 = new Join(joinConf, end, tableUtils) - val computed = runner2.computeJoin(Some(3)) + val computed = runner2.computeJoin(Some(3)).get println(s"join start = $start") val expectedQuery = s""" @@ -239,7 +239,7 @@ class JoinTest { val runner3 = new Join(joinConf, end, tableUtils) val expected2 = spark.sql(expectedQuery) - val computed2 = runner3.computeJoin(Some(3)) + val computed2 = runner3.computeJoin(Some(3)).get val diff2 = Comparison.sideBySide(computed2, expected2, List("user_name", "user", "ts", "ds")) if (diff2.count() > 0) { @@ -312,7 +312,7 @@ class JoinTest { ) val runner = new Join(joinConf, end, tableUtils) - val computed = runner.computeJoin(Some(7)) + val computed = runner.computeJoin(Some(7)).get val expected = tableUtils.sql(s""" |WITH | countries AS (SELECT country, ds from $countryTable where ds >= '$start' and ds <= '$end'), @@ -365,7 +365,7 @@ class JoinTest { val stream = new java.io.ByteArrayOutputStream() Console.withOut(stream) { // rerun the same join job - runner.computeJoin(Some(7)) + runner.computeJoin(Some(7)).get } val stdOutMsg = stream.toString() println(s"std out message =\n $stdOutMsg") @@ -414,7 +414,7 @@ class JoinTest { ) val runner = new Join(joinConf, end, tableUtils) - val computed = runner.computeJoin(Some(7)) + val computed = runner.computeJoin(Some(7)).get println("showing join result") computed.show() @@ -471,7 +471,7 @@ class JoinTest { (new Analyzer(tableUtils, joinConf, monthAgo, today)).run() val join = new Join(joinConf = joinConf, endPartition = monthAgo, tableUtils) - val computed = join.computeJoin() + val computed = join.computeJoin().get computed.show() val expected = tableUtils.sql(s""" @@ -539,7 +539,7 @@ class JoinTest { val start = tableUtils.partitionSpec.minus(today, new Window(100, TimeUnit.DAYS)) (new Analyzer(tableUtils, joinConf, monthAgo, today)).run() val join = new Join(joinConf = joinConf, endPartition = dayAndMonthBefore, tableUtils) - val computed = join.computeJoin(Some(100)) + val computed = join.computeJoin(Some(100)).get computed.show() val expected = tableUtils.sql(s""" @@ -601,7 +601,7 @@ class JoinTest { spark.sql(q).show() val start = tableUtils.partitionSpec.minus(today, new Window(100, TimeUnit.DAYS)) val join = new Join(joinConf = joinConf, endPartition = dayAndMonthBefore, tableUtils) - val computed = join.computeJoin(Some(100)) + val computed = join.computeJoin(Some(100)).get computed.show() val expected = tableUtils.sql(s""" @@ -751,7 +751,7 @@ class JoinTest { ) val runner = new Join(joinConf, end, tableUtils) - val computed = runner.computeJoin(Some(7)) + val computed = runner.computeJoin(Some(7)).get println(s"join start = $start") val expected = tableUtils.sql(s""" |WITH @@ -793,7 +793,7 @@ class JoinTest { // Run the old join to ensure that tables exist val oldJoin = new Join(joinConf = joinConf, endPartition = dayAndMonthBefore, tableUtils) - oldJoin.computeJoin(Some(100)) + oldJoin.computeJoin(Some(100)).get // Make sure that there is no versioning-detected changes at this phase val joinPartsToRecomputeNoChange = JoinUtils.tablesToRecompute(joinConf, joinConf.metaData.outputTable, tableUtils) @@ -822,7 +822,7 @@ class JoinTest { assertEquals(addPartRecompute.size, 1) assertEquals(addPartRecompute, Seq(addPartJoinConf.metaData.outputTable)) // Compute to ensure that it works and to set the stage for the next assertion - addPartJoin.computeJoin(Some(100)) + addPartJoin.computeJoin(Some(100)).get // Test modifying only one of two joinParts val rightModJoinConf = addPartJoinConf.deepCopy() @@ -837,7 +837,7 @@ class JoinTest { rightModJoinConf.getJoinParts.get(0).setPrefix("user_4") val rightModBothJoin = new Join(joinConf = rightModJoinConf, endPartition = dayAndMonthBefore, tableUtils) // Compute to ensure that it works - val computed = rightModBothJoin.computeJoin(Some(100)) + val computed = rightModBothJoin.computeJoin(Some(100)).get // Now assert that the actual output is correct after all these runs computed.show() @@ -955,7 +955,7 @@ class JoinTest { ) assertTrue(end < today) val toCompute = new Join(limitedJoin, today, tableUtils) - toCompute.computeJoin() + toCompute.computeJoin().get val ds = tableUtils.sql(s"SELECT MAX(ds) FROM ${limitedJoin.metaData.outputTable}") assertTrue(ds.first().getString(0) < today) } @@ -1004,7 +1004,7 @@ class JoinTest { joinParts = Seq(Builders.JoinPart(groupBy = viewsGroupBy, prefix = "user")), metaData = Builders.MetaData(name = "test.item_snapshot_bloom_test", namespace = namespace, team = "chronon") ) - val skipBloomComputed = new Join(joinConf, today, testTableUtils).computeJoin() + val skipBloomComputed = new Join(joinConf, today, testTableUtils).computeJoin().get val leftSideCount = testSpark.sql(s"SELECT item, ts, ds from $itemQueriesTable where ds >= '$start'").count() println("computed count: " + skipBloomComputed.count()) assertEquals(leftSideCount, skipBloomComputed.count()) @@ -1064,7 +1064,7 @@ class JoinTest { Builders.MetaData(name = s"test.item_temporal_features$nameSuffix", namespace = namespace, team = "item_team") ) val toCompute = new Join(join, today, tableUtils) - toCompute.computeJoin() + toCompute.computeJoin().get // Add stats new SummaryJob(spark, join, today).dailyRun(stepDays = Some(30)) } @@ -1167,7 +1167,116 @@ class JoinTest { ) val runner = new Join(joinConf, end, tableUtils) - val computed = runner.computeJoin(Some(7)) + val computed = runner.computeJoin(Some(7)).get assertFalse(computed.isEmpty) } + + /** + * Create a event table as left side, 3 group bys as right side. + * Generate data using DataFrameGen and save to the tables. + * Create a join with only one join part selected. + * Run computeJoin(). + * Check if the selected join part is computed and the other join parts are not computed. + */ + @Test + def testSelectedJoinParts(): Unit = { + // Left + val itemQueries = List( + Column("item", api.StringType, 100), + Column("value", api.LongType, 100) + ) + val itemQueriesTable = s"$namespace.item_queries_selected_join_parts" + spark.sql(s"DROP TABLE IF EXISTS $itemQueriesTable") + spark.sql(s"DROP TABLE IF EXISTS ${itemQueriesTable}_tmp") + DataFrameGen.events(spark, itemQueries, 10000, partitions = 30).save(s"${itemQueriesTable}_tmp") + val leftDf = tableUtils.sql(s"SELECT item, value, ts, ds FROM ${itemQueriesTable}_tmp") + leftDf.save(itemQueriesTable) + val start = monthAgo + + // Right + val viewsSchema = List( + Column("user", api.StringType, 10000), + Column("item", api.StringType, 100), + Column("value", api.LongType, 100) + ) + val viewsTable = s"$namespace.view_selected_join_parts" + spark.sql(s"DROP TABLE IF EXISTS $viewsTable") + DataFrameGen.events(spark, viewsSchema, count = 10000, partitions = 30).save(viewsTable) + + // Group By + val gb1 = Builders.GroupBy( + sources = Seq( + Builders.Source.events( + table = viewsTable, + query = Builders.Query(startPartition = start) + )), + keyColumns = Seq("item"), + aggregations = Seq( + Builders.Aggregation(operation = Operation.LAST_K, argMap = Map("k" -> "10"), inputColumn = "user"), + Builders.Aggregation(operation = Operation.MAX, argMap = Map("k" -> "2"), inputColumn = "value") + ), + metaData = + Builders.MetaData(name = s"unit_test.item_views_selected_join_parts_1", namespace = namespace, team = "item_team"), + accuracy = Accuracy.SNAPSHOT + ) + + val gb2 = Builders.GroupBy( + sources = Seq( + Builders.Source.events( + table = viewsTable, + query = Builders.Query(startPartition = start) + )), + keyColumns = Seq("item"), + aggregations = Seq( + Builders.Aggregation(operation = Operation.MIN, argMap = Map("k" -> "1"), inputColumn = "value") + ), + metaData = + Builders.MetaData(name = s"unit_test.item_views_selected_join_parts_2", namespace = namespace, team = "item_team"), + accuracy = Accuracy.SNAPSHOT + ) + + val gb3 = Builders.GroupBy( + sources = Seq( + Builders.Source.events( + table = viewsTable, + query = Builders.Query(startPartition = start) + )), + keyColumns = Seq("item"), + aggregations = Seq( + Builders.Aggregation(operation = Operation.AVERAGE, inputColumn = "value") + ), + metaData = + Builders.MetaData(name = s"unit_test.item_views_selected_join_parts_3", namespace = namespace, team = "item_team"), + accuracy = Accuracy.SNAPSHOT + ) + + // Join + val joinConf = Builders.Join( + left = Builders.Source.events(Builders.Query(startPartition = start), table = itemQueriesTable), + joinParts = Seq( + Builders.JoinPart(groupBy = gb1, prefix = "user1"), + Builders.JoinPart(groupBy = gb2, prefix = "user2"), + Builders.JoinPart(groupBy = gb3, prefix = "user3"), + ), + metaData = Builders.MetaData(name = s"unit_test.item_temporal_features.selected_join_parts", + namespace = namespace, + team = "item_team", + online = true) + ) + + // Compute daily join. + val joinJob = new Join(joinConf, today, tableUtils, selectedJoinParts = Some(List("user1"))) + val computedDf = joinJob.computeJoin() + assertTrue(computedDf.isEmpty) + +// val partTable1 = s"${joinConf.metaData.outputTable}_user1_unit_test_item_views_selected_join_parts_1" +// val partTable2 = s"${joinConf.metaData.outputTable}_user2_unit_test_item_views_selected_join_parts_2" +// val partTable3 = s"${joinConf.metaData.outputTable}_user3_unit_test_item_views_selected_join_parts_3" +// val part1 = tableUtils.sql(s"SELECT * FROM $partTable1") +// val part2 = tableUtils.sql(s"SELECT * FROM $partTable2") +// val part3 = tableUtils.sql(s"SELECT * FROM $partTable3") +// assertTrue(part1.count() > 0) +// assertTrue(part2.count() == 0) +// assertTrue(part3.count() == 0) + } } diff --git a/spark/src/test/scala/ai/chronon/spark/test/MigrationCompareTest.scala b/spark/src/test/scala/ai/chronon/spark/test/MigrationCompareTest.scala index e87622e5b6..fc85537d0a 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/MigrationCompareTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/MigrationCompareTest.scala @@ -77,7 +77,7 @@ class MigrationCompareTest { ) val join = new Join(joinConf = joinConf, endPartition = today, tableUtils) - join.computeJoin() + join.computeJoin().get //--------------------------------Staging Query----------------------------- val stagingQueryConf = Builders.StagingQuery( diff --git a/spark/src/test/scala/ai/chronon/spark/test/MutationsTest.scala b/spark/src/test/scala/ai/chronon/spark/test/MutationsTest.scala index 59b1eb8783..4d079cb2f0 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/MutationsTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/MutationsTest.scala @@ -207,7 +207,7 @@ class MutationsTest { ) val runner = new Join(joinConf, endPartition, tableUtils) - runner.computeJoin() + runner.computeJoin().get } /** diff --git a/spark/src/test/scala/ai/chronon/spark/test/bootstrap/DerivationTest.scala b/spark/src/test/scala/ai/chronon/spark/test/bootstrap/DerivationTest.scala index 75cbf2a8bd..79a97d3d18 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/bootstrap/DerivationTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/bootstrap/DerivationTest.scala @@ -139,7 +139,7 @@ class DerivationTest { ) val runner = new ai.chronon.spark.Join(baseJoin, today, tableUtils) - val outputDf = runner.computeJoin() + val outputDf = runner.computeJoin().get assertTrue( outputDf.columns.toSet == Set( @@ -243,7 +243,7 @@ class DerivationTest { ) val runner2 = new ai.chronon.spark.Join(bootstrapJoin, today, tableUtils) - val computed = runner2.computeJoin() + val computed = runner2.computeJoin().get // Comparison val expected = outputDf @@ -346,7 +346,7 @@ class DerivationTest { ) val runner = new ai.chronon.spark.Join(joinConf, today, tableUtils) - val outputDf = runner.computeJoin() + val outputDf = runner.computeJoin().get // assert that no computation happened for join part since all derivations have been bootstrapped assertFalse(tableUtils.tableExists(joinConf.partOutputTable(joinPart))) @@ -465,7 +465,7 @@ class DerivationTest { assertTrue(baseColumns.forall(logDf.columns.contains)) val bootstrapJoinJob = new ai.chronon.spark.Join(bootstrapJoin, endDs, tableUtils) - val computedDf = bootstrapJoinJob.computeJoin() + val computedDf = bootstrapJoinJob.computeJoin().get if (!wildcardSelection) { assertTrue(baseColumns.forall(c => !computedDf.columns.contains(c))) } @@ -474,7 +474,7 @@ class DerivationTest { assertFalse(tableUtils.tableExists(bootstrapJoin.partOutputTable(joinPart))) val baseJoinJob = new ai.chronon.spark.Join(baseJoin, endDs, tableUtils) - val baseDf = baseJoinJob.computeJoin() + val baseDf = baseJoinJob.computeJoin().get val expectedDf = JoinUtils .coalescedJoin( @@ -545,7 +545,7 @@ class DerivationTest { } def getSchema(joinConf: ai.chronon.api.Join): Seq[String] = { val runner = new ai.chronon.spark.Join(joinConf, today, tableUtils) - val outputDf = runner.computeJoin() + val outputDf = runner.computeJoin().get outputDf.columns } diff --git a/spark/src/test/scala/ai/chronon/spark/test/bootstrap/LogBootstrapTest.scala b/spark/src/test/scala/ai/chronon/spark/test/bootstrap/LogBootstrapTest.scala index 00fce2256e..e3093f7d54 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/bootstrap/LogBootstrapTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/bootstrap/LogBootstrapTest.scala @@ -147,7 +147,7 @@ class LogBootstrapTest { assertEquals(logDf.count(), responses.length) val baseJoinJob = new ai.chronon.spark.Join(baseJoinV2, endDs, tableUtils) - val baseOutput = baseJoinJob.computeJoin() + val baseOutput = baseJoinJob.computeJoin().get val expected = baseOutput .join(logDf, baseOutput("request_id") <=> logDf("request_id") and baseOutput("ds") <=> logDf("ds"), "left") @@ -169,7 +169,7 @@ class LogBootstrapTest { ) val joinJob = new ai.chronon.spark.Join(joinV2, endDs, tableUtils) - val computed = joinJob.computeJoin() + val computed = joinJob.computeJoin().get val overlapCount = baseOutput.join(logDf, Seq("request_id", "ds")).count() logger.info(s"""Debugging information: diff --git a/spark/src/test/scala/ai/chronon/spark/test/bootstrap/TableBootstrapTest.scala b/spark/src/test/scala/ai/chronon/spark/test/bootstrap/TableBootstrapTest.scala index 1bbf6a8872..7aee35f63d 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/bootstrap/TableBootstrapTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/bootstrap/TableBootstrapTest.scala @@ -98,7 +98,7 @@ class TableBootstrapTest { // Runs through standard backfill val runner1 = new ai.chronon.spark.Join(baseJoin, today, tableUtils) - val baseOutput = runner1.computeJoin() + val baseOutput = runner1.computeJoin().get // Create two bootstrap parts to verify that bootstrap coalesce respects the ordering of the input bootstrap parts val (bootstrapTable1, bootstrapTable2) = ("user_transactions_bootstrap1", "user_transactions_bootstrap2") @@ -112,7 +112,7 @@ class TableBootstrapTest { // Runs through boostrap backfill which combines backfill and bootstrap val runner2 = new ai.chronon.spark.Join(bootstrapJoin, today, tableUtils) - val computed = runner2.computeJoin() + val computed = runner2.computeJoin().get // Comparison val expected = baseOutput @@ -202,7 +202,7 @@ class TableBootstrapTest { ) val joinJob = new ai.chronon.spark.Join(join, endDs, tableUtils) - joinJob.computeJoin() + joinJob.computeJoin().get // assert that no computation happened for join part since all derivations have been bootstrapped assertFalse(tableUtils.tableExists(join.partOutputTable(joinPart)))