Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update dependency SQLAlchemy to v2 (master) #1726

Merged
merged 2 commits into from
Feb 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .prospector.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ pycodestyle:
disable:
- E203 # Whitespace before ':', duplicated with black, with error in array
- E722 # do not use bare 'except', duplicated with pylint
- E261 # at least two spaces before inline comment, duplicated with black

pydocstyle:
disable:
Expand Down
75 changes: 44 additions & 31 deletions c2cwsgiutils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def setup_session(
force_master: Optional[Iterable[str]] = None,
force_slave: Optional[Iterable[str]] = None,
) -> Tuple[
Union[sqlalchemy.orm.Session, sqlalchemy.orm.scoped_session],
Union[sqlalchemy.orm.Session, sqlalchemy.orm.scoped_session[sqlalchemy.orm.Session]],
sbrunner marked this conversation as resolved.
Show resolved Hide resolved
sqlalchemy.engine.Engine,
sqlalchemy.engine.Engine,
]:
Expand Down Expand Up @@ -67,7 +67,7 @@ def setup_session(
slave_prefix = master_prefix
settings = config.registry.settings
rw_engine = sqlalchemy.engine_from_config(settings, master_prefix + ".")
rw_engine.c2c_name = master_prefix
rw_engine.c2c_name = master_prefix # type: ignore
sbrunner marked this conversation as resolved.
Show resolved Hide resolved
factory = sqlalchemy.orm.sessionmaker(bind=rw_engine)
register(factory)
db_session = sqlalchemy.orm.scoped_session(factory)
Expand All @@ -76,14 +76,14 @@ def setup_session(
if settings[master_prefix + ".url"] != settings.get(slave_prefix + ".url"):
LOG.info("Using a slave DB for reading %s", master_prefix)
ro_engine = sqlalchemy.engine_from_config(config.get_settings(), slave_prefix + ".")
ro_engine.c2c_name = slave_prefix
ro_engine.c2c_name = slave_prefix # type: ignore
tween_name = master_prefix.replace(".", "_")
_add_tween(config, tween_name, db_session, force_master, force_slave)
else:
ro_engine = rw_engine

db_session.c2c_rw_bind = rw_engine
db_session.c2c_ro_bind = ro_engine
db_session.c2c_rw_bind = rw_engine # type: ignore
db_session.c2c_ro_bind = ro_engine # type: ignore
return db_session, rw_engine, ro_engine


Expand All @@ -95,7 +95,7 @@ def create_session(
force_master: Optional[Iterable[str]] = None,
force_slave: Optional[Iterable[str]] = None,
**engine_config: Any,
) -> Union[sqlalchemy.orm.Session, sqlalchemy.orm.scoped_session]:
) -> Union[sqlalchemy.orm.Session, sqlalchemy.orm.scoped_session[sqlalchemy.orm.Session]]:
"""
Create a SQLAlchemy session.

Expand Down Expand Up @@ -133,21 +133,21 @@ def create_session(
LOG.info("Using a slave DB for reading %s", name)
ro_engine = sqlalchemy.create_engine(slave_url, **engine_config)
_add_tween(config, name, db_session, force_master, force_slave)
rw_engine.c2c_name = name + "_master"
ro_engine.c2c_name = name + "_slave"
rw_engine.c2c_name = name + "_master" # type: ignore
ro_engine.c2c_name = name + "_slave" # type: ignore
else:
rw_engine.c2c_name = name
rw_engine.c2c_name = name # type: ignore
ro_engine = rw_engine

db_session.c2c_rw_bind = rw_engine
db_session.c2c_ro_bind = ro_engine
db_session.c2c_rw_bind = rw_engine # type: ignore
db_session.c2c_ro_bind = ro_engine # type: ignore
return db_session


def _add_tween(
config: pyramid.config.Configurator,
name: str,
db_session: Union[sqlalchemy.orm.Session, sqlalchemy.orm.scoped_session],
db_session: sqlalchemy.orm.scoped_session[sqlalchemy.orm.Session],
force_master: Optional[Iterable[str]],
force_slave: Optional[Iterable[str]],
) -> None:
Expand Down Expand Up @@ -176,11 +176,19 @@ def db_chooser_tween(request: pyramid.request.Request) -> Any:
not has_force_master
and (request.method in ("GET", "OPTIONS") or any(r.match(method_path) for r in slave_paths))
):
LOG.debug("Using %s database for: %s", db_session.c2c_ro_bind.c2c_name, method_path)
session.bind = db_session.c2c_ro_bind
LOG.debug(
"Using %s database for: %s",
db_session.c2c_ro_bind.c2c_name, # type: ignore
method_path,
)
session.bind = db_session.c2c_ro_bind # type: ignore
else:
LOG.debug("Using %s database for: %s", db_session.c2c_rw_bind.c2c_name, method_path)
session.bind = db_session.c2c_rw_bind
LOG.debug(
"Using %s database for: %s",
db_session.c2c_rw_bind.c2c_name, # type: ignore
method_path,
)
session.bind = db_session.c2c_rw_bind # type: ignore

try:
return handler(request)
Expand All @@ -193,7 +201,7 @@ def db_chooser_tween(request: pyramid.request.Request) -> Any:
config.add_tween("c2cwsgiutils.db.tweens." + name, over="pyramid_tm.tm_tween_factory")


class SessionFactory(sessionmaker): # type: ignore
class SessionFactory(sessionmaker[sqlalchemy.orm.Session]): # pylint: disable=unsubscriptable-object
"""The custom session factory that manage the read only and read write sessions."""

def __init__(
Expand All @@ -213,18 +221,18 @@ def __init__(

def engine_name(self, readwrite: bool) -> str:
if readwrite:
return cast(str, self.rw_engine.c2c_name)
return cast(str, self.ro_engine.c2c_name)
return cast(str, self.rw_engine.c2c_name) # type: ignore
return cast(str, self.ro_engine.c2c_name) # type: ignore

def __call__(
def __call__( # type: ignore
self, request: Optional[pyramid.request.Request], readwrite: Optional[bool] = None, **local_kw: Any
) -> sqlalchemy.orm.Session:
) -> sqlalchemy.orm.scoped_session[sqlalchemy.orm.Session]:
if readwrite is not None:
if readwrite and not force_readonly:
LOG.debug("Using %s database", self.rw_engine.c2c_name)
LOG.debug("Using %s database", self.rw_engine.c2c_name) # type: ignore
self.configure(bind=self.rw_engine)
else:
LOG.debug("Using %s database", self.ro_engine.c2c_name)
LOG.debug("Using %s database", self.ro_engine.c2c_name) # type: ignore
self.configure(bind=self.ro_engine)
else:
assert request is not None
Expand All @@ -237,12 +245,12 @@ def __call__(
or any(r.match(method_path) for r in self.slave_paths)
)
):
LOG.debug("Using %s database for: %s", self.ro_engine.c2c_name, method_path)
LOG.debug("Using %s database for: %s", self.ro_engine.c2c_name, method_path) # type: ignore
self.configure(bind=self.ro_engine)
else:
LOG.debug("Using %s database for: %s", self.rw_engine.c2c_name, method_path)
LOG.debug("Using %s database for: %s", self.rw_engine.c2c_name, method_path) # type: ignore
self.configure(bind=self.rw_engine)
return super().__call__(**local_kw)
return super().__call__(**local_kw) # type: ignore


def get_engine(
Expand All @@ -252,15 +260,17 @@ def get_engine(
return engine_from_config(settings, prefix)


def get_session_factory(engine: sqlalchemy.engine.Engine) -> sessionmaker:
def get_session_factory(
engine: sqlalchemy.engine.Engine,
) -> sessionmaker[sqlalchemy.orm.Session]: # pylint: disable=unsubscriptable-object
"""Get the session factory from the engine."""
factory = sessionmaker()
factory.configure(bind=engine)
return factory


def get_tm_session(
session_factory: sessionmaker,
session_factory: sessionmaker[sqlalchemy.orm.Session], # pylint: disable=unsubscriptable-object
transaction_manager: transaction.TransactionManager,
) -> sqlalchemy.orm.Session:
"""
Expand Down Expand Up @@ -319,6 +329,7 @@ def get_tm_session(
request = dbsession.info["request"]
"""
dbsession = session_factory()
assert isinstance(dbsession, sqlalchemy.orm.Session), type(dbsession)
zope.sqlalchemy.register(dbsession, transaction_manager=transaction_manager)
return dbsession

Expand All @@ -327,7 +338,7 @@ def get_tm_session_pyramid(
session_factory: SessionFactory,
transaction_manager: transaction.TransactionManager,
request: pyramid.request.Request,
) -> sqlalchemy.orm.Session:
) -> sqlalchemy.orm.scoped_session[sqlalchemy.orm.Session]:
"""
Get a ``sqlalchemy.orm.Session`` instance backed by a transaction.

Expand Down Expand Up @@ -367,13 +378,13 @@ def init(
dbengine = settings.get("dbengine")
if not dbengine:
rw_engine = get_engine(settings, master_prefix + ".")
rw_engine.c2c_name = master_prefix
rw_engine.c2c_name = master_prefix # type: ignore

# Setup a slave DB connection and add a tween to use it.
if slave_prefix and settings[master_prefix + ".url"] != settings.get(slave_prefix + ".url"):
LOG.info("Using a slave DB for reading %s", master_prefix)
ro_engine = get_engine(config.get_settings(), slave_prefix + ".")
ro_engine.c2c_name = slave_prefix
ro_engine.c2c_name = slave_prefix # type: ignore
else:
ro_engine = rw_engine
else:
Expand All @@ -389,6 +400,8 @@ def dbsession(request: pyramid.request.Request) -> sqlalchemy.orm.Session:
if dbsession is None:
# request.tm is the transaction manager used by pyramid_tm
dbsession = get_tm_session_pyramid(session_factory, request.tm, request=request)
assert dbsession is not None
assert isinstance(dbsession, sqlalchemy.orm.Session), type(dbsession)
return dbsession

config.add_request_method(dbsession, reify=True)
Expand Down
9 changes: 5 additions & 4 deletions c2cwsgiutils/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,11 @@ def _include_dev_details(request: pyramid.request.Request) -> bool:
def _integrity_error(
exception: sqlalchemy.exc.StatementError, request: pyramid.request.Request
) -> pyramid.response.Response:
def reduce_info_sent(e: sqlalchemy.exc.StatementError) -> None:
# remove details (SQL statement and links to SQLAlchemy) from the error
e.statement = None
e.code = None
def reduce_info_sent(e: Exception) -> None:
if isinstance(e, sqlalchemy.exc.StatementError):
sbrunner marked this conversation as resolved.
Show resolved Hide resolved
# remove details (SQL statement and links to SQLAlchemy) from the error
e.statement = None
e.code = None

return _do_error(request, 400, exception, reduce_info_sent=reduce_info_sent)

Expand Down
57 changes: 34 additions & 23 deletions c2cwsgiutils/health_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import requests
import sqlalchemy.engine
import sqlalchemy.orm
import sqlalchemy.sql
from pyramid.httpexceptions import HTTPNotFound

import c2cwsgiutils.db
Expand Down Expand Up @@ -58,7 +59,7 @@ class _Binding:
def name(self) -> str:
raise NotImplementedError()

def __enter__(self) -> sqlalchemy.orm.Session:
def __enter__(self) -> sqlalchemy.orm.scoped_session[sqlalchemy.orm.Session]:
raise NotImplementedError()

def __exit__(
Expand All @@ -78,21 +79,23 @@ def __init__(self, session: c2cwsgiutils.db.SessionFactory, readwrite: bool):
def name(self) -> str:
return self.session.engine_name(self.readwrite)

def __enter__(self) -> sqlalchemy.orm.Session:
def __enter__(self) -> sqlalchemy.orm.scoped_session[sqlalchemy.orm.Session]:
return self.session(None, self.readwrite)


class _OldBinding(_Binding):
def __init__(self, session: sqlalchemy.orm.scoping.scoped_session, engine: sqlalchemy.engine.Engine):
def __init__(
self, session: sqlalchemy.orm.scoped_session[sqlalchemy.orm.Session], engine: sqlalchemy.engine.Engine
):
self.session = session
self.engine = engine
self.prev_bind = None

def name(self) -> str:
return cast(str, self.engine.c2c_name)
return cast(str, self.engine.c2c_name) # type: ignore

def __enter__(self) -> sqlalchemy.orm.Session:
self.prev_bind = self.session.bind
def __enter__(self) -> sqlalchemy.orm.scoped_session[sqlalchemy.orm.Session]:
self.prev_bind = self.session.bind # type: ignore
self.session.bind = self.engine
return self.session

Expand All @@ -107,7 +110,7 @@ def __exit__(


def _get_binding_class(
session: Union[sqlalchemy.orm.scoping.scoped_session, c2cwsgiutils.db.SessionFactory],
session: Union[sqlalchemy.orm.scoped_session[sqlalchemy.orm.Session], c2cwsgiutils.db.SessionFactory],
ro_engin: sqlalchemy.engine.Engine,
rw_engin: sqlalchemy.engine.Engine,
readwrite: bool,
Expand All @@ -119,15 +122,15 @@ def _get_binding_class(


def _get_bindings(
session: Union[sqlalchemy.orm.scoping.scoped_session, c2cwsgiutils.db.SessionFactory],
session: Union[sqlalchemy.orm.scoped_session[sqlalchemy.orm.Session], c2cwsgiutils.db.SessionFactory],
engine_type: EngineType,
) -> List[sqlalchemy.engine.Engine]:
) -> List[_Binding]:
if isinstance(session, c2cwsgiutils.db.SessionFactory):
ro_engin = session.ro_engine
rw_engin = session.rw_engine
else:
ro_engin = session.c2c_ro_bind
rw_engin = session.c2c_rw_bind
ro_engin = session.c2c_ro_bind # type: ignore
rw_engin = session.c2c_rw_bind # type: ignore

if rw_engin == ro_engin:
engine_type = EngineType.WRITE_ONLY
Expand Down Expand Up @@ -192,8 +195,8 @@ def __init__(self, config: pyramid.config.Configurator) -> None:

def add_db_session_check(
self,
session: Union[sqlalchemy.orm.scoping.scoped_session, c2cwsgiutils.db.SessionFactory],
query_cb: Optional[Callable[[sqlalchemy.orm.scoping.scoped_session], Any]] = None,
session: Union[sqlalchemy.orm.scoped_session[sqlalchemy.orm.Session], c2cwsgiutils.db.SessionFactory],
query_cb: Optional[Callable[[sqlalchemy.orm.scoped_session[sqlalchemy.orm.Session]], Any]] = None,
at_least_one_model: Optional[object] = None,
level: int = 1,
engine_type: EngineType = EngineType.READ_AND_WRITE,
Expand All @@ -220,7 +223,7 @@ def add_db_session_check(

def add_alembic_check(
self,
session: Union[sqlalchemy.orm.scoping.scoped_session, c2cwsgiutils.db.SessionFactory],
session: sqlalchemy.orm.scoped_session[sqlalchemy.orm.Session],
alembic_ini_path: str,
level: int = 2,
name: str = "alembic",
Expand Down Expand Up @@ -249,17 +252,21 @@ def add_alembic_check(

if version_schema is None:
version_schema = config.get(name, "version_table_schema", fallback="public")
assert version_schema

if version_table is None:
version_table = config.get(name, "version_table", fallback="alembic_version")
assert version_table

class _Check:
def __init__(self, session: sqlalchemy.orm.scoping.scoped_session) -> None:
def __init__(self, session: sqlalchemy.orm.scoped_session[sqlalchemy.orm.Session]) -> None:
self.session = session

def __call__(self, request: pyramid.request.Request) -> str:
assert version_schema
assert version_table
for binding in _get_bindings(self.session, EngineType.READ_AND_WRITE):
with binding as session:
with binding as binded_session:
if stats.USE_TAGS:
key = ["sql", "manual", "health_check", "alembic"]
tags: Optional[Dict[str, str]] = {"conf": alembic_ini_path, "con": binding.name()}
Expand All @@ -274,11 +281,15 @@ def __call__(self, request: pyramid.request.Request) -> str:
]
tags = None
with stats.timer_context(key, tags):
quote = session.bind.dialect.identifier_preparer.quote
(actual_version,) = session.execute(
"SELECT version_num FROM " # nosec
f"{quote(version_schema)}.{quote(version_table)}"
result = binded_session.execute(
sqlalchemy.text(
"SELECT version_num FROM " # nosec
f"{sqlalchemy.sql.quoted_name(version_schema, True)}."
f"{sqlalchemy.sql.quoted_name(version_table, True)}"
)
).fetchone()
assert result is not None
(actual_version,) = result
if stats.USE_TAGS:
stats.increment_counter(
["alembic_version"], 1, tags={"version": actual_version, "name": name}
Expand Down Expand Up @@ -492,7 +503,7 @@ def _run_one(
@staticmethod
def _create_db_engine_check(
binding: _Binding,
query_cb: Callable[[sqlalchemy.orm.scoping.scoped_session], None],
query_cb: Callable[[sqlalchemy.orm.scoped_session[sqlalchemy.orm.Session]], None],
) -> Tuple[str, Callable[[pyramid.request.Request], None]]:
def check(request: pyramid.request.Request) -> None:
with binding as session:
Expand All @@ -508,8 +519,8 @@ def check(request: pyramid.request.Request) -> None:
return "db_engine_" + binding.name(), check

@staticmethod
def _at_least_one(model: Any) -> Callable[[sqlalchemy.orm.scoping.scoped_session], Any]:
def query(session: sqlalchemy.orm.scoping.scoped_session) -> None:
def _at_least_one(model: Any) -> Callable[[sqlalchemy.orm.scoped_session[sqlalchemy.orm.Session]], Any]:
def query(session: sqlalchemy.orm.scoped_session[sqlalchemy.orm.Session]) -> None:
result = session.query(model).first()
if result is None:
raise HTTPNotFound(model.__name__ + " record not found")
Expand Down
Loading