From edbf49f2339e487d2cabac112bce011dce580dc6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lukas=20Tr=C3=BCmper?= Date: Wed, 29 Nov 2023 20:27:20 +0100 Subject: [PATCH] PruneConnectors: Fission into separate states before pruning (#1451) The PruneConnectors transformation currently avoids pruning connectors of access nodes which are connected to other nodes again. Fissioning first, pruning, and then fusing states simplifies the whole problem, because we can simply use the analysis implemented in StateFusion --- .../dataflow/prune_connectors.py | 87 +++++++++++-------- .../npbench/polybench/floyd_warshall_test.py | 14 ++- .../transformations/prune_connectors_test.py | 82 +++++++++++++++++ 3 files changed, 146 insertions(+), 37 deletions(-) diff --git a/dace/transformation/dataflow/prune_connectors.py b/dace/transformation/dataflow/prune_connectors.py index ecc89bc753..865f28f7d9 100644 --- a/dace/transformation/dataflow/prune_connectors.py +++ b/dace/transformation/dataflow/prune_connectors.py @@ -1,12 +1,12 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -from os import stat -from typing import Any, AnyStr, Dict, Optional, Set, Tuple, Union +from typing import Set, Tuple import re -from dace import dtypes, registry, SDFG, SDFGState, symbolic, properties, data as dt +from dace import dtypes, SDFG, SDFGState, symbolic, properties, data as dt from dace.transformation import transformation as pm, helpers from dace.sdfg import nodes, utils from dace.sdfg.analysis import cfg +from dace.sdfg.state import StateSubgraphView @properties.make_properties @@ -46,23 +46,52 @@ def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissi # Add WCR outputs to "do not prune" input list for e in graph.out_edges(nsdfg): if e.data.wcr is not None and e.src_conn in prune_in: - if (graph.in_degree(next(iter(graph.in_edges_by_connector(nsdfg, e.src_conn))).src) > 0): - prune_in.remove(e.src_conn) - has_before = all( - graph.in_degree(graph.memlet_path(e)[0].src) > 0 for e in graph.in_edges(nsdfg) if e.dst_conn in prune_in) - has_after = all( - graph.out_degree(graph.memlet_path(e)[-1].dst) > 0 for e in graph.out_edges(nsdfg) - if e.src_conn in prune_out) - if has_before and has_after: + prune_in.remove(e.src_conn) + + if not prune_in and not prune_out: return False - if len(prune_in) > 0 or len(prune_out) > 0: - return True - return False + return True def apply(self, state: SDFGState, sdfg: SDFG): nsdfg = self.nsdfg + # Fission subgraph around nsdfg into its own state to avoid data races + predecessors = set() + for inedge in state.in_edges(nsdfg): + if inedge.data is None: + continue + + pred = state.memlet_path(inedge)[0].src + if state.in_degree(pred) == 0: + continue + + predecessors.add(pred) + for e in state.bfs_edges(pred, reverse=True): + predecessors.add(e.src) + + subgraph = StateSubgraphView(state, predecessors) + pred_state = helpers.state_fission(sdfg, subgraph) + + subgraph_nodes = set() + subgraph_nodes.add(nsdfg) + for inedge in state.in_edges(nsdfg): + if inedge.data is None: + continue + path = state.memlet_path(inedge) + for edge in path: + subgraph_nodes.add(edge.src) + + for oedge in state.out_edges(nsdfg): + if oedge.data is None: + continue + path = state.memlet_path(oedge) + for edge in path: + subgraph_nodes.add(edge.dst) + + subgraph = StateSubgraphView(state, subgraph_nodes) + nsdfg_state = helpers.state_fission(sdfg, subgraph) + read_set, write_set = nsdfg.sdfg.read_and_write_sets() prune_in = nsdfg.in_connectors.keys() - read_set prune_out = nsdfg.out_connectors.keys() - write_set @@ -70,36 +99,26 @@ def apply(self, state: SDFGState, sdfg: SDFG): # Detect which nodes are used, so we can delete unused nodes after the # connectors have been pruned all_data_used = read_set | write_set + # Add WCR outputs to "do not prune" input list - for e in state.out_edges(nsdfg): + for e in nsdfg_state.out_edges(nsdfg): if e.data.wcr is not None and e.src_conn in prune_in: - if (state.in_degree(next(iter(state.in_edges_by_connector(nsdfg, e.src_conn))).src) > 0): - prune_in.remove(e.src_conn) - do_not_prune = set() + prune_in.remove(e.src_conn) + for conn in prune_in: - if any( - state.in_degree(state.memlet_path(e)[0].src) > 0 for e in state.in_edges(nsdfg) - if e.dst_conn == conn): - do_not_prune.add(conn) - continue - for e in state.in_edges_by_connector(nsdfg, conn): - state.remove_memlet_path(e, remove_orphans=True) + for e in nsdfg_state.in_edges_by_connector(nsdfg, conn): + nsdfg_state.remove_memlet_path(e, remove_orphans=True) for conn in prune_out: - if any( - state.out_degree(state.memlet_path(e)[-1].dst) > 0 for e in state.out_edges(nsdfg) - if e.src_conn == conn): - do_not_prune.add(conn) - continue - for e in state.out_edges_by_connector(nsdfg, conn): - state.remove_memlet_path(e, remove_orphans=True) + for e in nsdfg_state.out_edges_by_connector(nsdfg, conn): + nsdfg_state.remove_memlet_path(e, remove_orphans=True) for conn in prune_in: - if conn in nsdfg.sdfg.arrays and conn not in all_data_used and conn not in do_not_prune: + if conn in nsdfg.sdfg.arrays and conn not in all_data_used: # If the data is now unused, we can purge it from the SDFG nsdfg.sdfg.remove_data(conn) for conn in prune_out: - if conn in nsdfg.sdfg.arrays and conn not in all_data_used and conn not in do_not_prune: + if conn in nsdfg.sdfg.arrays and conn not in all_data_used: # If the data is now unused, we can purge it from the SDFG nsdfg.sdfg.remove_data(conn) diff --git a/tests/npbench/polybench/floyd_warshall_test.py b/tests/npbench/polybench/floyd_warshall_test.py index a95a417a19..7bd1e3d91d 100644 --- a/tests/npbench/polybench/floyd_warshall_test.py +++ b/tests/npbench/polybench/floyd_warshall_test.py @@ -7,7 +7,7 @@ import pytest import argparse from dace.fpga_testing import fpga_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG +from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG, StateFusion from dace.transformation.dataflow import StreamingMemory, MapFusion, StreamingComposition, PruneConnectors from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt @@ -91,15 +91,23 @@ def run_floyd_warshall(device_type: dace.dtypes.DeviceType): }]) assert pruned_conns == 1 + sdfg.apply_transformations_repeated(StateFusion) fpga_auto_opt.fpga_rr_interleave_containers_to_banks(sdfg) # In this case, we want to generate the top-level state as an host-based state, # not an FPGA kernel. We need to explicitly indicate that - sdfg.states()[0].location["is_FPGA_kernel"] = False + for state in sdfg.states(): + if any([isinstance(node, dace.nodes.NestedSDFG) for node in state.nodes()]): + state.location["is_FPGA_kernel"] = False + # we need to specialize both the top-level SDFG and the nested SDFG sdfg.specialize(dict(N=N)) - sdfg.states()[0].nodes()[0].sdfg.specialize(dict(N=N)) + for state in sdfg.states(): + for node in state.nodes(): + if isinstance(node, dace.nodes.NestedSDFG): + node.sdfg.specialize(dict(N=N)) + # run program sdfg(path=path) diff --git a/tests/transformations/prune_connectors_test.py b/tests/transformations/prune_connectors_test.py index 1b9ee4369d..e9c7e34a83 100644 --- a/tests/transformations/prune_connectors_test.py +++ b/tests/transformations/prune_connectors_test.py @@ -2,9 +2,12 @@ import argparse import numpy as np import os +import copy import pytest import dace from dace.transformation.dataflow import PruneConnectors +from dace.transformation.helpers import nest_state_subgraph +from dace.sdfg.state import StateSubgraphView def make_sdfg(): @@ -237,6 +240,84 @@ def test_unused_retval_2(): assert np.allclose(a, 1) +def test_prune_connectors_with_dependencies(): + sdfg = dace.SDFG('tester') + A, A_desc = sdfg.add_array('A', [4], dace.float64) + B, B_desc = sdfg.add_array('B', [4], dace.float64) + C, C_desc = sdfg.add_array('C', [4], dace.float64) + D, D_desc = sdfg.add_array('D', [4], dace.float64) + + state = sdfg.add_state() + a = state.add_access("A") + b1 = state.add_access("B") + b2 = state.add_access("B") + c1 = state.add_access("C") + c2 = state.add_access("C") + d = state.add_access("D") + + _, map_entry_a, map_exit_a = state.add_mapped_tasklet("a", + map_ranges={"i": "0:4"}, + inputs={"_in": dace.Memlet(data="A", subset='i')}, + outputs={"_out": dace.Memlet(data="B", subset='i')}, + code="_out = _in + 1") + state.add_edge(a, None, map_entry_a, None, dace.Memlet(data="A", subset="0:4")) + state.add_edge(map_exit_a, None, b1, None, dace.Memlet(data="B", subset="0:4")) + + tasklet_c, map_entry_c, map_exit_c = state.add_mapped_tasklet("c", + map_ranges={"i": "0:4"}, + inputs={"_in": dace.Memlet(data="C", subset='i')}, + outputs={"_out": dace.Memlet(data="C", subset='i')}, + code="_out = _in + 1") + state.add_edge(c1, None, map_entry_c, None, dace.Memlet(data="C", subset="0:4")) + state.add_edge(map_exit_c, None, c2, None, dace.Memlet(data="C", subset="0:4")) + + _, map_entry_d, map_exit_d = state.add_mapped_tasklet("d", + map_ranges={"i": "0:4"}, + inputs={"_in": dace.Memlet(data="B", subset='i')}, + outputs={"_out": dace.Memlet(data="D", subset='i')}, + code="_out = _in + 1") + state.add_edge(b2, None, map_entry_d, None, dace.Memlet(data="B", subset="0:4")) + state.add_edge(map_exit_d, None, d, None, dace.Memlet(data="D", subset="0:4")) + + sdfg.fill_scope_connectors() + + subgraph = StateSubgraphView(state, subgraph_nodes=[map_entry_c, map_exit_c, tasklet_c]) + nsdfg_node = nest_state_subgraph(sdfg, state, subgraph=subgraph) + + nsdfg_node.sdfg.add_datadesc("B1", datadesc=copy.deepcopy(B_desc)) + nsdfg_node.sdfg.arrays["B1"].transient = False + nsdfg_node.sdfg.add_datadesc("B2", datadesc=copy.deepcopy(B_desc)) + nsdfg_node.sdfg.arrays["B2"].transient = False + + nsdfg_node.add_in_connector("B1") + state.add_edge(b1, None, nsdfg_node, "B1", dace.Memlet.from_array(dataname="B", datadesc=B_desc)) + nsdfg_node.add_out_connector("B2") + state.add_edge(nsdfg_node, "B2", b2, None, dace.Memlet.from_array(dataname="B", datadesc=B_desc)) + + np_a = np.random.random(4) + np_a_ = np.copy(np_a) + np_b = np.random.random(4) + np_b_ = np.copy(np_b) + np_c = np.random.random(4) + np_c_ = np.copy(np_c) + np_d = np.random.random(4) + np_d_ = np.copy(np_d) + + sdfg(A=np_a, B=np_b, C=np_c, D=np_d) + + applied = sdfg.apply_transformations_repeated(PruneConnectors) + assert applied == 1 + assert len(sdfg.states()) == 3 + assert "B1" not in nsdfg_node.in_connectors + assert "B2" not in nsdfg_node.out_connectors + + sdfg(A=np_a_, B=np_b_, C=np_c_, D=np_d_) + assert np.allclose(np_a, np_a_) + assert np.allclose(np_b, np_b_) + assert np.allclose(np_c, np_c_) + assert np.allclose(np_d, np_d_) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--N", default=64) @@ -248,3 +329,4 @@ def test_unused_retval_2(): test_prune_connectors(True, n=n) test_unused_retval() test_unused_retval_2() + test_prune_connectors_with_dependencies()