diff --git a/.changes/unreleased/Under the Hood-20241204-185912.yaml b/.changes/unreleased/Under the Hood-20241204-185912.yaml new file mode 100644 index 00000000..5c731703 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20241204-185912.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Add retry logic for retryable exceptions. +time: 2024-12-04T18:59:12.48816-08:00 +custom: + Author: 'colin-rogers-dbt ' + Issue: "368" diff --git a/dbt/adapters/sql/connections.py b/dbt/adapters/sql/connections.py index baccddc9..0c6797cf 100644 --- a/dbt/adapters/sql/connections.py +++ b/dbt/adapters/sql/connections.py @@ -1,13 +1,26 @@ import abc import time -from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, TYPE_CHECKING +from typing import ( + Any, + Dict, + Iterable, + Iterator, + List, + Optional, + Tuple, + TYPE_CHECKING, + Callable, + Type, + Union, +) from dbt_common.events.contextvars import get_node_info from dbt_common.events.functions import fire_event -from dbt_common.exceptions import DbtInternalError, NotImplementedError +from dbt_common.exceptions import DbtInternalError, NotImplementedError, DbtRuntimeError from dbt_common.utils import cast_to_str from dbt.adapters.base import BaseConnectionManager +from dbt.adapters.base.connections import SleepTime from dbt.adapters.contracts.connection import ( AdapterResponse, Connection, @@ -18,6 +31,7 @@ SQLCommit, SQLQuery, SQLQueryStatus, + AdapterEventDebug, ) if TYPE_CHECKING: @@ -61,6 +75,9 @@ def add_query( auto_begin: bool = True, bindings: Optional[Any] = None, abridge_sql_log: bool = False, + retryable_exceptions: Iterable[Type[Exception]] = [], + retry_limit: int = 1, + retry_timeout: Union[Callable[[int], SleepTime], SleepTime] = 1, ) -> Tuple[Connection, Any]: connection = self.get_thread_connection() if auto_begin and connection.transaction_open is False: @@ -90,7 +107,14 @@ def add_query( pre = time.perf_counter() cursor = connection.handle.cursor() - cursor.execute(sql, bindings) + self._retryable_cursor_execute( + execute_fn=cursor.execute, + sql=sql, + bindings=bindings, + retryable_exceptions=retryable_exceptions, + retry_limit=retry_limit, + retry_timeout=retry_timeout, + ) result = self.get_response(cursor) @@ -199,3 +223,45 @@ def commit(self): connection.transaction_open = False return connection + + def _retryable_cursor_execute( + self, + execute_fn: Callable, + sql: str, + bindings: Optional[Any] = None, + retryable_exceptions: Iterable[Type[Exception]] = [], + retry_limit: int = 1, + retry_timeout: Union[Callable[[int], SleepTime], SleepTime] = 1, + _attempts: int = 0, + ) -> None: + timeout = retry_timeout(_attempts) if callable(retry_timeout) else retry_timeout + if timeout < 0: + raise DbtRuntimeError("retry_timeout cannot be negative or return a negative time.") + + try: + execute_fn(sql, bindings) + + except tuple(retryable_exceptions) as e: + retry_limit -= 1 + if retry_limit <= 0: + raise e + fire_event( + AdapterEventDebug( + message=f"Got a retryable error {type(e)} when attempting to execute a query.\n" + f"{retry_limit} attempts remaining. Retrying in {timeout} seconds.\n" + f"Error:\n{e}" + ) + ) + + time.sleep(timeout) + return self._retryable_cursor_execute( + execute_fn=execute_fn, + sql=sql, + retry_limit=retry_limit - 1, + retry_timeout=retry_timeout, + retryable_exceptions=retryable_exceptions, + _attempts=_attempts + 1, + ) + + except Exception as e: + raise e