diff --git a/dace/transformation/dataflow/prune_connectors.py b/dace/transformation/dataflow/prune_connectors.py index ecc89bc753..865f28f7d9 100644 --- a/dace/transformation/dataflow/prune_connectors.py +++ b/dace/transformation/dataflow/prune_connectors.py @@ -1,12 +1,12 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -from os import stat -from typing import Any, AnyStr, Dict, Optional, Set, Tuple, Union +from typing import Set, Tuple import re -from dace import dtypes, registry, SDFG, SDFGState, symbolic, properties, data as dt +from dace import dtypes, SDFG, SDFGState, symbolic, properties, data as dt from dace.transformation import transformation as pm, helpers from dace.sdfg import nodes, utils from dace.sdfg.analysis import cfg +from dace.sdfg.state import StateSubgraphView @properties.make_properties @@ -46,23 +46,52 @@ def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissi # Add WCR outputs to "do not prune" input list for e in graph.out_edges(nsdfg): if e.data.wcr is not None and e.src_conn in prune_in: - if (graph.in_degree(next(iter(graph.in_edges_by_connector(nsdfg, e.src_conn))).src) > 0): - prune_in.remove(e.src_conn) - has_before = all( - graph.in_degree(graph.memlet_path(e)[0].src) > 0 for e in graph.in_edges(nsdfg) if e.dst_conn in prune_in) - has_after = all( - graph.out_degree(graph.memlet_path(e)[-1].dst) > 0 for e in graph.out_edges(nsdfg) - if e.src_conn in prune_out) - if has_before and has_after: + prune_in.remove(e.src_conn) + + if not prune_in and not prune_out: return False - if len(prune_in) > 0 or len(prune_out) > 0: - return True - return False + return True def apply(self, state: SDFGState, sdfg: SDFG): nsdfg = self.nsdfg + # Fission subgraph around nsdfg into its own state to avoid data races + predecessors = set() + for inedge in state.in_edges(nsdfg): + if inedge.data is None: + continue + + pred = state.memlet_path(inedge)[0].src + if state.in_degree(pred) == 0: + continue + + predecessors.add(pred) + for e in state.bfs_edges(pred, reverse=True): + predecessors.add(e.src) + + subgraph = StateSubgraphView(state, predecessors) + pred_state = helpers.state_fission(sdfg, subgraph) + + subgraph_nodes = set() + subgraph_nodes.add(nsdfg) + for inedge in state.in_edges(nsdfg): + if inedge.data is None: + continue + path = state.memlet_path(inedge) + for edge in path: + subgraph_nodes.add(edge.src) + + for oedge in state.out_edges(nsdfg): + if oedge.data is None: + continue + path = state.memlet_path(oedge) + for edge in path: + subgraph_nodes.add(edge.dst) + + subgraph = StateSubgraphView(state, subgraph_nodes) + nsdfg_state = helpers.state_fission(sdfg, subgraph) + read_set, write_set = nsdfg.sdfg.read_and_write_sets() prune_in = nsdfg.in_connectors.keys() - read_set prune_out = nsdfg.out_connectors.keys() - write_set @@ -70,36 +99,26 @@ def apply(self, state: SDFGState, sdfg: SDFG): # Detect which nodes are used, so we can delete unused nodes after the # connectors have been pruned all_data_used = read_set | write_set + # Add WCR outputs to "do not prune" input list - for e in state.out_edges(nsdfg): + for e in nsdfg_state.out_edges(nsdfg): if e.data.wcr is not None and e.src_conn in prune_in: - if (state.in_degree(next(iter(state.in_edges_by_connector(nsdfg, e.src_conn))).src) > 0): - prune_in.remove(e.src_conn) - do_not_prune = set() + prune_in.remove(e.src_conn) + for conn in prune_in: - if any( - state.in_degree(state.memlet_path(e)[0].src) > 0 for e in state.in_edges(nsdfg) - if e.dst_conn == conn): - do_not_prune.add(conn) - continue - for e in state.in_edges_by_connector(nsdfg, conn): - state.remove_memlet_path(e, remove_orphans=True) + for e in nsdfg_state.in_edges_by_connector(nsdfg, conn): + nsdfg_state.remove_memlet_path(e, remove_orphans=True) for conn in prune_out: - if any( - state.out_degree(state.memlet_path(e)[-1].dst) > 0 for e in state.out_edges(nsdfg) - if e.src_conn == conn): - do_not_prune.add(conn) - continue - for e in state.out_edges_by_connector(nsdfg, conn): - state.remove_memlet_path(e, remove_orphans=True) + for e in nsdfg_state.out_edges_by_connector(nsdfg, conn): + nsdfg_state.remove_memlet_path(e, remove_orphans=True) for conn in prune_in: - if conn in nsdfg.sdfg.arrays and conn not in all_data_used and conn not in do_not_prune: + if conn in nsdfg.sdfg.arrays and conn not in all_data_used: # If the data is now unused, we can purge it from the SDFG nsdfg.sdfg.remove_data(conn) for conn in prune_out: - if conn in nsdfg.sdfg.arrays and conn not in all_data_used and conn not in do_not_prune: + if conn in nsdfg.sdfg.arrays and conn not in all_data_used: # If the data is now unused, we can purge it from the SDFG nsdfg.sdfg.remove_data(conn) diff --git a/tests/npbench/polybench/floyd_warshall_test.py b/tests/npbench/polybench/floyd_warshall_test.py index a95a417a19..7bd1e3d91d 100644 --- a/tests/npbench/polybench/floyd_warshall_test.py +++ b/tests/npbench/polybench/floyd_warshall_test.py @@ -7,7 +7,7 @@ import pytest import argparse from dace.fpga_testing import fpga_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG +from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG, StateFusion from dace.transformation.dataflow import StreamingMemory, MapFusion, StreamingComposition, PruneConnectors from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt @@ -91,15 +91,23 @@ def run_floyd_warshall(device_type: dace.dtypes.DeviceType): }]) assert pruned_conns == 1 + sdfg.apply_transformations_repeated(StateFusion) fpga_auto_opt.fpga_rr_interleave_containers_to_banks(sdfg) # In this case, we want to generate the top-level state as an host-based state, # not an FPGA kernel. We need to explicitly indicate that - sdfg.states()[0].location["is_FPGA_kernel"] = False + for state in sdfg.states(): + if any([isinstance(node, dace.nodes.NestedSDFG) for node in state.nodes()]): + state.location["is_FPGA_kernel"] = False + # we need to specialize both the top-level SDFG and the nested SDFG sdfg.specialize(dict(N=N)) - sdfg.states()[0].nodes()[0].sdfg.specialize(dict(N=N)) + for state in sdfg.states(): + for node in state.nodes(): + if isinstance(node, dace.nodes.NestedSDFG): + node.sdfg.specialize(dict(N=N)) + # run program sdfg(path=path) diff --git a/tests/transformations/prune_connectors_test.py b/tests/transformations/prune_connectors_test.py index 1b9ee4369d..e9c7e34a83 100644 --- a/tests/transformations/prune_connectors_test.py +++ b/tests/transformations/prune_connectors_test.py @@ -2,9 +2,12 @@ import argparse import numpy as np import os +import copy import pytest import dace from dace.transformation.dataflow import PruneConnectors +from dace.transformation.helpers import nest_state_subgraph +from dace.sdfg.state import StateSubgraphView def make_sdfg(): @@ -237,6 +240,84 @@ def test_unused_retval_2(): assert np.allclose(a, 1) +def test_prune_connectors_with_dependencies(): + sdfg = dace.SDFG('tester') + A, A_desc = sdfg.add_array('A', [4], dace.float64) + B, B_desc = sdfg.add_array('B', [4], dace.float64) + C, C_desc = sdfg.add_array('C', [4], dace.float64) + D, D_desc = sdfg.add_array('D', [4], dace.float64) + + state = sdfg.add_state() + a = state.add_access("A") + b1 = state.add_access("B") + b2 = state.add_access("B") + c1 = state.add_access("C") + c2 = state.add_access("C") + d = state.add_access("D") + + _, map_entry_a, map_exit_a = state.add_mapped_tasklet("a", + map_ranges={"i": "0:4"}, + inputs={"_in": dace.Memlet(data="A", subset='i')}, + outputs={"_out": dace.Memlet(data="B", subset='i')}, + code="_out = _in + 1") + state.add_edge(a, None, map_entry_a, None, dace.Memlet(data="A", subset="0:4")) + state.add_edge(map_exit_a, None, b1, None, dace.Memlet(data="B", subset="0:4")) + + tasklet_c, map_entry_c, map_exit_c = state.add_mapped_tasklet("c", + map_ranges={"i": "0:4"}, + inputs={"_in": dace.Memlet(data="C", subset='i')}, + outputs={"_out": dace.Memlet(data="C", subset='i')}, + code="_out = _in + 1") + state.add_edge(c1, None, map_entry_c, None, dace.Memlet(data="C", subset="0:4")) + state.add_edge(map_exit_c, None, c2, None, dace.Memlet(data="C", subset="0:4")) + + _, map_entry_d, map_exit_d = state.add_mapped_tasklet("d", + map_ranges={"i": "0:4"}, + inputs={"_in": dace.Memlet(data="B", subset='i')}, + outputs={"_out": dace.Memlet(data="D", subset='i')}, + code="_out = _in + 1") + state.add_edge(b2, None, map_entry_d, None, dace.Memlet(data="B", subset="0:4")) + state.add_edge(map_exit_d, None, d, None, dace.Memlet(data="D", subset="0:4")) + + sdfg.fill_scope_connectors() + + subgraph = StateSubgraphView(state, subgraph_nodes=[map_entry_c, map_exit_c, tasklet_c]) + nsdfg_node = nest_state_subgraph(sdfg, state, subgraph=subgraph) + + nsdfg_node.sdfg.add_datadesc("B1", datadesc=copy.deepcopy(B_desc)) + nsdfg_node.sdfg.arrays["B1"].transient = False + nsdfg_node.sdfg.add_datadesc("B2", datadesc=copy.deepcopy(B_desc)) + nsdfg_node.sdfg.arrays["B2"].transient = False + + nsdfg_node.add_in_connector("B1") + state.add_edge(b1, None, nsdfg_node, "B1", dace.Memlet.from_array(dataname="B", datadesc=B_desc)) + nsdfg_node.add_out_connector("B2") + state.add_edge(nsdfg_node, "B2", b2, None, dace.Memlet.from_array(dataname="B", datadesc=B_desc)) + + np_a = np.random.random(4) + np_a_ = np.copy(np_a) + np_b = np.random.random(4) + np_b_ = np.copy(np_b) + np_c = np.random.random(4) + np_c_ = np.copy(np_c) + np_d = np.random.random(4) + np_d_ = np.copy(np_d) + + sdfg(A=np_a, B=np_b, C=np_c, D=np_d) + + applied = sdfg.apply_transformations_repeated(PruneConnectors) + assert applied == 1 + assert len(sdfg.states()) == 3 + assert "B1" not in nsdfg_node.in_connectors + assert "B2" not in nsdfg_node.out_connectors + + sdfg(A=np_a_, B=np_b_, C=np_c_, D=np_d_) + assert np.allclose(np_a, np_a_) + assert np.allclose(np_b, np_b_) + assert np.allclose(np_c, np_c_) + assert np.allclose(np_d, np_d_) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--N", default=64) @@ -248,3 +329,4 @@ def test_unused_retval_2(): test_prune_connectors(True, n=n) test_unused_retval() test_unused_retval_2() + test_prune_connectors_with_dependencies()