From 21b1dfc8a80ffd9f1825c1c72ca34c25e3db3c4c Mon Sep 17 00:00:00 2001 From: SF-N Date: Fri, 20 Sep 2024 09:07:53 +0200 Subject: [PATCH] feat[next]: domain inference for let, make_tuple, tuple_get, cond (#1591) Infers the minimal domain of (nested) `let`, `make_tuple`, `tuple_get`, `cond` and other builtins as an extension to PR #1568 - New functions `infer_let`, `infer_make_tuple`, `infer_tuple_get`, `infer_cond` in `gt4py.next.iterator.transforms.infer_domain` - New function `infer_expr` in gt4py.next.iterator.transforms.infer_domain which calls the appropriate of the above (or `infer_as_fieldop` and `infer_program`) - Several new tests in test_infer_domain.py to test functionality Note: Temporary handling was only present until commit fc4846f and has been removed in commit e8e679d to reduce unneeded complexity. This pass will be executed before temporary extraction, hence there exist valid `domain`s in all program calls, i.e. all `SetAt` do have a domain (not `AUTO_DOMAIN`) that doesn't need to be inferred. --------- Co-authored-by: Till Ehrengruber --- .../ir_utils/common_pattern_matcher.py | 5 + .../next/iterator/transforms/global_tmps.py | 8 +- .../next/iterator/transforms/infer_domain.py | 335 +++++-- .../codegens/gtfn/itir_to_gtfn_ir.py | 11 +- .../transforms_tests/test_domain_inference.py | 878 +++++++++++++----- 5 files changed, 914 insertions(+), 323 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py index 2c31cd17da..135b18c367 100644 --- a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py +++ b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py @@ -42,6 +42,11 @@ def is_applied_shift(arg: itir.Node) -> TypeGuard[itir.FunCall]: ) +def is_applied_as_fieldop(arg: itir.Node) -> TypeGuard[itir.FunCall]: + """Match expressions of the form `as_fieldop(stencil)(*args)`.""" + return isinstance(arg, itir.FunCall) and is_call_to(arg.fun, "as_fieldop") + + def is_let(node: itir.Node) -> TypeGuard[itir.FunCall]: """Match expression of the form `(λ(...) → ...)(...)`.""" return isinstance(node, itir.FunCall) and isinstance(node.fun, itir.Lambda) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index c0063e82d5..9bbbaa5c8f 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -16,7 +16,7 @@ import gt4py.next as gtx from gt4py.eve import NodeTranslator, PreserveLocationVisitor -from gt4py.eve.extended_typing import Dict, Tuple +from gt4py.eve.extended_typing import Tuple from gt4py.eve.traits import SymbolTableTrait from gt4py.eve.utils import UIDGenerator from gt4py.next import common @@ -454,7 +454,7 @@ def as_expr(self) -> ir.FunCall: def translate( self: SymbolicDomain, shift: Tuple[ir.OffsetLiteral, ...], - offset_provider: Dict[str, common.Dimension], + offset_provider: common.OffsetProvider, ) -> SymbolicDomain: dims = list(self.ranges.keys()) new_ranges = {dim: self.ranges[dim] for dim in dims} @@ -498,7 +498,7 @@ def translate( raise AssertionError("Number of shifts must be a multiple of 2.") -def domain_union(domains: list[SymbolicDomain]) -> SymbolicDomain: +def domain_union(*domains: SymbolicDomain) -> SymbolicDomain: """Return the (set) union of a list of domains.""" new_domain_ranges = {} assert all(domain.grid_type == domains[0].grid_type for domain in domains) @@ -617,7 +617,7 @@ def update_domains( consumed_domain.ranges.keys() == consumed_domains[0].ranges.keys() for consumed_domain in consumed_domains ): # scalar otherwise - domains[param] = domain_union(consumed_domains).as_expr() + domains[param] = domain_union(*consumed_domains).as_expr() return FencilWithTemporaries( fencil=ir.FencilDefinition( diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index e05c58e157..87f754d644 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -6,52 +6,131 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + +import itertools +import typing +from typing import Callable, TypeAlias + from gt4py.eve import utils as eve_utils -from gt4py.eve.extended_typing import Dict, Tuple +from gt4py.next import common from gt4py.next.common import Dimension from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator.transforms import trace_shifts -from gt4py.next.iterator.transforms.global_tmps import AUTO_DOMAIN, SymbolicDomain, domain_union +from gt4py.next.iterator.transforms.global_tmps import SymbolicDomain, domain_union +from gt4py.next.utils import tree_map + + +DOMAIN: TypeAlias = SymbolicDomain | None | tuple["DOMAIN", ...] +ACCESSED_DOMAINS: TypeAlias = dict[str, DOMAIN] + + +def split_dict_by_key(pred: Callable, d: dict): + """ + Split dictionary into two based on predicate. + + >>> d = {1: "a", 2: "b", 3: "c", 4: "d"} + >>> split_dict_by_key(lambda k: k % 2 == 0, d) + ({2: 'b', 4: 'd'}, {1: 'a', 3: 'c'}) + """ + a: dict = {} + b: dict = {} + for k, v in d.items(): + (a if pred(k) else b)[k] = v + return a, b + + +# TODO(tehrengruber): Revisit whether we want to move this behaviour to `domain_union`. +def _domain_union_with_none(*domains: SymbolicDomain | None) -> SymbolicDomain | None: + filtered_domains: list[SymbolicDomain] = [d for d in domains if d is not None] + if len(filtered_domains) == 0: + return None + return domain_union(*filtered_domains) + + +def canonicalize_domain_structure(d1: DOMAIN, d2: DOMAIN) -> tuple[DOMAIN, DOMAIN]: + """ + Given two domains or composites thereof, canonicalize their structure. + + If one of the arguments is a tuple the other one will be promoted to a tuple of same structure + unless it already is a tuple. Missing values are replaced by None, meaning no domain is + specified. + + >>> domain = im.domain(common.GridType.CARTESIAN, {}) + >>> canonicalize_domain_structure((domain,), (domain, domain)) == ( + ... (domain, None), + ... (domain, domain), + ... ) + True + + >>> canonicalize_domain_structure((domain, None), None) == ((domain, None), (None, None)) + True + """ + if d1 is None and isinstance(d2, tuple): + return canonicalize_domain_structure((None,) * len(d2), d2) + if d2 is None and isinstance(d1, tuple): + return canonicalize_domain_structure(d1, (None,) * len(d1)) + if isinstance(d1, tuple) and isinstance(d2, tuple): + return tuple( + zip( + *( + canonicalize_domain_structure(el1, el2) + for el1, el2 in itertools.zip_longest(d1, d2, fillvalue=None) + ) + ) + ) # type: ignore[return-value] # mypy not smart enough + return d1, d2 def _merge_domains( - original_domains: Dict[str, SymbolicDomain], additional_domains: Dict[str, SymbolicDomain] -) -> Dict[str, SymbolicDomain]: + original_domains: ACCESSED_DOMAINS, + additional_domains: ACCESSED_DOMAINS, +) -> ACCESSED_DOMAINS: new_domains = {**original_domains} - for key, value in additional_domains.items(): - if key in original_domains: - new_domains[key] = domain_union([original_domains[key], value]) - else: - new_domains[key] = value + + for key, domain in additional_domains.items(): + original_domain, domain = canonicalize_domain_structure( + original_domains.get(key, None), domain + ) + new_domains[key] = tree_map(_domain_union_with_none)(original_domain, domain) return new_domains -def extract_shifts_and_translate_domains( +def extract_accessed_domains( stencil: itir.Expr, input_ids: list[str], target_domain: SymbolicDomain, - offset_provider: Dict[str, Dimension], - accessed_domains: Dict[str, SymbolicDomain], -): + offset_provider: common.OffsetProvider, +) -> ACCESSED_DOMAINS: + accessed_domains: dict[str, SymbolicDomain | None] = {} + shifts_results = trace_shifts.trace_stencil(stencil, num_args=len(input_ids)) for in_field_id, shifts_list in zip(input_ids, shifts_results, strict=True): new_domains = [ SymbolicDomain.translate(target_domain, shift, offset_provider) for shift in shifts_list ] - if new_domains: - accessed_domains[in_field_id] = domain_union(new_domains) + # `None` means field is never accessed + accessed_domains[in_field_id] = _domain_union_with_none( + accessed_domains.get(in_field_id, None), *new_domains + ) + + return typing.cast(ACCESSED_DOMAINS, accessed_domains) def infer_as_fieldop( applied_fieldop: itir.FunCall, - target_domain: SymbolicDomain | itir.FunCall, - offset_provider: Dict[str, Dimension], -) -> Tuple[itir.FunCall, Dict[str, SymbolicDomain]]: + target_domain: DOMAIN, + offset_provider: common.OffsetProvider, +) -> tuple[itir.FunCall, ACCESSED_DOMAINS]: assert isinstance(applied_fieldop, itir.FunCall) assert cpm.is_call_to(applied_fieldop.fun, "as_fieldop") + if target_domain is None: + raise ValueError("'target_domain' cannot be 'None'.") + if not isinstance(target_domain, SymbolicDomain): + raise ValueError("'target_domain' needs to be a 'SymbolicDomain'.") # `as_fieldop(stencil)(inputs...)` stencil, inputs = applied_fieldop.fun.args[0], applied_fieldop.args @@ -60,7 +139,6 @@ def infer_as_fieldop( assert not isinstance(stencil, itir.Lambda) or len(stencil.params) == len(applied_fieldop.args) input_ids: list[str] = [] - accessed_domains: Dict[str, SymbolicDomain] = {} # Assign ids for all inputs to `as_fieldop`. `SymRef`s stay as is, nested `as_fieldop` get a # temporary id. @@ -71,31 +149,22 @@ def infer_as_fieldop( elif isinstance(in_field, itir.SymRef): id_ = in_field.id else: - raise ValueError(f"Unsupported type {type(in_field)}") + raise ValueError(f"Unsupported expression of type '{type(in_field)}'.") input_ids.append(id_) - if isinstance(target_domain, itir.FunCall): - target_domain = SymbolicDomain.from_expr(target_domain) - - extract_shifts_and_translate_domains( - stencil, input_ids, target_domain, offset_provider, accessed_domains + accessed_domains: ACCESSED_DOMAINS = extract_accessed_domains( + stencil, input_ids, target_domain, offset_provider ) - # Recursively infer domain of inputs and update domain arg of nested `as_fieldops` + # Recursively infer domain of inputs and update domain arg of nested `as_fieldop`s transformed_inputs: list[itir.Expr] = [] for in_field_id, in_field in zip(input_ids, inputs): - if isinstance(in_field, itir.FunCall): - transformed_input, accessed_domains_tmp = infer_as_fieldop( - in_field, accessed_domains[in_field_id], offset_provider - ) - transformed_inputs.append(transformed_input) + transformed_input, accessed_domains_tmp = infer_expr( + in_field, accessed_domains[in_field_id], offset_provider + ) + transformed_inputs.append(transformed_input) - # Merge accessed_domains and accessed_domains_tmp - accessed_domains = _merge_domains(accessed_domains, accessed_domains_tmp) - elif isinstance(in_field, itir.SymRef) or isinstance(in_field, itir.Literal): - transformed_inputs.append(in_field) - else: - raise ValueError(f"Unsupported type {type(in_field)}") + accessed_domains = _merge_domains(accessed_domains, accessed_domains_tmp) transformed_call = im.as_fieldop(stencil, SymbolicDomain.as_expr(target_domain))( *transformed_inputs @@ -110,81 +179,159 @@ def infer_as_fieldop( return transformed_call, accessed_domains_without_tmp -def _validate_temporary_usage(body: list[itir.Stmt], temporaries: list[str]): - assigned_targets = set() - for stmt in body: - assert isinstance(stmt, itir.SetAt) # TODO: extend for if-statements when they land - assert isinstance( - stmt.target, itir.SymRef - ) # TODO: stmt.target can be an expr, e.g. make_tuple - if stmt.target.id in assigned_targets: - raise ValueError("Temporaries can only be used once within a program.") - if stmt.target.id in temporaries: - assigned_targets.add(stmt.target.id) +def infer_let( + let_expr: itir.FunCall, + input_domain: DOMAIN, + offset_provider: common.OffsetProvider, +) -> tuple[itir.FunCall, ACCESSED_DOMAINS]: + assert cpm.is_let(let_expr) + assert isinstance(let_expr.fun, itir.Lambda) # just to make mypy happy + transformed_calls_expr, accessed_domains = infer_expr( + let_expr.fun.expr, input_domain, offset_provider + ) + + let_params = {param_sym.id for param_sym in let_expr.fun.params} + accessed_domains_let_args, accessed_domains_outer = split_dict_by_key( + lambda k: k in let_params, accessed_domains + ) + + transformed_calls_args: list[itir.Expr] = [] + for param, arg in zip(let_expr.fun.params, let_expr.args, strict=True): + transformed_calls_arg, accessed_domains_arg = infer_expr( + arg, + accessed_domains_let_args.get( + param.id, + None, + ), + offset_provider, + ) + accessed_domains_outer = _merge_domains(accessed_domains_outer, accessed_domains_arg) + transformed_calls_args.append(transformed_calls_arg) + + transformed_call = im.let( + *( + (str(param.id), call) + for param, call in zip(let_expr.fun.params, transformed_calls_args, strict=True) + ) + )(transformed_calls_expr) + + return transformed_call, accessed_domains_outer + + +def infer_make_tuple( + expr: itir.Expr, + domain: DOMAIN, + offset_provider: common.OffsetProvider, +) -> tuple[itir.Expr, ACCESSED_DOMAINS]: + assert cpm.is_call_to(expr, "make_tuple") + infered_args_expr = [] + actual_domains: ACCESSED_DOMAINS = {} + if not isinstance(domain, tuple): + # promote domain to a tuple of domains such that it has the same structure as + # the expression + # TODO(tehrengruber): Revisit. Still open how to handle IR in this case example: + # out @ c⟨ IDimₕ: [0, __out_size_0) ⟩ ← {__sym_1, __sym_2}; + domain = (domain,) * len(expr.args) + assert len(expr.args) >= len(domain) + # There may be less domains than tuple args, pad the domain with `None` in that case. + # e.g. `im.tuple_get(0, im.make_tuple(a, b), domain=domain)` + domain = (*domain, *(None for _ in range(len(expr.args) - len(domain)))) + for i, arg in enumerate(expr.args): + infered_arg_expr, actual_domains_arg = infer_expr(arg, domain[i], offset_provider) + infered_args_expr.append(infered_arg_expr) + actual_domains = _merge_domains(actual_domains, actual_domains_arg) + return im.call(expr.fun)(*infered_args_expr), actual_domains + + +def infer_tuple_get( + expr: itir.Expr, + domain: DOMAIN, + offset_provider: common.OffsetProvider, +) -> tuple[itir.Expr, ACCESSED_DOMAINS]: + assert cpm.is_call_to(expr, "tuple_get") + actual_domains: ACCESSED_DOMAINS = {} + idx, tuple_arg = expr.args + assert isinstance(idx, itir.Literal) + child_domain = tuple(None if i != int(idx.value) else domain for i in range(int(idx.value) + 1)) + infered_arg_expr, actual_domains_arg = infer_expr(tuple_arg, child_domain, offset_provider) + + infered_args_expr = im.tuple_get(idx.value, infered_arg_expr) + actual_domains = _merge_domains(actual_domains, actual_domains_arg) + return infered_args_expr, actual_domains + + +def infer_cond( + expr: itir.Expr, + domain: DOMAIN, + offset_provider: common.OffsetProvider, +) -> tuple[itir.Expr, ACCESSED_DOMAINS]: + assert cpm.is_call_to(expr, "cond") + infered_args_expr = [] + actual_domains: ACCESSED_DOMAINS = {} + cond, true_val, false_val = expr.args + for arg in [true_val, false_val]: + infered_arg_expr, actual_domains_arg = infer_expr(arg, domain, offset_provider) + infered_args_expr.append(infered_arg_expr) + actual_domains = _merge_domains(actual_domains, actual_domains_arg) + return im.call(expr.fun)(cond, *infered_args_expr), actual_domains + + +def infer_expr( + expr: itir.Expr, + domain: DOMAIN, + offset_provider: common.OffsetProvider, +) -> tuple[itir.Expr, ACCESSED_DOMAINS]: + if isinstance(expr, itir.SymRef): + return expr, {str(expr.id): domain} + elif isinstance(expr, itir.Literal): + return expr, {} + elif cpm.is_applied_as_fieldop(expr): + return infer_as_fieldop(expr, domain, offset_provider) + elif cpm.is_let(expr): + return infer_let(expr, domain, offset_provider) + elif cpm.is_call_to(expr, "make_tuple"): + return infer_make_tuple(expr, domain, offset_provider) + elif cpm.is_call_to(expr, "tuple_get"): + return infer_tuple_get(expr, domain, offset_provider) + elif cpm.is_call_to(expr, "cond"): + return infer_cond(expr, domain, offset_provider) + elif ( + cpm.is_call_to(expr, itir.ARITHMETIC_BUILTINS) + or cpm.is_call_to(expr, itir.TYPEBUILTINS) + or cpm.is_call_to(expr, "cast_") + ): + return expr, {} + else: + raise ValueError(f"Unsupported expression: {expr}") def infer_program( program: itir.Program, - offset_provider: Dict[str, Dimension], + offset_provider: dict[str, Dimension], ) -> itir.Program: - accessed_domains: dict[str, SymbolicDomain] = {} transformed_set_ats: list[itir.SetAt] = [] + assert ( + not program.function_definitions + ), "Domain propagation does not support function definitions." - temporaries: list[str] = [tmp.id for tmp in program.declarations] - - _validate_temporary_usage(program.body, temporaries) - - for set_at in reversed(program.body): + for set_at in program.body: assert isinstance(set_at, itir.SetAt) - if isinstance(set_at.expr, itir.SymRef): - transformed_set_ats.insert(0, set_at) - continue - assert isinstance(set_at.expr, itir.FunCall) - assert cpm.is_call_to(set_at.expr.fun, "as_fieldop") - assert isinstance( - set_at.target, itir.SymRef - ) # TODO: stmt.target can be an expr, e.g. make_tuple - if set_at.target.id in temporaries: - # ignore temporaries as their domain is the `AUTO_DOMAIN` placeholder - assert set_at.domain == AUTO_DOMAIN - else: - accessed_domains[set_at.target.id] = SymbolicDomain.from_expr(set_at.domain) - transformed_as_fieldop, current_accessed_domains = infer_as_fieldop( - set_at.expr, accessed_domains[set_at.target.id], offset_provider + transformed_call, _unused_domain = infer_expr( + set_at.expr, SymbolicDomain.from_expr(set_at.domain), offset_provider ) - transformed_set_ats.insert( - 0, + transformed_set_ats.append( itir.SetAt( - expr=transformed_as_fieldop, - domain=SymbolicDomain.as_expr(accessed_domains[set_at.target.id]), + expr=transformed_call, + domain=set_at.domain, target=set_at.target, ), ) - for field in current_accessed_domains: - if field in accessed_domains: - # multiple accesses to the same field -> compute union of accessed domains - if field in temporaries: - accessed_domains[field] = domain_union( - [accessed_domains[field], current_accessed_domains[field]] - ) - else: - # TODO(tehrengruber): if domain_ref is an external field the domain must - # already be larger. This should be checked, but would require additions - # to the IR. - pass - else: - accessed_domains[field] = current_accessed_domains[field] - - new_declarations = program.declarations - for temporary in new_declarations: - temporary.domain = SymbolicDomain.as_expr(accessed_domains[temporary.id]) - return itir.Program( id=program.id, function_definitions=program.function_definitions, params=program.params, - declarations=new_declarations, + declarations=program.declarations, body=transformed_set_ats, ) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index ad98ade084..e9a0ad16c4 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -15,6 +15,7 @@ from gt4py.eve.concepts import SymbolName from gt4py.next import common from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.iterator.type_system import inference as itir_type_inference from gt4py.next.program_processors.codegens.gtfn.gtfn_ir import ( Backend, @@ -208,14 +209,6 @@ def _bool_from_literal(node: itir.Node) -> bool: return node.value == "True" -def _is_applied_as_fieldop(arg: itir.Expr) -> TypeGuard[itir.FunCall]: - return ( - isinstance(arg, itir.FunCall) - and isinstance(arg.fun, itir.FunCall) - and arg.fun.fun == itir.SymRef(id="as_fieldop") - ) - - class _CannonicalizeUnstructuredDomain(eve.NodeTranslator): def visit_FunCall(self, node: itir.FunCall) -> itir.FunCall: if node.fun == itir.SymRef(id="unstructured_domain"): @@ -583,7 +576,7 @@ def visit_Stmt(self, node: itir.Stmt, **kwargs: Any) -> None: def visit_SetAt( self, node: itir.SetAt, *, extracted_functions: list, **kwargs: Any ) -> Union[StencilExecution, ScanExecution]: - assert _is_applied_as_fieldop(node.expr) + assert cpm.is_applied_as_fieldop(node.expr) stencil = node.expr.fun.args[0] # type: ignore[attr-defined] # checked in assert domain = node.domain inputs = node.expr.args diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py index eb3deb1a85..51932e0aa0 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py @@ -8,20 +8,20 @@ # TODO(SF-N): test scan operator +import pytest import numpy as np -from typing import Iterable +from typing import Iterable, Optional, Literal, Union from gt4py import eve from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms.infer_domain import infer_as_fieldop, infer_program -from gt4py.next.iterator.transforms.global_tmps import SymbolicDomain, AUTO_DOMAIN -import pytest -from gt4py.eve.extended_typing import Dict -from gt4py.next.common import Dimension, DimensionKind +from gt4py.next.iterator.transforms import infer_domain +from gt4py.next.iterator.transforms.global_tmps import SymbolicDomain +from gt4py.next.common import Dimension from gt4py.next import common, NeighborTableOffsetProvider from gt4py.next.type_system import type_specifications as ts from gt4py.next.iterator.transforms.constant_folding import ConstantFolding +from gt4py.next import utils float_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) IDim = common.Dimension(value="IDim", kind=common.DimensionKind.HORIZONTAL) @@ -33,11 +33,7 @@ @pytest.fixture def offset_provider(): - return { - "Ioff": IDim, - "Joff": JDim, - "Koff": KDim, - } + return {"Ioff": IDim, "Joff": JDim, "Koff": KDim} @pytest.fixture @@ -52,44 +48,66 @@ def unstructured_offset_provider(): } -def run_test_as_fieldop( - stencil: itir.Lambda, +def premap_field( + field: itir.Expr, dim: str, offset: int, domain: Optional[itir.FunCall] = None +) -> itir.Expr: + return im.as_fieldop(im.lambda_("it")(im.deref(im.shift(dim, offset)("it"))), domain)(field) + + +def setup_test_as_fieldop( + stencil: itir.Lambda | Literal["deref"], domain: itir.FunCall, - expected_domain_dict: Dict[str, Dict[str | Dimension, tuple[itir.Expr, itir.Expr]]], - offset_provider: Dict[str, Dimension], *, refs: Iterable[itir.SymRef] = None, - domain_type: str = common.GridType.CARTESIAN, -) -> None: +) -> tuple[itir.FunCall, itir.FunCall]: if refs is None: + assert isinstance(stencil, itir.Lambda) refs = [f"in_field{i+1}" for i in range(0, len(stencil.params))] testee = im.as_fieldop(stencil)(*refs) expected = im.as_fieldop(stencil, domain)(*refs) - - actual_call, actual_domains = infer_as_fieldop( - testee, SymbolicDomain.from_expr(domain), offset_provider - ) - - folded_domains = constant_fold_accessed_domains(actual_domains) - expected_domains = { - ref: SymbolicDomain.from_expr(im.domain(domain_type, d)) - for ref, d in expected_domain_dict.items() - } - - assert actual_call == expected - assert folded_domains == expected_domains + return testee, expected def run_test_program( - testee: itir.Program, expected: itir.Program, offset_provider: dict[str, Dimension] + testee: itir.Program, expected: itir.Program, offset_provider: common.OffsetProvider ) -> None: - actual_program = infer_program(testee, offset_provider) + actual_program = infer_domain.infer_program(testee, offset_provider) folded_program = constant_fold_domain_exprs(actual_program) assert folded_program == expected +def run_test_expr( + testee: itir.FunCall, + expected: itir.FunCall, + domain: itir.FunCall, + expected_domains: dict[str, itir.Expr | dict[str | Dimension, tuple[itir.Expr, itir.Expr]]], + offset_provider: common.OffsetProvider, +): + actual_call, actual_domains = infer_domain.infer_expr( + testee, SymbolicDomain.from_expr(domain), offset_provider + ) + folded_call = constant_fold_domain_exprs(actual_call) + folded_domains = constant_fold_accessed_domains(actual_domains) if actual_domains else None + + grid_type = str(domain.fun.id) + + def canonicalize_domain(d): + if isinstance(d, dict): + return im.domain(grid_type, d) + elif isinstance(d, itir.FunCall): + return d + elif d is None: + return None + raise AssertionError() + + expected_domains = {ref: canonicalize_domain(d) for ref, d in expected_domains.items()} + + assert folded_call == expected + assert folded_domains == expected_domains + + class _ConstantFoldDomainsExprs(eve.NodeTranslator): def visit_FunCall(self, node: itir.FunCall): if cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")): @@ -101,18 +119,56 @@ def constant_fold_domain_exprs(arg: itir.Node) -> itir.Node: return _ConstantFoldDomainsExprs().visit(arg) -def constant_fold_accessed_domains(domains: Dict[str, SymbolicDomain]) -> Dict[str, SymbolicDomain]: - return { - k: SymbolicDomain.from_expr(constant_fold_domain_exprs(v.as_expr())) - for k, v in domains.items() - } +def constant_fold_accessed_domains( + domains: infer_domain.ACCESSED_DOMAINS, +) -> infer_domain.ACCESSED_DOMAINS: + def fold_domain(domain: SymbolicDomain | None): + if domain is None: + return domain + return constant_fold_domain_exprs(domain.as_expr()) + + return {k: utils.tree_map(fold_domain)(v) for k, v in domains.items()} + + +def translate_domain( + domain: itir.FunCall, + shifts: dict[Union[common.Dimension, str], tuple[itir.Expr, itir.Expr]], + offset_provider: common.OffsetProvider, +) -> SymbolicDomain: + shift_tuples = [ + ( + im.ensure_offset( + itir.AxisLiteral(value=d.value, kind=d.kind) + if isinstance(d, common.Dimension) + else itir.AxisLiteral(value=d) + ), + im.ensure_offset(r), + ) + for d, r in shifts.items() + ] + + shift_list = [item for sublist in shift_tuples for item in sublist] + + translated_domain_expr = SymbolicDomain.translate( + SymbolicDomain.from_expr(domain), shift_list, offset_provider + ) + + return constant_fold_domain_exprs(translated_domain_expr.as_expr()) def test_forward_difference_x(offset_provider): stencil = im.lambda_("arg0")(im.minus(im.deref(im.shift("Ioff", 1)("arg0")), im.deref("arg0"))) domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) - expected_accessed_domains = {"in_field1": {IDim: (0, 12)}} - run_test_as_fieldop(stencil, domain, expected_accessed_domains, offset_provider) + expected_domains = {"in_field1": {IDim: (0, 12)}} + testee, expected = setup_test_as_fieldop(stencil, domain) + run_test_expr(testee, expected, domain, expected_domains, offset_provider) + + +def test_deref(offset_provider): + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + expected_domains = {"in_field": {IDim: (0, 11)}} + testee, expected = setup_test_as_fieldop("deref", domain, refs=["in_field"]) + run_test_expr(testee, expected, domain, expected_domains, offset_provider) def test_multi_length_shift(offset_provider): @@ -129,37 +185,21 @@ def test_multi_length_shift(offset_provider): ) ) domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) - expected_accessed_domains = {"in_field1": {IDim: (3, 14)}} - run_test_as_fieldop(stencil, domain, expected_accessed_domains, offset_provider) - - -def test_unused_input(offset_provider): - stencil = im.lambda_("arg0", "arg1")(im.deref("arg0")) - - domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) - expected_accessed_domains = { - "in_field1": {IDim: (0, 11)}, - } - run_test_as_fieldop( - stencil, - domain, - expected_accessed_domains, - offset_provider, - ) + expected_domains = {"in_field1": {IDim: (3, 14)}} + testee, expected = setup_test_as_fieldop(stencil, domain) + run_test_expr(testee, expected, domain, expected_domains, offset_provider) def test_unstructured_shift(unstructured_offset_provider): stencil = im.lambda_("arg0")(im.deref(im.shift("E2V", 1)("arg0"))) domain = im.domain(common.GridType.UNSTRUCTURED, {Edge: (0, 1)}) - expected_accessed_domains = {"in_field1": {Vertex: (0, 2)}} + expected_domains = {"in_field1": {Vertex: (0, 2)}} - run_test_as_fieldop( + testee, expected = setup_test_as_fieldop( stencil, domain, - expected_accessed_domains, - unstructured_offset_provider, - domain_type=common.GridType.UNSTRUCTURED, ) + run_test_expr(testee, expected, domain, expected_domains, unstructured_offset_provider) def test_laplace(offset_provider): @@ -179,9 +219,10 @@ def test_laplace(offset_provider): ) ) domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11), JDim: (0, 7)}) - expected_accessed_domains = {"in_field1": {IDim: (-1, 12), JDim: (-1, 8)}} + expected_domains = {"in_field1": {IDim: (-1, 12), JDim: (-1, 8)}} - run_test_as_fieldop(stencil, domain, expected_accessed_domains, offset_provider) + testee, expected = setup_test_as_fieldop(stencil, domain) + run_test_expr(testee, expected, domain, expected_domains, offset_provider) def test_shift_x_y_two_inputs(offset_provider): @@ -192,16 +233,15 @@ def test_shift_x_y_two_inputs(offset_provider): ) ) domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11), JDim: (0, 7)}) - expected_accessed_domains = { + expected_domains = { "in_field1": {IDim: (-1, 10), JDim: (0, 7)}, "in_field2": {IDim: (0, 11), JDim: (1, 8)}, } - run_test_as_fieldop( + testee, expected = setup_test_as_fieldop( stencil, domain, - expected_accessed_domains, - offset_provider, ) + run_test_expr(testee, expected, domain, expected_domains, offset_provider) def test_shift_x_y_two_inputs_literal(offset_provider): @@ -212,16 +252,15 @@ def test_shift_x_y_two_inputs_literal(offset_provider): ) ) domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11), JDim: (0, 7)}) - expected_accessed_domains = { + expected_domains = { "in_field1": {IDim: (-1, 10), JDim: (0, 7)}, } - run_test_as_fieldop( + testee, expected = setup_test_as_fieldop( stencil, domain, - expected_accessed_domains, - offset_provider, refs=(im.ref("in_field1"), 2), ) + run_test_expr(testee, expected, domain, expected_domains, offset_provider) def test_shift_x_y_z_three_inputs(offset_provider): @@ -234,36 +273,40 @@ def test_shift_x_y_z_three_inputs(offset_provider): im.deref(im.shift("Koff", -1)("arg2")), ) ) - domain_dict = { - IDim: (0, 11), - JDim: (0, 7), - KDim: (0, 3), - } - expected_domain_dict = { - "in_field1": { - IDim: (1, 12), - JDim: (0, 7), - KDim: (0, 3), - }, - "in_field2": { - IDim: (0, 11), - JDim: (1, 8), - KDim: (0, 3), - }, - "in_field3": { - IDim: (0, 11), - JDim: (0, 7), - KDim: (-1, 2), - }, + domain_dict = {IDim: (0, 11), JDim: (0, 7), KDim: (0, 3)} + expected_domains = { + "in_field1": {IDim: (1, 12), JDim: (0, 7), KDim: (0, 3)}, + "in_field2": {IDim: (0, 11), JDim: (1, 8), KDim: (0, 3)}, + "in_field3": {IDim: (0, 11), JDim: (0, 7), KDim: (-1, 2)}, } - run_test_as_fieldop( + testee, expected = setup_test_as_fieldop( stencil, im.domain(common.GridType.CARTESIAN, domain_dict), - expected_domain_dict, + ) + run_test_expr( + testee, + expected, + im.domain(common.GridType.CARTESIAN, domain_dict), + expected_domains, offset_provider, ) +def test_two_params_same_arg(offset_provider): + stencil = im.lambda_("arg0", "arg1")( + im.plus( + im.deref("arg0"), + im.deref(im.shift("Ioff", 1)("arg1")), + ) + ) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + expected_domains = { + "in_field": {IDim: (0, 12)}, + } + testee, expected = setup_test_as_fieldop(stencil, domain, refs=["in_field", "in_field"]) + run_test_expr(testee, expected, domain, expected_domains, offset_provider) + + def test_nested_stencils(offset_provider): inner_stencil = im.lambda_("arg0_tmp", "arg1_tmp")( im.plus( @@ -280,8 +323,8 @@ def test_nested_stencils(offset_provider): tmp = im.as_fieldop(inner_stencil)(im.ref("in_field1"), im.ref("in_field2")) testee = im.as_fieldop(stencil)(im.ref("in_field1"), tmp) - domain_inner = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11), JDim: (-1, 6)}) domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11), JDim: (0, 7)}) + domain_inner = translate_domain(domain, {"Ioff": 0, "Joff": -1}, offset_provider) expected_inner = im.as_fieldop(inner_stencil, domain_inner)( im.ref("in_field1"), im.ref("in_field2") @@ -289,14 +332,10 @@ def test_nested_stencils(offset_provider): expected = im.as_fieldop(stencil, domain)(im.ref("in_field1"), expected_inner) expected_domains = { - "in_field1": SymbolicDomain.from_expr( - im.domain(common.GridType.CARTESIAN, {IDim: (1, 12), JDim: (-1, 7)}) - ), - "in_field2": SymbolicDomain.from_expr( - im.domain(common.GridType.CARTESIAN, {IDim: (0, 11), JDim: (-2, 5)}) - ), + "in_field1": im.domain(common.GridType.CARTESIAN, {IDim: (1, 12), JDim: (-1, 7)}), + "in_field2": translate_domain(domain, {"Ioff": 0, "Joff": -2}, offset_provider), } - actual_call, actual_domains = infer_as_fieldop( + actual_call, actual_domains = infer_domain.infer_expr( testee, SymbolicDomain.from_expr(domain), offset_provider ) folded_domains = constant_fold_accessed_domains(actual_domains) @@ -332,17 +371,15 @@ def test_nested_stencils_n_times(offset_provider, iterations): testee = testee expected_domains = { - "in_field1": SymbolicDomain.from_expr( - im.domain(common.GridType.CARTESIAN, {IDim: (1, 12), JDim: (0, 7 + iterations - 1)}) + "in_field1": im.domain( + common.GridType.CARTESIAN, {IDim: (1, 12), JDim: (0, 7 + iterations - 1)} ), - "in_field2": SymbolicDomain.from_expr( - im.domain( - common.GridType.CARTESIAN, {IDim: (0, 11), JDim: (iterations, 7 + iterations)} - ) + "in_field2": im.domain( + common.GridType.CARTESIAN, {IDim: (0, 11), JDim: (iterations, 7 + iterations)} ), } - actual_call, actual_domains = infer_as_fieldop( + actual_call, actual_domains = infer_domain.infer_expr( testee, SymbolicDomain.from_expr(domain), offset_provider ) @@ -352,24 +389,45 @@ def test_nested_stencils_n_times(offset_provider, iterations): assert folded_domains == expected_domains +def test_unused_input(offset_provider): + stencil = im.lambda_("arg0", "arg1")(im.deref("arg0")) + + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + expected_domains = {"in_field1": {IDim: (0, 11)}, "in_field2": None} + testee, expected = setup_test_as_fieldop( + stencil, + domain, + ) + run_test_expr(testee, expected, domain, expected_domains, offset_provider) + + +def test_let_unused_field(offset_provider): + testee = im.let("a", "c")("b") + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + expected = im.let("a", "c")("b") + expected_domains = {"b": {IDim: (0, 11)}, "c": None} + + run_test_expr(testee, expected, domain, expected_domains, offset_provider) + + def test_program(offset_provider): stencil = im.lambda_("arg0")(im.minus(im.deref(im.shift("Ioff", 1)("arg0")), im.deref("arg0"))) applied_as_fieldop_tmp = im.as_fieldop(stencil)(im.ref("in_field")) applied_as_fieldop = im.as_fieldop(stencil)(im.ref("tmp")) - domain_tmp = im.domain(common.GridType.CARTESIAN, {IDim: (0, 12)}) domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + domain_tmp = im.domain(common.GridType.CARTESIAN, {IDim: (0, 12)}) - params = [im.sym(name) for name in ["in_field", "out_field", "_gtmp_auto_domain"]] + params = [im.sym(name) for name in ["in_field", "out_field"]] testee = itir.Program( id="forward_diff_with_tmp", function_definitions=[], params=params, - declarations=[itir.Temporary(id="tmp", domain=AUTO_DOMAIN, dtype=float_type)], + declarations=[itir.Temporary(id="tmp", domain=domain_tmp, dtype=float_type)], body=[ - itir.SetAt(expr=applied_as_fieldop_tmp, domain=AUTO_DOMAIN, target=im.ref("tmp")), + itir.SetAt(expr=applied_as_fieldop_tmp, domain=domain_tmp, target=im.ref("tmp")), itir.SetAt(expr=applied_as_fieldop, domain=domain, target=im.ref("out_field")), ], ) @@ -391,155 +449,543 @@ def test_program(offset_provider): run_test_program(testee, expected, offset_provider) -def test_program_two_tmps(offset_provider): - stencil = im.lambda_("arg0")(im.minus(im.deref(im.shift("Ioff", 1)("arg0")), im.deref("arg0"))) - - as_fieldop_tmp1 = im.as_fieldop(stencil)(im.ref("in_field")) - as_fieldop_tmp2 = im.as_fieldop(stencil)(im.ref("tmp1")) - as_fieldop = im.as_fieldop(stencil)(im.ref("tmp2")) - +def test_program_make_tuple(offset_provider): domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) - domain_tmp1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 13)}) - domain_tmp2 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 12)}) - - params = [im.sym(name) for name in ["in_field", "out_field", "_gtmp_auto_domain"]] + params = [im.sym(name) for name in ["in_field", "out_field"]] testee = itir.Program( - id="forward_diff_with_two_tmps", + id="make_tuple_prog", function_definitions=[], params=params, - declarations=[ - itir.Temporary(id="tmp1", domain=AUTO_DOMAIN, dtype=float_type), - itir.Temporary(id="tmp2", domain=AUTO_DOMAIN, dtype=float_type), - ], + declarations=[], body=[ - itir.SetAt(expr=as_fieldop_tmp1, domain=AUTO_DOMAIN, target=im.ref("tmp1")), - itir.SetAt(expr=as_fieldop_tmp2, domain=AUTO_DOMAIN, target=im.ref("tmp2")), - itir.SetAt(expr=as_fieldop, domain=domain, target=im.ref("out_field")), + itir.SetAt( + expr=im.make_tuple(im.as_fieldop("deref")("in_field"), "in_field"), + domain=domain, + target=im.ref("out_field"), + ), ], ) - expected_expr_tmp1 = im.as_fieldop(stencil, domain_tmp1)(im.ref("in_field")) - expected_expr_tmp2 = im.as_fieldop(stencil, domain_tmp2)(im.ref("tmp1")) - expected_expr = im.as_fieldop(stencil, domain)(im.ref("tmp2")) - expected = itir.Program( - id="forward_diff_with_two_tmps", + id="make_tuple_prog", function_definitions=[], params=params, - declarations=[ - itir.Temporary(id="tmp1", domain=domain_tmp1, dtype=float_type), - itir.Temporary(id="tmp2", domain=domain_tmp2, dtype=float_type), - ], + declarations=[], body=[ - itir.SetAt(expr=expected_expr_tmp1, domain=domain_tmp1, target=im.ref("tmp1")), - itir.SetAt(expr=expected_expr_tmp2, domain=domain_tmp2, target=im.ref("tmp2")), - itir.SetAt(expr=expected_expr, domain=domain, target=im.ref("out_field")), + itir.SetAt( + expr=im.make_tuple(im.as_fieldop("deref", domain)("in_field"), "in_field"), + domain=domain, + target=im.ref("out_field"), + ), ], ) run_test_program(testee, expected, offset_provider) -@pytest.mark.xfail(raises=ValueError) -def test_program_ValueError(offset_provider): - with pytest.raises(ValueError, match=r"Temporaries can only be used once within a program."): - stencil = im.lambda_("arg0")(im.deref("arg0")) +def test_cond(offset_provider): + stencil1 = im.lambda_("arg0")(im.minus(im.deref(im.shift("Ioff", 1)("arg0")), im.deref("arg0"))) + field_1 = im.as_fieldop(stencil1)(im.ref("in_field1")) + tmp_stencil2 = im.lambda_("arg0_tmp", "arg1_tmp")( + im.plus( + im.deref(im.shift("Ioff", 1)("arg0_tmp")), + im.deref(im.shift("Ioff", -1)("arg1_tmp")), + ) + ) + stencil2 = im.lambda_("arg0", "arg1")( + im.plus( + im.deref(im.shift("Ioff", 1)("arg0")), + im.deref(im.shift("Ioff", -1)("arg1")), + ) + ) + tmp2 = im.as_fieldop(tmp_stencil2)(im.ref("in_field1"), im.ref("in_field2")) + field_2 = im.as_fieldop(stencil2)(im.ref("in_field2"), tmp2) - as_fieldop_tmp = im.as_fieldop(stencil)(im.ref("in_field")) - as_fieldop = im.as_fieldop(stencil)(im.ref("tmp")) + cond = im.deref("cond_") - domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + testee = im.cond(cond, field_1, field_2) - params = [im.sym(name) for name in ["in_field", "out_field", "_gtmp_auto_domain"]] + domain = im.domain(common.GridType.CARTESIAN, {"IDim": (0, 11)}) + domain_tmp = translate_domain(domain, {"Ioff": -1}, offset_provider) + expected_domains_dict = {"in_field1": {IDim: (0, 12)}, "in_field2": {IDim: (-2, 12)}} + expected_tmp2 = im.as_fieldop(tmp_stencil2, domain_tmp)( + im.ref("in_field1"), im.ref("in_field2") + ) + expected_field_1 = im.as_fieldop(stencil1, domain)(im.ref("in_field1")) + expected_field_2 = im.as_fieldop(stencil2, domain)(im.ref("in_field2"), expected_tmp2) - infer_program( - itir.Program( - id="forward_diff_with_tmp", - function_definitions=[], - params=params, - declarations=[itir.Temporary(id="tmp", domain=AUTO_DOMAIN, dtype=float_type)], - body=[ - # target occurs twice here which is prohibited - itir.SetAt(expr=as_fieldop_tmp, domain=AUTO_DOMAIN, target=im.ref("tmp")), - itir.SetAt(expr=as_fieldop_tmp, domain=AUTO_DOMAIN, target=im.ref("tmp")), - itir.SetAt(expr=as_fieldop, domain=domain, target=im.ref("out_field")), - ], - ), - offset_provider, + expected = im.cond(cond, expected_field_1, expected_field_2) + + actual_call, actual_domains = infer_domain.infer_expr( + testee, SymbolicDomain.from_expr(domain), offset_provider + ) + + folded_domains = constant_fold_accessed_domains(actual_domains) + expected_domains = { + ref: im.domain(common.GridType.CARTESIAN, d) for ref, d in expected_domains_dict.items() + } + folded_call = constant_fold_domain_exprs(actual_call) + assert folded_call == expected + assert folded_domains == expected_domains + + +def test_let_scalar_expr(offset_provider): + testee = im.let("a", 1)(im.op_as_fieldop(im.plus)("a", "b")) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + expected = im.let("a", 1)(im.op_as_fieldop(im.plus, domain)("a", "b")) + expected_domains = {"b": {IDim: (0, 11)}} + + run_test_expr(testee, expected, domain, expected_domains, offset_provider) + + +def test_simple_let(offset_provider): + testee = im.let("a", premap_field("in_field", "Ioff", 1))("a") + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + expected = im.let("a", premap_field("in_field", "Ioff", 1, domain))("a") + + expected_domains = {"in_field": translate_domain(domain, {"Ioff": 1}, offset_provider)} + run_test_expr(testee, expected, domain, expected_domains, offset_provider) + + +def test_simple_let2(offset_provider): + testee = im.let("a", "in_field")(premap_field("a", "Ioff", 1)) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + expected = im.let("a", "in_field")(premap_field("a", "Ioff", 1, domain)) + + expected_domains = {"in_field": translate_domain(domain, {"Ioff": 1}, offset_provider)} + run_test_expr(testee, expected, domain, expected_domains, offset_provider) + + +def test_let(offset_provider): + testee = im.let( + "a", + premap_field("in_field", "Ioff", 1), + )(premap_field("a", "Ioff", 1)) + testee2 = premap_field(premap_field("in_field", "Ioff", 1), "Ioff", 1) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + domain_a = translate_domain(domain, {"Ioff": 1}, offset_provider) + expected = im.let( + "a", + premap_field("in_field", "Ioff", 1, domain_a), + )(premap_field("a", "Ioff", 1, domain)) + expected2 = premap_field(premap_field("in_field", "Ioff", 1, domain_a), "Ioff", 1, domain) + expected_domains = {"in_field": translate_domain(domain, {"Ioff": 2}, offset_provider)} + run_test_expr(testee, expected, domain, expected_domains, offset_provider) + + expected_domains_sym = {"in_field": translate_domain(domain, {"Ioff": 2}, offset_provider)} + actual_call2, actual_domains2 = infer_domain.infer_expr( + testee2, SymbolicDomain.from_expr(domain), offset_provider + ) + folded_domains2 = constant_fold_accessed_domains(actual_domains2) + folded_call2 = constant_fold_domain_exprs(actual_call2) + assert folded_call2 == expected2 + assert expected_domains_sym == folded_domains2 + + +def test_let_two_inputs(offset_provider): + multiply_stencil = im.lambda_("it1", "it2")(im.multiplies_(im.deref("it1"), im.deref("it2"))) + + testee = im.let( + ("inner1", premap_field("in_field1", "Ioff", 1)), + ("inner2", premap_field("in_field2", "Ioff", -1)), + )(im.as_fieldop(multiply_stencil)("inner1", "inner2")) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + domain_p1 = translate_domain(domain, {"Ioff": 1}, offset_provider) + domain_m1 = translate_domain(domain, {"Ioff": -1}, offset_provider) + expected = im.let( + ("inner1", premap_field("in_field1", "Ioff", 1, domain)), + ("inner2", premap_field("in_field2", "Ioff", -1, domain)), + )(im.as_fieldop(multiply_stencil, domain)("inner1", "inner2")) + expected_domains = { + "in_field1": domain_p1, + "in_field2": domain_m1, + } + run_test_expr(testee, expected, domain, expected_domains, offset_provider) + + +def test_nested_let_in_body(offset_provider): + testee = im.let("inner1", premap_field("outer", "Ioff", 1))( + im.let("inner2", premap_field("inner1", "Ioff", 1))(premap_field("inner2", "Ioff", 1)) + ) + + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + domain_p1 = translate_domain(domain, {"Ioff": 1}, offset_provider) + domain_p2 = translate_domain(domain, {"Ioff": 2}, offset_provider) + domain_p3 = translate_domain(domain, {"Ioff": 3}, offset_provider) + + expected = im.let( + "inner1", + premap_field("outer", "Ioff", 1, domain_p2), + )( + im.let("inner2", premap_field("inner1", "Ioff", 1, domain_p1))( + premap_field("inner2", "Ioff", 1, domain) ) + ) + expected_domains = {"outer": domain_p3} + run_test_expr(testee, expected, domain, expected_domains, offset_provider) -def test_program_tree_tmps_two_inputs(offset_provider): - stencil = im.lambda_("arg0", "arg1")( - im.minus(im.deref(im.shift("Ioff", 1)("arg0")), im.deref("arg1")) +def test_nested_let_arg(offset_provider): + testee = im.let("a", "in_field")( + im.as_fieldop( + im.lambda_("it1", "it2")( + im.multiplies_(im.deref("it1"), im.deref(im.shift("Ioff", 1)("it2"))) + ) + )("a", "in_field") ) - stencil_tmp = im.lambda_("arg0")( - im.minus(im.deref(im.shift("Ioff", 1)("arg0")), im.deref("arg0")) + + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + + expected = im.let("a", "in_field")( + im.as_fieldop( + im.lambda_("it1", "it2")( + im.multiplies_(im.deref("it1"), im.deref(im.shift("Ioff", 1)("it2"))) + ), + domain, + )("a", "in_field") + ) + expected_domains = {"in_field": {IDim: (0, 12)}} + run_test_expr(testee, expected, domain, expected_domains, offset_provider) + + +def test_nested_let_arg_shadowed(offset_provider): + testee = im.let("a", premap_field("in_field", "Ioff", 3))( + im.let("a", premap_field("a", "Ioff", 2))(premap_field("a", "Ioff", 1)) + ) + + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + domain_p1 = translate_domain(domain, {"Ioff": 1}, offset_provider) + domain_p3 = translate_domain(domain, {"Ioff": 3}, offset_provider) + domain_p6 = translate_domain(domain, {"Ioff": 6}, offset_provider) + + expected = im.let( + "a", + premap_field("in_field", "Ioff", 3, domain_p3), + )(im.let("a", premap_field("a", "Ioff", 2, domain_p1))(premap_field("a", "Ioff", 1, domain))) + expected_domains = {"in_field": domain_p6} + run_test_expr(testee, expected, domain, expected_domains, offset_provider) + + +def test_nested_let_arg_shadowed2(offset_provider): + # test that if we shadow `in_field1` its accessed domain is not affected by the accesses + # on the shadowed field + testee = im.as_fieldop( + im.lambda_("it1", "it2")(im.multiplies_(im.deref("it1"), im.deref("it2"))) + )( + premap_field("in_field1", "Ioff", 1), # only here we access `in_field1` + im.let("in_field1", "in_field2")("in_field1"), # here we actually access `in_field2` ) - stencil_tmp_minus = im.lambda_("arg0", "arg1")( - im.minus(im.deref(im.shift("Ioff", -1)("arg0")), im.deref("arg1")) + + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + domain_p1 = translate_domain(domain, {"Ioff": 1}, offset_provider) + + expected = im.as_fieldop( + im.lambda_("it1", "it2")(im.multiplies_(im.deref("it1"), im.deref("it2"))), domain + )( + premap_field("in_field1", "Ioff", 1, domain), + im.let("in_field1", "in_field2")("in_field1"), ) + expected_domains = {"in_field1": domain_p1, "in_field2": domain} + run_test_expr(testee, expected, domain, expected_domains, offset_provider) - as_fieldop_tmp1 = im.as_fieldop(stencil)(im.ref("in_field1"), im.ref("in_field2")) - as_fieldop_tmp2 = im.as_fieldop(stencil_tmp)(im.ref("tmp1")) - as_fieldop_out1 = im.as_fieldop(stencil_tmp)(im.ref("tmp2")) - as_fieldop_tmp3 = im.as_fieldop(stencil)(im.ref("tmp1"), im.ref("in_field2")) - as_fieldop_out2 = im.as_fieldop(stencil_tmp_minus)(im.ref("tmp2"), im.ref("tmp3")) - domain_tmp1 = im.domain(common.GridType.CARTESIAN, {IDim: (-1, 13)}) - domain_tmp2 = im.domain(common.GridType.CARTESIAN, {IDim: (-1, 12)}) - domain_tmp3 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) - domain_out = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) - params = [ - im.sym(name) - for name in ["in_field1", "in_field2", "out_field1", "out_field2", "_gtmp_auto_domain"] - ] +def test_double_nested_let_fun_expr(offset_provider): + testee = im.let("inner1", premap_field("outer", "Ioff", 1))( + im.let("inner2", premap_field("inner1", "Ioff", -1))( + im.let("inner3", premap_field("inner2", "Ioff", -1))(premap_field("inner3", "Ioff", 3)) + ) + ) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + domain_p1 = translate_domain(domain, {"Ioff": 1}, offset_provider) + domain_p2 = translate_domain(domain, {"Ioff": 2}, offset_provider) + domain_p3 = translate_domain(domain, {"Ioff": 3}, offset_provider) + + expected = im.let("inner1", premap_field("outer", "Ioff", 1, domain_p1))( + im.let("inner2", premap_field("inner1", "Ioff", -1, domain_p2))( + im.let("inner3", premap_field("inner2", "Ioff", -1, domain_p3))( + premap_field("inner3", "Ioff", 3, domain) + ) + ) + ) + + expected_domains = {"outer": domain_p2} + + run_test_expr(testee, expected, domain, expected_domains, offset_provider) + + +def test_nested_let_args(offset_provider): + testee = im.let( + "inner", + im.let("inner_arg", premap_field("outer", "Ioff", 1))( + premap_field("inner_arg", "Ioff", -1) + ), + )(premap_field("inner", "Ioff", -1)) + + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + domain_m1 = translate_domain(domain, {"Ioff": -1}, offset_provider) + domain_m2 = translate_domain(domain, {"Ioff": -2}, offset_provider) + + expected = im.let( + "inner", + im.let("inner_arg", premap_field("outer", "Ioff", 1, domain_m2))( + premap_field("inner_arg", "Ioff", -1, domain_m1) + ), + )(premap_field("inner", "Ioff", -1, domain)) + + expected_domains = {"outer": domain_m1} + + run_test_expr(testee, expected, domain, expected_domains, offset_provider) + + +def test_program_let(offset_provider): + stencil_tmp = im.lambda_("arg0")( + im.minus(im.deref(im.shift("Ioff", -1)("arg0")), im.deref("arg0")) + ) + let_tmp = im.let("inner", premap_field("outer", "Ioff", -1))(premap_field("inner", "Ioff", -1)) + as_fieldop = im.as_fieldop(stencil_tmp)(im.ref("tmp")) + + domain_lm2_rm1 = im.domain(common.GridType.CARTESIAN, {IDim: (-2, 10)}) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + domain_lm1 = im.domain(common.GridType.CARTESIAN, {IDim: (-1, 11)}) + + params = [im.sym(name) for name in ["in_field", "out_field", "outer"]] testee = itir.Program( - id="differences_three_tmps_two_inputs", + id="forward_diff_with_tmp", function_definitions=[], params=params, - declarations=[ - itir.Temporary(id="tmp1", domain=AUTO_DOMAIN, dtype=float_type), - itir.Temporary(id="tmp2", domain=AUTO_DOMAIN, dtype=float_type), - itir.Temporary(id="tmp3", domain=AUTO_DOMAIN, dtype=float_type), - ], + declarations=[itir.Temporary(id="tmp", domain=domain_lm1, dtype=float_type)], body=[ - itir.SetAt(expr=as_fieldop_tmp1, domain=AUTO_DOMAIN, target=im.ref("tmp1")), - itir.SetAt(expr=as_fieldop_tmp2, domain=AUTO_DOMAIN, target=im.ref("tmp2")), - itir.SetAt(expr=as_fieldop_out1, domain=domain_out, target=im.ref("out_field1")), - itir.SetAt(expr=as_fieldop_tmp3, domain=AUTO_DOMAIN, target=im.ref("tmp3")), - itir.SetAt(expr=as_fieldop_out2, domain=domain_out, target=im.ref("out_field2")), + itir.SetAt(expr=let_tmp, domain=domain_lm1, target=im.ref("tmp")), + itir.SetAt(expr=as_fieldop, domain=domain, target=im.ref("out_field")), ], ) - expected_expr_tmp1 = im.as_fieldop(stencil, domain_tmp1)( - im.ref("in_field1"), im.ref("in_field2") - ) - expected_expr_tmp2 = im.as_fieldop(stencil_tmp, domain_tmp2)(im.ref("tmp1")) - expected_expr_out1 = im.as_fieldop(stencil_tmp, domain_out)(im.ref("tmp2")) - expected_expr_tmp3 = im.as_fieldop(stencil, domain_tmp3)(im.ref("tmp1"), im.ref("in_field2")) - expected_expr_out2 = im.as_fieldop(stencil_tmp_minus, domain_out)( - im.ref("tmp2"), im.ref("tmp3") + expected_let = im.let("inner", premap_field("outer", "Ioff", -1, domain_lm2_rm1))( + premap_field("inner", "Ioff", -1, domain_lm1) ) + expected_as_fieldop = im.as_fieldop(stencil_tmp, domain)(im.ref("tmp")) expected = itir.Program( - id="differences_three_tmps_two_inputs", + id="forward_diff_with_tmp", function_definitions=[], params=params, - declarations=[ - itir.Temporary(id="tmp1", domain=domain_tmp1, dtype=float_type), - itir.Temporary(id="tmp2", domain=domain_tmp2, dtype=float_type), - itir.Temporary(id="tmp3", domain=domain_tmp3, dtype=float_type), - ], + declarations=[itir.Temporary(id="tmp", domain=domain_lm1, dtype=float_type)], body=[ - itir.SetAt(expr=expected_expr_tmp1, domain=domain_tmp1, target=im.ref("tmp1")), - itir.SetAt(expr=expected_expr_tmp2, domain=domain_tmp2, target=im.ref("tmp2")), - itir.SetAt(expr=expected_expr_out1, domain=domain_out, target=im.ref("out_field1")), - itir.SetAt(expr=expected_expr_tmp3, domain=domain_tmp3, target=im.ref("tmp3")), - itir.SetAt(expr=expected_expr_out2, domain=domain_out, target=im.ref("out_field2")), + itir.SetAt(expr=expected_let, domain=domain_lm1, target=im.ref("tmp")), + itir.SetAt(expr=expected_as_fieldop, domain=domain, target=im.ref("out_field")), ], ) run_test_program(testee, expected, offset_provider) + + +def test_make_tuple(offset_provider): + testee = im.make_tuple(im.as_fieldop("deref")("in_field1"), im.as_fieldop("deref")("in_field2")) + domain1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + domain2 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 13)}) + expected = im.make_tuple( + im.as_fieldop("deref", domain1)("in_field1"), im.as_fieldop("deref", domain2)("in_field2") + ) + expected_domains_dict = {"in_field1": {IDim: (0, 11)}, "in_field2": {IDim: (0, 13)}} + expected_domains = { + ref: im.domain(common.GridType.CARTESIAN, d) for ref, d in expected_domains_dict.items() + } + + actual, actual_domains = infer_domain.infer_expr( + testee, + (SymbolicDomain.from_expr(domain1), SymbolicDomain.from_expr(domain2)), + offset_provider, + ) + + assert expected == actual + assert expected_domains == constant_fold_accessed_domains(actual_domains) + + +def test_tuple_get_1_make_tuple(offset_provider): + testee = im.tuple_get(1, im.make_tuple(im.ref("a"), im.ref("b"), im.ref("c"))) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + expected = im.tuple_get(1, im.make_tuple(im.ref("a"), im.ref("b"), im.ref("c"))) + expected_domains = { + "a": None, + "b": im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}), + "c": None, + } + + actual, actual_domains = infer_domain.infer_expr( + testee, SymbolicDomain.from_expr(domain), offset_provider + ) + + assert expected == actual + assert expected_domains == constant_fold_accessed_domains(actual_domains) + + +def test_tuple_get_1_nested_make_tuple(offset_provider): + testee = im.tuple_get(1, im.make_tuple(im.ref("a"), im.make_tuple(im.ref("b"), im.ref("c")))) + domain1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + domain2 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 12)}) + expected = im.tuple_get(1, im.make_tuple(im.ref("a"), im.make_tuple(im.ref("b"), im.ref("c")))) + expected_domains = {"a": None, "b": domain1, "c": domain2} + + actual, actual_domains = infer_domain.infer_expr( + testee, + (SymbolicDomain.from_expr(domain1), SymbolicDomain.from_expr(domain2)), + offset_provider, + ) + + assert expected == actual + assert expected_domains == constant_fold_accessed_domains(actual_domains) + + +def test_tuple_get_let_arg_make_tuple(offset_provider): + testee = im.tuple_get(1, im.let("a", im.make_tuple(im.ref("b"), im.ref("c")))("d")) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + expected = im.tuple_get(1, im.let("a", im.make_tuple(im.ref("b"), im.ref("c")))("d")) + expected_domains = {"b": None, "c": None, "d": (None, domain)} + + actual, actual_domains = infer_domain.infer_expr( + testee, + SymbolicDomain.from_expr(im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)})), + offset_provider, + ) + + assert expected == actual + assert expected_domains == constant_fold_accessed_domains(actual_domains) + + +def test_tuple_get_let_make_tuple(offset_provider): + testee = im.tuple_get(1, im.let("a", "b")(im.make_tuple(im.ref("c"), im.ref("d")))) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + expected = im.tuple_get(1, im.let("a", "b")(im.make_tuple(im.ref("c"), im.ref("d")))) + expected_domains = {"c": None, "d": domain, "b": None} + + actual, actual_domains = infer_domain.infer_expr( + testee, + SymbolicDomain.from_expr(domain), + offset_provider, + ) + + assert expected == actual + assert expected_domains == constant_fold_accessed_domains(actual_domains) + + +def test_nested_make_tuple(offset_provider): + testee = im.make_tuple(im.make_tuple(im.ref("a"), im.ref("b")), im.ref("c")) + domain1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + domain2_1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 12)}) + domain2_2 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 13)}) + domain3 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 14)}) + expected = im.make_tuple(im.make_tuple(im.ref("a"), im.ref("b")), im.ref("c")) + expected_domains = {"a": domain1, "b": (domain2_1, domain2_2), "c": domain3} + + actual, actual_domains = infer_domain.infer_expr( + testee, + ( + ( + SymbolicDomain.from_expr(domain1), + (SymbolicDomain.from_expr(domain2_1), SymbolicDomain.from_expr(domain2_2)), + ), + SymbolicDomain.from_expr(domain3), + ), + offset_provider, + ) + + assert expected == actual + assert expected_domains == constant_fold_accessed_domains(actual_domains) + + +def test_tuple_get_1(offset_provider): + testee = im.tuple_get(1, im.ref("a")) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + expected = im.tuple_get(1, im.ref("a")) + expected_domains = {"a": (None, domain)} + + actual, actual_domains = infer_domain.infer_expr( + testee, SymbolicDomain.from_expr(domain), offset_provider + ) + + assert expected == actual + assert expected_domains == constant_fold_accessed_domains(actual_domains) + + +def test_domain_tuple(offset_provider): + testee = im.ref("a") + domain1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + domain2 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 12)}) + expected = im.ref("a") + expected_domains = {"a": (domain1, domain2)} + + actual, actual_domains = infer_domain.infer_expr( + testee, + (SymbolicDomain.from_expr(domain1), SymbolicDomain.from_expr(domain2)), + offset_provider, + ) + + assert expected == actual + assert expected_domains == constant_fold_accessed_domains(actual_domains) + + +def test_as_fieldop_tuple_get(offset_provider): + testee = im.op_as_fieldop(im.plus)(im.tuple_get(0, im.ref("a")), im.tuple_get(1, im.ref("a"))) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + expected = im.op_as_fieldop(im.plus, domain)( + im.tuple_get(0, im.ref("a")), im.tuple_get(1, im.ref("a")) + ) + expected_domains = {"a": (domain, domain)} + + actual, actual_domains = infer_domain.infer_expr( + testee, SymbolicDomain.from_expr(domain), offset_provider + ) + + assert expected == actual + assert expected_domains == constant_fold_accessed_domains(actual_domains) + + +def test_make_tuple_2tuple_get(offset_provider): + testee = im.make_tuple(im.tuple_get(0, im.ref("a")), im.tuple_get(1, im.ref("a"))) + domain1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + domain2 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + expected = im.make_tuple(im.tuple_get(0, im.ref("a")), im.tuple_get(1, im.ref("a"))) + expected_domains = {"a": (domain1, domain2)} + + actual, actual_domains = infer_domain.infer_expr( + testee, + (SymbolicDomain.from_expr(domain1), SymbolicDomain.from_expr(domain2)), + offset_provider, + ) + + assert expected == actual + assert expected_domains == constant_fold_accessed_domains(actual_domains) + + +def test_make_tuple_non_tuple_domain(offset_provider): + testee = im.make_tuple(im.as_fieldop("deref")("in_field1"), im.as_fieldop("deref")("in_field2")) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + + expected = im.make_tuple( + im.as_fieldop("deref", domain)("in_field1"), im.as_fieldop("deref", domain)("in_field2") + ) + expected_domains = {"in_field1": domain, "in_field2": domain} + + actual, actual_domains = infer_domain.infer_expr( + testee, SymbolicDomain.from_expr(domain), offset_provider + ) + + assert expected == actual + assert expected_domains == constant_fold_accessed_domains(actual_domains) + + +def test_arithmetic_builtin(offset_provider): + testee = im.plus(im.ref("in_field1"), im.ref("in_field2")) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + expected = im.plus(im.ref("in_field1"), im.ref("in_field2")) + expected_domains = {} + + actual_call, actual_domains = infer_domain.infer_expr( + testee, SymbolicDomain.from_expr(domain), offset_provider + ) + folded_call = constant_fold_domain_exprs(actual_call) + + assert folded_call == expected + assert actual_domains == expected_domains