Skip to content

Commit

Permalink
Implement has_account() as AST substitution
Browse files Browse the repository at this point in the history
The AST is rewritten to contain an explicit reference to the accounts
column. This removes the need to pass the current row object (which is
of the expected kind only when iterating the postings or entries
tables) to the function evaluation. This removes the last function
implementation that uses the current row.
  • Loading branch information
dnicolodi committed Jan 19, 2025
1 parent 1990d80 commit 75b3719
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
5 changes: 5 additions & 0 deletions beanquery/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,11 @@ def _function(self, node: ast.Function):
ast.Attribute(ast.Column('entry', parseinfo=node.parseinfo), 'meta'), key])])
return self._compile(node)

# Replace ``has_account(regexp)`` with ``('(?i)' + regexp) ~? any (accounts)``.
if node.fname == 'has_account':
node = ast.Any(ast.Add(ast.Constant('(?i)'), node.operands[0]), '?~', ast.Column('accounts'))
return self._compile(node)

function = function(self.context, operands)
# Constants folding.
if all(isinstance(operand, EvalConstant) for operand in operands) and function.pure:
Expand Down
15 changes: 6 additions & 9 deletions beanquery/query_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,18 @@
from beanquery import types


def function(intypes, outtype, pass_row=False, pass_context=None, name=None):
assert not (pass_row and pass_context)
def function(intypes, outtype, pass_context=None, name=None):
def decorator(func):
class Func(query_compile.EvalFunction):
__intypes__ = intypes
pure = not pass_row and not pass_context
pure = not pass_context
def __init__(self, context, operands):
super().__init__(context, operands, outtype)
def __call__(self, row):
args = [operand(row) for operand in self.operands]
for arg in args:
if arg is None:
return None
if pass_row:
return func(row, *args)
if pass_context:
return func(self.context, *args)
return func(*args)
Expand Down Expand Up @@ -395,7 +392,7 @@ def entry_meta(context, key):


# Stub kept only for function type checking and for generating documentation.
@function([str], object, pass_row=True)
@function([str], object)
def any_meta(context, key):
"""Get metadata from the posting or its parent transaction if not present."""
raise NotImplementedError
Expand Down Expand Up @@ -423,11 +420,11 @@ def account_sortkey(context, acc):
return '{}-{}'.format(index, name)


@function([str], bool, pass_row=True)
# Stub kept only for function type checking and for generating documentation.
@function([str], bool)
def has_account(context, pattern):
"""True if the transaction has at least one posting matching the regular expression argument."""
search = re.compile(pattern, re.IGNORECASE).search
return any(search(account) for account in getters.get_entry_accounts(context.entry))
raise NotImplementedError


# Note: Don't provide this, because polymorphic multiplication on Amount,
Expand Down

0 comments on commit 75b3719

Please sign in to comment.