Skip to content

Commit

Permalink
feat[next]: Limit use of global type inference in CollapseTuple pass (#…
Browse files Browse the repository at this point in the history
…1355)

CollapseTuple configurable such that whether the ITIR type inference is used to get the tuple size or the simple heuristics can be configured using a boolean flag to the pass
Execute with ITIR type inference once in loop in pass manager and in all subsequent runs use the simple heuristics
  • Loading branch information
nfarabullini authored Nov 6, 2023
1 parent 3c463a6 commit 4d8df69
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 11 deletions.
46 changes: 36 additions & 10 deletions src/gt4py/next/iterator/transforms/collapse_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,36 @@
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later

from dataclasses import dataclass
from typing import Optional

from gt4py import eve
from gt4py.next import type_inference
from gt4py.next.iterator import ir, type_inference as it_type_inference


def _get_tuple_size(type_: type_inference.Type) -> int:
assert isinstance(type_, it_type_inference.Val) and isinstance(
type_.dtype, it_type_inference.Tuple
)
class UnknownLength:
pass


def _get_tuple_size(elem: ir.Node, node_types: Optional[dict] = None) -> int | type[UnknownLength]:
if node_types:
type_ = node_types[id(elem)]
# global inference should always give a length, function should fail otherwise
assert isinstance(type_, it_type_inference.Val) and isinstance(
type_.dtype, it_type_inference.Tuple
)
else:
# use local type inference if no global information is available
assert isinstance(elem, ir.Node)
type_ = it_type_inference.infer(elem)

if not (
isinstance(type_, it_type_inference.Val)
and isinstance(type_.dtype, it_type_inference.Tuple)
):
return UnknownLength

return len(type_.dtype)


Expand All @@ -38,8 +56,8 @@ class CollapseTuple(eve.NodeTranslator):
ignore_tuple_size: bool
collapse_make_tuple_tuple_get: bool
collapse_tuple_get_make_tuple: bool

_node_types: dict[int, type_inference.Type]
use_global_type_inference: bool
_node_types: Optional[dict[int, type_inference.Type]] = None

@classmethod
def apply(
Expand All @@ -50,22 +68,30 @@ def apply(
# the following options are mostly for allowing separate testing of the modes
collapse_make_tuple_tuple_get: bool = True,
collapse_tuple_get_make_tuple: bool = True,
use_global_type_inference: bool = False,
) -> ir.Node:
"""
Simplifies `make_tuple`, `tuple_get` calls.
If `ignore_tuple_size`, apply the transformation even if length of the inner tuple
is greater than the length of the outer tuple.
"""
node_types = it_type_inference.infer_all(node)

node_types = it_type_inference.infer_all(node) if use_global_type_inference else None
return cls(
ignore_tuple_size,
collapse_make_tuple_tuple_get,
collapse_tuple_get_make_tuple,
use_global_type_inference,
node_types,
).visit(node)

return cls(
ignore_tuple_size,
collapse_make_tuple_tuple_get,
collapse_tuple_get_make_tuple,
use_global_type_inference,
).visit(node)

def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node:
if (
self.collapse_make_tuple_tuple_get
Expand All @@ -86,7 +112,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node:
# tuple argument differs, just continue with the rest of the tree
return self.generic_visit(node)

if self.ignore_tuple_size or _get_tuple_size(self._node_types[id(first_expr)]) == len(
if self.ignore_tuple_size or _get_tuple_size(first_expr, self._node_types) == len(
node.args
):
return first_expr
Expand Down
6 changes: 5 additions & 1 deletion src/gt4py/next/iterator/transforms/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,11 @@ def apply_common_transforms(
inlined = ConstantFolding.apply(inlined)
# This pass is required to be in the loop such that when an `if_` call with tuple arguments
# is constant-folded the surrounding tuple_get calls can be removed.
inlined = CollapseTuple.apply(inlined)
inlined = CollapseTuple.apply(
inlined,
# to limit number of times global type inference is executed, only in the last iterations.
use_global_type_inference=inlined == ir,
)

if inlined == ir:
break
Expand Down

0 comments on commit 4d8df69

Please sign in to comment.