From d7e73f5177e033032ba511f8bc8752eff78edd08 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 15 Jan 2025 14:54:09 +0100 Subject: [PATCH] This commit fixes an error that was reported by Edoardo (@edopao). The bug was because the `DistributedBufferRelocator` transformation did not check if its insertion would create a read-write conflict. This commit adds such a check, that is, however, not very sophisticated and needs some improvements. However, the example /`model/atmosphere/dycore/tests/dycore_stencil_tests/test_compute_exner_from_rhotheta.py`) where it surfaced, does hold more challenges. The main purpose of this commit is to unblock further development in ICON4Py. Link to ICON4Py PR: https://github.com/C2SM/icon4py/pull/638 --- .../transformations/simplify.py | 203 ++++++++++++++---- .../test_distributed_buffer_relocator.py | 163 +++++++++++++- 2 files changed, 316 insertions(+), 50 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py index 4339a761fa..13e5ac15e8 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py @@ -374,7 +374,7 @@ def apply( raise -AccessLocation: TypeAlias = tuple[dace.SDFGState, dace_nodes.AccessNode] +AccessLocation: TypeAlias = tuple[dace_nodes.AccessNode, dace.SDFGState] """Describes an access node and the state in which it is located. """ @@ -387,29 +387,38 @@ class DistributedBufferRelocator(dace_transformation.Pass): in each branch and then in the join state written back. Thus there is some additional storage needed. The transformation will look for the following situation: - - A transient data container, called `src_cont`, is written into another - container, called `dst_cont`, which is not transient. - - The access node of `src_cont` has an in degree of zero and an out degree of one. - - The access node of `dst_cont` has an in degree of of one and an + - A transient data container, called `temp_storage`, is written into another + container, called `dest_storage`, which is not transient. + - The access node of `temp_storage` has an in degree of zero and an out degree of one. + - The access node of `dest_storage` has an in degree of of one and an out degree of zero (this might be lifted). - - `src_cont` is not used afterwards. - - `dst_cont` is only used to implement the buffering. + - `temp_storage` is not used afterwards. + - `dest_storage` is only used to implement the buffering. - The function will relocate the writing of `dst_cont` to where `src_cont` is + The function will relocate the writing of `dest_storage` to where `temp_storage` is written, which might be multiple locations. It will also remove the writing back. It is advised that after this transformation simplify is run again. + The relocation will not take place if it might create data race. A necessary + but not sufficient condition for a data race is if `dest_storage` is present + in the state where `temp_storage` is defined. In addition at least one of the + following conditions has to be met: + - `dest_storage`, that exists in the state is not connected to the + `temp_storage` access node. + - There is a `dest_storage` access node, but it has an output degree larger + than one. + Note: - Essentially this transformation removes the double buffering of `dst_cont`. - Because we ensure that that `dst_cont` is non transient this is okay, as our + Essentially this transformation removes the double buffering of `dest_storage`. + Because we ensure that that `dest_storage` is non transient this is okay, as our rule guarantees this. Todo: - - Allow that `dst_cont` can also be transient. - - Allow that `dst_cont` does not need to be a sink node, this is most + - Allow that `dest_storage` can also be transient. + - Allow that `dest_storage` does not need to be a sink node, this is most likely most relevant if it is transient. - - Check if `dst_cont` is used between where we want to place it and + - Check if `dest_storage` is used between where we want to place it and where it is currently used. """ @@ -489,10 +498,10 @@ def _find_candidates( where the temporary is defined. """ # All nodes that are used as distributed buffers. - candidate_src_cont: list[AccessLocation] = [] + candidate_temp_storage: list[AccessLocation] = [] - # Which `src_cont` access node is written back to which global memory. - src_cont_to_global: dict[dace_nodes.AccessNode, str] = {} + # Which `temp_storage` access node is written back to which global memory. + temp_storage_to_global: dict[dace_nodes.AccessNode, str] = {} for state in sdfg.states(): # These are the possible targets we want to write into. @@ -508,26 +517,26 @@ def _find_candidates( if len(candidate_dst_nodes) == 0: continue - for src_cont in state.source_nodes(): - if not isinstance(src_cont, dace_nodes.AccessNode): + for temp_storage in state.source_nodes(): + if not isinstance(temp_storage, dace_nodes.AccessNode): continue - if not src_cont.desc(sdfg).transient: + if not temp_storage.desc(sdfg).transient: continue - if state.out_degree(src_cont) != 1: + if state.out_degree(temp_storage) != 1: continue dst_candidate: dace_nodes.AccessNode = next( - iter(edge.dst for edge in state.out_edges(src_cont)) + iter(edge.dst for edge in state.out_edges(temp_storage)) ) if dst_candidate not in candidate_dst_nodes: continue - candidate_src_cont.append((src_cont, state)) - src_cont_to_global[src_cont] = dst_candidate.data + candidate_temp_storage.append((temp_storage, state)) + temp_storage_to_global[temp_storage] = dst_candidate.data - if len(candidate_src_cont) == 0: + if len(candidate_temp_storage) == 0: return [] # Now we have to find the places where the temporary sources are defined. - # I.e. This is also the location where the original value is defined. + # I.e. This is also the location where the temporary source was initialized. result_candidates: list[tuple[AccessLocation, list[AccessLocation]]] = [] def find_upstream_states(dst_state: dace.SDFGState) -> set[dace.SDFGState]: @@ -537,39 +546,39 @@ def find_upstream_states(dst_state: dace.SDFGState) -> set[dace.SDFGState]: if dst_state in reachable[src_state] and dst_state is not src_state } - for src_cont in candidate_src_cont: + for temp_storage in candidate_temp_storage: def_locations: list[AccessLocation] = [] - for upstream_state in find_upstream_states(src_cont[1]): - if src_cont[0].data in access_sets[upstream_state][1]: + for upstream_state in find_upstream_states(temp_storage[1]): + if temp_storage[0].data in access_sets[upstream_state][1]: def_locations.extend( (data_node, upstream_state) for data_node in upstream_state.data_nodes() - if data_node.data == src_cont[0].data + if data_node.data == temp_storage[0].data ) if len(def_locations) != 0: - result_candidates.append((src_cont, def_locations)) + result_candidates.append((temp_storage, def_locations)) - # This transformation removes `src_cont` by writing its content directly - # to `dst_cont`, at the point where it is defined. + # This transformation removes `temp_storage` by writing its content directly + # to `dest_storage`, at the point where it is defined. # For this transformation to be valid the following conditions have to be met: - # - Between the definition of `src_cont` and the write back to `dst_cont`, - # `dst_cont` can not be accessed. - # - Between the definitions of `src_cont` and the point where it is written - # back, `src_cont` can only be accessed in the range that is written back. - # - After the write back point, `src_cont` shall not be accessed. This + # - Between the definition of `temp_storage` and the write back to `dest_storage`, + # `dest_storage` can not be accessed. + # - Between the definitions of `temp_storage` and the point where it is written + # back, `temp_storage` can only be accessed in the range that is written back. + # - After the write back point, `temp_storage` shall not be accessed. This # restriction could be lifted. # # To keep the implementation simple, we use the conditions: - # - `src_cont` is only accessed were it is defined and at the write back + # - `temp_storage` is only accessed were it is defined and at the write back # point. - # - Between the definitions of `src_cont` and the write back point, - # `dst_cont` is not used. + # - Between the definitions of `temp_storage` and the write back point, + # `dest_storage` is not used. result: list[tuple[AccessLocation, list[AccessLocation]]] = [] for wb_localation, def_locations in result_candidates: for def_node, def_state in def_locations: - # Test if `src_cont` is only accessed where it is defined and + # Test if `temp_storage` is only accessed where it is defined and # where it is written back. if gtx_transformations.util.is_accessed_downstream( start_state=def_state, @@ -579,30 +588,136 @@ def find_upstream_states(dst_state: dace.SDFGState) -> set[dace.SDFGState]: ): break # check if the global data is not used between the definition of - # `dst_cont` and where its written back. We allow one exception, + # `dest_storage` and where its written back. We allow one exception, # if the global data is used in the state the distributed temporary # is defined is used only for reading then it is ignored. This is # allowed because of rule 3 of ADR0018. glob_nodes_in_def_state = { dnode for dnode in def_state.data_nodes() - if dnode.data == src_cont_to_global[wb_localation[0]] + if dnode.data == temp_storage_to_global[wb_localation[0]] } if any(def_state.in_degree(gdnode) != 0 for gdnode in glob_nodes_in_def_state): break if gtx_transformations.util.is_accessed_downstream( start_state=def_state, sdfg=sdfg, - data_to_look=src_cont_to_global[wb_localation[0]], + data_to_look=temp_storage_to_global[wb_localation[0]], nodes_to_ignore=glob_nodes_in_def_state, states_to_ignore={wb_localation[1]}, ): break + if self._check_read_write_dependency(sdfg, wb_localation, def_locations): + break else: result.append((wb_localation, def_locations)) return result + def _check_read_write_dependency( + self, + sdfg: dace.SDFG, + write_back_location: AccessLocation, + target_locations: list[AccessLocation], + ) -> bool: + """Tests if read-write conflicts would be created. + + This function ensures that the substitution of `write_back_location` into + `target_locations` will not create a read-write conflict. + The rules that are used for this are outlined in the class description. + + Args: + sdfg: The SDFG on which we operate. + write_back_location: Where currently the write back occurs. + target_locations: List of the locations were we would like to perform + the write back instead. + """ + for target_location in target_locations: + if self._check_read_write_dependency_impl(sdfg, write_back_location, target_location): + return True + return False + + def _check_read_write_dependency_impl( + self, + sdfg: dace.SDFG, + write_back_location: AccessLocation, + target_location: AccessLocation, + ) -> bool: + """Tests if read-write conflict would be created for a single location. + + Args: + sdfg: The SDFG on which we operate. + write_back_location: Where currently the write back occurs. + target_locations: Location where the new write back should be performed. + + Todo: + Refine this checks later. + """ + assert write_back_location[0].data == target_location[0].data + + write_back_state: dace.SDFGState = write_back_location[1] + write_back_node = write_back_location[0] + write_back_edge = next(iter(write_back_state.out_edges(write_back_node))) + global_data_name = write_back_edge.dst.data + assert not sdfg.arrays[global_data_name].transient + + # This is the state in which we have to look for possible data races. + state_to_inspect: dace.SDFGState = target_location[1] + + # This is the access not on which we will perform the write back. + def_location_of_intermediate: dace_nodes.AccessNode = target_location[0] + assert state_to_inspect.out_degree(def_location_of_intermediate) == 0 + + # These are all access nodes that refers to the global data, that we want + # to move into the state `state_to_inspect`. We need them to do the + # second test. + access_to_global_data_in_this_state: set[dace_nodes.AccessNode] = set() + + # The first simple test is to look if global data is used and has more + # than output edge. Further analysis could show that it could be handled. + # But we do not do it. + for dnode in state_to_inspect.data_nodes(): + if dnode.data != global_data_name: + continue + if state_to_inspect.out_degree(dnode) >= 2: + return True + access_to_global_data_in_this_state.add(dnode) + + # There is no reference to the global data, so no need to do more tests. + if len(access_to_global_data_in_this_state) == 0: + return False + + # For the second test we look if `global_data_name` is referred to in + # another data flow graph in this state, that is however, not connected + # to the graph that contains `def_location_of_intermediate`. We will + # do this by exploring the dataflow graph from that node and if we found + # a node that refers to the global data, we remove it from + # `access_to_global_data_in_this_state`, if that list is empty at the end + # then it is not used in another component. + to_process: list[dace_nodes.Node] = [def_location_of_intermediate] + seen: set[dace_nodes.Node] = set() + while len(to_process) != 0: + node = to_process.pop() + seen.add(node) + + if isinstance(node, dace_nodes.AccessNode): + if node.data == global_data_name: + access_to_global_data_in_this_state.discard(node) + + # Note that we only explore the ingoing edges, thus we will not necessarily + # explore the whole graph. However, this is fine, because we will see the + # relevant parts. To see that assume that we would also have to check the + # outgoing edges, this would mean that there was some branching point, + # which is a serialization point, so the dataflow would have been invalid + # before. + to_process.extend( + iedge.src for iedge in state_to_inspect.in_edges(node) if iedge.src not in seen + ) + + if len(access_to_global_data_in_this_state) == 0: + return False + return True + @dace_properties.make_properties class GT4PyMoveTaskletIntoMap(dace_transformation.SingleStateTransformation): diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py index 1543a048ad..33c02a2055 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py @@ -13,7 +13,7 @@ transformations as gtx_transformations, ) -# from . import util +from . import util # dace = pytest.importorskip("dace") @@ -21,8 +21,8 @@ import dace -def _mk_distributed_buffer_sdfg() -> tuple[dace.SDFG, dace.SDFGState]: - sdfg = dace.SDFG("NAME") # util.unique_name("distributed_buffer_sdfg")) +def _mk_distributed_buffer_sdfg() -> tuple[dace.SDFG, dace.SDFGState, dace.SDFGState]: + sdfg = dace.SDFG(util.unique_name("distributed_buffer_sdfg")) for name in ["a", "b", "tmp"]: sdfg.add_array(name, shape=(10, 10), dtype=dace.float64, transient=False) @@ -66,19 +66,170 @@ def _mk_distributed_buffer_sdfg() -> tuple[dace.SDFG, dace.SDFGState]: sdfg.validate() assert sdfg.number_of_nodes() == 3 - return sdfg, state1 + return sdfg, state1, state3 def test_distributed_buffer_remover(): - sdfg, state1 = _mk_distributed_buffer_sdfg() + sdfg, state1, state3 = _mk_distributed_buffer_sdfg() assert state1.number_of_nodes() == 5 assert not any(dnode.data == "b" for dnode in state1.data_nodes()) res = gtx_transformations.gt_reduce_distributed_buffering(sdfg) - assert res is not None + assert res[sdfg]["DistributedBufferRelocator"][state3] == {"tmp"} # Because the final state has now become empty assert sdfg.number_of_nodes() == 3 assert state1.number_of_nodes() == 6 assert any(dnode.data == "b" for dnode in state1.data_nodes()) assert any(dnode.data == "tmp" for dnode in state1.data_nodes()) + + +def _make_distributed_buffer_global_memory_data_race_sdfg() -> tuple[dace.SDFG, dace.SDFGState]: + sdfg = dace.SDFG(util.unique_name("distributed_buffer_global_memory_data_race")) + arr_names = ["a", "b", "t"] + for name in arr_names: + sdfg.add_array( + name=name, + shape=(10, 10), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["t"].transient = True + + state1 = sdfg.add_state(is_start_block=True) + state2 = sdfg.add_state_after(state1) + + a_state1 = state1.add_access("a") + state1.add_mapped_tasklet( + "computation", + map_ranges={"__i0": "0:10", "__i1": "0:10"}, + inputs={"__in": dace.Memlet("a[__i0, __i1]")}, + code="__out = __in + 10", + outputs={"__out": dace.Memlet("t[__i0, __i1]")}, + input_nodes={a_state1}, + external_edges=True, + ) + state1.add_nedge(a_state1, state1.add_access("b"), dace.Memlet("a[0:10, 0:10]")) + + state2.add_nedge(state2.add_access("t"), state2.add_access("a"), dace.Memlet("t[0:10, 0:10]")) + sdfg.validate() + + return sdfg, state2 + + +def test_distributed_buffer_global_memory_data_race(): + """Tests if the transformation realized that it would create a data race. + + If the transformation would apply, then `a` is read twice, once from two + different branches, whose order of execution is indeterminate. + """ + sdfg, state2 = _make_distributed_buffer_global_memory_data_race_sdfg() + assert state2.number_of_nodes() == 2 + + sdfg.simplify() + assert sdfg.number_of_nodes() == 2 + + res = gtx_transformations.gt_reduce_distributed_buffering(sdfg) + assert "DistributedBufferRelocator" not in res[sdfg] + assert state2.number_of_nodes() == 2 + + +def _make_distributed_buffer_global_memory_data_race_sdfg2() -> ( + tuple[dace.SDFG, dace.SDFGState, dace.SDFGState] +): + sdfg = dace.SDFG(util.unique_name("distributed_buffer_global_memory_data_race2_sdfg")) + arr_names = ["a", "b", "t"] + for name in arr_names: + sdfg.add_array( + name=name, + shape=(10, 10), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["t"].transient = True + + state1 = sdfg.add_state(is_start_block=True) + state2 = sdfg.add_state_after(state1) + + state1.add_mapped_tasklet( + "computation1", + map_ranges={"__i0": "0:10", "__i1": "0:10"}, + inputs={"__in": dace.Memlet("a[__i0, __i1]")}, + code="__out = __in + 10", + outputs={"__out": dace.Memlet("t[__i0, __i1]")}, + external_edges=True, + ) + state1.add_mapped_tasklet( + "computation1", + map_ranges={"__i0": "0:10", "__i1": "0:10"}, + inputs={"__in": dace.Memlet("a[__i0, __i1]")}, + code="__out = __in - 10", + outputs={"__out": dace.Memlet("b[__i0, __i1]")}, + external_edges=True, + ) + state2.add_nedge(state2.add_access("t"), state2.add_access("a"), dace.Memlet("t[0:10, 0:10]")) + sdfg.validate() + + return sdfg, state1, state2 + + +def test_distributed_buffer_global_memory_data_race2(): + """Tests if the transformation realized that it would create a data race. + + Similar situation but now there are two different subgraphs. This is needed + because it is another branch that checks it. + """ + sdfg, state1, state2 = _make_distributed_buffer_global_memory_data_race_sdfg2() + assert state1.number_of_nodes() == 10 + assert state2.number_of_nodes() == 2 + + res = gtx_transformations.gt_reduce_distributed_buffering(sdfg) + assert "DistributedBufferRelocator" not in res[sdfg] + assert state1.number_of_nodes() == 10 + assert state2.number_of_nodes() == 2 + + +def _make_distributed_buffer_global_memory_data_no_rance() -> tuple[dace.SDFG, dace.SDFGState]: + sdfg = dace.SDFG(util.unique_name("distributed_buffer_global_memory_data_no_rance_sdfg")) + arr_names = ["a", "t"] + for name in arr_names: + sdfg.add_array( + name=name, + shape=(10, 10), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["t"].transient = True + + state1 = sdfg.add_state(is_start_block=True) + state2 = sdfg.add_state_after(state1) + + a_state1 = state1.add_access("a") + state1.add_mapped_tasklet( + "computation", + map_ranges={"__i0": "0:10", "__i1": "0:10"}, + inputs={"__in": dace.Memlet("a[__i0, __i1]")}, + code="__out = __in + 10", + outputs={"__out": dace.Memlet("t[__i0, __i1]")}, + input_nodes={a_state1}, + external_edges=True, + ) + + state2.add_nedge(state2.add_access("t"), state2.add_access("a"), dace.Memlet("t[0:10, 0:10]")) + sdfg.validate() + + return sdfg, state2 + + +def test_distributed_buffer_global_memory_data_no_rance(): + """Transformation applies if there is no data race. + + According to ADR18, pointwise dependencies are fine. This tests checks if the + checks for the read-write conflicts are not too strong. + """ + sdfg, state2 = _make_distributed_buffer_global_memory_data_no_rance() + assert state2.number_of_nodes() == 2 + + res = gtx_transformations.gt_reduce_distributed_buffering(sdfg) + assert res[sdfg]["DistributedBufferRelocator"][state2] == {"t"} + assert state2.number_of_nodes() == 0