Skip to content

Commit

Permalink
Fixed PruneConnectors (#1660)
Browse files Browse the repository at this point in the history
There was a bug in the `PruneConnectors` transformation, the apply
function did not prune the sets correctly.
I also made some additional changes.
  • Loading branch information
philip-paul-mueller authored Sep 20, 2024
1 parent c2bacca commit d0dcf1c
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 156 deletions.
167 changes: 29 additions & 138 deletions dace/transformation/dataflow/prune_connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,67 +11,66 @@

@properties.make_properties
class PruneConnectors(pm.SingleStateTransformation):
""" Removes unused connectors from nested SDFGs, as well as their memlets
in the outer scope, replacing them with empty memlets if necessary.
"""
Removes unused connectors from nested SDFGs, as well as their memlets in the outer scope.
Optionally: after pruning, removes the unused containers from parent SDFG.
The transformation will not apply if this would remove all inputs and outputs.
"""

nsdfg = pm.PatternNode(nodes.NestedSDFG)

remove_unused_containers = properties.Property(dtype=bool,
default=False,
desc='If True, remove unused containers from parent SDFG.')

@classmethod
def expressions(cls):
return [utils.node_path_graph(cls.nsdfg)]

def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissive: bool = False) -> bool:

prune_in, prune_out = self._get_prune_sets(graph)
if not prune_in and not prune_out:
return False

return True

def _get_prune_sets(self, state: SDFGState) -> Tuple[Set[str], Set[str]]:
"""Computes the set of the input and output connectors that can be removed.
Returns:
A tuple of two sets, the first set contains the name of all input
connectors that can be removed and the second the name of all output
connectors that can be removed.
"""
nsdfg = self.nsdfg

# From the input connectors (i.e. data container on the inside) remove
# all those that are not used for reading and from the output containers
# remove those that are not used fro reading.
# NOTE: If a data container is used for reading and writing then only the
# output connector is retained, except the output is a WCR, then the input
# is also retained.
read_set, write_set = nsdfg.sdfg.read_and_write_sets()
prune_in = nsdfg.in_connectors.keys() - read_set
prune_out = nsdfg.out_connectors.keys() - write_set

# Take into account symbol mappings
strs = tuple(nsdfg.symbol_mapping.values())
syms = tuple(symbolic.pystr_to_symbolic(s) for s in strs)
symnames = tuple(s.name if hasattr(s, 'name') else '' for s in syms)
for conn in list(prune_in):
if conn in syms or conn in symnames or conn in nsdfg.sdfg.symbols:
prune_in.remove(conn)

# Add WCR outputs to "do not prune" input list
for e in graph.out_edges(nsdfg):
for e in state.out_edges(nsdfg):
if e.data.wcr is not None and e.src_conn in prune_in:
prune_in.remove(e.src_conn)

if not prune_in and not prune_out:
return False

return True
return prune_in, prune_out

def apply(self, state: SDFGState, sdfg: SDFG):
nsdfg = self.nsdfg

# Determine which connectors can be removed.
prune_in, prune_out = self._get_prune_sets(state)

# Fission subgraph around nsdfg into its own state to avoid data races
nsdfg_state = helpers.state_fission_after(state, nsdfg)

read_set, write_set = nsdfg.sdfg.read_and_write_sets()
prune_in = nsdfg.in_connectors.keys() - read_set
prune_out = nsdfg.out_connectors.keys() - write_set

# Detect which nodes are used, so we can delete unused nodes after the
# connectors have been pruned
read_set, write_set = nsdfg.sdfg.read_and_write_sets()
all_data_used = read_set | write_set

# Add WCR outputs to "do not prune" input list
for e in nsdfg_state.out_edges(nsdfg):
if e.data.wcr is not None and e.src_conn in prune_in:
prune_in.remove(e.src_conn)

for conn in prune_in:
for e in nsdfg_state.in_edges_by_connector(nsdfg, conn):
nsdfg_state.remove_memlet_path(e, remove_orphans=True)
Expand All @@ -89,18 +88,6 @@ def apply(self, state: SDFGState, sdfg: SDFG):
# If the data is now unused, we can purge it from the SDFG
nsdfg.sdfg.remove_data(conn)

if self.remove_unused_containers:
# Remove unused containers from parent SDFGs
containers = list(sdfg.arrays.keys())
for name in containers:
s = nsdfg.sdfg
while s.parent_sdfg:
s = s.parent_sdfg
try:
s.remove_data(name)
except ValueError:
break


class PruneSymbols(pm.SingleStateTransformation):
"""
Expand Down Expand Up @@ -177,99 +164,3 @@ def apply(self, graph: SDFGState, sdfg: SDFG):
# If not used in SDFG, remove from symbols as well
if helpers.is_symbol_unused(nsdfg.sdfg, candidate):
nsdfg.sdfg.remove_symbol(candidate)


class PruneUnusedOutputs(pm.SingleStateTransformation):
"""
Removes unused symbol mappings from nested SDFGs, as well as internal
symbols if necessary.
"""

nsdfg = pm.PatternNode(nodes.NestedSDFG)

@classmethod
def expressions(cls):
return [utils.node_path_graph(cls.nsdfg)]

@classmethod
def _candidates(cls, nsdfg: nodes.NestedSDFG) -> Tuple[Set[str], Set[Tuple[SDFGState, nodes.AccessNode]]]:
# Start with all non-transient arrays
candidates = set(conn for conn in nsdfg.out_connectors.keys())
candidate_nodes: Set[Tuple[SDFGState, nodes.AccessNode]] = set()

# Remove candidates that are used more than once in the outer SDFG
state = nsdfg.sdfg.parent
sdfg = nsdfg.sdfg.parent_sdfg
for e in state.out_edges(nsdfg):
if e.data.is_empty():
continue
outer_desc = sdfg.arrays[e.data.data]
if isinstance(outer_desc, dt.View):
candidates.remove(e.src_conn)
continue
if not outer_desc.transient:
candidates.remove(e.src_conn)
continue
if not isinstance(state.memlet_path(e)[-1].dst, nodes.AccessNode):
candidates.remove(e.src_conn)
continue

all_access_nodes = [(s, n) for s in sdfg.nodes() for n in s.data_nodes() if n.data == e.data.data]
if len(all_access_nodes) > 1:
candidates.remove(e.src_conn)
continue
if all_access_nodes[0][0].out_degree(all_access_nodes[0][1]) > 0:
candidates.remove(e.src_conn)
continue

if not candidates:
return set(), set()

# Remove candidates that are used in the nested SDFG
for nstate in nsdfg.sdfg.states():
for node in nstate.data_nodes():
if node.data in candidates:
# If used in nested SDFG
if nstate.out_degree(node) > 0:
candidates.remove(node.data)
continue
# If a result of a code node
if any(not isinstance(nstate.memlet_path(e)[0].src, nodes.AccessNode)
for e in nstate.in_edges(node)):
candidates.remove(node.data)
continue

# Add node for later use
candidate_nodes.add((nstate, node))

# Any array that is used in interstate edges is removed
for e in nsdfg.sdfg.all_interstate_edges():
candidates -= (set(map(str, symbolic.symbols_in_ast(e.data.condition.code[0]))))
for assign in e.data.assignments.values():
candidates -= (symbolic.free_symbols_and_functions(assign))

candidate_nodes = {n for n in candidate_nodes if n[1].data in candidates}

return candidates, candidate_nodes

def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissive: bool = False) -> bool:
nsdfg: nodes.NestedSDFG = self.nsdfg
candidates, _ = self._candidates(nsdfg)
if len(candidates) > 0:
return True

return False

def apply(self, state: SDFGState, sdfg: SDFG):
nsdfg = self.nsdfg

candidates, candidate_nodes = self._candidates(nsdfg)
for outer_edge in state.out_edges(nsdfg):
if outer_edge.src_conn in candidates:
state.remove_memlet_path(outer_edge)
sdfg.remove_data(outer_edge.data.data, validate=False)
for nstate, node in candidate_nodes:
for ie in nstate.in_edges(node):
nstate.remove_memlet_path(ie)
for cand in candidates:
nsdfg.sdfg.remove_data(cand, validate=False)
Loading

0 comments on commit d0dcf1c

Please sign in to comment.