Skip to content

Commit

Permalink
feat(duckdb): enable backend create from dbapi con
Browse files Browse the repository at this point in the history
  • Loading branch information
deepyaman committed Jul 16, 2024
1 parent f50cbfc commit 6fa7d34
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 1 deletion.
15 changes: 14 additions & 1 deletion ibis/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,13 +767,22 @@ class BaseBackend(abc.ABC, _FileIOHandler):
def __init__(self, *args, **kwargs):
self._con_args: tuple[Any] = args
self._con_kwargs: dict[str, Any] = kwargs
self._can_reconnect: bool = True
# expression cache
self._query_cache = RefCountedCache(
populate=self._load_into_cache,
lookup=lambda name: self.table(name).op(),
finalize=self._clean_up_cached_table,
)

@property
def can_reconnect(self) -> bool:
return self._can_reconnect

@can_reconnect.setter
def can_reconnect(self, value) -> None:
self._can_reconnect = value

@property
@abc.abstractmethod
def dialect(self) -> sg.Dialect | None:
Expand Down Expand Up @@ -842,6 +851,7 @@ def connect(self, *args, **kwargs) -> BaseBackend:
"""
new_backend = self.__class__(*args, **kwargs)
new_backend.can_reconnect = True
new_backend.reconnect()
return new_backend

Expand All @@ -856,7 +866,10 @@ def _convert_kwargs(kwargs: MutableMapping) -> None:
# TODO(kszucs): should call self.connect(*self._con_args, **self._con_kwargs)
def reconnect(self) -> None:
"""Reconnect to the database already configured with connect."""
self.do_connect(*self._con_args, **self._con_kwargs)
if self.can_reconnect:
self.do_connect(*self._con_args, **self._con_kwargs)
else:
raise exc.IbisError("Cannot reconnect to unconfigured {self.name} backend")

def do_connect(self, *args, **kwargs) -> None:
"""Connect to database specified by `args` and `kwargs`."""
Expand Down
3 changes: 3 additions & 0 deletions ibis/backends/duckdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,9 @@ def do_connect(

self.con = duckdb.connect(str(database), config=config, read_only=read_only)

self._post_connect(extensions)

def _post_connect(self, extensions: Sequence[str] | None = None) -> None:
# Load any pre-specified extensions
if extensions is not None:
self._load_extensions(extensions)
Expand Down
8 changes: 8 additions & 0 deletions ibis/backends/sql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,14 @@ def truncate_table(
with self._safe_raw_sql(f"TRUNCATE TABLE {ident}"):
pass

@classmethod
def from_dbapi_connection(cls, con: Any, **kwargs: Any) -> BaseBackend:
new_backend = cls(**kwargs)
new_backend.can_reconnect = False
new_backend.con = con
new_backend._post_connect(**kwargs)
return new_backend

def disconnect(self):
# This is part of the Python DB-API specification so should work for
# _most_ sqlglot backends
Expand Down
5 changes: 5 additions & 0 deletions ibis/backends/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1764,3 +1764,8 @@ def test_insert_using_col_name_not_position(con, first_row, second_row, monkeypa
# Ideally we'd use a temp table for this test, but several backends don't
# support them and it's nice to know that data are being inserted correctly.
con.drop_table(table_name)


def test_from_dbapi_connection(con):
new_con = type(con).from_dbapi_connection(con.con, **con._con_kwargs)
assert {"astronauts", "batting", "diamonds"} <= set(new_con.list_tables())

0 comments on commit 6fa7d34

Please sign in to comment.