Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature[next]: Non-tree-size-increasing collapse tuple on ifs #1762

Merged
merged 27 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
34d6040
Non-tree-size-increasing collapse tuple on ifs
tehrengruber Dec 1, 2024
48abc08
Fix typos
tehrengruber Dec 1, 2024
3790944
Fix typos
tehrengruber Dec 1, 2024
a8a63bf
Disable PROPAGATE_TO_IF_ON_TUPLES by default in pass manager
tehrengruber Dec 1, 2024
2c44ffc
Improve doc
tehrengruber Dec 1, 2024
42b5817
Improve typo
tehrengruber Dec 1, 2024
9da19a2
Improve typo
tehrengruber Dec 1, 2024
bcd9e48
Improve typo
tehrengruber Dec 1, 2024
0a212bd
Improve typo
tehrengruber Dec 1, 2024
70562fe
Fix typo
tehrengruber Dec 1, 2024
9cee650
Improve doc
tehrengruber Dec 1, 2024
43f5741
Format
tehrengruber Dec 1, 2024
7b37f1c
Fix type synthesizer for partially typed arithmetic ops
tehrengruber Dec 1, 2024
a0341a6
Fix test
tehrengruber Dec 1, 2024
5a892f3
Fix test
tehrengruber Dec 1, 2024
914a9e5
Address review comments
tehrengruber Dec 6, 2024
fc46edf
Add test and fix nested transformation on nested ifs
tehrengruber Dec 10, 2024
4e12195
Merge origin/main
tehrengruber Dec 10, 2024
f8b5b99
Merge remote-tracking branch 'origin/main' into ct_cps
tehrengruber Jan 2, 2025
5e1a88c
Improve documentation
tehrengruber Jan 9, 2025
53d442a
Address review comments
tehrengruber Jan 10, 2025
3b1af1e
Address review comments
tehrengruber Jan 10, 2025
96a840a
Address review comments
tehrengruber Jan 10, 2025
9263f4d
Address review comments
tehrengruber Jan 14, 2025
4d700e7
Merge branch 'main' into ct_cps
tehrengruber Jan 14, 2025
72ca50d
Fix format
tehrengruber Jan 14, 2025
2a0bad1
Merge branch 'main' into ct_cps
tehrengruber Jan 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 135 additions & 18 deletions src/gt4py/next/iterator/transforms/collapse_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@
from gt4py.next.type_system import type_info, type_specifications as ts


def _with_altered_arg(node: ir.FunCall, arg_idx: int, new_arg: ir.Expr):
def _with_altered_arg(node: ir.FunCall, arg_idx: int, new_arg: ir.Expr | str):
"""Given a itir.FunCall return a new call with one of its argument replaced."""
return ir.FunCall(
fun=node.fun, args=[arg if i != arg_idx else new_arg for i, arg in enumerate(node.args)]
fun=node.fun,
args=[arg if i != arg_idx else im.ensure_expr(new_arg) for i, arg in enumerate(node.args)],
)


Expand All @@ -47,6 +48,32 @@ def _is_trivial_make_tuple_call(node: ir.Expr):
return True


def _is_trivial_or_tuple_thereof_expr(node: ir.Node) -> bool:
"""
Return `true` if the expr is a trivial expression or tuple thereof.
tehrengruber marked this conversation as resolved.
Show resolved Hide resolved

>>> _is_trivial_or_tuple_thereof_expr(im.make_tuple("a", "b"))
True
>>> _is_trivial_or_tuple_thereof_expr(im.tuple_get(1, "a"))
True
>>> _is_trivial_or_tuple_thereof_expr(
... im.let("t", im.make_tuple("a", "b"))(im.tuple_get(1, "t"))
... )
True
"""
if cpm.is_call_to(node, "make_tuple"):
return all(_is_trivial_or_tuple_thereof_expr(arg) for arg in node.args)
if cpm.is_call_to(node, "tuple_get"):
return _is_trivial_or_tuple_thereof_expr(node.args[1])
if isinstance(node, (ir.SymRef, ir.Literal)):
return True
tehrengruber marked this conversation as resolved.
Show resolved Hide resolved
if cpm.is_let(node):
return _is_trivial_or_tuple_thereof_expr(node.fun.expr) and all( # type: ignore[attr-defined] # ensured by is_let
_is_trivial_or_tuple_thereof_expr(arg) for arg in node.args
)
return False


# TODO(tehrengruber): Conceptually the structure of this pass makes sense: Visit depth first,
# transform each node until no transformations apply anymore, whenever a node is to be transformed
# go through all available transformation and apply them. However the final result here still
Expand Down Expand Up @@ -76,28 +103,42 @@ class Flag(enum.Flag):
#: `let(tup, {trivial_expr1, trivial_expr2})(foo(tup))`
#: -> `foo({trivial_expr1, trivial_expr2})`
INLINE_TRIVIAL_MAKE_TUPLE = enum.auto()
#: Similar as `PROPAGATE_TO_IF_ON_TUPLES`, but propagates in the opposite direction, i.e.
#: into the tree, allowing removal of tuple expressions across `if_` calls without
#: increasing the size of the tree. This is particularly important for `if` statements
#: in the frontend, where outwards propagation can have devastating effects on the tree
#: size, without any gained optimization potential. For example
#: ```
#: complex_lambda(if cond1
#: if cond2
#: {...}
#: else:
#: {...}
#: else
#: {...})
#: ```
#: is problematic, since `PROPAGATE_TO_IF_ON_TUPLES` would propagate and hence duplicate
#: `complex_lambda` three times, while we only want to get rid of the tuple expressions
#: inside of the `if_`s.
#: Note that this transformation is not mutually exclusive to `PROPAGATE_TO_IF_ON_TUPLES`.
PROPAGATE_TO_IF_ON_TUPLES_CPS = enum.auto()
#: `(if cond then {1, 2} else {3, 4})[0]` -> `if cond then {1, 2}[0] else {3, 4}[0]`
PROPAGATE_TO_IF_ON_TUPLES = enum.auto()
#: `let((a, let(b, 1)(a_val)))(a)`-> `let(b, 1)(let(a, a_val)(a))`
PROPAGATE_NESTED_LET = enum.auto()
#: `let(a, 1)(a)` -> `1`
#: `let(a, 1)(a)` -> `1` or `let(a, b)(f(a))` -> `f(a)`
INLINE_TRIVIAL_LET = enum.auto()

@classmethod
def all(self) -> CollapseTuple.Flag:
return functools.reduce(operator.or_, self.__members__.values())

uids: eve_utils.UIDGenerator
ignore_tuple_size: bool
flags: Flag = Flag.all() # noqa: RUF009 [function-call-in-dataclass-default-argument]

PRESERVED_ANNEX_ATTRS = ("type",)

# we use one UID generator per instance such that the generated ids are
# stable across multiple runs (required for caching to properly work)
_letify_make_tuple_uids: eve_utils.UIDGenerator = dataclasses.field(
init=False, repr=False, default_factory=lambda: eve_utils.UIDGenerator(prefix="_tuple_el")
)

@classmethod
def apply(
cls,
Expand All @@ -111,6 +152,7 @@ def apply(
flags: Optional[Flag] = None,
# allow sym references without a symbol declaration, mostly for testing
allow_undeclared_symbols: bool = False,
uids: Optional[eve_utils.UIDGenerator] = None,
) -> ir.Node:
"""
Simplifies `make_tuple`, `tuple_get` calls.
Expand All @@ -127,6 +169,7 @@ def apply(
"""
flags = flags or cls.flags
offset_provider_type = offset_provider_type or {}
uids = uids or eve_utils.UIDGenerator()

if isinstance(node, (ir.Program, ir.FencilDefinition)):
within_stencil = False
Expand All @@ -145,6 +188,7 @@ def apply(
new_node = cls(
ignore_tuple_size=ignore_tuple_size,
flags=flags,
uids=uids,
).visit(node, within_stencil=within_stencil)

# inline to remove left-overs from LETIFY_MAKE_TUPLE_ELEMENTS. this is important
Expand Down Expand Up @@ -185,6 +229,8 @@ def transform(self, node: ir.Node, **kwargs) -> Optional[ir.Node]:
method = getattr(self, f"transform_{transformation.name.lower()}")
result = method(node, **kwargs)
if result is not None:
assert result is not node
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this assert useful?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In case one (erronously) returns the same node again instead of None to signify no change, fp_transform runs into an endless loop. This is a cheap way to prevent this case.

itir_type_inference.reinfer(result)
return result
return None

Expand Down Expand Up @@ -263,13 +309,13 @@ def transform_letify_make_tuple_elements(self, node: ir.FunCall, **kwargs) -> Op
if node.fun == ir.SymRef(id="make_tuple"):
# `make_tuple(expr1, expr1)`
# -> `let((_tuple_el_1, expr1), (_tuple_el_2, expr2))(make_tuple(_tuple_el_1, _tuple_el_2))`
bound_vars: dict[str, ir.Expr] = {}
bound_vars: dict[ir.Sym, ir.Expr] = {}
new_args: list[ir.Expr] = []
for arg in node.args:
if cpm.is_call_to(node, "make_tuple") and not _is_trivial_make_tuple_call(node):
el_name = self._letify_make_tuple_uids.sequential_id()
new_args.append(im.ref(el_name))
bound_vars[el_name] = arg
el_name = self.uids.sequential_id(prefix="__ct_el")
new_args.append(im.ref(el_name, arg.type))
bound_vars[im.sym(el_name, arg.type)] = arg
else:
new_args.append(arg)

Expand Down Expand Up @@ -312,6 +358,73 @@ def transform_propagate_to_if_on_tuples(self, node: ir.FunCall, **kwargs) -> Opt
return im.if_(cond, new_true_branch, new_false_branch)
return None

def transform_propagate_to_if_on_tuples_cps(
havogt marked this conversation as resolved.
Show resolved Hide resolved
self, node: ir.FunCall, **kwargs
) -> Optional[ir.Node]:
if not cpm.is_call_to(node, "if_"):
tehrengruber marked this conversation as resolved.
Show resolved Hide resolved
for i, arg in enumerate(node.args):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we iterate over any functions args or do we know more?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No we don't know more. We need to build and simplify the continuation first to know if we want to transform. It might be possible to handle multiple ifs, but the code would become more complex and error prone, and there is no benefit. It is just important to remember that the entire node is split into the continuation and the first matching if argument.

if cpm.is_call_to(arg, "if_"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure how it looks, but maybe do the same here: if not cpm.is_call_to(arg, "if_"): continue

itir_type_inference.reinfer(arg)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the missing type from a previous transform in the fp iteration?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, as discussed the approach is reinfer on demand.

if not any(isinstance(branch.type, ts.TupleType) for branch in arg.args[1:]):
continue

cond, true_branch, false_branch = arg.args
tuple_type: ts.TupleType = true_branch.type # type: ignore[assignment] # type ensured above
tuple_len = len(tuple_type.types)
itir_type_inference.reinfer(node)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

under which conditions is this needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which part?

assert node.type

# transform function into continuation-passing-style
f_type = ts.FunctionType(
pos_only_args=tuple_type.types,
pos_or_kw_args={},
kw_only_args={},
returns=node.type,
)
f_params = [
im.sym(self.uids.sequential_id(prefix="__ct_el_cps"), type_)
for type_ in tuple_type.types
]
f_args = [im.ref(param.id, param.type) for param in f_params]
f_body = _with_altered_arg(node, i, im.make_tuple(*f_args))
# simplify, e.g., inline trivial make_tuple args
new_f_body = self.fp_transform(f_body, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this recursion, we handle all the other args after the current one? I am confused...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The expression is split into a continuation (which includes all the args following the current one) and the currently transformed arg. Since the currently used arg will use the newly build continuation it is natural to first build & transform the continuation and only then tackle the arg. In the new node (continuation + arg) the continuation appears first. So in the expression the continuation appears first, but evaluation is still first arg then continuation. This is somewhat what continuations are all about (https://en.wikipedia.org/wiki/Continuation-passing_style):

This has the effect of turning expressions "inside-out" because the innermost parts of the expression must be evaluated first

# if the function did not simplify there is nothing to gain. Skip
# transformation.
if new_f_body is f_body:
continue
# if the function is not trivial the transformation would still work, but
# inlining would result in a larger tree again and we didn't didn't gain
# anything compared to regular `propagate_to_if_on_tuples`. Not inling also
# works, but we don't want bound lambda functions in our tree (at least right
# now).
if not _is_trivial_or_tuple_thereof_expr(new_f_body):
continue
f = im.lambda_(*f_params)(new_f_body)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

basically here it's decided that we actually do something.


tuple_var = self.uids.sequential_id(prefix="__ct_tuple_cps")
f_var = self.uids.sequential_id(prefix="__ct_cont")
new_branches = []
for branch in arg.args[1:]:
new_branch = im.let(tuple_var, branch)(
im.call(im.ref(f_var, f_type))(
*(
im.tuple_get(i, im.ref(tuple_var, branch.type))
for i in range(tuple_len)
)
)
)
new_branches.append(self.fp_transform(new_branch, **kwargs))

new_node = im.let(f_var, f)(im.if_(cond, *new_branches))
new_node = inline_lambda(new_node, eligible_params=[True])
assert cpm.is_call_to(new_node, "if_")
new_node = im.if_(
cond, *(self.fp_transform(branch, **kwargs) for branch in new_node.args[1:])
)
return new_node
return None

def transform_propagate_nested_let(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]:
if cpm.is_let(node):
# `let((a, let(b, 1)(a_val)))(a)`-> `let(b, 1)(let(a, a_val)(a))`
Expand Down Expand Up @@ -339,9 +452,13 @@ def transform_propagate_nested_let(self, node: ir.FunCall, **kwargs) -> Optional
return None

def transform_inline_trivial_let(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]:
if cpm.is_let(node) and isinstance(node.fun.expr, ir.SymRef): # type: ignore[attr-defined] # ensured by is_let
# `let(a, 1)(a)` -> `1`
for arg_sym, arg in zip(node.fun.params, node.args): # type: ignore[attr-defined] # ensured by is_let
if isinstance(node.fun.expr, ir.SymRef) and node.fun.expr.id == arg_sym.id: # type: ignore[attr-defined] # ensured by is_let
return arg
if cpm.is_let(node):
if isinstance(node.fun.expr, ir.SymRef): # type: ignore[attr-defined] # ensured by is_let
# `let(a, 1)(a)` -> `1`
for arg_sym, arg in zip(node.fun.params, node.args): # type: ignore[attr-defined] # ensured by is_let
if isinstance(node.fun.expr, ir.SymRef) and node.fun.expr.id == arg_sym.id: # type: ignore[attr-defined] # ensured by is_let
return arg
if any(trivial_args := [isinstance(arg, (ir.SymRef, ir.Literal)) for arg in node.args]):
return inline_lambda(node, eligible_params=trivial_args)

return None
24 changes: 20 additions & 4 deletions src/gt4py/next/iterator/transforms/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def apply_common_transforms(

tmp_uids = eve_utils.UIDGenerator(prefix="__tmp")
mergeasfop_uids = eve_utils.UIDGenerator()
collapse_tuple_uids = eve_utils.UIDGenerator()

ir = MergeLet().visit(ir)
ir = inline_fundefs.InlineFundefs().visit(ir)
Expand All @@ -80,7 +81,12 @@ def apply_common_transforms(
# Inline. The domain inference can not handle "user" functions, e.g. `let f = λ(...) → ... in f(...)`
ir = InlineLambdas.apply(ir, opcount_preserving=True, force_inline_lambda_args=True)
# required in order to get rid of expressions without a domain (e.g. when a tuple element is never accessed)
ir = CollapseTuple.apply(ir, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program
ir = CollapseTuple.apply(
ir,
flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES,
uids=collapse_tuple_uids,
offset_provider_type=offset_provider_type,
) # type: ignore[assignment] # always an itir.Program
ir = infer_domain.infer_program(
ir, # type: ignore[arg-type] # always an itir.Program
offset_provider=offset_provider,
Expand All @@ -94,7 +100,12 @@ def apply_common_transforms(
inlined = ConstantFolding.apply(inlined) # type: ignore[assignment] # always an itir.Program
# This pass is required to be in the loop such that when an `if_` call with tuple arguments
# is constant-folded the surrounding tuple_get calls can be removed.
inlined = CollapseTuple.apply(inlined, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program
inlined = CollapseTuple.apply(
inlined,
flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES,
uids=collapse_tuple_uids,
offset_provider_type=offset_provider_type,
) # type: ignore[assignment] # always an itir.Program
inlined = InlineScalar.apply(inlined, offset_provider_type=offset_provider_type)

# This pass is required to run after CollapseTuple as otherwise we can not inline
Expand Down Expand Up @@ -126,7 +137,10 @@ def apply_common_transforms(
# only run the unconditional version here instead of in the loop above.
if unconditionally_collapse_tuples:
ir = CollapseTuple.apply(
ir, ignore_tuple_size=True, offset_provider_type=offset_provider_type
ir,
havogt marked this conversation as resolved.
Show resolved Hide resolved
ignore_tuple_size=True,
uids=collapse_tuple_uids,
offset_provider_type=offset_provider_type,
) # type: ignore[assignment] # always an itir.Program

ir = NormalizeShifts().visit(ir)
Expand Down Expand Up @@ -164,7 +178,9 @@ def apply_fieldview_transforms(
ir = inline_fundefs.prune_unreferenced_fundefs(ir)
ir = InlineLambdas.apply(ir, opcount_preserving=True, force_inline_lambda_args=True)
ir = CollapseTuple.apply(
ir, offset_provider_type=common.offset_provider_to_type(offset_provider)
ir,
flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES,
offset_provider_type=common.offset_provider_to_type(offset_provider),
) # type: ignore[assignment] # type is still `itir.Program`
ir = infer_domain.infer_program(ir, offset_provider=offset_provider)
return ir
Loading