diff --git a/ci/cscs-ci.yml b/ci/cscs-ci.yml index b8902f37f7..243716c459 100644 --- a/ci/cscs-ci.yml +++ b/ci/cscs-ci.yml @@ -52,8 +52,6 @@ stages: variables: CUDA_VERSION: 12.4.1 CUPY_PACKAGE: cupy-cuda12x - # TODO: re-enable CI job when Todi is back in operational state - when: manual build_py311_baseimage_x86_64: extends: .build_baseimage_x86_64 @@ -176,9 +174,6 @@ build_py38_image_x86_64: - SUBPACKAGE: next VARIANT: [-nomesh, -atlas] SUBVARIANT: [-cuda12x, -cpu] - before_script: - # TODO: remove start of CUDA MPS daemon once CI-CD can handle CRAY_CUDA_MPS - - CUDA_MPS_PIPE_DIRECTORY="/tmp/nvidia-mps" nvidia-cuda-mps-control -d variables: # Grace-Hopper gpu architecture is not enabled by default in CUDA build CUDAARCHS: "90" diff --git a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py index e3dac7a578..e01d6ea51f 100644 --- a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py +++ b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py @@ -27,6 +27,16 @@ def is_applied_lift(arg: itir.Node) -> TypeGuard[itir.FunCall]: ) +def is_applied_reduce(arg: itir.Node) -> TypeGuard[itir.FunCall]: + """Match expressions of the form `reduce(λ(...) → ...)(...)`.""" + return ( + isinstance(arg, itir.FunCall) + and isinstance(arg.fun, itir.FunCall) + and isinstance(arg.fun.fun, itir.SymRef) + and arg.fun.fun.id == "reduce" + ) + + def is_applied_shift(arg: itir.Node) -> TypeGuard[itir.FunCall]: """Match expressions of the form `shift(λ(...) → ...)(...)`.""" return ( diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index b1c9ec409f..c97d4c0e56 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -423,3 +423,32 @@ def as_fieldop(expr: itir.Expr, domain: Optional[itir.FunCall] = None) -> call: ) ) ) + + +def op_as_fieldop( + op: str | itir.SymRef | Callable, domain: Optional[itir.FunCall] = None +) -> Callable[..., itir.FunCall]: + """ + Promotes a function `op` to a field_operator. + + Args: + op: a function from values to value. + domain: the domain of the returned field. + + Returns: + A function from Fields to Field. + + Examples: + >>> str(op_as_fieldop("op")("a", "b")) + '(⇑(λ(__arg0, __arg1) → op(·__arg0, ·__arg1)))(a, b)' + """ + if isinstance(op, (str, itir.SymRef, itir.Lambda)): + op = call(op) + + def _impl(*its: itir.Expr) -> itir.FunCall: + args = [ + f"__arg{i}" for i in range(len(its)) + ] # TODO: `op` must not contain `SymRef(id="__argX")` + return as_fieldop(lambda_(*args)(op(*[deref(arg) for arg in args])), domain)(*its) + + return _impl diff --git a/src/gt4py/next/iterator/transforms/fuse_maps.py b/src/gt4py/next/iterator/transforms/fuse_maps.py index c10cb6f3e7..a8089e521e 100644 --- a/src/gt4py/next/iterator/transforms/fuse_maps.py +++ b/src/gt4py/next/iterator/transforms/fuse_maps.py @@ -18,6 +18,7 @@ from gt4py.eve import NodeTranslator, traits from gt4py.eve.utils import UIDGenerator from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.iterator.transforms import inline_lambdas @@ -29,14 +30,6 @@ def _is_map(node: ir.Node) -> TypeGuard[ir.FunCall]: ) -def _is_reduce(node: ir.Node) -> TypeGuard[ir.FunCall]: - return ( - isinstance(node, ir.FunCall) - and isinstance(node.fun, ir.FunCall) - and node.fun.fun == ir.SymRef(id="reduce") - ) - - @dataclasses.dataclass(frozen=True) class FuseMaps(traits.PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator): """ @@ -71,7 +64,7 @@ def _as_lambda(self, fun: ir.SymRef | ir.Lambda, param_count: int) -> ir.Lambda: def visit_FunCall(self, node: ir.FunCall, **kwargs): node = self.generic_visit(node) - if _is_map(node) or _is_reduce(node): + if _is_map(node) or cpm.is_applied_reduce(node): if any(_is_map(arg) for arg in node.args): first_param = ( 0 if _is_map(node) else 1 @@ -83,7 +76,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs): inlined_args = [] new_params = [] new_args = [] - if _is_reduce(node): + if cpm.is_applied_reduce(node): # param corresponding to reduce acc inlined_args.append(ir.SymRef(id=outer_op.params[0].id)) new_params.append(outer_op.params[0]) @@ -119,7 +112,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs): return ir.FunCall( fun=ir.FunCall(fun=ir.SymRef(id="map_"), args=[new_op]), args=new_args ) - else: # _is_reduce(node) + else: # is_applied_reduce(node) return ir.FunCall( fun=ir.FunCall(fun=ir.SymRef(id="reduce"), args=[new_op, node.fun.args[1]]), args=new_args, diff --git a/src/gt4py/next/iterator/transforms/unroll_reduce.py b/src/gt4py/next/iterator/transforms/unroll_reduce.py index 47b8556c4e..75cde58723 100644 --- a/src/gt4py/next/iterator/transforms/unroll_reduce.py +++ b/src/gt4py/next/iterator/transforms/unroll_reduce.py @@ -20,6 +20,7 @@ from gt4py.eve.utils import UIDGenerator from gt4py.next import common from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift @@ -60,16 +61,12 @@ def _get_partial_offset_tags(reduce_args: Iterable[itir.Expr]) -> Iterable[str]: return [_get_partial_offset_tag(arg) for arg in _get_neighbors_args(reduce_args)] -def _is_reduce(node: itir.FunCall) -> TypeGuard[itir.FunCall]: - return isinstance(node.fun, itir.FunCall) and node.fun.fun == itir.SymRef(id="reduce") - - def _get_connectivity( applied_reduce_node: itir.FunCall, offset_provider: dict[str, common.Dimension | common.Connectivity], ) -> common.Connectivity: """Return single connectivity that is compatible with the arguments of the reduce.""" - if not _is_reduce(applied_reduce_node): + if not cpm.is_applied_reduce(applied_reduce_node): raise ValueError("Expected a call to a 'reduce' object, i.e. 'reduce(...)(...)'.") connectivities: list[common.Connectivity] = [] @@ -158,6 +155,6 @@ def _visit_reduce(self, node: itir.FunCall, **kwargs) -> itir.Expr: def visit_FunCall(self, node: itir.FunCall, **kwargs) -> itir.Expr: node = self.generic_visit(node, **kwargs) - if _is_reduce(node): + if cpm.is_applied_reduce(node): return self._visit_reduce(node, **kwargs) return node diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 0ce9952a51..e8dff30a1a 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -24,7 +24,7 @@ from gt4py.next import common as gtx_common from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm -from gt4py.next.iterator.type_system import type_specifications as itir_ts +from gt4py.next.iterator.type_system import type_specifications as gtir_ts from gt4py.next.program_processors.runners.dace_fieldview import ( gtir_python_codegen, gtir_to_tasklet, @@ -38,6 +38,7 @@ IteratorIndexDType: TypeAlias = dace.int32 # type of iterator indexes +LetSymbol: TypeAlias = tuple[gtir.Literal | gtir.SymRef, ts.FieldType | ts.ScalarType] TemporaryData: TypeAlias = tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType] @@ -49,6 +50,7 @@ def __call__( sdfg: dace.SDFG, state: dace.SDFGState, sdfg_builder: gtir_to_sdfg.SDFGBuilder, + let_symbols: dict[str, LetSymbol], ) -> list[TemporaryData]: """Creates the dataflow subgraph representing a GTIR primitive function. @@ -60,6 +62,9 @@ def __call__( sdfg: The SDFG where the primitive subgraph should be instantiated state: The SDFG state where the result of the primitive function should be made available sdfg_builder: The object responsible for visiting child nodes of the primitive node. + let_symbols: Mapping of symbols (i.e. lambda parameters and/or local constants + like the identity value in a reduction context) to temporary fields + or symbolic expressions. Returns: A list of data access nodes and the associated GT4Py data type, which provide @@ -77,8 +82,14 @@ def _parse_arg_expr( domain: list[ tuple[gtx_common.Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType] ], + let_symbols: dict[str, LetSymbol], ) -> gtir_to_tasklet.IteratorExpr | gtir_to_tasklet.MemletExpr: - fields: list[TemporaryData] = sdfg_builder.visit(node, sdfg=sdfg, head_state=state) + fields: list[TemporaryData] = sdfg_builder.visit( + node, + sdfg=sdfg, + head_state=state, + let_symbols=let_symbols, + ) assert len(fields) == 1 data_node, arg_type = fields[0] @@ -96,11 +107,12 @@ def _parse_arg_expr( ) for dim, _, _ in domain } - return gtir_to_tasklet.IteratorExpr( - data_node, - arg_type.dims, - indices, + dims = arg_type.dims + ( + # we add an extra anonymous dimension in the iterator definition to enable + # dereferencing elements in `ListType` + [gtx_common.Dimension("")] if isinstance(arg_type.dtype, gtir_ts.ListType) else [] ) + return gtir_to_tasklet.IteratorExpr(data_node, dims, indices) def _create_temporary_field( @@ -125,27 +137,20 @@ def _create_temporary_field( field_offset = [-lb for lb in domain_lbs] if isinstance(output_desc, dace.data.Array): - # extend the result arrays with the local dimensions added by the field operator e.g. `neighbors`) - assert isinstance(output_field_type, ts.FieldType) - if isinstance(node_type.dtype, itir_ts.ListType): - raise NotImplementedError - else: - field_dtype = node_type.dtype - assert output_field_type.dtype == field_dtype - field_dims.extend(output_field_type.dims) + assert isinstance(node_type.dtype, gtir_ts.ListType) + field_dtype = node_type.dtype.element_type + # extend the result arrays with the local dimensions added by the field operator (e.g. `neighbors`) field_shape.extend(output_desc.shape) else: assert isinstance(output_desc, dace.data.Scalar) - assert isinstance(output_field_type, ts.ScalarType) field_dtype = node_type.dtype - assert output_field_type == field_dtype # allocate local temporary storage for the result field temp_name, _ = sdfg.add_temp_transient( field_shape, dace_fieldview_util.as_dace_type(field_dtype), offset=field_offset ) field_node = state.add_access(temp_name) - field_type = ts.FieldType(field_dims, field_dtype) + field_type = ts.FieldType(field_dims, node_type.dtype) return field_node, field_type @@ -155,10 +160,12 @@ def translate_as_field_op( sdfg: dace.SDFG, state: dace.SDFGState, sdfg_builder: gtir_to_sdfg.SDFGBuilder, + let_symbols: dict[str, LetSymbol], ) -> list[TemporaryData]: - """Generates the dataflow subgraph for the `as_field_op` builtin function.""" + """Generates the dataflow subgraph for the `as_fieldop` builtin function.""" assert isinstance(node, gtir.FunCall) assert cpm.is_call_to(node.fun, "as_fieldop") + assert isinstance(node.type, ts.FieldType) fun_node = node.fun assert len(fun_node.args) == 2 @@ -172,11 +179,40 @@ def translate_as_field_op( domain = dace_fieldview_util.get_domain(domain_expr) assert isinstance(node.type, ts.FieldType) + reduce_identity: Optional[gtir_to_tasklet.SymbolExpr] = None + if cpm.is_applied_reduce(stencil_expr.expr): + # 'reduce' is a reserved keyword of the DSL and we will never find a user-defined symbol + # with this name. Since 'reduce' will never collide with a user-defined symbol, it is safe + # to use it internally to store the reduce identity value as a let-symbol. + if "reduce" in let_symbols: + raise NotImplementedError("nested reductions not supported.") + + # the reduce identity value is used to fill the skip values in neighbors list + _, _, reduce_identity = gtir_to_tasklet.get_reduce_params(stencil_expr.expr) + + # we store the reduce identity value as a constant let-symbol + let_symbols = let_symbols | { + "reduce": ( + gtir.Literal(value=str(reduce_identity.value), type=stencil_expr.expr.type), + reduce_identity.dtype, + ) + } + + elif "reduce" in let_symbols: + # a parent node is a reduction node, so we are visiting the current node in the context of a reduction + reduce_symbol, _ = let_symbols["reduce"] + assert isinstance(reduce_symbol, gtir.Literal) + reduce_identity = gtir_to_tasklet.SymbolExpr( + reduce_symbol.value, dace_fieldview_util.as_dace_type(reduce_symbol.type) + ) + # first visit the list of arguments and build a symbol map - stencil_args = [_parse_arg_expr(arg, sdfg, state, sdfg_builder, domain) for arg in node.args] + stencil_args = [ + _parse_arg_expr(arg, sdfg, state, sdfg_builder, domain, let_symbols) for arg in node.args + ] # represent the field operator as a mapped tasklet graph, which will range over the field domain - taskgen = gtir_to_tasklet.LambdaToTasklet(sdfg, state, sdfg_builder) + taskgen = gtir_to_tasklet.LambdaToTasklet(sdfg, state, sdfg_builder, reduce_identity) input_connections, output_expr = taskgen.visit(stencil_expr, args=stencil_args) assert isinstance(output_expr, gtir_to_tasklet.ValueExpr) output_desc = output_expr.node.desc(sdfg) @@ -193,7 +229,7 @@ def translate_as_field_op( # allocate local temporary storage for the result field field_node, field_type = _create_temporary_field( - sdfg, state, domain, node.type, output_desc, output_expr.field_type + sdfg, state, domain, node.type, output_desc, output_expr.dtype ) # assume tasklet with single output @@ -236,6 +272,7 @@ def translate_cond( sdfg: dace.SDFG, state: dace.SDFGState, sdfg_builder: gtir_to_sdfg.SDFGBuilder, + let_symbols: dict[str, LetSymbol], ) -> list[TemporaryData]: """Generates the dataflow subgraph for the `cond` builtin function.""" assert cpm.is_call_to(node, "cond") @@ -273,8 +310,18 @@ def translate_cond( sdfg.add_edge(cond_state, false_state, dace.InterstateEdge(condition=(f"not bool({cond})"))) sdfg.add_edge(false_state, state, dace.InterstateEdge()) - true_br_args = sdfg_builder.visit(true_expr, sdfg=sdfg, head_state=true_state) - false_br_args = sdfg_builder.visit(false_expr, sdfg=sdfg, head_state=false_state) + true_br_args = sdfg_builder.visit( + true_expr, + sdfg=sdfg, + head_state=true_state, + let_symbols=let_symbols, + ) + false_br_args = sdfg_builder.visit( + false_expr, + sdfg=sdfg, + head_state=false_state, + let_symbols=let_symbols, + ) output_nodes = [] for true_br, false_br in zip(true_br_args, false_br_args, strict=True): @@ -304,55 +351,89 @@ def translate_cond( return output_nodes +def _get_symbolic_value( + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_to_sdfg.SDFGBuilder, + symbolic_expr: dace.symbolic.SymExpr, + scalar_type: ts.ScalarType, + temp_name: Optional[str] = None, +) -> dace.nodes.AccessNode: + tasklet_node = sdfg_builder.add_tasklet( + "get_value", + state, + {}, + {"__out"}, + f"__out = {symbolic_expr}", + ) + temp_name, _ = sdfg.add_scalar( + f"__{temp_name or 'tmp'}", + dace_fieldview_util.as_dace_type(scalar_type), + find_new_name=True, + transient=True, + ) + data_node = state.add_access(temp_name) + state.add_edge( + tasklet_node, + "__out", + data_node, + None, + dace.Memlet(data=temp_name, subset="0"), + ) + return data_node + + +def translate_literal( + node: gtir.Node, + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_to_sdfg.SDFGBuilder, + let_symbols: dict[str, LetSymbol], +) -> list[TemporaryData]: + """Generates the dataflow subgraph for a `ir.Literal` node.""" + assert isinstance(node, gtir.Literal) + + data_type = node.type + data_node = _get_symbolic_value(sdfg, state, sdfg_builder, node.value, data_type) + + return [(data_node, data_type)] + + def translate_symbol_ref( node: gtir.Node, sdfg: dace.SDFG, state: dace.SDFGState, sdfg_builder: gtir_to_sdfg.SDFGBuilder, + let_symbols: dict[str, LetSymbol], ) -> list[TemporaryData]: """Generates the dataflow subgraph for a `ir.SymRef` node.""" - assert isinstance(node, (gtir.Literal, gtir.SymRef)) - - data_type: ts.FieldType | ts.ScalarType - if isinstance(node, gtir.Literal): - sym_value = node.value - data_type = node.type - temp_name = "literal" + assert isinstance(node, gtir.SymRef) + + sym_value = str(node.id) + if sym_value in let_symbols: + let_node, sym_type = let_symbols[sym_value] + if isinstance(let_node, gtir.Literal): + # this branch handles the case a let-symbol is mapped to some constant value + return sdfg_builder.visit(let_node) + # The `let_symbols` dictionary maps a `gtir.SymRef` string to a temporary + # data container. These symbols are visited and initialized in a state + # that preceeds the current state, therefore a new access node needs to + # be created in the state where they are accessed. + sym_value = str(let_node.id) else: - sym_value = str(node.id) - data_type = sdfg_builder.get_symbol_type(sym_value) - temp_name = sym_value + sym_type = sdfg_builder.get_symbol_type(sym_value) - if isinstance(data_type, ts.FieldType): - # add access node to current state + # Create new access node in current state. It is possible that multiple + # access nodes are created in one state for the same data container. + # We rely on the dace simplify pass to remove duplicated access nodes. + if isinstance(sym_type, ts.FieldType): sym_node = state.add_access(sym_value) - else: - # scalar symbols are passed to the SDFG as symbols: build tasklet node - # to write the symbol to a scalar access node - tasklet_node = sdfg_builder.add_tasklet( - f"get_{temp_name}", - state, - {}, - {"__out"}, - f"__out = {sym_value}", - ) - temp_name, _ = sdfg.add_scalar( - f"__{temp_name}", - dace_fieldview_util.as_dace_type(data_type), - find_new_name=True, - transient=True, - ) - sym_node = state.add_access(temp_name) - state.add_edge( - tasklet_node, - "__out", - sym_node, - None, - dace.Memlet(data=sym_node.data, subset="0"), + sym_node = _get_symbolic_value( + sdfg, state, sdfg_builder, sym_value, sym_type, temp_name=sym_value ) - return [(sym_node, data_type)] + return [(sym_node, sym_type)] if TYPE_CHECKING: diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index b0cf2215cf..53dca9e689 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -108,7 +108,7 @@ class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): """ offset_provider: dict[str, gtx_common.Connectivity | gtx_common.Dimension] - symbol_types: dict[str, ts.FieldType | ts.ScalarType] = dataclasses.field( + global_symbols: dict[str, ts.FieldType | ts.ScalarType] = dataclasses.field( default_factory=lambda: {} ) map_uids: eve.utils.UIDGenerator = dataclasses.field( @@ -119,12 +119,10 @@ class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): ) def get_offset_provider(self, offset: str) -> gtx_common.Connectivity | gtx_common.Dimension: - assert offset in self.offset_provider return self.offset_provider[offset] def get_symbol_type(self, symbol_name: str) -> ts.FieldType | ts.ScalarType: - assert symbol_name in self.symbol_types - return self.symbol_types[symbol_name] + return self.global_symbols[symbol_name] def unique_map_name(self, name: str) -> str: return f"{self.map_uids.sequential_id()}_{name}" @@ -184,7 +182,7 @@ def _add_storage( # TODO: unclear why mypy complains about incompatible types assert isinstance(symbol_type, (ts.FieldType, ts.ScalarType)) - self.symbol_types[name] = symbol_type + self.global_symbols[name] = symbol_type def _add_storage_for_temporary(self, temp_decl: gtir.Temporary) -> dict[str, str]: """ @@ -210,7 +208,7 @@ def _visit_expression( to have the same memory layout as the target array. """ results: list[gtir_builtin_translators.TemporaryData] = self.visit( - node, sdfg=sdfg, head_state=head_state + node, sdfg=sdfg, head_state=head_state, let_symbols={} ) field_nodes = [] @@ -303,7 +301,7 @@ def visit_SetAt(self, stmt: gtir.SetAt, sdfg: dace.SDFG, state: dace.SDFGState) for expr_node, target_node in zip(expr_nodes, target_nodes, strict=True): target_array = sdfg.arrays[target_node.data] assert not target_array.transient - target_symbol_type = self.symbol_types[target_node.data] + target_symbol_type = self.global_symbols[target_node.data] if isinstance(target_symbol_type, ts.FieldType): subset = ",".join( @@ -324,38 +322,102 @@ def visit_FunCall( node: gtir.FunCall, sdfg: dace.SDFG, head_state: dace.SDFGState, + let_symbols: dict[str, gtir_builtin_translators.LetSymbol], ) -> list[gtir_builtin_translators.TemporaryData]: # use specialized dataflow builder classes for each builtin function if cpm.is_call_to(node, "cond"): - return gtir_builtin_translators.translate_cond(node, sdfg, head_state, self) + return gtir_builtin_translators.translate_cond( + node, sdfg, head_state, self, let_symbols + ) elif cpm.is_call_to(node.fun, "as_fieldop"): - return gtir_builtin_translators.translate_as_field_op(node, sdfg, head_state, self) + return gtir_builtin_translators.translate_as_field_op( + node, sdfg, head_state, self, let_symbols + ) + elif isinstance(node.fun, gtir.Lambda): + # We use a separate state to ensure that the lambda arguments are evaluated + # before the computation starts. This is required in case the let-symbols + # are used in conditional branch execution, which happens in different states. + lambda_state = sdfg.add_state_before(head_state, f"{head_state.label}_symbols") + + node_args = [] + for arg in node.args: + node_args.extend( + self.visit( + arg, + sdfg=sdfg, + head_state=lambda_state, + let_symbols=let_symbols, + ) + ) + + # some cleanup: remove isolated nodes for program arguments in lambda state + isolated_node_args = [node for node, _ in node_args if lambda_state.degree(node) == 0] + assert all( + isinstance(node, dace.nodes.AccessNode) and node.data in self.global_symbols + for node in isolated_node_args + ) + lambda_state.remove_nodes_from(isolated_node_args) + + return self.visit( + node.fun, + sdfg=sdfg, + head_state=head_state, + let_symbols=let_symbols, + args=node_args, + ) else: raise NotImplementedError(f"Unexpected 'FunCall' expression ({node}).") - def visit_Lambda(self, node: gtir.Lambda) -> Any: + def visit_Lambda( + self, + node: gtir.Lambda, + sdfg: dace.SDFG, + head_state: dace.SDFGState, + let_symbols: dict[str, gtir_builtin_translators.LetSymbol], + args: list[gtir_builtin_translators.TemporaryData], + ) -> list[gtir_builtin_translators.TemporaryData]: """ - This visitor class should never encounter `itir.Lambda` expressions - because a lambda represents a stencil, which operates from iterator to values. - In fieldview, lambdas should only be arguments to field operators (`as_field_op`). + Translates a `Lambda` node to a tasklet subgraph in the current SDFG state. + + All arguments to lambda functions are fields (i.e. `as_fieldop`, field or scalar `gtir.SymRef`, + nested let-lambdas thereof). The dictionary called `let_symbols` maps the lambda parameters + to symbols, e.g. temporary fields or program arguments. If the lambda has a parameter whose name + is already present in `let_symbols`, i.e. a paramater with the same name as a previously defined + symbol, the parameter will shadow the previous symbol during traversal of the lambda expression. """ - raise RuntimeError("Unexpected 'itir.Lambda' node encountered in GTIR.") + lambda_symbols = let_symbols | { + str(p.id): (gtir.SymRef(id=temp_node.data), type_) + for p, (temp_node, type_) in zip(node.params, args, strict=True) + } + + return self.visit( + node.expr, + sdfg=sdfg, + head_state=head_state, + let_symbols=lambda_symbols, + ) def visit_Literal( self, node: gtir.Literal, sdfg: dace.SDFG, head_state: dace.SDFGState, + let_symbols: dict[str, gtir_builtin_translators.LetSymbol], ) -> list[gtir_builtin_translators.TemporaryData]: - return gtir_builtin_translators.translate_symbol_ref(node, sdfg, head_state, self) + return gtir_builtin_translators.translate_literal( + node, sdfg, head_state, self, let_symbols={} + ) def visit_SymRef( self, node: gtir.SymRef, sdfg: dace.SDFG, head_state: dace.SDFGState, + let_symbols: dict[str, gtir_builtin_translators.LetSymbol], ) -> list[gtir_builtin_translators.TemporaryData]: - return gtir_builtin_translators.translate_symbol_ref(node, sdfg, head_state, self) + return gtir_builtin_translators.translate_symbol_ref( + node, sdfg, head_state, self, let_symbols + ) def build_sdfg_from_gtir( diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py index f41f7cd500..188b19c577 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py @@ -25,6 +25,7 @@ from gt4py.next import common as gtx_common from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.iterator.type_system import type_specifications as gtir_ts from gt4py.next.program_processors.runners.dace_fieldview import ( gtir_python_codegen, gtir_to_sdfg, @@ -54,7 +55,7 @@ class ValueExpr: """Result of the computation implemented by a tasklet node.""" node: dace.nodes.AccessNode - field_type: ts.FieldType | ts.ScalarType + dtype: gtir_ts.ListType | ts.ScalarType # Define alias for the elements needed to setup input connections to a map scope @@ -70,17 +71,61 @@ class ValueExpr: @dataclasses.dataclass(frozen=True) class IteratorExpr: - """Iterator for field access to be consumed by `deref` or `shift` builtin functions.""" + """ + Iterator for field access to be consumed by `deref` or `shift` builtin functions. + + Args: + field: The field this iterator operates on. + dimensions: Field domain represented as a sorted list of dimensions. + In order to dereference an element in the field, we need index values + for all the dimensions in the right order. + indices: Maps each dimension to an index value, which could be either a symbolic value + or the result of a tasklet computation like neighbors connectivity or dynamic offset. + + """ field: dace.nodes.AccessNode dimensions: list[gtx_common.Dimension] indices: dict[gtx_common.Dimension, IteratorIndexExpr] +DACE_REDUCTION_MAPPING: dict[str, dace.dtypes.ReductionType] = { + "minimum": dace.dtypes.ReductionType.Min, + "maximum": dace.dtypes.ReductionType.Max, + "plus": dace.dtypes.ReductionType.Sum, + "multiplies": dace.dtypes.ReductionType.Product, + "and_": dace.dtypes.ReductionType.Logical_And, + "or_": dace.dtypes.ReductionType.Logical_Or, + "xor_": dace.dtypes.ReductionType.Logical_Xor, + "minus": dace.dtypes.ReductionType.Sub, + "divides": dace.dtypes.ReductionType.Div, +} + + +def get_reduce_params(node: gtir.FunCall) -> tuple[str, SymbolExpr, SymbolExpr]: + assert node.type + dtype = dace_fieldview_util.as_dace_type(node.type) + + assert isinstance(node.fun, gtir.FunCall) + assert len(node.fun.args) == 2 + assert isinstance(node.fun.args[0], gtir.SymRef) + op_name = str(node.fun.args[0]) + assert isinstance(node.fun.args[1], gtir.Literal) + assert node.fun.args[1].type == node.type + reduce_init = SymbolExpr(node.fun.args[1].value, dtype) + + if op_name not in DACE_REDUCTION_MAPPING: + raise RuntimeError(f"Reduction operation '{op_name}' not supported.") + identity_value = dace.dtypes.reduction_identity(dtype, DACE_REDUCTION_MAPPING[op_name]) + reduce_identity = SymbolExpr(identity_value, dtype) + + return op_name, reduce_init, reduce_identity + + class LambdaToTasklet(eve.NodeVisitor): """Translates an `ir.Lambda` expression to a dataflow graph. - Lambda functions should only be encountered as argument to the `as_field_op` + Lambda functions should only be encountered as argument to the `as_fieldop` builtin function, therefore the dataflow graph generated here typically represents the stencil function of a field operator. """ @@ -88,6 +133,7 @@ class LambdaToTasklet(eve.NodeVisitor): sdfg: dace.SDFG state: dace.SDFGState subgraph_builder: gtir_to_sdfg.DataflowBuilder + reduce_identity: Optional[SymbolExpr] input_connections: list[InputConnection] symbol_map: dict[str, IteratorExpr | MemletExpr | SymbolExpr] @@ -96,10 +142,12 @@ def __init__( sdfg: dace.SDFG, state: dace.SDFGState, subgraph_builder: gtir_to_sdfg.DataflowBuilder, + reduce_identity: Optional[SymbolExpr], ): self.sdfg = sdfg self.state = state self.subgraph_builder = subgraph_builder + self.reduce_identity = reduce_identity self.input_connections = [] self.symbol_map = {} @@ -191,7 +239,12 @@ def _visit_deref(self, node: gtir.FunCall) -> MemletExpr | ValueExpr: assert len(field_desc.shape) == len(it.dimensions) if all(isinstance(index, SymbolExpr) for index in it.indices.values()): # when all indices are symblic expressions, we can perform direct field access through a memlet - field_subset = sbs.Indices([it.indices[dim].value for dim in it.dimensions]) # type: ignore[union-attr] + field_subset = sbs.Range( + (it.indices[dim].value, it.indices[dim].value, 1) # type: ignore[union-attr] + if dim in it.indices + else (0, size - 1, 1) + for dim, size in zip(it.dimensions, field_desc.shape) + ) return MemletExpr(it.field, field_subset) else: @@ -256,6 +309,192 @@ def _visit_deref(self, node: gtir.FunCall) -> MemletExpr | ValueExpr: assert isinstance(it, MemletExpr) return it + def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: + assert len(node.args) == 2 + + assert isinstance(node.args[0], gtir.OffsetLiteral) + offset = node.args[0].value + assert isinstance(offset, str) + offset_provider = self.subgraph_builder.get_offset_provider(offset) + assert isinstance(offset_provider, gtx_common.Connectivity) + + it = self.visit(node.args[1]) + assert isinstance(it, IteratorExpr) + assert offset_provider.neighbor_axis in it.dimensions + neighbor_dim_index = it.dimensions.index(offset_provider.neighbor_axis) + assert offset_provider.neighbor_axis not in it.indices + assert offset_provider.origin_axis not in it.dimensions + assert offset_provider.origin_axis in it.indices + origin_index = it.indices[offset_provider.origin_axis] + assert isinstance(origin_index, SymbolExpr) + assert all(isinstance(index, SymbolExpr) for index in it.indices.values()) + + field_desc = it.field.desc(self.sdfg) + connectivity = dace_fieldview_util.connectivity_identifier(offset) + # initially, the storage for the connectivty tables is created as transient; + # when the tables are used, the storage is changed to non-transient, + # as the corresponding arrays are supposed to be allocated by the SDFG caller + connectivity_desc = self.sdfg.arrays[connectivity] + connectivity_desc.transient = False + + # The visitor is constructing a list of input connections that will be handled + # by `translate_as_fieldop` (the primitive translator), that is responsible + # of creating the map for the field domain. For each input connection, it will + # create a memlet that will write to a node specified by the third attribute + # in the `InputConnection` tuple (either a tasklet, or a view node, or a library + # node). For the specific case of `neighbors` we need to nest the neighbors map + # inside the field map and the memlets will traverse the external map and write + # to the view nodes. The simplify pass will remove the redundant access nodes. + field_slice_view, field_slice_desc = self.sdfg.add_view( + f"{offset_provider.neighbor_axis.value}_view", + (field_desc.shape[neighbor_dim_index],), + field_desc.dtype, + strides=(field_desc.strides[neighbor_dim_index],), + find_new_name=True, + ) + field_slice_node = self.state.add_access(field_slice_view) + field_subset = ",".join( + it.indices[dim].value # type: ignore[union-attr] + if dim != offset_provider.neighbor_axis + else f"0:{size}" + for dim, size in zip(it.dimensions, field_desc.shape, strict=True) + ) + self._add_entry_memlet_path( + it.field, + sbs.Range.from_string(field_subset), + field_slice_node, + ) + + connectivity_slice_view, _ = self.sdfg.add_view( + "neighbors_view", + (offset_provider.max_neighbors,), + connectivity_desc.dtype, + strides=(connectivity_desc.strides[1],), + find_new_name=True, + ) + connectivity_slice_node = self.state.add_access(connectivity_slice_view) + self._add_entry_memlet_path( + self.state.add_access(connectivity), + sbs.Range.from_string(f"{origin_index.value}, 0:{offset_provider.max_neighbors}"), + connectivity_slice_node, + ) + + neighbors_temp, _ = self.sdfg.add_temp_transient( + (offset_provider.max_neighbors,), field_desc.dtype + ) + neighbors_node = self.state.add_access(neighbors_temp) + + offset_dim = gtx_common.Dimension(offset) + neighbor_idx = dace_fieldview_util.get_map_variable(offset_dim) + me, mx = self._add_map( + f"{offset}_neighbors", + { + neighbor_idx: f"0:{offset_provider.max_neighbors}", + }, + ) + index_connector = "__index" + if offset_provider.has_skip_values: + assert self.reduce_identity is not None + assert self.reduce_identity.dtype == field_desc.dtype + # TODO: Investigate if a NestedSDFG brings benefits + tasklet_node = self._add_tasklet( + "gather_neighbors_with_skip_values", + {"__field", index_connector}, + {"__val"}, + f"__val = __field[{index_connector}] if {index_connector} != {gtx_common._DEFAULT_SKIP_VALUE} else {self.reduce_identity.dtype}({self.reduce_identity.value})", + ) + + else: + tasklet_node = self._add_tasklet( + "gather_neighbors", + {"__field", index_connector}, + {"__val"}, + f"__val = __field[{index_connector}]", + ) + + self.state.add_memlet_path( + field_slice_node, + me, + tasklet_node, + dst_conn="__field", + memlet=dace.Memlet.from_array(field_slice_view, field_slice_desc), + ) + self.state.add_memlet_path( + connectivity_slice_node, + me, + tasklet_node, + dst_conn=index_connector, + memlet=dace.Memlet(data=connectivity_slice_view, subset=neighbor_idx), + ) + self.state.add_memlet_path( + tasklet_node, + mx, + neighbors_node, + src_conn="__val", + memlet=dace.Memlet(data=neighbors_temp, subset=neighbor_idx), + ) + + assert isinstance(node.type, gtir_ts.ListType) + return ValueExpr(neighbors_node, node.type) + + def _visit_reduce(self, node: gtir.FunCall) -> ValueExpr: + op_name, reduce_init, reduce_identity = get_reduce_params(node) + dtype = reduce_identity.dtype + + # We store the value of reduce identity in the visitor context while visiting + # the input to reduction; this value will be use by the `neighbors` visitor + # to fill the skip values in the neighbors list. + prev_reduce_identity = self.reduce_identity + self.reduce_identity = reduce_identity + + try: + input_expr = self.visit(node.args[0]) + finally: + # ensure that we leave the visitor in the same state as we entered + self.reduce_identity = prev_reduce_identity + + assert isinstance(input_expr, MemletExpr | ValueExpr) + input_desc = input_expr.node.desc(self.sdfg) + assert isinstance(input_desc, dace.data.Array) + + if len(input_desc.shape) > 1: + assert isinstance(input_expr, MemletExpr) + ndims = len(input_desc.shape) - 1 + # the axis to be reduced is always the last one, because `reduce` is supposed + # to operate on `ListType` + assert set(input_expr.subset.size()[0:ndims]) == {1} + reduce_axes = [ndims] + else: + reduce_axes = None + + reduce_wcr = "lambda x, y: " + gtir_python_codegen.format_builtin(op_name, "x", "y") + reduce_node = self.state.add_reduce(reduce_wcr, reduce_axes, reduce_init.value) + + if isinstance(input_expr, MemletExpr): + self._add_entry_memlet_path( + input_expr.node, + input_expr.subset, + reduce_node, + ) + else: + self.state.add_nedge( + input_expr.node, + reduce_node, + dace.Memlet.from_array(input_expr.node.data, input_desc), + ) + + temp_name = self.sdfg.temp_data_name() + self.sdfg.add_scalar(temp_name, dtype, transient=True) + temp_node = self.state.add_access(temp_name) + + self.state.add_nedge( + reduce_node, + temp_node, + dace.Memlet(data=temp_name, subset="0"), + ) + assert isinstance(node.type, ts.ScalarType) + return ValueExpr(temp_node, node.type) + def _split_shift_args( self, args: list[gtir.Expr] ) -> tuple[tuple[gtir.Expr, gtir.Expr], Optional[list[gtir.Expr]]]: @@ -472,6 +711,12 @@ def visit_FunCall(self, node: gtir.FunCall) -> IteratorExpr | MemletExpr | Value if cpm.is_call_to(node, "deref"): return self._visit_deref(node) + elif cpm.is_call_to(node, "neighbors"): + return self._visit_neighbors(node) + + elif cpm.is_applied_reduce(node): + return self._visit_reduce(node) + elif cpm.is_applied_shift(node): return self._visit_shift(node) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index 8c419c9dd0..e8a2eddb99 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -21,14 +21,22 @@ from gt4py.next import common as gtx_common from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.iterator.type_system import type_specifications as gtir_ts from gt4py.next.program_processors.runners.dace_fieldview import gtir_python_codegen from gt4py.next.type_system import type_specifications as ts def as_dace_type(type_: ts.TypeSpec) -> dace.typeclass: """Converts GT4Py scalar type to corresponding DaCe type.""" - assert isinstance(type_, ts.ScalarType) - match type_.kind: + if isinstance(type_, ts.ScalarType): + scalar_type = type_ + elif isinstance(type_, gtir_ts.ListType): + assert isinstance(type_.element_type, ts.ScalarType) + scalar_type = type_.element_type + else: + raise NotImplementedError + + match scalar_type.kind: case ts.ScalarKind.BOOL: return dace.bool_ case ts.ScalarKind.INT32: @@ -40,7 +48,7 @@ def as_dace_type(type_: ts.TypeSpec) -> dace.typeclass: case ts.ScalarKind.FLOAT64: return dace.float64 case _: - raise ValueError(f"Scalar type '{type_}' not supported.") + raise ValueError(f"Scalar type '{scalar_type}' not supported.") def as_scalar_type(typestr: str) -> ts.ScalarType: diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py index e0d820331b..a4d04511fa 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_dace_fieldview.py @@ -18,9 +18,11 @@ """ import copy +import functools from gt4py.next import common as gtx_common from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.type_system import type_specifications as gtir_ts from gt4py.next.program_processors.runners import dace_fieldview as dace_backend from gt4py.next.type_system import type_specifications as ts from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( @@ -28,8 +30,10 @@ Edge, IDim, MeshDescriptor, + V2EDim, Vertex, simple_mesh, + skip_value_mesh, ) import numpy as np import pytest @@ -42,6 +46,7 @@ CFTYPE = ts.FieldType(dims=[Cell], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) EFTYPE = ts.FieldType(dims=[Edge], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) VFTYPE = ts.FieldType(dims=[Vertex], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) +V2E_FTYPE = ts.FieldType(dims=[Vertex, V2EDim], dtype=EFTYPE.dtype) CARTESIAN_OFFSETS = { "IDim": IDim, } @@ -49,6 +54,10 @@ SIMPLE_MESH_OFFSET_PROVIDER: dict[str, gtx_common.Connectivity | gtx_common.Dimension] = ( SIMPLE_MESH.offset_provider | CARTESIAN_OFFSETS ) +SKIP_VALUE_MESH: MeshDescriptor = skip_value_mesh() +SKIP_VALUE_MESH_OFFSET_PROVIDER: dict[str, gtx_common.Connectivity | gtx_common.Dimension] = ( + SKIP_VALUE_MESH.offset_provider | CARTESIAN_OFFSETS +) SIZE_TYPE = ts.ScalarType(ts.ScalarKind.INT32) FSYMBOLS = dict( __w_size_0=N, @@ -132,10 +141,7 @@ def test_gtir_update(): im.lambda_("a")(im.plus(im.deref("a"), 1.0)), domain, )("x") - stencil2 = im.as_fieldop( - im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), - domain, - )("x", 1.0) + stencil2 = im.op_as_fieldop("plus", domain)("x", 1.0) for i, stencil in enumerate([stencil1, stencil2]): testee = gtir.Program( @@ -179,10 +185,7 @@ def test_gtir_sum2(): declarations=[], body=[ gtir.SetAt( - expr=im.as_fieldop( - im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), - domain, - )("x", "y"), + expr=im.op_as_fieldop("plus", domain)("x", "y"), domain=domain, target=gtir.SymRef(id="z"), ) @@ -214,10 +217,7 @@ def test_gtir_sum2_sym(): declarations=[], body=[ gtir.SetAt( - expr=im.as_fieldop( - im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), - domain, - )("x", "x"), + expr=im.op_as_fieldop("plus", domain)("x", "x"), domain=domain, target=gtir.SymRef(id="z"), ) @@ -237,15 +237,9 @@ def test_gtir_sum3(): domain = im.call("cartesian_domain")( im.call("named_range")(gtir.AxisLiteral(value=IDim.value), 0, "size") ) - stencil1 = im.as_fieldop( - im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), - domain, - )( + stencil1 = im.op_as_fieldop("plus", domain)( "x", - im.as_fieldop( - im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), - domain, - )("y", "w"), + im.op_as_fieldop("plus", domain)("y", "w"), ) stencil2 = im.as_fieldop( im.lambda_("a", "b", "c")(im.plus(im.deref("a"), im.plus(im.deref("b"), im.deref("c")))), @@ -304,21 +298,12 @@ def test_gtir_cond(): declarations=[], body=[ gtir.SetAt( - expr=im.as_fieldop( - im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), - domain, - )( + expr=im.op_as_fieldop("plus", domain)( "x", im.call("cond")( gtir.SymRef(id="pred"), - im.as_fieldop( - im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), - domain, - )("y", "scalar"), - im.as_fieldop( - im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), - domain, - )("w", "scalar"), + im.op_as_fieldop("plus", domain)("y", "scalar"), + im.op_as_fieldop("plus", domain)("w", "scalar"), ), ), domain=domain, @@ -358,20 +343,11 @@ def test_gtir_cond_nested(): gtir.SetAt( expr=im.call("cond")( gtir.SymRef(id="pred_1"), - im.as_fieldop( - im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), - domain, - )("x", 1.0), + im.op_as_fieldop("plus", domain)("x", 1.0), im.call("cond")( gtir.SymRef(id="pred_2"), - im.as_fieldop( - im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), - domain, - )("x", 2.0), - im.as_fieldop( - im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), - domain, - )("x", 3.0), + im.op_as_fieldop("plus", domain)("x", 2.0), + im.op_as_fieldop("plus", domain)("x", 3.0), ), ), domain=domain, @@ -408,18 +384,12 @@ def test_gtir_cartesian_shift_left(): domain, )("x") # fieldview flavor of same stencil, in which a temporary field is initialized with the `DELTA` constant value - stencil1_fieldview = im.as_fieldop( - im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), - domain, - )( + stencil1_fieldview = im.op_as_fieldop("plus", domain)( im.as_fieldop( im.lambda_("a")(im.deref(im.shift("IDim", OFFSET)("a"))), domain, )("x"), - im.as_fieldop( - im.lambda_()(DELTA), - domain, - )(), + im.as_fieldop(im.lambda_()(DELTA), domain)(), ) # use dynamic offset retrieved from field @@ -428,10 +398,7 @@ def test_gtir_cartesian_shift_left(): domain, )("x", "x_offset") # fieldview flavor of same stencil - stencil2_fieldview = im.as_fieldop( - im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), - domain, - )( + stencil2_fieldview = im.op_as_fieldop("plus", domain)( im.as_fieldop( im.lambda_("a", "off")(im.deref(im.shift("IDim", im.deref("off"))("a"))), domain, @@ -447,19 +414,13 @@ def test_gtir_cartesian_shift_left(): domain, )("x", "x_offset") # fieldview flavor of same stencil - stencil3_fieldview = im.as_fieldop( - im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), - domain, - )( + stencil3_fieldview = im.op_as_fieldop("plus", domain)( im.as_fieldop( im.lambda_("a", "off")(im.deref(im.shift("IDim", im.deref("off"))("a"))), domain, )( "x", - im.as_fieldop( - im.lambda_("it")(im.plus(im.deref("it"), 0)), - domain, - )("x_offset"), + im.op_as_fieldop("plus", domain)("x_offset", 0), ), im.as_fieldop(im.lambda_()(DELTA), domain)(), ) @@ -520,10 +481,7 @@ def test_gtir_cartesian_shift_right(): domain, )("x") # fieldview flavor of same stencil, in which a temporary field is initialized with the `DELTA` constant value - stencil1_fieldview = im.as_fieldop( - im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), - domain, - )( + stencil1_fieldview = im.op_as_fieldop("plus", domain)( im.as_fieldop( im.lambda_("a")(im.deref(im.shift("IDim", -OFFSET)("a"))), domain, @@ -537,10 +495,7 @@ def test_gtir_cartesian_shift_right(): domain, )("x", "x_offset") # fieldview flavor of same stencil - stencil2_fieldview = im.as_fieldop( - im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), - domain, - )( + stencil2_fieldview = im.op_as_fieldop("plus", domain)( im.as_fieldop( im.lambda_("a", "off")(im.deref(im.shift("IDim", im.deref("off"))("a"))), domain, @@ -556,19 +511,13 @@ def test_gtir_cartesian_shift_right(): domain, )("x", "x_offset") # fieldview flavor of same stencil - stencil3_fieldview = im.as_fieldop( - im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), - domain, - )( + stencil3_fieldview = im.op_as_fieldop("plus", domain)( im.as_fieldop( im.lambda_("a", "off")(im.deref(im.shift("IDim", im.deref("off"))("a"))), domain, )( "x", - im.as_fieldop( - im.lambda_("it")(im.plus(im.deref("it"), 0)), - domain, - )("x_offset"), + im.op_as_fieldop("plus", domain)("x_offset", 0), ), im.as_fieldop(im.lambda_()(DELTA), domain)(), ) @@ -699,12 +648,12 @@ def test_gtir_connectivity_shift(): )( "ev_field", "c2e_offset", - im.as_fieldop( - im.lambda_("it")(im.plus(im.deref("it"), 0)), + im.op_as_fieldop( + "plus", im.call("unstructured_domain")( im.call("named_range")(gtir.AxisLiteral(value=Edge.value), 0, "nedges"), ), - )("e2v_offset"), + )("e2v_offset", 0), ) CE_FTYPE = ts.FieldType(dims=[Cell, Edge], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) @@ -834,3 +783,532 @@ def test_gtir_connectivity_shift_chain(): __edges_out_stride_0=1, ) assert np.allclose(e_out, ref) + + +def test_gtir_neighbors_as_input(): + # FIXME[#1582](edopao): Enable testcase when type inference is working + pytest.skip("Field of lists not fully supported by GTIR type inference") + init_value = np.random.rand() + vertex_domain = im.call("unstructured_domain")( + im.call("named_range")(gtir.AxisLiteral(value=Vertex.value), 0, "nvertices"), + ) + testee = gtir.Program( + id=f"neighbors_as_input", + function_definitions=[], + params=[ + gtir.Sym(id="v2e_field", type=V2E_FTYPE), + gtir.Sym(id="vertex", type=EFTYPE), + gtir.Sym(id="nvertices", type=SIZE_TYPE), + ], + declarations=[], + body=[ + gtir.SetAt( + expr=im.call( + im.call("as_fieldop")( + im.lambda_("it")( + im.call(im.call("reduce")("plus", im.literal_from_value(init_value)))( + "it" + ) + ), + vertex_domain, + ) + )("v2e_field"), + domain=vertex_domain, + target=gtir.SymRef(id="vertex"), + ) + ], + ) + + sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH_OFFSET_PROVIDER) + + connectivity_V2E = SIMPLE_MESH_OFFSET_PROVIDER["V2E"] + assert isinstance(connectivity_V2E, gtx_common.NeighborTable) + + v2e_field = np.random.rand(SIMPLE_MESH.num_vertices, connectivity_V2E.max_neighbors) + v = np.empty(SIMPLE_MESH.num_vertices, dtype=v2e_field.dtype) + + v_ref = [ + functools.reduce(lambda x, y: x + y, v2e_neighbors, init_value) + for v2e_neighbors in v2e_field + ] + + sdfg( + v2e_field=v2e_field, + vertex=v, + **FSYMBOLS, + **make_mesh_symbols(SIMPLE_MESH), + __v2e_field_size_0=SIMPLE_MESH.num_vertices, + __v2e_field_size_1=connectivity_V2E.max_neighbors, + __v2e_field_stride_0=connectivity_V2E.max_neighbors, + __v2e_field_stride_1=1, + ) + assert np.allclose(v, v_ref) + + +def test_gtir_neighbors_as_output(): + # FIXME[#1582](edopao): Enable testcase when type inference is working + pytest.skip("Field of lists not fully supported by GTIR type inference") + vertex_domain = im.call("unstructured_domain")( + im.call("named_range")(gtir.AxisLiteral(value=Vertex.value), 0, "nvertices"), + ) + v2e_domain = im.call("unstructured_domain")( + im.call("named_range")(gtir.AxisLiteral(value=Vertex.value), 0, "nvertices"), + im.call("named_range")( + gtir.AxisLiteral(value=V2EDim.value), + 0, + SIMPLE_MESH_OFFSET_PROVIDER["V2E"].max_neighbors, + ), + ) + testee = gtir.Program( + id=f"neighbors_as_output", + function_definitions=[], + params=[ + gtir.Sym(id="edges", type=EFTYPE), + gtir.Sym(id="v2e_field", type=V2E_FTYPE), + gtir.Sym(id="nvertices", type=SIZE_TYPE), + ], + declarations=[], + body=[ + gtir.SetAt( + expr=im.call( + im.call("as_fieldop")( + im.lambda_("it")(im.neighbors("V2E", "it")), + vertex_domain, + ) + )("edges"), + domain=v2e_domain, + target=gtir.SymRef(id="v2e_field"), + ) + ], + ) + + sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH_OFFSET_PROVIDER) + + connectivity_V2E = SIMPLE_MESH_OFFSET_PROVIDER["V2E"] + assert isinstance(connectivity_V2E, gtx_common.NeighborTable) + + e = np.random.rand(SIMPLE_MESH.num_edges) + v2e_field = np.empty([SIMPLE_MESH.num_vertices, connectivity_V2E.max_neighbors], dtype=e.dtype) + + sdfg( + edges=e, + v2e_field=v2e_field, + connectivity_V2E=connectivity_V2E.table, + **FSYMBOLS, + **make_mesh_symbols(SIMPLE_MESH), + __v2e_field_size_0=SIMPLE_MESH.num_vertices, + __v2e_field_size_1=connectivity_V2E.max_neighbors, + __v2e_field_stride_0=connectivity_V2E.max_neighbors, + __v2e_field_stride_1=1, + ) + assert np.allclose(v2e_field, e[connectivity_V2E.table]) + + +def test_gtir_reduce(): + init_value = np.random.rand() + vertex_domain = im.call("unstructured_domain")( + im.call("named_range")(gtir.AxisLiteral(value=Vertex.value), 0, "nvertices"), + ) + stencil_inlined = im.call( + im.call("as_fieldop")( + im.lambda_("it")( + im.call(im.call("reduce")("plus", im.literal_from_value(init_value)))( + im.neighbors("V2E", "it") + ) + ), + vertex_domain, + ) + )("edges") + stencil_fieldview = im.call( + im.call("as_fieldop")( + im.lambda_("it")( + im.call(im.call("reduce")("plus", im.literal_from_value(init_value)))( + im.deref("it") + ) + ), + vertex_domain, + ) + )( + im.call( + im.call("as_fieldop")( + im.lambda_("it")(im.neighbors("V2E", "it")), + vertex_domain, + ) + )("edges") + ) + + connectivity_V2E = SIMPLE_MESH_OFFSET_PROVIDER["V2E"] + assert isinstance(connectivity_V2E, gtx_common.NeighborTable) + + e = np.random.rand(SIMPLE_MESH.num_edges) + v_ref = [ + functools.reduce(lambda x, y: x + y, e[v2e_neighbors], init_value) + for v2e_neighbors in connectivity_V2E.table + ] + + for i, stencil in enumerate([stencil_inlined, stencil_fieldview]): + testee = gtir.Program( + id=f"reduce_{i}", + function_definitions=[], + params=[ + gtir.Sym(id="edges", type=EFTYPE), + gtir.Sym(id="vertices", type=VFTYPE), + gtir.Sym(id="nvertices", type=SIZE_TYPE), + ], + declarations=[], + body=[ + gtir.SetAt( + expr=stencil, + domain=vertex_domain, + target=gtir.SymRef(id="vertices"), + ) + ], + ) + sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH_OFFSET_PROVIDER) + + # new empty output field + v = np.empty(SIMPLE_MESH.num_vertices, dtype=e.dtype) + + sdfg( + edges=e, + vertices=v, + connectivity_V2E=connectivity_V2E.table, + **FSYMBOLS, + **make_mesh_symbols(SIMPLE_MESH), + ) + assert np.allclose(v, v_ref) + + +def test_gtir_reduce_with_skip_values(): + init_value = np.random.rand() + vertex_domain = im.call("unstructured_domain")( + im.call("named_range")(gtir.AxisLiteral(value=Vertex.value), 0, "nvertices"), + ) + stencil_inlined = im.call( + im.call("as_fieldop")( + im.lambda_("it")( + im.call(im.call("reduce")("plus", im.literal_from_value(init_value)))( + im.neighbors("V2E", "it") + ) + ), + vertex_domain, + ) + )("edges") + stencil_fieldview = im.call( + im.call("as_fieldop")( + im.lambda_("it")( + im.call(im.call("reduce")("plus", im.literal_from_value(init_value)))( + im.deref("it") + ) + ), + vertex_domain, + ) + )( + im.call( + im.call("as_fieldop")( + im.lambda_("it")(im.neighbors("V2E", "it")), + vertex_domain, + ) + )("edges") + ) + + connectivity_V2E = SKIP_VALUE_MESH_OFFSET_PROVIDER["V2E"] + assert isinstance(connectivity_V2E, gtx_common.NeighborTable) + + e = np.random.rand(SKIP_VALUE_MESH.num_edges) + v_ref = [ + functools.reduce( + lambda x, y: x + y, [e[i] if i != -1 else 0.0 for i in v2e_neighbors], init_value + ) + for v2e_neighbors in connectivity_V2E.table + ] + + for i, stencil in enumerate([stencil_inlined, stencil_fieldview]): + testee = gtir.Program( + id=f"reduce_with_skip_values_{i}", + function_definitions=[], + params=[ + gtir.Sym(id="edges", type=EFTYPE), + gtir.Sym(id="vertices", type=VFTYPE), + gtir.Sym(id="nvertices", type=SIZE_TYPE), + ], + declarations=[], + body=[ + gtir.SetAt( + expr=stencil, + domain=vertex_domain, + target=gtir.SymRef(id="vertices"), + ) + ], + ) + sdfg = dace_backend.build_sdfg_from_gtir(testee, SKIP_VALUE_MESH_OFFSET_PROVIDER) + + # new empty output field + v = np.empty(SKIP_VALUE_MESH.num_vertices, dtype=e.dtype) + + sdfg( + edges=e, + vertices=v, + connectivity_V2E=connectivity_V2E.table, + **FSYMBOLS, + **make_mesh_symbols(SKIP_VALUE_MESH), + ) + assert np.allclose(v, v_ref) + + +def test_gtir_reduce_dot_product(): + # FIXME[#1582](edopao): Enable testcase when type inference is working + pytest.skip("Field of lists not fully supported as a type in GTIR yet") + init_value = np.random.rand() + vertex_domain = im.call("unstructured_domain")( + im.call("named_range")(gtir.AxisLiteral(value=Vertex.value), 0, "nvertices"), + ) + v2e_domain = im.call("unstructured_domain")( + im.call("named_range")(gtir.AxisLiteral(value=Vertex.value), 0, "nvertices"), + im.call("named_range")( + gtir.AxisLiteral(value=V2EDim.value), + 0, + SIMPLE_MESH_OFFSET_PROVIDER["V2E"].max_neighbors, + ), + ) + + testee = gtir.Program( + id=f"reduce_dot_product", + function_definitions=[], + params=[ + gtir.Sym(id="edges", type=EFTYPE), + gtir.Sym(id="vertices", type=VFTYPE), + gtir.Sym(id="nvertices", type=SIZE_TYPE), + ], + declarations=[], + body=[ + gtir.SetAt( + expr=im.call( + im.call("as_fieldop")( + im.lambda_("it")( + im.call(im.call("reduce")("plus", im.literal_from_value(init_value)))( + im.deref("it") + ) + ), + vertex_domain, + ) + )( + im.op_as_fieldop("multiplies", vertex_domain)( + im.call( + im.call("as_fieldop")( + im.lambda_("it")(im.neighbors("V2E", "it")), + vertex_domain, + ) + )("edges"), + im.call( + im.call("as_fieldop")( + im.lambda_("it")(im.neighbors("V2E", "it")), + vertex_domain, + ) + )("edges"), + ), + ), + domain=vertex_domain, + target=gtir.SymRef(id="vertices"), + ) + ], + ) + + connectivity_V2E = SIMPLE_MESH_OFFSET_PROVIDER["V2E"] + assert isinstance(connectivity_V2E, gtx_common.NeighborTable) + + sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH_OFFSET_PROVIDER) + + e = np.random.rand(SIMPLE_MESH.num_edges) + v = np.empty(SIMPLE_MESH.num_vertices, dtype=e.dtype) + v_ref = [ + reduce(lambda x, y: x + y, e[v2e_neighbors] * e[v2e_neighbors], init_value) + for v2e_neighbors in connectivity_V2E.table + ] + + sdfg( + edges=e, + vertices=v, + connectivity_V2E=connectivity_V2E.table, + **FSYMBOLS, + **make_mesh_symbols(SIMPLE_MESH), + ) + assert np.allclose(v, v_ref) + + +def test_gtir_reduce_with_cond_neighbors(): + init_value = np.random.rand() + vertex_domain = im.call("unstructured_domain")( + im.call("named_range")(gtir.AxisLiteral(value=Vertex.value), 0, "nvertices"), + ) + testee = gtir.Program( + id=f"reduce_with_cond_neighbors", + function_definitions=[], + params=[ + gtir.Sym(id="pred", type=ts.ScalarType(ts.ScalarKind.BOOL)), + gtir.Sym(id="edges", type=EFTYPE), + gtir.Sym(id="vertices", type=VFTYPE), + gtir.Sym(id="nvertices", type=SIZE_TYPE), + ], + declarations=[], + body=[ + gtir.SetAt( + expr=im.as_fieldop( + im.lambda_("it")( + im.call(im.call("reduce")("plus", im.literal_from_value(init_value)))( + im.deref("it") + ) + ), + vertex_domain, + )( + im.call("cond")( + gtir.SymRef(id="pred"), + im.as_fieldop( + im.lambda_("it")(im.neighbors("V2E_FULL", "it")), + vertex_domain, + )("edges"), + im.as_fieldop( + im.lambda_("it")(im.neighbors("V2E", "it")), + vertex_domain, + )("edges"), + ) + ), + domain=vertex_domain, + target=gtir.SymRef(id="vertices"), + ) + ], + ) + + connectivity_V2E_simple = SIMPLE_MESH_OFFSET_PROVIDER["V2E"] + assert isinstance(connectivity_V2E_simple, gtx_common.NeighborTable) + connectivity_V2E_skip_values = copy.deepcopy(SKIP_VALUE_MESH_OFFSET_PROVIDER["V2E"]) + assert isinstance(connectivity_V2E_skip_values, gtx_common.NeighborTable) + assert SKIP_VALUE_MESH.num_vertices <= SIMPLE_MESH.num_vertices + connectivity_V2E_skip_values.table = np.concatenate( + ( + connectivity_V2E_skip_values.table[:, 0 : connectivity_V2E_simple.max_neighbors], + connectivity_V2E_simple.table[SKIP_VALUE_MESH.num_vertices :, :], + ), + axis=0, + ) + connectivity_V2E_skip_values.max_neighbors = connectivity_V2E_simple.max_neighbors + + e = np.random.rand(SIMPLE_MESH.num_edges) + + for use_full in [False, True]: + sdfg = dace_backend.build_sdfg_from_gtir( + testee, + SIMPLE_MESH_OFFSET_PROVIDER | {"V2E_FULL": connectivity_V2E_skip_values}, + ) + + v = np.empty(SIMPLE_MESH.num_vertices, dtype=e.dtype) + v_ref = [ + functools.reduce( + lambda x, y: x + y, [e[i] if i != -1 else 0.0 for i in v2e_neighbors], init_value + ) + for v2e_neighbors in ( + connectivity_V2E_simple.table if use_full else connectivity_V2E_skip_values.table + ) + ] + sdfg( + pred=np.bool_(use_full), + edges=e, + vertices=v, + connectivity_V2E=connectivity_V2E_skip_values.table, + connectivity_V2E_FULL=connectivity_V2E_simple.table, + **FSYMBOLS, + **make_mesh_symbols(SIMPLE_MESH), + __connectivity_V2E_FULL_size_0=SIMPLE_MESH.num_edges, + __connectivity_V2E_FULL_size_1=connectivity_V2E_skip_values.max_neighbors, + __connectivity_V2E_FULL_stride_0=connectivity_V2E_skip_values.max_neighbors, + __connectivity_V2E_FULL_stride_1=1, + ) + assert np.allclose(v, v_ref) + + +def test_gtir_let_lambda(): + domain = im.call("cartesian_domain")( + im.call("named_range")(gtir.AxisLiteral(value=IDim.value), 0, "size") + ) + testee = gtir.Program( + id="let_lambda", + function_definitions=[], + params=[ + gtir.Sym(id="x", type=IFTYPE), + gtir.Sym(id="y", type=IFTYPE), + gtir.Sym(id="size", type=SIZE_TYPE), + ], + declarations=[], + body=[ + gtir.SetAt( + # `x1` is a let-lambda expression representing `x * 3` + # `x2` is a let-lambda expression representing `x * 4` + # - note that the let-symbol `x2` is used twice, in a nested let-expression, to test aliasing of the symbol + # `x3` is a let-lambda expression simply accessing `x` field symref + expr=im.let("x1", im.op_as_fieldop("multiplies", domain)(3.0, "x"))( + im.let( + "x2", + im.let("x2", im.op_as_fieldop("multiplies", domain)(2.0, "x"))( + im.op_as_fieldop("plus", domain)("x2", "x2") + ), + )( + im.let("x3", "x")( + im.op_as_fieldop("plus", domain)( + "x1", im.op_as_fieldop("plus", domain)("x2", "x3") + ) + ) + ) + ), + domain=domain, + target=gtir.SymRef(id="y"), + ) + ], + ) + + a = np.random.rand(N) + b = np.empty_like(a) + + sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) + + sdfg(x=a, y=b, **FSYMBOLS) + assert np.allclose(b, a * 8) + + +def test_gtir_let_lambda_with_cond(): + domain = im.call("cartesian_domain")( + im.call("named_range")(gtir.AxisLiteral(value=IDim.value), 0, "size") + ) + testee = gtir.Program( + id="let_lambda_with_cond", + function_definitions=[], + params=[ + gtir.Sym(id="x", type=IFTYPE), + gtir.Sym(id="y", type=IFTYPE), + gtir.Sym(id="pred", type=ts.ScalarType(ts.ScalarKind.BOOL)), + gtir.Sym(id="size", type=SIZE_TYPE), + ], + declarations=[], + body=[ + gtir.SetAt( + expr=im.let("x1", "x")( + im.let("x2", im.op_as_fieldop("multiplies", domain)(2.0, "x"))( + im.call("cond")( + gtir.SymRef(id="pred"), + im.as_fieldop(im.lambda_("a")(im.deref("a")), domain)("x1"), + im.as_fieldop(im.lambda_("a")(im.deref("a")), domain)("x2"), + ) + ) + ), + domain=domain, + target=gtir.SymRef(id="y"), + ) + ], + ) + + sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) + + a = np.random.rand(N) + for s in [False, True]: + b = np.empty_like(a) + sdfg(pred=np.bool_(s), x=a, y=b, **FSYMBOLS) + assert np.allclose(b, a if s else a * 2)