Skip to content

Commit

Permalink
Use State Fissioning to Generalize Transformations (#1462)
Browse files Browse the repository at this point in the history
  • Loading branch information
lukastruemper authored Mar 1, 2024
1 parent b1a7f8a commit ff6e064
Show file tree
Hide file tree
Showing 7 changed files with 274 additions and 128 deletions.
39 changes: 25 additions & 14 deletions dace/sdfg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1295,32 +1295,43 @@ def inline_sdfgs(sdfg: SDFG, permissive: bool = False, progress: bool = None, mu
from dace.transformation.interstate import InlineSDFG, InlineMultistateSDFG

counter = 0
nsdfgs = [(n, p) for n, p in sdfg.all_nodes_recursive() if isinstance(n, NestedSDFG)]

for node, state in optional_progressbar(reversed(nsdfgs), title='Inlining SDFGs', n=len(nsdfgs), progress=progress):
id = node.sdfg.cfg_id
sd = state.parent
nsdfgs = [n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, NestedSDFG)]

for nsdfg_node in optional_progressbar(reversed(nsdfgs), title='Inlining SDFGs', n=len(nsdfgs), progress=progress):
# We have to reevaluate every time due to changing IDs
state_id = sd.node_id(state)
# e.g., InlineMultistateSDFG may fission states
parent_state = nsdfg_node.sdfg.parent
parent_sdfg = parent_state.parent
parent_state_id = parent_sdfg.node_id(parent_state)

if multistate:
candidate = {
InlineMultistateSDFG.nested_sdfg: node,
InlineMultistateSDFG.nested_sdfg: nsdfg_node,
}
inliner = InlineMultistateSDFG()
inliner.setup_match(sd, id, state_id, candidate, 0, override=True)
if inliner.can_be_applied(state, 0, sd, permissive=permissive):
inliner.apply(state, sd)
inliner.setup_match(sdfg=parent_sdfg,
cfg_id=parent_sdfg.sdfg_id,
state_id=parent_state_id,
subgraph=candidate,
expr_index=0,
override=True)
if inliner.can_be_applied(parent_state, 0, parent_sdfg, permissive=permissive):
inliner.apply(parent_state, parent_sdfg)
counter += 1
continue

candidate = {
InlineSDFG.nested_sdfg: node,
InlineSDFG.nested_sdfg: nsdfg_node,
}
inliner = InlineSDFG()
inliner.setup_match(sd, id, state_id, candidate, 0, override=True)
if inliner.can_be_applied(state, 0, sd, permissive=permissive):
inliner.apply(state, sd)
inliner.setup_match(sdfg=parent_sdfg,
cfg_id=parent_sdfg.sdfg_id,
state_id=parent_state_id,
subgraph=candidate,
expr_index=0,
override=True)
if inliner.can_be_applied(parent_state, 0, parent_sdfg, permissive=permissive):
inliner.apply(parent_state, parent_sdfg)
counter += 1

return counter
Expand Down
35 changes: 1 addition & 34 deletions dace/transformation/dataflow/prune_connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,40 +57,7 @@ def apply(self, state: SDFGState, sdfg: SDFG):
nsdfg = self.nsdfg

# Fission subgraph around nsdfg into its own state to avoid data races
predecessors = set()
for inedge in state.in_edges(nsdfg):
if inedge.data is None:
continue

pred = state.memlet_path(inedge)[0].src
if state.in_degree(pred) == 0:
continue

predecessors.add(pred)
for e in state.bfs_edges(pred, reverse=True):
predecessors.add(e.src)

subgraph = StateSubgraphView(state, predecessors)
pred_state = helpers.state_fission(sdfg, subgraph)

subgraph_nodes = set()
subgraph_nodes.add(nsdfg)
for inedge in state.in_edges(nsdfg):
if inedge.data is None:
continue
path = state.memlet_path(inedge)
for edge in path:
subgraph_nodes.add(edge.src)

for oedge in state.out_edges(nsdfg):
if oedge.data is None:
continue
path = state.memlet_path(oedge)
for edge in path:
subgraph_nodes.add(edge.dst)

subgraph = StateSubgraphView(state, subgraph_nodes)
nsdfg_state = helpers.state_fission(sdfg, subgraph)
nsdfg_state = helpers.state_fission_after(sdfg, state, nsdfg)

read_set, write_set = nsdfg.sdfg.read_and_write_sets()
prune_in = nsdfg.in_connectors.keys() - read_set
Expand Down
23 changes: 10 additions & 13 deletions dace/transformation/dataflow/wcr_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,7 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False):

# If in map, only match if the subset is independent of any
# map indices (otherwise no conflict)
if not permissive and len(outedge.data.subset.free_symbols & set(me.map.params)) == len(
me.map.params):
if not permissive and len(outedge.data.subset.free_symbols & set(me.map.params)) == len(me.map.params):
return False

# Get relevant output connector
Expand Down Expand Up @@ -151,18 +150,16 @@ def apply(self, state: SDFGState, sdfg: SDFG):

# If state fission is necessary to keep semantics, do it first
if state.in_degree(input) > 0:
subgraph_nodes = set([e.src for e in state.bfs_edges(input, reverse=True)])
subgraph_nodes.add(input)

subgraph = StateSubgraphView(state, subgraph_nodes)
helpers.state_fission(sdfg, subgraph)
new_state = helpers.state_fission_after(sdfg, state, tasklet)
else:
new_state = state

if self.expr_index == 0:
inedges = state.edges_between(input, tasklet)
outedge = state.edges_between(tasklet, output)[0]
inedges = new_state.edges_between(input, tasklet)
outedge = new_state.edges_between(tasklet, output)[0]
else:
inedges = state.edges_between(me, tasklet)
outedge = state.edges_between(tasklet, mx)[0]
inedges = new_state.edges_between(me, tasklet)
outedge = new_state.edges_between(tasklet, mx)[0]

# Get relevant output connector
outconn = outedge.src_conn
Expand Down Expand Up @@ -253,8 +250,8 @@ def apply(self, state: SDFGState, sdfg: SDFG):
outedge.data.wcr = f'lambda a,b: a {op} b'

# Remove input node and connector
state.remove_memlet_path(inedge)
propagate_memlets_state(sdfg, state)
new_state.remove_memlet_path(inedge)
propagate_memlets_state(sdfg, new_state)

# If outedge leads to non-transient, and this is a nested SDFG,
# propagate outwards
Expand Down
79 changes: 79 additions & 0 deletions dace/transformation/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,85 @@ def state_fission(sdfg: SDFG, subgraph: graph.SubgraphView, label: Optional[str]
return newstate


def state_fission_after(sdfg: SDFG, state: SDFGState, node: nodes.Node, label: Optional[str] = None) -> SDFGState:
"""
"""
newstate = sdfg.add_state_after(state, label=label)

# Bookkeeping
nodes_to_move = set([node])
boundary_nodes = set()
orig_edges = set()

# Collect predecessors
if not isinstance(node, nodes.AccessNode):
for edge in state.in_edges(node):
for e in state.memlet_path(edge):
nodes_to_move.add(e.src)
orig_edges.add(e)

# Collect nodes_to_move
for edge in state.bfs_edges(node):
nodes_to_move.add(edge.dst)
orig_edges.add(edge)

if not isinstance(edge.dst, nodes.AccessNode):
for iedge in state.in_edges(edge.dst):
if iedge == edge:
continue

for e in state.memlet_path(iedge):
nodes_to_move.add(e.src)
orig_edges.add(e)

# Define boundary nodes
for node in set(nodes_to_move):
if isinstance(node, nodes.AccessNode):
for iedge in state.in_edges(node):
if iedge.src not in nodes_to_move:
boundary_nodes.add(node)
break

if node in boundary_nodes:
continue

for oedge in state.out_edges(node):
if oedge.dst not in nodes_to_move:
boundary_nodes.add(node)
break

# Duplicate boundary nodes
new_nodes = {}
for node in boundary_nodes:
node_ = copy.deepcopy(node)
state.add_node(node_)
new_nodes[node] = node_

for edge in state.edges():
if edge.src in boundary_nodes and edge.dst in boundary_nodes:
state.add_edge(new_nodes[edge.src], edge.src_conn, new_nodes[edge.dst], edge.dst_conn,
copy.deepcopy(edge.data))
elif edge.src in boundary_nodes:
state.add_edge(new_nodes[edge.src], edge.src_conn, edge.dst, edge.dst_conn, copy.deepcopy(edge.data))
elif edge.dst in boundary_nodes:
state.add_edge(edge.src, edge.src_conn, new_nodes[edge.dst], edge.dst_conn, copy.deepcopy(edge.data))

# Move nodes
state.remove_nodes_from(nodes_to_move)

for n in nodes_to_move:
if isinstance(n, nodes.NestedSDFG):
# Set the new parent state
n.sdfg.parent = newstate

newstate.add_nodes_from(nodes_to_move)

for e in orig_edges:
newstate.add_edge(e.src, e.src_conn, e.dst, e.dst_conn, e.data)

return newstate


def _get_internal_subset(internal_memlet: Memlet,
external_memlet: Memlet,
use_src_subset: bool = False,
Expand Down
Loading

0 comments on commit ff6e064

Please sign in to comment.