From 0a292611aa655c34abe497bcfeb72e33361939c5 Mon Sep 17 00:00:00 2001 From: edopao Date: Fri, 16 Feb 2024 11:01:40 +0100 Subject: [PATCH] fix[next][dace]: Bugfix for nested neighbor reduction (#1457) In case of nested neighbor reduction with lift expression on inner node, the DaCe backend should generate a conditional state transition to field access, based on the value of neighbor index provided by the outer connectivity table. Additional change. The previous selection of valid neighbors implemented as conditional inter-state edge is replaced by a select tasklet, which makes the SDFG easier to read. --- .../runners/dace_iterator/itir_to_tasklet.py | 102 ++++++++++++------ .../ffront_tests/test_execution.py | 12 +-- 2 files changed, 77 insertions(+), 37 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 2e58eccec8..3a33ee1e35 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -191,6 +191,7 @@ def _visit_lift_in_neighbors_reduction( neighbor_index_node: dace.nodes.AccessNode, neighbor_value_node: dace.nodes.AccessNode, ) -> list[ValueExpr]: + assert transformer.context.reduce_identity is not None neighbor_dim = offset_provider.neighbor_axis.value origin_dim = offset_provider.origin_axis.value @@ -220,7 +221,7 @@ def _visit_lift_in_neighbors_reduction( input_nodes = {} iterator_index_nodes = {} - lifted_index_connectors = set() + lifted_index_connectors = [] for x, y in inner_inputs: if isinstance(y, IteratorExpr): @@ -228,7 +229,7 @@ def _visit_lift_in_neighbors_reduction( input_nodes[field_connector] = y.field for dim, connector in inner_index_table.items(): if dim == neighbor_dim: - lifted_index_connectors.add(connector) + lifted_index_connectors.append(connector) iterator_index_nodes[connector] = y.indices[dim] else: assert isinstance(y, ValueExpr) @@ -298,6 +299,30 @@ def _visit_lift_in_neighbors_reduction( memlet=dace.Memlet(data=neighbor_value_node.data, subset=",".join(map_entry.params)), ) + if offset_provider.has_skip_values: + # check neighbor validity on if/else inter-state edge + start_state = lift_context.body.add_state("start", is_start_block=True) + skip_neighbor_state = lift_context.body.add_state("skip_neighbor") + skip_neighbor_state.add_edge( + skip_neighbor_state.add_tasklet( + "identity", {}, {"val"}, f"val = {transformer.context.reduce_identity.value}" + ), + "val", + skip_neighbor_state.add_access(inner_outputs[0].value.data), + None, + dace.Memlet(data=inner_outputs[0].value.data, subset="0"), + ) + lift_context.body.add_edge( + start_state, + skip_neighbor_state, + dace.InterstateEdge(condition=f"{lifted_index_connectors[0]} == {neighbor_skip_value}"), + ) + lift_context.body.add_edge( + start_state, + lift_context.state, + dace.InterstateEdge(condition=f"{lifted_index_connectors[0]} != {neighbor_skip_value}"), + ) + return [ValueExpr(neighbor_value_node, inner_outputs[0].dtype)] @@ -467,7 +492,7 @@ def builtin_neighbors( neighbor_valid_node = state.add_access(neighbor_valid_var, debuginfo=di) neighbor_valid_tasklet = state.add_tasklet( - "check_valid_neighbor", + f"check_valid_neighbor_{offset_dim}", {"__idx"}, {"__valid"}, f"__valid = True if __idx != {neighbor_skip_value} else False", @@ -1223,7 +1248,7 @@ def _visit_reduce(self, node: itir.FunCall): nreduce_shape = args_shape[0] input_args = [arg[0] for arg in args] - input_valid = [arg[1] for arg in args if len(arg) == 2] + input_valid_args = [arg[1] for arg in args if len(arg) == 2] nreduce_index = tuple(f"_i{i}" for i in range(len(nreduce_shape))) nreduce_domain = {idx: f"0:{size}" for idx, size in zip(nreduce_index, nreduce_shape)} @@ -1255,41 +1280,56 @@ def _visit_reduce(self, node: itir.FunCall): self.context.body, lambda_context.body, input_mapping ) - if input_valid: + if input_valid_args: """ - The neighbors builtin returns an array of booleans in case the connectivity table - contains skip values. These boolean values indicate whether the neighbor value is present or not, - and are used below to construct an if/else branch to bypass the lambda call for neighbor skip values. + The neighbors builtin returns an array of booleans in case the connectivity table contains skip values. + These booleans indicate whether the neighbor is present or not, and are used in a tasklet to select + the result of field access or the identity value, respectively. If the neighbor table has full connectivity (no skip values by type definition), the input_valid node - is not built, and the construction of the if/else branch below is also skipped. + is not built, and the construction of the select tasklet below is also skipped. """ - input_args.append(input_valid[0]) - input_valid_node = input_valid[0].value + input_args.append(input_valid_args[0]) + input_valid_node = input_valid_args[0].value + lambda_output_node = inner_outputs[0].value # add input connector to nested sdfg - input_mapping["is_valid"] = create_memlet_at(input_valid_node.data, nreduce_index) - # check neighbor validity on if/else inter-state edge - start_state = lambda_context.body.add_state("start", is_start_block=True) - skip_neighbor_state = lambda_context.body.add_state("skip_neighbor") - skip_neighbor_state.add_edge( - skip_neighbor_state.add_tasklet( - "identity", {}, {"val"}, f"val = {reduce_identity}" - ), - "val", - skip_neighbor_state.add_access(inner_outputs[0].value.data), + lambda_context.body.add_scalar("_valid_neighbor", dace.dtypes.bool) + input_mapping["_valid_neighbor"] = create_memlet_at( + input_valid_node.data, nreduce_index + ) + # add select tasklet before writing to output node + # TODO: consider replacing it with a select-memlet once it is supported by DaCe SDFG API + output_edge = lambda_context.state.in_edges(lambda_output_node)[0] + assert isinstance( + lambda_context.body.arrays[output_edge.src.data], dace.data.Scalar + ) + select_tasklet = lambda_context.state.add_tasklet( + "neighbor_select", + {"_inp", "_valid"}, + {"_out"}, + f"_out = _inp if _valid else {reduce_identity}", + ) + lambda_context.state.add_edge( + output_edge.src, None, - dace.Memlet(data=inner_outputs[0].value.data, subset="0"), + select_tasklet, + "_inp", + dace.Memlet(data=output_edge.src.data, subset="0"), ) - lambda_context.body.add_scalar("is_valid", dace.dtypes.bool) - lambda_context.body.add_edge( - start_state, - skip_neighbor_state, - dace.InterstateEdge(condition="is_valid == False"), + lambda_context.state.add_edge( + lambda_context.state.add_access("_valid_neighbor"), + None, + select_tasklet, + "_valid", + dace.Memlet(data="_valid_neighbor", subset="0"), ) - lambda_context.body.add_edge( - start_state, - lambda_context.state, - dace.InterstateEdge(condition="is_valid == True"), + lambda_context.state.add_edge( + select_tasklet, + "_out", + lambda_output_node, + None, + dace.Memlet(data=lambda_output_node.data, subset="0"), ) + lambda_context.state.remove_edge(output_edge) reduce_input_node = self.context.state.add_access(reduce_input_name, debuginfo=di) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 3c9c4e686c..e499f83f86 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -515,9 +515,9 @@ def combine(a: cases.IField, b: cases.IField) -> cases.IField: @pytest.mark.uses_reduction_over_lift_expressions def test_nested_reduction(unstructured_case): @gtx.field_operator - def testee(a: cases.EField) -> cases.EField: - tmp = neighbor_sum(a(V2E), axis=V2EDim) - tmp_2 = neighbor_sum(tmp(E2V), axis=E2VDim) + def testee(a: cases.VField) -> cases.VField: + tmp = neighbor_sum(a(E2V), axis=E2VDim) + tmp_2 = neighbor_sum(tmp(V2E), axis=V2EDim) return tmp_2 cases.verify_with_default_data( @@ -525,12 +525,12 @@ def testee(a: cases.EField) -> cases.EField: testee, ref=lambda a: np.sum( np.sum( - a[unstructured_case.offset_provider["V2E"].table], + a[unstructured_case.offset_provider["E2V"].table], axis=1, initial=0, - where=unstructured_case.offset_provider["V2E"].table != common.SKIP_VALUE, - )[unstructured_case.offset_provider["E2V"].table], + )[unstructured_case.offset_provider["V2E"].table], axis=1, + where=unstructured_case.offset_provider["V2E"].table != common.SKIP_VALUE, ), comparison=lambda a, tmp_2: np.all(a == tmp_2), )