Skip to content

Commit

Permalink
Update tests to support refactored versioning manager
Browse files Browse the repository at this point in the history
  • Loading branch information
xerona committed Aug 24, 2021
1 parent 5c28625 commit 824322d
Show file tree
Hide file tree
Showing 10 changed files with 327 additions and 161 deletions.
160 changes: 109 additions & 51 deletions postgresql_audit/alembic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,102 +4,157 @@
from alembic.autogenerate import comparators, rewriter
from alembic.operations import ops

from postgresql_audit.alembic.init_activity_table_triggers import InitActivityTableTriggersOp, \
from postgresql_audit.alembic.init_activity_table_triggers import (
InitActivityTableTriggersOp,
RemoveActivityTableTriggersOp
from postgresql_audit.alembic.migration_ops import AddColumnToActivityOp, RemoveColumnFromRemoveActivityOp
from postgresql_audit.alembic.register_table_for_version_tracking import RegisterTableForVersionTrackingOp, \
DeregisterTableForVersionTrackingOp


@comparators.dispatch_for("schema")
)
from postgresql_audit.alembic.migration_ops import (
AddColumnToActivityOp,
RemoveColumnFromRemoveActivityOp
)
from postgresql_audit.alembic.register_table_for_version_tracking import (
DeregisterTableForVersionTrackingOp,
RegisterTableForVersionTrackingOp
)


@comparators.dispatch_for('schema')
def compare_timestamp_schema(autogen_context, upgrade_ops, schemas):
routines = set()
for sch in schemas:
schema_name = autogen_context.dialect.default_schema_name if sch is None else sch
schema_name = (
autogen_context.dialect.default_schema_name if sch is None else sch
)
routines.update([
(sch, *row) for row in autogen_context.connection.execute(
"select routine_name, routine_definition from information_schema.routines "
f"where routines.specific_schema='{schema_name}' "
)])
(sch, *row) for row in autogen_context.connection.execute(
'select routine_name, routine_definition '
'from information_schema.routines '
f"where routines.specific_schema='{schema_name}' "
)
])

for sch in schemas:
should_track_versions = any("versioned" in table.info for table in autogen_context.sorted_tables if table.info and table.schema == sch)
schema_prefix = f"{sch}." if sch else ""
should_track_versions = any(
'versioned' in table.info
for table in autogen_context.sorted_tables
if table.info and table.schema == sch
)
schema_prefix = f'{sch}.' if sch else ''

a = next((v for k, v in groupby(routines, key=lambda x: x[0]) if k == sch), None)
a = next(
(v for k, v in groupby(routines, key=lambda x: x[0]) if k == sch),
None
)
a = list(a) if a else []
if should_track_versions:
if f"{schema_prefix}audit_table" not in (x[1] for x in a):
upgrade_ops.ops.insert(0,
if f'{schema_prefix}audit_table' not in (x[1] for x in a):
upgrade_ops.ops.insert(
0,
InitActivityTableTriggersOp(False, schema=sch)
)
else:
if f"{schema_prefix}audit_table" in (x[1] for x in a):
if f'{schema_prefix}audit_table' in (x[1] for x in a):
upgrade_ops.ops.append(
RemoveActivityTableTriggersOp(False, schema=sch)
)


@comparators.dispatch_for("table")
def compare_timestamp_table(autogen_context, modify_ops, schemaname, tablename, conn_table, metadata_table):
@comparators.dispatch_for('table')
def compare_timestamp_table(
autogen_context,
modify_ops,
schemaname,
tablename,
conn_table,
metadata_table
):
if metadata_table is None:
return
meta_info = metadata_table.info or {}
schema_name = autogen_context.dialect.default_schema_name if schemaname is None else schemaname

triggers = [row for row in autogen_context.connection.execute(f"""
select event_object_schema as table_schema,
event_object_table as table_name,
trigger_schema,
trigger_name,
string_agg(event_manipulation, ',') as event,
action_timing as activation,
action_condition as condition,
action_statement as definition
from information_schema.triggers
where event_object_table = '{tablename}' and trigger_schema = '{schema_name}'
group by 1,2,3,4,6,7,8
order by table_schema, table_name;
""")]

trigger_name = "audit_trigger"

if "versioned" in meta_info:
excluded_columns = metadata_table.info["versioned"].get("exclude", tuple())
trigger = next((trigger for trigger in triggers if trigger_name in trigger[3]), None)
schema_name = (
autogen_context.dialect.default_schema_name
if schemaname is None else schemaname
)

triggers = [row for row in autogen_context.connection.execute(
'select event_object_schema as table_schema,'
'event_object_table as table_name,'
'trigger_schema,'
'trigger_name,'
"string_agg(event_manipulation, ',') as event,"
'action_timing as activation,'
'action_condition as condition,'
'action_statement as definition'
'from information_schema.triggers'
f"where event_object_table = '{tablename}'"
f"and trigger_schema = '{schema_name}'"
'group by 1,2,3,4,6,7,8'
'order by table_schema, table_name;'
)]

trigger_name = 'audit_trigger'

if 'versioned' in meta_info:
excluded_columns = (
metadata_table.info['versioned'].get('exclude', tuple())
)
trigger = next(
(trigger for trigger in triggers if trigger_name in trigger[3]),
None
)
original_excluded_columns = __get_existing_excluded_columns(trigger)

if trigger and set(original_excluded_columns) == set(excluded_columns):
return

modify_ops.ops.insert(0,
RegisterTableForVersionTrackingOp(tablename, excluded_columns, original_excluded_columns, schema=schema_name)
modify_ops.ops.insert(
0,
RegisterTableForVersionTrackingOp(
tablename,
excluded_columns,
original_excluded_columns,
schema=schema_name
)
)
else:
trigger = next((trigger for trigger in triggers if trigger_name in trigger[3]), None)
trigger = next(
(trigger for trigger in triggers if trigger_name in trigger[3]),
None
)
original_excluded_columns = __get_existing_excluded_columns(trigger)

if trigger:
modify_ops.ops.append(
DeregisterTableForVersionTrackingOp(tablename, original_excluded_columns, schema=schema_name)
DeregisterTableForVersionTrackingOp(
tablename,
original_excluded_columns,
schema=schema_name
)
)


def __get_existing_excluded_columns(trigger):
original_excluded_columns = ()
if trigger:
arguments_match = re.search(r"EXECUTE FUNCTION create_activity\('{(.+)}'\)", trigger[7])
arguments_match = re.search(
r"EXECUTE FUNCTION create_activity\('{(.+)}'\)",
trigger[7]
)
if arguments_match:
original_excluded_columns = arguments_match.group(1).split(",")
original_excluded_columns = arguments_match.group(1).split(',')
return original_excluded_columns


writer = rewriter.Rewriter()


@writer.rewrites(ops.AddColumnOp)
def add_column_rewrite(context, revision, op):
table_info = op.column.table.info or {}
if "versioned" in table_info and op.column.name not in table_info["versioned"].get("exclude", []):
if (
'versioned' in table_info
and op.column.name not in table_info['versioned'].get('exclude', [])
):
return [
op,
AddColumnToActivityOp(
Expand All @@ -111,11 +166,15 @@ def add_column_rewrite(context, revision, op):
else:
return op


@writer.rewrites(ops.DropColumnOp)
def drop_column_rewrite(context, revision, op):
column = op._orig_column
table_info = column.table.info or {}
if "versioned" in table_info and column.name not in table_info["versioned"].get("exclude", []):
if (
'versioned' in table_info
and column.name not in table_info['versioned'].get('exclude', [])
):
return [
op,
RemoveColumnFromRemoveActivityOp(
Expand All @@ -126,4 +185,3 @@ def drop_column_rewrite(context, revision, op):
]
else:
return op

78 changes: 49 additions & 29 deletions postgresql_audit/alembic/init_activity_table_triggers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from alembic.autogenerate import renderers
from alembic.operations import Operations, MigrateOperation
from alembic.operations import MigrateOperation, Operations

from postgresql_audit.utils import render_tmpl, create_audit_table, create_operators
from postgresql_audit.utils import (
create_audit_table,
create_operators,
render_tmpl
)


@Operations.register_operation("init_activity_table_triggers")
@Operations.register_operation('init_activity_table_triggers')
class InitActivityTableTriggersOp(MigrateOperation):
"""Initialize Activity Table Triggers"""

Expand All @@ -13,31 +17,39 @@ def __init__(self, use_statement_level_triggers, schema=None):
self.use_statement_level_triggers = use_statement_level_triggers

@classmethod
def init_activity_table_triggers(cls, operations, use_statement_level_triggers, **kwargs):
op = InitActivityTableTriggersOp(use_statement_level_triggers, **kwargs)
def init_activity_table_triggers(
cls, operations, use_statement_level_triggers, **kwargs
):
op = InitActivityTableTriggersOp(
use_statement_level_triggers, **kwargs
)
return operations.invoke(op)

def reverse(self):
# only needed to support autogenerate
return RemoveActivityTableTriggersOp(self.use_statement_level_triggers, schema=self.schema)
return RemoveActivityTableTriggersOp(
self.use_statement_level_triggers, schema=self.schema
)

@Operations.register_operation("remove_activity_table_triggers")

@Operations.register_operation('remove_activity_table_triggers')
class RemoveActivityTableTriggersOp(MigrateOperation):
"""Drop Activity Table Triggers"""

def __init__(self, use_statement_level_triggers, schema=None):
self.schema = schema
self.use_statement_level_triggers = use_statement_level_triggers


@classmethod
def remove_activity_table_triggers(cls, operations, **kwargs):
op = RemoveActivityTableTriggersOp(False, **kwargs)
return operations.invoke(op)

def reverse(self):
# only needed to support autogenerate
return InitActivityTableTriggersOp(self.use_statement_level_triggers, schema=self.schema)
return InitActivityTableTriggersOp(
self.use_statement_level_triggers, schema=self.schema
)


@Operations.implementation_for(InitActivityTableTriggersOp)
Expand All @@ -49,7 +61,9 @@ def init_activity_table_triggers(operations, operation):
conn.execute(render_tmpl('create_schema.sql', operation.schema))

conn.execute(render_tmpl('jsonb_change_key_name.sql', operation.schema))
create_audit_table(None, bind, operation.schema, operation.use_statement_level_triggers)
create_audit_table(
None, bind, operation.schema, operation.use_statement_level_triggers
)
create_operators(None, bind, operation.schema)


Expand All @@ -61,38 +75,44 @@ def remove_activity_table_triggers(operations, operation):
if operation.schema:
conn.execute(render_tmpl('drop_schema.sql', operation.schema))

conn.execute("DROP FUNCTION jsonb_change_key_name(data jsonb, old_key text, new_key text)")
schema_prefix = f"{operation.schema}." if operation.schema else ""

conn.execute(f"DROP FUNCTION {schema_prefix}audit_table(target_table regclass, ignored_cols text[])")
conn.execute(f"DROP FUNCTION {schema_prefix}create_activity()")
conn.execute(
'DROP FUNCTION jsonb_change_key_name('
'data jsonb, old_key text, new_key text)'
)
schema_prefix = f'{operation.schema}.' if operation.schema else ''

conn.execute(
f'DROP FUNCTION {schema_prefix}audit_table('
'target_table regclass, ignored_cols text[])'
)
conn.execute(f'DROP FUNCTION {schema_prefix}create_activity()')

if bind.dialect.server_version_info < (9, 5, 0):
conn.execute(f"""DROP FUNCTION jsonb_subtract(jsonb, TEXT)""")
conn.execute(f"""DROP OPERATOR IF EXISTS - (jsonb, text);""")
conn.execute(f"""DROP FUNCTION jsonb_merge(jsonb, jsonb)""")
conn.execute(f"""DROP OPERATOR IF EXISTS || (jsonb, jsonb);""")
conn.execute('DROP FUNCTION jsonb_subtract(jsonb, TEXT)')
conn.execute('DROP OPERATOR IF EXISTS - (jsonb, text);')
conn.execute('DROP FUNCTION jsonb_merge(jsonb, jsonb)')
conn.execute('DROP OPERATOR IF EXISTS || (jsonb, jsonb);')
if bind.dialect.server_version_info < (9, 6, 0):
conn.execute(f"""DROP FUNCTION current_setting(TEXT, BOOL)""")
conn.execute('DROP FUNCTION current_setting(TEXT, BOOL)')
if bind.dialect.server_version_info < (10, 0):
conn.execute(f"""DROP FUNCTION jsonb_subtract(jsonb, TEXT[])""")
conn.execute(f"""DROP OPERATOR IF EXISTS - (jsonb, text[])""")
conn.execute('DROP FUNCTION jsonb_subtract(jsonb, TEXT[])')
conn.execute('DROP OPERATOR IF EXISTS - (jsonb, text[])')

conn.execute(f"""DROP OPERATOR IF EXISTS - (jsonb, jsonb)""")
conn.execute(f"""DROP FUNCTION get_setting(text, text)""")
conn.execute(f"""DROP FUNCTION jsonb_subtract(jsonb,jsonb)""")
conn.execute('DROP OPERATOR IF EXISTS - (jsonb, jsonb)')
conn.execute('DROP FUNCTION get_setting(text, text)')
conn.execute('DROP FUNCTION jsonb_subtract(jsonb,jsonb)')


@renderers.dispatch_for(InitActivityTableTriggersOp)
def render_init_activity_table_triggers(autogen_context, op):
return "op.init_activity_table_triggers(%r, **%r)" % (
return 'op.init_activity_table_triggers(%r, **%r)' % (
op.use_statement_level_triggers,
{"schema": op.schema}
{'schema': op.schema}
)


@renderers.dispatch_for(RemoveActivityTableTriggersOp)
def render_remove_activity_table_triggers(autogen_context, op):
return "op.remove_activity_table_triggers(**%r)" % (
{"schema": op.schema}
return 'op.remove_activity_table_triggers(**%r)' % (
{'schema': op.schema}
)
Loading

0 comments on commit 824322d

Please sign in to comment.