From 6fa7d34841896435752a3a2c0212339f2be171c2 Mon Sep 17 00:00:00 2001 From: Deepyaman Datta Date: Mon, 15 Jul 2024 22:54:06 -0600 Subject: [PATCH] feat(duckdb): enable backend create from dbapi con --- ibis/backends/__init__.py | 15 ++++++++++++++- ibis/backends/duckdb/__init__.py | 3 +++ ibis/backends/sql/__init__.py | 8 ++++++++ ibis/backends/tests/test_client.py | 5 +++++ 4 files changed, 30 insertions(+), 1 deletion(-) diff --git a/ibis/backends/__init__.py b/ibis/backends/__init__.py index bfc3f461adb4f..b2e3fc559f695 100644 --- a/ibis/backends/__init__.py +++ b/ibis/backends/__init__.py @@ -767,6 +767,7 @@ class BaseBackend(abc.ABC, _FileIOHandler): def __init__(self, *args, **kwargs): self._con_args: tuple[Any] = args self._con_kwargs: dict[str, Any] = kwargs + self._can_reconnect: bool = True # expression cache self._query_cache = RefCountedCache( populate=self._load_into_cache, @@ -774,6 +775,14 @@ def __init__(self, *args, **kwargs): finalize=self._clean_up_cached_table, ) + @property + def can_reconnect(self) -> bool: + return self._can_reconnect + + @can_reconnect.setter + def can_reconnect(self, value) -> None: + self._can_reconnect = value + @property @abc.abstractmethod def dialect(self) -> sg.Dialect | None: @@ -842,6 +851,7 @@ def connect(self, *args, **kwargs) -> BaseBackend: """ new_backend = self.__class__(*args, **kwargs) + new_backend.can_reconnect = True new_backend.reconnect() return new_backend @@ -856,7 +866,10 @@ def _convert_kwargs(kwargs: MutableMapping) -> None: # TODO(kszucs): should call self.connect(*self._con_args, **self._con_kwargs) def reconnect(self) -> None: """Reconnect to the database already configured with connect.""" - self.do_connect(*self._con_args, **self._con_kwargs) + if self.can_reconnect: + self.do_connect(*self._con_args, **self._con_kwargs) + else: + raise exc.IbisError("Cannot reconnect to unconfigured {self.name} backend") def do_connect(self, *args, **kwargs) -> None: """Connect to database specified by `args` and `kwargs`.""" diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index 8f23c3bce534c..08c5c2a22ce87 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -473,6 +473,9 @@ def do_connect( self.con = duckdb.connect(str(database), config=config, read_only=read_only) + self._post_connect(extensions) + + def _post_connect(self, extensions: Sequence[str] | None = None) -> None: # Load any pre-specified extensions if extensions is not None: self._load_extensions(extensions) diff --git a/ibis/backends/sql/__init__.py b/ibis/backends/sql/__init__.py index a8356a6945470..996150f9bf194 100644 --- a/ibis/backends/sql/__init__.py +++ b/ibis/backends/sql/__init__.py @@ -547,6 +547,14 @@ def truncate_table( with self._safe_raw_sql(f"TRUNCATE TABLE {ident}"): pass + @classmethod + def from_dbapi_connection(cls, con: Any, **kwargs: Any) -> BaseBackend: + new_backend = cls(**kwargs) + new_backend.can_reconnect = False + new_backend.con = con + new_backend._post_connect(**kwargs) + return new_backend + def disconnect(self): # This is part of the Python DB-API specification so should work for # _most_ sqlglot backends diff --git a/ibis/backends/tests/test_client.py b/ibis/backends/tests/test_client.py index 1bbadd4c89572..6d112cdb8a695 100644 --- a/ibis/backends/tests/test_client.py +++ b/ibis/backends/tests/test_client.py @@ -1764,3 +1764,8 @@ def test_insert_using_col_name_not_position(con, first_row, second_row, monkeypa # Ideally we'd use a temp table for this test, but several backends don't # support them and it's nice to know that data are being inserted correctly. con.drop_table(table_name) + + +def test_from_dbapi_connection(con): + new_con = type(con).from_dbapi_connection(con.con, **con._con_kwargs) + assert {"astronauts", "batting", "diamonds"} <= set(new_con.list_tables())