Skip to content

Commit

Permalink
CollapseTuple configurable such that whether the ITIR type inference …
Browse files Browse the repository at this point in the history
…is used to get the tuple size or the simple heuristics can be configured using a boolean flag to the pass

Execute with ITIR type inference once in loop in pass manager and in all subsequent runs use the simple heuristics
  • Loading branch information
nfarabullini committed Nov 3, 2023
1 parent 1e2de16 commit 555a40d
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 10 deletions.
34 changes: 27 additions & 7 deletions src/gt4py/next/iterator/transforms/collapse_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,20 @@
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later

from dataclasses import dataclass
from typing import Optional

from gt4py import eve
from gt4py.next import type_inference
from gt4py.next.iterator import ir, type_inference as it_type_inference


def _get_tuple_size(type_: type_inference.Type) -> int:
assert isinstance(type_, it_type_inference.Val) and isinstance(
type_.dtype, it_type_inference.Tuple
def _get_tuple_size(elem: type_inference.Type | ir.Node) -> int:
infered_type = it_type_inference.infer(elem) if isinstance(elem, ir.Node) else elem
assert isinstance(infered_type, it_type_inference.Val) and isinstance(
infered_type.dtype, it_type_inference.Tuple
)
return len(type_.dtype)
return len(infered_type.dtype)


@dataclass(frozen=True)
Expand All @@ -38,6 +39,8 @@ class CollapseTuple(eve.NodeTranslator):
ignore_tuple_size: bool
collapse_make_tuple_tuple_get: bool
collapse_tuple_get_make_tuple: bool
collapse_tuple_inference: bool
node_types: Optional[dict[int, type_inference.Type]] = None

@classmethod
def apply(
Expand All @@ -48,15 +51,29 @@ def apply(
# the following options are mostly for allowing separate testing of the modes
collapse_make_tuple_tuple_get: bool = True,
collapse_tuple_get_make_tuple: bool = True,
collapse_tuple_inference: bool = False,
) -> ir.Node:
"""
Simplifies `make_tuple`, `tuple_get` calls.
If `ignore_tuple_size`, apply the transformation even if length of the inner tuple
is greater than the length of the outer tuple.
"""
if collapse_tuple_inference:
node_types = it_type_inference.infer_all(node)
return cls(
ignore_tuple_size,
collapse_make_tuple_tuple_get,
collapse_tuple_get_make_tuple,
collapse_tuple_inference,
node_types,
).visit(node)

return cls(
ignore_tuple_size, collapse_make_tuple_tuple_get, collapse_tuple_get_make_tuple
ignore_tuple_size,
collapse_make_tuple_tuple_get,
collapse_tuple_get_make_tuple,
collapse_tuple_inference,
).visit(node)

def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node:
Expand All @@ -79,7 +96,10 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node:
# tuple argument differs, just continue with the rest of the tree
return self.generic_visit(node)

if self.ignore_tuple_size or _get_tuple_size(first_expr) == len(node.args):
tuple_expr: type_inference.Type | ir.Node = (
self.node_types[id(first_expr)] if self.node_types is not None else first_expr
)
if self.ignore_tuple_size or _get_tuple_size(tuple_expr) == len(node.args):
return first_expr
if (
self.collapse_tuple_get_make_tuple
Expand Down
9 changes: 6 additions & 3 deletions src/gt4py/next/iterator/transforms/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,15 @@ def apply_common_transforms(
unconditionally_collapse_tuples=False,
):
if lift_mode is None:
lift_mode = LiftMode.FORCE_INLINE
lift_mode = LiftMode.FORCE_TEMPORARIES
assert isinstance(lift_mode, LiftMode)
ir = MergeLet().visit(ir)
ir = InlineFundefs().visit(ir)
ir = PruneUnreferencedFundefs().visit(ir)
ir = PropagateDeref.apply(ir)
ir = NormalizeShifts().visit(ir)

for _ in range(10):
for i in range(10):
inlined = ir

inlined = _inline_lifts(inlined, lift_mode)
Expand All @@ -106,7 +106,10 @@ def apply_common_transforms(
inlined = ConstantFolding.apply(inlined)
# This pass is required to be in the loop such that when an `if_` call with tuple arguments
# is constant-folded the surrounding tuple_get calls can be removed.
inlined = CollapseTuple.apply(inlined)
if i == 1:
inlined = CollapseTuple.apply(inlined, collapse_tuple_inference=True)
else:
inlined = CollapseTuple.apply(inlined, collapse_tuple_inference=False)

if inlined == ir:
break
Expand Down

0 comments on commit 555a40d

Please sign in to comment.