From b0d688a598967ca1fc630b9906a225730123896d Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 25 Sep 2024 21:19:17 +0200 Subject: [PATCH 001/150] Add support for IfStmt in ITIR --- src/gt4py/next/iterator/embedded.py | 8 ++ src/gt4py/next/iterator/ir.py | 7 ++ src/gt4py/next/iterator/pretty_parser.py | 34 ++++++++- src/gt4py/next/iterator/pretty_printer.py | 13 ++++ src/gt4py/next/iterator/runtime.py | 7 +- src/gt4py/next/iterator/tracing.py | 21 +++++ .../next/iterator/type_system/inference.py | 5 ++ .../codegens/gtfn/codegen.py | 10 +++ .../codegens/gtfn/gtfn_ir.py | 29 ++++--- .../codegens/gtfn/gtfn_module.py | 9 +-- .../codegens/gtfn/itir_to_gtfn_ir.py | 15 +++- .../program_processors/runners/roundtrip.py | 4 + tests/next_tests/definitions.py | 2 + .../iterator_tests/test_if_stmt.py | 76 +++++++++++++++++++ .../iterator_tests/test_pretty_parser.py | 31 ++++++++ 15 files changed, 250 insertions(+), 21 deletions(-) create mode 100644 tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_if_stmt.py diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index e1b52043ed..997851d0b7 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -1603,6 +1603,14 @@ def set_at(expr: common.Field, domain: common.DomainLike, target: common.Mutable operators._tuple_assign_field(target, expr, common.domain(domain)) +@runtime.if_stmt.register(EMBEDDED) +def if_stmt(cond: bool, true_branch: Callable[[], None], false_branch: Callable[[], None]) -> None: + if cond: + true_branch() + else: + false_branch() + + def _compute_at_position( sten: Callable, ins: Sequence[common.Field], diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index 28adaaddf1..dc05f21d6a 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -211,6 +211,12 @@ class SetAt(Stmt): # from JAX array.at[...].set() target: Expr # `make_tuple` or SymRef +class IfStmt(Stmt): + cond: Expr + true_branch: list[Stmt] + false_branch: list[Stmt] + + class Temporary(Node): id: Coerced[eve.SymbolName] domain: Optional[Expr] = None @@ -243,3 +249,4 @@ class Program(Node, ValidatedSymbolTableTrait): FencilDefinition.__hash__ = Node.__hash__ # type: ignore[method-assign] Program.__hash__ = Node.__hash__ # type: ignore[method-assign] SetAt.__hash__ = Node.__hash__ # type: ignore[method-assign] +IfStmt.__hash__ = Node.__hash__ # type: ignore[method-assign] diff --git a/src/gt4py/next/iterator/pretty_parser.py b/src/gt4py/next/iterator/pretty_parser.py index 6af573272b..08459a9423 100644 --- a/src/gt4py/next/iterator/pretty_parser.py +++ b/src/gt4py/next/iterator/pretty_parser.py @@ -8,7 +8,7 @@ from typing import Union -from lark import lark, lexer as lark_lexer, visitors as lark_visitors +from lark import lark, lexer as lark_lexer, tree as lark_tree, visitors as lark_visitors from gt4py.next.iterator import ir from gt4py.next.iterator.ir_utils import ir_makers as im @@ -21,6 +21,7 @@ | declaration | stencil_closure | set_at + | if_stmt | program | prec0 @@ -78,13 +79,17 @@ | named_range | "(" prec0 ")" + ?stmt: set_at | if_stmt + set_at: prec0 "@" prec0 "←" prec1 ";" + else_branch_seperator: "else" + if_stmt: "if" "(" prec0 ")" "{" ( stmt )* "}" else_branch_seperator "{" ( stmt )* "}" + named_range: AXIS_NAME ":" "[" prec0 "," prec0 ")" function_definition: ID_NAME "=" "λ(" ( SYM "," )* SYM? ")" "→" prec0 ";" declaration: ID_NAME "=" "temporary(" "domain=" prec0 "," "dtype=" TYPE_LITERAL ")" ";" stencil_closure: prec0 "←" "(" prec0 ")" "(" ( SYM_REF ", " )* SYM_REF ")" "@" prec0 ";" - set_at: prec0 "@" prec0 "←" prec1 ";" fencil_definition: ID_NAME "(" ( SYM "," )* SYM ")" "{" ( function_definition )* ( stencil_closure )+ "}" - program: ID_NAME "(" ( SYM "," )* SYM ")" "{" ( function_definition )* ( declaration )* ( set_at )+ "}" + program: ID_NAME "(" ( SYM "," )* SYM ")" "{" ( function_definition )* ( declaration )* ( stmt )+ "}" %import common (CNAME, SIGNED_FLOAT, SIGNED_INT, WS) %ignore WS @@ -215,6 +220,27 @@ def stencil_closure(self, *args: ir.Expr) -> ir.StencilClosure: output, stencil, *inputs, domain = args return ir.StencilClosure(domain=domain, stencil=stencil, output=output, inputs=inputs) + def if_stmt(self, cond: ir.Expr, *args): + found_else_seperator = False + true_branch = [] + false_branch = [] + for arg in args: + if isinstance(arg, lark_tree.Tree): + assert arg.data == "else_branch_seperator" + found_else_seperator = True + continue + + if not found_else_seperator: + true_branch.append(arg) + else: + false_branch.append(arg) + + return ir.IfStmt( + cond=cond, + true_branch=true_branch, + false_branch=false_branch, + ) + def declaration(self, *args: ir.Expr) -> ir.Temporary: tid, domain, dtype = args return ir.Temporary(id=tid, domain=domain, dtype=dtype) @@ -253,7 +279,7 @@ def program(self, fid: str, *args: ir.Node) -> ir.Program: elif isinstance(arg, ir.Temporary): declarations.append(arg) else: - assert isinstance(arg, ir.SetAt) + assert isinstance(arg, ir.Stmt) body.append(arg) return ir.Program( id=fid, diff --git a/src/gt4py/next/iterator/pretty_printer.py b/src/gt4py/next/iterator/pretty_printer.py index d9f62717a6..99287f8a11 100644 --- a/src/gt4py/next/iterator/pretty_printer.py +++ b/src/gt4py/next/iterator/pretty_printer.py @@ -299,6 +299,19 @@ def visit_SetAt(self, node: ir.SetAt, *, prec: int) -> list[str]: ) return self._optimum(h, v) + def visit_IfStmt(self, node: ir.IfStmt, *, prec: int) -> list[str]: + cond = self.visit(node.cond, prec=0) + true_branch = self._vmerge(*self.visit(node.true_branch, prec=0)) + false_branch = self._vmerge(*self.visit(node.false_branch, prec=0)) + + hhead = self._hmerge(["if ("], cond, [") {"]) + vhead = self._vmerge(["if ("], cond, [") {"]) + head = self._optimum(hhead, vhead) + + return self._vmerge( + head, self._indent(true_branch), ["} else {"], self._indent(false_branch), ["}"] + ) + def visit_FencilDefinition(self, node: ir.FencilDefinition, *, prec: int) -> list[str]: assert prec == 0 function_definitions = self.visit(node.function_definitions, prec=0) diff --git a/src/gt4py/next/iterator/runtime.py b/src/gt4py/next/iterator/runtime.py index 1a8d22b090..6618a34cfd 100644 --- a/src/gt4py/next/iterator/runtime.py +++ b/src/gt4py/next/iterator/runtime.py @@ -25,7 +25,7 @@ ) -__all__ = ["offset", "fundef", "fendef", "closure", "set_at"] +__all__ = ["offset", "fundef", "fendef", "closure", "set_at", "if_stmt"] @dataclass(frozen=True) @@ -207,3 +207,8 @@ def closure(*args): # TODO remove @builtin_dispatch def set_at(*args): return BackendNotSelectedError() + + +@builtin_dispatch +def if_stmt(*args): + return BackendNotSelectedError() diff --git a/src/gt4py/next/iterator/tracing.py b/src/gt4py/next/iterator/tracing.py index a0f0b86392..6772d4b507 100644 --- a/src/gt4py/next/iterator/tracing.py +++ b/src/gt4py/next/iterator/tracing.py @@ -247,6 +247,27 @@ def set_at(expr: itir.Expr, domain: itir.Expr, target: itir.Expr) -> None: TracerContext.add_stmt(itir.SetAt(expr=expr, domain=domain, target=target)) +@iterator.runtime.if_stmt.register(TRACING) +def if_stmt( + cond: itir.Expr, true_branch_f: typing.Callable, false_branch_f: typing.Callable +) -> None: + true_branch: List[itir.Stmt] = [] + false_branch: List[itir.Stmt] = [] + + old_body = TracerContext.body + TracerContext.body = true_branch + true_branch_f() + + TracerContext.body = false_branch + false_branch_f() + + TracerContext.body = old_body + + TracerContext.add_stmt( + itir.IfStmt(cond=cond, true_branch=true_branch, false_branch=false_branch) + ) + + def _contains_tuple_dtype_field(arg): if isinstance(arg, tuple): return any(_contains_tuple_dtype_field(el) for el in arg) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 465669245f..c141c80999 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -499,6 +499,11 @@ def visit_Temporary(self, node: itir.Temporary, *, ctx) -> ts.FieldType | ts.Tup lambda dtype: ts.FieldType(dims=domain.dims, dtype=dtype), node.dtype ) + def visit_IfStmt(self, node: itir.IfStmt, *, ctx) -> None: + self.visit(node.cond, ctx=ctx) # TODO: check is boolean + self.visit(node.true_branch, ctx=ctx) + self.visit(node.false_branch, ctx=ctx) + def visit_SetAt(self, node: itir.SetAt, *, ctx) -> None: self.visit(node.expr, ctx=ctx) self.visit(node.domain, ctx=ctx) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py index c9107469fd..92dbcedeaa 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py @@ -166,6 +166,16 @@ def visit_FunCall(self, node: gtfn_ir.FunCall, **kwargs: Any) -> str: "{backend}.vertical_executor({axis})().{'.'.join('arg(' + a + ')' for a in args)}.{'.'.join(scans)}.execute();" ) + IfStmt = as_mako( + """ + if (${cond}) { + ${'\\n'.join(true_branch)} + } else { + ${'\\n'.join(false_branch)} + } + """ + ) + ScanPassDefinition = as_mako( """ struct ${id} : ${'gtfn::fwd' if _this_node.forward else 'gtfn::bwd'} { diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py index 825623d6e8..1995e4de0b 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py @@ -96,23 +96,23 @@ class Backend(Node): domain: Union[SymRef, CartesianDomain, UnstructuredDomain] -def _is_ref_or_tuple_expr_of_ref(expr: Expr) -> bool: +def _is_ref_literal_or_tuple_expr_of_ref(expr: Expr) -> bool: if ( isinstance(expr, FunCall) and isinstance(expr.fun, SymRef) and expr.fun.id == "tuple_get" and len(expr.args) == 2 - and _is_ref_or_tuple_expr_of_ref(expr.args[1]) + and _is_ref_literal_or_tuple_expr_of_ref(expr.args[1]) ): return True if ( isinstance(expr, FunCall) and isinstance(expr.fun, SymRef) and expr.fun.id == "make_tuple" - and all(_is_ref_or_tuple_expr_of_ref(arg) for arg in expr.args) + and all(_is_ref_literal_or_tuple_expr_of_ref(arg) for arg in expr.args) ): return True - if isinstance(expr, SymRef): + if isinstance(expr, (SymRef, Literal)): return True return False @@ -125,7 +125,8 @@ def _values_validator( self: datamodels.DataModelTP, attribute: datamodels.Attribute, value: list[Expr] ) -> None: if not all( - isinstance(el, (SidFromScalar, SidComposite)) or _is_ref_or_tuple_expr_of_ref(el) + isinstance(el, (SidFromScalar, SidComposite)) + or _is_ref_literal_or_tuple_expr_of_ref(el) for el in value ): raise ValueError( @@ -140,11 +141,15 @@ class SidFromScalar(Expr): def _arg_validator( self: datamodels.DataModelTP, attribute: datamodels.Attribute, value: Expr ) -> None: - if not _is_ref_or_tuple_expr_of_ref(value): + if not _is_ref_literal_or_tuple_expr_of_ref(value): raise ValueError("Only 'SymRef' or tuple expr of 'SymRef' allowed.") -class StencilExecution(Node): +class Stmt(Node): + pass + + +class StencilExecution(Stmt): backend: Backend stencil: SymRef output: Union[SymRef, SidComposite] @@ -158,13 +163,19 @@ class Scan(Node): init: Expr -class ScanExecution(Node): +class ScanExecution(Stmt): backend: Backend scans: list[Scan] args: list[Expr] axis: SymRef +class IfStmt(Stmt): + cond: Expr + true_branch: list[Stmt] + false_branch: list[Stmt] + + class TemporaryAllocation(Node): id: SymbolName dtype: str @@ -199,7 +210,7 @@ class Program(Node, ValidatedSymbolTableTrait): function_definitions: list[ Union[FunctionDefinition, ScanPassDefinition, ImperativeFunctionDefinition] ] - executions: list[Union[StencilExecution, ScanExecution]] + executions: list[Stmt] offset_definitions: list[TagDefinition] grid_type: common.GridType temporaries: list[TemporaryAllocation] diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index ac5325aade..5d30141b45 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -75,7 +75,7 @@ def _default_language_settings(self) -> languages.LanguageWithHeaderFilesSetting def _process_regular_arguments( self, - program: itir.FencilDefinition, + program: itir.FencilDefinition | itir.Program, args: tuple[Any, ...], offset_provider: dict[str, Connectivity | Dimension], ) -> tuple[list[interface.Parameter], list[str]]: @@ -154,10 +154,10 @@ def _process_connectivity_args( def _preprocess_program( self, - program: itir.FencilDefinition, + program: itir.FencilDefinition | itir.Program, offset_provider: dict[str, Connectivity | Dimension], ) -> itir.Program: - if not self.enable_itir_transforms: + if isinstance(program, itir.FencilDefinition) and not self.enable_itir_transforms: return fencil_to_program.FencilToProgram().apply( program ) # FIXME[#1582](tehrengruber): should be removed after refactoring to combined IR @@ -188,7 +188,7 @@ def _preprocess_program( def generate_stencil_source( self, - program: itir.FencilDefinition, + program: itir.FencilDefinition | itir.Program, offset_provider: dict[str, Connectivity | Dimension], column_axis: Optional[common.Dimension], ) -> str: @@ -209,7 +209,6 @@ def __call__( ) -> stages.ProgramSource[languages.NanobindSrcL, languages.LanguageWithHeaderFilesSettings]: """Generate GTFN C++ code from the ITIR definition.""" program: itir.FencilDefinition | itir.Program = inp.data - assert isinstance(program, itir.FencilDefinition) # handle regular parameters and arguments of the program (i.e. what the user defined in # the program) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index e9a0ad16c4..3bd96d14d7 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -24,6 +24,7 @@ CastExpr, FunCall, FunctionDefinition, + IfStmt, IntegralConstant, Lambda, Literal, @@ -66,8 +67,11 @@ def pytype_to_cpptype(t: ts.ScalarType | str) -> Optional[str]: _horizontal_dimension = "gtfn::unstructured::dim::horizontal" -def _get_domains(node: Iterable[itir.Stmt]) -> Iterable[itir.FunCall]: - return eve_utils.xiter(node).if_isinstance(itir.SetAt).getattr("domain").to_set() +def _get_domains(nodes: Iterable[itir.Stmt]) -> Iterable[itir.FunCall]: + result = set() + for node in nodes: + result |= node.walk_values().if_isinstance(itir.SetAt).getattr("domain").to_set() + return result def _extract_grid_type(domain: itir.FunCall) -> common.GridType: @@ -573,6 +577,13 @@ def remap_args(s: Scan) -> Scan: def visit_Stmt(self, node: itir.Stmt, **kwargs: Any) -> None: raise AssertionError("All Stmts need to be handled explicitly.") + def visit_IfStmt(self, node: itir.IfStmt, **kwargs: Any) -> IfStmt: + return IfStmt( + cond=self.visit(node.cond, **kwargs), + true_branch=self.visit(node.true_branch, **kwargs), + false_branch=self.visit(node.false_branch, **kwargs), + ) + def visit_SetAt( self, node: itir.SetAt, *, extracted_functions: list, **kwargs: Any ) -> Union[StencilExecution, ScanExecution]: diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 8337a2b44a..d385e078cb 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -66,6 +66,10 @@ def ${id}(${','.join(params)}): """ ) SetAt = as_mako("set_at(${expr}, ${domain}, ${target})") + IfStmt = as_mako("""if_stmt(${cond}, + lambda: [${','.join(true_branch)}], + lambda: [${','.join(false_branch)}] + )""") def visit_Temporary(self, node: itir.Temporary, **kwargs: Any) -> str: assert ( diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index c0066872f3..d1824ffc84 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -104,6 +104,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): USES_DYNAMIC_OFFSETS = "uses_dynamic_offsets" USES_FLOORDIV = "uses_floordiv" USES_IF_STMTS = "uses_if_stmts" +USES_IR_IF_STMTS = "uses_ir_if_stmts" USES_INDEX_FIELDS = "uses_index_fields" USES_LIFT_EXPRESSIONS = "uses_lift_expressions" USES_NEGATIVE_MODULO = "uses_negative_modulo" @@ -146,6 +147,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), ] DACE_SKIP_TEST_LIST = COMMON_SKIP_TEST_LIST + [ + (USES_IR_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE), (USES_SCALAR_IN_DOMAIN_AND_FO, XFAIL, UNSUPPORTED_MESSAGE), (USES_INDEX_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), (USES_LIFT_EXPRESSIONS, XFAIL, UNSUPPORTED_MESSAGE), diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_if_stmt.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_if_stmt.py new file mode 100644 index 0000000000..1507def2c2 --- /dev/null +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_if_stmt.py @@ -0,0 +1,76 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import numpy as np +import pytest + +import gt4py.next as gtx +from gt4py.next.iterator.builtins import cartesian_domain, deref, as_fieldop, named_range +from gt4py.next.iterator.runtime import set_at, if_stmt, fendef, fundef, offset +from gt4py.next.program_processors.runners import gtfn + +from next_tests.unit_tests.conftest import program_processor, run_processor + +i = offset("i") + + +@fundef +def multiply(alpha, inp): + return deref(alpha) * deref(inp) + + +IDim = gtx.Dimension("IDim") + + +@pytest.mark.uses_ir_if_stmts +@pytest.mark.parametrize("cond", [True, False]) +def test_if_stmt(program_processor, cond): + program_processor, validate = program_processor + size = 10 + + @fendef(offset_provider={"i": IDim}) + def fencil(cond1, inp, out): + domain = cartesian_domain(named_range(IDim, 0, size)) + if_stmt( + cond1, + lambda: set_at( + as_fieldop(multiply, domain)(1.0, inp), + domain, + out, + ), + lambda: set_at( + as_fieldop(multiply, domain)(2.0, inp), + domain, + out, + ), + ) + + rng = np.random.default_rng() + cond = False + inp = gtx.as_field([IDim], rng.normal(size=size)) + out = gtx.as_field([IDim], np.zeros(size)) + ref = inp if cond else 2.0 * inp + + run_processor(fencil, program_processor, False, inp, out) + + if validate: + assert np.allclose(out.asnumpy(), ref.asnumpy()) diff --git a/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py b/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py index 5bb359dffc..da4bea8874 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py @@ -231,6 +231,37 @@ def test_set_at(): assert actual == expected +def test_if_stmt(): + testee = """if (cond) { + y @ cartesian_domain() ← x; + if (cond) { + y @ cartesian_domain() ← x; + } else { + } + } else { + y @ cartesian_domain() ← x; + }""" + stmt = ir.SetAt( + expr=im.ref("x"), + domain=im.domain("cartesian_domain", {}), + target=im.ref("y"), + ) + expected = ir.IfStmt( + cond=im.ref("cond"), + true_branch=[ + stmt, + ir.IfStmt( + cond=im.ref("cond"), + true_branch=[stmt], + false_branch=[], + ), + ], + false_branch=[stmt], + ) + actual = pparse(testee) + assert actual == expected + + # TODO(havogt): remove after refactoring to GTIR def test_fencil_definition(): testee = "f(d, x, y) {\n g = λ(x) → x;\n y ← (deref)(x) @ cartesian_domain();\n}" From 65a22fe27636ae691bad8591018e14683ef187c7 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sat, 28 Sep 2024 11:55:02 +0200 Subject: [PATCH 002/150] Initial version of FuseAsFieldOp --- .../ir_utils/common_pattern_matcher.py | 4 + .../iterator/transforms/fuse_as_fieldop.py | 165 ++++++++++++++++++ .../transforms_tests/test_fuse_as_fieldop.py | 126 +++++++++++++ 3 files changed, 295 insertions(+) create mode 100644 src/gt4py/next/iterator/transforms/fuse_as_fieldop.py create mode 100644 tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py diff --git a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py index 135b18c367..4aea7ef149 100644 --- a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py +++ b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py @@ -70,3 +70,7 @@ def is_call_to(node: itir.Node, fun: str | Iterable[str]) -> TypeGuard[itir.FunC return ( isinstance(node, itir.FunCall) and isinstance(node.fun, itir.SymRef) and node.fun.id == fun ) + + +def is_ref_to(node, ref: str): + return isinstance(node, itir.SymRef) and node.id == ref diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py new file mode 100644 index 0000000000..6922857dbd --- /dev/null +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -0,0 +1,165 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import dataclasses +from typing import Optional + +from gt4py import eve +from gt4py.eve import utils as eve_utils +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im +from gt4py.next.iterator.transforms import inline_lambdas, inline_lifts, trace_shifts +from gt4py.next.iterator.type_system import inference as type_inference, type_specifications as ts +from gt4py.next.type_system import type_info + + +def inline_as_fieldop_arg(arg, uids): + assert cpm.is_applied_as_fieldop(arg) + arg = canonicalize_as_fieldop(arg) + + stencil, *_ = arg.fun.args + inner_args = arg.args + extracted_args = {} # mapping from stencil param to arg + + stencil_params = [] + stencil_body = stencil.expr + + for inner_param, inner_arg in zip(stencil.params, inner_args, strict=True): + if isinstance(inner_arg, itir.SymRef): + stencil_params.append(inner_param) + extracted_args[inner_arg.id] = inner_arg + elif isinstance(inner_arg, itir.Literal): # TODO: all non capturing scalars + stencil_body = im.let(inner_param, im.promote_to_const_iterator(inner_arg))( + stencil_body + ) + else: # either a literal or a previous not inlined arg + stencil_params.append(inner_param) + new_outer_stencil_param = uids.sequential_id(prefix="__iasfop") + extracted_args[new_outer_stencil_param] = inner_arg + + return im.lift(im.lambda_(*stencil_params)(stencil_body))( + *extracted_args.keys() + ), extracted_args + + +def merge_arguments(args1: dict, arg2: dict): + new_args = {**args1} + for stencil_param, stencil_arg in arg2.items(): + if stencil_param not in new_args: + new_args[stencil_param] = stencil_arg + else: + assert new_args[stencil_param] == stencil_arg + return new_args + + +def canonicalize_as_fieldop(expr: itir.Expr) -> itir.Expr: + assert cpm.is_applied_as_fieldop(expr) + + stencil = expr.fun.args[0] + domain = expr.fun.args[1] if len(expr.fun.args) > 1 else None + if cpm.is_ref_to(stencil, "deref"): + stencil = im.lambda_("arg")(im.deref("arg")) + new_expr = im.as_fieldop(stencil, domain)(*expr.args) + type_inference.copy_type(from_=expr, to=new_expr) + + return new_expr + + return expr + + +@dataclasses.dataclass +class FuseAsFieldOp(eve.NodeTranslator): + uids: eve_utils.UIDGenerator + + @classmethod + def apply( + cls, + node: itir.Program, + *, + offset_provider, + uids: Optional[eve_utils.UIDGenerator] = None, + allow_undeclared_symbols=False, + ): + node = type_inference.infer( + node, offset_provider=offset_provider, allow_undeclared_symbols=allow_undeclared_symbols + ) + + if not uids: + uids = eve_utils.UIDGenerator() + + return cls(uids=uids).visit(node) + + def visit_FunCall(self, node: itir.FunCall): + node = self.generic_visit(node) + + if cpm.is_call_to(node.fun, "as_fieldop"): + node = canonicalize_as_fieldop(node) + + if cpm.is_call_to(node.fun, "as_fieldop") and isinstance(node.fun.args[0], itir.Lambda): + stencil = node.fun.args[0] + domain = node.fun.args[1] if len(node.fun.args) > 1 else None + + shifts = trace_shifts.trace_stencil(stencil) + + args = node.args + + new_args = {} + new_stencil_body = stencil.expr + + for stencil_param, arg, arg_shifts in zip(stencil.params, args, shifts, strict=True): + dtype = type_info.extract_dtype(arg.type) + should_inline = isinstance(arg, itir.Literal) or ( + isinstance(arg, itir.FunCall) + and (cpm.is_call_to(arg.fun, "as_fieldop") or cpm.is_call_to(arg, "cond")) + and (isinstance(dtype, ts.ListType) or len(arg_shifts) <= 1) + ) + if should_inline: + if cpm.is_applied_as_fieldop(arg): + pass + elif cpm.is_call_to(arg, "if_"): + type_ = arg.type + arg = im.op_as_fieldop("if_")(*arg.args) + arg.type = type_ + elif isinstance(arg, itir.Literal): + arg = im.op_as_fieldop(im.lambda_()(arg))() + else: + raise NotImplementedError() + + inline_expr, extracted_args = inline_as_fieldop_arg(arg, self.uids) + + new_stencil_body = im.let(stencil_param, inline_expr)(new_stencil_body) + + new_args = merge_arguments(new_args, extracted_args) + else: + # see test_tuple_with_local_field_in_reduction_shifted for ex where assert fails + # assert not isinstance(dtype, ts.ListType) + if isinstance( + arg, itir.SymRef + ): # use name from outer scope (optional, just to get a nice IR) + new_param = arg.id + new_stencil_body = im.let(stencil_param.id, arg.id)(new_stencil_body) + else: + new_param = stencil_param.id + new_args = merge_arguments(new_args, {new_param: arg}) + + new_stencil_body = inline_lambdas.InlineLambdas.apply( + new_stencil_body, + opcount_preserving=True, + force_inline_lift_args=False, + # If trivial lifts are not inlined we might create temporaries for constants. In all + # other cases we want it anyway. + force_inline_trivial_lift_args=True, + ) + new_stencil_body = inline_lifts.InlineLifts().visit(new_stencil_body) + + new_node = im.as_fieldop(im.lambda_(*new_args.keys())(new_stencil_body), domain)( + *new_args.values() + ) + new_node.type = node.type + return new_node + return node diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py new file mode 100644 index 0000000000..bc371f64de --- /dev/null +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py @@ -0,0 +1,126 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +from typing import Callable, Optional + +from gt4py import next as gtx +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms import fuse_as_fieldop +from gt4py.next.type_system import type_specifications as ts + +IDim = gtx.Dimension("IDim") +field_type = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32)) + + +def op_asfieldop2(op: str | itir.SymRef | Callable, domain: Optional[itir.FunCall] = None): + assert isinstance(op, itir.Lambda) + op = im.call(op) + + args = [param.id for param in op.fun.params] + + def _impl(*its: itir.Expr) -> itir.FunCall: + return im.as_fieldop(im.lambda_(*args)(op(*[im.deref(arg) for arg in args])), domain)(*its) + + return _impl + + +def test_trivial(): + d = im.domain("cartesian_domain", {}) + testee = im.op_as_fieldop("plus", d)( + im.op_as_fieldop("multiplies", d)(im.ref("inp1", field_type), im.ref("inp2", field_type)), + im.ref("inp3", field_type), + ) + expected = im.as_fieldop( + im.lambda_("inp1", "inp2", "inp3", "__iasfop_4")( + im.plus( + im.multiplies_(im.deref("__iasfop_1"), im.deref("__iasfop_2")), + ) + ), + d, + )(1, 2, 3, 4) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider={}, allow_undeclared_symbols=True + ) + assert actual == expected + + +def test_trivial_literal(): + d = im.domain("cartesian_domain", {}) + testee = im.op_as_fieldop("plus", d)(im.op_as_fieldop("multiplies", d)(1, 2), 3) + expected = im.as_fieldop(im.lambda_()(im.plus(im.multiplies_(1, 2), 3)), d)() + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider={}, allow_undeclared_symbols=True + ) + assert actual == expected + + +def test_symref_used_twice(): + d = im.domain("cartesian_domain", {IDim: (0, 1)}) + testee = op_asfieldop2(im.lambda_("a", "b")(im.plus("a", "b")), d)( + im.as_fieldop(im.lambda_("c", "d")(im.multiplies_(im.deref("c"), im.deref("d"))), d)( + im.ref("inp1", field_type), im.ref("inp2", field_type) + ), + im.ref("inp1", field_type), + ) + expected = im.as_fieldop( + im.lambda_("inp1", "inp2")( + im.plus(im.multiplies_(im.deref("inp1"), im.deref("inp2")), im.deref("inp1")) + ), + d, + )("inp1", "inp2") + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider={}, allow_undeclared_symbols=True + ) + assert actual == expected + + +def test_no_inline(): + d1 = im.domain("cartesian_domain", {IDim: (1, 2)}) + d2 = im.domain("cartesian_domain", {IDim: (0, 3)}) + testee = im.as_fieldop( + im.lambda_("a")( + im.plus(im.deref(im.shift("IOff", 1)("a")), im.deref(im.shift("IOff", -1)("a"))) + ), + d1, + )(im.as_fieldop(im.lambda_("inp1")(im.deref("inp1")), d2)(im.ref("inp1", field_type))) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider={"IOff": IDim}, allow_undeclared_symbols=True + ) + assert actual == testee + + +def test_partial_inline(): + d1 = im.domain("cartesian_domain", {IDim: (1, 2)}) + d2 = im.domain("cartesian_domain", {IDim: (0, 3)}) + testee = im.as_fieldop( + # first argument used at multiple locations -> not inlined + # second argument only used at a single location -> inlined + im.lambda_("a", "b")( + im.plus( + im.plus(im.deref(im.shift("IOff", 1)("a")), im.deref(im.shift("IOff", -1)("a"))), + im.deref("b"), + ) + ), + d1, + )( + im.as_fieldop(im.lambda_("inp1")(im.deref("inp1")), d2)(im.ref("inp1", field_type)), + im.as_fieldop(im.lambda_("inp1")(im.deref("inp1")), d2)(im.ref("inp1", field_type)), + ) + expected = im.as_fieldop( + im.lambda_("a", "inp1")( + im.plus( + im.plus(im.deref(im.shift("IOff", 1)("a")), im.deref(im.shift("IOff", -1)("a"))), + im.deref("inp1"), + ) + ), + d1, + )(im.as_fieldop(im.lambda_("inp1")(im.deref("inp1")), d2)(im.ref("inp1", field_type)), "inp1") + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider={"IOff": IDim}, allow_undeclared_symbols=True + ) + assert actual == expected From 3edc130867902024e85b48bc5d1606354f43fce9 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sat, 28 Sep 2024 12:06:42 +0200 Subject: [PATCH 003/150] Move symbolic domain utilities out of global tmp pass into seperate module --- .../next/iterator/ir_utils/domain_utils.py | 150 ++++++++++++++++++ .../next/iterator/transforms/global_tmps.py | 139 +--------------- .../next/iterator/transforms/infer_domain.py | 38 +++-- 3 files changed, 178 insertions(+), 149 deletions(-) create mode 100644 src/gt4py/next/iterator/ir_utils/domain_utils.py diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py new file mode 100644 index 0000000000..ed04eb11c9 --- /dev/null +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -0,0 +1,150 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import dataclasses +import functools +from typing import Any, Literal, Mapping + +import gt4py.next as gtx +from gt4py.next import common +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im + + +def _max_domain_sizes_by_location_type(offset_provider: Mapping[str, Any]) -> dict[str, int]: + """ + Extract horizontal domain sizes from an `offset_provider`. + + Considers the shape of the neighbor table to get the size of each `origin_axis` and the maximum + value inside the neighbor table to get the size of each `neighbor_axis`. + """ + sizes = dict[str, int]() + for provider in offset_provider.values(): + if isinstance(provider, gtx.NeighborTableOffsetProvider): + assert provider.origin_axis.kind == gtx.DimensionKind.HORIZONTAL + assert provider.neighbor_axis.kind == gtx.DimensionKind.HORIZONTAL + sizes[provider.origin_axis.value] = max( + sizes.get(provider.origin_axis.value, 0), provider.table.shape[0] + ) + sizes[provider.neighbor_axis.value] = max( + sizes.get(provider.neighbor_axis.value, 0), + provider.table.max() + 1, # type: ignore[attr-defined] # TODO(havogt): improve typing for NDArrayObject + ) + return sizes + + +@dataclasses.dataclass +class SymbolicRange: + start: itir.Expr + stop: itir.Expr + + def translate(self, distance: int) -> SymbolicRange: + return SymbolicRange(im.plus(self.start, distance), im.plus(self.stop, distance)) + + +@dataclasses.dataclass +class SymbolicDomain: + grid_type: Literal["unstructured_domain", "cartesian_domain"] + ranges: dict[ + common.Dimension, SymbolicRange + ] # TODO(havogt): remove `AxisLiteral` by `Dimension` everywhere + + @classmethod + def from_expr(cls, node: itir.Node) -> SymbolicDomain: + assert isinstance(node, itir.FunCall) and node.fun in [ + im.ref("unstructured_domain"), + im.ref("cartesian_domain"), + ] + + ranges: dict[common.Dimension, SymbolicRange] = {} + for named_range in node.args: + assert ( + isinstance(named_range, itir.FunCall) + and isinstance(named_range.fun, itir.SymRef) + and named_range.fun.id == "named_range" + ) + axis_literal, lower_bound, upper_bound = named_range.args + assert isinstance(axis_literal, itir.AxisLiteral) + + ranges[common.Dimension(value=axis_literal.value, kind=axis_literal.kind)] = ( + SymbolicRange(lower_bound, upper_bound) + ) + return cls(node.fun.id, ranges) # type: ignore[attr-defined] # ensure by assert above + + def as_expr(self) -> itir.FunCall: + converted_ranges: dict[common.Dimension | str, tuple[itir.Expr, itir.Expr]] = { + key: (value.start, value.stop) for key, value in self.ranges.items() + } + return im.domain(self.grid_type, converted_ranges) + + def translate( + self: SymbolicDomain, + shift: tuple[itir.OffsetLiteral, ...], + offset_provider: common.OffsetProvider, + ) -> SymbolicDomain: + dims = list(self.ranges.keys()) + new_ranges = {dim: self.ranges[dim] for dim in dims} + if len(shift) == 0: + return self + if len(shift) == 2: + off, val = shift + assert isinstance(off.value, str) and isinstance(val.value, int) + nbt_provider = offset_provider[off.value] + if isinstance(nbt_provider, common.Dimension): + current_dim = nbt_provider + # cartesian offset + new_ranges[current_dim] = SymbolicRange.translate( + self.ranges[current_dim], val.value + ) + elif isinstance(nbt_provider, common.Connectivity): + # unstructured shift + # note: ugly but cheap re-computation, but should disappear + horizontal_sizes = _max_domain_sizes_by_location_type(offset_provider) + + old_dim = nbt_provider.origin_axis + new_dim = nbt_provider.neighbor_axis + + assert new_dim not in new_ranges or old_dim == new_dim + + # TODO(tehrengruber): Do we need symbolic sizes, e.g., for ICON? + new_range = SymbolicRange( + im.literal("0", itir.INTEGER_INDEX_BUILTIN), + im.literal(str(horizontal_sizes[new_dim.value]), itir.INTEGER_INDEX_BUILTIN), + ) + new_ranges = dict( + (dim, range_) if dim != old_dim else (new_dim, new_range) + for dim, range_ in new_ranges.items() + ) + else: + raise AssertionError() + return SymbolicDomain(self.grid_type, new_ranges) + elif len(shift) > 2: + return self.translate(shift[0:2], offset_provider).translate(shift[2:], offset_provider) + else: + raise AssertionError("Number of shifts must be a multiple of 2.") + + +def domain_union(*domains: SymbolicDomain) -> SymbolicDomain: + """Return the (set) union of a list of domains.""" + new_domain_ranges = {} + assert all(domain.grid_type == domains[0].grid_type for domain in domains) + assert all(domain.ranges.keys() == domains[0].ranges.keys() for domain in domains) + for dim in domains[0].ranges.keys(): + start = functools.reduce( + lambda current_expr, el_expr: im.call("minimum")(current_expr, el_expr), + [domain.ranges[dim].start for domain in domains], + ) + stop = functools.reduce( + lambda current_expr, el_expr: im.call("maximum")(current_expr, el_expr), + [domain.ranges[dim].stop for domain in domains], + ) + new_domain_ranges[dim] = SymbolicRange(start, stop) + + return SymbolicDomain(domains[0].grid_type, new_domain_ranges) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index f00a5e9f70..5a6873f916 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -10,18 +10,22 @@ import copy import dataclasses -import functools from collections.abc import Mapping from typing import Any, Callable, Final, Iterable, Literal, Optional, Sequence import gt4py.next as gtx from gt4py.eve import NodeTranslator, PreserveLocationVisitor -from gt4py.eve.extended_typing import Tuple from gt4py.eve.traits import SymbolTableTrait from gt4py.eve.utils import UIDGenerator from gt4py.next import common from gt4py.next.iterator import ir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im +from gt4py.next.iterator.ir_utils.domain_utils import ( + SymbolicDomain, + SymbolicRange, + _max_domain_sizes_by_location_type, + domain_union, +) from gt4py.next.iterator.pretty_printer import PrettyPrinter from gt4py.next.iterator.transforms import trace_shifts from gt4py.next.iterator.transforms.cse import extract_subexpression @@ -387,137 +391,6 @@ def prune_unused_temporaries(node: FencilWithTemporaries) -> FencilWithTemporari ) -def _max_domain_sizes_by_location_type(offset_provider: Mapping[str, Any]) -> dict[str, int]: - """Extract horizontal domain sizes from an `offset_provider`. - - Considers the shape of the neighbor table to get the size of each `origin_axis` and the maximum - value inside the neighbor table to get the size of each `neighbor_axis`. - """ - sizes = dict[str, int]() - for provider in offset_provider.values(): - if isinstance(provider, gtx.NeighborTableOffsetProvider): - assert provider.origin_axis.kind == gtx.DimensionKind.HORIZONTAL - assert provider.neighbor_axis.kind == gtx.DimensionKind.HORIZONTAL - sizes[provider.origin_axis.value] = max( - sizes.get(provider.origin_axis.value, 0), provider.table.shape[0] - ) - sizes[provider.neighbor_axis.value] = max( - sizes.get(provider.neighbor_axis.value, 0), - provider.table.max() + 1, # type: ignore[attr-defined] # TODO(havogt): improve typing for NDArrayObject - ) - return sizes - - -@dataclasses.dataclass -class SymbolicRange: - start: ir.Expr - stop: ir.Expr - - def translate(self, distance: int) -> "SymbolicRange": - return SymbolicRange(im.plus(self.start, distance), im.plus(self.stop, distance)) - - -@dataclasses.dataclass -class SymbolicDomain: - grid_type: Literal["unstructured_domain", "cartesian_domain"] - ranges: dict[ - common.Dimension, SymbolicRange - ] # TODO(havogt): remove `AxisLiteral` by `Dimension` everywhere - - @classmethod - def from_expr(cls, node: ir.Node) -> SymbolicDomain: - assert isinstance(node, ir.FunCall) and node.fun in [ - im.ref("unstructured_domain"), - im.ref("cartesian_domain"), - ] - - ranges: dict[common.Dimension, SymbolicRange] = {} - for named_range in node.args: - assert ( - isinstance(named_range, ir.FunCall) - and isinstance(named_range.fun, ir.SymRef) - and named_range.fun.id == "named_range" - ) - axis_literal, lower_bound, upper_bound = named_range.args - assert isinstance(axis_literal, ir.AxisLiteral) - - ranges[common.Dimension(value=axis_literal.value, kind=axis_literal.kind)] = ( - SymbolicRange(lower_bound, upper_bound) - ) - return cls(node.fun.id, ranges) # type: ignore[attr-defined] # ensure by assert above - - def as_expr(self) -> ir.FunCall: - converted_ranges: dict[common.Dimension | str, tuple[ir.Expr, ir.Expr]] = { - key: (value.start, value.stop) for key, value in self.ranges.items() - } - return im.domain(self.grid_type, converted_ranges) - - def translate( - self: SymbolicDomain, - shift: Tuple[ir.OffsetLiteral, ...], - offset_provider: common.OffsetProvider, - ) -> SymbolicDomain: - dims = list(self.ranges.keys()) - new_ranges = {dim: self.ranges[dim] for dim in dims} - if len(shift) == 0: - return self - if len(shift) == 2: - off, val = shift - assert isinstance(off.value, str) and isinstance(val.value, int) - nbt_provider = offset_provider[off.value] - if isinstance(nbt_provider, common.Dimension): - current_dim = nbt_provider - # cartesian offset - new_ranges[current_dim] = SymbolicRange.translate( - self.ranges[current_dim], val.value - ) - elif isinstance(nbt_provider, common.Connectivity): - # unstructured shift - # note: ugly but cheap re-computation, but should disappear - horizontal_sizes = _max_domain_sizes_by_location_type(offset_provider) - - old_dim = nbt_provider.origin_axis - new_dim = nbt_provider.neighbor_axis - - assert new_dim not in new_ranges or old_dim == new_dim - - # TODO(tehrengruber): Do we need symbolic sizes, e.g., for ICON? - new_range = SymbolicRange( - im.literal("0", ir.INTEGER_INDEX_BUILTIN), - im.literal(str(horizontal_sizes[new_dim.value]), ir.INTEGER_INDEX_BUILTIN), - ) - new_ranges = dict( - (dim, range_) if dim != old_dim else (new_dim, new_range) - for dim, range_ in new_ranges.items() - ) - else: - raise AssertionError() - return SymbolicDomain(self.grid_type, new_ranges) - elif len(shift) > 2: - return self.translate(shift[0:2], offset_provider).translate(shift[2:], offset_provider) - else: - raise AssertionError("Number of shifts must be a multiple of 2.") - - -def domain_union(*domains: SymbolicDomain) -> SymbolicDomain: - """Return the (set) union of a list of domains.""" - new_domain_ranges = {} - assert all(domain.grid_type == domains[0].grid_type for domain in domains) - assert all(domain.ranges.keys() == domains[0].ranges.keys() for domain in domains) - for dim in domains[0].ranges.keys(): - start = functools.reduce( - lambda current_expr, el_expr: im.call("minimum")(current_expr, el_expr), - [domain.ranges[dim].start for domain in domains], - ) - stop = functools.reduce( - lambda current_expr, el_expr: im.call("maximum")(current_expr, el_expr), - [domain.ranges[dim].stop for domain in domains], - ) - new_domain_ranges[dim] = SymbolicRange(start, stop) - - return SymbolicDomain(domains[0].grid_type, new_domain_ranges) - - def _group_offsets( offset_literals: Sequence[ir.OffsetLiteral], ) -> Sequence[tuple[str, int | Literal[trace_shifts.Sentinel.ALL_NEIGHBORS]]]: diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index 5104d09d3a..4b4b60e1c7 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -16,13 +16,16 @@ from gt4py.next import common from gt4py.next.common import Dimension from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + domain_utils, + ir_makers as im, +) from gt4py.next.iterator.transforms import trace_shifts -from gt4py.next.iterator.transforms.global_tmps import SymbolicDomain, domain_union from gt4py.next.utils import tree_map -DOMAIN: TypeAlias = SymbolicDomain | None | tuple["DOMAIN", ...] +DOMAIN: TypeAlias = domain_utils.SymbolicDomain | None | tuple["DOMAIN", ...] ACCESSED_DOMAINS: TypeAlias = dict[str, DOMAIN] @@ -41,12 +44,14 @@ def split_dict_by_key(pred: Callable, d: dict): return a, b -# TODO(tehrengruber): Revisit whether we want to move this behaviour to `domain_union`. -def _domain_union_with_none(*domains: SymbolicDomain | None) -> SymbolicDomain | None: - filtered_domains: list[SymbolicDomain] = [d for d in domains if d is not None] +# TODO(tehrengruber): Revisit whether we want to move this behaviour to `domain_utils.domain_union`. +def domain_union_with_none( + *domains: domain_utils.SymbolicDomain | None, +) -> domain_utils.SymbolicDomain | None: + filtered_domains: list[domain_utils.SymbolicDomain] = [d for d in domains if d is not None] if len(filtered_domains) == 0: return None - return domain_union(*filtered_domains) + return domain_utils.domain_union(*filtered_domains) def canonicalize_domain_structure(d1: DOMAIN, d2: DOMAIN) -> tuple[DOMAIN, DOMAIN]: @@ -93,7 +98,7 @@ def _merge_domains( original_domain, domain = canonicalize_domain_structure( original_domains.get(key, None), domain ) - new_domains[key] = tree_map(_domain_union_with_none)(original_domain, domain) + new_domains[key] = tree_map(domain_union_with_none)(original_domain, domain) return new_domains @@ -101,19 +106,20 @@ def _merge_domains( def extract_accessed_domains( stencil: itir.Expr, input_ids: list[str], - target_domain: SymbolicDomain, + target_domain: domain_utils.SymbolicDomain, offset_provider: common.OffsetProvider, ) -> ACCESSED_DOMAINS: - accessed_domains: dict[str, SymbolicDomain | None] = {} + accessed_domains: dict[str, domain_utils.SymbolicDomain | None] = {} shifts_results = trace_shifts.trace_stencil(stencil, num_args=len(input_ids)) for in_field_id, shifts_list in zip(input_ids, shifts_results, strict=True): new_domains = [ - SymbolicDomain.translate(target_domain, shift, offset_provider) for shift in shifts_list + domain_utils.SymbolicDomain.translate(target_domain, shift, offset_provider) + for shift in shifts_list ] # `None` means field is never accessed - accessed_domains[in_field_id] = _domain_union_with_none( + accessed_domains[in_field_id] = domain_union_with_none( accessed_domains.get(in_field_id, None), *new_domains ) @@ -129,8 +135,8 @@ def infer_as_fieldop( assert cpm.is_call_to(applied_fieldop.fun, "as_fieldop") if target_domain is None: raise ValueError("'target_domain' cannot be 'None'.") - if not isinstance(target_domain, SymbolicDomain): - raise ValueError("'target_domain' needs to be a 'SymbolicDomain'.") + if not isinstance(target_domain, domain_utils.SymbolicDomain): + raise ValueError("'target_domain' needs to be a 'domain_utils.SymbolicDomain'.") # `as_fieldop(stencil)(inputs...)` stencil, inputs = applied_fieldop.fun.args[0], applied_fieldop.args @@ -166,7 +172,7 @@ def infer_as_fieldop( accessed_domains = _merge_domains(accessed_domains, accessed_domains_tmp) - transformed_call = im.as_fieldop(stencil, SymbolicDomain.as_expr(target_domain))( + transformed_call = im.as_fieldop(stencil, domain_utils.SymbolicDomain.as_expr(target_domain))( *transformed_inputs ) @@ -318,7 +324,7 @@ def infer_program( assert isinstance(set_at, itir.SetAt) transformed_call, _unused_domain = infer_expr( - set_at.expr, SymbolicDomain.from_expr(set_at.domain), offset_provider + set_at.expr, domain_utils.SymbolicDomain.from_expr(set_at.domain), offset_provider ) transformed_set_ats.append( itir.SetAt( From dd63122b7bff43d1746b66120e342d48fa88d002 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 1 Oct 2024 13:33:09 +0000 Subject: [PATCH 004/150] feat[next]: gtir lowering of broadcasted scalars --- src/gt4py/next/ffront/foast_to_gtir.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 948a8481d7..fd6e082477 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -373,7 +373,10 @@ def create_if(true_: itir.Expr, false_: itir.Expr) -> itir.FunCall: _visit_concat_where = _visit_where # TODO(havogt): upgrade concat_where def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: - return self.visit(node.args[0], **kwargs) + expr = self.visit(node.args[0], **kwargs) + if type_info.is_type_or_tuple_of_type(node.args[0].type, ts.ScalarType): + return im.as_fieldop(im.ref("deref"))(expr) + return expr def _visit_math_built_in(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: return self._map(self.visit(node.func, **kwargs), *node.args) From 717244204bfa4e4ae97529f41f8f004954aa6fab Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 2 Oct 2024 00:39:26 +0200 Subject: [PATCH 005/150] Second draft of temporary pass --- src/gt4py/next/iterator/transforms/cse.py | 4 +- .../iterator/transforms/fencil_to_program.py | 12 +- .../next/iterator/transforms/global_tmps.py | 687 ++++-------------- .../next/iterator/transforms/pass_manager.py | 2 +- .../transforms_tests/test_global_tmps.py | 562 ++++---------- 5 files changed, 292 insertions(+), 975 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index 1a89adbb20..a2a169f6e4 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -51,7 +51,7 @@ def visit_FunCall(self, node: itir.FunCall) -> itir.Node: if isinstance(node, itir.FunCall) and isinstance(node.fun, itir.Lambda): eligible_params = [] for arg in node.args: - eligible_params.append(isinstance(arg, itir.SymRef) and arg.id.startswith("_cs")) + eligible_params.append(isinstance(arg, itir.SymRef)) # and arg.id.startswith("_cs")) # TODO: document? this is for lets in the global tmp pass, e.g. test_trivial_let if any(eligible_params): # note: the inline is opcount preserving anyway so avoid the additional # effort in the inliner by disabling opcount preservation. @@ -319,7 +319,7 @@ def extract_subexpression( subexprs = CollectSubexpressions.apply(node) # collect multiple occurrences and map them to fresh symbols - expr_map = dict[int, itir.SymRef]() + expr_map: dict[int, itir.SymRef] = {} ignored_ids = set() for expr, subexpr_entry in ( subexprs.items() if not deepest_expr_first else reversed(subexprs.items()) diff --git a/src/gt4py/next/iterator/transforms/fencil_to_program.py b/src/gt4py/next/iterator/transforms/fencil_to_program.py index db0b81a837..e07cbc282a 100644 --- a/src/gt4py/next/iterator/transforms/fencil_to_program.py +++ b/src/gt4py/next/iterator/transforms/fencil_to_program.py @@ -15,7 +15,7 @@ class FencilToProgram(eve.NodeTranslator): @classmethod def apply( - cls, node: itir.FencilDefinition | global_tmps.FencilWithTemporaries | itir.Program + cls, node: itir.FencilDefinition | itir.Program ) -> itir.Program: return cls().visit(node) @@ -32,13 +32,3 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition) -> itir.Program: body=self.visit(node.closures), implicit_domain=node.implicit_domain, ) - - def visit_FencilWithTemporaries(self, node: global_tmps.FencilWithTemporaries) -> itir.Program: - return itir.Program( - id=node.fencil.id, - function_definitions=node.fencil.function_definitions, - params=node.params, - declarations=node.tmps, - body=self.visit(node.fencil.closures), - implicit_domain=node.fencil.implicit_domain, - ) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 5a6873f916..13511f2c63 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -10,580 +10,153 @@ import copy import dataclasses +import functools from collections.abc import Mapping from typing import Any, Callable, Final, Iterable, Literal, Optional, Sequence -import gt4py.next as gtx -from gt4py.eve import NodeTranslator, PreserveLocationVisitor -from gt4py.eve.traits import SymbolTableTrait -from gt4py.eve.utils import UIDGenerator -from gt4py.next import common -from gt4py.next.iterator import ir +from gt4py import eve +from gt4py.eve import utils as eve_utils +from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im -from gt4py.next.iterator.ir_utils.domain_utils import ( - SymbolicDomain, - SymbolicRange, - _max_domain_sizes_by_location_type, - domain_union, -) -from gt4py.next.iterator.pretty_printer import PrettyPrinter -from gt4py.next.iterator.transforms import trace_shifts -from gt4py.next.iterator.transforms.cse import extract_subexpression -from gt4py.next.iterator.transforms.eta_reduction import EtaReduction -from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas -from gt4py.next.iterator.transforms.prune_closure_inputs import PruneClosureInputs -from gt4py.next.iterator.transforms.symbol_ref_utils import collect_symbol_refs -from gt4py.next.iterator.type_system import ( - inference as itir_type_inference, - type_specifications as it_ts, -) -from gt4py.next.type_system import type_specifications as ts - - -"""Iterator IR extension for global temporaries. - -Replaces lifted function calls by temporaries using the following steps: -1. Split closures by popping up lifted function calls to the top of the expression tree, (that is, - to stencil arguments) and then extracting them as new closures. -2. Introduces a new fencil-scope variable (the temporary) for each output of newly created closures. - The domain size is set to a new symbol `_gtmp_auto_domain`. -3. Infer the domain sizes for the new closures by analysing the accesses/shifts within all closures - and replace all occurrences of `_gtmp_auto_domain` by concrete domain sizes. -4. Infer the data type and size of the temporary buffers. -""" - - -AUTO_DOMAIN: Final = ir.FunCall(fun=ir.SymRef(id="_gtmp_auto_domain"), args=[]) - - -# Iterator IR extension nodes - - -class FencilWithTemporaries( - ir.Node, SymbolTableTrait -): # TODO(havogt): remove and use new `itir.Program` instead. - """Iterator IR extension: declaration of a fencil with temporary buffers.""" - - fencil: ir.FencilDefinition - params: list[ir.Sym] - tmps: list[ir.Temporary] - - -# Extensions for `PrettyPrinter` for easier debugging - - -def pformat_FencilWithTemporaries( - printer: PrettyPrinter, node: FencilWithTemporaries, *, prec: int -) -> list[str]: - assert prec == 0 - params = printer.visit(node.params, prec=0) - fencil = printer.visit(node.fencil, prec=0) - tmps = printer.visit(node.tmps, prec=0) - args = params + [[tmp.id] for tmp in node.tmps] - - hparams = printer._hmerge([node.fencil.id + "("], *printer._hinterleave(params, ", "), [") {"]) - vparams = printer._vmerge( - [node.fencil.id + "("], *printer._hinterleave(params, ",", indent=True), [") {"] +from gt4py.next.type_system import type_info +from gt4py.next.iterator.transforms import inline_lambdas + +# TODO: remove +SimpleTemporaryExtractionHeuristics = None +CreateGlobalTmps = None + +from gt4py.next.iterator.transforms import cse + +class IncompleteTemporary: + expr: itir.Expr + target: itir.Expr + +def get_expr_domain(expr: itir.Expr, ctx=None): + ctx = ctx or {} + + if cpm.is_applied_as_fieldop(expr): + _, domain = expr.fun.args + return domain + elif cpm.is_call_to(expr, "tuple_get"): + idx_expr, tuple_expr = expr.args + assert isinstance(idx_expr, itir.Literal) and type_info.is_integer(idx_expr.type) + idx = int(idx_expr.value) + tuple_expr_domain = get_expr_domain(tuple_expr, ctx) + assert isinstance(tuple_expr_domain, tuple) and idx < len(tuple_expr_domain) + return tuple_expr_domain[idx] + elif cpm.is_call_to(expr, "make_tuple"): + return tuple(get_expr_domain(el, ctx) for el in expr.args) + elif cpm.is_call_to(expr, "if_"): + cond, true_val, false_val = expr.args + true_domain, false_domain = get_expr_domain(true_val, ctx), get_expr_domain(false_val, ctx) + assert true_domain == false_domain + return true_domain + elif cpm.is_let(expr): + new_ctx = {} + for var_name, var_value in zip(expr.fun.params, expr.args, strict=True): + new_ctx[var_name.id] = get_expr_domain(var_value, ctx) + return get_expr_domain(expr.fun.expr, ctx={**ctx, **new_ctx}) + raise ValueError() + + +def transform_if(stmt: itir.SetAt, declarations: list[itir.Temporary], uids: eve_utils.UIDGenerator): + if not isinstance(stmt, itir.SetAt): + return None + + if cpm.is_call_to(stmt.expr, "if_"): + cond, true_val, false_val = stmt.expr.args + return [itir.IfStmt( + cond=cond, + # recursively transform + true_branch=transform(itir.SetAt(target=stmt.target, expr=true_val, domain=stmt.domain), declarations, uids), + false_branch=transform(itir.SetAt(target=stmt.target, expr=false_val, domain=stmt.domain), declarations, uids), + )] + return None + +def transform_by_pattern(stmt: itir.SetAt, predicate, declarations: list[itir.Temporary], uids: eve_utils.UIDGenerator): + if not isinstance(stmt, itir.SetAt): + return None + + new_expr, extracted_fields, _ = cse.extract_subexpression( + stmt.expr, + predicate=predicate, + uid_generator=uids, + # allows better fusing later on + #deepest_expr_first=True # TODO: better, but not supported right now ) - params = printer._optimum(hparams, vparams) - - hargs = printer._hmerge(*printer._hinterleave(args, ", ")) - vargs = printer._vmerge(*printer._hinterleave(args, ",")) - args = printer._optimum(hargs, vargs) - - fencil = printer._hmerge(fencil, [";"]) - - hcall = printer._hmerge([node.fencil.id + "("], args, [");"]) - vcall = printer._vmerge(printer._hmerge([node.fencil.id + "("]), printer._indent(args), [");"]) - call = printer._optimum(hcall, vcall) - - body = printer._vmerge(*tmps, fencil, call) - return printer._vmerge(params, printer._indent(body), ["}"]) - - -PrettyPrinter.visit_FencilWithTemporaries = pformat_FencilWithTemporaries # type: ignore - - -# Main implementation -def canonicalize_applied_lift(closure_params: list[str], node: ir.FunCall) -> ir.FunCall: - """ - Canonicalize applied lift expressions. - - Transform lift such that the arguments to the applied lift are only symbols. - - >>> bool_type = ts.ScalarType(kind=ts.ScalarKind.BOOL) - >>> it_type = it_ts.IteratorType(position_dims=[], defined_dims=[], element_type=bool_type) - >>> expr = im.lift(im.lambda_("a")(im.deref("a")))(im.lift("deref")(im.ref("inp", it_type))) - >>> print(expr) - (↑(λ(a) → ·a))((↑deref)(inp)) - >>> print(canonicalize_applied_lift(["inp"], expr)) - (↑(λ(inp) → (λ(a) → ·a)((↑deref)(inp))))(inp) - """ - assert cpm.is_applied_lift(node) - stencil = node.fun.args[0] # type: ignore[attr-defined] # ensured by is_applied lift - it_args = node.args - if any(not isinstance(it_arg, ir.SymRef) for it_arg in it_args): - closure_param_refs = collect_symbol_refs(node, as_ref=True) - assert not ({str(ref.id) for ref in closure_param_refs} - set(closure_params)) - new_node = im.lift( - im.lambda_(*[im.sym(param.id) for param in closure_param_refs])( - im.call(stencil)(*it_args) - ) - )(*closure_param_refs) - # ensure all types are inferred - return itir_type_inference.infer( - new_node, inplace=True, allow_undeclared_symbols=True, offset_provider={} - ) - return node - - -@dataclasses.dataclass(frozen=True) -class TemporaryExtractionPredicate: - """ - Construct a callable that determines if a lift expr can and should be extracted to a temporary. - - The class optionally takes a heuristic that can restrict the extraction. - """ - - heuristics: Optional[Callable[[ir.Expr], bool]] = None - - def __call__(self, expr: ir.Expr, num_occurences: int) -> bool: - """Determine if `expr` is an applied lift that should be extracted as a temporary.""" - if not cpm.is_applied_lift(expr): - return False - # do not extract when the result is a list (i.e. a lift expression used in a `reduce` call) - # as we can not create temporaries for these stencils - assert isinstance(expr.type, it_ts.IteratorType) - if isinstance(expr.type.element_type, it_ts.ListType): - return False - if self.heuristics and not self.heuristics(expr): - return False - stencil = expr.fun.args[0] # type: ignore[attr-defined] # ensured by `is_applied_lift` - # do not extract when the stencil is capturing - used_symbols = collect_symbol_refs(stencil) - if used_symbols: - return False - return True - - -@dataclasses.dataclass(frozen=True) -class SimpleTemporaryExtractionHeuristics: - """ - Heuristic that extracts only if a lift expr is derefed in more than one position. - - Note that such expression result in redundant computations if inlined instead of being - placed into a temporary. - """ - - closure: ir.StencilClosure - - def __post_init__(self) -> None: - trace_shifts.trace_stencil( - self.closure.stencil, num_args=len(self.closure.inputs), save_to_annex=True - ) - - def __call__(self, expr: ir.Expr) -> bool: - shifts = expr.annex.recorded_shifts - if len(shifts) > 1: - return True - return False - - -def _closure_parameter_argument_mapping(closure: ir.StencilClosure) -> dict[str, ir.Expr]: - """ - Create a mapping from the closures parameters to the closure arguments. - - E.g. for the closure `out ← (λ(param) → ...)(arg) @ u⟨ ... ⟩;` we get a mapping from `param` - to `arg`. In case the stencil is a scan, a mapping from closure inputs to scan pass (i.e. first - arg is ignored) is returned. - """ - is_scan = cpm.is_call_to(closure.stencil, "scan") - - if is_scan: - stencil = closure.stencil.args[0] # type: ignore[attr-defined] # ensured by is_scan - return { - param.id: arg for param, arg in zip(stencil.params[1:], closure.inputs, strict=True) - } - else: - assert isinstance(closure.stencil, ir.Lambda) - return { - param.id: arg for param, arg in zip(closure.stencil.params, closure.inputs, strict=True) - } - - -def _ensure_expr_does_not_capture(expr: ir.Expr, whitelist: list[ir.Sym]) -> None: - used_symbol_refs = collect_symbol_refs(expr) - assert not (set(used_symbol_refs) - {param.id for param in whitelist}) - - -def split_closures( - node: ir.FencilDefinition, - offset_provider: common.OffsetProvider, - *, - extraction_heuristics: Optional[ - Callable[[ir.StencilClosure], Callable[[ir.Expr], bool]] - ] = None, -) -> FencilWithTemporaries: - """Split closures on lifted function calls and introduce new temporary buffers for return values. - - Newly introduced temporaries will have the symbolic size of `AUTO_DOMAIN`. A symbol with the - same name is also added as a fencil argument (to be replaced at a later stage). - For each closure, follows these steps: - 1. Pops up lifted function calls to the top of the expression tree. - 2. Introduce new temporary for the output. - 3. Extract lifted function class as new closures with the previously created temporary as output. - The closures are processed in reverse order to properly respect the dependencies. - """ - if not extraction_heuristics: - # extract all (eligible) lifts - def always_extract_heuristics(_: ir.StencilClosure) -> Callable[[ir.Expr], bool]: - return lambda _: True + if extracted_fields: + new_stmts = [] + for tmp_sym, tmp_expr in extracted_fields.items(): + # TODO: expr domain can not be a tuple here + domain = get_expr_domain(tmp_expr) - extraction_heuristics = always_extract_heuristics - - uid_gen_tmps = UIDGenerator(prefix="_tmp") - - node = itir_type_inference.infer(node, offset_provider=offset_provider) - - tmps: list[tuple[str, ts.DataType]] = [] - - closures: list[ir.StencilClosure] = [] - for closure in reversed(node.closures): - closure_stack: list[ir.StencilClosure] = [closure] - while closure_stack: - current_closure: ir.StencilClosure = closure_stack.pop() - - if ( - isinstance(current_closure.stencil, ir.SymRef) - and current_closure.stencil.id == "deref" - ): - closures.append(current_closure) - continue - - is_scan: bool = cpm.is_call_to(current_closure.stencil, "scan") - current_closure_stencil = ( - current_closure.stencil if not is_scan else current_closure.stencil.args[0] # type: ignore[attr-defined] # ensured by is_scan + scalar_type = type_info.apply_to_primitive_constituents( + type_info.extract_dtype, tmp_expr.type ) + declarations.append(itir.Temporary(id=tmp_sym.id, domain=domain, dtype=scalar_type)) - extraction_predicate = TemporaryExtractionPredicate( - extraction_heuristics(current_closure) - ) + # TODO: transform not needed if deepest_expr_first=True + new_stmts.extend(transform(itir.SetAt(target=im.ref(tmp_sym.id), domain=domain, expr=tmp_expr), declarations, uids)) - stencil_body, extracted_lifts, _ = extract_subexpression( - current_closure_stencil.expr, - extraction_predicate, - uid_gen_tmps, - once_only=True, - deepest_expr_first=True, + return [ + *new_stmts, + itir.SetAt( + target=stmt.target, + domain=stmt.domain, + expr=new_expr ) - - if extracted_lifts: - for tmp_sym, lift_expr in extracted_lifts.items(): - # make sure the applied lift is not capturing anything except of closure params - _ensure_expr_does_not_capture(lift_expr, current_closure_stencil.params) - - assert isinstance(lift_expr, ir.FunCall) and isinstance( - lift_expr.fun, ir.FunCall - ) - - # make sure the arguments to the applied lift are only symbols - if not all(isinstance(arg, ir.SymRef) for arg in lift_expr.args): - lift_expr = canonicalize_applied_lift( - [str(param.id) for param in current_closure_stencil.params], lift_expr - ) - assert all(isinstance(arg, ir.SymRef) for arg in lift_expr.args) - - # create a mapping from the closures parameters to the closure arguments - closure_param_arg_mapping = _closure_parameter_argument_mapping(current_closure) - - # usually an ir.Lambda or scan - stencil: ir.Node = lift_expr.fun.args[0] # type: ignore[attr-defined] # ensured by canonicalize_applied_lift - - # allocate a new temporary - assert isinstance(stencil.type, ts.FunctionType) - assert isinstance(stencil.type.returns, ts.DataType) - tmps.append((tmp_sym.id, stencil.type.returns)) - - # create a new closure that executes the stencil of the applied lift and - # writes the result to the newly created temporary - closure_stack.append( - ir.StencilClosure( - domain=AUTO_DOMAIN, - stencil=stencil, - output=im.ref(tmp_sym.id), - inputs=[ - closure_param_arg_mapping[param.id] # type: ignore[attr-defined] - for param in lift_expr.args - ], - location=current_closure.location, - ) - ) - - new_stencil: ir.Lambda | ir.FunCall - # create a new stencil where all applied lifts that have been extracted are - # replaced by references to the respective temporary - new_stencil = ir.Lambda( - params=current_closure_stencil.params + list(extracted_lifts.keys()), - expr=stencil_body, - ) - # if we are extracting from an applied scan we have to wrap the scan pass again, - # i.e. transform `λ(state, ...) → ...` into `scan(λ(state, ...) → ..., ...)` - if is_scan: - new_stencil = im.call("scan")(new_stencil, current_closure.stencil.args[1:]) # type: ignore[attr-defined] # ensure by is_scan - # inline such that let statements which are just rebinding temporaries disappear - new_stencil = InlineLambdas.apply( - new_stencil, opcount_preserving=True, force_inline_lift_args=False - ) - # we're done with the current closure, add it back to the stack for further - # extraction. - closure_stack.append( - ir.StencilClosure( - domain=current_closure.domain, - stencil=new_stencil, - output=current_closure.output, - inputs=current_closure.inputs - + [ir.SymRef(id=sym.id) for sym in extracted_lifts.keys()], - location=current_closure.location, - ) - ) - else: - closures.append(current_closure) - - return FencilWithTemporaries( - fencil=ir.FencilDefinition( - id=node.id, - function_definitions=node.function_definitions, - params=node.params + [im.sym(name) for name, _ in tmps] + [im.sym(AUTO_DOMAIN.fun.id)], # type: ignore[attr-defined] # value is a global constant - closures=list(reversed(closures)), - location=node.location, - implicit_domain=node.implicit_domain, - ), - params=node.params, - tmps=[ir.Temporary(id=name, dtype=type_) for name, type_ in tmps], - ) - - -def prune_unused_temporaries(node: FencilWithTemporaries) -> FencilWithTemporaries: - """Remove temporaries that are never read.""" - unused_tmps = {tmp.id for tmp in node.tmps} - for closure in node.fencil.closures: - unused_tmps -= {inp.id for inp in closure.inputs} - - if not unused_tmps: - return node - - closures = [ - closure - for closure in node.fencil.closures - if not (isinstance(closure.output, ir.SymRef) and closure.output.id in unused_tmps) + ] + return None + +def transform(stmt: itir.SetAt, declarations: list[itir.Temporary], uids: eve_utils.UIDGenerator): + # TODO: what happens for a trivial let, e.g `let a=as_fieldop() in a end`? + unprocessed_stmts = [stmt] + stmts = [] + + transforms = [ + # transform functional if_ into if-stmt + transform_if, + # extract applied `as_fieldop` to top-level + functools.partial(transform_by_pattern, predicate=lambda expr, _: cpm.is_applied_as_fieldop(expr)), + # extract functional if_ to the top-level + functools.partial(transform_by_pattern, predicate=lambda expr, _: cpm.is_call_to(expr, "if_")), ] - return FencilWithTemporaries( - fencil=ir.FencilDefinition( - id=node.fencil.id, - function_definitions=node.fencil.function_definitions, - params=[p for p in node.fencil.params if p.id not in unused_tmps], - closures=closures, - location=node.fencil.location, - ), - params=node.params, - tmps=[tmp for tmp in node.tmps if tmp.id not in unused_tmps], - ) - -def _group_offsets( - offset_literals: Sequence[ir.OffsetLiteral], -) -> Sequence[tuple[str, int | Literal[trace_shifts.Sentinel.ALL_NEIGHBORS]]]: - tags = [tag.value for tag in offset_literals[::2]] - offsets = [ - offset.value if isinstance(offset, ir.OffsetLiteral) else offset - for offset in offset_literals[1::2] - ] - assert all(isinstance(tag, str) for tag in tags) - assert all( - isinstance(offset, int) or offset == trace_shifts.Sentinel.ALL_NEIGHBORS - for offset in offsets - ) - return zip(tags, offsets, strict=True) # type: ignore[return-value] # mypy doesn't infer literal correctly + while unprocessed_stmts: + stmt = unprocessed_stmts.pop(0) + did_transform = False + for transform in transforms: + transformed_stmts = transform(stmt=stmt, declarations=declarations, uids=uids) + if transformed_stmts: + unprocessed_stmts = [*transformed_stmts, *unprocessed_stmts] + did_transform = True + break -def update_domains( - node: FencilWithTemporaries, - offset_provider: Mapping[str, Any], - symbolic_sizes: Optional[dict[str, str]], -) -> FencilWithTemporaries: - horizontal_sizes = _max_domain_sizes_by_location_type(offset_provider) - closures: list[ir.StencilClosure] = [] - domains = dict[str, ir.FunCall]() - for closure in reversed(node.fencil.closures): - if closure.domain == AUTO_DOMAIN: - # every closure with auto domain should have a single out field - assert isinstance(closure.output, ir.SymRef) + # no transformation occurred + if not did_transform: + stmts.append(stmt) - if closure.output.id not in domains: - raise NotImplementedError(f"Closure output '{closure.output.id}' is never used.") + return stmts - domain = domains[closure.output.id] +def create_global_tmps(program: itir.Program): + uids = eve_utils.UIDGenerator(prefix="__tmp") + declarations = program.declarations + new_body = [] - closure = ir.StencilClosure( - domain=copy.deepcopy(domain), - stencil=closure.stencil, - output=closure.output, - inputs=closure.inputs, - location=closure.location, + for stmt in program.body: + if isinstance(stmt, (itir.SetAt, itir.IfStmt)): + new_body.extend( + transform(stmt, uids=uids, declarations=declarations) ) else: - domain = closure.domain - - closures.append(closure) - - local_shifts = trace_shifts.trace_stencil(closure.stencil, num_args=len(closure.inputs)) - for param_sym, shift_chains in zip(closure.inputs, local_shifts): - param = param_sym.id - assert isinstance(param, str) - consumed_domains: list[SymbolicDomain] = ( - [SymbolicDomain.from_expr(domains[param])] if param in domains else [] - ) - for shift_chain in shift_chains: - consumed_domain = SymbolicDomain.from_expr(domain) - for offset_name, offset in _group_offsets(shift_chain): - if isinstance(offset_provider[offset_name], gtx.Dimension): - # cartesian shift - dim = offset_provider[offset_name] - assert offset is not trace_shifts.Sentinel.ALL_NEIGHBORS - consumed_domain.ranges[dim] = consumed_domain.ranges[dim].translate(offset) - elif isinstance(offset_provider[offset_name], common.Connectivity): - # unstructured shift - nbt_provider = offset_provider[offset_name] - old_axis = nbt_provider.origin_axis - new_axis = nbt_provider.neighbor_axis - - assert new_axis not in consumed_domain.ranges or old_axis == new_axis - - if symbolic_sizes is None: - new_range = SymbolicRange( - im.literal("0", ir.INTEGER_INDEX_BUILTIN), - im.literal( - str(horizontal_sizes[new_axis.value]), ir.INTEGER_INDEX_BUILTIN - ), - ) - else: - new_range = SymbolicRange( - im.literal("0", ir.INTEGER_INDEX_BUILTIN), - im.ref(symbolic_sizes[new_axis.value]), - ) - consumed_domain.ranges = dict( - (axis, range_) if axis != old_axis else (new_axis, new_range) - for axis, range_ in consumed_domain.ranges.items() - ) - # TODO(tehrengruber): Revisit. Somehow the order matters so preserve it. - consumed_domain.ranges = dict( - (axis, range_) if axis != old_axis else (new_axis, new_range) - for axis, range_ in consumed_domain.ranges.items() - ) - else: - raise NotImplementedError() - consumed_domains.append(consumed_domain) - - # compute the bounds of all consumed domains - if consumed_domains: - if all( - consumed_domain.ranges.keys() == consumed_domains[0].ranges.keys() - for consumed_domain in consumed_domains - ): # scalar otherwise - domains[param] = domain_union(*consumed_domains).as_expr() - - return FencilWithTemporaries( - fencil=ir.FencilDefinition( - id=node.fencil.id, - function_definitions=node.fencil.function_definitions, - params=node.fencil.params[:-1], # remove `_gtmp_auto_domain` param again - closures=list(reversed(closures)), - location=node.fencil.location, - implicit_domain=node.fencil.implicit_domain, - ), - params=node.params, - tmps=node.tmps, - ) - - -def _tuple_constituents(node: ir.Expr) -> Iterable[ir.Expr]: - if cpm.is_call_to(node, "make_tuple"): - for arg in node.args: - yield from _tuple_constituents(arg) - else: - yield node - - -def collect_tmps_info( - node: FencilWithTemporaries, *, offset_provider: common.OffsetProvider -) -> FencilWithTemporaries: - """Perform type inference for finding the types of temporaries and sets the temporary size.""" - tmps = {tmp.id for tmp in node.tmps} - domains: dict[str, ir.Expr] = {} - for closure in node.fencil.closures: - for output_field in _tuple_constituents(closure.output): - assert isinstance(output_field, ir.SymRef) - if output_field.id not in tmps: - continue - - assert output_field.id not in domains or domains[output_field.id] == closure.domain - domains[output_field.id] = closure.domain - - new_node = FencilWithTemporaries( - fencil=node.fencil, - params=node.params, - tmps=[ - ir.Temporary(id=tmp.id, domain=domains[tmp.id], dtype=tmp.dtype) for tmp in node.tmps - ], - ) - # TODO(tehrengruber): type inference is only really needed to infer the types of the temporaries - # and write them to the params of the inner fencil. This should be cleaned up after we - # refactored the IR. - return itir_type_inference.infer(new_node, offset_provider=offset_provider) - - -def validate_no_dynamic_offsets(node: ir.Node) -> None: - """Vaidate we have no dynamic offsets, e.g. `shift(Ioff, deref(...))(...)`""" - for call_node in node.walk_values().if_isinstance(ir.FunCall): - assert isinstance(call_node, ir.FunCall) - if cpm.is_call_to(call_node, "shift"): - if any(not isinstance(arg, ir.OffsetLiteral) for arg in call_node.args): - raise NotImplementedError("Dynamic offsets not supported in temporary pass.") - - -# TODO(tehrengruber): Add support for dynamic shifts (e.g. the distance is a symbol). This can be -# tricky: For every lift statement that is dynamically shifted we can not compute bounds anymore -# and hence also not extract as a temporary. -class CreateGlobalTmps(PreserveLocationVisitor, NodeTranslator): - """Main entry point for introducing global temporaries. - - Transforms an existing iterator IR fencil into a fencil with global temporaries. - """ - - def visit_FencilDefinition( - self, - node: ir.FencilDefinition, - *, - offset_provider: Mapping[str, Any], - extraction_heuristics: Optional[ - Callable[[ir.StencilClosure], Callable[[ir.Expr], bool]] - ] = None, - symbolic_sizes: Optional[dict[str, str]], - ) -> FencilWithTemporaries: - # Vaidate we have no dynamic offsets, e.g. `shift(Ioff, deref(...))(...)` - validate_no_dynamic_offsets(node) - # Split closures on lifted function calls and introduce temporaries - res = split_closures( - node, offset_provider=offset_provider, extraction_heuristics=extraction_heuristics - ) - # Prune unreferences closure inputs introduced in the previous step - res = PruneClosureInputs().visit(res) - # Prune unused temporaries possibly introduced in the previous step - res = prune_unused_temporaries(res) - # Perform an eta-reduction which should put all calls at the highest level of a closure - res = EtaReduction().visit(res) - # Perform a naive extent analysis to compute domain sizes of closures and temporaries - res = update_domains(res, offset_provider, symbolic_sizes) - # Use type inference to determine the data type of the temporaries - return collect_tmps_info(res, offset_provider=offset_provider) + raise NotImplementedError() + + return itir.Program( + id=program.id, + function_definitions=program.function_definitions, + params=program.params, + declarations=declarations, + body=new_body + ) \ No newline at end of file diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 8dd76b289b..6cd4784062 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -18,7 +18,7 @@ from gt4py.next.iterator.transforms.cse import CommonSubexpressionElimination from gt4py.next.iterator.transforms.eta_reduction import EtaReduction from gt4py.next.iterator.transforms.fuse_maps import FuseMaps -from gt4py.next.iterator.transforms.global_tmps import CreateGlobalTmps, FencilWithTemporaries +from gt4py.next.iterator.transforms.global_tmps import CreateGlobalTmps from gt4py.next.iterator.transforms.inline_center_deref_lift_vars import InlineCenterDerefLiftVars from gt4py.next.iterator.transforms.inline_into_scan import InlineIntoScan from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py index ffb5447684..c8a91e037e 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py @@ -10,460 +10,214 @@ # itir. Currently we only test temporaries from frontend code which makes testing changes # to anything related to temporaries tedious. import copy +from typing import Optional import gt4py.next as gtx from gt4py.eve.utils import UIDs from gt4py.next import common -from gt4py.next.iterator import ir +from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.iterator.transforms.global_tmps import ( - AUTO_DOMAIN, - FencilWithTemporaries, - SimpleTemporaryExtractionHeuristics, - collect_tmps_info, - split_closures, - update_domains, -) +from gt4py.next.iterator.transforms import global_tmps, infer_domain +from gt4py.next.iterator.type_system import inference as type_inference from gt4py.next.type_system import type_specifications as ts IDim = common.Dimension(value="IDim") JDim = common.Dimension(value="JDim") KDim = common.Dimension(value="KDim", kind=common.DimensionKind.VERTICAL) -index_type = ts.ScalarType(kind=getattr(ts.ScalarKind, ir.INTEGER_INDEX_BUILTIN.upper())) +index_type = ts.ScalarType(kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper())) float_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) i_field_type = ts.FieldType(dims=[IDim], dtype=float_type) index_field_type_factory = lambda dim: ts.FieldType(dims=[dim], dtype=index_type) - -def test_split_closures(): - UIDs.reset_sequence() - testee = ir.FencilDefinition( - id="f", +def program_factory( + params: list[itir.Sym], + body: list[itir.SetAt], + declarations: Optional[list[itir.Temporary]] = None +) -> itir.Program: + return itir.Program( + id="testee", function_definitions=[], - params=[ - im.sym("d", i_field_type), - im.sym("inp", i_field_type), - im.sym("out", i_field_type), - ], - closures=[ - ir.StencilClosure( - domain=im.call("cartesian_domain")(), - stencil=im.lambda_("baz_inp")( - im.deref( - im.lift( - im.lambda_("bar_inp")( - im.deref( - im.lift(im.lambda_("foo_inp")(im.deref("foo_inp")))("bar_inp") - ) - ) - )("baz_inp") - ) - ), - output=im.ref("out"), - inputs=[im.ref("inp")], - ) - ], + params=params, + declarations=declarations or [], + body=body ) - expected = ir.FencilDefinition( - id="f", - function_definitions=[], - params=[ - im.sym("d", i_field_type), - im.sym("inp", i_field_type), - im.sym("out", i_field_type), - im.sym("_tmp_1", i_field_type), - im.sym("_tmp_2", i_field_type), - im.sym("_gtmp_auto_domain", ts.DeferredType(constraint=None)), - ], - closures=[ - ir.StencilClosure( - domain=AUTO_DOMAIN, - stencil=im.lambda_("foo_inp")(im.deref("foo_inp")), - output=im.ref("_tmp_2"), - inputs=[im.ref("inp")], - ), - ir.StencilClosure( - domain=AUTO_DOMAIN, - stencil=im.lambda_("bar_inp", "_tmp_2")(im.deref("_tmp_2")), - output=im.ref("_tmp_1"), - inputs=[im.ref("inp"), im.ref("_tmp_2")], - ), - ir.StencilClosure( - domain=im.call("cartesian_domain")(), - stencil=im.lambda_("baz_inp", "_tmp_1")(im.deref("_tmp_1")), - output=im.ref("out"), - inputs=[im.ref("inp"), im.ref("_tmp_1")], - ), - ], - ) - actual = split_closures(testee, offset_provider={}) - assert actual.tmps == [ - ir.Temporary(id="_tmp_1", dtype=float_type), - ir.Temporary(id="_tmp_2", dtype=float_type), - ] - assert actual.fencil == expected - - -def test_split_closures_simple_heuristics(): - UIDs.reset_sequence() - testee = ir.FencilDefinition( - id="f", - function_definitions=[], - params=[ - im.sym("d", i_field_type), - im.sym("inp", i_field_type), - im.sym("out", i_field_type), - ], - closures=[ - ir.StencilClosure( - domain=im.call("cartesian_domain")(), - stencil=im.lambda_("foo")( - im.let("lifted_it", im.lift(im.lambda_("bar")(im.deref("bar")))("foo"))( - im.plus(im.deref("lifted_it"), im.deref(im.shift("I", 1)("lifted_it"))) - ) +def test_trivial(): + domain = im.domain("cartesian_domain", {IDim: (0, 1)}) + offset_provider = {} + testee = program_factory( + params=[im.sym("inp", i_field_type), im.sym("out", i_field_type)], + body=[ + itir.SetAt( + target=im.ref("out"), + expr=im.as_fieldop("deref", domain)( + im.as_fieldop("deref", domain)("inp") ), - output=im.ref("out"), - inputs=[im.ref("inp")], + domain=domain ) - ], + ] ) + testee = type_inference.infer(testee, offset_provider=offset_provider) - expected = ir.FencilDefinition( - id="f", - function_definitions=[], - params=[ - im.sym("d", i_field_type), - im.sym("inp", i_field_type), - im.sym("out", i_field_type), - im.sym("_tmp_1", i_field_type), - im.sym("_gtmp_auto_domain", ts.DeferredType(constraint=None)), - ], - closures=[ - ir.StencilClosure( - domain=AUTO_DOMAIN, - stencil=im.lambda_("bar")(im.deref("bar")), - output=im.ref("_tmp_1"), - inputs=[im.ref("inp")], - ), - ir.StencilClosure( - domain=im.call("cartesian_domain")(), - stencil=im.lambda_("foo", "_tmp_1")( - im.plus(im.deref("_tmp_1"), im.deref(im.shift("I", 1)("_tmp_1"))) - ), - output=im.ref("out"), - inputs=[im.ref("inp"), im.ref("_tmp_1")], + expected = program_factory( + params=[im.sym("inp", i_field_type), im.sym("out", i_field_type)], + declarations=[itir.Temporary( + id="__tmp_1", + domain=domain, + dtype=float_type + )], + body=[ + itir.SetAt( + target=im.ref("__tmp_1"), + expr=im.as_fieldop("deref", domain)("inp"), + domain=domain ), - ], - ) - actual = split_closures( - testee, - extraction_heuristics=SimpleTemporaryExtractionHeuristics, - offset_provider={"I": IDim}, + itir.SetAt( + target=im.ref("out"), + expr=im.as_fieldop("deref", domain)("__tmp_1"), + domain=domain + ) + ] ) - assert actual.tmps == [ir.Temporary(id="_tmp_1", dtype=float_type)] - assert actual.fencil == expected + actual = global_tmps.create_global_tmps(testee) + assert actual == expected -def test_split_closures_lifted_scan(): - UIDs.reset_sequence() - testee = ir.FencilDefinition( - id="f", - function_definitions=[], +def test_trivial_let(): + domain = im.domain("cartesian_domain", {IDim: (0, 1)}) + offset_provider = {} + testee = program_factory( params=[im.sym("inp", i_field_type), im.sym("out", i_field_type)], - closures=[ - ir.StencilClosure( - domain=im.call("cartesian_domain")(), - stencil=im.lambda_("a")( - im.call( - im.call("scan")( - im.lambda_("carry", "b")(im.plus("carry", im.deref("b"))), - True, - im.literal_from_value(0.0), - ) - )( - im.lift( - im.call("scan")( - im.lambda_("carry", "c")(im.plus("carry", im.deref("c"))), - False, - im.literal_from_value(0.0), - ) - )("a") - ) - ), - output=im.ref("out"), - inputs=[im.ref("inp")], + body=[ + itir.SetAt( + target=im.ref("out"), + expr=im.let("tmp", im.as_fieldop("deref", domain)("inp"))( + im.as_fieldop("deref", domain)("tmp")), + domain=domain ) - ], + ] ) + testee = type_inference.infer(testee, offset_provider=offset_provider) - expected = ir.FencilDefinition( - id="f", - function_definitions=[], - params=[ - im.sym("inp", i_field_type), - im.sym("out", i_field_type), - im.sym("_tmp_1", i_field_type), - im.sym("_gtmp_auto_domain", ts.DeferredType(constraint=None)), - ], - closures=[ - ir.StencilClosure( - domain=AUTO_DOMAIN, - stencil=im.call("scan")( - im.lambda_("carry", "c")(im.plus("carry", im.deref("c"))), - False, - im.literal_from_value(0.0), - ), - output=im.ref("_tmp_1"), - inputs=[im.ref("inp")], - ), - ir.StencilClosure( - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - stencil=im.lambda_("a", "_tmp_1")( - im.call( - im.call("scan")( - im.lambda_("carry", "b")(im.plus("carry", im.deref("b"))), - True, - im.literal_from_value(0.0), - ) - )("_tmp_1") - ), - output=im.ref("out"), - inputs=[im.ref("inp"), im.ref("_tmp_1")], + expected = program_factory( + params=[im.sym("inp", i_field_type), im.sym("out", i_field_type)], + declarations=[itir.Temporary( + id="__tmp_1", + domain=domain, + dtype=float_type + )], + body=[ + itir.SetAt( + target=im.ref("__tmp_1"), + expr=im.as_fieldop("deref", domain)("inp"), + domain=domain ), - ], + itir.SetAt( + target=im.ref("out"), + expr=im.as_fieldop("deref", domain)("__tmp_1"), + domain=domain + ) + ] ) - actual = split_closures(testee, offset_provider={}) - assert actual.tmps == [ir.Temporary(id="_tmp_1", dtype=float_type)] - assert actual.fencil == expected + actual = global_tmps.create_global_tmps(testee) + assert actual == expected -def test_update_cartesian_domains(): - testee = FencilWithTemporaries( - fencil=ir.FencilDefinition( - id="f", - function_definitions=[], - params=[ - im.sym("i", index_type), - im.sym("j", index_type), - im.sym("k", index_type), - im.sym("inp", i_field_type), - im.sym("out", i_field_type), - im.sym("_gtmp_0", i_field_type), - im.sym("_gtmp_1", i_field_type), - im.sym("_gtmp_auto_domain", ts.DeferredType(constraint=None)), - ], - closures=[ - ir.StencilClosure( - domain=AUTO_DOMAIN, - stencil=im.lambda_("foo_inp")(im.deref("foo_inp")), - output=im.ref("_gtmp_1"), - inputs=[im.ref("inp")], - ), - ir.StencilClosure( - domain=AUTO_DOMAIN, - stencil=im.ref("deref"), - output=im.ref("_gtmp_0"), - inputs=[im.ref("_gtmp_1")], - ), - ir.StencilClosure( - domain=im.call("cartesian_domain")( - *( - im.call("named_range")( - ir.AxisLiteral(value=a), - im.literal("0", ir.INTEGER_INDEX_BUILTIN), - im.ref(s), - ) - for a, s in (("IDim", "i"), ("JDim", "j"), ("KDim", "k")) - ) - ), - stencil=im.lambda_("baz_inp", "_lift_2")(im.deref(im.shift("I", 1)("_lift_2"))), - output=im.ref("out"), - inputs=[im.ref("inp"), im.ref("_gtmp_0")], - ), - ], - ), - params=[im.sym("i"), im.sym("j"), im.sym("k"), im.sym("inp"), im.sym("out")], - tmps=[ir.Temporary(id="_gtmp_0"), ir.Temporary(id="_gtmp_1")], - ) - expected = copy.deepcopy(testee) - assert expected.fencil.params.pop() == im.sym("_gtmp_auto_domain") - expected.fencil.closures[0].domain = ir.FunCall( - fun=im.ref("cartesian_domain"), - args=[ - ir.FunCall( - fun=im.ref("named_range"), - args=[ - ir.AxisLiteral(value="IDim"), - im.plus( - im.literal("0", ir.INTEGER_INDEX_BUILTIN), - im.literal("1", ir.INTEGER_INDEX_BUILTIN), - ), - im.plus(im.ref("i"), im.literal("1", ir.INTEGER_INDEX_BUILTIN)), - ], +def test_top_level_if(): + domain = im.domain("cartesian_domain", {IDim: (0, 1)}) + offset_provider = {} + testee = program_factory( + params=[im.sym("inp1", i_field_type), im.sym("inp2", i_field_type), im.sym("out", i_field_type)], + body=[ + itir.SetAt( + target=im.ref("out"), + expr=im.if_(True, im.as_fieldop("deref", domain)("inp1"), im.as_fieldop("deref", domain)("inp2")), + domain=domain ) ] - + [ - ir.FunCall( - fun=im.ref("named_range"), - args=[ - ir.AxisLiteral(value=a), - im.literal("0", ir.INTEGER_INDEX_BUILTIN), - im.ref(s), - ], - ) - for a, s in (("JDim", "j"), ("KDim", "k")) - ], ) - expected.fencil.closures[1].domain = ir.FunCall( - fun=im.ref("cartesian_domain"), - args=[ - ir.FunCall( - fun=im.ref("named_range"), - args=[ - ir.AxisLiteral(value="IDim"), - im.plus( - im.literal("0", ir.INTEGER_INDEX_BUILTIN), - im.literal("1", ir.INTEGER_INDEX_BUILTIN), - ), - im.plus(im.ref("i"), im.literal("1", ir.INTEGER_INDEX_BUILTIN)), + testee = type_inference.infer(testee, offset_provider=offset_provider) + + expected = program_factory( + params=[im.sym("inp1", i_field_type), im.sym("inp2", i_field_type), im.sym("out", i_field_type)], + declarations=[], + body=[ + itir.IfStmt( + cond=im.literal_from_value(True), + true_branch=[ + itir.SetAt( + target=im.ref("out"), + expr=im.as_fieldop("deref", domain)("inp1"), + domain=domain + ) ], + false_branch=[ + itir.SetAt( + target=im.ref("out"), + expr=im.as_fieldop("deref", domain)("inp2"), + domain=domain + ) + ] ) ] - + [ - ir.FunCall( - fun=im.ref("named_range"), - args=[ - ir.AxisLiteral(value=a), - im.literal("0", ir.INTEGER_INDEX_BUILTIN), - im.ref(s), - ], - ) - for a, s in (("JDim", "j"), ("KDim", "k")) - ], ) - actual = update_domains(testee, {"I": gtx.Dimension("IDim")}, symbolic_sizes=None) + + actual = global_tmps.create_global_tmps(testee) assert actual == expected -def test_collect_tmps_info(): - tmp_domain = ir.FunCall( - fun=im.ref("cartesian_domain"), - args=[ - ir.FunCall( - fun=im.ref("named_range"), - args=[ - ir.AxisLiteral(value="IDim"), - im.literal("0", ir.INTEGER_INDEX_BUILTIN), - ir.FunCall( - fun=im.ref("plus"), - args=[im.ref("i"), im.literal("1", ir.INTEGER_INDEX_BUILTIN)], - ), - ], +def test_nested_if(): + domain = im.domain("cartesian_domain", {IDim: (0, 1)}) + offset_provider = {} + testee = program_factory( + params=[im.sym("inp1", i_field_type), im.sym("inp2", i_field_type), im.sym("out", i_field_type)], + body=[ + itir.SetAt( + target=im.ref("out"), + expr=im.as_fieldop("deref", domain)( + im.if_(True, im.as_fieldop("deref", domain)("inp1"), im.as_fieldop("deref", domain)("inp2"))), + domain=domain ) ] - + [ - ir.FunCall( - fun=im.ref("named_range"), - args=[ - ir.AxisLiteral(value=a), - im.literal("0", ir.INTEGER_INDEX_BUILTIN), - im.ref(s), + ) + testee = type_inference.infer(testee, offset_provider=offset_provider) + + expected = program_factory( + params=[im.sym("inp1", i_field_type), im.sym("inp2", i_field_type), im.sym("out", i_field_type)], + declarations=[itir.Temporary( + id="__tmp_1", + domain=domain, + dtype=float_type + )], + body=[ + itir.IfStmt( + cond=im.literal_from_value(True), + true_branch=[ + itir.SetAt( + target=im.ref("__tmp_1"), + expr=im.as_fieldop("deref", domain)("inp1"), + domain=domain + ) ], + false_branch=[ + itir.SetAt( + target=im.ref("__tmp_1"), + expr=im.as_fieldop("deref", domain)("inp2"), + domain=domain + ) + ] + ), + itir.SetAt( + target=im.ref("out"), + expr=im.as_fieldop("deref", domain)("__tmp_1"), + domain=domain ) - for a, s in (("JDim", "j"), ("KDim", "k")) - ], + ] ) - i = im.sym("i", index_type) - j = im.sym("j", index_type) - k = im.sym("k", index_type) - inp = im.sym("inp", i_field_type) - out = im.sym("out", i_field_type) - - testee = FencilWithTemporaries( - fencil=ir.FencilDefinition( - id="f", - function_definitions=[], - params=[ - i, - j, - k, - inp, - out, - im.sym("_gtmp_0", i_field_type), - im.sym("_gtmp_1", i_field_type), - ], - closures=[ - ir.StencilClosure( - domain=tmp_domain, - stencil=ir.Lambda( - params=[ir.Sym(id="foo_inp")], - expr=ir.FunCall(fun=im.ref("deref"), args=[im.ref("foo_inp")]), - ), - output=im.ref("_gtmp_1"), - inputs=[im.ref("inp")], - ), - ir.StencilClosure( - domain=tmp_domain, - stencil=im.ref("deref"), - output=im.ref("_gtmp_0"), - inputs=[im.ref("_gtmp_1")], - ), - ir.StencilClosure( - domain=ir.FunCall( - fun=im.ref("cartesian_domain"), - args=[ - ir.FunCall( - fun=im.ref("named_range"), - args=[ - ir.AxisLiteral(value=a), - im.literal("0", ir.INTEGER_INDEX_BUILTIN), - im.ref(s), - ], - ) - for a, s in (("IDim", "i"), ("JDim", "j"), ("KDim", "k")) - ], - ), - stencil=ir.Lambda( - params=[ir.Sym(id="baz_inp"), ir.Sym(id="_lift_2")], - expr=ir.FunCall( - fun=im.ref("deref"), - args=[ - ir.FunCall( - fun=ir.FunCall( - fun=im.ref("shift"), - args=[ - ir.OffsetLiteral(value="I"), - ir.OffsetLiteral(value=1), - ], - ), - args=[im.ref("_lift_2")], - ) - ], - ), - ), - output=im.ref("out"), - inputs=[im.ref("inp"), im.ref("_gtmp_0")], - ), - ], - ), - params=[i, j, k, inp, out], - tmps=[ - ir.Temporary(id="_gtmp_0", dtype=float_type), - ir.Temporary(id="_gtmp_1", dtype=float_type), - ], - ) - expected = FencilWithTemporaries( - fencil=testee.fencil, - params=testee.params, - tmps=[ - ir.Temporary(id="_gtmp_0", domain=tmp_domain, dtype=float_type), - ir.Temporary(id="_gtmp_1", domain=tmp_domain, dtype=float_type), - ], - ) - actual = collect_tmps_info(testee, offset_provider={"I": IDim, "J": JDim, "K": KDim}) + actual = global_tmps.create_global_tmps(testee) assert actual == expected + From d4f066a764f5651affeb6d7cb734c47e93795d48 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 2 Oct 2024 03:34:41 +0200 Subject: [PATCH 006/150] Cleanup --- .../next/iterator/transforms/infer_domain.py | 4 +- .../transforms_tests/test_domain_inference.py | 46 +++++++++---------- 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index 1d789406fa..c1a743af1c 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -97,7 +97,7 @@ def _merge_domains( original_domain, domain = _canonicalize_domain_structure( original_domains.get(key, None), domain ) - new_domains[key] = tree_map(domain_union_with_none)(original_domain, domain) + new_domains[key] = tree_map(_domain_union_with_none)(original_domain, domain) return new_domains @@ -118,7 +118,7 @@ def _extract_accessed_domains( for shift in shifts_list ] # `None` means field is never accessed - accessed_domains[in_field_id] = domain_union_with_none( + accessed_domains[in_field_id] = _domain_union_with_none( accessed_domains.get(in_field_id, None), *new_domains ) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py index 79456e4d85..5d13337a94 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py @@ -16,7 +16,7 @@ from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator import ir as itir from gt4py.next.iterator.transforms import infer_domain -from gt4py.next.iterator.transforms.global_tmps import SymbolicDomain +from gt4py.next.iterator.ir_utils import domain_utils from gt4py.next.common import Dimension from gt4py.next import common, NeighborTableOffsetProvider from gt4py.next.type_system import type_specifications as ts @@ -86,7 +86,7 @@ def run_test_expr( offset_provider: common.OffsetProvider, ): actual_call, actual_domains = infer_domain.infer_expr( - testee, SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider ) folded_call = constant_fold_domain_exprs(actual_call) folded_domains = constant_fold_accessed_domains(actual_domains) if actual_domains else None @@ -122,7 +122,7 @@ def constant_fold_domain_exprs(arg: itir.Node) -> itir.Node: def constant_fold_accessed_domains( domains: infer_domain.ACCESSED_DOMAINS, ) -> infer_domain.ACCESSED_DOMAINS: - def fold_domain(domain: SymbolicDomain | None): + def fold_domain(domain: domain_utils.SymbolicDomain | None): if domain is None: return domain return constant_fold_domain_exprs(domain.as_expr()) @@ -134,7 +134,7 @@ def translate_domain( domain: itir.FunCall, shifts: dict[str, tuple[itir.Expr, itir.Expr]], offset_provider: common.OffsetProvider, -) -> SymbolicDomain: +) -> domain_utils.SymbolicDomain: shift_tuples = [ ( im.ensure_offset(d), @@ -145,7 +145,7 @@ def translate_domain( shift_list = [item for sublist in shift_tuples for item in sublist] - translated_domain_expr = SymbolicDomain.from_expr(domain).translate(shift_list, offset_provider) + translated_domain_expr = domain_utils.SymbolicDomain.from_expr(domain).translate(shift_list, offset_provider) return constant_fold_domain_exprs(translated_domain_expr.as_expr()) @@ -330,7 +330,7 @@ def test_nested_stencils(offset_provider): "in_field2": translate_domain(domain, {"Ioff": 0, "Joff": -2}, offset_provider), } actual_call, actual_domains = infer_domain.infer_expr( - testee, SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider ) folded_domains = constant_fold_accessed_domains(actual_domains) folded_call = constant_fold_domain_exprs(actual_call) @@ -374,7 +374,7 @@ def test_nested_stencils_n_times(offset_provider, iterations): } actual_call, actual_domains = infer_domain.infer_expr( - testee, SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider ) folded_domains = constant_fold_accessed_domains(actual_domains) @@ -512,7 +512,7 @@ def test_cond(offset_provider): expected = im.if_(cond, expected_field_1, expected_field_2) actual_call, actual_domains = infer_domain.infer_expr( - testee, SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider ) folded_domains = constant_fold_accessed_domains(actual_domains) @@ -569,7 +569,7 @@ def test_let(offset_provider): expected_domains_sym = {"in_field": translate_domain(domain, {"Ioff": 2}, offset_provider)} actual_call2, actual_domains2 = infer_domain.infer_expr( - testee2, SymbolicDomain.from_expr(domain), offset_provider + testee2, domain_utils.SymbolicDomain.from_expr(domain), offset_provider ) folded_domains2 = constant_fold_accessed_domains(actual_domains2) folded_call2 = constant_fold_domain_exprs(actual_call2) @@ -789,7 +789,7 @@ def test_make_tuple(offset_provider): actual, actual_domains = infer_domain.infer_expr( testee, - (SymbolicDomain.from_expr(domain1), SymbolicDomain.from_expr(domain2)), + (domain_utils.SymbolicDomain.from_expr(domain1), domain_utils.SymbolicDomain.from_expr(domain2)), offset_provider, ) @@ -808,7 +808,7 @@ def test_tuple_get_1_make_tuple(offset_provider): } actual, actual_domains = infer_domain.infer_expr( - testee, SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider ) assert expected == actual @@ -824,7 +824,7 @@ def test_tuple_get_1_nested_make_tuple(offset_provider): actual, actual_domains = infer_domain.infer_expr( testee, - (SymbolicDomain.from_expr(domain1), SymbolicDomain.from_expr(domain2)), + (domain_utils.SymbolicDomain.from_expr(domain1), domain_utils.SymbolicDomain.from_expr(domain2)), offset_provider, ) @@ -840,7 +840,7 @@ def test_tuple_get_let_arg_make_tuple(offset_provider): actual, actual_domains = infer_domain.infer_expr( testee, - SymbolicDomain.from_expr(im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)})), + domain_utils.SymbolicDomain.from_expr(im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)})), offset_provider, ) @@ -856,7 +856,7 @@ def test_tuple_get_let_make_tuple(offset_provider): actual, actual_domains = infer_domain.infer_expr( testee, - SymbolicDomain.from_expr(domain), + domain_utils.SymbolicDomain.from_expr(domain), offset_provider, ) @@ -877,10 +877,10 @@ def test_nested_make_tuple(offset_provider): testee, ( ( - SymbolicDomain.from_expr(domain1), - (SymbolicDomain.from_expr(domain2_1), SymbolicDomain.from_expr(domain2_2)), + domain_utils.SymbolicDomain.from_expr(domain1), + (domain_utils.SymbolicDomain.from_expr(domain2_1), domain_utils.SymbolicDomain.from_expr(domain2_2)), ), - SymbolicDomain.from_expr(domain3), + domain_utils.SymbolicDomain.from_expr(domain3), ), offset_provider, ) @@ -896,7 +896,7 @@ def test_tuple_get_1(offset_provider): expected_domains = {"a": (None, domain)} actual, actual_domains = infer_domain.infer_expr( - testee, SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider ) assert expected == actual @@ -912,7 +912,7 @@ def test_domain_tuple(offset_provider): actual, actual_domains = infer_domain.infer_expr( testee, - (SymbolicDomain.from_expr(domain1), SymbolicDomain.from_expr(domain2)), + (domain_utils.SymbolicDomain.from_expr(domain1), domain_utils.SymbolicDomain.from_expr(domain2)), offset_provider, ) @@ -929,7 +929,7 @@ def test_as_fieldop_tuple_get(offset_provider): expected_domains = {"a": (domain, domain)} actual, actual_domains = infer_domain.infer_expr( - testee, SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider ) assert expected == actual @@ -945,7 +945,7 @@ def test_make_tuple_2tuple_get(offset_provider): actual, actual_domains = infer_domain.infer_expr( testee, - (SymbolicDomain.from_expr(domain1), SymbolicDomain.from_expr(domain2)), + (domain_utils.SymbolicDomain.from_expr(domain1), domain_utils.SymbolicDomain.from_expr(domain2)), offset_provider, ) @@ -963,7 +963,7 @@ def test_make_tuple_non_tuple_domain(offset_provider): expected_domains = {"in_field1": domain, "in_field2": domain} actual, actual_domains = infer_domain.infer_expr( - testee, SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider ) assert expected == actual @@ -977,7 +977,7 @@ def test_arithmetic_builtin(offset_provider): expected_domains = {} actual_call, actual_domains = infer_domain.infer_expr( - testee, SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider ) folded_call = constant_fold_domain_exprs(actual_call) From c1038facbeeb553e7537bf24ecf84525c17d40f0 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 3 Oct 2024 10:53:24 +0200 Subject: [PATCH 007/150] Address review comments --- src/gt4py/next/iterator/embedded.py | 12 ++++++++++++ src/gt4py/next/iterator/type_system/inference.py | 3 ++- .../iterator_tests/test_if_stmt.py | 3 +-- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 997851d0b7..afe0cec402 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -1605,6 +1605,18 @@ def set_at(expr: common.Field, domain: common.DomainLike, target: common.Mutable @runtime.if_stmt.register(EMBEDDED) def if_stmt(cond: bool, true_branch: Callable[[], None], false_branch: Callable[[], None]) -> None: + """ + (Stateful) if statement. + + The two branches are represented as lambda functions, such that they are not executed eagerly. + This is required to avoid out-of-bounds accesses. Note that a dedicated built-in is required, + contrary to using a plain python if-stmt, such that tracing / double roundtrip works. + + Arguments: + cond: The condition to decide which branch to execute. + true_branch: A lambda function to be executed when `cond` is `True`. + false_branch: A lambda function to be executed when `cond` is `False`. + """ if cond: true_branch() else: diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index c141c80999..47c04def3e 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -500,7 +500,8 @@ def visit_Temporary(self, node: itir.Temporary, *, ctx) -> ts.FieldType | ts.Tup ) def visit_IfStmt(self, node: itir.IfStmt, *, ctx) -> None: - self.visit(node.cond, ctx=ctx) # TODO: check is boolean + assert node.cond == ts.ScalarType(kind=ts.ScalarKind.BOOL) + self.visit(node.cond, ctx=ctx) self.visit(node.true_branch, ctx=ctx) self.visit(node.false_branch, ctx=ctx) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_if_stmt.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_if_stmt.py index 1507def2c2..2dde7d7653 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_if_stmt.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_if_stmt.py @@ -65,12 +65,11 @@ def fencil(cond1, inp, out): ) rng = np.random.default_rng() - cond = False inp = gtx.as_field([IDim], rng.normal(size=size)) out = gtx.as_field([IDim], np.zeros(size)) ref = inp if cond else 2.0 * inp - run_processor(fencil, program_processor, False, inp, out) + run_processor(fencil, program_processor, cond, inp, out) if validate: assert np.allclose(out.asnumpy(), ref.asnumpy()) From 160aeaf8e04781240f77b6524041f44aa6932d51 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 3 Oct 2024 13:48:56 +0200 Subject: [PATCH 008/150] Add type inference test for IfStmt --- .../next/iterator/type_system/inference.py | 4 ++-- .../iterator_tests/test_type_inference.py | 22 +++++++++++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 47c04def3e..bc1095dfb8 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -500,8 +500,8 @@ def visit_Temporary(self, node: itir.Temporary, *, ctx) -> ts.FieldType | ts.Tup ) def visit_IfStmt(self, node: itir.IfStmt, *, ctx) -> None: - assert node.cond == ts.ScalarType(kind=ts.ScalarKind.BOOL) - self.visit(node.cond, ctx=ctx) + cond = self.visit(node.cond, ctx=ctx) + assert cond == ts.ScalarType(kind=ts.ScalarKind.BOOL) self.visit(node.true_branch, ctx=ctx) self.visit(node.false_branch, ctx=ctx) diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index acfb1d0bd8..05cd6b6854 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -456,3 +456,25 @@ def test_program_tuple_setat_short_target(): isinstance(result.body[0].target.type, ts.TupleType) and len(result.body[0].target.type.types) == 1 ) + + +def test_if_stmt(): + cartesian_domain = im.call("cartesian_domain")( + im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) + ) + + testee = itir.IfStmt( + cond=im.literal_from_value(True), + true_branch=[ + itir.SetAt( + expr=im.as_fieldop("deref", cartesian_domain)(im.ref("inp", float_i_field)), + domain=cartesian_domain, + target=im.ref("out", float_i_field), + ) + ], + false_branch=[], + ) + + result = itir_type_inference.infer(testee, offset_provider={}, allow_undeclared_symbols=True) + assert result.cond.type == bool_type + assert result.true_branch[0].expr.type == float_i_field From 19a2c5e53a023f197793884a3455a966f106cb32 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 10 Oct 2024 00:20:15 +0200 Subject: [PATCH 009/150] Cleanup --- src/gt4py/next/iterator/ir_utils/ir_makers.py | 2 +- .../iterator/transforms/fuse_as_fieldop.py | 60 ++++++++++--------- .../next/iterator/type_system/inference.py | 11 ++++ .../transforms_tests/test_fuse_as_fieldop.py | 14 ++--- 4 files changed, 51 insertions(+), 36 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index b2662fa278..19e26f24b6 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -446,7 +446,7 @@ def domain( ) -def as_fieldop(expr: itir.Expr, domain: Optional[itir.FunCall] = None) -> call: +def as_fieldop(expr: itir.Expr, domain: Optional[itir.Expr] = None) -> call: """ Create an `as_fieldop` call. diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index 6922857dbd..573a84b79f 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -14,17 +14,20 @@ from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator.transforms import inline_lambdas, inline_lifts, trace_shifts -from gt4py.next.iterator.type_system import inference as type_inference, type_specifications as ts -from gt4py.next.type_system import type_info +from gt4py.next.iterator.type_system import ( + inference as type_inference, + type_specifications as it_ts, +) +from gt4py.next.type_system import type_info, type_specifications as ts -def inline_as_fieldop_arg(arg, uids): +def _inline_as_fieldop_arg(arg: itir.Expr, uids: eve_utils.UIDGenerator): assert cpm.is_applied_as_fieldop(arg) - arg = canonicalize_as_fieldop(arg) + arg = _canonicalize_as_fieldop(arg) - stencil, *_ = arg.fun.args + stencil, *_ = arg.fun.args # type: ignore[attr-defined] # ensured by `is_applied_as_fieldop` inner_args = arg.args - extracted_args = {} # mapping from stencil param to arg + extracted_args: dict[str, itir.Expr] = {} # mapping from stencil param to arg stencil_params = [] stencil_body = stencil.expr @@ -33,7 +36,9 @@ def inline_as_fieldop_arg(arg, uids): if isinstance(inner_arg, itir.SymRef): stencil_params.append(inner_param) extracted_args[inner_arg.id] = inner_arg - elif isinstance(inner_arg, itir.Literal): # TODO: all non capturing scalars + # note: only literals, not all scalar expressions are required as it doesn't make sense + # for them to be computed per grid point. + elif isinstance(inner_arg, itir.Literal): stencil_body = im.let(inner_param, im.promote_to_const_iterator(inner_arg))( stencil_body ) @@ -47,7 +52,7 @@ def inline_as_fieldop_arg(arg, uids): ), extracted_args -def merge_arguments(args1: dict, arg2: dict): +def _merge_arguments(args1: dict, arg2: dict): new_args = {**args1} for stencil_param, stencil_arg in arg2.items(): if stencil_param not in new_args: @@ -57,11 +62,11 @@ def merge_arguments(args1: dict, arg2: dict): return new_args -def canonicalize_as_fieldop(expr: itir.Expr) -> itir.Expr: +def _canonicalize_as_fieldop(expr: itir.FunCall) -> itir.FunCall: assert cpm.is_applied_as_fieldop(expr) - stencil = expr.fun.args[0] - domain = expr.fun.args[1] if len(expr.fun.args) > 1 else None + stencil = expr.fun.args[0] # type: ignore[attr-defined] + domain = expr.fun.args[1] if len(expr.fun.args) > 1 else None # type: ignore[attr-defined] if cpm.is_ref_to(stencil, "deref"): stencil = im.lambda_("arg")(im.deref("arg")) new_expr = im.as_fieldop(stencil, domain)(*expr.args) @@ -98,25 +103,27 @@ def visit_FunCall(self, node: itir.FunCall): node = self.generic_visit(node) if cpm.is_call_to(node.fun, "as_fieldop"): - node = canonicalize_as_fieldop(node) + node = _canonicalize_as_fieldop(node) if cpm.is_call_to(node.fun, "as_fieldop") and isinstance(node.fun.args[0], itir.Lambda): - stencil = node.fun.args[0] + stencil: itir.Lambda = node.fun.args[0] domain = node.fun.args[1] if len(node.fun.args) > 1 else None shifts = trace_shifts.trace_stencil(stencil) - args = node.args + args: list[itir.Expr] = node.args - new_args = {} - new_stencil_body = stencil.expr + new_args: dict[str, itir.Expr] = {} + new_stencil_body: itir.Expr = stencil.expr for stencil_param, arg, arg_shifts in zip(stencil.params, args, shifts, strict=True): + assert isinstance(arg.type, ts.TypeSpec) dtype = type_info.extract_dtype(arg.type) + # TODO(tehrengruber): make this configurable should_inline = isinstance(arg, itir.Literal) or ( isinstance(arg, itir.FunCall) - and (cpm.is_call_to(arg.fun, "as_fieldop") or cpm.is_call_to(arg, "cond")) - and (isinstance(dtype, ts.ListType) or len(arg_shifts) <= 1) + and (cpm.is_call_to(arg.fun, "as_fieldop") or cpm.is_call_to(arg, "if_")) + and (isinstance(dtype, it_ts.ListType) or len(arg_shifts) <= 1) ) if should_inline: if cpm.is_applied_as_fieldop(arg): @@ -130,14 +137,13 @@ def visit_FunCall(self, node: itir.FunCall): else: raise NotImplementedError() - inline_expr, extracted_args = inline_as_fieldop_arg(arg, self.uids) + inline_expr, extracted_args = _inline_as_fieldop_arg(arg, self.uids) new_stencil_body = im.let(stencil_param, inline_expr)(new_stencil_body) - new_args = merge_arguments(new_args, extracted_args) + new_args = _merge_arguments(new_args, extracted_args) else: - # see test_tuple_with_local_field_in_reduction_shifted for ex where assert fails - # assert not isinstance(dtype, ts.ListType) + new_param: str if isinstance( arg, itir.SymRef ): # use name from outer scope (optional, just to get a nice IR) @@ -145,21 +151,21 @@ def visit_FunCall(self, node: itir.FunCall): new_stencil_body = im.let(stencil_param.id, arg.id)(new_stencil_body) else: new_param = stencil_param.id - new_args = merge_arguments(new_args, {new_param: arg}) + new_args = _merge_arguments(new_args, {new_param: arg}) + # simplify stencil directly to keep the tree small new_stencil_body = inline_lambdas.InlineLambdas.apply( new_stencil_body, opcount_preserving=True, + # TODO(tehrengruber): Revisit. Set to False for now to expose cases where we end + # up with lifts remaining in the IR. force_inline_lift_args=False, - # If trivial lifts are not inlined we might create temporaries for constants. In all - # other cases we want it anyway. - force_inline_trivial_lift_args=True, ) new_stencil_body = inline_lifts.InlineLifts().visit(new_stencil_body) new_node = im.as_fieldop(im.lambda_(*new_args.keys())(new_stencil_body), domain)( *new_args.values() ) - new_node.type = node.type + type_inference.copy_type(from_=node, to=new_node) return new_node return node diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 465669245f..76ea543d69 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -96,6 +96,17 @@ def _set_node_type(node: itir.Node, type_: ts.TypeSpec) -> None: node.type = type_ +def copy_type(from_: itir.Node, to: itir.Node) -> itir.Node: + """ + Copy type from one node to another. + + This function mainly exists for readability reasons. + """ + assert isinstance(from_.type, ts.TypeSpec) + _set_node_type(to, from_.type) + return to + + def on_inferred(callback: Callable, *args: Union[ts.TypeSpec, ObservableTypeSynthesizer]) -> None: """ Execute `callback` as soon as all `args` have a type. diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py index bc371f64de..082f16b3a1 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py @@ -30,19 +30,17 @@ def _impl(*its: itir.Expr) -> itir.FunCall: def test_trivial(): - d = im.domain("cartesian_domain", {}) + d = im.domain("cartesian_domain", {IDim: (0, 1)}) testee = im.op_as_fieldop("plus", d)( im.op_as_fieldop("multiplies", d)(im.ref("inp1", field_type), im.ref("inp2", field_type)), im.ref("inp3", field_type), ) expected = im.as_fieldop( - im.lambda_("inp1", "inp2", "inp3", "__iasfop_4")( - im.plus( - im.multiplies_(im.deref("__iasfop_1"), im.deref("__iasfop_2")), - ) + im.lambda_("inp1", "inp2", "inp3")( + im.plus(im.multiplies_(im.deref("inp1"), im.deref("inp2")), im.deref("inp3")) ), d, - )(1, 2, 3, 4) + )(im.ref("inp1", field_type), im.ref("inp2", field_type), im.ref("inp3", field_type)) actual = fuse_as_fieldop.FuseAsFieldOp.apply( testee, offset_provider={}, allow_undeclared_symbols=True ) @@ -98,8 +96,8 @@ def test_partial_inline(): d1 = im.domain("cartesian_domain", {IDim: (1, 2)}) d2 = im.domain("cartesian_domain", {IDim: (0, 3)}) testee = im.as_fieldop( - # first argument used at multiple locations -> not inlined - # second argument only used at a single location -> inlined + # first argument read at multiple locations -> not inlined + # second argument only reat at a single location -> inlined im.lambda_("a", "b")( im.plus( im.plus(im.deref(im.shift("IOff", 1)("a")), im.deref(im.shift("IOff", -1)("a"))), From b50b85a1b5ac2310ccab7ce525e8fc42bbb0f2f5 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 10 Oct 2024 00:40:03 +0200 Subject: [PATCH 010/150] Cleanup --- .../iterator_tests/transforms_tests/test_fuse_as_fieldop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py index 082f16b3a1..569843a0a5 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py @@ -59,7 +59,7 @@ def test_trivial_literal(): def test_symref_used_twice(): d = im.domain("cartesian_domain", {IDim: (0, 1)}) - testee = op_asfieldop2(im.lambda_("a", "b")(im.plus("a", "b")), d)( + testee = im.as_fieldop(im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), d)( im.as_fieldop(im.lambda_("c", "d")(im.multiplies_(im.deref("c"), im.deref("d"))), d)( im.ref("inp1", field_type), im.ref("inp2", field_type) ), From bb2f2b1695a2a0224e25c020dc13b4c9ab0786aa Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 10 Oct 2024 00:46:51 +0200 Subject: [PATCH 011/150] Cleanup --- .../iterator/transforms/fuse_as_fieldop.py | 26 ++++++++++++++++--- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index 573a84b79f..8f39236961 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -79,6 +79,27 @@ def _canonicalize_as_fieldop(expr: itir.FunCall) -> itir.FunCall: @dataclasses.dataclass class FuseAsFieldOp(eve.NodeTranslator): + """ + Merge multiple `as_fieldop` calls into one. + + >>> from gt4py import next as gtx + >>> from gt4py.next.iterator.ir_utils import ir_makers as im + >>> IDim = gtx.Dimension("IDim") + >>> field_type = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32)) + >>> d = im.domain("cartesian_domain", {IDim: (0, 1)}) + >>> nested_as_fieldop = im.op_as_fieldop("plus", d)( + ... im.op_as_fieldop("multiplies", d)(im.ref("inp1", field_type), im.ref("inp2", field_type)), + ... im.ref("inp3", field_type), + ... ) + >>> print(nested_as_fieldop) + as_fieldop(λ(__arg0, __arg1) → ·__arg0 + ·__arg1, c⟨ IDimₕ: [0, 1) ⟩)( + as_fieldop(λ(__arg0, __arg1) → ·__arg0 × ·__arg1, c⟨ IDimₕ: [0, 1) ⟩)(inp1, inp2), inp3 + ) + >>> print(FuseAsFieldOp.apply( + ... nested_as_fieldop, offset_provider={}, allow_undeclared_symbols=True + ... )) + as_fieldop(λ(inp1, inp2, inp3) → ·inp1 × ·inp2 + ·inp3, c⟨ IDimₕ: [0, 1) ⟩)(inp1, inp2, inp3) + """ uids: eve_utils.UIDGenerator @classmethod @@ -156,10 +177,7 @@ def visit_FunCall(self, node: itir.FunCall): # simplify stencil directly to keep the tree small new_stencil_body = inline_lambdas.InlineLambdas.apply( new_stencil_body, - opcount_preserving=True, - # TODO(tehrengruber): Revisit. Set to False for now to expose cases where we end - # up with lifts remaining in the IR. - force_inline_lift_args=False, + opcount_preserving=True ) new_stencil_body = inline_lifts.InlineLifts().visit(new_stencil_body) From 729968e78ddd742c7fe310ac1db118cdd7cc0dce Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 10 Oct 2024 00:47:09 +0200 Subject: [PATCH 012/150] Cleanup --- .../next/iterator/transforms/fuse_as_fieldop.py | 16 ++++++++++------ .../transforms_tests/test_fuse_as_fieldop.py | 12 ------------ 2 files changed, 10 insertions(+), 18 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index 8f39236961..225f063342 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -88,18 +88,23 @@ class FuseAsFieldOp(eve.NodeTranslator): >>> field_type = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32)) >>> d = im.domain("cartesian_domain", {IDim: (0, 1)}) >>> nested_as_fieldop = im.op_as_fieldop("plus", d)( - ... im.op_as_fieldop("multiplies", d)(im.ref("inp1", field_type), im.ref("inp2", field_type)), + ... im.op_as_fieldop("multiplies", d)( + ... im.ref("inp1", field_type), im.ref("inp2", field_type) + ... ), ... im.ref("inp3", field_type), ... ) >>> print(nested_as_fieldop) as_fieldop(λ(__arg0, __arg1) → ·__arg0 + ·__arg1, c⟨ IDimₕ: [0, 1) ⟩)( as_fieldop(λ(__arg0, __arg1) → ·__arg0 × ·__arg1, c⟨ IDimₕ: [0, 1) ⟩)(inp1, inp2), inp3 ) - >>> print(FuseAsFieldOp.apply( - ... nested_as_fieldop, offset_provider={}, allow_undeclared_symbols=True - ... )) + >>> print( + ... FuseAsFieldOp.apply( + ... nested_as_fieldop, offset_provider={}, allow_undeclared_symbols=True + ... ) + ... ) as_fieldop(λ(inp1, inp2, inp3) → ·inp1 × ·inp2 + ·inp3, c⟨ IDimₕ: [0, 1) ⟩)(inp1, inp2, inp3) """ + uids: eve_utils.UIDGenerator @classmethod @@ -176,8 +181,7 @@ def visit_FunCall(self, node: itir.FunCall): # simplify stencil directly to keep the tree small new_stencil_body = inline_lambdas.InlineLambdas.apply( - new_stencil_body, - opcount_preserving=True + new_stencil_body, opcount_preserving=True ) new_stencil_body = inline_lifts.InlineLifts().visit(new_stencil_body) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py index 569843a0a5..da2c16336e 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py @@ -17,18 +17,6 @@ field_type = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32)) -def op_asfieldop2(op: str | itir.SymRef | Callable, domain: Optional[itir.FunCall] = None): - assert isinstance(op, itir.Lambda) - op = im.call(op) - - args = [param.id for param in op.fun.params] - - def _impl(*its: itir.Expr) -> itir.FunCall: - return im.as_fieldop(im.lambda_(*args)(op(*[im.deref(arg) for arg in args])), domain)(*its) - - return _impl - - def test_trivial(): d = im.domain("cartesian_domain", {IDim: (0, 1)}) testee = im.op_as_fieldop("plus", d)( From cdcaac05d795d61a528dd9a6e373911bb36694af Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 10 Oct 2024 00:57:15 +0200 Subject: [PATCH 013/150] Cleanup --- src/gt4py/next/iterator/transforms/fuse_as_fieldop.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index 225f063342..f105cabd64 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -82,6 +82,8 @@ class FuseAsFieldOp(eve.NodeTranslator): """ Merge multiple `as_fieldop` calls into one. + # ruff: noqa: RUF002 + >>> from gt4py import next as gtx >>> from gt4py.next.iterator.ir_utils import ir_makers as im >>> IDim = gtx.Dimension("IDim") @@ -103,7 +105,7 @@ class FuseAsFieldOp(eve.NodeTranslator): ... ) ... ) as_fieldop(λ(inp1, inp2, inp3) → ·inp1 × ·inp2 + ·inp3, c⟨ IDimₕ: [0, 1) ⟩)(inp1, inp2, inp3) - """ + """ # noqa: RUF002 # ignore × character ambiguity uids: eve_utils.UIDGenerator From 9472502b3b1996b43990b9ffe28221d178d673bb Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 10 Oct 2024 01:03:35 +0200 Subject: [PATCH 014/150] Cleanup --- .../next/iterator/transforms/fuse_as_fieldop.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index f105cabd64..11950ff5d5 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -23,26 +23,28 @@ def _inline_as_fieldop_arg(arg: itir.Expr, uids: eve_utils.UIDGenerator): assert cpm.is_applied_as_fieldop(arg) - arg = _canonicalize_as_fieldop(arg) + arg: itir.FunCall = _canonicalize_as_fieldop(arg) stencil, *_ = arg.fun.args # type: ignore[attr-defined] # ensured by `is_applied_as_fieldop` - inner_args = arg.args + inner_args: list[itir.Expr] = arg.args extracted_args: dict[str, itir.Expr] = {} # mapping from stencil param to arg - stencil_params = [] - stencil_body = stencil.expr + stencil_params: list[itir.Sym] = [] + stencil_body: itir.Expr = stencil.expr for inner_param, inner_arg in zip(stencil.params, inner_args, strict=True): if isinstance(inner_arg, itir.SymRef): stencil_params.append(inner_param) extracted_args[inner_arg.id] = inner_arg - # note: only literals, not all scalar expressions are required as it doesn't make sense - # for them to be computed per grid point. elif isinstance(inner_arg, itir.Literal): + # note: only literals, not all scalar expressions are required as it doesn't make sense + # for them to be computed per grid point. stencil_body = im.let(inner_param, im.promote_to_const_iterator(inner_arg))( stencil_body ) - else: # either a literal or a previous not inlined arg + else: + # a scalar expression, a previously not inlined `as_fieldop` call or an opaque + # expression e.g. containing a tuple stencil_params.append(inner_param) new_outer_stencil_param = uids.sequential_id(prefix="__iasfop") extracted_args[new_outer_stencil_param] = inner_arg From 01dd86ed245f91141bf92cc1c766d55033dee3ab Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 10 Oct 2024 01:04:36 +0200 Subject: [PATCH 015/150] Cleanup --- src/gt4py/next/iterator/transforms/fuse_as_fieldop.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index 11950ff5d5..b069ce4543 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -23,7 +23,7 @@ def _inline_as_fieldop_arg(arg: itir.Expr, uids: eve_utils.UIDGenerator): assert cpm.is_applied_as_fieldop(arg) - arg: itir.FunCall = _canonicalize_as_fieldop(arg) + arg = _canonicalize_as_fieldop(arg) stencil, *_ = arg.fun.args # type: ignore[attr-defined] # ensured by `is_applied_as_fieldop` inner_args: list[itir.Expr] = arg.args @@ -84,8 +84,6 @@ class FuseAsFieldOp(eve.NodeTranslator): """ Merge multiple `as_fieldop` calls into one. - # ruff: noqa: RUF002 - >>> from gt4py import next as gtx >>> from gt4py.next.iterator.ir_utils import ir_makers as im >>> IDim = gtx.Dimension("IDim") @@ -107,7 +105,7 @@ class FuseAsFieldOp(eve.NodeTranslator): ... ) ... ) as_fieldop(λ(inp1, inp2, inp3) → ·inp1 × ·inp2 + ·inp3, c⟨ IDimₕ: [0, 1) ⟩)(inp1, inp2, inp3) - """ # noqa: RUF002 # ignore × character ambiguity + """ # noqa: RUF002 # ignore ambiguous multiplication character uids: eve_utils.UIDGenerator From e078b362657be54559f307164280bdb3d26df5cb Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 10 Oct 2024 10:45:06 +0200 Subject: [PATCH 016/150] Cleanup --- src/gt4py/next/iterator/transforms/fuse_as_fieldop.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index b069ce4543..4be2fa312f 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -157,6 +157,7 @@ def visit_FunCall(self, node: itir.FunCall): if cpm.is_applied_as_fieldop(arg): pass elif cpm.is_call_to(arg, "if_"): + # TODO(tehrengruber): revisit if we want to inline if_ type_ = arg.type arg = im.op_as_fieldop("if_")(*arg.args) arg.type = type_ From f0331cb0ce9a7d8009f73ecdb6585ddb1f570a06 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 10 Oct 2024 13:10:09 +0200 Subject: [PATCH 017/150] Cleanup --- src/gt4py/next/iterator/transforms/cse.py | 6 +- .../iterator/transforms/fencil_to_program.py | 5 +- .../next/iterator/transforms/global_tmps.py | 206 ++++++++++-------- .../next/iterator/transforms/infer_domain.py | 45 ++-- .../next/iterator/transforms/pass_manager.py | 52 ++--- .../next/iterator/type_system/inference.py | 28 +-- .../next/program_processors/runners/gtfn.py | 4 +- src/gt4py/next/utils.py | 26 ++- .../transforms_tests/test_domain_inference.py | 33 ++- .../transforms_tests/test_global_tmps.py | 149 ++++++------- .../runners_tests/test_gtfn.py | 17 -- 11 files changed, 306 insertions(+), 265 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index a2a169f6e4..518f294cbb 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -32,7 +32,7 @@ @dataclasses.dataclass class _NodeReplacer(PreserveLocationVisitor, NodeTranslator): - PRESERVED_ANNEX_ATTRS = ("type",) + PRESERVED_ANNEX_ATTRS = ("type", "domain") expr_map: dict[int, itir.SymRef] @@ -51,7 +51,9 @@ def visit_FunCall(self, node: itir.FunCall) -> itir.Node: if isinstance(node, itir.FunCall) and isinstance(node.fun, itir.Lambda): eligible_params = [] for arg in node.args: - eligible_params.append(isinstance(arg, itir.SymRef)) # and arg.id.startswith("_cs")) # TODO: document? this is for lets in the global tmp pass, e.g. test_trivial_let + eligible_params.append( + isinstance(arg, itir.SymRef) + ) # and arg.id.startswith("_cs")) # TODO: document? this is for lets in the global tmp pass, e.g. test_trivial_let if any(eligible_params): # note: the inline is opcount preserving anyway so avoid the additional # effort in the inliner by disabling opcount preservation. diff --git a/src/gt4py/next/iterator/transforms/fencil_to_program.py b/src/gt4py/next/iterator/transforms/fencil_to_program.py index e07cbc282a..4ad91645d4 100644 --- a/src/gt4py/next/iterator/transforms/fencil_to_program.py +++ b/src/gt4py/next/iterator/transforms/fencil_to_program.py @@ -9,14 +9,11 @@ from gt4py import eve from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.iterator.transforms import global_tmps class FencilToProgram(eve.NodeTranslator): @classmethod - def apply( - cls, node: itir.FencilDefinition | itir.Program - ) -> itir.Program: + def apply(cls, node: itir.FencilDefinition | itir.Program) -> itir.Program: return cls().visit(node) def visit_StencilClosure(self, node: itir.StencilClosure) -> itir.SetAt: diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 13511f2c63..fbff2e8528 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -8,119 +8,145 @@ from __future__ import annotations -import copy -import dataclasses import functools -from collections.abc import Mapping -from typing import Any, Callable, Final, Iterable, Literal, Optional, Sequence +from typing import Callable, Optional -from gt4py import eve from gt4py.eve import utils as eve_utils +from gt4py.next import common, utils as next_utils from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im -from gt4py.next.type_system import type_info -from gt4py.next.iterator.transforms import inline_lambdas - -# TODO: remove -SimpleTemporaryExtractionHeuristics = None -CreateGlobalTmps = None - -from gt4py.next.iterator.transforms import cse - -class IncompleteTemporary: - expr: itir.Expr - target: itir.Expr - -def get_expr_domain(expr: itir.Expr, ctx=None): - ctx = ctx or {} - - if cpm.is_applied_as_fieldop(expr): - _, domain = expr.fun.args - return domain - elif cpm.is_call_to(expr, "tuple_get"): - idx_expr, tuple_expr = expr.args - assert isinstance(idx_expr, itir.Literal) and type_info.is_integer(idx_expr.type) - idx = int(idx_expr.value) - tuple_expr_domain = get_expr_domain(tuple_expr, ctx) - assert isinstance(tuple_expr_domain, tuple) and idx < len(tuple_expr_domain) - return tuple_expr_domain[idx] - elif cpm.is_call_to(expr, "make_tuple"): - return tuple(get_expr_domain(el, ctx) for el in expr.args) - elif cpm.is_call_to(expr, "if_"): - cond, true_val, false_val = expr.args - true_domain, false_domain = get_expr_domain(true_val, ctx), get_expr_domain(false_val, ctx) - assert true_domain == false_domain - return true_domain - elif cpm.is_let(expr): - new_ctx = {} - for var_name, var_value in zip(expr.fun.params, expr.args, strict=True): - new_ctx[var_name.id] = get_expr_domain(var_value, ctx) - return get_expr_domain(expr.fun.expr, ctx={**ctx, **new_ctx}) - raise ValueError() - - -def transform_if(stmt: itir.SetAt, declarations: list[itir.Temporary], uids: eve_utils.UIDGenerator): +from gt4py.next.iterator.transforms import cse, infer_domain, inline_lambdas +from gt4py.next.iterator.type_system import inference as type_inference +from gt4py.next.type_system import type_info, type_specifications as ts + + +def transform_if( + stmt: itir.Stmt, declarations: list[itir.Temporary], uids: eve_utils.UIDGenerator +) -> Optional[list[itir.Stmt]]: if not isinstance(stmt, itir.SetAt): return None if cpm.is_call_to(stmt.expr, "if_"): cond, true_val, false_val = stmt.expr.args - return [itir.IfStmt( - cond=cond, - # recursively transform - true_branch=transform(itir.SetAt(target=stmt.target, expr=true_val, domain=stmt.domain), declarations, uids), - false_branch=transform(itir.SetAt(target=stmt.target, expr=false_val, domain=stmt.domain), declarations, uids), - )] + return [ + itir.IfStmt( + cond=cond, + # recursively transform + true_branch=transform( + itir.SetAt(target=stmt.target, expr=true_val, domain=stmt.domain), + declarations, + uids, + ), + false_branch=transform( + itir.SetAt(target=stmt.target, expr=false_val, domain=stmt.domain), + declarations, + uids, + ), + ) + ] return None -def transform_by_pattern(stmt: itir.SetAt, predicate, declarations: list[itir.Temporary], uids: eve_utils.UIDGenerator): + +def transform_by_pattern( + stmt: itir.Stmt, predicate, declarations: list[itir.Temporary], uids: eve_utils.UIDGenerator +) -> Optional[list[itir.Stmt]]: if not isinstance(stmt, itir.SetAt): return None new_expr, extracted_fields, _ = cse.extract_subexpression( stmt.expr, predicate=predicate, - uid_generator=uids, - # allows better fusing later on - #deepest_expr_first=True # TODO: better, but not supported right now + uid_generator=eve_utils.UIDGenerator(prefix="__tmp_subexpr"), + # TODO(tehrengruber): extracting the deepest expression first would allow us to fuse + # the extracted expressions resulting in fewer kernel calls, better data-locality. + # Extracting the multiple expressions deepest-first is however not supported right now. + # deepest_expr_first=True # noqa: ERA001 ) if extracted_fields: - new_stmts = [] + tmp_stmts: list[itir.Stmt] = [] + + # for each extracted expression generate: + # - one or more `Temporary` declarations (depending on whether the expression is a field + # or a tuple thereof) + # - one `SetAt` statement that materializes the expression into the temporary for tmp_sym, tmp_expr in extracted_fields.items(): - # TODO: expr domain can not be a tuple here - domain = get_expr_domain(tmp_expr) + domain = tmp_expr.annex.domain + + # TODO(tehrengruber): Implement. This happens when the expression for a combination + # of an `if_` call with a tuple, e.g., `if_(cond, {a, b}, {c, d})`. As long as we are + # able to eliminate all tuples, e.g., by propagating the scalar ifs to the top-level + # of a SetAt, the CollapseTuple pass will eliminate most of this cases. + if isinstance(domain, tuple): + flattened_domains: tuple[itir.Expr] = next_utils.flatten_nested_tuple(domain) # type: ignore[assignment] # mypy not smart enough + if not all(d == flattened_domains[0] for d in flattened_domains): + raise NotImplementedError( + "Tuple expressions with different domains is not " "supported yet." + ) + domain = flattened_domains[0] + + assert isinstance(tmp_expr.type, ts.TypeSpec) + tmp_names: str | tuple[str | tuple, ...] = type_info.apply_to_primitive_constituents( + lambda x: uids.sequential_id(), + tmp_expr.type, + tuple_constructor=lambda *elements: tuple(elements), + ) + tmp_dtypes: ts.ScalarType | tuple[ts.ScalarType | tuple, ...] = ( + type_info.apply_to_primitive_constituents( + type_info.extract_dtype, + tmp_expr.type, + tuple_constructor=lambda *elements: tuple(elements), + ) + ) - scalar_type = type_info.apply_to_primitive_constituents( - type_info.extract_dtype, tmp_expr.type + # allocate temporary for all tuple elements + def allocate_temporary(tmp_name: str, dtype: ts.ScalarType, domain: itir.Expr): + declarations.append(itir.Temporary(id=tmp_name, domain=domain, dtype=dtype)) + + next_utils.tree_map(functools.partial(allocate_temporary, domain=domain))( + tmp_names, tmp_dtypes ) - declarations.append(itir.Temporary(id=tmp_sym.id, domain=domain, dtype=scalar_type)) - # TODO: transform not needed if deepest_expr_first=True - new_stmts.extend(transform(itir.SetAt(target=im.ref(tmp_sym.id), domain=domain, expr=tmp_expr), declarations, uids)) + # if the expr is a field this just gives a simple `itir.SymRef`, otherwise we generate a + # `make_tuple` expression. + target_expr: itir.Expr = next_utils.tree_map( + lambda x: im.ref(x), result_collection_constructor=lambda els: im.make_tuple(*els) + )(tmp_names) # type: ignore[assignment] # typing of tree_map does not reflect action of `result_collection_constructor` yet - return [ - *new_stmts, - itir.SetAt( - target=stmt.target, - domain=stmt.domain, - expr=new_expr + # note: the let would be removed automatically by the `cse.extract_subexpression`, but + # we remove it here for readability & debuggability. + new_expr = inline_lambdas.inline_lambda( + im.let(tmp_sym, target_expr)(new_expr), opcount_preserving=False ) - ] + + # TODO: transform not needed if deepest_expr_first=True + tmp_stmts.extend( + transform( + itir.SetAt(target=target_expr, domain=domain, expr=tmp_expr), declarations, uids + ) + ) + + return [*tmp_stmts, itir.SetAt(target=stmt.target, domain=stmt.domain, expr=new_expr)] return None -def transform(stmt: itir.SetAt, declarations: list[itir.Temporary], uids: eve_utils.UIDGenerator): - # TODO: what happens for a trivial let, e.g `let a=as_fieldop() in a end`? - unprocessed_stmts = [stmt] - stmts = [] - transforms = [ +def transform( + stmt: itir.Stmt, declarations: list[itir.Temporary], uids: eve_utils.UIDGenerator +) -> list[itir.Stmt]: + unprocessed_stmts: list[itir.Stmt] = [stmt] + stmts: list[itir.Stmt] = [] + + transforms: list[Callable] = [ # transform functional if_ into if-stmt transform_if, # extract applied `as_fieldop` to top-level - functools.partial(transform_by_pattern, predicate=lambda expr, _: cpm.is_applied_as_fieldop(expr)), + functools.partial( + transform_by_pattern, predicate=lambda expr, _: cpm.is_applied_as_fieldop(expr) + ), # extract functional if_ to the top-level - functools.partial(transform_by_pattern, predicate=lambda expr, _: cpm.is_call_to(expr, "if_")), + functools.partial( + transform_by_pattern, predicate=lambda expr, _: cpm.is_call_to(expr, "if_") + ), ] while unprocessed_stmts: @@ -140,23 +166,25 @@ def transform(stmt: itir.SetAt, declarations: list[itir.Temporary], uids: eve_ut return stmts -def create_global_tmps(program: itir.Program): + +def create_global_tmps( + program: itir.Program, offset_provider: common.OffsetProvider +) -> itir.Program: + program = infer_domain.infer_program(program, offset_provider) + program = type_inference.infer(program, offset_provider=offset_provider) + uids = eve_utils.UIDGenerator(prefix="__tmp") - declarations = program.declarations + declarations = program.declarations.copy() new_body = [] for stmt in program.body: - if isinstance(stmt, (itir.SetAt, itir.IfStmt)): - new_body.extend( - transform(stmt, uids=uids, declarations=declarations) - ) - else: - raise NotImplementedError() + assert isinstance(stmt, itir.SetAt) + new_body.extend(transform(stmt, uids=uids, declarations=declarations)) return itir.Program( id=program.id, function_definitions=program.function_definitions, params=program.params, declarations=declarations, - body=new_body - ) \ No newline at end of file + body=new_body, + ) diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index c1a743af1c..7de430aad9 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -21,7 +21,7 @@ ir_makers as im, ) from gt4py.next.iterator.transforms import trace_shifts -from gt4py.next.utils import tree_map +from gt4py.next.utils import flatten_nested_tuple, tree_map DOMAIN: TypeAlias = domain_utils.SymbolicDomain | None | tuple["DOMAIN", ...] @@ -125,6 +125,10 @@ def _extract_accessed_domains( return typing.cast(ACCESSED_DOMAINS, accessed_domains) +def copy_domain_annex(from_: itir.Expr, to: itir.Expr): + to.annex.domain = from_.annex.domain + + def infer_as_fieldop( applied_fieldop: itir.FunCall, target_domain: DOMAIN, @@ -134,6 +138,9 @@ def infer_as_fieldop( assert cpm.is_call_to(applied_fieldop.fun, "as_fieldop") if target_domain is None: raise ValueError("'target_domain' cannot be 'None'.") + # TODO: needed for scans, try test_solve_triag + if isinstance(target_domain, tuple): + target_domain = _domain_union_with_none(*flatten_nested_tuple(target_domain)) if not isinstance(target_domain, domain_utils.SymbolicDomain): raise ValueError("'target_domain' needs to be a 'domain_utils.SymbolicDomain'.") @@ -157,23 +164,28 @@ def infer_as_fieldop( raise ValueError(f"Unsupported expression of type '{type(in_field)}'.") input_ids.append(id_) - accessed_domains: ACCESSED_DOMAINS = _extract_accessed_domains( + # TODO: note for pr: this dict contains as keys not only the symref inputs, but also + # temporary ids. The symrefs are already added to the result dict by the loop below, while + # the temporary ids should not be in the result anyway. as such do not use this dict + # as the starting point for the domain union in the loop below. + inputs_accessed_domains: ACCESSED_DOMAINS = _extract_accessed_domains( stencil, input_ids, target_domain, offset_provider ) # Recursively infer domain of inputs and update domain arg of nested `as_fieldop`s + accessed_domains: ACCESSED_DOMAINS = {} transformed_inputs: list[itir.Expr] = [] for in_field_id, in_field in zip(input_ids, inputs): transformed_input, accessed_domains_tmp = infer_expr( - in_field, accessed_domains[in_field_id], offset_provider + in_field, inputs_accessed_domains[in_field_id], offset_provider ) transformed_inputs.append(transformed_input) accessed_domains = _merge_domains(accessed_domains, accessed_domains_tmp) - transformed_call = im.as_fieldop(stencil, domain_utils.SymbolicDomain.as_expr(target_domain))( - *transformed_inputs - ) + target_domain_expr = domain_utils.SymbolicDomain.as_expr(target_domain) + transformed_call = im.as_fieldop(stencil, target_domain_expr)(*transformed_inputs) + transformed_call.annex.domain = target_domain_expr accessed_domains_without_tmp = { k: v @@ -219,6 +231,7 @@ def infer_let( for param, call in zip(let_expr.fun.params, transformed_calls_args, strict=True) ) )(transformed_calls_expr) + transformed_call.annex.domain = tree_map(lambda x: x.as_expr() if x else None)(input_domain) return transformed_call, accessed_domains_outer @@ -245,7 +258,9 @@ def infer_make_tuple( infered_arg_expr, actual_domains_arg = infer_expr(arg, domain[i], offset_provider) infered_args_expr.append(infered_arg_expr) actual_domains = _merge_domains(actual_domains, actual_domains_arg) - return im.call(expr.fun)(*infered_args_expr), actual_domains + result_expr = im.call(expr.fun)(*infered_args_expr) + result_expr.annex.domain = tree_map(lambda x: x.as_expr() if x else None)(domain) + return result_expr, actual_domains def infer_tuple_get( @@ -255,13 +270,15 @@ def infer_tuple_get( ) -> tuple[itir.Expr, ACCESSED_DOMAINS]: assert cpm.is_call_to(expr, "tuple_get") actual_domains: ACCESSED_DOMAINS = {} - idx, tuple_arg = expr.args - assert isinstance(idx, itir.Literal) - child_domain = tuple(None if i != int(idx.value) else domain for i in range(int(idx.value) + 1)) - infered_arg_expr, actual_domains_arg = infer_expr(tuple_arg, child_domain, offset_provider) + idx_expr, tuple_arg = expr.args + assert isinstance(idx_expr, itir.Literal) + idx = int(idx_expr.value) + tuple_domain = tuple(None if i != idx else domain for i in range(idx + 1)) + infered_arg_expr, actual_domains_arg = infer_expr(tuple_arg, tuple_domain, offset_provider) - infered_args_expr = im.tuple_get(idx.value, infered_arg_expr) + infered_args_expr = im.tuple_get(idx, infered_arg_expr) actual_domains = _merge_domains(actual_domains, actual_domains_arg) + infered_args_expr.annex.domain = tree_map(lambda x: x.as_expr() if x else None)(domain) return infered_args_expr, actual_domains @@ -278,7 +295,9 @@ def infer_if( infered_arg_expr, actual_domains_arg = infer_expr(arg, domain, offset_provider) infered_args_expr.append(infered_arg_expr) actual_domains = _merge_domains(actual_domains, actual_domains_arg) - return im.call(expr.fun)(cond, *infered_args_expr), actual_domains + result_expr = im.call(expr.fun)(cond, *infered_args_expr) + result_expr.annex.domain = tree_map(lambda x: x.as_expr() if x else None)(domain) + return result_expr, actual_domains def infer_expr( diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 6cd4784062..b3bb7bc6e1 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -18,7 +18,6 @@ from gt4py.next.iterator.transforms.cse import CommonSubexpressionElimination from gt4py.next.iterator.transforms.eta_reduction import EtaReduction from gt4py.next.iterator.transforms.fuse_maps import FuseMaps -from gt4py.next.iterator.transforms.global_tmps import CreateGlobalTmps from gt4py.next.iterator.transforms.inline_center_deref_lift_vars import InlineCenterDerefLiftVars from gt4py.next.iterator.transforms.inline_into_scan import InlineIntoScan from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas @@ -74,12 +73,14 @@ def apply_common_transforms( common_subexpression_elimination=True, force_inline_lambda_args=False, unconditionally_collapse_tuples=False, + # FIXME[#1582](tehrengruber): Revisit and cleanup after new GTIR temporary pass is in place temporary_extraction_heuristics: Optional[ Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] ] = None, + # FIXME[#1582](tehrengruber): Revisit and cleanup after new GTIR temporary pass is in place symbolic_domain_sizes: Optional[dict[str, str]] = None, ) -> itir.Program: - if isinstance(ir, (itir.FencilDefinition, FencilWithTemporaries)): + if isinstance(ir, itir.FencilDefinition): ir = fencil_to_program.FencilToProgram().apply( ir ) # FIXME[#1582](havogt): should be removed after refactoring to combined IR @@ -137,29 +138,30 @@ def apply_common_transforms( if lift_mode != LiftMode.FORCE_INLINE: # FIXME[#1582](tehrengruber): implement new temporary pass here raise NotImplementedError() - assert offset_provider is not None - ir = CreateGlobalTmps().visit( - ir, - offset_provider=offset_provider, - extraction_heuristics=temporary_extraction_heuristics, - symbolic_sizes=symbolic_domain_sizes, - ) - - for _ in range(10): - inlined = InlineLifts().visit(ir) - inlined = InlineLambdas.apply( - inlined, opcount_preserving=True, force_inline_lift_args=True - ) - if inlined == ir: - break - ir = inlined - else: - raise RuntimeError("Inlining 'lift' and 'lambdas' did not converge.") - - # If after creating temporaries, the scan is not at the top, we inline. - # The following example doesn't have a lift around the shift, i.e. temporary pass will not extract it. - # λ(inp) → scan(λ(state, k, kp) → state + ·k + ·kp, True, 0.0)(inp, ⟪Koffₒ, 1ₒ⟫(inp))` - ir = _inline_into_scan(ir) + # ruff: noqa: ERA001 + # assert offset_provider is not None + # ir = CreateGlobalTmps().visit( + # ir, + # offset_provider=offset_provider, + # extraction_heuristics=temporary_extraction_heuristics, + # symbolic_sizes=symbolic_domain_sizes, + # ) + # + # for _ in range(10): + # inlined = InlineLifts().visit(ir) + # inlined = InlineLambdas.apply( + # inlined, opcount_preserving=True, force_inline_lift_args=True + # ) + # if inlined == ir: + # break + # ir = inlined + # else: + # raise RuntimeError("Inlining 'lift' and 'lambdas' did not converge.") + # + # # If after creating temporaries, the scan is not at the top, we inline. + # # The following example doesn't have a lift around the shift, i.e. temporary pass will not extract it. + # # λ(inp) → scan(λ(state, k, kp) → state + ·k + ·kp, True, 0.0)(inp, ⟪Koffₒ, 1ₒ⟫(inp))` + # ir = _inline_into_scan(ir) # Since `CollapseTuple` relies on the type inference which does not support returning tuples # larger than the number of closure outputs as given by the unconditional collapse, we can diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index c141c80999..288917d281 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -19,7 +19,6 @@ from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_call_to -from gt4py.next.iterator.transforms import global_tmps from gt4py.next.iterator.type_system import type_specifications as it_ts, type_synthesizer from gt4py.next.type_system import type_info, type_specifications as ts from gt4py.next.type_system.type_info import primitive_constituents @@ -282,6 +281,9 @@ def type_synthesizer(*args, **kwargs): class SanitizeTypes(eve.NodeTranslator, eve.VisitorWithSymbolTableTrait): + # TODO: all + PRESERVED_ANNEX_ATTRS = ("domain",) + def visit_Node(self, node: itir.Node, *, symtable: dict[str, itir.Node]) -> itir.Node: node = self.generic_visit(node) # We only want to sanitize types that have been inferred previously such that we don't run @@ -305,6 +307,8 @@ class ITIRTypeInference(eve.NodeTranslator): See :method:ITIRTypeInference.apply for more details. """ + PRESERVED_ANNEX_ATTRS = ("domain",) + offset_provider: common.OffsetProvider #: Mapping from a dimension name to the actual dimension instance. dimensions: dict[str, common.Dimension] @@ -456,28 +460,6 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition, *, ctx) -> it_ts.F closures = self.visit(node.closures, ctx=ctx | params | function_definitions) return it_ts.FencilType(params=params, closures=closures) - # TODO(tehrengruber): Remove after new ITIR format with apply_stencil is used everywhere - def visit_FencilWithTemporaries( - self, node: global_tmps.FencilWithTemporaries, *, ctx - ) -> it_ts.FencilType: - # TODO(tehrengruber): This implementation is not very appealing. Since we are about to - # refactor the IR anyway this is fine for now. - params: dict[str, ts.DataType] = {} - for param in node.params: - assert isinstance(param.type, ts.DataType) - params[param.id] = param.type - # infer types of temporary declarations - tmps: dict[str, ts.FieldType] = {} - for tmp_node in node.tmps: - tmps[tmp_node.id] = self.visit(tmp_node, ctx=ctx | params) - # and store them in the inner fencil - for fencil_param in node.fencil.params: - if fencil_param.id in tmps: - fencil_param.type = tmps[fencil_param.id] - self.visit(node.fencil, ctx=ctx) - assert isinstance(node.fencil.type, it_ts.FencilType) - return node.fencil.type - def visit_Program(self, node: itir.Program, *, ctx) -> it_ts.ProgramType: params: dict[str, ts.DataType] = {} for param in node.params: diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 9c4e73520f..652440e29f 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -18,7 +18,6 @@ from gt4py.eve.utils import content_hash from gt4py.next import backend, common, config from gt4py.next.iterator import transforms -from gt4py.next.iterator.transforms import global_tmps from gt4py.next.otf import arguments, recipes, stages, workflow from gt4py.next.otf.binding import nanobind from gt4py.next.otf.compilation import compiler @@ -173,8 +172,9 @@ class Params: name_cached="_cached", ) use_temporaries = factory.Trait( + # FIXME[#1582](tehrengruber): Revisit and cleanup after new GTIR temporary pass is in place otf_workflow__translation__lift_mode=transforms.LiftMode.USE_TEMPORARIES, - otf_workflow__translation__temporary_extraction_heuristics=global_tmps.SimpleTemporaryExtractionHeuristics, + # otf_workflow__translation__temporary_extraction_heuristics=global_tmps.SimpleTemporaryExtractionHeuristics, # noqa: ERA001 name_temps="_with_temporaries", ) device_type = core_defs.DeviceType.CPU diff --git a/src/gt4py/next/utils.py b/src/gt4py/next/utils.py index 44fa929e56..7489908ba9 100644 --- a/src/gt4py/next/utils.py +++ b/src/gt4py/next/utils.py @@ -68,7 +68,12 @@ def flatten_nested_tuple( @overload -def tree_map(fun: Callable[_P, _R], /) -> Callable[..., _R | tuple[_R | tuple, ...]]: ... +def tree_map( + fun: Callable[_P, _R], + *, + collection_type: type | tuple[type, ...] = tuple, + result_collection_constructor: Optional[type | Callable] = None, +) -> Callable[..., _R | tuple[_R | tuple, ...]]: ... @overload @@ -82,7 +87,8 @@ def tree_map( def tree_map( - *args: Callable[_P, _R], + fun: Optional[Callable[_P, _R]] = None, + *, collection_type: type | tuple[type, ...] = tuple, result_collection_constructor: Optional[type | Callable] = None, ) -> Callable[..., _R | tuple[_R | tuple, ...]] | Callable[[Callable[_P, _R]], Callable[..., Any]]: @@ -108,6 +114,12 @@ def tree_map( ... [[1, 2], 3] ... ) ((2, 3), 4) + + >>> @tree_map + ... def impl(x): + ... return x + 1 + >>> impl(((1, 2), 3)) + ((2, 3), 4) """ if result_collection_constructor is None: @@ -117,8 +129,7 @@ def tree_map( ) result_collection_constructor = collection_type - if len(args) == 1: - fun = args[0] + if fun: @functools.wraps(fun) def impl(*args: Any | tuple[Any | tuple, ...]) -> _R | tuple[_R | tuple, ...]: @@ -129,17 +140,14 @@ def impl(*args: Any | tuple[Any | tuple, ...]) -> _R | tuple[_R | tuple, ...]: assert result_collection_constructor is not None return result_collection_constructor(impl(*arg) for arg in zip(*args)) - return fun( + return fun( # type: ignore[misc] # mypy not smart enough *cast(_P.args, args) ) # mypy doesn't understand that `args` at this point is of type `_P.args` return impl - if len(args) == 0: + else: return functools.partial( tree_map, collection_type=collection_type, result_collection_constructor=result_collection_constructor, ) - raise TypeError( - "tree_map() can be used as decorator with optional kwarg `collection_type` and `result_collection_constructor`." - ) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py index 5d13337a94..37c63825b2 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py @@ -145,7 +145,9 @@ def translate_domain( shift_list = [item for sublist in shift_tuples for item in sublist] - translated_domain_expr = domain_utils.SymbolicDomain.from_expr(domain).translate(shift_list, offset_provider) + translated_domain_expr = domain_utils.SymbolicDomain.from_expr(domain).translate( + shift_list, offset_provider + ) return constant_fold_domain_exprs(translated_domain_expr.as_expr()) @@ -789,7 +791,10 @@ def test_make_tuple(offset_provider): actual, actual_domains = infer_domain.infer_expr( testee, - (domain_utils.SymbolicDomain.from_expr(domain1), domain_utils.SymbolicDomain.from_expr(domain2)), + ( + domain_utils.SymbolicDomain.from_expr(domain1), + domain_utils.SymbolicDomain.from_expr(domain2), + ), offset_provider, ) @@ -824,7 +829,10 @@ def test_tuple_get_1_nested_make_tuple(offset_provider): actual, actual_domains = infer_domain.infer_expr( testee, - (domain_utils.SymbolicDomain.from_expr(domain1), domain_utils.SymbolicDomain.from_expr(domain2)), + ( + domain_utils.SymbolicDomain.from_expr(domain1), + domain_utils.SymbolicDomain.from_expr(domain2), + ), offset_provider, ) @@ -840,7 +848,9 @@ def test_tuple_get_let_arg_make_tuple(offset_provider): actual, actual_domains = infer_domain.infer_expr( testee, - domain_utils.SymbolicDomain.from_expr(im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)})), + domain_utils.SymbolicDomain.from_expr( + im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + ), offset_provider, ) @@ -878,7 +888,10 @@ def test_nested_make_tuple(offset_provider): ( ( domain_utils.SymbolicDomain.from_expr(domain1), - (domain_utils.SymbolicDomain.from_expr(domain2_1), domain_utils.SymbolicDomain.from_expr(domain2_2)), + ( + domain_utils.SymbolicDomain.from_expr(domain2_1), + domain_utils.SymbolicDomain.from_expr(domain2_2), + ), ), domain_utils.SymbolicDomain.from_expr(domain3), ), @@ -912,7 +925,10 @@ def test_domain_tuple(offset_provider): actual, actual_domains = infer_domain.infer_expr( testee, - (domain_utils.SymbolicDomain.from_expr(domain1), domain_utils.SymbolicDomain.from_expr(domain2)), + ( + domain_utils.SymbolicDomain.from_expr(domain1), + domain_utils.SymbolicDomain.from_expr(domain2), + ), offset_provider, ) @@ -945,7 +961,10 @@ def test_make_tuple_2tuple_get(offset_provider): actual, actual_domains = infer_domain.infer_expr( testee, - (domain_utils.SymbolicDomain.from_expr(domain1), domain_utils.SymbolicDomain.from_expr(domain2)), + ( + domain_utils.SymbolicDomain.from_expr(domain1), + domain_utils.SymbolicDomain.from_expr(domain2), + ), offset_provider, ) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py index c8a91e037e..23f62842c4 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py @@ -6,14 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -# TODO(tehrengruber): add integration tests for temporaries starting from manually written -# itir. Currently we only test temporaries from frontend code which makes testing changes -# to anything related to temporaries tedious. -import copy from typing import Optional -import gt4py.next as gtx -from gt4py.eve.utils import UIDs from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im @@ -30,19 +24,21 @@ i_field_type = ts.FieldType(dims=[IDim], dtype=float_type) index_field_type_factory = lambda dim: ts.FieldType(dims=[dim], dtype=index_type) + def program_factory( params: list[itir.Sym], body: list[itir.SetAt], - declarations: Optional[list[itir.Temporary]] = None + declarations: Optional[list[itir.Temporary]] = None, ) -> itir.Program: return itir.Program( id="testee", function_definitions=[], params=params, declarations=declarations or [], - body=body + body=body, ) + def test_trivial(): domain = im.domain("cartesian_domain", {IDim: (0, 1)}) offset_provider = {} @@ -51,37 +47,28 @@ def test_trivial(): body=[ itir.SetAt( target=im.ref("out"), - expr=im.as_fieldop("deref", domain)( - im.as_fieldop("deref", domain)("inp") - ), - domain=domain + expr=im.as_fieldop("deref", domain)(im.as_fieldop("deref", domain)("inp")), + domain=domain, ) - ] + ], ) testee = type_inference.infer(testee, offset_provider=offset_provider) + testee = infer_domain.infer_program(testee, offset_provider=offset_provider) expected = program_factory( params=[im.sym("inp", i_field_type), im.sym("out", i_field_type)], - declarations=[itir.Temporary( - id="__tmp_1", - domain=domain, - dtype=float_type - )], + declarations=[itir.Temporary(id="__tmp_1", domain=domain, dtype=float_type)], body=[ itir.SetAt( - target=im.ref("__tmp_1"), - expr=im.as_fieldop("deref", domain)("inp"), - domain=domain + target=im.ref("__tmp_1"), expr=im.as_fieldop("deref", domain)("inp"), domain=domain ), itir.SetAt( - target=im.ref("out"), - expr=im.as_fieldop("deref", domain)("__tmp_1"), - domain=domain - ) - ] + target=im.ref("out"), expr=im.as_fieldop("deref", domain)("__tmp_1"), domain=domain + ), + ], ) - actual = global_tmps.create_global_tmps(testee) + actual = global_tmps.create_global_tmps(testee, offset_provider) assert actual == expected @@ -94,35 +81,29 @@ def test_trivial_let(): itir.SetAt( target=im.ref("out"), expr=im.let("tmp", im.as_fieldop("deref", domain)("inp"))( - im.as_fieldop("deref", domain)("tmp")), - domain=domain + im.as_fieldop("deref", domain)("tmp") + ), + domain=domain, ) - ] + ], ) testee = type_inference.infer(testee, offset_provider=offset_provider) + testee = infer_domain.infer_program(testee, offset_provider=offset_provider) expected = program_factory( params=[im.sym("inp", i_field_type), im.sym("out", i_field_type)], - declarations=[itir.Temporary( - id="__tmp_1", - domain=domain, - dtype=float_type - )], + declarations=[itir.Temporary(id="__tmp_1", domain=domain, dtype=float_type)], body=[ itir.SetAt( - target=im.ref("__tmp_1"), - expr=im.as_fieldop("deref", domain)("inp"), - domain=domain + target=im.ref("__tmp_1"), expr=im.as_fieldop("deref", domain)("inp"), domain=domain ), itir.SetAt( - target=im.ref("out"), - expr=im.as_fieldop("deref", domain)("__tmp_1"), - domain=domain - ) - ] + target=im.ref("out"), expr=im.as_fieldop("deref", domain)("__tmp_1"), domain=domain + ), + ], ) - actual = global_tmps.create_global_tmps(testee) + actual = global_tmps.create_global_tmps(testee, offset_provider) assert actual == expected @@ -130,19 +111,32 @@ def test_top_level_if(): domain = im.domain("cartesian_domain", {IDim: (0, 1)}) offset_provider = {} testee = program_factory( - params=[im.sym("inp1", i_field_type), im.sym("inp2", i_field_type), im.sym("out", i_field_type)], + params=[ + im.sym("inp1", i_field_type), + im.sym("inp2", i_field_type), + im.sym("out", i_field_type), + ], body=[ itir.SetAt( target=im.ref("out"), - expr=im.if_(True, im.as_fieldop("deref", domain)("inp1"), im.as_fieldop("deref", domain)("inp2")), - domain=domain + expr=im.if_( + True, + im.as_fieldop("deref", domain)("inp1"), + im.as_fieldop("deref", domain)("inp2"), + ), + domain=domain, ) - ] + ], ) testee = type_inference.infer(testee, offset_provider=offset_provider) + testee = infer_domain.infer_program(testee, offset_provider=offset_provider) expected = program_factory( - params=[im.sym("inp1", i_field_type), im.sym("inp2", i_field_type), im.sym("out", i_field_type)], + params=[ + im.sym("inp1", i_field_type), + im.sym("inp2", i_field_type), + im.sym("out", i_field_type), + ], declarations=[], body=[ itir.IfStmt( @@ -151,21 +145,21 @@ def test_top_level_if(): itir.SetAt( target=im.ref("out"), expr=im.as_fieldop("deref", domain)("inp1"), - domain=domain + domain=domain, ) ], false_branch=[ itir.SetAt( target=im.ref("out"), expr=im.as_fieldop("deref", domain)("inp2"), - domain=domain + domain=domain, ) - ] + ], ) - ] + ], ) - actual = global_tmps.create_global_tmps(testee) + actual = global_tmps.create_global_tmps(testee, offset_provider) assert actual == expected @@ -173,25 +167,35 @@ def test_nested_if(): domain = im.domain("cartesian_domain", {IDim: (0, 1)}) offset_provider = {} testee = program_factory( - params=[im.sym("inp1", i_field_type), im.sym("inp2", i_field_type), im.sym("out", i_field_type)], + params=[ + im.sym("inp1", i_field_type), + im.sym("inp2", i_field_type), + im.sym("out", i_field_type), + ], body=[ itir.SetAt( target=im.ref("out"), expr=im.as_fieldop("deref", domain)( - im.if_(True, im.as_fieldop("deref", domain)("inp1"), im.as_fieldop("deref", domain)("inp2"))), - domain=domain + im.if_( + True, + im.as_fieldop("deref", domain)("inp1"), + im.as_fieldop("deref", domain)("inp2"), + ) + ), + domain=domain, ) - ] + ], ) testee = type_inference.infer(testee, offset_provider=offset_provider) + testee = infer_domain.infer_program(testee, offset_provider=offset_provider) expected = program_factory( - params=[im.sym("inp1", i_field_type), im.sym("inp2", i_field_type), im.sym("out", i_field_type)], - declarations=[itir.Temporary( - id="__tmp_1", - domain=domain, - dtype=float_type - )], + params=[ + im.sym("inp1", i_field_type), + im.sym("inp2", i_field_type), + im.sym("out", i_field_type), + ], + declarations=[itir.Temporary(id="__tmp_1", domain=domain, dtype=float_type)], body=[ itir.IfStmt( cond=im.literal_from_value(True), @@ -199,25 +203,22 @@ def test_nested_if(): itir.SetAt( target=im.ref("__tmp_1"), expr=im.as_fieldop("deref", domain)("inp1"), - domain=domain + domain=domain, ) ], false_branch=[ itir.SetAt( target=im.ref("__tmp_1"), expr=im.as_fieldop("deref", domain)("inp2"), - domain=domain + domain=domain, ) - ] + ], ), itir.SetAt( - target=im.ref("out"), - expr=im.as_fieldop("deref", domain)("__tmp_1"), - domain=domain - ) - ] + target=im.ref("out"), expr=im.as_fieldop("deref", domain)("__tmp_1"), domain=domain + ), + ], ) - actual = global_tmps.create_global_tmps(testee) + actual = global_tmps.create_global_tmps(testee, offset_provider) assert actual == expected - diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py index 4a7a0e29ca..9b05a01b88 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py @@ -54,23 +54,6 @@ def test_backend_factory_trait_cached(): assert cached_version.executor.__name__ == "run_gtfn_cpu_cached" -def test_backend_factory_trait_temporaries(): - inline_version = gtfn.GTFNBackendFactory(cached=False) - temps_version = gtfn.GTFNBackendFactory(cached=False, use_temporaries=True) - - assert inline_version.executor.otf_workflow.translation.lift_mode is None - assert ( - temps_version.executor.otf_workflow.translation.lift_mode - is transforms.LiftMode.USE_TEMPORARIES - ) - - assert inline_version.executor.otf_workflow.translation.temporary_extraction_heuristics is None - assert ( - temps_version.executor.otf_workflow.translation.temporary_extraction_heuristics - is global_tmps.SimpleTemporaryExtractionHeuristics - ) - - def test_backend_factory_build_cache_config(monkeypatch): monkeypatch.setattr(config, "BUILD_CACHE_LIFETIME", config.BuildCacheLifetime.SESSION) session_version = gtfn.GTFNBackendFactory() From 61e966581cfc227826ff156ac7e21fd095e2e7b5 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 10 Oct 2024 13:20:57 +0200 Subject: [PATCH 018/150] Cleanup --- src/gt4py/next/iterator/transforms/cse.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index 518f294cbb..b59d1cfe78 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -43,11 +43,14 @@ def visit_Expr(self, node: itir.Node) -> itir.Node: def visit_FunCall(self, node: itir.FunCall) -> itir.Node: node = cast(itir.FunCall, self.visit_Expr(node)) + # TODO(tehrengruber): This symbol name from the inner expr, to increase readability of IR # If we encounter an expression like: # (λ(_cs_1) → (λ(a) → a+a)(_cs_1))(outer_expr) # (non-recursively) inline the lambda to obtain: # (λ(_cs_1) → _cs_1+_cs_1)(outer_expr) - # This allows identifying more common subexpressions later on + # In the CSE this allows identifying more common subexpressions later on. Other users + # of `extract_subexpression` (e.g. temporary extraction) can also rely on this to avoid + # the need to handle this artificial let-statements. if isinstance(node, itir.FunCall) and isinstance(node.fun, itir.Lambda): eligible_params = [] for arg in node.args: From f9dff502f00865eaea686379bec6186232ed3f7a Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 10 Oct 2024 13:21:15 +0200 Subject: [PATCH 019/150] Cleanup --- src/gt4py/next/iterator/transforms/cse.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index b59d1cfe78..88c02daa2f 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -43,7 +43,7 @@ def visit_Expr(self, node: itir.Node) -> itir.Node: def visit_FunCall(self, node: itir.FunCall) -> itir.Node: node = cast(itir.FunCall, self.visit_Expr(node)) - # TODO(tehrengruber): This symbol name from the inner expr, to increase readability of IR + # TODO(tehrengruber): Use symbol name from the inner let, to increase readability of IR # If we encounter an expression like: # (λ(_cs_1) → (λ(a) → a+a)(_cs_1))(outer_expr) # (non-recursively) inline the lambda to obtain: From 71f512f29930b9615da8d941e0f7540abc095c2d Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 10 Oct 2024 16:29:10 +0200 Subject: [PATCH 020/150] Cleanup --- .../next/iterator/transforms/infer_domain.py | 2 +- .../transforms_tests/test_cse.py | 12 +++------- .../transforms_tests/test_domain_inference.py | 23 +++++++++++++++++-- 3 files changed, 25 insertions(+), 12 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index 7de430aad9..47436fec81 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -138,7 +138,7 @@ def infer_as_fieldop( assert cpm.is_call_to(applied_fieldop.fun, "as_fieldop") if target_domain is None: raise ValueError("'target_domain' cannot be 'None'.") - # TODO: needed for scans, try test_solve_triag + # FIXME[#1582](tehrengruber): Temporary solution for `tuple_get` on scan result. See `test_solve_triag`. if isinstance(target_domain, tuple): target_domain = _domain_union_with_none(*flatten_nested_tuple(target_domain)) if not isinstance(target_domain, domain_utils.SymbolicDomain): diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py index 78f95da8ca..3204b49371 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py @@ -74,16 +74,10 @@ def common_expr(): return im.plus("x", "x") # λ(x) → (λ(y) → y + (x + x + (x + x)))(z) - testee = im.lambda_("x")( - im.call(im.lambda_("y")(im.plus("y", im.plus(common_expr(), common_expr()))))("z") - ) - # λ(x) → (λ(_cs_1) → (λ(y) → y + (_cs_1 + _cs_1))(z))(x + x) + testee = im.lambda_("x")(im.let("y", "z")(im.plus("y", im.plus(common_expr(), common_expr())))) + # λ(x) → (λ(_cs_1) → z + (_cs_1 + _cs_1))(x + x) expected = im.lambda_("x")( - im.call( - im.lambda_("_cs_1")( - im.call(im.lambda_("y")(im.plus("y", im.plus("_cs_1", "_cs_1"))))("z") - ) - )(common_expr()) + im.let("_cs_1", common_expr())(im.plus("z", im.plus("_cs_1", "_cs_1"))) ) actual = CSE.apply(testee, is_local_view=True) assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py index 37c63825b2..50756f40e7 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py @@ -990,9 +990,9 @@ def test_make_tuple_non_tuple_domain(offset_provider): def test_arithmetic_builtin(offset_provider): - testee = im.plus(im.ref("in_field1"), im.ref("in_field2")) + testee = im.plus(im.ref("alpha"), im.ref("beta")) domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) - expected = im.plus(im.ref("in_field1"), im.ref("in_field2")) + expected = im.plus(im.ref("alpha"), im.ref("beta")) expected_domains = {} actual_call, actual_domains = infer_domain.infer_expr( @@ -1002,3 +1002,22 @@ def test_arithmetic_builtin(offset_provider): assert folded_call == expected assert actual_domains == expected_domains + + +def test_scan(offset_provider): + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + testee = im.as_fieldop( + im.call("scan")(im.lambda_("init", "it")(im.deref(im.shift("Ioff", 1)("it"))), True, 0.0) + )("a") + expected = im.as_fieldop( + im.call("scan")(im.lambda_("init", "it")(im.deref(im.shift("Ioff", 1)("it"))), True, 0.0), + domain, + )("a") + + run_test_expr( + testee, + expected, + domain, + {"a": im.domain(common.GridType.CARTESIAN, {IDim: (1, 12)})}, + offset_provider, + ) From 6fb424fd8f55a27962e9c999f5328d69347f5a64 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 10 Oct 2024 16:30:18 +0200 Subject: [PATCH 021/150] Cleanup --- src/gt4py/next/iterator/transforms/cse.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index 88c02daa2f..ccc1d2195f 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -52,11 +52,7 @@ def visit_FunCall(self, node: itir.FunCall) -> itir.Node: # of `extract_subexpression` (e.g. temporary extraction) can also rely on this to avoid # the need to handle this artificial let-statements. if isinstance(node, itir.FunCall) and isinstance(node.fun, itir.Lambda): - eligible_params = [] - for arg in node.args: - eligible_params.append( - isinstance(arg, itir.SymRef) - ) # and arg.id.startswith("_cs")) # TODO: document? this is for lets in the global tmp pass, e.g. test_trivial_let + eligible_params = [isinstance(arg, itir.SymRef) for arg in node.args] if any(eligible_params): # note: the inline is opcount preserving anyway so avoid the additional # effort in the inliner by disabling opcount preservation. From 7d809fb8b076ff7185e0a1bafa87440cd71678ba Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 10 Oct 2024 16:34:19 +0200 Subject: [PATCH 022/150] Cleanup --- .../next/iterator/transforms/infer_domain.py | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index 47436fec81..f96de32b21 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -125,10 +125,6 @@ def _extract_accessed_domains( return typing.cast(ACCESSED_DOMAINS, accessed_domains) -def copy_domain_annex(from_: itir.Expr, to: itir.Expr): - to.annex.domain = from_.annex.domain - - def infer_as_fieldop( applied_fieldop: itir.FunCall, target_domain: DOMAIN, @@ -185,7 +181,6 @@ def infer_as_fieldop( target_domain_expr = domain_utils.SymbolicDomain.as_expr(target_domain) transformed_call = im.as_fieldop(stencil, target_domain_expr)(*transformed_inputs) - transformed_call.annex.domain = target_domain_expr accessed_domains_without_tmp = { k: v @@ -231,7 +226,6 @@ def infer_let( for param, call in zip(let_expr.fun.params, transformed_calls_args, strict=True) ) )(transformed_calls_expr) - transformed_call.annex.domain = tree_map(lambda x: x.as_expr() if x else None)(input_domain) return transformed_call, accessed_domains_outer @@ -259,7 +253,6 @@ def infer_make_tuple( infered_args_expr.append(infered_arg_expr) actual_domains = _merge_domains(actual_domains, actual_domains_arg) result_expr = im.call(expr.fun)(*infered_args_expr) - result_expr.annex.domain = tree_map(lambda x: x.as_expr() if x else None)(domain) return result_expr, actual_domains @@ -278,7 +271,6 @@ def infer_tuple_get( infered_args_expr = im.tuple_get(idx, infered_arg_expr) actual_domains = _merge_domains(actual_domains, actual_domains_arg) - infered_args_expr.annex.domain = tree_map(lambda x: x.as_expr() if x else None)(domain) return infered_args_expr, actual_domains @@ -296,11 +288,10 @@ def infer_if( infered_args_expr.append(infered_arg_expr) actual_domains = _merge_domains(actual_domains, actual_domains_arg) result_expr = im.call(expr.fun)(cond, *infered_args_expr) - result_expr.annex.domain = tree_map(lambda x: x.as_expr() if x else None)(domain) return result_expr, actual_domains -def infer_expr( +def _infer_expr( expr: itir.Expr, domain: DOMAIN, offset_provider: common.OffsetProvider, @@ -329,6 +320,17 @@ def infer_expr( raise ValueError(f"Unsupported expression: {expr}") +def infer_expr( + expr: itir.Expr, + domain: DOMAIN, + offset_provider: common.OffsetProvider, +) -> tuple[itir.Expr, ACCESSED_DOMAINS]: + # this is just a small wrapper that populates the `domain` annex + expr, accessed_domains = _infer_expr(expr, domain, offset_provider) + expr.annex.domain = domain.as_expr() + return expr, accessed_domains + + def infer_program( program: itir.Program, offset_provider: common.OffsetProvider, From c7b79c0c4c848f8cefa37aa3e98b09526071f91e Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 10 Oct 2024 16:39:58 +0200 Subject: [PATCH 023/150] Cleanup --- src/gt4py/next/iterator/transforms/global_tmps.py | 6 ++++++ src/gt4py/next/iterator/type_system/inference.py | 1 - 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index fbff2e8528..2571cbc78a 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -170,6 +170,12 @@ def transform( def create_global_tmps( program: itir.Program, offset_provider: common.OffsetProvider ) -> itir.Program: + """ + Given an `itir.Program` create temporaries for intermediate values. + + This pass looks at all `as_fieldop` calls and transforms field-typed subexpressions of its + arguments into temporaries. + """ program = infer_domain.infer_program(program, offset_provider) program = type_inference.infer(program, offset_provider=offset_provider) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 5835d58013..2cc28eb45f 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -281,7 +281,6 @@ def type_synthesizer(*args, **kwargs): class SanitizeTypes(eve.NodeTranslator, eve.VisitorWithSymbolTableTrait): - # TODO: all PRESERVED_ANNEX_ATTRS = ("domain",) def visit_Node(self, node: itir.Node, *, symtable: dict[str, itir.Node]) -> itir.Node: From f16a961c38a4127c818dbaf7cd91135c02b9a3a3 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 10 Oct 2024 16:41:27 +0200 Subject: [PATCH 024/150] Cleanup --- .../next/iterator/transforms/global_tmps.py | 46 +++++++++---------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 2571cbc78a..bdd3e76252 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -20,7 +20,7 @@ from gt4py.next.type_system import type_info, type_specifications as ts -def transform_if( +def _transform_stmt_if( stmt: itir.Stmt, declarations: list[itir.Temporary], uids: eve_utils.UIDGenerator ) -> Optional[list[itir.Stmt]]: if not isinstance(stmt, itir.SetAt): @@ -31,13 +31,13 @@ def transform_if( return [ itir.IfStmt( cond=cond, - # recursively transform - true_branch=transform( + # recursively _transform_stmt + true_branch=_transform_stmt( itir.SetAt(target=stmt.target, expr=true_val, domain=stmt.domain), declarations, uids, ), - false_branch=transform( + false_branch=_transform_stmt( itir.SetAt(target=stmt.target, expr=false_val, domain=stmt.domain), declarations, uids, @@ -47,7 +47,7 @@ def transform_if( return None -def transform_by_pattern( +def _transform_stmt_by_pattern( stmt: itir.Stmt, predicate, declarations: list[itir.Temporary], uids: eve_utils.UIDGenerator ) -> Optional[list[itir.Stmt]]: if not isinstance(stmt, itir.SetAt): @@ -119,9 +119,9 @@ def allocate_temporary(tmp_name: str, dtype: ts.ScalarType, domain: itir.Expr): im.let(tmp_sym, target_expr)(new_expr), opcount_preserving=False ) - # TODO: transform not needed if deepest_expr_first=True + # TODO: _transform_stmt not needed if deepest_expr_first=True tmp_stmts.extend( - transform( + _transform_stmt( itir.SetAt(target=target_expr, domain=domain, expr=tmp_expr), declarations, uids ) ) @@ -130,38 +130,38 @@ def allocate_temporary(tmp_name: str, dtype: ts.ScalarType, domain: itir.Expr): return None -def transform( +def _transform_stmt( stmt: itir.Stmt, declarations: list[itir.Temporary], uids: eve_utils.UIDGenerator ) -> list[itir.Stmt]: unprocessed_stmts: list[itir.Stmt] = [stmt] stmts: list[itir.Stmt] = [] - transforms: list[Callable] = [ - # transform functional if_ into if-stmt - transform_if, + _transform_stmts: list[Callable] = [ + # _transform_stmt functional if_ into if-stmt + _transform_stmt_if, # extract applied `as_fieldop` to top-level functools.partial( - transform_by_pattern, predicate=lambda expr, _: cpm.is_applied_as_fieldop(expr) + _transform_stmt_by_pattern, predicate=lambda expr, _: cpm.is_applied_as_fieldop(expr) ), # extract functional if_ to the top-level functools.partial( - transform_by_pattern, predicate=lambda expr, _: cpm.is_call_to(expr, "if_") + _transform_stmt_by_pattern, predicate=lambda expr, _: cpm.is_call_to(expr, "if_") ), ] while unprocessed_stmts: stmt = unprocessed_stmts.pop(0) - did_transform = False - for transform in transforms: - transformed_stmts = transform(stmt=stmt, declarations=declarations, uids=uids) - if transformed_stmts: - unprocessed_stmts = [*transformed_stmts, *unprocessed_stmts] - did_transform = True + did_transform_stmt = False + for _transform_stmt in _transform_stmts: + _transform_stmted_stmts = _transform_stmt(stmt=stmt, declarations=declarations, uids=uids) + if _transform_stmted_stmts: + unprocessed_stmts = [*_transform_stmted_stmts, *unprocessed_stmts] + did_transform_stmt = True break - # no transformation occurred - if not did_transform: + # no _transform_stmtation occurred + if not did_transform_stmt: stmts.append(stmt) return stmts @@ -173,7 +173,7 @@ def create_global_tmps( """ Given an `itir.Program` create temporaries for intermediate values. - This pass looks at all `as_fieldop` calls and transforms field-typed subexpressions of its + This pass looks at all `as_fieldop` calls and _transform_stmts field-typed subexpressions of its arguments into temporaries. """ program = infer_domain.infer_program(program, offset_provider) @@ -185,7 +185,7 @@ def create_global_tmps( for stmt in program.body: assert isinstance(stmt, itir.SetAt) - new_body.extend(transform(stmt, uids=uids, declarations=declarations)) + new_body.extend(_transform_stmt(stmt, uids=uids, declarations=declarations)) return itir.Program( id=program.id, From 3e3f9a1b72a7676ff3084eb1d675b8d76c9e064d Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 10 Oct 2024 16:56:08 +0200 Subject: [PATCH 025/150] Cleanup --- .../next/iterator/transforms/global_tmps.py | 24 ++++++++++++++----- .../next/iterator/transforms/infer_domain.py | 2 +- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index bdd3e76252..2c63fd2c3a 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -14,7 +14,11 @@ from gt4py.eve import utils as eve_utils from gt4py.next import common, utils as next_utils from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + domain_utils, + ir_makers as im, +) from gt4py.next.iterator.transforms import cse, infer_domain, inline_lambdas from gt4py.next.iterator.type_system import inference as type_inference from gt4py.next.type_system import type_info, type_specifications as ts @@ -78,12 +82,16 @@ def _transform_stmt_by_pattern( # able to eliminate all tuples, e.g., by propagating the scalar ifs to the top-level # of a SetAt, the CollapseTuple pass will eliminate most of this cases. if isinstance(domain, tuple): - flattened_domains: tuple[itir.Expr] = next_utils.flatten_nested_tuple(domain) # type: ignore[assignment] # mypy not smart enough + flattened_domains: tuple[domain_utils.SymbolicDomain] = ( + next_utils.flatten_nested_tuple(domain) # type: ignore[assignment] # mypy not smart enough + ) if not all(d == flattened_domains[0] for d in flattened_domains): raise NotImplementedError( "Tuple expressions with different domains is not " "supported yet." ) domain = flattened_domains[0] + assert isinstance(domain, domain_utils.SymbolicDomain) + domain_expr = domain.as_expr() assert isinstance(tmp_expr.type, ts.TypeSpec) tmp_names: str | tuple[str | tuple, ...] = type_info.apply_to_primitive_constituents( @@ -103,7 +111,7 @@ def _transform_stmt_by_pattern( def allocate_temporary(tmp_name: str, dtype: ts.ScalarType, domain: itir.Expr): declarations.append(itir.Temporary(id=tmp_name, domain=domain, dtype=dtype)) - next_utils.tree_map(functools.partial(allocate_temporary, domain=domain))( + next_utils.tree_map(functools.partial(allocate_temporary, domain=domain_expr))( tmp_names, tmp_dtypes ) @@ -119,10 +127,12 @@ def allocate_temporary(tmp_name: str, dtype: ts.ScalarType, domain: itir.Expr): im.let(tmp_sym, target_expr)(new_expr), opcount_preserving=False ) - # TODO: _transform_stmt not needed if deepest_expr_first=True + # TODO(tehrengruber): _transform_stmt not needed if deepest_expr_first=True tmp_stmts.extend( _transform_stmt( - itir.SetAt(target=target_expr, domain=domain, expr=tmp_expr), declarations, uids + itir.SetAt(target=target_expr, domain=domain_expr, expr=tmp_expr), + declarations, + uids, ) ) @@ -154,7 +164,9 @@ def _transform_stmt( did_transform_stmt = False for _transform_stmt in _transform_stmts: - _transform_stmted_stmts = _transform_stmt(stmt=stmt, declarations=declarations, uids=uids) + _transform_stmted_stmts = _transform_stmt( + stmt=stmt, declarations=declarations, uids=uids + ) if _transform_stmted_stmts: unprocessed_stmts = [*_transform_stmted_stmts, *unprocessed_stmts] did_transform_stmt = True diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index f96de32b21..acf249363c 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -327,7 +327,7 @@ def infer_expr( ) -> tuple[itir.Expr, ACCESSED_DOMAINS]: # this is just a small wrapper that populates the `domain` annex expr, accessed_domains = _infer_expr(expr, domain, offset_provider) - expr.annex.domain = domain.as_expr() + expr.annex.domain = domain return expr, accessed_domains From edffd9794740cf2ac02b1261a27f6af9ff473878 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 10 Oct 2024 16:56:42 +0200 Subject: [PATCH 026/150] Cleanup --- src/gt4py/next/iterator/transforms/global_tmps.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 2c63fd2c3a..5f698cfa48 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -24,7 +24,7 @@ from gt4py.next.type_system import type_info, type_specifications as ts -def _transform_stmt_if( +def _transform_if( stmt: itir.Stmt, declarations: list[itir.Temporary], uids: eve_utils.UIDGenerator ) -> Optional[list[itir.Stmt]]: if not isinstance(stmt, itir.SetAt): @@ -51,7 +51,7 @@ def _transform_stmt_if( return None -def _transform_stmt_by_pattern( +def _transform_by_pattern( stmt: itir.Stmt, predicate, declarations: list[itir.Temporary], uids: eve_utils.UIDGenerator ) -> Optional[list[itir.Stmt]]: if not isinstance(stmt, itir.SetAt): @@ -148,14 +148,14 @@ def _transform_stmt( _transform_stmts: list[Callable] = [ # _transform_stmt functional if_ into if-stmt - _transform_stmt_if, + _transform_if, # extract applied `as_fieldop` to top-level functools.partial( - _transform_stmt_by_pattern, predicate=lambda expr, _: cpm.is_applied_as_fieldop(expr) + _transform_by_pattern, predicate=lambda expr, _: cpm.is_applied_as_fieldop(expr) ), # extract functional if_ to the top-level functools.partial( - _transform_stmt_by_pattern, predicate=lambda expr, _: cpm.is_call_to(expr, "if_") + _transform_by_pattern, predicate=lambda expr, _: cpm.is_call_to(expr, "if_") ), ] From 3196a1152fae26a1caf15389a0fc854422962f17 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 10 Oct 2024 17:01:12 +0200 Subject: [PATCH 027/150] Cleanup --- .../next/iterator/transforms/global_tmps.py | 31 +++++++++---------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 5f698cfa48..f80dc74834 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -35,7 +35,6 @@ def _transform_if( return [ itir.IfStmt( cond=cond, - # recursively _transform_stmt true_branch=_transform_stmt( itir.SetAt(target=stmt.target, expr=true_val, domain=stmt.domain), declarations, @@ -62,8 +61,8 @@ def _transform_by_pattern( predicate=predicate, uid_generator=eve_utils.UIDGenerator(prefix="__tmp_subexpr"), # TODO(tehrengruber): extracting the deepest expression first would allow us to fuse - # the extracted expressions resulting in fewer kernel calls, better data-locality. - # Extracting the multiple expressions deepest-first is however not supported right now. + # the extracted expressions resulting in fewer kernel calls & better data-locality. + # Extracting multiple expressions deepest-first is however not supported right now. # deepest_expr_first=True # noqa: ERA001 ) @@ -146,14 +145,14 @@ def _transform_stmt( unprocessed_stmts: list[itir.Stmt] = [stmt] stmts: list[itir.Stmt] = [] - _transform_stmts: list[Callable] = [ - # _transform_stmt functional if_ into if-stmt + transforms: list[Callable] = [ + # transform `if_` call into `IfStmt` _transform_if, # extract applied `as_fieldop` to top-level functools.partial( _transform_by_pattern, predicate=lambda expr, _: cpm.is_applied_as_fieldop(expr) ), - # extract functional if_ to the top-level + # extract if_ call to the top-level functools.partial( _transform_by_pattern, predicate=lambda expr, _: cpm.is_call_to(expr, "if_") ), @@ -162,18 +161,16 @@ def _transform_stmt( while unprocessed_stmts: stmt = unprocessed_stmts.pop(0) - did_transform_stmt = False - for _transform_stmt in _transform_stmts: - _transform_stmted_stmts = _transform_stmt( - stmt=stmt, declarations=declarations, uids=uids - ) - if _transform_stmted_stmts: - unprocessed_stmts = [*_transform_stmted_stmts, *unprocessed_stmts] - did_transform_stmt = True + did_transform = False + for transform in transforms: + transformed_stmts = transform(stmt=stmt, declarations=declarations, uids=uids) + if transformed_stmts: + unprocessed_stmts = [*transformed_stmts, *unprocessed_stmts] + did_transform = True break - # no _transform_stmtation occurred - if not did_transform_stmt: + # no transformation occurred + if not did_transform: stmts.append(stmt) return stmts @@ -185,7 +182,7 @@ def create_global_tmps( """ Given an `itir.Program` create temporaries for intermediate values. - This pass looks at all `as_fieldop` calls and _transform_stmts field-typed subexpressions of its + This pass looks at all `as_fieldop` calls and transforms field-typed subexpressions of its arguments into temporaries. """ program = infer_domain.infer_program(program, offset_provider) From 04f59dd804a0f59a3d734b9743880ef22657c077 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 10 Oct 2024 22:20:44 +0200 Subject: [PATCH 028/150] Inline lambda pass: ensure opcount preserving option works whether `itir.SymRef` has a type or not --- src/gt4py/next/iterator/ir.py | 3 ++- .../transforms_tests/test_inline_lambdas.py | 16 ++++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index b2a549501f..2ca460b229 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -40,7 +40,8 @@ def __hash__(self) -> int: return hash(type(self)) ^ hash( tuple( hash(tuple(v)) if isinstance(v, list) else hash(v) - for v in self.iter_children_values() + for (k, v) in self.iter_children_items() + if k not in ["location", "type"] ) ) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py index e45281734b..2e0a83d33b 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py @@ -8,6 +8,7 @@ import pytest +from gt4py.next.type_system import type_specifications as ts from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas @@ -39,6 +40,21 @@ ), im.multiplies_(im.plus(2, 1), im.plus("x", "x")), ), + ( + # ensure opcount preserving option works whether `itir.SymRef` has a type or not + "typed_ref", + im.let("a", im.call("opaque")())( + im.plus(im.ref("a", ts.ScalarType(kind=ts.ScalarKind.FLOAT32)), im.ref("a", None)) + ), + { + True: im.let("a", im.call("opaque")())( + im.plus( # stays as is + im.ref("a", ts.ScalarType(kind=ts.ScalarKind.FLOAT32)), im.ref("a", None) + ) + ), + False: im.plus(im.call("opaque")(), im.call("opaque")()), + }, + ), ] From feab64705832fd226ea9066a5d09704850a74e97 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 10 Oct 2024 22:37:05 +0200 Subject: [PATCH 029/150] Retrigger CI From 376153f57093dcd669fcb91a7073ba6135d8ee31 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 11 Oct 2024 11:23:06 +0200 Subject: [PATCH 030/150] Address review comments --- .../iterator/transforms/fuse_as_fieldop.py | 78 ++++++++++--------- .../next/iterator/type_system/inference.py | 3 +- 2 files changed, 44 insertions(+), 37 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index 4be2fa312f..51bbd91d83 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -21,40 +21,9 @@ from gt4py.next.type_system import type_info, type_specifications as ts -def _inline_as_fieldop_arg(arg: itir.Expr, uids: eve_utils.UIDGenerator): - assert cpm.is_applied_as_fieldop(arg) - arg = _canonicalize_as_fieldop(arg) - - stencil, *_ = arg.fun.args # type: ignore[attr-defined] # ensured by `is_applied_as_fieldop` - inner_args: list[itir.Expr] = arg.args - extracted_args: dict[str, itir.Expr] = {} # mapping from stencil param to arg - - stencil_params: list[itir.Sym] = [] - stencil_body: itir.Expr = stencil.expr - - for inner_param, inner_arg in zip(stencil.params, inner_args, strict=True): - if isinstance(inner_arg, itir.SymRef): - stencil_params.append(inner_param) - extracted_args[inner_arg.id] = inner_arg - elif isinstance(inner_arg, itir.Literal): - # note: only literals, not all scalar expressions are required as it doesn't make sense - # for them to be computed per grid point. - stencil_body = im.let(inner_param, im.promote_to_const_iterator(inner_arg))( - stencil_body - ) - else: - # a scalar expression, a previously not inlined `as_fieldop` call or an opaque - # expression e.g. containing a tuple - stencil_params.append(inner_param) - new_outer_stencil_param = uids.sequential_id(prefix="__iasfop") - extracted_args[new_outer_stencil_param] = inner_arg - - return im.lift(im.lambda_(*stencil_params)(stencil_body))( - *extracted_args.keys() - ), extracted_args - - -def _merge_arguments(args1: dict, arg2: dict): +def _merge_arguments( + args1: dict[str, itir.Expr], arg2: dict[str, itir.Expr] +) -> dict[str, itir.Expr]: new_args = {**args1} for stencil_param, stencil_arg in arg2.items(): if stencil_param not in new_args: @@ -65,6 +34,12 @@ def _merge_arguments(args1: dict, arg2: dict): def _canonicalize_as_fieldop(expr: itir.FunCall) -> itir.FunCall: + """ + Canonicalize applied `as_fieldop`s. + + In case the stencil argument is a `deref` wrap it into a lambda such that we have a unified + format to work with (e.g. each parameter has a name without the need to special case). + """ assert cpm.is_applied_as_fieldop(expr) stencil = expr.fun.args[0] # type: ignore[attr-defined] @@ -109,6 +84,38 @@ class FuseAsFieldOp(eve.NodeTranslator): uids: eve_utils.UIDGenerator + def _inline_as_fieldop_arg(self, arg: itir.Expr) -> tuple[itir.Expr, dict[str, itir.Expr]]: + assert cpm.is_applied_as_fieldop(arg) + arg = _canonicalize_as_fieldop(arg) + + stencil, *_ = arg.fun.args # type: ignore[attr-defined] # ensured by `is_applied_as_fieldop` + inner_args: list[itir.Expr] = arg.args + extracted_args: dict[str, itir.Expr] = {} # mapping from outer-stencil param to arg + + stencil_params: list[itir.Sym] = [] + stencil_body: itir.Expr = stencil.expr + + for inner_param, inner_arg in zip(stencil.params, inner_args, strict=True): + if isinstance(inner_arg, itir.SymRef): + stencil_params.append(inner_param) + extracted_args[inner_arg.id] = inner_arg + elif isinstance(inner_arg, itir.Literal): + # note: only literals, not all scalar expressions are required as it doesn't make sense + # for them to be computed per grid point. + stencil_body = im.let(inner_param, im.promote_to_const_iterator(inner_arg))( + stencil_body + ) + else: + # a scalar expression, a previously not inlined `as_fieldop` call or an opaque + # expression e.g. containing a tuple + stencil_params.append(inner_param) + new_outer_stencil_param = self.uids.sequential_id(prefix="__iasfop") + extracted_args[new_outer_stencil_param] = inner_arg + + return im.lift(im.lambda_(*stencil_params)(stencil_body))( + *extracted_args.keys() + ), extracted_args + @classmethod def apply( cls, @@ -166,7 +173,7 @@ def visit_FunCall(self, node: itir.FunCall): else: raise NotImplementedError() - inline_expr, extracted_args = _inline_as_fieldop_arg(arg, self.uids) + inline_expr, extracted_args = self._inline_as_fieldop_arg(arg) new_stencil_body = im.let(stencil_param, inline_expr)(new_stencil_body) @@ -192,5 +199,6 @@ def visit_FunCall(self, node: itir.FunCall): *new_args.values() ) type_inference.copy_type(from_=node, to=new_node) + return new_node return node diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 41f1670dcd..fccaa56232 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -96,7 +96,7 @@ def _set_node_type(node: itir.Node, type_: ts.TypeSpec) -> None: node.type = type_ -def copy_type(from_: itir.Node, to: itir.Node) -> itir.Node: +def copy_type(from_: itir.Node, to: itir.Node) -> None: """ Copy type from one node to another. @@ -104,7 +104,6 @@ def copy_type(from_: itir.Node, to: itir.Node) -> itir.Node: """ assert isinstance(from_.type, ts.TypeSpec) _set_node_type(to, from_.type) - return to def on_inferred(callback: Callable, *args: Union[ts.TypeSpec, ObservableTypeSynthesizer]) -> None: From 52a1b908f958c84d2da30412f8e4b682190f7b16 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 14 Oct 2024 18:23:01 +0200 Subject: [PATCH 031/150] Allow type inference without domain argument to `as_fieldop` --- .../next/iterator/type_system/inference.py | 4 ++- .../type_system/type_specifications.py | 2 +- .../iterator/type_system/type_synthesizer.py | 27 ++++++++++++++++--- .../iterator_tests/test_type_inference.py | 13 +++++++++ 4 files changed, 40 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index bc1095dfb8..52661fb42c 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -494,9 +494,11 @@ def visit_Program(self, node: itir.Program, *, ctx) -> it_ts.ProgramType: def visit_Temporary(self, node: itir.Temporary, *, ctx) -> ts.FieldType | ts.TupleType: domain = self.visit(node.domain, ctx=ctx) assert isinstance(domain, it_ts.DomainType) + assert domain.dims != "unknown" assert node.dtype return type_info.apply_to_primitive_constituents( - lambda dtype: ts.FieldType(dims=domain.dims, dtype=dtype), node.dtype + lambda dtype: ts.FieldType(dims=domain.dims, dtype=dtype), # type: ignore[arg-type] # ensured by domain.dims != "unknown" above + node.dtype, ) def visit_IfStmt(self, node: itir.IfStmt, *, ctx) -> None: diff --git a/src/gt4py/next/iterator/type_system/type_specifications.py b/src/gt4py/next/iterator/type_system/type_specifications.py index cfe3987b8c..94a174dca4 100644 --- a/src/gt4py/next/iterator/type_system/type_specifications.py +++ b/src/gt4py/next/iterator/type_system/type_specifications.py @@ -20,7 +20,7 @@ class NamedRangeType(ts.TypeSpec): @dataclasses.dataclass(frozen=True) class DomainType(ts.DataType): - dims: list[common.Dimension] + dims: list[common.Dimension] | Literal["unknown"] @dataclasses.dataclass(frozen=True) diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 77cd39389a..4f1e05bdfd 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -271,17 +271,36 @@ def _convert_as_fieldop_input_to_iterator( @_register_builtin_type_synthesizer def as_fieldop( - stencil: TypeSynthesizer, domain: it_ts.DomainType, offset_provider: common.OffsetProvider + stencil: TypeSynthesizer, + domain: Optional[it_ts.DomainType] = None, + *, + offset_provider: common.OffsetProvider, ) -> TypeSynthesizer: + # In case we don't have a domain argument to `as_fieldop` we can not infer the exact result + # type. In order to still allow some passes to run before the domain inference which don't + # need this information we continue with a dummy domain. One example is the CollapseTuple + # pass which only needs information about the structure, e.g. how many tuple elements does + # this node have, but not the dimensions of a field. + # Note that it might appear as if using the TraceShift pass would allow us to deduce the + # return type of `as_fieldop` without a domain, but this is not the case, since we don't + # have information on the ordering of dimensions. In this example + # `as_fieldop(it1, it2 -> deref(it1) + deref(it2))(i_field, j_field)` + # it is unclear if the result has dimenion I, J or J, I. + if domain is None: + domain = it_ts.DomainType(dims="unknown") + @TypeSynthesizer - def applied_as_fieldop(*fields) -> ts.FieldType: + def applied_as_fieldop(*fields) -> ts.FieldType | ts.DeferredType: stencil_return = stencil( *(_convert_as_fieldop_input_to_iterator(domain, field) for field in fields), offset_provider=offset_provider, ) assert isinstance(stencil_return, ts.DataType) return type_info.apply_to_primitive_constituents( - lambda el_type: ts.FieldType(dims=domain.dims, dtype=el_type), stencil_return + lambda el_type: ts.FieldType(dims=domain.dims, dtype=el_type) + if domain.dims != "unknown" + else ts.DeferredType(constraint=ts.FieldType), + stencil_return, ) return applied_as_fieldop @@ -329,7 +348,7 @@ def applied_reduce(*args: it_ts.ListType, offset_provider: common.OffsetProvider @_register_builtin_type_synthesizer -def shift(*offset_literals, offset_provider) -> TypeSynthesizer: +def shift(*offset_literals, offset_provider: common.OffsetProvider) -> TypeSynthesizer: @TypeSynthesizer def apply_shift( it: it_ts.IteratorType | ts.DeferredType, diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 05cd6b6854..20a1d7e9b7 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -478,3 +478,16 @@ def test_if_stmt(): result = itir_type_inference.infer(testee, offset_provider={}, allow_undeclared_symbols=True) assert result.cond.type == bool_type assert result.true_branch[0].expr.type == float_i_field + + +def test_as_fieldop_without_domain(): + testee = im.as_fieldop(im.lambda_("it")(im.deref(im.shift("IOff", 1)("it"))))( + im.ref("inp", float_i_field) + ) + result = itir_type_inference.infer( + testee, offset_provider={"IOff": IDim}, allow_undeclared_symbols=True + ) + assert result.type == ts.DeferredType(constraint=ts.FieldType) + assert result.fun.args[0].type.pos_only_args[0] == it_ts.IteratorType( + position_dims="unknown", defined_dims=float_i_field.dims, element_type=float_i_field.dtype + ) From a1b4448c5b8af14e8b6801d1a777bba6b436f9e1 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 14 Oct 2024 18:26:06 +0200 Subject: [PATCH 032/150] Cleanup --- .../iterator/type_system/type_synthesizer.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 4f1e05bdfd..c836de1391 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -277,15 +277,15 @@ def as_fieldop( offset_provider: common.OffsetProvider, ) -> TypeSynthesizer: # In case we don't have a domain argument to `as_fieldop` we can not infer the exact result - # type. In order to still allow some passes to run before the domain inference which don't - # need this information we continue with a dummy domain. One example is the CollapseTuple - # pass which only needs information about the structure, e.g. how many tuple elements does - # this node have, but not the dimensions of a field. - # Note that it might appear as if using the TraceShift pass would allow us to deduce the - # return type of `as_fieldop` without a domain, but this is not the case, since we don't - # have information on the ordering of dimensions. In this example - # `as_fieldop(it1, it2 -> deref(it1) + deref(it2))(i_field, j_field)` - # it is unclear if the result has dimenion I, J or J, I. + # type. In order to still allow some passes which don't need this information to run before the + # domain inference, we continue with a dummy domain. One example is the CollapseTuple pass + # which only needs information about the structure, e.g. how many tuple elements does this node + # have, but not the dimensions of a field. + # Note that it might appear as if using the TraceShift pass would allow us to deduce the return + # type of `as_fieldop` without a domain, but this is not the case, since we don't have + # information on the ordering of dimensions. In this example + # `as_fieldop(it1, it2 -> deref(it1) + deref(it2))(i_field, j_field)` + # it is unclear if the result has dimension I, J or J, I. if domain is None: domain = it_ts.DomainType(dims="unknown") From 5b86f19c5484bc7760b26e6ead3cf6cf7f84ed93 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 15 Oct 2024 15:30:22 +0200 Subject: [PATCH 033/150] Remove comment --- src/gt4py/next/iterator/transforms/infer_domain.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index acf249363c..2a85e6f2cf 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -160,10 +160,6 @@ def infer_as_fieldop( raise ValueError(f"Unsupported expression of type '{type(in_field)}'.") input_ids.append(id_) - # TODO: note for pr: this dict contains as keys not only the symref inputs, but also - # temporary ids. The symrefs are already added to the result dict by the loop below, while - # the temporary ids should not be in the result anyway. as such do not use this dict - # as the starting point for the domain union in the loop below. inputs_accessed_domains: ACCESSED_DOMAINS = _extract_accessed_domains( stencil, input_ids, target_domain, offset_provider ) From 0e50214640848e4f9f72402d35848097c944f18b Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 16 Oct 2024 01:01:10 +0200 Subject: [PATCH 034/150] Use GTIR in embedded and gtfn --- src/gt4py/next/backend.py | 13 +- src/gt4py/next/config.py | 2 +- src/gt4py/next/ffront/decorator.py | 13 +- src/gt4py/next/ffront/foast_to_gtir.py | 35 +++- src/gt4py/next/ffront/foast_to_past.py | 9 +- src/gt4py/next/ffront/gtcallable.py | 11 ++ src/gt4py/next/ffront/past_to_itir.py | 17 +- .../next/iterator/ir_utils/domain_utils.py | 21 ++- src/gt4py/next/iterator/ir_utils/ir_makers.py | 5 +- .../next/iterator/transforms/__init__.py | 4 +- .../iterator/transforms/collapse_list_get.py | 9 + .../iterator/transforms/collapse_tuple.py | 7 + src/gt4py/next/iterator/transforms/cse.py | 19 +- .../iterator/transforms/fuse_as_fieldop.py | 1 + .../next/iterator/transforms/global_tmps.py | 8 +- .../next/iterator/transforms/infer_domain.py | 51 +++-- .../iterator/transforms/inline_into_scan.py | 2 +- .../next/iterator/transforms/pass_manager.py | 165 ++++------------ .../transforms/pass_manager_legacy.py | 176 ++++++++++++++++++ .../next/iterator/transforms/unroll_reduce.py | 9 +- .../codegens/gtfn/gtfn_ir.py | 56 ++++-- .../codegens/gtfn/gtfn_module.py | 10 +- .../codegens/gtfn/itir_to_gtfn_ir.py | 27 ++- .../program_processors/formatters/lisp.py | 4 +- .../next/program_processors/runners/dace.py | 8 +- .../runners/dace_iterator/__init__.py | 13 +- .../runners/dace_iterator/workflow.py | 3 - .../next/program_processors/runners/gtfn.py | 9 - .../program_processors/runners/roundtrip.py | 26 ++- tests/next_tests/definitions.py | 9 +- .../ffront_tests/ffront_test_utils.py | 3 +- .../ffront_tests/test_execution.py | 9 - .../ffront_tests/test_scalar_if.py | 1 + .../test_temporaries_with_sizes.py | 18 +- .../transforms_tests/test_domain_inference.py | 25 ++- 35 files changed, 541 insertions(+), 257 deletions(-) create mode 100644 src/gt4py/next/iterator/transforms/pass_manager_legacy.py diff --git a/src/gt4py/next/backend.py b/src/gt4py/next/backend.py index 0340d61f89..017b7324cc 100644 --- a/src/gt4py/next/backend.py +++ b/src/gt4py/next/backend.py @@ -15,6 +15,7 @@ from gt4py._core import definitions as core_defs from gt4py.next import allocators as next_allocators from gt4py.next.ffront import ( + foast_to_gtir, foast_to_itir, foast_to_past, func_to_foast, @@ -76,7 +77,7 @@ class Transforms(workflow.MultiWorkflow[INPUT_PAIR, stages.CompilableProgram]): ) foast_to_itir: workflow.Workflow[AOT_FOP, itir.Expr] = dataclasses.field( - default_factory=foast_to_itir.adapted_foast_to_itir_factory + default_factory=foast_to_gtir.adapted_foast_to_gtir_factory ) field_view_op_to_prog: workflow.Workflow[AOT_FOP, AOT_PRG] = dataclasses.field( @@ -134,6 +135,16 @@ def step_order(self, inp: INPUT_PAIR) -> list[str]: DEFAULT_TRANSFORMS: Transforms = Transforms() +# FIXME[#1582](havogt): remove after refactoring to GTIR +_foast_to_itir_step = foast_to_itir.adapted_foast_to_itir_factory(cached=True) +LEGACY_TRANSFORMS: Transforms = Transforms( + past_to_itir=past_to_itir.past_to_itir_factory(to_gtir=False), + foast_to_itir=_foast_to_itir_step, + field_view_op_to_prog=foast_to_past.operator_to_program_factory( + foast_to_itir_step=_foast_to_itir_step + ), +) + # TODO(tehrengruber): Rename class and `executor` & `transforms` attribute. Maybe: # `Backend` -> `Toolchain` diff --git a/src/gt4py/next/config.py b/src/gt4py/next/config.py index ed244c2932..4f53e3c535 100644 --- a/src/gt4py/next/config.py +++ b/src/gt4py/next/config.py @@ -56,7 +56,7 @@ def env_flag_to_bool(name: str, default: bool) -> bool: #: Master debug flag #: Changes defaults for all the other options to be as helpful for debugging as possible. #: Does not override values set in environment variables. -DEBUG: Final[bool] = env_flag_to_bool(f"{_PREFIX}_DEBUG", default=False) +DEBUG: Final[bool] = env_flag_to_bool(f"{_PREFIX}_DEBUG", default=True) #: Verbose flag for DSL compilation errors diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 52fe8d8116..dc2421e1d2 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -34,6 +34,8 @@ from gt4py.next.embedded import operators as embedded_operators from gt4py.next.ffront import ( field_operator_ast as foast, + foast_to_gtir, + foast_to_itir, past_process_args, signature, stages as ffront_stages, @@ -560,10 +562,15 @@ def with_grid_type(self, grid_type: GridType) -> FieldOperator: self, definition_stage=dataclasses.replace(self.definition_stage, grid_type=grid_type) ) + # TODO(tehrengruber): We can not use transforms from `self.backend` since this can be + # a different backend than the one of the program that calls this field operator. Just use + # the hard-coded lowering until this is cleaned up. def __gt_itir__(self) -> itir.FunctionDefinition: - return self._frontend_transforms.foast_to_itir( - toolchain.CompilableProgram(self.foast_stage, arguments.CompileTimeArgs.empty()) - ) + return foast_to_itir.foast_to_itir(self.foast_stage) + + # FIXME[#1582](tehrengruber): remove after refactoring to GTIR + def __gt_gtir__(self) -> itir.FunctionDefinition: + return foast_to_gtir.foast_to_gtir(self.foast_stage) def __gt_closure_vars__(self) -> dict[str, Any]: return self.foast_stage.closure_vars diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 948a8481d7..5a27fe380f 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -116,7 +116,31 @@ def visit_FieldOperator( def visit_ScanOperator( self, node: foast.ScanOperator, **kwargs: Any ) -> itir.FunctionDefinition: - raise NotImplementedError("TODO") + # note: we don't need the axis here as this is handled by the program + # decorator + assert isinstance(node.type, ts_ffront.ScanOperatorType) + + # We are lowering node.forward and node.init to iterators, but here we expect values -> `deref`. + # In iterator IR we didn't properly specify if this is legal, + # however after lift-inlining the expressions are transformed back to literals. + forward = self.visit(node.forward, **kwargs) + init = self.visit(node.init, **kwargs) + + # lower definition function + func_definition: itir.FunctionDefinition = self.visit(node.definition, **kwargs) + new_body = func_definition.expr + + stencil_args: list[itir.Expr] = [] + assert not node.type.definition.pos_only_args and not node.type.definition.kw_only_args + for param in func_definition.params[1:]: + new_body = im.let(param.id, im.deref(param.id))(new_body) + stencil_args.append(im.ref(param.id)) + + definition = itir.Lambda(params=func_definition.params, expr=new_body) + + body = im.as_fieldop(im.call("scan")(definition, forward, init))(*stencil_args) + + return itir.FunctionDefinition(id=node.id, params=definition.params[1:], expr=body) def visit_Stmt(self, node: foast.Stmt, **kwargs: Any) -> Never: raise AssertionError("Statements must always be visited in the context of a function.") @@ -324,10 +348,6 @@ def visit_Call(self, node: foast.Call, **kwargs: Any) -> itir.Expr: *lowered_args, *lowered_kwargs.values() ) - # scan operators return an iterator of tuples, transform into tuples of iterator again - if isinstance(node.func.type, ts_ffront.ScanOperatorType): - raise NotImplementedError("TODO") - return result raise AssertionError( @@ -373,7 +393,10 @@ def create_if(true_: itir.Expr, false_: itir.Expr) -> itir.FunCall: _visit_concat_where = _visit_where # TODO(havogt): upgrade concat_where def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: - return self.visit(node.args[0], **kwargs) + expr = self.visit(node.args[0], **kwargs) + if type_info.is_type_or_tuple_of_type(node.args[0].type, ts.ScalarType): + return im.as_fieldop(im.ref("deref"))(expr) + return expr def _visit_math_built_in(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: return self._map(self.visit(node.func, **kwargs), *node.args) diff --git a/src/gt4py/next/ffront/foast_to_past.py b/src/gt4py/next/ffront/foast_to_past.py index 0844f63286..a3b6c00ffa 100644 --- a/src/gt4py/next/ffront/foast_to_past.py +++ b/src/gt4py/next/ffront/foast_to_past.py @@ -12,7 +12,7 @@ from gt4py.eve import utils as eve_utils from gt4py.next.ffront import ( dialect_ast_enums, - foast_to_itir, + foast_to_gtir, program_ast as past, stages as ffront_stages, type_specifications as ts_ffront, @@ -45,6 +45,11 @@ def __gt_type__(self) -> ts.CallableType: def __gt_itir__(self) -> itir.Expr: return self.foast_to_itir(self.definition) + # FIXME[#1582](tehrengruber): remove after refactoring to GTIR + def __gt_gtir__(self) -> itir.Expr: + # backend should have self.foast_to_itir set to foast_to_gtir + return self.foast_to_itir(self.definition) + @dataclasses.dataclass(frozen=True) class OperatorToProgram(workflow.Workflow[AOT_FOP, AOT_PRG]): @@ -164,7 +169,7 @@ def operator_to_program_factory( ) -> workflow.Workflow[AOT_FOP, AOT_PRG]: """Optionally wrap `OperatorToProgram` in a `CachedStep`.""" wf: workflow.Workflow[AOT_FOP, AOT_PRG] = OperatorToProgram( - foast_to_itir_step or foast_to_itir.adapted_foast_to_itir_factory() + foast_to_itir_step or foast_to_gtir.adapted_foast_to_gtir_factory() ) if cached: wf = workflow.CachedStep(wf, hash_function=ffront_stages.fingerprint_stage) diff --git a/src/gt4py/next/ffront/gtcallable.py b/src/gt4py/next/ffront/gtcallable.py index beaebb3a5a..072890a5d1 100644 --- a/src/gt4py/next/ffront/gtcallable.py +++ b/src/gt4py/next/ffront/gtcallable.py @@ -52,6 +52,17 @@ def __gt_itir__(self) -> itir.FunctionDefinition: """ ... + # FIXME[#1582](tehrengruber): remove after refactoring to GTIR + @abc.abstractmethod + def __gt_gtir__(self) -> itir.FunctionDefinition: + """ + Return iterator IR function definition representing the callable. + + Used internally by the Program decorator to populate the function + definitions of the iterator IR. + """ + ... + # TODO(tehrengruber): For embedded execution a `__call__` method and for # "truly" embedded execution arguably also a `from_function` method is # required. Since field operators currently have a `__gt_type__` with a diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index a20c517cce..c0348bb5c6 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -80,11 +80,18 @@ def past_to_itir(inp: AOT_PRG, to_gtir: bool = False) -> stages.CompilableProgra gt_callables = transform_utils._filter_closure_vars_by_type( all_closure_vars, gtcallable.GTCallable ).values() + + # FIXME[#1582](tehrengruber): remove after refactoring to GTIR # TODO(ricoh): The following calls to .__gt_itir__, which will use whatever - # backend is set for each of these field operators (GTCallables). Instead - # we should use the current toolchain to lower these to ITIR. This will require - # making this step aware of the toolchain it is called by (it can be part of multiple). - lowered_funcs = [gt_callable.__gt_itir__() for gt_callable in gt_callables] + # backend is set for each of these field operators (GTCallables). Instead + # we should use the current toolchain to lower these to ITIR. This will require + # making this step aware of the toolchain it is called by (it can be part of multiple). + lowered_funcs = [] + for gt_callable in gt_callables: + if to_gtir: + lowered_funcs.append(gt_callable.__gt_gtir__()) + else: + lowered_funcs.append(gt_callable.__gt_itir__()) itir_program = ProgramLowering.apply( inp.data.past_node, function_definitions=lowered_funcs, grid_type=grid_type, to_gtir=to_gtir @@ -101,7 +108,7 @@ def past_to_itir(inp: AOT_PRG, to_gtir: bool = False) -> stages.CompilableProgra # FIXME[#1582](havogt): remove `to_gtir` arg after refactoring to GTIR def past_to_itir_factory( - cached: bool = True, to_gtir: bool = False + cached: bool = True, to_gtir: bool = True ) -> workflow.Workflow[AOT_PRG, stages.CompilableProgram]: wf = workflow.make_step(functools.partial(past_to_itir, to_gtir=to_gtir)) if cached: diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index 8eec405136..9e8729b339 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -10,7 +10,7 @@ import dataclasses import functools -from typing import Any, Literal, Mapping +from typing import Any, Literal, Mapping, Optional import gt4py.next as gtx from gt4py.next import common @@ -93,6 +93,7 @@ def translate( ..., ], offset_provider: common.OffsetProvider, + symbolic_domain_sizes: Optional[dict[str, str]], ) -> SymbolicDomain: dims = list(self.ranges.keys()) new_ranges = {dim: self.ranges[dim] for dim in dims} @@ -119,18 +120,24 @@ def translate( trace_shifts.Sentinel.ALL_NEIGHBORS, trace_shifts.Sentinel.VALUE, ] - # note: ugly but cheap re-computation, but should disappear - horizontal_sizes = _max_domain_sizes_by_location_type(offset_provider) + horizontal_sizes: dict[str, itir.Expr] + if symbolic_domain_sizes is not None: + horizontal_sizes = {k: im.ref(v) for k, v in symbolic_domain_sizes.items()} + else: + # note: ugly but cheap re-computation, but should disappear + horizontal_sizes = { + k: im.literal(str(v), itir.INTEGER_INDEX_BUILTIN) + for k, v in _max_domain_sizes_by_location_type(offset_provider).items() + } old_dim = nbt_provider.origin_axis new_dim = nbt_provider.neighbor_axis assert new_dim not in new_ranges or old_dim == new_dim - # TODO(tehrengruber): Do we need symbolic sizes, e.g., for ICON? new_range = SymbolicRange( im.literal("0", itir.INTEGER_INDEX_BUILTIN), - im.literal(str(horizontal_sizes[new_dim.value]), itir.INTEGER_INDEX_BUILTIN), + horizontal_sizes[new_dim.value], ) new_ranges = dict( (dim, range_) if dim != old_dim else (new_dim, new_range) @@ -140,7 +147,9 @@ def translate( raise AssertionError() return SymbolicDomain(self.grid_type, new_ranges) elif len(shift) > 2: - return self.translate(shift[0:2], offset_provider).translate(shift[2:], offset_provider) + return self.translate(shift[0:2], offset_provider, symbolic_domain_sizes).translate( + shift[2:], offset_provider, symbolic_domain_sizes + ) else: raise AssertionError("Number of shifts must be a multiple of 2.") diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 19e26f24b6..d7a66b8285 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -10,7 +10,6 @@ from typing import Callable, Optional, Union from gt4py._core import definitions as core_defs -from gt4py.eve.extended_typing import Dict, Tuple from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.type_system import type_specifications as ts, type_translation @@ -412,7 +411,7 @@ def _impl(*its: itir.Expr) -> itir.FunCall: def domain( grid_type: Union[common.GridType, str], - ranges: Dict[Union[common.Dimension, str], Tuple[itir.Expr, itir.Expr]], + ranges: dict[Union[common.Dimension, str], tuple[itir.Expr, itir.Expr]], ) -> itir.FunCall: """ >>> str( @@ -446,7 +445,7 @@ def domain( ) -def as_fieldop(expr: itir.Expr, domain: Optional[itir.Expr] = None) -> call: +def as_fieldop(expr: itir.Expr | str, domain: Optional[itir.Expr] = None) -> call: """ Create an `as_fieldop` call. diff --git a/src/gt4py/next/iterator/transforms/__init__.py b/src/gt4py/next/iterator/transforms/__init__.py index 58678cfc9c..6c35b53729 100644 --- a/src/gt4py/next/iterator/transforms/__init__.py +++ b/src/gt4py/next/iterator/transforms/__init__.py @@ -6,7 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from gt4py.next.iterator.transforms.pass_manager import LiftMode, apply_common_transforms +from gt4py.next.iterator.transforms.pass_manager import apply_common_transforms -__all__ = ["apply_common_transforms", "LiftMode"] +__all__ = ["apply_common_transforms"] diff --git a/src/gt4py/next/iterator/transforms/collapse_list_get.py b/src/gt4py/next/iterator/transforms/collapse_list_get.py index f8a3c08e8f..0795cf5739 100644 --- a/src/gt4py/next/iterator/transforms/collapse_list_get.py +++ b/src/gt4py/next/iterator/transforms/collapse_list_get.py @@ -8,6 +8,7 @@ from gt4py import eve from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im class CollapseListGet(eve.PreserveLocationVisitor, eve.NodeTranslator): @@ -21,6 +22,14 @@ class CollapseListGet(eve.PreserveLocationVisitor, eve.NodeTranslator): def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: node = self.generic_visit(node) if node.fun == ir.SymRef(id="list_get"): + if cpm.is_call_to(node.args[1], "if_"): + list_idx = node.args[0] + cond, true_val, false_val = node.args[1].args + return im.if_( + cond, + self.visit(im.call("list_get")(list_idx, true_val)), + self.visit(im.call("list_get")(list_idx, false_val)), + ) if isinstance(node.args[1], ir.FunCall): if node.args[1].fun == ir.SymRef(id="neighbors"): offset_tag = node.args[1].args[0] diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 40d98208dd..2a2608081c 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -87,6 +87,7 @@ def all(self) -> CollapseTuple.Flag: return functools.reduce(operator.or_, self.__members__.values()) ignore_tuple_size: bool + field_view_only: bool flags: Flag = Flag.all() # noqa: RUF009 [function-call-in-dataclass-default-argument] PRESERVED_ANNEX_ATTRS = ("type",) @@ -105,6 +106,7 @@ def apply( ignore_tuple_size: bool = False, remove_letified_make_tuple_elements: bool = True, offset_provider=None, + field_view_only=True, # manually passing flags is mostly for allowing separate testing of the modes flags=None, # allow sym references without a symbol declaration, mostly for testing @@ -135,6 +137,7 @@ def apply( new_node = cls( ignore_tuple_size=ignore_tuple_size, + field_view_only=field_view_only, flags=flags, ).visit(node) @@ -151,6 +154,10 @@ def apply( return new_node def visit_FunCall(self, node: ir.FunCall) -> ir.Node: + # don't visit stencil argument of `as_fieldop` + if self.field_view_only and cpm.is_call_to(node, "as_fieldop"): + return node + node = self.generic_visit(node) return self.fp_transform(node) diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index ccc1d2195f..aefff0c8bb 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -30,6 +30,21 @@ from gt4py.next.type_system import type_info, type_specifications as ts +def _is_trivial_tuple_expr(node: itir.Expr): + """Return if node is a `make_tuple` call with all elements `SymRef`s, `Literal`s or tuples thereof.""" + if cpm.is_call_to(node, "make_tuple") and all( + isinstance(arg, (itir.SymRef, itir.Literal)) or _is_trivial_tuple_expr(arg) + for arg in node.args + ): + return True + if cpm.is_call_to(node, "tuple_get") and ( + isinstance(node.args[1], (itir.SymRef, itir.Literal)) + or _is_trivial_tuple_expr(node.args[1]) + ): + return True + return True + + @dataclasses.dataclass class _NodeReplacer(PreserveLocationVisitor, NodeTranslator): PRESERVED_ANNEX_ATTRS = ("type", "domain") @@ -437,11 +452,13 @@ def predicate(subexpr: itir.Expr, num_occurences: int): # only extract fields outside of `as_fieldop` # `as_fieldop(...)(field_expr, field_expr)` # -> `(λ(_cs_1) → as_fieldop(...)(_cs_1, _cs_1))(field_expr)` + # only extract if subexpression is not a trivial tuple expressions, e.g., + # `make_tuple(a, b)`, as this would result in a more costly temporary. assert isinstance(subexpr.type, ts.TypeSpec) if all( isinstance(stype, ts.FieldType) for stype in type_info.primitive_constituents(subexpr.type) - ): + ) and not _is_trivial_tuple_expr(subexpr): return True return False diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index 51bbd91d83..8cb54ef305 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -179,6 +179,7 @@ def visit_FunCall(self, node: itir.FunCall): new_args = _merge_arguments(new_args, extracted_args) else: + assert not isinstance(dtype, it_ts.ListType) new_param: str if isinstance( arg, itir.SymRef diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index f80dc74834..2444828895 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -177,7 +177,10 @@ def _transform_stmt( def create_global_tmps( - program: itir.Program, offset_provider: common.OffsetProvider + program: itir.Program, + offset_provider: common.OffsetProvider, + *, + uids: Optional[eve_utils.UIDGenerator] = None, ) -> itir.Program: """ Given an `itir.Program` create temporaries for intermediate values. @@ -188,7 +191,8 @@ def create_global_tmps( program = infer_domain.infer_program(program, offset_provider) program = type_inference.infer(program, offset_provider=offset_provider) - uids = eve_utils.UIDGenerator(prefix="__tmp") + if not uids: + uids = eve_utils.UIDGenerator(prefix="__tmp") declarations = program.declarations.copy() new_body = [] diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index 2a85e6f2cf..340ae7c53e 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -10,7 +10,7 @@ import itertools import typing -from typing import Callable, TypeAlias +from typing import Callable, Optional, TypeAlias from gt4py.eve import utils as eve_utils from gt4py.next import common @@ -107,6 +107,7 @@ def _extract_accessed_domains( input_ids: list[str], target_domain: domain_utils.SymbolicDomain, offset_provider: common.OffsetProvider, + symbolic_domain_sizes: Optional[dict[str, str]], ) -> ACCESSED_DOMAINS: accessed_domains: dict[str, domain_utils.SymbolicDomain | None] = {} @@ -114,7 +115,9 @@ def _extract_accessed_domains( for in_field_id, shifts_list in zip(input_ids, shifts_results, strict=True): new_domains = [ - domain_utils.SymbolicDomain.translate(target_domain, shift, offset_provider) + domain_utils.SymbolicDomain.translate( + target_domain, shift, offset_provider, symbolic_domain_sizes + ) for shift in shifts_list ] # `None` means field is never accessed @@ -129,6 +132,7 @@ def infer_as_fieldop( applied_fieldop: itir.FunCall, target_domain: DOMAIN, offset_provider: common.OffsetProvider, + symbolic_domain_sizes: Optional[dict[str, str]], ) -> tuple[itir.FunCall, ACCESSED_DOMAINS]: assert isinstance(applied_fieldop, itir.FunCall) assert cpm.is_call_to(applied_fieldop.fun, "as_fieldop") @@ -161,7 +165,7 @@ def infer_as_fieldop( input_ids.append(id_) inputs_accessed_domains: ACCESSED_DOMAINS = _extract_accessed_domains( - stencil, input_ids, target_domain, offset_provider + stencil, input_ids, target_domain, offset_provider, symbolic_domain_sizes ) # Recursively infer domain of inputs and update domain arg of nested `as_fieldop`s @@ -169,7 +173,7 @@ def infer_as_fieldop( transformed_inputs: list[itir.Expr] = [] for in_field_id, in_field in zip(input_ids, inputs): transformed_input, accessed_domains_tmp = infer_expr( - in_field, inputs_accessed_domains[in_field_id], offset_provider + in_field, inputs_accessed_domains[in_field_id], offset_provider, symbolic_domain_sizes ) transformed_inputs.append(transformed_input) @@ -191,11 +195,12 @@ def infer_let( let_expr: itir.FunCall, input_domain: DOMAIN, offset_provider: common.OffsetProvider, + symbolic_domain_sizes: Optional[dict[str, str]], ) -> tuple[itir.FunCall, ACCESSED_DOMAINS]: assert cpm.is_let(let_expr) assert isinstance(let_expr.fun, itir.Lambda) # just to make mypy happy transformed_calls_expr, accessed_domains = infer_expr( - let_expr.fun.expr, input_domain, offset_provider + let_expr.fun.expr, input_domain, offset_provider, symbolic_domain_sizes ) let_params = {param_sym.id for param_sym in let_expr.fun.params} @@ -212,6 +217,7 @@ def infer_let( None, ), offset_provider, + symbolic_domain_sizes, ) accessed_domains_outer = _merge_domains(accessed_domains_outer, accessed_domains_arg) transformed_calls_args.append(transformed_calls_arg) @@ -230,6 +236,7 @@ def infer_make_tuple( expr: itir.Expr, domain: DOMAIN, offset_provider: common.OffsetProvider, + symbolic_domain_sizes: Optional[dict[str, str]], ) -> tuple[itir.Expr, ACCESSED_DOMAINS]: assert cpm.is_call_to(expr, "make_tuple") infered_args_expr = [] @@ -245,7 +252,9 @@ def infer_make_tuple( # e.g. `im.tuple_get(0, im.make_tuple(a, b), domain=domain)` domain = (*domain, *(None for _ in range(len(expr.args) - len(domain)))) for i, arg in enumerate(expr.args): - infered_arg_expr, actual_domains_arg = infer_expr(arg, domain[i], offset_provider) + infered_arg_expr, actual_domains_arg = infer_expr( + arg, domain[i], offset_provider, symbolic_domain_sizes + ) infered_args_expr.append(infered_arg_expr) actual_domains = _merge_domains(actual_domains, actual_domains_arg) result_expr = im.call(expr.fun)(*infered_args_expr) @@ -256,6 +265,7 @@ def infer_tuple_get( expr: itir.Expr, domain: DOMAIN, offset_provider: common.OffsetProvider, + symbolic_domain_sizes: Optional[dict[str, str]], ) -> tuple[itir.Expr, ACCESSED_DOMAINS]: assert cpm.is_call_to(expr, "tuple_get") actual_domains: ACCESSED_DOMAINS = {} @@ -263,7 +273,9 @@ def infer_tuple_get( assert isinstance(idx_expr, itir.Literal) idx = int(idx_expr.value) tuple_domain = tuple(None if i != idx else domain for i in range(idx + 1)) - infered_arg_expr, actual_domains_arg = infer_expr(tuple_arg, tuple_domain, offset_provider) + infered_arg_expr, actual_domains_arg = infer_expr( + tuple_arg, tuple_domain, offset_provider, symbolic_domain_sizes + ) infered_args_expr = im.tuple_get(idx, infered_arg_expr) actual_domains = _merge_domains(actual_domains, actual_domains_arg) @@ -274,13 +286,16 @@ def infer_if( expr: itir.Expr, domain: DOMAIN, offset_provider: common.OffsetProvider, + symbolic_domain_sizes: Optional[dict[str, str]], ) -> tuple[itir.Expr, ACCESSED_DOMAINS]: assert cpm.is_call_to(expr, "if_") infered_args_expr = [] actual_domains: ACCESSED_DOMAINS = {} cond, true_val, false_val = expr.args for arg in [true_val, false_val]: - infered_arg_expr, actual_domains_arg = infer_expr(arg, domain, offset_provider) + infered_arg_expr, actual_domains_arg = infer_expr( + arg, domain, offset_provider, symbolic_domain_sizes + ) infered_args_expr.append(infered_arg_expr) actual_domains = _merge_domains(actual_domains, actual_domains_arg) result_expr = im.call(expr.fun)(cond, *infered_args_expr) @@ -291,21 +306,22 @@ def _infer_expr( expr: itir.Expr, domain: DOMAIN, offset_provider: common.OffsetProvider, + symbolic_domain_sizes: Optional[dict[str, str]], ) -> tuple[itir.Expr, ACCESSED_DOMAINS]: if isinstance(expr, itir.SymRef): return expr, {str(expr.id): domain} elif isinstance(expr, itir.Literal): return expr, {} elif cpm.is_applied_as_fieldop(expr): - return infer_as_fieldop(expr, domain, offset_provider) + return infer_as_fieldop(expr, domain, offset_provider, symbolic_domain_sizes) elif cpm.is_let(expr): - return infer_let(expr, domain, offset_provider) + return infer_let(expr, domain, offset_provider, symbolic_domain_sizes) elif cpm.is_call_to(expr, "make_tuple"): - return infer_make_tuple(expr, domain, offset_provider) + return infer_make_tuple(expr, domain, offset_provider, symbolic_domain_sizes) elif cpm.is_call_to(expr, "tuple_get"): - return infer_tuple_get(expr, domain, offset_provider) + return infer_tuple_get(expr, domain, offset_provider, symbolic_domain_sizes) elif cpm.is_call_to(expr, "if_"): - return infer_if(expr, domain, offset_provider) + return infer_if(expr, domain, offset_provider, symbolic_domain_sizes) elif ( cpm.is_call_to(expr, itir.ARITHMETIC_BUILTINS) or cpm.is_call_to(expr, itir.TYPEBUILTINS) @@ -320,9 +336,10 @@ def infer_expr( expr: itir.Expr, domain: DOMAIN, offset_provider: common.OffsetProvider, + symbolic_domain_sizes: Optional[dict[str, str]] = None, ) -> tuple[itir.Expr, ACCESSED_DOMAINS]: # this is just a small wrapper that populates the `domain` annex - expr, accessed_domains = _infer_expr(expr, domain, offset_provider) + expr, accessed_domains = _infer_expr(expr, domain, offset_provider, symbolic_domain_sizes) expr.annex.domain = domain return expr, accessed_domains @@ -330,6 +347,7 @@ def infer_expr( def infer_program( program: itir.Program, offset_provider: common.OffsetProvider, + symbolic_domain_sizes: Optional[dict[str, str]] = None, ) -> itir.Program: transformed_set_ats: list[itir.SetAt] = [] assert ( @@ -340,7 +358,10 @@ def infer_program( assert isinstance(set_at, itir.SetAt) transformed_call, _unused_domain = infer_expr( - set_at.expr, domain_utils.SymbolicDomain.from_expr(set_at.domain), offset_provider + set_at.expr, + domain_utils.SymbolicDomain.from_expr(set_at.domain), + offset_provider, + symbolic_domain_sizes, ) transformed_set_ats.append( itir.SetAt( diff --git a/src/gt4py/next/iterator/transforms/inline_into_scan.py b/src/gt4py/next/iterator/transforms/inline_into_scan.py index f899da73b1..33e36bfa4b 100644 --- a/src/gt4py/next/iterator/transforms/inline_into_scan.py +++ b/src/gt4py/next/iterator/transforms/inline_into_scan.py @@ -5,7 +5,7 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause - +# FIXME[#1582](tehrengruber): This transformation is not used anymore. Decide on its fate. from typing import Sequence, TypeGuard from gt4py import eve diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index b3bb7bc6e1..d5c084a0a9 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -6,68 +6,34 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -import enum from typing import Callable, Optional from gt4py.eve import utils as eve_utils from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms import fencil_to_program, inline_fundefs +from gt4py.next.iterator.transforms import ( + fuse_as_fieldop, + global_tmps, + infer_domain, + inline_fundefs, +) from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple from gt4py.next.iterator.transforms.constant_folding import ConstantFolding from gt4py.next.iterator.transforms.cse import CommonSubexpressionElimination -from gt4py.next.iterator.transforms.eta_reduction import EtaReduction from gt4py.next.iterator.transforms.fuse_maps import FuseMaps -from gt4py.next.iterator.transforms.inline_center_deref_lift_vars import InlineCenterDerefLiftVars -from gt4py.next.iterator.transforms.inline_into_scan import InlineIntoScan from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas -from gt4py.next.iterator.transforms.inline_lifts import InlineLifts from gt4py.next.iterator.transforms.merge_let import MergeLet from gt4py.next.iterator.transforms.normalize_shifts import NormalizeShifts -from gt4py.next.iterator.transforms.propagate_deref import PropagateDeref -from gt4py.next.iterator.transforms.scan_eta_reduction import ScanEtaReduction from gt4py.next.iterator.transforms.unroll_reduce import UnrollReduce - - -@enum.unique -class LiftMode(enum.Enum): - FORCE_INLINE = enum.auto() - USE_TEMPORARIES = enum.auto() - - -def _inline_lifts(ir, lift_mode): - if lift_mode == LiftMode.FORCE_INLINE: - return InlineLifts().visit(ir) - elif lift_mode == LiftMode.USE_TEMPORARIES: - return InlineLifts( - flags=InlineLifts.Flag.INLINE_TRIVIAL_DEREF_LIFT - | InlineLifts.Flag.INLINE_DEREF_LIFT # some tuple exprs found in FVM don't work yet. - ).visit(ir) - else: - raise ValueError() - - return ir - - -def _inline_into_scan(ir, *, max_iter=10): - for _ in range(10): - # in case there are multiple levels of lambdas around the scan we have to do multiple iterations - inlined = InlineIntoScan().visit(ir) - inlined = InlineLambdas.apply(inlined, opcount_preserving=True, force_inline_lift_args=True) - if inlined == ir: - break - ir = inlined - else: - raise RuntimeError(f"Inlining into 'scan' did not converge within {max_iter} iterations.") - return ir +from gt4py.next.iterator.type_system.inference import infer # TODO(tehrengruber): Revisit interface to configure temporary extraction. We currently forward -# `lift_mode` and `temporary_extraction_heuristics` which is inconvenient. +# `extract_temporaries` and `temporary_extraction_heuristics` which is inconvenient. def apply_common_transforms( ir: itir.Node, *, - lift_mode=None, + extract_temporaries=False, offset_provider=None, unroll_reduce=False, common_subexpression_elimination=True, @@ -77,57 +43,44 @@ def apply_common_transforms( temporary_extraction_heuristics: Optional[ Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] ] = None, - # FIXME[#1582](tehrengruber): Revisit and cleanup after new GTIR temporary pass is in place symbolic_domain_sizes: Optional[dict[str, str]] = None, ) -> itir.Program: - if isinstance(ir, itir.FencilDefinition): - ir = fencil_to_program.FencilToProgram().apply( - ir - ) # FIXME[#1582](havogt): should be removed after refactoring to combined IR - else: - assert isinstance(ir, itir.Program) - # FIXME[#1582](havogt): note: currently the case when using the roundtrip backend - pass + assert isinstance(ir, itir.Program) - icdlv_uids = eve_utils.UIDGenerator() + tmp_uids = eve_utils.UIDGenerator(prefix="__tmp") + mergeasfop_uids = eve_utils.UIDGenerator() - if lift_mode is None: - lift_mode = LiftMode.FORCE_INLINE - assert isinstance(lift_mode, LiftMode) ir = MergeLet().visit(ir) ir = inline_fundefs.InlineFundefs().visit(ir) ir = inline_fundefs.prune_unreferenced_fundefs(ir) # type: ignore[arg-type] # all previous passes return itir.Program - ir = PropagateDeref.apply(ir) ir = NormalizeShifts().visit(ir) + # note: this increases the size of the tree + # Inline. The domain inference can not handle "user" functions, e.g. `let f = λ(...) → ... in f(...)` + ir = InlineLambdas.apply(ir, opcount_preserving=True, force_inline_lambda_args=True) + ir = infer_domain.infer_program( + ir, # type: ignore[arg-type] # always an itir.Program + offset_provider=offset_provider, + symbolic_domain_sizes=symbolic_domain_sizes, + ) + for _ in range(10): inlined = ir - inlined = InlineCenterDerefLiftVars.apply(inlined, uids=icdlv_uids) # type: ignore[arg-type] # always a fencil - inlined = _inline_lifts(inlined, lift_mode) - - inlined = InlineLambdas.apply( - inlined, - opcount_preserving=True, - force_inline_lift_args=(lift_mode == LiftMode.FORCE_INLINE), - # If trivial lifts are not inlined we might create temporaries for constants. In all - # other cases we want it anyway. - force_inline_trivial_lift_args=True, - ) - inlined = ConstantFolding.apply(inlined) + inlined = InlineLambdas.apply(inlined, opcount_preserving=True) + inlined = ConstantFolding.apply(inlined) # type: ignore[assignment] # always an itir.Program # This pass is required to be in the loop such that when an `if_` call with tuple arguments # is constant-folded the surrounding tuple_get calls can be removed. - inlined = CollapseTuple.apply( - inlined, - offset_provider=offset_provider, - # TODO(tehrengruber): disabled since it increases compile-time too much right now - flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, + inlined = CollapseTuple.apply(inlined, offset_provider=offset_provider) # type: ignore[assignment] # always an itir.Program + + # This pass is required to run after CollapseTuple as otherwise we can not inline + # expressions like `tuple_get(make_tuple(as_fieldop(stencil)(...)))` where stencil returns + # a list. Such expressions must be inlined however because no backend supports such + # field operators right now. + inlined = fuse_as_fieldop.FuseAsFieldOp.apply( + inlined, uids=mergeasfop_uids, offset_provider=offset_provider ) - # This pass is required such that a deref outside of a - # `tuple_get(make_tuple(let(...), ...))` call is propagated into the let after the - # `tuple_get` is removed by the `CollapseTuple` pass. - inlined = PropagateDeref.apply(inlined) if inlined == ir: break @@ -135,48 +88,21 @@ def apply_common_transforms( else: raise RuntimeError("Inlining 'lift' and 'lambdas' did not converge.") - if lift_mode != LiftMode.FORCE_INLINE: - # FIXME[#1582](tehrengruber): implement new temporary pass here - raise NotImplementedError() - # ruff: noqa: ERA001 - # assert offset_provider is not None - # ir = CreateGlobalTmps().visit( - # ir, - # offset_provider=offset_provider, - # extraction_heuristics=temporary_extraction_heuristics, - # symbolic_sizes=symbolic_domain_sizes, - # ) - # - # for _ in range(10): - # inlined = InlineLifts().visit(ir) - # inlined = InlineLambdas.apply( - # inlined, opcount_preserving=True, force_inline_lift_args=True - # ) - # if inlined == ir: - # break - # ir = inlined - # else: - # raise RuntimeError("Inlining 'lift' and 'lambdas' did not converge.") - # - # # If after creating temporaries, the scan is not at the top, we inline. - # # The following example doesn't have a lift around the shift, i.e. temporary pass will not extract it. - # # λ(inp) → scan(λ(state, k, kp) → state + ·k + ·kp, True, 0.0)(inp, ⟪Koffₒ, 1ₒ⟫(inp))` - # ir = _inline_into_scan(ir) + # breaks in test_zero_dim_tuple_arg as trivial tuple_get is not inlined + if common_subexpression_elimination: + ir = CommonSubexpressionElimination.apply(ir, offset_provider=offset_provider) + ir = MergeLet().visit(ir) + ir = InlineLambdas.apply(ir, opcount_preserving=True) + + if extract_temporaries: + ir = infer(ir, inplace=True, offset_provider=offset_provider) + ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider, uids=tmp_uids) # type: ignore[arg-type] # always an itir.Program # Since `CollapseTuple` relies on the type inference which does not support returning tuples # larger than the number of closure outputs as given by the unconditional collapse, we can # only run the unconditional version here instead of in the loop above. if unconditionally_collapse_tuples: - ir = CollapseTuple.apply( - ir, - ignore_tuple_size=True, - offset_provider=offset_provider, - # TODO(tehrengruber): disabled since it increases compile-time too much right now - flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, - ) - - if lift_mode == LiftMode.FORCE_INLINE: - ir = _inline_into_scan(ir) + ir = CollapseTuple.apply(ir, ignore_tuple_size=True, offset_provider=offset_provider) ir = NormalizeShifts().visit(ir) @@ -191,18 +117,9 @@ def apply_common_transforms( ir = unrolled ir = CollapseListGet().visit(ir) ir = NormalizeShifts().visit(ir) - ir = _inline_lifts(ir, LiftMode.FORCE_INLINE) - ir = NormalizeShifts().visit(ir) else: raise RuntimeError("Reduction unrolling failed.") - ir = EtaReduction().visit(ir) - ir = ScanEtaReduction().visit(ir) - - if common_subexpression_elimination: - ir = CommonSubexpressionElimination.apply(ir, offset_provider=offset_provider) # type: ignore[type-var] # always an itir.Program - ir = MergeLet().visit(ir) - ir = InlineLambdas.apply( ir, opcount_preserving=True, force_inline_lambda_args=force_inline_lambda_args ) diff --git a/src/gt4py/next/iterator/transforms/pass_manager_legacy.py b/src/gt4py/next/iterator/transforms/pass_manager_legacy.py new file mode 100644 index 0000000000..9933fcd4ae --- /dev/null +++ b/src/gt4py/next/iterator/transforms/pass_manager_legacy.py @@ -0,0 +1,176 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +# FIXME[#1582](tehrengruber): file should be removed after refactoring to GTIR +import enum +from typing import Callable, Optional + +from gt4py.eve import utils as eve_utils +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.transforms import fencil_to_program, inline_fundefs +from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet +from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple +from gt4py.next.iterator.transforms.constant_folding import ConstantFolding +from gt4py.next.iterator.transforms.cse import CommonSubexpressionElimination +from gt4py.next.iterator.transforms.eta_reduction import EtaReduction +from gt4py.next.iterator.transforms.fuse_maps import FuseMaps +from gt4py.next.iterator.transforms.inline_center_deref_lift_vars import InlineCenterDerefLiftVars +from gt4py.next.iterator.transforms.inline_into_scan import InlineIntoScan +from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas +from gt4py.next.iterator.transforms.inline_lifts import InlineLifts +from gt4py.next.iterator.transforms.merge_let import MergeLet +from gt4py.next.iterator.transforms.normalize_shifts import NormalizeShifts +from gt4py.next.iterator.transforms.propagate_deref import PropagateDeref +from gt4py.next.iterator.transforms.scan_eta_reduction import ScanEtaReduction +from gt4py.next.iterator.transforms.unroll_reduce import UnrollReduce + + +@enum.unique +class LiftMode(enum.Enum): + FORCE_INLINE = enum.auto() + USE_TEMPORARIES = enum.auto() + + +def _inline_lifts(ir, lift_mode): + if lift_mode == LiftMode.FORCE_INLINE: + return InlineLifts().visit(ir) + elif lift_mode == LiftMode.USE_TEMPORARIES: + return InlineLifts( + flags=InlineLifts.Flag.INLINE_TRIVIAL_DEREF_LIFT + | InlineLifts.Flag.INLINE_DEREF_LIFT # some tuple exprs found in FVM don't work yet. + ).visit(ir) + else: + raise ValueError() + + return ir + + +def _inline_into_scan(ir, *, max_iter=10): + for _ in range(10): + # in case there are multiple levels of lambdas around the scan we have to do multiple iterations + inlined = InlineIntoScan().visit(ir) + inlined = InlineLambdas.apply(inlined, opcount_preserving=True, force_inline_lift_args=True) + if inlined == ir: + break + ir = inlined + else: + raise RuntimeError(f"Inlining into 'scan' did not converge within {max_iter} iterations.") + return ir + + +def apply_common_transforms( + ir: itir.Node, + *, + lift_mode=None, + offset_provider=None, + unroll_reduce=False, + common_subexpression_elimination=True, + force_inline_lambda_args=False, + unconditionally_collapse_tuples=False, + temporary_extraction_heuristics: Optional[ + Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] + ] = None, + symbolic_domain_sizes: Optional[dict[str, str]] = None, +) -> itir.Program: + assert isinstance(ir, itir.FencilDefinition) + ir = fencil_to_program.FencilToProgram().apply(ir) + icdlv_uids = eve_utils.UIDGenerator() + + if lift_mode is None: + lift_mode = LiftMode.FORCE_INLINE + assert isinstance(lift_mode, LiftMode) + ir = MergeLet().visit(ir) + ir = inline_fundefs.InlineFundefs().visit(ir) + + ir = inline_fundefs.prune_unreferenced_fundefs(ir) # type: ignore[arg-type] # all previous passes return itir.Program + ir = PropagateDeref.apply(ir) + ir = NormalizeShifts().visit(ir) + + for _ in range(10): + inlined = ir + + inlined = InlineCenterDerefLiftVars.apply(inlined, uids=icdlv_uids) # type: ignore[arg-type] # always a fencil + inlined = _inline_lifts(inlined, lift_mode) + + inlined = InlineLambdas.apply( + inlined, + opcount_preserving=True, + force_inline_lift_args=(lift_mode == LiftMode.FORCE_INLINE), + # If trivial lifts are not inlined we might create temporaries for constants. In all + # other cases we want it anyway. + force_inline_trivial_lift_args=True, + ) + inlined = ConstantFolding.apply(inlined) + # This pass is required to be in the loop such that when an `if_` call with tuple arguments + # is constant-folded the surrounding tuple_get calls can be removed. + inlined = CollapseTuple.apply( + inlined, + offset_provider=offset_provider, + # TODO(tehrengruber): disabled since it increases compile-time too much right now + flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, + field_view_only=False, + ) + # This pass is required such that a deref outside of a + # `tuple_get(make_tuple(let(...), ...))` call is propagated into the let after the + # `tuple_get` is removed by the `CollapseTuple` pass. + inlined = PropagateDeref.apply(inlined) + + if inlined == ir: + break + ir = inlined + else: + raise RuntimeError("Inlining 'lift' and 'lambdas' did not converge.") + + if lift_mode != LiftMode.FORCE_INLINE: + raise NotImplementedError() + + # Since `CollapseTuple` relies on the type inference which does not support returning tuples + # larger than the number of closure outputs as given by the unconditional collapse, we can + # only run the unconditional version here instead of in the loop above. + if unconditionally_collapse_tuples: + ir = CollapseTuple.apply( + ir, + ignore_tuple_size=True, + offset_provider=offset_provider, + # TODO(tehrengruber): disabled since it increases compile-time too much right now + flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, + ) + + if lift_mode == LiftMode.FORCE_INLINE: + ir = _inline_into_scan(ir) + + ir = NormalizeShifts().visit(ir) + + ir = FuseMaps().visit(ir) + ir = CollapseListGet().visit(ir) + + if unroll_reduce: + for _ in range(10): + unrolled = UnrollReduce.apply(ir, offset_provider=offset_provider) + if unrolled == ir: + break + ir = unrolled + ir = CollapseListGet().visit(ir) + ir = NormalizeShifts().visit(ir) + ir = _inline_lifts(ir, LiftMode.FORCE_INLINE) + ir = NormalizeShifts().visit(ir) + else: + raise RuntimeError("Reduction unrolling failed.") + + ir = EtaReduction().visit(ir) + ir = ScanEtaReduction().visit(ir) + + if common_subexpression_elimination: + ir = CommonSubexpressionElimination.apply(ir, offset_provider=offset_provider) # type: ignore[type-var] # always an itir.Program + ir = MergeLet().visit(ir) + + ir = InlineLambdas.apply( + ir, opcount_preserving=True, force_inline_lambda_args=force_inline_lambda_args + ) + + assert isinstance(ir, itir.Program) + return ir diff --git a/src/gt4py/next/iterator/transforms/unroll_reduce.py b/src/gt4py/next/iterator/transforms/unroll_reduce.py index 700b8571a5..ec9c3efb2b 100644 --- a/src/gt4py/next/iterator/transforms/unroll_reduce.py +++ b/src/gt4py/next/iterator/transforms/unroll_reduce.py @@ -30,7 +30,14 @@ def _is_neighbors_or_lifted_and_neighbors(arg: itir.Expr) -> TypeGuard[itir.FunC def _get_neighbors_args(reduce_args: Iterable[itir.Expr]) -> Iterator[itir.FunCall]: - return filter(_is_neighbors_or_lifted_and_neighbors, reduce_args) + flat_reduce_args: list[itir.Expr] = [] + for arg in reduce_args: + if cpm.is_call_to(arg, "if_"): + flat_reduce_args.extend(_get_neighbors_args(arg.args[1:3])) + else: + flat_reduce_args.append(arg) + + return filter(_is_neighbors_or_lifted_and_neighbors, flat_reduce_args) def _is_list_of_funcalls(lst: list) -> TypeGuard[list[itir.FunCall]]: diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py index 1995e4de0b..f4306bca1f 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py @@ -8,7 +8,7 @@ from __future__ import annotations -from typing import ClassVar, Optional, Union +from typing import Callable, ClassVar, Optional, Union from gt4py.eve import Coerced, SymbolName, datamodels from gt4py.eve.traits import SymbolTableTrait, ValidatedSymbolTableTrait @@ -96,25 +96,23 @@ class Backend(Node): domain: Union[SymRef, CartesianDomain, UnstructuredDomain] -def _is_ref_literal_or_tuple_expr_of_ref(expr: Expr) -> bool: +def _is_tuple_expr_of(pred: Callable[[Expr], bool], expr: Expr) -> bool: if ( isinstance(expr, FunCall) and isinstance(expr.fun, SymRef) and expr.fun.id == "tuple_get" and len(expr.args) == 2 - and _is_ref_literal_or_tuple_expr_of_ref(expr.args[1]) + and _is_tuple_expr_of(pred, expr.args[1]) ): return True if ( isinstance(expr, FunCall) and isinstance(expr.fun, SymRef) and expr.fun.id == "make_tuple" - and all(_is_ref_literal_or_tuple_expr_of_ref(arg) for arg in expr.args) + and all(_is_tuple_expr_of(pred, arg) for arg in expr.args) ): return True - if isinstance(expr, (SymRef, Literal)): - return True - return False + return pred(expr) class SidComposite(Expr): @@ -126,14 +124,32 @@ def _values_validator( ) -> None: if not all( isinstance(el, (SidFromScalar, SidComposite)) - or _is_ref_literal_or_tuple_expr_of_ref(el) + or _is_tuple_expr_of(lambda expr: isinstance(expr, (SymRef, Literal)), el) for el in value ): raise ValueError( - "Only 'SymRef', tuple expr of 'SymRef', 'SidFromScalar', or 'SidComposite' allowed." + "Only 'SymRef', 'Literal', tuple expr thereof, 'SidFromScalar', or 'SidComposite' allowed." ) +def _might_be_scalar_expr(expr: Expr) -> bool: + if isinstance(expr, BinaryExpr): + return all(_is_tuple_expr_of(_might_be_scalar_expr, arg) for arg in (expr.lhs, expr.rhs)) + if isinstance(expr, UnaryExpr): + return _is_tuple_expr_of(_might_be_scalar_expr, expr.expr) + if ( + isinstance(expr, FunCall) + and isinstance(expr.fun, SymRef) + and expr.fun.id in ARITHMETIC_BUILTINS + ): + return all(_might_be_scalar_expr(arg) for arg in expr.args) + if isinstance(expr, CastExpr): + return _might_be_scalar_expr(expr.obj_expr) + if _is_tuple_expr_of(lambda e: isinstance(e, (SymRef, Literal)), expr): + return True + return False + + class SidFromScalar(Expr): arg: Expr @@ -141,8 +157,10 @@ class SidFromScalar(Expr): def _arg_validator( self: datamodels.DataModelTP, attribute: datamodels.Attribute, value: Expr ) -> None: - if not _is_ref_literal_or_tuple_expr_of_ref(value): - raise ValueError("Only 'SymRef' or tuple expr of 'SymRef' allowed.") + if not _might_be_scalar_expr(value): + raise ValueError( + "Only 'SymRef', 'Literal', arithmetic op or tuple expr thereof allowed." + ) class Stmt(Node): @@ -153,7 +171,21 @@ class StencilExecution(Stmt): backend: Backend stencil: SymRef output: Union[SymRef, SidComposite] - inputs: list[Union[SymRef, SidComposite, SidFromScalar]] + inputs: list[ + Union[SymRef, SidComposite, SidFromScalar, FunCall] + ] # TODO: StencilExecution only for tuple_get + + @datamodels.validator("inputs") + def _arg_validator( + self: datamodels.DataModelTP, attribute: datamodels.Attribute, inputs: list[Expr] + ) -> None: + for inp in inputs: + if not _is_tuple_expr_of( + lambda expr: isinstance(expr, (SymRef, SidComposite, SidFromScalar)), inp + ): + raise ValueError( + "Only 'SymRef', 'SidComposite', 'SidFromScalar' or tuple expr thereof allowed." + ) class Scan(Node): diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index d729a5ba2f..85260afa07 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -21,7 +21,7 @@ from gt4py.next.common import Connectivity, Dimension from gt4py.next.ffront import fbuiltins from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms import LiftMode, fencil_to_program, pass_manager +from gt4py.next.iterator.transforms import pass_manager from gt4py.next.otf import languages, stages, step_types, workflow from gt4py.next.otf.binding import cpp_interface, interface from gt4py.next.program_processors.codegens.gtfn.codegen import GTFNCodegen, GTFNIMCodegen @@ -52,7 +52,6 @@ class GTFNTranslationStep( # TODO replace by more general mechanism, see https://github.com/GridTools/gt4py/issues/1135 enable_itir_transforms: bool = True use_imperative_backend: bool = False - lift_mode: Optional[LiftMode] = None device_type: core_defs.DeviceType = core_defs.DeviceType.CPU symbolic_domain_sizes: Optional[dict[str, str]] = None temporary_extraction_heuristics: Optional[ @@ -164,14 +163,9 @@ def _preprocess_program( program: itir.FencilDefinition | itir.Program, offset_provider: dict[str, Connectivity | Dimension], ) -> itir.Program: - if isinstance(program, itir.FencilDefinition) and not self.enable_itir_transforms: - return fencil_to_program.FencilToProgram().apply( - program - ) # FIXME[#1582](tehrengruber): should be removed after refactoring to combined IR - apply_common_transforms = functools.partial( pass_manager.apply_common_transforms, - lift_mode=self.lift_mode, + extract_temporaries=True, offset_provider=offset_provider, # sid::composite (via hymap) supports assigning from tuple with more elements to tuple with fewer elements unconditionally_collapse_tuples=True, diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index 3bd96d14d7..47cca740f9 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -587,6 +587,32 @@ def visit_IfStmt(self, node: itir.IfStmt, **kwargs: Any) -> IfStmt: def visit_SetAt( self, node: itir.SetAt, *, extracted_functions: list, **kwargs: Any ) -> Union[StencilExecution, ScanExecution]: + # TODO: symref, literal, tuple thereof is also fine, similar to broadcast fix in gtir lowering + def _is_ref_or_tuple_expr_of_ref(expr: itir.Expr) -> bool: + if ( + isinstance(expr, itir.FunCall) + and isinstance(expr.fun, itir.SymRef) + and expr.fun.id == "tuple_get" + and len(expr.args) == 2 + and _is_ref_or_tuple_expr_of_ref(expr.args[1]) + ): + return True + if ( + isinstance(expr, itir.FunCall) + and isinstance(expr.fun, itir.SymRef) + and expr.fun.id == "make_tuple" + and all(_is_ref_or_tuple_expr_of_ref(arg) for arg in expr.args) + ): + return True + if isinstance(expr, (itir.SymRef, itir.Literal)): + return True + return False + + from gt4py.next.iterator.ir_utils import ir_makers as im + + if _is_ref_or_tuple_expr_of_ref(node.expr): + node.expr = im.as_fieldop("deref", node.domain)(node.expr) + assert cpm.is_applied_as_fieldop(node.expr) stencil = node.expr.fun.args[0] # type: ignore[attr-defined] # checked in assert domain = node.domain @@ -611,7 +637,6 @@ def convert_el_to_sid(el_expr: Expr, el_type: ts.ScalarType | ts.FieldType) -> E tuple_constructor=lambda *elements: SidComposite(values=list(elements)), ) - assert isinstance(lowered_input_as_sid, (SidComposite, SidFromScalar, SymRef)) lowered_inputs.append(lowered_input_as_sid) backend = Backend(domain=self.visit(domain, stencil=stencil, **kwargs)) diff --git a/src/gt4py/next/program_processors/formatters/lisp.py b/src/gt4py/next/program_processors/formatters/lisp.py index c477795c34..0a71ebd1ef 100644 --- a/src/gt4py/next/program_processors/formatters/lisp.py +++ b/src/gt4py/next/program_processors/formatters/lisp.py @@ -51,9 +51,7 @@ class ToLispLike(TemplatedGenerator): @classmethod def apply(cls, root: itir.Node, **kwargs: Any) -> str: # type: ignore[override] - transformed = apply_common_transforms( - root, lift_mode=kwargs.get("lift_mode"), offset_provider=kwargs["offset_provider"] - ) + transformed = apply_common_transforms(root, offset_provider=kwargs["offset_provider"]) generated_code = super().apply(transformed, **kwargs) try: from yasi import indent_code diff --git a/src/gt4py/next/program_processors/runners/dace.py b/src/gt4py/next/program_processors/runners/dace.py index 2db8e98804..3deaf01fcf 100644 --- a/src/gt4py/next/program_processors/runners/dace.py +++ b/src/gt4py/next/program_processors/runners/dace.py @@ -9,7 +9,6 @@ import factory from gt4py.next import allocators as next_allocators, backend -from gt4py.next.ffront import foast_to_gtir, past_to_itir from gt4py.next.program_processors.runners.dace_fieldview import workflow as dace_fieldview_workflow from gt4py.next.program_processors.runners.dace_iterator import workflow as dace_iterator_workflow from gt4py.next.program_processors.runners.gtfn import GTFNBackendFactory @@ -33,7 +32,7 @@ class Params: lambda o: f"run_dace_{o.name_device}{o.name_temps}{o.name_cached}{o.name_postfix}" ) - transforms = backend.DEFAULT_TRANSFORMS + transforms = backend.LEGACY_TRANSFORMS run_dace_cpu = DaCeIteratorBackendFactory(cached=True, auto_optimize=True) @@ -49,8 +48,5 @@ class Params: name="dace.gtir.cpu", executor=dace_fieldview_workflow.DaCeWorkflowFactory(), allocator=next_allocators.StandardCPUFieldBufferAllocator(), - transforms=backend.Transforms( - past_to_itir=past_to_itir.past_to_itir_factory(to_gtir=True), - foast_to_itir=foast_to_gtir.adapted_foast_to_gtir_factory(cached=True), - ), + transforms=backend.DEFAULT_TRANSFORMS, ) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index dab8d29fd1..ef3aa53b54 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -24,7 +24,10 @@ from gt4py.next import common from gt4py.next.ffront import decorator from gt4py.next.iterator import transforms as itir_transforms -from gt4py.next.iterator.transforms import program_to_fencil +from gt4py.next.iterator.transforms import ( + pass_manager_legacy as legacy_itir_transforms, + program_to_fencil, +) from gt4py.next.iterator.type_system import inference as itir_type_inference from gt4py.next.program_processors.runners.dace_common import utility as dace_utils from gt4py.next.type_system import type_specifications as ts @@ -35,14 +38,14 @@ def preprocess_program( program: itir.FencilDefinition, offset_provider: Mapping[str, Any], - lift_mode: itir_transforms.LiftMode, + lift_mode: legacy_itir_transforms.LiftMode, symbolic_domain_sizes: Optional[dict[str, str]] = None, temporary_extraction_heuristics: Optional[ Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] ] = None, unroll_reduce: bool = False, ): - node = itir_transforms.apply_common_transforms( + node = legacy_itir_transforms.apply_common_transforms( program, common_subexpression_elimination=False, force_inline_lambda_args=True, @@ -72,7 +75,7 @@ def build_sdfg_from_itir( auto_optimize: bool = False, on_gpu: bool = False, column_axis: Optional[common.Dimension] = None, - lift_mode: itir_transforms.LiftMode = itir_transforms.LiftMode.FORCE_INLINE, + lift_mode: legacy_itir_transforms.LiftMode = legacy_itir_transforms.LiftMode.FORCE_INLINE, symbolic_domain_sizes: Optional[dict[str, str]] = None, temporary_extraction_heuristics: Optional[ Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] @@ -228,7 +231,7 @@ def __sdfg__(self, *args, **kwargs) -> dace.sdfg.sdfg.SDFG: } sdfg.offset_providers_per_input_field = {} - itir_tmp = itir_transforms.apply_common_transforms( + itir_tmp = legacy_itir_transforms.apply_common_transforms( self.itir, offset_provider=offset_provider ) itir_tmp_fencil = program_to_fencil.program_to_fencil(itir_tmp) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py b/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py index 7a442e3819..aa2e94ee68 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py @@ -18,7 +18,6 @@ from gt4py._core import definitions as core_defs from gt4py.next import common, config from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms import LiftMode from gt4py.next.otf import languages, recipes, stages, step_types, workflow from gt4py.next.otf.binding import interface from gt4py.next.otf.languages import LanguageSettings @@ -36,7 +35,6 @@ class DaCeTranslator( step_types.TranslationStep[languages.SDFG, languages.LanguageSettings], ): auto_optimize: bool = False - lift_mode: LiftMode = LiftMode.FORCE_INLINE device_type: core_defs.DeviceType = core_defs.DeviceType.CPU symbolic_domain_sizes: Optional[dict[str, str]] = None temporary_extraction_heuristics: Optional[ @@ -69,7 +67,6 @@ def generate_sdfg( auto_optimize=self.auto_optimize, on_gpu=on_gpu, column_axis=column_axis, - lift_mode=self.lift_mode, symbolic_domain_sizes=self.symbolic_domain_sizes, temporary_extraction_heuristics=self.temporary_extraction_heuristics, load_sdfg_from_file=False, diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 2275576081..e82c12fad2 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -17,7 +17,6 @@ import gt4py.next.allocators as next_allocators from gt4py.eve.utils import content_hash from gt4py.next import backend, common, config -from gt4py.next.iterator import transforms from gt4py.next.otf import arguments, recipes, stages, workflow from gt4py.next.otf.binding import nanobind from gt4py.next.otf.compilation import compiler @@ -166,12 +165,6 @@ class Params: ), name_cached="_cached", ) - use_temporaries = factory.Trait( - # FIXME[#1582](tehrengruber): Revisit and cleanup after new GTIR temporary pass is in place - otf_workflow__translation__lift_mode=transforms.LiftMode.USE_TEMPORARIES, - # otf_workflow__translation__temporary_extraction_heuristics=global_tmps.SimpleTemporaryExtractionHeuristics, # noqa: ERA001 - name_temps="_with_temporaries", - ) device_type = core_defs.DeviceType.CPU hash_function = compilation_hash otf_workflow = factory.SubFactory( @@ -195,8 +188,6 @@ class Params: run_gtfn_cached = GTFNBackendFactory(cached=True) -run_gtfn_with_temporaries = GTFNBackendFactory(use_temporaries=True) - run_gtfn_gpu = GTFNBackendFactory(gpu=True) run_gtfn_gpu_cached = GTFNBackendFactory(gpu=True, cached=True) diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 93e6d09c5b..2501e96caf 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -20,7 +20,7 @@ from gt4py.eve import codegen from gt4py.eve.codegen import FormatTemplate as as_fmt, MakoTemplate as as_mako from gt4py.next import allocators as next_allocators, backend as next_backend, common, config -from gt4py.next.ffront import foast_to_gtir, past_to_itir +from gt4py.next.ffront import foast_to_gtir, foast_to_past, past_to_itir from gt4py.next.iterator import ir as itir, transforms as itir_transforms from gt4py.next.otf import stages, workflow from gt4py.next.type_system import type_specifications as ts @@ -92,7 +92,7 @@ def visit_Temporary(self, node: itir.Temporary, **kwargs: Any) -> str: def fencil_generator( ir: itir.Node, debug: bool, - lift_mode: itir_transforms.LiftMode, + extract_temporaries: bool, use_embedded: bool, offset_provider: dict[str, common.Connectivity | common.Dimension], ) -> stages.CompiledProgram: @@ -102,7 +102,7 @@ def fencil_generator( Arguments: ir: The iterator IR (ITIR) node. debug: Keep module source containing fencil implementation. - lift_mode: Change the way lifted function calls are evaluated. + extract_temporaries: Extract intermediate field values into temporaries. use_embedded: Directly use builtins from embedded backend instead of generic dispatcher. Gives faster performance and is easier to debug. @@ -110,14 +110,14 @@ def fencil_generator( """ # TODO(tehrengruber): just a temporary solution until we have a proper generic # caching mechanism - cache_key = hash((ir, lift_mode, debug, use_embedded, tuple(offset_provider.items()))) + cache_key = hash((ir, extract_temporaries, debug, use_embedded, tuple(offset_provider.items()))) if cache_key in _FENCIL_CACHE: if debug: print(f"Using cached fencil for key {cache_key}") return typing.cast(stages.CompiledProgram, _FENCIL_CACHE[cache_key]) ir = itir_transforms.apply_common_transforms( - ir, lift_mode=lift_mode, offset_provider=offset_provider + ir, extract_temporaries=extract_temporaries, offset_provider=offset_provider ) program = EmbeddedDSL.apply(ir) @@ -187,18 +187,19 @@ def fencil_generator( @dataclasses.dataclass(frozen=True) class Roundtrip(workflow.Workflow[stages.CompilableProgram, stages.CompiledProgram]): debug: Optional[bool] = None - lift_mode: itir_transforms.LiftMode = itir_transforms.LiftMode.FORCE_INLINE + extract_temporaries: bool = False use_embedded: bool = True dispatch_backend: Optional[next_backend.Backend] = None def __call__(self, inp: stages.CompilableProgram) -> stages.CompiledProgram: debug = config.DEBUG if self.debug is None else self.debug + assert isinstance(inp.data, itir.Program) fencil = fencil_generator( inp.data, offset_provider=inp.args.offset_provider, debug=debug, - lift_mode=self.lift_mode, + extract_temporaries=self.extract_temporaries, use_embedded=self.use_embedded, ) @@ -211,7 +212,7 @@ def decorated_fencil( ) -> None: if out is not None: args = (*args, out) - if not column_axis: + if not column_axis: # TODO(tehrengruber): This variable is never used. Bug? column_axis = inp.args.column_axis fencil( *args, @@ -225,7 +226,7 @@ def decorated_fencil( executor = Roundtrip() -executor_with_temporaries = Roundtrip(lift_mode=itir_transforms.LiftMode.USE_TEMPORARIES) +executor_with_temporaries = Roundtrip(extract_temporaries=True) default = next_backend.Backend( name="roundtrip", @@ -240,12 +241,17 @@ def decorated_fencil( transforms=next_backend.DEFAULT_TRANSFORMS, ) +foast_to_gtir_step = foast_to_gtir.adapted_foast_to_gtir_factory(cached=True) + gtir = next_backend.Backend( name="roundtrip_gtir", executor=executor, allocator=next_allocators.StandardCPUFieldBufferAllocator(), transforms=next_backend.Transforms( past_to_itir=past_to_itir.past_to_itir_factory(to_gtir=True), - foast_to_itir=foast_to_gtir.adapted_foast_to_gtir_factory(cached=True), + foast_to_itir=foast_to_gtir_step, + field_view_op_to_prog=foast_to_past.operator_to_program_factory( + foast_to_itir_step=foast_to_gtir_step + ), ), ) diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 123384a098..1c9af94e55 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -129,13 +129,12 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (REQUIRES_ATLAS, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), (STARTS_FROM_GTIR_PROGRAM, SKIP, UNSUPPORTED_MESSAGE), (USES_APPLIED_SHIFTS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE), (USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE), (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE), - (USES_SCAN_IN_FIELD_OPERATOR, XFAIL, UNSUPPORTED_MESSAGE), (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), ] DACE_SKIP_TEST_LIST = COMMON_SKIP_TEST_LIST + [ + (USES_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE), (USES_IR_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE), (USES_SCALAR_IN_DOMAIN_AND_FO, XFAIL, UNSUPPORTED_MESSAGE), (USES_INDEX_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), @@ -145,6 +144,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (USES_TUPLE_ARGS, XFAIL, UNSUPPORTED_MESSAGE), (USES_TUPLE_RETURNS, XFAIL, UNSUPPORTED_MESSAGE), (USES_ZERO_DIMENSIONAL_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_SCAN_IN_FIELD_OPERATOR, XFAIL, UNSUPPORTED_MESSAGE), ] GTIR_DACE_SKIP_TEST_LIST = [ (ALL, SKIP, UNSUPPORTED_MESSAGE), @@ -182,7 +182,10 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): ProgramBackendId.GTFN_GPU: GTFN_SKIP_TEST_LIST + [(USES_SCAN_NESTED, XFAIL, UNSUPPORTED_MESSAGE)], ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES: GTFN_SKIP_TEST_LIST - + [(ALL, XFAIL, UNSUPPORTED_MESSAGE), (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE)], + + [ + # (ALL, XFAIL, UNSUPPORTED_MESSAGE), + (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE) + ], ProgramFormatterId.GTFN_CPP_FORMATTER: [ (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE) ], diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index a0e72ede8d..15d08e6daa 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -46,10 +46,9 @@ def __gt_allocator__( @pytest.fixture( params=[ next_tests.definitions.ProgramBackendId.ROUNDTRIP, - # next_tests.definitions.ProgramBackendId.GTIR_EMBEDDED, # FIXME[#1582](havogt): enable once all ingredients for GTIR are available + next_tests.definitions.ProgramBackendId.GTIR_EMBEDDED, # FIXME[#1582](havogt): enable once all ingredients for GTIR are available # noqa: ERA001 next_tests.definitions.ProgramBackendId.GTFN_CPU, next_tests.definitions.ProgramBackendId.GTFN_CPU_IMPERATIVE, - next_tests.definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES, pytest.param( next_tests.definitions.ProgramBackendId.GTFN_GPU, marks=pytest.mark.requires_gpu ), diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 36d6debf9d..c0b2f97db7 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -684,9 +684,6 @@ def simple_scan_operator(carry: float) -> float: @pytest.mark.uses_lift_expressions @pytest.mark.uses_scan_nested def test_solve_triag(cartesian_case): - if cartesian_case.backend == gtfn.run_gtfn_with_temporaries: - pytest.xfail("Temporary extraction does not work correctly in combination with scans.") - @gtx.scan_operator(axis=KDim, forward=True, init=(0.0, 0.0)) def tridiag_forward( state: tuple[float, float], a: float, b: float, c: float, d: float @@ -785,9 +782,6 @@ def testee(a: cases.EField, b: cases.EField) -> cases.VField: @pytest.mark.uses_scan def test_ternary_scan(cartesian_case): - if cartesian_case.backend in [gtfn.run_gtfn_with_temporaries]: - pytest.xfail("Temporary extraction does not work correctly in combination with scans.") - @gtx.scan_operator(axis=KDim, forward=True, init=0.0) def simple_scan_operator(carry: float, a: float) -> float: return carry if carry > a else carry + 1.0 @@ -810,9 +804,6 @@ def simple_scan_operator(carry: float, a: float) -> float: @pytest.mark.uses_scan_without_field_args @pytest.mark.uses_tuple_returns def test_scan_nested_tuple_output(forward, cartesian_case): - if cartesian_case.backend in [gtfn.run_gtfn_with_temporaries]: - pytest.xfail("Temporary extraction does not work correctly in combination with scans.") - init = (1, (2, 3)) k_size = cartesian_case.default_sizes[KDim] expected = np.arange(1, 1 + k_size, 1, dtype=int32) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py index 0efb599f9e..f5d946c7bd 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py @@ -56,6 +56,7 @@ def simple_if(a: cases.IField, b: cases.IField, condition: bool) -> cases.IField cases.verify(cartesian_case, simple_if, a, b, condition, out=out, ref=a if condition else b) +# TODO: test with fields on different domains @pytest.mark.parametrize("condition1, condition2", [[True, False], [True, False]]) @pytest.mark.uses_if_stmts def test_simple_if_conditional(condition1, condition2, cartesian_case): diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py index 0305a5841a..bb02f0a89e 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py @@ -11,8 +11,8 @@ from gt4py import next as gtx from gt4py.next import backend, common -from gt4py.next.iterator.transforms import LiftMode, apply_common_transforms -from gt4py.next.program_processors.runners.gtfn import run_gtfn_with_temporaries +from gt4py.next.iterator.transforms import apply_common_transforms +from gt4py.next.program_processors.runners.gtfn import run_gtfn from next_tests.integration_tests import cases from next_tests.integration_tests.cases import ( @@ -34,8 +34,8 @@ def run_gtfn_with_temporaries_and_symbolic_sizes(): return backend.Backend( name="run_gtfn_with_temporaries_and_sizes", transforms=backend.DEFAULT_TRANSFORMS, - executor=run_gtfn_with_temporaries.executor.replace( - translation=run_gtfn_with_temporaries.executor.translation.replace( + executor=run_gtfn.executor.replace( + translation=run_gtfn.executor.translation.replace( symbolic_domain_sizes={ "Cell": "num_cells", "Edge": "num_edges", @@ -43,7 +43,7 @@ def run_gtfn_with_temporaries_and_symbolic_sizes(): } ) ), - allocator=run_gtfn_with_temporaries.allocator, + allocator=run_gtfn.allocator, ) @@ -64,9 +64,6 @@ def prog( def test_verification(testee, run_gtfn_with_temporaries_and_symbolic_sizes, mesh_descriptor): - # FIXME[#1582](tehrengruber): enable when temporary pass has been implemented - pytest.xfail("Temporary pass not implemented.") - unstructured_case = Case( run_gtfn_with_temporaries_and_symbolic_sizes, offset_provider=mesh_descriptor.offset_provider, @@ -100,12 +97,9 @@ def test_verification(testee, run_gtfn_with_temporaries_and_symbolic_sizes, mesh def test_temporary_symbols(testee, mesh_descriptor): - # FIXME[#1582](tehrengruber): enable when temporary pass has been implemented - pytest.xfail("Temporary pass not implemented.") - itir_with_tmp = apply_common_transforms( testee.itir, - lift_mode=LiftMode.USE_TEMPORARIES, + extract_temporaries=True, offset_provider=mesh_descriptor.offset_provider, ) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py index 50756f40e7..141091b450 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py @@ -84,9 +84,13 @@ def run_test_expr( domain: itir.FunCall, expected_domains: dict[str, itir.Expr | dict[str | Dimension, tuple[itir.Expr, itir.Expr]]], offset_provider: common.OffsetProvider, + symbolic_domain_sizes: Optional[dict[str, str]] = None, ): actual_call, actual_domains = infer_domain.infer_expr( - testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider + testee, + domain_utils.SymbolicDomain.from_expr(domain), + offset_provider, + symbolic_domain_sizes, ) folded_call = constant_fold_domain_exprs(actual_call) folded_domains = constant_fold_accessed_domains(actual_domains) if actual_domains else None @@ -1021,3 +1025,22 @@ def test_scan(offset_provider): {"a": im.domain(common.GridType.CARTESIAN, {IDim: (1, 12)})}, offset_provider, ) + + +def test_symbolic_domain_sizes(unstructured_offset_provider): + stencil = im.lambda_("arg0")(im.deref(im.shift("E2V", 1)("arg0"))) + domain = im.domain(common.GridType.UNSTRUCTURED, {Edge: (0, 1)}) + symbolic_domain_sizes = {"Vertex": "num_vertices"} + + testee, expected = setup_test_as_fieldop( + stencil, + domain, + ) + run_test_expr( + testee, + expected, + domain, + {"in_field1": {Vertex: (0, im.ref("num_vertices"))}}, + unstructured_offset_provider, + symbolic_domain_sizes, + ) From 068ff06b04a9980accf6fb15da306224acfe51be Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 16 Oct 2024 10:30:51 +0200 Subject: [PATCH 035/150] Cleanup --- src/gt4py/next/iterator/transforms/fuse_as_fieldop.py | 6 +++++- src/gt4py/next/iterator/transforms/pass_manager.py | 5 +++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index 8cb54ef305..dddaa678a7 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -157,7 +157,11 @@ def visit_FunCall(self, node: itir.FunCall): # TODO(tehrengruber): make this configurable should_inline = isinstance(arg, itir.Literal) or ( isinstance(arg, itir.FunCall) - and (cpm.is_call_to(arg.fun, "as_fieldop") or cpm.is_call_to(arg, "if_")) + and ( + cpm.is_call_to(arg.fun, "as_fieldop") + and isinstance(arg.fun.args[0], itir.Lambda) + or cpm.is_call_to(arg, "if_") + ) and (isinstance(dtype, it_ts.ListType) or len(arg_shifts) <= 1) ) if should_inline: diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index d5c084a0a9..f8fa006e5b 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -15,6 +15,7 @@ global_tmps, infer_domain, inline_fundefs, + inline_lifts, ) from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple @@ -117,6 +118,10 @@ def apply_common_transforms( ir = unrolled ir = CollapseListGet().visit(ir) ir = NormalizeShifts().visit(ir) + # this is required as nested neighbor reductions can contain lifts, e.g., + # `neighbors(V2Eₒ, ↑f(...))` + ir = inline_lifts.InlineLifts().visit(ir) + ir = NormalizeShifts().visit(ir) else: raise RuntimeError("Reduction unrolling failed.") From aaba729ac0f28faf17d9de14e21a14aa8c30116d Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 16 Oct 2024 10:34:27 +0200 Subject: [PATCH 036/150] Cleanup --- .../feature_tests/ffront_tests/test_decorator.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py index e3e919e52e..47419c278b 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py @@ -30,8 +30,10 @@ def testee_op(a: cases.IField) -> cases.IField: def testee(a: cases.IField, out: cases.IField): testee_op(a, out=out) - assert isinstance(testee.itir, itir.FencilDefinition) - assert isinstance(testee.with_backend(cartesian_case.backend).itir, itir.FencilDefinition) + assert isinstance(testee.itir, (itir.Program, itir.FencilDefinition)) + assert isinstance( + testee.with_backend(cartesian_case.backend).itir, (itir.Program, itir.FencilDefinition) + ) def test_frozen(cartesian_case): From 6044d760dd575e204c0b8c10039daca6809e0fb6 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 16 Oct 2024 14:15:58 +0200 Subject: [PATCH 037/150] Cleanup --- .../iterator/transforms/collapse_tuple.py | 83 ++++++++++++------- .../iterator/transforms/fuse_as_fieldop.py | 14 +++- .../next/iterator/transforms/pass_manager.py | 5 ++ .../transforms/pass_manager_legacy.py | 1 - .../program_processors/runners/roundtrip.py | 1 - tests/next_tests/unit_tests/conftest.py | 1 - .../transforms_tests/test_fuse_as_fieldop.py | 15 ++++ 7 files changed, 82 insertions(+), 38 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 2a2608081c..cfabf5b6d1 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -87,7 +87,6 @@ def all(self) -> CollapseTuple.Flag: return functools.reduce(operator.or_, self.__members__.values()) ignore_tuple_size: bool - field_view_only: bool flags: Flag = Flag.all() # noqa: RUF009 [function-call-in-dataclass-default-argument] PRESERVED_ANNEX_ATTRS = ("type",) @@ -106,7 +105,7 @@ def apply( ignore_tuple_size: bool = False, remove_letified_make_tuple_elements: bool = True, offset_provider=None, - field_view_only=True, + is_local_view=None, # manually passing flags is mostly for allowing separate testing of the modes flags=None, # allow sym references without a symbol declaration, mostly for testing @@ -128,6 +127,13 @@ def apply( flags = flags or cls.flags offset_provider = offset_provider or {} + if isinstance(node, (ir.Program, ir.FencilDefinition)): + is_local_view = False + assert is_local_view in [ + True, + False, + ], "Parameter 'is_local_view' mandatory if node is not a 'Program'." + if not ignore_tuple_size: node = itir_type_inference.infer( node, @@ -137,9 +143,8 @@ def apply( new_node = cls( ignore_tuple_size=ignore_tuple_size, - field_view_only=field_view_only, flags=flags, - ).visit(node) + ).visit(node, is_local_view=is_local_view) # inline to remove left-overs from LETIFY_MAKE_TUPLE_ELEMENTS. this is important # as otherwise two equal expressions containing a tuple will not be equal anymore @@ -153,24 +158,24 @@ def apply( return new_node - def visit_FunCall(self, node: ir.FunCall) -> ir.Node: + def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: # don't visit stencil argument of `as_fieldop` - if self.field_view_only and cpm.is_call_to(node, "as_fieldop"): - return node + if cpm.is_call_to(node, "as_fieldop"): + kwargs = {**kwargs, "is_local_view": True} - node = self.generic_visit(node) - return self.fp_transform(node) + node = self.generic_visit(node, **kwargs) + return self.fp_transform(node, **kwargs) - def fp_transform(self, node: ir.Node) -> ir.Node: + def fp_transform(self, node: ir.Node, **kwargs) -> ir.Node: while True: - new_node = self.transform(node) + new_node = self.transform(node, **kwargs) if new_node is None: break assert new_node != node node = new_node return node - def transform(self, node: ir.Node) -> Optional[ir.Node]: + def transform(self, node: ir.Node, **kwargs) -> Optional[ir.Node]: if not isinstance(node, ir.FunCall): return None @@ -178,12 +183,14 @@ def transform(self, node: ir.Node) -> Optional[ir.Node]: if self.flags & transformation: assert isinstance(transformation.name, str) method = getattr(self, f"transform_{transformation.name.lower()}") - result = method(node) + result = method(node, **kwargs) if result is not None: return result return None - def transform_collapse_make_tuple_tuple_get(self, node: ir.FunCall) -> Optional[ir.Node]: + def transform_collapse_make_tuple_tuple_get( + self, node: ir.FunCall, **kwargs + ) -> Optional[ir.Node]: if node.fun == ir.SymRef(id="make_tuple") and all( isinstance(arg, ir.FunCall) and arg.fun == ir.SymRef(id="tuple_get") for arg in node.args @@ -204,7 +211,9 @@ def transform_collapse_make_tuple_tuple_get(self, node: ir.FunCall) -> Optional[ return first_expr return None - def transform_collapse_tuple_get_make_tuple(self, node: ir.FunCall) -> Optional[ir.Node]: + def transform_collapse_tuple_get_make_tuple( + self, node: ir.FunCall, **kwargs + ) -> Optional[ir.Node]: if ( node.fun == ir.SymRef(id="tuple_get") and isinstance(node.args[1], ir.FunCall) @@ -221,7 +230,7 @@ def transform_collapse_tuple_get_make_tuple(self, node: ir.FunCall) -> Optional[ return node.args[1].args[idx] return None - def transform_propagate_tuple_get(self, node: ir.FunCall) -> Optional[ir.Node]: + def transform_propagate_tuple_get(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: if node.fun == ir.SymRef(id="tuple_get") and isinstance(node.args[0], ir.Literal): # TODO(tehrengruber): extend to general symbols as long as the tail call in the let # does not capture @@ -230,7 +239,7 @@ def transform_propagate_tuple_get(self, node: ir.FunCall) -> Optional[ir.Node]: idx, let_expr = node.args return im.call( im.lambda_(*let_expr.fun.params)( # type: ignore[attr-defined] # ensured by is_let - self.fp_transform(im.tuple_get(idx.value, let_expr.fun.expr)) # type: ignore[attr-defined] # ensured by is_let + self.fp_transform(im.tuple_get(idx.value, let_expr.fun.expr), **kwargs) # type: ignore[attr-defined] # ensured by is_let ) )( *let_expr.args # type: ignore[attr-defined] # ensured by is_let @@ -240,12 +249,12 @@ def transform_propagate_tuple_get(self, node: ir.FunCall) -> Optional[ir.Node]: cond, true_branch, false_branch = node.args[1].args return im.if_( cond, - self.fp_transform(im.tuple_get(idx.value, true_branch)), - self.fp_transform(im.tuple_get(idx.value, false_branch)), + self.fp_transform(im.tuple_get(idx.value, true_branch), **kwargs), + self.fp_transform(im.tuple_get(idx.value, false_branch), **kwargs), ) return None - def transform_letify_make_tuple_elements(self, node: ir.FunCall) -> Optional[ir.Node]: + def transform_letify_make_tuple_elements(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: if node.fun == ir.SymRef(id="make_tuple"): # `make_tuple(expr1, expr1)` # -> `let((_tuple_el_1, expr1), (_tuple_el_2, expr2))(make_tuple(_tuple_el_1, _tuple_el_2))` @@ -260,21 +269,24 @@ def transform_letify_make_tuple_elements(self, node: ir.FunCall) -> Optional[ir. new_args.append(arg) if bound_vars: - return self.fp_transform(im.let(*bound_vars.items())(im.call(node.fun)(*new_args))) + return self.fp_transform( + im.let(*bound_vars.items())(im.call(node.fun)(*new_args)), **kwargs + ) return None - def transform_inline_trivial_make_tuple(self, node: ir.FunCall) -> Optional[ir.Node]: + def transform_inline_trivial_make_tuple(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: if cpm.is_let(node): # `let(tup, make_tuple(trivial_expr1, trivial_expr2))(foo(tup))` # -> `foo(make_tuple(trivial_expr1, trivial_expr2))` eligible_params = [_is_trivial_make_tuple_call(arg) for arg in node.args] if any(eligible_params): - return self.visit(inline_lambda(node, eligible_params=eligible_params)) + return self.visit(inline_lambda(node, eligible_params=eligible_params), **kwargs) return None - def transform_propagate_to_if_on_tuples(self, node: ir.FunCall) -> Optional[ir.Node]: - if not cpm.is_call_to(node, "if_"): - # TODO(tehrengruber): This significantly increases the size of the tree. Revisit. + def transform_propagate_to_if_on_tuples(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + # TODO(tehrengruber): This significantly increases the size of the tree. Skip transformation + # in local-view for now. Revisit. + if not cpm.is_call_to(node, "if_") and not kwargs["is_local_view"]: # TODO(tehrengruber): Only inline if type of branch value is a tuple. # Examples: # `(if cond then {1, 2} else {3, 4})[0]` -> `if cond then {1, 2}[0] else {3, 4}[0]` @@ -283,12 +295,16 @@ def transform_propagate_to_if_on_tuples(self, node: ir.FunCall) -> Optional[ir.N for i, arg in enumerate(node.args): if cpm.is_call_to(arg, "if_"): cond, true_branch, false_branch = arg.args - new_true_branch = self.fp_transform(_with_altered_arg(node, i, true_branch)) - new_false_branch = self.fp_transform(_with_altered_arg(node, i, false_branch)) + new_true_branch = self.fp_transform( + _with_altered_arg(node, i, true_branch), **kwargs + ) + new_false_branch = self.fp_transform( + _with_altered_arg(node, i, false_branch), **kwargs + ) return im.if_(cond, new_true_branch, new_false_branch) return None - def transform_propagate_nested_let(self, node: ir.FunCall) -> Optional[ir.Node]: + def transform_propagate_nested_let(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: if cpm.is_let(node): # `let((a, let(b, 1)(a_val)))(a)`-> `let(b, 1)(let(a, a_val)(a))` outer_vars = {} @@ -306,12 +322,15 @@ def transform_propagate_nested_let(self, node: ir.FunCall) -> Optional[ir.Node]: if outer_vars: return self.fp_transform( im.let(*outer_vars.items())( - self.fp_transform(im.let(*inner_vars.items())(original_inner_expr)) - ) + self.fp_transform( + im.let(*inner_vars.items())(original_inner_expr), **kwargs + ) + ), + **kwargs, ) return None - def transform_inline_trivial_let(self, node: ir.FunCall) -> Optional[ir.Node]: + def transform_inline_trivial_let(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: if cpm.is_let(node) and isinstance(node.fun.expr, ir.SymRef): # type: ignore[attr-defined] # ensured by is_let # `let(a, 1)(a)` -> `1` for arg_sym, arg in zip(node.fun.params, node.args): # type: ignore[attr-defined] # ensured by is_let diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index dddaa678a7..8928b406f3 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -54,6 +54,14 @@ def _canonicalize_as_fieldop(expr: itir.FunCall) -> itir.FunCall: return expr +def _is_tuple_expr_of_literals(expr: itir.Expr): + if cpm.is_call_to(expr, "make_tuple"): + return all(_is_tuple_expr_of_literals(arg) for arg in expr.args) + if cpm.is_call_to(expr, "tuple_get"): + return _is_tuple_expr_of_literals(expr.args[1]) + return isinstance(expr, itir.Literal) + + @dataclasses.dataclass class FuseAsFieldOp(eve.NodeTranslator): """ @@ -153,9 +161,9 @@ def visit_FunCall(self, node: itir.FunCall): for stencil_param, arg, arg_shifts in zip(stencil.params, args, shifts, strict=True): assert isinstance(arg.type, ts.TypeSpec) - dtype = type_info.extract_dtype(arg.type) + dtype = type_info.apply_to_primitive_constituents(type_info.extract_dtype, arg.type) # TODO(tehrengruber): make this configurable - should_inline = isinstance(arg, itir.Literal) or ( + should_inline = _is_tuple_expr_of_literals(arg) or ( isinstance(arg, itir.FunCall) and ( cpm.is_call_to(arg.fun, "as_fieldop") @@ -172,7 +180,7 @@ def visit_FunCall(self, node: itir.FunCall): type_ = arg.type arg = im.op_as_fieldop("if_")(*arg.args) arg.type = type_ - elif isinstance(arg, itir.Literal): + elif _is_tuple_expr_of_literals(arg): arg = im.op_as_fieldop(im.lambda_()(arg))() else: raise NotImplementedError() diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index f8fa006e5b..e711021cf9 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -11,6 +11,7 @@ from gt4py.eve import utils as eve_utils from gt4py.next.iterator import ir as itir from gt4py.next.iterator.transforms import ( + fencil_to_program, fuse_as_fieldop, global_tmps, infer_domain, @@ -46,6 +47,9 @@ def apply_common_transforms( ] = None, symbolic_domain_sizes: Optional[dict[str, str]] = None, ) -> itir.Program: + # FIXME[#1582](tehrengruber): Rewrite iterator tests with itir.Program and remove this + if isinstance(ir, itir.FencilDefinition): + ir = fencil_to_program.FencilToProgram.apply(ir) assert isinstance(ir, itir.Program) tmp_uids = eve_utils.UIDGenerator(prefix="__tmp") @@ -60,6 +64,7 @@ def apply_common_transforms( # note: this increases the size of the tree # Inline. The domain inference can not handle "user" functions, e.g. `let f = λ(...) → ... in f(...)` ir = InlineLambdas.apply(ir, opcount_preserving=True, force_inline_lambda_args=True) + # todo: run collapse tuple ir = infer_domain.infer_program( ir, # type: ignore[arg-type] # always an itir.Program offset_provider=offset_provider, diff --git a/src/gt4py/next/iterator/transforms/pass_manager_legacy.py b/src/gt4py/next/iterator/transforms/pass_manager_legacy.py index 9933fcd4ae..792bb421f1 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager_legacy.py +++ b/src/gt4py/next/iterator/transforms/pass_manager_legacy.py @@ -112,7 +112,6 @@ def apply_common_transforms( offset_provider=offset_provider, # TODO(tehrengruber): disabled since it increases compile-time too much right now flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, - field_view_only=False, ) # This pass is required such that a deref outside of a # `tuple_get(make_tuple(let(...), ...))` call is propagated into the let after the diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 2501e96caf..f6983d81f5 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -194,7 +194,6 @@ class Roundtrip(workflow.Workflow[stages.CompilableProgram, stages.CompiledProgr def __call__(self, inp: stages.CompilableProgram) -> stages.CompiledProgram: debug = config.DEBUG if self.debug is None else self.debug - assert isinstance(inp.data, itir.Program) fencil = fencil_generator( inp.data, offset_provider=inp.args.offset_provider, diff --git a/tests/next_tests/unit_tests/conftest.py b/tests/next_tests/unit_tests/conftest.py index 8a4aa50730..87cdafc025 100644 --- a/tests/next_tests/unit_tests/conftest.py +++ b/tests/next_tests/unit_tests/conftest.py @@ -33,7 +33,6 @@ (next_tests.definitions.ProgramBackendId.DOUBLE_ROUNDTRIP, True), (next_tests.definitions.ProgramBackendId.GTFN_CPU, True), (next_tests.definitions.ProgramBackendId.GTFN_CPU_IMPERATIVE, True), - (next_tests.definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES, True), # pytest.param((definitions.ProgramBackendId.GTFN_GPU, True), marks=pytest.mark.requires_gpu), # TODO(havogt): update tests to use proper allocation (next_tests.definitions.ProgramFormatterId.LISP_FORMATTER, False), (next_tests.definitions.ProgramFormatterId.ITIR_PRETTY_PRINTER, False), diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py index da2c16336e..273d5afd6b 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py @@ -45,6 +45,21 @@ def test_trivial_literal(): assert actual == expected +def test_tuple_arg(): + d = im.domain("cartesian_domain", {}) + testee = im.op_as_fieldop("plus", d)( + im.op_as_fieldop(im.lambda_("t")(im.plus(im.tuple_get(0, "t"), im.tuple_get(1, "t"))), d)( + im.make_tuple(1, 2) + ), + 3, + ) + expected = im.as_fieldop(im.lambda_()(im.plus(im.multiplies_(1, 2), 3)), d)() + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider={}, allow_undeclared_symbols=True + ) + assert actual == expected + + def test_symref_used_twice(): d = im.domain("cartesian_domain", {IDim: (0, 1)}) testee = im.as_fieldop(im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), d)( From 1dc9ebbc38c15a30b34384e86095f696b9a90360 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 16 Oct 2024 14:17:20 +0200 Subject: [PATCH 038/150] Cleanup --- src/gt4py/next/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/config.py b/src/gt4py/next/config.py index 4f53e3c535..ed244c2932 100644 --- a/src/gt4py/next/config.py +++ b/src/gt4py/next/config.py @@ -56,7 +56,7 @@ def env_flag_to_bool(name: str, default: bool) -> bool: #: Master debug flag #: Changes defaults for all the other options to be as helpful for debugging as possible. #: Does not override values set in environment variables. -DEBUG: Final[bool] = env_flag_to_bool(f"{_PREFIX}_DEBUG", default=True) +DEBUG: Final[bool] = env_flag_to_bool(f"{_PREFIX}_DEBUG", default=False) #: Verbose flag for DSL compilation errors From 685bedbbe6e9329e0c3f1e53569046e47226c40f Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 16 Oct 2024 14:28:58 +0200 Subject: [PATCH 039/150] Cleanup --- .../iterator_tests/test_vertical_advection.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py index a89f250571..5d34912c94 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py @@ -12,7 +12,6 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import * from gt4py.next.iterator.runtime import closure, fendef, fundef -from gt4py.next.iterator.transforms import LiftMode from gt4py.next.program_processors.formatters import gtfn as gtfn_formatters from gt4py.next.program_processors.runners import gtfn From 378b3b3e4ddda5a14f19c07cee9ce84522cab2d2 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 16 Oct 2024 14:29:05 +0200 Subject: [PATCH 040/150] Cleanup --- .../multi_feature_tests/iterator_tests/test_fvm_nabla.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py index 156bc1c37f..3db4497910 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py @@ -30,7 +30,6 @@ unstructured_domain, ) from gt4py.next.iterator.runtime import closure, fendef, fundef, offset -from gt4py.next.iterator.transforms.pass_manager import LiftMode from next_tests.integration_tests.multi_feature_tests.fvm_nabla_setup import ( assert_close, From 83e5ce2580ca8a97e2863c42cbcb829fde5c647f Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 16 Oct 2024 16:00:09 +0200 Subject: [PATCH 041/150] Cleanup --- .../next/iterator/transforms/infer_domain.py | 13 +++++++++++++ .../next/iterator/transforms/inline_lambdas.py | 10 +++++++--- .../next/iterator/transforms/pass_manager.py | 3 ++- .../next/iterator/transforms/remap_symbols.py | 8 ++++---- src/gt4py/next/iterator/type_system/inference.py | 15 ++++++++------- .../next/iterator/type_system/type_synthesizer.py | 3 +++ .../transforms_tests/test_fuse_as_fieldop.py | 12 +++++++++++- 7 files changed, 48 insertions(+), 16 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index 340ae7c53e..f15b54b89e 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -12,6 +12,7 @@ import typing from typing import Callable, Optional, TypeAlias +from gt4py import eve from gt4py.eve import utils as eve_utils from gt4py.next import common from gt4py.next.iterator import ir as itir @@ -28,6 +29,18 @@ ACCESSED_DOMAINS: TypeAlias = dict[str, DOMAIN] +class DomainAnnexDebugger(eve.NodeVisitor): + """ + Small utility class to debug missing domain attribute in annex. + """ + + def visit_Node(self, node: itir.Node): + if cpm.is_applied_as_fieldop(node): + if not hasattr(node.annex, "domain"): + breakpoint() # noqa: T100 + return self.generic_visit(node) + + def _split_dict_by_key(pred: Callable, d: dict): """ Split dictionary into two based on predicate. diff --git a/src/gt4py/next/iterator/transforms/inline_lambdas.py b/src/gt4py/next/iterator/transforms/inline_lambdas.py index 920d628166..55eb002b83 100644 --- a/src/gt4py/next/iterator/transforms/inline_lambdas.py +++ b/src/gt4py/next/iterator/transforms/inline_lambdas.py @@ -98,7 +98,7 @@ def new_name(name): new_expr.location = node.location return new_expr else: - return ir.FunCall( + new_expr = ir.FunCall( fun=ir.Lambda( params=[ param @@ -110,6 +110,10 @@ def new_name(name): args=[arg for arg, eligible in zip(node.args, eligible_params) if not eligible], location=node.location, ) + for attr in ("type", "recorded_shifts", "domain"): + if hasattr(node.annex, attr): + setattr(new_expr.annex, attr, getattr(node.annex, attr)) + return new_expr @dataclasses.dataclass @@ -117,10 +121,10 @@ class InlineLambdas(PreserveLocationVisitor, NodeTranslator): """ Inline lambda calls by substituting every argument by its value. - Note: This pass preserves, but doesn't use the `type` and `recorded_shifts` annex. + Note: This pass preserves, but doesn't use the `type` `recorded_shifts`, `domain` annex. """ - PRESERVED_ANNEX_ATTRS = ("type", "recorded_shifts") + PRESERVED_ANNEX_ATTRS = ("type", "recorded_shifts", "domain") opcount_preserving: bool diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index e711021cf9..27a8d3c558 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -64,7 +64,8 @@ def apply_common_transforms( # note: this increases the size of the tree # Inline. The domain inference can not handle "user" functions, e.g. `let f = λ(...) → ... in f(...)` ir = InlineLambdas.apply(ir, opcount_preserving=True, force_inline_lambda_args=True) - # todo: run collapse tuple + # required in order to get rid of expressions without a domain (e.g. when a tuple element is never accessed) + ir = CollapseTuple.apply(ir, offset_provider=offset_provider) ir = infer_domain.infer_program( ir, # type: ignore[arg-type] # always an itir.Program offset_provider=offset_provider, diff --git a/src/gt4py/next/iterator/transforms/remap_symbols.py b/src/gt4py/next/iterator/transforms/remap_symbols.py index 02180a3699..08d896121d 100644 --- a/src/gt4py/next/iterator/transforms/remap_symbols.py +++ b/src/gt4py/next/iterator/transforms/remap_symbols.py @@ -13,8 +13,8 @@ class RemapSymbolRefs(PreserveLocationVisitor, NodeTranslator): - # This pass preserves, but doesn't use the `type` and `recorded_shifts` annex. - PRESERVED_ANNEX_ATTRS = ("type", "recorded_shifts") + # This pass preserves, but doesn't use the `type`, `recorded_shifts`, `domain` annex. + PRESERVED_ANNEX_ATTRS = ("type", "recorded_shifts", "domain") def visit_SymRef(self, node: ir.SymRef, *, symbol_map: Dict[str, ir.Node]): return symbol_map.get(str(node.id), node) @@ -32,8 +32,8 @@ def generic_visit(self, node: ir.Node, **kwargs: Any): # type: ignore[override] class RenameSymbols(PreserveLocationVisitor, NodeTranslator): - # This pass preserves, but doesn't use the `type` and `recorded_shifts` annex. - PRESERVED_ANNEX_ATTRS = ("type", "recorded_shifts") + # This pass preserves, but doesn't use the `type`, `recorded_shifts`, `domain` annex. + PRESERVED_ANNEX_ATTRS = ("type", "recorded_shifts", "domain") def visit_Sym( self, node: ir.Sym, *, name_map: Dict[str, str], active: Optional[Set[str]] = None diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 4640aa11d1..6f9a59b037 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -435,7 +435,7 @@ def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: result = super().visit(node, **kwargs) if isinstance(node, itir.Node): if isinstance(result, ts.TypeSpec): - if node.type: + if node.type and not isinstance(node.type, ts.DeferredType): assert _is_compatible_type(node.type, result) node.type = result elif isinstance(result, ObservableTypeSynthesizer) or result is None: @@ -511,17 +511,18 @@ def visit_SetAt(self, node: itir.SetAt, *, ctx) -> None: path, node.expr.type, ) - assert isinstance(target_type, ts.FieldType) - assert isinstance(expr_type, ts.FieldType) + assert isinstance(target_type, (ts.FieldType, ts.DeferredType)) + assert isinstance(expr_type, (ts.FieldType, ts.DeferredType)) # TODO(tehrengruber): The lowering emits domains that always have the horizontal domain # first. Since the expr inherits the ordering from the domain this can lead to a mismatch # between the target and expr (e.g. when the target has dimension K, Vertex). We should # probably just change the behaviour of the lowering. Until then we do this more # complicated comparison. - assert ( - set(expr_type.dims) == set(target_type.dims) - and target_type.dtype == expr_type.dtype - ) + if isinstance(target_type, ts.FieldType) and isinstance(expr_type, ts.FieldType): + assert ( + set(expr_type.dims) == set(target_type.dims) + and target_type.dtype == expr_type.dtype + ) # TODO(tehrengruber): Remove after new ITIR format with apply_stencil is used everywhere def visit_StencilClosure(self, node: itir.StencilClosure, *, ctx) -> it_ts.StencilClosureType: diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index c836de1391..f30dfc0fcf 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -291,6 +291,9 @@ def as_fieldop( @TypeSynthesizer def applied_as_fieldop(*fields) -> ts.FieldType | ts.DeferredType: + if any(isinstance(f, ts.DeferredType) for f in fields): + return ts.DeferredType(constraint=ts.FieldType) + stencil_return = stencil( *(_convert_as_fieldop_input_to_iterator(domain, field) for field in fields), offset_provider=offset_provider, diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py index 273d5afd6b..b5b9a62009 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py @@ -53,7 +53,17 @@ def test_tuple_arg(): ), 3, ) - expected = im.as_fieldop(im.lambda_()(im.plus(im.multiplies_(1, 2), 3)), d)() + expected = im.as_fieldop( + im.lambda_()( + im.plus( + im.let("t", im.make_tuple(1, 2))( + im.plus(im.tuple_get(0, "t"), im.tuple_get(1, "t")) + ), + 3, + ) + ), + d, + )() actual = fuse_as_fieldop.FuseAsFieldOp.apply( testee, offset_provider={}, allow_undeclared_symbols=True ) From b3ae17b679fd60f250865dde6a7fb9aad9913198 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 16 Oct 2024 16:26:29 +0200 Subject: [PATCH 042/150] Cleanup --- src/gt4py/next/ffront/foast_to_gtir.py | 2 +- .../unit_tests/ffront_tests/test_foast_to_gtir.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index fd6e082477..9cb0ce05f5 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -374,7 +374,7 @@ def create_if(true_: itir.Expr, false_: itir.Expr) -> itir.FunCall: def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: expr = self.visit(node.args[0], **kwargs) - if type_info.is_type_or_tuple_of_type(node.args[0].type, ts.ScalarType): + if isinstance(node.args[0].type, ts.ScalarType): return im.as_fieldop(im.ref("deref"))(expr) return expr diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index 3951c410dc..09f18246dc 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -916,3 +916,14 @@ def foo(inp: gtx.Field[[TDim], float64]): assert lowered.id == "foo" assert lowered.expr == im.ref("inp") + + +def test_scalar_broadcast(): + def foo(): + return broadcast(1, (UDim, TDim)) + + parsed = FieldOperatorParser.apply_to_function(foo) + lowered = FieldOperatorLowering.apply(parsed) + + assert lowered.id == "foo" + assert lowered.expr == im.as_fieldop("deref")(1) From dbd71a9c04b63f2bf73d1954f3b7f813cfd0ced1 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 16 Oct 2024 17:01:27 +0200 Subject: [PATCH 043/150] Small fix --- .pre-commit-config.yaml | 1 - src/gt4py/next/iterator/type_system/inference.py | 2 +- src/gt4py/next/type_system/type_info.py | 4 +++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5e0314bca3..7088f8febf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -43,7 +43,6 @@ repos: - id: check-merge-conflict - id: check-toml - id: check-yaml - - id: debug-statements - repo: https://github.com/astral-sh/ruff-pre-commit ##[[[cog diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 6f9a59b037..dff81ce40f 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -84,7 +84,7 @@ def _is_compatible_type(type_a: ts.TypeSpec, type_b: ts.TypeSpec): is_compatible &= _is_compatible_type(arg_a, arg_b) is_compatible &= _is_compatible_type(type_a.returns, type_b.returns) else: - is_compatible &= type_a == type_b + is_compatible &= type_info.is_concretizable(type_a, type_b) return is_compatible diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index 5bda9a6f2e..66f8937dc5 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -459,7 +459,9 @@ def is_concretizable(symbol_type: ts.TypeSpec, to_type: ts.TypeSpec) -> bool: """ if isinstance(symbol_type, ts.DeferredType) and ( - symbol_type.constraint is None or issubclass(type_class(to_type), symbol_type.constraint) + symbol_type.constraint is None + or (isinstance(to_type, ts.DeferredType) and to_type.constraint is None) + or issubclass(type_class(to_type), symbol_type.constraint) ): return True elif is_concrete(symbol_type): From f54598412308afe6e0b690fa2d7f6fbb39ab7c30 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 16 Oct 2024 17:11:05 +0200 Subject: [PATCH 044/150] Remove superfluous test backend --- .../next/program_processors/runners/roundtrip.py | 15 +-------------- .../ffront_tests/ffront_test_utils.py | 1 - 2 files changed, 1 insertion(+), 15 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index f6983d81f5..8dfa1f823f 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -20,7 +20,7 @@ from gt4py.eve import codegen from gt4py.eve.codegen import FormatTemplate as as_fmt, MakoTemplate as as_mako from gt4py.next import allocators as next_allocators, backend as next_backend, common, config -from gt4py.next.ffront import foast_to_gtir, foast_to_past, past_to_itir +from gt4py.next.ffront import foast_to_gtir from gt4py.next.iterator import ir as itir, transforms as itir_transforms from gt4py.next.otf import stages, workflow from gt4py.next.type_system import type_specifications as ts @@ -241,16 +241,3 @@ def decorated_fencil( ) foast_to_gtir_step = foast_to_gtir.adapted_foast_to_gtir_factory(cached=True) - -gtir = next_backend.Backend( - name="roundtrip_gtir", - executor=executor, - allocator=next_allocators.StandardCPUFieldBufferAllocator(), - transforms=next_backend.Transforms( - past_to_itir=past_to_itir.past_to_itir_factory(to_gtir=True), - foast_to_itir=foast_to_gtir_step, - field_view_op_to_prog=foast_to_past.operator_to_program_factory( - foast_to_itir_step=foast_to_gtir_step - ), - ), -) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index 15d08e6daa..765ee9d686 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -46,7 +46,6 @@ def __gt_allocator__( @pytest.fixture( params=[ next_tests.definitions.ProgramBackendId.ROUNDTRIP, - next_tests.definitions.ProgramBackendId.GTIR_EMBEDDED, # FIXME[#1582](havogt): enable once all ingredients for GTIR are available # noqa: ERA001 next_tests.definitions.ProgramBackendId.GTFN_CPU, next_tests.definitions.ProgramBackendId.GTFN_CPU_IMPERATIVE, pytest.param( From 0904d88ba5e586f9d3a64744c3b1540abf56b6a3 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 16 Oct 2024 17:12:37 +0200 Subject: [PATCH 045/150] Cleanup --- tests/next_tests/definitions.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 1c9af94e55..946361b1d1 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -183,7 +183,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): + [(USES_SCAN_NESTED, XFAIL, UNSUPPORTED_MESSAGE)], ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES: GTFN_SKIP_TEST_LIST + [ - # (ALL, XFAIL, UNSUPPORTED_MESSAGE), (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE) ], ProgramFormatterId.GTFN_CPP_FORMATTER: [ From 6f6c65b821ab7baa1eee8e51e68db53cf19a0ae6 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 16 Oct 2024 17:15:32 +0200 Subject: [PATCH 046/150] Cleanup --- src/gt4py/next/ffront/foast_to_past.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/ffront/foast_to_past.py b/src/gt4py/next/ffront/foast_to_past.py index a3b6c00ffa..330bc79809 100644 --- a/src/gt4py/next/ffront/foast_to_past.py +++ b/src/gt4py/next/ffront/foast_to_past.py @@ -68,7 +68,7 @@ class OperatorToProgram(workflow.Workflow[AOT_FOP, AOT_PRG]): ... def copy(a: gtx.Field[[IDim], gtx.float32]) -> gtx.Field[[IDim], gtx.float32]: ... return a - >>> op_to_prog = OperatorToProgram(foast_to_itir.adapted_foast_to_itir_factory()) + >>> op_to_prog = OperatorToProgram(foast_to_gtir.adapted_foast_to_gtir_factory()) >>> compile_time_args = arguments.CompileTimeArgs( ... args=tuple(param.type for param in copy.foast_stage.foast_node.definition.params), From 606f662adb0300f14dee59ecf49dcc614af4d712 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 16 Oct 2024 17:20:10 +0200 Subject: [PATCH 047/150] Cleanup --- src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py | 2 +- tests/next_tests/definitions.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py index f4306bca1f..fa1362719f 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py @@ -173,7 +173,7 @@ class StencilExecution(Stmt): output: Union[SymRef, SidComposite] inputs: list[ Union[SymRef, SidComposite, SidFromScalar, FunCall] - ] # TODO: StencilExecution only for tuple_get + ] @datamodels.validator("inputs") def _arg_validator( diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 946361b1d1..2c61ff085c 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -182,9 +182,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): ProgramBackendId.GTFN_GPU: GTFN_SKIP_TEST_LIST + [(USES_SCAN_NESTED, XFAIL, UNSUPPORTED_MESSAGE)], ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES: GTFN_SKIP_TEST_LIST - + [ - (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE) - ], + + [(USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE)], ProgramFormatterId.GTFN_CPP_FORMATTER: [ (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE) ], From b917011a10cd1ca57a4e5c73dc913581fb582c96 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 16 Oct 2024 17:20:17 +0200 Subject: [PATCH 048/150] Cleanup --- src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py index fa1362719f..571baf2d9f 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py @@ -171,9 +171,7 @@ class StencilExecution(Stmt): backend: Backend stencil: SymRef output: Union[SymRef, SidComposite] - inputs: list[ - Union[SymRef, SidComposite, SidFromScalar, FunCall] - ] + inputs: list[Union[SymRef, SidComposite, SidFromScalar, FunCall]] @datamodels.validator("inputs") def _arg_validator( From 6c9e8ab8479567c8129e242b9ae6effd6dfefa0b Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 16 Oct 2024 17:22:43 +0200 Subject: [PATCH 049/150] Cleanup --- .../codegens/gtfn/itir_to_gtfn_ir.py | 47 +++++++++---------- 1 file changed, 23 insertions(+), 24 deletions(-) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index 47cca740f9..68cc884429 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -45,6 +45,7 @@ ) from gt4py.next.program_processors.codegens.gtfn.gtfn_ir_common import Expr, Node, Sym, SymRef from gt4py.next.type_system import type_info, type_specifications as ts +from gt4py.next.iterator.ir_utils import ir_makers as im def pytype_to_cpptype(t: ts.ScalarType | str) -> Optional[str]: @@ -67,6 +68,27 @@ def pytype_to_cpptype(t: ts.ScalarType | str) -> Optional[str]: _horizontal_dimension = "gtfn::unstructured::dim::horizontal" +def _is_tuple_of_ref_or_literal(expr: itir.Expr) -> bool: + if ( + isinstance(expr, itir.FunCall) + and isinstance(expr.fun, itir.SymRef) + and expr.fun.id == "tuple_get" + and len(expr.args) == 2 + and _is_tuple_of_ref_or_literal(expr.args[1]) + ): + return True + if ( + isinstance(expr, itir.FunCall) + and isinstance(expr.fun, itir.SymRef) + and expr.fun.id == "make_tuple" + and all(_is_tuple_of_ref_or_literal(arg) for arg in expr.args) + ): + return True + if isinstance(expr, (itir.SymRef, itir.Literal)): + return True + return False + + def _get_domains(nodes: Iterable[itir.Stmt]) -> Iterable[itir.FunCall]: result = set() for node in nodes: @@ -587,30 +609,7 @@ def visit_IfStmt(self, node: itir.IfStmt, **kwargs: Any) -> IfStmt: def visit_SetAt( self, node: itir.SetAt, *, extracted_functions: list, **kwargs: Any ) -> Union[StencilExecution, ScanExecution]: - # TODO: symref, literal, tuple thereof is also fine, similar to broadcast fix in gtir lowering - def _is_ref_or_tuple_expr_of_ref(expr: itir.Expr) -> bool: - if ( - isinstance(expr, itir.FunCall) - and isinstance(expr.fun, itir.SymRef) - and expr.fun.id == "tuple_get" - and len(expr.args) == 2 - and _is_ref_or_tuple_expr_of_ref(expr.args[1]) - ): - return True - if ( - isinstance(expr, itir.FunCall) - and isinstance(expr.fun, itir.SymRef) - and expr.fun.id == "make_tuple" - and all(_is_ref_or_tuple_expr_of_ref(arg) for arg in expr.args) - ): - return True - if isinstance(expr, (itir.SymRef, itir.Literal)): - return True - return False - - from gt4py.next.iterator.ir_utils import ir_makers as im - - if _is_ref_or_tuple_expr_of_ref(node.expr): + if _is_tuple_of_ref_or_literal(node.expr): node.expr = im.as_fieldop("deref", node.domain)(node.expr) assert cpm.is_applied_as_fieldop(node.expr) From 320c7f8258e4cc3e26a765c3128f4078763f2953 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 17 Oct 2024 11:15:37 +0200 Subject: [PATCH 050/150] Cleanup --- .../next/iterator/transforms/global_tmps.py | 22 +++++++++---------- .../codegens/gtfn/itir_to_gtfn_ir.py | 3 +-- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 2444828895..90f8a6cded 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -27,10 +27,7 @@ def _transform_if( stmt: itir.Stmt, declarations: list[itir.Temporary], uids: eve_utils.UIDGenerator ) -> Optional[list[itir.Stmt]]: - if not isinstance(stmt, itir.SetAt): - return None - - if cpm.is_call_to(stmt.expr, "if_"): + if isinstance(stmt, itir.SetAt) and cpm.is_call_to(stmt.expr, "if_"): cond, true_val, false_val = stmt.expr.args return [ itir.IfStmt( @@ -51,7 +48,10 @@ def _transform_if( def _transform_by_pattern( - stmt: itir.Stmt, predicate, declarations: list[itir.Temporary], uids: eve_utils.UIDGenerator + stmt: itir.Stmt, + predicate: Callable[[itir.Expr, int], bool], + declarations: list[itir.Temporary], + uids: eve_utils.UIDGenerator, ) -> Optional[list[itir.Stmt]]: if not isinstance(stmt, itir.SetAt): return None @@ -76,7 +76,7 @@ def _transform_by_pattern( for tmp_sym, tmp_expr in extracted_fields.items(): domain = tmp_expr.annex.domain - # TODO(tehrengruber): Implement. This happens when the expression for a combination + # TODO(tehrengruber): Implement. This happens when the expression is a combination # of an `if_` call with a tuple, e.g., `if_(cond, {a, b}, {c, d})`. As long as we are # able to eliminate all tuples, e.g., by propagating the scalar ifs to the top-level # of a SetAt, the CollapseTuple pass will eliminate most of this cases. @@ -86,7 +86,7 @@ def _transform_by_pattern( ) if not all(d == flattened_domains[0] for d in flattened_domains): raise NotImplementedError( - "Tuple expressions with different domains is not " "supported yet." + "Tuple expressions with different domains is not supported yet." ) domain = flattened_domains[0] assert isinstance(domain, domain_utils.SymbolicDomain) @@ -107,12 +107,10 @@ def _transform_by_pattern( ) # allocate temporary for all tuple elements - def allocate_temporary(tmp_name: str, dtype: ts.ScalarType, domain: itir.Expr): - declarations.append(itir.Temporary(id=tmp_name, domain=domain, dtype=dtype)) + def allocate_temporary(tmp_name: str, dtype: ts.ScalarType): + declarations.append(itir.Temporary(id=tmp_name, domain=domain_expr, dtype=dtype)) # noqa: B023 # function only used inside loop - next_utils.tree_map(functools.partial(allocate_temporary, domain=domain_expr))( - tmp_names, tmp_dtypes - ) + next_utils.tree_map(allocate_temporary)(tmp_names, tmp_dtypes) # if the expr is a field this just gives a simple `itir.SymRef`, otherwise we generate a # `make_tuple` expression. diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index 68cc884429..bc2bd645e8 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -15,7 +15,7 @@ from gt4py.eve.concepts import SymbolName from gt4py.next import common from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator.type_system import inference as itir_type_inference from gt4py.next.program_processors.codegens.gtfn.gtfn_ir import ( Backend, @@ -45,7 +45,6 @@ ) from gt4py.next.program_processors.codegens.gtfn.gtfn_ir_common import Expr, Node, Sym, SymRef from gt4py.next.type_system import type_info, type_specifications as ts -from gt4py.next.iterator.ir_utils import ir_makers as im def pytype_to_cpptype(t: ts.ScalarType | str) -> Optional[str]: From b58934639602ead9558be826f8dad29577e6f983 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 17 Oct 2024 11:15:37 +0200 Subject: [PATCH 051/150] Cleanup --- .../next/iterator/transforms/global_tmps.py | 22 +++++++++---------- .../codegens/gtfn/itir_to_gtfn_ir.py | 2 +- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index f80dc74834..11d3fccec1 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -27,10 +27,7 @@ def _transform_if( stmt: itir.Stmt, declarations: list[itir.Temporary], uids: eve_utils.UIDGenerator ) -> Optional[list[itir.Stmt]]: - if not isinstance(stmt, itir.SetAt): - return None - - if cpm.is_call_to(stmt.expr, "if_"): + if isinstance(stmt, itir.SetAt) and cpm.is_call_to(stmt.expr, "if_"): cond, true_val, false_val = stmt.expr.args return [ itir.IfStmt( @@ -51,7 +48,10 @@ def _transform_if( def _transform_by_pattern( - stmt: itir.Stmt, predicate, declarations: list[itir.Temporary], uids: eve_utils.UIDGenerator + stmt: itir.Stmt, + predicate: Callable[[itir.Expr, int], bool], + declarations: list[itir.Temporary], + uids: eve_utils.UIDGenerator, ) -> Optional[list[itir.Stmt]]: if not isinstance(stmt, itir.SetAt): return None @@ -76,7 +76,7 @@ def _transform_by_pattern( for tmp_sym, tmp_expr in extracted_fields.items(): domain = tmp_expr.annex.domain - # TODO(tehrengruber): Implement. This happens when the expression for a combination + # TODO(tehrengruber): Implement. This happens when the expression is a combination # of an `if_` call with a tuple, e.g., `if_(cond, {a, b}, {c, d})`. As long as we are # able to eliminate all tuples, e.g., by propagating the scalar ifs to the top-level # of a SetAt, the CollapseTuple pass will eliminate most of this cases. @@ -86,7 +86,7 @@ def _transform_by_pattern( ) if not all(d == flattened_domains[0] for d in flattened_domains): raise NotImplementedError( - "Tuple expressions with different domains is not " "supported yet." + "Tuple expressions with different domains is not supported yet." ) domain = flattened_domains[0] assert isinstance(domain, domain_utils.SymbolicDomain) @@ -107,12 +107,10 @@ def _transform_by_pattern( ) # allocate temporary for all tuple elements - def allocate_temporary(tmp_name: str, dtype: ts.ScalarType, domain: itir.Expr): - declarations.append(itir.Temporary(id=tmp_name, domain=domain, dtype=dtype)) + def allocate_temporary(tmp_name: str, dtype: ts.ScalarType): + declarations.append(itir.Temporary(id=tmp_name, domain=domain_expr, dtype=dtype)) # noqa: B023 # function only used inside loop - next_utils.tree_map(functools.partial(allocate_temporary, domain=domain_expr))( - tmp_names, tmp_dtypes - ) + next_utils.tree_map(allocate_temporary)(tmp_names, tmp_dtypes) # if the expr is a field this just gives a simple `itir.SymRef`, otherwise we generate a # `make_tuple` expression. diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index 3bd96d14d7..9e4088a349 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -15,7 +15,7 @@ from gt4py.eve.concepts import SymbolName from gt4py.next import common from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator.type_system import inference as itir_type_inference from gt4py.next.program_processors.codegens.gtfn.gtfn_ir import ( Backend, From 4d2b3da5d3289210c3d21fec912c4a647a293d82 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 17 Oct 2024 14:07:38 +0200 Subject: [PATCH 052/150] Fix format --- .../next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index 9e4088a349..3bd96d14d7 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -15,7 +15,7 @@ from gt4py.eve.concepts import SymbolName from gt4py.next import common from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.iterator.type_system import inference as itir_type_inference from gt4py.next.program_processors.codegens.gtfn.gtfn_ir import ( Backend, From 70c0dffa89b4256ea655d6f06ab5bf6112f044a9 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sat, 19 Oct 2024 17:30:10 +0200 Subject: [PATCH 053/150] Fix scalar let in gtfn --- .../iterator/transforms/collapse_tuple.py | 5 +++ .../iterator/transforms/fuse_as_fieldop.py | 23 +++++++++----- .../next/iterator/transforms/inline_scalar.py | 31 +++++++++++++++++++ .../next/iterator/transforms/pass_manager.py | 2 ++ .../next/iterator/type_system/inference.py | 2 ++ .../iterator/type_system/type_synthesizer.py | 2 +- .../codegens/gtfn/itir_to_gtfn_ir.py | 3 +- .../ffront_tests/test_execution.py | 15 +++++++++ 8 files changed, 73 insertions(+), 10 deletions(-) create mode 100644 src/gt4py/next/iterator/transforms/inline_scalar.py diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index cfabf5b6d1..d5d85e2bc2 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -206,6 +206,11 @@ def transform_collapse_make_tuple_tuple_get( # tuple argument differs, just continue with the rest of the tree return None + # this occurs in case of a scan returning a tuple, e.g.: + # `tuple_get(0, as_fieldop(scan(...)(...))(...))` + if isinstance(first_expr.type, ts.DeferredType): + return None + assert self.ignore_tuple_size or isinstance(first_expr.type, ts.TupleType) if self.ignore_tuple_size or len(first_expr.type.types) == len(node.args): # type: ignore[union-attr] # ensured by assert above return first_expr diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index 8928b406f3..da238733da 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -13,7 +13,12 @@ from gt4py.eve import utils as eve_utils from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im -from gt4py.next.iterator.transforms import inline_lambdas, inline_lifts, trace_shifts +from gt4py.next.iterator.transforms import ( + inline_center_deref_lift_vars, + inline_lambdas, + inline_lifts, + trace_shifts, +) from gt4py.next.iterator.type_system import ( inference as type_inference, type_specifications as it_ts, @@ -202,15 +207,19 @@ def visit_FunCall(self, node: itir.FunCall): new_param = stencil_param.id new_args = _merge_arguments(new_args, {new_param: arg}) - # simplify stencil directly to keep the tree small - new_stencil_body = inline_lambdas.InlineLambdas.apply( - new_stencil_body, opcount_preserving=True - ) - new_stencil_body = inline_lifts.InlineLifts().visit(new_stencil_body) - new_node = im.as_fieldop(im.lambda_(*new_args.keys())(new_stencil_body), domain)( *new_args.values() ) + + # simplify stencil directly to keep the tree small + new_node = inline_center_deref_lift_vars.InlineCenterDerefLiftVars.apply( + new_node + ) # to keep the tree small + new_node = inline_lambdas.InlineLambdas.apply( + new_node, opcount_preserving=True, force_inline_lift_args=True + ) + new_node = inline_lifts.InlineLifts().visit(new_node) + type_inference.copy_type(from_=node, to=new_node) return new_node diff --git a/src/gt4py/next/iterator/transforms/inline_scalar.py b/src/gt4py/next/iterator/transforms/inline_scalar.py new file mode 100644 index 0000000000..c6e2c38b90 --- /dev/null +++ b/src/gt4py/next/iterator/transforms/inline_scalar.py @@ -0,0 +1,31 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from gt4py import eve +from gt4py.next import common +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.iterator.transforms import inline_lambdas +from gt4py.next.iterator.type_system import inference as itir_inference +from gt4py.next.type_system import type_specifications as ts + + +class InlineScalar(eve.NodeTranslator): + @classmethod + def apply(cls, program: itir.Program, offset_provider: common.OffsetProvider): + program = itir_inference.infer(program, offset_provider=offset_provider) + return cls().visit(program) + + def visit_Expr(self, node: itir.Expr): + node = self.generic_visit(node) + + if cpm.is_let(node): + eligible_params = [isinstance(arg.type, ts.ScalarType) for arg in node.args] + node = inline_lambdas.inline_lambda(node, eligible_params=eligible_params) + return node + return node diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 1faffecc0b..32560683f8 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -24,6 +24,7 @@ from gt4py.next.iterator.transforms.cse import CommonSubexpressionElimination from gt4py.next.iterator.transforms.fuse_maps import FuseMaps from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas +from gt4py.next.iterator.transforms.inline_scalar import InlineScalar from gt4py.next.iterator.transforms.merge_let import MergeLet from gt4py.next.iterator.transforms.normalize_shifts import NormalizeShifts from gt4py.next.iterator.transforms.unroll_reduce import UnrollReduce @@ -81,6 +82,7 @@ def apply_common_transforms( # This pass is required to be in the loop such that when an `if_` call with tuple arguments # is constant-folded the surrounding tuple_get calls can be removed. inlined = CollapseTuple.apply(inlined, offset_provider=offset_provider) # type: ignore[assignment] # always an itir.Program + inlined = InlineScalar.apply(inlined, offset_provider=offset_provider) # This pass is required to run after CollapseTuple as otherwise we can not inline # expressions like `tuple_get(make_tuple(as_fieldop(stencil)(...)))` where stencil returns diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index dff81ce40f..87fdefc4c1 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -624,6 +624,8 @@ def visit_FunCall( self.visit(tuple_, ctx=ctx) # ensure tuple is typed assert isinstance(index_literal, itir.Literal) index = int(index_literal.value) + if isinstance(tuple_.type, ts.DeferredType): + return ts.DeferredType(constraint=None) assert isinstance(tuple_.type, ts.TupleType) return tuple_.type.types[index] diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index f30dfc0fcf..c55cfd8d51 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -292,7 +292,7 @@ def as_fieldop( @TypeSynthesizer def applied_as_fieldop(*fields) -> ts.FieldType | ts.DeferredType: if any(isinstance(f, ts.DeferredType) for f in fields): - return ts.DeferredType(constraint=ts.FieldType) + return ts.DeferredType(constraint=None) stencil_return = stencil( *(_convert_as_fieldop_input_to_iterator(domain, field) for field in fields), diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index 7c7265ff32..bc2bd645e8 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -15,8 +15,7 @@ from gt4py.eve.concepts import SymbolName from gt4py.next import common from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator.type_system import inference as itir_type_inference from gt4py.next.program_processors.codegens.gtfn.gtfn_ir import ( Backend, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index c0b2f97db7..efcbc33c55 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -297,6 +297,21 @@ def testee(a: cases.IJKField, b: int32) -> cases.IJKField: cases.verify(cartesian_case, testee, a, b, out=out, ref=ref) +@pytest.mark.uses_tuple_args +def test_double_use_scalar(cartesian_case): + # TODO(tehrengruber): This should be a regression test on ITIR level, but tracing doesn't + # work for this case. + @gtx.field_operator + def testee(a: np.int32, b: np.int32, c: cases.IField) -> cases.IField: + tmp = a * b + tmp2 = tmp * tmp + # important part here is that we use the intermediate twice so that it is + # not inlined + return tmp2 * tmp2 * c + + cases.verify_with_default_data(cartesian_case, testee, ref=lambda a, b, c: a * b * a * b * c) + + @pytest.mark.uses_scalar_in_domain_and_fo def test_scalar_in_domain_spec_and_fo_call(cartesian_case): @gtx.field_operator From 108af05534b1e619ea6b419a3e29e7b542719222 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sun, 20 Oct 2024 17:33:44 +0200 Subject: [PATCH 054/150] Fix some tests --- src/gt4py/next/iterator/type_system/inference.py | 4 ++-- .../runners/dace_iterator/workflow.py | 6 +++++- .../feature_tests/iterator_tests/test_scan.py | 2 ++ .../ffront_tests/test_icon_like_scan.py | 13 +------------ 4 files changed, 10 insertions(+), 15 deletions(-) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 87fdefc4c1..623134b619 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -520,8 +520,8 @@ def visit_SetAt(self, node: itir.SetAt, *, ctx) -> None: # complicated comparison. if isinstance(target_type, ts.FieldType) and isinstance(expr_type, ts.FieldType): assert ( - set(expr_type.dims) == set(target_type.dims) - and target_type.dtype == expr_type.dtype + set(expr_type.dims).issubset(set(target_type.dims)) + and target_type.dtype == expr_type.dtype ) # TODO(tehrengruber): Remove after new ITIR format with apply_stencil is used everywhere diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py b/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py index aa2e94ee68..19f4ca92dd 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py @@ -18,6 +18,7 @@ from gt4py._core import definitions as core_defs from gt4py.next import common, config from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.transforms.fencil_to_program import FencilToProgram from gt4py.next.otf import languages, recipes, stages, step_types, workflow from gt4py.next.otf.binding import interface from gt4py.next.otf.languages import LanguageSettings @@ -79,7 +80,10 @@ def __call__( ) -> stages.ProgramSource[languages.SDFG, LanguageSettings]: """Generate DaCe SDFG file from the ITIR definition.""" program: itir.FencilDefinition | itir.Program = inp.data - assert isinstance(program, itir.FencilDefinition) + + # FIXME[#1582](tehrengruber): Remove. This code-path is only used by the dace_itir backend. + if isinstance(program, itir.FencilDefinition): + program = FencilToProgram.apply(program) sdfg = self.generate_sdfg( program, diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py index a86959d075..efa2e3f5b3 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py @@ -19,6 +19,8 @@ @pytest.mark.uses_index_fields def test_scan_in_stencil(program_processor): + # FIXME[#1582](tehrengruber): Remove test after scan is reworked. + pytest.skip("Scan inside of stencil is not supported in GTIR.") program_processor, validate = program_processor isize = 1 diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py index 505879a506..a37d0962e7 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py @@ -229,7 +229,7 @@ def test_solve_nonhydro_stencil_52_like_z_q(test_setup): def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup): if ( test_setup.case.backend - == test_definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES.load() + == test_definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES.load() ): pytest.xfail( "Needs implementation of scan projector. Breaks in type inference as executed" @@ -254,12 +254,6 @@ def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup): @pytest.mark.uses_tuple_returns def test_solve_nonhydro_stencil_52_like(test_setup): - if ( - test_setup.case.backend - == test_definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES.load() - ): - pytest.xfail("Temporary extraction does not work correctly in combination with scans.") - cases.run( test_setup.case, solve_nonhydro_stencil_52_like, @@ -276,11 +270,6 @@ def test_solve_nonhydro_stencil_52_like(test_setup): @pytest.mark.uses_tuple_returns def test_solve_nonhydro_stencil_52_like_with_gtfn_tuple_merge(test_setup): - if ( - test_setup.case.backend - == test_definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES.load() - ): - pytest.xfail("Temporary extraction does not work correctly in combination with scans.") if test_setup.case.backend == test_definitions.ProgramBackendId.ROUNDTRIP.load(): pytest.xfail("Needs proper handling of tuple[Column] <-> Column[tuple].") From 7380d6ecd2a7ac5d1a2c7149f6f60fafc5734e15 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sun, 20 Oct 2024 17:35:54 +0200 Subject: [PATCH 055/150] Fix --- src/gt4py/next/iterator/type_system/inference.py | 4 ++-- .../program_processors/runners/dace_iterator/workflow.py | 6 +++--- .../multi_feature_tests/ffront_tests/test_icon_like_scan.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 623134b619..edcb9b540c 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -520,8 +520,8 @@ def visit_SetAt(self, node: itir.SetAt, *, ctx) -> None: # complicated comparison. if isinstance(target_type, ts.FieldType) and isinstance(expr_type, ts.FieldType): assert ( - set(expr_type.dims).issubset(set(target_type.dims)) - and target_type.dtype == expr_type.dtype + set(expr_type.dims).issubset(set(target_type.dims)) + and target_type.dtype == expr_type.dtype ) # TODO(tehrengruber): Remove after new ITIR format with apply_stencil is used everywhere diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py b/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py index 19f4ca92dd..6e27af5f95 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py @@ -18,7 +18,7 @@ from gt4py._core import definitions as core_defs from gt4py.next import common, config from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms.fencil_to_program import FencilToProgram +from gt4py.next.iterator.transforms import program_to_fencil from gt4py.next.otf import languages, recipes, stages, step_types, workflow from gt4py.next.otf.binding import interface from gt4py.next.otf.languages import LanguageSettings @@ -82,8 +82,8 @@ def __call__( program: itir.FencilDefinition | itir.Program = inp.data # FIXME[#1582](tehrengruber): Remove. This code-path is only used by the dace_itir backend. - if isinstance(program, itir.FencilDefinition): - program = FencilToProgram.apply(program) + if isinstance(program, itir.Program): + program = program_to_fencil.program_to_fencil(program) sdfg = self.generate_sdfg( program, diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py index a37d0962e7..b4079a0080 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py @@ -229,7 +229,7 @@ def test_solve_nonhydro_stencil_52_like_z_q(test_setup): def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup): if ( test_setup.case.backend - == test_definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES.load() + == test_definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES.load() ): pytest.xfail( "Needs implementation of scan projector. Breaks in type inference as executed" From c268ee19d32bdb27fadcfd49b2aef2fbd97852f7 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 21 Oct 2024 00:00:35 +0200 Subject: [PATCH 056/150] Fix some tests --- .../next/iterator/ir_utils/domain_utils.py | 2 +- src/gt4py/next/iterator/transforms/cse.py | 2 +- .../next/iterator/transforms/infer_domain.py | 50 ++++++++++++------- .../test_with_toy_connectivity.py | 1 - .../transforms_tests/test_collapse_tuple.py | 19 ++++++- 5 files changed, 51 insertions(+), 23 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index 9e8729b339..e25508e279 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -93,7 +93,7 @@ def translate( ..., ], offset_provider: common.OffsetProvider, - symbolic_domain_sizes: Optional[dict[str, str]], + symbolic_domain_sizes: Optional[dict[str, str]] = None, ) -> SymbolicDomain: dims = list(self.ranges.keys()) new_ranges = {dim: self.ranges[dim] for dim in dims} diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index aefff0c8bb..ecdde572dc 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -42,7 +42,7 @@ def _is_trivial_tuple_expr(node: itir.Expr): or _is_trivial_tuple_expr(node.args[1]) ): return True - return True + return False @dataclasses.dataclass diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index f15b54b89e..d897bd4ec8 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -357,37 +357,49 @@ def infer_expr( return expr, accessed_domains +def _infer_stmt( + stmt: itir.Stmt, + offset_provider: common.OffsetProvider, + symbolic_domain_sizes: Optional[dict[str, str]], +): + if isinstance(stmt, itir.SetAt): + transformed_call, _unused_domain = infer_expr( + stmt.expr, + domain_utils.SymbolicDomain.from_expr(stmt.domain), + offset_provider, + symbolic_domain_sizes, + ) + return itir.SetAt( + expr=transformed_call, + domain=stmt.domain, + target=stmt.target, + ) + elif isinstance(stmt, itir.IfStmt): + return itir.IfStmt( + cond=stmt.cond, + true_branch=[ + _infer_stmt(c, offset_provider, symbolic_domain_sizes) for c in stmt.true_branch + ], + false_branch=[ + _infer_stmt(c, offset_provider, symbolic_domain_sizes) for c in stmt.false_branch + ], + ) + raise ValueError(f"Unsupported stmt: {stmt}") + + def infer_program( program: itir.Program, offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]] = None, ) -> itir.Program: - transformed_set_ats: list[itir.SetAt] = [] assert ( not program.function_definitions ), "Domain propagation does not support function definitions." - for set_at in program.body: - assert isinstance(set_at, itir.SetAt) - - transformed_call, _unused_domain = infer_expr( - set_at.expr, - domain_utils.SymbolicDomain.from_expr(set_at.domain), - offset_provider, - symbolic_domain_sizes, - ) - transformed_set_ats.append( - itir.SetAt( - expr=transformed_call, - domain=set_at.domain, - target=set_at.target, - ), - ) - return itir.Program( id=program.id, function_definitions=program.function_definitions, params=program.params, declarations=program.declarations, - body=transformed_set_ats, + body=[_infer_stmt(stmt, offset_provider, symbolic_domain_sizes) for stmt in program.body], ) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py index 6fb1d4c152..6fdc6a77a1 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py @@ -383,7 +383,6 @@ def test_shift_sparse_input_field2(program_processor): if program_processor in [ gtfn.run_gtfn, gtfn.run_gtfn_imperative, - gtfn.run_gtfn_with_temporaries, ]: pytest.xfail( "Bug in bindings/compilation/caching: only the first program seems to be compiled." diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py index bcf8b726be..2fe39cb4c9 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py @@ -19,6 +19,7 @@ def test_simple_make_tuple_tuple_get(): remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, allow_undeclared_symbols=True, + is_local_view=False, ) expected = tuple_of_size_2 @@ -36,6 +37,7 @@ def test_nested_make_tuple_tuple_get(): remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, allow_undeclared_symbols=True, + is_local_view=False, ) assert actual == tup_of_size2_from_lambda @@ -51,6 +53,7 @@ def test_different_tuples_make_tuple_tuple_get(): remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, allow_undeclared_symbols=True, + is_local_view=False, ) assert actual == testee # did nothing @@ -64,6 +67,7 @@ def test_incompatible_order_make_tuple_tuple_get(): remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, allow_undeclared_symbols=True, + is_local_view=False, ) assert actual == testee # did nothing @@ -75,6 +79,7 @@ def test_incompatible_size_make_tuple_tuple_get(): remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, allow_undeclared_symbols=True, + is_local_view=False, ) assert actual == testee # did nothing @@ -86,6 +91,7 @@ def test_merged_with_smaller_outer_size_make_tuple_tuple_get(): ignore_tuple_size=True, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, allow_undeclared_symbols=True, + is_local_view=False, ) assert actual == im.make_tuple("first", "second") @@ -98,6 +104,7 @@ def test_simple_tuple_get_make_tuple(): remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.COLLAPSE_TUPLE_GET_MAKE_TUPLE, allow_undeclared_symbols=True, + is_local_view=False, ) assert expected == actual @@ -110,6 +117,7 @@ def test_propagate_tuple_get(): remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.PROPAGATE_TUPLE_GET, allow_undeclared_symbols=True, + is_local_view=False, ) assert expected == actual @@ -127,6 +135,7 @@ def test_letify_make_tuple_elements(): remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS, allow_undeclared_symbols=True, + is_local_view=False, ) assert actual == expected @@ -140,6 +149,7 @@ def test_letify_make_tuple_with_trivial_elements(): remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS, allow_undeclared_symbols=True, + is_local_view=False, ) assert actual == expected @@ -153,6 +163,7 @@ def test_inline_trivial_make_tuple(): remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.INLINE_TRIVIAL_MAKE_TUPLE, allow_undeclared_symbols=True, + is_local_view=False, ) assert actual == expected @@ -171,6 +182,7 @@ def test_propagate_to_if_on_tuples(): remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, allow_undeclared_symbols=True, + is_local_view=False, ) assert actual == expected @@ -188,6 +200,7 @@ def test_propagate_to_if_on_tuples_with_let(): flags=CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES | CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS, allow_undeclared_symbols=True, + is_local_view=False, ) assert actual == expected @@ -200,6 +213,7 @@ def test_propagate_nested_lift(): remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.PROPAGATE_NESTED_LET, allow_undeclared_symbols=True, + is_local_view=False, ) assert actual == expected @@ -210,6 +224,9 @@ def test_if_on_tuples_with_let(): )(im.tuple_get(0, "val")) expected = im.if_("pred", 1, 3) actual = CollapseTuple.apply( - testee, remove_letified_make_tuple_elements=False, allow_undeclared_symbols=True + testee, + remove_letified_make_tuple_elements=False, + allow_undeclared_symbols=True, + is_local_view=False, ) assert actual == expected From c5b0171733cfbe8732f02fbc4add9dec6bdf2a27 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 21 Oct 2024 23:56:44 +0200 Subject: [PATCH 057/150] Address review comments --- .../iterator/transforms/collapse_list_get.py | 43 +++++++------------ .../iterator/transforms/collapse_tuple.py | 7 +-- 2 files changed, 20 insertions(+), 30 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/collapse_list_get.py b/src/gt4py/next/iterator/transforms/collapse_list_get.py index 0795cf5739..4a354879ca 100644 --- a/src/gt4py/next/iterator/transforms/collapse_list_get.py +++ b/src/gt4py/next/iterator/transforms/collapse_list_get.py @@ -7,7 +7,7 @@ # SPDX-License-Identifier: BSD-3-Clause from gt4py import eve -from gt4py.next.iterator import ir +from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im @@ -19,9 +19,9 @@ class CollapseListGet(eve.PreserveLocationVisitor, eve.NodeTranslator): - `list_get(i, make_const_list(e))` -> `e` """ - def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: + def visit_FunCall(self, node: itir.FunCall, **kwargs) -> itir.Node: node = self.generic_visit(node) - if node.fun == ir.SymRef(id="list_get"): + if cpm.is_call_to(node, "list_get"): if cpm.is_call_to(node.args[1], "if_"): list_idx = node.args[0] cond, true_val, false_val = node.args[1].args @@ -30,29 +30,18 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: self.visit(im.call("list_get")(list_idx, true_val)), self.visit(im.call("list_get")(list_idx, false_val)), ) - if isinstance(node.args[1], ir.FunCall): - if node.args[1].fun == ir.SymRef(id="neighbors"): - offset_tag = node.args[1].args[0] - offset_index = ( - ir.OffsetLiteral(value=int(node.args[0].value)) - if isinstance(node.args[0], ir.Literal) - else node.args[ - 0 - ] # else-branch: e.g. SymRef from unroll_reduce, TODO(havogt): remove when we replace unroll_reduce by list support in gtfn - ) - it = node.args[1].args[1] - return ir.FunCall( - fun=ir.SymRef(id="deref"), - args=[ - ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="shift"), args=[offset_tag, offset_index] - ), - args=[it], - ) - ], - ) - if node.args[1].fun == ir.SymRef(id="make_const_list"): - return node.args[1].args[0] + if cpm.is_call_to(node.args[1], "neighbors"): + offset_tag = node.args[1].args[0] + offset_index = ( + itir.OffsetLiteral(value=int(node.args[0].value)) + if isinstance(node.args[0], itir.Literal) + else node.args[ + 0 + ] # else-branch: e.g. SymRef from unroll_reduce, TODO(havogt): remove when we replace unroll_reduce by list support in gtfn + ) + it = node.args[1].args[1] + return im.deref(im.shift(offset_tag, offset_index)(it)) + if cpm.is_call_to(node.args[1], "make_const_list"): + return node.args[1].args[0] return node diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index d5d85e2bc2..a4a46a6bc0 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -16,6 +16,7 @@ from gt4py import eve from gt4py.eve import utils as eve_utils +from gt4py.next import common from gt4py.next.iterator import ir from gt4py.next.iterator.ir_utils import ( common_pattern_matcher as cpm, @@ -104,10 +105,10 @@ def apply( *, ignore_tuple_size: bool = False, remove_letified_make_tuple_elements: bool = True, - offset_provider=None, - is_local_view=None, + offset_provider: Optional[common.OffsetProvider] = None, + is_local_view: Optional[bool] = None, # manually passing flags is mostly for allowing separate testing of the modes - flags=None, + flags: Optional[Flag] = None, # allow sym references without a symbol declaration, mostly for testing allow_undeclared_symbols: bool = False, ) -> ir.Node: From bba6aa4f17b8a3274c73e94a91f96a060c536bd7 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 24 Oct 2024 10:52:18 +0200 Subject: [PATCH 058/150] Allow partial type inference in ITIR --- .../next/iterator/type_system/inference.py | 19 +++++++++++-------- .../iterator/type_system/type_synthesizer.py | 3 +++ 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 4640aa11d1..edcb9b540c 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -84,7 +84,7 @@ def _is_compatible_type(type_a: ts.TypeSpec, type_b: ts.TypeSpec): is_compatible &= _is_compatible_type(arg_a, arg_b) is_compatible &= _is_compatible_type(type_a.returns, type_b.returns) else: - is_compatible &= type_a == type_b + is_compatible &= type_info.is_concretizable(type_a, type_b) return is_compatible @@ -435,7 +435,7 @@ def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: result = super().visit(node, **kwargs) if isinstance(node, itir.Node): if isinstance(result, ts.TypeSpec): - if node.type: + if node.type and not isinstance(node.type, ts.DeferredType): assert _is_compatible_type(node.type, result) node.type = result elif isinstance(result, ObservableTypeSynthesizer) or result is None: @@ -511,17 +511,18 @@ def visit_SetAt(self, node: itir.SetAt, *, ctx) -> None: path, node.expr.type, ) - assert isinstance(target_type, ts.FieldType) - assert isinstance(expr_type, ts.FieldType) + assert isinstance(target_type, (ts.FieldType, ts.DeferredType)) + assert isinstance(expr_type, (ts.FieldType, ts.DeferredType)) # TODO(tehrengruber): The lowering emits domains that always have the horizontal domain # first. Since the expr inherits the ordering from the domain this can lead to a mismatch # between the target and expr (e.g. when the target has dimension K, Vertex). We should # probably just change the behaviour of the lowering. Until then we do this more # complicated comparison. - assert ( - set(expr_type.dims) == set(target_type.dims) - and target_type.dtype == expr_type.dtype - ) + if isinstance(target_type, ts.FieldType) and isinstance(expr_type, ts.FieldType): + assert ( + set(expr_type.dims).issubset(set(target_type.dims)) + and target_type.dtype == expr_type.dtype + ) # TODO(tehrengruber): Remove after new ITIR format with apply_stencil is used everywhere def visit_StencilClosure(self, node: itir.StencilClosure, *, ctx) -> it_ts.StencilClosureType: @@ -623,6 +624,8 @@ def visit_FunCall( self.visit(tuple_, ctx=ctx) # ensure tuple is typed assert isinstance(index_literal, itir.Literal) index = int(index_literal.value) + if isinstance(tuple_.type, ts.DeferredType): + return ts.DeferredType(constraint=None) assert isinstance(tuple_.type, ts.TupleType) return tuple_.type.types[index] diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index c836de1391..c55cfd8d51 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -291,6 +291,9 @@ def as_fieldop( @TypeSynthesizer def applied_as_fieldop(*fields) -> ts.FieldType | ts.DeferredType: + if any(isinstance(f, ts.DeferredType) for f in fields): + return ts.DeferredType(constraint=None) + stencil_return = stencil( *(_convert_as_fieldop_input_to_iterator(domain, field) for field in fields), offset_provider=offset_provider, From f9fc5c54de7a061bc6986cf66a552016a90ed65f Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 24 Oct 2024 14:19:24 +0200 Subject: [PATCH 059/150] Small fix --- src/gt4py/next/iterator/transforms/collapse_tuple.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 40d98208dd..0db7388a58 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -192,8 +192,13 @@ def transform_collapse_make_tuple_tuple_get(self, node: ir.FunCall) -> Optional[ # tuple argument differs, just continue with the rest of the tree return None - assert self.ignore_tuple_size or isinstance(first_expr.type, ts.TupleType) - if self.ignore_tuple_size or len(first_expr.type.types) == len(node.args): # type: ignore[union-attr] # ensured by assert above + assert self.ignore_tuple_size or isinstance( + first_expr.type, (ts.TupleType, ts.DeferredType) + ) + if self.ignore_tuple_size or ( + isinstance(first_expr.type, ts.TupleType) + and len(first_expr.type.types) == len(node.args) + ): # type: ignore[union-attr] # ensured by assert above return first_expr return None From c241bc4a98f787f3fcdfd95437ef2f88d3431359 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 28 Oct 2024 11:47:46 +0100 Subject: [PATCH 060/150] Add tests --- .../iterator/transforms/collapse_tuple.py | 2 +- .../iterator_tests/test_type_inference.py | 37 +++++++++++++++++++ .../transforms_tests/test_collapse_tuple.py | 9 +++++ 3 files changed, 47 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 0db7388a58..b61fb2ba87 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -198,7 +198,7 @@ def transform_collapse_make_tuple_tuple_get(self, node: ir.FunCall) -> Optional[ if self.ignore_tuple_size or ( isinstance(first_expr.type, ts.TupleType) and len(first_expr.type.types) == len(node.args) - ): # type: ignore[union-attr] # ensured by assert above + ): return first_expr return None diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 20a1d7e9b7..7b6214fb1b 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -103,6 +103,10 @@ def expression_test_cases(): # tuple_get (im.tuple_get(0, im.make_tuple(im.ref("a", int_type), im.ref("b", bool_type))), int_type), (im.tuple_get(1, im.make_tuple(im.ref("a", int_type), im.ref("b", bool_type))), bool_type), + ( + im.tuple_get(0, im.ref("t", ts.DeferredType(constraint=None))), + ts.DeferredType(constraint=None), + ), # neighbors ( im.neighbors("E2V", im.ref("a", it_on_e_of_e_type)), @@ -171,6 +175,12 @@ def expression_test_cases(): )(im.ref("inp1", float_i_field), im.ref("inp2", float_i_field)), ts.TupleType(types=[float_i_field, float_i_field]), ), + ( + im.as_fieldop(im.lambda_("x")(im.deref("x")))( + im.ref("inp", ts.DeferredType(constraint=None)) + ), + ts.DeferredType(constraint=None), + ), # if in field-view scope ( im.if_( @@ -458,6 +468,33 @@ def test_program_tuple_setat_short_target(): ) +def test_program_setat_without_domain(): + cartesian_domain = im.call("cartesian_domain")( + im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) + ) + + testee = itir.Program( + id="f", + function_definitions=[], + params=[im.sym("inp", float_i_field), im.sym("out", float_i_field)], + declarations=[], + body=[ + itir.SetAt( + expr=im.as_fieldop(im.lambda_("x")(im.deref("x")))("inp"), + domain=cartesian_domain, + target=im.ref("out", float_i_field), + ) + ], + ) + + result = itir_type_inference.infer(testee, offset_provider={"Ioff": IDim}) + + assert ( + isinstance(result.body[0].expr.type, ts.DeferredType) + and result.body[0].expr.type.constraint == ts.FieldType + ) + + def test_if_stmt(): cartesian_domain = im.call("cartesian_domain")( im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py index bcf8b726be..720076c8c2 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py @@ -8,6 +8,7 @@ from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple +from gt4py.next.type_system import type_specifications as ts def test_simple_make_tuple_tuple_get(): @@ -213,3 +214,11 @@ def test_if_on_tuples_with_let(): testee, remove_letified_make_tuple_elements=False, allow_undeclared_symbols=True ) assert actual == expected + + +def test_tuple_get_on_untyped_ref(): + # test pass gracefully handles untyped nodes. + testee = im.tuple_get(0, im.ref("val", ts.DeferredType(constraint=None))) + + actual = CollapseTuple.apply(testee, allow_undeclared_symbols=True) + assert actual == testee From 45f41beb32296575cfcd01ef947dd3520392add8 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 28 Oct 2024 12:43:37 +0100 Subject: [PATCH 061/150] Address review comments --- .../next/iterator/transforms/collapse_tuple.py | 1 - .../next/iterator/transforms/infer_domain.py | 17 +++++++++++++++++ .../next/iterator/transforms/pass_manager.py | 2 +- .../runners/dace_iterator/workflow.py | 1 - 4 files changed, 18 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index a4a46a6bc0..708b5a93be 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -160,7 +160,6 @@ def apply( return new_node def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: - # don't visit stencil argument of `as_fieldop` if cpm.is_call_to(node, "as_fieldop"): kwargs = {**kwargs, "is_local_view": True} diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index d897bd4ec8..8de88959e2 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -351,6 +351,23 @@ def infer_expr( offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]] = None, ) -> tuple[itir.Expr, ACCESSED_DOMAINS]: + """ + Infer the domain of all field subexpressions of `expr`. + + Given an expression `expr` and the domain it is accessed at, back-propagate the domain of all + (field-typed) subexpression. + + Arguments: + - expr: The expression to be inferred. + - domain: The domain `expr` is read at. + - symbolic_domain_sizes: A dictionary mapping axes names, e.g., `I`, `Vertex`, to a symbol + name that evaluates to the length of that axis. + + Returns: + A tuple containing the inferred expression with all applied `as_fieldop` (that are accessed) + having a domain argument now, and a dictionary mapping symbol names referenced in `expr` to + domain they are accessed at. + """ # this is just a small wrapper that populates the `domain` annex expr, accessed_domains = _infer_expr(expr, domain, offset_provider, symbolic_domain_sizes) expr.annex.domain = domain diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index a762ee99e1..98dc6c623d 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -53,7 +53,7 @@ def apply_common_transforms( temporary_extraction_heuristics: Optional[ Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] ] = None, - # FIXME[#1582](tehrengruber): Revisit and cleanup after new GTIR temporary pass is in place + #: A dictionary mapping axes names, e.g., `I`, `Vertex`, to a symbol name that evaluates to the length of that axis. symbolic_domain_sizes: Optional[dict[str, str]] = None, ) -> itir.Program: # FIXME[#1582](tehrengruber): Rewrite iterator tests with itir.Program and remove this diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py b/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py index 6e27af5f95..740f1979cd 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py @@ -81,7 +81,6 @@ def __call__( """Generate DaCe SDFG file from the ITIR definition.""" program: itir.FencilDefinition | itir.Program = inp.data - # FIXME[#1582](tehrengruber): Remove. This code-path is only used by the dace_itir backend. if isinstance(program, itir.Program): program = program_to_fencil.program_to_fencil(program) From 1874eba5334b5761abfd92a57ad9c0f5efb9daee Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 5 Nov 2024 09:43:31 +0100 Subject: [PATCH 062/150] Fix tests --- .../codegens/gtfn/gtfn_module.py | 7 ++- .../next/program_processors/runners/gtfn.py | 2 + .../program_processors/runners/roundtrip.py | 7 +++ tests/next_tests/definitions.py | 13 ++--- .../iterator_tests/test_if_stmt.py | 6 +-- .../iterator_tests/test_vertical_advection.py | 4 ++ tests/next_tests/unit_tests/conftest.py | 53 ++++++++++++------- 7 files changed, 62 insertions(+), 30 deletions(-) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index 85260afa07..275f70b401 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -193,7 +193,12 @@ def generate_stencil_source( offset_provider: dict[str, Connectivity | Dimension], column_axis: Optional[common.Dimension], ) -> str: - new_program = self._preprocess_program(program, offset_provider) + if self.enable_itir_transforms: + new_program = self._preprocess_program(program, offset_provider) + else: + assert isinstance(program, itir.Program) + new_program = program + gtfn_ir = GTFN_lowering.apply( new_program, offset_provider=offset_provider, column_axis=column_axis ) diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index e82c12fad2..aa832cf29b 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -191,3 +191,5 @@ class Params: run_gtfn_gpu = GTFNBackendFactory(gpu=True) run_gtfn_gpu_cached = GTFNBackendFactory(gpu=True, cached=True) + +run_gtfn_no_transforms = GTFNBackendFactory(otf_workflow__translation__enable_itir_transforms=False) diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 70f0bc7df9..4d518d7fcc 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -223,6 +223,7 @@ def decorated_fencil( return decorated_fencil +# TODO(tehrengruber): introduce factory default = next_backend.Backend( name="roundtrip", executor=Roundtrip( @@ -245,6 +246,12 @@ def decorated_fencil( allocator=next_allocators.StandardCPUFieldBufferAllocator(), transforms=next_backend.DEFAULT_TRANSFORMS, ) +no_transforms = next_backend.Backend( + name="roundtrip", + executor=Roundtrip(transforms=lambda o, *, offset_provider: o), + allocator=next_allocators.StandardCPUFieldBufferAllocator(), + transforms=next_backend.DEFAULT_TRANSFORMS, +) gtir = next_backend.Backend( diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 2c61ff085c..439b8be95c 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -43,11 +43,10 @@ def short_id(self, num_components: int = 2) -> str: class ProgramBackendId(_PythonObjectIdMixin, str, enum.Enum): GTFN_CPU = "gt4py.next.program_processors.runners.gtfn.run_gtfn" GTFN_CPU_IMPERATIVE = "gt4py.next.program_processors.runners.gtfn.run_gtfn_imperative" - GTFN_CPU_WITH_TEMPORARIES = ( - "gt4py.next.program_processors.runners.gtfn.run_gtfn_with_temporaries" - ) + GTFN_CPU_NO_TRANSFORMS = "gt4py.next.program_processors.runners.gtfn.run_gtfn_no_transforms" GTFN_GPU = "gt4py.next.program_processors.runners.gtfn.run_gtfn_gpu" ROUNDTRIP = "gt4py.next.program_processors.runners.roundtrip.default" + ROUNDTRIP_NO_TRANSFORMS = "gt4py.next.program_processors.runners.roundtrip.no_transforms" GTIR_EMBEDDED = "gt4py.next.program_processors.runners.roundtrip.gtir" ROUNDTRIP_WITH_TEMPORARIES = "gt4py.next.program_processors.runners.roundtrip.with_temporaries" DOUBLE_ROUNDTRIP = "gt4py.next.program_processors.runners.double_roundtrip.backend" @@ -127,6 +126,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): # Common list of feature markers to skip COMMON_SKIP_TEST_LIST = [ (REQUIRES_ATLAS, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), + (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), (STARTS_FROM_GTIR_PROGRAM, SKIP, UNSUPPORTED_MESSAGE), (USES_APPLIED_SHIFTS, XFAIL, UNSUPPORTED_MESSAGE), (USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE), @@ -181,12 +181,13 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): + [(USES_SCAN_NESTED, XFAIL, UNSUPPORTED_MESSAGE)], ProgramBackendId.GTFN_GPU: GTFN_SKIP_TEST_LIST + [(USES_SCAN_NESTED, XFAIL, UNSUPPORTED_MESSAGE)], - ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES: GTFN_SKIP_TEST_LIST - + [(USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE)], ProgramFormatterId.GTFN_CPP_FORMATTER: [ (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE) ], - ProgramBackendId.ROUNDTRIP: [(USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE)], + ProgramBackendId.ROUNDTRIP: [ + (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), + (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), + ], ProgramBackendId.ROUNDTRIP_WITH_TEMPORARIES: [ (ALL, XFAIL, UNSUPPORTED_MESSAGE), (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_if_stmt.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_if_stmt.py index 2dde7d7653..c38a29bc61 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_if_stmt.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_if_stmt.py @@ -28,7 +28,7 @@ from gt4py.next.iterator.runtime import set_at, if_stmt, fendef, fundef, offset from gt4py.next.program_processors.runners import gtfn -from next_tests.unit_tests.conftest import program_processor, run_processor +from next_tests.unit_tests.conftest import program_processor_no_transforms, run_processor i = offset("i") @@ -43,8 +43,8 @@ def multiply(alpha, inp): @pytest.mark.uses_ir_if_stmts @pytest.mark.parametrize("cond", [True, False]) -def test_if_stmt(program_processor, cond): - program_processor, validate = program_processor +def test_if_stmt(program_processor_no_transforms, cond): + program_processor, validate = program_processor_no_transforms size = 10 @fendef(offset_provider={"i": IDim}) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py index a56cf694f2..961e536cc6 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py @@ -91,6 +91,10 @@ def fen_solve_tridiag2(i_size, j_size, k_size, a, b, c, d, x): @pytest.mark.parametrize("fencil", [fen_solve_tridiag, fen_solve_tridiag2]) def test_tridiag(fencil, tridiag_reference, program_processor): program_processor, validate = program_processor + + if "dace" in program_processor.name: + pytest.xfail("Dace ITIR backend doesn't support the IR format used in this test.") + a, b, c, d, x = tridiag_reference shape = a.shape as_3d_field = gtx.as_field.partial([IDim, JDim, KDim]) diff --git a/tests/next_tests/unit_tests/conftest.py b/tests/next_tests/unit_tests/conftest.py index 87cdafc025..ca66b45d6d 100644 --- a/tests/next_tests/unit_tests/conftest.py +++ b/tests/next_tests/unit_tests/conftest.py @@ -25,7 +25,30 @@ ProgramProcessor: TypeAlias = backend.Backend | program_formatter.ProgramFormatter -@pytest.fixture( +def _program_processor(request) -> tuple[ProgramProcessor, bool]: + """ + Fixture creating program processors on-demand for tests. + + Notes: + Check ADR 15 for details on the test-exclusion matrices. + """ + processor_id, is_backend = request.param + if processor_id is None: + return None, is_backend + + processor = processor_id.load() + + for marker, skip_mark, msg in next_tests.definitions.BACKEND_SKIP_TEST_MATRIX.get( + processor_id, [] + ): + if marker == next_tests.definitions.ALL or request.node.get_closest_marker(marker): + skip_mark(msg.format(marker=marker, backend=processor_id)) + + return processor, is_backend + + +program_processor = pytest.fixture( + _program_processor, params=[ (None, True), (next_tests.definitions.ProgramBackendId.ROUNDTRIP, True), @@ -49,26 +72,16 @@ ], ids=lambda p: p[0].short_id() if p[0] is not None else "None", ) -def program_processor(request) -> tuple[ProgramProcessor, bool]: - """ - Fixture creating program processors on-demand for tests. - - Notes: - Check ADR 15 for details on the test-exclusion matrices. - """ - processor_id, is_backend = request.param - if processor_id is None: - return None, is_backend - - processor = processor_id.load() - - for marker, skip_mark, msg in next_tests.definitions.BACKEND_SKIP_TEST_MATRIX.get( - processor_id, [] - ): - if marker == next_tests.definitions.ALL or request.node.get_closest_marker(marker): - skip_mark(msg.format(marker=marker, backend=processor_id)) - return processor, is_backend +program_processor_no_transforms = pytest.fixture( + _program_processor, + params=[ + (None, True), + (next_tests.definitions.ProgramBackendId.GTFN_CPU_NO_TRANSFORMS, True), + (next_tests.definitions.ProgramBackendId.ROUNDTRIP_NO_TRANSFORMS, True), + ], + ids=lambda p: p[0].short_id() if p[0] is not None else "None", +) def run_processor( From dd5bfa739932a09e3eee04fbab324eaac8b95aba Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 5 Nov 2024 10:20:46 +0100 Subject: [PATCH 063/150] Fix tests --- .../feature_tests/iterator_tests/test_builtins.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py index c2f72e4ca7..3fc4ed9945 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py @@ -45,8 +45,9 @@ plus, shift, xor_, + as_fieldop, ) -from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.iterator.runtime import set_at, closure, fendef, fundef, offset from gt4py.next.program_processors.runners.gtfn import run_gtfn from next_tests.integration_tests.feature_tests.math_builtin_test_data import math_builtin_test_data @@ -87,7 +88,9 @@ def dispatch(arg0): @fendef(offset_provider={}, column_axis=column_axis) def fenimpl(size, arg0, out): - closure(cartesian_domain(named_range(IDim, 0, size)), dispatch, out, [arg0]) + domain = cartesian_domain(named_range(IDim, 0, size)) + + set_at(as_fieldop(dispatch, domain)(arg0), domain, out) elif len(inps) == 2: @@ -102,7 +105,9 @@ def dispatch(arg0, arg1): @fendef(offset_provider={}, column_axis=column_axis) def fenimpl(size, arg0, arg1, out): - closure(cartesian_domain(named_range(IDim, 0, size)), dispatch, out, [arg0, arg1]) + domain = cartesian_domain(named_range(IDim, 0, size)) + + set_at(as_fieldop(dispatch, domain)(arg0, arg1), domain, out) elif len(inps) == 3: @@ -117,7 +122,9 @@ def dispatch(arg0, arg1, arg2): @fendef(offset_provider={}, column_axis=column_axis) def fenimpl(size, arg0, arg1, arg2, out): - closure(cartesian_domain(named_range(IDim, 0, size)), dispatch, out, [arg0, arg1, arg2]) + domain = cartesian_domain(named_range(IDim, 0, size)) + + set_at(as_fieldop(dispatch, domain)(arg0, arg1, arg2), domain, out) else: raise AssertionError("Add overload.") From 8f1e84abc6b65dcd197c0cc3af01ffa1c6ac8255 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 5 Nov 2024 10:25:28 +0100 Subject: [PATCH 064/150] Fix tests --- src/gt4py/next/iterator/transforms/infer_domain.py | 2 +- src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index 4536bae632..37f61f4e78 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -338,7 +338,7 @@ def _infer_expr( elif ( cpm.is_call_to(expr, itir.ARITHMETIC_BUILTINS) or cpm.is_call_to(expr, itir.TYPEBUILTINS) - or cpm.is_call_to(expr, ("cast_", "unstructured_domain", "cartesian_domain")) + or cpm.is_call_to(expr, ("cast_", "index", "unstructured_domain", "cartesian_domain")) ): return expr, {} else: diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py index cc229b6652..d6c1542429 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py @@ -179,10 +179,12 @@ def _arg_validator( ) -> None: for inp in inputs: if not _is_tuple_expr_of( - lambda expr: isinstance(expr, (SymRef, SidComposite, SidFromScalar)), inp + lambda expr: isinstance(expr, (SymRef, SidComposite, SidFromScalar)) + or (isinstance(expr, FunCall) and isinstance(expr.fun, SymRef) and expr.fun.id == "index"), + inp, ): raise ValueError( - "Only 'SymRef', 'SidComposite', 'SidFromScalar' or tuple expr thereof allowed." + "Only 'SymRef', 'SidComposite', 'SidFromScalar', 'index' call or tuple expr thereof allowed." ) From d53d3bbd6cedfbac37bb06bb8d93d88b93b3f97f Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 5 Nov 2024 10:25:38 +0100 Subject: [PATCH 065/150] Fix tests --- src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py index d6c1542429..85a100a88d 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py @@ -180,7 +180,11 @@ def _arg_validator( for inp in inputs: if not _is_tuple_expr_of( lambda expr: isinstance(expr, (SymRef, SidComposite, SidFromScalar)) - or (isinstance(expr, FunCall) and isinstance(expr.fun, SymRef) and expr.fun.id == "index"), + or ( + isinstance(expr, FunCall) + and isinstance(expr.fun, SymRef) + and expr.fun.id == "index" + ), inp, ): raise ValueError( From 4bfef546a37ec25238ae75df4d2f3159ab7903dc Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 5 Nov 2024 10:43:52 +0100 Subject: [PATCH 066/150] Fix tests --- tests/next_tests/definitions.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index eead36be19..5ef16e7511 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -160,6 +160,10 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): UNSUPPORTED_MESSAGE, ), # we can't extract the field type from scan args ] +ROUNDTRIP_SKIP_LIST = [ + (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), + (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), +] GTFN_SKIP_TEST_LIST = COMMON_SKIP_TEST_LIST + [ # floordiv not yet supported, see https://github.com/GridTools/gt4py/issues/1136 (USES_FLOORDIV, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), @@ -184,20 +188,16 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): ProgramBackendId.GTFN_GPU: GTFN_SKIP_TEST_LIST + [(USES_SCAN_NESTED, XFAIL, UNSUPPORTED_MESSAGE)], ProgramFormatterId.GTFN_CPP_FORMATTER: [ - (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE) - ], - ProgramBackendId.ROUNDTRIP: [ - (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE), ], - ProgramBackendId.GTIR_EMBEDDED: [ - (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), - (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), - ], - ProgramBackendId.ROUNDTRIP_WITH_TEMPORARIES: [ + ProgramFormatterId.LISP_FORMATTER: [(USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE)], + ProgramBackendId.ROUNDTRIP: ROUNDTRIP_SKIP_LIST, + ProgramBackendId.DOUBLE_ROUNDTRIP: ROUNDTRIP_SKIP_LIST, + ProgramBackendId.ROUNDTRIP_WITH_TEMPORARIES: ROUNDTRIP_SKIP_LIST + + [ (ALL, XFAIL, UNSUPPORTED_MESSAGE), - (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), - (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, UNSUPPORTED_MESSAGE), ], + ProgramBackendId.GTIR_EMBEDDED: ROUNDTRIP_SKIP_LIST, } From 78b7a98ac9fd31c5700e80489b7b099357fa6243 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 5 Nov 2024 10:59:25 +0100 Subject: [PATCH 067/150] Fix tests --- tests/next_tests/definitions.py | 13 ++++++++----- .../iterator_tests/test_vertical_advection.py | 3 ++- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 5ef16e7511..59f3a152b2 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -160,9 +160,12 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): UNSUPPORTED_MESSAGE, ), # we can't extract the field type from scan args ] -ROUNDTRIP_SKIP_LIST = [ - (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), +DOMAIN_INFERENCE_SKIP_LIST = [ (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, UNSUPPORTED_MESSAGE), +] +ROUNDTRIP_SKIP_LIST = DOMAIN_INFERENCE_SKIP_LIST + [ + (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), ] GTFN_SKIP_TEST_LIST = COMMON_SKIP_TEST_LIST + [ # floordiv not yet supported, see https://github.com/GridTools/gt4py/issues/1136 @@ -187,11 +190,11 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): + [(USES_SCAN_NESTED, XFAIL, UNSUPPORTED_MESSAGE)], ProgramBackendId.GTFN_GPU: GTFN_SKIP_TEST_LIST + [(USES_SCAN_NESTED, XFAIL, UNSUPPORTED_MESSAGE)], - ProgramFormatterId.GTFN_CPP_FORMATTER: [ - (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), + ProgramFormatterId.GTFN_CPP_FORMATTER: DOMAIN_INFERENCE_SKIP_LIST + + [ (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE), ], - ProgramFormatterId.LISP_FORMATTER: [(USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE)], + ProgramFormatterId.LISP_FORMATTER: DOMAIN_INFERENCE_SKIP_LIST, ProgramBackendId.ROUNDTRIP: ROUNDTRIP_SKIP_LIST, ProgramBackendId.DOUBLE_ROUNDTRIP: ROUNDTRIP_SKIP_LIST, ProgramBackendId.ROUNDTRIP_WITH_TEMPORARIES: ROUNDTRIP_SKIP_LIST diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py index 961e536cc6..30ceaf9376 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py @@ -10,6 +10,7 @@ import pytest import gt4py.next as gtx +from gt4py.next import backend from gt4py.next.iterator.builtins import * from gt4py.next.iterator.runtime import set_at, fendef, fundef from gt4py.next.program_processors.formatters import gtfn as gtfn_formatters @@ -92,7 +93,7 @@ def fen_solve_tridiag2(i_size, j_size, k_size, a, b, c, d, x): def test_tridiag(fencil, tridiag_reference, program_processor): program_processor, validate = program_processor - if "dace" in program_processor.name: + if isinstance(program_processor, backend.Backend) and "dace" in program_processor.name: pytest.xfail("Dace ITIR backend doesn't support the IR format used in this test.") a, b, c, d, x = tridiag_reference From 88a76607735a8480d03ca52f347b8259283bd4bf Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 5 Nov 2024 11:27:24 +0100 Subject: [PATCH 068/150] Fix tests --- tests/next_tests/definitions.py | 65 +++++++++++-------- .../transforms_tests/test_collapse_tuple.py | 2 +- 2 files changed, 38 insertions(+), 29 deletions(-) diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 59f3a152b2..47baba57d2 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -133,25 +133,34 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE), (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), ] -DACE_SKIP_TEST_LIST = COMMON_SKIP_TEST_LIST + [ - (USES_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_IR_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_SCALAR_IN_DOMAIN_AND_FO, XFAIL, UNSUPPORTED_MESSAGE), - (USES_INDEX_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_LIFT_EXPRESSIONS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_ORIGIN, XFAIL, UNSUPPORTED_MESSAGE), - (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), - (USES_TUPLE_ARGS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_TUPLE_RETURNS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_ZERO_DIMENSIONAL_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_INDEX_BUILTIN, XFAIL, UNSUPPORTED_MESSAGE), - (STARTS_FROM_GTIR_PROGRAM, SKIP, UNSUPPORTED_MESSAGE), - (USES_SCAN_IN_FIELD_OPERATOR, XFAIL, UNSUPPORTED_MESSAGE), +# Markers to skip because of missing features in the domain inference +DOMAIN_INFERENCE_SKIP_LIST = [ + (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, UNSUPPORTED_MESSAGE), ] -GTIR_DACE_SKIP_TEST_LIST = [ +DACE_SKIP_TEST_LIST = ( + COMMON_SKIP_TEST_LIST + + DOMAIN_INFERENCE_SKIP_LIST + + [ + (USES_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_IR_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_SCALAR_IN_DOMAIN_AND_FO, XFAIL, UNSUPPORTED_MESSAGE), + (USES_INDEX_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_LIFT_EXPRESSIONS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_ORIGIN, XFAIL, UNSUPPORTED_MESSAGE), + (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), + (USES_TUPLE_ARGS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_TUPLE_RETURNS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_ZERO_DIMENSIONAL_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_INDEX_BUILTIN, XFAIL, UNSUPPORTED_MESSAGE), + (STARTS_FROM_GTIR_PROGRAM, SKIP, UNSUPPORTED_MESSAGE), + (USES_SCAN_IN_FIELD_OPERATOR, XFAIL, UNSUPPORTED_MESSAGE), + ] +) +GTIR_DACE_SKIP_TEST_LIST = DOMAIN_INFERENCE_SKIP_LIST + [ (ALL, SKIP, UNSUPPORTED_MESSAGE), ] -EMBEDDED_SKIP_LIST = [ +EMBEDDED_SKIP_LIST = DOMAIN_INFERENCE_SKIP_LIST + [ (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), (CHECKS_SPECIFIC_ERROR, XFAIL, UNSUPPORTED_MESSAGE), ( @@ -160,21 +169,21 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): UNSUPPORTED_MESSAGE, ), # we can't extract the field type from scan args ] -DOMAIN_INFERENCE_SKIP_LIST = [ - (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, UNSUPPORTED_MESSAGE), -] ROUNDTRIP_SKIP_LIST = DOMAIN_INFERENCE_SKIP_LIST + [ (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), ] -GTFN_SKIP_TEST_LIST = COMMON_SKIP_TEST_LIST + [ - # floordiv not yet supported, see https://github.com/GridTools/gt4py/issues/1136 - (USES_FLOORDIV, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), - (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), - # max_over broken, see https://github.com/GridTools/gt4py/issues/1289 - (USES_MAX_OVER, XFAIL, UNSUPPORTED_MESSAGE), - (USES_SCAN_REQUIRING_PROJECTOR, XFAIL, UNSUPPORTED_MESSAGE), -] +GTFN_SKIP_TEST_LIST = ( + COMMON_SKIP_TEST_LIST + + DOMAIN_INFERENCE_SKIP_LIST + + [ + # floordiv not yet supported, see https://github.com/GridTools/gt4py/issues/1136 + (USES_FLOORDIV, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), + (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), + # max_over broken, see https://github.com/GridTools/gt4py/issues/1289 + (USES_MAX_OVER, XFAIL, UNSUPPORTED_MESSAGE), + (USES_SCAN_REQUIRING_PROJECTOR, XFAIL, UNSUPPORTED_MESSAGE), + ] +) #: Skip matrix, contains for each backend processor a list of tuples with following fields: #: (, ) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py index 457e41cb29..4d8b9a1c0a 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py @@ -237,5 +237,5 @@ def test_tuple_get_on_untyped_ref(): # test pass gracefully handles untyped nodes. testee = im.tuple_get(0, im.ref("val", ts.DeferredType(constraint=None))) - actual = CollapseTuple.apply(testee, allow_undeclared_symbols=True) + actual = CollapseTuple.apply(testee, allow_undeclared_symbols=True, is_local_view=False) assert actual == testee From 1b79b3a9611033a621fde43638469cabe2c3ff01 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 5 Nov 2024 11:27:46 +0100 Subject: [PATCH 069/150] Fix tests --- tests/next_tests/definitions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 47baba57d2..1302a461dc 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -160,7 +160,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): GTIR_DACE_SKIP_TEST_LIST = DOMAIN_INFERENCE_SKIP_LIST + [ (ALL, SKIP, UNSUPPORTED_MESSAGE), ] -EMBEDDED_SKIP_LIST = DOMAIN_INFERENCE_SKIP_LIST + [ +EMBEDDED_SKIP_LIST = [ (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), (CHECKS_SPECIFIC_ERROR, XFAIL, UNSUPPORTED_MESSAGE), ( From b6b603e32baefd59bf13e582bf233fd523821c55 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 5 Nov 2024 14:19:48 +0100 Subject: [PATCH 070/150] Bump CI version to 22.04 --- ci/base.Dockerfile | 2 +- src/gt4py/next/backend.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/ci/base.Dockerfile b/ci/base.Dockerfile index d20d9ca6ef..68fb4b6b3c 100644 --- a/ci/base.Dockerfile +++ b/ci/base.Dockerfile @@ -1,5 +1,5 @@ ARG CUDA_VERSION=12.5.0 -FROM docker.io/nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 +FROM docker.io/nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 ENV LANG C.UTF-8 ENV LC_ALL C.UTF-8 diff --git a/src/gt4py/next/backend.py b/src/gt4py/next/backend.py index 017b7324cc..e223d7771c 100644 --- a/src/gt4py/next/backend.py +++ b/src/gt4py/next/backend.py @@ -136,6 +136,7 @@ def step_order(self, inp: INPUT_PAIR) -> list[str]: DEFAULT_TRANSFORMS: Transforms = Transforms() # FIXME[#1582](havogt): remove after refactoring to GTIR +# note: this step is deliberately placed here, such that the cache is shared _foast_to_itir_step = foast_to_itir.adapted_foast_to_itir_factory(cached=True) LEGACY_TRANSFORMS: Transforms = Transforms( past_to_itir=past_to_itir.past_to_itir_factory(to_gtir=False), From 5beccf087f7288c04e3b5dc34e69dd8a0023920b Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 5 Nov 2024 14:36:19 +0100 Subject: [PATCH 071/150] Bump CI version to 22.04 --- ci/base.Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/base.Dockerfile b/ci/base.Dockerfile index 68fb4b6b3c..77ac401546 100644 --- a/ci/base.Dockerfile +++ b/ci/base.Dockerfile @@ -1,4 +1,4 @@ -ARG CUDA_VERSION=12.5.0 +ARG CUDA_VERSION=12.6.2 FROM docker.io/nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 ENV LANG C.UTF-8 ENV LC_ALL C.UTF-8 From deca907e1d36a25249643705036a5051030af383 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 5 Nov 2024 14:42:41 +0100 Subject: [PATCH 072/150] Bump CI version to 22.04 --- ci/cscs-ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ci/cscs-ci.yml b/ci/cscs-ci.yml index 7fcd65106d..e2a358dac4 100644 --- a/ci/cscs-ci.yml +++ b/ci/cscs-ci.yml @@ -46,12 +46,12 @@ stages: .build_baseimage_x86_64: extends: [.container-builder-cscs-zen2, .build_baseimage] variables: - CUDA_VERSION: 11.2.2 + CUDA_VERSION: 11.7.1 CUPY_PACKAGE: cupy-cuda11x .build_baseimage_aarch64: extends: [.container-builder-cscs-gh200, .build_baseimage] variables: - CUDA_VERSION: 12.4.1 + CUDA_VERSION: 12.6.2 CUPY_PACKAGE: cupy-cuda12x # TODO: enable CI job when Todi is back in operational state when: manual From af9d7764fa2aa5d9933f5b9a7a1897bd1a5f0f1b Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 5 Nov 2024 14:45:30 +0100 Subject: [PATCH 073/150] Bump CI version to 22.04 --- ci/base.Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/base.Dockerfile b/ci/base.Dockerfile index 77ac401546..eb7e0b09ea 100644 --- a/ci/base.Dockerfile +++ b/ci/base.Dockerfile @@ -22,7 +22,7 @@ RUN apt-get update -qq && apt-get install -qq -y --no-install-recommends \ tk-dev \ libffi-dev \ liblzma-dev \ - python-openssl \ + python3-openssl \ libreadline-dev \ git \ rustc \ From ee0b94ab5b60731358b31aa92077f1a571cc6639 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 5 Nov 2024 23:57:15 +0100 Subject: [PATCH 074/150] Fix failing dace tests --- tests/next_tests/definitions.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 002b1ebfcd..b3d491f989 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -143,6 +143,8 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): COMMON_SKIP_TEST_LIST + DOMAIN_INFERENCE_SKIP_LIST + [ + (USES_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_SCAN_IN_FIELD_OPERATOR, XFAIL, UNSUPPORTED_MESSAGE), (USES_IR_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE), (USES_SCALAR_IN_DOMAIN_AND_FO, XFAIL, UNSUPPORTED_MESSAGE), (USES_INDEX_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), From 270e1734b365b586cf158b975dd17c79f2ea4aec Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 6 Nov 2024 00:58:17 +0100 Subject: [PATCH 075/150] Address review comments --- .../next/iterator/ir_utils/domain_utils.py | 2 + .../iterator/transforms/collapse_tuple.py | 14 +-- src/gt4py/next/iterator/transforms/cse.py | 24 ++--- .../next/iterator/transforms/infer_domain.py | 25 +++--- .../next/iterator/transforms/pass_manager.py | 3 +- .../transforms_tests/test_collapse_tuple.py | 32 +++---- .../transforms_tests/test_cse.py | 26 +++--- .../transforms_tests/test_unroll_reduce.py | 87 ++++--------------- 8 files changed, 86 insertions(+), 127 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index e25508e279..8f842e1c13 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -93,6 +93,8 @@ def translate( ..., ], offset_provider: common.OffsetProvider, + #: A dictionary mapping axes names to their length. See + #: func:`gt4py.next.iterator.transforms.infer_domain.infer_expr` for more details. symbolic_domain_sizes: Optional[dict[str, str]] = None, ) -> SymbolicDomain: dims = list(self.ranges.keys()) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index b498290735..4a7f070ac7 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -106,7 +106,7 @@ def apply( ignore_tuple_size: bool = False, remove_letified_make_tuple_elements: bool = True, offset_provider: Optional[common.OffsetProvider] = None, - is_local_view: Optional[bool] = None, + within_stencil: Optional[bool] = None, # manually passing flags is mostly for allowing separate testing of the modes flags: Optional[Flag] = None, # allow sym references without a symbol declaration, mostly for testing @@ -129,11 +129,11 @@ def apply( offset_provider = offset_provider or {} if isinstance(node, (ir.Program, ir.FencilDefinition)): - is_local_view = False - assert is_local_view in [ + within_stencil = False + assert within_stencil in [ True, False, - ], "Parameter 'is_local_view' mandatory if node is not a 'Program'." + ], "Parameter 'within_stencil' mandatory if node is not a 'Program'." if not ignore_tuple_size: node = itir_type_inference.infer( @@ -145,7 +145,7 @@ def apply( new_node = cls( ignore_tuple_size=ignore_tuple_size, flags=flags, - ).visit(node, is_local_view=is_local_view) + ).visit(node, within_stencil=within_stencil) # inline to remove left-overs from LETIFY_MAKE_TUPLE_ELEMENTS. this is important # as otherwise two equal expressions containing a tuple will not be equal anymore @@ -161,7 +161,7 @@ def apply( def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: if cpm.is_call_to(node, "as_fieldop"): - kwargs = {**kwargs, "is_local_view": True} + kwargs = {**kwargs, "within_stencil": True} node = self.generic_visit(node, **kwargs) return self.fp_transform(node, **kwargs) @@ -291,7 +291,7 @@ def transform_inline_trivial_make_tuple(self, node: ir.FunCall, **kwargs) -> Opt def transform_propagate_to_if_on_tuples(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: # TODO(tehrengruber): This significantly increases the size of the tree. Skip transformation # in local-view for now. Revisit. - if not cpm.is_call_to(node, "if_") and not kwargs["is_local_view"]: + if not cpm.is_call_to(node, "if_") and not kwargs["within_stencil"]: # TODO(tehrengruber): Only inline if type of branch value is a tuple. # Examples: # `(if cond then {1, 2} else {3, 4})[0]` -> `if cond then {1, 2}[0] else {3, 4}[0]` diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index ecdde572dc..8a0c35cfc7 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -388,7 +388,7 @@ class CommonSubexpressionElimination(PreserveLocationVisitor, NodeTranslator): >>> x = itir.SymRef(id="x") >>> plus = lambda a, b: itir.FunCall(fun=itir.SymRef(id=("plus")), args=[a, b]) >>> expr = plus(plus(x, x), plus(x, x)) - >>> print(CommonSubexpressionElimination.apply(expr, is_local_view=True)) + >>> print(CommonSubexpressionElimination.apply(expr, within_stencil=True)) (λ(_cs_1) → _cs_1 + _cs_1)(x + x) The pass visits the tree top-down starting from the root node, e.g. an itir.Program. @@ -410,33 +410,33 @@ class CommonSubexpressionElimination(PreserveLocationVisitor, NodeTranslator): def apply( cls, node: ProgramOrExpr, - is_local_view: bool | None = None, + within_stencil: bool | None = None, offset_provider: common.OffsetProvider | None = None, ) -> ProgramOrExpr: is_program = isinstance(node, (itir.Program, itir.FencilDefinition)) if is_program: - assert is_local_view is None - is_local_view = False + assert within_stencil is None + within_stencil = False else: assert ( - is_local_view is not None - ), "The expression's context must be specified using `is_local_view`." + within_stencil is not None + ), "The expression's context must be specified using `within_stencil`." offset_provider = offset_provider or {} node = itir_type_inference.infer( node, offset_provider=offset_provider, allow_undeclared_symbols=not is_program ) - return cls().visit(node, is_local_view=is_local_view) + return cls().visit(node, within_stencil=within_stencil) def generic_visit(self, node, **kwargs): if cpm.is_call_to("as_fieldop", node): - assert not kwargs.get("is_local_view") - is_local_view = cpm.is_call_to("as_fieldop", node) or kwargs.get("is_local_view") + assert not kwargs.get("within_stencil") + within_stencil = cpm.is_call_to("as_fieldop", node) or kwargs.get("within_stencil") - return super().generic_visit(node, **(kwargs | {"is_local_view": is_local_view})) + return super().generic_visit(node, **(kwargs | {"within_stencil": within_stencil})) def visit_FunCall(self, node: itir.FunCall, **kwargs): - is_local_view = kwargs["is_local_view"] + within_stencil = kwargs["within_stencil"] if cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")): return node @@ -446,7 +446,7 @@ def predicate(subexpr: itir.Expr, num_occurences: int): # view, even though the syntactic context `node` is in field view. # note: what is extracted is sketched in the docstring above. keep it updated. if num_occurences > 1: - if is_local_view: + if within_stencil: return True else: # only extract fields outside of `as_fieldop` diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index 37f61f4e78..6852b47a7a 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -141,7 +141,7 @@ def _extract_accessed_domains( return typing.cast(ACCESSED_DOMAINS, accessed_domains) -def infer_as_fieldop( +def _infer_as_fieldop( applied_fieldop: itir.FunCall, target_domain: DOMAIN, offset_provider: common.OffsetProvider, @@ -204,7 +204,7 @@ def infer_as_fieldop( return transformed_call, accessed_domains_without_tmp -def infer_let( +def _infer_let( let_expr: itir.FunCall, input_domain: DOMAIN, offset_provider: common.OffsetProvider, @@ -245,7 +245,7 @@ def infer_let( return transformed_call, accessed_domains_outer -def infer_make_tuple( +def _infer_make_tuple( expr: itir.Expr, domain: DOMAIN, offset_provider: common.OffsetProvider, @@ -274,7 +274,7 @@ def infer_make_tuple( return result_expr, actual_domains -def infer_tuple_get( +def _infer_tuple_get( expr: itir.Expr, domain: DOMAIN, offset_provider: common.OffsetProvider, @@ -295,7 +295,7 @@ def infer_tuple_get( return infered_args_expr, actual_domains -def infer_if( +def _infer_if( expr: itir.Expr, domain: DOMAIN, offset_provider: common.OffsetProvider, @@ -326,15 +326,15 @@ def _infer_expr( elif isinstance(expr, itir.Literal): return expr, {} elif cpm.is_applied_as_fieldop(expr): - return infer_as_fieldop(expr, domain, offset_provider, symbolic_domain_sizes) + return _infer_as_fieldop(expr, domain, offset_provider, symbolic_domain_sizes) elif cpm.is_let(expr): - return infer_let(expr, domain, offset_provider, symbolic_domain_sizes) + return _infer_let(expr, domain, offset_provider, symbolic_domain_sizes) elif cpm.is_call_to(expr, "make_tuple"): - return infer_make_tuple(expr, domain, offset_provider, symbolic_domain_sizes) + return _infer_make_tuple(expr, domain, offset_provider, symbolic_domain_sizes) elif cpm.is_call_to(expr, "tuple_get"): - return infer_tuple_get(expr, domain, offset_provider, symbolic_domain_sizes) + return _infer_tuple_get(expr, domain, offset_provider, symbolic_domain_sizes) elif cpm.is_call_to(expr, "if_"): - return infer_if(expr, domain, offset_provider, symbolic_domain_sizes) + return _infer_if(expr, domain, offset_provider, symbolic_domain_sizes) elif ( cpm.is_call_to(expr, itir.ARITHMETIC_BUILTINS) or cpm.is_call_to(expr, itir.TYPEBUILTINS) @@ -409,6 +409,11 @@ def infer_program( offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]] = None, ) -> itir.Program: + """ + Infer the domain of all field subexpressions inside a program. + + See :func:`infer_expr` for more details. + """ assert ( not program.function_definitions ), "Domain propagation does not support function definitions." diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 108e642bbc..52a452155a 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -53,7 +53,8 @@ def apply_common_transforms( temporary_extraction_heuristics: Optional[ Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] ] = None, - #: A dictionary mapping axes names, e.g., `I`, `Vertex`, to a symbol name that evaluates to the length of that axis. + #: A dictionary mapping axes names to their length. See :func:`infer_domain.infer_expr` for + #: more details. symbolic_domain_sizes: Optional[dict[str, str]] = None, ) -> itir.Program: # FIXME[#1582](tehrengruber): Rewrite iterator tests with itir.Program and remove this diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py index 4d8b9a1c0a..28090ff1e2 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py @@ -20,7 +20,7 @@ def test_simple_make_tuple_tuple_get(): remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, allow_undeclared_symbols=True, - is_local_view=False, + within_stencil=False, ) expected = tuple_of_size_2 @@ -38,7 +38,7 @@ def test_nested_make_tuple_tuple_get(): remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, allow_undeclared_symbols=True, - is_local_view=False, + within_stencil=False, ) assert actual == tup_of_size2_from_lambda @@ -54,7 +54,7 @@ def test_different_tuples_make_tuple_tuple_get(): remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, allow_undeclared_symbols=True, - is_local_view=False, + within_stencil=False, ) assert actual == testee # did nothing @@ -68,7 +68,7 @@ def test_incompatible_order_make_tuple_tuple_get(): remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, allow_undeclared_symbols=True, - is_local_view=False, + within_stencil=False, ) assert actual == testee # did nothing @@ -80,7 +80,7 @@ def test_incompatible_size_make_tuple_tuple_get(): remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, allow_undeclared_symbols=True, - is_local_view=False, + within_stencil=False, ) assert actual == testee # did nothing @@ -92,7 +92,7 @@ def test_merged_with_smaller_outer_size_make_tuple_tuple_get(): ignore_tuple_size=True, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, allow_undeclared_symbols=True, - is_local_view=False, + within_stencil=False, ) assert actual == im.make_tuple("first", "second") @@ -105,7 +105,7 @@ def test_simple_tuple_get_make_tuple(): remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.COLLAPSE_TUPLE_GET_MAKE_TUPLE, allow_undeclared_symbols=True, - is_local_view=False, + within_stencil=False, ) assert expected == actual @@ -118,7 +118,7 @@ def test_propagate_tuple_get(): remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.PROPAGATE_TUPLE_GET, allow_undeclared_symbols=True, - is_local_view=False, + within_stencil=False, ) assert expected == actual @@ -136,7 +136,7 @@ def test_letify_make_tuple_elements(): remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS, allow_undeclared_symbols=True, - is_local_view=False, + within_stencil=False, ) assert actual == expected @@ -150,7 +150,7 @@ def test_letify_make_tuple_with_trivial_elements(): remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS, allow_undeclared_symbols=True, - is_local_view=False, + within_stencil=False, ) assert actual == expected @@ -164,7 +164,7 @@ def test_inline_trivial_make_tuple(): remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.INLINE_TRIVIAL_MAKE_TUPLE, allow_undeclared_symbols=True, - is_local_view=False, + within_stencil=False, ) assert actual == expected @@ -183,7 +183,7 @@ def test_propagate_to_if_on_tuples(): remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, allow_undeclared_symbols=True, - is_local_view=False, + within_stencil=False, ) assert actual == expected @@ -201,7 +201,7 @@ def test_propagate_to_if_on_tuples_with_let(): flags=CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES | CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS, allow_undeclared_symbols=True, - is_local_view=False, + within_stencil=False, ) assert actual == expected @@ -214,7 +214,7 @@ def test_propagate_nested_lift(): remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.PROPAGATE_NESTED_LET, allow_undeclared_symbols=True, - is_local_view=False, + within_stencil=False, ) assert actual == expected @@ -228,7 +228,7 @@ def test_if_on_tuples_with_let(): testee, remove_letified_make_tuple_elements=False, allow_undeclared_symbols=True, - is_local_view=False, + within_stencil=False, ) assert actual == expected @@ -237,5 +237,5 @@ def test_tuple_get_on_untyped_ref(): # test pass gracefully handles untyped nodes. testee = im.tuple_get(0, im.ref("val", ts.DeferredType(constraint=None))) - actual = CollapseTuple.apply(testee, allow_undeclared_symbols=True, is_local_view=False) + actual = CollapseTuple.apply(testee, allow_undeclared_symbols=True, within_stencil=False) assert actual == testee diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py index 3204b49371..e04856b75f 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py @@ -37,7 +37,7 @@ def test_trivial(): ), args=[common], ) - actual = CSE.apply(testee, is_local_view=True) + actual = CSE.apply(testee, within_stencil=True) assert actual == expected @@ -45,7 +45,7 @@ def test_lambda_capture(): common = ir.FunCall(fun=ir.SymRef(id="plus"), args=[ir.SymRef(id="x"), ir.SymRef(id="y")]) testee = ir.FunCall(fun=ir.Lambda(params=[ir.Sym(id="x")], expr=common), args=[common]) expected = testee - actual = CSE.apply(testee, is_local_view=True) + actual = CSE.apply(testee, within_stencil=True) assert actual == expected @@ -53,7 +53,7 @@ def test_lambda_no_capture(): common = im.plus("x", "y") testee = im.call(im.lambda_("z")(im.plus("x", "y")))(im.plus("x", "y")) expected = im.let("_cs_1", common)("_cs_1") - actual = CSE.apply(testee, is_local_view=True) + actual = CSE.apply(testee, within_stencil=True) assert actual == expected @@ -65,7 +65,7 @@ def common_expr(): testee = im.call(im.lambda_("x", "y")(common_expr()))(common_expr(), common_expr()) # (λ(_cs_1) → _cs_1 + _cs_1)(x + y) expected = im.let("_cs_1", common_expr())(im.plus("_cs_1", "_cs_1")) - actual = CSE.apply(testee, is_local_view=True) + actual = CSE.apply(testee, within_stencil=True) assert actual == expected @@ -79,7 +79,7 @@ def common_expr(): expected = im.lambda_("x")( im.let("_cs_1", common_expr())(im.plus("z", im.plus("_cs_1", "_cs_1"))) ) - actual = CSE.apply(testee, is_local_view=True) + actual = CSE.apply(testee, within_stencil=True) assert actual == expected @@ -93,7 +93,7 @@ def common_expr(): ) # (λ(_cs_1) → _cs_1(2) + _cs_1(3))(λ(a) → a + 1) expected = im.let("_cs_1", common_expr())(im.plus(im.call("_cs_1")(2), im.call("_cs_1")(3))) - actual = CSE.apply(testee, is_local_view=True) + actual = CSE.apply(testee, within_stencil=True) assert actual == expected @@ -109,7 +109,7 @@ def common_expr(): expected = im.let("_cs_1", common_expr())( im.let("_cs_2", im.call("_cs_1")(2))(im.plus("_cs_2", "_cs_2")) ) - actual = CSE.apply(testee, is_local_view=True) + actual = CSE.apply(testee, within_stencil=True) assert actual == expected @@ -133,7 +133,7 @@ def common_expr(): ) ) ) - actual = CSE.apply(testee, is_local_view=True) + actual = CSE.apply(testee, within_stencil=True) assert actual == expected @@ -157,7 +157,7 @@ def test_if_can_deref_no_extraction(offset_provider): ) ) - actual = CSE.apply(testee, offset_provider=offset_provider, is_local_view=True) + actual = CSE.apply(testee, offset_provider=offset_provider, within_stencil=True) assert actual == expected @@ -178,7 +178,7 @@ def test_if_can_deref_eligible_extraction(offset_provider): ) ) - actual = CSE.apply(testee, offset_provider=offset_provider, is_local_view=True) + actual = CSE.apply(testee, offset_provider=offset_provider, within_stencil=True) assert actual == expected @@ -191,7 +191,7 @@ def test_if_eligible_extraction(offset_provider): # (λ(_cs_1) → if _cs_1 ∧ _cs_1 then c else d)(a ∧ b) expected = im.let("_cs_1", im.and_("a", "b"))(im.if_(im.and_("_cs_1", "_cs_1"), "c", "d")) - actual = CSE.apply(testee, offset_provider=offset_provider, is_local_view=True) + actual = CSE.apply(testee, offset_provider=offset_provider, within_stencil=True) assert actual == expected @@ -268,7 +268,7 @@ def test_no_extraction_outside_asfieldop(): identity_fieldop(im.ref("a", field_type)), identity_fieldop(im.ref("b", field_type)) ) - actual = CSE.apply(testee, is_local_view=False) + actual = CSE.apply(testee, within_stencil=False) assert actual == testee @@ -289,5 +289,5 @@ def test_field_extraction_outside_asfieldop(): # ) expected = im.let("_cs_1", identity_fieldop(field))(plus_fieldop("_cs_1", "_cs_1")) - actual = CSE.apply(testee, is_local_view=False) + actual = CSE.apply(testee, within_stencil=False) assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py index 09ed204a91..28bd88b853 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py @@ -26,93 +26,35 @@ def has_skip_values(request): @pytest.fixture def basic_reduction(): UIDs.reset_sequence() - return ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="reduce"), - args=[ir.SymRef(id="foo"), im.literal("0.0", "float64")], - ), - args=[ - ir.FunCall( - fun=ir.SymRef(id="neighbors"), - args=[ir.OffsetLiteral(value="Dim"), ir.SymRef(id="x")], - ) - ], - ) + return im.call(im.call("reduce")("foo", 0.0))(im.neighbors("Dim", "x")) @pytest.fixture def reduction_with_shift_on_second_arg(): UIDs.reset_sequence() - return ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="reduce"), - args=[ir.SymRef(id="foo"), im.literal("0.0", "float64")], - ), - args=[ - ir.SymRef(id="x"), - ir.FunCall( - fun=ir.SymRef(id="neighbors"), - args=[ir.OffsetLiteral(value="Dim"), ir.SymRef(id="y")], - ), - ], - ) + return im.call(im.call("reduce")("foo", 0.0))("x", im.neighbors("Dim", "y")) @pytest.fixture def reduction_with_incompatible_shifts(): UIDs.reset_sequence() - return ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="reduce"), - args=[ir.SymRef(id="foo"), im.literal("0.0", "float64")], - ), - args=[ - ir.FunCall( - fun=ir.SymRef(id="neighbors"), - args=[ir.OffsetLiteral(value="Dim"), ir.SymRef(id="x")], - ), - ir.FunCall( - fun=ir.SymRef(id="neighbors"), - args=[ir.OffsetLiteral(value="Dim2"), ir.SymRef(id="y")], - ), - ], + return im.call(im.call("reduce")("foo", 0.0))( + im.neighbors("Dim", "x"), im.neighbors("Dim2", "y") ) @pytest.fixture def reduction_with_irrelevant_full_shift(): UIDs.reset_sequence() - return ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="reduce"), - args=[ir.SymRef(id="foo"), im.literal("0.0", "float64")], - ), - args=[ - ir.FunCall( - fun=ir.SymRef(id="neighbors"), - args=[ - ir.OffsetLiteral(value="Dim"), - ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="shift"), - args=[ - ir.OffsetLiteral(value="IrrelevantDim"), - ir.OffsetLiteral(value="0"), - ], - ), - args=[ir.SymRef(id="x")], - ), - ], - ), - ir.FunCall( - fun=ir.SymRef(id="neighbors"), - args=[ir.OffsetLiteral(value="Dim"), ir.SymRef(id="y")], - ), - ], + return im.call(im.call("reduce")("foo", 0.0))( + im.neighbors("Dim", im.shift("IrrelevantDim", 0)("x")), im.neighbors("Dim", "y") ) -# TODO add a test with lift +@pytest.fixture +def reduction_if(): + UIDs.reset_sequence() + return im.call(im.call("reduce")("foo", 0.0))(im.if_(True, im.neighbors("Dim", "x"), "y")) @pytest.mark.parametrize( @@ -121,6 +63,7 @@ def reduction_with_irrelevant_full_shift(): "basic_reduction", "reduction_with_irrelevant_full_shift", "reduction_with_shift_on_second_arg", + "reduction_if", ], ) def test_get_partial_offsets(reduction, request): @@ -178,6 +121,14 @@ def test_reduction_with_shift_on_second_arg(reduction_with_shift_on_second_arg, assert actual == expected +def test_reduction_with_if(reduction_if): + expected = _expected(reduction_if, "Dim", 2, False) + + offset_provider = {"Dim": DummyConnectivity(max_neighbors=2, has_skip_values=False)} + actual = UnrollReduce.apply(reduction_if, offset_provider=offset_provider) + assert actual == expected + + def test_reduction_with_irrelevant_full_shift(reduction_with_irrelevant_full_shift): expected = _expected(reduction_with_irrelevant_full_shift, "Dim", 3, False) From 8e2ba0ca9577726656d485aff3da47519c9f0cf4 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 8 Nov 2024 14:45:42 +0100 Subject: [PATCH 076/150] Revert CI changes & skip failing test --- ci/base.Dockerfile | 6 +++--- ci/cscs-ci.yml | 4 ++-- .../ffront_tests/test_temporaries_with_sizes.py | 8 ++++++++ 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/ci/base.Dockerfile b/ci/base.Dockerfile index eb7e0b09ea..d20d9ca6ef 100644 --- a/ci/base.Dockerfile +++ b/ci/base.Dockerfile @@ -1,5 +1,5 @@ -ARG CUDA_VERSION=12.6.2 -FROM docker.io/nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 +ARG CUDA_VERSION=12.5.0 +FROM docker.io/nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 ENV LANG C.UTF-8 ENV LC_ALL C.UTF-8 @@ -22,7 +22,7 @@ RUN apt-get update -qq && apt-get install -qq -y --no-install-recommends \ tk-dev \ libffi-dev \ liblzma-dev \ - python3-openssl \ + python-openssl \ libreadline-dev \ git \ rustc \ diff --git a/ci/cscs-ci.yml b/ci/cscs-ci.yml index e2a358dac4..7fcd65106d 100644 --- a/ci/cscs-ci.yml +++ b/ci/cscs-ci.yml @@ -46,12 +46,12 @@ stages: .build_baseimage_x86_64: extends: [.container-builder-cscs-zen2, .build_baseimage] variables: - CUDA_VERSION: 11.7.1 + CUDA_VERSION: 11.2.2 CUPY_PACKAGE: cupy-cuda11x .build_baseimage_aarch64: extends: [.container-builder-cscs-gh200, .build_baseimage] variables: - CUDA_VERSION: 12.6.2 + CUDA_VERSION: 12.4.1 CUPY_PACKAGE: cupy-cuda12x # TODO: enable CI job when Todi is back in operational state when: manual diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py index bb02f0a89e..b2fb454932 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py @@ -5,6 +5,7 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +import platform import pytest from numpy import int32, int64 @@ -64,6 +65,13 @@ def prog( def test_verification(testee, run_gtfn_with_temporaries_and_symbolic_sizes, mesh_descriptor): + if platform.machine() == 'x86_64': + pytest.xfail(reason="The C++ code generated in this test contains unicode characters " + "(coming from the ssa pass) which is not supported by gcc 9 used" + "in the CI. Bumping the container version sadly did not work for" + "unrelated and unclear reasons. Since the issue is not present" + "on Alps we just skip the test for now before investing more time.") + unstructured_case = Case( run_gtfn_with_temporaries_and_symbolic_sizes, offset_provider=mesh_descriptor.offset_provider, From 580cb79a36def1fd4549b4fdc38ad0a3bcfcd6de Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 8 Nov 2024 15:04:28 +0100 Subject: [PATCH 077/150] Fix format --- .../ffront_tests/test_temporaries_with_sizes.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py index b2fb454932..11e28de9e1 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py @@ -65,12 +65,14 @@ def prog( def test_verification(testee, run_gtfn_with_temporaries_and_symbolic_sizes, mesh_descriptor): - if platform.machine() == 'x86_64': - pytest.xfail(reason="The C++ code generated in this test contains unicode characters " - "(coming from the ssa pass) which is not supported by gcc 9 used" - "in the CI. Bumping the container version sadly did not work for" - "unrelated and unclear reasons. Since the issue is not present" - "on Alps we just skip the test for now before investing more time.") + if platform.machine() == "x86_64": + pytest.xfail( + reason="The C++ code generated in this test contains unicode characters " + "(coming from the ssa pass) which is not supported by gcc 9 used" + "in the CI. Bumping the container version sadly did not work for" + "unrelated and unclear reasons. Since the issue is not present" + "on Alps we just skip the test for now before investing more time." + ) unstructured_case = Case( run_gtfn_with_temporaries_and_symbolic_sizes, From c67d355a191000793e08229ebef26a8978beb390 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 8 Nov 2024 15:29:36 +0100 Subject: [PATCH 078/150] Ugly fix --- src/gt4py/next/program_processors/runners/gtfn.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 7f16aaf5d5..ea88ffc74c 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -166,19 +166,19 @@ class Params: cached_translation = factory.Trait( translation=factory.LazyAttribute( lambda o: workflow.CachedStep( - o.translation_, + o.uncached_translation, hash_function=fingerprint_compilable_program, cache=FileCache(str(config.BUILD_CACHE_DIR / "gtfn_cache")), ) ), ) - translation_ = factory.SubFactory( + uncached_translation = factory.SubFactory( gtfn_module.GTFNTranslationStepFactory, device_type=factory.SelfAttribute("..device_type"), ) - translation = factory.LazyAttribute(lambda o: o.translation_) + translation = factory.LazyAttribute(lambda o: o.uncached_translation) bindings: workflow.Workflow[stages.ProgramSource, stages.CompilableSource] = ( nanobind.bind_source @@ -240,4 +240,4 @@ class Params: run_gtfn_gpu_cached = GTFNBackendFactory(gpu=True, cached=True) -run_gtfn_no_transforms = GTFNBackendFactory(otf_workflow__translation__enable_itir_transforms=False) +run_gtfn_no_transforms = GTFNBackendFactory(otf_workflow__uncached_translation__enable_itir_transforms=False) \ No newline at end of file From 79ab838bb929235457186a10520261a9075e48bb Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 8 Nov 2024 15:49:59 +0100 Subject: [PATCH 079/150] Format --- src/gt4py/next/program_processors/runners/gtfn.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index ea88ffc74c..5028d73663 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -166,19 +166,19 @@ class Params: cached_translation = factory.Trait( translation=factory.LazyAttribute( lambda o: workflow.CachedStep( - o.uncached_translation, + o.bare_translation, hash_function=fingerprint_compilable_program, cache=FileCache(str(config.BUILD_CACHE_DIR / "gtfn_cache")), ) ), ) - uncached_translation = factory.SubFactory( + bare_translation = factory.SubFactory( gtfn_module.GTFNTranslationStepFactory, device_type=factory.SelfAttribute("..device_type"), ) - translation = factory.LazyAttribute(lambda o: o.uncached_translation) + translation = factory.LazyAttribute(lambda o: o.bare_translation) bindings: workflow.Workflow[stages.ProgramSource, stages.CompilableSource] = ( nanobind.bind_source @@ -240,4 +240,6 @@ class Params: run_gtfn_gpu_cached = GTFNBackendFactory(gpu=True, cached=True) -run_gtfn_no_transforms = GTFNBackendFactory(otf_workflow__uncached_translation__enable_itir_transforms=False) \ No newline at end of file +run_gtfn_no_transforms = GTFNBackendFactory( + otf_workflow__uncached_translation__enable_itir_transforms=False +) From d53fda93a7dca7567156867ca9096feb12c015f6 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 8 Nov 2024 15:57:09 +0100 Subject: [PATCH 080/150] Small fix --- src/gt4py/next/program_processors/runners/gtfn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 5028d73663..965c6417b2 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -241,5 +241,5 @@ class Params: run_gtfn_gpu_cached = GTFNBackendFactory(gpu=True, cached=True) run_gtfn_no_transforms = GTFNBackendFactory( - otf_workflow__uncached_translation__enable_itir_transforms=False + otf_workflow__bare_translation__enable_itir_transforms=False ) From 24c3e871628de97fc040e1b31759932f349cf103 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 11 Nov 2024 10:13:10 +0100 Subject: [PATCH 081/150] Bump gridtools-cpp to 2.3.7 in preperation of #1648 --- .pre-commit-config.yaml | 8 ++++---- constraints.txt | 16 ++++++++-------- min-extra-requirements-test.txt | 2 +- min-requirements-test.txt | 2 +- pyproject.toml | 2 +- requirements-dev.txt | 16 ++++++++-------- 6 files changed, 23 insertions(+), 23 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f2f5b73613..93ea4685f4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -51,7 +51,7 @@ repos: ## version = re.search('ruff==([0-9\.]*)', open("constraints.txt").read())[1] ## print(f"rev: v{version}") ##]]] - rev: v0.7.2 + rev: v0.7.3 ##[[[end]]] hooks: # Run the linter. @@ -97,14 +97,14 @@ repos: - boltons==24.1.0 - cached-property==2.0.1 - click==8.1.7 - - cmake==3.30.5 + - cmake==3.31.0.1 - cytoolz==1.0.0 - deepdiff==8.0.1 - devtools==0.12.2 - diskcache==5.6.3 - factory-boy==3.3.1 - frozendict==2.4.6 - - gridtools-cpp==2.3.6 + - gridtools-cpp==2.3.7 - importlib-resources==6.4.5 - jinja2==3.1.4 - lark==1.2.2 @@ -112,7 +112,7 @@ repos: - nanobind==2.2.0 - ninja==1.11.1.1 - numpy==1.24.4 - - packaging==24.1 + - packaging==24.2 - pybind11==2.13.6 - setuptools==75.3.0 - tabulate==0.9.0 diff --git a/constraints.txt b/constraints.txt index e7acc466cd..4aca6645d5 100644 --- a/constraints.txt +++ b/constraints.txt @@ -25,7 +25,7 @@ chardet==5.2.0 # via tox charset-normalizer==3.4.0 # via requests clang-format==19.1.3 # via -r requirements-dev.in, gt4py (pyproject.toml) click==8.1.7 # via black, bump-my-version, gt4py (pyproject.toml), pip-tools, rich-click -cmake==3.30.5 # via gt4py (pyproject.toml) +cmake==3.31.0.1 # via gt4py (pyproject.toml) cogapp==3.4.1 # via -r requirements-dev.in colorama==0.4.6 # via tox comm==0.2.2 # via ipykernel @@ -35,7 +35,7 @@ cycler==0.12.1 # via matplotlib cytoolz==1.0.0 # via gt4py (pyproject.toml) dace==0.16.1 # via gt4py (pyproject.toml) darglint==1.8.1 # via -r requirements-dev.in -debugpy==1.8.7 # via ipykernel +debugpy==1.8.8 # via ipykernel decorator==5.1.1 # via ipython deepdiff==8.0.1 # via gt4py (pyproject.toml) devtools==0.12.2 # via gt4py (pyproject.toml) @@ -55,7 +55,7 @@ fparser==0.1.4 # via dace frozendict==2.4.6 # via gt4py (pyproject.toml) gitdb==4.0.11 # via gitpython gitpython==3.1.43 # via tach -gridtools-cpp==2.3.6 # via gt4py (pyproject.toml) +gridtools-cpp==2.3.7 # via gt4py (pyproject.toml) hypothesis==6.113.0 # via -r requirements-dev.in, gt4py (pyproject.toml) identify==2.6.1 # via pre-commit idna==3.10 # via requests @@ -66,7 +66,7 @@ inflection==0.5.1 # via pytest-factoryboy iniconfig==2.0.0 # via pytest ipykernel==6.29.5 # via nbmake ipython==8.12.3 # via ipykernel -jedi==0.19.1 # via ipython +jedi==0.19.2 # via ipython jinja2==3.1.4 # via dace, gt4py (pyproject.toml), sphinx jsonschema==4.23.0 # via nbformat jsonschema-specifications==2023.12.1 # via jsonschema @@ -95,7 +95,7 @@ ninja==1.11.1.1 # via gt4py (pyproject.toml) nodeenv==1.9.1 # via pre-commit numpy==1.24.4 # via contourpy, dace, gt4py (pyproject.toml), matplotlib, scipy orderly-set==5.2.2 # via deepdiff -packaging==24.1 # via black, build, gt4py (pyproject.toml), ipykernel, jupytext, matplotlib, pipdeptree, pyproject-api, pytest, pytest-factoryboy, setuptools-scm, sphinx, tox +packaging==24.2 # via black, build, gt4py (pyproject.toml), ipykernel, jupytext, matplotlib, pipdeptree, pyproject-api, pytest, pytest-factoryboy, setuptools-scm, sphinx, tox parso==0.8.4 # via jedi pathspec==0.12.1 # via black pexpect==4.9.0 # via ipython @@ -139,7 +139,7 @@ requests==2.32.3 # via sphinx rich==13.9.4 # via bump-my-version, rich-click, tach rich-click==1.8.3 # via bump-my-version rpds-py==0.20.1 # via jsonschema, referencing -ruff==0.7.2 # via -r requirements-dev.in +ruff==0.7.3 # via -r requirements-dev.in scipy==1.10.1 # via gt4py (pyproject.toml) setuptools-scm==8.1.0 # via fparser six==1.16.0 # via asttokens, astunparse, python-dateutil @@ -159,7 +159,7 @@ stack-data==0.6.3 # via ipython stdlib-list==0.10.0 # via tach sympy==1.12.1 # via dace, gt4py (pyproject.toml) tabulate==0.9.0 # via gt4py (pyproject.toml) -tach==0.14.2 # via -r requirements-dev.in +tach==0.14.3 # via -r requirements-dev.in tomli==2.0.2 ; python_version < "3.11" # via -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox tomli-w==1.0.0 # via tach tomlkit==0.13.2 # via bump-my-version @@ -174,7 +174,7 @@ virtualenv==20.27.1 # via pre-commit, tox wcmatch==10.0 # via bump-my-version wcwidth==0.2.13 # via prompt-toolkit websockets==13.1 # via dace -wheel==0.44.0 # via astunparse, pip-tools +wheel==0.45.0 # via astunparse, pip-tools xxhash==3.0.0 # via gt4py (pyproject.toml) zipp==3.20.2 # via importlib-metadata, importlib-resources diff --git a/min-extra-requirements-test.txt b/min-extra-requirements-test.txt index f63042906c..6fd3d1af55 100644 --- a/min-extra-requirements-test.txt +++ b/min-extra-requirements-test.txt @@ -68,7 +68,7 @@ devtools==0.6 diskcache==5.6.3 factory-boy==3.3.0 frozendict==2.3 -gridtools-cpp==2.3.6 +gridtools-cpp==2.3.7 hypothesis==6.0.0 importlib-resources==5.0; python_version < "3.9" jax[cpu]==0.4.18; python_version >= "3.10" diff --git a/min-requirements-test.txt b/min-requirements-test.txt index 666aa79107..b8779096c0 100644 --- a/min-requirements-test.txt +++ b/min-requirements-test.txt @@ -64,7 +64,7 @@ devtools==0.6 diskcache==5.6.3 factory-boy==3.3.0 frozendict==2.3 -gridtools-cpp==2.3.6 +gridtools-cpp==2.3.7 hypothesis==6.0.0 importlib-resources==5.0; python_version < "3.9" jinja2==3.0.0 diff --git a/pyproject.toml b/pyproject.toml index c9f7b3b50b..7d63f70f15 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ dependencies = [ 'diskcache>=5.6.3', 'factory-boy>=3.3.0', 'frozendict>=2.3', - 'gridtools-cpp>=2.3.6,==2.*', + 'gridtools-cpp>=2.3.7,==2.*', "importlib-resources>=5.0;python_version<'3.9'", 'jinja2>=3.0.0', 'lark>=1.1.2', diff --git a/requirements-dev.txt b/requirements-dev.txt index a036307e80..8892620786 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -25,7 +25,7 @@ chardet==5.2.0 # via -c constraints.txt, tox charset-normalizer==3.4.0 # via -c constraints.txt, requests clang-format==19.1.3 # via -c constraints.txt, -r requirements-dev.in, gt4py (pyproject.toml) click==8.1.7 # via -c constraints.txt, black, bump-my-version, gt4py (pyproject.toml), pip-tools, rich-click -cmake==3.30.5 # via -c constraints.txt, gt4py (pyproject.toml) +cmake==3.31.0.1 # via -c constraints.txt, gt4py (pyproject.toml) cogapp==3.4.1 # via -c constraints.txt, -r requirements-dev.in colorama==0.4.6 # via -c constraints.txt, tox comm==0.2.2 # via -c constraints.txt, ipykernel @@ -35,7 +35,7 @@ cycler==0.12.1 # via -c constraints.txt, matplotlib cytoolz==1.0.0 # via -c constraints.txt, gt4py (pyproject.toml) dace==0.16.1 # via -c constraints.txt, gt4py (pyproject.toml) darglint==1.8.1 # via -c constraints.txt, -r requirements-dev.in -debugpy==1.8.7 # via -c constraints.txt, ipykernel +debugpy==1.8.8 # via -c constraints.txt, ipykernel decorator==5.1.1 # via -c constraints.txt, ipython deepdiff==8.0.1 # via -c constraints.txt, gt4py (pyproject.toml) devtools==0.12.2 # via -c constraints.txt, gt4py (pyproject.toml) @@ -55,7 +55,7 @@ fparser==0.1.4 # via -c constraints.txt, dace frozendict==2.4.6 # via -c constraints.txt, gt4py (pyproject.toml) gitdb==4.0.11 # via -c constraints.txt, gitpython gitpython==3.1.43 # via -c constraints.txt, tach -gridtools-cpp==2.3.6 # via -c constraints.txt, gt4py (pyproject.toml) +gridtools-cpp==2.3.7 # via -c constraints.txt, gt4py (pyproject.toml) hypothesis==6.113.0 # via -c constraints.txt, -r requirements-dev.in, gt4py (pyproject.toml) identify==2.6.1 # via -c constraints.txt, pre-commit idna==3.10 # via -c constraints.txt, requests @@ -66,7 +66,7 @@ inflection==0.5.1 # via -c constraints.txt, pytest-factoryboy iniconfig==2.0.0 # via -c constraints.txt, pytest ipykernel==6.29.5 # via -c constraints.txt, nbmake ipython==8.12.3 # via -c constraints.txt, ipykernel -jedi==0.19.1 # via -c constraints.txt, ipython +jedi==0.19.2 # via -c constraints.txt, ipython jinja2==3.1.4 # via -c constraints.txt, dace, gt4py (pyproject.toml), sphinx jsonschema==4.23.0 # via -c constraints.txt, nbformat jsonschema-specifications==2023.12.1 # via -c constraints.txt, jsonschema @@ -95,7 +95,7 @@ ninja==1.11.1.1 # via -c constraints.txt, gt4py (pyproject.toml) nodeenv==1.9.1 # via -c constraints.txt, pre-commit numpy==1.24.4 # via -c constraints.txt, contourpy, dace, gt4py (pyproject.toml), matplotlib orderly-set==5.2.2 # via -c constraints.txt, deepdiff -packaging==24.1 # via -c constraints.txt, black, build, gt4py (pyproject.toml), ipykernel, jupytext, matplotlib, pipdeptree, pyproject-api, pytest, pytest-factoryboy, setuptools-scm, sphinx, tox +packaging==24.2 # via -c constraints.txt, black, build, gt4py (pyproject.toml), ipykernel, jupytext, matplotlib, pipdeptree, pyproject-api, pytest, pytest-factoryboy, setuptools-scm, sphinx, tox parso==0.8.4 # via -c constraints.txt, jedi pathspec==0.12.1 # via -c constraints.txt, black pexpect==4.9.0 # via -c constraints.txt, ipython @@ -139,7 +139,7 @@ requests==2.32.3 # via -c constraints.txt, sphinx rich==13.9.4 # via -c constraints.txt, bump-my-version, rich-click, tach rich-click==1.8.3 # via -c constraints.txt, bump-my-version rpds-py==0.20.1 # via -c constraints.txt, jsonschema, referencing -ruff==0.7.2 # via -c constraints.txt, -r requirements-dev.in +ruff==0.7.3 # via -c constraints.txt, -r requirements-dev.in setuptools-scm==8.1.0 # via -c constraints.txt, fparser six==1.16.0 # via -c constraints.txt, asttokens, astunparse, python-dateutil smmap==5.0.1 # via -c constraints.txt, gitdb @@ -158,7 +158,7 @@ stack-data==0.6.3 # via -c constraints.txt, ipython stdlib-list==0.10.0 # via -c constraints.txt, tach sympy==1.12.1 # via -c constraints.txt, dace, gt4py (pyproject.toml) tabulate==0.9.0 # via -c constraints.txt, gt4py (pyproject.toml) -tach==0.14.2 # via -c constraints.txt, -r requirements-dev.in +tach==0.14.3 # via -c constraints.txt, -r requirements-dev.in tomli==2.0.2 ; python_version < "3.11" # via -c constraints.txt, -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox tomli-w==1.0.0 # via -c constraints.txt, tach tomlkit==0.13.2 # via -c constraints.txt, bump-my-version @@ -173,7 +173,7 @@ virtualenv==20.27.1 # via -c constraints.txt, pre-commit, tox wcmatch==10.0 # via -c constraints.txt, bump-my-version wcwidth==0.2.13 # via -c constraints.txt, prompt-toolkit websockets==13.1 # via -c constraints.txt, dace -wheel==0.44.0 # via -c constraints.txt, astunparse, pip-tools +wheel==0.45.0 # via -c constraints.txt, astunparse, pip-tools xxhash==3.0.0 # via -c constraints.txt, gt4py (pyproject.toml) zipp==3.20.2 # via -c constraints.txt, importlib-metadata, importlib-resources From 0faa7efe60bd4cae22c7e3b91384b4ec6bcab48f Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 11 Nov 2024 12:02:08 +0100 Subject: [PATCH 082/150] Test tuple fix in gridtools --- pyproject.toml | 2 +- requirements-dev.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7d63f70f15..d665eb119e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ dependencies = [ 'diskcache>=5.6.3', 'factory-boy>=3.3.0', 'frozendict>=2.3', - 'gridtools-cpp>=2.3.7,==2.*', + #'gridtools-cpp>=2.3.7,==2.*', "importlib-resources>=5.0;python_version<'3.9'", 'jinja2>=3.0.0', 'lark>=1.1.2', diff --git a/requirements-dev.txt b/requirements-dev.txt index 8892620786..3a65ebf5da 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -55,7 +55,7 @@ fparser==0.1.4 # via -c constraints.txt, dace frozendict==2.4.6 # via -c constraints.txt, gt4py (pyproject.toml) gitdb==4.0.11 # via -c constraints.txt, gitpython gitpython==3.1.43 # via -c constraints.txt, tach -gridtools-cpp==2.3.7 # via -c constraints.txt, gt4py (pyproject.toml) +gridtools-cpp@git+https://github.com/havogt/gridtools@fix_tuple_copy_assignment_refs#subdirectory=.python_package, # via -c constraints.txt, gt4py (pyproject.toml) hypothesis==6.113.0 # via -c constraints.txt, -r requirements-dev.in, gt4py (pyproject.toml) identify==2.6.1 # via -c constraints.txt, pre-commit idna==3.10 # via -c constraints.txt, requests From 8e8a0a151bf6a5cd49134348513fb46924c7c292 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 11 Nov 2024 12:04:56 +0100 Subject: [PATCH 083/150] Fix typo --- requirements-dev.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index 3a65ebf5da..6a1af877f0 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -55,7 +55,7 @@ fparser==0.1.4 # via -c constraints.txt, dace frozendict==2.4.6 # via -c constraints.txt, gt4py (pyproject.toml) gitdb==4.0.11 # via -c constraints.txt, gitpython gitpython==3.1.43 # via -c constraints.txt, tach -gridtools-cpp@git+https://github.com/havogt/gridtools@fix_tuple_copy_assignment_refs#subdirectory=.python_package, # via -c constraints.txt, gt4py (pyproject.toml) +gridtools-cpp@git+https://github.com/havogt/gridtools@fix_tuple_copy_assignment_refs#subdirectory=.python_package # via -c constraints.txt, gt4py (pyproject.toml) hypothesis==6.113.0 # via -c constraints.txt, -r requirements-dev.in, gt4py (pyproject.toml) identify==2.6.1 # via -c constraints.txt, pre-commit idna==3.10 # via -c constraints.txt, requests From e03dd387cfb5f7f20ac8d1c1067b3da15ad72c59 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 12 Nov 2024 09:59:52 +0100 Subject: [PATCH 084/150] Fix ITIR program hash stability --- src/gt4py/next/iterator/ir.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index f50d8080eb..7098e9fa2e 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -242,7 +242,9 @@ class Program(Node, ValidatedSymbolTableTrait): body: List[Stmt] implicit_domain: bool = False - _NODE_SYMBOLS_: ClassVar[List[Sym]] = [Sym(id=name) for name in GTIR_BUILTINS] + _NODE_SYMBOLS_: ClassVar[List[Sym]] = [ + Sym(id=name) for name in sorted(GTIR_BUILTINS) + ] # sorted for serialization stability # TODO(fthaler): just use hashable types in nodes (tuples instead of lists) From 405cbb05208e9fcb74f94aa3d201438f27753b68 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 13 Nov 2024 17:00:04 +0100 Subject: [PATCH 085/150] Revert "Fix typo" This reverts commit 8e8a0a151bf6a5cd49134348513fb46924c7c292. --- requirements-dev.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index 6a1af877f0..3a65ebf5da 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -55,7 +55,7 @@ fparser==0.1.4 # via -c constraints.txt, dace frozendict==2.4.6 # via -c constraints.txt, gt4py (pyproject.toml) gitdb==4.0.11 # via -c constraints.txt, gitpython gitpython==3.1.43 # via -c constraints.txt, tach -gridtools-cpp@git+https://github.com/havogt/gridtools@fix_tuple_copy_assignment_refs#subdirectory=.python_package # via -c constraints.txt, gt4py (pyproject.toml) +gridtools-cpp@git+https://github.com/havogt/gridtools@fix_tuple_copy_assignment_refs#subdirectory=.python_package, # via -c constraints.txt, gt4py (pyproject.toml) hypothesis==6.113.0 # via -c constraints.txt, -r requirements-dev.in, gt4py (pyproject.toml) identify==2.6.1 # via -c constraints.txt, pre-commit idna==3.10 # via -c constraints.txt, requests From 4c272792d6d2ecae11e67c25f4ac4f9cc2c3921f Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 13 Nov 2024 17:00:08 +0100 Subject: [PATCH 086/150] Revert "Test tuple fix in gridtools" This reverts commit 0faa7efe60bd4cae22c7e3b91384b4ec6bcab48f. --- pyproject.toml | 2 +- requirements-dev.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d665eb119e..7d63f70f15 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ dependencies = [ 'diskcache>=5.6.3', 'factory-boy>=3.3.0', 'frozendict>=2.3', - #'gridtools-cpp>=2.3.7,==2.*', + 'gridtools-cpp>=2.3.7,==2.*', "importlib-resources>=5.0;python_version<'3.9'", 'jinja2>=3.0.0', 'lark>=1.1.2', diff --git a/requirements-dev.txt b/requirements-dev.txt index 3a65ebf5da..8892620786 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -55,7 +55,7 @@ fparser==0.1.4 # via -c constraints.txt, dace frozendict==2.4.6 # via -c constraints.txt, gt4py (pyproject.toml) gitdb==4.0.11 # via -c constraints.txt, gitpython gitpython==3.1.43 # via -c constraints.txt, tach -gridtools-cpp@git+https://github.com/havogt/gridtools@fix_tuple_copy_assignment_refs#subdirectory=.python_package, # via -c constraints.txt, gt4py (pyproject.toml) +gridtools-cpp==2.3.7 # via -c constraints.txt, gt4py (pyproject.toml) hypothesis==6.113.0 # via -c constraints.txt, -r requirements-dev.in, gt4py (pyproject.toml) identify==2.6.1 # via -c constraints.txt, pre-commit idna==3.10 # via -c constraints.txt, requests From a666eef89542a8eb51aefd235435f747e1a0c654 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 13 Nov 2024 17:01:54 +0100 Subject: [PATCH 087/150] Test 2.3.8 --- pyproject.toml | 2 +- requirements-dev.txt | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7d63f70f15..1504c8b17b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ dependencies = [ 'diskcache>=5.6.3', 'factory-boy>=3.3.0', 'frozendict>=2.3', - 'gridtools-cpp>=2.3.7,==2.*', + 'gridtools-cpp>=2.3.8,==2.*', "importlib-resources>=5.0;python_version<'3.9'", 'jinja2>=3.0.0', 'lark>=1.1.2', diff --git a/requirements-dev.txt b/requirements-dev.txt index 8892620786..039da03a6f 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -55,7 +55,6 @@ fparser==0.1.4 # via -c constraints.txt, dace frozendict==2.4.6 # via -c constraints.txt, gt4py (pyproject.toml) gitdb==4.0.11 # via -c constraints.txt, gitpython gitpython==3.1.43 # via -c constraints.txt, tach -gridtools-cpp==2.3.7 # via -c constraints.txt, gt4py (pyproject.toml) hypothesis==6.113.0 # via -c constraints.txt, -r requirements-dev.in, gt4py (pyproject.toml) identify==2.6.1 # via -c constraints.txt, pre-commit idna==3.10 # via -c constraints.txt, requests @@ -180,3 +179,6 @@ zipp==3.20.2 # via -c constraints.txt, importlib-metadata, importli # The following packages are considered to be unsafe in a requirements file: pip==24.3.1 # via -c constraints.txt, pip-tools, pipdeptree setuptools==75.3.0 # via -c constraints.txt, gt4py (pyproject.toml), pip-tools, setuptools-scm + +--extra-index-url https://test.pypi.org/simple/ +gridtools-cpp==2.3.8 # via -c constraints.txt, gt4py (pyproject.toml) \ No newline at end of file From 8402e4ffe68c1bdd1d443b79c5a91a6c4ca108bd Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 13 Nov 2024 20:41:04 +0100 Subject: [PATCH 088/150] Fix type preservation in CSE --- src/gt4py/next/iterator/transforms/cse.py | 14 ++++++++------ .../next/iterator/transforms/inline_lambdas.py | 5 ++++- src/gt4py/next/iterator/type_system/inference.py | 6 +++--- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index ccc1d2195f..5e3b77b062 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -14,6 +14,7 @@ import operator from typing import Callable, Iterable, TypeVar, Union, cast +import gt4py.next.iterator.ir_utils.ir_makers as im from gt4py.eve import ( NodeTranslator, NodeVisitor, @@ -241,7 +242,6 @@ def extract_subexpression( Examples: Default case for `(x+y) + ((x+y)+z)`: - >>> import gt4py.next.iterator.ir_utils.ir_makers as im >>> from gt4py.eve.utils import UIDGenerator >>> expr = im.plus(im.plus("x", "y"), im.plus(im.plus("x", "y"), "z")) >>> predicate = lambda subexpr, num_occurences: num_occurences > 1 @@ -433,10 +433,14 @@ def predicate(subexpr: itir.Expr, num_occurences: int): if num_occurences > 1: if is_local_view: return True - else: + # condition is only necessary since typing on lambdas is not preserved during + # the pass + elif not isinstance(subexpr, itir.Lambda): # only extract fields outside of `as_fieldop` # `as_fieldop(...)(field_expr, field_expr)` # -> `(λ(_cs_1) → as_fieldop(...)(_cs_1, _cs_1))(field_expr)` + # only extract if subexpression is not a trivial tuple expressions, e.g., + # `make_tuple(a, b)`, as this would result in a more costly temporary. assert isinstance(subexpr.type, ts.TypeSpec) if all( isinstance(stype, ts.FieldType) @@ -451,10 +455,8 @@ def predicate(subexpr: itir.Expr, num_occurences: int): return self.generic_visit(node, **kwargs) # apply remapping - result = itir.FunCall( - fun=itir.Lambda(params=list(extracted.keys()), expr=new_expr), - args=list(extracted.values()), - ) + result = im.let(*extracted.items())(new_expr) + itir_type_inference.copy_type(from_=node, to=result, allow_untyped=True) # if the node id is ignored (because its parent is eliminated), but it occurs # multiple times then we want to visit the final result once more. diff --git a/src/gt4py/next/iterator/transforms/inline_lambdas.py b/src/gt4py/next/iterator/transforms/inline_lambdas.py index 920d628166..399a7a3dc6 100644 --- a/src/gt4py/next/iterator/transforms/inline_lambdas.py +++ b/src/gt4py/next/iterator/transforms/inline_lambdas.py @@ -14,6 +14,7 @@ from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift from gt4py.next.iterator.transforms.remap_symbols import RemapSymbolRefs, RenameSymbols from gt4py.next.iterator.transforms.symbol_ref_utils import CountSymbolRefs +from gt4py.next.iterator.type_system import inference as itir_inference # TODO(tehrengruber): Reduce complexity of the function by removing the different options here @@ -98,7 +99,7 @@ def new_name(name): new_expr.location = node.location return new_expr else: - return ir.FunCall( + new_expr = ir.FunCall( fun=ir.Lambda( params=[ param @@ -110,6 +111,8 @@ def new_name(name): args=[arg for arg, eligible in zip(node.args, eligible_params) if not eligible], location=node.location, ) + itir_inference.copy_type(from_=node, to=new_expr, allow_untyped=True) + return new_expr @dataclasses.dataclass diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index edcb9b540c..66d8345b94 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -95,14 +95,14 @@ def _set_node_type(node: itir.Node, type_: ts.TypeSpec) -> None: node.type = type_ -def copy_type(from_: itir.Node, to: itir.Node) -> None: +def copy_type(from_: itir.Node, to: itir.Node, allow_untyped=False) -> None: """ Copy type from one node to another. This function mainly exists for readability reasons. """ - assert isinstance(from_.type, ts.TypeSpec) - _set_node_type(to, from_.type) + assert allow_untyped is not None or isinstance(from_.type, ts.TypeSpec) + _set_node_type(to, from_.type) # type: ignore[arg-type] def on_inferred(callback: Callable, *args: Union[ts.TypeSpec, ObservableTypeSynthesizer]) -> None: From 4547295a99b71efbe42695c627e48b879a51b103 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 13 Nov 2024 20:42:28 +0100 Subject: [PATCH 089/150] Fix type preservation in CSE --- src/gt4py/next/iterator/transforms/cse.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index 5e3b77b062..4932d376ad 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -434,13 +434,11 @@ def predicate(subexpr: itir.Expr, num_occurences: int): if is_local_view: return True # condition is only necessary since typing on lambdas is not preserved during - # the pass + # the transformation elif not isinstance(subexpr, itir.Lambda): # only extract fields outside of `as_fieldop` # `as_fieldop(...)(field_expr, field_expr)` # -> `(λ(_cs_1) → as_fieldop(...)(_cs_1, _cs_1))(field_expr)` - # only extract if subexpression is not a trivial tuple expressions, e.g., - # `make_tuple(a, b)`, as this would result in a more costly temporary. assert isinstance(subexpr.type, ts.TypeSpec) if all( isinstance(stype, ts.FieldType) From 2d6464fd989178dafd11b6897791fc77984eb910 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 14 Nov 2024 10:21:04 +0100 Subject: [PATCH 090/150] Fix test skip matrix --- tests/next_tests/definitions.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index b3d491f989..592177301b 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -128,7 +128,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): # Common list of feature markers to skip COMMON_SKIP_TEST_LIST = [ (REQUIRES_ATLAS, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), - (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), (USES_APPLIED_SHIFTS, XFAIL, UNSUPPORTED_MESSAGE), (USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE), (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE), @@ -141,7 +140,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): ] DACE_SKIP_TEST_LIST = ( COMMON_SKIP_TEST_LIST - + DOMAIN_INFERENCE_SKIP_LIST + [ (USES_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE), (USES_SCAN_IN_FIELD_OPERATOR, XFAIL, UNSUPPORTED_MESSAGE), From 4b67a99a382a557bead1dfedf6af48be223e41ff Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 14 Nov 2024 10:23:07 +0100 Subject: [PATCH 091/150] Fix format --- tests/next_tests/definitions.py | 31 ++++++++++++++----------------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 592177301b..b0bec661f9 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -138,23 +138,20 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, UNSUPPORTED_MESSAGE), ] -DACE_SKIP_TEST_LIST = ( - COMMON_SKIP_TEST_LIST - + [ - (USES_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_SCAN_IN_FIELD_OPERATOR, XFAIL, UNSUPPORTED_MESSAGE), - (USES_IR_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_SCALAR_IN_DOMAIN_AND_FO, XFAIL, UNSUPPORTED_MESSAGE), - (USES_INDEX_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_LIFT_EXPRESSIONS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_ORIGIN, XFAIL, UNSUPPORTED_MESSAGE), - (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), - (USES_TUPLE_ARGS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_TUPLE_RETURNS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_ZERO_DIMENSIONAL_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), - (STARTS_FROM_GTIR_PROGRAM, SKIP, UNSUPPORTED_MESSAGE), - ] -) +DACE_SKIP_TEST_LIST = COMMON_SKIP_TEST_LIST + [ + (USES_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_SCAN_IN_FIELD_OPERATOR, XFAIL, UNSUPPORTED_MESSAGE), + (USES_IR_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_SCALAR_IN_DOMAIN_AND_FO, XFAIL, UNSUPPORTED_MESSAGE), + (USES_INDEX_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_LIFT_EXPRESSIONS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_ORIGIN, XFAIL, UNSUPPORTED_MESSAGE), + (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), + (USES_TUPLE_ARGS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_TUPLE_RETURNS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_ZERO_DIMENSIONAL_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), + (STARTS_FROM_GTIR_PROGRAM, SKIP, UNSUPPORTED_MESSAGE), +] GTIR_DACE_SKIP_TEST_LIST = DOMAIN_INFERENCE_SKIP_LIST + [ (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), (USES_INDEX_BUILTIN, XFAIL, UNSUPPORTED_MESSAGE), From af2ed5f3fc54c3746e9fc2e124aeedc02547c1ae Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 14 Nov 2024 12:46:17 +0100 Subject: [PATCH 092/150] Address review comments --- tests/next_tests/definitions.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index b0bec661f9..175edcb160 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -153,7 +153,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (STARTS_FROM_GTIR_PROGRAM, SKIP, UNSUPPORTED_MESSAGE), ] GTIR_DACE_SKIP_TEST_LIST = DOMAIN_INFERENCE_SKIP_LIST + [ - (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), (USES_INDEX_BUILTIN, XFAIL, UNSUPPORTED_MESSAGE), (USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE), (USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE), From bff17b2fd01a66cc9d7d04b19044a7f7ce168788 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 14 Nov 2024 17:59:17 +0100 Subject: [PATCH 093/150] Inline dynamic shifts --- .../iterator/transforms/fuse_as_fieldop.py | 207 ++++++++++-------- .../next/iterator/transforms/infer_domain.py | 103 +++++++-- .../transforms/inline_dynamic_shifts.py | 76 +++++++ .../next/iterator/transforms/pass_manager.py | 7 + tests/next_tests/definitions.py | 1 - .../test_inline_dynamic_shifts.py | 48 ++++ 6 files changed, 326 insertions(+), 116 deletions(-) create mode 100644 src/gt4py/next/iterator/transforms/inline_dynamic_shifts.py create mode 100644 tests/next_tests/unit_tests/iterator_tests/test_inline_dynamic_shifts.py diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index da238733da..becab10853 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -67,6 +67,107 @@ def _is_tuple_expr_of_literals(expr: itir.Expr): return isinstance(expr, itir.Literal) +def _inline_as_fieldop_arg( + arg: itir.Expr, *, uids: eve_utils.UIDGenerator +) -> tuple[itir.Expr, dict[str, itir.Expr]]: + assert cpm.is_applied_as_fieldop(arg) + arg = _canonicalize_as_fieldop(arg) + + stencil, *_ = arg.fun.args # type: ignore[attr-defined] # ensured by `is_applied_as_fieldop` + inner_args: list[itir.Expr] = arg.args + extracted_args: dict[str, itir.Expr] = {} # mapping from outer-stencil param to arg + + stencil_params: list[itir.Sym] = [] + stencil_body: itir.Expr = stencil.expr + + for inner_param, inner_arg in zip(stencil.params, inner_args, strict=True): + if isinstance(inner_arg, itir.SymRef): + stencil_params.append(inner_param) + extracted_args[inner_arg.id] = inner_arg + elif isinstance(inner_arg, itir.Literal): + # note: only literals, not all scalar expressions are required as it doesn't make sense + # for them to be computed per grid point. + stencil_body = im.let(inner_param, im.promote_to_const_iterator(inner_arg))( + stencil_body + ) + else: + # a scalar expression, a previously not inlined `as_fieldop` call or an opaque + # expression e.g. containing a tuple + stencil_params.append(inner_param) + new_outer_stencil_param = uids.sequential_id(prefix="__iasfop") + extracted_args[new_outer_stencil_param] = inner_arg + + return im.lift(im.lambda_(*stencil_params)(stencil_body))( + *extracted_args.keys() + ), extracted_args + + +def fuse_as_fieldop( + expr: itir.Expr, eligible_args: list[bool], *, uids: eve_utils.UIDGenerator +) -> itir.Expr: + assert cpm.is_applied_as_fieldop(expr) and isinstance(expr.fun.args[0], itir.Lambda) # type: ignore[attr-defined] # ensured by is_applied_as_fieldop + + stencil: itir.Lambda = expr.fun.args[0] # type: ignore[attr-defined] # ensured by is_applied_as_fieldop + domain = expr.fun.args[1] if len(expr.fun.args) > 1 else None # type: ignore[attr-defined] # ensured by is_applied_as_fieldop + + args: list[itir.Expr] = expr.args + + new_args: dict[str, itir.Expr] = {} + new_stencil_body: itir.Expr = stencil.expr + + for eligible, stencil_param, arg in zip(eligible_args, stencil.params, args, strict=True): + if eligible: + if cpm.is_applied_as_fieldop(arg): + pass + elif cpm.is_call_to(arg, "if_"): + # TODO(tehrengruber): revisit if we want to inline if_ + type_ = arg.type + arg = im.op_as_fieldop("if_")(*arg.args) + arg.type = type_ + elif _is_tuple_expr_of_literals(arg): + arg = im.op_as_fieldop(im.lambda_()(arg))() + else: + raise NotImplementedError() + + inline_expr, extracted_args = _inline_as_fieldop_arg(arg, uids=uids) + + new_stencil_body = im.let(stencil_param, inline_expr)(new_stencil_body) + + new_args = _merge_arguments(new_args, extracted_args) + else: + # just a safety check if typing information is available + if arg.type and not isinstance(arg.type, ts.DeferredType): + assert isinstance(arg.type, ts.TypeSpec) + dtype = type_info.apply_to_primitive_constituents(type_info.extract_dtype, arg.type) + assert not isinstance(dtype, it_ts.ListType) + new_param: str + if isinstance( + arg, itir.SymRef + ): # use name from outer scope (optional, just to get a nice IR) + new_param = arg.id + new_stencil_body = im.let(stencil_param.id, arg.id)(new_stencil_body) + else: + new_param = stencil_param.id + new_args = _merge_arguments(new_args, {new_param: arg}) + + new_node = im.as_fieldop(im.lambda_(*new_args.keys())(new_stencil_body), domain)( + *new_args.values() + ) + + # simplify stencil directly to keep the tree small + new_node = inline_center_deref_lift_vars.InlineCenterDerefLiftVars.apply( + new_node + ) # to keep the tree small + new_node = inline_lambdas.InlineLambdas.apply( + new_node, opcount_preserving=True, force_inline_lift_args=True + ) + new_node = inline_lifts.InlineLifts().visit(new_node) + + type_inference.copy_type(from_=expr, to=new_node) + + return new_node + + @dataclasses.dataclass class FuseAsFieldOp(eve.NodeTranslator): """ @@ -97,38 +198,6 @@ class FuseAsFieldOp(eve.NodeTranslator): uids: eve_utils.UIDGenerator - def _inline_as_fieldop_arg(self, arg: itir.Expr) -> tuple[itir.Expr, dict[str, itir.Expr]]: - assert cpm.is_applied_as_fieldop(arg) - arg = _canonicalize_as_fieldop(arg) - - stencil, *_ = arg.fun.args # type: ignore[attr-defined] # ensured by `is_applied_as_fieldop` - inner_args: list[itir.Expr] = arg.args - extracted_args: dict[str, itir.Expr] = {} # mapping from outer-stencil param to arg - - stencil_params: list[itir.Sym] = [] - stencil_body: itir.Expr = stencil.expr - - for inner_param, inner_arg in zip(stencil.params, inner_args, strict=True): - if isinstance(inner_arg, itir.SymRef): - stencil_params.append(inner_param) - extracted_args[inner_arg.id] = inner_arg - elif isinstance(inner_arg, itir.Literal): - # note: only literals, not all scalar expressions are required as it doesn't make sense - # for them to be computed per grid point. - stencil_body = im.let(inner_param, im.promote_to_const_iterator(inner_arg))( - stencil_body - ) - else: - # a scalar expression, a previously not inlined `as_fieldop` call or an opaque - # expression e.g. containing a tuple - stencil_params.append(inner_param) - new_outer_stencil_param = self.uids.sequential_id(prefix="__iasfop") - extracted_args[new_outer_stencil_param] = inner_arg - - return im.lift(im.lambda_(*stencil_params)(stencil_body))( - *extracted_args.keys() - ), extracted_args - @classmethod def apply( cls, @@ -155,72 +224,26 @@ def visit_FunCall(self, node: itir.FunCall): if cpm.is_call_to(node.fun, "as_fieldop") and isinstance(node.fun.args[0], itir.Lambda): stencil: itir.Lambda = node.fun.args[0] - domain = node.fun.args[1] if len(node.fun.args) > 1 else None - - shifts = trace_shifts.trace_stencil(stencil) - args: list[itir.Expr] = node.args + shifts = trace_shifts.trace_stencil(stencil) - new_args: dict[str, itir.Expr] = {} - new_stencil_body: itir.Expr = stencil.expr - - for stencil_param, arg, arg_shifts in zip(stencil.params, args, shifts, strict=True): + eligible_args = [] + for arg, arg_shifts in zip(args, shifts, strict=True): assert isinstance(arg.type, ts.TypeSpec) dtype = type_info.apply_to_primitive_constituents(type_info.extract_dtype, arg.type) # TODO(tehrengruber): make this configurable - should_inline = _is_tuple_expr_of_literals(arg) or ( - isinstance(arg, itir.FunCall) - and ( - cpm.is_call_to(arg.fun, "as_fieldop") - and isinstance(arg.fun.args[0], itir.Lambda) - or cpm.is_call_to(arg, "if_") + eligible_args.append( + _is_tuple_expr_of_literals(arg) + or ( + isinstance(arg, itir.FunCall) + and ( + cpm.is_call_to(arg.fun, "as_fieldop") + and isinstance(arg.fun.args[0], itir.Lambda) + or cpm.is_call_to(arg, "if_") + ) + and (isinstance(dtype, it_ts.ListType) or len(arg_shifts) <= 1) ) - and (isinstance(dtype, it_ts.ListType) or len(arg_shifts) <= 1) ) - if should_inline: - if cpm.is_applied_as_fieldop(arg): - pass - elif cpm.is_call_to(arg, "if_"): - # TODO(tehrengruber): revisit if we want to inline if_ - type_ = arg.type - arg = im.op_as_fieldop("if_")(*arg.args) - arg.type = type_ - elif _is_tuple_expr_of_literals(arg): - arg = im.op_as_fieldop(im.lambda_()(arg))() - else: - raise NotImplementedError() - - inline_expr, extracted_args = self._inline_as_fieldop_arg(arg) - - new_stencil_body = im.let(stencil_param, inline_expr)(new_stencil_body) - - new_args = _merge_arguments(new_args, extracted_args) - else: - assert not isinstance(dtype, it_ts.ListType) - new_param: str - if isinstance( - arg, itir.SymRef - ): # use name from outer scope (optional, just to get a nice IR) - new_param = arg.id - new_stencil_body = im.let(stencil_param.id, arg.id)(new_stencil_body) - else: - new_param = stencil_param.id - new_args = _merge_arguments(new_args, {new_param: arg}) - - new_node = im.as_fieldop(im.lambda_(*new_args.keys())(new_stencil_body), domain)( - *new_args.values() - ) - - # simplify stencil directly to keep the tree small - new_node = inline_center_deref_lift_vars.InlineCenterDerefLiftVars.apply( - new_node - ) # to keep the tree small - new_node = inline_lambdas.InlineLambdas.apply( - new_node, opcount_preserving=True, force_inline_lift_args=True - ) - new_node = inline_lifts.InlineLifts().visit(new_node) - - type_inference.copy_type(from_=node, to=new_node) - return new_node + return fuse_as_fieldop(node, eligible_args, uids=self.uids) return node diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index 6852b47a7a..b1612d5f63 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -10,7 +10,7 @@ import itertools import typing -from typing import Callable, Optional, TypeAlias +from typing import Callable, Literal, Optional, TypeAlias from gt4py import eve from gt4py.eve import utils as eve_utils @@ -25,7 +25,7 @@ from gt4py.next.utils import flatten_nested_tuple, tree_map -DOMAIN: TypeAlias = domain_utils.SymbolicDomain | None | tuple["DOMAIN", ...] +DOMAIN: TypeAlias = domain_utils.SymbolicDomain | None | Literal["UNKNOWN"] | tuple["DOMAIN", ...] ACCESSED_DOMAINS: TypeAlias = dict[str, DOMAIN] @@ -58,9 +58,12 @@ def _split_dict_by_key(pred: Callable, d: dict): # TODO(tehrengruber): Revisit whether we want to move this behaviour to `domain_utils.domain_union`. def _domain_union_with_none( - *domains: domain_utils.SymbolicDomain | None, -) -> domain_utils.SymbolicDomain | None: - filtered_domains: list[domain_utils.SymbolicDomain] = [d for d in domains if d is not None] + *domains: domain_utils.SymbolicDomain | None | Literal["UNKNOWN"], +) -> domain_utils.SymbolicDomain | None | Literal["UNKNOWN"]: + if any(d == "UNKNOWN" for d in domains): + return "UNKNOWN" + + filtered_domains: list[domain_utils.SymbolicDomain] = [d for d in domains if d is not None] # type: ignore[misc] # domain can never be none because as such cases are filtered above if len(filtered_domains) == 0: return None return domain_utils.domain_union(*filtered_domains) @@ -122,11 +125,16 @@ def _extract_accessed_domains( offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]], ) -> ACCESSED_DOMAINS: - accessed_domains: dict[str, domain_utils.SymbolicDomain | None] = {} + accessed_domains: dict[str, domain_utils.SymbolicDomain | None | Literal["UNKNOWN"]] = {} shifts_results = trace_shifts.trace_stencil(stencil, num_args=len(input_ids)) for in_field_id, shifts_list in zip(input_ids, shifts_results, strict=True): + # special marker for dynamic shifts + if any(s == trace_shifts.Sentinel.VALUE for shift in shifts_list for s in shift): + accessed_domains[in_field_id] = "UNKNOWN" + continue + new_domains = [ domain_utils.SymbolicDomain.translate( target_domain, shift, offset_provider, symbolic_domain_sizes @@ -146,6 +154,7 @@ def _infer_as_fieldop( target_domain: DOMAIN, offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]], + allowed_unknown_domains: list[str], ) -> tuple[itir.FunCall, ACCESSED_DOMAINS]: assert isinstance(applied_fieldop, itir.FunCall) assert cpm.is_call_to(applied_fieldop.fun, "as_fieldop") @@ -186,7 +195,11 @@ def _infer_as_fieldop( transformed_inputs: list[itir.Expr] = [] for in_field_id, in_field in zip(input_ids, inputs): transformed_input, accessed_domains_tmp = infer_expr( - in_field, inputs_accessed_domains[in_field_id], offset_provider, symbolic_domain_sizes + in_field, + inputs_accessed_domains[in_field_id], + offset_provider, + symbolic_domain_sizes, + allowed_unknown_domains, ) transformed_inputs.append(transformed_input) @@ -209,14 +222,20 @@ def _infer_let( input_domain: DOMAIN, offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]], + allowed_unknown_domains: list[str], ) -> tuple[itir.FunCall, ACCESSED_DOMAINS]: assert cpm.is_let(let_expr) assert isinstance(let_expr.fun, itir.Lambda) # just to make mypy happy + let_params = {param_sym.id for param_sym in let_expr.fun.params} + transformed_calls_expr, accessed_domains = infer_expr( - let_expr.fun.expr, input_domain, offset_provider, symbolic_domain_sizes + let_expr.fun.expr, + input_domain, + offset_provider, + symbolic_domain_sizes, + [p for p in allowed_unknown_domains if p not in let_params], ) - let_params = {param_sym.id for param_sym in let_expr.fun.params} accessed_domains_let_args, accessed_domains_outer = _split_dict_by_key( lambda k: k in let_params, accessed_domains ) @@ -231,6 +250,7 @@ def _infer_let( ), offset_provider, symbolic_domain_sizes, + allowed_unknown_domains, ) accessed_domains_outer = _merge_domains(accessed_domains_outer, accessed_domains_arg) transformed_calls_args.append(transformed_calls_arg) @@ -250,6 +270,7 @@ def _infer_make_tuple( domain: DOMAIN, offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]], + allowed_unknown_domains: list[str], ) -> tuple[itir.Expr, ACCESSED_DOMAINS]: assert cpm.is_call_to(expr, "make_tuple") infered_args_expr = [] @@ -266,7 +287,7 @@ def _infer_make_tuple( domain = (*domain, *(None for _ in range(len(expr.args) - len(domain)))) for i, arg in enumerate(expr.args): infered_arg_expr, actual_domains_arg = infer_expr( - arg, domain[i], offset_provider, symbolic_domain_sizes + arg, domain[i], offset_provider, symbolic_domain_sizes, allowed_unknown_domains ) infered_args_expr.append(infered_arg_expr) actual_domains = _merge_domains(actual_domains, actual_domains_arg) @@ -279,6 +300,7 @@ def _infer_tuple_get( domain: DOMAIN, offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]], + allowed_unknown_domains: list[str], ) -> tuple[itir.Expr, ACCESSED_DOMAINS]: assert cpm.is_call_to(expr, "tuple_get") actual_domains: ACCESSED_DOMAINS = {} @@ -287,7 +309,7 @@ def _infer_tuple_get( idx = int(idx_expr.value) tuple_domain = tuple(None if i != idx else domain for i in range(idx + 1)) infered_arg_expr, actual_domains_arg = infer_expr( - tuple_arg, tuple_domain, offset_provider, symbolic_domain_sizes + tuple_arg, tuple_domain, offset_provider, symbolic_domain_sizes, allowed_unknown_domains ) infered_args_expr = im.tuple_get(idx, infered_arg_expr) @@ -300,6 +322,7 @@ def _infer_if( domain: DOMAIN, offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]], + allowed_unknown_domains: list[str], ) -> tuple[itir.Expr, ACCESSED_DOMAINS]: assert cpm.is_call_to(expr, "if_") infered_args_expr = [] @@ -307,7 +330,7 @@ def _infer_if( cond, true_val, false_val = expr.args for arg in [true_val, false_val]: infered_arg_expr, actual_domains_arg = infer_expr( - arg, domain, offset_provider, symbolic_domain_sizes + arg, domain, offset_provider, symbolic_domain_sizes, allowed_unknown_domains ) infered_args_expr.append(infered_arg_expr) actual_domains = _merge_domains(actual_domains, actual_domains_arg) @@ -320,21 +343,32 @@ def _infer_expr( domain: DOMAIN, offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]], + allowed_unknown_domains: list[str], ) -> tuple[itir.Expr, ACCESSED_DOMAINS]: if isinstance(expr, itir.SymRef): return expr, {str(expr.id): domain} elif isinstance(expr, itir.Literal): return expr, {} elif cpm.is_applied_as_fieldop(expr): - return _infer_as_fieldop(expr, domain, offset_provider, symbolic_domain_sizes) + return _infer_as_fieldop( + expr, domain, offset_provider, symbolic_domain_sizes, allowed_unknown_domains + ) elif cpm.is_let(expr): - return _infer_let(expr, domain, offset_provider, symbolic_domain_sizes) + return _infer_let( + expr, domain, offset_provider, symbolic_domain_sizes, allowed_unknown_domains + ) elif cpm.is_call_to(expr, "make_tuple"): - return _infer_make_tuple(expr, domain, offset_provider, symbolic_domain_sizes) + return _infer_make_tuple( + expr, domain, offset_provider, symbolic_domain_sizes, allowed_unknown_domains + ) elif cpm.is_call_to(expr, "tuple_get"): - return _infer_tuple_get(expr, domain, offset_provider, symbolic_domain_sizes) + return _infer_tuple_get( + expr, domain, offset_provider, symbolic_domain_sizes, allowed_unknown_domains + ) elif cpm.is_call_to(expr, "if_"): - return _infer_if(expr, domain, offset_provider, symbolic_domain_sizes) + return _infer_if( + expr, domain, offset_provider, symbolic_domain_sizes, allowed_unknown_domains + ) elif ( cpm.is_call_to(expr, itir.ARITHMETIC_BUILTINS) or cpm.is_call_to(expr, itir.TYPEBUILTINS) @@ -350,6 +384,7 @@ def infer_expr( domain: DOMAIN, offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]] = None, + allowed_unknown_domains: Optional[list[str]] = None, ) -> tuple[itir.Expr, ACCESSED_DOMAINS]: """ Infer the domain of all field subexpressions of `expr`. @@ -362,15 +397,27 @@ def infer_expr( - domain: The domain `expr` is read at. - symbolic_domain_sizes: A dictionary mapping axes names, e.g., `I`, `Vertex`, to a symbol name that evaluates to the length of that axis. + - allowed_unknown_domains: A list of references (as strings) for which the domain does not need + to be inferred, but can be `UNKNOWN`. This is used by `infer_program` to allow dynamic shifts + on program inputs. Returns: A tuple containing the inferred expression with all applied `as_fieldop` (that are accessed) having a domain argument now, and a dictionary mapping symbol names referenced in `expr` to domain they are accessed at. """ - # this is just a small wrapper that populates the `domain` annex - expr, accessed_domains = _infer_expr(expr, domain, offset_provider, symbolic_domain_sizes) + allowed_unknown_domains = allowed_unknown_domains or [] + + expr, accessed_domains = _infer_expr( + expr, domain, offset_provider, symbolic_domain_sizes, allowed_unknown_domains + ) expr.annex.domain = domain + + if any( + p not in allowed_unknown_domains and d == "UNKNOWN" for p, d in accessed_domains.items() + ): + raise ValueError("Some accessed domains are unknown, e.g. because of a dynamic shift.") + return expr, accessed_domains @@ -378,14 +425,17 @@ def _infer_stmt( stmt: itir.Stmt, offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]], + allowed_unknown_domains: list[str], ): if isinstance(stmt, itir.SetAt): - transformed_call, _unused_domain = infer_expr( + transformed_call, _ = infer_expr( stmt.expr, domain_utils.SymbolicDomain.from_expr(stmt.domain), offset_provider, symbolic_domain_sizes, + allowed_unknown_domains=allowed_unknown_domains, ) + return itir.SetAt( expr=transformed_call, domain=stmt.domain, @@ -395,10 +445,12 @@ def _infer_stmt( return itir.IfStmt( cond=stmt.cond, true_branch=[ - _infer_stmt(c, offset_provider, symbolic_domain_sizes) for c in stmt.true_branch + _infer_stmt(c, offset_provider, symbolic_domain_sizes, allowed_unknown_domains) + for c in stmt.true_branch ], false_branch=[ - _infer_stmt(c, offset_provider, symbolic_domain_sizes) for c in stmt.false_branch + _infer_stmt(c, offset_provider, symbolic_domain_sizes, allowed_unknown_domains) + for c in stmt.false_branch ], ) raise ValueError(f"Unsupported stmt: {stmt}") @@ -423,5 +475,10 @@ def infer_program( function_definitions=program.function_definitions, params=program.params, declarations=program.declarations, - body=[_infer_stmt(stmt, offset_provider, symbolic_domain_sizes) for stmt in program.body], + body=[ + _infer_stmt( + stmt, offset_provider, symbolic_domain_sizes, [param.id for param in program.params] + ) + for stmt in program.body + ], ) diff --git a/src/gt4py/next/iterator/transforms/inline_dynamic_shifts.py b/src/gt4py/next/iterator/transforms/inline_dynamic_shifts.py new file mode 100644 index 0000000000..ee231ecd2f --- /dev/null +++ b/src/gt4py/next/iterator/transforms/inline_dynamic_shifts.py @@ -0,0 +1,76 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import dataclasses +from typing import Optional + +import gt4py.next.iterator.ir_utils.common_pattern_matcher as cpm +from gt4py import eve +from gt4py.eve import utils as eve_utils +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.transforms import fuse_as_fieldop, inline_lambdas, trace_shifts +from gt4py.next.iterator.transforms.symbol_ref_utils import collect_symbol_refs + + +def _dynamic_shift_args(node: itir.Expr) -> None | list[bool]: + if not cpm.is_applied_as_fieldop(node): + return None + params_shifts = trace_shifts.trace_stencil( + node.fun.args[0], # type: ignore[attr-defined] # ensured by is_applied_as_fieldop + num_args=len(node.args), + save_to_annex=True, + ) + dynamic_shifts = [] + for param_shifts in params_shifts: + has_dynamic_shift = False + for shifts in param_shifts: + for _, offset in zip(shifts[::2], shifts[1::2], strict=True): + has_dynamic_shift |= offset == trace_shifts.Sentinel.VALUE + dynamic_shifts.append(has_dynamic_shift) + return dynamic_shifts + + +@dataclasses.dataclass +class InlineDynamicShifts(eve.NodeTranslator, eve.VisitorWithSymbolTableTrait): + uids: eve_utils.UIDGenerator + + @classmethod + def apply(cls, node: itir.Program, uids: Optional[eve_utils.UIDGenerator] = None): + if not uids: + uids = eve_utils.UIDGenerator() + + return cls(uids=uids).visit(node) + + def visit_FunCall(self, node: itir.FunCall, **kwargs): + node = self.generic_visit(node, **kwargs) + + if cpm.is_let(node) and ( + dynamic_shift_args := _dynamic_shift_args(let_body := node.fun.expr) # type: ignore[attr-defined] # ensured by is_let + ): + inline_let_params = {p.id: False for p in node.fun.params} # type: ignore[attr-defined] # ensured by is_let + + for inp, is_dynamic_shift_arg in zip(let_body.args, dynamic_shift_args, strict=True): + for ref in collect_symbol_refs(inp): + if ref in inline_let_params and is_dynamic_shift_arg: + inline_let_params[ref] = True + + if any(inline_let_params): + node = inline_lambdas.inline_lambda( + node, eligible_params=list(inline_let_params.values()) + ) + + if dynamic_shift_args := _dynamic_shift_args(node): + assert len(node.fun.args) in [1, 2] # type: ignore[attr-defined] # ensured by is_applied_as_fieldop in _dynamic_shift_args + fuse_args = [ + not isinstance(inp, itir.SymRef) and dynamic_shift_arg + for inp, dynamic_shift_arg in zip(node.args, dynamic_shift_args, strict=True) + ] + if any(fuse_args): + return fuse_as_fieldop.fuse_as_fieldop(node, fuse_args, uids=self.uids) + + return node diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 52a452155a..6fc7833c6f 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -16,6 +16,7 @@ fuse_as_fieldop, global_tmps, infer_domain, + inline_dynamic_shifts, inline_fundefs, inline_lifts, ) @@ -76,6 +77,9 @@ def apply_common_transforms( ir = InlineLambdas.apply(ir, opcount_preserving=True, force_inline_lambda_args=True) # required in order to get rid of expressions without a domain (e.g. when a tuple element is never accessed) ir = CollapseTuple.apply(ir, offset_provider=offset_provider) # type: ignore[assignment] # always an itir.Program + ir = inline_dynamic_shifts.InlineDynamicShifts.apply( + ir # type: ignore[arg-type] # always an itir.Program + ) # domain inference does not support dynamic offsets yet ir = infer_domain.infer_program( ir, # type: ignore[arg-type] # always an itir.Program offset_provider=offset_provider, @@ -157,5 +161,8 @@ def apply_fieldview_transforms( ir = inline_fundefs.prune_unreferenced_fundefs(ir) ir = InlineLambdas.apply(ir, opcount_preserving=True, force_inline_lambda_args=True) ir = CollapseTuple.apply(ir, offset_provider=offset_provider) # type: ignore[assignment] # type is still `itir.Program` + ir = inline_dynamic_shifts.InlineDynamicShifts.apply( + ir + ) # domain inference does not support dynamic offsets yet ir = infer_domain.infer_program(ir, offset_provider=offset_provider) return ir diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index b0bec661f9..8221db98e9 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -135,7 +135,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): ] # Markers to skip because of missing features in the domain inference DOMAIN_INFERENCE_SKIP_LIST = [ - (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, UNSUPPORTED_MESSAGE), ] DACE_SKIP_TEST_LIST = COMMON_SKIP_TEST_LIST + [ diff --git a/tests/next_tests/unit_tests/iterator_tests/test_inline_dynamic_shifts.py b/tests/next_tests/unit_tests/iterator_tests/test_inline_dynamic_shifts.py new file mode 100644 index 0000000000..ff7a761c5a --- /dev/null +++ b/tests/next_tests/unit_tests/iterator_tests/test_inline_dynamic_shifts.py @@ -0,0 +1,48 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +from typing import Callable, Optional + +from gt4py import next as gtx +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms import inline_dynamic_shifts +from gt4py.next.type_system import type_specifications as ts + +IDim = gtx.Dimension("IDim") +field_type = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32)) + + +def test_inline_dynamic_shift_as_fieldop_arg(): + testee = im.as_fieldop(im.lambda_("a", "b")(im.deref(im.shift("IOff", im.deref("b"))("a"))))( + im.as_fieldop("deref")("inp"), "offset_field" + ) + expected = im.as_fieldop( + im.lambda_("inp", "offset_field")( + im.deref(im.shift("IOff", im.deref("offset_field"))("inp")) + ) + )("inp", "offset_field") + + actual = inline_dynamic_shifts.InlineDynamicShifts.apply(testee) + assert actual == expected + + +def test_inline_dynamic_shift_let_var(): + testee = im.let("tmp", im.as_fieldop("deref")("inp"))( + im.as_fieldop(im.lambda_("a", "b")(im.deref(im.shift("IOff", im.deref("b"))("a"))))( + "tmp", "offset_field" + ) + ) + + expected = im.as_fieldop( + im.lambda_("inp", "offset_field")( + im.deref(im.shift("IOff", im.deref("offset_field"))("inp")) + ) + )("inp", "offset_field") + + actual = inline_dynamic_shifts.InlineDynamicShifts.apply(testee) + assert actual == expected From 79026c740cc276ea680faae6b60444bab35478ee Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 15 Nov 2024 10:02:37 +0100 Subject: [PATCH 094/150] dace-related changes --- .../program_processors/runners/dace_fieldview/gtir_dataflow.py | 1 - tests/next_tests/definitions.py | 1 - 2 files changed, 2 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index cf91d15aba..3c508742b2 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -1057,7 +1057,6 @@ def _make_unstructured_shift( """Implements shift in unstructured domain by means of a neighbor table.""" assert connectivity.neighbor_axis in it.dimensions neighbor_dim = connectivity.neighbor_axis - assert neighbor_dim not in it.indices origin_dim = connectivity.origin_axis assert origin_dim in it.indices diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 8221db98e9..f3accb61a3 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -152,7 +152,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (STARTS_FROM_GTIR_PROGRAM, SKIP, UNSUPPORTED_MESSAGE), ] GTIR_DACE_SKIP_TEST_LIST = DOMAIN_INFERENCE_SKIP_LIST + [ - (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), (USES_INDEX_BUILTIN, XFAIL, UNSUPPORTED_MESSAGE), (USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE), (USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE), From 16f143f7f1159900164c14656d81508193218b15 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 15 Nov 2024 10:32:45 +0100 Subject: [PATCH 095/150] Address review comments --- tests/next_tests/definitions.py | 3 +++ .../feature_tests/ffront_tests/test_scalar_if.py | 2 +- .../feature_tests/iterator_tests/test_scan.py | 2 +- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 175edcb160..c86ba88ead 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -101,6 +101,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): USES_REDUCTION_OVER_LIFT_EXPRESSIONS = "uses_reduction_over_lift_expressions" USES_SCAN = "uses_scan" USES_SCAN_IN_FIELD_OPERATOR = "uses_scan_in_field_operator" +USES_SCAN_IN_STENCIL = "uses_scan_in_stencil" USES_SCAN_WITHOUT_FIELD_ARGS = "uses_scan_without_field_args" USES_SCAN_NESTED = "uses_scan_nested" USES_SCAN_REQUIRING_PROJECTOR = "uses_scan_requiring_projector" @@ -176,6 +177,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): + [ # floordiv not yet supported, see https://github.com/GridTools/gt4py/issues/1136 (USES_FLOORDIV, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), + (USES_SCAN_IN_STENCIL, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), # max_over broken, see https://github.com/GridTools/gt4py/issues/1289 (USES_MAX_OVER, XFAIL, UNSUPPORTED_MESSAGE), @@ -204,6 +206,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): + [(USES_SCAN_NESTED, XFAIL, UNSUPPORTED_MESSAGE)], ProgramFormatterId.GTFN_CPP_FORMATTER: DOMAIN_INFERENCE_SKIP_LIST + [ + (USES_SCAN_IN_STENCIL, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE), ], ProgramFormatterId.LISP_FORMATTER: DOMAIN_INFERENCE_SKIP_LIST, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py index f5d946c7bd..7ff7edf226 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py @@ -56,7 +56,7 @@ def simple_if(a: cases.IField, b: cases.IField, condition: bool) -> cases.IField cases.verify(cartesian_case, simple_if, a, b, condition, out=out, ref=a if condition else b) -# TODO: test with fields on different domains +# TODO(tehrengruber): test with fields on different domains @pytest.mark.parametrize("condition1, condition2", [[True, False], [True, False]]) @pytest.mark.uses_if_stmts def test_simple_if_conditional(condition1, condition2, cartesian_case): diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py index efa2e3f5b3..e462aa07eb 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py @@ -18,9 +18,9 @@ @pytest.mark.uses_index_fields +@pytest.mark.uses_scan_in_stencil def test_scan_in_stencil(program_processor): # FIXME[#1582](tehrengruber): Remove test after scan is reworked. - pytest.skip("Scan inside of stencil is not supported in GTIR.") program_processor, validate = program_processor isize = 1 From 75695d906b3b8ec461b1d8d4e33922fd4e57185b Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 15 Nov 2024 10:34:02 +0100 Subject: [PATCH 096/150] Address review comments --- src/gt4py/next/iterator/transforms/collapse_tuple.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 4a7f070ac7..f84714e779 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -289,9 +289,12 @@ def transform_inline_trivial_make_tuple(self, node: ir.FunCall, **kwargs) -> Opt return None def transform_propagate_to_if_on_tuples(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: - # TODO(tehrengruber): This significantly increases the size of the tree. Skip transformation - # in local-view for now. Revisit. - if not cpm.is_call_to(node, "if_") and not kwargs["within_stencil"]: + if kwargs["within_stencil"]: + # TODO(tehrengruber): This significantly increases the size of the tree. Skip transformation + # in local-view for now. Revisit. + return None + + if not cpm.is_call_to(node, "if_"): # TODO(tehrengruber): Only inline if type of branch value is a tuple. # Examples: # `(if cond then {1, 2} else {3, 4})[0]` -> `if cond then {1, 2}[0] else {3, 4}[0]` From 6e404999e14d2ae5e01e5c9292d5d8e0b2af1b3a Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 15 Nov 2024 11:57:14 +0100 Subject: [PATCH 097/150] Small cleanup --- src/gt4py/next/iterator/transforms/infer_domain.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index b1612d5f63..54fa3dfc91 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -63,7 +63,7 @@ def _domain_union_with_none( if any(d == "UNKNOWN" for d in domains): return "UNKNOWN" - filtered_domains: list[domain_utils.SymbolicDomain] = [d for d in domains if d is not None] # type: ignore[misc] # domain can never be none because as such cases are filtered above + filtered_domains: list[domain_utils.SymbolicDomain] = [d for d in domains if d is not None] # type: ignore[misc] # domain can never be none because as these cases are filtered above if len(filtered_domains) == 0: return None return domain_utils.domain_union(*filtered_domains) @@ -130,7 +130,8 @@ def _extract_accessed_domains( shifts_results = trace_shifts.trace_stencil(stencil, num_args=len(input_ids)) for in_field_id, shifts_list in zip(input_ids, shifts_results, strict=True): - # special marker for dynamic shifts + # TODO(tehrengruber): Dynamic shifts are not supported by `SymbolicDomain.translate`. Use + # special `UNKNOWN` marker for them until we have implemented a proper solution. if any(s == trace_shifts.Sentinel.VALUE for shift in shifts_list for s in shift): accessed_domains[in_field_id] = "UNKNOWN" continue From b0bd6583b32ceef68c506d50f6d2b463267029a7 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 27 Nov 2024 21:08:32 +0100 Subject: [PATCH 098/150] Address review comments --- .../next/iterator/transforms/global_tmps.py | 2 +- .../next/iterator/transforms/infer_domain.py | 223 +++++++++--------- .../iterator_tests/test_type_inference.py | 2 +- .../transforms_tests/test_domain_inference.py | 116 ++++++--- 4 files changed, 192 insertions(+), 151 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 90f8a6cded..c8e75de259 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -186,7 +186,7 @@ def create_global_tmps( This pass looks at all `as_fieldop` calls and transforms field-typed subexpressions of its arguments into temporaries. """ - program = infer_domain.infer_program(program, offset_provider) + program = infer_domain.infer_program(program, offset_provider=offset_provider) program = type_inference.infer(program, offset_provider=offset_provider) if not uids: diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index 54fa3dfc91..56b626cc61 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -10,10 +10,10 @@ import itertools import typing -from typing import Callable, Literal, Optional, TypeAlias from gt4py import eve from gt4py.eve import utils as eve_utils +from gt4py.eve.extended_typing import Callable, Optional, TypeAlias, Unpack from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ( @@ -25,10 +25,29 @@ from gt4py.next.utils import flatten_nested_tuple, tree_map -DOMAIN: TypeAlias = domain_utils.SymbolicDomain | None | Literal["UNKNOWN"] | tuple["DOMAIN", ...] +class DomainAccessDescriptor(eve.StrEnum): + """ + Descriptor for domains that could not be inferred. + """ + + #: The access if unknown because of a dynamic shift.whose extent is not known. + #: E.g.: `(⇑(λ(arg0, arg1) → ·⟪Ioffₒ, ·arg1⟫(arg0)))(in_field1, in_field2)` + UNKNOWN = "unknown" + #: The domain is never accessed. + #: E.g.: `{in_field1, in_field2}[0]` + NEVER = "never" + + +DOMAIN: TypeAlias = domain_utils.SymbolicDomain | DomainAccessDescriptor | tuple["DOMAIN", ...] ACCESSED_DOMAINS: TypeAlias = dict[str, DOMAIN] +class InferenceOptions(typing.TypedDict): + offset_provider: common.OffsetProvider + symbolic_domain_sizes: Optional[dict[str, str]] + allow_uninferred: bool + + class DomainAnnexDebugger(eve.NodeVisitor): """ Small utility class to debug missing domain attribute in annex. @@ -57,15 +76,19 @@ def _split_dict_by_key(pred: Callable, d: dict): # TODO(tehrengruber): Revisit whether we want to move this behaviour to `domain_utils.domain_union`. -def _domain_union_with_none( - *domains: domain_utils.SymbolicDomain | None | Literal["UNKNOWN"], -) -> domain_utils.SymbolicDomain | None | Literal["UNKNOWN"]: - if any(d == "UNKNOWN" for d in domains): - return "UNKNOWN" - - filtered_domains: list[domain_utils.SymbolicDomain] = [d for d in domains if d is not None] # type: ignore[misc] # domain can never be none because as these cases are filtered above +def _domain_union( + *domains: domain_utils.SymbolicDomain | DomainAccessDescriptor, +) -> domain_utils.SymbolicDomain | DomainAccessDescriptor: + if any(d == DomainAccessDescriptor.UNKNOWN for d in domains): + return DomainAccessDescriptor.UNKNOWN + + filtered_domains: list[domain_utils.SymbolicDomain] = [ + d # type: ignore[misc] # domain can never be unknown because as these cases are filtered above + for d in domains + if d != DomainAccessDescriptor.NEVER + ] if len(filtered_domains) == 0: - return None + return DomainAccessDescriptor.NEVER return domain_utils.domain_union(*filtered_domains) @@ -74,29 +97,35 @@ def _canonicalize_domain_structure(d1: DOMAIN, d2: DOMAIN) -> tuple[DOMAIN, DOMA Given two domains or composites thereof, canonicalize their structure. If one of the arguments is a tuple the other one will be promoted to a tuple of same structure - unless it already is a tuple. Missing values are replaced by None, meaning no domain is - specified. + unless it already is a tuple. Missing values are filled by :ref:`DomainAccessDescriptor.NEVER`. >>> domain = im.domain(common.GridType.CARTESIAN, {}) >>> _canonicalize_domain_structure((domain,), (domain, domain)) == ( - ... (domain, None), + ... (domain, DomainAccessDescriptor.NEVER), ... (domain, domain), ... ) True - >>> _canonicalize_domain_structure((domain, None), None) == ((domain, None), (None, None)) + >>> _canonicalize_domain_structure( + ... (domain, DomainAccessDescriptor.NEVER), DomainAccessDescriptor.NEVER + ... ) == ( + ... (domain, DomainAccessDescriptor.NEVER), + ... (DomainAccessDescriptor.NEVER, DomainAccessDescriptor.NEVER), + ... ) True """ - if d1 is None and isinstance(d2, tuple): - return _canonicalize_domain_structure((None,) * len(d2), d2) - if d2 is None and isinstance(d1, tuple): - return _canonicalize_domain_structure(d1, (None,) * len(d1)) + if d1 is DomainAccessDescriptor.NEVER and isinstance(d2, tuple): + return _canonicalize_domain_structure((DomainAccessDescriptor.NEVER,) * len(d2), d2) + if d2 is DomainAccessDescriptor.NEVER and isinstance(d1, tuple): + return _canonicalize_domain_structure(d1, (DomainAccessDescriptor.NEVER,) * len(d1)) if isinstance(d1, tuple) and isinstance(d2, tuple): return tuple( zip( *( _canonicalize_domain_structure(el1, el2) - for el1, el2 in itertools.zip_longest(d1, d2, fillvalue=None) + for el1, el2 in itertools.zip_longest( + d1, d2, fillvalue=DomainAccessDescriptor.NEVER + ) ) ) ) # type: ignore[return-value] # mypy not smart enough @@ -111,9 +140,9 @@ def _merge_domains( for key, domain in additional_domains.items(): original_domain, domain = _canonicalize_domain_structure( - original_domains.get(key, None), domain + original_domains.get(key, DomainAccessDescriptor.NEVER), domain ) - new_domains[key] = tree_map(_domain_union_with_none)(original_domain, domain) + new_domains[key] = tree_map(_domain_union)(original_domain, domain) return new_domains @@ -124,8 +153,9 @@ def _extract_accessed_domains( target_domain: domain_utils.SymbolicDomain, offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]], + allow_uninferred: bool, ) -> ACCESSED_DOMAINS: - accessed_domains: dict[str, domain_utils.SymbolicDomain | None | Literal["UNKNOWN"]] = {} + accessed_domains: dict[str, domain_utils.SymbolicDomain | DomainAccessDescriptor] = {} shifts_results = trace_shifts.trace_stencil(stencil, num_args=len(input_ids)) @@ -133,7 +163,9 @@ def _extract_accessed_domains( # TODO(tehrengruber): Dynamic shifts are not supported by `SymbolicDomain.translate`. Use # special `UNKNOWN` marker for them until we have implemented a proper solution. if any(s == trace_shifts.Sentinel.VALUE for shift in shifts_list for s in shift): - accessed_domains[in_field_id] = "UNKNOWN" + if not allow_uninferred: + raise ValueError("Dynamic shifts not allowed if `allow_uninferred=False`") + accessed_domains[in_field_id] = DomainAccessDescriptor.UNKNOWN continue new_domains = [ @@ -142,9 +174,8 @@ def _extract_accessed_domains( ) for shift in shifts_list ] - # `None` means field is never accessed - accessed_domains[in_field_id] = _domain_union_with_none( - accessed_domains.get(in_field_id, None), *new_domains + accessed_domains[in_field_id] = _domain_union( + accessed_domains.get(in_field_id, DomainAccessDescriptor.NEVER), *new_domains ) return typing.cast(ACCESSED_DOMAINS, accessed_domains) @@ -153,17 +184,18 @@ def _extract_accessed_domains( def _infer_as_fieldop( applied_fieldop: itir.FunCall, target_domain: DOMAIN, + *, offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]], - allowed_unknown_domains: list[str], + allow_uninferred: bool, ) -> tuple[itir.FunCall, ACCESSED_DOMAINS]: assert isinstance(applied_fieldop, itir.FunCall) assert cpm.is_call_to(applied_fieldop.fun, "as_fieldop") - if target_domain is None: - raise ValueError("'target_domain' cannot be 'None'.") + if target_domain is DomainAccessDescriptor.NEVER: + raise ValueError("'target_domain' cannot be 'NEVER'.") # FIXME[#1582](tehrengruber): Temporary solution for `tuple_get` on scan result. See `test_solve_triag`. if isinstance(target_domain, tuple): - target_domain = _domain_union_with_none(*flatten_nested_tuple(target_domain)) + target_domain = _domain_union(*flatten_nested_tuple(target_domain)) # type: ignore[arg-type] # mypy not smart enough if not isinstance(target_domain, domain_utils.SymbolicDomain): raise ValueError("'target_domain' needs to be a 'domain_utils.SymbolicDomain'.") @@ -188,7 +220,7 @@ def _infer_as_fieldop( input_ids.append(id_) inputs_accessed_domains: ACCESSED_DOMAINS = _extract_accessed_domains( - stencil, input_ids, target_domain, offset_provider, symbolic_domain_sizes + stencil, input_ids, target_domain, offset_provider, symbolic_domain_sizes, allow_uninferred ) # Recursively infer domain of inputs and update domain arg of nested `as_fieldop`s @@ -198,9 +230,9 @@ def _infer_as_fieldop( transformed_input, accessed_domains_tmp = infer_expr( in_field, inputs_accessed_domains[in_field_id], - offset_provider, - symbolic_domain_sizes, - allowed_unknown_domains, + offset_provider=offset_provider, + symbolic_domain_sizes=symbolic_domain_sizes, + allow_uninferred=allow_uninferred, ) transformed_inputs.append(transformed_input) @@ -221,21 +253,13 @@ def _infer_as_fieldop( def _infer_let( let_expr: itir.FunCall, input_domain: DOMAIN, - offset_provider: common.OffsetProvider, - symbolic_domain_sizes: Optional[dict[str, str]], - allowed_unknown_domains: list[str], + **kwargs: Unpack[InferenceOptions], ) -> tuple[itir.FunCall, ACCESSED_DOMAINS]: assert cpm.is_let(let_expr) assert isinstance(let_expr.fun, itir.Lambda) # just to make mypy happy let_params = {param_sym.id for param_sym in let_expr.fun.params} - transformed_calls_expr, accessed_domains = infer_expr( - let_expr.fun.expr, - input_domain, - offset_provider, - symbolic_domain_sizes, - [p for p in allowed_unknown_domains if p not in let_params], - ) + transformed_calls_expr, accessed_domains = infer_expr(let_expr.fun.expr, input_domain, **kwargs) accessed_domains_let_args, accessed_domains_outer = _split_dict_by_key( lambda k: k in let_params, accessed_domains @@ -247,11 +271,9 @@ def _infer_let( arg, accessed_domains_let_args.get( param.id, - None, + DomainAccessDescriptor.NEVER, ), - offset_provider, - symbolic_domain_sizes, - allowed_unknown_domains, + **kwargs, ) accessed_domains_outer = _merge_domains(accessed_domains_outer, accessed_domains_arg) transformed_calls_args.append(transformed_calls_arg) @@ -269,9 +291,7 @@ def _infer_let( def _infer_make_tuple( expr: itir.Expr, domain: DOMAIN, - offset_provider: common.OffsetProvider, - symbolic_domain_sizes: Optional[dict[str, str]], - allowed_unknown_domains: list[str], + **kwargs: Unpack[InferenceOptions], ) -> tuple[itir.Expr, ACCESSED_DOMAINS]: assert cpm.is_call_to(expr, "make_tuple") infered_args_expr = [] @@ -283,13 +303,12 @@ def _infer_make_tuple( # out @ c⟨ IDimₕ: [0, __out_size_0) ⟩ ← {__sym_1, __sym_2}; domain = (domain,) * len(expr.args) assert len(expr.args) >= len(domain) - # There may be less domains than tuple args, pad the domain with `None` in that case. - # e.g. `im.tuple_get(0, im.make_tuple(a, b), domain=domain)` - domain = (*domain, *(None for _ in range(len(expr.args) - len(domain)))) + # There may be fewer domains than tuple args, pad the domain with `NEVER` + # in that case. + # e.g. `im.tuple_get(0, im.make_tuple(a, b), domain=domain)` + domain = (*domain, *(DomainAccessDescriptor.NEVER for _ in range(len(expr.args) - len(domain)))) for i, arg in enumerate(expr.args): - infered_arg_expr, actual_domains_arg = infer_expr( - arg, domain[i], offset_provider, symbolic_domain_sizes, allowed_unknown_domains - ) + infered_arg_expr, actual_domains_arg = infer_expr(arg, domain[i], **kwargs) infered_args_expr.append(infered_arg_expr) actual_domains = _merge_domains(actual_domains, actual_domains_arg) result_expr = im.call(expr.fun)(*infered_args_expr) @@ -299,19 +318,17 @@ def _infer_make_tuple( def _infer_tuple_get( expr: itir.Expr, domain: DOMAIN, - offset_provider: common.OffsetProvider, - symbolic_domain_sizes: Optional[dict[str, str]], - allowed_unknown_domains: list[str], + **kwargs: Unpack[InferenceOptions], ) -> tuple[itir.Expr, ACCESSED_DOMAINS]: assert cpm.is_call_to(expr, "tuple_get") actual_domains: ACCESSED_DOMAINS = {} idx_expr, tuple_arg = expr.args assert isinstance(idx_expr, itir.Literal) idx = int(idx_expr.value) - tuple_domain = tuple(None if i != idx else domain for i in range(idx + 1)) - infered_arg_expr, actual_domains_arg = infer_expr( - tuple_arg, tuple_domain, offset_provider, symbolic_domain_sizes, allowed_unknown_domains + tuple_domain = tuple( + DomainAccessDescriptor.NEVER if i != idx else domain for i in range(idx + 1) ) + infered_arg_expr, actual_domains_arg = infer_expr(tuple_arg, tuple_domain, **kwargs) infered_args_expr = im.tuple_get(idx, infered_arg_expr) actual_domains = _merge_domains(actual_domains, actual_domains_arg) @@ -321,18 +338,14 @@ def _infer_tuple_get( def _infer_if( expr: itir.Expr, domain: DOMAIN, - offset_provider: common.OffsetProvider, - symbolic_domain_sizes: Optional[dict[str, str]], - allowed_unknown_domains: list[str], + **kwargs: Unpack[InferenceOptions], ) -> tuple[itir.Expr, ACCESSED_DOMAINS]: assert cpm.is_call_to(expr, "if_") infered_args_expr = [] actual_domains: ACCESSED_DOMAINS = {} cond, true_val, false_val = expr.args for arg in [true_val, false_val]: - infered_arg_expr, actual_domains_arg = infer_expr( - arg, domain, offset_provider, symbolic_domain_sizes, allowed_unknown_domains - ) + infered_arg_expr, actual_domains_arg = infer_expr(arg, domain, **kwargs) infered_args_expr.append(infered_arg_expr) actual_domains = _merge_domains(actual_domains, actual_domains_arg) result_expr = im.call(expr.fun)(cond, *infered_args_expr) @@ -342,34 +355,22 @@ def _infer_if( def _infer_expr( expr: itir.Expr, domain: DOMAIN, - offset_provider: common.OffsetProvider, - symbolic_domain_sizes: Optional[dict[str, str]], - allowed_unknown_domains: list[str], + **kwargs: Unpack[InferenceOptions], ) -> tuple[itir.Expr, ACCESSED_DOMAINS]: if isinstance(expr, itir.SymRef): return expr, {str(expr.id): domain} elif isinstance(expr, itir.Literal): return expr, {} elif cpm.is_applied_as_fieldop(expr): - return _infer_as_fieldop( - expr, domain, offset_provider, symbolic_domain_sizes, allowed_unknown_domains - ) + return _infer_as_fieldop(expr, domain, **kwargs) elif cpm.is_let(expr): - return _infer_let( - expr, domain, offset_provider, symbolic_domain_sizes, allowed_unknown_domains - ) + return _infer_let(expr, domain, **kwargs) elif cpm.is_call_to(expr, "make_tuple"): - return _infer_make_tuple( - expr, domain, offset_provider, symbolic_domain_sizes, allowed_unknown_domains - ) + return _infer_make_tuple(expr, domain, **kwargs) elif cpm.is_call_to(expr, "tuple_get"): - return _infer_tuple_get( - expr, domain, offset_provider, symbolic_domain_sizes, allowed_unknown_domains - ) + return _infer_tuple_get(expr, domain, **kwargs) elif cpm.is_call_to(expr, "if_"): - return _infer_if( - expr, domain, offset_provider, symbolic_domain_sizes, allowed_unknown_domains - ) + return _infer_if(expr, domain, **kwargs) elif ( cpm.is_call_to(expr, itir.ARITHMETIC_BUILTINS) or cpm.is_call_to(expr, itir.TYPEBUILTINS) @@ -383,9 +384,10 @@ def _infer_expr( def infer_expr( expr: itir.Expr, domain: DOMAIN, + *, offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]] = None, - allowed_unknown_domains: Optional[list[str]] = None, + allow_uninferred: bool = False, ) -> tuple[itir.Expr, ACCESSED_DOMAINS]: """ Infer the domain of all field subexpressions of `expr`. @@ -398,43 +400,33 @@ def infer_expr( - domain: The domain `expr` is read at. - symbolic_domain_sizes: A dictionary mapping axes names, e.g., `I`, `Vertex`, to a symbol name that evaluates to the length of that axis. - - allowed_unknown_domains: A list of references (as strings) for which the domain does not need - to be inferred, but can be `UNKNOWN`. This is used by `infer_program` to allow dynamic shifts - on program inputs. + - allow_uninferred: Allow expressions whose domain is either unknown (e.g. because of a + dynamic shift) or empty. Returns: A tuple containing the inferred expression with all applied `as_fieldop` (that are accessed) having a domain argument now, and a dictionary mapping symbol names referenced in `expr` to domain they are accessed at. """ - allowed_unknown_domains = allowed_unknown_domains or [] - expr, accessed_domains = _infer_expr( - expr, domain, offset_provider, symbolic_domain_sizes, allowed_unknown_domains + expr, + domain, + offset_provider=offset_provider, + symbolic_domain_sizes=symbolic_domain_sizes, + allow_uninferred=allow_uninferred, ) expr.annex.domain = domain - if any( - p not in allowed_unknown_domains and d == "UNKNOWN" for p, d in accessed_domains.items() - ): - raise ValueError("Some accessed domains are unknown, e.g. because of a dynamic shift.") - return expr, accessed_domains def _infer_stmt( stmt: itir.Stmt, - offset_provider: common.OffsetProvider, - symbolic_domain_sizes: Optional[dict[str, str]], - allowed_unknown_domains: list[str], + **kwargs: Unpack[InferenceOptions], ): if isinstance(stmt, itir.SetAt): transformed_call, _ = infer_expr( - stmt.expr, - domain_utils.SymbolicDomain.from_expr(stmt.domain), - offset_provider, - symbolic_domain_sizes, - allowed_unknown_domains=allowed_unknown_domains, + stmt.expr, domain_utils.SymbolicDomain.from_expr(stmt.domain), **kwargs ) return itir.SetAt( @@ -445,22 +437,18 @@ def _infer_stmt( elif isinstance(stmt, itir.IfStmt): return itir.IfStmt( cond=stmt.cond, - true_branch=[ - _infer_stmt(c, offset_provider, symbolic_domain_sizes, allowed_unknown_domains) - for c in stmt.true_branch - ], - false_branch=[ - _infer_stmt(c, offset_provider, symbolic_domain_sizes, allowed_unknown_domains) - for c in stmt.false_branch - ], + true_branch=[_infer_stmt(c, **kwargs) for c in stmt.true_branch], + false_branch=[_infer_stmt(c, **kwargs) for c in stmt.false_branch], ) raise ValueError(f"Unsupported stmt: {stmt}") def infer_program( program: itir.Program, + *, offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]] = None, + allow_uninferred: bool = False, ) -> itir.Program: """ Infer the domain of all field subexpressions inside a program. @@ -478,7 +466,10 @@ def infer_program( declarations=program.declarations, body=[ _infer_stmt( - stmt, offset_provider, symbolic_domain_sizes, [param.id for param in program.params] + stmt, + offset_provider=offset_provider, + symbolic_domain_sizes=symbolic_domain_sizes, + allow_uninferred=allow_uninferred, ) for stmt in program.body ], diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 7b6214fb1b..438a5a3421 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -179,7 +179,7 @@ def expression_test_cases(): im.as_fieldop(im.lambda_("x")(im.deref("x")))( im.ref("inp", ts.DeferredType(constraint=None)) ), - ts.DeferredType(constraint=None), + ts.DeferredType(constraint=ts.TupleType), ), # if in field-view scope ( diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py index 141091b450..c7272098eb 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py @@ -72,7 +72,7 @@ def setup_test_as_fieldop( def run_test_program( testee: itir.Program, expected: itir.Program, offset_provider: common.OffsetProvider ) -> None: - actual_program = infer_domain.infer_program(testee, offset_provider) + actual_program = infer_domain.infer_program(testee, offset_provider=offset_provider) folded_program = constant_fold_domain_exprs(actual_program) assert folded_program == expected @@ -85,12 +85,14 @@ def run_test_expr( expected_domains: dict[str, itir.Expr | dict[str | Dimension, tuple[itir.Expr, itir.Expr]]], offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]] = None, + allow_uninferred: bool = False, ): actual_call, actual_domains = infer_domain.infer_expr( testee, domain_utils.SymbolicDomain.from_expr(domain), - offset_provider, - symbolic_domain_sizes, + offset_provider=offset_provider, + symbolic_domain_sizes=symbolic_domain_sizes, + allow_uninferred=allow_uninferred, ) folded_call = constant_fold_domain_exprs(actual_call) folded_domains = constant_fold_accessed_domains(actual_domains) if actual_domains else None @@ -100,10 +102,8 @@ def run_test_expr( def canonicalize_domain(d): if isinstance(d, dict): return im.domain(grid_type, d) - elif isinstance(d, itir.FunCall): + elif isinstance(d, (itir.FunCall, infer_domain.DomainAccessDescriptor)): return d - elif d is None: - return None raise AssertionError() expected_domains = {ref: canonicalize_domain(d) for ref, d in expected_domains.items()} @@ -126,8 +126,10 @@ def constant_fold_domain_exprs(arg: itir.Node) -> itir.Node: def constant_fold_accessed_domains( domains: infer_domain.ACCESSED_DOMAINS, ) -> infer_domain.ACCESSED_DOMAINS: - def fold_domain(domain: domain_utils.SymbolicDomain | None): - if domain is None: + def fold_domain( + domain: domain_utils.SymbolicDomain | Literal[infer_domain.DomainAccessDescriptor.NEVER], + ): + if isinstance(domain, infer_domain.DomainAccessDescriptor): return domain return constant_fold_domain_exprs(domain.as_expr()) @@ -150,7 +152,7 @@ def translate_domain( shift_list = [item for sublist in shift_tuples for item in sublist] translated_domain_expr = domain_utils.SymbolicDomain.from_expr(domain).translate( - shift_list, offset_provider + shift_list, offset_provider=offset_provider ) return constant_fold_domain_exprs(translated_domain_expr.as_expr()) @@ -336,7 +338,7 @@ def test_nested_stencils(offset_provider): "in_field2": translate_domain(domain, {"Ioff": 0, "Joff": -2}, offset_provider), } actual_call, actual_domains = infer_domain.infer_expr( - testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) folded_domains = constant_fold_accessed_domains(actual_domains) folded_call = constant_fold_domain_exprs(actual_call) @@ -380,7 +382,7 @@ def test_nested_stencils_n_times(offset_provider, iterations): } actual_call, actual_domains = infer_domain.infer_expr( - testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) folded_domains = constant_fold_accessed_domains(actual_domains) @@ -393,7 +395,10 @@ def test_unused_input(offset_provider): stencil = im.lambda_("arg0", "arg1")(im.deref("arg0")) domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) - expected_domains = {"in_field1": {IDim: (0, 11)}, "in_field2": None} + expected_domains = { + "in_field1": {IDim: (0, 11)}, + "in_field2": infer_domain.DomainAccessDescriptor.NEVER, + } testee, expected = setup_test_as_fieldop( stencil, domain, @@ -405,7 +410,7 @@ def test_let_unused_field(offset_provider): testee = im.let("a", "c")("b") domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) expected = im.let("a", "c")("b") - expected_domains = {"b": {IDim: (0, 11)}, "c": None} + expected_domains = {"b": {IDim: (0, 11)}, "c": infer_domain.DomainAccessDescriptor.NEVER} run_test_expr(testee, expected, domain, expected_domains, offset_provider) @@ -518,7 +523,7 @@ def test_cond(offset_provider): expected = im.if_(cond, expected_field_1, expected_field_2) actual_call, actual_domains = infer_domain.infer_expr( - testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) folded_domains = constant_fold_accessed_domains(actual_domains) @@ -575,7 +580,7 @@ def test_let(offset_provider): expected_domains_sym = {"in_field": translate_domain(domain, {"Ioff": 2}, offset_provider)} actual_call2, actual_domains2 = infer_domain.infer_expr( - testee2, domain_utils.SymbolicDomain.from_expr(domain), offset_provider + testee2, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) folded_domains2 = constant_fold_accessed_domains(actual_domains2) folded_call2 = constant_fold_domain_exprs(actual_call2) @@ -799,7 +804,7 @@ def test_make_tuple(offset_provider): domain_utils.SymbolicDomain.from_expr(domain1), domain_utils.SymbolicDomain.from_expr(domain2), ), - offset_provider, + offset_provider=offset_provider, ) assert expected == actual @@ -811,13 +816,13 @@ def test_tuple_get_1_make_tuple(offset_provider): domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) expected = im.tuple_get(1, im.make_tuple(im.ref("a"), im.ref("b"), im.ref("c"))) expected_domains = { - "a": None, + "a": infer_domain.DomainAccessDescriptor.NEVER, "b": im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}), - "c": None, + "c": infer_domain.DomainAccessDescriptor.NEVER, } actual, actual_domains = infer_domain.infer_expr( - testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) assert expected == actual @@ -829,7 +834,7 @@ def test_tuple_get_1_nested_make_tuple(offset_provider): domain1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) domain2 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 12)}) expected = im.tuple_get(1, im.make_tuple(im.ref("a"), im.make_tuple(im.ref("b"), im.ref("c")))) - expected_domains = {"a": None, "b": domain1, "c": domain2} + expected_domains = {"a": infer_domain.DomainAccessDescriptor.NEVER, "b": domain1, "c": domain2} actual, actual_domains = infer_domain.infer_expr( testee, @@ -837,7 +842,7 @@ def test_tuple_get_1_nested_make_tuple(offset_provider): domain_utils.SymbolicDomain.from_expr(domain1), domain_utils.SymbolicDomain.from_expr(domain2), ), - offset_provider, + offset_provider=offset_provider, ) assert expected == actual @@ -848,14 +853,18 @@ def test_tuple_get_let_arg_make_tuple(offset_provider): testee = im.tuple_get(1, im.let("a", im.make_tuple(im.ref("b"), im.ref("c")))("d")) domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) expected = im.tuple_get(1, im.let("a", im.make_tuple(im.ref("b"), im.ref("c")))("d")) - expected_domains = {"b": None, "c": None, "d": (None, domain)} + expected_domains = { + "b": infer_domain.DomainAccessDescriptor.NEVER, + "c": infer_domain.DomainAccessDescriptor.NEVER, + "d": (infer_domain.DomainAccessDescriptor.NEVER, domain), + } actual, actual_domains = infer_domain.infer_expr( testee, domain_utils.SymbolicDomain.from_expr( im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) ), - offset_provider, + offset_provider=offset_provider, ) assert expected == actual @@ -866,12 +875,16 @@ def test_tuple_get_let_make_tuple(offset_provider): testee = im.tuple_get(1, im.let("a", "b")(im.make_tuple(im.ref("c"), im.ref("d")))) domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) expected = im.tuple_get(1, im.let("a", "b")(im.make_tuple(im.ref("c"), im.ref("d")))) - expected_domains = {"c": None, "d": domain, "b": None} + expected_domains = { + "c": infer_domain.DomainAccessDescriptor.NEVER, + "d": domain, + "b": infer_domain.DomainAccessDescriptor.NEVER, + } actual, actual_domains = infer_domain.infer_expr( testee, domain_utils.SymbolicDomain.from_expr(domain), - offset_provider, + offset_provider=offset_provider, ) assert expected == actual @@ -899,7 +912,7 @@ def test_nested_make_tuple(offset_provider): ), domain_utils.SymbolicDomain.from_expr(domain3), ), - offset_provider, + offset_provider=offset_provider, ) assert expected == actual @@ -910,10 +923,10 @@ def test_tuple_get_1(offset_provider): testee = im.tuple_get(1, im.ref("a")) domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) expected = im.tuple_get(1, im.ref("a")) - expected_domains = {"a": (None, domain)} + expected_domains = {"a": (infer_domain.DomainAccessDescriptor.NEVER, domain)} actual, actual_domains = infer_domain.infer_expr( - testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) assert expected == actual @@ -933,7 +946,7 @@ def test_domain_tuple(offset_provider): domain_utils.SymbolicDomain.from_expr(domain1), domain_utils.SymbolicDomain.from_expr(domain2), ), - offset_provider, + offset_provider=offset_provider, ) assert expected == actual @@ -949,7 +962,7 @@ def test_as_fieldop_tuple_get(offset_provider): expected_domains = {"a": (domain, domain)} actual, actual_domains = infer_domain.infer_expr( - testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) assert expected == actual @@ -969,7 +982,7 @@ def test_make_tuple_2tuple_get(offset_provider): domain_utils.SymbolicDomain.from_expr(domain1), domain_utils.SymbolicDomain.from_expr(domain2), ), - offset_provider, + offset_provider=offset_provider, ) assert expected == actual @@ -986,7 +999,7 @@ def test_make_tuple_non_tuple_domain(offset_provider): expected_domains = {"in_field1": domain, "in_field2": domain} actual, actual_domains = infer_domain.infer_expr( - testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) assert expected == actual @@ -1000,7 +1013,7 @@ def test_arithmetic_builtin(offset_provider): expected_domains = {} actual_call, actual_domains = infer_domain.infer_expr( - testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) folded_call = constant_fold_domain_exprs(actual_call) @@ -1044,3 +1057,40 @@ def test_symbolic_domain_sizes(unstructured_offset_provider): unstructured_offset_provider, symbolic_domain_sizes, ) + + +@pytest.mark.parametrize("allow_uninferred", [True, False]) +def test_unknown_domain(offset_provider, allow_uninferred: bool): + stencil = im.lambda_("arg0", "arg1")(im.deref(im.shift("Ioff", im.deref("arg1"))("arg0"))) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 10)}) + expected_domains = { + "in_field1": infer_domain.DomainAccessDescriptor.UNKNOWN, + "in_field2": {IDim: (0, 10)}, + } + testee, expected = setup_test_as_fieldop(stencil, domain) + if allow_uninferred: + run_test_expr(testee, expected, domain, expected_domains, offset_provider, None, True) + else: + with pytest.raises(ValueError, match="Dynamic shifts not allowed"): + run_test_expr(testee, expected, domain, expected_domains, offset_provider, None, False) + + +def test_never_accessed_domain(offset_provider): + stencil = im.lambda_("arg0", "arg1")(im.deref("arg0")) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 10)}) + expected_domains = { + "in_field1": {IDim: (0, 10)}, + "in_field2": infer_domain.DomainAccessDescriptor.NEVER, + } + testee, expected = setup_test_as_fieldop(stencil, domain) + run_test_expr(testee, expected, domain, expected_domains, offset_provider) + + +def test_never_accessed_domain_tuple(offset_provider): + testee = im.tuple_get(0, im.make_tuple("in_field1", "in_field2")) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 10)}) + expected_domains = { + "in_field1": {IDim: (0, 10)}, + "in_field2": infer_domain.DomainAccessDescriptor.NEVER, + } + run_test_expr(testee, testee, domain, expected_domains, offset_provider) From 7b67ecb0edc6de68e2936dbf21a16456b00a7c33 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 27 Nov 2024 21:17:53 +0100 Subject: [PATCH 099/150] Small cleanup --- .../next/iterator/transforms/infer_domain.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index 56b626cc61..6e9e6a6cac 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -191,13 +191,15 @@ def _infer_as_fieldop( ) -> tuple[itir.FunCall, ACCESSED_DOMAINS]: assert isinstance(applied_fieldop, itir.FunCall) assert cpm.is_call_to(applied_fieldop.fun, "as_fieldop") - if target_domain is DomainAccessDescriptor.NEVER: - raise ValueError("'target_domain' cannot be 'NEVER'.") + if not allow_uninferred and target_domain is DomainAccessDescriptor.NEVER: + raise ValueError("'target_domain' cannot be 'NEVER' unless `allow_uninferred=True`.") # FIXME[#1582](tehrengruber): Temporary solution for `tuple_get` on scan result. See `test_solve_triag`. if isinstance(target_domain, tuple): target_domain = _domain_union(*flatten_nested_tuple(target_domain)) # type: ignore[arg-type] # mypy not smart enough - if not isinstance(target_domain, domain_utils.SymbolicDomain): - raise ValueError("'target_domain' needs to be a 'domain_utils.SymbolicDomain'.") + if not isinstance(target_domain, (domain_utils.SymbolicDomain, DomainAccessDescriptor)): + raise ValueError( + "'target_domain' needs to be a 'domain_utils.SymbolicDomain' or a 'DomainAccessDescriptor'." + ) # `as_fieldop(stencil)(inputs...)` stencil, inputs = applied_fieldop.fun.args[0], applied_fieldop.args @@ -238,7 +240,10 @@ def _infer_as_fieldop( accessed_domains = _merge_domains(accessed_domains, accessed_domains_tmp) - target_domain_expr = domain_utils.SymbolicDomain.as_expr(target_domain) + if not isinstance(target_domain, DomainAccessDescriptor): + target_domain_expr = domain_utils.SymbolicDomain.as_expr(target_domain) + else: + target_domain_expr = None transformed_call = im.as_fieldop(stencil, target_domain_expr)(*transformed_inputs) accessed_domains_without_tmp = { From 438eb6b70fd9a74d30c4f0a965de80d2a42c12c0 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 27 Nov 2024 21:29:20 +0100 Subject: [PATCH 100/150] Small cleanup --- src/gt4py/next/iterator/transforms/infer_domain.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index 6e9e6a6cac..226bd75ccb 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -150,7 +150,7 @@ def _merge_domains( def _extract_accessed_domains( stencil: itir.Expr, input_ids: list[str], - target_domain: domain_utils.SymbolicDomain, + target_domain: domain_utils.SymbolicDomain | DomainAccessDescriptor, offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]], allow_uninferred: bool, @@ -172,6 +172,8 @@ def _extract_accessed_domains( domain_utils.SymbolicDomain.translate( target_domain, shift, offset_provider, symbolic_domain_sizes ) + if not isinstance(target_domain, DomainAccessDescriptor) + else target_domain for shift in shifts_list ] accessed_domains[in_field_id] = _domain_union( From 752e08aa12086dbef5c393f32d1ef8e94e4fbb2e Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 27 Nov 2024 21:53:48 +0100 Subject: [PATCH 101/150] Small fix --- src/gt4py/next/iterator/transforms/infer_domain.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index 226bd75ccb..77014c082d 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -163,8 +163,6 @@ def _extract_accessed_domains( # TODO(tehrengruber): Dynamic shifts are not supported by `SymbolicDomain.translate`. Use # special `UNKNOWN` marker for them until we have implemented a proper solution. if any(s == trace_shifts.Sentinel.VALUE for shift in shifts_list for s in shift): - if not allow_uninferred: - raise ValueError("Dynamic shifts not allowed if `allow_uninferred=False`") accessed_domains[in_field_id] = DomainAccessDescriptor.UNKNOWN continue From 395e6ee90007116571c6df3a13fb36d6dd0e95b7 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 27 Nov 2024 22:57:18 +0100 Subject: [PATCH 102/150] Small fix --- src/gt4py/next/iterator/transforms/infer_domain.py | 4 ++-- .../unit_tests/iterator_tests/test_type_inference.py | 2 +- .../transforms_tests/test_domain_inference.py | 9 ++------- 3 files changed, 5 insertions(+), 10 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index 77014c082d..3809fb43d7 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -405,8 +405,8 @@ def infer_expr( - domain: The domain `expr` is read at. - symbolic_domain_sizes: A dictionary mapping axes names, e.g., `I`, `Vertex`, to a symbol name that evaluates to the length of that axis. - - allow_uninferred: Allow expressions whose domain is either unknown (e.g. because of a - dynamic shift) or empty. + - allow_uninferred: Allow `as_fieldop` expressions whose domain is either unknown (e.g. + because of a dynamic shift) or empty. Returns: A tuple containing the inferred expression with all applied `as_fieldop` (that are accessed) diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 2e47513a7f..65a5b5888d 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -179,7 +179,7 @@ def expression_test_cases(): im.as_fieldop(im.lambda_("x")(im.deref("x")))( im.ref("inp", ts.DeferredType(constraint=None)) ), - ts.DeferredType(constraint=ts.TupleType), + ts.DeferredType(constraint=None), ), # if in field-view scope ( diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py index de7fea5101..4883718744 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py @@ -1060,8 +1060,7 @@ def test_symbolic_domain_sizes(unstructured_offset_provider): ) -@pytest.mark.parametrize("allow_uninferred", [True, False]) -def test_unknown_domain(offset_provider, allow_uninferred: bool): +def test_unknown_domain(offset_provider): stencil = im.lambda_("arg0", "arg1")(im.deref(im.shift("Ioff", im.deref("arg1"))("arg0"))) domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 10)}) expected_domains = { @@ -1069,11 +1068,7 @@ def test_unknown_domain(offset_provider, allow_uninferred: bool): "in_field2": {IDim: (0, 10)}, } testee, expected = setup_test_as_fieldop(stencil, domain) - if allow_uninferred: - run_test_expr(testee, expected, domain, expected_domains, offset_provider, None, True) - else: - with pytest.raises(ValueError, match="Dynamic shifts not allowed"): - run_test_expr(testee, expected, domain, expected_domains, offset_provider, None, False) + run_test_expr(testee, expected, domain, expected_domains, offset_provider, None) def test_never_accessed_domain(offset_provider): From ff00b55c775c18041e05dfd11772e92db166ff9e Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 27 Nov 2024 22:57:26 +0100 Subject: [PATCH 103/150] Small fix --- .../iterator_tests/transforms_tests/test_domain_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py index 4883718744..7a7a307901 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py @@ -1068,7 +1068,7 @@ def test_unknown_domain(offset_provider): "in_field2": {IDim: (0, 10)}, } testee, expected = setup_test_as_fieldop(stencil, domain) - run_test_expr(testee, expected, domain, expected_domains, offset_provider, None) + run_test_expr(testee, expected, domain, expected_domains, offset_provider) def test_never_accessed_domain(offset_provider): From 34d604034dd6bf52d7f2ac07164f5a24d875cd18 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sun, 1 Dec 2024 19:58:08 +0100 Subject: [PATCH 104/150] Non-tree-size-increasing collapse tuple on ifs --- .../iterator/transforms/collapse_tuple.py | 152 +++++++++++++++--- .../next/iterator/transforms/pass_manager.py | 14 +- .../next/iterator/type_system/inference.py | 76 ++++++--- .../iterator_tests/test_type_inference.py | 19 +++ .../transforms_tests/test_collapse_tuple.py | 50 +++++- 5 files changed, 269 insertions(+), 42 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index e71a24127f..b71470f3e7 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -28,10 +28,11 @@ from gt4py.next.type_system import type_info, type_specifications as ts -def _with_altered_arg(node: ir.FunCall, arg_idx: int, new_arg: ir.Expr): +def _with_altered_arg(node: ir.FunCall, arg_idx: int, new_arg: ir.Expr | str): """Given a itir.FunCall return a new call with one of its argument replaced.""" return ir.FunCall( - fun=node.fun, args=[arg if i != arg_idx else new_arg for i, arg in enumerate(node.args)] + fun=node.fun, + args=[arg if i != arg_idx else im.ensure_expr(new_arg) for i, arg in enumerate(node.args)], ) @@ -47,6 +48,32 @@ def _is_trivial_make_tuple_call(node: ir.Expr): return True +def _is_trivial_or_tuple_thereof_expr(node: ir.Node) -> bool: + """ + Return `true` if the expr is a trivial expression or tuple thereof. + + >>> _is_trivial_or_tuple_thereof_expr(im.make_tuple("a", "b")) + True + >>> _is_trivial_or_tuple_thereof_expr(im.tuple_get(1, "a")) + True + >>> _is_trivial_or_tuple_thereof_expr( + ... im.let("t", im.make_tuple("a", "b"))(im.tuple_get(1, "t")) + ... ) + True + """ + if cpm.is_call_to(node, "make_tuple"): + return all(_is_trivial_or_tuple_thereof_expr(arg) for arg in node.args) + if cpm.is_call_to(node, "tuple_get"): + return _is_trivial_or_tuple_thereof_expr(node.args[1]) + if isinstance(node, (ir.SymRef, ir.Literal)): + return True + if cpm.is_let(node): + return _is_trivial_or_tuple_thereof_expr(node.fun.expr) and all( # type: ignore[attr-defined] # ensured by is_let + _is_trivial_or_tuple_thereof_expr(arg) for arg in node.args + ) + return False + + # TODO(tehrengruber): Conceptually the structure of this pass makes sense: Visit depth first, # transform each node until no transformations apply anymore, whenever a node is to be transformed # go through all available transformation and apply them. However the final result here still @@ -76,28 +103,41 @@ class Flag(enum.Flag): #: `let(tup, {trivial_expr1, trivial_expr2})(foo(tup))` #: -> `foo({trivial_expr1, trivial_expr2})` INLINE_TRIVIAL_MAKE_TUPLE = enum.auto() + #: Similar as `PROPAGATE_TO_IF_ON_TUPLES`, but propagates in the opposite direction, e.g. + #: into the tree, allowing removal if tuple expressions accross `if_` calls without + #: increasing the size of the tree. This is particullary important for `if` statements + #: in the frontend, where outwards propagation can have devestating effects on the tree + #: size. In particular boundary conditions inside of scans, e.g. something like: + #: ``` + #: if level == 0 then + #: make_tuple("a", "b") + #: else + #: if level == 1 then + #: make_tuple("b", "c") + #: ` else + #: make_tuple("d", "e") + #: ``` + #: is problematic, since `PROPAGATE_TO_IF_ON_TUPLES` would also generate a branch for + #: the condition `level == 0` and `level == 1` which can never occur. Note that this + #: transformation is not mutally exclusive to `PROPAGATE_TO_IF_ON_TUPLES`. + PROPAGATE_TO_IF_ON_TUPLES_CPS = enum.auto() #: `(if cond then {1, 2} else {3, 4})[0]` -> `if cond then {1, 2}[0] else {3, 4}[0]` PROPAGATE_TO_IF_ON_TUPLES = enum.auto() #: `let((a, let(b, 1)(a_val)))(a)`-> `let(b, 1)(let(a, a_val)(a))` PROPAGATE_NESTED_LET = enum.auto() - #: `let(a, 1)(a)` -> `1` + #: `let(a, 1)(a)` -> `1` or `let(a, b)(f(a))` -> `f(a)` INLINE_TRIVIAL_LET = enum.auto() @classmethod def all(self) -> CollapseTuple.Flag: return functools.reduce(operator.or_, self.__members__.values()) + uids: eve_utils.UIDGenerator ignore_tuple_size: bool flags: Flag = Flag.all() # noqa: RUF009 [function-call-in-dataclass-default-argument] PRESERVED_ANNEX_ATTRS = ("type",) - # we use one UID generator per instance such that the generated ids are - # stable across multiple runs (required for caching to properly work) - _letify_make_tuple_uids: eve_utils.UIDGenerator = dataclasses.field( - init=False, repr=False, default_factory=lambda: eve_utils.UIDGenerator(prefix="_tuple_el") - ) - @classmethod def apply( cls, @@ -111,6 +151,7 @@ def apply( flags: Optional[Flag] = None, # allow sym references without a symbol declaration, mostly for testing allow_undeclared_symbols: bool = False, + uids: Optional[eve_utils.UIDGenerator] = None, ) -> ir.Node: """ Simplifies `make_tuple`, `tuple_get` calls. @@ -127,6 +168,7 @@ def apply( """ flags = flags or cls.flags offset_provider_type = offset_provider_type or {} + uids = uids or eve_utils.UIDGenerator() if isinstance(node, (ir.Program, ir.FencilDefinition)): within_stencil = False @@ -145,6 +187,7 @@ def apply( new_node = cls( ignore_tuple_size=ignore_tuple_size, flags=flags, + uids=uids, ).visit(node, within_stencil=within_stencil) # inline to remove left-overs from LETIFY_MAKE_TUPLE_ELEMENTS. this is important @@ -185,6 +228,8 @@ def transform(self, node: ir.Node, **kwargs) -> Optional[ir.Node]: method = getattr(self, f"transform_{transformation.name.lower()}") result = method(node, **kwargs) if result is not None: + assert result is not node + itir_type_inference.reinfer(result) return result return None @@ -263,13 +308,13 @@ def transform_letify_make_tuple_elements(self, node: ir.FunCall, **kwargs) -> Op if node.fun == ir.SymRef(id="make_tuple"): # `make_tuple(expr1, expr1)` # -> `let((_tuple_el_1, expr1), (_tuple_el_2, expr2))(make_tuple(_tuple_el_1, _tuple_el_2))` - bound_vars: dict[str, ir.Expr] = {} + bound_vars: dict[ir.Sym, ir.Expr] = {} new_args: list[ir.Expr] = [] for arg in node.args: if cpm.is_call_to(node, "make_tuple") and not _is_trivial_make_tuple_call(node): - el_name = self._letify_make_tuple_uids.sequential_id() - new_args.append(im.ref(el_name)) - bound_vars[el_name] = arg + el_name = self.uids.sequential_id(prefix="__ct_el") + new_args.append(im.ref(el_name, arg.type)) + bound_vars[im.sym(el_name, arg.type)] = arg else: new_args.append(arg) @@ -312,6 +357,73 @@ def transform_propagate_to_if_on_tuples(self, node: ir.FunCall, **kwargs) -> Opt return im.if_(cond, new_true_branch, new_false_branch) return None + def transform_propagate_to_if_on_tuples_cps( + self, node: ir.FunCall, **kwargs + ) -> Optional[ir.Node]: + if not cpm.is_call_to(node, "if_"): + for i, arg in enumerate(node.args): + if cpm.is_call_to(arg, "if_"): + itir_type_inference.reinfer(arg) + if not any(isinstance(branch.type, ts.TupleType) for branch in arg.args[1:]): + continue + + cond, true_branch, false_branch = arg.args + tuple_type: ts.TupleType = true_branch.type # type: ignore[assignment] # type ensured above + tuple_len = len(tuple_type.types) + itir_type_inference.reinfer(node) + assert node.type + + # transform function into continuation-passing-style + f_type = ts.FunctionType( + pos_only_args=tuple_type.types, + pos_or_kw_args={}, + kw_only_args={}, + returns=node.type, + ) + f_params = [ + im.sym(self.uids.sequential_id(prefix="__ct_el_cps"), type_) + for type_ in tuple_type.types + ] + f_args = [im.ref(param.id, param.type) for param in f_params] + f_body = _with_altered_arg(node, i, im.make_tuple(*f_args)) + # simplify, e.g., inline trivial make_tuple args + new_f_body = self.fp_transform(f_body, **kwargs) + # if the function did not simplify there is nothing to gain. Skip + # transformation. + if new_f_body is f_body: + continue + # if the function is not trivial the transformation would still work, but + # inlining would result in a larger tree again and we didn't didn't gain + # anything compared to regular `propagate_to_if_on_tuples`. Not inling also + # works, but we don't want bound lambda functions in our tree (at least right + # now). + if not _is_trivial_or_tuple_thereof_expr(new_f_body): + continue + f = im.lambda_(*f_params)(new_f_body) + + tuple_var = self.uids.sequential_id(prefix="__ct_tuple_cps") + f_var = self.uids.sequential_id(prefix="__ct_cont") + new_branches = [] + for branch in arg.args[1:]: + new_branch = im.let(tuple_var, branch)( + im.call(im.ref(f_var, f_type))( + *( + im.tuple_get(i, im.ref(tuple_var, branch.type)) + for i in range(tuple_len) + ) + ) + ) + new_branches.append(self.fp_transform(new_branch, **kwargs)) + + new_node = im.let(f_var, f)(im.if_(cond, *new_branches)) + new_node = inline_lambda(new_node, eligible_params=[True]) + assert cpm.is_call_to(new_node, "if_") + new_node = im.if_( + cond, *(self.fp_transform(branch, **kwargs) for branch in new_node.args[1:]) + ) + return new_node + return None + def transform_propagate_nested_let(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: if cpm.is_let(node): # `let((a, let(b, 1)(a_val)))(a)`-> `let(b, 1)(let(a, a_val)(a))` @@ -339,9 +451,13 @@ def transform_propagate_nested_let(self, node: ir.FunCall, **kwargs) -> Optional return None def transform_inline_trivial_let(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: - if cpm.is_let(node) and isinstance(node.fun.expr, ir.SymRef): # type: ignore[attr-defined] # ensured by is_let - # `let(a, 1)(a)` -> `1` - for arg_sym, arg in zip(node.fun.params, node.args): # type: ignore[attr-defined] # ensured by is_let - if isinstance(node.fun.expr, ir.SymRef) and node.fun.expr.id == arg_sym.id: # type: ignore[attr-defined] # ensured by is_let - return arg + if cpm.is_let(node): + if isinstance(node.fun.expr, ir.SymRef): # type: ignore[attr-defined] # ensured by is_let + # `let(a, 1)(a)` -> `1` + for arg_sym, arg in zip(node.fun.params, node.args): # type: ignore[attr-defined] # ensured by is_let + if isinstance(node.fun.expr, ir.SymRef) and node.fun.expr.id == arg_sym.id: # type: ignore[attr-defined] # ensured by is_let + return arg + if any(trivial_args := [isinstance(arg, (ir.SymRef, ir.Literal)) for arg in node.args]): + return inline_lambda(node, eligible_params=trivial_args) + return None diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index ec6f89685a..6501e7436b 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -69,6 +69,7 @@ def apply_common_transforms( tmp_uids = eve_utils.UIDGenerator(prefix="__tmp") mergeasfop_uids = eve_utils.UIDGenerator() + collapse_tuple_uids = eve_utils.UIDGenerator() ir = MergeLet().visit(ir) ir = inline_fundefs.InlineFundefs().visit(ir) @@ -80,7 +81,9 @@ def apply_common_transforms( # Inline. The domain inference can not handle "user" functions, e.g. `let f = λ(...) → ... in f(...)` ir = InlineLambdas.apply(ir, opcount_preserving=True, force_inline_lambda_args=True) # required in order to get rid of expressions without a domain (e.g. when a tuple element is never accessed) - ir = CollapseTuple.apply(ir, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program + ir = CollapseTuple.apply( + ir, uids=collapse_tuple_uids, offset_provider_type=offset_provider_type + ) # type: ignore[assignment] # always an itir.Program ir = infer_domain.infer_program( ir, # type: ignore[arg-type] # always an itir.Program offset_provider=offset_provider, @@ -94,7 +97,9 @@ def apply_common_transforms( inlined = ConstantFolding.apply(inlined) # type: ignore[assignment] # always an itir.Program # This pass is required to be in the loop such that when an `if_` call with tuple arguments # is constant-folded the surrounding tuple_get calls can be removed. - inlined = CollapseTuple.apply(inlined, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program + inlined = CollapseTuple.apply( + inlined, uids=collapse_tuple_uids, offset_provider_type=offset_provider_type + ) # type: ignore[assignment] # always an itir.Program inlined = InlineScalar.apply(inlined, offset_provider_type=offset_provider_type) # This pass is required to run after CollapseTuple as otherwise we can not inline @@ -126,7 +131,10 @@ def apply_common_transforms( # only run the unconditional version here instead of in the loop above. if unconditionally_collapse_tuples: ir = CollapseTuple.apply( - ir, ignore_tuple_size=True, offset_provider_type=offset_provider_type + ir, + ignore_tuple_size=True, + uids=collapse_tuple_uids, + offset_provider_type=offset_provider_type, ) # type: ignore[assignment] # always an itir.Program ir = NormalizeShifts().visit(ir) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 249019769b..026dfedbd7 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -289,7 +289,9 @@ def type_synthesizer(*args, **kwargs): assert type_info.accepts_args(fun_type, with_args=list(args), with_kwargs=kwargs) return fun_type.returns - return type_synthesizer + return ObservableTypeSynthesizer( + type_synthesizer=type_synthesizer, store_inferred_type_in_node=False + ) class SanitizeTypes(eve.NodeTranslator, eve.VisitorWithSymbolTableTrait): @@ -309,6 +311,15 @@ def visit_Node(self, node: itir.Node, *, symtable: dict[str, itir.Node]) -> itir T = TypeVar("T", bound=itir.Node) +_INITIAL_CONTEXT = { + name: ObservableTypeSynthesizer( + type_synthesizer=type_synthesizer.builtin_type_synthesizers[name], + # builtin functions are polymorphic + store_inferred_type_in_node=False, + ) + for name in type_synthesizer.builtin_type_synthesizers.keys() +} + @dataclasses.dataclass class ITIRTypeInference(eve.NodeTranslator): @@ -320,11 +331,13 @@ class ITIRTypeInference(eve.NodeTranslator): PRESERVED_ANNEX_ATTRS = ("domain",) - offset_provider_type: common.OffsetProviderType + offset_provider_type: Optional[common.OffsetProviderType] #: Mapping from a dimension name to the actual dimension instance. - dimensions: dict[str, common.Dimension] + dimensions: Optional[dict[str, common.Dimension]] #: Allow sym refs to symbols that have not been declared. Mostly used in testing. allow_undeclared_symbols: bool + #: Reinference-mode skipping already typed nodes. + reinfer: bool @classmethod def apply( @@ -345,7 +358,7 @@ def apply( offset_provider_type: Offset provider dictionary. inplace: Write types directly to the given ``node`` instead of returning a copy. allow_undeclared_symbols: Allow references to symbols that don't have a corresponding - declaration. This is useful for testing or inference on partially inferred sub-nodes. + declaration. This is useful for testing or inference on reinferly inferred sub-nodes. Preconditions: @@ -417,24 +430,44 @@ def apply( ) ), allow_undeclared_symbols=allow_undeclared_symbols, + reinfer=False, ) if not inplace: node = copy.deepcopy(node) - instance.visit( - node, - ctx={ - name: ObservableTypeSynthesizer( - type_synthesizer=type_synthesizer.builtin_type_synthesizers[name], - # builtin functions are polymorphic - store_inferred_type_in_node=False, - ) - for name in type_synthesizer.builtin_type_synthesizers.keys() - }, + instance.visit(node, ctx=_INITIAL_CONTEXT) + return node + + @classmethod + def apply_reinfer(cls, node: T) -> T: + """ + Given a partially typed node infer the type of ``node`` and its sub-nodes. + + Contrary to the regular inference, this method does not descend into already typed sub-nodes + and can be used as a lightweight way to restore type information during a pass. + + Note that this function is stateful, which is usually desired, and more performant. + + Arguments: + node: The :class:`itir.Node` to infer the types of. + """ + if node.type: # already inferred + return node + + instance = cls( + offset_provider_type=None, dimensions=None, allow_undeclared_symbols=True, reinfer=True ) + instance.visit(node, ctx=_INITIAL_CONTEXT) return node def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: + # we found a node that is typed, do not descend into children + if self.reinfer and isinstance(node, itir.Node) and node.type: + if isinstance(node.type, ts.FunctionType): + return _type_synthesizer_from_function_type(node.type) + return node.type + result = super().visit(node, **kwargs) + if isinstance(node, itir.Node): if isinstance(result, ts.TypeSpec): if node.type and not isinstance(node.type, ts.DeferredType): @@ -561,19 +594,22 @@ def visit_StencilClosure(self, node: itir.StencilClosure, *, ctx) -> it_ts.Stenc ) def visit_AxisLiteral(self, node: itir.AxisLiteral, **kwargs) -> ts.DimensionType: - assert ( - node.value in self.dimensions - ), f"Dimension {node.value} not present in offset provider." - return ts.DimensionType(dim=self.dimensions[node.value]) + return ts.DimensionType(dim=common.Dimension(value=node.value, kind=node.kind)) # TODO: revisit what we want to do with OffsetLiterals as we already have an Offset type in # the frontend. - def visit_OffsetLiteral(self, node: itir.OffsetLiteral, **kwargs) -> it_ts.OffsetLiteralType: + def visit_OffsetLiteral( + self, node: itir.OffsetLiteral, **kwargs + ) -> it_ts.OffsetLiteralType | ts.DeferredType: + if self.reinfer: + return ts.DeferredType(constraint=it_ts.OffsetLiteralType) + if _is_representable_as_int(node.value): return it_ts.OffsetLiteralType( value=ts.ScalarType(kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper())) ) else: + assert isinstance(self.dimensions, dict) assert isinstance(node.value, str) and node.value in self.dimensions return it_ts.OffsetLiteralType(value=self.dimensions[node.value]) @@ -650,3 +686,5 @@ def visit_Node(self, node: itir.Node, **kwargs): infer = ITIRTypeInference.apply + +reinfer = ITIRTypeInference.apply_reinfer diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 65a5b5888d..6c58de7650 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -5,6 +5,7 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +import copy # TODO: test failure when something is not typed after inference is run # TODO: test lift with no args @@ -534,3 +535,21 @@ def test_as_fieldop_without_domain(): assert result.fun.args[0].type.pos_only_args[0] == it_ts.IteratorType( position_dims="unknown", defined_dims=float_i_field.dims, element_type=float_i_field.dtype ) + + +def test_reinference(): + testee = im.make_tuple(im.ref("inp1", float_i_field), im.ref("inp2", float_i_field)) + result = itir_type_inference.reinfer(copy.deepcopy(testee)) + assert result.type == ts.TupleType(types=[float_i_field, float_i_field]) + + +def test_func_reinference(): + f_type = ts.FunctionType( + pos_only_args=[], + pos_or_kw_args={}, + kw_only_args={}, + returns=float_i_field, + ) + testee = im.call(im.ref("f", f_type))() + result = itir_type_inference.reinfer(copy.deepcopy(testee)) + assert result.type == float_i_field diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py index 28090ff1e2..2212dfb6e1 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py @@ -9,6 +9,7 @@ from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple from gt4py.next.type_system import type_specifications as ts +from tests.next_tests.unit_tests.iterator_tests.test_type_inference import int_type def test_simple_make_tuple_tuple_get(): @@ -127,8 +128,8 @@ def test_letify_make_tuple_elements(): # anything that is not trivial, i.e. a SymRef, works here el1, el2 = im.let("foo", "foo")("foo"), im.let("bar", "bar")("bar") testee = im.make_tuple(el1, el2) - expected = im.let(("_tuple_el_1", el1), ("_tuple_el_2", el2))( - im.make_tuple("_tuple_el_1", "_tuple_el_2") + expected = im.let(("__ct_el_1", el1), ("__ct_el_2", el2))( + im.make_tuple("__ct_el_1", "__ct_el_2") ) actual = CollapseTuple.apply( @@ -239,3 +240,48 @@ def test_tuple_get_on_untyped_ref(): actual = CollapseTuple.apply(testee, allow_undeclared_symbols=True, within_stencil=False) assert actual == testee + + +def test_if_make_tuple_reorder_cps(): + testee = im.let("t", im.if_(True, im.make_tuple(1, 2), im.make_tuple(3, 4)))( + im.make_tuple(im.tuple_get(1, "t"), im.tuple_get(0, "t")) + ) + expected = im.if_(True, im.make_tuple(2, 1), im.make_tuple(4, 3)) + actual = CollapseTuple.apply( + testee, + flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, + allow_undeclared_symbols=True, + within_stencil=False, + ) + assert actual == expected + + +def test_if_make_tuple_reorder_cps_nested(): + testee = im.let("t", im.if_(True, im.make_tuple(1, 2), im.make_tuple(3, 4)))( + im.let("c", im.tuple_get(0, "t"))( + im.make_tuple(im.tuple_get(1, "t"), im.tuple_get(0, "t"), "c") + ) + ) + expected = im.if_(True, im.make_tuple(2, 1, 1), im.make_tuple(4, 3, 3)) + actual = CollapseTuple.apply( + testee, + flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, + allow_undeclared_symbols=True, + within_stencil=False, + ) + assert actual == expected + + +def test_if_make_tuple_reorder_cps_external(): + external_ref = im.tuple_get(0, im.ref("external", ts.TupleType(types=[int_type]))) + testee = im.let("t", im.if_(True, im.make_tuple(1, 2), im.make_tuple(3, 4)))( + im.make_tuple(external_ref, im.tuple_get(1, "t"), im.tuple_get(0, "t")) + ) + expected = im.if_(True, im.make_tuple(external_ref, 2, 1), im.make_tuple(external_ref, 4, 3)) + actual = CollapseTuple.apply( + testee, + flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, + allow_undeclared_symbols=True, + within_stencil=False, + ) + assert actual == expected From 48abc08c252de74877207f6b536a4b42be10a7cf Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sun, 1 Dec 2024 19:59:44 +0100 Subject: [PATCH 105/150] Fix typos --- src/gt4py/next/iterator/transforms/collapse_tuple.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index b71470f3e7..329d6363a0 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -103,8 +103,8 @@ class Flag(enum.Flag): #: `let(tup, {trivial_expr1, trivial_expr2})(foo(tup))` #: -> `foo({trivial_expr1, trivial_expr2})` INLINE_TRIVIAL_MAKE_TUPLE = enum.auto() - #: Similar as `PROPAGATE_TO_IF_ON_TUPLES`, but propagates in the opposite direction, e.g. - #: into the tree, allowing removal if tuple expressions accross `if_` calls without + #: Similar as `PROPAGATE_TO_IF_ON_TUPLES`, but propagates in the opposite direction, i.e. + #: into the tree, allowing removal of tuple expressions accross `if_` calls without #: increasing the size of the tree. This is particullary important for `if` statements #: in the frontend, where outwards propagation can have devestating effects on the tree #: size. In particular boundary conditions inside of scans, e.g. something like: From 3790944aa52ccb5def937a7dea1eed8aafcf826b Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sun, 1 Dec 2024 20:02:39 +0100 Subject: [PATCH 106/150] Fix typos --- src/gt4py/next/iterator/type_system/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 026dfedbd7..bc437ba44c 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -358,7 +358,7 @@ def apply( offset_provider_type: Offset provider dictionary. inplace: Write types directly to the given ``node`` instead of returning a copy. allow_undeclared_symbols: Allow references to symbols that don't have a corresponding - declaration. This is useful for testing or inference on reinferly inferred sub-nodes. + declaration. This is useful for testing or inference on partially inferred sub-nodes. Preconditions: From a8a63bf69d9fb5cd0a8d8e61aa78c9bafc0eef81 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sun, 1 Dec 2024 20:05:49 +0100 Subject: [PATCH 107/150] Disable PROPAGATE_TO_IF_ON_TUPLES by default in pass manager --- src/gt4py/next/iterator/transforms/pass_manager.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 6501e7436b..a7aa59b9b4 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -82,7 +82,10 @@ def apply_common_transforms( ir = InlineLambdas.apply(ir, opcount_preserving=True, force_inline_lambda_args=True) # required in order to get rid of expressions without a domain (e.g. when a tuple element is never accessed) ir = CollapseTuple.apply( - ir, uids=collapse_tuple_uids, offset_provider_type=offset_provider_type + ir, + flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, + uids=collapse_tuple_uids, + offset_provider_type=offset_provider_type ) # type: ignore[assignment] # always an itir.Program ir = infer_domain.infer_program( ir, # type: ignore[arg-type] # always an itir.Program @@ -98,7 +101,10 @@ def apply_common_transforms( # This pass is required to be in the loop such that when an `if_` call with tuple arguments # is constant-folded the surrounding tuple_get calls can be removed. inlined = CollapseTuple.apply( - inlined, uids=collapse_tuple_uids, offset_provider_type=offset_provider_type + inlined, + flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, + uids=collapse_tuple_uids, + offset_provider_type=offset_provider_type ) # type: ignore[assignment] # always an itir.Program inlined = InlineScalar.apply(inlined, offset_provider_type=offset_provider_type) @@ -172,7 +178,9 @@ def apply_fieldview_transforms( ir = inline_fundefs.prune_unreferenced_fundefs(ir) ir = InlineLambdas.apply(ir, opcount_preserving=True, force_inline_lambda_args=True) ir = CollapseTuple.apply( - ir, offset_provider_type=common.offset_provider_to_type(offset_provider) + ir, + flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, + offset_provider_type=common.offset_provider_to_type(offset_provider) ) # type: ignore[assignment] # type is still `itir.Program` ir = infer_domain.infer_program(ir, offset_provider=offset_provider) return ir From 2c44ffc81bd637ab5edab0070582a27c0dbe6506 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sun, 1 Dec 2024 20:25:36 +0100 Subject: [PATCH 108/150] Improve doc --- .../iterator/transforms/collapse_tuple.py | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 329d6363a0..2e5e2cd3f5 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -107,19 +107,20 @@ class Flag(enum.Flag): #: into the tree, allowing removal of tuple expressions accross `if_` calls without #: increasing the size of the tree. This is particullary important for `if` statements #: in the frontend, where outwards propagation can have devestating effects on the tree - #: size. In particular boundary conditions inside of scans, e.g. something like: + #: size, without any gained optimization potential. For example #: ``` - #: if level == 0 then - #: make_tuple("a", "b") - #: else - #: if level == 1 then - #: make_tuple("b", "c") - #: ` else - #: make_tuple("d", "e") + #: complex_lambda(if cond + #: if cond + #: {...} + #: else: + #: {...} + #: else + #: {...}) #: ``` - #: is problematic, since `PROPAGATE_TO_IF_ON_TUPLES` would also generate a branch for - #: the condition `level == 0` and `level == 1` which can never occur. Note that this - #: transformation is not mutally exclusive to `PROPAGATE_TO_IF_ON_TUPLES`. + #: is problematic, since `PROPAGATE_TO_IF_ON_TUPLES` would propagate and hence duplicate + #: `complex_lambda` three times, while we only want to get rid of the tuple expressions + #: inside of the `if_`s. + #: Note that this transformation is not mutally exclusive to `PROPAGATE_TO_IF_ON_TUPLES`. PROPAGATE_TO_IF_ON_TUPLES_CPS = enum.auto() #: `(if cond then {1, 2} else {3, 4})[0]` -> `if cond then {1, 2}[0] else {3, 4}[0]` PROPAGATE_TO_IF_ON_TUPLES = enum.auto() From 42b58179dcb7ac61704ea7947ce5aad8018703a6 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sun, 1 Dec 2024 20:27:47 +0100 Subject: [PATCH 109/150] Improve typo --- src/gt4py/next/iterator/transforms/collapse_tuple.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 2e5e2cd3f5..7e7083ce47 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -104,7 +104,7 @@ class Flag(enum.Flag): #: -> `foo({trivial_expr1, trivial_expr2})` INLINE_TRIVIAL_MAKE_TUPLE = enum.auto() #: Similar as `PROPAGATE_TO_IF_ON_TUPLES`, but propagates in the opposite direction, i.e. - #: into the tree, allowing removal of tuple expressions accross `if_` calls without + #: into the tree, allowing removal of tuple expressions across `if_` calls without #: increasing the size of the tree. This is particullary important for `if` statements #: in the frontend, where outwards propagation can have devestating effects on the tree #: size, without any gained optimization potential. For example From 9da19a2752ad78a0ccb7d6906f40bc9c09140103 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sun, 1 Dec 2024 20:28:16 +0100 Subject: [PATCH 110/150] Improve typo --- src/gt4py/next/iterator/transforms/collapse_tuple.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 7e7083ce47..1210f0f651 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -105,7 +105,7 @@ class Flag(enum.Flag): INLINE_TRIVIAL_MAKE_TUPLE = enum.auto() #: Similar as `PROPAGATE_TO_IF_ON_TUPLES`, but propagates in the opposite direction, i.e. #: into the tree, allowing removal of tuple expressions across `if_` calls without - #: increasing the size of the tree. This is particullary important for `if` statements + #: increasing the size of the tree. This is particularly important for `if` statements #: in the frontend, where outwards propagation can have devestating effects on the tree #: size, without any gained optimization potential. For example #: ``` From bcd9e48528283c2ecbef9258e7e3d0a764ab4b1f Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sun, 1 Dec 2024 20:29:28 +0100 Subject: [PATCH 111/150] Improve typo --- src/gt4py/next/iterator/transforms/collapse_tuple.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 1210f0f651..a881707fbf 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -120,7 +120,7 @@ class Flag(enum.Flag): #: is problematic, since `PROPAGATE_TO_IF_ON_TUPLES` would propagate and hence duplicate #: `complex_lambda` three times, while we only want to get rid of the tuple expressions #: inside of the `if_`s. - #: Note that this transformation is not mutally exclusive to `PROPAGATE_TO_IF_ON_TUPLES`. + #: Note that this transformation is not mutaly exclusive to `PROPAGATE_TO_IF_ON_TUPLES`. PROPAGATE_TO_IF_ON_TUPLES_CPS = enum.auto() #: `(if cond then {1, 2} else {3, 4})[0]` -> `if cond then {1, 2}[0] else {3, 4}[0]` PROPAGATE_TO_IF_ON_TUPLES = enum.auto() From 0a212bd8b7bf826616b08067b540f908fcd4a83e Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sun, 1 Dec 2024 20:30:04 +0100 Subject: [PATCH 112/150] Improve typo --- src/gt4py/next/iterator/transforms/collapse_tuple.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index a881707fbf..eb94d30431 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -106,7 +106,7 @@ class Flag(enum.Flag): #: Similar as `PROPAGATE_TO_IF_ON_TUPLES`, but propagates in the opposite direction, i.e. #: into the tree, allowing removal of tuple expressions across `if_` calls without #: increasing the size of the tree. This is particularly important for `if` statements - #: in the frontend, where outwards propagation can have devestating effects on the tree + #: in the frontend, where outwards propagation can have devastating effects on the tree #: size, without any gained optimization potential. For example #: ``` #: complex_lambda(if cond From 70562fe2b254b1fa801e5e98aadada29e3cb633d Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sun, 1 Dec 2024 20:30:32 +0100 Subject: [PATCH 113/150] Fix typo --- src/gt4py/next/iterator/transforms/collapse_tuple.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index eb94d30431..593723eb13 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -120,7 +120,7 @@ class Flag(enum.Flag): #: is problematic, since `PROPAGATE_TO_IF_ON_TUPLES` would propagate and hence duplicate #: `complex_lambda` three times, while we only want to get rid of the tuple expressions #: inside of the `if_`s. - #: Note that this transformation is not mutaly exclusive to `PROPAGATE_TO_IF_ON_TUPLES`. + #: Note that this transformation is not mutually exclusive to `PROPAGATE_TO_IF_ON_TUPLES`. PROPAGATE_TO_IF_ON_TUPLES_CPS = enum.auto() #: `(if cond then {1, 2} else {3, 4})[0]` -> `if cond then {1, 2}[0] else {3, 4}[0]` PROPAGATE_TO_IF_ON_TUPLES = enum.auto() From 9cee650538bfe550b4e312d0f9bdbdde2695440d Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sun, 1 Dec 2024 20:33:26 +0100 Subject: [PATCH 114/150] Improve doc --- src/gt4py/next/iterator/transforms/collapse_tuple.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 593723eb13..7e2e5b127e 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -109,8 +109,8 @@ class Flag(enum.Flag): #: in the frontend, where outwards propagation can have devastating effects on the tree #: size, without any gained optimization potential. For example #: ``` - #: complex_lambda(if cond - #: if cond + #: complex_lambda(if cond1 + #: if cond2 #: {...} #: else: #: {...} From 43f5741a1dea458000508ab679db4d828af61098 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sun, 1 Dec 2024 20:38:01 +0100 Subject: [PATCH 115/150] Format --- src/gt4py/next/iterator/transforms/pass_manager.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index a7aa59b9b4..2ed6e93f2d 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -85,7 +85,7 @@ def apply_common_transforms( ir, flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, uids=collapse_tuple_uids, - offset_provider_type=offset_provider_type + offset_provider_type=offset_provider_type, ) # type: ignore[assignment] # always an itir.Program ir = infer_domain.infer_program( ir, # type: ignore[arg-type] # always an itir.Program @@ -104,7 +104,7 @@ def apply_common_transforms( inlined, flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, uids=collapse_tuple_uids, - offset_provider_type=offset_provider_type + offset_provider_type=offset_provider_type, ) # type: ignore[assignment] # always an itir.Program inlined = InlineScalar.apply(inlined, offset_provider_type=offset_provider_type) @@ -180,7 +180,7 @@ def apply_fieldview_transforms( ir = CollapseTuple.apply( ir, flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, - offset_provider_type=common.offset_provider_to_type(offset_provider) + offset_provider_type=common.offset_provider_to_type(offset_provider), ) # type: ignore[assignment] # type is still `itir.Program` ir = infer_domain.infer_program(ir, offset_provider=offset_provider) return ir From 7b37f1c0e01f07191ba268005948653a0b3af7a6 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sun, 1 Dec 2024 20:55:42 +0100 Subject: [PATCH 116/150] Fix type synthesizer for partially typed arithmetic ops --- src/gt4py/next/iterator/type_system/type_synthesizer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 5be9ed7438..5fc78a7c6f 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -94,6 +94,10 @@ def power(base: ts.ScalarType, exponent: ts.ScalarType) -> ts.ScalarType: @_register_builtin_type_synthesizer(fun_names=itir.BINARY_MATH_NUMBER_BUILTINS) def _(lhs: ts.ScalarType, rhs: ts.ScalarType) -> ts.ScalarType: + if isinstance(lhs, ts.DeferredType): + return rhs + if isinstance(rhs, ts.DeferredType): + return lhs assert lhs == rhs return lhs From a0341a6367bfd823d04f4600900943b1a622e3e5 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sun, 1 Dec 2024 21:56:22 +0100 Subject: [PATCH 117/150] Fix test --- .../unit_tests/iterator_tests/test_type_inference.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 6c58de7650..7557b0b5e0 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -16,6 +16,7 @@ import pytest +from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.type_system import ( @@ -323,8 +324,8 @@ def test_cartesian_fencil_definition(): def test_unstructured_fencil_definition(): mesh = simple_mesh() unstructured_domain = im.call("unstructured_domain")( - im.call("named_range")(itir.AxisLiteral(value="Vertex"), 0, 1), - im.call("named_range")(itir.AxisLiteral(value="KDim"), 0, 1), + im.call("named_range")(itir.AxisLiteral(value="Vertex", kind=common.DimensionKind.HORIZONTAL), 0, 1), + im.call("named_range")(itir.AxisLiteral(value="KDim", kind=common.DimensionKind.VERTICAL), 0, 1), ) testee = itir.FencilDefinition( From 5a892f3f43d9219d144948e6875c2349151ec0f7 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sun, 1 Dec 2024 21:59:56 +0100 Subject: [PATCH 118/150] Fix test --- .../iterator_tests/test_type_inference.py | 34 ++++++++++++++----- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 7557b0b5e0..305947ec77 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -82,7 +82,9 @@ def expression_test_cases(): (im.call("make_const_list")(True), it_ts.ListType(element_type=bool_type)), (im.call("list_get")(0, im.ref("l", it_ts.ListType(element_type=bool_type))), bool_type), ( - im.call("named_range")(itir.AxisLiteral(value="Vertex"), 0, 1), + im.call("named_range")( + itir.AxisLiteral(value="Vertex", kind=common.DimensionKind.HORIZONTAL), 0, 1 + ), it_ts.NamedRangeType(dim=Vertex), ), ( @@ -93,7 +95,9 @@ def expression_test_cases(): ), ( im.call("unstructured_domain")( - im.call("named_range")(itir.AxisLiteral(value="Vertex"), 0, 1) + im.call("named_range")( + itir.AxisLiteral(value="Vertex", kind=common.DimensionKind.HORIZONTAL), 0, 1 + ) ), it_ts.DomainType(dims=[Vertex]), ), @@ -159,8 +163,14 @@ def expression_test_cases(): im.call("as_fieldop")( im.lambda_("it")(im.deref(im.shift("V2E", 0)("it"))), im.call("unstructured_domain")( - im.call("named_range")(itir.AxisLiteral(value="Vertex"), 0, 1), - im.call("named_range")(itir.AxisLiteral(value="KDim"), 0, 1), + im.call("named_range")( + itir.AxisLiteral(value="Vertex", kind=common.DimensionKind.HORIZONTAL), + 0, + 1, + ), + im.call("named_range")( + itir.AxisLiteral(value="KDim", kind=common.DimensionKind.VERTICAL), 0, 1 + ), ), ) )(im.ref("inp", float_edge_k_field)), @@ -324,8 +334,12 @@ def test_cartesian_fencil_definition(): def test_unstructured_fencil_definition(): mesh = simple_mesh() unstructured_domain = im.call("unstructured_domain")( - im.call("named_range")(itir.AxisLiteral(value="Vertex", kind=common.DimensionKind.HORIZONTAL), 0, 1), - im.call("named_range")(itir.AxisLiteral(value="KDim", kind=common.DimensionKind.VERTICAL), 0, 1), + im.call("named_range")( + itir.AxisLiteral(value="Vertex", kind=common.DimensionKind.HORIZONTAL), 0, 1 + ), + im.call("named_range")( + itir.AxisLiteral(value="KDim", kind=common.DimensionKind.VERTICAL), 0, 1 + ), ) testee = itir.FencilDefinition( @@ -417,8 +431,12 @@ def test_function_definition(): def test_fencil_with_nb_field_input(): mesh = simple_mesh() unstructured_domain = im.call("unstructured_domain")( - im.call("named_range")(itir.AxisLiteral(value="Vertex"), 0, 1), - im.call("named_range")(itir.AxisLiteral(value="KDim"), 0, 1), + im.call("named_range")( + itir.AxisLiteral(value="Vertex", kind=common.DimensionKind.HORIZONTAL), 0, 1 + ), + im.call("named_range")( + itir.AxisLiteral(value="KDim", kind=common.DimensionKind.VERTICAL), 0, 1 + ), ) testee = itir.FencilDefinition( From 93688a8214edec772f3f8cea12b8adfe199f5884 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 2 Dec 2024 13:07:42 +0100 Subject: [PATCH 119/150] Improve field operator fusion --- .../iterator/transforms/collapse_tuple.py | 11 +- .../iterator/transforms/fuse_as_fieldop.py | 139 +++++++++++++---- .../inline_center_deref_lift_vars.py | 12 +- .../transforms/pass_manager_legacy.py | 2 +- .../transforms_tests/test_fuse_as_fieldop.py | 146 +++++++++++++++++- 5 files changed, 263 insertions(+), 47 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 7e2e5b127e..877e5cfc9e 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -185,11 +185,9 @@ def apply( allow_undeclared_symbols=allow_undeclared_symbols, ) - new_node = cls( - ignore_tuple_size=ignore_tuple_size, - flags=flags, - uids=uids, - ).visit(node, within_stencil=within_stencil) + new_node = cls(ignore_tuple_size=ignore_tuple_size, flags=flags, uids=uids).visit( + node, within_stencil=within_stencil + ) # inline to remove left-overs from LETIFY_MAKE_TUPLE_ELEMENTS. this is important # as otherwise two equal expressions containing a tuple will not be equal anymore @@ -252,6 +250,7 @@ def transform_collapse_make_tuple_tuple_get( # tuple argument differs, just continue with the rest of the tree return None + itir_type_inference.reinfer(first_expr) # type is needed so reinfer on-demand assert self.ignore_tuple_size or isinstance( first_expr.type, (ts.TupleType, ts.DeferredType) ) @@ -272,7 +271,7 @@ def transform_collapse_tuple_get_make_tuple( and isinstance(node.args[0], ir.Literal) ): # `tuple_get(i, make_tuple(e_0, e_1, ..., e_i, ..., e_N))` -> `e_i` - assert type_info.is_integer(node.args[0].type) + assert not node.args[0].type or type_info.is_integer(node.args[0].type) make_tuple_call = node.args[1] idx = int(node.args[0].value) assert idx < len( diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index 3099cf9f31..74ef61d4f6 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -18,6 +18,7 @@ inline_center_deref_lift_vars, inline_lambdas, inline_lifts, + merge_let, trace_shifts, ) from gt4py.next.iterator.type_system import ( @@ -83,7 +84,12 @@ def _inline_as_fieldop_arg( for inner_param, inner_arg in zip(stencil.params, inner_args, strict=True): if isinstance(inner_arg, itir.SymRef): - stencil_params.append(inner_param) + if inner_arg.id in extracted_args: + assert extracted_args[inner_arg.id] == inner_arg + alias = stencil_params[list(extracted_args.keys()).index(inner_arg.id)] + stencil_body = im.let(inner_param, im.ref(alias.id))(stencil_body) + else: + stencil_params.append(inner_param) extracted_args[inner_arg.id] = inner_arg elif isinstance(inner_arg, itir.Literal): # note: only literals, not all scalar expressions are required as it doesn't make sense @@ -151,24 +157,55 @@ def fuse_as_fieldop( new_param = stencil_param.id new_args = _merge_arguments(new_args, {new_param: arg}) - new_node = im.as_fieldop(im.lambda_(*new_args.keys())(new_stencil_body), domain)( - *new_args.values() - ) + stencil = im.lambda_(*new_args.keys())(new_stencil_body) # simplify stencil directly to keep the tree small - new_node = inline_center_deref_lift_vars.InlineCenterDerefLiftVars.apply( - new_node + new_stencil = inline_lambdas.InlineLambdas.apply( + stencil, opcount_preserving=True, force_inline_lift_args=False + ) + new_stencil = inline_center_deref_lift_vars.InlineCenterDerefLiftVars.apply( + new_stencil, is_stencil=True, uids=uids ) # to keep the tree small - new_node = inline_lambdas.InlineLambdas.apply( - new_node, opcount_preserving=True, force_inline_lift_args=True + new_stencil = merge_let.MergeLet().visit(new_stencil) + new_stencil = inline_lambdas.InlineLambdas.apply( + new_stencil, opcount_preserving=True, force_inline_lift_args=True ) - new_node = inline_lifts.InlineLifts().visit(new_node) + new_stencil = inline_lifts.InlineLifts().visit(new_stencil) - type_inference.copy_type(from_=expr, to=new_node) + new_node = im.as_fieldop(new_stencil, domain)(*new_args.values()) + type_inference.copy_type(from_=expr, to=new_node) return new_node +def _arg_inline_predicate(node: itir.Expr, shifts): + if _is_tuple_expr_of_literals(node): + return True + if (is_applied_fieldop := cpm.is_applied_as_fieldop(node)) or cpm.is_call_to(node, "if_"): + # always inline arg if it is an applied fieldop with only a single arg + if is_applied_fieldop and len(node.args) == 1: + return True + # argument is never used, will be removed when inlined + if len(shifts) == 0: + return True + # applied fieldop with list return type must always be inlined as no backend supports this + assert isinstance(node.type, ts.TypeSpec) + dtype = type_info.apply_to_primitive_constituents(type_info.extract_dtype, node.type) + if isinstance(dtype, it_ts.ListType): + return True + # only accessed at the center location + if shifts in [set(), {()}]: + return True + # TODO(tehrengruber): Disabled as the InlineCenterDerefLiftVars does not support this yet + # and it would increase the size of the tree otherwise. + # if len(shifts) == 1 and not any( + # trace_shifts.Sentinel.ALL_NEIGHBORS in access for access in shifts + # ): + # return True # noqa: ERA001 [commented-out-code] + + return False + + @dataclasses.dataclass class FuseAsFieldOp(eve.NodeTranslator): """ @@ -219,34 +256,76 @@ def apply( return cls(uids=uids).visit(node) - def visit_FunCall(self, node: itir.FunCall): - node = self.generic_visit(node) + def visit_FunCall(self, node: itir.FunCall, **kwargs): + if cpm.is_applied_as_fieldop(node): # don't descend in stencil + old_node = node + node = im.as_fieldop(*node.fun.args)(*self.generic_visit(node.args)) # type: ignore[attr-defined] # ensured by cpm.is_applied_as_fieldop + type_inference.copy_type(from_=old_node, to=node) + elif kwargs.get("recurse", True): + node = self.generic_visit(node, **kwargs) + + if cpm.is_call_to(node, "make_tuple"): + as_fieldop_args = [arg for arg in node.args if cpm.is_applied_as_fieldop(arg)] + distinct_domains = set(arg.fun.args[1] for arg in as_fieldop_args) # type: ignore[attr-defined] # ensured by cpm.is_applied_as_fieldop + if len(distinct_domains) != len(as_fieldop_args): + new_els: list[itir.Expr | None] = [None for _ in node.args] + as_fieldop_args_by_domain: dict[itir.Expr, list[tuple[int, itir.Expr]]] = {} + for i, arg in enumerate(node.args): + if cpm.is_applied_as_fieldop(arg): + assert arg.type + _, domain = arg.fun.args # type: ignore[attr-defined] # ensured by cpm.is_applied_as_fieldop + as_fieldop_args_by_domain.setdefault(domain, []) + as_fieldop_args_by_domain[domain].append((i, arg)) + else: + new_els[i] = arg # keep as is + let_vars = {} + for domain, inner_as_fieldop_args in as_fieldop_args_by_domain.items(): + if len(inner_as_fieldop_args) > 1: + var = self.uids.sequential_id(prefix="__fasfop") + fused_args = im.op_as_fieldop(lambda *args: im.make_tuple(*args), domain)( + *(arg for _, arg in inner_as_fieldop_args) + ) + fused_args.type = ts.TupleType( + types=[arg.type for _, arg in inner_as_fieldop_args] # type: ignore[misc] # has type is ensured on list creation + ) + # don't recurse into nested args, but only consider newly created `as_fieldop` + let_vars[var] = self.visit(fused_args, **{**kwargs, "recurse": False}) + for outer_tuple_idx, (inner_tuple_idx, _) in enumerate( + inner_as_fieldop_args + ): + new_els[inner_tuple_idx] = im.tuple_get(outer_tuple_idx, var) + else: + i, arg = inner_as_fieldop_args[0] + new_els[i] = arg + assert not any(el is None for el in new_els) + assert let_vars + new_node = im.let(*let_vars.items())(im.make_tuple(*new_els)) + new_node = inline_lambdas.inline_lambda(new_node, opcount_preserving=True) + return new_node if cpm.is_call_to(node.fun, "as_fieldop"): node = _canonicalize_as_fieldop(node) + # when multiple `as_fieldop` calls are fused that use the same argument, this argument + # might become referenced once only. In order to be able to continue fusing such arguments + # try inlining here. + if cpm.is_let(node): + new_node = inline_lambdas.inline_lambda(node, opcount_preserving=True) + if new_node is not node: # nothing has been inlined + return self.visit(new_node, **kwargs) + if cpm.is_call_to(node.fun, "as_fieldop") and isinstance(node.fun.args[0], itir.Lambda): stencil: itir.Lambda = node.fun.args[0] args: list[itir.Expr] = node.args shifts = trace_shifts.trace_stencil(stencil) - eligible_args = [] - for arg, arg_shifts in zip(args, shifts, strict=True): - assert isinstance(arg.type, ts.TypeSpec) - dtype = type_info.apply_to_primitive_constituents(type_info.extract_dtype, arg.type) - # TODO(tehrengruber): make this configurable - eligible_args.append( - _is_tuple_expr_of_literals(arg) - or ( - isinstance(arg, itir.FunCall) - and ( - cpm.is_call_to(arg.fun, "as_fieldop") - and isinstance(arg.fun.args[0], itir.Lambda) - or cpm.is_call_to(arg, "if_") - ) - and (isinstance(dtype, it_ts.ListType) or len(arg_shifts) <= 1) - ) + eligible_args = [ + _arg_inline_predicate(arg, arg_shifts) + for arg, arg_shifts in zip(args, shifts, strict=True) + ] + if any(eligible_args): + return self.visit( + fuse_as_fieldop(node, eligible_args, uids=self.uids), + **{**kwargs, "recurse": False}, ) - - return fuse_as_fieldop(node, eligible_args, uids=self.uids) return node diff --git a/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py b/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py index 95c761d7ba..9a94a4a338 100644 --- a/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py +++ b/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py @@ -7,7 +7,7 @@ # SPDX-License-Identifier: BSD-3-Clause import dataclasses -from typing import ClassVar, Optional +from typing import ClassVar, Optional, TypeVar import gt4py.next.iterator.ir_utils.common_pattern_matcher as cpm from gt4py import eve @@ -23,6 +23,9 @@ def is_center_derefed_only(node: itir.Node) -> bool: return hasattr(node.annex, "recorded_shifts") and node.annex.recorded_shifts in [set(), {()}] +T = TypeVar("T", bound=itir.Program | itir.Lambda) + + @dataclasses.dataclass class InlineCenterDerefLiftVars(eve.NodeTranslator): """ @@ -50,9 +53,14 @@ class InlineCenterDerefLiftVars(eve.NodeTranslator): uids: eve_utils.UIDGenerator @classmethod - def apply(cls, node: itir.Program, uids: Optional[eve_utils.UIDGenerator] = None): + def apply( + cls, node: T, *, is_stencil=False, uids: Optional[eve_utils.UIDGenerator] = None + ) -> T: if not uids: uids = eve_utils.UIDGenerator() + if is_stencil: + assert isinstance(node, itir.Lambda) + trace_shifts.trace_stencil(node, num_args=len(node.params), save_to_annex=True) return cls(uids=uids).visit(node) def visit_FunCall(self, node: itir.FunCall, **kwargs): diff --git a/src/gt4py/next/iterator/transforms/pass_manager_legacy.py b/src/gt4py/next/iterator/transforms/pass_manager_legacy.py index 94c962e92d..2717ae6924 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager_legacy.py +++ b/src/gt4py/next/iterator/transforms/pass_manager_legacy.py @@ -99,7 +99,7 @@ def apply_common_transforms( for _ in range(10): inlined = ir - inlined = InlineCenterDerefLiftVars.apply(inlined, uids=icdlv_uids) # type: ignore[arg-type] # always a fencil + inlined = InlineCenterDerefLiftVars.apply(inlined, uids=icdlv_uids) # type: ignore[type-var] # always a fencil inlined = _inline_lifts(inlined, lift_mode) inlined = InlineLambdas.apply( diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py index 168e9490e0..97be4411eb 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py @@ -10,11 +10,12 @@ from gt4py import next as gtx from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.iterator.transforms import fuse_as_fieldop +from gt4py.next.iterator.transforms import fuse_as_fieldop, collapse_tuple from gt4py.next.type_system import type_specifications as ts IDim = gtx.Dimension("IDim") +JDim = gtx.Dimension("JDim") field_type = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32)) @@ -46,6 +47,25 @@ def test_trivial_literal(): assert actual == expected +def test_trivial_same_arg_twice(): + d = im.domain("cartesian_domain", {IDim: (0, 1)}) + testee = im.op_as_fieldop("plus", d)( + # note: inp1 occurs twice here + im.op_as_fieldop("multiplies", d)(im.ref("inp1", field_type), im.ref("inp1", field_type)), + im.ref("inp2", field_type), + ) + expected = im.as_fieldop( + im.lambda_("inp1", "inp2")( + im.plus(im.multiplies_(im.deref("inp1"), im.deref("inp1")), im.deref("inp2")) + ), + d, + )(im.ref("inp1", field_type), im.ref("inp2", field_type)) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + assert actual == expected + + def test_tuple_arg(): d = im.domain("cartesian_domain", {}) testee = im.op_as_fieldop("plus", d)( @@ -99,19 +119,101 @@ def test_no_inline(): im.plus(im.deref(im.shift("IOff", 1)("a")), im.deref(im.shift("IOff", -1)("a"))) ), d1, - )(im.as_fieldop(im.lambda_("inp1")(im.deref("inp1")), d2)(im.ref("inp1", field_type))) + )(im.op_as_fieldop("plus", d2)(im.ref("inp1", field_type), im.ref("inp2", field_type))) actual = fuse_as_fieldop.FuseAsFieldOp.apply( testee, offset_provider_type={"IOff": IDim}, allow_undeclared_symbols=True ) assert actual == testee +def test_staged_inlining(): + d = im.domain("cartesian_domain", {IDim: (0, 1)}) + testee = im.let( + "tmp", im.op_as_fieldop("plus", d)(im.ref("a", field_type), im.ref("b", field_type)) + )( + im.op_as_fieldop("plus", d)( + im.op_as_fieldop(im.lambda_("a")(im.plus("a", 1)), d)("tmp"), + im.op_as_fieldop(im.lambda_("a")(im.plus("a", 2)), d)("tmp"), + ) + ) + expected = im.as_fieldop( + im.lambda_("a", "b")( + im.let("_icdlv_1", im.plus(im.deref("a"), im.deref("b")))( + im.plus(im.plus("_icdlv_1", 1), im.plus("_icdlv_1", 2)) + ) + ), + d, + )(im.ref("a", field_type), im.ref("b", field_type)) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + assert actual == expected + + +def test_make_tuple_fusion_trivial(): + d = im.domain("cartesian_domain", {IDim: (0, 1)}) + testee = im.make_tuple( + im.as_fieldop("deref", d)(im.ref("a", field_type)), + im.as_fieldop("deref", d)(im.ref("a", field_type)), + ) + expected = im.as_fieldop( + im.lambda_("a")(im.make_tuple(im.deref("a"), im.deref("a"))), + d, + )(im.ref("a", field_type)) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + # simplify to remove unnecessary make_tuple call `{v[0], v[1]}(actual)` + actual_simplified = collapse_tuple.CollapseTuple.apply( + actual, within_stencil=False, allow_undeclared_symbols=True + ) + assert actual_simplified == expected + + +def test_make_tuple_fusion_different_domains(): + d1 = im.domain("cartesian_domain", {IDim: (0, 1)}) + d2 = im.domain("cartesian_domain", {JDim: (0, 1)}) + field_i_type = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32)) + field_j_type = ts.FieldType(dims=[JDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32)) + testee = im.make_tuple( + im.as_fieldop("deref", d1)(im.ref("a", field_i_type)), + im.as_fieldop("deref", d2)(im.ref("b", field_j_type)), + im.as_fieldop("deref", d1)(im.ref("c", field_i_type)), + im.as_fieldop("deref", d2)(im.ref("d", field_j_type)), + ) + expected = im.let( + ( + "__fasfop_1", + im.as_fieldop(im.lambda_("a", "c")(im.make_tuple(im.deref("a"), im.deref("c"))), d1)( + "a", "c" + ), + ), + ( + "__fasfop_2", + im.as_fieldop(im.lambda_("b", "d")(im.make_tuple(im.deref("b"), im.deref("d"))), d2)( + "b", "d" + ), + ), + )( + im.make_tuple( + im.tuple_get(0, "__fasfop_1"), + im.tuple_get(0, "__fasfop_2"), + im.tuple_get(1, "__fasfop_1"), + im.tuple_get(1, "__fasfop_2"), + ) + ) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + assert actual == expected + + def test_partial_inline(): d1 = im.domain("cartesian_domain", {IDim: (1, 2)}) d2 = im.domain("cartesian_domain", {IDim: (0, 3)}) testee = im.as_fieldop( # first argument read at multiple locations -> not inlined - # second argument only reat at a single location -> inlined + # second argument only read at a single location -> inlined im.lambda_("a", "b")( im.plus( im.plus(im.deref(im.shift("IOff", 1)("a")), im.deref(im.shift("IOff", -1)("a"))), @@ -120,19 +222,47 @@ def test_partial_inline(): ), d1, )( - im.as_fieldop(im.lambda_("inp1")(im.deref("inp1")), d2)(im.ref("inp1", field_type)), - im.as_fieldop(im.lambda_("inp1")(im.deref("inp1")), d2)(im.ref("inp1", field_type)), + im.op_as_fieldop("plus", d2)(im.ref("inp1", field_type), im.ref("inp2", field_type)), + im.op_as_fieldop("plus", d2)(im.ref("inp1", field_type), im.ref("inp2", field_type)), ) expected = im.as_fieldop( - im.lambda_("a", "inp1")( + im.lambda_("a", "inp1", "inp2")( im.plus( im.plus(im.deref(im.shift("IOff", 1)("a")), im.deref(im.shift("IOff", -1)("a"))), - im.deref("inp1"), + im.plus(im.deref("inp1"), im.deref("inp2")), ) ), d1, - )(im.as_fieldop(im.lambda_("inp1")(im.deref("inp1")), d2)(im.ref("inp1", field_type)), "inp1") + )( + im.op_as_fieldop("plus", d2)(im.ref("inp1", field_type), im.ref("inp2", field_type)), + "inp1", + "inp2", + ) actual = fuse_as_fieldop.FuseAsFieldOp.apply( testee, offset_provider_type={"IOff": IDim}, allow_undeclared_symbols=True ) assert actual == expected + + +def test_chained_fusion(): + d = im.domain("cartesian_domain", {IDim: (0, 1)}) + testee = im.let( + "a", im.op_as_fieldop("plus", d)(im.ref("inp1", field_type), im.ref("inp2", field_type)) + )( + im.op_as_fieldop("plus", d)( + im.as_fieldop("deref", d)(im.ref("a", field_type)), + im.as_fieldop("deref", d)(im.ref("a", field_type)), + ) + ) + expected = im.as_fieldop( + im.lambda_("inp1", "inp2")( + im.let("_icdlv_1", im.plus(im.deref("inp1"), im.deref("inp2")))( + im.plus("_icdlv_1", "_icdlv_1") + ) + ), + d, + )(im.ref("inp1", field_type), im.ref("inp2", field_type)) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + assert actual == expected From 7c271cad395844c8281b46bb3a5e737e313a73fe Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 2 Dec 2024 13:10:52 +0100 Subject: [PATCH 120/150] Small fix --- src/gt4py/next/iterator/transforms/fuse_as_fieldop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index 3099cf9f31..e8a221b814 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -53,7 +53,7 @@ def _canonicalize_as_fieldop(expr: itir.FunCall) -> itir.FunCall: if cpm.is_ref_to(stencil, "deref"): stencil = im.lambda_("arg")(im.deref("arg")) new_expr = im.as_fieldop(stencil, domain)(*expr.args) - type_inference.copy_type(from_=expr, to=new_expr) + type_inference.copy_type(from_=expr, to=new_expr, allow_untyped=True) return new_expr @@ -164,7 +164,7 @@ def fuse_as_fieldop( ) new_node = inline_lifts.InlineLifts().visit(new_node) - type_inference.copy_type(from_=expr, to=new_node) + type_inference.copy_type(from_=expr, to=new_node, allow_untyped=True) return new_node From 22dbf5e47f8e622f8e0588807d85da199aa3fb94 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 2 Dec 2024 15:32:40 +0100 Subject: [PATCH 121/150] Fix inlining of multiple-use, dynamically-calculated neighbor field --- .../next/iterator/transforms/fuse_as_fieldop.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index 74ef61d4f6..afefa1dbd2 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -257,6 +257,18 @@ def apply( return cls(uids=uids).visit(node) def visit_FunCall(self, node: itir.FunCall, **kwargs): + # inline all fields with list dtype. This needs to happen before the children are visited + # such that the `as_fieldop` can be fused. + # TODO(tehrengruber): what should we do in case the field with list dtype is a let itself? + # This could duplicate other expressions which we did not intend to duplicate. + # TODO(tehrengruber): Write test-case. E.g. Adding two sparse fields. Sara observed this + # with a cast to a sparse field, but this is likely already covered. + if cpm.is_let(node): + eligible_args = [isinstance(arg.type, ts.FieldType) and isinstance(arg.type.dtype, it_ts.ListType) for arg in node.args] + if any(eligible_args): + node = inline_lambdas.inline_lambda(node, eligible_params=eligible_args) + return self.visit(node) + if cpm.is_applied_as_fieldop(node): # don't descend in stencil old_node = node node = im.as_fieldop(*node.fun.args)(*self.generic_visit(node.args)) # type: ignore[attr-defined] # ensured by cpm.is_applied_as_fieldop From bfacda67aa3ad3c752d64dd2d7727334f141cad5 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 3 Dec 2024 00:34:34 +0100 Subject: [PATCH 122/150] Fix broken iterator tests containing lifts --- src/gt4py/next/iterator/transforms/fuse_as_fieldop.py | 11 +++++++++-- src/gt4py/next/iterator/transforms/pass_manager.py | 4 ++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index afefa1dbd2..ecb0f35217 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -181,7 +181,11 @@ def fuse_as_fieldop( def _arg_inline_predicate(node: itir.Expr, shifts): if _is_tuple_expr_of_literals(node): return True - if (is_applied_fieldop := cpm.is_applied_as_fieldop(node)) or cpm.is_call_to(node, "if_"): + # TODO(tehrengruber): write test case ensuring scan is not tried to be inlined (e.g. test_call_scan_operator_from_field_operator) + if ( + is_applied_fieldop := cpm.is_applied_as_fieldop(node) + and not cpm.is_call_to(node.fun.args[0], "scan") # type: ignore[attr-defined] # ensured by cpm.is_applied_as_fieldop + ) or cpm.is_call_to(node, "if_"): # always inline arg if it is an applied fieldop with only a single arg if is_applied_fieldop and len(node.args) == 1: return True @@ -264,7 +268,10 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs): # TODO(tehrengruber): Write test-case. E.g. Adding two sparse fields. Sara observed this # with a cast to a sparse field, but this is likely already covered. if cpm.is_let(node): - eligible_args = [isinstance(arg.type, ts.FieldType) and isinstance(arg.type.dtype, it_ts.ListType) for arg in node.args] + eligible_args = [ + isinstance(arg.type, ts.FieldType) and isinstance(arg.type.dtype, it_ts.ListType) + for arg in node.args + ] if any(eligible_args): node = inline_lambdas.inline_lambda(node, eligible_params=eligible_args) return self.visit(node) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 32c2153e4b..4109a36539 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -78,6 +78,10 @@ def apply_common_transforms( ir = inline_fundefs.prune_unreferenced_fundefs(ir) # type: ignore[arg-type] # all previous passes return itir.Program ir = NormalizeShifts().visit(ir) + # TODO(tehrengruber): Many iterator test contain lifts that need to be inlined, e.g. + # test_can_deref. We didn't notice previously as FieldOpFusion did this implicitly everywhere. + ir = inline_lifts.InlineLifts().visit(ir) + # note: this increases the size of the tree # Inline. The domain inference can not handle "user" functions, e.g. `let f = λ(...) → ... in f(...)` ir = InlineLambdas.apply(ir, opcount_preserving=True, force_inline_lambda_args=True) From 92b753347d702f5bcecd74f5b3f7e6435aa8c7f5 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 3 Dec 2024 04:42:28 +0100 Subject: [PATCH 123/150] Support inlining into scan --- .../iterator/transforms/fuse_as_fieldop.py | 35 ++++++++++++++++--- .../inline_center_deref_lift_vars.py | 4 +-- .../next/iterator/transforms/trace_shifts.py | 10 +++--- .../transforms_tests/test_fuse_as_fieldop.py | 11 ++++++ 4 files changed, 50 insertions(+), 10 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index ecb0f35217..a570c719c7 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -109,12 +109,37 @@ def _inline_as_fieldop_arg( ), extracted_args +def _unwrap_scan(stencil: itir.Lambda | itir.FunCall): + if cpm.is_call_to(stencil, "scan"): + scan_pass, direction, init = stencil.args + assert isinstance(scan_pass, itir.Lambda) + # remove scan pass state to be used by caller + state_param = scan_pass.params[0] + stencil_like = im.lambda_(*scan_pass.params[1:])(scan_pass.expr) + + def restore_scan(transformed_stencil_like: itir.Lambda): + new_scan_pass = im.lambda_(state_param, *transformed_stencil_like.params)( + im.call(transformed_stencil_like)( + *(param.id for param in transformed_stencil_like.params) + ) + ) + return im.call("scan")(new_scan_pass, direction, init) + + return stencil_like, restore_scan + + assert isinstance(stencil, itir.Lambda) + return stencil, lambda s: s + + def fuse_as_fieldop( expr: itir.Expr, eligible_args: list[bool], *, uids: eve_utils.UIDGenerator ) -> itir.Expr: - assert cpm.is_applied_as_fieldop(expr) and isinstance(expr.fun.args[0], itir.Lambda) # type: ignore[attr-defined] # ensured by is_applied_as_fieldop + assert cpm.is_applied_as_fieldop(expr) stencil: itir.Lambda = expr.fun.args[0] # type: ignore[attr-defined] # ensured by is_applied_as_fieldop + assert isinstance(expr.fun.args[0], itir.Lambda) or cpm.is_call_to(stencil, "scan") # type: ignore[attr-defined] # ensured by is_applied_as_fieldop + stencil, restore_scan = _unwrap_scan(stencil) + domain = expr.fun.args[1] if len(expr.fun.args) > 1 else None # type: ignore[attr-defined] # ensured by is_applied_as_fieldop args: list[itir.Expr] = expr.args @@ -158,6 +183,7 @@ def fuse_as_fieldop( new_args = _merge_arguments(new_args, {new_param: arg}) stencil = im.lambda_(*new_args.keys())(new_stencil_body) + stencil = restore_scan(stencil) # simplify stencil directly to keep the tree small new_stencil = inline_lambdas.InlineLambdas.apply( @@ -333,10 +359,11 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs): if new_node is not node: # nothing has been inlined return self.visit(new_node, **kwargs) - if cpm.is_call_to(node.fun, "as_fieldop") and isinstance(node.fun.args[0], itir.Lambda): - stencil: itir.Lambda = node.fun.args[0] + if cpm.is_call_to(node.fun, "as_fieldop"): + stencil = node.fun.args[0] + assert isinstance(stencil, itir.Lambda) or cpm.is_call_to(stencil, "scan") args: list[itir.Expr] = node.args - shifts = trace_shifts.trace_stencil(stencil) + shifts = trace_shifts.trace_stencil(stencil, num_args=len(args)) eligible_args = [ _arg_inline_predicate(arg, arg_shifts) diff --git a/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py b/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py index 9a94a4a338..9169c26769 100644 --- a/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py +++ b/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py @@ -59,8 +59,8 @@ def apply( if not uids: uids = eve_utils.UIDGenerator() if is_stencil: - assert isinstance(node, itir.Lambda) - trace_shifts.trace_stencil(node, num_args=len(node.params), save_to_annex=True) + assert isinstance(node, itir.Expr) + trace_shifts.trace_stencil(node, save_to_annex=True) return cls(uids=uids).visit(node) def visit_FunCall(self, node: itir.FunCall, **kwargs): diff --git a/src/gt4py/next/iterator/transforms/trace_shifts.py b/src/gt4py/next/iterator/transforms/trace_shifts.py index 68346b6622..b2ab49deea 100644 --- a/src/gt4py/next/iterator/transforms/trace_shifts.py +++ b/src/gt4py/next/iterator/transforms/trace_shifts.py @@ -14,15 +14,14 @@ from gt4py import eve from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import ir -from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im class ValidateRecordedShiftsAnnex(eve.NodeVisitor): """Ensure every applied lift and its arguments have the `recorded_shifts` annex populated.""" def visit_FunCall(self, node: ir.FunCall): - if is_applied_lift(node): + if cpm.is_applied_lift(node): assert hasattr(node.annex, "recorded_shifts") if len(node.annex.recorded_shifts) == 0: @@ -334,8 +333,11 @@ def trace_stencil( if isinstance(stencil, ir.Lambda): assert num_args is None or num_args == len(stencil.params) num_args = len(stencil.params) + elif cpm.is_call_to(stencil, "scan"): + assert isinstance(stencil.args[0], ir.Lambda) + num_args = len(stencil.args[0].params) - 1 if not isinstance(num_args, int): - raise ValueError("Stencil must be an 'itir.Lambda' or `num_args` is given.") + raise ValueError("Stencil must be an 'itir.Lambda', scan, or `num_args` is given.") assert isinstance(num_args, int) args = [im.ref(f"__arg{i}") for i in range(num_args)] diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py index 97be4411eb..e9cb016313 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py @@ -266,3 +266,14 @@ def test_chained_fusion(): testee, offset_provider_type={}, allow_undeclared_symbols=True ) assert actual == expected + + +def test_inline_into_scan(): + d = im.domain("cartesian_domain", {IDim: (0, 1)}) + scan = im.call("scan")(im.lambda_("state", "a")(im.plus("state", im.deref("a"))), True, 0.0) + testee = im.as_fieldop(scan, d)(im.as_fieldop("deref")(im.ref("a", field_type))) + expected = im.as_fieldop(scan, d)(im.ref("a", field_type)) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + assert actual == expected From 1eeca75880c98c16493e661bb0b89a6467216e0b Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 3 Dec 2024 04:48:01 +0100 Subject: [PATCH 124/150] Small fix --- src/gt4py/next/iterator/transforms/fuse_as_fieldop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index a570c719c7..c987849717 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -54,7 +54,7 @@ def _canonicalize_as_fieldop(expr: itir.FunCall) -> itir.FunCall: if cpm.is_ref_to(stencil, "deref"): stencil = im.lambda_("arg")(im.deref("arg")) new_expr = im.as_fieldop(stencil, domain)(*expr.args) - type_inference.copy_type(from_=expr, to=new_expr) + type_inference.copy_type(from_=expr, to=new_expr, allow_untyped=True) return new_expr @@ -200,7 +200,7 @@ def fuse_as_fieldop( new_node = im.as_fieldop(new_stencil, domain)(*new_args.values()) - type_inference.copy_type(from_=expr, to=new_node) + type_inference.copy_type(from_=expr, to=new_node, allow_untyped=True) return new_node From e3306fc957d46d26c98158f899283f2e7428d6ef Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 5 Dec 2024 10:19:05 +0100 Subject: [PATCH 125/150] Address review comments. --- .../next/iterator/transforms/infer_domain.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index 3809fb43d7..f61b78c0ad 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -30,7 +30,7 @@ class DomainAccessDescriptor(eve.StrEnum): Descriptor for domains that could not be inferred. """ - #: The access if unknown because of a dynamic shift.whose extent is not known. + #: The access is unknown because of a dynamic shift.whose extent is not known. #: E.g.: `(⇑(λ(arg0, arg1) → ·⟪Ioffₒ, ·arg1⟫(arg0)))(in_field1, in_field2)` UNKNOWN = "unknown" #: The domain is never accessed. @@ -83,7 +83,7 @@ def _domain_union( return DomainAccessDescriptor.UNKNOWN filtered_domains: list[domain_utils.SymbolicDomain] = [ - d # type: ignore[misc] # domain can never be unknown because as these cases are filtered above + d # type: ignore[misc] # domain can never be unknown as these cases are filtered above for d in domains if d != DomainAccessDescriptor.NEVER ] @@ -153,7 +153,6 @@ def _extract_accessed_domains( target_domain: domain_utils.SymbolicDomain | DomainAccessDescriptor, offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]], - allow_uninferred: bool, ) -> ACCESSED_DOMAINS: accessed_domains: dict[str, domain_utils.SymbolicDomain | DomainAccessDescriptor] = {} @@ -178,6 +177,8 @@ def _extract_accessed_domains( accessed_domains.get(in_field_id, DomainAccessDescriptor.NEVER), *new_domains ) + # Widen type to allow callee to all other types that can be in ACCESSED_DOMAINS, i.e. tuple. + # Fine since we transfer ownership of return value to callee. return typing.cast(ACCESSED_DOMAINS, accessed_domains) @@ -196,10 +197,7 @@ def _infer_as_fieldop( # FIXME[#1582](tehrengruber): Temporary solution for `tuple_get` on scan result. See `test_solve_triag`. if isinstance(target_domain, tuple): target_domain = _domain_union(*flatten_nested_tuple(target_domain)) # type: ignore[arg-type] # mypy not smart enough - if not isinstance(target_domain, (domain_utils.SymbolicDomain, DomainAccessDescriptor)): - raise ValueError( - "'target_domain' needs to be a 'domain_utils.SymbolicDomain' or a 'DomainAccessDescriptor'." - ) + assert isinstance(target_domain, (domain_utils.SymbolicDomain, DomainAccessDescriptor)) # `as_fieldop(stencil)(inputs...)` stencil, inputs = applied_fieldop.fun.args[0], applied_fieldop.args @@ -222,7 +220,7 @@ def _infer_as_fieldop( input_ids.append(id_) inputs_accessed_domains: ACCESSED_DOMAINS = _extract_accessed_domains( - stencil, input_ids, target_domain, offset_provider, symbolic_domain_sizes, allow_uninferred + stencil, input_ids, target_domain, offset_provider, symbolic_domain_sizes ) # Recursively infer domain of inputs and update domain arg of nested `as_fieldop`s @@ -406,7 +404,7 @@ def infer_expr( - symbolic_domain_sizes: A dictionary mapping axes names, e.g., `I`, `Vertex`, to a symbol name that evaluates to the length of that axis. - allow_uninferred: Allow `as_fieldop` expressions whose domain is either unknown (e.g. - because of a dynamic shift) or empty. + because of a dynamic shift) or never accessed. Returns: A tuple containing the inferred expression with all applied `as_fieldop` (that are accessed) From 2928503374ebc1aa91acfb16e5fab74008dd65af Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 5 Dec 2024 15:05:09 +0100 Subject: [PATCH 126/150] Address review comments. --- .../next/iterator/transforms/infer_domain.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index f61b78c0ad..60346e069c 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -38,7 +38,12 @@ class DomainAccessDescriptor(eve.StrEnum): NEVER = "never" -DOMAIN: TypeAlias = domain_utils.SymbolicDomain | DomainAccessDescriptor | tuple["DOMAIN", ...] +NON_TUPLE_DOMAIN = domain_utils.SymbolicDomain | DomainAccessDescriptor +#: The domain can also be a tuple of domains, usually this only occurs for scan operators returning +#: a tuple since other occurrences for tuples are removed before domain inference. This is +#: however not a requirement of the pass and `make_tuple(vertex_field, edge_field)` infers just +#: fine to a tuple of a vertexn and an edge domain. +DOMAIN: TypeAlias = NON_TUPLE_DOMAIN | tuple["DOMAIN", ...] ACCESSED_DOMAINS: TypeAlias = dict[str, DOMAIN] @@ -150,11 +155,11 @@ def _merge_domains( def _extract_accessed_domains( stencil: itir.Expr, input_ids: list[str], - target_domain: domain_utils.SymbolicDomain | DomainAccessDescriptor, + target_domain: NON_TUPLE_DOMAIN, offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]], -) -> ACCESSED_DOMAINS: - accessed_domains: dict[str, domain_utils.SymbolicDomain | DomainAccessDescriptor] = {} +) -> dict[str, NON_TUPLE_DOMAIN]: + accessed_domains: dict[str, NON_TUPLE_DOMAIN] = {} shifts_results = trace_shifts.trace_stencil(stencil, num_args=len(input_ids)) @@ -177,9 +182,7 @@ def _extract_accessed_domains( accessed_domains.get(in_field_id, DomainAccessDescriptor.NEVER), *new_domains ) - # Widen type to allow callee to all other types that can be in ACCESSED_DOMAINS, i.e. tuple. - # Fine since we transfer ownership of return value to callee. - return typing.cast(ACCESSED_DOMAINS, accessed_domains) + return accessed_domains def _infer_as_fieldop( @@ -219,7 +222,7 @@ def _infer_as_fieldop( raise ValueError(f"Unsupported expression of type '{type(in_field)}'.") input_ids.append(id_) - inputs_accessed_domains: ACCESSED_DOMAINS = _extract_accessed_domains( + inputs_accessed_domains: dict[str, NON_TUPLE_DOMAIN] = _extract_accessed_domains( stencil, input_ids, target_domain, offset_provider, symbolic_domain_sizes ) From 447b4673788c920d4e7295d27c635f9a98581ab7 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 5 Dec 2024 15:13:18 +0100 Subject: [PATCH 127/150] Address review comments. --- .../next/iterator/transforms/infer_domain.py | 58 +++++++++---------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index 60346e069c..0332d60168 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -38,13 +38,13 @@ class DomainAccessDescriptor(eve.StrEnum): NEVER = "never" -NON_TUPLE_DOMAIN = domain_utils.SymbolicDomain | DomainAccessDescriptor +NonTupleDomain: TypeAlias = domain_utils.SymbolicDomain | DomainAccessDescriptor #: The domain can also be a tuple of domains, usually this only occurs for scan operators returning #: a tuple since other occurrences for tuples are removed before domain inference. This is #: however not a requirement of the pass and `make_tuple(vertex_field, edge_field)` infers just #: fine to a tuple of a vertexn and an edge domain. -DOMAIN: TypeAlias = NON_TUPLE_DOMAIN | tuple["DOMAIN", ...] -ACCESSED_DOMAINS: TypeAlias = dict[str, DOMAIN] +Domain: TypeAlias = NonTupleDomain | tuple["Domain", ...] +AccessedDomains: TypeAlias = dict[str, Domain] class InferenceOptions(typing.TypedDict): @@ -97,7 +97,7 @@ def _domain_union( return domain_utils.domain_union(*filtered_domains) -def _canonicalize_domain_structure(d1: DOMAIN, d2: DOMAIN) -> tuple[DOMAIN, DOMAIN]: +def _canonicalize_domain_structure(d1: Domain, d2: Domain) -> tuple[Domain, Domain]: """ Given two domains or composites thereof, canonicalize their structure. @@ -138,9 +138,9 @@ def _canonicalize_domain_structure(d1: DOMAIN, d2: DOMAIN) -> tuple[DOMAIN, DOMA def _merge_domains( - original_domains: ACCESSED_DOMAINS, - additional_domains: ACCESSED_DOMAINS, -) -> ACCESSED_DOMAINS: + original_domains: AccessedDomains, + additional_domains: AccessedDomains, +) -> AccessedDomains: new_domains = {**original_domains} for key, domain in additional_domains.items(): @@ -155,11 +155,11 @@ def _merge_domains( def _extract_accessed_domains( stencil: itir.Expr, input_ids: list[str], - target_domain: NON_TUPLE_DOMAIN, + target_domain: NonTupleDomain, offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]], -) -> dict[str, NON_TUPLE_DOMAIN]: - accessed_domains: dict[str, NON_TUPLE_DOMAIN] = {} +) -> dict[str, NonTupleDomain]: + accessed_domains: dict[str, NonTupleDomain] = {} shifts_results = trace_shifts.trace_stencil(stencil, num_args=len(input_ids)) @@ -187,12 +187,12 @@ def _extract_accessed_domains( def _infer_as_fieldop( applied_fieldop: itir.FunCall, - target_domain: DOMAIN, + target_domain: Domain, *, offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]], allow_uninferred: bool, -) -> tuple[itir.FunCall, ACCESSED_DOMAINS]: +) -> tuple[itir.FunCall, AccessedDomains]: assert isinstance(applied_fieldop, itir.FunCall) assert cpm.is_call_to(applied_fieldop.fun, "as_fieldop") if not allow_uninferred and target_domain is DomainAccessDescriptor.NEVER: @@ -222,12 +222,12 @@ def _infer_as_fieldop( raise ValueError(f"Unsupported expression of type '{type(in_field)}'.") input_ids.append(id_) - inputs_accessed_domains: dict[str, NON_TUPLE_DOMAIN] = _extract_accessed_domains( + inputs_accessed_domains: dict[str, NonTupleDomain] = _extract_accessed_domains( stencil, input_ids, target_domain, offset_provider, symbolic_domain_sizes ) # Recursively infer domain of inputs and update domain arg of nested `as_fieldop`s - accessed_domains: ACCESSED_DOMAINS = {} + accessed_domains: AccessedDomains = {} transformed_inputs: list[itir.Expr] = [] for in_field_id, in_field in zip(input_ids, inputs): transformed_input, accessed_domains_tmp = infer_expr( @@ -258,9 +258,9 @@ def _infer_as_fieldop( def _infer_let( let_expr: itir.FunCall, - input_domain: DOMAIN, + input_domain: Domain, **kwargs: Unpack[InferenceOptions], -) -> tuple[itir.FunCall, ACCESSED_DOMAINS]: +) -> tuple[itir.FunCall, AccessedDomains]: assert cpm.is_let(let_expr) assert isinstance(let_expr.fun, itir.Lambda) # just to make mypy happy let_params = {param_sym.id for param_sym in let_expr.fun.params} @@ -296,12 +296,12 @@ def _infer_let( def _infer_make_tuple( expr: itir.Expr, - domain: DOMAIN, + domain: Domain, **kwargs: Unpack[InferenceOptions], -) -> tuple[itir.Expr, ACCESSED_DOMAINS]: +) -> tuple[itir.Expr, AccessedDomains]: assert cpm.is_call_to(expr, "make_tuple") infered_args_expr = [] - actual_domains: ACCESSED_DOMAINS = {} + actual_domains: AccessedDomains = {} if not isinstance(domain, tuple): # promote domain to a tuple of domains such that it has the same structure as # the expression @@ -323,11 +323,11 @@ def _infer_make_tuple( def _infer_tuple_get( expr: itir.Expr, - domain: DOMAIN, + domain: Domain, **kwargs: Unpack[InferenceOptions], -) -> tuple[itir.Expr, ACCESSED_DOMAINS]: +) -> tuple[itir.Expr, AccessedDomains]: assert cpm.is_call_to(expr, "tuple_get") - actual_domains: ACCESSED_DOMAINS = {} + actual_domains: AccessedDomains = {} idx_expr, tuple_arg = expr.args assert isinstance(idx_expr, itir.Literal) idx = int(idx_expr.value) @@ -343,12 +343,12 @@ def _infer_tuple_get( def _infer_if( expr: itir.Expr, - domain: DOMAIN, + domain: Domain, **kwargs: Unpack[InferenceOptions], -) -> tuple[itir.Expr, ACCESSED_DOMAINS]: +) -> tuple[itir.Expr, AccessedDomains]: assert cpm.is_call_to(expr, "if_") infered_args_expr = [] - actual_domains: ACCESSED_DOMAINS = {} + actual_domains: AccessedDomains = {} cond, true_val, false_val = expr.args for arg in [true_val, false_val]: infered_arg_expr, actual_domains_arg = infer_expr(arg, domain, **kwargs) @@ -360,9 +360,9 @@ def _infer_if( def _infer_expr( expr: itir.Expr, - domain: DOMAIN, + domain: Domain, **kwargs: Unpack[InferenceOptions], -) -> tuple[itir.Expr, ACCESSED_DOMAINS]: +) -> tuple[itir.Expr, AccessedDomains]: if isinstance(expr, itir.SymRef): return expr, {str(expr.id): domain} elif isinstance(expr, itir.Literal): @@ -389,12 +389,12 @@ def _infer_expr( def infer_expr( expr: itir.Expr, - domain: DOMAIN, + domain: Domain, *, offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]] = None, allow_uninferred: bool = False, -) -> tuple[itir.Expr, ACCESSED_DOMAINS]: +) -> tuple[itir.Expr, AccessedDomains]: """ Infer the domain of all field subexpressions of `expr`. From 54cd6b17775d41013f1661151297697368789453 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 5 Dec 2024 15:24:59 +0100 Subject: [PATCH 128/150] Fix type annotation --- .../iterator_tests/transforms_tests/test_domain_inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py index 7a7a307901..2b215a59fa 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py @@ -125,8 +125,8 @@ def constant_fold_domain_exprs(arg: itir.Node) -> itir.Node: def constant_fold_accessed_domains( - domains: infer_domain.ACCESSED_DOMAINS, -) -> infer_domain.ACCESSED_DOMAINS: + domains: infer_domain.AccessedDomains, +) -> infer_domain.AccessedDomains: def fold_domain( domain: domain_utils.SymbolicDomain | Literal[infer_domain.DomainAccessDescriptor.NEVER], ): From 914a9e5bf39f564c011004cdf49ddf59f9f73c44 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 6 Dec 2024 10:19:07 +0100 Subject: [PATCH 129/150] Address review comments --- .../iterator/transforms/collapse_tuple.py | 121 +++++++++--------- .../transforms_tests/test_collapse_tuple.py | 17 +++ 2 files changed, 78 insertions(+), 60 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 7e2e5b127e..f9bcccbb26 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -229,7 +229,7 @@ def transform(self, node: ir.Node, **kwargs) -> Optional[ir.Node]: method = getattr(self, f"transform_{transformation.name.lower()}") result = method(node, **kwargs) if result is not None: - assert result is not node + assert result is not node # transformation should have returned None, since nothing changed itir_type_inference.reinfer(result) return result return None @@ -361,69 +361,70 @@ def transform_propagate_to_if_on_tuples(self, node: ir.FunCall, **kwargs) -> Opt def transform_propagate_to_if_on_tuples_cps( self, node: ir.FunCall, **kwargs ) -> Optional[ir.Node]: - if not cpm.is_call_to(node, "if_"): - for i, arg in enumerate(node.args): - if cpm.is_call_to(arg, "if_"): - itir_type_inference.reinfer(arg) - if not any(isinstance(branch.type, ts.TupleType) for branch in arg.args[1:]): - continue + if cpm.is_call_to(node, "if_"): + return None - cond, true_branch, false_branch = arg.args - tuple_type: ts.TupleType = true_branch.type # type: ignore[assignment] # type ensured above - tuple_len = len(tuple_type.types) - itir_type_inference.reinfer(node) - assert node.type - - # transform function into continuation-passing-style - f_type = ts.FunctionType( - pos_only_args=tuple_type.types, - pos_or_kw_args={}, - kw_only_args={}, - returns=node.type, - ) - f_params = [ - im.sym(self.uids.sequential_id(prefix="__ct_el_cps"), type_) - for type_ in tuple_type.types - ] - f_args = [im.ref(param.id, param.type) for param in f_params] - f_body = _with_altered_arg(node, i, im.make_tuple(*f_args)) - # simplify, e.g., inline trivial make_tuple args - new_f_body = self.fp_transform(f_body, **kwargs) - # if the function did not simplify there is nothing to gain. Skip - # transformation. - if new_f_body is f_body: - continue - # if the function is not trivial the transformation would still work, but - # inlining would result in a larger tree again and we didn't didn't gain - # anything compared to regular `propagate_to_if_on_tuples`. Not inling also - # works, but we don't want bound lambda functions in our tree (at least right - # now). - if not _is_trivial_or_tuple_thereof_expr(new_f_body): - continue - f = im.lambda_(*f_params)(new_f_body) - - tuple_var = self.uids.sequential_id(prefix="__ct_tuple_cps") - f_var = self.uids.sequential_id(prefix="__ct_cont") - new_branches = [] - for branch in arg.args[1:]: - new_branch = im.let(tuple_var, branch)( - im.call(im.ref(f_var, f_type))( - *( - im.tuple_get(i, im.ref(tuple_var, branch.type)) - for i in range(tuple_len) - ) + for i, arg in enumerate(node.args): + if cpm.is_call_to(arg, "if_"): + itir_type_inference.reinfer(arg) + if not any(isinstance(branch.type, ts.TupleType) for branch in arg.args[1:]): + continue + + cond, true_branch, false_branch = arg.args + tuple_type: ts.TupleType = true_branch.type # type: ignore[assignment] # type ensured above + tuple_len = len(tuple_type.types) + + # transform function into continuation-passing-style + itir_type_inference.reinfer(node) + assert node.type + f_type = ts.FunctionType( + pos_only_args=tuple_type.types, + pos_or_kw_args={}, + kw_only_args={}, + returns=node.type, + ) + f_params = [ + im.sym(self.uids.sequential_id(prefix="__ct_el_cps"), type_) + for type_ in tuple_type.types + ] + f_args = [im.ref(param.id, param.type) for param in f_params] + f_body = _with_altered_arg(node, i, im.make_tuple(*f_args)) + # simplify, e.g., inline trivial make_tuple args + new_f_body = self.fp_transform(f_body, **kwargs) + # if the function did not simplify there is nothing to gain. Skip + # transformation. + if new_f_body is f_body: + continue + # if the function is not trivial the transformation would still work, but + # inlining would result in a larger tree again and we didn't didn't gain + # anything compared to regular `propagate_to_if_on_tuples`. Not inling also + # works, but we don't want bound lambda functions in our tree (at least right + # now). + if not _is_trivial_or_tuple_thereof_expr(new_f_body): + continue + f = im.lambda_(*f_params)(new_f_body) + + tuple_var = self.uids.sequential_id(prefix="__ct_tuple_cps") + f_var = self.uids.sequential_id(prefix="__ct_cont") + new_branches = [] + for branch in arg.args[1:]: + new_branch = im.let(tuple_var, branch)( + im.call(im.ref(f_var, f_type))( + *( + im.tuple_get(i, im.ref(tuple_var, branch.type)) + for i in range(tuple_len) ) ) - new_branches.append(self.fp_transform(new_branch, **kwargs)) - - new_node = im.let(f_var, f)(im.if_(cond, *new_branches)) - new_node = inline_lambda(new_node, eligible_params=[True]) - assert cpm.is_call_to(new_node, "if_") - new_node = im.if_( - cond, *(self.fp_transform(branch, **kwargs) for branch in new_node.args[1:]) ) - return new_node - return None + new_branches.append(self.fp_transform(new_branch, **kwargs)) + + new_node = im.let(f_var, f)(im.if_(cond, *new_branches)) + new_node = inline_lambda(new_node, eligible_params=[True]) + assert cpm.is_call_to(new_node, "if_") + new_node = im.if_( + cond, *(self.fp_transform(branch, **kwargs) for branch in new_node.args[1:]) + ) + return new_node def transform_propagate_nested_let(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: if cpm.is_let(node): diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py index 2212dfb6e1..f216b48856 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py @@ -256,6 +256,23 @@ def test_if_make_tuple_reorder_cps(): assert actual == expected +def test_if_make_tuple_reorder_cps(): + testee = im.let( + ("t1", im.if_(True, im.make_tuple(1, 2), im.make_tuple(3, 4))), + ("t2", im.if_(False, im.make_tuple(5, 6), im.make_tuple(7, 8))) + )( + im.make_tuple(im.tuple_get(1, "t"), im.tuple_get(0, "t")) + ) + expected = im.if_(True, im.if_(False, im.make_tuple(2, 1), im.make_tuple(4, 3))) + actual = CollapseTuple.apply( + testee, + flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, + allow_undeclared_symbols=True, + within_stencil=False, + ) + assert actual == expected + + def test_if_make_tuple_reorder_cps_nested(): testee = im.let("t", im.if_(True, im.make_tuple(1, 2), im.make_tuple(3, 4)))( im.let("c", im.tuple_get(0, "t"))( From e043d0a7cbf4d7ff0449a1d95fb09710d23a350b Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 6 Dec 2024 10:19:46 +0100 Subject: [PATCH 130/150] Add type annotation in extract_tmp pass. --- src/gt4py/next/iterator/transforms/global_tmps.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index fbf73a8f0b..cdfc1052fe 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -74,7 +74,7 @@ def _transform_by_pattern( # or a tuple thereof) # - one `SetAt` statement that materializes the expression into the temporary for tmp_sym, tmp_expr in extracted_fields.items(): - domain = tmp_expr.annex.domain + domain: infer_domain.Domain = tmp_expr.annex.domain # TODO(tehrengruber): Implement. This happens when the expression is a combination # of an `if_` call with a tuple, e.g., `if_(cond, {a, b}, {c, d})`. As long as we are From 20fb59ecbcaf15b66103d3d182474339144b956e Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 6 Dec 2024 16:07:01 +0100 Subject: [PATCH 131/150] Update src/gt4py/next/iterator/transforms/infer_domain.py Co-authored-by: Hannes Vogt --- src/gt4py/next/iterator/transforms/infer_domain.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index 0332d60168..c1915e4ba9 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -42,7 +42,7 @@ class DomainAccessDescriptor(eve.StrEnum): #: The domain can also be a tuple of domains, usually this only occurs for scan operators returning #: a tuple since other occurrences for tuples are removed before domain inference. This is #: however not a requirement of the pass and `make_tuple(vertex_field, edge_field)` infers just -#: fine to a tuple of a vertexn and an edge domain. +#: fine to a tuple of a vertex and an edge domain. Domain: TypeAlias = NonTupleDomain | tuple["Domain", ...] AccessedDomains: TypeAlias = dict[str, Domain] From d8d391320a65e15625c004786fa8f47ee7f524a0 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 6 Dec 2024 16:56:23 +0100 Subject: [PATCH 132/150] Address review comments --- .../next/iterator/transforms/global_tmps.py | 2 +- .../next/iterator/transforms/infer_domain.py | 32 ++++++++++--------- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index cdfc1052fe..334fb330d7 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -74,7 +74,7 @@ def _transform_by_pattern( # or a tuple thereof) # - one `SetAt` statement that materializes the expression into the temporary for tmp_sym, tmp_expr in extracted_fields.items(): - domain: infer_domain.Domain = tmp_expr.annex.domain + domain: infer_domain.DomainAccess = tmp_expr.annex.domain # TODO(tehrengruber): Implement. This happens when the expression is a combination # of an `if_` call with a tuple, e.g., `if_(cond, {a, b}, {c, d})`. As long as we are diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index 0332d60168..9c7d2aa3aa 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -29,6 +29,8 @@ class DomainAccessDescriptor(eve.StrEnum): """ Descriptor for domains that could not be inferred. """ + # TODO(tehrengruber): Revisit this concept. It is strange that we don't have a descriptor + # `KNOWN`, but since we don't need it, it wasn't added. #: The access is unknown because of a dynamic shift.whose extent is not known. #: E.g.: `(⇑(λ(arg0, arg1) → ·⟪Ioffₒ, ·arg1⟫(arg0)))(in_field1, in_field2)` @@ -38,13 +40,13 @@ class DomainAccessDescriptor(eve.StrEnum): NEVER = "never" -NonTupleDomain: TypeAlias = domain_utils.SymbolicDomain | DomainAccessDescriptor +NonTupleDomainAccess: TypeAlias = domain_utils.SymbolicDomain | DomainAccessDescriptor #: The domain can also be a tuple of domains, usually this only occurs for scan operators returning #: a tuple since other occurrences for tuples are removed before domain inference. This is #: however not a requirement of the pass and `make_tuple(vertex_field, edge_field)` infers just #: fine to a tuple of a vertexn and an edge domain. -Domain: TypeAlias = NonTupleDomain | tuple["Domain", ...] -AccessedDomains: TypeAlias = dict[str, Domain] +DomainAccess: TypeAlias = NonTupleDomainAccess | tuple["DomainAccess", ...] +AccessedDomains: TypeAlias = dict[str, DomainAccess] class InferenceOptions(typing.TypedDict): @@ -97,7 +99,7 @@ def _domain_union( return domain_utils.domain_union(*filtered_domains) -def _canonicalize_domain_structure(d1: Domain, d2: Domain) -> tuple[Domain, Domain]: +def _canonicalize_domain_structure(d1: DomainAccess, d2: DomainAccess) -> tuple[DomainAccess, DomainAccess]: """ Given two domains or composites thereof, canonicalize their structure. @@ -155,11 +157,11 @@ def _merge_domains( def _extract_accessed_domains( stencil: itir.Expr, input_ids: list[str], - target_domain: NonTupleDomain, + target_domain: NonTupleDomainAccess, offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]], -) -> dict[str, NonTupleDomain]: - accessed_domains: dict[str, NonTupleDomain] = {} +) -> dict[str, NonTupleDomainAccess]: + accessed_domains: dict[str, NonTupleDomainAccess] = {} shifts_results = trace_shifts.trace_stencil(stencil, num_args=len(input_ids)) @@ -187,7 +189,7 @@ def _extract_accessed_domains( def _infer_as_fieldop( applied_fieldop: itir.FunCall, - target_domain: Domain, + target_domain: DomainAccess, *, offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]], @@ -222,7 +224,7 @@ def _infer_as_fieldop( raise ValueError(f"Unsupported expression of type '{type(in_field)}'.") input_ids.append(id_) - inputs_accessed_domains: dict[str, NonTupleDomain] = _extract_accessed_domains( + inputs_accessed_domains: dict[str, NonTupleDomainAccess] = _extract_accessed_domains( stencil, input_ids, target_domain, offset_provider, symbolic_domain_sizes ) @@ -258,7 +260,7 @@ def _infer_as_fieldop( def _infer_let( let_expr: itir.FunCall, - input_domain: Domain, + input_domain: DomainAccess, **kwargs: Unpack[InferenceOptions], ) -> tuple[itir.FunCall, AccessedDomains]: assert cpm.is_let(let_expr) @@ -296,7 +298,7 @@ def _infer_let( def _infer_make_tuple( expr: itir.Expr, - domain: Domain, + domain: DomainAccess, **kwargs: Unpack[InferenceOptions], ) -> tuple[itir.Expr, AccessedDomains]: assert cpm.is_call_to(expr, "make_tuple") @@ -323,7 +325,7 @@ def _infer_make_tuple( def _infer_tuple_get( expr: itir.Expr, - domain: Domain, + domain: DomainAccess, **kwargs: Unpack[InferenceOptions], ) -> tuple[itir.Expr, AccessedDomains]: assert cpm.is_call_to(expr, "tuple_get") @@ -343,7 +345,7 @@ def _infer_tuple_get( def _infer_if( expr: itir.Expr, - domain: Domain, + domain: DomainAccess, **kwargs: Unpack[InferenceOptions], ) -> tuple[itir.Expr, AccessedDomains]: assert cpm.is_call_to(expr, "if_") @@ -360,7 +362,7 @@ def _infer_if( def _infer_expr( expr: itir.Expr, - domain: Domain, + domain: DomainAccess, **kwargs: Unpack[InferenceOptions], ) -> tuple[itir.Expr, AccessedDomains]: if isinstance(expr, itir.SymRef): @@ -389,7 +391,7 @@ def _infer_expr( def infer_expr( expr: itir.Expr, - domain: Domain, + domain: DomainAccess, *, offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]] = None, From 58323710eab2257b11f113ada1f554ce92e67bc9 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 6 Dec 2024 16:57:31 +0100 Subject: [PATCH 133/150] Address review comments --- src/gt4py/next/iterator/transforms/infer_domain.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index ff9a303851..f26d3f9ec2 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -29,6 +29,7 @@ class DomainAccessDescriptor(eve.StrEnum): """ Descriptor for domains that could not be inferred. """ + # TODO(tehrengruber): Revisit this concept. It is strange that we don't have a descriptor # `KNOWN`, but since we don't need it, it wasn't added. @@ -99,7 +100,9 @@ def _domain_union( return domain_utils.domain_union(*filtered_domains) -def _canonicalize_domain_structure(d1: DomainAccess, d2: DomainAccess) -> tuple[DomainAccess, DomainAccess]: +def _canonicalize_domain_structure( + d1: DomainAccess, d2: DomainAccess +) -> tuple[DomainAccess, DomainAccess]: """ Given two domains or composites thereof, canonicalize their structure. From 668e3eccde9f76159376f6c5f4b0c59769a71d70 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 6 Dec 2024 20:43:52 +0100 Subject: [PATCH 134/150] Fix format --- src/gt4py/next/iterator/transforms/pass_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 763c3b27b9..d967c8fbb8 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -75,7 +75,7 @@ def apply_common_transforms( # required in order to get rid of expressions without a domain (e.g. when a tuple element is never accessed) ir = CollapseTuple.apply(ir, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program ir = inline_dynamic_shifts.InlineDynamicShifts.apply( - ir # type: ignore[arg-type] # always an itir.Program + ir ) # domain inference does not support dynamic offsets yet ir = infer_domain.infer_program( ir, From fc46edf3d51af5c7159cc28f8d831d2e3a3f68c1 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 10 Dec 2024 14:35:35 +0100 Subject: [PATCH 135/150] Add test and fix nested transformation on nested ifs --- .../next/iterator/transforms/collapse_tuple.py | 10 +++++++++- .../transforms_tests/test_collapse_tuple.py | 17 +++++++++++++---- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index f9bcccbb26..01ccbc8ab6 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -65,6 +65,8 @@ def _is_trivial_or_tuple_thereof_expr(node: ir.Node) -> bool: return all(_is_trivial_or_tuple_thereof_expr(arg) for arg in node.args) if cpm.is_call_to(node, "tuple_get"): return _is_trivial_or_tuple_thereof_expr(node.args[1]) + if cpm.is_call_to(node, "if_"): + return all(_is_trivial_or_tuple_thereof_expr(arg) for arg in node.args[1:]) if isinstance(node, (ir.SymRef, ir.Literal)): return True if cpm.is_let(node): @@ -229,7 +231,9 @@ def transform(self, node: ir.Node, **kwargs) -> Optional[ir.Node]: method = getattr(self, f"transform_{transformation.name.lower()}") result = method(node, **kwargs) if result is not None: - assert result is not node # transformation should have returned None, since nothing changed + assert ( + result is not node + ) # transformation should have returned None, since nothing changed itir_type_inference.reinfer(result) return result return None @@ -400,6 +404,8 @@ def transform_propagate_to_if_on_tuples_cps( # anything compared to regular `propagate_to_if_on_tuples`. Not inling also # works, but we don't want bound lambda functions in our tree (at least right # now). + # TODO(tehrengruber): `if_` of trivial expression is also considered fine. This + # will duplicate the condition and unnecessarily increase the size of the tree. if not _is_trivial_or_tuple_thereof_expr(new_f_body): continue f = im.lambda_(*f_params)(new_f_body) @@ -426,6 +432,8 @@ def transform_propagate_to_if_on_tuples_cps( ) return new_node + return None + def transform_propagate_nested_let(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: if cpm.is_let(node): # `let((a, let(b, 1)(a_val)))(a)`-> `let(b, 1)(let(a, a_val)(a))` diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py index f216b48856..5e2c07ef0a 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py @@ -256,14 +256,23 @@ def test_if_make_tuple_reorder_cps(): assert actual == expected -def test_if_make_tuple_reorder_cps(): +def test_nested_if_make_tuple_reorder_cps(): testee = im.let( ("t1", im.if_(True, im.make_tuple(1, 2), im.make_tuple(3, 4))), - ("t2", im.if_(False, im.make_tuple(5, 6), im.make_tuple(7, 8))) + ("t2", im.if_(False, im.make_tuple(5, 6), im.make_tuple(7, 8))), )( - im.make_tuple(im.tuple_get(1, "t"), im.tuple_get(0, "t")) + im.make_tuple( + im.tuple_get(1, "t1"), + im.tuple_get(0, "t1"), + im.tuple_get(1, "t2"), + im.tuple_get(0, "t2"), + ) + ) + expected = im.if_( + True, + im.if_(False, im.make_tuple(2, 1, 6, 5), im.make_tuple(2, 1, 8, 7)), + im.if_(False, im.make_tuple(4, 3, 6, 5), im.make_tuple(4, 3, 8, 7)), ) - expected = im.if_(True, im.if_(False, im.make_tuple(2, 1), im.make_tuple(4, 3))) actual = CollapseTuple.apply( testee, flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, From d04c4dcfc8ee8d30fe1ca3808be5a97dc513c7da Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 6 Jan 2025 14:27:40 +0100 Subject: [PATCH 136/150] Add additional make_tuple test cases --- .../iterator/transforms/fuse_as_fieldop.py | 1 + .../transforms_tests/test_fuse_as_fieldop.py | 39 +++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index 4b7cc45adc..83cc922dc6 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -310,6 +310,7 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs): node = self.generic_visit(node, **kwargs) if cpm.is_call_to(node, "make_tuple"): + # TODO(tehrengruber): x, y = alpha * y, x is not fused as_fieldop_args = [arg for arg in node.args if cpm.is_applied_as_fieldop(arg)] distinct_domains = set(arg.fun.args[1] for arg in as_fieldop_args) # type: ignore[attr-defined] # ensured by cpm.is_applied_as_fieldop if len(distinct_domains) != len(as_fieldop_args): diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py index e9cb016313..1b95183422 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py @@ -169,6 +169,45 @@ def test_make_tuple_fusion_trivial(): ) assert actual_simplified == expected +def test_make_tuple_fusion_symref(): + d = im.domain("cartesian_domain", {IDim: (0, 1)}) + testee = im.make_tuple( + im.as_fieldop("deref", d)(im.ref("a", field_type)), + im.ref("b", field_type), + ) + expected = im.as_fieldop( + im.lambda_("a", "b")(im.make_tuple(im.deref("a"), im.deref("b"))), + d, + )(im.ref("a", field_type), im.ref("b", field_type)) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + # simplify to remove unnecessary make_tuple call + actual_simplified = collapse_tuple.CollapseTuple.apply( + actual, within_stencil=False, allow_undeclared_symbols=True + ) + assert actual_simplified == expected + + +def test_make_tuple_fusion_symref2(): + d = im.domain("cartesian_domain", {IDim: (0, 1)}) + testee = im.make_tuple( + im.as_fieldop("deref", d)(im.ref("a", field_type)), + im.ref("a", field_type), + ) + expected = im.as_fieldop( + im.lambda_("a")(im.make_tuple(im.deref("a"), im.deref("a"))), + d, + )(im.ref("a", field_type)) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + # simplify to remove unnecessary make_tuple call + actual_simplified = collapse_tuple.CollapseTuple.apply( + actual, within_stencil=False, allow_undeclared_symbols=True + ) + assert actual_simplified == expected + def test_make_tuple_fusion_different_domains(): d1 = im.domain("cartesian_domain", {IDim: (0, 1)}) From 5e1a88cdb2bc8b0975da64ee5fa8cbf74d5fa4b0 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 10 Jan 2025 00:39:28 +0100 Subject: [PATCH 137/150] Improve documentation --- .../iterator/transforms/collapse_tuple.py | 48 ++++++++++++++----- 1 file changed, 35 insertions(+), 13 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index ea7aad890c..b2f75ad7d6 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -65,6 +65,8 @@ def _is_trivial_or_tuple_thereof_expr(node: ir.Node) -> bool: return all(_is_trivial_or_tuple_thereof_expr(arg) for arg in node.args) if cpm.is_call_to(node, "tuple_get"): return _is_trivial_or_tuple_thereof_expr(node.args[1]) + # This will duplicate the condition and increase the size of the tree, but this is probably + # acceptable. if cpm.is_call_to(node, "if_"): return all(_is_trivial_or_tuple_thereof_expr(arg) for arg in node.args[1:]) if isinstance(node, (ir.SymRef, ir.Literal)): @@ -365,6 +367,27 @@ def transform_propagate_to_if_on_tuples(self, node: ir.FunCall, **kwargs) -> Opt def transform_propagate_to_if_on_tuples_cps( self, node: ir.FunCall, **kwargs ) -> Optional[ir.Node]: + # The basic idea of this transformation is to remove tuples across if-stmts by rewriting + # the expression in continuation passing style, e.g. something like a tuple reordering + # ``` + # let t = if True then {1, 2} else {3, 4} in + # {t[1], t[0]}) + # end + # ``` + # is rewritten into: + # ``` + # let cont = λ(el0, el1) → {el1, el0} in + # if True then cont(1, 2) else cont(3, 4) + # end + # ``` + # Note how the `make_tuple` call argument of the `if` disappears. Since lambda functions + # are currently inlined (due to limitations of the domain inference) we will only + # gain something compared `PROPAGATE_TO_IF_ON_TUPLES` if the continuation `cont` is trivial, + # e.g. a `make_tuple` call like in the example. In that case we can inline the trivial + # continuation and end up with an only moderately larger tree, e.g. + # `if True then {2, 1} else {4, 3}`. The examples in the comments below all refer to this + # tuple reordering example here. + if cpm.is_call_to(node, "if_"): return None @@ -374,14 +397,14 @@ def transform_propagate_to_if_on_tuples_cps( if not any(isinstance(branch.type, ts.TupleType) for branch in arg.args[1:]): continue - cond, true_branch, false_branch = arg.args + cond, true_branch, false_branch = arg.args # e.g. `True`, `{1, 2}`, `{3, 4}` tuple_type: ts.TupleType = true_branch.type # type: ignore[assignment] # type ensured above tuple_len = len(tuple_type.types) - # transform function into continuation-passing-style + # build and simplify continuation, e.g. λ(el0, el1) → {el1, el0} itir_type_inference.reinfer(node) assert node.type - f_type = ts.FunctionType( + f_type = ts.FunctionType( # type of continuation in order to keep full type info pos_only_args=tuple_type.types, pos_or_kw_args={}, kw_only_args={}, @@ -395,27 +418,25 @@ def transform_propagate_to_if_on_tuples_cps( f_body = _with_altered_arg(node, i, im.make_tuple(*f_args)) # simplify, e.g., inline trivial make_tuple args new_f_body = self.fp_transform(f_body, **kwargs) - # if the function did not simplify there is nothing to gain. Skip - # transformation. + # if the continuation did not simplify there is nothing to gain. Skip + # transformation of this argument. if new_f_body is f_body: continue - # if the function is not trivial the transformation would still work, but - # inlining would result in a larger tree again and we didn't didn't gain - # anything compared to regular `propagate_to_if_on_tuples`. Not inling also - # works, but we don't want bound lambda functions in our tree (at least right - # now). - # TODO(tehrengruber): `if_` of trivial expression is also considered fine. This - # will duplicate the condition and unnecessarily increase the size of the tree. + # if the function is not trivial the transformation we would get a larger tree + # after inlining so we skip transformation this argument. if not _is_trivial_or_tuple_thereof_expr(new_f_body): continue f = im.lambda_(*f_params)(new_f_body) + # this is the symbol refering to the tuple value inside the two branches of the + # if, e.g. a symbol refering to `{1, 2}` and `{3, 4}` respectively tuple_var = self.uids.sequential_id(prefix="__ct_tuple_cps") + # this is the symbol refering to our continuation, e.g. `cont` in our example. f_var = self.uids.sequential_id(prefix="__ct_cont") new_branches = [] for branch in arg.args[1:]: new_branch = im.let(tuple_var, branch)( - im.call(im.ref(f_var, f_type))( + im.call(im.ref(f_var, f_type))( # call to the continuation *( im.tuple_get(i, im.ref(tuple_var, branch.type)) for i in range(tuple_len) @@ -424,6 +445,7 @@ def transform_propagate_to_if_on_tuples_cps( ) new_branches.append(self.fp_transform(new_branch, **kwargs)) + # assemble everything together new_node = im.let(f_var, f)(im.if_(cond, *new_branches)) new_node = inline_lambda(new_node, eligible_params=[True]) assert cpm.is_call_to(new_node, "if_") From 53d442a7e03197835b77b3cba578552f7830986f Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 10 Jan 2025 01:36:31 +0100 Subject: [PATCH 138/150] Address review comments --- src/gt4py/next/iterator/transforms/collapse_tuple.py | 11 ++++++++--- src/gt4py/next/iterator/transforms/pass_manager.py | 1 + src/gt4py/next/iterator/type_system/inference.py | 3 ++- .../transforms_tests/test_collapse_tuple.py | 2 +- 4 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index b2f75ad7d6..c76b9e2318 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -50,7 +50,10 @@ def _is_trivial_make_tuple_call(node: ir.Expr): def _is_trivial_or_tuple_thereof_expr(node: ir.Node) -> bool: """ - Return `true` if the expr is a trivial expression or tuple thereof. + Return `true` if the expr is a trivial expression (`SymRef` or `Literal`) or tuple thereof. + + Let forms with trivial body and args as well as if call with trivial branches are also + considered trivial. >>> _is_trivial_or_tuple_thereof_expr(im.make_tuple("a", "b")) True @@ -61,6 +64,8 @@ def _is_trivial_or_tuple_thereof_expr(node: ir.Node) -> bool: ... ) True """ + if isinstance(node, (ir.SymRef, ir.Literal)): + return True if cpm.is_call_to(node, "make_tuple"): return all(_is_trivial_or_tuple_thereof_expr(arg) for arg in node.args) if cpm.is_call_to(node, "tuple_get"): @@ -69,8 +74,6 @@ def _is_trivial_or_tuple_thereof_expr(node: ir.Node) -> bool: # acceptable. if cpm.is_call_to(node, "if_"): return all(_is_trivial_or_tuple_thereof_expr(arg) for arg in node.args[1:]) - if isinstance(node, (ir.SymRef, ir.Literal)): - return True if cpm.is_let(node): return _is_trivial_or_tuple_thereof_expr(node.fun.expr) and all( # type: ignore[attr-defined] # ensured by is_let _is_trivial_or_tuple_thereof_expr(arg) for arg in node.args @@ -391,6 +394,8 @@ def transform_propagate_to_if_on_tuples_cps( if cpm.is_call_to(node, "if_"): return None + # The first argument that is eligible also transforms all remaining args (They will be + # part of the continuation and recursively transformed). for i, arg in enumerate(node.args): if cpm.is_call_to(arg, "if_"): itir_type_inference.reinfer(arg) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index f3cb0cc468..6906f81e3f 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -136,6 +136,7 @@ def apply_common_transforms( ir, ignore_tuple_size=True, uids=collapse_tuple_uids, + flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, offset_provider_type=offset_provider_type, ) # type: ignore[assignment] # always an itir.Program diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 9f7a14b0b8..16ab160f2c 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -448,7 +448,8 @@ def apply_reinfer(cls, node: T) -> T: Contrary to the regular inference, this method does not descend into already typed sub-nodes and can be used as a lightweight way to restore type information during a pass. - Note that this function is stateful, which is usually desired, and more performant. + Note that this function alters the input node, which is usually desired, and more + performant. Arguments: node: The :class:`itir.Node` to infer the types of. diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py index 5e2c07ef0a..938b998565 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py @@ -9,7 +9,7 @@ from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple from gt4py.next.type_system import type_specifications as ts -from tests.next_tests.unit_tests.iterator_tests.test_type_inference import int_type +from next_tests.unit_tests.iterator_tests.test_type_inference import int_type def test_simple_make_tuple_tuple_get(): From 3b1af1eda5e5a019cdb216b0b913d5c9334748c4 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 10 Jan 2025 01:37:21 +0100 Subject: [PATCH 139/150] Address review comments --- src/gt4py/next/iterator/transforms/collapse_tuple.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index c76b9e2318..8d52e69ea3 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -52,7 +52,7 @@ def _is_trivial_or_tuple_thereof_expr(node: ir.Node) -> bool: """ Return `true` if the expr is a trivial expression (`SymRef` or `Literal`) or tuple thereof. - Let forms with trivial body and args as well as if call with trivial branches are also + Let forms with trivial body and args as well as `if` calls with trivial branches are also considered trivial. >>> _is_trivial_or_tuple_thereof_expr(im.make_tuple("a", "b")) From 96a840ae4870144366260e4b21e9e83460f1e0e7 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 10 Jan 2025 01:37:45 +0100 Subject: [PATCH 140/150] Address review comments --- src/gt4py/next/iterator/transforms/collapse_tuple.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 8d52e69ea3..b5ceb99c1b 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -395,7 +395,7 @@ def transform_propagate_to_if_on_tuples_cps( return None # The first argument that is eligible also transforms all remaining args (They will be - # part of the continuation and recursively transformed). + # part of the continuation which is recursively transformed). for i, arg in enumerate(node.args): if cpm.is_call_to(arg, "if_"): itir_type_inference.reinfer(arg) From fd5fd0a4bc0da5c03d4a66089df10899559270df Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 10 Jan 2025 16:25:26 +0100 Subject: [PATCH 141/150] Use type reinference --- .../iterator/transforms/fuse_as_fieldop.py | 24 +++++++------------ .../transforms_tests/test_fuse_as_fieldop.py | 1 + 2 files changed, 10 insertions(+), 15 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index 83cc922dc6..521579a400 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -54,7 +54,6 @@ def _canonicalize_as_fieldop(expr: itir.FunCall) -> itir.FunCall: if cpm.is_ref_to(stencil, "deref"): stencil = im.lambda_("arg")(im.deref("arg")) new_expr = im.as_fieldop(stencil, domain)(*expr.args) - type_inference.copy_type(from_=expr, to=new_expr, allow_untyped=True) return new_expr @@ -153,9 +152,7 @@ def fuse_as_fieldop( pass elif cpm.is_call_to(arg, "if_"): # TODO(tehrengruber): revisit if we want to inline if_ - type_ = arg.type arg = im.op_as_fieldop("if_")(*arg.args) - arg.type = type_ elif _is_tuple_expr_of_literals(arg): arg = im.op_as_fieldop(im.lambda_()(arg))() else: @@ -167,11 +164,12 @@ def fuse_as_fieldop( new_args = _merge_arguments(new_args, extracted_args) else: - # just a safety check if typing information is available - if arg.type and not isinstance(arg.type, ts.DeferredType): - assert isinstance(arg.type, ts.TypeSpec) - dtype = type_info.apply_to_primitive_constituents(type_info.extract_dtype, arg.type) - assert not isinstance(dtype, it_ts.ListType) + # just a safety check + type_inference.reinfer(arg) + assert isinstance(arg.type, ts.TypeSpec) + dtype = type_info.apply_to_primitive_constituents(type_info.extract_dtype, arg.type) + assert not isinstance(dtype, it_ts.ListType) + new_param: str if isinstance( arg, itir.SymRef @@ -200,7 +198,6 @@ def fuse_as_fieldop( new_node = im.as_fieldop(new_stencil, domain)(*new_args.values()) - type_inference.copy_type(from_=expr, to=new_node, allow_untyped=True) return new_node @@ -219,6 +216,7 @@ def _arg_inline_predicate(node: itir.Expr, shifts): if len(shifts) == 0: return True # applied fieldop with list return type must always be inlined as no backend supports this + type_inference.reinfer(node) assert isinstance(node.type, ts.TypeSpec) dtype = type_info.apply_to_primitive_constituents(type_info.extract_dtype, node.type) if isinstance(dtype, it_ts.ListType): @@ -294,6 +292,7 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs): # TODO(tehrengruber): Write test-case. E.g. Adding two sparse fields. Sara observed this # with a cast to a sparse field, but this is likely already covered. if cpm.is_let(node): + type_inference.reinfer(node) eligible_args = [ isinstance(arg.type, ts.FieldType) and isinstance(arg.type.dtype, it_ts.ListType) for arg in node.args @@ -303,9 +302,7 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs): return self.visit(node) if cpm.is_applied_as_fieldop(node): # don't descend in stencil - old_node = node node = im.as_fieldop(*node.fun.args)(*self.generic_visit(node.args)) # type: ignore[attr-defined] # ensured by cpm.is_applied_as_fieldop - type_inference.copy_type(from_=old_node, to=node) elif kwargs.get("recurse", True): node = self.generic_visit(node, **kwargs) @@ -318,7 +315,6 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs): as_fieldop_args_by_domain: dict[itir.Expr, list[tuple[int, itir.Expr]]] = {} for i, arg in enumerate(node.args): if cpm.is_applied_as_fieldop(arg): - assert arg.type _, domain = arg.fun.args # type: ignore[attr-defined] # ensured by cpm.is_applied_as_fieldop as_fieldop_args_by_domain.setdefault(domain, []) as_fieldop_args_by_domain[domain].append((i, arg)) @@ -331,9 +327,7 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs): fused_args = im.op_as_fieldop(lambda *args: im.make_tuple(*args), domain)( *(arg for _, arg in inner_as_fieldop_args) ) - fused_args.type = ts.TupleType( - types=[arg.type for _, arg in inner_as_fieldop_args] # type: ignore[misc] # has type is ensured on list creation - ) + type_inference.reinfer(arg) # don't recurse into nested args, but only consider newly created `as_fieldop` let_vars[var] = self.visit(fused_args, **{**kwargs, "recurse": False}) for outer_tuple_idx, (inner_tuple_idx, _) in enumerate( diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py index 1b95183422..ab2080c9e0 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py @@ -169,6 +169,7 @@ def test_make_tuple_fusion_trivial(): ) assert actual_simplified == expected + def test_make_tuple_fusion_symref(): d = im.domain("cartesian_domain", {IDim: (0, 1)}) testee = im.make_tuple( From 447f762770d63163a77f0d632ca6f8cb45ed1926 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sat, 11 Jan 2025 15:30:27 +0100 Subject: [PATCH 142/150] make_tuple fusion with symref args --- src/gt4py/next/iterator/ir_utils/ir_makers.py | 14 +++++- .../iterator/transforms/collapse_tuple.py | 2 +- .../iterator/transforms/constant_folding.py | 5 +++ .../iterator/transforms/fuse_as_fieldop.py | 45 +++++++++++-------- .../next/iterator/transforms/inline_scalar.py | 2 + .../transforms_tests/test_fuse_as_fieldop.py | 11 ++++- 6 files changed, 56 insertions(+), 23 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 0839e95b5b..f4b5285add 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -445,7 +445,7 @@ def domain( ) -def as_fieldop(expr: itir.Expr | str, domain: Optional[itir.Expr] = None) -> call: +def as_fieldop(expr: itir.Expr | str, domain: Optional[itir.Expr] = None) -> Callable: """ Create an `as_fieldop` call. @@ -454,7 +454,9 @@ def as_fieldop(expr: itir.Expr | str, domain: Optional[itir.Expr] = None) -> cal >>> str(as_fieldop(lambda_("it1", "it2")(plus(deref("it1"), deref("it2"))))("field1", "field2")) '(⇑(λ(it1, it2) → ·it1 + ·it2))(field1, field2)' """ - return call( + from gt4py.next.iterator.ir_utils import domain_utils + + result = call( call("as_fieldop")( *( ( @@ -467,6 +469,14 @@ def as_fieldop(expr: itir.Expr | str, domain: Optional[itir.Expr] = None) -> cal ) ) + def _populate_domain_annex_wrapper(*args, **kwargs): + node = result(*args, **kwargs) + if domain: + node.annex.domain = domain_utils.SymbolicDomain.from_expr(domain) + return node + + return _populate_domain_annex_wrapper + def op_as_fieldop( op: str | itir.SymRef | Callable, domain: Optional[itir.FunCall] = None diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 6ee527bb18..941f3c9662 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -137,7 +137,7 @@ def all(self) -> CollapseTuple.Flag: ignore_tuple_size: bool flags: Flag = Flag.all() # noqa: RUF009 [function-call-in-dataclass-default-argument] - PRESERVED_ANNEX_ATTRS = ("type",) + PRESERVED_ANNEX_ATTRS = ("type", "domain") @classmethod def apply( diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index 2084ab2518..8802d0dd84 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -12,6 +12,11 @@ class ConstantFolding(PreserveLocationVisitor, NodeTranslator): + PRESERVED_ANNEX_ATTRS = ( + "type", + "domain", + ) + @classmethod def apply(cls, node: ir.Node) -> ir.Node: return cls().visit(node) diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index 521579a400..2528a7828a 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -13,7 +13,11 @@ from gt4py.eve import utils as eve_utils from gt4py.next import common from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + domain_utils, + ir_makers as im, +) from gt4py.next.iterator.transforms import ( inline_center_deref_lift_vars, inline_lambdas, @@ -262,6 +266,8 @@ class FuseAsFieldOp(eve.NodeTranslator): as_fieldop(λ(inp1, inp2, inp3) → ·inp1 × ·inp2 + ·inp3, c⟨ IDimₕ: [0, 1[ ⟩)(inp1, inp2, inp3) """ # noqa: RUF002 # ignore ambiguous multiplication character + PRESERVED_ANNEX_ATTRS = ("domain",) + uids: eve_utils.UIDGenerator @classmethod @@ -292,7 +298,8 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs): # TODO(tehrengruber): Write test-case. E.g. Adding two sparse fields. Sara observed this # with a cast to a sparse field, but this is likely already covered. if cpm.is_let(node): - type_inference.reinfer(node) + for arg in node.args: + type_inference.reinfer(arg) eligible_args = [ isinstance(arg.type, ts.FieldType) and isinstance(arg.type.dtype, it_ts.ListType) for arg in node.args @@ -307,35 +314,37 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs): node = self.generic_visit(node, **kwargs) if cpm.is_call_to(node, "make_tuple"): - # TODO(tehrengruber): x, y = alpha * y, x is not fused - as_fieldop_args = [arg for arg in node.args if cpm.is_applied_as_fieldop(arg)] - distinct_domains = set(arg.fun.args[1] for arg in as_fieldop_args) # type: ignore[attr-defined] # ensured by cpm.is_applied_as_fieldop - if len(distinct_domains) != len(as_fieldop_args): + for arg in node.args: + type_inference.reinfer(arg) + assert hasattr(arg.annex, "domain") and isinstance( + arg.annex.domain, domain_utils.SymbolicDomain + ) + field_args = [arg for arg in node.args if isinstance(arg.type, ts.FieldType)] + distinct_domains = set(arg.annex.domain.as_expr() for arg in field_args) + if len(distinct_domains) != len(field_args): new_els: list[itir.Expr | None] = [None for _ in node.args] - as_fieldop_args_by_domain: dict[itir.Expr, list[tuple[int, itir.Expr]]] = {} + field_args_by_domain: dict[itir.Expr, list[tuple[int, itir.Expr]]] = {} for i, arg in enumerate(node.args): - if cpm.is_applied_as_fieldop(arg): - _, domain = arg.fun.args # type: ignore[attr-defined] # ensured by cpm.is_applied_as_fieldop - as_fieldop_args_by_domain.setdefault(domain, []) - as_fieldop_args_by_domain[domain].append((i, arg)) + if isinstance(arg.type, ts.FieldType): + domain = arg.annex.domain.as_expr() + field_args_by_domain.setdefault(domain, []) + field_args_by_domain[domain].append((i, arg)) else: new_els[i] = arg # keep as is let_vars = {} - for domain, inner_as_fieldop_args in as_fieldop_args_by_domain.items(): - if len(inner_as_fieldop_args) > 1: + for domain, inner_field_args in field_args_by_domain.items(): + if len(inner_field_args) > 1: var = self.uids.sequential_id(prefix="__fasfop") fused_args = im.op_as_fieldop(lambda *args: im.make_tuple(*args), domain)( - *(arg for _, arg in inner_as_fieldop_args) + *(arg for _, arg in inner_field_args) ) type_inference.reinfer(arg) # don't recurse into nested args, but only consider newly created `as_fieldop` let_vars[var] = self.visit(fused_args, **{**kwargs, "recurse": False}) - for outer_tuple_idx, (inner_tuple_idx, _) in enumerate( - inner_as_fieldop_args - ): + for outer_tuple_idx, (inner_tuple_idx, _) in enumerate(inner_field_args): new_els[inner_tuple_idx] = im.tuple_get(outer_tuple_idx, var) else: - i, arg = inner_as_fieldop_args[0] + i, arg = inner_field_args[0] new_els[i] = arg assert not any(el is None for el in new_els) assert let_vars diff --git a/src/gt4py/next/iterator/transforms/inline_scalar.py b/src/gt4py/next/iterator/transforms/inline_scalar.py index 87b576d14d..dd6470630e 100644 --- a/src/gt4py/next/iterator/transforms/inline_scalar.py +++ b/src/gt4py/next/iterator/transforms/inline_scalar.py @@ -16,6 +16,8 @@ class InlineScalar(eve.NodeTranslator): + PRESERVED_ANNEX_ATTRS = ("domain",) + @classmethod def apply(cls, program: itir.Program, offset_provider_type: common.OffsetProviderType): program = itir_inference.infer(program, offset_provider_type=offset_provider_type) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py index ab2080c9e0..1272559456 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py @@ -5,6 +5,7 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +import copy from typing import Callable, Optional from gt4py import next as gtx @@ -19,6 +20,12 @@ field_type = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32)) +def _with_domain_annex(node: itir.Expr, domain: itir.Expr): + node = copy.deepcopy(node) + node.annex.domain = domain + return node + + def test_trivial(): d = im.domain("cartesian_domain", {IDim: (0, 1)}) testee = im.op_as_fieldop("plus", d)( @@ -174,7 +181,7 @@ def test_make_tuple_fusion_symref(): d = im.domain("cartesian_domain", {IDim: (0, 1)}) testee = im.make_tuple( im.as_fieldop("deref", d)(im.ref("a", field_type)), - im.ref("b", field_type), + _with_domain_annex(im.ref("b", field_type), d), ) expected = im.as_fieldop( im.lambda_("a", "b")(im.make_tuple(im.deref("a"), im.deref("b"))), @@ -194,7 +201,7 @@ def test_make_tuple_fusion_symref2(): d = im.domain("cartesian_domain", {IDim: (0, 1)}) testee = im.make_tuple( im.as_fieldop("deref", d)(im.ref("a", field_type)), - im.ref("a", field_type), + _with_domain_annex(im.ref("a", field_type), d), ) expected = im.as_fieldop( im.lambda_("a")(im.make_tuple(im.deref("a"), im.deref("a"))), From 9263f4d3f56e108529355f2fc1d6b841409efe9f Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 14 Jan 2025 11:44:52 +0100 Subject: [PATCH 143/150] Address review comments --- src/gt4py/next/iterator/transforms/collapse_tuple.py | 8 ++++---- src/gt4py/next/iterator/type_system/inference.py | 1 + 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index b5ceb99c1b..44f3d99694 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -399,10 +399,10 @@ def transform_propagate_to_if_on_tuples_cps( for i, arg in enumerate(node.args): if cpm.is_call_to(arg, "if_"): itir_type_inference.reinfer(arg) - if not any(isinstance(branch.type, ts.TupleType) for branch in arg.args[1:]): - continue cond, true_branch, false_branch = arg.args # e.g. `True`, `{1, 2}`, `{3, 4}` + if not any(isinstance(branch.type, ts.TupleType) for branch in [true_branch, false_branch]): + continue tuple_type: ts.TupleType = true_branch.type # type: ignore[assignment] # type ensured above tuple_len = len(tuple_type.types) @@ -427,7 +427,7 @@ def transform_propagate_to_if_on_tuples_cps( # transformation of this argument. if new_f_body is f_body: continue - # if the function is not trivial the transformation we would get a larger tree + # if the function is not trivial the transformation we would create a larger tree # after inlining so we skip transformation this argument. if not _is_trivial_or_tuple_thereof_expr(new_f_body): continue @@ -439,7 +439,7 @@ def transform_propagate_to_if_on_tuples_cps( # this is the symbol refering to our continuation, e.g. `cont` in our example. f_var = self.uids.sequential_id(prefix="__ct_cont") new_branches = [] - for branch in arg.args[1:]: + for branch in [true_branch, false_branch]: new_branch = im.let(tuple_var, branch)( im.call(im.ref(f_var, f_type))( # call to the continuation *( diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 16ab160f2c..0f1fac64f8 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -560,6 +560,7 @@ def visit_AxisLiteral(self, node: itir.AxisLiteral, **kwargs) -> ts.DimensionTyp def visit_OffsetLiteral( self, node: itir.OffsetLiteral, **kwargs ) -> it_ts.OffsetLiteralType | ts.DeferredType: + # `self.dimensions` not available in re-inference mode. Skip since we don't care anyway. if self.reinfer: return ts.DeferredType(constraint=it_ts.OffsetLiteralType) From 72ca50d0541418b7efcfd1f822d212d28e1c67d2 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 14 Jan 2025 11:46:21 +0100 Subject: [PATCH 144/150] Fix format --- src/gt4py/next/iterator/transforms/collapse_tuple.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 44f3d99694..0a0cf6d37e 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -401,7 +401,9 @@ def transform_propagate_to_if_on_tuples_cps( itir_type_inference.reinfer(arg) cond, true_branch, false_branch = arg.args # e.g. `True`, `{1, 2}`, `{3, 4}` - if not any(isinstance(branch.type, ts.TupleType) for branch in [true_branch, false_branch]): + if not any( + isinstance(branch.type, ts.TupleType) for branch in [true_branch, false_branch] + ): continue tuple_type: ts.TupleType = true_branch.type # type: ignore[assignment] # type ensured above tuple_len = len(tuple_type.types) From ac170b696b007e4640f717afb8a02d95b5b1e46f Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 10 Jan 2025 16:25:26 +0100 Subject: [PATCH 145/150] Add support for nested tuples and symrefs --- .../iterator/transforms/fuse_as_fieldop.py | 123 ++++++++++++++---- .../inline_center_deref_lift_vars.py | 2 +- .../next/iterator/transforms/inline_lifts.py | 4 +- .../next/iterator/transforms/merge_let.py | 3 + .../transforms_tests/test_fuse_as_fieldop.py | 31 ++++- 5 files changed, 131 insertions(+), 32 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index 2528a7828a..38e362ef93 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -113,6 +113,27 @@ def _inline_as_fieldop_arg( def _unwrap_scan(stencil: itir.Lambda | itir.FunCall): + """ + If given a scan, extract stencil part of its scan pass and a back-transformation into a scan. + + If a regular stencil is given the stencil is left as-is and the back-transformation is the + identity function. This function allows treating a scan stencil like a regular stencil during + a transformation avoiding the complexity introduced by the different IR format. + + >>> scan = im.call("scan")( + ... im.lambda_("state", "arg")(im.plus("state", im.deref("arg"))), True, 0.0 + ... ) + >>> stencil, back_trafo = _unwrap_scan(scan) + >>> str(stencil) + 'λ(arg) → state + ·arg' + >>> assert back_trafo(stencil) == scan + + In case a regular stencil is given it is returned as-is: + + >>> deref_stencil = im.lambda_("it")(im.deref("it")) + >>> stencil, back_trafo = _unwrap_scan(deref_stencil) + >>> assert stencil == deref_stencil + """ if cpm.is_call_to(stencil, "scan"): scan_pass, direction, init = stencil.args assert isinstance(scan_pass, itir.Lambda) @@ -238,6 +259,14 @@ def _arg_inline_predicate(node: itir.Expr, shifts): return False +def _make_tuple_element_inline_predicate(node: itir.Expr): + if cpm.is_applied_as_fieldop(node): # field, or tuple of fields + return True + if isinstance(node.type, ts.FieldType) and isinstance(node, itir.SymRef): + return True + return False + + @dataclasses.dataclass class FuseAsFieldOp(eve.NodeTranslator): """ @@ -278,6 +307,7 @@ def apply( offset_provider_type: common.OffsetProviderType, uids: Optional[eve_utils.UIDGenerator] = None, allow_undeclared_symbols=False, + within_set_at_expr: Optional[bool] = None, ): node = type_inference.infer( node, @@ -285,10 +315,28 @@ def apply( allow_undeclared_symbols=allow_undeclared_symbols, ) + if within_set_at_expr is None: + within_set_at_expr = not isinstance(node, itir.Program) + if not uids: uids = eve_utils.UIDGenerator() - return cls(uids=uids).visit(node) + return cls(uids=uids).visit(node, within_set_at_expr=within_set_at_expr) + + def visit(self, node, **kwargs): + if not kwargs.get("within_set_at_expr"): + return node + new_node = super().visit(node, **kwargs) + if isinstance(node, itir.Expr) and hasattr(node.annex, "domain"): + new_node.annex.domain = node.annex.domain + return new_node + + def visit_SetAt(self, node: itir.SetAt, **kwargs): + return itir.SetAt( + expr=self.visit(node.expr, **kwargs | {"within_set_at_expr": True}), + domain=node.domain, + target=node.target, + ) def visit_FunCall(self, node: itir.FunCall, **kwargs): # inline all fields with list dtype. This needs to happen before the children are visited @@ -306,50 +354,71 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs): ] if any(eligible_args): node = inline_lambdas.inline_lambda(node, eligible_params=eligible_args) - return self.visit(node) + return self.visit(node, **kwargs) if cpm.is_applied_as_fieldop(node): # don't descend in stencil - node = im.as_fieldop(*node.fun.args)(*self.generic_visit(node.args)) # type: ignore[attr-defined] # ensured by cpm.is_applied_as_fieldop + node = im.as_fieldop(*node.fun.args)(*self.generic_visit(node.args, **kwargs)) # type: ignore[attr-defined] # ensured by cpm.is_applied_as_fieldop elif kwargs.get("recurse", True): node = self.generic_visit(node, **kwargs) if cpm.is_call_to(node, "make_tuple"): for arg in node.args: type_inference.reinfer(arg) - assert hasattr(arg.annex, "domain") and isinstance( - arg.annex.domain, domain_utils.SymbolicDomain + assert not isinstance(arg.type, ts.FieldType) or ( + hasattr(arg.annex, "domain") + and isinstance(arg.annex.domain, domain_utils.SymbolicDomain) ) - field_args = [arg for arg in node.args if isinstance(arg.type, ts.FieldType)] + + eligible_args = [_make_tuple_element_inline_predicate(arg) for arg in node.args] + field_args = [arg for i, arg in enumerate(node.args) if eligible_args[i]] distinct_domains = set(arg.annex.domain.as_expr() for arg in field_args) if len(distinct_domains) != len(field_args): new_els: list[itir.Expr | None] = [None for _ in node.args] - field_args_by_domain: dict[itir.Expr, list[tuple[int, itir.Expr]]] = {} + field_args_by_domain: dict[itir.FunCall, list[tuple[int, itir.Expr]]] = {} for i, arg in enumerate(node.args): - if isinstance(arg.type, ts.FieldType): + if eligible_args[i]: + assert isinstance(arg.annex.domain, domain_utils.SymbolicDomain) domain = arg.annex.domain.as_expr() field_args_by_domain.setdefault(domain, []) field_args_by_domain[domain].append((i, arg)) else: new_els[i] = arg # keep as is - let_vars = {} - for domain, inner_field_args in field_args_by_domain.items(): - if len(inner_field_args) > 1: - var = self.uids.sequential_id(prefix="__fasfop") - fused_args = im.op_as_fieldop(lambda *args: im.make_tuple(*args), domain)( - *(arg for _, arg in inner_field_args) - ) - type_inference.reinfer(arg) - # don't recurse into nested args, but only consider newly created `as_fieldop` - let_vars[var] = self.visit(fused_args, **{**kwargs, "recurse": False}) - for outer_tuple_idx, (inner_tuple_idx, _) in enumerate(inner_field_args): - new_els[inner_tuple_idx] = im.tuple_get(outer_tuple_idx, var) - else: - i, arg = inner_field_args[0] - new_els[i] = arg - assert not any(el is None for el in new_els) - assert let_vars - new_node = im.let(*let_vars.items())(im.make_tuple(*new_els)) - new_node = inline_lambdas.inline_lambda(new_node, opcount_preserving=True) + + if len(field_args_by_domain) == 1 and len( + next(iter(field_args_by_domain.values())) + ) == len(node.args): + # if we only have a single domain covering all args we don't need to create an + # unnecessary let + ((domain, inner_field_args),) = field_args_by_domain.items() + new_node = im.op_as_fieldop(lambda *args: im.make_tuple(*args), domain)( + *(arg for _, arg in inner_field_args) + ) + new_node = self.visit(new_node, **{**kwargs, "recurse": False}) + else: + let_vars = {} + for domain, inner_field_args in field_args_by_domain.items(): + if len(inner_field_args) > 1: + var = self.uids.sequential_id(prefix="__fasfop") + fused_args = im.op_as_fieldop( + lambda *args: im.make_tuple(*args), domain + )(*(arg for _, arg in inner_field_args)) + type_inference.reinfer(arg) + # don't recurse into nested args, but only consider newly created `as_fieldop` + # note: this will always inline as long as we inline center accessed + let_vars[var] = self.visit(fused_args, **{**kwargs, "recurse": False}) + for outer_tuple_idx, (inner_tuple_idx, _) in enumerate( + inner_field_args + ): + new_el = im.tuple_get(outer_tuple_idx, var) + new_el.annex.domain = domain_utils.SymbolicDomain.from_expr(domain) + new_els[inner_tuple_idx] = new_el + else: + i, arg = inner_field_args[0] + new_els[i] = arg + assert not any(el is None for el in new_els) + assert let_vars + new_node = im.let(*let_vars.items())(im.make_tuple(*new_els)) + new_node = inline_lambdas.inline_lambda(new_node, opcount_preserving=True) return new_node if cpm.is_call_to(node.fun, "as_fieldop"): diff --git a/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py b/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py index 9169c26769..7bd26d0f19 100644 --- a/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py +++ b/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py @@ -48,7 +48,7 @@ class InlineCenterDerefLiftVars(eve.NodeTranslator): Note: This pass uses and preserves the `recorded_shifts` annex. """ - PRESERVED_ANNEX_ATTRS: ClassVar[tuple[str, ...]] = ("recorded_shifts",) + PRESERVED_ANNEX_ATTRS: ClassVar[tuple[str, ...]] = ("domain", "recorded_shifts") uids: eve_utils.UIDGenerator diff --git a/src/gt4py/next/iterator/transforms/inline_lifts.py b/src/gt4py/next/iterator/transforms/inline_lifts.py index f27dbbb74c..07d116555d 100644 --- a/src/gt4py/next/iterator/transforms/inline_lifts.py +++ b/src/gt4py/next/iterator/transforms/inline_lifts.py @@ -8,7 +8,7 @@ import dataclasses import enum -from typing import Callable, Optional +from typing import Callable, ClassVar, Optional import gt4py.eve as eve from gt4py.eve import NodeTranslator, traits @@ -112,6 +112,8 @@ class InlineLifts( function nodes. """ + PRESERVED_ANNEX_ATTRS: ClassVar[tuple[str, ...]] = ("domain",) + class Flag(enum.IntEnum): #: `shift(...)(lift(f)(args...))` -> `lift(f)(shift(...)(args)...)` PROPAGATE_SHIFT = 1 diff --git a/src/gt4py/next/iterator/transforms/merge_let.py b/src/gt4py/next/iterator/transforms/merge_let.py index 0e7d74e594..9c0c25bd49 100644 --- a/src/gt4py/next/iterator/transforms/merge_let.py +++ b/src/gt4py/next/iterator/transforms/merge_let.py @@ -5,6 +5,7 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from typing import ClassVar import gt4py.eve as eve from gt4py.next.iterator import ir as itir @@ -26,6 +27,8 @@ class MergeLet(eve.PreserveLocationVisitor, eve.NodeTranslator): This can significantly reduce the depth of the tree and its readability. """ + PRESERVED_ANNEX_ATTRS: ClassVar[tuple[str, ...]] = ("domain",) + def visit_FunCall(self, node: itir.FunCall): node = self.generic_visit(node) if ( diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py index 1272559456..cc4e7529b2 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py @@ -10,7 +10,7 @@ from gt4py import next as gtx from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.ir_utils import ir_makers as im, domain_utils from gt4py.next.iterator.transforms import fuse_as_fieldop, collapse_tuple from gt4py.next.type_system import type_specifications as ts @@ -22,7 +22,7 @@ def _with_domain_annex(node: itir.Expr, domain: itir.Expr): node = copy.deepcopy(node) - node.annex.domain = domain + node.annex.domain = domain_utils.SymbolicDomain.from_expr(domain) return node @@ -197,7 +197,7 @@ def test_make_tuple_fusion_symref(): assert actual_simplified == expected -def test_make_tuple_fusion_symref2(): +def test_make_tuple_fusion_symref_same_ref(): d = im.domain("cartesian_domain", {IDim: (0, 1)}) testee = im.make_tuple( im.as_fieldop("deref", d)(im.ref("a", field_type)), @@ -217,6 +217,31 @@ def test_make_tuple_fusion_symref2(): assert actual_simplified == expected +def test_make_tuple_nested(): + d = im.domain("cartesian_domain", {IDim: (0, 1)}) + testee = im.make_tuple( + _with_domain_annex(im.ref("a", field_type), d), + im.make_tuple( + _with_domain_annex(im.ref("b", field_type), d), + _with_domain_annex(im.ref("c", field_type), d), + ), + ) + expected = im.as_fieldop( + im.lambda_("a", "b", "c")( + im.make_tuple(im.deref("a"), im.make_tuple(im.deref("b"), im.deref("c"))) + ), + d, + )(im.ref("a", field_type), im.ref("b", field_type), im.ref("c", field_type)) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + # simplify to remove unnecessary make_tuple call + actual_simplified = collapse_tuple.CollapseTuple.apply( + actual, within_stencil=False, allow_undeclared_symbols=True + ) + assert actual_simplified == expected + + def test_make_tuple_fusion_different_domains(): d1 = im.domain("cartesian_domain", {IDim: (0, 1)}) d2 = im.domain("cartesian_domain", {JDim: (0, 1)}) From 8d0a6fa01411c35d29544380c7f51803a056b99f Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 15 Jan 2025 03:27:48 +0100 Subject: [PATCH 146/150] Small fix --- src/gt4py/next/iterator/transforms/fuse_as_fieldop.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index 495e7d0756..f5e7320e87 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -321,8 +321,6 @@ def apply( return cls(uids=uids).visit(node, within_set_at_expr=within_set_at_expr) def visit(self, node, **kwargs): - if not kwargs.get("within_set_at_expr"): - return node new_node = super().visit(node, **kwargs) if isinstance(node, itir.Expr) and hasattr(node.annex, "domain"): new_node.annex.domain = node.annex.domain @@ -336,6 +334,9 @@ def visit_SetAt(self, node: itir.SetAt, **kwargs): ) def visit_FunCall(self, node: itir.FunCall, **kwargs): + if not kwargs.get("within_set_at_expr"): + return node + # inline all fields with list dtype. This needs to happen before the children are visited # such that the `as_fieldop` can be fused. # TODO(tehrengruber): what should we do in case the field with list dtype is a let itself? From 2b270fef3afc6249504fbdacd6ce73229d3685f7 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 16 Jan 2025 10:47:48 +0100 Subject: [PATCH 147/150] Small cleanup --- .../iterator/transforms/fuse_as_fieldop.py | 24 +++++++++---------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index f5e7320e87..c712e4b47f 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -346,12 +346,12 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs): if cpm.is_let(node): for arg in node.args: type_inference.reinfer(arg) - eligible_args = [ + eligible_els = [ isinstance(arg.type, ts.FieldType) and isinstance(arg.type.dtype, ts.ListType) for arg in node.args ] - if any(eligible_args): - node = inline_lambdas.inline_lambda(node, eligible_params=eligible_args) + if any(eligible_els): + node = inline_lambdas.inline_lambda(node, eligible_params=eligible_els) return self.visit(node, **kwargs) if cpm.is_applied_as_fieldop(node): # don't descend in stencil @@ -367,14 +367,14 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs): and isinstance(arg.annex.domain, domain_utils.SymbolicDomain) ) - eligible_args = [_make_tuple_element_inline_predicate(arg) for arg in node.args] - field_args = [arg for i, arg in enumerate(node.args) if eligible_args[i]] + eligible_els = [_make_tuple_element_inline_predicate(arg) for arg in node.args] + field_args = [arg for i, arg in enumerate(node.args) if eligible_els[i]] distinct_domains = set(arg.annex.domain.as_expr() for arg in field_args) if len(distinct_domains) != len(field_args): new_els: list[itir.Expr | None] = [None for _ in node.args] field_args_by_domain: dict[itir.FunCall, list[tuple[int, itir.Expr]]] = {} for i, arg in enumerate(node.args): - if eligible_args[i]: + if eligible_els[i]: assert isinstance(arg.annex.domain, domain_utils.SymbolicDomain) domain = arg.annex.domain.as_expr() field_args_by_domain.setdefault(domain, []) @@ -382,9 +382,7 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs): else: new_els[i] = arg # keep as is - if len(field_args_by_domain) == 1 and len( - next(iter(field_args_by_domain.values())) - ) == len(node.args): + if len(field_args_by_domain) == 1 and all(eligible_els): # if we only have a single domain covering all args we don't need to create an # unnecessary let ((domain, inner_field_args),) = field_args_by_domain.items() @@ -402,7 +400,7 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs): )(*(arg for _, arg in inner_field_args)) type_inference.reinfer(arg) # don't recurse into nested args, but only consider newly created `as_fieldop` - # note: this will always inline as long as we inline center accessed + # note: this will always inline (as we inline center accessed) let_vars[var] = self.visit(fused_args, **{**kwargs, "recurse": False}) for outer_tuple_idx, (inner_tuple_idx, _) in enumerate( inner_field_args @@ -436,13 +434,13 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs): args: list[itir.Expr] = node.args shifts = trace_shifts.trace_stencil(stencil, num_args=len(args)) - eligible_args = [ + eligible_els = [ _arg_inline_predicate(arg, arg_shifts) for arg, arg_shifts in zip(args, shifts, strict=True) ] - if any(eligible_args): + if any(eligible_els): return self.visit( - fuse_as_fieldop(node, eligible_args, uids=self.uids), + fuse_as_fieldop(node, eligible_els, uids=self.uids), **{**kwargs, "recurse": False}, ) return node From 4821a44fa7cc26b9962ccf738df77047f0fc4fe6 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 16 Jan 2025 12:16:07 +0100 Subject: [PATCH 148/150] Small cleanup & additional test cases --- .../iterator/transforms/fuse_as_fieldop.py | 4 +-- .../transforms_tests/test_fuse_as_fieldop.py | 26 ++++++++++++++++++- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index c712e4b47f..f05558499b 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -226,7 +226,7 @@ def fuse_as_fieldop( def _arg_inline_predicate(node: itir.Expr, shifts): if _is_tuple_expr_of_literals(node): return True - # TODO(tehrengruber): write test case ensuring scan is not tried to be inlined (e.g. test_call_scan_operator_from_field_operator) + if ( is_applied_fieldop := cpm.is_applied_as_fieldop(node) and not cpm.is_call_to(node.fun.args[0], "scan") # type: ignore[attr-defined] # ensured by cpm.is_applied_as_fieldop @@ -341,8 +341,6 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs): # such that the `as_fieldop` can be fused. # TODO(tehrengruber): what should we do in case the field with list dtype is a let itself? # This could duplicate other expressions which we did not intend to duplicate. - # TODO(tehrengruber): Write test-case. E.g. Adding two sparse fields. Sara observed this - # with a cast to a sparse field, but this is likely already covered. if cpm.is_let(node): for arg in node.args: type_inference.reinfer(arg) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py index cc4e7529b2..7d1911cc03 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py @@ -339,13 +339,37 @@ def test_chained_fusion(): ) assert actual == expected +def test_inline_as_fieldop_with_list_dtype(): + list_field_type = ts.FieldType(dims=[IDim], dtype=ts.ListType(element_type=ts.ScalarType(kind=ts.ScalarKind.INT32))) + d = im.domain("cartesian_domain", {IDim: (0, 1)}) + testee = im.as_fieldop(im.lambda_("inp")(im.call("reduce")(im.deref("inp"), 0)), d)( + im.as_fieldop("deref")(im.ref("inp", list_field_type)) + ) + expected = im.as_fieldop(im.lambda_("inp")(im.call("reduce")(im.deref("inp"), 0)), d)( + im.ref("inp", list_field_type) + ) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + assert actual == expected + def test_inline_into_scan(): d = im.domain("cartesian_domain", {IDim: (0, 1)}) - scan = im.call("scan")(im.lambda_("state", "a")(im.plus("state", im.deref("a"))), True, 0.0) + scan = im.call("scan")(im.lambda_("state", "a")(im.plus("state", im.deref("a"))), True, 0) testee = im.as_fieldop(scan, d)(im.as_fieldop("deref")(im.ref("a", field_type))) expected = im.as_fieldop(scan, d)(im.ref("a", field_type)) actual = fuse_as_fieldop.FuseAsFieldOp.apply( testee, offset_provider_type={}, allow_undeclared_symbols=True ) assert actual == expected + +def test_no_inline_into_scan(): + d = im.domain("cartesian_domain", {IDim: (0, 1)}) + scan_stencil = im.call("scan")(im.lambda_("state", "a")(im.plus("state", im.deref("a"))), True, 0) + scan = im.as_fieldop(scan_stencil, d)(im.ref("a", field_type)) + testee = im.as_fieldop(im.lambda_("arg")(im.deref("arg")), d)(scan) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + assert actual == testee From c954291f90786fcea2908c540bbb8e425269c684 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 16 Jan 2025 12:19:50 +0100 Subject: [PATCH 149/150] Fix doctest --- src/gt4py/next/iterator/transforms/fuse_as_fieldop.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index f05558499b..0ca1c57642 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -123,7 +123,8 @@ def _unwrap_scan(stencil: itir.Lambda | itir.FunCall): >>> stencil, back_trafo = _unwrap_scan(scan) >>> str(stencil) 'λ(arg) → state + ·arg' - >>> assert back_trafo(stencil) == scan + >>> str(back_trafo(stencil)) + 'scan(λ(state, arg) → (λ(arg) → state + ·arg)(arg), True, 0.0)' In case a regular stencil is given it is returned as-is: From 7ce9edf8edbf400acbd97a01650250e3d2b414cb Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 16 Jan 2025 12:28:17 +0100 Subject: [PATCH 150/150] Fix format --- .../transforms_tests/test_fuse_as_fieldop.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py index 7d1911cc03..dd8b931960 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py @@ -339,8 +339,11 @@ def test_chained_fusion(): ) assert actual == expected + def test_inline_as_fieldop_with_list_dtype(): - list_field_type = ts.FieldType(dims=[IDim], dtype=ts.ListType(element_type=ts.ScalarType(kind=ts.ScalarKind.INT32))) + list_field_type = ts.FieldType( + dims=[IDim], dtype=ts.ListType(element_type=ts.ScalarType(kind=ts.ScalarKind.INT32)) + ) d = im.domain("cartesian_domain", {IDim: (0, 1)}) testee = im.as_fieldop(im.lambda_("inp")(im.call("reduce")(im.deref("inp"), 0)), d)( im.as_fieldop("deref")(im.ref("inp", list_field_type)) @@ -364,9 +367,12 @@ def test_inline_into_scan(): ) assert actual == expected + def test_no_inline_into_scan(): d = im.domain("cartesian_domain", {IDim: (0, 1)}) - scan_stencil = im.call("scan")(im.lambda_("state", "a")(im.plus("state", im.deref("a"))), True, 0) + scan_stencil = im.call("scan")( + im.lambda_("state", "a")(im.plus("state", im.deref("a"))), True, 0 + ) scan = im.as_fieldop(scan_stencil, d)(im.ref("a", field_type)) testee = im.as_fieldop(im.lambda_("arg")(im.deref("arg")), d)(scan) actual = fuse_as_fieldop.FuseAsFieldOp.apply(