Skip to content

Commit

Permalink
Apply the update to SQLAlchemy 2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrunner committed Feb 2, 2023
1 parent 46a3f07 commit 1770ecf
Show file tree
Hide file tree
Showing 11 changed files with 153 additions and 84 deletions.
1 change: 1 addition & 0 deletions .prospector.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ pycodestyle:
disable:
- E203 # Whitespace before ':', duplicated with black, with error in array
- E722 # do not use bare 'except', duplicated with pylint
- E261 # at least two spaces before inline comment, duplicated with black

pydocstyle:
disable:
Expand Down
75 changes: 44 additions & 31 deletions c2cwsgiutils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def setup_session(
force_master: Optional[Iterable[str]] = None,
force_slave: Optional[Iterable[str]] = None,
) -> Tuple[
Union[sqlalchemy.orm.Session, sqlalchemy.orm.scoped_session],
Union[sqlalchemy.orm.Session, sqlalchemy.orm.scoped_session[sqlalchemy.orm.Session]],
sqlalchemy.engine.Engine,
sqlalchemy.engine.Engine,
]:
Expand Down Expand Up @@ -67,7 +67,7 @@ def setup_session(
slave_prefix = master_prefix
settings = config.registry.settings
rw_engine = sqlalchemy.engine_from_config(settings, master_prefix + ".")
rw_engine.c2c_name = master_prefix
rw_engine.c2c_name = master_prefix # type: ignore
factory = sqlalchemy.orm.sessionmaker(bind=rw_engine)
register(factory)
db_session = sqlalchemy.orm.scoped_session(factory)
Expand All @@ -76,14 +76,14 @@ def setup_session(
if settings[master_prefix + ".url"] != settings.get(slave_prefix + ".url"):
LOG.info("Using a slave DB for reading %s", master_prefix)
ro_engine = sqlalchemy.engine_from_config(config.get_settings(), slave_prefix + ".")
ro_engine.c2c_name = slave_prefix
ro_engine.c2c_name = slave_prefix # type: ignore
tween_name = master_prefix.replace(".", "_")
_add_tween(config, tween_name, db_session, force_master, force_slave)
else:
ro_engine = rw_engine

db_session.c2c_rw_bind = rw_engine
db_session.c2c_ro_bind = ro_engine
db_session.c2c_rw_bind = rw_engine # type: ignore
db_session.c2c_ro_bind = ro_engine # type: ignore
return db_session, rw_engine, ro_engine


Expand All @@ -95,7 +95,7 @@ def create_session(
force_master: Optional[Iterable[str]] = None,
force_slave: Optional[Iterable[str]] = None,
**engine_config: Any,
) -> Union[sqlalchemy.orm.Session, sqlalchemy.orm.scoped_session]:
) -> Union[sqlalchemy.orm.Session, sqlalchemy.orm.scoped_session[sqlalchemy.orm.Session]]:
"""
Create a SQLAlchemy session.
Expand Down Expand Up @@ -133,21 +133,21 @@ def create_session(
LOG.info("Using a slave DB for reading %s", name)
ro_engine = sqlalchemy.create_engine(slave_url, **engine_config)
_add_tween(config, name, db_session, force_master, force_slave)
rw_engine.c2c_name = name + "_master"
ro_engine.c2c_name = name + "_slave"
rw_engine.c2c_name = name + "_master" # type: ignore
ro_engine.c2c_name = name + "_slave" # type: ignore
else:
rw_engine.c2c_name = name
rw_engine.c2c_name = name # type: ignore
ro_engine = rw_engine

db_session.c2c_rw_bind = rw_engine
db_session.c2c_ro_bind = ro_engine
db_session.c2c_rw_bind = rw_engine # type: ignore
db_session.c2c_ro_bind = ro_engine # type: ignore
return db_session


def _add_tween(
config: pyramid.config.Configurator,
name: str,
db_session: Union[sqlalchemy.orm.Session, sqlalchemy.orm.scoped_session],
db_session: sqlalchemy.orm.scoped_session[sqlalchemy.orm.Session],
force_master: Optional[Iterable[str]],
force_slave: Optional[Iterable[str]],
) -> None:
Expand Down Expand Up @@ -176,11 +176,19 @@ def db_chooser_tween(request: pyramid.request.Request) -> Any:
not has_force_master
and (request.method in ("GET", "OPTIONS") or any(r.match(method_path) for r in slave_paths))
):
LOG.debug("Using %s database for: %s", db_session.c2c_ro_bind.c2c_name, method_path)
session.bind = db_session.c2c_ro_bind
LOG.debug(
"Using %s database for: %s",
db_session.c2c_ro_bind.c2c_name, # type: ignore
method_path,
)
session.bind = db_session.c2c_ro_bind # type: ignore
else:
LOG.debug("Using %s database for: %s", db_session.c2c_rw_bind.c2c_name, method_path)
session.bind = db_session.c2c_rw_bind
LOG.debug(
"Using %s database for: %s",
db_session.c2c_rw_bind.c2c_name, # type: ignore
method_path,
)
session.bind = db_session.c2c_rw_bind # type: ignore

try:
return handler(request)
Expand All @@ -193,7 +201,7 @@ def db_chooser_tween(request: pyramid.request.Request) -> Any:
config.add_tween("c2cwsgiutils.db.tweens." + name, over="pyramid_tm.tm_tween_factory")


class SessionFactory(sessionmaker): # type: ignore
class SessionFactory(sessionmaker[sqlalchemy.orm.Session]): # pylint: disable=unsubscriptable-object
"""The custom session factory that manage the read only and read write sessions."""

def __init__(
Expand All @@ -213,18 +221,18 @@ def __init__(

def engine_name(self, readwrite: bool) -> str:
if readwrite:
return cast(str, self.rw_engine.c2c_name)
return cast(str, self.ro_engine.c2c_name)
return cast(str, self.rw_engine.c2c_name) # type: ignore
return cast(str, self.ro_engine.c2c_name) # type: ignore

def __call__(
def __call__( # type: ignore
self, request: Optional[pyramid.request.Request], readwrite: Optional[bool] = None, **local_kw: Any
) -> sqlalchemy.orm.Session:
) -> sqlalchemy.orm.scoped_session[sqlalchemy.orm.Session]:
if readwrite is not None:
if readwrite and not force_readonly:
LOG.debug("Using %s database", self.rw_engine.c2c_name)
LOG.debug("Using %s database", self.rw_engine.c2c_name) # type: ignore
self.configure(bind=self.rw_engine)
else:
LOG.debug("Using %s database", self.ro_engine.c2c_name)
LOG.debug("Using %s database", self.ro_engine.c2c_name) # type: ignore
self.configure(bind=self.ro_engine)
else:
assert request is not None
Expand All @@ -237,12 +245,12 @@ def __call__(
or any(r.match(method_path) for r in self.slave_paths)
)
):
LOG.debug("Using %s database for: %s", self.ro_engine.c2c_name, method_path)
LOG.debug("Using %s database for: %s", self.ro_engine.c2c_name, method_path) # type: ignore
self.configure(bind=self.ro_engine)
else:
LOG.debug("Using %s database for: %s", self.rw_engine.c2c_name, method_path)
LOG.debug("Using %s database for: %s", self.rw_engine.c2c_name, method_path) # type: ignore
self.configure(bind=self.rw_engine)
return super().__call__(**local_kw)
return super().__call__(**local_kw) # type: ignore


def get_engine(
Expand All @@ -252,15 +260,17 @@ def get_engine(
return engine_from_config(settings, prefix)


def get_session_factory(engine: sqlalchemy.engine.Engine) -> sessionmaker:
def get_session_factory(
engine: sqlalchemy.engine.Engine,
) -> sessionmaker[sqlalchemy.orm.Session]: # pylint: disable=unsubscriptable-object
"""Get the session factory from the engine."""
factory = sessionmaker()
factory.configure(bind=engine)
return factory


def get_tm_session(
session_factory: sessionmaker,
session_factory: sessionmaker[sqlalchemy.orm.Session], # pylint: disable=unsubscriptable-object
transaction_manager: transaction.TransactionManager,
) -> sqlalchemy.orm.Session:
"""
Expand Down Expand Up @@ -319,6 +329,7 @@ def get_tm_session(
request = dbsession.info["request"]
"""
dbsession = session_factory()
assert isinstance(dbsession, sqlalchemy.orm.Session)
zope.sqlalchemy.register(dbsession, transaction_manager=transaction_manager)
return dbsession

Expand All @@ -327,7 +338,7 @@ def get_tm_session_pyramid(
session_factory: SessionFactory,
transaction_manager: transaction.TransactionManager,
request: pyramid.request.Request,
) -> sqlalchemy.orm.Session:
) -> sqlalchemy.orm.scoped_session[sqlalchemy.orm.Session]:
"""
Get a ``sqlalchemy.orm.Session`` instance backed by a transaction.
Expand Down Expand Up @@ -367,13 +378,13 @@ def init(
dbengine = settings.get("dbengine")
if not dbengine:
rw_engine = get_engine(settings, master_prefix + ".")
rw_engine.c2c_name = master_prefix
rw_engine.c2c_name = master_prefix # type: ignore

# Setup a slave DB connection and add a tween to use it.
if slave_prefix and settings[master_prefix + ".url"] != settings.get(slave_prefix + ".url"):
LOG.info("Using a slave DB for reading %s", master_prefix)
ro_engine = get_engine(config.get_settings(), slave_prefix + ".")
ro_engine.c2c_name = slave_prefix
ro_engine.c2c_name = slave_prefix # type: ignore
else:
ro_engine = rw_engine
else:
Expand All @@ -389,6 +400,8 @@ def dbsession(request: pyramid.request.Request) -> sqlalchemy.orm.Session:
if dbsession is None:
# request.tm is the transaction manager used by pyramid_tm
dbsession = get_tm_session_pyramid(session_factory, request.tm, request=request)
assert dbsession is not None
assert isinstance(dbsession, sqlalchemy.orm.Session), type(dbsession)
return dbsession

config.add_request_method(dbsession, reify=True)
Expand Down
9 changes: 5 additions & 4 deletions c2cwsgiutils/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,11 @@ def _include_dev_details(request: pyramid.request.Request) -> bool:
def _integrity_error(
exception: sqlalchemy.exc.StatementError, request: pyramid.request.Request
) -> pyramid.response.Response:
def reduce_info_sent(e: sqlalchemy.exc.StatementError) -> None:
# remove details (SQL statement and links to SQLAlchemy) from the error
e.statement = None
e.code = None
def reduce_info_sent(e: Exception) -> None:
if isinstance(e, sqlalchemy.exc.StatementError):
# remove details (SQL statement and links to SQLAlchemy) from the error
e.statement = None
e.code = None

return _do_error(request, 400, exception, reduce_info_sent=reduce_info_sent)

Expand Down
57 changes: 34 additions & 23 deletions c2cwsgiutils/health_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import requests
import sqlalchemy.engine
import sqlalchemy.orm
import sqlalchemy.sql
from pyramid.httpexceptions import HTTPNotFound

import c2cwsgiutils.db
Expand Down Expand Up @@ -58,7 +59,7 @@ class _Binding:
def name(self) -> str:
raise NotImplementedError()

def __enter__(self) -> sqlalchemy.orm.Session:
def __enter__(self) -> sqlalchemy.orm.scoped_session[sqlalchemy.orm.Session]:
raise NotImplementedError()

def __exit__(
Expand All @@ -78,21 +79,23 @@ def __init__(self, session: c2cwsgiutils.db.SessionFactory, readwrite: bool):
def name(self) -> str:
return self.session.engine_name(self.readwrite)

def __enter__(self) -> sqlalchemy.orm.Session:
def __enter__(self) -> sqlalchemy.orm.scoped_session[sqlalchemy.orm.Session]:
return self.session(None, self.readwrite)


class _OldBinding(_Binding):
def __init__(self, session: sqlalchemy.orm.scoping.scoped_session, engine: sqlalchemy.engine.Engine):
def __init__(
self, session: sqlalchemy.orm.scoped_session[sqlalchemy.orm.Session], engine: sqlalchemy.engine.Engine
):
self.session = session
self.engine = engine
self.prev_bind = None

def name(self) -> str:
return cast(str, self.engine.c2c_name)
return cast(str, self.engine.c2c_name) # type: ignore

def __enter__(self) -> sqlalchemy.orm.Session:
self.prev_bind = self.session.bind
def __enter__(self) -> sqlalchemy.orm.scoped_session[sqlalchemy.orm.Session]:
self.prev_bind = self.session.bind # type: ignore
self.session.bind = self.engine
return self.session

Expand All @@ -107,7 +110,7 @@ def __exit__(


def _get_binding_class(
session: Union[sqlalchemy.orm.scoping.scoped_session, c2cwsgiutils.db.SessionFactory],
session: Union[sqlalchemy.orm.scoped_session[sqlalchemy.orm.Session], c2cwsgiutils.db.SessionFactory],
ro_engin: sqlalchemy.engine.Engine,
rw_engin: sqlalchemy.engine.Engine,
readwrite: bool,
Expand All @@ -119,15 +122,15 @@ def _get_binding_class(


def _get_bindings(
session: Union[sqlalchemy.orm.scoping.scoped_session, c2cwsgiutils.db.SessionFactory],
session: Union[sqlalchemy.orm.scoped_session[sqlalchemy.orm.Session], c2cwsgiutils.db.SessionFactory],
engine_type: EngineType,
) -> List[sqlalchemy.engine.Engine]:
) -> List[_Binding]:
if isinstance(session, c2cwsgiutils.db.SessionFactory):
ro_engin = session.ro_engine
rw_engin = session.rw_engine
else:
ro_engin = session.c2c_ro_bind
rw_engin = session.c2c_rw_bind
ro_engin = session.c2c_ro_bind # type: ignore
rw_engin = session.c2c_rw_bind # type: ignore

if rw_engin == ro_engin:
engine_type = EngineType.WRITE_ONLY
Expand Down Expand Up @@ -192,8 +195,8 @@ def __init__(self, config: pyramid.config.Configurator) -> None:

def add_db_session_check(
self,
session: Union[sqlalchemy.orm.scoping.scoped_session, c2cwsgiutils.db.SessionFactory],
query_cb: Optional[Callable[[sqlalchemy.orm.scoping.scoped_session], Any]] = None,
session: Union[sqlalchemy.orm.scoped_session[sqlalchemy.orm.Session], c2cwsgiutils.db.SessionFactory],
query_cb: Optional[Callable[[sqlalchemy.orm.scoped_session[sqlalchemy.orm.Session]], Any]] = None,
at_least_one_model: Optional[object] = None,
level: int = 1,
engine_type: EngineType = EngineType.READ_AND_WRITE,
Expand All @@ -220,7 +223,7 @@ def add_db_session_check(

def add_alembic_check(
self,
session: Union[sqlalchemy.orm.scoping.scoped_session, c2cwsgiutils.db.SessionFactory],
session: sqlalchemy.orm.scoped_session[sqlalchemy.orm.Session],
alembic_ini_path: str,
level: int = 2,
name: str = "alembic",
Expand Down Expand Up @@ -249,17 +252,21 @@ def add_alembic_check(

if version_schema is None:
version_schema = config.get(name, "version_table_schema", fallback="public")
assert version_schema

if version_table is None:
version_table = config.get(name, "version_table", fallback="alembic_version")
assert version_table

class _Check:
def __init__(self, session: sqlalchemy.orm.scoping.scoped_session) -> None:
def __init__(self, session: sqlalchemy.orm.scoped_session[sqlalchemy.orm.Session]) -> None:
self.session = session

def __call__(self, request: pyramid.request.Request) -> str:
assert version_schema
assert version_table
for binding in _get_bindings(self.session, EngineType.READ_AND_WRITE):
with binding as session:
with binding as binded_session:
if stats.USE_TAGS:
key = ["sql", "manual", "health_check", "alembic"]
tags: Optional[Dict[str, str]] = {"conf": alembic_ini_path, "con": binding.name()}
Expand All @@ -274,11 +281,15 @@ def __call__(self, request: pyramid.request.Request) -> str:
]
tags = None
with stats.timer_context(key, tags):
quote = session.bind.dialect.identifier_preparer.quote
(actual_version,) = session.execute(
"SELECT version_num FROM " # nosec
f"{quote(version_schema)}.{quote(version_table)}"
result = binded_session.execute(
sqlalchemy.text(
"SELECT version_num FROM " # nosec
f"{sqlalchemy.sql.quoted_name(version_schema, True)}."
f"{sqlalchemy.sql.quoted_name(version_table, True)}"
)
).fetchone()
assert result is not None
(actual_version,) = result
if stats.USE_TAGS:
stats.increment_counter(
["alembic_version"], 1, tags={"version": actual_version, "name": name}
Expand Down Expand Up @@ -492,7 +503,7 @@ def _run_one(
@staticmethod
def _create_db_engine_check(
binding: _Binding,
query_cb: Callable[[sqlalchemy.orm.scoping.scoped_session], None],
query_cb: Callable[[sqlalchemy.orm.scoped_session[sqlalchemy.orm.Session]], None],
) -> Tuple[str, Callable[[pyramid.request.Request], None]]:
def check(request: pyramid.request.Request) -> None:
with binding as session:
Expand All @@ -508,8 +519,8 @@ def check(request: pyramid.request.Request) -> None:
return "db_engine_" + binding.name(), check

@staticmethod
def _at_least_one(model: Any) -> Callable[[sqlalchemy.orm.scoping.scoped_session], Any]:
def query(session: sqlalchemy.orm.scoping.scoped_session) -> None:
def _at_least_one(model: Any) -> Callable[[sqlalchemy.orm.scoped_session[sqlalchemy.orm.Session]], Any]:
def query(session: sqlalchemy.orm.scoped_session[sqlalchemy.orm.Session]) -> None:
result = session.query(model).first()
if result is None:
raise HTTPNotFound(model.__name__ + " record not found")
Expand Down
Loading

0 comments on commit 1770ecf

Please sign in to comment.