Skip to content

Commit

Permalink
feat[next]: domain inference for let, make_tuple, tuple_get, cond (#1591
Browse files Browse the repository at this point in the history
)

Infers the minimal domain of (nested) `let`, `make_tuple`, `tuple_get`,
`cond` and other builtins as an extension to PR #1568

- New functions `infer_let`, `infer_make_tuple`, `infer_tuple_get`,
`infer_cond` in `gt4py.next.iterator.transforms.infer_domain`
- New function `infer_expr` in
gt4py.next.iterator.transforms.infer_domain which calls the appropriate
of the above (or `infer_as_fieldop` and `infer_program`)
- Several new tests in test_infer_domain.py to test functionality

Note:
Temporary handling was only present until commit fc4846f and has been
removed in commit e8e679d to reduce unneeded complexity. This pass will
be executed before temporary extraction, hence there exist valid
`domain`s in all program calls, i.e. all `SetAt` do have a domain (not
`AUTO_DOMAIN`) that doesn't need to be inferred.

---------

Co-authored-by: Till Ehrengruber <[email protected]>
  • Loading branch information
SF-N and tehrengruber authored Sep 20, 2024
1 parent fe6dbd4 commit 21b1dfc
Show file tree
Hide file tree
Showing 5 changed files with 914 additions and 323 deletions.
5 changes: 5 additions & 0 deletions src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ def is_applied_shift(arg: itir.Node) -> TypeGuard[itir.FunCall]:
)


def is_applied_as_fieldop(arg: itir.Node) -> TypeGuard[itir.FunCall]:
"""Match expressions of the form `as_fieldop(stencil)(*args)`."""
return isinstance(arg, itir.FunCall) and is_call_to(arg.fun, "as_fieldop")


def is_let(node: itir.Node) -> TypeGuard[itir.FunCall]:
"""Match expression of the form `(λ(...) → ...)(...)`."""
return isinstance(node, itir.FunCall) and isinstance(node.fun, itir.Lambda)
Expand Down
8 changes: 4 additions & 4 deletions src/gt4py/next/iterator/transforms/global_tmps.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import gt4py.next as gtx
from gt4py.eve import NodeTranslator, PreserveLocationVisitor
from gt4py.eve.extended_typing import Dict, Tuple
from gt4py.eve.extended_typing import Tuple
from gt4py.eve.traits import SymbolTableTrait
from gt4py.eve.utils import UIDGenerator
from gt4py.next import common
Expand Down Expand Up @@ -454,7 +454,7 @@ def as_expr(self) -> ir.FunCall:
def translate(
self: SymbolicDomain,
shift: Tuple[ir.OffsetLiteral, ...],
offset_provider: Dict[str, common.Dimension],
offset_provider: common.OffsetProvider,
) -> SymbolicDomain:
dims = list(self.ranges.keys())
new_ranges = {dim: self.ranges[dim] for dim in dims}
Expand Down Expand Up @@ -498,7 +498,7 @@ def translate(
raise AssertionError("Number of shifts must be a multiple of 2.")


def domain_union(domains: list[SymbolicDomain]) -> SymbolicDomain:
def domain_union(*domains: SymbolicDomain) -> SymbolicDomain:
"""Return the (set) union of a list of domains."""
new_domain_ranges = {}
assert all(domain.grid_type == domains[0].grid_type for domain in domains)
Expand Down Expand Up @@ -617,7 +617,7 @@ def update_domains(
consumed_domain.ranges.keys() == consumed_domains[0].ranges.keys()
for consumed_domain in consumed_domains
): # scalar otherwise
domains[param] = domain_union(consumed_domains).as_expr()
domains[param] = domain_union(*consumed_domains).as_expr()

return FencilWithTemporaries(
fencil=ir.FencilDefinition(
Expand Down
Loading

0 comments on commit 21b1dfc

Please sign in to comment.