diff --git a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py index 2c31cd17da..135b18c367 100644 --- a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py +++ b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py @@ -42,6 +42,11 @@ def is_applied_shift(arg: itir.Node) -> TypeGuard[itir.FunCall]: ) +def is_applied_as_fieldop(arg: itir.Node) -> TypeGuard[itir.FunCall]: + """Match expressions of the form `as_fieldop(stencil)(*args)`.""" + return isinstance(arg, itir.FunCall) and is_call_to(arg.fun, "as_fieldop") + + def is_let(node: itir.Node) -> TypeGuard[itir.FunCall]: """Match expression of the form `(λ(...) → ...)(...)`.""" return isinstance(node, itir.FunCall) and isinstance(node.fun, itir.Lambda) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index c0063e82d5..9bbbaa5c8f 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -16,7 +16,7 @@ import gt4py.next as gtx from gt4py.eve import NodeTranslator, PreserveLocationVisitor -from gt4py.eve.extended_typing import Dict, Tuple +from gt4py.eve.extended_typing import Tuple from gt4py.eve.traits import SymbolTableTrait from gt4py.eve.utils import UIDGenerator from gt4py.next import common @@ -454,7 +454,7 @@ def as_expr(self) -> ir.FunCall: def translate( self: SymbolicDomain, shift: Tuple[ir.OffsetLiteral, ...], - offset_provider: Dict[str, common.Dimension], + offset_provider: common.OffsetProvider, ) -> SymbolicDomain: dims = list(self.ranges.keys()) new_ranges = {dim: self.ranges[dim] for dim in dims} @@ -498,7 +498,7 @@ def translate( raise AssertionError("Number of shifts must be a multiple of 2.") -def domain_union(domains: list[SymbolicDomain]) -> SymbolicDomain: +def domain_union(*domains: SymbolicDomain) -> SymbolicDomain: """Return the (set) union of a list of domains.""" new_domain_ranges = {} assert all(domain.grid_type == domains[0].grid_type for domain in domains) @@ -617,7 +617,7 @@ def update_domains( consumed_domain.ranges.keys() == consumed_domains[0].ranges.keys() for consumed_domain in consumed_domains ): # scalar otherwise - domains[param] = domain_union(consumed_domains).as_expr() + domains[param] = domain_union(*consumed_domains).as_expr() return FencilWithTemporaries( fencil=ir.FencilDefinition( diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index e05c58e157..87f754d644 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -6,52 +6,131 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + +import itertools +import typing +from typing import Callable, TypeAlias + from gt4py.eve import utils as eve_utils -from gt4py.eve.extended_typing import Dict, Tuple +from gt4py.next import common from gt4py.next.common import Dimension from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator.transforms import trace_shifts -from gt4py.next.iterator.transforms.global_tmps import AUTO_DOMAIN, SymbolicDomain, domain_union +from gt4py.next.iterator.transforms.global_tmps import SymbolicDomain, domain_union +from gt4py.next.utils import tree_map + + +DOMAIN: TypeAlias = SymbolicDomain | None | tuple["DOMAIN", ...] +ACCESSED_DOMAINS: TypeAlias = dict[str, DOMAIN] + + +def split_dict_by_key(pred: Callable, d: dict): + """ + Split dictionary into two based on predicate. + + >>> d = {1: "a", 2: "b", 3: "c", 4: "d"} + >>> split_dict_by_key(lambda k: k % 2 == 0, d) + ({2: 'b', 4: 'd'}, {1: 'a', 3: 'c'}) + """ + a: dict = {} + b: dict = {} + for k, v in d.items(): + (a if pred(k) else b)[k] = v + return a, b + + +# TODO(tehrengruber): Revisit whether we want to move this behaviour to `domain_union`. +def _domain_union_with_none(*domains: SymbolicDomain | None) -> SymbolicDomain | None: + filtered_domains: list[SymbolicDomain] = [d for d in domains if d is not None] + if len(filtered_domains) == 0: + return None + return domain_union(*filtered_domains) + + +def canonicalize_domain_structure(d1: DOMAIN, d2: DOMAIN) -> tuple[DOMAIN, DOMAIN]: + """ + Given two domains or composites thereof, canonicalize their structure. + + If one of the arguments is a tuple the other one will be promoted to a tuple of same structure + unless it already is a tuple. Missing values are replaced by None, meaning no domain is + specified. + + >>> domain = im.domain(common.GridType.CARTESIAN, {}) + >>> canonicalize_domain_structure((domain,), (domain, domain)) == ( + ... (domain, None), + ... (domain, domain), + ... ) + True + + >>> canonicalize_domain_structure((domain, None), None) == ((domain, None), (None, None)) + True + """ + if d1 is None and isinstance(d2, tuple): + return canonicalize_domain_structure((None,) * len(d2), d2) + if d2 is None and isinstance(d1, tuple): + return canonicalize_domain_structure(d1, (None,) * len(d1)) + if isinstance(d1, tuple) and isinstance(d2, tuple): + return tuple( + zip( + *( + canonicalize_domain_structure(el1, el2) + for el1, el2 in itertools.zip_longest(d1, d2, fillvalue=None) + ) + ) + ) # type: ignore[return-value] # mypy not smart enough + return d1, d2 def _merge_domains( - original_domains: Dict[str, SymbolicDomain], additional_domains: Dict[str, SymbolicDomain] -) -> Dict[str, SymbolicDomain]: + original_domains: ACCESSED_DOMAINS, + additional_domains: ACCESSED_DOMAINS, +) -> ACCESSED_DOMAINS: new_domains = {**original_domains} - for key, value in additional_domains.items(): - if key in original_domains: - new_domains[key] = domain_union([original_domains[key], value]) - else: - new_domains[key] = value + + for key, domain in additional_domains.items(): + original_domain, domain = canonicalize_domain_structure( + original_domains.get(key, None), domain + ) + new_domains[key] = tree_map(_domain_union_with_none)(original_domain, domain) return new_domains -def extract_shifts_and_translate_domains( +def extract_accessed_domains( stencil: itir.Expr, input_ids: list[str], target_domain: SymbolicDomain, - offset_provider: Dict[str, Dimension], - accessed_domains: Dict[str, SymbolicDomain], -): + offset_provider: common.OffsetProvider, +) -> ACCESSED_DOMAINS: + accessed_domains: dict[str, SymbolicDomain | None] = {} + shifts_results = trace_shifts.trace_stencil(stencil, num_args=len(input_ids)) for in_field_id, shifts_list in zip(input_ids, shifts_results, strict=True): new_domains = [ SymbolicDomain.translate(target_domain, shift, offset_provider) for shift in shifts_list ] - if new_domains: - accessed_domains[in_field_id] = domain_union(new_domains) + # `None` means field is never accessed + accessed_domains[in_field_id] = _domain_union_with_none( + accessed_domains.get(in_field_id, None), *new_domains + ) + + return typing.cast(ACCESSED_DOMAINS, accessed_domains) def infer_as_fieldop( applied_fieldop: itir.FunCall, - target_domain: SymbolicDomain | itir.FunCall, - offset_provider: Dict[str, Dimension], -) -> Tuple[itir.FunCall, Dict[str, SymbolicDomain]]: + target_domain: DOMAIN, + offset_provider: common.OffsetProvider, +) -> tuple[itir.FunCall, ACCESSED_DOMAINS]: assert isinstance(applied_fieldop, itir.FunCall) assert cpm.is_call_to(applied_fieldop.fun, "as_fieldop") + if target_domain is None: + raise ValueError("'target_domain' cannot be 'None'.") + if not isinstance(target_domain, SymbolicDomain): + raise ValueError("'target_domain' needs to be a 'SymbolicDomain'.") # `as_fieldop(stencil)(inputs...)` stencil, inputs = applied_fieldop.fun.args[0], applied_fieldop.args @@ -60,7 +139,6 @@ def infer_as_fieldop( assert not isinstance(stencil, itir.Lambda) or len(stencil.params) == len(applied_fieldop.args) input_ids: list[str] = [] - accessed_domains: Dict[str, SymbolicDomain] = {} # Assign ids for all inputs to `as_fieldop`. `SymRef`s stay as is, nested `as_fieldop` get a # temporary id. @@ -71,31 +149,22 @@ def infer_as_fieldop( elif isinstance(in_field, itir.SymRef): id_ = in_field.id else: - raise ValueError(f"Unsupported type {type(in_field)}") + raise ValueError(f"Unsupported expression of type '{type(in_field)}'.") input_ids.append(id_) - if isinstance(target_domain, itir.FunCall): - target_domain = SymbolicDomain.from_expr(target_domain) - - extract_shifts_and_translate_domains( - stencil, input_ids, target_domain, offset_provider, accessed_domains + accessed_domains: ACCESSED_DOMAINS = extract_accessed_domains( + stencil, input_ids, target_domain, offset_provider ) - # Recursively infer domain of inputs and update domain arg of nested `as_fieldops` + # Recursively infer domain of inputs and update domain arg of nested `as_fieldop`s transformed_inputs: list[itir.Expr] = [] for in_field_id, in_field in zip(input_ids, inputs): - if isinstance(in_field, itir.FunCall): - transformed_input, accessed_domains_tmp = infer_as_fieldop( - in_field, accessed_domains[in_field_id], offset_provider - ) - transformed_inputs.append(transformed_input) + transformed_input, accessed_domains_tmp = infer_expr( + in_field, accessed_domains[in_field_id], offset_provider + ) + transformed_inputs.append(transformed_input) - # Merge accessed_domains and accessed_domains_tmp - accessed_domains = _merge_domains(accessed_domains, accessed_domains_tmp) - elif isinstance(in_field, itir.SymRef) or isinstance(in_field, itir.Literal): - transformed_inputs.append(in_field) - else: - raise ValueError(f"Unsupported type {type(in_field)}") + accessed_domains = _merge_domains(accessed_domains, accessed_domains_tmp) transformed_call = im.as_fieldop(stencil, SymbolicDomain.as_expr(target_domain))( *transformed_inputs @@ -110,81 +179,159 @@ def infer_as_fieldop( return transformed_call, accessed_domains_without_tmp -def _validate_temporary_usage(body: list[itir.Stmt], temporaries: list[str]): - assigned_targets = set() - for stmt in body: - assert isinstance(stmt, itir.SetAt) # TODO: extend for if-statements when they land - assert isinstance( - stmt.target, itir.SymRef - ) # TODO: stmt.target can be an expr, e.g. make_tuple - if stmt.target.id in assigned_targets: - raise ValueError("Temporaries can only be used once within a program.") - if stmt.target.id in temporaries: - assigned_targets.add(stmt.target.id) +def infer_let( + let_expr: itir.FunCall, + input_domain: DOMAIN, + offset_provider: common.OffsetProvider, +) -> tuple[itir.FunCall, ACCESSED_DOMAINS]: + assert cpm.is_let(let_expr) + assert isinstance(let_expr.fun, itir.Lambda) # just to make mypy happy + transformed_calls_expr, accessed_domains = infer_expr( + let_expr.fun.expr, input_domain, offset_provider + ) + + let_params = {param_sym.id for param_sym in let_expr.fun.params} + accessed_domains_let_args, accessed_domains_outer = split_dict_by_key( + lambda k: k in let_params, accessed_domains + ) + + transformed_calls_args: list[itir.Expr] = [] + for param, arg in zip(let_expr.fun.params, let_expr.args, strict=True): + transformed_calls_arg, accessed_domains_arg = infer_expr( + arg, + accessed_domains_let_args.get( + param.id, + None, + ), + offset_provider, + ) + accessed_domains_outer = _merge_domains(accessed_domains_outer, accessed_domains_arg) + transformed_calls_args.append(transformed_calls_arg) + + transformed_call = im.let( + *( + (str(param.id), call) + for param, call in zip(let_expr.fun.params, transformed_calls_args, strict=True) + ) + )(transformed_calls_expr) + + return transformed_call, accessed_domains_outer + + +def infer_make_tuple( + expr: itir.Expr, + domain: DOMAIN, + offset_provider: common.OffsetProvider, +) -> tuple[itir.Expr, ACCESSED_DOMAINS]: + assert cpm.is_call_to(expr, "make_tuple") + infered_args_expr = [] + actual_domains: ACCESSED_DOMAINS = {} + if not isinstance(domain, tuple): + # promote domain to a tuple of domains such that it has the same structure as + # the expression + # TODO(tehrengruber): Revisit. Still open how to handle IR in this case example: + # out @ c⟨ IDimₕ: [0, __out_size_0) ⟩ ← {__sym_1, __sym_2}; + domain = (domain,) * len(expr.args) + assert len(expr.args) >= len(domain) + # There may be less domains than tuple args, pad the domain with `None` in that case. + # e.g. `im.tuple_get(0, im.make_tuple(a, b), domain=domain)` + domain = (*domain, *(None for _ in range(len(expr.args) - len(domain)))) + for i, arg in enumerate(expr.args): + infered_arg_expr, actual_domains_arg = infer_expr(arg, domain[i], offset_provider) + infered_args_expr.append(infered_arg_expr) + actual_domains = _merge_domains(actual_domains, actual_domains_arg) + return im.call(expr.fun)(*infered_args_expr), actual_domains + + +def infer_tuple_get( + expr: itir.Expr, + domain: DOMAIN, + offset_provider: common.OffsetProvider, +) -> tuple[itir.Expr, ACCESSED_DOMAINS]: + assert cpm.is_call_to(expr, "tuple_get") + actual_domains: ACCESSED_DOMAINS = {} + idx, tuple_arg = expr.args + assert isinstance(idx, itir.Literal) + child_domain = tuple(None if i != int(idx.value) else domain for i in range(int(idx.value) + 1)) + infered_arg_expr, actual_domains_arg = infer_expr(tuple_arg, child_domain, offset_provider) + + infered_args_expr = im.tuple_get(idx.value, infered_arg_expr) + actual_domains = _merge_domains(actual_domains, actual_domains_arg) + return infered_args_expr, actual_domains + + +def infer_cond( + expr: itir.Expr, + domain: DOMAIN, + offset_provider: common.OffsetProvider, +) -> tuple[itir.Expr, ACCESSED_DOMAINS]: + assert cpm.is_call_to(expr, "cond") + infered_args_expr = [] + actual_domains: ACCESSED_DOMAINS = {} + cond, true_val, false_val = expr.args + for arg in [true_val, false_val]: + infered_arg_expr, actual_domains_arg = infer_expr(arg, domain, offset_provider) + infered_args_expr.append(infered_arg_expr) + actual_domains = _merge_domains(actual_domains, actual_domains_arg) + return im.call(expr.fun)(cond, *infered_args_expr), actual_domains + + +def infer_expr( + expr: itir.Expr, + domain: DOMAIN, + offset_provider: common.OffsetProvider, +) -> tuple[itir.Expr, ACCESSED_DOMAINS]: + if isinstance(expr, itir.SymRef): + return expr, {str(expr.id): domain} + elif isinstance(expr, itir.Literal): + return expr, {} + elif cpm.is_applied_as_fieldop(expr): + return infer_as_fieldop(expr, domain, offset_provider) + elif cpm.is_let(expr): + return infer_let(expr, domain, offset_provider) + elif cpm.is_call_to(expr, "make_tuple"): + return infer_make_tuple(expr, domain, offset_provider) + elif cpm.is_call_to(expr, "tuple_get"): + return infer_tuple_get(expr, domain, offset_provider) + elif cpm.is_call_to(expr, "cond"): + return infer_cond(expr, domain, offset_provider) + elif ( + cpm.is_call_to(expr, itir.ARITHMETIC_BUILTINS) + or cpm.is_call_to(expr, itir.TYPEBUILTINS) + or cpm.is_call_to(expr, "cast_") + ): + return expr, {} + else: + raise ValueError(f"Unsupported expression: {expr}") def infer_program( program: itir.Program, - offset_provider: Dict[str, Dimension], + offset_provider: dict[str, Dimension], ) -> itir.Program: - accessed_domains: dict[str, SymbolicDomain] = {} transformed_set_ats: list[itir.SetAt] = [] + assert ( + not program.function_definitions + ), "Domain propagation does not support function definitions." - temporaries: list[str] = [tmp.id for tmp in program.declarations] - - _validate_temporary_usage(program.body, temporaries) - - for set_at in reversed(program.body): + for set_at in program.body: assert isinstance(set_at, itir.SetAt) - if isinstance(set_at.expr, itir.SymRef): - transformed_set_ats.insert(0, set_at) - continue - assert isinstance(set_at.expr, itir.FunCall) - assert cpm.is_call_to(set_at.expr.fun, "as_fieldop") - assert isinstance( - set_at.target, itir.SymRef - ) # TODO: stmt.target can be an expr, e.g. make_tuple - if set_at.target.id in temporaries: - # ignore temporaries as their domain is the `AUTO_DOMAIN` placeholder - assert set_at.domain == AUTO_DOMAIN - else: - accessed_domains[set_at.target.id] = SymbolicDomain.from_expr(set_at.domain) - transformed_as_fieldop, current_accessed_domains = infer_as_fieldop( - set_at.expr, accessed_domains[set_at.target.id], offset_provider + transformed_call, _unused_domain = infer_expr( + set_at.expr, SymbolicDomain.from_expr(set_at.domain), offset_provider ) - transformed_set_ats.insert( - 0, + transformed_set_ats.append( itir.SetAt( - expr=transformed_as_fieldop, - domain=SymbolicDomain.as_expr(accessed_domains[set_at.target.id]), + expr=transformed_call, + domain=set_at.domain, target=set_at.target, ), ) - for field in current_accessed_domains: - if field in accessed_domains: - # multiple accesses to the same field -> compute union of accessed domains - if field in temporaries: - accessed_domains[field] = domain_union( - [accessed_domains[field], current_accessed_domains[field]] - ) - else: - # TODO(tehrengruber): if domain_ref is an external field the domain must - # already be larger. This should be checked, but would require additions - # to the IR. - pass - else: - accessed_domains[field] = current_accessed_domains[field] - - new_declarations = program.declarations - for temporary in new_declarations: - temporary.domain = SymbolicDomain.as_expr(accessed_domains[temporary.id]) - return itir.Program( id=program.id, function_definitions=program.function_definitions, params=program.params, - declarations=new_declarations, + declarations=program.declarations, body=transformed_set_ats, ) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index ad98ade084..e9a0ad16c4 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -15,6 +15,7 @@ from gt4py.eve.concepts import SymbolName from gt4py.next import common from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.iterator.type_system import inference as itir_type_inference from gt4py.next.program_processors.codegens.gtfn.gtfn_ir import ( Backend, @@ -208,14 +209,6 @@ def _bool_from_literal(node: itir.Node) -> bool: return node.value == "True" -def _is_applied_as_fieldop(arg: itir.Expr) -> TypeGuard[itir.FunCall]: - return ( - isinstance(arg, itir.FunCall) - and isinstance(arg.fun, itir.FunCall) - and arg.fun.fun == itir.SymRef(id="as_fieldop") - ) - - class _CannonicalizeUnstructuredDomain(eve.NodeTranslator): def visit_FunCall(self, node: itir.FunCall) -> itir.FunCall: if node.fun == itir.SymRef(id="unstructured_domain"): @@ -583,7 +576,7 @@ def visit_Stmt(self, node: itir.Stmt, **kwargs: Any) -> None: def visit_SetAt( self, node: itir.SetAt, *, extracted_functions: list, **kwargs: Any ) -> Union[StencilExecution, ScanExecution]: - assert _is_applied_as_fieldop(node.expr) + assert cpm.is_applied_as_fieldop(node.expr) stencil = node.expr.fun.args[0] # type: ignore[attr-defined] # checked in assert domain = node.domain inputs = node.expr.args diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py index eb3deb1a85..51932e0aa0 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py @@ -8,20 +8,20 @@ # TODO(SF-N): test scan operator +import pytest import numpy as np -from typing import Iterable +from typing import Iterable, Optional, Literal, Union from gt4py import eve from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms.infer_domain import infer_as_fieldop, infer_program -from gt4py.next.iterator.transforms.global_tmps import SymbolicDomain, AUTO_DOMAIN -import pytest -from gt4py.eve.extended_typing import Dict -from gt4py.next.common import Dimension, DimensionKind +from gt4py.next.iterator.transforms import infer_domain +from gt4py.next.iterator.transforms.global_tmps import SymbolicDomain +from gt4py.next.common import Dimension from gt4py.next import common, NeighborTableOffsetProvider from gt4py.next.type_system import type_specifications as ts from gt4py.next.iterator.transforms.constant_folding import ConstantFolding +from gt4py.next import utils float_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) IDim = common.Dimension(value="IDim", kind=common.DimensionKind.HORIZONTAL) @@ -33,11 +33,7 @@ @pytest.fixture def offset_provider(): - return { - "Ioff": IDim, - "Joff": JDim, - "Koff": KDim, - } + return {"Ioff": IDim, "Joff": JDim, "Koff": KDim} @pytest.fixture @@ -52,44 +48,66 @@ def unstructured_offset_provider(): } -def run_test_as_fieldop( - stencil: itir.Lambda, +def premap_field( + field: itir.Expr, dim: str, offset: int, domain: Optional[itir.FunCall] = None +) -> itir.Expr: + return im.as_fieldop(im.lambda_("it")(im.deref(im.shift(dim, offset)("it"))), domain)(field) + + +def setup_test_as_fieldop( + stencil: itir.Lambda | Literal["deref"], domain: itir.FunCall, - expected_domain_dict: Dict[str, Dict[str | Dimension, tuple[itir.Expr, itir.Expr]]], - offset_provider: Dict[str, Dimension], *, refs: Iterable[itir.SymRef] = None, - domain_type: str = common.GridType.CARTESIAN, -) -> None: +) -> tuple[itir.FunCall, itir.FunCall]: if refs is None: + assert isinstance(stencil, itir.Lambda) refs = [f"in_field{i+1}" for i in range(0, len(stencil.params))] testee = im.as_fieldop(stencil)(*refs) expected = im.as_fieldop(stencil, domain)(*refs) - - actual_call, actual_domains = infer_as_fieldop( - testee, SymbolicDomain.from_expr(domain), offset_provider - ) - - folded_domains = constant_fold_accessed_domains(actual_domains) - expected_domains = { - ref: SymbolicDomain.from_expr(im.domain(domain_type, d)) - for ref, d in expected_domain_dict.items() - } - - assert actual_call == expected - assert folded_domains == expected_domains + return testee, expected def run_test_program( - testee: itir.Program, expected: itir.Program, offset_provider: dict[str, Dimension] + testee: itir.Program, expected: itir.Program, offset_provider: common.OffsetProvider ) -> None: - actual_program = infer_program(testee, offset_provider) + actual_program = infer_domain.infer_program(testee, offset_provider) folded_program = constant_fold_domain_exprs(actual_program) assert folded_program == expected +def run_test_expr( + testee: itir.FunCall, + expected: itir.FunCall, + domain: itir.FunCall, + expected_domains: dict[str, itir.Expr | dict[str | Dimension, tuple[itir.Expr, itir.Expr]]], + offset_provider: common.OffsetProvider, +): + actual_call, actual_domains = infer_domain.infer_expr( + testee, SymbolicDomain.from_expr(domain), offset_provider + ) + folded_call = constant_fold_domain_exprs(actual_call) + folded_domains = constant_fold_accessed_domains(actual_domains) if actual_domains else None + + grid_type = str(domain.fun.id) + + def canonicalize_domain(d): + if isinstance(d, dict): + return im.domain(grid_type, d) + elif isinstance(d, itir.FunCall): + return d + elif d is None: + return None + raise AssertionError() + + expected_domains = {ref: canonicalize_domain(d) for ref, d in expected_domains.items()} + + assert folded_call == expected + assert folded_domains == expected_domains + + class _ConstantFoldDomainsExprs(eve.NodeTranslator): def visit_FunCall(self, node: itir.FunCall): if cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")): @@ -101,18 +119,56 @@ def constant_fold_domain_exprs(arg: itir.Node) -> itir.Node: return _ConstantFoldDomainsExprs().visit(arg) -def constant_fold_accessed_domains(domains: Dict[str, SymbolicDomain]) -> Dict[str, SymbolicDomain]: - return { - k: SymbolicDomain.from_expr(constant_fold_domain_exprs(v.as_expr())) - for k, v in domains.items() - } +def constant_fold_accessed_domains( + domains: infer_domain.ACCESSED_DOMAINS, +) -> infer_domain.ACCESSED_DOMAINS: + def fold_domain(domain: SymbolicDomain | None): + if domain is None: + return domain + return constant_fold_domain_exprs(domain.as_expr()) + + return {k: utils.tree_map(fold_domain)(v) for k, v in domains.items()} + + +def translate_domain( + domain: itir.FunCall, + shifts: dict[Union[common.Dimension, str], tuple[itir.Expr, itir.Expr]], + offset_provider: common.OffsetProvider, +) -> SymbolicDomain: + shift_tuples = [ + ( + im.ensure_offset( + itir.AxisLiteral(value=d.value, kind=d.kind) + if isinstance(d, common.Dimension) + else itir.AxisLiteral(value=d) + ), + im.ensure_offset(r), + ) + for d, r in shifts.items() + ] + + shift_list = [item for sublist in shift_tuples for item in sublist] + + translated_domain_expr = SymbolicDomain.translate( + SymbolicDomain.from_expr(domain), shift_list, offset_provider + ) + + return constant_fold_domain_exprs(translated_domain_expr.as_expr()) def test_forward_difference_x(offset_provider): stencil = im.lambda_("arg0")(im.minus(im.deref(im.shift("Ioff", 1)("arg0")), im.deref("arg0"))) domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) - expected_accessed_domains = {"in_field1": {IDim: (0, 12)}} - run_test_as_fieldop(stencil, domain, expected_accessed_domains, offset_provider) + expected_domains = {"in_field1": {IDim: (0, 12)}} + testee, expected = setup_test_as_fieldop(stencil, domain) + run_test_expr(testee, expected, domain, expected_domains, offset_provider) + + +def test_deref(offset_provider): + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + expected_domains = {"in_field": {IDim: (0, 11)}} + testee, expected = setup_test_as_fieldop("deref", domain, refs=["in_field"]) + run_test_expr(testee, expected, domain, expected_domains, offset_provider) def test_multi_length_shift(offset_provider): @@ -129,37 +185,21 @@ def test_multi_length_shift(offset_provider): ) ) domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) - expected_accessed_domains = {"in_field1": {IDim: (3, 14)}} - run_test_as_fieldop(stencil, domain, expected_accessed_domains, offset_provider) - - -def test_unused_input(offset_provider): - stencil = im.lambda_("arg0", "arg1")(im.deref("arg0")) - - domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) - expected_accessed_domains = { - "in_field1": {IDim: (0, 11)}, - } - run_test_as_fieldop( - stencil, - domain, - expected_accessed_domains, - offset_provider, - ) + expected_domains = {"in_field1": {IDim: (3, 14)}} + testee, expected = setup_test_as_fieldop(stencil, domain) + run_test_expr(testee, expected, domain, expected_domains, offset_provider) def test_unstructured_shift(unstructured_offset_provider): stencil = im.lambda_("arg0")(im.deref(im.shift("E2V", 1)("arg0"))) domain = im.domain(common.GridType.UNSTRUCTURED, {Edge: (0, 1)}) - expected_accessed_domains = {"in_field1": {Vertex: (0, 2)}} + expected_domains = {"in_field1": {Vertex: (0, 2)}} - run_test_as_fieldop( + testee, expected = setup_test_as_fieldop( stencil, domain, - expected_accessed_domains, - unstructured_offset_provider, - domain_type=common.GridType.UNSTRUCTURED, ) + run_test_expr(testee, expected, domain, expected_domains, unstructured_offset_provider) def test_laplace(offset_provider): @@ -179,9 +219,10 @@ def test_laplace(offset_provider): ) ) domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11), JDim: (0, 7)}) - expected_accessed_domains = {"in_field1": {IDim: (-1, 12), JDim: (-1, 8)}} + expected_domains = {"in_field1": {IDim: (-1, 12), JDim: (-1, 8)}} - run_test_as_fieldop(stencil, domain, expected_accessed_domains, offset_provider) + testee, expected = setup_test_as_fieldop(stencil, domain) + run_test_expr(testee, expected, domain, expected_domains, offset_provider) def test_shift_x_y_two_inputs(offset_provider): @@ -192,16 +233,15 @@ def test_shift_x_y_two_inputs(offset_provider): ) ) domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11), JDim: (0, 7)}) - expected_accessed_domains = { + expected_domains = { "in_field1": {IDim: (-1, 10), JDim: (0, 7)}, "in_field2": {IDim: (0, 11), JDim: (1, 8)}, } - run_test_as_fieldop( + testee, expected = setup_test_as_fieldop( stencil, domain, - expected_accessed_domains, - offset_provider, ) + run_test_expr(testee, expected, domain, expected_domains, offset_provider) def test_shift_x_y_two_inputs_literal(offset_provider): @@ -212,16 +252,15 @@ def test_shift_x_y_two_inputs_literal(offset_provider): ) ) domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11), JDim: (0, 7)}) - expected_accessed_domains = { + expected_domains = { "in_field1": {IDim: (-1, 10), JDim: (0, 7)}, } - run_test_as_fieldop( + testee, expected = setup_test_as_fieldop( stencil, domain, - expected_accessed_domains, - offset_provider, refs=(im.ref("in_field1"), 2), ) + run_test_expr(testee, expected, domain, expected_domains, offset_provider) def test_shift_x_y_z_three_inputs(offset_provider): @@ -234,36 +273,40 @@ def test_shift_x_y_z_three_inputs(offset_provider): im.deref(im.shift("Koff", -1)("arg2")), ) ) - domain_dict = { - IDim: (0, 11), - JDim: (0, 7), - KDim: (0, 3), - } - expected_domain_dict = { - "in_field1": { - IDim: (1, 12), - JDim: (0, 7), - KDim: (0, 3), - }, - "in_field2": { - IDim: (0, 11), - JDim: (1, 8), - KDim: (0, 3), - }, - "in_field3": { - IDim: (0, 11), - JDim: (0, 7), - KDim: (-1, 2), - }, + domain_dict = {IDim: (0, 11), JDim: (0, 7), KDim: (0, 3)} + expected_domains = { + "in_field1": {IDim: (1, 12), JDim: (0, 7), KDim: (0, 3)}, + "in_field2": {IDim: (0, 11), JDim: (1, 8), KDim: (0, 3)}, + "in_field3": {IDim: (0, 11), JDim: (0, 7), KDim: (-1, 2)}, } - run_test_as_fieldop( + testee, expected = setup_test_as_fieldop( stencil, im.domain(common.GridType.CARTESIAN, domain_dict), - expected_domain_dict, + ) + run_test_expr( + testee, + expected, + im.domain(common.GridType.CARTESIAN, domain_dict), + expected_domains, offset_provider, ) +def test_two_params_same_arg(offset_provider): + stencil = im.lambda_("arg0", "arg1")( + im.plus( + im.deref("arg0"), + im.deref(im.shift("Ioff", 1)("arg1")), + ) + ) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + expected_domains = { + "in_field": {IDim: (0, 12)}, + } + testee, expected = setup_test_as_fieldop(stencil, domain, refs=["in_field", "in_field"]) + run_test_expr(testee, expected, domain, expected_domains, offset_provider) + + def test_nested_stencils(offset_provider): inner_stencil = im.lambda_("arg0_tmp", "arg1_tmp")( im.plus( @@ -280,8 +323,8 @@ def test_nested_stencils(offset_provider): tmp = im.as_fieldop(inner_stencil)(im.ref("in_field1"), im.ref("in_field2")) testee = im.as_fieldop(stencil)(im.ref("in_field1"), tmp) - domain_inner = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11), JDim: (-1, 6)}) domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11), JDim: (0, 7)}) + domain_inner = translate_domain(domain, {"Ioff": 0, "Joff": -1}, offset_provider) expected_inner = im.as_fieldop(inner_stencil, domain_inner)( im.ref("in_field1"), im.ref("in_field2") @@ -289,14 +332,10 @@ def test_nested_stencils(offset_provider): expected = im.as_fieldop(stencil, domain)(im.ref("in_field1"), expected_inner) expected_domains = { - "in_field1": SymbolicDomain.from_expr( - im.domain(common.GridType.CARTESIAN, {IDim: (1, 12), JDim: (-1, 7)}) - ), - "in_field2": SymbolicDomain.from_expr( - im.domain(common.GridType.CARTESIAN, {IDim: (0, 11), JDim: (-2, 5)}) - ), + "in_field1": im.domain(common.GridType.CARTESIAN, {IDim: (1, 12), JDim: (-1, 7)}), + "in_field2": translate_domain(domain, {"Ioff": 0, "Joff": -2}, offset_provider), } - actual_call, actual_domains = infer_as_fieldop( + actual_call, actual_domains = infer_domain.infer_expr( testee, SymbolicDomain.from_expr(domain), offset_provider ) folded_domains = constant_fold_accessed_domains(actual_domains) @@ -332,17 +371,15 @@ def test_nested_stencils_n_times(offset_provider, iterations): testee = testee expected_domains = { - "in_field1": SymbolicDomain.from_expr( - im.domain(common.GridType.CARTESIAN, {IDim: (1, 12), JDim: (0, 7 + iterations - 1)}) + "in_field1": im.domain( + common.GridType.CARTESIAN, {IDim: (1, 12), JDim: (0, 7 + iterations - 1)} ), - "in_field2": SymbolicDomain.from_expr( - im.domain( - common.GridType.CARTESIAN, {IDim: (0, 11), JDim: (iterations, 7 + iterations)} - ) + "in_field2": im.domain( + common.GridType.CARTESIAN, {IDim: (0, 11), JDim: (iterations, 7 + iterations)} ), } - actual_call, actual_domains = infer_as_fieldop( + actual_call, actual_domains = infer_domain.infer_expr( testee, SymbolicDomain.from_expr(domain), offset_provider ) @@ -352,24 +389,45 @@ def test_nested_stencils_n_times(offset_provider, iterations): assert folded_domains == expected_domains +def test_unused_input(offset_provider): + stencil = im.lambda_("arg0", "arg1")(im.deref("arg0")) + + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + expected_domains = {"in_field1": {IDim: (0, 11)}, "in_field2": None} + testee, expected = setup_test_as_fieldop( + stencil, + domain, + ) + run_test_expr(testee, expected, domain, expected_domains, offset_provider) + + +def test_let_unused_field(offset_provider): + testee = im.let("a", "c")("b") + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + expected = im.let("a", "c")("b") + expected_domains = {"b": {IDim: (0, 11)}, "c": None} + + run_test_expr(testee, expected, domain, expected_domains, offset_provider) + + def test_program(offset_provider): stencil = im.lambda_("arg0")(im.minus(im.deref(im.shift("Ioff", 1)("arg0")), im.deref("arg0"))) applied_as_fieldop_tmp = im.as_fieldop(stencil)(im.ref("in_field")) applied_as_fieldop = im.as_fieldop(stencil)(im.ref("tmp")) - domain_tmp = im.domain(common.GridType.CARTESIAN, {IDim: (0, 12)}) domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + domain_tmp = im.domain(common.GridType.CARTESIAN, {IDim: (0, 12)}) - params = [im.sym(name) for name in ["in_field", "out_field", "_gtmp_auto_domain"]] + params = [im.sym(name) for name in ["in_field", "out_field"]] testee = itir.Program( id="forward_diff_with_tmp", function_definitions=[], params=params, - declarations=[itir.Temporary(id="tmp", domain=AUTO_DOMAIN, dtype=float_type)], + declarations=[itir.Temporary(id="tmp", domain=domain_tmp, dtype=float_type)], body=[ - itir.SetAt(expr=applied_as_fieldop_tmp, domain=AUTO_DOMAIN, target=im.ref("tmp")), + itir.SetAt(expr=applied_as_fieldop_tmp, domain=domain_tmp, target=im.ref("tmp")), itir.SetAt(expr=applied_as_fieldop, domain=domain, target=im.ref("out_field")), ], ) @@ -391,155 +449,543 @@ def test_program(offset_provider): run_test_program(testee, expected, offset_provider) -def test_program_two_tmps(offset_provider): - stencil = im.lambda_("arg0")(im.minus(im.deref(im.shift("Ioff", 1)("arg0")), im.deref("arg0"))) - - as_fieldop_tmp1 = im.as_fieldop(stencil)(im.ref("in_field")) - as_fieldop_tmp2 = im.as_fieldop(stencil)(im.ref("tmp1")) - as_fieldop = im.as_fieldop(stencil)(im.ref("tmp2")) - +def test_program_make_tuple(offset_provider): domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) - domain_tmp1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 13)}) - domain_tmp2 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 12)}) - - params = [im.sym(name) for name in ["in_field", "out_field", "_gtmp_auto_domain"]] + params = [im.sym(name) for name in ["in_field", "out_field"]] testee = itir.Program( - id="forward_diff_with_two_tmps", + id="make_tuple_prog", function_definitions=[], params=params, - declarations=[ - itir.Temporary(id="tmp1", domain=AUTO_DOMAIN, dtype=float_type), - itir.Temporary(id="tmp2", domain=AUTO_DOMAIN, dtype=float_type), - ], + declarations=[], body=[ - itir.SetAt(expr=as_fieldop_tmp1, domain=AUTO_DOMAIN, target=im.ref("tmp1")), - itir.SetAt(expr=as_fieldop_tmp2, domain=AUTO_DOMAIN, target=im.ref("tmp2")), - itir.SetAt(expr=as_fieldop, domain=domain, target=im.ref("out_field")), + itir.SetAt( + expr=im.make_tuple(im.as_fieldop("deref")("in_field"), "in_field"), + domain=domain, + target=im.ref("out_field"), + ), ], ) - expected_expr_tmp1 = im.as_fieldop(stencil, domain_tmp1)(im.ref("in_field")) - expected_expr_tmp2 = im.as_fieldop(stencil, domain_tmp2)(im.ref("tmp1")) - expected_expr = im.as_fieldop(stencil, domain)(im.ref("tmp2")) - expected = itir.Program( - id="forward_diff_with_two_tmps", + id="make_tuple_prog", function_definitions=[], params=params, - declarations=[ - itir.Temporary(id="tmp1", domain=domain_tmp1, dtype=float_type), - itir.Temporary(id="tmp2", domain=domain_tmp2, dtype=float_type), - ], + declarations=[], body=[ - itir.SetAt(expr=expected_expr_tmp1, domain=domain_tmp1, target=im.ref("tmp1")), - itir.SetAt(expr=expected_expr_tmp2, domain=domain_tmp2, target=im.ref("tmp2")), - itir.SetAt(expr=expected_expr, domain=domain, target=im.ref("out_field")), + itir.SetAt( + expr=im.make_tuple(im.as_fieldop("deref", domain)("in_field"), "in_field"), + domain=domain, + target=im.ref("out_field"), + ), ], ) run_test_program(testee, expected, offset_provider) -@pytest.mark.xfail(raises=ValueError) -def test_program_ValueError(offset_provider): - with pytest.raises(ValueError, match=r"Temporaries can only be used once within a program."): - stencil = im.lambda_("arg0")(im.deref("arg0")) +def test_cond(offset_provider): + stencil1 = im.lambda_("arg0")(im.minus(im.deref(im.shift("Ioff", 1)("arg0")), im.deref("arg0"))) + field_1 = im.as_fieldop(stencil1)(im.ref("in_field1")) + tmp_stencil2 = im.lambda_("arg0_tmp", "arg1_tmp")( + im.plus( + im.deref(im.shift("Ioff", 1)("arg0_tmp")), + im.deref(im.shift("Ioff", -1)("arg1_tmp")), + ) + ) + stencil2 = im.lambda_("arg0", "arg1")( + im.plus( + im.deref(im.shift("Ioff", 1)("arg0")), + im.deref(im.shift("Ioff", -1)("arg1")), + ) + ) + tmp2 = im.as_fieldop(tmp_stencil2)(im.ref("in_field1"), im.ref("in_field2")) + field_2 = im.as_fieldop(stencil2)(im.ref("in_field2"), tmp2) - as_fieldop_tmp = im.as_fieldop(stencil)(im.ref("in_field")) - as_fieldop = im.as_fieldop(stencil)(im.ref("tmp")) + cond = im.deref("cond_") - domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + testee = im.cond(cond, field_1, field_2) - params = [im.sym(name) for name in ["in_field", "out_field", "_gtmp_auto_domain"]] + domain = im.domain(common.GridType.CARTESIAN, {"IDim": (0, 11)}) + domain_tmp = translate_domain(domain, {"Ioff": -1}, offset_provider) + expected_domains_dict = {"in_field1": {IDim: (0, 12)}, "in_field2": {IDim: (-2, 12)}} + expected_tmp2 = im.as_fieldop(tmp_stencil2, domain_tmp)( + im.ref("in_field1"), im.ref("in_field2") + ) + expected_field_1 = im.as_fieldop(stencil1, domain)(im.ref("in_field1")) + expected_field_2 = im.as_fieldop(stencil2, domain)(im.ref("in_field2"), expected_tmp2) - infer_program( - itir.Program( - id="forward_diff_with_tmp", - function_definitions=[], - params=params, - declarations=[itir.Temporary(id="tmp", domain=AUTO_DOMAIN, dtype=float_type)], - body=[ - # target occurs twice here which is prohibited - itir.SetAt(expr=as_fieldop_tmp, domain=AUTO_DOMAIN, target=im.ref("tmp")), - itir.SetAt(expr=as_fieldop_tmp, domain=AUTO_DOMAIN, target=im.ref("tmp")), - itir.SetAt(expr=as_fieldop, domain=domain, target=im.ref("out_field")), - ], - ), - offset_provider, + expected = im.cond(cond, expected_field_1, expected_field_2) + + actual_call, actual_domains = infer_domain.infer_expr( + testee, SymbolicDomain.from_expr(domain), offset_provider + ) + + folded_domains = constant_fold_accessed_domains(actual_domains) + expected_domains = { + ref: im.domain(common.GridType.CARTESIAN, d) for ref, d in expected_domains_dict.items() + } + folded_call = constant_fold_domain_exprs(actual_call) + assert folded_call == expected + assert folded_domains == expected_domains + + +def test_let_scalar_expr(offset_provider): + testee = im.let("a", 1)(im.op_as_fieldop(im.plus)("a", "b")) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + expected = im.let("a", 1)(im.op_as_fieldop(im.plus, domain)("a", "b")) + expected_domains = {"b": {IDim: (0, 11)}} + + run_test_expr(testee, expected, domain, expected_domains, offset_provider) + + +def test_simple_let(offset_provider): + testee = im.let("a", premap_field("in_field", "Ioff", 1))("a") + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + expected = im.let("a", premap_field("in_field", "Ioff", 1, domain))("a") + + expected_domains = {"in_field": translate_domain(domain, {"Ioff": 1}, offset_provider)} + run_test_expr(testee, expected, domain, expected_domains, offset_provider) + + +def test_simple_let2(offset_provider): + testee = im.let("a", "in_field")(premap_field("a", "Ioff", 1)) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + expected = im.let("a", "in_field")(premap_field("a", "Ioff", 1, domain)) + + expected_domains = {"in_field": translate_domain(domain, {"Ioff": 1}, offset_provider)} + run_test_expr(testee, expected, domain, expected_domains, offset_provider) + + +def test_let(offset_provider): + testee = im.let( + "a", + premap_field("in_field", "Ioff", 1), + )(premap_field("a", "Ioff", 1)) + testee2 = premap_field(premap_field("in_field", "Ioff", 1), "Ioff", 1) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + domain_a = translate_domain(domain, {"Ioff": 1}, offset_provider) + expected = im.let( + "a", + premap_field("in_field", "Ioff", 1, domain_a), + )(premap_field("a", "Ioff", 1, domain)) + expected2 = premap_field(premap_field("in_field", "Ioff", 1, domain_a), "Ioff", 1, domain) + expected_domains = {"in_field": translate_domain(domain, {"Ioff": 2}, offset_provider)} + run_test_expr(testee, expected, domain, expected_domains, offset_provider) + + expected_domains_sym = {"in_field": translate_domain(domain, {"Ioff": 2}, offset_provider)} + actual_call2, actual_domains2 = infer_domain.infer_expr( + testee2, SymbolicDomain.from_expr(domain), offset_provider + ) + folded_domains2 = constant_fold_accessed_domains(actual_domains2) + folded_call2 = constant_fold_domain_exprs(actual_call2) + assert folded_call2 == expected2 + assert expected_domains_sym == folded_domains2 + + +def test_let_two_inputs(offset_provider): + multiply_stencil = im.lambda_("it1", "it2")(im.multiplies_(im.deref("it1"), im.deref("it2"))) + + testee = im.let( + ("inner1", premap_field("in_field1", "Ioff", 1)), + ("inner2", premap_field("in_field2", "Ioff", -1)), + )(im.as_fieldop(multiply_stencil)("inner1", "inner2")) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + domain_p1 = translate_domain(domain, {"Ioff": 1}, offset_provider) + domain_m1 = translate_domain(domain, {"Ioff": -1}, offset_provider) + expected = im.let( + ("inner1", premap_field("in_field1", "Ioff", 1, domain)), + ("inner2", premap_field("in_field2", "Ioff", -1, domain)), + )(im.as_fieldop(multiply_stencil, domain)("inner1", "inner2")) + expected_domains = { + "in_field1": domain_p1, + "in_field2": domain_m1, + } + run_test_expr(testee, expected, domain, expected_domains, offset_provider) + + +def test_nested_let_in_body(offset_provider): + testee = im.let("inner1", premap_field("outer", "Ioff", 1))( + im.let("inner2", premap_field("inner1", "Ioff", 1))(premap_field("inner2", "Ioff", 1)) + ) + + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + domain_p1 = translate_domain(domain, {"Ioff": 1}, offset_provider) + domain_p2 = translate_domain(domain, {"Ioff": 2}, offset_provider) + domain_p3 = translate_domain(domain, {"Ioff": 3}, offset_provider) + + expected = im.let( + "inner1", + premap_field("outer", "Ioff", 1, domain_p2), + )( + im.let("inner2", premap_field("inner1", "Ioff", 1, domain_p1))( + premap_field("inner2", "Ioff", 1, domain) ) + ) + expected_domains = {"outer": domain_p3} + run_test_expr(testee, expected, domain, expected_domains, offset_provider) -def test_program_tree_tmps_two_inputs(offset_provider): - stencil = im.lambda_("arg0", "arg1")( - im.minus(im.deref(im.shift("Ioff", 1)("arg0")), im.deref("arg1")) +def test_nested_let_arg(offset_provider): + testee = im.let("a", "in_field")( + im.as_fieldop( + im.lambda_("it1", "it2")( + im.multiplies_(im.deref("it1"), im.deref(im.shift("Ioff", 1)("it2"))) + ) + )("a", "in_field") ) - stencil_tmp = im.lambda_("arg0")( - im.minus(im.deref(im.shift("Ioff", 1)("arg0")), im.deref("arg0")) + + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + + expected = im.let("a", "in_field")( + im.as_fieldop( + im.lambda_("it1", "it2")( + im.multiplies_(im.deref("it1"), im.deref(im.shift("Ioff", 1)("it2"))) + ), + domain, + )("a", "in_field") + ) + expected_domains = {"in_field": {IDim: (0, 12)}} + run_test_expr(testee, expected, domain, expected_domains, offset_provider) + + +def test_nested_let_arg_shadowed(offset_provider): + testee = im.let("a", premap_field("in_field", "Ioff", 3))( + im.let("a", premap_field("a", "Ioff", 2))(premap_field("a", "Ioff", 1)) + ) + + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + domain_p1 = translate_domain(domain, {"Ioff": 1}, offset_provider) + domain_p3 = translate_domain(domain, {"Ioff": 3}, offset_provider) + domain_p6 = translate_domain(domain, {"Ioff": 6}, offset_provider) + + expected = im.let( + "a", + premap_field("in_field", "Ioff", 3, domain_p3), + )(im.let("a", premap_field("a", "Ioff", 2, domain_p1))(premap_field("a", "Ioff", 1, domain))) + expected_domains = {"in_field": domain_p6} + run_test_expr(testee, expected, domain, expected_domains, offset_provider) + + +def test_nested_let_arg_shadowed2(offset_provider): + # test that if we shadow `in_field1` its accessed domain is not affected by the accesses + # on the shadowed field + testee = im.as_fieldop( + im.lambda_("it1", "it2")(im.multiplies_(im.deref("it1"), im.deref("it2"))) + )( + premap_field("in_field1", "Ioff", 1), # only here we access `in_field1` + im.let("in_field1", "in_field2")("in_field1"), # here we actually access `in_field2` ) - stencil_tmp_minus = im.lambda_("arg0", "arg1")( - im.minus(im.deref(im.shift("Ioff", -1)("arg0")), im.deref("arg1")) + + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + domain_p1 = translate_domain(domain, {"Ioff": 1}, offset_provider) + + expected = im.as_fieldop( + im.lambda_("it1", "it2")(im.multiplies_(im.deref("it1"), im.deref("it2"))), domain + )( + premap_field("in_field1", "Ioff", 1, domain), + im.let("in_field1", "in_field2")("in_field1"), ) + expected_domains = {"in_field1": domain_p1, "in_field2": domain} + run_test_expr(testee, expected, domain, expected_domains, offset_provider) - as_fieldop_tmp1 = im.as_fieldop(stencil)(im.ref("in_field1"), im.ref("in_field2")) - as_fieldop_tmp2 = im.as_fieldop(stencil_tmp)(im.ref("tmp1")) - as_fieldop_out1 = im.as_fieldop(stencil_tmp)(im.ref("tmp2")) - as_fieldop_tmp3 = im.as_fieldop(stencil)(im.ref("tmp1"), im.ref("in_field2")) - as_fieldop_out2 = im.as_fieldop(stencil_tmp_minus)(im.ref("tmp2"), im.ref("tmp3")) - domain_tmp1 = im.domain(common.GridType.CARTESIAN, {IDim: (-1, 13)}) - domain_tmp2 = im.domain(common.GridType.CARTESIAN, {IDim: (-1, 12)}) - domain_tmp3 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) - domain_out = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) - params = [ - im.sym(name) - for name in ["in_field1", "in_field2", "out_field1", "out_field2", "_gtmp_auto_domain"] - ] +def test_double_nested_let_fun_expr(offset_provider): + testee = im.let("inner1", premap_field("outer", "Ioff", 1))( + im.let("inner2", premap_field("inner1", "Ioff", -1))( + im.let("inner3", premap_field("inner2", "Ioff", -1))(premap_field("inner3", "Ioff", 3)) + ) + ) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + domain_p1 = translate_domain(domain, {"Ioff": 1}, offset_provider) + domain_p2 = translate_domain(domain, {"Ioff": 2}, offset_provider) + domain_p3 = translate_domain(domain, {"Ioff": 3}, offset_provider) + + expected = im.let("inner1", premap_field("outer", "Ioff", 1, domain_p1))( + im.let("inner2", premap_field("inner1", "Ioff", -1, domain_p2))( + im.let("inner3", premap_field("inner2", "Ioff", -1, domain_p3))( + premap_field("inner3", "Ioff", 3, domain) + ) + ) + ) + + expected_domains = {"outer": domain_p2} + + run_test_expr(testee, expected, domain, expected_domains, offset_provider) + + +def test_nested_let_args(offset_provider): + testee = im.let( + "inner", + im.let("inner_arg", premap_field("outer", "Ioff", 1))( + premap_field("inner_arg", "Ioff", -1) + ), + )(premap_field("inner", "Ioff", -1)) + + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + domain_m1 = translate_domain(domain, {"Ioff": -1}, offset_provider) + domain_m2 = translate_domain(domain, {"Ioff": -2}, offset_provider) + + expected = im.let( + "inner", + im.let("inner_arg", premap_field("outer", "Ioff", 1, domain_m2))( + premap_field("inner_arg", "Ioff", -1, domain_m1) + ), + )(premap_field("inner", "Ioff", -1, domain)) + + expected_domains = {"outer": domain_m1} + + run_test_expr(testee, expected, domain, expected_domains, offset_provider) + + +def test_program_let(offset_provider): + stencil_tmp = im.lambda_("arg0")( + im.minus(im.deref(im.shift("Ioff", -1)("arg0")), im.deref("arg0")) + ) + let_tmp = im.let("inner", premap_field("outer", "Ioff", -1))(premap_field("inner", "Ioff", -1)) + as_fieldop = im.as_fieldop(stencil_tmp)(im.ref("tmp")) + + domain_lm2_rm1 = im.domain(common.GridType.CARTESIAN, {IDim: (-2, 10)}) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + domain_lm1 = im.domain(common.GridType.CARTESIAN, {IDim: (-1, 11)}) + + params = [im.sym(name) for name in ["in_field", "out_field", "outer"]] testee = itir.Program( - id="differences_three_tmps_two_inputs", + id="forward_diff_with_tmp", function_definitions=[], params=params, - declarations=[ - itir.Temporary(id="tmp1", domain=AUTO_DOMAIN, dtype=float_type), - itir.Temporary(id="tmp2", domain=AUTO_DOMAIN, dtype=float_type), - itir.Temporary(id="tmp3", domain=AUTO_DOMAIN, dtype=float_type), - ], + declarations=[itir.Temporary(id="tmp", domain=domain_lm1, dtype=float_type)], body=[ - itir.SetAt(expr=as_fieldop_tmp1, domain=AUTO_DOMAIN, target=im.ref("tmp1")), - itir.SetAt(expr=as_fieldop_tmp2, domain=AUTO_DOMAIN, target=im.ref("tmp2")), - itir.SetAt(expr=as_fieldop_out1, domain=domain_out, target=im.ref("out_field1")), - itir.SetAt(expr=as_fieldop_tmp3, domain=AUTO_DOMAIN, target=im.ref("tmp3")), - itir.SetAt(expr=as_fieldop_out2, domain=domain_out, target=im.ref("out_field2")), + itir.SetAt(expr=let_tmp, domain=domain_lm1, target=im.ref("tmp")), + itir.SetAt(expr=as_fieldop, domain=domain, target=im.ref("out_field")), ], ) - expected_expr_tmp1 = im.as_fieldop(stencil, domain_tmp1)( - im.ref("in_field1"), im.ref("in_field2") - ) - expected_expr_tmp2 = im.as_fieldop(stencil_tmp, domain_tmp2)(im.ref("tmp1")) - expected_expr_out1 = im.as_fieldop(stencil_tmp, domain_out)(im.ref("tmp2")) - expected_expr_tmp3 = im.as_fieldop(stencil, domain_tmp3)(im.ref("tmp1"), im.ref("in_field2")) - expected_expr_out2 = im.as_fieldop(stencil_tmp_minus, domain_out)( - im.ref("tmp2"), im.ref("tmp3") + expected_let = im.let("inner", premap_field("outer", "Ioff", -1, domain_lm2_rm1))( + premap_field("inner", "Ioff", -1, domain_lm1) ) + expected_as_fieldop = im.as_fieldop(stencil_tmp, domain)(im.ref("tmp")) expected = itir.Program( - id="differences_three_tmps_two_inputs", + id="forward_diff_with_tmp", function_definitions=[], params=params, - declarations=[ - itir.Temporary(id="tmp1", domain=domain_tmp1, dtype=float_type), - itir.Temporary(id="tmp2", domain=domain_tmp2, dtype=float_type), - itir.Temporary(id="tmp3", domain=domain_tmp3, dtype=float_type), - ], + declarations=[itir.Temporary(id="tmp", domain=domain_lm1, dtype=float_type)], body=[ - itir.SetAt(expr=expected_expr_tmp1, domain=domain_tmp1, target=im.ref("tmp1")), - itir.SetAt(expr=expected_expr_tmp2, domain=domain_tmp2, target=im.ref("tmp2")), - itir.SetAt(expr=expected_expr_out1, domain=domain_out, target=im.ref("out_field1")), - itir.SetAt(expr=expected_expr_tmp3, domain=domain_tmp3, target=im.ref("tmp3")), - itir.SetAt(expr=expected_expr_out2, domain=domain_out, target=im.ref("out_field2")), + itir.SetAt(expr=expected_let, domain=domain_lm1, target=im.ref("tmp")), + itir.SetAt(expr=expected_as_fieldop, domain=domain, target=im.ref("out_field")), ], ) run_test_program(testee, expected, offset_provider) + + +def test_make_tuple(offset_provider): + testee = im.make_tuple(im.as_fieldop("deref")("in_field1"), im.as_fieldop("deref")("in_field2")) + domain1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + domain2 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 13)}) + expected = im.make_tuple( + im.as_fieldop("deref", domain1)("in_field1"), im.as_fieldop("deref", domain2)("in_field2") + ) + expected_domains_dict = {"in_field1": {IDim: (0, 11)}, "in_field2": {IDim: (0, 13)}} + expected_domains = { + ref: im.domain(common.GridType.CARTESIAN, d) for ref, d in expected_domains_dict.items() + } + + actual, actual_domains = infer_domain.infer_expr( + testee, + (SymbolicDomain.from_expr(domain1), SymbolicDomain.from_expr(domain2)), + offset_provider, + ) + + assert expected == actual + assert expected_domains == constant_fold_accessed_domains(actual_domains) + + +def test_tuple_get_1_make_tuple(offset_provider): + testee = im.tuple_get(1, im.make_tuple(im.ref("a"), im.ref("b"), im.ref("c"))) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + expected = im.tuple_get(1, im.make_tuple(im.ref("a"), im.ref("b"), im.ref("c"))) + expected_domains = { + "a": None, + "b": im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}), + "c": None, + } + + actual, actual_domains = infer_domain.infer_expr( + testee, SymbolicDomain.from_expr(domain), offset_provider + ) + + assert expected == actual + assert expected_domains == constant_fold_accessed_domains(actual_domains) + + +def test_tuple_get_1_nested_make_tuple(offset_provider): + testee = im.tuple_get(1, im.make_tuple(im.ref("a"), im.make_tuple(im.ref("b"), im.ref("c")))) + domain1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + domain2 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 12)}) + expected = im.tuple_get(1, im.make_tuple(im.ref("a"), im.make_tuple(im.ref("b"), im.ref("c")))) + expected_domains = {"a": None, "b": domain1, "c": domain2} + + actual, actual_domains = infer_domain.infer_expr( + testee, + (SymbolicDomain.from_expr(domain1), SymbolicDomain.from_expr(domain2)), + offset_provider, + ) + + assert expected == actual + assert expected_domains == constant_fold_accessed_domains(actual_domains) + + +def test_tuple_get_let_arg_make_tuple(offset_provider): + testee = im.tuple_get(1, im.let("a", im.make_tuple(im.ref("b"), im.ref("c")))("d")) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + expected = im.tuple_get(1, im.let("a", im.make_tuple(im.ref("b"), im.ref("c")))("d")) + expected_domains = {"b": None, "c": None, "d": (None, domain)} + + actual, actual_domains = infer_domain.infer_expr( + testee, + SymbolicDomain.from_expr(im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)})), + offset_provider, + ) + + assert expected == actual + assert expected_domains == constant_fold_accessed_domains(actual_domains) + + +def test_tuple_get_let_make_tuple(offset_provider): + testee = im.tuple_get(1, im.let("a", "b")(im.make_tuple(im.ref("c"), im.ref("d")))) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + expected = im.tuple_get(1, im.let("a", "b")(im.make_tuple(im.ref("c"), im.ref("d")))) + expected_domains = {"c": None, "d": domain, "b": None} + + actual, actual_domains = infer_domain.infer_expr( + testee, + SymbolicDomain.from_expr(domain), + offset_provider, + ) + + assert expected == actual + assert expected_domains == constant_fold_accessed_domains(actual_domains) + + +def test_nested_make_tuple(offset_provider): + testee = im.make_tuple(im.make_tuple(im.ref("a"), im.ref("b")), im.ref("c")) + domain1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + domain2_1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 12)}) + domain2_2 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 13)}) + domain3 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 14)}) + expected = im.make_tuple(im.make_tuple(im.ref("a"), im.ref("b")), im.ref("c")) + expected_domains = {"a": domain1, "b": (domain2_1, domain2_2), "c": domain3} + + actual, actual_domains = infer_domain.infer_expr( + testee, + ( + ( + SymbolicDomain.from_expr(domain1), + (SymbolicDomain.from_expr(domain2_1), SymbolicDomain.from_expr(domain2_2)), + ), + SymbolicDomain.from_expr(domain3), + ), + offset_provider, + ) + + assert expected == actual + assert expected_domains == constant_fold_accessed_domains(actual_domains) + + +def test_tuple_get_1(offset_provider): + testee = im.tuple_get(1, im.ref("a")) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + expected = im.tuple_get(1, im.ref("a")) + expected_domains = {"a": (None, domain)} + + actual, actual_domains = infer_domain.infer_expr( + testee, SymbolicDomain.from_expr(domain), offset_provider + ) + + assert expected == actual + assert expected_domains == constant_fold_accessed_domains(actual_domains) + + +def test_domain_tuple(offset_provider): + testee = im.ref("a") + domain1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + domain2 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 12)}) + expected = im.ref("a") + expected_domains = {"a": (domain1, domain2)} + + actual, actual_domains = infer_domain.infer_expr( + testee, + (SymbolicDomain.from_expr(domain1), SymbolicDomain.from_expr(domain2)), + offset_provider, + ) + + assert expected == actual + assert expected_domains == constant_fold_accessed_domains(actual_domains) + + +def test_as_fieldop_tuple_get(offset_provider): + testee = im.op_as_fieldop(im.plus)(im.tuple_get(0, im.ref("a")), im.tuple_get(1, im.ref("a"))) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + expected = im.op_as_fieldop(im.plus, domain)( + im.tuple_get(0, im.ref("a")), im.tuple_get(1, im.ref("a")) + ) + expected_domains = {"a": (domain, domain)} + + actual, actual_domains = infer_domain.infer_expr( + testee, SymbolicDomain.from_expr(domain), offset_provider + ) + + assert expected == actual + assert expected_domains == constant_fold_accessed_domains(actual_domains) + + +def test_make_tuple_2tuple_get(offset_provider): + testee = im.make_tuple(im.tuple_get(0, im.ref("a")), im.tuple_get(1, im.ref("a"))) + domain1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + domain2 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + expected = im.make_tuple(im.tuple_get(0, im.ref("a")), im.tuple_get(1, im.ref("a"))) + expected_domains = {"a": (domain1, domain2)} + + actual, actual_domains = infer_domain.infer_expr( + testee, + (SymbolicDomain.from_expr(domain1), SymbolicDomain.from_expr(domain2)), + offset_provider, + ) + + assert expected == actual + assert expected_domains == constant_fold_accessed_domains(actual_domains) + + +def test_make_tuple_non_tuple_domain(offset_provider): + testee = im.make_tuple(im.as_fieldop("deref")("in_field1"), im.as_fieldop("deref")("in_field2")) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + + expected = im.make_tuple( + im.as_fieldop("deref", domain)("in_field1"), im.as_fieldop("deref", domain)("in_field2") + ) + expected_domains = {"in_field1": domain, "in_field2": domain} + + actual, actual_domains = infer_domain.infer_expr( + testee, SymbolicDomain.from_expr(domain), offset_provider + ) + + assert expected == actual + assert expected_domains == constant_fold_accessed_domains(actual_domains) + + +def test_arithmetic_builtin(offset_provider): + testee = im.plus(im.ref("in_field1"), im.ref("in_field2")) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + expected = im.plus(im.ref("in_field1"), im.ref("in_field2")) + expected_domains = {} + + actual_call, actual_domains = infer_domain.infer_expr( + testee, SymbolicDomain.from_expr(domain), offset_provider + ) + folded_call = constant_fold_domain_exprs(actual_call) + + assert folded_call == expected + assert actual_domains == expected_domains