-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
IR: Better tuple autocasting via pydantic field_validators #476
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,7 +21,7 @@ | |
from pymbolic.primitives import Expression | ||
|
||
from pydantic.dataclasses import dataclass as dataclass_validated | ||
from pydantic import model_validator | ||
from pydantic import field_validator | ||
|
||
from loki.expression import ( | ||
symbols as sym, Variable, parse_expr, AttachScopesMapper, | ||
|
@@ -257,17 +257,10 @@ class InternalNode(Node, _InternalNode): | |
|
||
_traversable = ['body'] | ||
|
||
@model_validator(mode='before') | ||
@field_validator('body', mode='before') | ||
@classmethod | ||
def pre_init(cls, values): | ||
""" Ensure non-nested tuples for body. """ | ||
if values.kwargs and 'body' in values.kwargs: | ||
values.kwargs['body'] = _sanitize_tuple(values.kwargs['body']) | ||
if values.args: | ||
# ArgsKwargs are immutable, so we need to force it a little | ||
new_args = (_sanitize_tuple(values.args[0]),) + values.args[1:] | ||
values = type(values)(args=new_args, kwargs=values.kwargs) | ||
return values | ||
def ensure_tuple(cls, value): | ||
return _sanitize_tuple(value) | ||
|
||
def __repr__(self): | ||
raise NotImplementedError | ||
|
@@ -721,14 +714,10 @@ class Conditional(InternalNode, _ConditionalBase): | |
|
||
_traversable = ['condition', 'body', 'else_body'] | ||
|
||
@model_validator(mode='before') | ||
@field_validator('body', 'else_body', mode='before') | ||
@classmethod | ||
def pre_init(cls, values): | ||
values = super().pre_init(values) | ||
# Ensure non-nested tuples for else_body | ||
if 'else_body' in values.kwargs: | ||
values.kwargs['else_body'] = _sanitize_tuple(values.kwargs['else_body']) | ||
return values | ||
def ensure_tuple(cls, value): | ||
return _sanitize_tuple(value) | ||
|
||
def __post_init__(self): | ||
super().__post_init__() | ||
|
@@ -960,8 +949,8 @@ class _CallStatementBase(): | |
""" Type definitions for :any:`CallStatement` node type. """ | ||
|
||
name: Expression | ||
arguments: Optional[Tuple[Expression, ...]] = None | ||
kwarguments: Optional[Tuple[Tuple[str, Expression], ...]] = None | ||
arguments: Optional[Tuple[Expression, ...]] = () | ||
kwarguments: Optional[Tuple[Tuple[str, Expression], ...]] = () | ||
Comment on lines
+952
to
+953
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this a change in behaviour, i.e., There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Internally, we would always force this into |
||
pragma: Optional[Tuple[Node, ...]] = None | ||
not_active: Optional[bool] = None | ||
chevron: Optional[Tuple[Expression, ...]] = None | ||
|
@@ -997,22 +986,15 @@ class CallStatement(LeafNode, _CallStatementBase): | |
|
||
_traversable = ['name', 'arguments', 'kwarguments'] | ||
|
||
@model_validator(mode='before') | ||
@field_validator('arguments', mode='before') | ||
@classmethod | ||
def pre_init(cls, values): | ||
# Ensure non-nested tuples for arguments | ||
if 'arguments' in values.kwargs: | ||
values.kwargs['arguments'] = _sanitize_tuple(values.kwargs['arguments']) | ||
else: | ||
values.kwargs['arguments'] = () | ||
# Ensure two-level nested tuples for kwarguments | ||
if 'kwarguments' in values.kwargs: | ||
kwarguments = as_tuple(values.kwargs['kwarguments']) | ||
kwarguments = tuple(_sanitize_tuple(pair) for pair in kwarguments) | ||
values.kwargs['kwarguments'] = kwarguments | ||
else: | ||
values.kwargs['kwarguments'] = () | ||
return values | ||
def ensure_tuple(cls, value): | ||
return _sanitize_tuple(value) | ||
|
||
@field_validator('kwarguments', mode='before') | ||
@classmethod | ||
def ensure_nested_tuple(cls, value): | ||
return tuple(_sanitize_tuple(pair) for pair in as_tuple(value)) | ||
|
||
def __post_init__(self): | ||
super().__post_init__() | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -96,6 +96,12 @@ | |||||
loop = ir.Loop(variable=i, bounds=bounds, body=( assign, (assign,), assign, None)) | ||||||
assert loop.body == (assign, assign, assign) | ||||||
|
||||||
# Test auto-casting with unnamed constructor args | ||||||
loop = ir.Loop(i, bounds, assign) | ||||||
assert loop.body == (assign,) | ||||||
loop = ir.Loop(i, bounds, [(assign,), None, assign]) | ||||||
assert loop.body == (assign, assign) | ||||||
|
||||||
# Test errors for wrong contructor usage | ||||||
with pytest.raises(ValidationError): | ||||||
ir.Loop(variable=i, bounds=bounds, body=n) | ||||||
|
@@ -148,6 +154,14 @@ | |||||
) | ||||||
assert not cond.body and cond.else_body == (assign, assign, assign) | ||||||
|
||||||
# Test auto-casting with unnamed constructor args | ||||||
cond = ir.Conditional(condition) | ||||||
assert cond.body == () and not cond.else_body | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Pylint doesn't like
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here, |
||||||
cond = ir.Conditional(condition, assign) | ||||||
assert cond.body == (assign,) and not cond.else_body | ||||||
cond = ir.Conditional(condition, body=[assign, (assign,)], else_body=[assign, None, (assign,)]) | ||||||
assert cond.body == (assign, assign) and cond.else_body == (assign, assign) | ||||||
|
||||||
# TODO: Test inline, name, has_elseif | ||||||
|
||||||
|
||||||
|
@@ -235,6 +249,12 @@ | |||||
sec = ir.Section((assign, (func,), assign, None)) | ||||||
assert sec.body == (assign, func, assign) | ||||||
|
||||||
# Test auto-casting with unnamed constructor args | ||||||
sec = ir.Section() | ||||||
assert sec.body == () | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Or even just |
||||||
sec = ir.Section([(assign,), assign, None, assign]) | ||||||
assert sec.body == (assign, assign, assign) | ||||||
|
||||||
# Test prepend/insert/append additions | ||||||
sec = ir.Section(body=func) | ||||||
assert sec.body == (func,) | ||||||
|
@@ -281,6 +301,12 @@ | |||||
call = ir.CallStatement(name=cname, kwarguments=None) | ||||||
assert not call.arguments and not call.kwarguments | ||||||
|
||||||
# Test auto-casting with unnamed constructor args | ||||||
call = ir.CallStatement(cname, a_i) | ||||||
assert call.arguments == (a_i,) and not call.kwarguments | ||||||
call = ir.CallStatement(cname, [a_i, one], [('i', i), ('j', one)]) | ||||||
assert call.arguments == (a_i, one) and call.kwarguments == (('i', i), ('j', one)) | ||||||
|
||||||
# Test errors for wrong contructor usage | ||||||
with pytest.raises(ValidationError): | ||||||
ir.CallStatement(name='a', arguments=(sym.Literal(42.0),)) | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Likely my ignorance of pydantic showing here: The
pre_init
method did also stuff toargs
andkwargs
before. Is this implicitly handled now or not required anymore?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, but this one worked on the entire "Model" in pydantic speak, so it got a single
ArgsKwargs
argument that held all the constructor arguments asArgsKwarg.args
andArgsKwargs.kwargs
. This became an issue, as theArgsKwargs
object is immutable, but theArgsKwarg.kwargs
dict could be changed. This means, if any kw-args were given, I could change the dict, but I cannot replaceNone
with()
asArgsKwargs.args = ()
is not allowed.Long story short, this new validator is applied to each arg individual when pydantic creates the
ArgsKwargs
object, so all our troubles fade away and pydantic takes care of identifying unnamed args for free.