Skip to content

Commit

Permalink
Apply the update to SQLAlchemy 2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrunner committed Feb 1, 2023
1 parent e495732 commit 079d8f0
Show file tree
Hide file tree
Showing 9 changed files with 149 additions and 87 deletions.
1 change: 1 addition & 0 deletions .prospector.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ pycodestyle:
disable:
- E203 # Whitespace before ':', duplicated with black, with error in array
- E722 # do not use bare 'except', duplicated with pylint
- E261 # at least two spaces before inline comment, duplicated with black

pydocstyle:
disable:
Expand Down
78 changes: 45 additions & 33 deletions c2cwsgiutils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import re
import warnings
from typing import Any, Callable, Iterable, Optional, Pattern, Tuple, Union, cast

import pyramid.config
import pyramid.config.settings
import pyramid.request
Expand Down Expand Up @@ -35,7 +34,7 @@ def setup_session(
force_master: Optional[Iterable[str]] = None,
force_slave: Optional[Iterable[str]] = None,
) -> Tuple[
Union[sqlalchemy.orm.Session, sqlalchemy.orm.scoped_session],
Union[sqlalchemy.orm.Session, sqlalchemy.orm.scoped_session[sqlalchemy.orm.Session]],
sqlalchemy.engine.Engine,
sqlalchemy.engine.Engine,
]:
Expand Down Expand Up @@ -67,7 +66,7 @@ def setup_session(
slave_prefix = master_prefix
settings = config.registry.settings
rw_engine = sqlalchemy.engine_from_config(settings, master_prefix + ".")
rw_engine.c2c_name = master_prefix
rw_engine.c2c_name = master_prefix # type: ignore
factory = sqlalchemy.orm.sessionmaker(bind=rw_engine)
register(factory)
db_session = sqlalchemy.orm.scoped_session(factory)
Expand All @@ -76,14 +75,14 @@ def setup_session(
if settings[master_prefix + ".url"] != settings.get(slave_prefix + ".url"):
LOG.info("Using a slave DB for reading %s", master_prefix)
ro_engine = sqlalchemy.engine_from_config(config.get_settings(), slave_prefix + ".")
ro_engine.c2c_name = slave_prefix
ro_engine.c2c_name = slave_prefix # type: ignore
tween_name = master_prefix.replace(".", "_")
_add_tween(config, tween_name, db_session, force_master, force_slave)
else:
ro_engine = rw_engine

db_session.c2c_rw_bind = rw_engine
db_session.c2c_ro_bind = ro_engine
db_session.c2c_rw_bind = rw_engine # type: ignore
db_session.c2c_ro_bind = ro_engine # type: ignore
return db_session, rw_engine, ro_engine


Expand All @@ -95,7 +94,7 @@ def create_session(
force_master: Optional[Iterable[str]] = None,
force_slave: Optional[Iterable[str]] = None,
**engine_config: Any,
) -> Union[sqlalchemy.orm.Session, sqlalchemy.orm.scoped_session]:
) -> Union[sqlalchemy.orm.Session, sqlalchemy.orm.scoped_session[sqlalchemy.orm.Session]]:
"""
Create a SQLAlchemy session.
Expand Down Expand Up @@ -133,21 +132,21 @@ def create_session(
LOG.info("Using a slave DB for reading %s", name)
ro_engine = sqlalchemy.create_engine(slave_url, **engine_config)
_add_tween(config, name, db_session, force_master, force_slave)
rw_engine.c2c_name = name + "_master"
ro_engine.c2c_name = name + "_slave"
rw_engine.c2c_name = name + "_master" # type: ignore
ro_engine.c2c_name = name + "_slave" # type: ignore
else:
rw_engine.c2c_name = name
rw_engine.c2c_name = name # type: ignore
ro_engine = rw_engine

db_session.c2c_rw_bind = rw_engine
db_session.c2c_ro_bind = ro_engine
db_session.c2c_rw_bind = rw_engine # type: ignore
db_session.c2c_ro_bind = ro_engine # type: ignore
return db_session


def _add_tween(
config: pyramid.config.Configurator,
name: str,
db_session: Union[sqlalchemy.orm.Session, sqlalchemy.orm.scoped_session],
db_session: sqlalchemy.orm.scoped_session[sqlalchemy.orm.Session],
force_master: Optional[Iterable[str]],
force_slave: Optional[Iterable[str]],
) -> None:
Expand Down Expand Up @@ -176,11 +175,19 @@ def db_chooser_tween(request: pyramid.request.Request) -> Any:
not has_force_master
and (request.method in ("GET", "OPTIONS") or any(r.match(method_path) for r in slave_paths))
):
LOG.debug("Using %s database for: %s", db_session.c2c_ro_bind.c2c_name, method_path)
session.bind = db_session.c2c_ro_bind
LOG.debug(
"Using %s database for: %s",
db_session.c2c_ro_bind.c2c_name, # type: ignore
method_path,
)
session.bind = db_session.c2c_ro_bind # type: ignore
else:
LOG.debug("Using %s database for: %s", db_session.c2c_rw_bind.c2c_name, method_path)
session.bind = db_session.c2c_rw_bind
LOG.debug(
"Using %s database for: %s",
db_session.c2c_rw_bind.c2c_name, # type: ignore
method_path,
)
session.bind = db_session.c2c_rw_bind # type: ignore

try:
return handler(request)
Expand All @@ -193,7 +200,7 @@ def db_chooser_tween(request: pyramid.request.Request) -> Any:
config.add_tween("c2cwsgiutils.db.tweens." + name, over="pyramid_tm.tm_tween_factory")


class SessionFactory(sessionmaker): # type: ignore
class SessionFactory(sessionmaker[sqlalchemy.orm.Session]): # pylint: disable=unsubscriptable-object
"""The custom session factory that manage the read only and read write sessions."""

def __init__(
Expand All @@ -213,18 +220,18 @@ def __init__(

def engine_name(self, readwrite: bool) -> str:
if readwrite:
return cast(str, self.rw_engine.c2c_name)
return cast(str, self.ro_engine.c2c_name)
return cast(str, self.rw_engine.c2c_name) # type: ignore
return cast(str, self.ro_engine.c2c_name) # type: ignore

def __call__(
def __call__( # type: ignore
self, request: Optional[pyramid.request.Request], readwrite: Optional[bool] = None, **local_kw: Any
) -> sqlalchemy.orm.Session:
) -> sqlalchemy.orm.scoped_session[sqlalchemy.orm.Session]:
if readwrite is not None:
if readwrite and not force_readonly:
LOG.debug("Using %s database", self.rw_engine.c2c_name)
LOG.debug("Using %s database", self.rw_engine.c2c_name) # type: ignore
self.configure(bind=self.rw_engine)
else:
LOG.debug("Using %s database", self.ro_engine.c2c_name)
LOG.debug("Using %s database", self.ro_engine.c2c_name) # type: ignore
self.configure(bind=self.ro_engine)
else:
assert request is not None
Expand All @@ -237,12 +244,12 @@ def __call__(
or any(r.match(method_path) for r in self.slave_paths)
)
):
LOG.debug("Using %s database for: %s", self.ro_engine.c2c_name, method_path)
LOG.debug("Using %s database for: %s", self.ro_engine.c2c_name, method_path) # type: ignore
self.configure(bind=self.ro_engine)
else:
LOG.debug("Using %s database for: %s", self.rw_engine.c2c_name, method_path)
LOG.debug("Using %s database for: %s", self.rw_engine.c2c_name, method_path) # type: ignore
self.configure(bind=self.rw_engine)
return super().__call__(**local_kw)
return super().__call__(**local_kw) # type: ignore


def get_engine(
Expand All @@ -252,15 +259,17 @@ def get_engine(
return engine_from_config(settings, prefix)


def get_session_factory(engine: sqlalchemy.engine.Engine) -> sessionmaker:
def get_session_factory(
engine: sqlalchemy.engine.Engine,
) -> sessionmaker[sqlalchemy.orm.Session]: # pylint: disable=unsubscriptable-object
"""Get the session factory from the engine."""
factory = sessionmaker()
factory.configure(bind=engine)
return factory


def get_tm_session(
session_factory: sessionmaker,
session_factory: sessionmaker[sqlalchemy.orm.Session], # pylint: disable=unsubscriptable-object
transaction_manager: transaction.TransactionManager,
) -> sqlalchemy.orm.Session:
"""
Expand Down Expand Up @@ -319,6 +328,7 @@ def get_tm_session(
request = dbsession.info["request"]
"""
dbsession = session_factory()
assert isinstance(dbsession, sqlalchemy.orm.Session)
zope.sqlalchemy.register(dbsession, transaction_manager=transaction_manager)
return dbsession

Expand All @@ -327,7 +337,7 @@ def get_tm_session_pyramid(
session_factory: SessionFactory,
transaction_manager: transaction.TransactionManager,
request: pyramid.request.Request,
) -> sqlalchemy.orm.Session:
) -> sqlalchemy.orm.scoped_session[sqlalchemy.orm.Session]:
"""
Get a ``sqlalchemy.orm.Session`` instance backed by a transaction.
Expand Down Expand Up @@ -367,13 +377,13 @@ def init(
dbengine = settings.get("dbengine")
if not dbengine:
rw_engine = get_engine(settings, master_prefix + ".")
rw_engine.c2c_name = master_prefix
rw_engine.c2c_name = master_prefix # type: ignore

# Setup a slave DB connection and add a tween to use it.
if slave_prefix and settings[master_prefix + ".url"] != settings.get(slave_prefix + ".url"):
LOG.info("Using a slave DB for reading %s", master_prefix)
ro_engine = get_engine(config.get_settings(), slave_prefix + ".")
ro_engine.c2c_name = slave_prefix
ro_engine.c2c_name = slave_prefix # type: ignore
else:
ro_engine = rw_engine
else:
Expand All @@ -383,12 +393,14 @@ def init(
config.registry["dbsession_factory"] = session_factory

# make request.dbsession available for use in Pyramid
def dbsession(request: pyramid.request.Request) -> sqlalchemy.orm.Session:
def dbsession(request: pyramid.request.Request) -> sqlalchemy.orm.scoped_session[sqlalchemy.orm.Session]:
# hook to share the dbsession fixture in testing
dbsession = request.environ.get("app.dbsession")
if dbsession is None:
# request.tm is the transaction manager used by pyramid_tm
dbsession = get_tm_session_pyramid(session_factory, request.tm, request=request)
assert dbsession is not None
assert isinstance(dbsession, sqlalchemy.orm.scoped_session)
return dbsession

config.add_request_method(dbsession, reify=True)
Expand Down
9 changes: 5 additions & 4 deletions c2cwsgiutils/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,11 @@ def _include_dev_details(request: pyramid.request.Request) -> bool:
def _integrity_error(
exception: sqlalchemy.exc.StatementError, request: pyramid.request.Request
) -> pyramid.response.Response:
def reduce_info_sent(e: sqlalchemy.exc.StatementError) -> None:
# remove details (SQL statement and links to SQLAlchemy) from the error
e.statement = None
e.code = None
def reduce_info_sent(e: Exception) -> None:
if isinstance(e, sqlalchemy.exc.StatementError):
# remove details (SQL statement and links to SQLAlchemy) from the error
e.statement = None
e.code = None

return _do_error(request, 400, exception, reduce_info_sent=reduce_info_sent)

Expand Down
Loading

0 comments on commit 079d8f0

Please sign in to comment.