From d0dcf1ca8407f02f691816eed7102057df2d8149 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Fri, 20 Sep 2024 11:01:38 +0200 Subject: [PATCH] Fixed `PruneConnectors` (#1660) There was a bug in the `PruneConnectors` transformation, the apply function did not prune the sets correctly. I also made some additional changes. --- .../dataflow/prune_connectors.py | 167 +++--------------- .../transformations/prune_connectors_test.py | 155 ++++++++++++++-- 2 files changed, 166 insertions(+), 156 deletions(-) diff --git a/dace/transformation/dataflow/prune_connectors.py b/dace/transformation/dataflow/prune_connectors.py index 499f488448..a8371047df 100644 --- a/dace/transformation/dataflow/prune_connectors.py +++ b/dace/transformation/dataflow/prune_connectors.py @@ -11,67 +11,66 @@ @properties.make_properties class PruneConnectors(pm.SingleStateTransformation): - """ Removes unused connectors from nested SDFGs, as well as their memlets - in the outer scope, replacing them with empty memlets if necessary. + """ + Removes unused connectors from nested SDFGs, as well as their memlets in the outer scope. - Optionally: after pruning, removes the unused containers from parent SDFG. + The transformation will not apply if this would remove all inputs and outputs. """ nsdfg = pm.PatternNode(nodes.NestedSDFG) - remove_unused_containers = properties.Property(dtype=bool, - default=False, - desc='If True, remove unused containers from parent SDFG.') - @classmethod def expressions(cls): return [utils.node_path_graph(cls.nsdfg)] def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissive: bool = False) -> bool: + prune_in, prune_out = self._get_prune_sets(graph) + if not prune_in and not prune_out: + return False + + return True + + def _get_prune_sets(self, state: SDFGState) -> Tuple[Set[str], Set[str]]: + """Computes the set of the input and output connectors that can be removed. + + Returns: + A tuple of two sets, the first set contains the name of all input + connectors that can be removed and the second the name of all output + connectors that can be removed. + """ nsdfg = self.nsdfg + # From the input connectors (i.e. data container on the inside) remove + # all those that are not used for reading and from the output containers + # remove those that are not used fro reading. + # NOTE: If a data container is used for reading and writing then only the + # output connector is retained, except the output is a WCR, then the input + # is also retained. read_set, write_set = nsdfg.sdfg.read_and_write_sets() prune_in = nsdfg.in_connectors.keys() - read_set prune_out = nsdfg.out_connectors.keys() - write_set - # Take into account symbol mappings - strs = tuple(nsdfg.symbol_mapping.values()) - syms = tuple(symbolic.pystr_to_symbolic(s) for s in strs) - symnames = tuple(s.name if hasattr(s, 'name') else '' for s in syms) - for conn in list(prune_in): - if conn in syms or conn in symnames or conn in nsdfg.sdfg.symbols: - prune_in.remove(conn) - - # Add WCR outputs to "do not prune" input list - for e in graph.out_edges(nsdfg): + for e in state.out_edges(nsdfg): if e.data.wcr is not None and e.src_conn in prune_in: prune_in.remove(e.src_conn) - if not prune_in and not prune_out: - return False - - return True + return prune_in, prune_out def apply(self, state: SDFGState, sdfg: SDFG): nsdfg = self.nsdfg + # Determine which connectors can be removed. + prune_in, prune_out = self._get_prune_sets(state) + # Fission subgraph around nsdfg into its own state to avoid data races nsdfg_state = helpers.state_fission_after(state, nsdfg) - read_set, write_set = nsdfg.sdfg.read_and_write_sets() - prune_in = nsdfg.in_connectors.keys() - read_set - prune_out = nsdfg.out_connectors.keys() - write_set - # Detect which nodes are used, so we can delete unused nodes after the # connectors have been pruned + read_set, write_set = nsdfg.sdfg.read_and_write_sets() all_data_used = read_set | write_set - # Add WCR outputs to "do not prune" input list - for e in nsdfg_state.out_edges(nsdfg): - if e.data.wcr is not None and e.src_conn in prune_in: - prune_in.remove(e.src_conn) - for conn in prune_in: for e in nsdfg_state.in_edges_by_connector(nsdfg, conn): nsdfg_state.remove_memlet_path(e, remove_orphans=True) @@ -89,18 +88,6 @@ def apply(self, state: SDFGState, sdfg: SDFG): # If the data is now unused, we can purge it from the SDFG nsdfg.sdfg.remove_data(conn) - if self.remove_unused_containers: - # Remove unused containers from parent SDFGs - containers = list(sdfg.arrays.keys()) - for name in containers: - s = nsdfg.sdfg - while s.parent_sdfg: - s = s.parent_sdfg - try: - s.remove_data(name) - except ValueError: - break - class PruneSymbols(pm.SingleStateTransformation): """ @@ -177,99 +164,3 @@ def apply(self, graph: SDFGState, sdfg: SDFG): # If not used in SDFG, remove from symbols as well if helpers.is_symbol_unused(nsdfg.sdfg, candidate): nsdfg.sdfg.remove_symbol(candidate) - - -class PruneUnusedOutputs(pm.SingleStateTransformation): - """ - Removes unused symbol mappings from nested SDFGs, as well as internal - symbols if necessary. - """ - - nsdfg = pm.PatternNode(nodes.NestedSDFG) - - @classmethod - def expressions(cls): - return [utils.node_path_graph(cls.nsdfg)] - - @classmethod - def _candidates(cls, nsdfg: nodes.NestedSDFG) -> Tuple[Set[str], Set[Tuple[SDFGState, nodes.AccessNode]]]: - # Start with all non-transient arrays - candidates = set(conn for conn in nsdfg.out_connectors.keys()) - candidate_nodes: Set[Tuple[SDFGState, nodes.AccessNode]] = set() - - # Remove candidates that are used more than once in the outer SDFG - state = nsdfg.sdfg.parent - sdfg = nsdfg.sdfg.parent_sdfg - for e in state.out_edges(nsdfg): - if e.data.is_empty(): - continue - outer_desc = sdfg.arrays[e.data.data] - if isinstance(outer_desc, dt.View): - candidates.remove(e.src_conn) - continue - if not outer_desc.transient: - candidates.remove(e.src_conn) - continue - if not isinstance(state.memlet_path(e)[-1].dst, nodes.AccessNode): - candidates.remove(e.src_conn) - continue - - all_access_nodes = [(s, n) for s in sdfg.nodes() for n in s.data_nodes() if n.data == e.data.data] - if len(all_access_nodes) > 1: - candidates.remove(e.src_conn) - continue - if all_access_nodes[0][0].out_degree(all_access_nodes[0][1]) > 0: - candidates.remove(e.src_conn) - continue - - if not candidates: - return set(), set() - - # Remove candidates that are used in the nested SDFG - for nstate in nsdfg.sdfg.states(): - for node in nstate.data_nodes(): - if node.data in candidates: - # If used in nested SDFG - if nstate.out_degree(node) > 0: - candidates.remove(node.data) - continue - # If a result of a code node - if any(not isinstance(nstate.memlet_path(e)[0].src, nodes.AccessNode) - for e in nstate.in_edges(node)): - candidates.remove(node.data) - continue - - # Add node for later use - candidate_nodes.add((nstate, node)) - - # Any array that is used in interstate edges is removed - for e in nsdfg.sdfg.all_interstate_edges(): - candidates -= (set(map(str, symbolic.symbols_in_ast(e.data.condition.code[0])))) - for assign in e.data.assignments.values(): - candidates -= (symbolic.free_symbols_and_functions(assign)) - - candidate_nodes = {n for n in candidate_nodes if n[1].data in candidates} - - return candidates, candidate_nodes - - def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissive: bool = False) -> bool: - nsdfg: nodes.NestedSDFG = self.nsdfg - candidates, _ = self._candidates(nsdfg) - if len(candidates) > 0: - return True - - return False - - def apply(self, state: SDFGState, sdfg: SDFG): - nsdfg = self.nsdfg - - candidates, candidate_nodes = self._candidates(nsdfg) - for outer_edge in state.out_edges(nsdfg): - if outer_edge.src_conn in candidates: - state.remove_memlet_path(outer_edge) - sdfg.remove_data(outer_edge.data.data, validate=False) - for nstate, node in candidate_nodes: - for ie in nstate.in_edges(node): - nstate.remove_memlet_path(ie) - for cand in candidates: - nsdfg.sdfg.remove_data(cand, validate=False) diff --git a/tests/transformations/prune_connectors_test.py b/tests/transformations/prune_connectors_test.py index 59e1b125ff..4026ec3e1c 100644 --- a/tests/transformations/prune_connectors_test.py +++ b/tests/transformations/prune_connectors_test.py @@ -4,6 +4,8 @@ import os import copy import pytest +from typing import Tuple + import dace from dace.transformation.dataflow import PruneConnectors from dace.transformation.helpers import nest_state_subgraph @@ -137,17 +139,109 @@ def make_sdfg(): return sdfg_outer -@pytest.mark.parametrize("remove_unused_containers", [False, True]) -def test_prune_connectors(remove_unused_containers, n=None): +def _make_read_write_sdfg( + conforming_memlet: bool, +) -> Tuple[dace.SDFG, dace.nodes.NestedSDFG]: + """Creates an SDFG for the `test_read_write_{1, 2}` tests. + + The SDFG is rather synthetic, it has an input `in_arg` and adds to every element + 10 and stores that in array `A`, through access node `A1`. From this access node + the data flows into a nested SDFG. However, the data is not read but overwritten, + through a map that writes through access node `inner_A`. That access node + then writes into container `inner_B`. Both `inner_A` and `inner_B` are outputs + of the nested SDFG and are written back into data container `A` and `B`. + + Depending on `conforming_memlet` the memlet that copies `inner_A` into `inner_B` + will either be associated to `inner_A` (`True`) or `inner_B` (`False`). + This choice has consequences on if the transformation can apply or not. + + Notes: + This is most likely a bug, see [issue#1643](https://github.com/spcl/dace/issues/1643), + however, it is the historical behaviour. + """ + + # Creating the outer SDFG. + osdfg = dace.SDFG("Outer_sdfg") + ostate = osdfg.add_state(is_start_block=True) + + osdfg.add_array("in_arg", dtype=dace.float64, shape=(4, 4), transient=False) + osdfg.add_array("A", dtype=dace.float64, shape=(4, 4), transient=False) + osdfg.add_array("B", dtype=dace.float64, shape=(4, 4), transient=False) + in_arg, A1, A2, B = (ostate.add_access(name) for name in ["in_arg", "A", "A", "B"]) + + ostate.add_mapped_tasklet( + "producer", + map_ranges={"i": "0:4", "j": "0:4"}, + inputs={"__in": dace.Memlet("in_arg[i, j]")}, + code="__out = __in + 10.", + outputs={"__out": dace.Memlet("A[i, j]")}, + input_nodes={in_arg}, + output_nodes={A1}, + external_edges=True, + ) + + # Creating the inner SDFG + isdfg = dace.SDFG("Inner_sdfg") + istate = isdfg.add_state(is_start_block=True) + + isdfg.add_array("inner_A", dtype=dace.float64, shape=(4, 4), transient=False) + isdfg.add_array("inner_B", dtype=dace.float64, shape=(4, 4), transient=False) + inner_A, inner_B = (istate.add_access(name) for name in ["inner_A", "inner_B"]) + + istate.add_mapped_tasklet( + "inner_consumer", + map_ranges={"i": "0:4", "j": "0:4"}, + inputs={}, + code="__out = 10", + outputs={"__out": dace.Memlet("inner_A[i, j]")}, + output_nodes={inner_A}, + external_edges=True, + ) + + # Depending on to which data container this memlet is associated, + # the transformation will apply or it will not apply. + if conforming_memlet: + # Because the `data` field of the inncoming and outgoing memlet are both + # set to `inner_A` the read to `inner_A` will be removed and the + # transformation can apply. + istate.add_nedge( + inner_A, + inner_B, + dace.Memlet("inner_A[0:4, 0:4] -> 0:4, 0:4"), + ) + else: + # Because the `data` filed of the involved memlets differs the read to + # `inner_A` will not be removed thus the transformation can not remove + # the incoming `inner_A`. + istate.add_nedge( + inner_A, + inner_B, + dace.Memlet("inner_B[0:4, 0:4] -> 0:4, 0:4"), + ) + + # Add the nested SDFG + nsdfg = ostate.add_nested_sdfg( + sdfg=isdfg, + parent=osdfg, + inputs={"inner_A"}, + outputs={"inner_A", "inner_B"}, + ) + + # Connecting the nested SDFG + ostate.add_edge(A1, None, nsdfg, "inner_A", dace.Memlet("A[0:4, 0:4]")) + ostate.add_edge(nsdfg, "inner_A", A2, None, dace.Memlet("A[0:4, 0:4]")) + ostate.add_edge(nsdfg, "inner_B", B, None, dace.Memlet("B[0:4, 0:4]")) + + return osdfg, nsdfg + + +def test_prune_connectors(n=None): if n is None: n = 64 sdfg = make_sdfg() - if sdfg.apply_transformations_repeated(PruneConnectors, - options=[{ - 'remove_unused_containers': remove_unused_containers - }]) != 3: + if sdfg.apply_transformations_repeated(PruneConnectors) != 3: raise RuntimeError("PruneConnectors was not applied.") arr_in = np.zeros((n, n), dtype=np.uint16) @@ -158,18 +252,16 @@ def test_prune_connectors(remove_unused_containers, n=None): except FileNotFoundError: pass - if remove_unused_containers: - sdfg(read_used=arr_in, write_used=arr_out, N=n) - else: - sdfg(read_used=arr_in, - read_unused=arr_in, - read_used_outer=arr_in, - read_unused_outer=arr_in, - write_used=arr_out, - write_unused=arr_out, - write_used_outer=arr_out, - write_unused_outer=arr_out, - N=n) + # The pruned connectors are not removed so they have to be supplied. + sdfg(read_used=arr_in, + read_unused=arr_in, + read_used_outer=arr_in, + read_unused_outer=arr_in, + write_used=arr_out, + write_unused=arr_out, + write_used_outer=arr_out, + write_unused_outer=arr_out, + N=n) assert np.allclose(arr_out, arr_in + 1) @@ -240,6 +332,16 @@ def test_unused_retval_2(): assert np.allclose(a, 1) +def test_read_write_1(): + # Because the memlet is conforming, we can apply the transformation. + sdfg = _make_read_write_sdfg(True) + + assert first_mode == PruneConnectors.can_be_applied_to(nsdfg=nsdfg, sdfg=osdfg, expr_index=0, permissive=False) + + + + + def test_prune_connectors_with_dependencies(): sdfg = dace.SDFG('tester') A, A_desc = sdfg.add_array('A', [4], dace.float64) @@ -318,6 +420,21 @@ def test_prune_connectors_with_dependencies(): assert np.allclose(np_d, np_d_) +def test_read_write_1(): + # Because the memlet is conforming, we can apply the transformation. + sdfg, nsdfg = _make_read_write_sdfg(True) + + assert PruneConnectors.can_be_applied_to(nsdfg=nsdfg, sdfg=sdfg, expr_index=0, permissive=False) + sdfg.apply_transformations_repeated(PruneConnectors, validate=True, validate_all=True) + + +def test_read_write_2(): + # Because the memlet is not conforming, we can not apply the transformation. + sdfg, nsdfg = _make_read_write_sdfg(False) + + assert not PruneConnectors.can_be_applied_to(nsdfg=nsdfg, sdfg=sdfg, expr_index=0, permissive=False) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--N", default=64) @@ -330,3 +447,5 @@ def test_prune_connectors_with_dependencies(): test_unused_retval() test_unused_retval_2() test_prune_connectors_with_dependencies() + test_read_write_1() + test_read_write_2()