diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 2e58eccec8..3a33ee1e35 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -191,6 +191,7 @@ def _visit_lift_in_neighbors_reduction( neighbor_index_node: dace.nodes.AccessNode, neighbor_value_node: dace.nodes.AccessNode, ) -> list[ValueExpr]: + assert transformer.context.reduce_identity is not None neighbor_dim = offset_provider.neighbor_axis.value origin_dim = offset_provider.origin_axis.value @@ -220,7 +221,7 @@ def _visit_lift_in_neighbors_reduction( input_nodes = {} iterator_index_nodes = {} - lifted_index_connectors = set() + lifted_index_connectors = [] for x, y in inner_inputs: if isinstance(y, IteratorExpr): @@ -228,7 +229,7 @@ def _visit_lift_in_neighbors_reduction( input_nodes[field_connector] = y.field for dim, connector in inner_index_table.items(): if dim == neighbor_dim: - lifted_index_connectors.add(connector) + lifted_index_connectors.append(connector) iterator_index_nodes[connector] = y.indices[dim] else: assert isinstance(y, ValueExpr) @@ -298,6 +299,30 @@ def _visit_lift_in_neighbors_reduction( memlet=dace.Memlet(data=neighbor_value_node.data, subset=",".join(map_entry.params)), ) + if offset_provider.has_skip_values: + # check neighbor validity on if/else inter-state edge + start_state = lift_context.body.add_state("start", is_start_block=True) + skip_neighbor_state = lift_context.body.add_state("skip_neighbor") + skip_neighbor_state.add_edge( + skip_neighbor_state.add_tasklet( + "identity", {}, {"val"}, f"val = {transformer.context.reduce_identity.value}" + ), + "val", + skip_neighbor_state.add_access(inner_outputs[0].value.data), + None, + dace.Memlet(data=inner_outputs[0].value.data, subset="0"), + ) + lift_context.body.add_edge( + start_state, + skip_neighbor_state, + dace.InterstateEdge(condition=f"{lifted_index_connectors[0]} == {neighbor_skip_value}"), + ) + lift_context.body.add_edge( + start_state, + lift_context.state, + dace.InterstateEdge(condition=f"{lifted_index_connectors[0]} != {neighbor_skip_value}"), + ) + return [ValueExpr(neighbor_value_node, inner_outputs[0].dtype)] @@ -467,7 +492,7 @@ def builtin_neighbors( neighbor_valid_node = state.add_access(neighbor_valid_var, debuginfo=di) neighbor_valid_tasklet = state.add_tasklet( - "check_valid_neighbor", + f"check_valid_neighbor_{offset_dim}", {"__idx"}, {"__valid"}, f"__valid = True if __idx != {neighbor_skip_value} else False", @@ -1223,7 +1248,7 @@ def _visit_reduce(self, node: itir.FunCall): nreduce_shape = args_shape[0] input_args = [arg[0] for arg in args] - input_valid = [arg[1] for arg in args if len(arg) == 2] + input_valid_args = [arg[1] for arg in args if len(arg) == 2] nreduce_index = tuple(f"_i{i}" for i in range(len(nreduce_shape))) nreduce_domain = {idx: f"0:{size}" for idx, size in zip(nreduce_index, nreduce_shape)} @@ -1255,41 +1280,56 @@ def _visit_reduce(self, node: itir.FunCall): self.context.body, lambda_context.body, input_mapping ) - if input_valid: + if input_valid_args: """ - The neighbors builtin returns an array of booleans in case the connectivity table - contains skip values. These boolean values indicate whether the neighbor value is present or not, - and are used below to construct an if/else branch to bypass the lambda call for neighbor skip values. + The neighbors builtin returns an array of booleans in case the connectivity table contains skip values. + These booleans indicate whether the neighbor is present or not, and are used in a tasklet to select + the result of field access or the identity value, respectively. If the neighbor table has full connectivity (no skip values by type definition), the input_valid node - is not built, and the construction of the if/else branch below is also skipped. + is not built, and the construction of the select tasklet below is also skipped. """ - input_args.append(input_valid[0]) - input_valid_node = input_valid[0].value + input_args.append(input_valid_args[0]) + input_valid_node = input_valid_args[0].value + lambda_output_node = inner_outputs[0].value # add input connector to nested sdfg - input_mapping["is_valid"] = create_memlet_at(input_valid_node.data, nreduce_index) - # check neighbor validity on if/else inter-state edge - start_state = lambda_context.body.add_state("start", is_start_block=True) - skip_neighbor_state = lambda_context.body.add_state("skip_neighbor") - skip_neighbor_state.add_edge( - skip_neighbor_state.add_tasklet( - "identity", {}, {"val"}, f"val = {reduce_identity}" - ), - "val", - skip_neighbor_state.add_access(inner_outputs[0].value.data), + lambda_context.body.add_scalar("_valid_neighbor", dace.dtypes.bool) + input_mapping["_valid_neighbor"] = create_memlet_at( + input_valid_node.data, nreduce_index + ) + # add select tasklet before writing to output node + # TODO: consider replacing it with a select-memlet once it is supported by DaCe SDFG API + output_edge = lambda_context.state.in_edges(lambda_output_node)[0] + assert isinstance( + lambda_context.body.arrays[output_edge.src.data], dace.data.Scalar + ) + select_tasklet = lambda_context.state.add_tasklet( + "neighbor_select", + {"_inp", "_valid"}, + {"_out"}, + f"_out = _inp if _valid else {reduce_identity}", + ) + lambda_context.state.add_edge( + output_edge.src, None, - dace.Memlet(data=inner_outputs[0].value.data, subset="0"), + select_tasklet, + "_inp", + dace.Memlet(data=output_edge.src.data, subset="0"), ) - lambda_context.body.add_scalar("is_valid", dace.dtypes.bool) - lambda_context.body.add_edge( - start_state, - skip_neighbor_state, - dace.InterstateEdge(condition="is_valid == False"), + lambda_context.state.add_edge( + lambda_context.state.add_access("_valid_neighbor"), + None, + select_tasklet, + "_valid", + dace.Memlet(data="_valid_neighbor", subset="0"), ) - lambda_context.body.add_edge( - start_state, - lambda_context.state, - dace.InterstateEdge(condition="is_valid == True"), + lambda_context.state.add_edge( + select_tasklet, + "_out", + lambda_output_node, + None, + dace.Memlet(data=lambda_output_node.data, subset="0"), ) + lambda_context.state.remove_edge(output_edge) reduce_input_node = self.context.state.add_access(reduce_input_name, debuginfo=di) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 3c9c4e686c..e499f83f86 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -515,9 +515,9 @@ def combine(a: cases.IField, b: cases.IField) -> cases.IField: @pytest.mark.uses_reduction_over_lift_expressions def test_nested_reduction(unstructured_case): @gtx.field_operator - def testee(a: cases.EField) -> cases.EField: - tmp = neighbor_sum(a(V2E), axis=V2EDim) - tmp_2 = neighbor_sum(tmp(E2V), axis=E2VDim) + def testee(a: cases.VField) -> cases.VField: + tmp = neighbor_sum(a(E2V), axis=E2VDim) + tmp_2 = neighbor_sum(tmp(V2E), axis=V2EDim) return tmp_2 cases.verify_with_default_data( @@ -525,12 +525,12 @@ def testee(a: cases.EField) -> cases.EField: testee, ref=lambda a: np.sum( np.sum( - a[unstructured_case.offset_provider["V2E"].table], + a[unstructured_case.offset_provider["E2V"].table], axis=1, initial=0, - where=unstructured_case.offset_provider["V2E"].table != common.SKIP_VALUE, - )[unstructured_case.offset_provider["E2V"].table], + )[unstructured_case.offset_provider["V2E"].table], axis=1, + where=unstructured_case.offset_provider["V2E"].table != common.SKIP_VALUE, ), comparison=lambda a, tmp_2: np.all(a == tmp_2), )