From 37dbd118d3043ce1ebe67d19ec121396275bdbf8 Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Thu, 28 Nov 2024 17:06:51 -0500 Subject: [PATCH] refactor: internalize parallel to RunTask._submit_batch --- core/dbt/task/run.py | 39 +++++++------------ .../functional/microbatch/test_microbatch.py | 2 +- tests/unit/task/test_run.py | 8 ++-- 3 files changed, 19 insertions(+), 30 deletions(-) diff --git a/core/dbt/task/run.py b/core/dbt/task/run.py index 28b0b9734c7..f5bcc578541 100644 --- a/core/dbt/task/run.py +++ b/core/dbt/task/run.py @@ -602,15 +602,15 @@ def _has_relation(self, model) -> bool: ) return relation is not None - def _should_run_in_parallel( - self, - relation_exists: bool, - ) -> bool: + def should_run_in_parallel(self) -> bool: if not self.adapter.supports(Capability.MicrobatchConcurrency): run_in_parallel = False - elif not relation_exists: + elif not self.relation_exists: # If the relation doesn't exist, we can't run in parallel run_in_parallel = False + elif self.batch_idx == 0 or self.batch_idx == len(self.batches) - 1: + # First and last batch don't run in parallel + run_in_parallel = False elif self.node.config.concurrent_batches is not None: # If the relation exists and the `concurrent_batches` config isn't None, use the config value run_in_parallel = self.node.config.concurrent_batches @@ -705,9 +705,7 @@ def handle_microbatch_model( ) -> RunResult: # Initial run computes batch metadata result = self.call_runner(runner) - batches = runner.batches - node = runner.node - relation_exists = runner.relation_exists + batches, node, relation_exists = runner.batches, runner.node, runner.relation_exists # Return early if model should be skipped, or there are no batches to execute if result.status == RunStatus.Skipped: @@ -717,30 +715,20 @@ def handle_microbatch_model( batch_results: List[RunResult] = [] batch_idx = 0 - - # Run first batch runs in serial - relation_exists = self._submit_batch( - node, relation_exists, batches, batch_idx, batch_results, pool, parallel=False - ) - batch_idx += 1 - - # Subsequent batches can be run in parallel + # Run all batches except last batch, in parallel if possible while batch_idx < len(runner.batches) - 1: - parallel = runner._should_run_in_parallel(relation_exists) relation_exists = self._submit_batch( - node, relation_exists, batches, batch_idx, batch_results, pool, parallel + node, relation_exists, batches, batch_idx, batch_results, pool ) batch_idx += 1 # Wait until all submitted batches have completed while len(batch_results) != batch_idx: pass + # Final batch runs once all others complete to ensure post_hook runs at the end + self._submit_batch(node, relation_exists, batches, batch_idx, batch_results, pool) - # Final batch runs in serial - self._submit_batch( - node, relation_exists, batches, batch_idx, batch_results, pool, parallel=False - ) - + # Finalize run: merge results, track model run, and print final result line runner.merge_batch_results(result, batch_results) track_model_run(runner.node_index, runner.num_nodes, result, adapter=runner.adapter) runner.print_result_line(result) @@ -755,7 +743,6 @@ def _submit_batch( batch_idx: int, batch_results: List[RunResult], pool: ThreadPool, - parallel: bool, ): node_copy = deepcopy(node) # Only run pre_hook(s) for first batch @@ -764,14 +751,14 @@ def _submit_batch( # Only run post_hook(s) for last batch elif batch_idx != len(batches) - 1: node_copy.config.post_hook = [] - + batch_runner = self.get_runner(node_copy) assert isinstance(batch_runner, MicrobatchModelRunner) batch_runner.set_batch_idx(batch_idx) batch_runner.set_relation_exists(relation_exists) batch_runner.set_batches(batches) - if parallel: + if batch_runner.should_run_in_parallel(): fire_event( MicrobatchExecutionDebug( msg=f"{batch_runner.describe_batch} is being run concurrently" diff --git a/tests/functional/microbatch/test_microbatch.py b/tests/functional/microbatch/test_microbatch.py index e3acc415273..56537ad48cf 100644 --- a/tests/functional/microbatch/test_microbatch.py +++ b/tests/functional/microbatch/test_microbatch.py @@ -875,7 +875,7 @@ def batch_exc_catcher(self) -> EventCatcher: def test_microbatch( self, mocker: MockerFixture, project, batch_exc_catcher: EventCatcher ) -> None: - mocked_srip = mocker.patch("dbt.task.run.MicrobatchModelRunner._should_run_in_parallel") + mocked_srip = mocker.patch("dbt.task.run.MicrobatchModelRunner.should_run_in_parallel") # Should be run in parallel mocked_srip.return_value = True diff --git a/tests/unit/task/test_run.py b/tests/unit/task/test_run.py index b33e6f57ffe..b28ac505a7f 100644 --- a/tests/unit/task/test_run.py +++ b/tests/unit/task/test_run.py @@ -264,7 +264,7 @@ class Relation: (False, False, False, True, False), ], ) - def test__should_run_in_parallel( + def test_should_run_in_parallel( self, mocker: MockerFixture, model_runner: MicrobatchModelRunner, @@ -276,11 +276,13 @@ def test__should_run_in_parallel( ) -> None: model_runner.node._has_this = has_this model_runner.node.config = ModelConfig(concurrent_batches=concurrent_batches) + model_runner.set_relation_exists(has_relation) + mocked_supports = mocker.patch.object(model_runner.adapter, "supports") mocked_supports.return_value = adapter_microbatch_concurrency - # Assert result of _should_run_in_parallel - assert model_runner._should_run_in_parallel(has_relation) == expectation + # Assert result of should_run_in_parallel + assert model_runner.should_run_in_parallel() == expectation class TestRunTask: