Skip to content

Commit

Permalink
More robust loop detection (#1646)
Browse files Browse the repository at this point in the history
Generalizes the behavior of loop detection to support rotated loops,
single-state loops, and as a result LLVM canonical loops. The PR also
refactors the loop analysis methods and generalizes `DetectLoop`
transformation subclasses, such as `LoopToMap` and `LoopPeeling`.
  • Loading branch information
tbennun authored Sep 7, 2024
1 parent e1daf32 commit 0a2c55a
Show file tree
Hide file tree
Showing 9 changed files with 806 additions and 141 deletions.
465 changes: 445 additions & 20 deletions dace/transformation/interstate/loop_detection.py

Large diffs are not rendered by default.

15 changes: 7 additions & 8 deletions dace/transformation/interstate/loop_peeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,17 +79,16 @@ def _modify_cond(self, condition, var, step):
def apply(self, graph: ControlFlowRegion, sdfg: sd.SDFG):
####################################################################
# Obtain loop information
guard: sd.SDFGState = self.loop_guard
begin: sd.SDFGState = self.loop_begin
after_state: sd.SDFGState = self.exit_state

# Obtain iteration variable, range, and stride
condition_edge = graph.edges_between(guard, begin)[0]
not_condition_edge = graph.edges_between(guard, after_state)[0]
itervar, rng, loop_struct = find_for_loop(graph, guard, begin)
condition_edge = self.loop_condition_edge()
not_condition_edge = self.loop_exit_edge()
itervar, rng, loop_struct = self.loop_information()

# Get loop states
loop_states = list(sdutil.dfs_conditional(graph, sources=[begin], condition=lambda _, child: child != guard))
loop_states = self.loop_body()
first_id = loop_states.index(begin)
last_state = loop_struct[1]
last_id = loop_states.index(last_state)
Expand All @@ -104,7 +103,7 @@ def apply(self, graph: ControlFlowRegion, sdfg: sd.SDFG):
init_edges = []
before_states = loop_struct[0]
for before_state in before_states:
init_edge = graph.edges_between(before_state, guard)[0]
init_edge = self.loop_init_edge()
init_edge.data.assignments[itervar] = str(rng[0] + self.count * rng[2])
init_edges.append(init_edge)
append_states = before_states
Expand Down Expand Up @@ -133,7 +132,7 @@ def apply(self, graph: ControlFlowRegion, sdfg: sd.SDFG):
if append_state not in before_states:
for init_edge in init_edges:
graph.remove_edge(init_edge)
graph.add_edge(append_state, guard, init_edges[0].data)
graph.add_edge(append_state, init_edge.dst, init_edges[0].data)
else:
# If begin, change initialization assignment and prepend states before
# guard
Expand Down Expand Up @@ -164,4 +163,4 @@ def apply(self, graph: ControlFlowRegion, sdfg: sd.SDFG):
# Reconnect edge to guard state from last peeled iteration
if prepend_state != after_state:
graph.remove_edge(not_condition_edge)
graph.add_edge(guard, prepend_state, not_condition_edge.data)
graph.add_edge(not_condition_edge.src, prepend_state, not_condition_edge.data)
114 changes: 69 additions & 45 deletions dace/transformation/interstate/loop_to_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,16 @@ def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissi
if not super().can_be_applied(graph, expr_index, sdfg, permissive):
return False

guard = self.loop_guard
begin = self.loop_begin

# Guard state should not contain any dataflow
if len(guard.nodes()) != 0:
return False
if expr_index <= 1:
guard = self.loop_guard
if len(guard.nodes()) != 0:
return False

# If loop cannot be detected, fail
found = find_for_loop(graph, guard, begin, itervar=self.itervar)
found = self.loop_information(itervar=self.itervar)
if not found:
return False

Expand All @@ -123,7 +124,7 @@ def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissi
return False

# Find all loop-body states
states: List[SDFGState] = list(sdutil.dfs_conditional(sdfg, [begin], lambda _, c: c is not guard))
states: List[SDFGState] = self.loop_body()

assert (body_end in states)

Expand Down Expand Up @@ -349,22 +350,15 @@ def apply(self, _, sdfg: sd.SDFG):
from dace.sdfg.propagation import align_memlet

# Obtain loop information
guard: sd.SDFGState = self.loop_guard
itervar, (start, end, step), (_, body_end) = self.loop_information(itervar=self.itervar)
states = self.loop_body()
body: sd.SDFGState = self.loop_begin
after: sd.SDFGState = self.exit_state

# Obtain iteration variable, range, and stride
itervar, (start, end, step), (_, body_end) = find_for_loop(sdfg, guard, body, itervar=self.itervar)

# Find all loop-body states
states = set()
to_visit = [body]
while to_visit:
state = to_visit.pop(0)
for _, dst, _ in sdfg.out_edges(state):
if dst not in states and dst is not guard:
to_visit.append(dst)
states.add(state)
exit_state = self.exit_state
entry_edge = self.loop_condition_edge()
init_edge = self.loop_init_edge()
after_edge = self.loop_exit_edge()
condition_edge = self.loop_condition_edge()
increment_edge = self.loop_increment_edge()

nsdfg = None

Expand Down Expand Up @@ -425,7 +419,7 @@ def apply(self, _, sdfg: sd.SDFG):
nsdfg = SDFG("loop_body", constants=sdfg.constants_prop, parent=new_body)
nsdfg.add_node(body, is_start_state=True)
body.parent = nsdfg
exit_state = nsdfg.add_state('exit')
nexit_state = nsdfg.add_state('exit')
nsymbols = dict()
for state in states:
if state is body:
Expand All @@ -438,20 +432,48 @@ def apply(self, _, sdfg: sd.SDFG):
for src, dst, data in sdfg.in_edges(state):
nsymbols.update({s: sdfg.symbols[s] for s in data.assignments.keys() if s in sdfg.symbols})
nsdfg.add_edge(src, dst, data)
nsdfg.add_edge(body_end, exit_state, InterstateEdge())
nsdfg.add_edge(body_end, nexit_state, InterstateEdge())

# Move guard -> body edge to guard -> new_body
for src, dst, data, in sdfg.edges_between(guard, body):
sdfg.add_edge(src, new_body, data)
# Move body_end -> guard edge to new_body -> guard
for src, dst, data in sdfg.edges_between(body_end, guard):
sdfg.add_edge(new_body, dst, data)
increment_edge = None

# Delete loop-body states and edges from parent SDFG
for state in states:
for e in sdfg.all_edges(state):
# Specific instructions for loop type
if self.expr_index <= 1: # Natural loop with guard
guard = self.loop_guard

# Move guard -> body edge to guard -> new_body
for e in sdfg.edges_between(guard, body):
sdfg.remove_edge(e)
condition_edge = sdfg.add_edge(e.src, new_body, e.data)
# Move body_end -> guard edge to new_body -> guard
for e in sdfg.edges_between(body_end, guard):
sdfg.remove_edge(e)
sdfg.remove_node(state)
increment_edge = sdfg.add_edge(new_body, e.dst, e.data)


elif 1 < self.expr_index <= 3: # Rotated loop
entrystate = self.entry_state
latch = self.loop_latch

# Move entry edge to entry -> new_body
for src, dst, data, in sdfg.edges_between(entrystate, body):
init_edge = sdfg.add_edge(src, new_body, data)

# Move body_end -> latch to new_body -> latch
for src, dst, data in sdfg.edges_between(latch, exit_state):
after_edge = sdfg.add_edge(new_body, dst, data)

elif self.expr_index == 4: # Self-loop
entrystate = self.entry_state

# Move entry edge to entry -> new_body
for src, dst, data in sdfg.edges_between(entrystate, body):
init_edge = sdfg.add_edge(src, new_body, data)
for src, dst, data in sdfg.edges_between(body, exit_state):
after_edge = sdfg.add_edge(new_body, dst, data)


# Delete loop-body states and edges from parent SDFG
sdfg.remove_nodes_from(states)

# Add NestedSDFG arrays
for name in read_set | write_set:
Expand Down Expand Up @@ -490,12 +512,13 @@ def apply(self, _, sdfg: sd.SDFG):
# correct map with a positive increment
start, end, step = end, start, -step

reentry_assignments = {k: v for k, v in condition_edge.data.assignments.items() if k != itervar}

# If necessary, make a nested SDFG with assignments
isedge = sdfg.edges_between(guard, body)[0]
symbols_to_remove = set()
if len(isedge.data.assignments) > 0:
if len(reentry_assignments) > 0:
nsdfg = helpers.nest_state_subgraph(sdfg, body, gr.SubgraphView(body, body.nodes()))
for sym in isedge.data.free_symbols:
for sym in entry_edge.data.free_symbols:
if sym in nsdfg.symbol_mapping or sym in nsdfg.in_connectors:
continue
if sym in sdfg.symbols:
Expand All @@ -522,12 +545,12 @@ def apply(self, _, sdfg: sd.SDFG):
nstate = nsdfg.sdfg.node(0)
init_state = nsdfg.sdfg.add_state_before(nstate)
nisedge = nsdfg.sdfg.edges_between(init_state, nstate)[0]
nisedge.data.assignments = isedge.data.assignments
nisedge.data.assignments = reentry_assignments
symbols_to_remove = set(nisedge.data.assignments.keys())
for k in nisedge.data.assignments.keys():
if k in nsdfg.symbol_mapping:
del nsdfg.symbol_mapping[k]
isedge.data.assignments = {}
condition_edge.data.assignments = {}

source_nodes = body.source_nodes()
sink_nodes = body.sink_nodes()
Expand All @@ -541,8 +564,8 @@ def apply(self, _, sdfg: sd.SDFG):
continue
# Arrays written with subsets that do not depend on the loop variable must be thread-local
map_dependency = False
for e in state.in_edges(node):
subset = e.data.get_dst_subset(e, state)
for e in body.in_edges(node):
subset = e.data.get_dst_subset(e, body)
if any(str(s) == itervar for s in subset.free_symbols):
map_dependency = True
break
Expand Down Expand Up @@ -644,25 +667,26 @@ def apply(self, _, sdfg: sd.SDFG):
if not source_nodes and not sink_nodes:
body.add_nedge(entry, exit, memlet.Memlet())

# Get rid of the loop exit condition edge
after_edge = sdfg.edges_between(guard, after)[0]
# Get rid of the loop exit condition edge (it will be readded below)
sdfg.remove_edge(after_edge)

# Remove the assignment on the edge to the guard
for e in sdfg.in_edges(guard):
for e in [init_edge, increment_edge]:
if e is None:
continue
if itervar in e.data.assignments:
del e.data.assignments[itervar]

# Remove the condition on the entry edge
condition_edge = sdfg.edges_between(guard, body)[0]
condition_edge.data.condition = CodeBlock("1")

# Get rid of backedge to guard
sdfg.remove_edge(sdfg.edges_between(body, guard)[0])
if increment_edge is not None:
sdfg.remove_edge(increment_edge)

# Route body directly to after state, maintaining any other assignments
# it might have had
sdfg.add_edge(body, after, sd.InterstateEdge(assignments=after_edge.data.assignments))
sdfg.add_edge(body, exit_state, sd.InterstateEdge(assignments=after_edge.data.assignments))

# If this had made the iteration variable a free symbol, we can remove
# it from the SDFG symbols
Expand Down
14 changes: 6 additions & 8 deletions dace/transformation/interstate/loop_unroll.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
if not super().can_be_applied(graph, expr_index, sdfg, permissive):
return False

guard = self.loop_guard
begin = self.loop_begin
found = find_for_loop(graph, guard, begin)
found = self.loop_information()

# If loop cannot be detected, fail
if not found:
Expand All @@ -49,20 +47,19 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False):

def apply(self, graph: ControlFlowRegion, sdfg):
# Obtain loop information
guard: sd.SDFGState = self.loop_guard
begin: sd.SDFGState = self.loop_begin
after_state: sd.SDFGState = self.exit_state

# Obtain iteration variable, range, and stride, together with the last
# state(s) before the loop and the last loop state.
itervar, rng, loop_struct = find_for_loop(graph, guard, begin)
itervar, rng, loop_struct = self.loop_information()

# Loop must be fully unrollable for now.
if self.count != 0:
raise NotImplementedError # TODO(later)

# Get loop states
loop_states = list(sdutil.dfs_conditional(graph, sources=[begin], condition=lambda _, child: child != guard))
loop_states = self.loop_body()
first_id = loop_states.index(begin)
last_state = loop_struct[1]
last_id = loop_states.index(last_state)
Expand Down Expand Up @@ -91,7 +88,7 @@ def apply(self, graph: ControlFlowRegion, sdfg):
unrolled_states.append((new_states[first_id], new_states[last_id]))

# Get any assignments that might be on the edge to the after state
after_assignments = (graph.edges_between(guard, after_state)[0].data.assignments)
after_assignments = self.loop_exit_edge().data.assignments

# Connect new states to before and after states without conditions
if unrolled_states:
Expand All @@ -101,7 +98,8 @@ def apply(self, graph: ControlFlowRegion, sdfg):
graph.add_edge(unrolled_states[-1][1], after_state, sd.InterstateEdge(assignments=after_assignments))

# Remove old states from SDFG
graph.remove_nodes_from([guard] + loop_states)
guard_or_latch = self.loop_meta_states()
graph.remove_nodes_from(guard_or_latch + loop_states)

def instantiate_loop(
self,
Expand Down
52 changes: 27 additions & 25 deletions dace/transformation/interstate/move_loop_into_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,10 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
return False

# Obtain loop information
guard: sd.SDFGState = self.loop_guard
body: sd.SDFGState = self.loop_begin
after: sd.SDFGState = self.exit_state

# Obtain iteration variable, range, and stride
loop_info = find_for_loop(sdfg, guard, body)
loop_info = self.loop_information()
if not loop_info:
return False
itervar, (start, end, step), (_, body_end) = loop_info
Expand Down Expand Up @@ -157,11 +155,10 @@ def test_subset_dependency(subset: sbs.Subset, mparams: Set[int]) -> Tuple[bool,

def apply(self, _, sdfg: sd.SDFG):
# Obtain loop information
guard: sd.SDFGState = self.loop_guard
body: sd.SDFGState = self.loop_begin

# Obtain iteration variable, range, and stride
itervar, (start, end, step), _ = find_for_loop(sdfg, guard, body)
itervar, (start, end, step), _ = self.loop_information()

forward_loop = step > 0

Expand Down Expand Up @@ -194,26 +191,31 @@ def apply(self, _, sdfg: sd.SDFG):
else:
guard_body_edge = e

for body_inedge in sdfg.in_edges(body):
if body_inedge.src is guard:
guard_body_edge.data.assignments.update(body_inedge.data.assignments)
sdfg.remove_edge(body_inedge)
for body_outedge in sdfg.out_edges(body):
sdfg.remove_edge(body_outedge)
for guard_inedge in sdfg.in_edges(guard):
before_guard_edge.data.assignments.update(guard_inedge.data.assignments)
guard_inedge.data.assignments = {}
sdfg.add_edge(guard_inedge.src, body, guard_inedge.data)
sdfg.remove_edge(guard_inedge)
for guard_outedge in sdfg.out_edges(guard):
if guard_outedge.dst is body:
guard_body_edge.data.assignments.update(guard_outedge.data.assignments)
else:
guard_after_edge.data.assignments.update(guard_outedge.data.assignments)
guard_outedge.data.condition = CodeBlock("1")
sdfg.add_edge(body, guard_outedge.dst, guard_outedge.data)
sdfg.remove_edge(guard_outedge)
sdfg.remove_node(guard)
if self.expr_index <= 1:
guard = self.loop_guard
for body_inedge in sdfg.in_edges(body):
if body_inedge.src is guard:
guard_body_edge.data.assignments.update(body_inedge.data.assignments)
sdfg.remove_edge(body_inedge)
for body_outedge in sdfg.out_edges(body):
sdfg.remove_edge(body_outedge)
for guard_inedge in sdfg.in_edges(guard):
before_guard_edge.data.assignments.update(guard_inedge.data.assignments)
guard_inedge.data.assignments = {}
sdfg.add_edge(guard_inedge.src, body, guard_inedge.data)
sdfg.remove_edge(guard_inedge)
for guard_outedge in sdfg.out_edges(guard):
if guard_outedge.dst is body:
guard_body_edge.data.assignments.update(guard_outedge.data.assignments)
else:
guard_after_edge.data.assignments.update(guard_outedge.data.assignments)
guard_outedge.data.condition = CodeBlock("1")
sdfg.add_edge(body, guard_outedge.dst, guard_outedge.data)
sdfg.remove_edge(guard_outedge)
sdfg.remove_node(guard)
else: # Rotated or self loops
raise NotImplementedError('MoveLoopIntoMap not implemented for rotated and self-loops')

if itervar in nsdfg.symbol_mapping:
del nsdfg.symbol_mapping[itervar]
if itervar in sdfg.symbols:
Expand Down
Loading

0 comments on commit 0a2c55a

Please sign in to comment.