From 6fa021287628a4064236fdd7b4464942843d6303 Mon Sep 17 00:00:00 2001 From: BenWeber42 Date: Thu, 27 Jun 2024 12:31:33 +0200 Subject: [PATCH] Rename misleading topological_sort to bfs_nodes (#1590) This is currently work in progress to see how we can best fix this misleading naming. Fixes https://github.com/spcl/dace/issues/1560 Since https://github.com/spcl/dace/issues/1560 is still in flux, we have to make sure the PR stays in sync with what we are discussing in https://github.com/spcl/dace/issues/1560. Additionally, at the call-sites of the _previous topological_sort_ (before renaming), there are various comments to use a topoligical sort. After the renaming, they become misleading, so we should probably fix/improve those comments. --- dace/codegen/targets/cuda.py | 2 +- dace/codegen/targets/fpga.py | 2 +- dace/codegen/targets/framecode.py | 4 ++-- dace/sdfg/graph.py | 20 ++++++++++---------- dace/sdfg/state.py | 2 +- 5 files changed, 15 insertions(+), 15 deletions(-) diff --git a/dace/codegen/targets/cuda.py b/dace/codegen/targets/cuda.py index 61a44b8fb2..4731165309 100644 --- a/dace/codegen/targets/cuda.py +++ b/dace/codegen/targets/cuda.py @@ -2096,7 +2096,7 @@ def get_next_scope_entries(self, dfg, scope_entry): # Get all non-sequential scopes from the same level all_scopes = [ - node for node in parent_scope.topological_sort(scope_entry) + node for node in parent_scope.bfs_nodes(scope_entry) if isinstance(node, nodes.EntryNode) and node.map.schedule != dtypes.ScheduleType.Sequential ] diff --git a/dace/codegen/targets/fpga.py b/dace/codegen/targets/fpga.py index db47324268..fb85bdb464 100644 --- a/dace/codegen/targets/fpga.py +++ b/dace/codegen/targets/fpga.py @@ -1848,7 +1848,7 @@ def get_next_scope_entries(self, sdfg, dfg, scope_entry): parent_scope = dfg.scope_subgraph(parent_scope_entry) # Get all scopes from the same level - all_scopes = [node for node in parent_scope.topological_sort() if isinstance(node, dace.sdfg.nodes.EntryNode)] + all_scopes = [node for node in parent_scope.bfs_nodes() if isinstance(node, dace.sdfg.nodes.EntryNode)] return all_scopes[all_scopes.index(scope_entry) + 1:] diff --git a/dace/codegen/targets/framecode.py b/dace/codegen/targets/framecode.py index c1abf82b69..d1e540c39e 100644 --- a/dace/codegen/targets/framecode.py +++ b/dace/codegen/targets/framecode.py @@ -476,7 +476,7 @@ def dispatch_state(state: SDFGState) -> str: cft = cflow.structured_control_flow_tree(sdfg, dispatch_state) else: # If disabled, generate entire graph as general control flow block - states_topological = list(sdfg.topological_sort(sdfg.start_state)) + states_topological = list(sdfg.bfs_nodes(sdfg.start_state)) last = states_topological[-1] cft = cflow.GeneralBlock(dispatch_state, None, [cflow.SingleState(dispatch_state, s, s is last) for s in states_topological], [], @@ -553,7 +553,7 @@ def determine_allocation_lifetime(self, top_sdfg: SDFG): array_names = sdfg.arrays.keys( ) #set(k for k, v in sdfg.arrays.items() if v.lifetime == dtypes.AllocationLifetime.Scope) # Iterate topologically to get state-order - for state in sdfg.topological_sort(): + for state in sdfg.bfs_nodes(): for node in state.data_nodes(): if node.data not in array_names: continue diff --git a/dace/sdfg/graph.py b/dace/sdfg/graph.py index 91ed698896..567e5e84d2 100644 --- a/dace/sdfg/graph.py +++ b/dace/sdfg/graph.py @@ -6,7 +6,7 @@ import networkx as nx from dace.dtypes import deduplicate import dace.serialize -from typing import Any, Callable, Generic, Iterable, List, Sequence, TypeVar, Union +from typing import Any, Callable, Generic, Iterable, List, Optional, Sequence, TypeVar, Union class NodeNotFoundError(Exception): @@ -364,19 +364,19 @@ def sink_nodes(self) -> List[NodeT]: """Returns nodes with no outgoing edges.""" return [n for n in self.nodes() if self.out_degree(n) == 0] - def topological_sort(self, source: NodeT = None) -> Sequence[NodeT]: - """Returns nodes in topological order iff the graph contains exactly - one node with no incoming edges.""" + def bfs_nodes(self, source: Optional[NodeT] = None) -> Iterable[NodeT]: + """Returns an iterable over nodes traversed in breadth-first search + order starting from ``source``.""" if source is not None: sources = [source] else: sources = self.source_nodes() - if len(sources) == 0: - sources = [self.nodes()[0]] - #raise RuntimeError("No source nodes found") - if len(sources) > 1: - sources = [self.nodes()[0]] - #raise RuntimeError("Multiple source nodes found") + if len(sources) != 1: + source = next(iter(self.nodes()), None) + if source is None: + return [] # graph has no nodes + sources = [source] + seen = OrderedDict() # No OrderedSet in Python queue = deque(sources) while len(queue) > 0: diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 736a4799df..45a7913f6a 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -2675,7 +2675,7 @@ def _used_symbols_internal(self, used_before_assignment = set() if used_before_assignment is None else used_before_assignment try: - ordered_blocks = self.topological_sort(self.start_block) + ordered_blocks = self.bfs_nodes(self.start_block) except ValueError: # Failsafe (e.g., for invalid or empty SDFGs) ordered_blocks = self.nodes()