Skip to content

Commit

Permalink
Enable retry support for Microbatch models (#10751)
Browse files Browse the repository at this point in the history
* Add `PartialSuccess` status type and use it for microbatch models with mixed results

* Handle `PartialSuccess` in `interpret_run_result`

* Add `BatchResults` object to `BaseResult` and begin tracking during microbatch runs

* Ensure batch_results being propagated to `run_results` artifact

* Move `batch_results` from `BaseResult` class to `RunResult` class

* Move `BatchResults` and `BatchType` to separate arifacts file to avoid circular imports

In our next commit we're gonna modify `dbt/contracts/graph/nodes.py` to import the
`BatchType` as part of our work to implement dbt retry for microbatch model nodes.
Unfortunately, the import in `nodes.py` creates a circular dependency because
`dbt/artifacts/schemas/results.py` imports from `nodes.py` and `dbt/artifacts/schemas/run/v5/run.py`
imports from that `results.py`. Thus the new import creates a circular import. Now this
_shouldn't_ be necessary as nothing in artifacts should import from the rest of dbt-core.
However, we do. We should fix this, but this is also out of scope for this segement of work.

* Add `PartialSuccess` as a retry-able status, and use batches to retry microbatch models

* Fix BatchType type so that the first datetime is no longer Optional

* Ensure `PartialSuccess` causes skipping of downstream nodes

* Alter `PartialSuccess` status to be considered an error in `interpret_run_result`

* Update schemas and test artifacts to include new batch_results run results key

* Add functional test to check that 'dbt retry' retries 'PartialSuccess' models

* Update partition failure test to assert downstream models are skipped

* Improve `success`/`error`/`partial success` messaging for microbatch models

* Include `PartialSuccess` in status that `--fail-fast` counts as a failure

* Update `LogModelResult` to handle partial successes

* Update `EndOfRunSummary` to handle partial successes

* Cleanup TODO comment

* Raise a DbtInternalError if we get a batch run result without `batch_results`

* When running a microbatch model with supplied batches, force non full-refresh behavior

This is necessary because of retry. Say on the initial run the microbatch model
succeeds on 97% of it's batches. Then on retry it does the last 3%. If the retry
of the microbatch model executes in full refresh mode it _might_ blow away the
97% of work that has been done. This edge case seems to be adapter specific.

* Only pass batches to retry for microbatch model when there was a PartialSuccess

In the previous commit we made it so that retries of microbatch models wouldn't
run in full refresh mode when the microbatch model to retry has batches already
specified from the prior run. This is only problematic when the run being retried
was a full refresh AND all the batches for a given microbatch model failed. In
that case WE DO want to do a full refresh for the given microbatch model. To better
outline the problem, consider the following:

* a microbatch model had a begin of `2020-01-01` and has been running this way for awhile
* the begin config has changed to `2024-01-01` and  dbt run --full-refresh gets run
* every batch for an microbatch model fails
* on dbt retry the the relation is said to exist, and the now out of range data (2020-01-01 through 2023-12-31) is never purged

To avoid this, all we have to do is ONLY pass the batch information for partially successful microbatch
models. Note: microbatch models only have a partially successful status IFF they have both
successful and failed batches.

* Fix test_manifest unit tests to know about model 'batches' key

* Add some console output assertions to microbatch functional tests

* add batch_results: None to expected_run_results

* Add changie doc for microbatch retry functionality

* maintain protoc version 5.26.1

* Cleanup extraneous comment in LogModelResult

---------

Co-authored-by: Michelle Ark <[email protected]>
  • Loading branch information
QMalcolm and MichelleArk authored Sep 26, 2024
1 parent ac66f91 commit 1fd4d2e
Show file tree
Hide file tree
Showing 29 changed files with 479 additions and 102 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20240925-165002.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Enable `retry` support for microbatch models
time: 2024-09-25T16:50:02.105069-05:00
custom:
Author: QMalcolm MichelleArk
Issue: "10624"
21 changes: 21 additions & 0 deletions core/dbt/artifacts/schemas/batch_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from __future__ import annotations

from dataclasses import dataclass, field
from datetime import datetime
from typing import List, Tuple

from dbt_common.dataclass_schema import dbtClassMixin

BatchType = Tuple[datetime, datetime]


@dataclass
class BatchResults(dbtClassMixin):
successful: List[BatchType] = field(default_factory=list)
failed: List[BatchType] = field(default_factory=list)

def __add__(self, other: BatchResults) -> BatchResults:
return BatchResults(
successful=self.successful + other.successful,
failed=self.failed + other.failed,
)
2 changes: 2 additions & 0 deletions core/dbt/artifacts/schemas/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class NodeStatus(StrEnum):
Fail = "fail"
Warn = "warn"
Skipped = "skipped"
PartialSuccess = "partial success"
Pass = "pass"
RuntimeErr = "runtime error"

Expand All @@ -63,6 +64,7 @@ class RunStatus(StrEnum):
Success = NodeStatus.Success
Error = NodeStatus.Error
Skipped = NodeStatus.Skipped
PartialSuccess = NodeStatus.PartialSuccess


class TestStatus(StrEnum):
Expand Down
7 changes: 7 additions & 0 deletions core/dbt/artifacts/schemas/run/v5/run.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import copy
import threading
from dataclasses import dataclass, field
Expand All @@ -17,6 +19,7 @@
get_artifact_schema_version,
schema_version,
)
from dbt.artifacts.schemas.batch_results import BatchResults
from dbt.artifacts.schemas.results import (
BaseResult,
ExecutionResult,
Expand All @@ -34,6 +37,7 @@ class RunResult(NodeResult):
agate_table: Optional["agate.Table"] = field(
default=None, metadata={"serialize": lambda x: None, "deserialize": lambda x: None}
)
batch_results: Optional[BatchResults] = None

@property
def skipped(self):
Expand All @@ -51,6 +55,7 @@ def from_node(cls, node: ResultNode, status: RunStatus, message: Optional[str]):
node=node,
adapter_response={},
failures=None,
batch_results=None,
)


Expand All @@ -67,6 +72,7 @@ class RunResultOutput(BaseResult):
compiled: Optional[bool]
compiled_code: Optional[str]
relation_name: Optional[str]
batch_results: Optional[BatchResults] = None


def process_run_result(result: RunResult) -> RunResultOutput:
Expand All @@ -82,6 +88,7 @@ def process_run_result(result: RunResult) -> RunResultOutput:
message=result.message,
adapter_response=result.adapter_response,
failures=result.failures,
batch_results=result.batch_results,
compiled=result.node.compiled if compiled else None, # type:ignore
compiled_code=result.node.compiled_code if compiled else None, # type:ignore
relation_name=result.node.relation_name if compiled else None, # type:ignore
Expand Down
3 changes: 3 additions & 0 deletions core/dbt/contracts/graph/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from dbt.artifacts.resources import SqlOperation as SqlOperationResource
from dbt.artifacts.resources import TimeSpine
from dbt.artifacts.resources import UnitTestDefinition as UnitTestDefinitionResource
from dbt.artifacts.schemas.batch_results import BatchType
from dbt.contracts.graph.model_config import UnitTestNodeConfig
from dbt.contracts.graph.node_args import ModelNodeArgs
from dbt.contracts.graph.unparsed import (
Expand Down Expand Up @@ -442,6 +443,8 @@ def resource_class(cls) -> Type[HookNodeResource]:

@dataclass
class ModelNode(ModelResource, CompiledNode):
batches: Optional[List[BatchType]] = None

@classmethod
def resource_class(cls) -> Type[ModelResource]:
return ModelResource
Expand Down
1 change: 1 addition & 0 deletions core/dbt/events/core_types.proto
Original file line number Diff line number Diff line change
Expand Up @@ -1916,6 +1916,7 @@ message EndOfRunSummary {
int32 num_errors = 1;
int32 num_warnings = 2;
bool keyboard_interrupt = 3;
int32 num_partial_success = 4;
}

message EndOfRunSummaryMsg {
Expand Down
136 changes: 68 additions & 68 deletions core/dbt/events/core_types_pb2.py

Large diffs are not rendered by default.

11 changes: 10 additions & 1 deletion core/dbt/events/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1296,6 +1296,9 @@ def message(self) -> str:
if self.status == "error":
info = "ERROR creating"
status = red(self.status.upper())
elif "PARTIAL SUCCESS" in self.status:
info = "PARTIALLY created"
status = yellow(self.status.upper())
else:
info = "OK created"
status = green(self.status)
Expand Down Expand Up @@ -1860,10 +1863,16 @@ def code(self) -> str:
def message(self) -> str:
error_plural = pluralize(self.num_errors, "error")
warn_plural = pluralize(self.num_warnings, "warning")
partial_success_plural = pluralize(self.num_partial_success, "partial success")

if self.keyboard_interrupt:
message = yellow("Exited because of keyboard interrupt")
elif self.num_errors > 0:
message = red(f"Completed with {error_plural} and {warn_plural}:")
message = red(
f"Completed with {error_plural}, {partial_success_plural}, and {warn_plural}:"
)
elif self.num_partial_success > 0:
message = yellow(f"Completed with {partial_success_plural} and {warn_plural}")
elif self.num_warnings > 0:
message = yellow(f"Completed with {warn_plural}:")
else:
Expand Down
7 changes: 4 additions & 3 deletions core/dbt/materializations/incremental/microbatch.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from datetime import datetime, timedelta
from typing import List, Optional, Tuple
from typing import List, Optional

import pytz

from dbt.artifacts.resources.types import BatchSize
from dbt.artifacts.schemas.batch_results import BatchType
from dbt.contracts.graph.nodes import ModelNode, NodeConfig
from dbt.exceptions import DbtInternalError, DbtRuntimeError

Expand Down Expand Up @@ -68,7 +69,7 @@ def build_start_time(self, checkpoint: Optional[datetime]):

return start

def build_batches(self, start: datetime, end: datetime) -> List[Tuple[datetime, datetime]]:
def build_batches(self, start: datetime, end: datetime) -> List[BatchType]:
"""
Given a start and end datetime, builds a list of batches where each batch is
the size of the model's batch_size.
Expand All @@ -79,7 +80,7 @@ def build_batches(self, start: datetime, end: datetime) -> List[Tuple[datetime,
curr_batch_start, batch_size, 1
)

batches: List[Tuple[datetime, datetime]] = [(curr_batch_start, curr_batch_end)]
batches: List[BatchType] = [(curr_batch_start, curr_batch_end)]
while curr_batch_end <= end:
curr_batch_start = curr_batch_end
curr_batch_end = MicrobatchBuilder.offset_timestamp(curr_batch_start, batch_size, 1)
Expand Down
3 changes: 3 additions & 0 deletions core/dbt/task/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def _build_run_result(
agate_table=None,
adapter_response=None,
failures=None,
batch_results=None,
):
execution_time = time.time() - start_time
thread_id = threading.current_thread().name
Expand All @@ -242,6 +243,7 @@ def _build_run_result(
agate_table=agate_table,
adapter_response=adapter_response,
failures=failures,
batch_results=batch_results,
)

def error_result(self, node, message, start_time, timing_info):
Expand Down Expand Up @@ -272,6 +274,7 @@ def from_run_result(self, result, start_time, timing_info):
agate_table=result.agate_table,
adapter_response=result.adapter_response,
failures=result.failures,
batch_results=result.batch_results,
)

def compile_and_execute(self, manifest: Manifest, ctx: ExecutionContext):
Expand Down
8 changes: 7 additions & 1 deletion core/dbt/task/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def execute(self, compiled_node, manifest):
message="NO-OP",
adapter_response={},
failures=0,
batch_results=None,
agate_table=None,
)

Expand All @@ -65,7 +66,12 @@ class BuildTask(RunTask):
I.E. a resource of type Model is handled by the ModelRunner which is
imported as run_model_runner."""

MARK_DEPENDENT_ERRORS_STATUSES = [NodeStatus.Error, NodeStatus.Fail, NodeStatus.Skipped]
MARK_DEPENDENT_ERRORS_STATUSES = [
NodeStatus.Error,
NodeStatus.Fail,
NodeStatus.Skipped,
NodeStatus.PartialSuccess,
]

RUNNER_MAP = {
NodeType.Model: run_model_runner,
Expand Down
1 change: 1 addition & 0 deletions core/dbt/task/clone.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def _build_run_model_result(self, model, context):
message=message,
adapter_response=adapter_response,
failures=None,
batch_results=None,
)

def compile(self, manifest: Manifest):
Expand Down
1 change: 1 addition & 0 deletions core/dbt/task/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def execute(self, compiled_node, manifest):
message=None,
adapter_response={},
failures=None,
batch_results=None,
)

def compile(self, manifest: Manifest):
Expand Down
7 changes: 5 additions & 2 deletions core/dbt/task/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def get_counts(flat_nodes) -> str:


def interpret_run_result(result) -> str:
if result.status in (NodeStatus.Error, NodeStatus.Fail):
if result.status in (NodeStatus.Error, NodeStatus.Fail, NodeStatus.PartialSuccess):
return "error"
elif result.status == NodeStatus.Skipped:
return "skip"
Expand Down Expand Up @@ -136,7 +136,7 @@ def print_run_result_error(
def print_run_end_messages(
results, keyboard_interrupt: bool = False, groups: Optional[Dict[str, Group]] = None
) -> None:
errors, warnings = [], []
errors, warnings, partial_successes = [], [], []
for r in results:
if r.status in (NodeStatus.RuntimeErr, NodeStatus.Error, NodeStatus.Fail):
errors.append(r)
Expand All @@ -146,12 +146,15 @@ def print_run_end_messages(
errors.append(r)
elif r.status == NodeStatus.Warn:
warnings.append(r)
elif r.status == NodeStatus.PartialSuccess:
partial_successes.append(r)

fire_event(Formatting(""))
fire_event(
EndOfRunSummary(
num_errors=len(errors),
num_warnings=len(warnings),
num_partial_success=len(partial_successes),
keyboard_interrupt=keyboard_interrupt,
)
)
Expand Down
19 changes: 18 additions & 1 deletion core/dbt/task/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,13 @@
from dbt.task.test import TestTask
from dbt_common.exceptions import DbtRuntimeError

RETRYABLE_STATUSES = {NodeStatus.Error, NodeStatus.Fail, NodeStatus.Skipped, NodeStatus.RuntimeErr}
RETRYABLE_STATUSES = {
NodeStatus.Error,
NodeStatus.Fail,
NodeStatus.Skipped,
NodeStatus.RuntimeErr,
NodeStatus.PartialSuccess,
}
IGNORE_PARENT_FLAGS = {
"log_path",
"output_path",
Expand Down Expand Up @@ -123,6 +129,14 @@ def run(self):
]
)

batch_map = {
result.unique_id: result.batch_results.failed
for result in self.previous_results.results
if result.status == NodeStatus.PartialSuccess
and result.batch_results is not None
and len(result.batch_results.failed) > 0
}

class TaskWrapper(self.task_class):
def get_graph_queue(self):
new_graph = self.graph.get_subset_graph(unique_ids)
Expand All @@ -138,6 +152,9 @@ def get_graph_queue(self):
self.manifest,
)

if self.task_class == RunTask:
task.batch_map = batch_map

return_value = task.run()
return return_value

Expand Down
Loading

0 comments on commit 1fd4d2e

Please sign in to comment.