Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: Limit use of global type inference in CollapseTuple pass #1355

Merged
merged 16 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 36 additions & 10 deletions src/gt4py/next/iterator/transforms/collapse_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,36 @@
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later

from dataclasses import dataclass
from typing import Optional

from gt4py import eve
from gt4py.next import type_inference
from gt4py.next.iterator import ir, type_inference as it_type_inference


def _get_tuple_size(type_: type_inference.Type) -> int:
assert isinstance(type_, it_type_inference.Val) and isinstance(
type_.dtype, it_type_inference.Tuple
)
class UnknownLength:
pass


def _get_tuple_size(elem: ir.Node, node_types: Optional[dict] = None) -> int | type[UnknownLength]:
if node_types:
type_ = node_types[id(elem)]
# global inference should always give a length, function should fail otherwise
assert isinstance(type_, it_type_inference.Val) and isinstance(
type_.dtype, it_type_inference.Tuple
)
else:
# use local type inference if no global information is available
assert isinstance(elem, ir.Node)
type_ = it_type_inference.infer(elem)

if not (
isinstance(type_, it_type_inference.Val)
and isinstance(type_.dtype, it_type_inference.Tuple)
):
return UnknownLength

return len(type_.dtype)


Expand All @@ -38,8 +56,8 @@ class CollapseTuple(eve.NodeTranslator):
ignore_tuple_size: bool
collapse_make_tuple_tuple_get: bool
collapse_tuple_get_make_tuple: bool

_node_types: dict[int, type_inference.Type]
use_global_type_inference: bool
_node_types: Optional[dict[int, type_inference.Type]] = None

@classmethod
def apply(
Expand All @@ -50,22 +68,30 @@ def apply(
# the following options are mostly for allowing separate testing of the modes
collapse_make_tuple_tuple_get: bool = True,
collapse_tuple_get_make_tuple: bool = True,
use_global_type_inference: bool = False,
) -> ir.Node:
"""
Simplifies `make_tuple`, `tuple_get` calls.

If `ignore_tuple_size`, apply the transformation even if length of the inner tuple
is greater than the length of the outer tuple.
"""
node_types = it_type_inference.infer_all(node)

node_types = it_type_inference.infer_all(node) if use_global_type_inference else None
return cls(
ignore_tuple_size,
collapse_make_tuple_tuple_get,
collapse_tuple_get_make_tuple,
use_global_type_inference,
node_types,
).visit(node)

return cls(
ignore_tuple_size,
collapse_make_tuple_tuple_get,
collapse_tuple_get_make_tuple,
use_global_type_inference,
).visit(node)

def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node:
if (
self.collapse_make_tuple_tuple_get
Expand All @@ -86,7 +112,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node:
# tuple argument differs, just continue with the rest of the tree
return self.generic_visit(node)

if self.ignore_tuple_size or _get_tuple_size(self._node_types[id(first_expr)]) == len(
if self.ignore_tuple_size or _get_tuple_size(first_expr, self._node_types) == len(
node.args
):
return first_expr
Expand Down
6 changes: 5 additions & 1 deletion src/gt4py/next/iterator/transforms/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,11 @@ def apply_common_transforms(
inlined = ConstantFolding.apply(inlined)
# This pass is required to be in the loop such that when an `if_` call with tuple arguments
# is constant-folded the surrounding tuple_get calls can be removed.
inlined = CollapseTuple.apply(inlined)
inlined = CollapseTuple.apply(
inlined,
# to limit number of times global type inference is executed, only in the last iterations.
use_global_type_inference=inlined == ir,
)

if inlined == ir:
break
Expand Down