Skip to content

Commit

Permalink
Add cursor contextmanager, add comments to contextmanager return types
Browse files Browse the repository at this point in the history
  • Loading branch information
Jason Lubken committed Aug 24, 2020
1 parent 7a614fd commit fc1fa8f
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 50 deletions.
40 changes: 17 additions & 23 deletions src/dsdk/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,23 +101,10 @@ def check(self, cur, exceptions=(DatabaseError, InterfaceError)):
raise RuntimeError(self.ERRORS, errors)
logger.info(self.END)

@contextmanager
def commit(self) -> Generator[Any, None, None]:
"""Commit."""
with self.connect() as con:
try:
with con.cursor(as_dict=True) as cur:
yield cur
con.commit()
logger.info(self.COMMIT)
except BaseException:
con.rollback()
logger.info(self.ROLLBACK)
raise

@contextmanager
def connect(self) -> Generator[Any, None, None]:
"""Connect."""
# Replace return type with ContextManager[Any] when mypy is fixed.
con = connect(
server=self.host,
user=self.username,
Expand All @@ -134,15 +121,11 @@ def connect(self) -> Generator[Any, None, None]:
logger.info(self.CLOSE)

@contextmanager
def rollback(self) -> Generator[Any, None, None]:
"""Rollback."""
with self.connect() as con:
try:
with con.cursor(as_dict=True) as cur:
yield cur
finally:
con.rollback()
logger.info(self.ROLLBACK)
def cursor(self, con) -> Generator[Any, None, None]:
"""Yield cursor that provides dicts."""
# Replace return type with ContextManager[Any] when mypy is fixed.
with con.cursor(as_dict=True) as cur:
yield cur


class AlchemyPersistor(Messages, BaseAbstractPersistor):
Expand All @@ -154,6 +137,7 @@ def configure(
cls, service: Service, parser
) -> Generator[None, None, None]:
"""Dependencies."""
# Replace return type with ContextManager[None] when mypy is fixed.
kwargs: Dict[str, Any] = {}

for key, help_, inject in (
Expand Down Expand Up @@ -230,6 +214,7 @@ def check(

@contextmanager
def connect(self) -> Generator[Any, None, None]:
# Replace return type with ContextManager[Any] when mypy is fixed.
"""Connect."""
con = self.engine.connect()
logger.info(self.OPEN)
Expand All @@ -239,6 +224,13 @@ def connect(self) -> Generator[Any, None, None]:
con.close()
logger.info(self.CLOSE)

@contextmanager
def cursor(self, con) -> Generator[Any, None, None]:
# Replace return type with ContextManager[Any] when mypy is fixed.
"""Yield a cursor that provides dicts."""
with con.cursor() as cur:
yield cur


class Mixin(BaseMixin):
"""Mixin."""
Expand All @@ -254,6 +246,7 @@ def inject_arguments(
self, parser: ArgumentParser
) -> Generator[None, None, None]:
"""Inject arguments."""
# Replace return type with ContextManager[Any] when mypy is fixed.
with self.mssql_cls.configure(self, parser):
with super().inject_arguments(parser):
yield
Expand All @@ -274,6 +267,7 @@ def __init__(
def inject_arguments(
self, parser: ArgumentParser
) -> Generator[None, None, None]:
# Replace return type with ContextManager[None] when mypy is fixed.
"""Inject arguments."""
with self.mssql_cls.configure(self, parser):
with super().inject_arguments(parser):
Expand Down
15 changes: 13 additions & 2 deletions src/dsdk/persistor.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,10 @@ def check(self, cur, exceptions):
@contextmanager
def commit(self) -> Generator[Any, None, None]:
"""Commit."""
# Replace return type with ContextManager[Any] when mypy is fixed.
with self.connect() as con:
try:
with con.cursor() as cur:
with self.cursor(con) as cur:
yield cur
con.commit()
logger.info(self.COMMIT)
Expand All @@ -120,6 +121,13 @@ def commit(self) -> Generator[Any, None, None]:
@contextmanager
def connect(self) -> Generator[Any, None, None]:
"""Connect."""
# Replace return type with ContextManager[Any] when mypy is fixed.
raise NotImplementedError()

@contextmanager
def cursor(self, con):
"""Yield a cursor that provides dicts."""
# Replace return type with ContextManager[Any] when mypy is fixed.
raise NotImplementedError()

def extant(self, table: str) -> str:
Expand All @@ -131,9 +139,10 @@ def extant(self, table: str) -> str:
@contextmanager
def rollback(self) -> Generator[Any, None, None]:
"""Rollback."""
# Replace return type with ContextManager[Any] when mypy is fixed.
with self.connect() as con:
try:
with con.cursor() as cur:
with self.cursor(con) as cur:
yield cur
finally:
con.rollback()
Expand All @@ -149,6 +158,7 @@ def configure(
cls, service: Service, parser
) -> Generator[None, None, None]:
"""Configure."""
# Replace return type with ContextManager[None] when mypy is fixed.
kwargs: Dict[str, Any] = {}

for key, help_, inject in (
Expand Down Expand Up @@ -197,4 +207,5 @@ def __init__( # pylint: disable=too-many-arguments
@contextmanager
def connect(self) -> Generator[Any, None, None]:
"""Connect."""
# Replace return type with ContextManager[Any] when mypy is fixed.
raise NotImplementedError()
37 changes: 12 additions & 25 deletions src/dsdk/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,23 +71,10 @@ def check(self, cur, exceptions=(DatabaseError, InterfaceError)):
"""Check."""
super().check(cur, exceptions)

@contextmanager
def commit(self) -> Generator[Any, None, None]:
"""Commit."""
with self.connect() as con:
try:
with con.cursor(cursor_factory=DictCursor) as cur:
yield cur
con.commit()
logger.info(self.COMMIT)
except BaseException:
con.rollback()
logger.info(self.ROLLBACK)
raise

@contextmanager
def connect(self) -> Generator[Any, None, None]:
"""Connect."""
# Replace return type with ContextManager[Any] when mypy is fixed.
# The `with ... as con:` formulation does not close the connection:
# https://www.psycopg.org/docs/usage.html#with-statement
con = self.retry_connect()
Expand All @@ -98,6 +85,13 @@ def connect(self) -> Generator[Any, None, None]:
con.close()
logger.info(self.CLOSE)

@contextmanager
def cursor(self, con) -> Generator[Any, None, None]:
"""Yield a cursor that provides dicts."""
# Replace return type with ContextManager[Any] when mypy is fixed.
with con.cursor(cursor_factory=DictCursor) as cur:
yield cur

@retry((OperationalError,))
def retry_connect(self):
"""Retry connect."""
Expand All @@ -109,17 +103,6 @@ def retry_connect(self):
dbname=self.database,
)

@contextmanager
def rollback(self) -> Generator[Any, None, None]:
"""Rollback."""
with self.connect() as con:
try:
with con.cursor(cursor_factory=DictCursor) as cur:
yield cur
finally:
con.rollback()
logger.info(self.ROLLBACK)


class Mixin(BaseMixin):
"""Mixin."""
Expand All @@ -137,6 +120,7 @@ def inject_arguments(
self, parser: ArgumentParser
) -> Generator[None, None, None]:
"""Inject arguments."""
# Replace return type with ContextManager[None] when mypy is fixed.
with self.postgres_cls.configure(self, parser):
with super().inject_arguments(parser):
yield
Expand All @@ -162,6 +146,7 @@ def open_run(
self, microservice_version: str, model_version: str
) -> Generator[Run, None, None]:
"""Open run."""
# Replace return type with ContextManager[Run] when mypy is fixed.
sql = self.sql
with self.commit() as cur:
cur.execute(sql.schema)
Expand All @@ -180,7 +165,9 @@ def open_run(
row["duration"],
)
break

yield run

with self.commit() as cur:
cur.execute(sql.schema)
if run.predictions is not None:
Expand Down

0 comments on commit fc1fa8f

Please sign in to comment.