diff --git a/dace/transformation/interstate/loop_detection.py b/dace/transformation/interstate/loop_detection.py index da225232fe..93c2f6ea1c 100644 --- a/dace/transformation/interstate/loop_detection.py +++ b/dace/transformation/interstate/loop_detection.py @@ -3,12 +3,11 @@ import sympy as sp import networkx as nx -import typing -from typing import AnyStr, Optional, Tuple, List +from typing import AnyStr, Optional, Tuple, List, Set from dace import sdfg as sd, symbolic -from dace.sdfg import graph as gr, utils as sdutil -from dace.sdfg.state import ControlFlowRegion +from dace.sdfg import graph as gr, utils as sdutil, InterstateEdge +from dace.sdfg.state import ControlFlowRegion, ControlFlowBlock from dace.transformation import transformation @@ -17,10 +16,19 @@ class DetectLoop(transformation.PatternTransformation): """ Detects a for-loop construct from an SDFG. """ - loop_guard = transformation.PatternNode(sd.SDFGState) + # Always available loop_begin = transformation.PatternNode(sd.SDFGState) exit_state = transformation.PatternNode(sd.SDFGState) + # Available for natural loops + loop_guard = transformation.PatternNode(sd.SDFGState) + + # Available for rotated loops + loop_latch = transformation.PatternNode(sd.SDFGState) + + # Available for rotated and self loops + entry_state = transformation.PatternNode(sd.SDFGState) + @classmethod def expressions(cls): # Case 1: Loop with one state @@ -31,39 +39,98 @@ def expressions(cls): sdfg.add_edge(cls.loop_begin, cls.loop_guard, sd.InterstateEdge()) # Case 2: Loop with multiple states (no back-edge from state) + # The reason for the second case is that subgraph isomorphism requires accounting for every involved edge msdfg = gr.OrderedDiGraph() msdfg.add_nodes_from([cls.loop_guard, cls.loop_begin, cls.exit_state]) msdfg.add_edge(cls.loop_guard, cls.loop_begin, sd.InterstateEdge()) msdfg.add_edge(cls.loop_guard, cls.exit_state, sd.InterstateEdge()) - return [sdfg, msdfg] + # Case 3: Rotated single-state loop + # Here the loop latch (like guard) is the last state in the loop + rsdfg = gr.OrderedDiGraph() + rsdfg.add_nodes_from([cls.entry_state, cls.loop_latch, cls.loop_begin, cls.exit_state]) + rsdfg.add_edge(cls.entry_state, cls.loop_begin, sd.InterstateEdge()) + rsdfg.add_edge(cls.loop_begin, cls.loop_latch, sd.InterstateEdge()) + rsdfg.add_edge(cls.loop_latch, cls.loop_begin, sd.InterstateEdge()) + rsdfg.add_edge(cls.loop_latch, cls.exit_state, sd.InterstateEdge()) + + # Case 4: Rotated multi-state loop + # The reason for this case is also that subgraph isomorphism requires accounting for every involved edge + rmsdfg = gr.OrderedDiGraph() + rmsdfg.add_nodes_from([cls.entry_state, cls.loop_latch, cls.loop_begin, cls.exit_state]) + rmsdfg.add_edge(cls.entry_state, cls.loop_begin, sd.InterstateEdge()) + rmsdfg.add_edge(cls.loop_latch, cls.loop_begin, sd.InterstateEdge()) + rmsdfg.add_edge(cls.loop_latch, cls.exit_state, sd.InterstateEdge()) + + # Case 5: Self-loop + ssdfg = gr.OrderedDiGraph() + ssdfg.add_nodes_from([cls.entry_state, cls.loop_begin, cls.exit_state]) + ssdfg.add_edge(cls.entry_state, cls.loop_begin, sd.InterstateEdge()) + ssdfg.add_edge(cls.loop_begin, cls.loop_begin, sd.InterstateEdge()) + ssdfg.add_edge(cls.loop_begin, cls.exit_state, sd.InterstateEdge()) + + return [sdfg, msdfg, rsdfg, rmsdfg, ssdfg] + + def can_be_applied(self, + graph: ControlFlowRegion, + expr_index: int, + sdfg: sd.SDFG, + permissive: bool = False) -> bool: + if expr_index == 0: + return self.detect_loop(graph, False) is not None + elif expr_index == 1: + return self.detect_loop(graph, True) is not None + elif expr_index == 2: + return self.detect_rotated_loop(graph, False) is not None + elif expr_index == 3: + return self.detect_rotated_loop(graph, True) is not None + elif expr_index == 4: + return self.detect_self_loop(graph) is not None + + raise ValueError(f'Invalid expression index {expr_index}') + + def detect_loop(self, graph: ControlFlowRegion, multistate_loop: bool) -> Optional[str]: + """ + Detects a loop of the form: + + .. code-block:: text + + ---------------- + | v + entry -> guard -> body exit + ^ | + ---------- + - def can_be_applied(self, graph, expr_index, sdfg, permissive=False): + :param graph: The graph to look for the loop. + :param multistate_loop: Whether the loop contains multiple states. + :return: The loop variable or ``None`` if not detected. + """ guard = self.loop_guard begin = self.loop_begin # A for-loop guard only has two incoming edges (init and increment) guard_inedges = graph.in_edges(guard) if len(guard_inedges) < 2: - return False + return None # A for-loop guard only has two outgoing edges (loop and exit-loop) guard_outedges = graph.out_edges(guard) if len(guard_outedges) != 2: - return False + return None # All incoming edges to the guard must set the same variable - itvar = None + itvar: Optional[Set[str]] = None for iedge in guard_inedges: if itvar is None: itvar = set(iedge.data.assignments.keys()) else: itvar &= iedge.data.assignments.keys() if itvar is None: - return False + return None # Outgoing edges must be a negation of each other if guard_outedges[0].data.condition_sympy() != (sp.Not(guard_outedges[1].data.condition_sympy())): - return False + return None # All nodes inside loop must be dominated by loop guard dominators = nx.dominance.immediate_dominators(graph.nx, graph.start_block) @@ -84,23 +151,274 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): break dom = dominators[dom] else: - return False + return None + + if backedge is None: + return None + + # The backedge must reassign the iteration variable + itvar &= backedge.data.assignments.keys() + if len(itvar) != 1: + # Either no consistent iteration variable found, or too many + # consistent iteration variables found + return None + + return next(iter(itvar)) + + def detect_rotated_loop(self, graph: ControlFlowRegion, multistate_loop: bool) -> Optional[str]: + """ + Detects a loop of the form: + + .. code-block:: text + + entry -> body -> latch -> exit + ^ | + ---------- + + + :param graph: The graph to look for the loop. + :param multistate_loop: Whether the loop contains multiple states. + :return: The loop variable or ``None`` if not detected. + """ + latch = self.loop_latch + begin = self.loop_begin + + # A for-loop start has at least two incoming edges (init and increment) + begin_inedges = graph.in_edges(begin) + if len(begin_inedges) < 2: + return None + # A for-loop latch only has two outgoing edges (loop condition and exit-loop) + latch_outedges = graph.out_edges(latch) + if len(latch_outedges) != 2: + return None + + # All incoming edges to the start of the loop must set the same variable + itvar = None + for iedge in begin_inedges: + if itvar is None: + itvar = set(iedge.data.assignments.keys()) + else: + itvar &= iedge.data.assignments.keys() + if itvar is None: + return None + + # Outgoing edges must be a negation of each other + if latch_outedges[0].data.condition_sympy() != (sp.Not(latch_outedges[1].data.condition_sympy())): + return None + + # All nodes inside loop must be dominated by loop start + dominators = nx.dominance.immediate_dominators(graph.nx, graph.start_block) + loop_nodes = list(sdutil.dfs_conditional(graph, sources=[begin], condition=lambda _, child: child != latch)) + loop_nodes += [latch] + backedge = None + for node in loop_nodes: + for e in graph.out_edges(node): + if e.dst == begin: + backedge = e + break + + # Traverse the dominator tree upwards, if we reached the beginning, + # the node is in the loop. If we reach any node in the loop + # without passing through the loop start, fail. + dom = node + while dom != dominators[dom]: + if dom == begin: + break + dom = dominators[dom] + else: + return None if backedge is None: - return False + return None - # The backedge must assignment the iteration variable + # The backedge must reassign the iteration variable itvar &= backedge.data.assignments.keys() if len(itvar) != 1: # Either no consistent iteration variable found, or too many # consistent iteration variables found - return False + return None + + return next(iter(itvar)) - return True + def detect_self_loop(self, graph: ControlFlowRegion) -> Optional[str]: + """ + Detects a loop of the form: + + .. code-block:: text + + entry -> body state -> exit + ^ | + ------ + + + :param graph: The graph to look for the loop. + :return: The loop variable or ``None`` if not detected. + """ + body = self.loop_begin + + # A self-loop body must have only two incoming edges (initialize, increment) + body_inedges = graph.in_edges(body) + if len(body_inedges) != 2: + return None + # A self-loop body must have only two outgoing edges (condition success + increment, condition fail) + body_outedges = graph.out_edges(body) + if len(body_outedges) != 2: + return None + + # All incoming edges to the body must set the same variable + itvar = None + for iedge in body_inedges: + if itvar is None: + itvar = set(iedge.data.assignments.keys()) + else: + itvar &= iedge.data.assignments.keys() + if itvar is None: + return None + + # Outgoing edges must be a negation of each other + if body_outedges[0].data.condition_sympy() != (sp.Not(body_outedges[1].data.condition_sympy())): + return None + + # Backedge is the self-edge + edges = graph.edges_between(body, body) + if len(edges) != 1: + return None + backedge = edges[0] + + # The backedge must reassign the iteration variable + itvar &= backedge.data.assignments.keys() + if len(itvar) != 1: + # Either no consistent iteration variable found, or too many + # consistent iteration variables found + return None + + return next(iter(itvar)) def apply(self, _, sdfg): pass + ############################################ + # Functionality that provides loop metadata + + def loop_information( + self, + itervar: Optional[str] = None + ) -> Optional[Tuple[AnyStr, Tuple[symbolic.SymbolicType, symbolic.SymbolicType, symbolic.SymbolicType], Tuple[ + List[sd.SDFGState], sd.SDFGState]]]: + + entry = self.loop_begin + if self.expr_index <= 1: + guard = self.loop_guard + return find_for_loop(guard.parent_graph, guard, entry, itervar) + elif self.expr_index in (2, 3): + latch = self.loop_latch + return find_rotated_for_loop(latch.parent_graph, latch, entry, itervar) + elif self.expr_index == 4: + return find_rotated_for_loop(entry.parent_graph, entry, entry, itervar) + + raise ValueError(f'Invalid expression index {self.expr_index}') + + def loop_body(self) -> List[ControlFlowBlock]: + """ + Returns a list of all control flow blocks (or states) contained in the loop. + """ + begin = self.loop_begin + graph = begin.parent_graph + if self.expr_index in (0, 1): + guard = self.loop_guard + return list(sdutil.dfs_conditional(graph, sources=[begin], condition=lambda _, child: child != guard)) + elif self.expr_index in (2, 3): + latch = self.loop_latch + loop_nodes = list(sdutil.dfs_conditional(graph, sources=[begin], condition=lambda _, child: child != latch)) + loop_nodes += [latch] + return loop_nodes + elif self.expr_index == 4: + return [begin] + + return [] + + def loop_meta_states(self) -> List[ControlFlowBlock]: + """ + Returns the non-body control-flow blocks of this loop (e.g., guard, latch). + """ + if self.expr_index in (0, 1): + return [self.loop_guard] + if self.expr_index in (2, 3): + return [self.loop_latch] + return [] + + def loop_init_edge(self) -> gr.Edge[InterstateEdge]: + """ + Returns the initialization edge of the loop (assignment to the beginning of the loop range). + """ + begin = self.loop_begin + graph = begin.parent_graph + if self.expr_index in (0, 1): + guard = self.loop_guard + body = self.loop_body() + return next(e for e in graph.in_edges(guard) if e.src not in body) + elif self.expr_index in (2, 3): + latch = self.loop_latch + return next(e for e in graph.in_edges(begin) if e.src is not latch) + elif self.expr_index == 4: + return next(e for e in graph.in_edges(begin) if e.src is not begin) + + raise ValueError(f'Invalid expression index {self.expr_index}') + + def loop_exit_edge(self) -> gr.Edge[InterstateEdge]: + """ + Returns the negative condition edge that exits the loop. + """ + exitstate = self.exit_state + graph = exitstate.parent_graph + if self.expr_index in (0, 1): + guard = self.loop_guard + return graph.edges_between(guard, exitstate)[0] + elif self.expr_index in (2, 3): + latch = self.loop_latch + return graph.edges_between(latch, exitstate)[0] + elif self.expr_index == 4: + begin = self.loop_begin + return graph.edges_between(begin, exitstate)[0] + + raise ValueError(f'Invalid expression index {self.expr_index}') + + def loop_condition_edge(self) -> gr.Edge[InterstateEdge]: + """ + Returns the positive condition edge that (re-)enters the loop after the bound check. + """ + begin = self.loop_begin + graph = begin.parent_graph + if self.expr_index in (0, 1): + guard = self.loop_guard + return graph.edges_between(guard, begin)[0] + elif self.expr_index in (2, 3): + latch = self.loop_latch + return graph.edges_between(latch, begin)[0] + elif self.expr_index == 4: + begin = self.loop_begin + return graph.edges_between(begin, begin)[0] + + raise ValueError(f'Invalid expression index {self.expr_index}') + + def loop_increment_edge(self) -> gr.Edge[InterstateEdge]: + """ + Returns the back-edge that increments the loop induction variable. + """ + begin = self.loop_begin + graph = begin.parent_graph + if self.expr_index in (0, 1): + guard = self.loop_guard + body = self.loop_body() + return next(e for e in graph.in_edges(guard) if e.src in body) + elif self.expr_index in (2, 3): + body = self.loop_body() + return next(e for e in graph.in_edges(begin) if e.src in body) + elif self.expr_index == 4: + return graph.edges_between(begin, begin)[0] + + raise ValueError(f'Invalid expression index {self.expr_index}') + def find_for_loop( graph: ControlFlowRegion, @@ -114,7 +432,8 @@ def find_for_loop( :param guard: State from which the outgoing edges detect whether to exit the loop or not. - :param entry: First state in the loop "body". + :param entry: First state in the loop body. + :param itervar: An optional field that overrides the analyzed iteration variable. :return: (iteration variable, (start, end, stride), (start_states, last_loop_state)), or None if proper for-loop was not detected. ``end`` is inclusive. @@ -123,7 +442,7 @@ def find_for_loop( # Extract state transition edge information guard_inedges = graph.in_edges(guard) condition_edge = graph.edges_between(guard, entry)[0] - + # All incoming edges to the guard must set the same variable if itervar is None: itervars = None @@ -137,7 +456,7 @@ def find_for_loop( else: # Ambiguous or no iteration variable return None - + condition = condition_edge.data.condition_sympy() # Find the stride edge. All in-edges to the guard except for the stride edge @@ -206,3 +525,109 @@ def find_for_loop( return None return itervar, (start, end, stride), (start_states, last_loop_state) + + +def find_rotated_for_loop( + graph: ControlFlowRegion, + latch: sd.SDFGState, + entry: sd.SDFGState, + itervar: Optional[str] = None +) -> Optional[Tuple[AnyStr, Tuple[symbolic.SymbolicType, symbolic.SymbolicType, symbolic.SymbolicType], Tuple[ + List[sd.SDFGState], sd.SDFGState]]]: + """ + Finds rotated loop range from state machine. + + :param latch: State from which the outgoing edges detect whether to exit + the loop or not. + :param entry: First state in the loop body. + :param itervar: An optional field that overrides the analyzed iteration variable. + :return: (iteration variable, (start, end, stride), + (start_states, last_loop_state)), or None if proper + for-loop was not detected. ``end`` is inclusive. + """ + # Extract state transition edge information + entry_inedges = graph.in_edges(entry) + condition_edge = graph.edges_between(latch, entry)[0] + + # All incoming edges to the loop entry must set the same variable + if itervar is None: + itervars = None + for iedge in entry_inedges: + if itervars is None: + itervars = set(iedge.data.assignments.keys()) + else: + itervars &= iedge.data.assignments.keys() + if itervars and len(itervars) == 1: + itervar = next(iter(itervars)) + else: + # Ambiguous or no iteration variable + return None + + condition = condition_edge.data.condition_sympy() + + # Find the stride edge. All in-edges to the entry except for the stride edge + # should have exactly the same assignment, since a valid for loop can only + # have one assignment. + init_edges = [] + init_assignment = None + step_edge = None + itersym = symbolic.symbol(itervar) + for iedge in entry_inedges: + assignment = iedge.data.assignments[itervar] + if itersym in symbolic.pystr_to_symbolic(assignment).free_symbols: + if step_edge is None: + step_edge = iedge + else: + # More than one edge with the iteration variable as a free + # symbol, which is not legal. Invalid for loop. + return None + else: + if init_assignment is None: + init_assignment = assignment + init_edges.append(iedge) + elif init_assignment != assignment: + # More than one init assignment variations mean that this for + # loop is not valid. + return None + else: + init_edges.append(iedge) + if step_edge is None or len(init_edges) == 0 or init_assignment is None: + # Less than two assignment variations, can't be a valid for loop. + return None + + # Get the init expression and the stride. + start = symbolic.pystr_to_symbolic(init_assignment) + stride = (symbolic.pystr_to_symbolic(step_edge.data.assignments[itervar]) - itersym) + + # Get a list of the last states before the loop and a reference to the last + # loop state. + start_states = [] + for init_edge in init_edges: + start_state = init_edge.src + if start_state not in start_states: + start_states.append(start_state) + last_loop_state = step_edge.src + + # Find condition by matching expressions + end: Optional[symbolic.SymbolicType] = None + a = sp.Wild('a') + match = condition.match(itersym < a) + if match: + end = match[a] - 1 + if end is None: + match = condition.match(itersym <= a) + if match: + end = match[a] + if end is None: + match = condition.match(itersym > a) + if match: + end = match[a] + 1 + if end is None: + match = condition.match(itersym >= a) + if match: + end = match[a] + + if end is None: # No match found + return None + + return itervar, (start, end, stride), (start_states, last_loop_state) diff --git a/dace/transformation/interstate/loop_peeling.py b/dace/transformation/interstate/loop_peeling.py index 5dc998c724..c2e50cd37a 100644 --- a/dace/transformation/interstate/loop_peeling.py +++ b/dace/transformation/interstate/loop_peeling.py @@ -79,17 +79,16 @@ def _modify_cond(self, condition, var, step): def apply(self, graph: ControlFlowRegion, sdfg: sd.SDFG): #################################################################### # Obtain loop information - guard: sd.SDFGState = self.loop_guard begin: sd.SDFGState = self.loop_begin after_state: sd.SDFGState = self.exit_state # Obtain iteration variable, range, and stride - condition_edge = graph.edges_between(guard, begin)[0] - not_condition_edge = graph.edges_between(guard, after_state)[0] - itervar, rng, loop_struct = find_for_loop(graph, guard, begin) + condition_edge = self.loop_condition_edge() + not_condition_edge = self.loop_exit_edge() + itervar, rng, loop_struct = self.loop_information() # Get loop states - loop_states = list(sdutil.dfs_conditional(graph, sources=[begin], condition=lambda _, child: child != guard)) + loop_states = self.loop_body() first_id = loop_states.index(begin) last_state = loop_struct[1] last_id = loop_states.index(last_state) @@ -104,7 +103,7 @@ def apply(self, graph: ControlFlowRegion, sdfg: sd.SDFG): init_edges = [] before_states = loop_struct[0] for before_state in before_states: - init_edge = graph.edges_between(before_state, guard)[0] + init_edge = self.loop_init_edge() init_edge.data.assignments[itervar] = str(rng[0] + self.count * rng[2]) init_edges.append(init_edge) append_states = before_states @@ -133,7 +132,7 @@ def apply(self, graph: ControlFlowRegion, sdfg: sd.SDFG): if append_state not in before_states: for init_edge in init_edges: graph.remove_edge(init_edge) - graph.add_edge(append_state, guard, init_edges[0].data) + graph.add_edge(append_state, init_edge.dst, init_edges[0].data) else: # If begin, change initialization assignment and prepend states before # guard @@ -164,4 +163,4 @@ def apply(self, graph: ControlFlowRegion, sdfg: sd.SDFG): # Reconnect edge to guard state from last peeled iteration if prepend_state != after_state: graph.remove_edge(not_condition_edge) - graph.add_edge(guard, prepend_state, not_condition_edge.data) + graph.add_edge(not_condition_edge.src, prepend_state, not_condition_edge.data) diff --git a/dace/transformation/interstate/loop_to_map.py b/dace/transformation/interstate/loop_to_map.py index 7344b54161..39410f2547 100644 --- a/dace/transformation/interstate/loop_to_map.py +++ b/dace/transformation/interstate/loop_to_map.py @@ -95,15 +95,16 @@ def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissi if not super().can_be_applied(graph, expr_index, sdfg, permissive): return False - guard = self.loop_guard begin = self.loop_begin # Guard state should not contain any dataflow - if len(guard.nodes()) != 0: - return False + if expr_index <= 1: + guard = self.loop_guard + if len(guard.nodes()) != 0: + return False # If loop cannot be detected, fail - found = find_for_loop(graph, guard, begin, itervar=self.itervar) + found = self.loop_information(itervar=self.itervar) if not found: return False @@ -123,7 +124,7 @@ def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissi return False # Find all loop-body states - states: List[SDFGState] = list(sdutil.dfs_conditional(sdfg, [begin], lambda _, c: c is not guard)) + states: List[SDFGState] = self.loop_body() assert (body_end in states) @@ -349,22 +350,15 @@ def apply(self, _, sdfg: sd.SDFG): from dace.sdfg.propagation import align_memlet # Obtain loop information - guard: sd.SDFGState = self.loop_guard + itervar, (start, end, step), (_, body_end) = self.loop_information(itervar=self.itervar) + states = self.loop_body() body: sd.SDFGState = self.loop_begin - after: sd.SDFGState = self.exit_state - - # Obtain iteration variable, range, and stride - itervar, (start, end, step), (_, body_end) = find_for_loop(sdfg, guard, body, itervar=self.itervar) - - # Find all loop-body states - states = set() - to_visit = [body] - while to_visit: - state = to_visit.pop(0) - for _, dst, _ in sdfg.out_edges(state): - if dst not in states and dst is not guard: - to_visit.append(dst) - states.add(state) + exit_state = self.exit_state + entry_edge = self.loop_condition_edge() + init_edge = self.loop_init_edge() + after_edge = self.loop_exit_edge() + condition_edge = self.loop_condition_edge() + increment_edge = self.loop_increment_edge() nsdfg = None @@ -425,7 +419,7 @@ def apply(self, _, sdfg: sd.SDFG): nsdfg = SDFG("loop_body", constants=sdfg.constants_prop, parent=new_body) nsdfg.add_node(body, is_start_state=True) body.parent = nsdfg - exit_state = nsdfg.add_state('exit') + nexit_state = nsdfg.add_state('exit') nsymbols = dict() for state in states: if state is body: @@ -438,20 +432,48 @@ def apply(self, _, sdfg: sd.SDFG): for src, dst, data in sdfg.in_edges(state): nsymbols.update({s: sdfg.symbols[s] for s in data.assignments.keys() if s in sdfg.symbols}) nsdfg.add_edge(src, dst, data) - nsdfg.add_edge(body_end, exit_state, InterstateEdge()) + nsdfg.add_edge(body_end, nexit_state, InterstateEdge()) - # Move guard -> body edge to guard -> new_body - for src, dst, data, in sdfg.edges_between(guard, body): - sdfg.add_edge(src, new_body, data) - # Move body_end -> guard edge to new_body -> guard - for src, dst, data in sdfg.edges_between(body_end, guard): - sdfg.add_edge(new_body, dst, data) + increment_edge = None - # Delete loop-body states and edges from parent SDFG - for state in states: - for e in sdfg.all_edges(state): + # Specific instructions for loop type + if self.expr_index <= 1: # Natural loop with guard + guard = self.loop_guard + + # Move guard -> body edge to guard -> new_body + for e in sdfg.edges_between(guard, body): + sdfg.remove_edge(e) + condition_edge = sdfg.add_edge(e.src, new_body, e.data) + # Move body_end -> guard edge to new_body -> guard + for e in sdfg.edges_between(body_end, guard): sdfg.remove_edge(e) - sdfg.remove_node(state) + increment_edge = sdfg.add_edge(new_body, e.dst, e.data) + + + elif 1 < self.expr_index <= 3: # Rotated loop + entrystate = self.entry_state + latch = self.loop_latch + + # Move entry edge to entry -> new_body + for src, dst, data, in sdfg.edges_between(entrystate, body): + init_edge = sdfg.add_edge(src, new_body, data) + + # Move body_end -> latch to new_body -> latch + for src, dst, data in sdfg.edges_between(latch, exit_state): + after_edge = sdfg.add_edge(new_body, dst, data) + + elif self.expr_index == 4: # Self-loop + entrystate = self.entry_state + + # Move entry edge to entry -> new_body + for src, dst, data in sdfg.edges_between(entrystate, body): + init_edge = sdfg.add_edge(src, new_body, data) + for src, dst, data in sdfg.edges_between(body, exit_state): + after_edge = sdfg.add_edge(new_body, dst, data) + + + # Delete loop-body states and edges from parent SDFG + sdfg.remove_nodes_from(states) # Add NestedSDFG arrays for name in read_set | write_set: @@ -490,12 +512,13 @@ def apply(self, _, sdfg: sd.SDFG): # correct map with a positive increment start, end, step = end, start, -step + reentry_assignments = {k: v for k, v in condition_edge.data.assignments.items() if k != itervar} + # If necessary, make a nested SDFG with assignments - isedge = sdfg.edges_between(guard, body)[0] symbols_to_remove = set() - if len(isedge.data.assignments) > 0: + if len(reentry_assignments) > 0: nsdfg = helpers.nest_state_subgraph(sdfg, body, gr.SubgraphView(body, body.nodes())) - for sym in isedge.data.free_symbols: + for sym in entry_edge.data.free_symbols: if sym in nsdfg.symbol_mapping or sym in nsdfg.in_connectors: continue if sym in sdfg.symbols: @@ -522,12 +545,12 @@ def apply(self, _, sdfg: sd.SDFG): nstate = nsdfg.sdfg.node(0) init_state = nsdfg.sdfg.add_state_before(nstate) nisedge = nsdfg.sdfg.edges_between(init_state, nstate)[0] - nisedge.data.assignments = isedge.data.assignments + nisedge.data.assignments = reentry_assignments symbols_to_remove = set(nisedge.data.assignments.keys()) for k in nisedge.data.assignments.keys(): if k in nsdfg.symbol_mapping: del nsdfg.symbol_mapping[k] - isedge.data.assignments = {} + condition_edge.data.assignments = {} source_nodes = body.source_nodes() sink_nodes = body.sink_nodes() @@ -541,8 +564,8 @@ def apply(self, _, sdfg: sd.SDFG): continue # Arrays written with subsets that do not depend on the loop variable must be thread-local map_dependency = False - for e in state.in_edges(node): - subset = e.data.get_dst_subset(e, state) + for e in body.in_edges(node): + subset = e.data.get_dst_subset(e, body) if any(str(s) == itervar for s in subset.free_symbols): map_dependency = True break @@ -644,25 +667,26 @@ def apply(self, _, sdfg: sd.SDFG): if not source_nodes and not sink_nodes: body.add_nedge(entry, exit, memlet.Memlet()) - # Get rid of the loop exit condition edge - after_edge = sdfg.edges_between(guard, after)[0] + # Get rid of the loop exit condition edge (it will be readded below) sdfg.remove_edge(after_edge) # Remove the assignment on the edge to the guard - for e in sdfg.in_edges(guard): + for e in [init_edge, increment_edge]: + if e is None: + continue if itervar in e.data.assignments: del e.data.assignments[itervar] # Remove the condition on the entry edge - condition_edge = sdfg.edges_between(guard, body)[0] condition_edge.data.condition = CodeBlock("1") # Get rid of backedge to guard - sdfg.remove_edge(sdfg.edges_between(body, guard)[0]) + if increment_edge is not None: + sdfg.remove_edge(increment_edge) # Route body directly to after state, maintaining any other assignments # it might have had - sdfg.add_edge(body, after, sd.InterstateEdge(assignments=after_edge.data.assignments)) + sdfg.add_edge(body, exit_state, sd.InterstateEdge(assignments=after_edge.data.assignments)) # If this had made the iteration variable a free symbol, we can remove # it from the SDFG symbols diff --git a/dace/transformation/interstate/loop_unroll.py b/dace/transformation/interstate/loop_unroll.py index e6592b5519..663745c0d6 100644 --- a/dace/transformation/interstate/loop_unroll.py +++ b/dace/transformation/interstate/loop_unroll.py @@ -30,9 +30,7 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): if not super().can_be_applied(graph, expr_index, sdfg, permissive): return False - guard = self.loop_guard - begin = self.loop_begin - found = find_for_loop(graph, guard, begin) + found = self.loop_information() # If loop cannot be detected, fail if not found: @@ -49,20 +47,19 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): def apply(self, graph: ControlFlowRegion, sdfg): # Obtain loop information - guard: sd.SDFGState = self.loop_guard begin: sd.SDFGState = self.loop_begin after_state: sd.SDFGState = self.exit_state # Obtain iteration variable, range, and stride, together with the last # state(s) before the loop and the last loop state. - itervar, rng, loop_struct = find_for_loop(graph, guard, begin) + itervar, rng, loop_struct = self.loop_information() # Loop must be fully unrollable for now. if self.count != 0: raise NotImplementedError # TODO(later) # Get loop states - loop_states = list(sdutil.dfs_conditional(graph, sources=[begin], condition=lambda _, child: child != guard)) + loop_states = self.loop_body() first_id = loop_states.index(begin) last_state = loop_struct[1] last_id = loop_states.index(last_state) @@ -91,7 +88,7 @@ def apply(self, graph: ControlFlowRegion, sdfg): unrolled_states.append((new_states[first_id], new_states[last_id])) # Get any assignments that might be on the edge to the after state - after_assignments = (graph.edges_between(guard, after_state)[0].data.assignments) + after_assignments = self.loop_exit_edge().data.assignments # Connect new states to before and after states without conditions if unrolled_states: @@ -101,7 +98,8 @@ def apply(self, graph: ControlFlowRegion, sdfg): graph.add_edge(unrolled_states[-1][1], after_state, sd.InterstateEdge(assignments=after_assignments)) # Remove old states from SDFG - graph.remove_nodes_from([guard] + loop_states) + guard_or_latch = self.loop_meta_states() + graph.remove_nodes_from(guard_or_latch + loop_states) def instantiate_loop( self, diff --git a/dace/transformation/interstate/move_loop_into_map.py b/dace/transformation/interstate/move_loop_into_map.py index 916f9c5e41..29a9906fe0 100644 --- a/dace/transformation/interstate/move_loop_into_map.py +++ b/dace/transformation/interstate/move_loop_into_map.py @@ -35,12 +35,10 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): return False # Obtain loop information - guard: sd.SDFGState = self.loop_guard body: sd.SDFGState = self.loop_begin - after: sd.SDFGState = self.exit_state # Obtain iteration variable, range, and stride - loop_info = find_for_loop(sdfg, guard, body) + loop_info = self.loop_information() if not loop_info: return False itervar, (start, end, step), (_, body_end) = loop_info @@ -157,11 +155,10 @@ def test_subset_dependency(subset: sbs.Subset, mparams: Set[int]) -> Tuple[bool, def apply(self, _, sdfg: sd.SDFG): # Obtain loop information - guard: sd.SDFGState = self.loop_guard body: sd.SDFGState = self.loop_begin # Obtain iteration variable, range, and stride - itervar, (start, end, step), _ = find_for_loop(sdfg, guard, body) + itervar, (start, end, step), _ = self.loop_information() forward_loop = step > 0 @@ -194,26 +191,31 @@ def apply(self, _, sdfg: sd.SDFG): else: guard_body_edge = e - for body_inedge in sdfg.in_edges(body): - if body_inedge.src is guard: - guard_body_edge.data.assignments.update(body_inedge.data.assignments) - sdfg.remove_edge(body_inedge) - for body_outedge in sdfg.out_edges(body): - sdfg.remove_edge(body_outedge) - for guard_inedge in sdfg.in_edges(guard): - before_guard_edge.data.assignments.update(guard_inedge.data.assignments) - guard_inedge.data.assignments = {} - sdfg.add_edge(guard_inedge.src, body, guard_inedge.data) - sdfg.remove_edge(guard_inedge) - for guard_outedge in sdfg.out_edges(guard): - if guard_outedge.dst is body: - guard_body_edge.data.assignments.update(guard_outedge.data.assignments) - else: - guard_after_edge.data.assignments.update(guard_outedge.data.assignments) - guard_outedge.data.condition = CodeBlock("1") - sdfg.add_edge(body, guard_outedge.dst, guard_outedge.data) - sdfg.remove_edge(guard_outedge) - sdfg.remove_node(guard) + if self.expr_index <= 1: + guard = self.loop_guard + for body_inedge in sdfg.in_edges(body): + if body_inedge.src is guard: + guard_body_edge.data.assignments.update(body_inedge.data.assignments) + sdfg.remove_edge(body_inedge) + for body_outedge in sdfg.out_edges(body): + sdfg.remove_edge(body_outedge) + for guard_inedge in sdfg.in_edges(guard): + before_guard_edge.data.assignments.update(guard_inedge.data.assignments) + guard_inedge.data.assignments = {} + sdfg.add_edge(guard_inedge.src, body, guard_inedge.data) + sdfg.remove_edge(guard_inedge) + for guard_outedge in sdfg.out_edges(guard): + if guard_outedge.dst is body: + guard_body_edge.data.assignments.update(guard_outedge.data.assignments) + else: + guard_after_edge.data.assignments.update(guard_outedge.data.assignments) + guard_outedge.data.condition = CodeBlock("1") + sdfg.add_edge(body, guard_outedge.dst, guard_outedge.data) + sdfg.remove_edge(guard_outedge) + sdfg.remove_node(guard) + else: # Rotated or self loops + raise NotImplementedError('MoveLoopIntoMap not implemented for rotated and self-loops') + if itervar in nsdfg.symbol_mapping: del nsdfg.symbol_mapping[itervar] if itervar in sdfg.symbols: diff --git a/dace/transformation/interstate/trivial_loop_elimination.py b/dace/transformation/interstate/trivial_loop_elimination.py index d214cb5343..411d9ff07d 100644 --- a/dace/transformation/interstate/trivial_loop_elimination.py +++ b/dace/transformation/interstate/trivial_loop_elimination.py @@ -18,12 +18,8 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): if not super().can_be_applied(graph, expr_index, sdfg, permissive): return False - # Obtain loop information - guard: sd.SDFGState = self.loop_guard - body: sd.SDFGState = self.loop_begin - # Obtain iteration variable, range, and stride - loop_info = find_for_loop(sdfg, guard, body) + loop_info = self.loop_information() if not loop_info: return False _, (start, end, step), _ = loop_info @@ -41,39 +37,26 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): def apply(self, _, sdfg: sd.SDFG): # Obtain loop information - guard: sd.SDFGState = self.loop_guard - body: sd.SDFGState = self.loop_begin - # Obtain iteration variable, range and stride - itervar, (start, end, step), (_, body_end) = find_for_loop(sdfg, guard, body) - - # Find all loop-body states - states = set() - to_visit = [body] - while to_visit: - state = to_visit.pop(0) - for _, dst, _ in sdfg.out_edges(state): - if dst not in states and dst is not guard: - to_visit.append(dst) - states.add(state) + itervar, (start, end, step), (_, body_end) = self.loop_information() + states = self.loop_body() for state in states: state.replace(itervar, start) - # remove loop - for body_inedge in sdfg.in_edges(body): - sdfg.remove_edge(body_inedge) - for body_outedge in sdfg.out_edges(body_end): - sdfg.remove_edge(body_outedge) + # Remove loop + sdfg.remove_edge(self.loop_increment_edge()) + + init_edge = self.loop_init_edge() + init_edge.data.assignments = {} + sdfg.add_edge(init_edge.src, self.loop_begin, init_edge.data) + sdfg.remove_edge(init_edge) + + exit_edge = self.loop_exit_edge() + exit_edge.data.condition = CodeBlock("1") + sdfg.add_edge(body_end, exit_edge.dst, exit_edge.data) + sdfg.remove_edge(exit_edge) - for guard_inedge in sdfg.in_edges(guard): - guard_inedge.data.assignments = {} - sdfg.add_edge(guard_inedge.src, body, guard_inedge.data) - sdfg.remove_edge(guard_inedge) - for guard_outedge in sdfg.out_edges(guard): - guard_outedge.data.condition = CodeBlock("1") - sdfg.add_edge(body_end, guard_outedge.dst, guard_outedge.data) - sdfg.remove_edge(guard_outedge) - sdfg.remove_node(guard) + sdfg.remove_nodes_from(self.loop_meta_states()) if itervar in sdfg.symbols and helpers.is_symbol_unused(sdfg, itervar): sdfg.remove_symbol(itervar) diff --git a/dace/transformation/transformation.py b/dace/transformation/transformation.py index 25d61d1ce8..2b37c579a7 100644 --- a/dace/transformation/transformation.py +++ b/dace/transformation/transformation.py @@ -30,9 +30,11 @@ from typing import Any, Dict, Generic, List, Optional, Set, Type, TypeVar, Union, Callable import pydoc import warnings +from typing import TypeVar +PassT = TypeVar('PassT', bound=ppl.Pass) -def experimental_cfg_block_compatible(cls: ppl.Pass): +def experimental_cfg_block_compatible(cls: PassT) -> PassT: cls.__experimental_cfg_block_compatible__ = True return cls diff --git a/tests/transformations/loop_detection_test.py b/tests/transformations/loop_detection_test.py new file mode 100644 index 0000000000..5469f45762 --- /dev/null +++ b/tests/transformations/loop_detection_test.py @@ -0,0 +1,164 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +import dace +import pytest + +from dace.transformation.interstate.loop_detection import DetectLoop +from dace.transformation import transformation as xf + + +class CountLoops(DetectLoop, xf.MultiStateTransformation): + + def can_be_applied(self, graph, expr_index, sdfg, permissive=False): + return super().can_be_applied(graph, expr_index, sdfg, permissive) + + +def test_pyloop(): + + @dace.program + def tester(a: dace.float64[20]): + for i in range(1, 20): + a[i] = a[i - 1] + 1 + + sdfg = tester.to_sdfg() + xform = CountLoops() + assert sdfg.apply_transformations(xform) == 1 + itvar, rng, _ = xform.loop_information() + assert itvar == 'i' + assert rng == (1, 19, 1) + + +def test_loop_rotated(): + sdfg = dace.SDFG('tester') + sdfg.add_symbol('N', dace.int32) + + entry = sdfg.add_state('entry', is_start_block=True) + body = sdfg.add_state('body') + latch = sdfg.add_state('latch') + exitstate = sdfg.add_state('exitstate') + + sdfg.add_edge(entry, body, dace.InterstateEdge(assignments=dict(i=0))) + sdfg.add_edge(body, latch, dace.InterstateEdge()) + sdfg.add_edge(latch, body, dace.InterstateEdge('i < N', assignments=dict(i='i + 2'))) + sdfg.add_edge(latch, exitstate, dace.InterstateEdge('i >= N')) + + xform = CountLoops() + assert sdfg.apply_transformations(xform) == 1 + itvar, rng, _ = xform.loop_information() + assert itvar == 'i' + assert rng == (0, dace.symbol('N') - 1, 2) + + +@pytest.mark.skip('Extra incrementation states should not be supported by loop detection') +def test_loop_rotated_extra_increment(): + sdfg = dace.SDFG('tester') + sdfg.add_symbol('N', dace.int32) + + entry = sdfg.add_state('entry', is_start_block=True) + body = sdfg.add_state('body') + latch = sdfg.add_state('latch') + increment = sdfg.add_state('increment') + exitstate = sdfg.add_state('exitstate') + + sdfg.add_edge(entry, body, dace.InterstateEdge(assignments=dict(i=0))) + sdfg.add_edge(latch, increment, dace.InterstateEdge('i < N')) + sdfg.add_edge(increment, body, dace.InterstateEdge(assignments=dict(i='i + 1'))) + sdfg.add_edge(latch, exitstate, dace.InterstateEdge('i >= N')) + + xform = CountLoops() + assert sdfg.apply_transformations(xform) == 1 + itvar, rng, _ = xform.loop_information() + assert itvar == 'i' + assert rng == (0, dace.symbol('N') - 1, 1) + + +def test_self_loop(): + # Tests a single-state loop + sdfg = dace.SDFG('tester') + sdfg.add_symbol('N', dace.int32) + + entry = sdfg.add_state('entry', is_start_block=True) + body = sdfg.add_state('body') + exitstate = sdfg.add_state('exitstate') + + sdfg.add_edge(entry, body, dace.InterstateEdge(assignments=dict(i=2))) + sdfg.add_edge(body, body, dace.InterstateEdge('i < N', assignments=dict(i='i + 3'))) + sdfg.add_edge(body, exitstate, dace.InterstateEdge('i >= N')) + + xform = CountLoops() + assert sdfg.apply_transformations(xform) == 1 + itvar, rng, _ = xform.loop_information() + assert itvar == 'i' + assert rng == (2, dace.symbol('N') - 1, 3) + + +def test_loop_llvm_canonical(): + sdfg = dace.SDFG('tester') + sdfg.add_symbol('N', dace.int32) + + entry = sdfg.add_state('entry', is_start_block=True) + guard = sdfg.add_state_after(entry, 'guard') + preheader = sdfg.add_state('preheader') + body = sdfg.add_state('body') + latch = sdfg.add_state('latch') + loopexit = sdfg.add_state('loopexit') + exitstate = sdfg.add_state('exitstate') + + sdfg.add_edge(guard, exitstate, dace.InterstateEdge('N <= 0')) + sdfg.add_edge(guard, preheader, dace.InterstateEdge('N > 0')) + sdfg.add_edge(preheader, body, dace.InterstateEdge(assignments=dict(i=0))) + sdfg.add_edge(body, latch, dace.InterstateEdge()) + sdfg.add_edge(latch, body, dace.InterstateEdge('i < N', assignments=dict(i='i + 1'))) + sdfg.add_edge(latch, loopexit, dace.InterstateEdge('i >= N')) + sdfg.add_edge(loopexit, exitstate, dace.InterstateEdge()) + + xform = CountLoops() + assert sdfg.apply_transformations(xform) == 1 + itvar, rng, _ = xform.loop_information() + assert itvar == 'i' + assert rng == (0, dace.symbol('N') - 1, 1) + + +@pytest.mark.skip('Extra incrementation states should not be supported by loop detection') +@pytest.mark.parametrize('with_bounds_check', (False, True)) +def test_loop_llvm_canonical_with_extras(with_bounds_check): + sdfg = dace.SDFG('tester') + sdfg.add_symbol('N', dace.int32) + + entry = sdfg.add_state('entry', is_start_block=True) + guard = sdfg.add_state_after(entry, 'guard') + preheader = sdfg.add_state('preheader') + body = sdfg.add_state('body') + latch = sdfg.add_state('latch') + increment1 = sdfg.add_state('increment1') + increment2 = sdfg.add_state('increment2') + loopexit = sdfg.add_state('loopexit') + exitstate = sdfg.add_state('exitstate') + + if with_bounds_check: + sdfg.add_edge(guard, exitstate, dace.InterstateEdge('N <= 0')) + sdfg.add_edge(guard, preheader, dace.InterstateEdge('N > 0')) + else: + sdfg.add_edge(guard, preheader, dace.InterstateEdge()) + sdfg.add_edge(preheader, body, dace.InterstateEdge(assignments=dict(i=0))) + sdfg.add_edge(body, latch, dace.InterstateEdge()) + sdfg.add_edge(latch, increment1, dace.InterstateEdge('i < N')) + sdfg.add_edge(increment1, increment2, dace.InterstateEdge(assignments=dict(i='i + 1'))) + sdfg.add_edge(increment2, body, dace.InterstateEdge()) + sdfg.add_edge(latch, loopexit, dace.InterstateEdge('i >= N')) + sdfg.add_edge(loopexit, exitstate, dace.InterstateEdge()) + + xform = CountLoops() + assert sdfg.apply_transformations(xform) == 1 + itvar, rng, _ = xform.loop_information() + assert itvar == 'i' + assert rng == (0, dace.symbol('N') - 1, 1) + + +if __name__ == '__main__': + test_pyloop() + test_loop_rotated() + # test_loop_rotated_extra_increment() + test_self_loop() + test_loop_llvm_canonical() + # test_loop_llvm_canonical_with_extras(False) + # test_loop_llvm_canonical_with_extras(True) diff --git a/tests/transformations/loop_to_map_test.py b/tests/transformations/loop_to_map_test.py index 8cd6947bb5..2cab97da78 100644 --- a/tests/transformations/loop_to_map_test.py +++ b/tests/transformations/loop_to_map_test.py @@ -10,7 +10,7 @@ import dace from dace.sdfg import nodes, propagation -from dace.transformation.interstate import LoopToMap +from dace.transformation.interstate import LoopToMap, StateFusion from dace.transformation.interstate.loop_detection import DetectLoop @@ -723,6 +723,71 @@ def internal_write(inp0: dace.int32[10], inp1: dace.int32[10], out: dace.int32[1 assert np.array_equal(val, ref) +@pytest.mark.parametrize('simplify', (False, True)) +def test_rotated_loop_to_map(simplify): + sdfg = dace.SDFG('tester') + sdfg.add_symbol('N', dace.int32) + N = dace.symbol('N') + sdfg.add_array('A', [N], dace.float64) + + entry = sdfg.add_state('entry', is_start_block=True) + guard = sdfg.add_state_after(entry, 'guard') + preheader = sdfg.add_state('preheader') + body = sdfg.add_state('body') + latch = sdfg.add_state('latch') + loopexit = sdfg.add_state('loopexit') + exitstate = sdfg.add_state('exitstate') + + sdfg.add_edge(guard, exitstate, dace.InterstateEdge('N <= 0')) + sdfg.add_edge(guard, preheader, dace.InterstateEdge('N > 0')) + sdfg.add_edge(preheader, body, dace.InterstateEdge(assignments=dict(i=0))) + sdfg.add_edge(body, latch, dace.InterstateEdge()) + sdfg.add_edge(latch, body, dace.InterstateEdge('i < N', assignments=dict(i='i + 1'))) + sdfg.add_edge(latch, loopexit, dace.InterstateEdge('i >= N')) + sdfg.add_edge(loopexit, exitstate, dace.InterstateEdge()) + + t = body.add_tasklet('addone', {'inp'}, {'out'}, 'out = inp + 1') + body.add_edge(body.add_read('A'), None, t, 'inp', dace.Memlet('A[i]')) + body.add_edge(t, 'out', body.add_write('A'), None, dace.Memlet('A[i]')) + + if simplify: + sdfg.apply_transformations_repeated(StateFusion) + + assert sdfg.apply_transformations_repeated(LoopToMap) == 1 + + a = np.random.rand(20) + ref = a + 1 + sdfg(A=a, N=20) + assert np.allclose(a, ref) + + +def test_self_loop_to_map(): + sdfg = dace.SDFG('tester') + sdfg.add_symbol('N', dace.int32) + N = dace.symbol('N') + sdfg.add_array('A', [N], dace.float64) + + entry = sdfg.add_state('entry', is_start_block=True) + body = sdfg.add_state('body') + exitstate = sdfg.add_state('exitstate') + + sdfg.add_edge(entry, body, dace.InterstateEdge(assignments=dict(i=2))) + sdfg.add_edge(body, body, dace.InterstateEdge('i < N', assignments=dict(i='i + 2'))) + sdfg.add_edge(body, exitstate, dace.InterstateEdge('i >= N')) + + t = body.add_tasklet('addone', {'inp'}, {'out'}, 'out = inp + 1') + body.add_edge(body.add_read('A'), None, t, 'inp', dace.Memlet('A[i]')) + body.add_edge(t, 'out', body.add_write('A'), None, dace.Memlet('A[i]')) + + assert sdfg.apply_transformations_repeated(LoopToMap) == 1 + + a = np.random.rand(20) + ref = np.copy(a) + ref[2::2] += 1 + sdfg(A=a, N=20) + assert np.allclose(a, ref) + + if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -759,3 +824,6 @@ def internal_write(inp0: dace.int32[10], inp1: dace.int32[10], out: dace.int32[1 test_nested_loops() test_internal_write() test_specialize() + test_rotated_loop_to_map(False) + test_rotated_loop_to_map(True) + test_self_loop_to_map()