Skip to content

Commit

Permalink
refactor: internalize parallel to RunTask._submit_batch
Browse files Browse the repository at this point in the history
  • Loading branch information
MichelleArk committed Nov 28, 2024
1 parent 32002ea commit 37dbd11
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 30 deletions.
39 changes: 13 additions & 26 deletions core/dbt/task/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,15 +602,15 @@ def _has_relation(self, model) -> bool:
)
return relation is not None

def _should_run_in_parallel(
self,
relation_exists: bool,
) -> bool:
def should_run_in_parallel(self) -> bool:
if not self.adapter.supports(Capability.MicrobatchConcurrency):
run_in_parallel = False
elif not relation_exists:
elif not self.relation_exists:
# If the relation doesn't exist, we can't run in parallel
run_in_parallel = False
elif self.batch_idx == 0 or self.batch_idx == len(self.batches) - 1:
# First and last batch don't run in parallel
run_in_parallel = False

Check warning on line 613 in core/dbt/task/run.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/task/run.py#L613

Added line #L613 was not covered by tests
elif self.node.config.concurrent_batches is not None:
# If the relation exists and the `concurrent_batches` config isn't None, use the config value
run_in_parallel = self.node.config.concurrent_batches
Expand Down Expand Up @@ -705,9 +705,7 @@ def handle_microbatch_model(
) -> RunResult:
# Initial run computes batch metadata
result = self.call_runner(runner)
batches = runner.batches
node = runner.node
relation_exists = runner.relation_exists
batches, node, relation_exists = runner.batches, runner.node, runner.relation_exists

# Return early if model should be skipped, or there are no batches to execute
if result.status == RunStatus.Skipped:
Expand All @@ -717,30 +715,20 @@ def handle_microbatch_model(

batch_results: List[RunResult] = []
batch_idx = 0

# Run first batch runs in serial
relation_exists = self._submit_batch(
node, relation_exists, batches, batch_idx, batch_results, pool, parallel=False
)
batch_idx += 1

# Subsequent batches can be run in parallel
# Run all batches except last batch, in parallel if possible
while batch_idx < len(runner.batches) - 1:
parallel = runner._should_run_in_parallel(relation_exists)
relation_exists = self._submit_batch(
node, relation_exists, batches, batch_idx, batch_results, pool, parallel
node, relation_exists, batches, batch_idx, batch_results, pool
)
batch_idx += 1

# Wait until all submitted batches have completed
while len(batch_results) != batch_idx:
pass
# Final batch runs once all others complete to ensure post_hook runs at the end
self._submit_batch(node, relation_exists, batches, batch_idx, batch_results, pool)

# Final batch runs in serial
self._submit_batch(
node, relation_exists, batches, batch_idx, batch_results, pool, parallel=False
)

# Finalize run: merge results, track model run, and print final result line
runner.merge_batch_results(result, batch_results)
track_model_run(runner.node_index, runner.num_nodes, result, adapter=runner.adapter)
runner.print_result_line(result)
Expand All @@ -755,7 +743,6 @@ def _submit_batch(
batch_idx: int,
batch_results: List[RunResult],
pool: ThreadPool,
parallel: bool,
):
node_copy = deepcopy(node)
# Only run pre_hook(s) for first batch
Expand All @@ -764,14 +751,14 @@ def _submit_batch(
# Only run post_hook(s) for last batch
elif batch_idx != len(batches) - 1:
node_copy.config.post_hook = []

batch_runner = self.get_runner(node_copy)
assert isinstance(batch_runner, MicrobatchModelRunner)
batch_runner.set_batch_idx(batch_idx)
batch_runner.set_relation_exists(relation_exists)
batch_runner.set_batches(batches)

if parallel:
if batch_runner.should_run_in_parallel():
fire_event(
MicrobatchExecutionDebug(
msg=f"{batch_runner.describe_batch} is being run concurrently"
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/microbatch/test_microbatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,7 +875,7 @@ def batch_exc_catcher(self) -> EventCatcher:
def test_microbatch(
self, mocker: MockerFixture, project, batch_exc_catcher: EventCatcher
) -> None:
mocked_srip = mocker.patch("dbt.task.run.MicrobatchModelRunner._should_run_in_parallel")
mocked_srip = mocker.patch("dbt.task.run.MicrobatchModelRunner.should_run_in_parallel")

# Should be run in parallel
mocked_srip.return_value = True
Expand Down
8 changes: 5 additions & 3 deletions tests/unit/task/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ class Relation:
(False, False, False, True, False),
],
)
def test__should_run_in_parallel(
def test_should_run_in_parallel(
self,
mocker: MockerFixture,
model_runner: MicrobatchModelRunner,
Expand All @@ -276,11 +276,13 @@ def test__should_run_in_parallel(
) -> None:
model_runner.node._has_this = has_this
model_runner.node.config = ModelConfig(concurrent_batches=concurrent_batches)
model_runner.set_relation_exists(has_relation)

mocked_supports = mocker.patch.object(model_runner.adapter, "supports")
mocked_supports.return_value = adapter_microbatch_concurrency

# Assert result of _should_run_in_parallel
assert model_runner._should_run_in_parallel(has_relation) == expectation
# Assert result of should_run_in_parallel
assert model_runner.should_run_in_parallel() == expectation


class TestRunTask:
Expand Down

0 comments on commit 37dbd11

Please sign in to comment.