Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IR: Better tuple autocasting via pydantic field_validators #476

Merged
merged 2 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 17 additions & 35 deletions loki/ir/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pymbolic.primitives import Expression

from pydantic.dataclasses import dataclass as dataclass_validated
from pydantic import model_validator
from pydantic import field_validator

from loki.expression import (
symbols as sym, Variable, parse_expr, AttachScopesMapper,
Expand Down Expand Up @@ -257,17 +257,10 @@ class InternalNode(Node, _InternalNode):

_traversable = ['body']

@model_validator(mode='before')
@field_validator('body', mode='before')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likely my ignorance of pydantic showing here: The pre_init method did also stuff to args and kwargs before. Is this implicitly handled now or not required anymore?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but this one worked on the entire "Model" in pydantic speak, so it got a single ArgsKwargs argument that held all the constructor arguments as ArgsKwarg.args and ArgsKwargs.kwargs. This became an issue, as the ArgsKwargs object is immutable, but the ArgsKwarg.kwargs dict could be changed. This means, if any kw-args were given, I could change the dict, but I cannot replace None with () as ArgsKwargs.args = () is not allowed.

Long story short, this new validator is applied to each arg individual when pydantic creates the ArgsKwargs object, so all our troubles fade away and pydantic takes care of identifying unnamed args for free.

@classmethod
def pre_init(cls, values):
""" Ensure non-nested tuples for body. """
if values.kwargs and 'body' in values.kwargs:
values.kwargs['body'] = _sanitize_tuple(values.kwargs['body'])
if values.args:
# ArgsKwargs are immutable, so we need to force it a little
new_args = (_sanitize_tuple(values.args[0]),) + values.args[1:]
values = type(values)(args=new_args, kwargs=values.kwargs)
return values
def ensure_tuple(cls, value):
return _sanitize_tuple(value)

def __repr__(self):
raise NotImplementedError
Expand Down Expand Up @@ -721,14 +714,10 @@ class Conditional(InternalNode, _ConditionalBase):

_traversable = ['condition', 'body', 'else_body']

@model_validator(mode='before')
@field_validator('body', 'else_body', mode='before')
@classmethod
def pre_init(cls, values):
values = super().pre_init(values)
# Ensure non-nested tuples for else_body
if 'else_body' in values.kwargs:
values.kwargs['else_body'] = _sanitize_tuple(values.kwargs['else_body'])
return values
def ensure_tuple(cls, value):
return _sanitize_tuple(value)

def __post_init__(self):
super().__post_init__()
Expand Down Expand Up @@ -960,8 +949,8 @@ class _CallStatementBase():
""" Type definitions for :any:`CallStatement` node type. """

name: Expression
arguments: Optional[Tuple[Expression, ...]] = None
kwarguments: Optional[Tuple[Tuple[str, Expression], ...]] = None
arguments: Optional[Tuple[Expression, ...]] = ()
kwarguments: Optional[Tuple[Tuple[str, Expression], ...]] = ()
Comment on lines +952 to +953
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a change in behaviour, i.e., arguments and kwarguments always a tuple now instead of None? (I'm not opposed, just trying to understand the impact)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Internally, we would always force this into () in the pre-validator autocast mechanism. The reason I need to change it here is that the before-validator does not get invoked if the arg is omitted so the default is used directly. This should be agnostic to the previous behaviour. 🤞

pragma: Optional[Tuple[Node, ...]] = None
not_active: Optional[bool] = None
chevron: Optional[Tuple[Expression, ...]] = None
Expand Down Expand Up @@ -997,22 +986,15 @@ class CallStatement(LeafNode, _CallStatementBase):

_traversable = ['name', 'arguments', 'kwarguments']

@model_validator(mode='before')
@field_validator('arguments', mode='before')
@classmethod
def pre_init(cls, values):
# Ensure non-nested tuples for arguments
if 'arguments' in values.kwargs:
values.kwargs['arguments'] = _sanitize_tuple(values.kwargs['arguments'])
else:
values.kwargs['arguments'] = ()
# Ensure two-level nested tuples for kwarguments
if 'kwarguments' in values.kwargs:
kwarguments = as_tuple(values.kwargs['kwarguments'])
kwarguments = tuple(_sanitize_tuple(pair) for pair in kwarguments)
values.kwargs['kwarguments'] = kwarguments
else:
values.kwargs['kwarguments'] = ()
return values
def ensure_tuple(cls, value):
return _sanitize_tuple(value)

@field_validator('kwarguments', mode='before')
@classmethod
def ensure_nested_tuple(cls, value):
return tuple(_sanitize_tuple(pair) for pair in as_tuple(value))

def __post_init__(self):
super().__post_init__()
Expand Down
26 changes: 26 additions & 0 deletions loki/ir/tests/test_ir_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@
loop = ir.Loop(variable=i, bounds=bounds, body=( assign, (assign,), assign, None))
assert loop.body == (assign, assign, assign)

# Test auto-casting with unnamed constructor args
loop = ir.Loop(i, bounds, assign)
assert loop.body == (assign,)
loop = ir.Loop(i, bounds, [(assign,), None, assign])
assert loop.body == (assign, assign)

# Test errors for wrong contructor usage
with pytest.raises(ValidationError):
ir.Loop(variable=i, bounds=bounds, body=n)
Expand Down Expand Up @@ -148,6 +154,14 @@
)
assert not cond.body and cond.else_body == (assign, assign, assign)

# Test auto-casting with unnamed constructor args
cond = ir.Conditional(condition)
assert cond.body == () and not cond.else_body

Check failure on line 159 in loki/ir/tests/test_ir_nodes.py

View workflow job for this annotation

GitHub Actions / code checks (3.11)

C1803: "cond.body == ()" can be simplified to "not cond.body", if it is strictly a sequence, as an empty tuple is falsey (use-implicit-booleaness-not-comparison)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pylint doesn't like cond.body == (), maybe

Suggested change
assert cond.body == () and not cond.else_body
assert isinstance(cond.body, tuple) and not (cond.body or cond.else_body)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, cond.body is () should do to appease pylint...

cond = ir.Conditional(condition, assign)
assert cond.body == (assign,) and not cond.else_body
cond = ir.Conditional(condition, body=[assign, (assign,)], else_body=[assign, None, (assign,)])
assert cond.body == (assign, assign) and cond.else_body == (assign, assign)

# TODO: Test inline, name, has_elseif


Expand Down Expand Up @@ -235,6 +249,12 @@
sec = ir.Section((assign, (func,), assign, None))
assert sec.body == (assign, func, assign)

# Test auto-casting with unnamed constructor args
sec = ir.Section()
assert sec.body == ()

Check failure on line 254 in loki/ir/tests/test_ir_nodes.py

View workflow job for this annotation

GitHub Actions / code checks (3.11)

C1803: "sec.body == ()" can be simplified to "not sec.body", if it is strictly a sequence, as an empty tuple is falsey (use-implicit-booleaness-not-comparison)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert sec.body == ()
assert isinstance(sec.body, tuple) and not sec.body

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or even just sec.body is ()...

sec = ir.Section([(assign,), assign, None, assign])
assert sec.body == (assign, assign, assign)

# Test prepend/insert/append additions
sec = ir.Section(body=func)
assert sec.body == (func,)
Expand Down Expand Up @@ -281,6 +301,12 @@
call = ir.CallStatement(name=cname, kwarguments=None)
assert not call.arguments and not call.kwarguments

# Test auto-casting with unnamed constructor args
call = ir.CallStatement(cname, a_i)
assert call.arguments == (a_i,) and not call.kwarguments
call = ir.CallStatement(cname, [a_i, one], [('i', i), ('j', one)])
assert call.arguments == (a_i, one) and call.kwarguments == (('i', i), ('j', one))

# Test errors for wrong contructor usage
with pytest.raises(ValidationError):
ir.CallStatement(name='a', arguments=(sym.Literal(42.0),))
Expand Down
Loading