Skip to content

Commit

Permalink
Fix Erroneous Audit Info Update (#127)
Browse files Browse the repository at this point in the history
* implement SqlaEventHandler.pause ctx manager

* pause event handlers during re-setting of default version

* adjust tests
  • Loading branch information
meksor authored Oct 24, 2024
1 parent 2df83c5 commit 2785183
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 22 deletions.
38 changes: 29 additions & 9 deletions ixmp4/data/db/events.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Mapping, Sequence, cast

from sqlalchemy import Connection, event, sql
Expand All @@ -20,15 +21,34 @@ class SqlaEventHandler(object):

def __init__(self, backend: "SqlAlchemyBackend") -> None:
self.backend = backend
event.listen(
self.backend.session, "do_orm_execute", self.receive_do_orm_execute
)
event.listen(
base.BaseModel, "before_insert", self.receive_before_insert, propagate=True
)
event.listen(
base.BaseModel, "before_update", self.receive_before_update, propagate=True
)
self.listeners = [
((backend.session, "do_orm_execute", self.receive_do_orm_execute), {}),
(
(base.BaseModel, "before_insert", self.receive_before_insert),
{"propagate": True},
),
(
(base.BaseModel, "before_update", self.receive_before_update),
{"propagate": True},
),
]
self.add_listeners()

def add_listeners(self):
for args, kwargs in self.listeners:
event.listen(*args, **kwargs)

def remove_listeners(self):
for args, kwargs in self.listeners:
if event.contains(*args):
event.remove(*args)

@contextmanager
def pause(self):
"""Temporarily removes all event listeners for the enclosed scope."""
self.remove_listeners()
yield
self.add_listeners()

def set_logger(self, state):
self.logger = logging.getLogger(__name__ + "." + str(id(state)))
Expand Down
8 changes: 6 additions & 2 deletions ixmp4/data/db/run/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,16 @@ def set_as_default_version(self, id: int) -> None:
.where(
Run.model__id == run.model__id,
Run.scenario__id == run.scenario__id,
Run.is_default,
)
.values(is_default=False)
)

self.session.execute(exc)
self.session.commit()
with self.backend.event_handler.pause():
# we dont want to trigger the
# updated_at fields for this operation.
self.session.execute(exc)
self.session.commit()

run.is_default = True
self.session.commit()
Expand Down
11 changes: 0 additions & 11 deletions tests/data/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,3 @@ def drop_audit_info(self, df):
inplace=True,
columns=["created_by", "created_at", "updated_by", "updated_at"],
)

def test_audit_info(self, platform: ixmp4.Platform):
run = platform.backend.runs.create("Model", "Scenario")
platform.backend.runs.set_as_default_version(run.id)
platform.backend.runs.create("Model", "Scenario")

runs = platform.backend.runs.tabulate(default_only=False)
assert (runs["created_by"] == "@unknown").all()
# was updated by set_as_default_version
assert runs["updated_by"][0] == "@unknown"
assert runs["updated_by"][1] is None
25 changes: 25 additions & 0 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,31 @@ def test_guards(self, sqlite_platform: ixmp4.Platform, user, truths):
with pytest.raises(Forbidden):
run.meta = {"meta": "test"}

def test_run_audit_info(self, db_platform: ixmp4.Platform):
backend = cast(SqlAlchemyBackend, db_platform.backend)

test_user = User(username="test_audit", is_verified=True, is_superuser=True)

run1 = backend.runs.create("Model 1", "Scenario 1")

backend.runs.create("Model 1", "Scenario 1")
backend.runs.set_as_default_version(run1.id)

with backend.auth(test_user, self.mock_manager, self.TEST_PLATFORMS[0]):
run3 = backend.runs.create("Model 1", "Scenario 1")
backend.runs.set_as_default_version(run3.id)

runs = backend.runs.tabulate(default_only=False)
assert runs["created_by"][0] == "@unknown"
assert runs["created_by"][1] == "@unknown"
assert runs["created_by"][2] == "test_audit"

# run1 was updated by set_as_default_version
# run2 was not
assert runs["updated_by"][0] == "@unknown"
assert runs["updated_by"][1] is None
assert runs["updated_by"][2] == "test_audit"

@pytest.mark.parametrize(
"model, platform_info, access",
[
Expand Down

0 comments on commit 2785183

Please sign in to comment.