Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature[next]: Inline center deref lift vars #1455

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2023, ETH Zurich
# All rights reserved.
#
# This file is part of the GT4Py project and the GridTools framework.
# GT4Py is free software: you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the
# Free Software Foundation, either version 3 of the License, or any later
# version. See the LICENSE.txt file at the top-level directory of this
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later

import dataclasses
from typing import ClassVar, Optional

import gt4py.next.iterator.ir_utils.common_pattern_matcher as common_pattern_matcher
from gt4py import eve
from gt4py.eve import utils as eve_utils
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import ir_makers as im
from gt4py.next.iterator.transforms.inline_lambdas import inline_lambda
from gt4py.next.iterator.transforms.inline_lifts import InlineLifts
from gt4py.next.iterator.transforms.trace_shifts import TraceShifts, copy_recorded_shifts


def is_center_derefed_only(node: itir.Node) -> bool:
return hasattr(node.annex, "recorded_shifts") and node.annex.recorded_shifts in [set(), {()}]


@dataclasses.dataclass
class InlineCenterDerefLiftVars(eve.NodeTranslator):
"""
Inline all variables which are derefed in the center only (i.e. unshifted).

Consider the following example where `var` is never shifted:

`let(var, (↑stencil)(it))(·var + ·var)`

Directly inlining `var` would increase the size of the tree and duplicate the calculation.
Instead, this pass computes the value at the current location once and replaces all previous
references to `var` by an applied lift which captures this value.

`let(_icdlv_1, stencil(it))(·(↑(λ() → _icdlv_1) + ·(↑(λ() → _icdlv_1))`

The lift inliner can then later easily transform this into a nice expression:

`let(_icdlv_1, stencil(it))(_icdlv_1 + _icdlv_1)`

Note: This pass uses and preserves the `recorded_shifts` annex.
"""

PRESERVED_ANNEX_ATTRS: ClassVar[tuple[str, ...]] = ("recorded_shifts",)

uids: eve_utils.UIDGenerator

@classmethod
def apply(cls, node: itir.FencilDefinition, uids: Optional[eve_utils.UIDGenerator] = None):
if not uids:
uids = eve_utils.UIDGenerator()
return cls(uids=uids).visit(node)

def visit_StencilClosure(self, node: itir.StencilClosure, **kwargs):
# TODO(tehrengruber): move the analysis out of this pass and just make it a requirement
# such that we don't need to run in multiple times if multiple passes use it.
TraceShifts.apply(node, save_to_annex=True)
return self.generic_visit(node, **kwargs)

def visit_FunCall(self, node: itir.FunCall, **kwargs):
node = self.generic_visit(node)
if common_pattern_matcher.is_let(node):
assert isinstance(node.fun, itir.Lambda) # to make mypy happy
eligible_params = [False] * len(node.fun.params)
new_args = []
bound_scalars: dict[str, itir.Expr] = {}

for i, (param, arg) in enumerate(zip(node.fun.params, node.args)):
if common_pattern_matcher.is_applied_lift(arg) and is_center_derefed_only(param):
eligible_params[i] = True
bound_arg_name = self.uids.sequential_id(prefix="_icdlv")
capture_lift = im.promote_to_const_iterator(bound_arg_name)
copy_recorded_shifts(from_=param, to=capture_lift)
new_args.append(capture_lift)
# since we deref an applied lift here we can (but don't need to) immediately
# inline
bound_scalars[bound_arg_name] = InlineLifts(
flags=InlineLifts.Flag.INLINE_TRIVIAL_DEREF_LIFT
).visit(im.deref(arg), recurse=False)
else:
new_args.append(arg)

if any(eligible_params):
new_node = inline_lambda(
im.call(node.fun)(*new_args),
eligible_params=eligible_params,
)
# TODO(tehrengruber): propagate let outwards
return im.let(*bound_scalars.items())(new_node) # type: ignore[arg-type] # mypy not smart enough

return node
8 changes: 6 additions & 2 deletions src/gt4py/next/iterator/transforms/inline_lambdas.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,13 @@ def new_name(name):

@dataclasses.dataclass
class InlineLambdas(PreserveLocationVisitor, NodeTranslator):
"""Inline lambda calls by substituting every argument by its value."""
"""
Inline lambda calls by substituting every argument by its value.

PRESERVED_ANNEX_ATTRS = ("type",)
Note: This pass preserves, but doesn't use the `type` and `recorded_shifts` annex.
"""

PRESERVED_ANNEX_ATTRS = ("type", "recorded_shifts")

opcount_preserving: bool

Expand Down
11 changes: 8 additions & 3 deletions src/gt4py/next/iterator/transforms/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
import enum
from typing import Callable, Optional

from gt4py.next.iterator import ir
from gt4py.eve import utils as eve_utils
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.transforms import simple_inline_heuristic
from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet
from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple
Expand All @@ -24,6 +25,7 @@
from gt4py.next.iterator.transforms.eta_reduction import EtaReduction
from gt4py.next.iterator.transforms.fuse_maps import FuseMaps
from gt4py.next.iterator.transforms.global_tmps import CreateGlobalTmps
from gt4py.next.iterator.transforms.inline_center_deref_lift_vars import InlineCenterDerefLiftVars
from gt4py.next.iterator.transforms.inline_fundefs import InlineFundefs, PruneUnreferencedFundefs
from gt4py.next.iterator.transforms.inline_into_scan import InlineIntoScan
from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas
Expand Down Expand Up @@ -74,7 +76,7 @@ def _inline_into_scan(ir, *, max_iter=10):
# TODO(tehrengruber): Revisit interface to configure temporary extraction. We currently forward
# `lift_mode` and `temporary_extraction_heuristics` which is inconvenient.
def apply_common_transforms(
ir: ir.Node,
ir: itir.Node,
*,
lift_mode=None,
offset_provider=None,
Expand All @@ -83,10 +85,12 @@ def apply_common_transforms(
force_inline_lambda_args=False,
unconditionally_collapse_tuples=False,
temporary_extraction_heuristics: Optional[
Callable[[ir.StencilClosure], Callable[[ir.Expr], bool]]
Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]]
] = None,
symbolic_domain_sizes: Optional[dict[str, str]] = None,
):
icdlv_uids = eve_utils.UIDGenerator()

if lift_mode is None:
lift_mode = LiftMode.FORCE_INLINE
assert isinstance(lift_mode, LiftMode)
Expand All @@ -99,6 +103,7 @@ def apply_common_transforms(
for _ in range(10):
inlined = ir

inlined = InlineCenterDerefLiftVars.apply(inlined, uids=icdlv_uids) # type: ignore[arg-type] # always a fencil
inlined = _inline_lifts(inlined, lift_mode)

inlined = InlineLambdas.apply(
Expand Down
6 changes: 4 additions & 2 deletions src/gt4py/next/iterator/transforms/remap_symbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@


class RemapSymbolRefs(PreserveLocationVisitor, NodeTranslator):
PRESERVED_ANNEX_ATTRS = ("type",)
# This pass preserves, but doesn't use the `type` and `recorded_shifts` annex.
PRESERVED_ANNEX_ATTRS = ("type", "recorded_shifts")

def visit_SymRef(self, node: ir.SymRef, *, symbol_map: Dict[str, ir.Node]):
return symbol_map.get(str(node.id), node)
Expand All @@ -40,7 +41,8 @@ def generic_visit(self, node: ir.Node, **kwargs: Any): # type: ignore[override]


class RenameSymbols(PreserveLocationVisitor, NodeTranslator):
PRESERVED_ANNEX_ATTRS = ("type",)
# This pass preserves, but doesn't use the `type` and `recorded_shifts` annex.
PRESERVED_ANNEX_ATTRS = ("type", "recorded_shifts")

def visit_Sym(
self, node: ir.Sym, *, name_map: Dict[str, str], active: Optional[Set[str]] = None
Expand Down
50 changes: 48 additions & 2 deletions src/gt4py/next/iterator/transforms/trace_shifts.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,37 @@
from collections.abc import Callable
from typing import Any, Final, Iterable, Literal

from gt4py import eve
from gt4py.eve import NodeTranslator, PreserveLocationVisitor
from gt4py.next.iterator import ir
from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift


class ValidateRecordedShiftsAnnex(eve.NodeVisitor):
"""Ensure every applied lift and its arguments have the `recorded_shifts` annex populated."""

def visit_FunCall(self, node: ir.FunCall):
if is_applied_lift(node):
assert hasattr(node.annex, "recorded_shifts")

if len(node.annex.recorded_shifts) == 0:
return

if isinstance(node.fun.args[0], ir.Lambda): # type: ignore[attr-defined] # ensured by is_applied_lift
stencil = node.fun.args[0] # type: ignore[attr-defined] # ensured by is_applied_lift
for param in stencil.params:
assert hasattr(param.annex, "recorded_shifts")

self.generic_visit(node)


def copy_recorded_shifts(from_: ir.Node, to: ir.Node) -> None:
"""
Copy `recorded_shifts` annex attribute from one node to another.

This function mainly exists for readability reasons.
"""
to.annex.recorded_shifts = from_.annex.recorded_shifts


class Sentinel(enum.Enum):
Expand Down Expand Up @@ -246,7 +275,9 @@ def visit_SymRef(self, node: ir.SymRef, *, ctx: dict[str, Any]) -> Any:
return ctx[node.id]
elif node.id in ir.TYPEBUILTINS:
return Sentinel.TYPE
return _combine
elif node.id in (ir.ARITHMETIC_BUILTINS | {"list_get", "make_const_list", "cast_"}):
return _combine
raise ValueError(f"Undefined symbol {node.id}")

def visit_FunCall(self, node: ir.FunCall, *, ctx: dict[str, Any]) -> Any:
if node.fun == ir.SymRef(id="tuple_get"):
Expand Down Expand Up @@ -301,7 +332,7 @@ def visit_StencilClosure(self, node: ir.StencilClosure):

@classmethod
def apply(
cls, node: ir.StencilClosure, *, inputs_only=True
cls, node: ir.StencilClosure, *, inputs_only=True, save_to_annex=False
) -> (
dict[int, set[tuple[ir.OffsetLiteral, ...]]] | dict[str, set[tuple[ir.OffsetLiteral, ...]]]
):
Expand All @@ -310,10 +341,25 @@ def apply(

recorded_shifts = instance.shift_recorder.recorded_shifts

if save_to_annex:
_save_to_annex(node, recorded_shifts)

if __debug__:
ValidateRecordedShiftsAnnex().visit(node)

if inputs_only:
inputs_shifts = {}
for inp in node.inputs:
inputs_shifts[str(inp.id)] = recorded_shifts[id(inp)]
return inputs_shifts

return recorded_shifts


def _save_to_annex(
node: ir.Node,
recorded_shifts: dict[int, set[tuple[ir.OffsetLiteral, ...]]],
) -> None:
for child_node in node.pre_walk_values():
if id(child_node) in recorded_shifts:
child_node.annex.recorded_shifts = recorded_shifts[id(child_node)]
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def conditional_indirection(inp, cond):
def test_simple_indirection(program_processor):
program_processor, validate = program_processor

pytest.xfail("Applied shifts in if_ statements are not supported in TraceShift pass.")

if program_processor in [
type_check.check_type_inference,
gtfn_format_sourcecode,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2023, ETH Zurich
# All rights reserved.
#
# This file is part of the GT4Py project and the GridTools framework.
# GT4Py is free software: you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the
# Free Software Foundation, either version 3 of the License, or any later
# version. See the LICENSE.txt file at the top-level directory of this
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later

from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import ir_makers as im
from gt4py.next.iterator.transforms.inline_center_deref_lift_vars import InlineCenterDerefLiftVars


def wrap_in_fencil(expr: itir.Expr) -> itir.FencilDefinition:
return itir.FencilDefinition(
id="f",
function_definitions=[],
params=[im.sym("d"), im.sym("inp"), im.sym("out")],
closures=[
itir.StencilClosure(
domain=im.call("cartesian_domain")(),
stencil=im.lambda_("it")(expr),
output=im.ref("out"),
inputs=[im.ref("inp")],
)
],
)


def unwrap_from_fencil(fencil: itir.FencilDefinition) -> itir.Expr:
return fencil.closures[0].stencil.expr


def test_simple():
testee = im.let("var", im.lift("deref")("it"))(im.deref("var"))
expected = "(λ(_icdlv_1) → ·(↑(λ() → _icdlv_1))())(·it)"

actual = unwrap_from_fencil(InlineCenterDerefLiftVars.apply(wrap_in_fencil(testee)))
assert str(actual) == expected


def test_double_deref():
testee = im.let("var", im.lift("deref")("it"))(im.plus(im.deref("var"), im.deref("var")))
expected = "(λ(_icdlv_1) → ·(↑(λ() → _icdlv_1))() + ·(↑(λ() → _icdlv_1))())(·it)"

actual = unwrap_from_fencil(InlineCenterDerefLiftVars.apply(wrap_in_fencil(testee)))
assert str(actual) == expected


def test_deref_at_non_center_different_pos():
testee = im.let("var", im.lift("deref")("it"))(im.deref(im.shift("I", 1)("var")))

actual = unwrap_from_fencil(InlineCenterDerefLiftVars.apply(wrap_in_fencil(testee)))
assert testee == actual


def test_deref_at_multiple_pos():
testee = im.let("var", im.lift("deref")("it"))(
im.plus(im.deref("var"), im.deref(im.shift("I", 1)("var")))
)

actual = unwrap_from_fencil(InlineCenterDerefLiftVars.apply(wrap_in_fencil(testee)))
assert testee == actual
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,13 @@ def test_neighbors():

def test_reduce():
testee = ir.StencilClosure(
# λ(inp) → reduce(plus, init)(·inp)
# λ(inp) → reduce(plus, 0.)(·inp)
stencil=ir.Lambda(
params=[ir.Sym(id="inp")],
expr=ir.FunCall(
fun=ir.FunCall(
fun=ir.SymRef(id="reduce"), args=[ir.SymRef(id="plus"), ir.SymRef(id="init")]
fun=ir.SymRef(id="reduce"),
args=[ir.SymRef(id="plus"), im.literal_from_value(0.0)],
),
args=[ir.FunCall(fun=ir.SymRef(id="deref"), args=[ir.SymRef(id="inp")])],
),
Expand Down
Loading