diff --git a/dace/sdfg/utils.py b/dace/sdfg/utils.py index 68980c3b10..4746fefe97 100644 --- a/dace/sdfg/utils.py +++ b/dace/sdfg/utils.py @@ -1295,32 +1295,43 @@ def inline_sdfgs(sdfg: SDFG, permissive: bool = False, progress: bool = None, mu from dace.transformation.interstate import InlineSDFG, InlineMultistateSDFG counter = 0 - nsdfgs = [(n, p) for n, p in sdfg.all_nodes_recursive() if isinstance(n, NestedSDFG)] - - for node, state in optional_progressbar(reversed(nsdfgs), title='Inlining SDFGs', n=len(nsdfgs), progress=progress): - id = node.sdfg.cfg_id - sd = state.parent + nsdfgs = [n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, NestedSDFG)] + for nsdfg_node in optional_progressbar(reversed(nsdfgs), title='Inlining SDFGs', n=len(nsdfgs), progress=progress): # We have to reevaluate every time due to changing IDs - state_id = sd.node_id(state) + # e.g., InlineMultistateSDFG may fission states + parent_state = nsdfg_node.sdfg.parent + parent_sdfg = parent_state.parent + parent_state_id = parent_sdfg.node_id(parent_state) + if multistate: candidate = { - InlineMultistateSDFG.nested_sdfg: node, + InlineMultistateSDFG.nested_sdfg: nsdfg_node, } inliner = InlineMultistateSDFG() - inliner.setup_match(sd, id, state_id, candidate, 0, override=True) - if inliner.can_be_applied(state, 0, sd, permissive=permissive): - inliner.apply(state, sd) + inliner.setup_match(sdfg=parent_sdfg, + cfg_id=parent_sdfg.sdfg_id, + state_id=parent_state_id, + subgraph=candidate, + expr_index=0, + override=True) + if inliner.can_be_applied(parent_state, 0, parent_sdfg, permissive=permissive): + inliner.apply(parent_state, parent_sdfg) counter += 1 continue candidate = { - InlineSDFG.nested_sdfg: node, + InlineSDFG.nested_sdfg: nsdfg_node, } inliner = InlineSDFG() - inliner.setup_match(sd, id, state_id, candidate, 0, override=True) - if inliner.can_be_applied(state, 0, sd, permissive=permissive): - inliner.apply(state, sd) + inliner.setup_match(sdfg=parent_sdfg, + cfg_id=parent_sdfg.sdfg_id, + state_id=parent_state_id, + subgraph=candidate, + expr_index=0, + override=True) + if inliner.can_be_applied(parent_state, 0, parent_sdfg, permissive=permissive): + inliner.apply(parent_state, parent_sdfg) counter += 1 return counter diff --git a/dace/transformation/dataflow/prune_connectors.py b/dace/transformation/dataflow/prune_connectors.py index 865f28f7d9..36352fef0d 100644 --- a/dace/transformation/dataflow/prune_connectors.py +++ b/dace/transformation/dataflow/prune_connectors.py @@ -57,40 +57,7 @@ def apply(self, state: SDFGState, sdfg: SDFG): nsdfg = self.nsdfg # Fission subgraph around nsdfg into its own state to avoid data races - predecessors = set() - for inedge in state.in_edges(nsdfg): - if inedge.data is None: - continue - - pred = state.memlet_path(inedge)[0].src - if state.in_degree(pred) == 0: - continue - - predecessors.add(pred) - for e in state.bfs_edges(pred, reverse=True): - predecessors.add(e.src) - - subgraph = StateSubgraphView(state, predecessors) - pred_state = helpers.state_fission(sdfg, subgraph) - - subgraph_nodes = set() - subgraph_nodes.add(nsdfg) - for inedge in state.in_edges(nsdfg): - if inedge.data is None: - continue - path = state.memlet_path(inedge) - for edge in path: - subgraph_nodes.add(edge.src) - - for oedge in state.out_edges(nsdfg): - if oedge.data is None: - continue - path = state.memlet_path(oedge) - for edge in path: - subgraph_nodes.add(edge.dst) - - subgraph = StateSubgraphView(state, subgraph_nodes) - nsdfg_state = helpers.state_fission(sdfg, subgraph) + nsdfg_state = helpers.state_fission_after(sdfg, state, nsdfg) read_set, write_set = nsdfg.sdfg.read_and_write_sets() prune_in = nsdfg.in_connectors.keys() - read_set diff --git a/dace/transformation/dataflow/wcr_conversion.py b/dace/transformation/dataflow/wcr_conversion.py index 3ef508f7e5..1a0ecf6bc4 100644 --- a/dace/transformation/dataflow/wcr_conversion.py +++ b/dace/transformation/dataflow/wcr_conversion.py @@ -77,8 +77,7 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): # If in map, only match if the subset is independent of any # map indices (otherwise no conflict) - if not permissive and len(outedge.data.subset.free_symbols & set(me.map.params)) == len( - me.map.params): + if not permissive and len(outedge.data.subset.free_symbols & set(me.map.params)) == len(me.map.params): return False # Get relevant output connector @@ -151,18 +150,16 @@ def apply(self, state: SDFGState, sdfg: SDFG): # If state fission is necessary to keep semantics, do it first if state.in_degree(input) > 0: - subgraph_nodes = set([e.src for e in state.bfs_edges(input, reverse=True)]) - subgraph_nodes.add(input) - - subgraph = StateSubgraphView(state, subgraph_nodes) - helpers.state_fission(sdfg, subgraph) + new_state = helpers.state_fission_after(sdfg, state, tasklet) + else: + new_state = state if self.expr_index == 0: - inedges = state.edges_between(input, tasklet) - outedge = state.edges_between(tasklet, output)[0] + inedges = new_state.edges_between(input, tasklet) + outedge = new_state.edges_between(tasklet, output)[0] else: - inedges = state.edges_between(me, tasklet) - outedge = state.edges_between(tasklet, mx)[0] + inedges = new_state.edges_between(me, tasklet) + outedge = new_state.edges_between(tasklet, mx)[0] # Get relevant output connector outconn = outedge.src_conn @@ -253,8 +250,8 @@ def apply(self, state: SDFGState, sdfg: SDFG): outedge.data.wcr = f'lambda a,b: a {op} b' # Remove input node and connector - state.remove_memlet_path(inedge) - propagate_memlets_state(sdfg, state) + new_state.remove_memlet_path(inedge) + propagate_memlets_state(sdfg, new_state) # If outedge leads to non-transient, and this is a nested SDFG, # propagate outwards diff --git a/dace/transformation/helpers.py b/dace/transformation/helpers.py index c39d744c39..cd73b96a68 100644 --- a/dace/transformation/helpers.py +++ b/dace/transformation/helpers.py @@ -687,6 +687,85 @@ def state_fission(sdfg: SDFG, subgraph: graph.SubgraphView, label: Optional[str] return newstate +def state_fission_after(sdfg: SDFG, state: SDFGState, node: nodes.Node, label: Optional[str] = None) -> SDFGState: + """ + """ + newstate = sdfg.add_state_after(state, label=label) + + # Bookkeeping + nodes_to_move = set([node]) + boundary_nodes = set() + orig_edges = set() + + # Collect predecessors + if not isinstance(node, nodes.AccessNode): + for edge in state.in_edges(node): + for e in state.memlet_path(edge): + nodes_to_move.add(e.src) + orig_edges.add(e) + + # Collect nodes_to_move + for edge in state.bfs_edges(node): + nodes_to_move.add(edge.dst) + orig_edges.add(edge) + + if not isinstance(edge.dst, nodes.AccessNode): + for iedge in state.in_edges(edge.dst): + if iedge == edge: + continue + + for e in state.memlet_path(iedge): + nodes_to_move.add(e.src) + orig_edges.add(e) + + # Define boundary nodes + for node in set(nodes_to_move): + if isinstance(node, nodes.AccessNode): + for iedge in state.in_edges(node): + if iedge.src not in nodes_to_move: + boundary_nodes.add(node) + break + + if node in boundary_nodes: + continue + + for oedge in state.out_edges(node): + if oedge.dst not in nodes_to_move: + boundary_nodes.add(node) + break + + # Duplicate boundary nodes + new_nodes = {} + for node in boundary_nodes: + node_ = copy.deepcopy(node) + state.add_node(node_) + new_nodes[node] = node_ + + for edge in state.edges(): + if edge.src in boundary_nodes and edge.dst in boundary_nodes: + state.add_edge(new_nodes[edge.src], edge.src_conn, new_nodes[edge.dst], edge.dst_conn, + copy.deepcopy(edge.data)) + elif edge.src in boundary_nodes: + state.add_edge(new_nodes[edge.src], edge.src_conn, edge.dst, edge.dst_conn, copy.deepcopy(edge.data)) + elif edge.dst in boundary_nodes: + state.add_edge(edge.src, edge.src_conn, new_nodes[edge.dst], edge.dst_conn, copy.deepcopy(edge.data)) + + # Move nodes + state.remove_nodes_from(nodes_to_move) + + for n in nodes_to_move: + if isinstance(n, nodes.NestedSDFG): + # Set the new parent state + n.sdfg.parent = newstate + + newstate.add_nodes_from(nodes_to_move) + + for e in orig_edges: + newstate.add_edge(e.src, e.src_conn, e.dst, e.dst_conn, e.data) + + return newstate + + def _get_internal_subset(internal_memlet: Memlet, external_memlet: Memlet, use_src_subset: bool = False, diff --git a/dace/transformation/interstate/multistate_inline.py b/dace/transformation/interstate/multistate_inline.py index 8623bdf468..0e4f1b4852 100644 --- a/dace/transformation/interstate/multistate_inline.py +++ b/dace/transformation/interstate/multistate_inline.py @@ -20,6 +20,7 @@ from dace.transformation import transformation, helpers from dace.properties import make_properties, Property from dace import data +from dace.sdfg.state import StateSubgraphView @make_properties @@ -85,56 +86,48 @@ def can_be_applied(self, state: SDFGState, expr_index, sdfg, permissive=False): if nested_sdfg.schedule == dtypes.ScheduleType.FPGA_Device: return False - # Ensure the state only contains a nested SDFG and input/output access - # nodes - for node in state.nodes(): - if isinstance(node, nodes.NestedSDFG): - if node is not nested_sdfg: - return False - elif isinstance(node, nodes.AccessNode): - # Must be connected to nested SDFG - # if nested_sdfg in state.predecessors(nested_sdfg): - # if state.in_degree(node) > 0: - # return False - found = False - for e in state.out_edges(node): - if e.dst is not nested_sdfg: - return False - if state.in_degree(node) > 0: - return False - # Only accept full ranges for now. TODO(later): Improve - if e.data.subset != subsets.Range.from_array(sdfg.arrays[node.data]): - return False - if e.dst_conn in nested_sdfg.sdfg.arrays: - # Do not accept views. TODO(later): Improve - outer_desc = sdfg.arrays[node.data] - inner_desc = nested_sdfg.sdfg.arrays[e.dst_conn] - if (outer_desc.shape != inner_desc.shape or outer_desc.strides != inner_desc.strides): - return False - found = True - - for e in state.in_edges(node): - if e.src is not nested_sdfg: - return False - if state.out_degree(node) > 0: - return False - # Only accept full ranges for now. TODO(later): Improve - if e.data.subset != subsets.Range.from_array(sdfg.arrays[node.data]): - return False - if e.src_conn in nested_sdfg.sdfg.arrays: - # Do not accept views. TODO(later): Improve - outer_desc = sdfg.arrays[node.data] - inner_desc = nested_sdfg.sdfg.arrays[e.src_conn] - if (outer_desc.shape != inner_desc.shape or outer_desc.strides != inner_desc.strides): - return False - found = True - - # elif nested_sdfg in state.successors(nested_sdfg): - # if state.out_degree(node) > 0: - # return False - if not found: - return False - else: + # Not nested in scope + if state.entry_node(nested_sdfg) is not None: + return False + + # Must be + # - connected to access nodes only + # - read full subsets + # - not use views inside + for edge in state.in_edges(nested_sdfg): + if edge.data.data is None: + return False + + if not isinstance(edge.src, nodes.AccessNode): + return False + + if edge.data.subset != subsets.Range.from_array(sdfg.arrays[edge.data.data]): + return False + + outer_desc = sdfg.arrays[edge.data.data] + if isinstance(outer_desc, data.View): + return False + + inner_desc = nested_sdfg.sdfg.arrays[edge.dst_conn] + if (outer_desc.shape != inner_desc.shape or outer_desc.strides != inner_desc.strides): + return False + + for edge in state.out_edges(nested_sdfg): + if edge.data.data is None: + return False + + if not isinstance(edge.dst, nodes.AccessNode): + return False + + if edge.data.subset != subsets.Range.from_array(sdfg.arrays[edge.data.data]): + return False + + outer_desc = sdfg.arrays[edge.data.data] + if isinstance(outer_desc, data.View): + return False + + inner_desc = nested_sdfg.sdfg.arrays[edge.src_conn] + if (outer_desc.shape != inner_desc.shape or outer_desc.strides != inner_desc.strides): return False return True @@ -168,16 +161,27 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): for ise in sdfg.edges(): outer_symbols.update(ise.data.new_symbols(sdfg, outer_symbols)) + # Isolate nsdfg in a separate state + # 1. Push nsdfg node plus dependencies down into new state + nsdfg_state = helpers.state_fission_after(sdfg, outer_state, nsdfg_node) + # 2. Push successors of nsdfg node into a later state + direct_subgraph = set() + direct_subgraph.add(nsdfg_node) + direct_subgraph.update(nsdfg_state.predecessors(nsdfg_node)) + direct_subgraph.update(nsdfg_state.successors(nsdfg_node)) + direct_subgraph = StateSubgraphView(nsdfg_state, direct_subgraph) + nsdfg_state = helpers.state_fission(sdfg, direct_subgraph) + # Find original source/destination edges (there is only one edge per # connector, according to match) inputs: Dict[str, MultiConnectorEdge] = {} outputs: Dict[str, MultiConnectorEdge] = {} input_set: Dict[str, str] = {} output_set: Dict[str, str] = {} - for e in outer_state.in_edges(nsdfg_node): + for e in nsdfg_state.in_edges(nsdfg_node): inputs[e.dst_conn] = e input_set[e.data.data] = e.dst_conn - for e in outer_state.out_edges(nsdfg_node): + for e in nsdfg_state.out_edges(nsdfg_node): outputs[e.src_conn] = e output_set[e.data.data] = e.src_conn @@ -260,7 +264,6 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): name = sdfg.add_datadesc(new_name, datadesc, find_new_name=True) transients[edge.data.data] = name - # All constants (and associated transients) become constants of the parent for cstname, (csttype, cstval) in nsdfg.constants_prop.items(): if cstname in sdfg.constants: @@ -273,7 +276,6 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): else: sdfg.constants_prop[cstname] = (csttype, cstval) - ####################################################### # Replace data on inlined SDFG nodes/edges @@ -352,9 +354,9 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): sinks = nsdfg.sink_nodes() # Reconnect state machine - for e in sdfg.in_edges(outer_state): + for e in sdfg.in_edges(nsdfg_state): sdfg.add_edge(e.src, source, e.data) - for e in sdfg.out_edges(outer_state): + for e in sdfg.out_edges(nsdfg_state): for sink in sinks: sdfg.add_edge(sink, e.dst, dc(e.data)) # Redirect sink incoming edges with a `False` condition to e.dst (return statements) @@ -363,7 +365,7 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): sdfg.add_edge(e2.src, e.dst, InterstateEdge()) # Modify start state as necessary - if outer_start_state is outer_state: + if outer_start_state is nsdfg_state: sdfg.start_state = sdfg.node_id(source) # TODO: Modify memlets by offsetting @@ -418,7 +420,7 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): ####################################################### # Remove nested SDFG and state - sdfg.remove_node(outer_state) + sdfg.remove_node(nsdfg_state) sdfg._cfg_list = sdfg.reset_cfg_list() diff --git a/tests/inlining_test.py b/tests/inlining_test.py index d207aa6c2c..7c3510daed 100644 --- a/tests/inlining_test.py +++ b/tests/inlining_test.py @@ -127,15 +127,16 @@ def outerprog(A: dace.float64[20]): nested(A) sdfg = outerprog.to_sdfg(simplify=True) - from dace.transformation.interstate import InlineMultistateSDFG - sdfg.apply_transformations(InlineMultistateSDFG) - assert sdfg.number_of_nodes() in (4, 5) A = np.random.rand(20) expected = np.copy(A) outerprog.f(expected) - outerprog(A) + from dace.transformation.interstate import InlineMultistateSDFG + sdfg.apply_transformations(InlineMultistateSDFG) + assert sdfg.number_of_nodes() in (4, 5) + + sdfg(A) assert np.allclose(A, expected) @@ -152,18 +153,105 @@ def outerprog(A: dace.float64[20]): nested(A) sdfg = outerprog.to_sdfg(simplify=True) - from dace.transformation.interstate import InlineMultistateSDFG - sdfg.apply_transformations(InlineMultistateSDFG) - assert sdfg.number_of_nodes() in (7, 8) A = np.random.rand(20) expected = np.copy(A) outerprog.f(expected) - outerprog(A) + from dace.transformation.interstate import InlineMultistateSDFG + sdfg.apply_transformations(InlineMultistateSDFG) + assert sdfg.number_of_nodes() in (7, 8) + + sdfg(A) assert np.allclose(A, expected) +def test_multistate_inline_outer_dependencies(): + + @dace.program + def nested(A: dace.float64[20]): + for i in range(1, 20): + A[i] += A[i - 1] + + @dace.program + def outerprog(A: dace.float64[20], B: dace.float64[20]): + for i in dace.map[0:20]: + with dace.tasklet: + a >> A[i] + b >> B[i] + + a = 0 + b = 1 + + nested(A) + + for i in dace.map[0:20]: + with dace.tasklet: + a << A[i] + b >> A[i] + + b = 2 * a + + sdfg = outerprog.to_sdfg(simplify=False) + sdfg.apply_transformations_repeated((StateFusion, InlineSDFG)) + assert len(sdfg.states()) == 1 + + A = np.random.rand(20) + B = np.random.rand(20) + expected_a = np.copy(A) + expected_b = np.copy(B) + outerprog.f(expected_a, expected_b) + + from dace.transformation.interstate import InlineMultistateSDFG + sdfg.apply_transformations(InlineMultistateSDFG) + + sdfg(A, B) + assert np.allclose(A, expected_a) + assert np.allclose(B, expected_b) + + +def test_multistate_inline_concurrent_subgraphs(): + + @dace.program + def nested(A: dace.float64[10], B: dace.float64[10]): + for i in range(1, 10): + B[i] = A[i] + + @dace.program + def outerprog(A: dace.float64[10], B: dace.float64[10], C: dace.float64[10]): + nested(A, B) + + for i in dace.map[0:10]: + with dace.tasklet: + a << A[i] + c >> C[i] + + c = 2 * a + + sdfg = outerprog.to_sdfg(simplify=False) + dace.propagate_memlets_sdfg(sdfg) + sdfg.apply_transformations_repeated((StateFusion, InlineSDFG)) + assert len(sdfg.states()) == 1 + assert len([node for node in sdfg.start_state.data_nodes()]) == 3 + + A = np.random.rand(10) + B = np.random.rand(10) + C = np.random.rand(10) + expected_a = np.copy(A) + expected_b = np.copy(B) + expected_c = np.copy(C) + outerprog.f(expected_a, expected_b, expected_c) + + from dace.transformation.interstate import InlineMultistateSDFG + applied = sdfg.apply_transformations(InlineMultistateSDFG) + assert applied == 1 + + sdfg(A, B, C) + assert np.allclose(A, expected_a) + assert np.allclose(B, expected_b) + assert np.allclose(C, expected_c) + + def test_inline_symexpr(): nsdfg = dace.SDFG('inner') nsdfg.add_array('a', [20], dace.float64) @@ -372,6 +460,8 @@ def test(A: dace.float64[96, 32], B: dace.float64[42, 32]): # test_regression_reshape_unsqueeze() test_empty_memlets() test_multistate_inline() + test_multistate_inline_outer_dependencies() + test_multistate_inline_concurrent_subgraphs() test_multistate_inline_samename() test_inline_symexpr() test_inline_unsqueeze() diff --git a/tests/transformations/prune_connectors_test.py b/tests/transformations/prune_connectors_test.py index e9c7e34a83..59e1b125ff 100644 --- a/tests/transformations/prune_connectors_test.py +++ b/tests/transformations/prune_connectors_test.py @@ -307,7 +307,7 @@ def test_prune_connectors_with_dependencies(): applied = sdfg.apply_transformations_repeated(PruneConnectors) assert applied == 1 - assert len(sdfg.states()) == 3 + assert len(sdfg.states()) == 2 assert "B1" not in nsdfg_node.in_connectors assert "B2" not in nsdfg_node.out_connectors