diff --git a/postgresql_audit/alembic/__init__.py b/postgresql_audit/alembic/__init__.py new file mode 100644 index 0000000..12cbaa6 --- /dev/null +++ b/postgresql_audit/alembic/__init__.py @@ -0,0 +1,185 @@ +import re + +from alembic.autogenerate import comparators, rewriter +from alembic.operations import ops + +from postgresql_audit.alembic.init_activity_table_triggers import ( + InitActivityTableTriggersOp, + RemoveActivityTableTriggersOp +) +from postgresql_audit.alembic.migration_ops import ( + AddColumnToActivityOp, + RemoveColumnFromRemoveActivityOp +) +from postgresql_audit.alembic.register_table_for_version_tracking import ( + DeregisterTableForVersionTrackingOp, + RegisterTableForVersionTrackingOp +) + + +@comparators.dispatch_for('schema') +def compare_timestamp_schema(autogen_context, upgrade_ops, schemas): + routines = set() + for schema in schemas: + schema_name = ( + autogen_context.dialect.default_schema_name if schema is None + else schema + ) + routines.update([ + (schema, *row) for row in autogen_context.connection.execute(f''' + SELECT routine_name, routine_definition + FROM information_schema.routines + WHERE routines.specific_schema='{schema_name}' + ''') + ]) + + for schema in schemas: + should_track_versions = any( + 'versioned' in table.info + for table in autogen_context.sorted_tables + if table.info and table.schema == schema + ) + schema_prefix = f'{schema}.' if schema else '' + tracked = f'{schema_prefix}audit_table' in [ + routine[1] for routine in routines if routine[0] == schema + ] + + if should_track_versions: + if not tracked: + upgrade_ops.ops.insert( + 0, + InitActivityTableTriggersOp(False, schema=schema) + ) + else: + if tracked: + upgrade_ops.ops.append( + RemoveActivityTableTriggersOp(False, schema=schema) + ) + + +@comparators.dispatch_for('table') +def compare_timestamp_table( + autogen_context, + modify_ops, + schemaname, + tablename, + conn_table, + metadata_table +): + if metadata_table is None: + return + meta_info = metadata_table.info or {} + schema_name = ( + autogen_context.dialect.default_schema_name + if schemaname is None else schemaname + ) + + triggers = [row for row in autogen_context.connection.execute( + 'SELECT event_object_schema AS table_schema,' + 'event_object_table AS table_name,' + 'trigger_schema,' + 'trigger_name,' + "STRING_AGG(event_manipulation, ',') AS event," + 'action_timing AS activation,' + 'action_condition AS condition,' + 'action_statement AS definition ' + 'FROM information_schema.triggers ' + f"WHERE event_object_table = '{tablename}' " + f"AND trigger_schema = '{schema_name}' " + 'GROUP BY 1,2,3,4,6,7,8 ' + 'ORDER BY table_schema, table_name' + )] + + trigger_name = 'audit_trigger' + + if 'versioned' in meta_info: + excluded_columns = ( + metadata_table.info['versioned'].get('exclude', tuple()) + ) + trigger = next( + (trigger for trigger in triggers if trigger_name in trigger[3]), + None + ) + original_excluded_columns = __get_existing_excluded_columns(trigger) + + if trigger and set(original_excluded_columns) == set(excluded_columns): + return + + modify_ops.ops.insert( + 0, + RegisterTableForVersionTrackingOp( + tablename, + excluded_columns, + original_excluded_columns, + schema=schema_name + ) + ) + else: + trigger = next( + (trigger for trigger in triggers if trigger_name in trigger[3]), + None + ) + original_excluded_columns = __get_existing_excluded_columns(trigger) + + if trigger: + modify_ops.ops.append( + DeregisterTableForVersionTrackingOp( + tablename, + original_excluded_columns, + schema=schema_name + ) + ) + + +def __get_existing_excluded_columns(trigger): + original_excluded_columns = () + if trigger: + arguments_match = re.search( + r"EXECUTE FUNCTION create_activity\('{(.+)}'\)", + trigger[7] + ) + if arguments_match: + original_excluded_columns = arguments_match.group(1).split(',') + return original_excluded_columns + + +writer = rewriter.Rewriter() + + +@writer.rewrites(ops.AddColumnOp) +def add_column_rewrite(context, revision, op): + table_info = op.column.table.info or {} + if ( + 'versioned' in table_info + and op.column.name not in table_info['versioned'].get('exclude', []) + ): + return [ + op, + AddColumnToActivityOp( + op.table_name, + op.column.name, + schema=op.column.table.schema, + ), + ] + else: + return op + + +@writer.rewrites(ops.DropColumnOp) +def drop_column_rewrite(context, revision, op): + column = op.to_column() + table_info = column.table.info or {} + if ( + 'versioned' in table_info + and column.name not in table_info['versioned'].get('exclude', []) + ): + return [ + op, + RemoveColumnFromRemoveActivityOp( + op.table_name, + column.name, + schema=column.table.schema, + ), + ] + else: + return op diff --git a/postgresql_audit/alembic/init_activity_table_triggers.py b/postgresql_audit/alembic/init_activity_table_triggers.py new file mode 100644 index 0000000..5a37549 --- /dev/null +++ b/postgresql_audit/alembic/init_activity_table_triggers.py @@ -0,0 +1,120 @@ +from alembic.autogenerate import renderers +from alembic.operations import MigrateOperation, Operations + +from postgresql_audit.utils import ( + create_audit_table, + create_operators, + render_tmpl +) + + +@Operations.register_operation('init_activity_table_triggers') +class InitActivityTableTriggersOp(MigrateOperation): + """Initialize Activity Table Triggers""" + + def __init__(self, use_statement_level_triggers, schema=None): + self.schema = schema + self.use_statement_level_triggers = use_statement_level_triggers + + @classmethod + def init_activity_table_triggers( + cls, operations, use_statement_level_triggers, **kwargs + ): + op = InitActivityTableTriggersOp( + use_statement_level_triggers, **kwargs + ) + return operations.invoke(op) + + def reverse(self): + # only needed to support autogenerate + return RemoveActivityTableTriggersOp( + self.use_statement_level_triggers, schema=self.schema + ) + + +@Operations.register_operation('remove_activity_table_triggers') +class RemoveActivityTableTriggersOp(MigrateOperation): + """Drop Activity Table Triggers""" + + def __init__(self, use_statement_level_triggers, schema=None): + self.schema = schema + self.use_statement_level_triggers = use_statement_level_triggers + + @classmethod + def remove_activity_table_triggers(cls, operations, **kwargs): + op = RemoveActivityTableTriggersOp(False, **kwargs) + return operations.invoke(op) + + def reverse(self): + # only needed to support autogenerate + return InitActivityTableTriggersOp( + self.use_statement_level_triggers, schema=self.schema + ) + + +@Operations.implementation_for(InitActivityTableTriggersOp) +def init_activity_table_triggers(operations, operation): + conn = operations + bind = conn.get_bind() + + if operation.schema: + conn.execute(render_tmpl('create_schema.sql', operation.schema)) + + conn.execute(render_tmpl('jsonb_change_key_name.sql', operation.schema)) + create_audit_table( + None, bind, operation.schema, operation.use_statement_level_triggers + ) + create_operators(None, bind, operation.schema) + + +@Operations.implementation_for(RemoveActivityTableTriggersOp) +def remove_activity_table_triggers(operations, operation): + conn = operations + bind = conn.get_bind() + + if operation.schema: + conn.execute(render_tmpl('drop_schema.sql', operation.schema)) + + conn.execute(''' + DROP FUNCTION jsonb_change_key_name( + data jsonb, old_key text, new_key text + ) + ''') + schema_prefix = f'{operation.schema}.' if operation.schema else '' + + conn.execute(f''' + DROP FUNCTION {schema_prefix}audit_table( + target_table regclass, ignored_cols text[] + ) + ''') + conn.execute(f'DROP FUNCTION {schema_prefix}create_activity()') + + if bind.dialect.server_version_info < (9, 5, 0): + conn.execute('DROP FUNCTION jsonb_subtract(jsonb, TEXT)') + conn.execute('DROP OPERATOR IF EXISTS - (jsonb, text);') + conn.execute('DROP FUNCTION jsonb_merge(jsonb, jsonb)') + conn.execute('DROP OPERATOR IF EXISTS || (jsonb, jsonb);') + if bind.dialect.server_version_info < (9, 6, 0): + conn.execute('DROP FUNCTION current_setting(TEXT, BOOL)') + if bind.dialect.server_version_info < (10, 0): + conn.execute('DROP FUNCTION jsonb_subtract(jsonb, TEXT[])') + conn.execute('DROP OPERATOR IF EXISTS - (jsonb, text[])') + + conn.execute('DROP OPERATOR IF EXISTS - (jsonb, jsonb)') + conn.execute('DROP FUNCTION get_setting(text, text)') + conn.execute('DROP FUNCTION jsonb_subtract(jsonb,jsonb)') + + +@renderers.dispatch_for(InitActivityTableTriggersOp) +def render_init_activity_table_triggers(autogen_context, op): + return 'op.init_activity_table_triggers(%r, **%r)' % ( + op.use_statement_level_triggers, + {'schema': op.schema} + ) + + +@renderers.dispatch_for(RemoveActivityTableTriggersOp) +def render_remove_activity_table_triggers(autogen_context, op): + return 'op.remove_activity_table_triggers(**%r)' % ( + {'schema': op.schema} + ) diff --git a/postgresql_audit/alembic/migration_ops.py b/postgresql_audit/alembic/migration_ops.py new file mode 100644 index 0000000..29f0c73 --- /dev/null +++ b/postgresql_audit/alembic/migration_ops.py @@ -0,0 +1,104 @@ +from alembic.autogenerate import renderers +from alembic.operations import MigrateOperation, Operations + +from postgresql_audit import add_column, remove_column + + +@Operations.register_operation('add_column_to_activity') +class AddColumnToActivityOp(MigrateOperation): + """Initialize Activity Table Triggers""" + + def __init__( + self, table_name, column_name, default_value=None, schema=None + ): + self.schema = schema + self.table_name = table_name + self.column_name = column_name + self.default_value = default_value + + @classmethod + def add_column_to_activity( + cls, operations, table_name, column_name, **kwargs + ): + op = AddColumnToActivityOp(table_name, column_name, **kwargs) + return operations.invoke(op) + + def reverse(self): + # only needed to support autogenerate + return RemoveColumnFromRemoveActivityOp( + self.table_name, + self.column_name, + default_value=self.default_value, + schema=self.schema + ) + + +@Operations.register_operation('remove_column_from_activity') +class RemoveColumnFromRemoveActivityOp(MigrateOperation): + """Drop Activity Table Triggers""" + + def __init__( + self, table_name, column_name, default_value=None, schema=None + ): + self.schema = schema + self.table_name = table_name + self.column_name = column_name + self.default_value = default_value + + @classmethod + def remove_column_from_activity( + cls, operations, table_name, column_name, **kwargs + ): + op = RemoveColumnFromRemoveActivityOp( + table_name, column_name, **kwargs + ) + return operations.invoke(op) + + def reverse(self): + # only needed to support autogenerate + return AddColumnToActivityOp( + self.table_name, + self.column_name, + default_value=self.default_value, + schema=self.schema + ) + + +@Operations.implementation_for(AddColumnToActivityOp) +def add_column_to_activity(operations, operation): + add_column( + operations, + operation.table_name, + operation.column_name, + default_value=operation.default_value, + schema=operation.schema + ) + + +@Operations.implementation_for(RemoveColumnFromRemoveActivityOp) +def remove_column_from_activity(operations, operation): + conn = operations.connection + remove_column( + conn, + operation.table_name, + operation.column_name, + operation.schema + ) + + +@renderers.dispatch_for(AddColumnToActivityOp) +def render_add_column_to_activity(autogen_context, op): + return 'op.add_column_to_activity(%r, %r, **%r)' % ( + op.table_name, + op.column_name, + {'schema': op.schema, 'default_value': op.default_value} + ) + + +@renderers.dispatch_for(RemoveColumnFromRemoveActivityOp) +def render_remove_column_from_activitys(autogen_context, op): + return 'op.remove_column_from_activity(%r, %r, **%r)' % ( + op.table_name, + op.column_name, + {'schema': op.schema} + ) diff --git a/postgresql_audit/alembic/register_table_for_version_tracking.py b/postgresql_audit/alembic/register_table_for_version_tracking.py new file mode 100644 index 0000000..daf4b50 --- /dev/null +++ b/postgresql_audit/alembic/register_table_for_version_tracking.py @@ -0,0 +1,102 @@ +import sqlalchemy as sa +from alembic.autogenerate import renderers +from alembic.operations import MigrateOperation, Operations + + +@Operations.register_operation('register_for_version_tracking') +class RegisterTableForVersionTrackingOp(MigrateOperation): + """Register Table for Version Tracking""" + + def __init__( + self, + tablename, + excluded_columns, + original_excluded_columns=None, + schema=None + ): + self.schema = schema + self.tablename = tablename + self.excluded_columns = excluded_columns + self.original_excluded_columns = original_excluded_columns + + @classmethod + def register_for_version_tracking( + cls, operations, tablename, exclude_columns, **kwargs + ): + op = RegisterTableForVersionTrackingOp( + tablename, exclude_columns, **kwargs + ) + return operations.invoke(op) + + def reverse(self): + # only needed to support autogenerate + return DeregisterTableForVersionTrackingOp( + self.tablename, self.original_excluded_columns, schema=self.schema + ) + + +@Operations.register_operation('deregister_for_version_tracking') +class DeregisterTableForVersionTrackingOp(MigrateOperation): + """Drop Table from Version Tracking""" + + def __init__(self, tablename, excluded_columns, schema=None): + self.schema = schema + self.tablename = tablename + self.excluded_columns = excluded_columns + + @classmethod + def deregister_for_version_tracking(cls, operations, tablename, **kwargs): + op = DeregisterTableForVersionTrackingOp(tablename, (), **kwargs) + return operations.invoke(op) + + def reverse(self): + # only needed to support autogenerate + return RegisterTableForVersionTrackingOp( + self.tablename, self.excluded_columns, (), schema=self.schema + ) + + +@Operations.implementation_for(RegisterTableForVersionTrackingOp) +def register_for_version_tracking(operations, operation): + if operation.schema is None: + func = sa.func.audit_table + else: + func = getattr(getattr(sa.func, operation.schema), 'audit_table') + operations.execute( + sa.select( + [func(operation.tablename, list(operation.excluded_columns))] + ) + ) + + +@Operations.implementation_for(DeregisterTableForVersionTrackingOp) +def deregister_for_version_tracking(operations, operation): + operations.execute( + f'DROP TRIGGER IF EXISTS audit_trigger_insert ON {operation.tablename}' + ) + operations.execute( + f'DROP TRIGGER IF EXISTS audit_trigger_update ON {operation.tablename}' + ) + operations.execute( + f'DROP TRIGGER IF EXISTS audit_trigger_delete ON {operation.tablename}' + ) + operations.execute( + f'DROP TRIGGER IF EXISTS audit_trigger_row ON {operation.tablename}' + ) + + +@renderers.dispatch_for(RegisterTableForVersionTrackingOp) +def render_register_for_version_tracking(autogen_context, op): + return 'op.register_for_version_tracking(%r, %r, **%r)' % ( + op.tablename, + op.excluded_columns, + {'schema': op.schema} + ) + + +@renderers.dispatch_for(DeregisterTableForVersionTrackingOp) +def render_deregister_for_version_tracking(autogen_context, op): + return 'op.deregister_for_version_tracking(%r, **%r)' % ( + op.tablename, + {'schema': op.schema} + ) diff --git a/postgresql_audit/base.py b/postgresql_audit/base.py index 4f3c50e..a32b1ea 100644 --- a/postgresql_audit/base.py +++ b/postgresql_audit/base.py @@ -1,5 +1,3 @@ -import os -import string import warnings from contextlib import contextmanager from weakref import WeakSet @@ -18,7 +16,13 @@ from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy_utils import get_class_by_table -HERE = os.path.dirname(os.path.abspath(__file__)) +from postgresql_audit.utils import ( + create_audit_table, + create_operators, + render_tmpl, + StatementExecutor +) + cached_statements = {} @@ -30,22 +34,6 @@ class ClassNotVersioned(Exception): pass -class StatementExecutor(object): - def __init__(self, stmt): - self.stmt = stmt - - def __call__(self, target, bind, **kwargs): - tx = bind.begin() - bind.execute(self.stmt) - tx.commit() - - -def read_file(file_): - with open(os.path.join(HERE, file_)) as f: - s = f.read() - return s - - def assign_actor(base, cls, actor_cls): if hasattr(cls, 'actor_id'): return @@ -215,51 +203,19 @@ def disable(self, session): ) def render_tmpl(self, tmpl_name): - file_contents = read_file( - 'templates/{}'.format(tmpl_name) - ).replace('%', '%%').replace('$$', '$$$$') - tmpl = string.Template(file_contents) - context = dict(schema_name=self.schema_name) - - if self.schema_name is None: - context['schema_prefix'] = '' - context['revoke_cmd'] = '' - else: - context['schema_prefix'] = '{}.'.format(self.schema_name) - context['revoke_cmd'] = ( - 'REVOKE ALL ON {schema_prefix}activity FROM public;' - ).format(**context) - - temp = tmpl.substitute(**context) - return temp + return render_tmpl(tmpl_name, schema_name=self.schema_name) def create_operators(self, target, bind, **kwargs): - if bind.dialect.server_version_info < (9, 5, 0): - StatementExecutor(self.render_tmpl('operators_pre95.sql'))( - target, bind, **kwargs - ) - if bind.dialect.server_version_info < (9, 6, 0): - StatementExecutor(self.render_tmpl('operators_pre96.sql'))( - target, bind, **kwargs - ) - if bind.dialect.server_version_info < (10, 0): - operators_template = self.render_tmpl('operators_pre100.sql') - StatementExecutor(operators_template)(target, bind, **kwargs) - operators_template = self.render_tmpl('operators.sql') - StatementExecutor(operators_template)(target, bind, **kwargs) + create_operators(target, bind, self.schema_name, **kwargs) def create_audit_table(self, target, bind, **kwargs): - sql = '' - if ( - self.use_statement_level_triggers and - bind.dialect.server_version_info >= (10, 0) - ): - sql += self.render_tmpl('create_activity_stmt_level.sql') - sql += self.render_tmpl('audit_table_stmt_level.sql') - else: - sql += self.render_tmpl('create_activity_row_level.sql') - sql += self.render_tmpl('audit_table_row_level.sql') - StatementExecutor(sql)(target, bind, **kwargs) + create_audit_table( + target, + bind, + self.schema_name, + self.use_statement_level_triggers, + **kwargs + ) def get_table_listeners(self): listeners = {'transaction': []} diff --git a/postgresql_audit/utils.py b/postgresql_audit/utils.py new file mode 100644 index 0000000..2df286e --- /dev/null +++ b/postgresql_audit/utils.py @@ -0,0 +1,81 @@ +import os +import string + +import sqlalchemy as sa + +HERE = os.path.dirname(os.path.abspath(__file__)) + + +class StatementExecutor(object): + def __init__(self, statement): + self.statement = statement + + def __call__(self, target, bind, **kwargs): + tx = bind.begin() + bind.execute(self.statement) + tx.commit() + + +def read_file(file_): + with open(os.path.join(HERE, file_)) as file: + data = file.read() + return data + + +def render_tmpl(tmpl_name, schema_name=None): + file_contents = read_file( + 'templates/{}'.format(tmpl_name) + ).replace('%', '%%').replace('$$', '$$$$') + tmpl = string.Template(file_contents) + context = dict(schema_name=schema_name) + + if schema_name is None: + context['schema_prefix'] = '' + context['revoke_cmd'] = '' + else: + context['schema_prefix'] = '{}.'.format(schema_name) + context['revoke_cmd'] = ( + 'REVOKE ALL ON {schema_prefix}activity FROM public;' + ).format(**context) + + return tmpl.substitute(**context) + + +def create_operators(target, bind, schema_name, **kwargs): + if bind.dialect.server_version_info < (9, 5, 0): + StatementExecutor(render_tmpl('operators_pre95.sql', schema_name))( + target, bind, **kwargs + ) + if bind.dialect.server_version_info < (9, 6, 0): + StatementExecutor(render_tmpl('operators_pre96.sql', schema_name))( + target, bind, **kwargs + ) + if bind.dialect.server_version_info < (10, 0): + operators_template = render_tmpl('operators_pre100.sql', schema_name) + StatementExecutor(operators_template)(target, bind, **kwargs) + operators_template = render_tmpl('operators.sql', schema_name) + StatementExecutor(operators_template)(target, bind, **kwargs) + + +def create_audit_table( + target, bind, schema_name, use_statement_level_triggers, **kwargs +): + sql = '' + if ( + use_statement_level_triggers and + bind.dialect.server_version_info >= (10, 0) + ): + sql += render_tmpl('create_activity_stmt_level.sql', schema_name) + sql += render_tmpl('audit_table_stmt_level.sql', schema_name) + else: + sql += render_tmpl('create_activity_row_level.sql', schema_name) + sql += render_tmpl('audit_table_row_level.sql', schema_name) + StatementExecutor(sql)(target, bind, **kwargs) + + +def build_register_table_query(schema_name, *args): + if schema_name is None: + func = sa.func.audit_table + else: + func = getattr(getattr(sa.func, schema_name), 'audit_table') + return sa.select([func(*args)])