From f4660bbfd80d71d869745ceab63b581d50b4666f Mon Sep 17 00:00:00 2001 From: Cameron Hurst Date: Wed, 18 Dec 2019 12:41:01 -0500 Subject: [PATCH 01/12] feat: added refactored for integration with alembic Reworked location of some function calls to support multple purposes Added migration calls to be manually added into alembic migrations --- postgresql_audit/base.py | 96 ++++++---------------------------- postgresql_audit/migrations.py | 23 ++++++++ postgresql_audit/utils.py | 75 ++++++++++++++++++++++++++ 3 files changed, 115 insertions(+), 79 deletions(-) create mode 100644 postgresql_audit/utils.py diff --git a/postgresql_audit/base.py b/postgresql_audit/base.py index 4f3c50e..01742f1 100644 --- a/postgresql_audit/base.py +++ b/postgresql_audit/base.py @@ -1,7 +1,6 @@ -import os -import string import warnings from contextlib import contextmanager +from functools import partial from weakref import WeakSet import sqlalchemy as sa @@ -18,7 +17,9 @@ from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy_utils import get_class_by_table -HERE = os.path.dirname(os.path.abspath(__file__)) +from postgresql_audit.utils import render_tmpl, StatementExecutor, create_audit_table, create_operators, \ + build_register_table_query + cached_statements = {} @@ -29,23 +30,6 @@ class ImproperlyConfigured(Exception): class ClassNotVersioned(Exception): pass - -class StatementExecutor(object): - def __init__(self, stmt): - self.stmt = stmt - - def __call__(self, target, bind, **kwargs): - tx = bind.begin() - bind.execute(self.stmt) - tx.commit() - - -def read_file(file_): - with open(os.path.join(HERE, file_)) as f: - s = f.read() - return s - - def assign_actor(base, cls, actor_cls): if hasattr(cls, 'actor_id'): return @@ -194,10 +178,11 @@ def __init__( ), ) self.schema_name = schema_name + self.use_statement_level_triggers = use_statement_level_triggers self.table_listeners = self.get_table_listeners() self.pending_classes = WeakSet() self.cached_ddls = {} - self.use_statement_level_triggers = use_statement_level_triggers + def get_transaction_values(self): return self.values @@ -214,70 +199,28 @@ def disable(self, session): "SET LOCAL postgresql_audit.enable_versioning = 'true'" ) - def render_tmpl(self, tmpl_name): - file_contents = read_file( - 'templates/{}'.format(tmpl_name) - ).replace('%', '%%').replace('$$', '$$$$') - tmpl = string.Template(file_contents) - context = dict(schema_name=self.schema_name) - - if self.schema_name is None: - context['schema_prefix'] = '' - context['revoke_cmd'] = '' - else: - context['schema_prefix'] = '{}.'.format(self.schema_name) - context['revoke_cmd'] = ( - 'REVOKE ALL ON {schema_prefix}activity FROM public;' - ).format(**context) - - temp = tmpl.substitute(**context) - return temp - - def create_operators(self, target, bind, **kwargs): - if bind.dialect.server_version_info < (9, 5, 0): - StatementExecutor(self.render_tmpl('operators_pre95.sql'))( - target, bind, **kwargs - ) - if bind.dialect.server_version_info < (9, 6, 0): - StatementExecutor(self.render_tmpl('operators_pre96.sql'))( - target, bind, **kwargs - ) - if bind.dialect.server_version_info < (10, 0): - operators_template = self.render_tmpl('operators_pre100.sql') - StatementExecutor(operators_template)(target, bind, **kwargs) - operators_template = self.render_tmpl('operators.sql') - StatementExecutor(operators_template)(target, bind, **kwargs) - - def create_audit_table(self, target, bind, **kwargs): - sql = '' - if ( - self.use_statement_level_triggers and - bind.dialect.server_version_info >= (10, 0) - ): - sql += self.render_tmpl('create_activity_stmt_level.sql') - sql += self.render_tmpl('audit_table_stmt_level.sql') - else: - sql += self.render_tmpl('create_activity_row_level.sql') - sql += self.render_tmpl('audit_table_row_level.sql') - StatementExecutor(sql)(target, bind, **kwargs) - def get_table_listeners(self): listeners = {'transaction': []} listeners['activity'] = [ ('after_create', sa.schema.DDL( - self.render_tmpl('jsonb_change_key_name.sql') + render_tmpl('jsonb_change_key_name.sql', self.schema_name) )), - ('after_create', self.create_audit_table), - ('after_create', self.create_operators) + ('after_create', partial( + create_audit_table, + schema_name=self.schema_name, + use_statement_level_triggers=self.use_statement_level_triggers + ) + ), + ('after_create', partial(create_operators, schema_name=self.schema_name)) ] if self.schema_name is not None: listeners['transaction'] = [ ('before_create', sa.schema.DDL( - self.render_tmpl('create_schema.sql') + render_tmpl('create_schema.sql', self.schema_name) )), ('after_drop', sa.schema.DDL( - self.render_tmpl('drop_schema.sql') + render_tmpl('drop_schema.sql', self.schema_name) )), ] return listeners @@ -294,12 +237,7 @@ def audit_table(self, table, exclude_columns=None): ) ) args.append(array(exclude_columns)) - - if self.schema_name is None: - func = sa.func.audit_table - else: - func = getattr(getattr(sa.func, self.schema_name), 'audit_table') - query = sa.select([func(*args)]) + query = build_register_table_query(self.schema_name, *args) if query not in cached_statements: cached_statements[query] = StatementExecutor(query) listener = (table, 'after_create', cached_statements[query]) diff --git a/postgresql_audit/migrations.py b/postgresql_audit/migrations.py index 6fab915..56adab6 100644 --- a/postgresql_audit/migrations.py +++ b/postgresql_audit/migrations.py @@ -1,6 +1,8 @@ import sqlalchemy as sa from sqlalchemy.dialects.postgresql import JSONB +from postgresql_audit.utils import render_tmpl, create_audit_table, create_operators, \ + build_register_table_query from .expressions import jsonb_change_key_name @@ -16,6 +18,27 @@ def get_activity_table(schema=None): schema=schema, ) +def init_activity_table_triggers(conn, schema_name = None, use_statement_level_triggers=True): + conn.execute(render_tmpl('jsonb_change_key_name.sql', schema_name)) + create_audit_table(None, conn, schema_name, use_statement_level_triggers) + create_operators(None, conn, schema_name) + + if schema_name: + conn.execute(render_tmpl('create_schema.sql', schema_name)) + +def rollback_create_transaction(conn, schema_name=None): + if schema_name: + conn.execute(render_tmpl('drop_schema.sql', schema_name)) + +def init_before_create_transaction(conn, schema_name=None): + if schema_name: + conn.execute(render_tmpl('create_schema.sql', schema_name)) + + +def register_table(conn, table_name, exclude_columns, schema_name=None): + sql = build_register_table_query(schema_name, table_name, exclude_columns) + conn.execute(sql) + def alter_column(conn, table, column_name, func, schema=None): """ diff --git a/postgresql_audit/utils.py b/postgresql_audit/utils.py new file mode 100644 index 0000000..127e694 --- /dev/null +++ b/postgresql_audit/utils.py @@ -0,0 +1,75 @@ +import os +import string +import sqlalchemy as sa + +HERE = os.path.dirname(os.path.abspath(__file__)) + + +class StatementExecutor(object): + def __init__(self, stmt): + self.stmt = stmt + + def __call__(self, target, bind, **kwargs): + tx = bind.begin() + bind.execute(self.stmt) + tx.commit() + +def read_file(file_): + with open(os.path.join(HERE, file_)) as f: + s = f.read() + return s + +def render_tmpl(tmpl_name, schema_name=None): + file_contents = read_file( + 'templates/{}'.format(tmpl_name) + ).replace('%', '%%').replace('$$', '$$$$') + tmpl = string.Template(file_contents) + context = dict(schema_name=schema_name) + + if schema_name is None: + context['schema_prefix'] = '' + context['revoke_cmd'] = '' + else: + context['schema_prefix'] = '{}.'.format(schema_name) + context['revoke_cmd'] = ( + 'REVOKE ALL ON {schema_prefix}activity FROM public;' + ).format(**context) + + return tmpl.substitute(**context) + + +def create_operators(target, bind, schema_name, **kwargs): + if bind.dialect.server_version_info < (9, 5, 0): + StatementExecutor(render_tmpl('operators_pre95.sql', schema_name))( + target, bind, **kwargs + ) + if bind.dialect.server_version_info < (9, 6, 0): + StatementExecutor(render_tmpl('operators_pre96.sql', schema_name))( + target, bind, **kwargs + ) + if bind.dialect.server_version_info < (10, 0): + operators_template = render_tmpl('operators_pre100.sql', schema_name) + StatementExecutor(operators_template)(target, bind, **kwargs) + operators_template = render_tmpl('operators.sql', schema_name) + StatementExecutor(operators_template)(target, bind, **kwargs) + +def create_audit_table(target, bind, schema_name, use_statement_level_triggers, **kwargs): + sql = '' + if ( + use_statement_level_triggers and + bind.dialect.server_version_info >= (10, 0) + ): + sql += render_tmpl('create_activity_stmt_level.sql', schema_name) + sql += render_tmpl('audit_table_stmt_level.sql', schema_name) + else: + sql += render_tmpl('create_activity_row_level.sql', schema_name) + sql += render_tmpl('audit_table_row_level.sql', schema_name) + StatementExecutor(sql)(target, bind, **kwargs) + + +def build_register_table_query(schema_name, *args): + if schema_name is None: + func = sa.func.audit_table + else: + func = getattr(getattr(sa.func, schema_name), 'audit_table') + return sa.select([func(*args)]) From cb6cc901865a7b4c2fc222026fd5f3c7789c8ac2 Mon Sep 17 00:00:00 2001 From: Cameron Hurst Date: Thu, 19 Dec 2019 12:28:09 -0500 Subject: [PATCH 02/12] feat: Modularized the different aspects of VersioningManager Session related features are not under SessionManager There is a BaseVersioningManager for use with Alembic VersioningManager and FlaskVersioningManager remain api compatible --- postgresql_audit/base.py | 307 +++++++++++++++++++++++--------------- postgresql_audit/flask.py | 8 +- 2 files changed, 187 insertions(+), 128 deletions(-) diff --git a/postgresql_audit/base.py b/postgresql_audit/base.py index 01742f1..3b6758b 100644 --- a/postgresql_audit/base.py +++ b/postgresql_audit/base.py @@ -148,102 +148,22 @@ def convert_callables(values): } -class VersioningManager(object): - _actor_cls = None - - def __init__( - self, - actor_cls=None, - schema_name=None, - use_statement_level_triggers=True - ): - if actor_cls is not None: - self._actor_cls = actor_cls - self.values = {} +class SessionManager(object): + def __init__(self, transaction_cls, values=None): + self.transaction_cls = transaction_cls + self.values = values or {} + self._marked_transactions = set() self.listeners = ( - ( - orm.mapper, - 'instrument_class', - self.instrument_versioned_classes - ), - ( - orm.mapper, - 'after_configured', - self.configure_versioned_classes - ), ( orm.session.Session, 'before_flush', - self.receive_before_flush, + self.before_flush, ), ) - self.schema_name = schema_name - self.use_statement_level_triggers = use_statement_level_triggers - self.table_listeners = self.get_table_listeners() - self.pending_classes = WeakSet() - self.cached_ddls = {} - def get_transaction_values(self): return self.values - @contextmanager - def disable(self, session): - session.execute( - "SET LOCAL postgresql_audit.enable_versioning = 'false'" - ) - try: - yield - finally: - session.execute( - "SET LOCAL postgresql_audit.enable_versioning = 'true'" - ) - - def get_table_listeners(self): - listeners = {'transaction': []} - - listeners['activity'] = [ - ('after_create', sa.schema.DDL( - render_tmpl('jsonb_change_key_name.sql', self.schema_name) - )), - ('after_create', partial( - create_audit_table, - schema_name=self.schema_name, - use_statement_level_triggers=self.use_statement_level_triggers - ) - ), - ('after_create', partial(create_operators, schema_name=self.schema_name)) - ] - if self.schema_name is not None: - listeners['transaction'] = [ - ('before_create', sa.schema.DDL( - render_tmpl('create_schema.sql', self.schema_name) - )), - ('after_drop', sa.schema.DDL( - render_tmpl('drop_schema.sql', self.schema_name) - )), - ] - return listeners - - def audit_table(self, table, exclude_columns=None): - args = [table.name] - if exclude_columns: - for column in exclude_columns: - if column not in table.c: - raise ImproperlyConfigured( - "Could not configure versioning. Table '{}'' does " - "not have a column named '{}'.".format( - table.name, column - ) - ) - args.append(array(exclude_columns)) - query = build_register_table_query(self.schema_name, *args) - if query not in cached_statements: - cached_statements[query] = StatementExecutor(query) - listener = (table, 'after_create', cached_statements[query]) - if not sa.event.contains(*listener): - sa.event.listen(*listener) - def set_activity_values(self, session): dialect = session.bind.engine.dialect table = self.transaction_cls.__table__ @@ -303,40 +223,50 @@ def is_modified(self, obj_or_session): if hasattr(entity, '__versioned__') ) - def receive_before_flush(self, session, flush_context, instances): + def before_flush(self, session, flush_context, instances): + if session.transaction in self._marked_transactions: + return + if session.transaction: + self.add_entry_and_mark_transaction(session) + + def add_entry_and_mark_transaction(self, session): if self.is_modified(session): + self._marked_transactions.add(session.transaction) self.set_activity_values(session) - def instrument_versioned_classes(self, mapper, cls): - """ - Collect versioned class and add it to pending_classes list. - - :mapper mapper: SQLAlchemy mapper object - :cls cls: SQLAlchemy declarative class - """ - if hasattr(cls, '__versioned__') and cls not in self.pending_classes: - self.pending_classes.add(cls) + def attach_listeners(self): + for listener in self.listeners: + sa.event.listen(*listener) - def configure_versioned_classes(self): - """ - Configures all versioned classes that were collected during - instrumentation process. - """ - for cls in self.pending_classes: - self.audit_table(cls.__table__, cls.__versioned__.get('exclude')) - assign_actor(self.base, self.transaction_cls, self.actor_cls) + def remove_listeners(self): + for listener in self.listeners: + sa.event.remove(*listener) - def attach_table_listeners(self): - for values in self.table_listeners['transaction']: - sa.event.listen(self.transaction_cls.__table__, *values) - for values in self.table_listeners['activity']: - sa.event.listen(self.activity_cls.__table__, *values) +class BasicVersioningManager(object): + _actor_cls = None + _session_manager_factory = partial(SessionManager, values={}) - def remove_table_listeners(self): - for values in self.table_listeners['transaction']: - sa.event.remove(self.transaction_cls.__table__, *values) - for values in self.table_listeners['activity']: - sa.event.remove(self.activity_cls.__table__, *values) + def __init__( + self, + actor_cls=None, + session_manager_factory=None, + schema_name=None, + use_statement_level_triggers=True + ): + if actor_cls is not None: + self._actor_cls = actor_cls + if session_manager_factory is not None: + self._session_manager_factory = session_manager_factory + self.values = {} + self.listeners = ( + ( + orm.mapper, + 'after_configured', + self.after_configured + ), + ) + self.schema_name = schema_name + self.use_statement_level_triggers = use_statement_level_triggers @property def actor_cls(self): @@ -365,15 +295,8 @@ def actor_cls(self): ) return self._actor_cls - def attach_listeners(self): - self.attach_table_listeners() - for listener in self.listeners: - sa.event.listen(*listener) - - def remove_listeners(self): - self.remove_table_listeners() - for listener in self.listeners: - sa.event.remove(*listener) + def after_configured(self): + assign_actor(self.base, self.transaction_cls, self.actor_cls) def activity_model_factory(self, base, transaction_cls): class Activity(activity_base(base, self.schema_name, transaction_cls)): @@ -387,6 +310,28 @@ class Transaction(transaction_base(base, self.schema_name)): return Transaction + def attach_listeners(self): + for listener in self.listeners: + sa.event.listen(*listener) + self.session_manager.attach_listeners() + + def remove_listeners(self): + for listener in self.listeners: + sa.event.remove(*listener) + self.session_manager.remove_listeners() + + @contextmanager + def disable(self, session): + session.execute( + "SET LOCAL postgresql_audit.enable_versioning = 'false'" + ) + try: + yield + finally: + session.execute( + "SET LOCAL postgresql_audit.enable_versioning = 'true'" + ) + def init(self, base): self.base = base self.transaction_cls = self.transaction_model_factory(base) @@ -394,7 +339,123 @@ def init(self, base): base, self.transaction_cls ) + self.session_manager = self._session_manager_factory(self.transaction_cls) self.attach_listeners() +class VersioningManager(BasicVersioningManager): + def __init__( + self, + actor_cls=None, + session_manager_factory=None, + schema_name=None, + use_statement_level_triggers=True + ): + super().__init__( + actor_cls=actor_cls, + schema_name=schema_name, + use_statement_level_triggers=use_statement_level_triggers, + session_manager_factory=session_manager_factory + ) + self.listeners = ( + ( + orm.mapper, + 'instrument_class', + self.instrument_versioned_classes + ), + ( + orm.mapper, + 'after_configured', + self.configure_versioned_classes + ), + ) + self.table_listeners = self.get_table_listeners() + self.pending_classes = WeakSet() + self.cached_ddls = {} + + def get_table_listeners(self): + listeners = {'transaction': []} + + listeners['activity'] = [ + ('after_create', sa.schema.DDL( + render_tmpl('jsonb_change_key_name.sql', self.schema_name) + )), + ('after_create', partial( + create_audit_table, + schema_name=self.schema_name, + use_statement_level_triggers=self.use_statement_level_triggers + ) + ), + ('after_create', partial(create_operators, schema_name=self.schema_name)) + ] + if self.schema_name is not None: + listeners['transaction'] = [ + ('before_create', sa.schema.DDL( + render_tmpl('create_schema.sql', self.schema_name) + )), + ('after_drop', sa.schema.DDL( + render_tmpl('drop_schema.sql', self.schema_name) + )), + ] + return listeners + + def audit_table(self, table, exclude_columns=None): + args = [table.name] + if exclude_columns: + for column in exclude_columns: + if column not in table.c: + raise ImproperlyConfigured( + "Could not configure versioning. Table '{}'' does " + "not have a column named '{}'.".format( + table.name, column + ) + ) + args.append(array(exclude_columns)) + query = build_register_table_query(self.schema_name, *args) + if query not in cached_statements: + cached_statements[query] = StatementExecutor(query) + listener = (table, 'after_create', cached_statements[query]) + if not sa.event.contains(*listener): + sa.event.listen(*listener) + + def instrument_versioned_classes(self, mapper, cls): + """ + Collect versioned class and add it to pending_classes list. + + :mapper mapper: SQLAlchemy mapper object + :cls cls: SQLAlchemy declarative class + """ + if hasattr(cls, '__versioned__') and cls not in self.pending_classes: + self.pending_classes.add(cls) + + def configure_versioned_classes(self): + """ + Configures all versioned classes that were collected during + instrumentation process. + """ + for cls in self.pending_classes: + self.audit_table(cls.__table__, cls.__versioned__.get('exclude')) + assign_actor(self.base, self.transaction_cls, self.actor_cls) + + def attach_table_listeners(self): + for values in self.table_listeners['transaction']: + sa.event.listen(self.transaction_cls.__table__, *values) + for values in self.table_listeners['activity']: + sa.event.listen(self.activity_cls.__table__, *values) + + def remove_table_listeners(self): + for values in self.table_listeners['transaction']: + sa.event.remove(self.transaction_cls.__table__, *values) + for values in self.table_listeners['activity']: + sa.event.remove(self.activity_cls.__table__, *values) + + def attach_listeners(self): + self.attach_table_listeners() + super().attach_listeners() + + def remove_listeners(self): + self.remove_table_listeners() + super().remove_listeners() + + versioning_manager = VersioningManager() diff --git a/postgresql_audit/flask.py b/postgresql_audit/flask.py index 3ca1437..6ab23f8 100644 --- a/postgresql_audit/flask.py +++ b/postgresql_audit/flask.py @@ -6,12 +6,10 @@ from flask import g, request from flask.globals import _app_ctx_stack, _request_ctx_stack -from .base import VersioningManager as BaseVersioningManager +from .base import VersioningManager, SessionManager -class VersioningManager(BaseVersioningManager): - _actor_cls = 'User' - +class FlaskSessionManager(SessionManager): def get_transaction_values(self): values = copy(self.values) if context_available() and hasattr(g, 'activity_values'): @@ -79,4 +77,4 @@ def activity_values(**values): g.activity_values = previous_value -versioning_manager = VersioningManager() +versioning_manager = VersioningManager(actor_cls="User", session_manager_factory=FlaskSessionManager) From 67d105b389f9af6bc9758b316fdc168554be0c02 Mon Sep 17 00:00:00 2001 From: Cameron Hurst Date: Sun, 5 Jan 2020 16:14:38 -0500 Subject: [PATCH 03/12] feature: reworked migrations to allow autogeneration with alembic --- postgresql_audit/alembic/__init__.py | 130 ++++++++++++++++++ .../alembic/init_activity_table_triggers.py | 97 +++++++++++++ postgresql_audit/alembic/migration_ops.py | 71 ++++++++++ .../register_table_for_version_tracking.py | 75 ++++++++++ postgresql_audit/base.py | 6 +- postgresql_audit/migrations.py | 23 ---- 6 files changed, 376 insertions(+), 26 deletions(-) create mode 100644 postgresql_audit/alembic/__init__.py create mode 100644 postgresql_audit/alembic/init_activity_table_triggers.py create mode 100644 postgresql_audit/alembic/migration_ops.py create mode 100644 postgresql_audit/alembic/register_table_for_version_tracking.py diff --git a/postgresql_audit/alembic/__init__.py b/postgresql_audit/alembic/__init__.py new file mode 100644 index 0000000..0314043 --- /dev/null +++ b/postgresql_audit/alembic/__init__.py @@ -0,0 +1,130 @@ +import re +from itertools import groupby + +from alembic.autogenerate import comparators, rewriter +from alembic.operations import ops + +from postgresql_audit.alembic.init_activity_table_triggers import InitActivityTableTriggersOp, \ + RemoveActivityTableTriggersOp +from postgresql_audit.alembic.migration_ops import AddColumnToActivityOp, RemoveColumnFromRemoveActivityOp +from postgresql_audit.alembic.register_table_for_version_tracking import RegisterTableForVersionTrackingOp, \ + DeregisterTableForVersionTrackingOp + + +@comparators.dispatch_for("schema") +def compare_timestamp_schema(autogen_context, upgrade_ops, schemas): + routines = set() + for sch in schemas: + schema_name = autogen_context.dialect.default_schema_name if sch is None else sch + routines.update([ + (sch, *row) for row in autogen_context.connection.execute( + "select routine_name, routine_definition from information_schema.routines " + f"where routines.specific_schema='{schema_name}' " + )]) + + for sch in schemas: + should_track_versions = any("versioned" in table.info for table in autogen_context.sorted_tables if table.info and table.schema == sch) + schema_prefix = f"{sch}." if sch else "" + + a = next((v for k, v in groupby(routines, key=lambda x: x[0]) if k == sch), None) + a = list(a) if a else [] + if should_track_versions: + if f"{schema_prefix}audit_table" not in (x[1] for x in a): + upgrade_ops.ops.append( + InitActivityTableTriggersOp(False, schema=sch) + ) + else: + if f"{schema_prefix}audit_table" in (x[1] for x in a): + upgrade_ops.ops.append( + RemoveActivityTableTriggersOp(schema=sch) + ) + + +@comparators.dispatch_for("table") +def compare_timestamp_table(autogen_context, modify_ops, schemaname, tablename, conn_table, metadata_table): + if metadata_table is None: + return + meta_info = metadata_table.info or {} + # TODO: Query triggers on the table + schema_name = autogen_context.dialect.default_schema_name if schemaname is None else schemaname + + triggers = [row for row in autogen_context.connection.execute(f""" +select event_object_schema as table_schema, + event_object_table as table_name, + trigger_schema, + trigger_name, + string_agg(event_manipulation, ',') as event, + action_timing as activation, + action_condition as condition, + action_statement as definition +from information_schema.triggers +where event_object_table = '{tablename}' and trigger_schema = '{schema_name}' +group by 1,2,3,4,6,7,8 +order by table_schema, table_name; + """)] + + trigger_name = "audit_trigger" + + if "versioned" in meta_info: + excluded_columns = metadata_table.info["versioned"].get("exclude", tuple()) + trigger = next((trigger for trigger in triggers if trigger_name in trigger[3]), None) + original_excluded_columns = __get_existing_excluded_columns(trigger) + + if trigger and set(original_excluded_columns) == set(excluded_columns): + return + + modify_ops.ops.insert(0, + RegisterTableForVersionTrackingOp(tablename, excluded_columns, original_excluded_columns, schema=schema_name) + ) + else: + trigger = next((trigger for trigger in triggers if trigger_name in trigger[3]), None) + original_excluded_columns = __get_existing_excluded_columns(trigger) + + if trigger: + modify_ops.ops.append( + DeregisterTableForVersionTrackingOp(tablename, original_excluded_columns, schema=schema_name) + ) + + +def __get_existing_excluded_columns(trigger): + original_excluded_columns = () + if trigger: + arguments_match = re.search(r"EXECUTE FUNCTION create_activity\('{(.+)}'\)", trigger[7]) + if arguments_match: + original_excluded_columns = arguments_match.group(1).split(",") + return original_excluded_columns + + +writer = rewriter.Rewriter() + +@writer.rewrites(ops.AddColumnOp) +def add_column_rewrite(context, revision, op): + table_info = op.column.table.info or {} + if "versioned" in table_info and op.column.name not in table_info["versioned"].get("exclude", []): + return [ + op, + AddColumnToActivityOp( + op.table_name, + op.column.name, + schema=op.column.table.schema, + ), + ] + else: + return op + +@writer.rewrites(ops.DropColumnOp) +def drop_column_rewrite(context, revision, op): + column = op._orig_column + table_info = column.table.info or {} + if "versioned" in table_info and column.name not in table_info["versioned"].get("exclude", []): + return [ + op, + RemoveColumnFromRemoveActivityOp( + op.table_name, + column.name, + schema=column.table.schema, + ), + ] + else: + return op + diff --git a/postgresql_audit/alembic/init_activity_table_triggers.py b/postgresql_audit/alembic/init_activity_table_triggers.py new file mode 100644 index 0000000..c7d01c8 --- /dev/null +++ b/postgresql_audit/alembic/init_activity_table_triggers.py @@ -0,0 +1,97 @@ +from alembic.autogenerate import renderers +from alembic.operations import Operations, MigrateOperation + +from postgresql_audit.utils import render_tmpl, create_audit_table, create_operators + + +@Operations.register_operation("init_activity_table_triggers") +class InitActivityTableTriggersOp(MigrateOperation): + """Initialize Activity Table Triggers""" + + def __init__(self, use_statement_level_triggers, schema=None): + self.schema = schema + self.use_statement_level_triggers = use_statement_level_triggers + + @classmethod + def init_activity_table_triggers(cls, operations, use_statement_level_triggers, **kwargs): + op = InitActivityTableTriggersOp(use_statement_level_triggers, **kwargs) + return operations.invoke(op) + + def reverse(self): + # only needed to support autogenerate + return RemoveActivityTableTriggersOp(schema=self.schema) + +@Operations.register_operation("remove_activity_table_triggers") +class RemoveActivityTableTriggersOp(MigrateOperation): + """Drop Activity Table Triggers""" + + def __init__(self, use_statement_level_triggers, schema=None): + self.schema = schema + self.use_statement_level_triggers = use_statement_level_triggers + + + @classmethod + def remove_activity_table_triggers(cls, operations, use_statement_level_triggers, **kwargs): + op = RemoveActivityTableTriggersOp(use_statement_level_triggers, **kwargs) + return operations.invoke(op) + + def reverse(self): + # only needed to support autogenerate + return InitActivityTableTriggersOp(self.use_statement_level_triggers, schema=self.schema) + + +@Operations.implementation_for(InitActivityTableTriggersOp) +def init_activity_table_triggers(operations, operation): + conn = operations.connection + + if operation.schema: + conn.execute(render_tmpl('create_schema.sql', operation.schema)) + + conn.execute(render_tmpl('jsonb_change_key_name.sql', operation.schema)) + create_audit_table(None, conn, operation.schema, operation.use_statement_level_triggers) + create_operators(None, conn, operation.schema) + + +@Operations.implementation_for(RemoveActivityTableTriggersOp) +def remove_activity_table_triggers(operations, operation): + conn = operations.connection + bind = conn.bind + + if operation.schema: + conn.execute(render_tmpl('drop_schema.sql', operation.schema)) + + conn.execute("DROP FUNCTION jsonb_change_key_name(data jsonb, old_key text, new_key text)") + schema_prefix = f"{operation.schema}." if operation.schema else "" + + conn.execute(f"DROP FUNCTION {schema_prefix}audit_table(target_table regclass, ignored_cols text[])") + conn.execute(f"DROP FUNCTION {schema_prefix}create_activity()") + + + if bind.dialect.server_version_info < (9, 5, 0): + conn.execute(f"""DROP FUNCTION jsonb_subtract(jsonb, TEXT)""") + conn.execute(f"""DROP OPERATOR IF EXISTS - (jsonb, text);""") + conn.execute(f"""DROP FUNCTION jsonb_merge(jsonb, jsonb)""") + conn.execute(f"""DROP OPERATOR IF EXISTS || (jsonb, jsonb);""") + if bind.dialect.server_version_info < (9, 6, 0): + conn.execute(f"""DROP FUNCTION current_setting(TEXT, BOOL)""") + if bind.dialect.server_version_info < (10, 0): + conn.execute(f"""DROP FUNCTION jsonb_subtract(jsonb, TEXT[])""") + conn.execute(f"""DROP OPERATOR IF EXISTS - (jsonb, text[])""") + + conn.execute(f"""DROP FUNCTION jsonb_subtract(jsonb,jsonb)""") + conn.execute(f"""DROP OPERATOR IF EXISTS - (jsonb, jsonb)""") + conn.execute(f"""DROP FUNCTION get_setting(text, text)""") + + +@renderers.dispatch_for(InitActivityTableTriggersOp) +def render_init_activity_table_triggers(autogen_context, op): + return "op.init_activity_table_triggers(%r, **%r)" % ( + op.use_statement_level_triggers, + {"schema": op.schema} + ) + +@renderers.dispatch_for(RemoveActivityTableTriggersOp) +def render_remove_activity_table_triggers(autogen_context, op): + return "op.remove_activity_table_triggers(**%r)" % ( + {"schema": op.schema} + ) diff --git a/postgresql_audit/alembic/migration_ops.py b/postgresql_audit/alembic/migration_ops.py new file mode 100644 index 0000000..23d4b78 --- /dev/null +++ b/postgresql_audit/alembic/migration_ops.py @@ -0,0 +1,71 @@ +from alembic.autogenerate import renderers +from alembic.operations import Operations, MigrateOperation + +from postgresql_audit import add_column, remove_column + + +@Operations.register_operation("add_column_to_activity") +class AddColumnToActivityOp(MigrateOperation): + """Initialize Activity Table Triggers""" + + def __init__(self, table_name, column_name, default_value=None, schema=None): + self.schema = schema + self.table_name = table_name + self.column_name = column_name + self.default_value = default_value + + @classmethod + def add_column_to_activity(cls, operations, table_name, column_name, **kwargs): + op = AddColumnToActivityOp(table_name, column_name, **kwargs) + return operations.invoke(op) + + def reverse(self): + # only needed to support autogenerate + return RemoveColumnFromRemoveActivityOp(self.table_name, self.column_name, default_value=self.default_value, schema=self.schema) + +@Operations.register_operation("remove_column_from_activity") +class RemoveColumnFromRemoveActivityOp(MigrateOperation): + """Drop Activity Table Triggers""" + + def __init__(self, table_name, column_name, default_value=None, schema=None): + self.schema = schema + self.table_name = table_name + self.column_name = column_name + self.default_value = default_value + + @classmethod + def remove_column_from_activity(cls, operations, table_name, column_name, **kwargs): + op = RemoveColumnFromRemoveActivityOp(table_name, column_name, **kwargs) + return operations.invoke(op) + + def reverse(self): + # only needed to support autogenerate + return AddColumnToActivityOp(self.table_name, self.column_name, default_value=self.default_value, schema=self.schema) + + +@Operations.implementation_for(AddColumnToActivityOp) +def add_column_to_activity(operations, operation): + conn = operations.connection + add_column(conn, operation.table_name, operation.column_name, default_value=operation.default_value, schema=operation.schema) + + +@Operations.implementation_for(RemoveColumnFromRemoveActivityOp) +def remove_column_from_activity(operations, operation): + conn = operations.connection + remove_column(conn, operation.table_name, operation.column_name, operation.schema) + +@renderers.dispatch_for(AddColumnToActivityOp) +def render_add_column_to_activity(autogen_context, op): + return "op.add_column_to_activity(%r, %r, **%r)" % ( + op.table_name, + op.column_name, + {"schema": op.schema, "default_value": op.default_value} + ) + +@renderers.dispatch_for(RemoveColumnFromRemoveActivityOp) +def render_remove_column_from_activitys(autogen_context, op): + return "op.remove_column_from_activity(%r, %r, **%r)" % ( + op.table_name, + op.column_name, + {"schema": op.schema} + ) diff --git a/postgresql_audit/alembic/register_table_for_version_tracking.py b/postgresql_audit/alembic/register_table_for_version_tracking.py new file mode 100644 index 0000000..dd6a944 --- /dev/null +++ b/postgresql_audit/alembic/register_table_for_version_tracking.py @@ -0,0 +1,75 @@ +import sqlalchemy as sa + +from alembic.autogenerate import renderers +from alembic.operations import Operations, MigrateOperation + + +@Operations.register_operation("register_for_version_tracking") +class RegisterTableForVersionTrackingOp(MigrateOperation): + """Register Table for Version Tracking""" + + def __init__(self, tablename, excluded_columns, original_excluded_columns=None, schema=None): + self.schema = schema + self.tablename = tablename + self.excluded_columns = excluded_columns + self.original_excluded_columns = original_excluded_columns + + @classmethod + def register_for_version_tracking(cls, operations, tablename, exclude_columns, **kwargs): + op = RegisterTableForVersionTrackingOp(tablename, exclude_columns, **kwargs) + return operations.invoke(op) + + def reverse(self): + # only needed to support autogenerate + return DeregisterTableForVersionTrackingOp(self.tablename, self.original_excluded_columns, schema=self.schema) + +@Operations.register_operation("deregister_for_version_tracking") +class DeregisterTableForVersionTrackingOp(MigrateOperation): + """Drop Table from Version Tracking""" + + def __init__(self, tablename, excluded_columns, schema=None): + self.schema = schema + self.tablename = tablename + self.excluded_columns = excluded_columns + + + @classmethod + def deregister_for_version_tracking(cls, operations, tablename, **kwargs): + op = DeregisterTableForVersionTrackingOp(tablename, (), **kwargs) + return operations.invoke(op) + + def reverse(self): + # only needed to support autogenerate + return RegisterTableForVersionTrackingOp(self.tablename, self.excluded_columns, (), schema=self.schema) + + +@Operations.implementation_for(RegisterTableForVersionTrackingOp) +def register_for_version_tracking(operations, operation): + if operation.schema is None: + func = sa.func.audit_table + else: + func = getattr(getattr(sa.func, operation.schema), 'audit_table') + operations.execute(sa.select([func(operation.tablename, list(operation.excluded_columns))])) + + +@Operations.implementation_for(DeregisterTableForVersionTrackingOp) +def deregister_for_version_tracking(operations, operation): + operations.execute(f"drop trigger audit_trigger_insert on {operation.tablename} ") + operations.execute(f"drop trigger audit_trigger_update on {operation.tablename} ") + operations.execute(f"drop trigger audit_trigger_delete on {operation.tablename} ") + + +@renderers.dispatch_for(RegisterTableForVersionTrackingOp) +def render_register_for_version_tracking(autogen_context, op): + return "op.register_for_version_tracking(%r, %r, **%r)" % ( + op.tablename, + op.excluded_columns, + {"schema": op.schema} + ) + +@renderers.dispatch_for(DeregisterTableForVersionTrackingOp) +def render_deregister_for_version_tracking(autogen_context, op): + return "op.deregister_for_version_tracking(%r, **%r)" % ( + op.tablename, + {"schema": op.schema} + ) diff --git a/postgresql_audit/base.py b/postgresql_audit/base.py index 3b6758b..1976bac 100644 --- a/postgresql_audit/base.py +++ b/postgresql_audit/base.py @@ -207,9 +207,9 @@ def modified_columns(self, obj): def is_modified(self, obj_or_session): if hasattr(obj_or_session, '__mapper__'): - if not hasattr(obj_or_session, '__versioned__'): + if not (hasattr(obj_or_session, '__versioned__') or getattr(obj_or_session, '__table_args__', {}).get("versioned", None)): raise ClassNotVersioned(obj_or_session.__class__.__name__) - excluded = obj_or_session.__versioned__.get('exclude', []) + excluded = getattr(obj_or_session, "__versioned__", obj_or_session.__table_args__["versioned"]).get('exclude', []) return bool( set([ column.name @@ -220,7 +220,7 @@ def is_modified(self, obj_or_session): return any( self.is_modified(entity) or entity in obj_or_session.deleted for entity in obj_or_session - if hasattr(entity, '__versioned__') + if hasattr(entity, '__versioned__') or getattr(obj_or_session, '__table_args__', {}).get("versioned", None) ) def before_flush(self, session, flush_context, instances): diff --git a/postgresql_audit/migrations.py b/postgresql_audit/migrations.py index 56adab6..6fab915 100644 --- a/postgresql_audit/migrations.py +++ b/postgresql_audit/migrations.py @@ -1,8 +1,6 @@ import sqlalchemy as sa from sqlalchemy.dialects.postgresql import JSONB -from postgresql_audit.utils import render_tmpl, create_audit_table, create_operators, \ - build_register_table_query from .expressions import jsonb_change_key_name @@ -18,27 +16,6 @@ def get_activity_table(schema=None): schema=schema, ) -def init_activity_table_triggers(conn, schema_name = None, use_statement_level_triggers=True): - conn.execute(render_tmpl('jsonb_change_key_name.sql', schema_name)) - create_audit_table(None, conn, schema_name, use_statement_level_triggers) - create_operators(None, conn, schema_name) - - if schema_name: - conn.execute(render_tmpl('create_schema.sql', schema_name)) - -def rollback_create_transaction(conn, schema_name=None): - if schema_name: - conn.execute(render_tmpl('drop_schema.sql', schema_name)) - -def init_before_create_transaction(conn, schema_name=None): - if schema_name: - conn.execute(render_tmpl('create_schema.sql', schema_name)) - - -def register_table(conn, table_name, exclude_columns, schema_name=None): - sql = build_register_table_query(schema_name, table_name, exclude_columns) - conn.execute(sql) - def alter_column(conn, table, column_name, func, schema=None): """ From 834e6439b2dd5c21635c2c4f1c4cab9d62160fd6 Mon Sep 17 00:00:00 2001 From: Cameron Hurst Date: Sun, 5 Jan 2020 16:30:59 -0500 Subject: [PATCH 04/12] fix: cleanup some lingering issues --- postgresql_audit/alembic/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/postgresql_audit/alembic/__init__.py b/postgresql_audit/alembic/__init__.py index 0314043..bb1d84a 100644 --- a/postgresql_audit/alembic/__init__.py +++ b/postgresql_audit/alembic/__init__.py @@ -36,7 +36,7 @@ def compare_timestamp_schema(autogen_context, upgrade_ops, schemas): else: if f"{schema_prefix}audit_table" in (x[1] for x in a): upgrade_ops.ops.append( - RemoveActivityTableTriggersOp(schema=sch) + RemoveActivityTableTriggersOp(False, schema=sch) ) @@ -45,7 +45,6 @@ def compare_timestamp_table(autogen_context, modify_ops, schemaname, tablename, if metadata_table is None: return meta_info = metadata_table.info or {} - # TODO: Query triggers on the table schema_name = autogen_context.dialect.default_schema_name if schemaname is None else schemaname triggers = [row for row in autogen_context.connection.execute(f""" From a9ca7029a45c5d8d96b4b8a050905d9c908ecd9d Mon Sep 17 00:00:00 2001 From: Cameron Hurst Date: Fri, 17 Jan 2020 07:37:52 -0700 Subject: [PATCH 05/12] fix: minor fixes to variables to make code executable --- postgresql_audit/alembic/__init__.py | 2 +- .../alembic/init_activity_table_triggers.py | 19 ++++++++++--------- .../register_table_for_version_tracking.py | 7 ++++--- 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/postgresql_audit/alembic/__init__.py b/postgresql_audit/alembic/__init__.py index bb1d84a..e0509da 100644 --- a/postgresql_audit/alembic/__init__.py +++ b/postgresql_audit/alembic/__init__.py @@ -30,7 +30,7 @@ def compare_timestamp_schema(autogen_context, upgrade_ops, schemas): a = list(a) if a else [] if should_track_versions: if f"{schema_prefix}audit_table" not in (x[1] for x in a): - upgrade_ops.ops.append( + upgrade_ops.ops.insert(0, InitActivityTableTriggersOp(False, schema=sch) ) else: diff --git a/postgresql_audit/alembic/init_activity_table_triggers.py b/postgresql_audit/alembic/init_activity_table_triggers.py index c7d01c8..a40364a 100644 --- a/postgresql_audit/alembic/init_activity_table_triggers.py +++ b/postgresql_audit/alembic/init_activity_table_triggers.py @@ -19,7 +19,7 @@ def init_activity_table_triggers(cls, operations, use_statement_level_triggers, def reverse(self): # only needed to support autogenerate - return RemoveActivityTableTriggersOp(schema=self.schema) + return RemoveActivityTableTriggersOp(self.use_statement_level_triggers, schema=self.schema) @Operations.register_operation("remove_activity_table_triggers") class RemoveActivityTableTriggersOp(MigrateOperation): @@ -31,8 +31,8 @@ def __init__(self, use_statement_level_triggers, schema=None): @classmethod - def remove_activity_table_triggers(cls, operations, use_statement_level_triggers, **kwargs): - op = RemoveActivityTableTriggersOp(use_statement_level_triggers, **kwargs) + def remove_activity_table_triggers(cls, operations, **kwargs): + op = RemoveActivityTableTriggersOp(False, **kwargs) return operations.invoke(op) def reverse(self): @@ -42,20 +42,21 @@ def reverse(self): @Operations.implementation_for(InitActivityTableTriggersOp) def init_activity_table_triggers(operations, operation): - conn = operations.connection + conn = operations + bind = conn.get_bind() if operation.schema: conn.execute(render_tmpl('create_schema.sql', operation.schema)) conn.execute(render_tmpl('jsonb_change_key_name.sql', operation.schema)) - create_audit_table(None, conn, operation.schema, operation.use_statement_level_triggers) - create_operators(None, conn, operation.schema) + create_audit_table(None, bind, operation.schema, operation.use_statement_level_triggers) + create_operators(None, bind, operation.schema) @Operations.implementation_for(RemoveActivityTableTriggersOp) def remove_activity_table_triggers(operations, operation): - conn = operations.connection - bind = conn.bind + conn = operations + bind = conn.get_bind() if operation.schema: conn.execute(render_tmpl('drop_schema.sql', operation.schema)) @@ -78,9 +79,9 @@ def remove_activity_table_triggers(operations, operation): conn.execute(f"""DROP FUNCTION jsonb_subtract(jsonb, TEXT[])""") conn.execute(f"""DROP OPERATOR IF EXISTS - (jsonb, text[])""") - conn.execute(f"""DROP FUNCTION jsonb_subtract(jsonb,jsonb)""") conn.execute(f"""DROP OPERATOR IF EXISTS - (jsonb, jsonb)""") conn.execute(f"""DROP FUNCTION get_setting(text, text)""") + conn.execute(f"""DROP FUNCTION jsonb_subtract(jsonb,jsonb)""") @renderers.dispatch_for(InitActivityTableTriggersOp) diff --git a/postgresql_audit/alembic/register_table_for_version_tracking.py b/postgresql_audit/alembic/register_table_for_version_tracking.py index dd6a944..70a158a 100644 --- a/postgresql_audit/alembic/register_table_for_version_tracking.py +++ b/postgresql_audit/alembic/register_table_for_version_tracking.py @@ -54,9 +54,10 @@ def register_for_version_tracking(operations, operation): @Operations.implementation_for(DeregisterTableForVersionTrackingOp) def deregister_for_version_tracking(operations, operation): - operations.execute(f"drop trigger audit_trigger_insert on {operation.tablename} ") - operations.execute(f"drop trigger audit_trigger_update on {operation.tablename} ") - operations.execute(f"drop trigger audit_trigger_delete on {operation.tablename} ") + operations.execute(f"drop trigger if exists audit_trigger_insert on {operation.tablename} ") + operations.execute(f"drop trigger if exists audit_trigger_update on {operation.tablename} ") + operations.execute(f"drop trigger if exists audit_trigger_delete on {operation.tablename} ") + operations.execute(f"drop trigger if exists audit_trigger_row on {operation.tablename} ") @renderers.dispatch_for(RegisterTableForVersionTrackingOp) From 18ac8cc4e1e95565027387a2c0ba241cc8fbd730 Mon Sep 17 00:00:00 2001 From: Cameron Hurst Date: Fri, 7 Feb 2020 13:28:20 -0500 Subject: [PATCH 06/12] fix: using wroong ops --- postgresql_audit/alembic/migration_ops.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/postgresql_audit/alembic/migration_ops.py b/postgresql_audit/alembic/migration_ops.py index 23d4b78..f22d091 100644 --- a/postgresql_audit/alembic/migration_ops.py +++ b/postgresql_audit/alembic/migration_ops.py @@ -45,8 +45,7 @@ def reverse(self): @Operations.implementation_for(AddColumnToActivityOp) def add_column_to_activity(operations, operation): - conn = operations.connection - add_column(conn, operation.table_name, operation.column_name, default_value=operation.default_value, schema=operation.schema) + add_column(operations, operation.table_name, operation.column_name, default_value=operation.default_value, schema=operation.schema) @Operations.implementation_for(RemoveColumnFromRemoveActivityOp) From 5c28625f506f8c0f06533d8cba8fc8a879fad606 Mon Sep 17 00:00:00 2001 From: Cameron Hurst Date: Sun, 26 Apr 2020 17:55:25 -0400 Subject: [PATCH 07/12] fix: getting versioned info works properly not with __table_args__ --- postgresql_audit/base.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/postgresql_audit/base.py b/postgresql_audit/base.py index 1976bac..fdb47eb 100644 --- a/postgresql_audit/base.py +++ b/postgresql_audit/base.py @@ -1,4 +1,5 @@ import warnings +from collections import Sequence from contextlib import contextmanager from functools import partial from weakref import WeakSet @@ -207,9 +208,10 @@ def modified_columns(self, obj): def is_modified(self, obj_or_session): if hasattr(obj_or_session, '__mapper__'): - if not (hasattr(obj_or_session, '__versioned__') or getattr(obj_or_session, '__table_args__', {}).get("versioned", None)): + version_info = self.__get_versioned_info(obj_or_session) + if not version_info: raise ClassNotVersioned(obj_or_session.__class__.__name__) - excluded = getattr(obj_or_session, "__versioned__", obj_or_session.__table_args__["versioned"]).get('exclude', []) + excluded = version_info.get('exclude', []) return bool( set([ column.name @@ -220,9 +222,22 @@ def is_modified(self, obj_or_session): return any( self.is_modified(entity) or entity in obj_or_session.deleted for entity in obj_or_session - if hasattr(entity, '__versioned__') or getattr(obj_or_session, '__table_args__', {}).get("versioned", None) + if self.__get_versioned_info(entity) ) + def __get_versioned_info(self, entity): + v_args = getattr(entity, '__versioned__', None) + if v_args: + return v_args + table_args = getattr(entity, '__table_args__', None) + if not table_args: + return None + if isinstance(table_args, Sequence): + table_args = next((x for x in iter(table_args) if isinstance(x, dict)), None) + if not table_args: + return None + return table_args.get("info", {}).get("versioned", None) + def before_flush(self, session, flush_context, instances): if session.transaction in self._marked_transactions: return From 824322d7893003009476285d0525437cfcc1ce77 Mon Sep 17 00:00:00 2001 From: Jake Stewart Date: Tue, 24 Aug 2021 09:14:46 -0400 Subject: [PATCH 08/12] Update tests to support refactored versioning manager --- postgresql_audit/alembic/__init__.py | 160 ++++++++++++------ .../alembic/init_activity_table_triggers.py | 78 +++++---- postgresql_audit/alembic/migration_ops.py | 66 ++++++-- .../register_table_for_version_tracking.py | 64 ++++--- postgresql_audit/base.py | 52 +++--- postgresql_audit/flask.py | 6 +- postgresql_audit/utils.py | 8 +- tests/test_custom_schema.py | 8 +- tests/test_flask_integration.py | 10 +- tests/test_sqlalchemy_integration.py | 36 ++-- 10 files changed, 327 insertions(+), 161 deletions(-) diff --git a/postgresql_audit/alembic/__init__.py b/postgresql_audit/alembic/__init__.py index e0509da..2dbb2a0 100644 --- a/postgresql_audit/alembic/__init__.py +++ b/postgresql_audit/alembic/__init__.py @@ -4,102 +4,157 @@ from alembic.autogenerate import comparators, rewriter from alembic.operations import ops -from postgresql_audit.alembic.init_activity_table_triggers import InitActivityTableTriggersOp, \ +from postgresql_audit.alembic.init_activity_table_triggers import ( + InitActivityTableTriggersOp, RemoveActivityTableTriggersOp -from postgresql_audit.alembic.migration_ops import AddColumnToActivityOp, RemoveColumnFromRemoveActivityOp -from postgresql_audit.alembic.register_table_for_version_tracking import RegisterTableForVersionTrackingOp, \ - DeregisterTableForVersionTrackingOp - - -@comparators.dispatch_for("schema") +) +from postgresql_audit.alembic.migration_ops import ( + AddColumnToActivityOp, + RemoveColumnFromRemoveActivityOp +) +from postgresql_audit.alembic.register_table_for_version_tracking import ( + DeregisterTableForVersionTrackingOp, + RegisterTableForVersionTrackingOp +) + + +@comparators.dispatch_for('schema') def compare_timestamp_schema(autogen_context, upgrade_ops, schemas): routines = set() for sch in schemas: - schema_name = autogen_context.dialect.default_schema_name if sch is None else sch + schema_name = ( + autogen_context.dialect.default_schema_name if sch is None else sch + ) routines.update([ - (sch, *row) for row in autogen_context.connection.execute( - "select routine_name, routine_definition from information_schema.routines " - f"where routines.specific_schema='{schema_name}' " - )]) + (sch, *row) for row in autogen_context.connection.execute( + 'select routine_name, routine_definition ' + 'from information_schema.routines ' + f"where routines.specific_schema='{schema_name}' " + ) + ]) for sch in schemas: - should_track_versions = any("versioned" in table.info for table in autogen_context.sorted_tables if table.info and table.schema == sch) - schema_prefix = f"{sch}." if sch else "" + should_track_versions = any( + 'versioned' in table.info + for table in autogen_context.sorted_tables + if table.info and table.schema == sch + ) + schema_prefix = f'{sch}.' if sch else '' - a = next((v for k, v in groupby(routines, key=lambda x: x[0]) if k == sch), None) + a = next( + (v for k, v in groupby(routines, key=lambda x: x[0]) if k == sch), + None + ) a = list(a) if a else [] if should_track_versions: - if f"{schema_prefix}audit_table" not in (x[1] for x in a): - upgrade_ops.ops.insert(0, + if f'{schema_prefix}audit_table' not in (x[1] for x in a): + upgrade_ops.ops.insert( + 0, InitActivityTableTriggersOp(False, schema=sch) ) else: - if f"{schema_prefix}audit_table" in (x[1] for x in a): + if f'{schema_prefix}audit_table' in (x[1] for x in a): upgrade_ops.ops.append( RemoveActivityTableTriggersOp(False, schema=sch) ) -@comparators.dispatch_for("table") -def compare_timestamp_table(autogen_context, modify_ops, schemaname, tablename, conn_table, metadata_table): +@comparators.dispatch_for('table') +def compare_timestamp_table( + autogen_context, + modify_ops, + schemaname, + tablename, + conn_table, + metadata_table +): if metadata_table is None: return meta_info = metadata_table.info or {} - schema_name = autogen_context.dialect.default_schema_name if schemaname is None else schemaname - - triggers = [row for row in autogen_context.connection.execute(f""" -select event_object_schema as table_schema, - event_object_table as table_name, - trigger_schema, - trigger_name, - string_agg(event_manipulation, ',') as event, - action_timing as activation, - action_condition as condition, - action_statement as definition -from information_schema.triggers -where event_object_table = '{tablename}' and trigger_schema = '{schema_name}' -group by 1,2,3,4,6,7,8 -order by table_schema, table_name; - """)] - - trigger_name = "audit_trigger" - - if "versioned" in meta_info: - excluded_columns = metadata_table.info["versioned"].get("exclude", tuple()) - trigger = next((trigger for trigger in triggers if trigger_name in trigger[3]), None) + schema_name = ( + autogen_context.dialect.default_schema_name + if schemaname is None else schemaname + ) + + triggers = [row for row in autogen_context.connection.execute( + 'select event_object_schema as table_schema,' + 'event_object_table as table_name,' + 'trigger_schema,' + 'trigger_name,' + "string_agg(event_manipulation, ',') as event," + 'action_timing as activation,' + 'action_condition as condition,' + 'action_statement as definition' + 'from information_schema.triggers' + f"where event_object_table = '{tablename}'" + f"and trigger_schema = '{schema_name}'" + 'group by 1,2,3,4,6,7,8' + 'order by table_schema, table_name;' + )] + + trigger_name = 'audit_trigger' + + if 'versioned' in meta_info: + excluded_columns = ( + metadata_table.info['versioned'].get('exclude', tuple()) + ) + trigger = next( + (trigger for trigger in triggers if trigger_name in trigger[3]), + None + ) original_excluded_columns = __get_existing_excluded_columns(trigger) if trigger and set(original_excluded_columns) == set(excluded_columns): return - modify_ops.ops.insert(0, - RegisterTableForVersionTrackingOp(tablename, excluded_columns, original_excluded_columns, schema=schema_name) + modify_ops.ops.insert( + 0, + RegisterTableForVersionTrackingOp( + tablename, + excluded_columns, + original_excluded_columns, + schema=schema_name + ) ) else: - trigger = next((trigger for trigger in triggers if trigger_name in trigger[3]), None) + trigger = next( + (trigger for trigger in triggers if trigger_name in trigger[3]), + None + ) original_excluded_columns = __get_existing_excluded_columns(trigger) if trigger: modify_ops.ops.append( - DeregisterTableForVersionTrackingOp(tablename, original_excluded_columns, schema=schema_name) + DeregisterTableForVersionTrackingOp( + tablename, + original_excluded_columns, + schema=schema_name + ) ) def __get_existing_excluded_columns(trigger): original_excluded_columns = () if trigger: - arguments_match = re.search(r"EXECUTE FUNCTION create_activity\('{(.+)}'\)", trigger[7]) + arguments_match = re.search( + r"EXECUTE FUNCTION create_activity\('{(.+)}'\)", + trigger[7] + ) if arguments_match: - original_excluded_columns = arguments_match.group(1).split(",") + original_excluded_columns = arguments_match.group(1).split(',') return original_excluded_columns writer = rewriter.Rewriter() + @writer.rewrites(ops.AddColumnOp) def add_column_rewrite(context, revision, op): table_info = op.column.table.info or {} - if "versioned" in table_info and op.column.name not in table_info["versioned"].get("exclude", []): + if ( + 'versioned' in table_info + and op.column.name not in table_info['versioned'].get('exclude', []) + ): return [ op, AddColumnToActivityOp( @@ -111,11 +166,15 @@ def add_column_rewrite(context, revision, op): else: return op + @writer.rewrites(ops.DropColumnOp) def drop_column_rewrite(context, revision, op): column = op._orig_column table_info = column.table.info or {} - if "versioned" in table_info and column.name not in table_info["versioned"].get("exclude", []): + if ( + 'versioned' in table_info + and column.name not in table_info['versioned'].get('exclude', []) + ): return [ op, RemoveColumnFromRemoveActivityOp( @@ -126,4 +185,3 @@ def drop_column_rewrite(context, revision, op): ] else: return op - diff --git a/postgresql_audit/alembic/init_activity_table_triggers.py b/postgresql_audit/alembic/init_activity_table_triggers.py index a40364a..47a5703 100644 --- a/postgresql_audit/alembic/init_activity_table_triggers.py +++ b/postgresql_audit/alembic/init_activity_table_triggers.py @@ -1,10 +1,14 @@ from alembic.autogenerate import renderers -from alembic.operations import Operations, MigrateOperation +from alembic.operations import MigrateOperation, Operations -from postgresql_audit.utils import render_tmpl, create_audit_table, create_operators +from postgresql_audit.utils import ( + create_audit_table, + create_operators, + render_tmpl +) -@Operations.register_operation("init_activity_table_triggers") +@Operations.register_operation('init_activity_table_triggers') class InitActivityTableTriggersOp(MigrateOperation): """Initialize Activity Table Triggers""" @@ -13,15 +17,22 @@ def __init__(self, use_statement_level_triggers, schema=None): self.use_statement_level_triggers = use_statement_level_triggers @classmethod - def init_activity_table_triggers(cls, operations, use_statement_level_triggers, **kwargs): - op = InitActivityTableTriggersOp(use_statement_level_triggers, **kwargs) + def init_activity_table_triggers( + cls, operations, use_statement_level_triggers, **kwargs + ): + op = InitActivityTableTriggersOp( + use_statement_level_triggers, **kwargs + ) return operations.invoke(op) def reverse(self): # only needed to support autogenerate - return RemoveActivityTableTriggersOp(self.use_statement_level_triggers, schema=self.schema) + return RemoveActivityTableTriggersOp( + self.use_statement_level_triggers, schema=self.schema + ) -@Operations.register_operation("remove_activity_table_triggers") + +@Operations.register_operation('remove_activity_table_triggers') class RemoveActivityTableTriggersOp(MigrateOperation): """Drop Activity Table Triggers""" @@ -29,7 +40,6 @@ def __init__(self, use_statement_level_triggers, schema=None): self.schema = schema self.use_statement_level_triggers = use_statement_level_triggers - @classmethod def remove_activity_table_triggers(cls, operations, **kwargs): op = RemoveActivityTableTriggersOp(False, **kwargs) @@ -37,7 +47,9 @@ def remove_activity_table_triggers(cls, operations, **kwargs): def reverse(self): # only needed to support autogenerate - return InitActivityTableTriggersOp(self.use_statement_level_triggers, schema=self.schema) + return InitActivityTableTriggersOp( + self.use_statement_level_triggers, schema=self.schema + ) @Operations.implementation_for(InitActivityTableTriggersOp) @@ -49,7 +61,9 @@ def init_activity_table_triggers(operations, operation): conn.execute(render_tmpl('create_schema.sql', operation.schema)) conn.execute(render_tmpl('jsonb_change_key_name.sql', operation.schema)) - create_audit_table(None, bind, operation.schema, operation.use_statement_level_triggers) + create_audit_table( + None, bind, operation.schema, operation.use_statement_level_triggers + ) create_operators(None, bind, operation.schema) @@ -61,38 +75,44 @@ def remove_activity_table_triggers(operations, operation): if operation.schema: conn.execute(render_tmpl('drop_schema.sql', operation.schema)) - conn.execute("DROP FUNCTION jsonb_change_key_name(data jsonb, old_key text, new_key text)") - schema_prefix = f"{operation.schema}." if operation.schema else "" - - conn.execute(f"DROP FUNCTION {schema_prefix}audit_table(target_table regclass, ignored_cols text[])") - conn.execute(f"DROP FUNCTION {schema_prefix}create_activity()") + conn.execute( + 'DROP FUNCTION jsonb_change_key_name(' + 'data jsonb, old_key text, new_key text)' + ) + schema_prefix = f'{operation.schema}.' if operation.schema else '' + conn.execute( + f'DROP FUNCTION {schema_prefix}audit_table(' + 'target_table regclass, ignored_cols text[])' + ) + conn.execute(f'DROP FUNCTION {schema_prefix}create_activity()') if bind.dialect.server_version_info < (9, 5, 0): - conn.execute(f"""DROP FUNCTION jsonb_subtract(jsonb, TEXT)""") - conn.execute(f"""DROP OPERATOR IF EXISTS - (jsonb, text);""") - conn.execute(f"""DROP FUNCTION jsonb_merge(jsonb, jsonb)""") - conn.execute(f"""DROP OPERATOR IF EXISTS || (jsonb, jsonb);""") + conn.execute('DROP FUNCTION jsonb_subtract(jsonb, TEXT)') + conn.execute('DROP OPERATOR IF EXISTS - (jsonb, text);') + conn.execute('DROP FUNCTION jsonb_merge(jsonb, jsonb)') + conn.execute('DROP OPERATOR IF EXISTS || (jsonb, jsonb);') if bind.dialect.server_version_info < (9, 6, 0): - conn.execute(f"""DROP FUNCTION current_setting(TEXT, BOOL)""") + conn.execute('DROP FUNCTION current_setting(TEXT, BOOL)') if bind.dialect.server_version_info < (10, 0): - conn.execute(f"""DROP FUNCTION jsonb_subtract(jsonb, TEXT[])""") - conn.execute(f"""DROP OPERATOR IF EXISTS - (jsonb, text[])""") + conn.execute('DROP FUNCTION jsonb_subtract(jsonb, TEXT[])') + conn.execute('DROP OPERATOR IF EXISTS - (jsonb, text[])') - conn.execute(f"""DROP OPERATOR IF EXISTS - (jsonb, jsonb)""") - conn.execute(f"""DROP FUNCTION get_setting(text, text)""") - conn.execute(f"""DROP FUNCTION jsonb_subtract(jsonb,jsonb)""") + conn.execute('DROP OPERATOR IF EXISTS - (jsonb, jsonb)') + conn.execute('DROP FUNCTION get_setting(text, text)') + conn.execute('DROP FUNCTION jsonb_subtract(jsonb,jsonb)') @renderers.dispatch_for(InitActivityTableTriggersOp) def render_init_activity_table_triggers(autogen_context, op): - return "op.init_activity_table_triggers(%r, **%r)" % ( + return 'op.init_activity_table_triggers(%r, **%r)' % ( op.use_statement_level_triggers, - {"schema": op.schema} + {'schema': op.schema} ) + @renderers.dispatch_for(RemoveActivityTableTriggersOp) def render_remove_activity_table_triggers(autogen_context, op): - return "op.remove_activity_table_triggers(**%r)" % ( - {"schema": op.schema} + return 'op.remove_activity_table_triggers(**%r)' % ( + {'schema': op.schema} ) diff --git a/postgresql_audit/alembic/migration_ops.py b/postgresql_audit/alembic/migration_ops.py index f22d091..29f0c73 100644 --- a/postgresql_audit/alembic/migration_ops.py +++ b/postgresql_audit/alembic/migration_ops.py @@ -1,70 +1,104 @@ from alembic.autogenerate import renderers -from alembic.operations import Operations, MigrateOperation +from alembic.operations import MigrateOperation, Operations from postgresql_audit import add_column, remove_column -@Operations.register_operation("add_column_to_activity") +@Operations.register_operation('add_column_to_activity') class AddColumnToActivityOp(MigrateOperation): """Initialize Activity Table Triggers""" - def __init__(self, table_name, column_name, default_value=None, schema=None): + def __init__( + self, table_name, column_name, default_value=None, schema=None + ): self.schema = schema self.table_name = table_name self.column_name = column_name self.default_value = default_value @classmethod - def add_column_to_activity(cls, operations, table_name, column_name, **kwargs): + def add_column_to_activity( + cls, operations, table_name, column_name, **kwargs + ): op = AddColumnToActivityOp(table_name, column_name, **kwargs) return operations.invoke(op) def reverse(self): # only needed to support autogenerate - return RemoveColumnFromRemoveActivityOp(self.table_name, self.column_name, default_value=self.default_value, schema=self.schema) + return RemoveColumnFromRemoveActivityOp( + self.table_name, + self.column_name, + default_value=self.default_value, + schema=self.schema + ) -@Operations.register_operation("remove_column_from_activity") + +@Operations.register_operation('remove_column_from_activity') class RemoveColumnFromRemoveActivityOp(MigrateOperation): """Drop Activity Table Triggers""" - def __init__(self, table_name, column_name, default_value=None, schema=None): + def __init__( + self, table_name, column_name, default_value=None, schema=None + ): self.schema = schema self.table_name = table_name self.column_name = column_name self.default_value = default_value @classmethod - def remove_column_from_activity(cls, operations, table_name, column_name, **kwargs): - op = RemoveColumnFromRemoveActivityOp(table_name, column_name, **kwargs) + def remove_column_from_activity( + cls, operations, table_name, column_name, **kwargs + ): + op = RemoveColumnFromRemoveActivityOp( + table_name, column_name, **kwargs + ) return operations.invoke(op) def reverse(self): # only needed to support autogenerate - return AddColumnToActivityOp(self.table_name, self.column_name, default_value=self.default_value, schema=self.schema) + return AddColumnToActivityOp( + self.table_name, + self.column_name, + default_value=self.default_value, + schema=self.schema + ) @Operations.implementation_for(AddColumnToActivityOp) def add_column_to_activity(operations, operation): - add_column(operations, operation.table_name, operation.column_name, default_value=operation.default_value, schema=operation.schema) + add_column( + operations, + operation.table_name, + operation.column_name, + default_value=operation.default_value, + schema=operation.schema + ) @Operations.implementation_for(RemoveColumnFromRemoveActivityOp) def remove_column_from_activity(operations, operation): conn = operations.connection - remove_column(conn, operation.table_name, operation.column_name, operation.schema) + remove_column( + conn, + operation.table_name, + operation.column_name, + operation.schema + ) + @renderers.dispatch_for(AddColumnToActivityOp) def render_add_column_to_activity(autogen_context, op): - return "op.add_column_to_activity(%r, %r, **%r)" % ( + return 'op.add_column_to_activity(%r, %r, **%r)' % ( op.table_name, op.column_name, - {"schema": op.schema, "default_value": op.default_value} + {'schema': op.schema, 'default_value': op.default_value} ) + @renderers.dispatch_for(RemoveColumnFromRemoveActivityOp) def render_remove_column_from_activitys(autogen_context, op): - return "op.remove_column_from_activity(%r, %r, **%r)" % ( + return 'op.remove_column_from_activity(%r, %r, **%r)' % ( op.table_name, op.column_name, - {"schema": op.schema} + {'schema': op.schema} ) diff --git a/postgresql_audit/alembic/register_table_for_version_tracking.py b/postgresql_audit/alembic/register_table_for_version_tracking.py index 70a158a..4316310 100644 --- a/postgresql_audit/alembic/register_table_for_version_tracking.py +++ b/postgresql_audit/alembic/register_table_for_version_tracking.py @@ -1,29 +1,41 @@ import sqlalchemy as sa - from alembic.autogenerate import renderers -from alembic.operations import Operations, MigrateOperation +from alembic.operations import MigrateOperation, Operations -@Operations.register_operation("register_for_version_tracking") +@Operations.register_operation('register_for_version_tracking') class RegisterTableForVersionTrackingOp(MigrateOperation): """Register Table for Version Tracking""" - def __init__(self, tablename, excluded_columns, original_excluded_columns=None, schema=None): + def __init__( + self, + tablename, + excluded_columns, + original_excluded_columns=None, + schema=None + ): self.schema = schema self.tablename = tablename self.excluded_columns = excluded_columns self.original_excluded_columns = original_excluded_columns @classmethod - def register_for_version_tracking(cls, operations, tablename, exclude_columns, **kwargs): - op = RegisterTableForVersionTrackingOp(tablename, exclude_columns, **kwargs) + def register_for_version_tracking( + cls, operations, tablename, exclude_columns, **kwargs + ): + op = RegisterTableForVersionTrackingOp( + tablename, exclude_columns, **kwargs + ) return operations.invoke(op) def reverse(self): # only needed to support autogenerate - return DeregisterTableForVersionTrackingOp(self.tablename, self.original_excluded_columns, schema=self.schema) + return DeregisterTableForVersionTrackingOp( + self.tablename, self.original_excluded_columns, schema=self.schema + ) + -@Operations.register_operation("deregister_for_version_tracking") +@Operations.register_operation('deregister_for_version_tracking') class DeregisterTableForVersionTrackingOp(MigrateOperation): """Drop Table from Version Tracking""" @@ -32,7 +44,6 @@ def __init__(self, tablename, excluded_columns, schema=None): self.tablename = tablename self.excluded_columns = excluded_columns - @classmethod def deregister_for_version_tracking(cls, operations, tablename, **kwargs): op = DeregisterTableForVersionTrackingOp(tablename, (), **kwargs) @@ -40,7 +51,9 @@ def deregister_for_version_tracking(cls, operations, tablename, **kwargs): def reverse(self): # only needed to support autogenerate - return RegisterTableForVersionTrackingOp(self.tablename, self.excluded_columns, (), schema=self.schema) + return RegisterTableForVersionTrackingOp( + self.tablename, self.excluded_columns, (), schema=self.schema + ) @Operations.implementation_for(RegisterTableForVersionTrackingOp) @@ -49,28 +62,41 @@ def register_for_version_tracking(operations, operation): func = sa.func.audit_table else: func = getattr(getattr(sa.func, operation.schema), 'audit_table') - operations.execute(sa.select([func(operation.tablename, list(operation.excluded_columns))])) + operations.execute( + sa.select( + [func(operation.tablename, list(operation.excluded_columns))] + ) + ) @Operations.implementation_for(DeregisterTableForVersionTrackingOp) def deregister_for_version_tracking(operations, operation): - operations.execute(f"drop trigger if exists audit_trigger_insert on {operation.tablename} ") - operations.execute(f"drop trigger if exists audit_trigger_update on {operation.tablename} ") - operations.execute(f"drop trigger if exists audit_trigger_delete on {operation.tablename} ") - operations.execute(f"drop trigger if exists audit_trigger_row on {operation.tablename} ") + operations.execute( + f'drop trigger if exists audit_trigger_insert on {operation.tablename}' + ) + operations.execute( + f'drop trigger if exists audit_trigger_update on {operation.tablename}' + ) + operations.execute( + f'drop trigger if exists audit_trigger_delete on {operation.tablename}' + ) + operations.execute( + f'drop trigger if exists audit_trigger_row on {operation.tablename}' + ) @renderers.dispatch_for(RegisterTableForVersionTrackingOp) def render_register_for_version_tracking(autogen_context, op): - return "op.register_for_version_tracking(%r, %r, **%r)" % ( + return 'op.register_for_version_tracking(%r, %r, **%r)' % ( op.tablename, op.excluded_columns, - {"schema": op.schema} + {'schema': op.schema} ) + @renderers.dispatch_for(DeregisterTableForVersionTrackingOp) def render_deregister_for_version_tracking(autogen_context, op): - return "op.deregister_for_version_tracking(%r, **%r)" % ( + return 'op.deregister_for_version_tracking(%r, **%r)' % ( op.tablename, - {"schema": op.schema} + {'schema': op.schema} ) diff --git a/postgresql_audit/base.py b/postgresql_audit/base.py index fdb47eb..9bce022 100644 --- a/postgresql_audit/base.py +++ b/postgresql_audit/base.py @@ -18,8 +18,13 @@ from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy_utils import get_class_by_table -from postgresql_audit.utils import render_tmpl, StatementExecutor, create_audit_table, create_operators, \ - build_register_table_query +from postgresql_audit.utils import ( + build_register_table_query, + create_audit_table, + create_operators, + render_tmpl, + StatementExecutor +) cached_statements = {} @@ -31,6 +36,7 @@ class ImproperlyConfigured(Exception): class ClassNotVersioned(Exception): pass + def assign_actor(base, cls, actor_cls): if hasattr(cls, 'actor_id'): return @@ -208,10 +214,9 @@ def modified_columns(self, obj): def is_modified(self, obj_or_session): if hasattr(obj_or_session, '__mapper__'): - version_info = self.__get_versioned_info(obj_or_session) - if not version_info: + if not hasattr(obj_or_session, '__versioned__'): raise ClassNotVersioned(obj_or_session.__class__.__name__) - excluded = version_info.get('exclude', []) + excluded = obj_or_session.__versioned__.get('exclude', []) return bool( set([ column.name @@ -222,7 +227,7 @@ def is_modified(self, obj_or_session): return any( self.is_modified(entity) or entity in obj_or_session.deleted for entity in obj_or_session - if self.__get_versioned_info(entity) + if hasattr(entity, '__versioned__') ) def __get_versioned_info(self, entity): @@ -233,10 +238,11 @@ def __get_versioned_info(self, entity): if not table_args: return None if isinstance(table_args, Sequence): - table_args = next((x for x in iter(table_args) if isinstance(x, dict)), None) + table_args = next( + (x for x in iter(table_args) if isinstance(x, dict)), None) if not table_args: return None - return table_args.get("info", {}).get("versioned", None) + return table_args.get('info', {}).get('versioned', None) def before_flush(self, session, flush_context, instances): if session.transaction in self._marked_transactions: @@ -257,6 +263,7 @@ def remove_listeners(self): for listener in self.listeners: sa.event.remove(*listener) + class BasicVersioningManager(object): _actor_cls = None _session_manager_factory = partial(SessionManager, values={}) @@ -354,7 +361,9 @@ def init(self, base): base, self.transaction_cls ) - self.session_manager = self._session_manager_factory(self.transaction_cls) + self.session_manager = self._session_manager_factory( + self.transaction_cls + ) self.attach_listeners() @@ -391,18 +400,21 @@ def __init__( def get_table_listeners(self): listeners = {'transaction': []} - listeners['activity'] = [ - ('after_create', sa.schema.DDL( + listeners['activity'] = [( + 'after_create', sa.schema.DDL( render_tmpl('jsonb_change_key_name.sql', self.schema_name) - )), - ('after_create', partial( - create_audit_table, - schema_name=self.schema_name, - use_statement_level_triggers=self.use_statement_level_triggers - ) - ), - ('after_create', partial(create_operators, schema_name=self.schema_name)) - ] + ) + ), ( + 'after_create', partial( + create_audit_table, + schema_name=self.schema_name, + use_statement_level_triggers=self.use_statement_level_triggers + ) + ), ( + 'after_create', partial( + create_operators, schema_name=self.schema_name + ) + )] if self.schema_name is not None: listeners['transaction'] = [ ('before_create', sa.schema.DDL( diff --git a/postgresql_audit/flask.py b/postgresql_audit/flask.py index 6ab23f8..6049d3a 100644 --- a/postgresql_audit/flask.py +++ b/postgresql_audit/flask.py @@ -6,7 +6,7 @@ from flask import g, request from flask.globals import _app_ctx_stack, _request_ctx_stack -from .base import VersioningManager, SessionManager +from .base import SessionManager, VersioningManager class FlaskSessionManager(SessionManager): @@ -77,4 +77,6 @@ def activity_values(**values): g.activity_values = previous_value -versioning_manager = VersioningManager(actor_cls="User", session_manager_factory=FlaskSessionManager) +versioning_manager = VersioningManager( + actor_cls='User', session_manager_factory=FlaskSessionManager +) diff --git a/postgresql_audit/utils.py b/postgresql_audit/utils.py index 127e694..d049c4c 100644 --- a/postgresql_audit/utils.py +++ b/postgresql_audit/utils.py @@ -1,5 +1,6 @@ import os import string + import sqlalchemy as sa HERE = os.path.dirname(os.path.abspath(__file__)) @@ -14,11 +15,13 @@ def __call__(self, target, bind, **kwargs): bind.execute(self.stmt) tx.commit() + def read_file(file_): with open(os.path.join(HERE, file_)) as f: s = f.read() return s + def render_tmpl(tmpl_name, schema_name=None): file_contents = read_file( 'templates/{}'.format(tmpl_name) @@ -53,7 +56,10 @@ def create_operators(target, bind, schema_name, **kwargs): operators_template = render_tmpl('operators.sql', schema_name) StatementExecutor(operators_template)(target, bind, **kwargs) -def create_audit_table(target, bind, schema_name, use_statement_level_triggers, **kwargs): + +def create_audit_table( + target, bind, schema_name, use_statement_level_triggers, **kwargs +): sql = '' if ( use_statement_level_triggers and diff --git a/tests/test_custom_schema.py b/tests/test_custom_schema.py index 4e17bdf..d167ace 100644 --- a/tests/test_custom_schema.py +++ b/tests/test_custom_schema.py @@ -98,7 +98,7 @@ def test_manager_defaults( versioning_manager, activity_cls ): - versioning_manager.values = {'actor_id': 1} + versioning_manager.session_manager.values = {'actor_id': 1} user = user_class(name='John') session.add(user) session.commit() @@ -112,7 +112,7 @@ def test_callables_as_manager_defaults( versioning_manager, activity_cls ): - versioning_manager.values = {'actor_id': lambda: 1} + versioning_manager.session_manager.values = {'actor_id': lambda: 1} user = user_class(name='John') session.add(user) session.commit() @@ -126,8 +126,8 @@ def test_raw_inserts( versioning_manager, activity_cls ): - versioning_manager.values = {'actor_id': 1} - versioning_manager.set_activity_values(session) + versioning_manager.session_manager.values = {'actor_id': 1} + versioning_manager.session_manager.set_activity_values(session) session.execute(user_class.__table__.insert().values(name='John')) session.execute(user_class.__table__.insert().values(name='John')) activity = session.query(activity_cls).first() diff --git a/tests/test_flask_integration.py b/tests/test_flask_integration.py index 3721854..9c175ed 100644 --- a/tests/test_flask_integration.py +++ b/tests/test_flask_integration.py @@ -7,7 +7,11 @@ from flask_sqlalchemy import SQLAlchemy import sqlalchemy as sa -from postgresql_audit.flask import activity_values, VersioningManager +from postgresql_audit.flask import ( + activity_values, + FlaskSessionManager, + VersioningManager +) def login(client, user): @@ -77,7 +81,9 @@ def test_simple_flush(): @pytest.yield_fixture def versioning_manager(db): - vm = VersioningManager() + vm = VersioningManager( + actor_cls="User", session_manager_factory=FlaskSessionManager + ) vm.init(db.Model) yield vm vm.remove_listeners() diff --git a/tests/test_sqlalchemy_integration.py b/tests/test_sqlalchemy_integration.py index 080a66c..58dc7b0 100644 --- a/tests/test_sqlalchemy_integration.py +++ b/tests/test_sqlalchemy_integration.py @@ -73,7 +73,7 @@ def test_manager_defaults( session, versioning_manager ): - versioning_manager.values = {'actor_id': 1} + versioning_manager.session_manager.values = {'actor_id': 1} user = user_class(name='John') session.add(user) session.commit() @@ -86,7 +86,7 @@ def test_callables_as_manager_defaults( session, versioning_manager ): - versioning_manager.values = {'actor_id': lambda: 1} + versioning_manager.session_manager.values = {'actor_id': lambda: 1} user = user_class(name='John') session.add(user) session.commit() @@ -100,8 +100,8 @@ def test_raw_inserts( versioning_manager, activity_cls ): - versioning_manager.values = {'actor_id': 1} - versioning_manager.set_activity_values(session) + versioning_manager.session_manager.values = {'actor_id': 1} + versioning_manager.session_manager.set_activity_values(session) session.execute(user_class.__table__.insert().values(name='John')) session.execute(user_class.__table__.insert().values(name='John')) @@ -208,7 +208,9 @@ def test_multiple_flush_within_same_transaction( session, versioning_manager ): - versioning_manager.values = {'client_addr': '127.0.0.1'} + versioning_manager.session_manager.values = { + 'client_addr': '127.0.0.1' + } user = user_class(name='Jack') session.add(user) session.flush() @@ -336,7 +338,7 @@ def test_class_with_synonyms( ): article = article_class(name='Someone', _created_at=datetime.now()) session.add(article) - assert versioning_manager.is_modified(article) + assert versioning_manager.session_manager.is_modified(article) def test_modified_transient_object( self, @@ -346,8 +348,8 @@ def test_modified_transient_object( ): article = article_class(name='Article 1') session.add(article) - assert versioning_manager.is_modified(article) - assert versioning_manager.is_modified(session) + assert versioning_manager.session_manager.is_modified(article) + assert versioning_manager.session_manager.is_modified(session) def test_modified_excluded_column_with_persistent_object( self, @@ -356,8 +358,8 @@ def test_modified_excluded_column_with_persistent_object( session ): article.updated_at = datetime.now() - assert not versioning_manager.is_modified(article) - assert not versioning_manager.is_modified(session) + assert not versioning_manager.session_manager.is_modified(article) + assert not versioning_manager.session_manager.is_modified(session) def test_modified_persistent_object( self, @@ -366,8 +368,8 @@ def test_modified_persistent_object( session ): article.name = 'Article updated' - assert versioning_manager.is_modified(article) - assert versioning_manager.is_modified(session) + assert versioning_manager.session_manager.is_modified(article) + assert versioning_manager.session_manager.is_modified(session) def test_modified_excluded_relationship_column( self, @@ -377,8 +379,8 @@ def test_modified_excluded_relationship_column( session ): article.creator = user_class(name='Someone') - assert not versioning_manager.is_modified(article) - assert not versioning_manager.is_modified(session) + assert not versioning_manager.session_manager.is_modified(article) + assert not versioning_manager.session_manager.is_modified(session) def test_modified_relationship( self, @@ -388,8 +390,8 @@ def test_modified_relationship( session ): article.author = user_class(name='Someone') - assert versioning_manager.is_modified(article) - assert versioning_manager.is_modified(session) + assert versioning_manager.session_manager.is_modified(article) + assert versioning_manager.session_manager.is_modified(session) def test_deleted_object( self, @@ -399,7 +401,7 @@ def test_deleted_object( session ): session.delete(article) - assert versioning_manager.is_modified(session) + assert versioning_manager.session_manager.is_modified(session) @pytest.mark.usefixtures('versioning_manager', 'table_creator') From f31c371dc1e3b2e41112ed1738f1a9a0ac61ee33 Mon Sep 17 00:00:00 2001 From: Jake Stewart Date: Tue, 31 Aug 2021 12:36:03 -0400 Subject: [PATCH 09/12] Keep versioning manager as it was. --- postgresql_audit/alembic/__init__.py | 68 ++-- .../alembic/init_activity_table_triggers.py | 18 +- .../register_table_for_version_tracking.py | 8 +- postgresql_audit/base.py | 342 +++++++----------- postgresql_audit/flask.py | 10 +- postgresql_audit/utils.py | 12 +- tests/test_custom_schema.py | 8 +- tests/test_flask_integration.py | 10 +- tests/test_sqlalchemy_integration.py | 36 +- 9 files changed, 217 insertions(+), 295 deletions(-) diff --git a/postgresql_audit/alembic/__init__.py b/postgresql_audit/alembic/__init__.py index 2dbb2a0..8785a47 100644 --- a/postgresql_audit/alembic/__init__.py +++ b/postgresql_audit/alembic/__init__.py @@ -1,5 +1,4 @@ import re -from itertools import groupby from alembic.autogenerate import comparators, rewriter from alembic.operations import ops @@ -21,41 +20,40 @@ @comparators.dispatch_for('schema') def compare_timestamp_schema(autogen_context, upgrade_ops, schemas): routines = set() - for sch in schemas: + for schema in schemas: schema_name = ( - autogen_context.dialect.default_schema_name if sch is None else sch + autogen_context.dialect.default_schema_name if schema is None + else schema ) routines.update([ - (sch, *row) for row in autogen_context.connection.execute( - 'select routine_name, routine_definition ' - 'from information_schema.routines ' - f"where routines.specific_schema='{schema_name}' " - ) + (schema, *row) for row in autogen_context.connection.execute(f''' + SELECT routine_name, routine_definition + FROM information_schema.routines + WHERE routines.specific_schema='{schema_name}' + ''') ]) - for sch in schemas: + for schema in schemas: should_track_versions = any( 'versioned' in table.info for table in autogen_context.sorted_tables - if table.info and table.schema == sch + if table.info and table.schema == schema ) - schema_prefix = f'{sch}.' if sch else '' + schema_prefix = f'{schema}.' if schema else '' + tracked = f'{schema_prefix}audit_table' in [ + routine[1] for routine in routines if routine[0] == schema + ] - a = next( - (v for k, v in groupby(routines, key=lambda x: x[0]) if k == sch), - None - ) - a = list(a) if a else [] if should_track_versions: - if f'{schema_prefix}audit_table' not in (x[1] for x in a): + if not tracked: upgrade_ops.ops.insert( 0, - InitActivityTableTriggersOp(False, schema=sch) + InitActivityTableTriggersOp(False, schema=schema) ) else: - if f'{schema_prefix}audit_table' in (x[1] for x in a): + if tracked: upgrade_ops.ops.append( - RemoveActivityTableTriggersOp(False, schema=sch) + RemoveActivityTableTriggersOp(False, schema=schema) ) @@ -76,21 +74,21 @@ def compare_timestamp_table( if schemaname is None else schemaname ) - triggers = [row for row in autogen_context.connection.execute( - 'select event_object_schema as table_schema,' - 'event_object_table as table_name,' - 'trigger_schema,' - 'trigger_name,' - "string_agg(event_manipulation, ',') as event," - 'action_timing as activation,' - 'action_condition as condition,' - 'action_statement as definition' - 'from information_schema.triggers' - f"where event_object_table = '{tablename}'" - f"and trigger_schema = '{schema_name}'" - 'group by 1,2,3,4,6,7,8' - 'order by table_schema, table_name;' - )] + triggers = [row for row in autogen_context.connection.execute(f''' + SELECT event_object_schema AS table_schema, + event_object_table AS table_name, + trigger_schema, + trigger_name, + STRING_AGG(event_manipulation, ',') AS event, + action_timing AS activation, + action_condition AS condition, + action_statement AS definition + FROM information_schema.triggers + WHERE event_object_table = '{tablename}' + AND trigger_schema = '{schema_name}' + GROUP BY 1,2,3,4,6,7,8 + ORDER BY table_schema, table_name + ''')] trigger_name = 'audit_trigger' diff --git a/postgresql_audit/alembic/init_activity_table_triggers.py b/postgresql_audit/alembic/init_activity_table_triggers.py index 47a5703..5a37549 100644 --- a/postgresql_audit/alembic/init_activity_table_triggers.py +++ b/postgresql_audit/alembic/init_activity_table_triggers.py @@ -75,16 +75,18 @@ def remove_activity_table_triggers(operations, operation): if operation.schema: conn.execute(render_tmpl('drop_schema.sql', operation.schema)) - conn.execute( - 'DROP FUNCTION jsonb_change_key_name(' - 'data jsonb, old_key text, new_key text)' - ) + conn.execute(''' + DROP FUNCTION jsonb_change_key_name( + data jsonb, old_key text, new_key text + ) + ''') schema_prefix = f'{operation.schema}.' if operation.schema else '' - conn.execute( - f'DROP FUNCTION {schema_prefix}audit_table(' - 'target_table regclass, ignored_cols text[])' - ) + conn.execute(f''' + DROP FUNCTION {schema_prefix}audit_table( + target_table regclass, ignored_cols text[] + ) + ''') conn.execute(f'DROP FUNCTION {schema_prefix}create_activity()') if bind.dialect.server_version_info < (9, 5, 0): diff --git a/postgresql_audit/alembic/register_table_for_version_tracking.py b/postgresql_audit/alembic/register_table_for_version_tracking.py index 4316310..daf4b50 100644 --- a/postgresql_audit/alembic/register_table_for_version_tracking.py +++ b/postgresql_audit/alembic/register_table_for_version_tracking.py @@ -72,16 +72,16 @@ def register_for_version_tracking(operations, operation): @Operations.implementation_for(DeregisterTableForVersionTrackingOp) def deregister_for_version_tracking(operations, operation): operations.execute( - f'drop trigger if exists audit_trigger_insert on {operation.tablename}' + f'DROP TRIGGER IF EXISTS audit_trigger_insert ON {operation.tablename}' ) operations.execute( - f'drop trigger if exists audit_trigger_update on {operation.tablename}' + f'DROP TRIGGER IF EXISTS audit_trigger_update ON {operation.tablename}' ) operations.execute( - f'drop trigger if exists audit_trigger_delete on {operation.tablename}' + f'DROP TRIGGER IF EXISTS audit_trigger_delete ON {operation.tablename}' ) operations.execute( - f'drop trigger if exists audit_trigger_row on {operation.tablename}' + f'DROP TRIGGER IF EXISTS audit_trigger_row ON {operation.tablename}' ) diff --git a/postgresql_audit/base.py b/postgresql_audit/base.py index 9bce022..a32b1ea 100644 --- a/postgresql_audit/base.py +++ b/postgresql_audit/base.py @@ -1,7 +1,5 @@ import warnings -from collections import Sequence from contextlib import contextmanager -from functools import partial from weakref import WeakSet import sqlalchemy as sa @@ -19,7 +17,6 @@ from sqlalchemy_utils import get_class_by_table from postgresql_audit.utils import ( - build_register_table_query, create_audit_table, create_operators, render_tmpl, @@ -155,22 +152,116 @@ def convert_callables(values): } -class SessionManager(object): - def __init__(self, transaction_cls, values=None): - self.transaction_cls = transaction_cls - self.values = values or {} - self._marked_transactions = set() +class VersioningManager(object): + _actor_cls = None + + def __init__( + self, + actor_cls=None, + schema_name=None, + use_statement_level_triggers=True + ): + if actor_cls is not None: + self._actor_cls = actor_cls + self.values = {} self.listeners = ( + ( + orm.mapper, + 'instrument_class', + self.instrument_versioned_classes + ), + ( + orm.mapper, + 'after_configured', + self.configure_versioned_classes + ), ( orm.session.Session, 'before_flush', - self.before_flush, + self.receive_before_flush, ), ) + self.schema_name = schema_name + self.table_listeners = self.get_table_listeners() + self.pending_classes = WeakSet() + self.cached_ddls = {} + self.use_statement_level_triggers = use_statement_level_triggers def get_transaction_values(self): return self.values + @contextmanager + def disable(self, session): + session.execute( + "SET LOCAL postgresql_audit.enable_versioning = 'false'" + ) + try: + yield + finally: + session.execute( + "SET LOCAL postgresql_audit.enable_versioning = 'true'" + ) + + def render_tmpl(self, tmpl_name): + return render_tmpl(tmpl_name, schema_name=self.schema_name) + + def create_operators(self, target, bind, **kwargs): + create_operators(target, bind, self.schema_name, **kwargs) + + def create_audit_table(self, target, bind, **kwargs): + create_audit_table( + target, + bind, + self.schema_name, + self.use_statement_level_triggers, + **kwargs + ) + + def get_table_listeners(self): + listeners = {'transaction': []} + + listeners['activity'] = [ + ('after_create', sa.schema.DDL( + self.render_tmpl('jsonb_change_key_name.sql') + )), + ('after_create', self.create_audit_table), + ('after_create', self.create_operators) + ] + if self.schema_name is not None: + listeners['transaction'] = [ + ('before_create', sa.schema.DDL( + self.render_tmpl('create_schema.sql') + )), + ('after_drop', sa.schema.DDL( + self.render_tmpl('drop_schema.sql') + )), + ] + return listeners + + def audit_table(self, table, exclude_columns=None): + args = [table.name] + if exclude_columns: + for column in exclude_columns: + if column not in table.c: + raise ImproperlyConfigured( + "Could not configure versioning. Table '{}'' does " + "not have a column named '{}'.".format( + table.name, column + ) + ) + args.append(array(exclude_columns)) + + if self.schema_name is None: + func = sa.func.audit_table + else: + func = getattr(getattr(sa.func, self.schema_name), 'audit_table') + query = sa.select([func(*args)]) + if query not in cached_statements: + cached_statements[query] = StatementExecutor(query) + listener = (table, 'after_create', cached_statements[query]) + if not sa.event.contains(*listener): + sa.event.listen(*listener) + def set_activity_values(self, session): dialect = session.bind.engine.dialect table = self.transaction_cls.__table__ @@ -230,65 +321,40 @@ def is_modified(self, obj_or_session): if hasattr(entity, '__versioned__') ) - def __get_versioned_info(self, entity): - v_args = getattr(entity, '__versioned__', None) - if v_args: - return v_args - table_args = getattr(entity, '__table_args__', None) - if not table_args: - return None - if isinstance(table_args, Sequence): - table_args = next( - (x for x in iter(table_args) if isinstance(x, dict)), None) - if not table_args: - return None - return table_args.get('info', {}).get('versioned', None) - - def before_flush(self, session, flush_context, instances): - if session.transaction in self._marked_transactions: - return - if session.transaction: - self.add_entry_and_mark_transaction(session) - - def add_entry_and_mark_transaction(self, session): + def receive_before_flush(self, session, flush_context, instances): if self.is_modified(session): - self._marked_transactions.add(session.transaction) self.set_activity_values(session) - def attach_listeners(self): - for listener in self.listeners: - sa.event.listen(*listener) + def instrument_versioned_classes(self, mapper, cls): + """ + Collect versioned class and add it to pending_classes list. - def remove_listeners(self): - for listener in self.listeners: - sa.event.remove(*listener) + :mapper mapper: SQLAlchemy mapper object + :cls cls: SQLAlchemy declarative class + """ + if hasattr(cls, '__versioned__') and cls not in self.pending_classes: + self.pending_classes.add(cls) + def configure_versioned_classes(self): + """ + Configures all versioned classes that were collected during + instrumentation process. + """ + for cls in self.pending_classes: + self.audit_table(cls.__table__, cls.__versioned__.get('exclude')) + assign_actor(self.base, self.transaction_cls, self.actor_cls) -class BasicVersioningManager(object): - _actor_cls = None - _session_manager_factory = partial(SessionManager, values={}) + def attach_table_listeners(self): + for values in self.table_listeners['transaction']: + sa.event.listen(self.transaction_cls.__table__, *values) + for values in self.table_listeners['activity']: + sa.event.listen(self.activity_cls.__table__, *values) - def __init__( - self, - actor_cls=None, - session_manager_factory=None, - schema_name=None, - use_statement_level_triggers=True - ): - if actor_cls is not None: - self._actor_cls = actor_cls - if session_manager_factory is not None: - self._session_manager_factory = session_manager_factory - self.values = {} - self.listeners = ( - ( - orm.mapper, - 'after_configured', - self.after_configured - ), - ) - self.schema_name = schema_name - self.use_statement_level_triggers = use_statement_level_triggers + def remove_table_listeners(self): + for values in self.table_listeners['transaction']: + sa.event.remove(self.transaction_cls.__table__, *values) + for values in self.table_listeners['activity']: + sa.event.remove(self.activity_cls.__table__, *values) @property def actor_cls(self): @@ -317,8 +383,15 @@ def actor_cls(self): ) return self._actor_cls - def after_configured(self): - assign_actor(self.base, self.transaction_cls, self.actor_cls) + def attach_listeners(self): + self.attach_table_listeners() + for listener in self.listeners: + sa.event.listen(*listener) + + def remove_listeners(self): + self.remove_table_listeners() + for listener in self.listeners: + sa.event.remove(*listener) def activity_model_factory(self, base, transaction_cls): class Activity(activity_base(base, self.schema_name, transaction_cls)): @@ -332,28 +405,6 @@ class Transaction(transaction_base(base, self.schema_name)): return Transaction - def attach_listeners(self): - for listener in self.listeners: - sa.event.listen(*listener) - self.session_manager.attach_listeners() - - def remove_listeners(self): - for listener in self.listeners: - sa.event.remove(*listener) - self.session_manager.remove_listeners() - - @contextmanager - def disable(self, session): - session.execute( - "SET LOCAL postgresql_audit.enable_versioning = 'false'" - ) - try: - yield - finally: - session.execute( - "SET LOCAL postgresql_audit.enable_versioning = 'true'" - ) - def init(self, base): self.base = base self.transaction_cls = self.transaction_model_factory(base) @@ -361,128 +412,7 @@ def init(self, base): base, self.transaction_cls ) - self.session_manager = self._session_manager_factory( - self.transaction_cls - ) self.attach_listeners() -class VersioningManager(BasicVersioningManager): - def __init__( - self, - actor_cls=None, - session_manager_factory=None, - schema_name=None, - use_statement_level_triggers=True - ): - super().__init__( - actor_cls=actor_cls, - schema_name=schema_name, - use_statement_level_triggers=use_statement_level_triggers, - session_manager_factory=session_manager_factory - ) - self.listeners = ( - ( - orm.mapper, - 'instrument_class', - self.instrument_versioned_classes - ), - ( - orm.mapper, - 'after_configured', - self.configure_versioned_classes - ), - ) - self.table_listeners = self.get_table_listeners() - self.pending_classes = WeakSet() - self.cached_ddls = {} - - def get_table_listeners(self): - listeners = {'transaction': []} - - listeners['activity'] = [( - 'after_create', sa.schema.DDL( - render_tmpl('jsonb_change_key_name.sql', self.schema_name) - ) - ), ( - 'after_create', partial( - create_audit_table, - schema_name=self.schema_name, - use_statement_level_triggers=self.use_statement_level_triggers - ) - ), ( - 'after_create', partial( - create_operators, schema_name=self.schema_name - ) - )] - if self.schema_name is not None: - listeners['transaction'] = [ - ('before_create', sa.schema.DDL( - render_tmpl('create_schema.sql', self.schema_name) - )), - ('after_drop', sa.schema.DDL( - render_tmpl('drop_schema.sql', self.schema_name) - )), - ] - return listeners - - def audit_table(self, table, exclude_columns=None): - args = [table.name] - if exclude_columns: - for column in exclude_columns: - if column not in table.c: - raise ImproperlyConfigured( - "Could not configure versioning. Table '{}'' does " - "not have a column named '{}'.".format( - table.name, column - ) - ) - args.append(array(exclude_columns)) - query = build_register_table_query(self.schema_name, *args) - if query not in cached_statements: - cached_statements[query] = StatementExecutor(query) - listener = (table, 'after_create', cached_statements[query]) - if not sa.event.contains(*listener): - sa.event.listen(*listener) - - def instrument_versioned_classes(self, mapper, cls): - """ - Collect versioned class and add it to pending_classes list. - - :mapper mapper: SQLAlchemy mapper object - :cls cls: SQLAlchemy declarative class - """ - if hasattr(cls, '__versioned__') and cls not in self.pending_classes: - self.pending_classes.add(cls) - - def configure_versioned_classes(self): - """ - Configures all versioned classes that were collected during - instrumentation process. - """ - for cls in self.pending_classes: - self.audit_table(cls.__table__, cls.__versioned__.get('exclude')) - assign_actor(self.base, self.transaction_cls, self.actor_cls) - - def attach_table_listeners(self): - for values in self.table_listeners['transaction']: - sa.event.listen(self.transaction_cls.__table__, *values) - for values in self.table_listeners['activity']: - sa.event.listen(self.activity_cls.__table__, *values) - - def remove_table_listeners(self): - for values in self.table_listeners['transaction']: - sa.event.remove(self.transaction_cls.__table__, *values) - for values in self.table_listeners['activity']: - sa.event.remove(self.activity_cls.__table__, *values) - - def attach_listeners(self): - self.attach_table_listeners() - super().attach_listeners() - - def remove_listeners(self): - self.remove_table_listeners() - super().remove_listeners() - - versioning_manager = VersioningManager() diff --git a/postgresql_audit/flask.py b/postgresql_audit/flask.py index 6049d3a..3ca1437 100644 --- a/postgresql_audit/flask.py +++ b/postgresql_audit/flask.py @@ -6,10 +6,12 @@ from flask import g, request from flask.globals import _app_ctx_stack, _request_ctx_stack -from .base import SessionManager, VersioningManager +from .base import VersioningManager as BaseVersioningManager -class FlaskSessionManager(SessionManager): +class VersioningManager(BaseVersioningManager): + _actor_cls = 'User' + def get_transaction_values(self): values = copy(self.values) if context_available() and hasattr(g, 'activity_values'): @@ -77,6 +79,4 @@ def activity_values(**values): g.activity_values = previous_value -versioning_manager = VersioningManager( - actor_cls='User', session_manager_factory=FlaskSessionManager -) +versioning_manager = VersioningManager() diff --git a/postgresql_audit/utils.py b/postgresql_audit/utils.py index d049c4c..2df286e 100644 --- a/postgresql_audit/utils.py +++ b/postgresql_audit/utils.py @@ -7,19 +7,19 @@ class StatementExecutor(object): - def __init__(self, stmt): - self.stmt = stmt + def __init__(self, statement): + self.statement = statement def __call__(self, target, bind, **kwargs): tx = bind.begin() - bind.execute(self.stmt) + bind.execute(self.statement) tx.commit() def read_file(file_): - with open(os.path.join(HERE, file_)) as f: - s = f.read() - return s + with open(os.path.join(HERE, file_)) as file: + data = file.read() + return data def render_tmpl(tmpl_name, schema_name=None): diff --git a/tests/test_custom_schema.py b/tests/test_custom_schema.py index d167ace..4e17bdf 100644 --- a/tests/test_custom_schema.py +++ b/tests/test_custom_schema.py @@ -98,7 +98,7 @@ def test_manager_defaults( versioning_manager, activity_cls ): - versioning_manager.session_manager.values = {'actor_id': 1} + versioning_manager.values = {'actor_id': 1} user = user_class(name='John') session.add(user) session.commit() @@ -112,7 +112,7 @@ def test_callables_as_manager_defaults( versioning_manager, activity_cls ): - versioning_manager.session_manager.values = {'actor_id': lambda: 1} + versioning_manager.values = {'actor_id': lambda: 1} user = user_class(name='John') session.add(user) session.commit() @@ -126,8 +126,8 @@ def test_raw_inserts( versioning_manager, activity_cls ): - versioning_manager.session_manager.values = {'actor_id': 1} - versioning_manager.session_manager.set_activity_values(session) + versioning_manager.values = {'actor_id': 1} + versioning_manager.set_activity_values(session) session.execute(user_class.__table__.insert().values(name='John')) session.execute(user_class.__table__.insert().values(name='John')) activity = session.query(activity_cls).first() diff --git a/tests/test_flask_integration.py b/tests/test_flask_integration.py index 9c175ed..3721854 100644 --- a/tests/test_flask_integration.py +++ b/tests/test_flask_integration.py @@ -7,11 +7,7 @@ from flask_sqlalchemy import SQLAlchemy import sqlalchemy as sa -from postgresql_audit.flask import ( - activity_values, - FlaskSessionManager, - VersioningManager -) +from postgresql_audit.flask import activity_values, VersioningManager def login(client, user): @@ -81,9 +77,7 @@ def test_simple_flush(): @pytest.yield_fixture def versioning_manager(db): - vm = VersioningManager( - actor_cls="User", session_manager_factory=FlaskSessionManager - ) + vm = VersioningManager() vm.init(db.Model) yield vm vm.remove_listeners() diff --git a/tests/test_sqlalchemy_integration.py b/tests/test_sqlalchemy_integration.py index 58dc7b0..080a66c 100644 --- a/tests/test_sqlalchemy_integration.py +++ b/tests/test_sqlalchemy_integration.py @@ -73,7 +73,7 @@ def test_manager_defaults( session, versioning_manager ): - versioning_manager.session_manager.values = {'actor_id': 1} + versioning_manager.values = {'actor_id': 1} user = user_class(name='John') session.add(user) session.commit() @@ -86,7 +86,7 @@ def test_callables_as_manager_defaults( session, versioning_manager ): - versioning_manager.session_manager.values = {'actor_id': lambda: 1} + versioning_manager.values = {'actor_id': lambda: 1} user = user_class(name='John') session.add(user) session.commit() @@ -100,8 +100,8 @@ def test_raw_inserts( versioning_manager, activity_cls ): - versioning_manager.session_manager.values = {'actor_id': 1} - versioning_manager.session_manager.set_activity_values(session) + versioning_manager.values = {'actor_id': 1} + versioning_manager.set_activity_values(session) session.execute(user_class.__table__.insert().values(name='John')) session.execute(user_class.__table__.insert().values(name='John')) @@ -208,9 +208,7 @@ def test_multiple_flush_within_same_transaction( session, versioning_manager ): - versioning_manager.session_manager.values = { - 'client_addr': '127.0.0.1' - } + versioning_manager.values = {'client_addr': '127.0.0.1'} user = user_class(name='Jack') session.add(user) session.flush() @@ -338,7 +336,7 @@ def test_class_with_synonyms( ): article = article_class(name='Someone', _created_at=datetime.now()) session.add(article) - assert versioning_manager.session_manager.is_modified(article) + assert versioning_manager.is_modified(article) def test_modified_transient_object( self, @@ -348,8 +346,8 @@ def test_modified_transient_object( ): article = article_class(name='Article 1') session.add(article) - assert versioning_manager.session_manager.is_modified(article) - assert versioning_manager.session_manager.is_modified(session) + assert versioning_manager.is_modified(article) + assert versioning_manager.is_modified(session) def test_modified_excluded_column_with_persistent_object( self, @@ -358,8 +356,8 @@ def test_modified_excluded_column_with_persistent_object( session ): article.updated_at = datetime.now() - assert not versioning_manager.session_manager.is_modified(article) - assert not versioning_manager.session_manager.is_modified(session) + assert not versioning_manager.is_modified(article) + assert not versioning_manager.is_modified(session) def test_modified_persistent_object( self, @@ -368,8 +366,8 @@ def test_modified_persistent_object( session ): article.name = 'Article updated' - assert versioning_manager.session_manager.is_modified(article) - assert versioning_manager.session_manager.is_modified(session) + assert versioning_manager.is_modified(article) + assert versioning_manager.is_modified(session) def test_modified_excluded_relationship_column( self, @@ -379,8 +377,8 @@ def test_modified_excluded_relationship_column( session ): article.creator = user_class(name='Someone') - assert not versioning_manager.session_manager.is_modified(article) - assert not versioning_manager.session_manager.is_modified(session) + assert not versioning_manager.is_modified(article) + assert not versioning_manager.is_modified(session) def test_modified_relationship( self, @@ -390,8 +388,8 @@ def test_modified_relationship( session ): article.author = user_class(name='Someone') - assert versioning_manager.session_manager.is_modified(article) - assert versioning_manager.session_manager.is_modified(session) + assert versioning_manager.is_modified(article) + assert versioning_manager.is_modified(session) def test_deleted_object( self, @@ -401,7 +399,7 @@ def test_deleted_object( session ): session.delete(article) - assert versioning_manager.session_manager.is_modified(session) + assert versioning_manager.is_modified(session) @pytest.mark.usefixtures('versioning_manager', 'table_creator') From 3a7ae2226a06a8307e07c46df4be3361792d8219 Mon Sep 17 00:00:00 2001 From: Jake Stewart Date: Tue, 31 Aug 2021 14:22:48 -0400 Subject: [PATCH 10/12] Leave trailing whitespace --- postgresql_audit/alembic/__init__.py | 30 ++++++++++++++-------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/postgresql_audit/alembic/__init__.py b/postgresql_audit/alembic/__init__.py index 8785a47..6101b42 100644 --- a/postgresql_audit/alembic/__init__.py +++ b/postgresql_audit/alembic/__init__.py @@ -74,21 +74,21 @@ def compare_timestamp_table( if schemaname is None else schemaname ) - triggers = [row for row in autogen_context.connection.execute(f''' - SELECT event_object_schema AS table_schema, - event_object_table AS table_name, - trigger_schema, - trigger_name, - STRING_AGG(event_manipulation, ',') AS event, - action_timing AS activation, - action_condition AS condition, - action_statement AS definition - FROM information_schema.triggers - WHERE event_object_table = '{tablename}' - AND trigger_schema = '{schema_name}' - GROUP BY 1,2,3,4,6,7,8 - ORDER BY table_schema, table_name - ''')] + triggers = [row for row in autogen_context.connection.execute( + 'SELECT event_object_schema AS table_schema,' + 'event_object_table AS table_name,' + 'trigger_schema,' + 'trigger_name,' + 'STRING_AGG(event_manipulation, ',') AS event,' + 'action_timing AS activation,' + 'action_condition AS condition,' + 'action_statement AS definition ' + 'FROM information_schema.triggers ' + f"WHERE event_object_table = '{tablename}' " + f"AND trigger_schema = '{schema_name}' " + 'GROUP BY 1,2,3,4,6,7,8 ' + 'ORDER BY table_schema, table_name' + )] trigger_name = 'audit_trigger' From a1f623d8024be9a823cbe92aab79f422097bd08f Mon Sep 17 00:00:00 2001 From: Jake Stewart Date: Tue, 31 Aug 2021 14:35:29 -0400 Subject: [PATCH 11/12] Wrap single quotes string with double quotes --- postgresql_audit/alembic/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/postgresql_audit/alembic/__init__.py b/postgresql_audit/alembic/__init__.py index 6101b42..3406fb5 100644 --- a/postgresql_audit/alembic/__init__.py +++ b/postgresql_audit/alembic/__init__.py @@ -79,7 +79,7 @@ def compare_timestamp_table( 'event_object_table AS table_name,' 'trigger_schema,' 'trigger_name,' - 'STRING_AGG(event_manipulation, ',') AS event,' + "STRING_AGG(event_manipulation, ',') AS event," 'action_timing AS activation,' 'action_condition AS condition,' 'action_statement AS definition ' From b0e5ce761b5dbec321462596eaddebeb59d9d05c Mon Sep 17 00:00:00 2001 From: Jake Stewart Date: Wed, 1 Sep 2021 12:46:24 -0400 Subject: [PATCH 12/12] Get Column with to_column --- postgresql_audit/alembic/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/postgresql_audit/alembic/__init__.py b/postgresql_audit/alembic/__init__.py index 3406fb5..12cbaa6 100644 --- a/postgresql_audit/alembic/__init__.py +++ b/postgresql_audit/alembic/__init__.py @@ -167,7 +167,7 @@ def add_column_rewrite(context, revision, op): @writer.rewrites(ops.DropColumnOp) def drop_column_rewrite(context, revision, op): - column = op._orig_column + column = op.to_column() table_info = column.table.info or {} if ( 'versioned' in table_info