diff --git a/src/dsdk/mssql.py b/src/dsdk/mssql.py index 0c9ef63..138df1f 100644 --- a/src/dsdk/mssql.py +++ b/src/dsdk/mssql.py @@ -101,23 +101,10 @@ def check(self, cur, exceptions=(DatabaseError, InterfaceError)): raise RuntimeError(self.ERRORS, errors) logger.info(self.END) - @contextmanager - def commit(self) -> Generator[Any, None, None]: - """Commit.""" - with self.connect() as con: - try: - with con.cursor(as_dict=True) as cur: - yield cur - con.commit() - logger.info(self.COMMIT) - except BaseException: - con.rollback() - logger.info(self.ROLLBACK) - raise - @contextmanager def connect(self) -> Generator[Any, None, None]: """Connect.""" + # Replace return type with ContextManager[Any] when mypy is fixed. con = connect( server=self.host, user=self.username, @@ -134,15 +121,11 @@ def connect(self) -> Generator[Any, None, None]: logger.info(self.CLOSE) @contextmanager - def rollback(self) -> Generator[Any, None, None]: - """Rollback.""" - with self.connect() as con: - try: - with con.cursor(as_dict=True) as cur: - yield cur - finally: - con.rollback() - logger.info(self.ROLLBACK) + def cursor(self, con) -> Generator[Any, None, None]: + """Yield cursor that provides dicts.""" + # Replace return type with ContextManager[Any] when mypy is fixed. + with con.cursor(as_dict=True) as cur: + yield cur class AlchemyPersistor(Messages, BaseAbstractPersistor): @@ -154,6 +137,7 @@ def configure( cls, service: Service, parser ) -> Generator[None, None, None]: """Dependencies.""" + # Replace return type with ContextManager[None] when mypy is fixed. kwargs: Dict[str, Any] = {} for key, help_, inject in ( @@ -230,6 +214,7 @@ def check( @contextmanager def connect(self) -> Generator[Any, None, None]: + # Replace return type with ContextManager[Any] when mypy is fixed. """Connect.""" con = self.engine.connect() logger.info(self.OPEN) @@ -239,6 +224,13 @@ def connect(self) -> Generator[Any, None, None]: con.close() logger.info(self.CLOSE) + @contextmanager + def cursor(self, con) -> Generator[Any, None, None]: + # Replace return type with ContextManager[Any] when mypy is fixed. + """Yield a cursor that provides dicts.""" + with con.cursor() as cur: + yield cur + class Mixin(BaseMixin): """Mixin.""" @@ -254,6 +246,7 @@ def inject_arguments( self, parser: ArgumentParser ) -> Generator[None, None, None]: """Inject arguments.""" + # Replace return type with ContextManager[Any] when mypy is fixed. with self.mssql_cls.configure(self, parser): with super().inject_arguments(parser): yield @@ -274,6 +267,7 @@ def __init__( def inject_arguments( self, parser: ArgumentParser ) -> Generator[None, None, None]: + # Replace return type with ContextManager[None] when mypy is fixed. """Inject arguments.""" with self.mssql_cls.configure(self, parser): with super().inject_arguments(parser): diff --git a/src/dsdk/persistor.py b/src/dsdk/persistor.py index 320c33a..4ed87b1 100644 --- a/src/dsdk/persistor.py +++ b/src/dsdk/persistor.py @@ -106,9 +106,10 @@ def check(self, cur, exceptions): @contextmanager def commit(self) -> Generator[Any, None, None]: """Commit.""" + # Replace return type with ContextManager[Any] when mypy is fixed. with self.connect() as con: try: - with con.cursor() as cur: + with self.cursor(con) as cur: yield cur con.commit() logger.info(self.COMMIT) @@ -120,6 +121,13 @@ def commit(self) -> Generator[Any, None, None]: @contextmanager def connect(self) -> Generator[Any, None, None]: """Connect.""" + # Replace return type with ContextManager[Any] when mypy is fixed. + raise NotImplementedError() + + @contextmanager + def cursor(self, con): + """Yield a cursor that provides dicts.""" + # Replace return type with ContextManager[Any] when mypy is fixed. raise NotImplementedError() def extant(self, table: str) -> str: @@ -131,9 +139,10 @@ def extant(self, table: str) -> str: @contextmanager def rollback(self) -> Generator[Any, None, None]: """Rollback.""" + # Replace return type with ContextManager[Any] when mypy is fixed. with self.connect() as con: try: - with con.cursor() as cur: + with self.cursor(con) as cur: yield cur finally: con.rollback() @@ -149,6 +158,7 @@ def configure( cls, service: Service, parser ) -> Generator[None, None, None]: """Configure.""" + # Replace return type with ContextManager[None] when mypy is fixed. kwargs: Dict[str, Any] = {} for key, help_, inject in ( @@ -197,4 +207,5 @@ def __init__( # pylint: disable=too-many-arguments @contextmanager def connect(self) -> Generator[Any, None, None]: """Connect.""" + # Replace return type with ContextManager[Any] when mypy is fixed. raise NotImplementedError() diff --git a/src/dsdk/postgres.py b/src/dsdk/postgres.py index de8a7d9..7a77960 100644 --- a/src/dsdk/postgres.py +++ b/src/dsdk/postgres.py @@ -71,23 +71,10 @@ def check(self, cur, exceptions=(DatabaseError, InterfaceError)): """Check.""" super().check(cur, exceptions) - @contextmanager - def commit(self) -> Generator[Any, None, None]: - """Commit.""" - with self.connect() as con: - try: - with con.cursor(cursor_factory=DictCursor) as cur: - yield cur - con.commit() - logger.info(self.COMMIT) - except BaseException: - con.rollback() - logger.info(self.ROLLBACK) - raise - @contextmanager def connect(self) -> Generator[Any, None, None]: """Connect.""" + # Replace return type with ContextManager[Any] when mypy is fixed. # The `with ... as con:` formulation does not close the connection: # https://www.psycopg.org/docs/usage.html#with-statement con = self.retry_connect() @@ -98,6 +85,13 @@ def connect(self) -> Generator[Any, None, None]: con.close() logger.info(self.CLOSE) + @contextmanager + def cursor(self, con) -> Generator[Any, None, None]: + """Yield a cursor that provides dicts.""" + # Replace return type with ContextManager[Any] when mypy is fixed. + with con.cursor(cursor_factory=DictCursor) as cur: + yield cur + @retry((OperationalError,)) def retry_connect(self): """Retry connect.""" @@ -109,17 +103,6 @@ def retry_connect(self): dbname=self.database, ) - @contextmanager - def rollback(self) -> Generator[Any, None, None]: - """Rollback.""" - with self.connect() as con: - try: - with con.cursor(cursor_factory=DictCursor) as cur: - yield cur - finally: - con.rollback() - logger.info(self.ROLLBACK) - class Mixin(BaseMixin): """Mixin.""" @@ -137,6 +120,7 @@ def inject_arguments( self, parser: ArgumentParser ) -> Generator[None, None, None]: """Inject arguments.""" + # Replace return type with ContextManager[None] when mypy is fixed. with self.postgres_cls.configure(self, parser): with super().inject_arguments(parser): yield @@ -162,6 +146,7 @@ def open_run( self, microservice_version: str, model_version: str ) -> Generator[Run, None, None]: """Open run.""" + # Replace return type with ContextManager[Run] when mypy is fixed. sql = self.sql with self.commit() as cur: cur.execute(sql.schema) @@ -180,7 +165,9 @@ def open_run( row["duration"], ) break + yield run + with self.commit() as cur: cur.execute(sql.schema) if run.predictions is not None: