Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trigger ignore_others parameter addition #194

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions docs/ignoring_triggers.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,66 @@ If you're ignoring triggers and handling database errors, there are two ways to

1. Wrap the outer transaction in `with pgtrigger.ignore.session():` so that the session is completed outside the transaction.
2. Wrap the inner `try/except` in `with transaction.atomic():` so that the errored part of the transaction is rolled back before the [pgtrigger.ignore][] context manager ends.


## Ignore other triggers within a trigger

Provide an `ignore_others` list of trigger URIs you would like to ignore while executing
a certain trigger. See the example below for details:

The `increment_comment_count` trigger will update the `comment_count` on a topic, instead of calculating
the count each time a topic is queried. Let's assume you are fixing a Justin Bieber Instagram
[bug](https://www.wired.com/2015/11/how-instagram-solved-its-justin-bieber-problem/). However we have
also protected the `comment_count` with a `pgtrigger.ReadOnly(name='read_only_comment_count')` trigger.

In this case you would provide a `ignore_others=['tests.Topic:read_only_comment_count']` to the
`increment_comment_count` trigger.

```python
class Topic(models.Model):
name = models.CharField(max_length=100)
comment_count = models.PositiveIntegerField(default=0)

class Meta:
triggers = [
pgtrigger.ReadOnly(
name='read_only_comment_count',
fields=['comment_count']
)
]


class Comment(models.Model):
topic = models.ForeignKey(Topic, on_delete=models.CASCADE)
# Other fields

class Meta:
triggers = [
pgtrigger.Trigger(
func=pgtrigger.Func(
'''
UPDATE "{db_table}"
SET "{comment_count}" = "{comment_count}" + 1
WHERE
"{db_table}"."{topic_pk}" = NEW."{columns.topic}";
{reset_ignore}
RETURN NEW;
''',
db_table = Topic._meta.db_table,
comment_count = Topic._meta.get_field('comment_count').get_attname_column()[1],
topic_pk = Topic._meta.pk.get_attname_column()[1]
),
ignore_others=['tests.Topic:read_only_comment_count'],
when=pgtrigger.Before,
operation=pgtrigger.Insert,
name='increment_comment_count'
),
]
```

!!! important

Remember to use the `{reset_ignore}` placeholder in the trigger function before you return
from any branch. Without it the triggers you have ignored will persist throughout the session.

It is mandatory to provide an instace of `pgtrigger.Func` to the `func` parameter.
2 changes: 2 additions & 0 deletions pgtrigger/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def get_template(self):
RETURN NEW;
END IF;
END IF;
{local_ignore}
{func}
END;
$$ LANGUAGE plpgsql;
Expand Down Expand Up @@ -118,6 +119,7 @@ def __init__(
condition=_unset,
execute=_unset,
hash=None,
local_ignore="",
):
"""Initialize the SQL and store it in the `.data` attribute."""
self.kwargs = {
Expand Down
138 changes: 112 additions & 26 deletions pgtrigger/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
else:
raise AssertionError

from .version import __version__ as ver

# Postgres only allows identifiers to be 63 chars max. Since "pgtrigger_"
# is the prefix for trigger names, and since an additional "_" and
Expand Down Expand Up @@ -510,10 +511,11 @@ class Func:
possible to do inline SQL in the `Meta` of a model and reference its properties.
"""

def __init__(self, func):
def __init__(self, func: str, **kwargs):
self.func = func
self.kwargs = kwargs

def render(self, model: models.Model) -> str:
def render(self, model: type[models.Model], trigger: Trigger) -> str:
"""
Render the SQL of the function.

Expand All @@ -523,9 +525,24 @@ def render(self, model: models.Model) -> str:
Returns:
The rendered SQL.
"""
fields = utils.AttrDict({field.name: field for field in model._meta.fields})
columns = utils.AttrDict({field.name: field.column for field in model._meta.fields})
return self.func.format(meta=model._meta, fields=fields, columns=columns)
kwargs = {
"meta": model._meta,
"fields": utils.AttrDict({field.name: field for field in model._meta.fields}),
"columns": utils.AttrDict({field.name: field.column for field in model._meta.fields}),
"reset_ignore": (
"""
IF _prev_ignore IS NOT NULL AND (_prev_ignore = '') IS NOT TRUE THEN
PERFORM set_config('pgtrigger.ignore', _prev_ignore, true);
ELSE
PERFORM set_config('pgtrigger.ignore', '', true);
END IF;
"""
if trigger.ignores_others
else ""
),
} | self.kwargs

return self.func.format(**kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Func doesn't seem to be the optimal place for generating boilerplate wrapper SQL like this. For example, this is where triggers are wrapped with the ability to be ignored. I assumed that we would also generate this boilerplate in a similar way. Having a self.kwargs variable also strikes me as a hack.

Was there a reason behind overriding Func?

Copy link
Author

@ikcom ikcom Dec 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only that Func seemed to be only a wrapper for string formatting. This extends it to also provide custom context. See docs/ignoring_triggers.md:101

The reset_ignore template can very well be defined outside the Func

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

About your example in compiler.py. I could not think of a way to wrap a user defined trigger body in such a way that resetting the _pgtrigger_ignore would always get executed. The user might return from any branch in their script.
The user would have to decide when to reset it



# Allows Trigger methods to be used as context managers, mostly for
Expand Down Expand Up @@ -559,6 +576,7 @@ class Trigger:
func: Func | str | None = None
declare: list[tuple[str, str]] | None = None
timing: Timing | None = None
ignore_others: list[str] | None = None

def __init__(
self,
Expand All @@ -572,6 +590,7 @@ def __init__(
func: Func | str | None = None,
declare: List[Tuple[str, str]] | None = None,
timing: Timing | None = None,
ignore_others: list[str] | None = None,
) -> None:
self.name = name or self.name
self.level = level or self.level
Expand All @@ -582,6 +601,7 @@ def __init__(
self.func = func or self.func
self.declare = declare or self.declare
self.timing = timing or self.timing
self.ignore_others = ignore_others or self.ignore_others

if not self.level or not isinstance(self.level, Level):
raise ValueError(f'Invalid "level" attribute: {self.level}')
Expand Down Expand Up @@ -609,6 +629,20 @@ def __init__(

self.validate_name()

if self.ignores_others:
if not isinstance(self.func, Func):
raise ValueError(
'Invalid "func" attribute. Triggers that ignore others must provide '
f"an instance of pgtrigger.Func(). Received {type(self.func)} instead"
)

if "{reset_ignore}" not in self.func.func:
raise ValueError(
f'Trigger "{self}" ignores other triggers, however, '
"placeholder {reset_ignore} was not found in the function "
f"body. Please refer to: https://django-pgtrigger.readthedocs.io/en/{ver}/ignoring_triggers/#ignore-other-triggers-within-a-trigger"
)

def __str__(self) -> str: # pragma: no cover
return self.name

Expand All @@ -627,7 +661,7 @@ def validate_name(self) -> None:
" Only alphanumeric characters, hyphens, and underscores are allowed."
)

def get_pgid(self, model: models.Model) -> str:
def get_pgid(self, model: type[models.Model]) -> str:
"""The ID of the trigger and function object in postgres

All objects are prefixed with "pgtrigger_" in order to be
Expand All @@ -650,7 +684,7 @@ def get_pgid(self, model: models.Model) -> str:
# and pruning tasks.
return pgid.lower()

def get_condition(self, model: models.Model) -> Condition:
def get_condition(self, model: type[models.Model]) -> Condition:
"""Get the condition of the trigger.

Args:
Expand All @@ -661,7 +695,7 @@ def get_condition(self, model: models.Model) -> Condition:
"""
return self.condition

def get_declare(self, model: models.Model) -> List[Tuple[str, str]]:
def get_declare(self, model: type[models.Model]) -> List[Tuple[str, str]]:
"""
Gets the DECLARE part of the trigger function if any variables
are used.
Expand All @@ -673,9 +707,60 @@ def get_declare(self, model: models.Model) -> List[Tuple[str, str]]:
A list of variable name / type tuples that will
be shown in the DECLARE. For example [('row_data', 'JSONB')]
"""
return self.declare or []
declare = self.declare or []

if self.ignore_others is not None:
declare.append(("_prev_ignore", "text"))
declare.append(self.declare_local_ignore(self.ignore_others))

return declare

def declare_local_ignore(self, ignore: list[str]) -> tuple[str, str]:
"""Given a list of trigger URIs compile the value for `_local_ignore`
variable of the trigger function

def get_func(self, model: models.Model) -> Union[str, Func]:
Parameters
----------
ignore : list[str]
List of trigger URIs

Returns
-------
tuple[str, str]
`_local_ignore` variable declaration and initial value for the DECLARE block
"""
local_ignore = (
"{"
+ ",".join(
f"{model._meta.db_table}:{(pgid:=trigger.get_pgid(model))},{pgid}"
for model, trigger in registry.registered(*ignore)
)
+ "}"
)
return ("_local_ignore", f"text[] = '{local_ignore}'")

@property
def ignores_others(self) -> bool:
"""True if the trigger is initialized with local trigger ignores"""
return self.ignore_others is not None

def render_local_ignore(self):
if self.ignores_others:
return """
BEGIN
SELECT CURRENT_SETTING('pgtrigger.ignore', true) INTO _prev_ignore;
EXCEPTION WHEN OTHERS THEN
END;

IF _prev_ignore IS NOT NULL AND (_prev_ignore = '') IS NOT TRUE THEN
SELECT _local_ignore || _prev_ignore::text[] INTO _local_ignore;
END IF;

PERFORM set_config('pgtrigger.ignore', _local_ignore::text, true);
"""
return ""

def get_func(self, model: type[models.Model]) -> Union[str, Func]:
"""
Returns the trigger function that comes between the BEGIN and END
clause.
Expand All @@ -690,7 +775,7 @@ def get_func(self, model: models.Model) -> Union[str, Func]:
raise ValueError("Must define func attribute or implement get_func")
return self.func

def get_uri(self, model: models.Model) -> str:
def get_uri(self, model: type[models.Model]) -> str:
"""The URI for the trigger.

Args:
Expand All @@ -702,7 +787,7 @@ def get_uri(self, model: models.Model) -> str:

return f"{model._meta.app_label}.{model._meta.object_name}:{self.name}"

def render_condition(self, model: models.Model) -> str:
def render_condition(self, model: type[models.Model]) -> str:
"""Renders the condition SQL in the trigger declaration.

Args:
Expand All @@ -721,7 +806,7 @@ def render_condition(self, model: models.Model) -> str:

return resolved

def render_declare(self, model: models.Model) -> str:
def render_declare(self, model: type[models.Model]) -> str:
"""Renders the DECLARE of the trigger function, if any.

Args:
Expand All @@ -740,7 +825,7 @@ def render_declare(self, model: models.Model) -> str:

return rendered_declare

def render_execute(self, model: models.Model) -> str:
def render_execute(self, model: type[models.Model]) -> str:
"""
Renders what should be executed by the trigger. This defaults
to the trigger function.
Expand All @@ -753,7 +838,7 @@ def render_execute(self, model: models.Model) -> str:
"""
return f"{self.get_pgid(model)}()"

def render_func(self, model: models.Model) -> str:
def render_func(self, model: type[models.Model]) -> str:
"""
Renders the func.

Expand All @@ -766,11 +851,11 @@ def render_func(self, model: models.Model) -> str:
func = self.get_func(model)

if isinstance(func, Func):
return func.render(model)
else:
return func
return func.render(model, self)

return func

def compile(self, model: models.Model) -> compiler.Trigger:
def compile(self, model: type[models.Model]) -> compiler.Trigger:
"""
Create a compiled representation of the trigger. useful for migrations.

Expand All @@ -796,10 +881,11 @@ def compile(self, model: models.Model) -> compiler.Trigger:
level=self.level,
condition=self.render_condition(model),
execute=self.render_execute(model),
local_ignore=self.render_local_ignore(),
),
)

def allow_migrate(self, model: models.Model, database: Union[str, None] = None) -> bool:
def allow_migrate(self, model: type[models.Model], database: Union[str, None] = None) -> bool:
"""True if the trigger for this model can be migrated.

Defaults to using the router's allow_migrate.
Expand Down Expand Up @@ -830,7 +916,7 @@ def format_sql(self, sql: str) -> str:
def exec_sql(
self,
sql: str,
model: models.Model,
model: type[models.Model],
database: Union[str, None] = None,
fetchall: bool = False,
) -> Any:
Expand All @@ -849,7 +935,7 @@ def exec_sql(
return utils.exec_sql(str(sql), database=database, fetchall=fetchall)

def get_installation_status(
self, model: models.Model, database: Union[str, None] = None
self, model: type[models.Model], database: Union[str, None] = None
) -> Tuple[str, Union[bool, None]]:
"""Returns the installation status of a trigger.

Expand Down Expand Up @@ -922,7 +1008,7 @@ def unregister(self, *models: models.Model):

return _cleanup_on_exit(lambda: self.register(*models))

def install(self, model: models.Model, database: Union[str, None] = None):
def install(self, model: type[models.Model], database: Union[str, None] = None):
"""Installs the trigger for a model.

Args:
Expand All @@ -934,7 +1020,7 @@ def install(self, model: models.Model, database: Union[str, None] = None):
self.exec_sql(install_sql, model, database=database)
return _cleanup_on_exit(lambda: self.uninstall(model, database=database))

def uninstall(self, model: models.Model, database: Union[str, None] = None):
def uninstall(self, model: type[models.Model], database: Union[str, None] = None):
"""Uninstalls the trigger for a model.

Args:
Expand All @@ -947,7 +1033,7 @@ def uninstall(self, model: models.Model, database: Union[str, None] = None):
lambda: self.install(model, database=database)
)

def enable(self, model: models.Model, database: Union[str, None] = None):
def enable(self, model: type[models.Model], database: Union[str, None] = None):
"""Enables the trigger for a model.

Args:
Expand All @@ -960,7 +1046,7 @@ def enable(self, model: models.Model, database: Union[str, None] = None):
lambda: self.disable(model, database=database)
)

def disable(self, model: models.Model, database: Union[str, None] = None):
def disable(self, model: type[models.Model], database: Union[str, None] = None):
"""Disables the trigger for a model.

Args:
Expand Down
Loading