From b339b82615eaabc9e2a5eb93f5656553863a7c6b Mon Sep 17 00:00:00 2001 From: edopao Date: Mon, 14 Oct 2024 16:24:47 +0200 Subject: [PATCH 1/5] feat[next]: Add IR transform to remove unnecessary cast expressions (#1688) Add IR transformation that removes cast expressions where the argument is already in the target type. --- .../next/iterator/transforms/prune_casts.py | 45 +++++++++++++++++++ .../runners/dace_fieldview/gtir_sdfg.py | 3 ++ .../transforms_tests/test_prune_casts.py | 23 ++++++++++ 3 files changed, 71 insertions(+) create mode 100644 src/gt4py/next/iterator/transforms/prune_casts.py create mode 100644 tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_casts.py diff --git a/src/gt4py/next/iterator/transforms/prune_casts.py b/src/gt4py/next/iterator/transforms/prune_casts.py new file mode 100644 index 0000000000..0720394db5 --- /dev/null +++ b/src/gt4py/next/iterator/transforms/prune_casts.py @@ -0,0 +1,45 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from gt4py.eve import NodeTranslator, PreserveLocationVisitor +from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.type_system import type_specifications as ts + + +class PruneCasts(PreserveLocationVisitor, NodeTranslator): + """ + Removes cast expressions where the argument is already in the target type. + + This transformation requires the IR to be fully type-annotated, + therefore it should be applied after type-inference. + """ + + def visit_FunCall(self, node: ir.FunCall) -> ir.Node: + node = self.generic_visit(node) + + if not cpm.is_call_to(node, "cast_"): + return node + + value, type_constructor = node.args + + assert ( + value.type + and isinstance(type_constructor, ir.SymRef) + and (type_constructor.id in ir.TYPEBUILTINS) + ) + dtype = ts.ScalarType(kind=getattr(ts.ScalarKind, type_constructor.id.upper())) + + if value.type == dtype: + return value + + return node + + @classmethod + def apply(cls, node: ir.Node) -> ir.Node: + return cls().visit(node) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index 7d878dde99..09d5d6c0d0 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -26,6 +26,7 @@ from gt4py.next import common as gtx_common, utils as gtx_utils from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.iterator.transforms import prune_casts as ir_prune_casts from gt4py.next.iterator.type_system import inference as gtir_type_inference from gt4py.next.program_processors.runners.dace_common import utility as dace_utils from gt4py.next.program_processors.runners.dace_fieldview import ( @@ -656,7 +657,9 @@ def build_sdfg_from_gtir( Returns: An SDFG in the DaCe canonical form (simplified) """ + ir = gtir_type_inference.infer(ir, offset_provider=offset_provider) + ir = ir_prune_casts.PruneCasts().visit(ir) ir = dace_gtir_utils.patch_gtir(ir) sdfg_genenerator = GTIRToSDFG(offset_provider) sdfg = sdfg_genenerator.visit(ir) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_casts.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_casts.py new file mode 100644 index 0000000000..462eed8408 --- /dev/null +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_casts.py @@ -0,0 +1,23 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.type_system import type_specifications as ts +from gt4py.next.iterator.transforms.prune_casts import PruneCasts +from gt4py.next.iterator.type_system import inference as type_inference + + +def test_prune_casts_simple(): + x_ref = im.ref("x", ts.ScalarType(kind=ts.ScalarKind.FLOAT32)) + y_ref = im.ref("y", ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) + testee = im.call("plus")(im.call("cast_")(x_ref, "float64"), im.call("cast_")(y_ref, "float64")) + testee = type_inference.infer(testee, offset_provider={}, allow_undeclared_symbols=True) + + expected = im.call("plus")(im.call("cast_")(x_ref, "float64"), y_ref) + actual = PruneCasts.apply(testee) + assert actual == expected From 9feb51db27bde798245d3f80f4075e622bd42173 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 14 Oct 2024 17:01:26 +0200 Subject: [PATCH 2/5] [next]: Fix inline lambda pass opcount preserving option (#1687) In #1531 the `itir.Node` class got a `type` attribute, that until now contributed to the hash computation of all nodes. As such two `itir.SymRef` with the same `id`, but one with a type inferred and one without (i.e. `None`) got a different hash value. Consequently the `inline_lambda` pass did not recognize them as a reference to the same symbol and erroneously inlined the expression even with `opcount_preserving=True`. This PR fixes the hash computation, such that again `node1 == node2` implies `hash(node1) == hash(node2)`. --- src/gt4py/next/iterator/ir.py | 12 ++++++++---- .../transforms_tests/test_inline_lambdas.py | 16 ++++++++++++++++ 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index b2a549501f..42da4c83a6 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -37,10 +37,14 @@ def __str__(self) -> str: return pformat(self) def __hash__(self) -> int: - return hash(type(self)) ^ hash( - tuple( - hash(tuple(v)) if isinstance(v, list) else hash(v) - for v in self.iter_children_values() + return hash( + ( + type(self), + *( + tuple(v) if isinstance(v, list) else v + for (k, v) in self.iter_children_items() + if k not in ["location", "type"] + ), ) ) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py index e45281734b..2e0a83d33b 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py @@ -8,6 +8,7 @@ import pytest +from gt4py.next.type_system import type_specifications as ts from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas @@ -39,6 +40,21 @@ ), im.multiplies_(im.plus(2, 1), im.plus("x", "x")), ), + ( + # ensure opcount preserving option works whether `itir.SymRef` has a type or not + "typed_ref", + im.let("a", im.call("opaque")())( + im.plus(im.ref("a", ts.ScalarType(kind=ts.ScalarKind.FLOAT32)), im.ref("a", None)) + ), + { + True: im.let("a", im.call("opaque")())( + im.plus( # stays as is + im.ref("a", ts.ScalarType(kind=ts.ScalarKind.FLOAT32)), im.ref("a", None) + ) + ), + False: im.plus(im.call("opaque")(), im.call("opaque")()), + }, + ), ] From 5ce0fb8b9234c0f514ac83cc060dee6b549684c1 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 16 Oct 2024 17:01:57 +0200 Subject: [PATCH 3/5] feat[next]: gtir lowering of broadcasted scalars (#1677) --- src/gt4py/next/ffront/foast_to_gtir.py | 5 ++++- .../unit_tests/ffront_tests/test_foast_to_gtir.py | 11 +++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 948a8481d7..9cb0ce05f5 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -373,7 +373,10 @@ def create_if(true_: itir.Expr, false_: itir.Expr) -> itir.FunCall: _visit_concat_where = _visit_where # TODO(havogt): upgrade concat_where def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: - return self.visit(node.args[0], **kwargs) + expr = self.visit(node.args[0], **kwargs) + if isinstance(node.args[0].type, ts.ScalarType): + return im.as_fieldop(im.ref("deref"))(expr) + return expr def _visit_math_built_in(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: return self._map(self.visit(node.func, **kwargs), *node.args) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index 3951c410dc..09f18246dc 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -916,3 +916,14 @@ def foo(inp: gtx.Field[[TDim], float64]): assert lowered.id == "foo" assert lowered.expr == im.ref("inp") + + +def test_scalar_broadcast(): + def foo(): + return broadcast(1, (UDim, TDim)) + + parsed = FieldOperatorParser.apply_to_function(foo) + lowered = FieldOperatorLowering.apply(parsed) + + assert lowered.id == "foo" + assert lowered.expr == im.as_fieldop("deref")(1) From 3f7fceed483e8a34b17fdf4d9a2625ecb0896759 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 16 Oct 2024 17:03:51 +0200 Subject: [PATCH 4/5] feat[next]: Allow type inference without domain argument to `as_fieldop` (#1689) In case we don't have a domain argument to `as_fieldop` we can not infer the exact result type. In order to still allow some passes which don't need this information to run before the domain inference, we continue with a dummy domain. One example is the CollapseTuple pass which only needs information about the structure, e.g. how many tuple elements does this node have, but not the dimensions of a field. Note that it might appear as if using the TraceShift pass would allow us to deduce the return type of `as_fieldop` without a domain, but this is not the case, since we don't have information on the ordering of dimensions. In this example ``` as_fieldop(it1, it2 -> deref(it1) + deref(it2))(i_field, j_field) ``` it is unclear if the result has dimension I, J or J, I. --- .../next/iterator/type_system/inference.py | 4 ++- .../type_system/type_specifications.py | 2 +- .../iterator/type_system/type_synthesizer.py | 27 ++++++++++++++++--- .../iterator_tests/test_type_inference.py | 13 +++++++++ 4 files changed, 40 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index fccaa56232..a13c7fb816 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -504,9 +504,11 @@ def visit_Program(self, node: itir.Program, *, ctx) -> it_ts.ProgramType: def visit_Temporary(self, node: itir.Temporary, *, ctx) -> ts.FieldType | ts.TupleType: domain = self.visit(node.domain, ctx=ctx) assert isinstance(domain, it_ts.DomainType) + assert domain.dims != "unknown" assert node.dtype return type_info.apply_to_primitive_constituents( - lambda dtype: ts.FieldType(dims=domain.dims, dtype=dtype), node.dtype + lambda dtype: ts.FieldType(dims=domain.dims, dtype=dtype), # type: ignore[arg-type] # ensured by domain.dims != "unknown" above + node.dtype, ) def visit_IfStmt(self, node: itir.IfStmt, *, ctx) -> None: diff --git a/src/gt4py/next/iterator/type_system/type_specifications.py b/src/gt4py/next/iterator/type_system/type_specifications.py index cfe3987b8c..94a174dca4 100644 --- a/src/gt4py/next/iterator/type_system/type_specifications.py +++ b/src/gt4py/next/iterator/type_system/type_specifications.py @@ -20,7 +20,7 @@ class NamedRangeType(ts.TypeSpec): @dataclasses.dataclass(frozen=True) class DomainType(ts.DataType): - dims: list[common.Dimension] + dims: list[common.Dimension] | Literal["unknown"] @dataclasses.dataclass(frozen=True) diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 77cd39389a..c836de1391 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -271,17 +271,36 @@ def _convert_as_fieldop_input_to_iterator( @_register_builtin_type_synthesizer def as_fieldop( - stencil: TypeSynthesizer, domain: it_ts.DomainType, offset_provider: common.OffsetProvider + stencil: TypeSynthesizer, + domain: Optional[it_ts.DomainType] = None, + *, + offset_provider: common.OffsetProvider, ) -> TypeSynthesizer: + # In case we don't have a domain argument to `as_fieldop` we can not infer the exact result + # type. In order to still allow some passes which don't need this information to run before the + # domain inference, we continue with a dummy domain. One example is the CollapseTuple pass + # which only needs information about the structure, e.g. how many tuple elements does this node + # have, but not the dimensions of a field. + # Note that it might appear as if using the TraceShift pass would allow us to deduce the return + # type of `as_fieldop` without a domain, but this is not the case, since we don't have + # information on the ordering of dimensions. In this example + # `as_fieldop(it1, it2 -> deref(it1) + deref(it2))(i_field, j_field)` + # it is unclear if the result has dimension I, J or J, I. + if domain is None: + domain = it_ts.DomainType(dims="unknown") + @TypeSynthesizer - def applied_as_fieldop(*fields) -> ts.FieldType: + def applied_as_fieldop(*fields) -> ts.FieldType | ts.DeferredType: stencil_return = stencil( *(_convert_as_fieldop_input_to_iterator(domain, field) for field in fields), offset_provider=offset_provider, ) assert isinstance(stencil_return, ts.DataType) return type_info.apply_to_primitive_constituents( - lambda el_type: ts.FieldType(dims=domain.dims, dtype=el_type), stencil_return + lambda el_type: ts.FieldType(dims=domain.dims, dtype=el_type) + if domain.dims != "unknown" + else ts.DeferredType(constraint=ts.FieldType), + stencil_return, ) return applied_as_fieldop @@ -329,7 +348,7 @@ def applied_reduce(*args: it_ts.ListType, offset_provider: common.OffsetProvider @_register_builtin_type_synthesizer -def shift(*offset_literals, offset_provider) -> TypeSynthesizer: +def shift(*offset_literals, offset_provider: common.OffsetProvider) -> TypeSynthesizer: @TypeSynthesizer def apply_shift( it: it_ts.IteratorType | ts.DeferredType, diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 05cd6b6854..20a1d7e9b7 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -478,3 +478,16 @@ def test_if_stmt(): result = itir_type_inference.infer(testee, offset_provider={}, allow_undeclared_symbols=True) assert result.cond.type == bool_type assert result.true_branch[0].expr.type == float_i_field + + +def test_as_fieldop_without_domain(): + testee = im.as_fieldop(im.lambda_("it")(im.deref(im.shift("IOff", 1)("it"))))( + im.ref("inp", float_i_field) + ) + result = itir_type_inference.infer( + testee, offset_provider={"IOff": IDim}, allow_undeclared_symbols=True + ) + assert result.type == ts.DeferredType(constraint=ts.FieldType) + assert result.fun.args[0].type.pos_only_args[0] == it_ts.IteratorType( + position_dims="unknown", defined_dims=float_i_field.dims, element_type=float_i_field.dtype + ) From 0a27c7a415a8cb7ec61e2a3fe2cdd4595a3481d7 Mon Sep 17 00:00:00 2001 From: edopao Date: Thu, 17 Oct 2024 16:08:03 +0200 Subject: [PATCH 5/5] feat[next][dace]: GTIR-to-DaCe lowering of map-reduce (only full connectivity) (#1683) This PR adds support for lowering of `map_` and `make_const_list` builtin functions. However, the current implementation only supports neighbor tables with full connectivity (no skip values). The support for skip values will be added in next PR. To be noted: - This PR generalizes the handling of tasklets without arguments inside a map scope. The return type for `input_connections` is extended to contain a `TaskletConnection` variant, which is lowered to an empty edge from map entry node to the tasklet node. - The result of `make_const_list` is a scalar value to be broadcasted on a local field. However, in order to keep the lowering simple, this value is represented as a 1D 1-element array (`shape=(1,)`). --- .../ir_utils/common_pattern_matcher.py | 10 + .../next/iterator/transforms/fuse_maps.py | 19 +- .../runners/dace_common/utility.py | 9 +- .../gtir_builtin_translators.py | 127 +++++-- .../runners/dace_fieldview/gtir_dataflow.py | 329 +++++++++++++----- .../dace_fieldview/gtir_python_codegen.py | 11 + .../runners/dace_fieldview/gtir_sdfg.py | 25 +- .../runners/dace_fieldview/utility.py | 46 +-- .../dace_tests/test_gtir_to_sdfg.py | 111 +++--- 9 files changed, 453 insertions(+), 234 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py index 4aea7ef149..16a88b282a 100644 --- a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py +++ b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py @@ -22,6 +22,16 @@ def is_applied_lift(arg: itir.Node) -> TypeGuard[itir.FunCall]: ) +def is_applied_map(arg: itir.Node) -> TypeGuard[itir.FunCall]: + """Match expressions of the form `map(λ(...) → ...)(...)`.""" + return ( + isinstance(arg, itir.FunCall) + and isinstance(arg.fun, itir.FunCall) + and isinstance(arg.fun.fun, itir.SymRef) + and arg.fun.fun.id == "map_" + ) + + def is_applied_reduce(arg: itir.Node) -> TypeGuard[itir.FunCall]: """Match expressions of the form `reduce(λ(...) → ...)(...)`.""" return ( diff --git a/src/gt4py/next/iterator/transforms/fuse_maps.py b/src/gt4py/next/iterator/transforms/fuse_maps.py index 430d794880..8d27178682 100644 --- a/src/gt4py/next/iterator/transforms/fuse_maps.py +++ b/src/gt4py/next/iterator/transforms/fuse_maps.py @@ -7,7 +7,6 @@ # SPDX-License-Identifier: BSD-3-Clause import dataclasses -from typing import TypeGuard from gt4py.eve import NodeTranslator, traits from gt4py.eve.utils import UIDGenerator @@ -16,14 +15,6 @@ from gt4py.next.iterator.transforms import inline_lambdas -def _is_map(node: ir.Node) -> TypeGuard[ir.FunCall]: - return ( - isinstance(node, ir.FunCall) - and isinstance(node.fun, ir.FunCall) - and node.fun.fun == ir.SymRef(id="map_") - ) - - @dataclasses.dataclass(frozen=True) class FuseMaps(traits.PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator): """ @@ -58,10 +49,10 @@ def _as_lambda(self, fun: ir.SymRef | ir.Lambda, param_count: int) -> ir.Lambda: def visit_FunCall(self, node: ir.FunCall, **kwargs): node = self.generic_visit(node) - if _is_map(node) or cpm.is_applied_reduce(node): - if any(_is_map(arg) for arg in node.args): + if cpm.is_applied_map(node) or cpm.is_applied_reduce(node): + if any(cpm.is_applied_map(arg) for arg in node.args): first_param = ( - 0 if _is_map(node) else 1 + 0 if cpm.is_applied_map(node) else 1 ) # index of the first param of op that maps to args (0 for map, 1 for reduce) assert isinstance(node.fun, ir.FunCall) assert isinstance(node.fun.args[0], (ir.Lambda, ir.SymRef)) @@ -76,7 +67,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs): new_params.append(outer_op.params[0]) for i in range(len(node.args)): - if _is_map(node.args[i]): + if cpm.is_applied_map(node.args[i]): map_call = node.args[i] assert isinstance(map_call, ir.FunCall) assert isinstance(map_call.fun, ir.FunCall) @@ -102,7 +93,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs): new_body ) # removes one level of nesting (the recursive inliner could simplify more, however this can also be done on the full tree later) new_op = ir.Lambda(params=new_params, expr=new_body) - if _is_map(node): + if cpm.is_applied_map(node): return ir.FunCall( fun=ir.FunCall(fun=ir.SymRef(id="map_"), args=[new_op]), args=new_args ) diff --git a/src/gt4py/next/program_processors/runners/dace_common/utility.py b/src/gt4py/next/program_processors/runners/dace_common/utility.py index dec34ecbac..d678fdab7f 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_common/utility.py @@ -37,12 +37,13 @@ def as_dace_type(type_: ts.ScalarType) -> dace.typeclass: raise ValueError(f"Scalar type '{type_}' not supported.") -def as_scalar_type(typestr: str) -> ts.ScalarType: - """Obtain GT4Py scalar type from generic numpy string representation.""" +def as_itir_type(dtype: dace.typeclass) -> ts.ScalarType: + """Get GT4Py scalar representation of a DaCe type.""" + type_name = str(dtype.as_numpy_dtype()) try: - kind = getattr(ts.ScalarKind, typestr.upper()) + kind = getattr(ts.ScalarKind, type_name.upper()) except AttributeError as ex: - raise ValueError(f"Data type {typestr} not supported.") from ex + raise ValueError(f"Data type {type_name} not supported.") from ex return ts.ScalarType(kind) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index e91bd880c6..8fb1451efb 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -10,12 +10,13 @@ import abc import dataclasses -from typing import TYPE_CHECKING, Iterable, Optional, Protocol, TypeAlias +from typing import TYPE_CHECKING, Final, Iterable, Optional, Protocol, TypeAlias import dace import dace.subsets as sbs from gt4py.next import common as gtx_common, utils as gtx_utils +from gt4py.next.ffront import fbuiltins as gtx_fbuiltins from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.iterator.type_system import type_specifications as itir_ts @@ -32,16 +33,29 @@ from gt4py.next.program_processors.runners.dace_fieldview import gtir_sdfg -IteratorIndexDType: TypeAlias = dace.int32 # type of iterator indexes - - @dataclasses.dataclass(frozen=True) class Field: data_node: dace.nodes.AccessNode data_type: ts.FieldType | ts.ScalarType +FieldopDomain: TypeAlias = list[ + tuple[gtx_common.Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType] +] +""" +Domain of a field operator represented as a list of tuples with 3 elements: + - dimension definition + - symbolic expression for lower bound (inclusive) + - symbolic expression for upper bound (exclusive) +""" + + FieldopResult: TypeAlias = Field | tuple[Field | tuple, ...] +"""Result of a field operator, can be either a field or a tuple fields.""" + + +INDEX_DTYPE: Final[dace.typeclass] = dace.dtype_to_typeclass(gtx_fbuiltins.IndexType) +"""Data type used for field indexing.""" class PrimitiveTranslator(Protocol): @@ -81,11 +95,11 @@ def _parse_fieldop_arg( sdfg: dace.SDFG, state: dace.SDFGState, sdfg_builder: gtir_sdfg.SDFGBuilder, - domain: list[ - tuple[gtx_common.Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType] - ], + domain: FieldopDomain, reduce_identity: Optional[gtir_dataflow.SymbolExpr], ) -> gtir_dataflow.IteratorExpr | gtir_dataflow.MemletExpr: + """Helper method to visit an expression passed as argument to a field operator.""" + arg = sdfg_builder.visit( node, sdfg=sdfg, @@ -101,10 +115,7 @@ def _parse_fieldop_arg( return gtir_dataflow.MemletExpr(arg.data_node, sbs.Indices([0])) elif isinstance(arg.data_type, ts.FieldType): indices: dict[gtx_common.Dimension, gtir_dataflow.ValueExpr] = { - dim: gtir_dataflow.SymbolExpr( - dace_gtir_utils.get_map_variable(dim), - IteratorIndexDType, - ) + dim: gtir_dataflow.SymbolExpr(dace_gtir_utils.get_map_variable(dim), INDEX_DTYPE) for dim, _, _ in domain } dims = arg.data_type.dims + ( @@ -120,12 +131,11 @@ def _parse_fieldop_arg( def _create_temporary_field( sdfg: dace.SDFG, state: dace.SDFGState, - domain: list[ - tuple[gtx_common.Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType] - ], + domain: FieldopDomain, node_type: ts.FieldType, - output_desc: dace.data.Data, + dataflow_output: gtir_dataflow.DataflowOutputEdge, ) -> Field: + """Helper method to allocate a temporary field where to write the output of a field operator.""" domain_dims, _, domain_ubs = zip(*domain) field_dims = list(domain_dims) # It should be enough to allocate an array with shape (upper_bound - lower_bound) @@ -138,6 +148,7 @@ def _create_temporary_field( # eliminate most of transient arrays. field_shape = list(domain_ubs) + output_desc = dataflow_output.result.node.desc(sdfg) if isinstance(output_desc, dace.data.Array): assert isinstance(node_type.dtype, itir_ts.ListType) assert isinstance(node_type.dtype.element_type, ts.ScalarType) @@ -157,7 +168,31 @@ def _create_temporary_field( return Field(field_node, field_type) -def translate_as_field_op( +def extract_domain(node: gtir.Node) -> FieldopDomain: + """ + Visits the domain of a field operator and returns a list of dimensions and + the corresponding lower and upper bounds. The returned lower bound is inclusive, + the upper bound is exclusive: [lower_bound, upper_bound[ + """ + assert cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")) + + domain = [] + for named_range in node.args: + assert cpm.is_call_to(named_range, "named_range") + assert len(named_range.args) == 3 + axis = named_range.args[0] + assert isinstance(axis, gtir.AxisLiteral) + lower_bound, upper_bound = ( + dace.symbolic.pystr_to_symbolic(gtir_python_codegen.get_source(arg)) + for arg in named_range.args[1:3] + ) + dim = gtx_common.Dimension(axis.value, axis.kind) + domain.append((dim, lower_bound, upper_bound)) + + return domain + + +def translate_as_fieldop( node: gtir.Node, sdfg: dace.SDFG, state: dace.SDFGState, @@ -188,25 +223,55 @@ def translate_as_field_op( assert isinstance(domain_expr, gtir.FunCall) # parse the domain of the field operator - domain = dace_gtir_utils.get_domain(domain_expr) + domain = extract_domain(domain_expr) + # The reduction identity value is used in place of skip values when building + # a list of neighbor values in the unstructured domain. + # + # A reduction on neighbor values can be either expressed in local view (itir): + # vertices @ u⟨ Vertexₕ: [0, nvertices) ⟩ + # ← as_fieldop( + # λ(it) → reduce(plus, 0)(neighbors(V2Eₒ, it)), u⟨ Vertexₕ: [0, nvertices) ⟩ + # )(edges); + # + # or in field view (gtir): + # vertices @ u⟨ Vertexₕ: [0, nvertices) ⟩ + # ← as_fieldop(λ(it) → reduce(plus, 0)(·it), u⟨ Vertexₕ: [0, nvertices) ⟩)( + # as_fieldop(λ(it) → neighbors(V2Eₒ, it), u⟨ Vertexₕ: [0, nvertices) ⟩)(edges) + # ); + # + # In local view, the list of neighbors is (recursively) built while visiting + # the current expression. + # In field view, the list of neighbors is built as argument to the current + # expression. Therefore, the reduction identity value needs to be passed to + # the argument visitor (`reduce_identity_for_args = reduce_identity`). if cpm.is_applied_reduce(stencil_expr.expr): if reduce_identity is not None: - raise NotImplementedError("nested reductions not supported.") - - # the reduce identity value is used to fill the skip values in neighbors list - _, _, reduce_identity = gtir_dataflow.get_reduce_params(stencil_expr.expr) + raise NotImplementedError("Nested reductions are not supported.") + _, _, reduce_identity_for_args = gtir_dataflow.get_reduce_params(stencil_expr.expr) + elif cpm.is_call_to(stencil_expr.expr, "neighbors"): + # When the visitor hits a neighbors expression, we stop carrying the reduce + # identity further (`reduce_identity_for_args = None`) because the reduce + # identity value is filled in place of skip values in the context of neighbors + # itself, not in the arguments context. + # Besides, setting `reduce_identity_for_args = None` enables a sanity check + # that the sequence 'reduce(V2E) -> neighbors(V2E) -> reduce(C2E) -> neighbors(C2E)' + # is accepted, while 'reduce(V2E) -> reduce(C2E) -> neighbors(V2E) -> neighbors(C2E)' + # is not. The latter sequence would raise the 'NotImplementedError' exception above. + reduce_identity_for_args = None + else: + reduce_identity_for_args = reduce_identity # visit the list of arguments to be passed to the lambda expression stencil_args = [ - _parse_fieldop_arg(arg, sdfg, state, sdfg_builder, domain, reduce_identity) + _parse_fieldop_arg(arg, sdfg, state, sdfg_builder, domain, reduce_identity_for_args) for arg in node.args ] # represent the field operator as a mapped tasklet graph, which will range over the field domain taskgen = gtir_dataflow.LambdaToDataflow(sdfg, state, sdfg_builder, reduce_identity) input_edges, output = taskgen.visit(stencil_expr, args=stencil_args) - output_desc = output.expr.node.desc(sdfg) + output_desc = output.result.node.desc(sdfg) domain_index = sbs.Indices([dace_gtir_utils.get_map_variable(dim) for dim, _, _ in domain]) if isinstance(node.type.dtype, itir_ts.ListType): @@ -220,11 +285,17 @@ def translate_as_field_op( output_subset = sbs.Range.from_indices(domain_index) # create map range corresponding to the field operator domain - map_ranges = {dace_gtir_utils.get_map_variable(dim): f"{lb}:{ub}" for dim, lb, ub in domain} - me, mx = sdfg_builder.add_map("field_op", state, map_ranges) + me, mx = sdfg_builder.add_map( + "fieldop", + state, + ndrange={ + dace_gtir_utils.get_map_variable(dim): f"{lower_bound}:{upper_bound}" + for dim, lower_bound, upper_bound in domain + }, + ) # allocate local temporary storage for the result field - result_field = _create_temporary_field(sdfg, state, domain, node.type, output_desc) + result_field = _create_temporary_field(sdfg, state, domain, node.type, output) # here we setup the edges from the map entry node for edge in input_edges: @@ -439,7 +510,7 @@ def translate_tuple_get( if not isinstance(node.args[0], gtir.Literal): raise ValueError("Tuple can only be subscripted with compile-time constants.") - assert node.args[0].type == dace_utils.as_scalar_type(gtir.INTEGER_INDEX_BUILTIN) + assert node.args[0].type == dace_utils.as_itir_type(INDEX_DTYPE) index = int(node.args[0].value) data_nodes = sdfg_builder.visit( @@ -566,7 +637,7 @@ def translate_symbol_ref( if TYPE_CHECKING: # Use type-checking to assert that all translator functions implement the `PrimitiveTranslator` protocol __primitive_translators: list[PrimitiveTranslator] = [ - translate_as_field_op, + translate_as_fieldop, translate_if, translate_literal, translate_make_tuple, diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index 9739d7927a..0e571fc17d 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -18,7 +18,7 @@ from gt4py import eve from gt4py.next import common as gtx_common from gt4py.next.iterator import ir as gtir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator.type_system import type_specifications as itir_ts from gt4py.next.program_processors.runners.dace_common import utility as dace_utils from gt4py.next.program_processors.runners.dace_fieldview import ( @@ -145,7 +145,7 @@ class DataflowOutputEdge: """ state: dace.SDFGState - expr: DataExpr + result: DataExpr def connect( self, @@ -154,13 +154,13 @@ def connect( subset: sbs.Range, ) -> None: # retrieve the node which writes the result - last_node = self.state.in_edges(self.expr.node)[0].src + last_node = self.state.in_edges(self.result.node)[0].src if isinstance(last_node, dace.nodes.Tasklet): # the last transient node can be deleted - last_node_connector = self.state.in_edges(self.expr.node)[0].src_conn - self.state.remove_node(self.expr.node) + last_node_connector = self.state.in_edges(self.result.node)[0].src_conn + self.state.remove_node(self.result.node) else: - last_node = self.expr.node + last_node = self.result.node last_node_connector = None self.state.add_memlet_path( @@ -272,7 +272,12 @@ def _add_map( ], **kwargs: Any, ) -> Tuple[dace.nodes.MapEntry, dace.nodes.MapExit]: - """Helper method to add a map with unique name in current state.""" + """ + Helper method to add a map in current state. + + The subgraph builder ensures that the map receives a unique name, + by adding a unique suffix to the provided name. + """ return self.subgraph_builder.add_map(name, self.state, ndrange, **kwargs) def _add_tasklet( @@ -283,7 +288,12 @@ def _add_tasklet( code: str, **kwargs: Any, ) -> dace.nodes.Tasklet: - """Helper method to add a tasklet with unique name in current state.""" + """ + Helper method to add a tasklet in current state. + + The subgraph builder ensures that the tasklet receives a unique name, + by adding a unique suffix to the provided name. + """ tasklet_node = self.subgraph_builder.add_tasklet( name, self.state, inputs, outputs, code, **kwargs ) @@ -295,15 +305,68 @@ def _add_tasklet( self.input_edges.append(edge) return tasklet_node + def _add_mapped_tasklet( + self, + name: str, + map_ranges: Dict[str, str | dace.subsets.Subset] + | List[Tuple[str, str | dace.subsets.Subset]], + inputs: Union[Set[str], Dict[str, dace.dtypes.typeclass]], + code: str, + outputs: Union[Set[str], Dict[str, dace.dtypes.typeclass]], + **kwargs: Any, + ) -> tuple[dace.nodes.Tasklet, dace.nodes.MapEntry, dace.nodes.MapExit]: + """ + Helper method to add a mapped tasklet in current state. + + The subgraph builder ensures that the tasklet receives a unique name, + by adding a unique suffix to the provided name. + """ + return self.subgraph_builder.add_mapped_tasklet( + name, self.state, map_ranges, inputs, code, outputs, **kwargs + ) + + def _construct_local_view(self, field: MemletExpr | DataExpr) -> DataExpr: + if isinstance(field, MemletExpr): + desc = field.node.desc(self.sdfg) + local_dim_indices = [i for i, size in enumerate(field.subset.size()) if size != 1] + if len(local_dim_indices) == 0: + # we are accessing a single-element array with shape (1,) + view_shape = (1,) + view_strides = (1,) + else: + view_shape = tuple(desc.shape[i] for i in local_dim_indices) + view_strides = tuple(desc.strides[i] for i in local_dim_indices) + view, _ = self.sdfg.add_view( + f"{field.node.data}_view", + view_shape, + desc.dtype, + strides=view_strides, + find_new_name=True, + ) + local_view_node = self.state.add_access(view) + self._add_input_data_edge(field.node, field.subset, local_view_node) + + return DataExpr(local_view_node, desc.dtype) + + else: + return field + def _construct_tasklet_result( self, dtype: dace.typeclass, src_node: dace.nodes.Tasklet, src_connector: str, + use_array: bool = False, ) -> DataExpr: temp_name = self.sdfg.temp_data_name() - self.sdfg.add_scalar(temp_name, dtype, transient=True) - data_type = dace_utils.as_scalar_type(str(dtype.as_numpy_dtype())) + if use_array: + # In some cases, such as result data with list-type annotation, we want + # that output data is represented as an array (single-element 1D array) + # in order to allow for composition of array shape in external memlets. + self.sdfg.add_array(temp_name, (1,), dtype, transient=True) + else: + self.sdfg.add_scalar(temp_name, dtype, transient=True) + data_type = dace_utils.as_itir_type(dtype) temp_node = self.state.add_access(temp_name) self._add_edge( src_node, @@ -412,6 +475,7 @@ def _visit_deref(self, node: gtir.FunCall) -> ValueExpr: def _visit_neighbors(self, node: gtir.FunCall) -> DataExpr: assert len(node.args) == 2 + assert isinstance(node.type, itir_ts.ListType) assert isinstance(node.args[0], gtir.OffsetLiteral) offset = node.args[0].value @@ -422,9 +486,6 @@ def _visit_neighbors(self, node: gtir.FunCall) -> DataExpr: it = self.visit(node.args[1]) assert isinstance(it, IteratorExpr) assert offset_provider.neighbor_axis in it.dimensions - neighbor_dim_index = it.dimensions.index(offset_provider.neighbor_axis) - assert offset_provider.neighbor_axis not in it.indices - assert offset_provider.origin_axis not in it.dimensions assert offset_provider.origin_axis in it.indices origin_index = it.indices[offset_provider.origin_axis] assert isinstance(origin_index, SymbolExpr) @@ -446,38 +507,24 @@ def _visit_neighbors(self, node: gtir.FunCall) -> DataExpr: # node). For the specific case of `neighbors` we need to nest the neighbors map # inside the field map and the memlets will traverse the external map and write # to the view nodes. The simplify pass will remove the redundant access nodes. - field_slice_view, field_slice_desc = self.sdfg.add_view( - f"{offset_provider.neighbor_axis.value}_view", - (field_desc.shape[neighbor_dim_index],), - field_desc.dtype, - strides=(field_desc.strides[neighbor_dim_index],), - find_new_name=True, - ) - field_slice_node = self.state.add_access(field_slice_view) - field_subset = ",".join( - it.indices[dim].value # type: ignore[union-attr] - if dim != offset_provider.neighbor_axis - else f"0:{size}" - for dim, size in zip(it.dimensions, field_desc.shape, strict=True) - ) - self._add_input_data_edge( - it.field, - sbs.Range.from_string(field_subset), - field_slice_node, - ) - - connectivity_slice_view, _ = self.sdfg.add_view( - "neighbors_view", - (offset_provider.max_neighbors,), - connectivity_desc.dtype, - strides=(connectivity_desc.strides[1],), - find_new_name=True, + field_slice = self._construct_local_view( + MemletExpr( + it.field, + sbs.Range.from_string( + ",".join( + it.indices[dim].value # type: ignore[union-attr] + if dim != offset_provider.neighbor_axis + else f"0:{size}" + for dim, size in zip(it.dimensions, field_desc.shape, strict=True) + ) + ), + ) ) - connectivity_slice_node = self.state.add_access(connectivity_slice_view) - self._add_input_data_edge( - self.state.add_access(connectivity), - sbs.Range.from_string(f"{origin_index.value}, 0:{offset_provider.max_neighbors}"), - connectivity_slice_node, + connectivity_slice = self._construct_local_view( + MemletExpr( + self.state.add_access(connectivity), + sbs.Range.from_string(f"{origin_index.value}, 0:{offset_provider.max_neighbors}"), + ) ) neighbors_temp, _ = self.sdfg.add_temp_transient( @@ -487,64 +534,135 @@ def _visit_neighbors(self, node: gtir.FunCall) -> DataExpr: offset_dim = gtx_common.Dimension(offset, kind=gtx_common.DimensionKind.LOCAL) neighbor_idx = dace_gtir_utils.get_map_variable(offset_dim) - me, mx = self._add_map( - f"{offset}_neighbors", - { - neighbor_idx: f"0:{offset_provider.max_neighbors}", - }, - ) + index_connector = "__index" + output_connector = "__val" + tasklet_expression = f"{output_connector} = __field[{index_connector}]" + input_memlets = { + "__field": self.sdfg.make_array_memlet(field_slice.node.data), + index_connector: dace.Memlet(data=connectivity_slice.node.data, subset=neighbor_idx), + } + input_nodes = { + field_slice.node.data: field_slice.node, + connectivity_slice.node.data: connectivity_slice.node, + } + if offset_provider.has_skip_values: assert self.reduce_identity is not None assert self.reduce_identity.dtype == field_desc.dtype - # TODO: Investigate if a NestedSDFG brings benefits - tasklet_node = self._add_tasklet( - "gather_neighbors_with_skip_values", - {"__field", index_connector}, - {"__val"}, - f"__val = __field[{index_connector}] if {index_connector} != {gtx_common._DEFAULT_SKIP_VALUE} else {self.reduce_identity.dtype}({self.reduce_identity.value})", - ) + tasklet_expression += f" if {index_connector} != {gtx_common._DEFAULT_SKIP_VALUE} else {field_desc.dtype}({self.reduce_identity.value})" + + self._add_mapped_tasklet( + name=f"{offset}_neighbors", + map_ranges={neighbor_idx: f"0:{offset_provider.max_neighbors}"}, + code=tasklet_expression, + inputs=input_memlets, + input_nodes=input_nodes, + outputs={ + output_connector: dace.Memlet(data=neighbors_temp, subset=neighbor_idx), + }, + output_nodes={neighbors_temp: neighbors_node}, + external_edges=True, + ) - else: - tasklet_node = self._add_tasklet( - "gather_neighbors", - {"__field", index_connector}, - {"__val"}, - f"__val = __field[{index_connector}]", - ) + return DataExpr(neighbors_node, node.type) - self.state.add_memlet_path( - field_slice_node, - me, - tasklet_node, - dst_conn="__field", - memlet=dace.Memlet.from_array(field_slice_view, field_slice_desc), - ) - self.state.add_memlet_path( - connectivity_slice_node, - me, - tasklet_node, - dst_conn=index_connector, - memlet=dace.Memlet(data=connectivity_slice_view, subset=neighbor_idx), - ) - self.state.add_memlet_path( - tasklet_node, - mx, - neighbors_node, - src_conn="__val", - memlet=dace.Memlet(data=neighbors_temp, subset=neighbor_idx), - ) + def _visit_map(self, node: gtir.FunCall) -> DataExpr: + """ + A map node defines an operation to be mapped on all elements of input arguments. + + The map operation is applied on the local dimension of input fields. + In the example below, the local dimension consists of a list of neighbor + values as the first argument, and a list of constant values `1.0`: + `map_(plus)(neighbors(V2E, it), make_const_list(1.0))` + + The `plus` operation is lowered to a tasklet inside a map that computes + the domain of the local dimension (in this example, max neighbors in V2E). + The result is a 1D local field, with same size as the input local dimension. + In above example, the result would be an array with size V2E.max_neighbors, + containing the V2E neighbor values incremented by 1.0. + """ assert isinstance(node.type, itir_ts.ListType) - return DataExpr(neighbors_node, node.type) + assert isinstance(node.fun, gtir.FunCall) + assert len(node.fun.args) == 1 # the operation to be mapped on the arguments + + assert isinstance(node.type.element_type, ts.ScalarType) + dtype = dace_utils.as_dace_type(node.type.element_type) + + input_args = [self.visit(arg) for arg in node.args] + input_connectors = [f"__arg{i}" for i in range(len(input_args))] + output_connector = "__out" + + # Here we build the body of the tasklet + fun_node = im.call(node.fun.args[0])(*input_connectors) + fun_python_code = gtir_python_codegen.get_source(fun_node) + tasklet_expression = f"{output_connector} = {fun_python_code}" + + # TODO(edopao): extract offset_dim from the input arguments + offset_dim = gtx_common.Dimension("", gtx_common.DimensionKind.LOCAL) + map_index = dace_gtir_utils.get_map_variable(offset_dim) + + # The dataflow we build in this class has some loose connections on input edges. + # These edges are described as set of nodes, that will have to be connected to + # external data source nodes passing through the map entry node of the field map. + # Similarly to `neighbors` expressions, the `map_` input edges terminate on view + # nodes (see `_construct_local_view` in the for-loop below), because it is simpler + # than representing map-to-map edges (which require memlets with 2 pass-nodes). + input_memlets = {} + input_nodes = {} + local_size: Optional[int] = None + for conn, input_expr in zip(input_connectors, input_args): + input_node = self._construct_local_view(input_expr).node + input_desc = input_node.desc(self.sdfg) + # we assume that there is a single local dimension + if len(input_desc.shape) != 1: + raise ValueError(f"More than one local dimension in map expression {node}.") + input_size = input_desc.shape[0] + if input_size == 1: + input_memlets[conn] = dace.Memlet(data=input_node.data, subset="0") + elif local_size is not None and input_size != local_size: + raise ValueError(f"Invalid node {node}") + else: + input_memlets[conn] = dace.Memlet(data=input_node.data, subset=map_index) + local_size = input_size + + input_nodes[input_node.data] = input_node + + if local_size is None: + # corner case where map is applied to 1-element lists + assert len(input_nodes) >= 1 + local_size = 1 + + out, _ = self.sdfg.add_temp_transient((local_size,), dtype) + out_node = self.state.add_access(out) + + self._add_mapped_tasklet( + name="map", + map_ranges={map_index: f"0:{local_size}"}, + code=tasklet_expression, + inputs=input_memlets, + input_nodes=input_nodes, + outputs={ + output_connector: dace.Memlet(data=out, subset=map_index), + }, + output_nodes={out: out_node}, + external_edges=True, + ) + + return DataExpr(out_node, dtype) def _visit_reduce(self, node: gtir.FunCall) -> DataExpr: + assert isinstance(node.type, ts.ScalarType) op_name, reduce_init, reduce_identity = get_reduce_params(node) - dtype = reduce_identity.dtype - # We store the value of reduce identity in the visitor context while visiting - # the input to reduction; this value will be use by the `neighbors` visitor - # to fill the skip values in the neighbors list. + # The input to reduction is a list of elements on a local dimension. + # This list is provided by an argument that typically calls the neighbors + # builtin function, to built a list of neighbor values for each element + # in the field target dimension. + # We store the value of reduce identity in the visitor context to have it + # available while visiting the input to reduction; this value might be used + # by the `neighbors` visitor to fill the skip values in the neighbors list. prev_reduce_identity = self.reduce_identity self.reduce_identity = reduce_identity @@ -585,7 +703,7 @@ def _visit_reduce(self, node: gtir.FunCall) -> DataExpr: ) temp_name = self.sdfg.temp_data_name() - self.sdfg.add_scalar(temp_name, dtype, transient=True) + self.sdfg.add_scalar(temp_name, reduce_identity.dtype, transient=True) temp_node = self.state.add_access(temp_name) self.state.add_nedge( @@ -593,7 +711,6 @@ def _visit_reduce(self, node: gtir.FunCall) -> DataExpr: temp_node, dace.Memlet(data=temp_name, subset="0"), ) - assert isinstance(node.type, ts.ScalarType) return DataExpr(temp_node, node.type) def _split_shift_args( @@ -816,9 +933,6 @@ def _visit_generic_builtin(self, node: gtir.FunCall) -> DataExpr: Generic handler called by `visit_FunCall()` when it encounters a builtin function that does not match any other specific handler. """ - assert isinstance(node.type, ts.ScalarType) - dtype = dace_utils.as_dace_type(node.type) - node_internals = [] node_connections: dict[str, MemletExpr | DataExpr] = {} for i, arg in enumerate(node.args): @@ -863,7 +977,27 @@ def _visit_generic_builtin(self, node: gtir.FunCall) -> DataExpr: connector, ) - return self._construct_tasklet_result(dtype, tasklet_node, "result") + if isinstance(node.type, itir_ts.ListType): + # The only builtin function (so far) handled here that returns a list + # is 'make_const_list'. There are other builtin functions (map_, neighbors) + # that return a list but they are handled in specialized visit methods. + # This method (the generic visitor for builtin functions) always returns + # a single value. This is also the case of 'make_const_list' expression: + # it simply broadcasts a scalar on the local domain of another expression, + # for example 'map_(plus)(neighbors(V2Eₒ, it), make_const_list(1.0))'. + # Therefore we handle `ListType` as a single-element array with shape (1,) + # that will be accessed in a map expression on a local domain. + assert isinstance(node.type.element_type, ts.ScalarType) + dtype = dace_utils.as_dace_type(node.type.element_type) + # In order to ease the lowring of the parent expression on local dimension, + # we represent the scalar value as a single-element 1D array. + use_array = True + else: + assert isinstance(node.type, ts.ScalarType) + dtype = dace_utils.as_dace_type(node.type) + use_array = False + + return self._construct_tasklet_result(dtype, tasklet_node, "result", use_array=use_array) def visit_FunCall(self, node: gtir.FunCall) -> IteratorExpr | ValueExpr: if cpm.is_call_to(node, "deref"): @@ -872,6 +1006,9 @@ def visit_FunCall(self, node: gtir.FunCall) -> IteratorExpr | ValueExpr: elif cpm.is_call_to(node, "neighbors"): return self._visit_neighbors(node) + elif cpm.is_applied_map(node): + return self._visit_map(node) + elif cpm.is_applied_reduce(node): return self._visit_reduce(node) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py index f133a9224d..6aee33c56e 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py @@ -75,6 +75,7 @@ def builtin_cast(*args: Any) -> str: val, target_type = args + assert target_type in gtir.TYPEBUILTINS return MATH_BUILTINS_MAPPING[target_type].format(val) @@ -83,9 +84,19 @@ def builtin_if(*args: Any) -> str: return f"{true_val} if {cond} else {false_val}" +def make_const_list(arg: str) -> str: + """ + Takes a single scalar argument and broadcasts this value on the local dimension + of map expression. In a dataflow, we represent it as a tasklet that writes + a value to a scalar node. + """ + return arg + + GENERAL_BUILTIN_MAPPING: dict[str, Callable[[Any], str]] = { "cast_": builtin_cast, "if_": builtin_if, + "make_const_list": make_const_list, } diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index 09d5d6c0d0..d79d887318 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -77,6 +77,21 @@ def add_tasklet( unique_name = self.unique_tasklet_name(name) return state.add_tasklet(unique_name, inputs, outputs, code, **kwargs) + def add_mapped_tasklet( + self, + name: str, + state: dace.SDFGState, + map_ranges: Dict[str, str | dace.subsets.Subset] + | List[Tuple[str, str | dace.subsets.Subset]], + inputs: Union[Set[str], Dict[str, dace.dtypes.typeclass]], + code: str, + outputs: Union[Set[str], Dict[str, dace.dtypes.typeclass]], + **kwargs: Any, + ) -> tuple[dace.nodes.Tasklet, dace.nodes.MapEntry, dace.nodes.MapExit]: + """Wrapper of `dace.SDFGState.add_mapped_tasklet` that assigns unique name.""" + unique_name = self.unique_tasklet_name(name) + return state.add_mapped_tasklet(unique_name, map_ranges, inputs, code, outputs, **kwargs) + class SDFGBuilder(DataflowBuilder, Protocol): """Visitor interface available to GTIR-primitive translators.""" @@ -111,7 +126,7 @@ class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): map_uids: eve.utils.UIDGenerator = dataclasses.field( init=False, repr=False, default_factory=lambda: eve.utils.UIDGenerator(prefix="map") ) - tesklet_uids: eve.utils.UIDGenerator = dataclasses.field( + tasklet_uids: eve.utils.UIDGenerator = dataclasses.field( init=False, repr=False, default_factory=lambda: eve.utils.UIDGenerator(prefix="tlet") ) @@ -125,7 +140,7 @@ def unique_map_name(self, name: str) -> str: return f"{self.map_uids.sequential_id()}_{name}" def unique_tasklet_name(self, name: str) -> str: - return f"{self.tesklet_uids.sequential_id()}_{name}" + return f"{self.tasklet_uids.sequential_id()}_{name}" def _make_array_shape_and_strides( self, name: str, dims: Sequence[gtx_common.Dimension] @@ -353,7 +368,9 @@ def visit_SetAt( target_fields = self._visit_expression(stmt.target, sdfg, state, use_temp=False) # convert domain expression to dictionary to ease access to dimension boundaries - domain = dace_gtir_utils.get_domain_ranges(stmt.domain) + domain = { + dim: (lb, ub) for dim, lb, ub in gtir_builtin_translators.extract_domain(stmt.domain) + } expr_input_args = { sym_id @@ -422,7 +439,7 @@ def visit_FunCall( node, sdfg, head_state, self, reduce_identity ) elif cpm.is_applied_as_fieldop(node): - return gtir_builtin_translators.translate_as_field_op( + return gtir_builtin_translators.translate_as_fieldop( node, sdfg, head_state, self, reduce_identity ) elif isinstance(node.fun, gtir.Lambda): diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index 2988b01a61..855dc9c91a 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -11,61 +11,19 @@ import itertools from typing import Any -import dace - from gt4py import eve from gt4py.next import common as gtx_common from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm -from gt4py.next.program_processors.runners.dace_fieldview import gtir_python_codegen from gt4py.next.type_system import type_specifications as ts -def get_domain( - node: gtir.Expr, -) -> list[tuple[gtx_common.Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]]: - """ - Specialized visit method for domain expressions. - - Returns for each domain dimension the corresponding range. - - TODO: Domain expressions will be recurrent in the GTIR program. An interesting idea - would be to cache the results of lowering here (e.g. using `functools.lru_cache`) - """ - assert cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")) - - domain = [] - for named_range in node.args: - assert cpm.is_call_to(named_range, "named_range") - assert len(named_range.args) == 3 - axis = named_range.args[0] - assert isinstance(axis, gtir.AxisLiteral) - bounds = [ - dace.symbolic.pystr_to_symbolic(gtir_python_codegen.get_source(arg)) - for arg in named_range.args[1:3] - ] - dim = gtx_common.Dimension(axis.value, axis.kind) - domain.append((dim, bounds[0], bounds[1])) - - return domain - - -def get_domain_ranges( - node: gtir.Expr, -) -> dict[gtx_common.Dimension, tuple[dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]]: - """ - Returns domain represented in dictionary form. - """ - domain = get_domain(node) - - return {dim: (lb, ub) for dim, lb, ub in domain} - - def get_map_variable(dim: gtx_common.Dimension) -> str: """ Format map variable name based on the naming convention for application-specific SDFG transformations. """ suffix = "dim" if dim.kind == gtx_common.DimensionKind.LOCAL else "" + # TODO(edopao): raise exception if dim.value is empty return f"i_{dim.value}_gtx_{dim.kind}{suffix}" @@ -140,7 +98,7 @@ def visit_FunCall(self, node: gtir.FunCall) -> gtir.Node: ) node.args = [] - node.args = [self.visit(arg) for arg in node.args] + node.args = self.visit(node.args) node.fun = self.visit(node.fun) return node diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index e819cdcd8c..98e15dac3c 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -1323,63 +1323,86 @@ def test_gtir_reduce_with_skip_values(): def test_gtir_reduce_dot_product(): - # FIXME[#1582](edopao): Enable testcase when type inference is working - pytest.skip("Field of lists not fully supported as a type in GTIR yet") init_value = np.random.rand() vertex_domain = im.domain(gtx_common.GridType.UNSTRUCTURED, ranges={Vertex: (0, "nvertices")}) - testee = gtir.Program( - id="reduce_dot_product", - function_definitions=[], - params=[ - gtir.Sym(id="edges", type=EFTYPE), - gtir.Sym(id="vertices", type=VFTYPE), - gtir.Sym(id="nvertices", type=SIZE_TYPE), - ], - declarations=[], - body=[ - gtir.SetAt( - expr=im.call( - im.call("as_fieldop")( - im.lambda_("it")( - im.call(im.call("reduce")("plus", im.literal_from_value(init_value)))( - im.deref("it") - ) - ), - vertex_domain, - ) - )( - im.op_as_fieldop("multiplies", vertex_domain)( - im.as_fieldop_neighbors("V2E", "edges", vertex_domain), - im.as_fieldop_neighbors("V2E", "edges", vertex_domain), - ), - ), - domain=vertex_domain, - target=gtir.SymRef(id="vertices"), - ) - ], - ) - connectivity_V2E = SIMPLE_MESH_OFFSET_PROVIDER["V2E"] assert isinstance(connectivity_V2E, gtx_common.NeighborTable) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH_OFFSET_PROVIDER) - e = np.random.rand(SIMPLE_MESH.num_edges) v = np.empty(SIMPLE_MESH.num_vertices, dtype=e.dtype) v_ref = [ - reduce(lambda x, y: x + y, e[v2e_neighbors] * e[v2e_neighbors], init_value) + functools.reduce( + lambda x, y: x + y, (e[v2e_neighbors] * e[v2e_neighbors]) + 1.0, init_value + ) for v2e_neighbors in connectivity_V2E.table ] - sdfg( - e, - v, - connectivity_V2E=connectivity_V2E.table, - **FSYMBOLS, - **make_mesh_symbols(SIMPLE_MESH), + stencil_inlined = im.call( + im.call("as_fieldop")( + im.lambda_("it")( + im.call(im.call("reduce")("plus", im.literal_from_value(init_value)))( + im.map_("plus")( + im.map_("multiplies")( + im.neighbors("V2E", "it"), + im.neighbors("V2E", "it"), + ), + im.call("make_const_list")(1.0), + ) + ) + ), + vertex_domain, + ) + )("edges") + + stencil_fieldview = im.call( + im.call("as_fieldop")( + im.lambda_("it")( + im.call(im.call("reduce")("plus", im.literal_from_value(init_value)))( + im.deref("it") + ) + ), + vertex_domain, + ) + )( + im.op_as_fieldop(im.map_("plus"), vertex_domain)( + im.op_as_fieldop(im.map_("multiplies"), vertex_domain)( + im.as_fieldop_neighbors("V2E", "edges", vertex_domain), + im.as_fieldop_neighbors("V2E", "edges", vertex_domain), + ), + im.op_as_fieldop("make_const_list", vertex_domain)(1.0), + ) ) - assert np.allclose(v, v_ref) + + for i, stencil in enumerate([stencil_inlined, stencil_fieldview]): + testee = gtir.Program( + id=f"reduce_dot_product_{i}", + function_definitions=[], + params=[ + gtir.Sym(id="edges", type=EFTYPE), + gtir.Sym(id="vertices", type=VFTYPE), + gtir.Sym(id="nvertices", type=SIZE_TYPE), + ], + declarations=[], + body=[ + gtir.SetAt( + expr=stencil, + domain=vertex_domain, + target=gtir.SymRef(id="vertices"), + ) + ], + ) + + sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH_OFFSET_PROVIDER) + + sdfg( + e, + v, + connectivity_V2E=connectivity_V2E.table, + **FSYMBOLS, + **make_mesh_symbols(SIMPLE_MESH), + ) + assert np.allclose(v, v_ref) def test_gtir_reduce_with_cond_neighbors():