Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactored for integration with alembic #46

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 129 additions & 0 deletions postgresql_audit/alembic/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import re
from itertools import groupby

from alembic.autogenerate import comparators, rewriter
from alembic.operations import ops

from postgresql_audit.alembic.init_activity_table_triggers import InitActivityTableTriggersOp, \
RemoveActivityTableTriggersOp
from postgresql_audit.alembic.migration_ops import AddColumnToActivityOp, RemoveColumnFromRemoveActivityOp
from postgresql_audit.alembic.register_table_for_version_tracking import RegisterTableForVersionTrackingOp, \
DeregisterTableForVersionTrackingOp


@comparators.dispatch_for("schema")
def compare_timestamp_schema(autogen_context, upgrade_ops, schemas):
routines = set()
for sch in schemas:
schema_name = autogen_context.dialect.default_schema_name if sch is None else sch
routines.update([
(sch, *row) for row in autogen_context.connection.execute(
"select routine_name, routine_definition from information_schema.routines "
f"where routines.specific_schema='{schema_name}' "
)])

for sch in schemas:
should_track_versions = any("versioned" in table.info for table in autogen_context.sorted_tables if table.info and table.schema == sch)
schema_prefix = f"{sch}." if sch else ""

a = next((v for k, v in groupby(routines, key=lambda x: x[0]) if k == sch), None)
a = list(a) if a else []
if should_track_versions:
if f"{schema_prefix}audit_table" not in (x[1] for x in a):
upgrade_ops.ops.insert(0,
InitActivityTableTriggersOp(False, schema=sch)
)
else:
if f"{schema_prefix}audit_table" in (x[1] for x in a):
upgrade_ops.ops.append(
RemoveActivityTableTriggersOp(False, schema=sch)
)


@comparators.dispatch_for("table")
def compare_timestamp_table(autogen_context, modify_ops, schemaname, tablename, conn_table, metadata_table):
if metadata_table is None:
return
meta_info = metadata_table.info or {}
schema_name = autogen_context.dialect.default_schema_name if schemaname is None else schemaname

triggers = [row for row in autogen_context.connection.execute(f"""
select event_object_schema as table_schema,
event_object_table as table_name,
trigger_schema,
trigger_name,
string_agg(event_manipulation, ',') as event,
action_timing as activation,
action_condition as condition,
action_statement as definition
from information_schema.triggers
where event_object_table = '{tablename}' and trigger_schema = '{schema_name}'
group by 1,2,3,4,6,7,8
order by table_schema, table_name;
""")]

trigger_name = "audit_trigger"

if "versioned" in meta_info:
excluded_columns = metadata_table.info["versioned"].get("exclude", tuple())
trigger = next((trigger for trigger in triggers if trigger_name in trigger[3]), None)
original_excluded_columns = __get_existing_excluded_columns(trigger)

if trigger and set(original_excluded_columns) == set(excluded_columns):
return

modify_ops.ops.insert(0,
RegisterTableForVersionTrackingOp(tablename, excluded_columns, original_excluded_columns, schema=schema_name)
)
else:
trigger = next((trigger for trigger in triggers if trigger_name in trigger[3]), None)
original_excluded_columns = __get_existing_excluded_columns(trigger)

if trigger:
modify_ops.ops.append(
DeregisterTableForVersionTrackingOp(tablename, original_excluded_columns, schema=schema_name)
)


def __get_existing_excluded_columns(trigger):
original_excluded_columns = ()
if trigger:
arguments_match = re.search(r"EXECUTE FUNCTION create_activity\('{(.+)}'\)", trigger[7])
if arguments_match:
original_excluded_columns = arguments_match.group(1).split(",")
return original_excluded_columns


writer = rewriter.Rewriter()

@writer.rewrites(ops.AddColumnOp)
def add_column_rewrite(context, revision, op):
table_info = op.column.table.info or {}
if "versioned" in table_info and op.column.name not in table_info["versioned"].get("exclude", []):
return [
op,
AddColumnToActivityOp(
op.table_name,
op.column.name,
schema=op.column.table.schema,
),
]
else:
return op

@writer.rewrites(ops.DropColumnOp)
def drop_column_rewrite(context, revision, op):
column = op._orig_column
table_info = column.table.info or {}
if "versioned" in table_info and column.name not in table_info["versioned"].get("exclude", []):
return [
op,
RemoveColumnFromRemoveActivityOp(
op.table_name,
column.name,
schema=column.table.schema,
),
]
else:
return op

98 changes: 98 additions & 0 deletions postgresql_audit/alembic/init_activity_table_triggers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from alembic.autogenerate import renderers
from alembic.operations import Operations, MigrateOperation

from postgresql_audit.utils import render_tmpl, create_audit_table, create_operators


@Operations.register_operation("init_activity_table_triggers")
class InitActivityTableTriggersOp(MigrateOperation):
"""Initialize Activity Table Triggers"""

def __init__(self, use_statement_level_triggers, schema=None):
self.schema = schema
self.use_statement_level_triggers = use_statement_level_triggers

@classmethod
def init_activity_table_triggers(cls, operations, use_statement_level_triggers, **kwargs):
op = InitActivityTableTriggersOp(use_statement_level_triggers, **kwargs)
return operations.invoke(op)

def reverse(self):
# only needed to support autogenerate
return RemoveActivityTableTriggersOp(self.use_statement_level_triggers, schema=self.schema)

@Operations.register_operation("remove_activity_table_triggers")
class RemoveActivityTableTriggersOp(MigrateOperation):
"""Drop Activity Table Triggers"""

def __init__(self, use_statement_level_triggers, schema=None):
self.schema = schema
self.use_statement_level_triggers = use_statement_level_triggers


@classmethod
def remove_activity_table_triggers(cls, operations, **kwargs):
op = RemoveActivityTableTriggersOp(False, **kwargs)
return operations.invoke(op)

def reverse(self):
# only needed to support autogenerate
return InitActivityTableTriggersOp(self.use_statement_level_triggers, schema=self.schema)


@Operations.implementation_for(InitActivityTableTriggersOp)
def init_activity_table_triggers(operations, operation):
conn = operations
bind = conn.get_bind()

if operation.schema:
conn.execute(render_tmpl('create_schema.sql', operation.schema))

conn.execute(render_tmpl('jsonb_change_key_name.sql', operation.schema))
create_audit_table(None, bind, operation.schema, operation.use_statement_level_triggers)
create_operators(None, bind, operation.schema)


@Operations.implementation_for(RemoveActivityTableTriggersOp)
def remove_activity_table_triggers(operations, operation):
conn = operations
bind = conn.get_bind()

if operation.schema:
conn.execute(render_tmpl('drop_schema.sql', operation.schema))

conn.execute("DROP FUNCTION jsonb_change_key_name(data jsonb, old_key text, new_key text)")
schema_prefix = f"{operation.schema}." if operation.schema else ""

conn.execute(f"DROP FUNCTION {schema_prefix}audit_table(target_table regclass, ignored_cols text[])")
conn.execute(f"DROP FUNCTION {schema_prefix}create_activity()")


if bind.dialect.server_version_info < (9, 5, 0):
conn.execute(f"""DROP FUNCTION jsonb_subtract(jsonb, TEXT)""")
conn.execute(f"""DROP OPERATOR IF EXISTS - (jsonb, text);""")
conn.execute(f"""DROP FUNCTION jsonb_merge(jsonb, jsonb)""")
conn.execute(f"""DROP OPERATOR IF EXISTS || (jsonb, jsonb);""")
if bind.dialect.server_version_info < (9, 6, 0):
conn.execute(f"""DROP FUNCTION current_setting(TEXT, BOOL)""")
if bind.dialect.server_version_info < (10, 0):
conn.execute(f"""DROP FUNCTION jsonb_subtract(jsonb, TEXT[])""")
conn.execute(f"""DROP OPERATOR IF EXISTS - (jsonb, text[])""")

conn.execute(f"""DROP OPERATOR IF EXISTS - (jsonb, jsonb)""")
conn.execute(f"""DROP FUNCTION get_setting(text, text)""")
conn.execute(f"""DROP FUNCTION jsonb_subtract(jsonb,jsonb)""")


@renderers.dispatch_for(InitActivityTableTriggersOp)
def render_init_activity_table_triggers(autogen_context, op):
return "op.init_activity_table_triggers(%r, **%r)" % (
op.use_statement_level_triggers,
{"schema": op.schema}
)

@renderers.dispatch_for(RemoveActivityTableTriggersOp)
def render_remove_activity_table_triggers(autogen_context, op):
return "op.remove_activity_table_triggers(**%r)" % (
{"schema": op.schema}
)
70 changes: 70 additions & 0 deletions postgresql_audit/alembic/migration_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from alembic.autogenerate import renderers
from alembic.operations import Operations, MigrateOperation

from postgresql_audit import add_column, remove_column


@Operations.register_operation("add_column_to_activity")
class AddColumnToActivityOp(MigrateOperation):
"""Initialize Activity Table Triggers"""

def __init__(self, table_name, column_name, default_value=None, schema=None):
self.schema = schema
self.table_name = table_name
self.column_name = column_name
self.default_value = default_value

@classmethod
def add_column_to_activity(cls, operations, table_name, column_name, **kwargs):
op = AddColumnToActivityOp(table_name, column_name, **kwargs)
return operations.invoke(op)

def reverse(self):
# only needed to support autogenerate
return RemoveColumnFromRemoveActivityOp(self.table_name, self.column_name, default_value=self.default_value, schema=self.schema)

@Operations.register_operation("remove_column_from_activity")
class RemoveColumnFromRemoveActivityOp(MigrateOperation):
"""Drop Activity Table Triggers"""

def __init__(self, table_name, column_name, default_value=None, schema=None):
self.schema = schema
self.table_name = table_name
self.column_name = column_name
self.default_value = default_value

@classmethod
def remove_column_from_activity(cls, operations, table_name, column_name, **kwargs):
op = RemoveColumnFromRemoveActivityOp(table_name, column_name, **kwargs)
return operations.invoke(op)

def reverse(self):
# only needed to support autogenerate
return AddColumnToActivityOp(self.table_name, self.column_name, default_value=self.default_value, schema=self.schema)


@Operations.implementation_for(AddColumnToActivityOp)
def add_column_to_activity(operations, operation):
add_column(operations, operation.table_name, operation.column_name, default_value=operation.default_value, schema=operation.schema)


@Operations.implementation_for(RemoveColumnFromRemoveActivityOp)
def remove_column_from_activity(operations, operation):
conn = operations.connection
remove_column(conn, operation.table_name, operation.column_name, operation.schema)

@renderers.dispatch_for(AddColumnToActivityOp)
def render_add_column_to_activity(autogen_context, op):
return "op.add_column_to_activity(%r, %r, **%r)" % (
op.table_name,
op.column_name,
{"schema": op.schema, "default_value": op.default_value}
)

@renderers.dispatch_for(RemoveColumnFromRemoveActivityOp)
def render_remove_column_from_activitys(autogen_context, op):
return "op.remove_column_from_activity(%r, %r, **%r)" % (
op.table_name,
op.column_name,
{"schema": op.schema}
)
76 changes: 76 additions & 0 deletions postgresql_audit/alembic/register_table_for_version_tracking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import sqlalchemy as sa

from alembic.autogenerate import renderers
from alembic.operations import Operations, MigrateOperation


@Operations.register_operation("register_for_version_tracking")
class RegisterTableForVersionTrackingOp(MigrateOperation):
"""Register Table for Version Tracking"""

def __init__(self, tablename, excluded_columns, original_excluded_columns=None, schema=None):
self.schema = schema
self.tablename = tablename
self.excluded_columns = excluded_columns
self.original_excluded_columns = original_excluded_columns

@classmethod
def register_for_version_tracking(cls, operations, tablename, exclude_columns, **kwargs):
op = RegisterTableForVersionTrackingOp(tablename, exclude_columns, **kwargs)
return operations.invoke(op)

def reverse(self):
# only needed to support autogenerate
return DeregisterTableForVersionTrackingOp(self.tablename, self.original_excluded_columns, schema=self.schema)

@Operations.register_operation("deregister_for_version_tracking")
class DeregisterTableForVersionTrackingOp(MigrateOperation):
"""Drop Table from Version Tracking"""

def __init__(self, tablename, excluded_columns, schema=None):
self.schema = schema
self.tablename = tablename
self.excluded_columns = excluded_columns


@classmethod
def deregister_for_version_tracking(cls, operations, tablename, **kwargs):
op = DeregisterTableForVersionTrackingOp(tablename, (), **kwargs)
return operations.invoke(op)

def reverse(self):
# only needed to support autogenerate
return RegisterTableForVersionTrackingOp(self.tablename, self.excluded_columns, (), schema=self.schema)


@Operations.implementation_for(RegisterTableForVersionTrackingOp)
def register_for_version_tracking(operations, operation):
if operation.schema is None:
func = sa.func.audit_table
else:
func = getattr(getattr(sa.func, operation.schema), 'audit_table')
operations.execute(sa.select([func(operation.tablename, list(operation.excluded_columns))]))


@Operations.implementation_for(DeregisterTableForVersionTrackingOp)
def deregister_for_version_tracking(operations, operation):
operations.execute(f"drop trigger if exists audit_trigger_insert on {operation.tablename} ")
operations.execute(f"drop trigger if exists audit_trigger_update on {operation.tablename} ")
operations.execute(f"drop trigger if exists audit_trigger_delete on {operation.tablename} ")
operations.execute(f"drop trigger if exists audit_trigger_row on {operation.tablename} ")


@renderers.dispatch_for(RegisterTableForVersionTrackingOp)
def render_register_for_version_tracking(autogen_context, op):
return "op.register_for_version_tracking(%r, %r, **%r)" % (
op.tablename,
op.excluded_columns,
{"schema": op.schema}
)

@renderers.dispatch_for(DeregisterTableForVersionTrackingOp)
def render_deregister_for_version_tracking(autogen_context, op):
return "op.deregister_for_version_tracking(%r, **%r)" % (
op.tablename,
{"schema": op.schema}
)
Loading