diff --git a/dace/codegen/control_flow.py b/dace/codegen/control_flow.py index 7701a19ec2..f5559984e7 100644 --- a/dace/codegen/control_flow.py +++ b/dace/codegen/control_flow.py @@ -275,9 +275,13 @@ def as_cpp(self, codegen, symbols) -> str: expr += elem.as_cpp(codegen, symbols) # In a general block, emit transitions and assignments after each individual block or region. if isinstance(elem, BasicCFBlock) or (isinstance(elem, RegionBlock) and elem.region): - cfg = elem.state.parent_graph if isinstance(elem, BasicCFBlock) else elem.region.parent_graph + if isinstance(elem, BasicCFBlock): + g_elem = elem.state + else: + g_elem = elem.region + cfg = g_elem.parent_graph sdfg = cfg if isinstance(cfg, SDFG) else cfg.sdfg - out_edges = cfg.out_edges(elem.state) if isinstance(elem, BasicCFBlock) else cfg.out_edges(elem.region) + out_edges = cfg.out_edges(g_elem) for j, e in enumerate(out_edges): if e not in self.gotos_to_ignore: # Skip gotos to immediate successors @@ -532,26 +536,27 @@ def as_cpp(self, codegen, symbols) -> str: expr = '' if self.loop.update_statement and self.loop.init_statement and self.loop.loop_variable: - # Initialize to either "int i = 0" or "i = 0" depending on whether the type has been defined. - defined_vars = codegen.dispatcher.defined_vars - if not defined_vars.has(self.loop.loop_variable): - try: - init = f'{symbols[self.loop.loop_variable]} ' - except KeyError: - init = 'auto ' - symbols[self.loop.loop_variable] = None - init += unparse_interstate_edge(self.loop.init_statement.code[0], sdfg, codegen=codegen, symbols=symbols) + init = unparse_interstate_edge(self.loop.init_statement.code[0], sdfg, codegen=codegen, symbols=symbols) init = init.strip(';') update = unparse_interstate_edge(self.loop.update_statement.code[0], sdfg, codegen=codegen, symbols=symbols) update = update.strip(';') if self.loop.inverted: - expr += f'{init};\n' - expr += 'do {\n' - expr += _clean_loop_body(self.body.as_cpp(codegen, symbols)) - expr += f'{update};\n' - expr += f'\n}} while({cond});\n' + if self.loop.update_before_condition: + expr += f'{init};\n' + expr += 'do {\n' + expr += _clean_loop_body(self.body.as_cpp(codegen, symbols)) + expr += f'{update};\n' + expr += f'}} while({cond});\n' + else: + expr += f'{init};\n' + expr += 'while (1) {\n' + expr += _clean_loop_body(self.body.as_cpp(codegen, symbols)) + expr += f'if (!({cond}))\n' + expr += 'break;\n' + expr += f'{update};\n' + expr += '}\n' else: expr += f'for ({init}; {cond}; {update}) {{\n' expr += _clean_loop_body(self.body.as_cpp(codegen, symbols)) diff --git a/dace/codegen/targets/framecode.py b/dace/codegen/targets/framecode.py index 488c1c7fbd..d71ea40fee 100644 --- a/dace/codegen/targets/framecode.py +++ b/dace/codegen/targets/framecode.py @@ -15,12 +15,14 @@ from dace.codegen.prettycode import CodeIOStream from dace.codegen.common import codeblock_to_cpp, sym2cpp from dace.codegen.targets.target import TargetCodeGenerator +from dace.codegen.tools.type_inference import infer_expr_type +from dace.frontend.python import astutils from dace.sdfg import SDFG, SDFGState, nodes from dace.sdfg import scope as sdscope from dace.sdfg import utils from dace.sdfg.analysis import cfg as cfg_analysis -from dace.sdfg.state import ControlFlowRegion -from dace.transformation.passes.analysis import StateReachability +from dace.sdfg.state import ControlFlowRegion, LoopRegion +from dace.transformation.passes.analysis import StateReachability, loop_analysis def _get_or_eval_sdfg_first_arg(func, sdfg): @@ -916,6 +918,24 @@ def generate_code(self, interstate_symbols.update(symbols) global_symbols.update(symbols) + if isinstance(cfr, LoopRegion) and cfr.loop_variable is not None and cfr.init_statement is not None: + init_assignment = cfr.init_statement.code[0] + update_assignment = cfr.update_statement.code[0] + if isinstance(init_assignment, astutils.ast.Assign): + init_assignment = init_assignment.value + if isinstance(update_assignment, astutils.ast.Assign): + update_assignment = update_assignment.value + if not cfr.loop_variable in interstate_symbols: + l_end = loop_analysis.get_loop_end(cfr) + l_start = loop_analysis.get_init_assignment(cfr) + l_step = loop_analysis.get_loop_stride(cfr) + sym_type = dtypes.result_type_of(infer_expr_type(l_start, global_symbols), + infer_expr_type(l_step, global_symbols), + infer_expr_type(l_end, global_symbols)) + interstate_symbols[cfr.loop_variable] = sym_type + if not cfr.loop_variable in global_symbols: + global_symbols[cfr.loop_variable] = interstate_symbols[cfr.loop_variable] + for isvarName, isvarType in interstate_symbols.items(): if isvarType is None: raise TypeError(f'Type inference failed for symbol {isvarName}') diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 0d40e13282..cacf15d785 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -2565,8 +2565,7 @@ def visit_If(self, node: ast.If): self._on_block_added(cond_block) if_body = ControlFlowRegion(cond_block.label + '_body', sdfg=self.sdfg) - cond_block.branches.append((CodeBlock(cond), if_body)) - if_body.parent_graph = self.cfg_target + cond_block.add_branch(CodeBlock(cond), if_body) # Visit recursively self._recursive_visit(node.body, 'if', node.lineno, if_body, False) @@ -2575,9 +2574,7 @@ def visit_If(self, node: ast.If): if len(node.orelse) > 0: else_body = ControlFlowRegion(f'{cond_block.label}_else_{node.orelse[0].lineno}', sdfg=self.sdfg) - #cond_block.branches.append((CodeBlock(cond_else), else_body)) - cond_block.branches.append((None, else_body)) - else_body.parent_graph = self.cfg_target + cond_block.add_branch(None, else_body) # Visit recursively self._recursive_visit(node.orelse, 'else', node.lineno, else_body, False) diff --git a/dace/frontend/python/parser.py b/dace/frontend/python/parser.py index b0ef56907f..d99be1265d 100644 --- a/dace/frontend/python/parser.py +++ b/dace/frontend/python/parser.py @@ -499,6 +499,8 @@ def _parse(self, args, kwargs, simplify=None, save=False, validate=False) -> SDF sdutils.inline_control_flow_regions(nsdfg) sdfg.using_experimental_blocks = self.use_experimental_cfg_blocks + sdfg.reset_cfg_list() + # Apply simplification pass automatically if not cached and (simplify == True or (simplify is None and Config.get_bool('optimizer', 'automatic_simplification'))): diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index 619b71b770..3b447fa15a 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -162,10 +162,17 @@ def as_string(self, indent: int = 0): loop = self.header.loop if loop.update_statement and loop.init_statement and loop.loop_variable: if loop.inverted: - pre_header = indent * INDENTATION + f'{loop.init_statement.as_string}\n' - header = indent * INDENTATION + 'do:\n' - pre_footer = (indent + 1) * INDENTATION + f'{loop.update_statement.as_string}\n' - footer = indent * INDENTATION + f'while {loop.loop_condition.as_string}' + if loop.update_before_condition: + pre_header = indent * INDENTATION + f'{loop.init_statement.as_string}\n' + header = indent * INDENTATION + 'do:\n' + pre_footer = (indent + 1) * INDENTATION + f'{loop.update_statement.as_string}\n' + footer = indent * INDENTATION + f'while {loop.loop_condition.as_string}' + else: + pre_header = indent * INDENTATION + f'{loop.init_statement.as_string}\n' + header = indent * INDENTATION + 'while True:\n' + pre_footer = (indent + 1) * INDENTATION + f'if (not {loop.loop_condition.as_string}):\n' + pre_footer += (indent + 2) * INDENTATION + 'break\n' + footer = (indent + 1) * INDENTATION + f'{loop.update_statement.as_string}\n' return pre_header + header + super().as_string(indent) + '\n' + pre_footer + footer else: result = (indent * INDENTATION + diff --git a/dace/sdfg/analysis/writeset_underapproximation.py b/dace/sdfg/analysis/writeset_underapproximation.py index bfd5f4cb00..a0f84e93a6 100644 --- a/dace/sdfg/analysis/writeset_underapproximation.py +++ b/dace/sdfg/analysis/writeset_underapproximation.py @@ -1,42 +1,36 @@ # Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. """ -Pass derived from ``propagation.py`` that under-approximates write-sets of for-loops and Maps in -an SDFG. +Pass derived from ``propagation.py`` that under-approximates write-sets of for-loops and Maps in an SDFG. """ -from collections import defaultdict import copy +from dataclasses import dataclass, field import itertools +import sys import warnings -from typing import Any, Dict, List, Set, Tuple, Type, Union +from collections import defaultdict +from typing import Dict, List, Set, Tuple, Union + +if sys.version_info >= (3, 8): + from typing import TypedDict +else: + from typing_extensions import TypedDict + import sympy import dace +from dace import SDFG, Memlet, data, dtypes, registry, subsets, symbolic +from dace.sdfg import SDFGState +from dace.sdfg import graph +from dace.sdfg import graph as gr +from dace.sdfg import nodes, scope +from dace.sdfg.analysis import cfg as cfg_analysis +from dace.sdfg.nodes import AccessNode, NestedSDFG +from dace.sdfg.state import LoopRegion from dace.symbolic import issymbolic, pystr_to_symbolic, simplify -from dace.transformation.pass_pipeline import Modifies, Pass -from dace import registry, subsets, symbolic, dtypes, data, SDFG, Memlet -from dace.sdfg.nodes import NestedSDFG, AccessNode -from dace.sdfg import nodes, SDFGState, graph as gr -from dace.sdfg.analysis import cfg from dace.transformation import pass_pipeline as ppl -from dace.sdfg import graph -from dace.sdfg import scope - -# dictionary mapping each edge to a copy of the memlet of that edge with its write set -# underapproximated -approximation_dict: Dict[graph.Edge, Memlet] = {} -# dictionary that maps loop headers to "border memlets" that are written to in the -# corresponding loop -loop_write_dict: Dict[SDFGState, Dict[str, Memlet]] = {} -# dictionary containing information about the for loops in the SDFG -loop_dict: Dict[SDFGState, Tuple[SDFGState, SDFGState, - List[SDFGState], str, subsets.Range]] = {} -# dictionary mapping each nested SDFG to the iteration variables surrounding it -iteration_variables: Dict[SDFG, Set[str]] = {} -# dictionary mapping each state to the iteration variables surrounding it -# (including the ones from surrounding SDFGs) -ranges_per_state: Dict[SDFGState, - Dict[str, subsets.Range]] = defaultdict(lambda: {}) +from dace.transformation import transformation +from dace.transformation.pass_pipeline import Modifies @registry.make_registry @@ -81,7 +75,7 @@ def can_be_applied(self, expressions, variable_context, node_range, orig_edges): # Return False if iteration variable appears in multiple dimensions # or if two iteration variables appear in the same dimension - if not self._iteration_variables_appear_multiple_times(data_dims, expressions, other_params, params): + if not self._iteration_variables_appear_only_once(data_dims, expressions, other_params, params): return False node_range = self._make_range(node_range) @@ -89,27 +83,25 @@ def can_be_applied(self, expressions, variable_context, node_range, orig_edges): for dim in range(data_dims): dexprs = [] for expr in expressions: - if isinstance(expr[dim], symbolic.SymExpr): - dexprs.append(expr[dim].expr) - elif isinstance(expr[dim], tuple): - dexprs.append( - (expr[dim][0].expr if isinstance(expr[dim][0], symbolic.SymExpr) else - expr[dim][0], expr[dim][1].expr if isinstance( - expr[dim][1], symbolic.SymExpr) else expr[dim][1], expr[dim][2].expr - if isinstance(expr[dim][2], symbolic.SymExpr) else expr[dim][2])) + expr_dim = expr[dim] + if isinstance(expr_dim, symbolic.SymExpr): + dexprs.append(expr_dim.expr) + elif isinstance(expr_dim, tuple): + dexprs.append((expr_dim[0].expr if isinstance(expr_dim[0], symbolic.SymExpr) else expr_dim[0], + expr_dim[1].expr if isinstance(expr_dim[1], symbolic.SymExpr) else expr_dim[1], + expr_dim[2].expr if isinstance(expr_dim[2], symbolic.SymExpr) else expr_dim[2])) else: - dexprs.append(expr[dim]) + dexprs.append(expr_dim) for pattern_class in SeparableUnderapproximationMemletPattern.extensions().keys(): smpattern = pattern_class() - if smpattern.can_be_applied(dexprs, variable_context, node_range, orig_edges, dim, - data_dims): + if smpattern.can_be_applied(dexprs, variable_context, node_range, orig_edges, dim, data_dims): self.patterns_per_dim[dim] = smpattern break return None not in self.patterns_per_dim - def _iteration_variables_appear_multiple_times(self, data_dims, expressions, other_params, params): + def _iteration_variables_appear_only_once(self, data_dims, expressions, other_params, params): for expr in expressions: for param in params: occured_before = False @@ -146,8 +138,7 @@ def _iteration_variables_appear_multiple_times(self, data_dims, expressions, oth def _make_range(self, node_range): return subsets.Range([(rb.expr if isinstance(rb, symbolic.SymExpr) else rb, - re.expr if isinstance( - re, symbolic.SymExpr) else re, + re.expr if isinstance(re, symbolic.SymExpr) else re, rs.expr if isinstance(rs, symbolic.SymExpr) else rs) for rb, re, rs in node_range]) @@ -160,19 +151,16 @@ def propagate(self, array, expressions, node_range): dexprs = [] for expr in expressions: - if isinstance(expr[i], symbolic.SymExpr): - dexprs.append(expr[i].expr) - elif isinstance(expr[i], tuple): - dexprs.append(( - expr[i][0].expr if isinstance( - expr[i][0], symbolic.SymExpr) else expr[i][0], - expr[i][1].expr if isinstance( - expr[i][1], symbolic.SymExpr) else expr[i][1], - expr[i][2].expr if isinstance( - expr[i][2], symbolic.SymExpr) else expr[i][2], - expr.tile_sizes[i])) + expr_i = expr[i] + if isinstance(expr_i, symbolic.SymExpr): + dexprs.append(expr_i.expr) + elif isinstance(expr_i, tuple): + dexprs.append((expr_i[0].expr if isinstance(expr_i[0], symbolic.SymExpr) else expr_i[0], + expr_i[1].expr if isinstance(expr_i[1], symbolic.SymExpr) else expr_i[1], + expr_i[2].expr if isinstance(expr_i[2], symbolic.SymExpr) else expr_i[2], + expr.tile_sizes[i])) else: - dexprs.append(expr[i]) + dexprs.append(expr_i) result[i] = smpattern.propagate(array, dexprs, node_range) @@ -417,7 +405,7 @@ def _find_unconditionally_executed_states(sdfg: SDFG) -> Set[SDFGState]: sdfg.add_edge(sink_node, dummy_sink, dace.sdfg.InterstateEdge()) # get all the nodes that are executed unconditionally in the state-machine a.k.a nodes # that dominate the sink states - dominators = cfg.all_dominators(sdfg) + dominators = cfg_analysis.all_dominators(sdfg) states = dominators[dummy_sink] # remove dummy state sdfg.remove_node(dummy_sink) @@ -689,21 +677,44 @@ def _merge_subsets(subset_a: subsets.Subset, subset_b: subsets.Subset) -> subset return subset_b +@dataclass +class UnderapproximateWritesDict: + approximation: Dict[graph.Edge, Memlet] = field(default_factory=dict) + loop_approximation: Dict[SDFGState, Dict[str, Memlet]] = field(default_factory=dict) + loops: Dict[SDFGState, + Tuple[SDFGState, SDFGState, List[SDFGState], str, subsets.Range]] = field(default_factory=dict) + + +@transformation.experimental_cfg_block_compatible class UnderapproximateWrites(ppl.Pass): + # Dictionary mapping each edge to a copy of the memlet of that edge with its write set underapproximated. + approximation_dict: Dict[graph.Edge, Memlet] + # Dictionary that maps loop headers to "border memlets" that are written to in the corresponding loop. + loop_write_dict: Dict[SDFGState, Dict[str, Memlet]] + # Dictionary containing information about the for loops in the SDFG. + loop_dict: Dict[SDFGState, Tuple[SDFGState, SDFGState, List[SDFGState], str, subsets.Range]] + # Dictionary mapping each nested SDFG to the iteration variables surrounding it. + iteration_variables: Dict[SDFG, Set[str]] + # Mapping of state to the iteration variables surrounding them, including the ones from surrounding SDFGs. + ranges_per_state: Dict[SDFGState, Dict[str, subsets.Range]] + + def __init__(self): + super().__init__() + self.approximation_dict = {} + self.loop_write_dict = {} + self.loop_dict = {} + self.iteration_variables = {} + self.ranges_per_state = defaultdict(lambda: {}) + def modifies(self) -> Modifies: - return ppl.Modifies.Everything + return ppl.Modifies.States def should_reapply(self, modified: ppl.Modifies) -> bool: - # If anything was modified, reapply - return modified & ppl.Modifies.States | ppl.Modifies.Edges | ppl.Modifies.Symbols | ppl.Modifies.Nodes - - def apply_pass( - self, sdfg: dace.SDFG, pipeline_results: Dict[str, Any] - ) -> Dict[str, Union[ - Dict[graph.Edge, Memlet], - Dict[SDFGState, Dict[str, Memlet]], - Dict[SDFGState, Tuple[SDFGState, SDFGState, List[SDFGState], str, subsets.Range]]]]: + # If anything was modified, reapply. + return modified & ppl.Modifies.Everything + + def apply_pass(self, top_sdfg: dace.SDFG, _) -> Dict[int, UnderapproximateWritesDict]: """ Applies the pass to the given SDFG. @@ -725,55 +736,71 @@ def apply_pass( :notes: The only modification this pass performs on the SDFG is splitting interstate edges. """ - # clear the global dictionaries - approximation_dict.clear() - loop_write_dict.clear() - loop_dict.clear() - iteration_variables.clear() - ranges_per_state.clear() - - # fill the approximation dictionary with the original edges as keys and the edges with the - # approximated memlets as values - for (edge, parent) in sdfg.all_edges_recursive(): - if isinstance(parent, SDFGState): - approximation_dict[edge] = copy.deepcopy(edge.data) - if not isinstance(approximation_dict[edge].subset, - subsets.SubsetUnion) and approximation_dict[edge].subset: - approximation_dict[edge].subset = subsets.SubsetUnion( - [approximation_dict[edge].subset]) - if not isinstance(approximation_dict[edge].dst_subset, - subsets.SubsetUnion) and approximation_dict[edge].dst_subset: - approximation_dict[edge].dst_subset = subsets.SubsetUnion( - [approximation_dict[edge].dst_subset]) - if not isinstance(approximation_dict[edge].src_subset, - subsets.SubsetUnion) and approximation_dict[edge].src_subset: - approximation_dict[edge].src_subset = subsets.SubsetUnion( - [approximation_dict[edge].src_subset]) - - self._underapproximate_writes_sdfg(sdfg) - - # Replace None with empty SubsetUnion in each Memlet - for entry in approximation_dict.values(): - if entry.subset is None: - entry.subset = subsets.SubsetUnion([]) - return { - "approximation": approximation_dict, - "loop_approximation": loop_write_dict, - "loops": loop_dict - } + result = defaultdict(lambda: UnderapproximateWritesDict()) + + for sdfg in top_sdfg.all_sdfgs_recursive(): + # Clear the global dictionaries. + self.approximation_dict = {} + self.loop_write_dict = {} + self.loop_dict = {} + self.iteration_variables = {} + self.ranges_per_state = defaultdict(lambda: {}) + + # fill the approximation dictionary with the original edges as keys and the edges with the + # approximated memlets as values + for (edge, parent) in sdfg.all_edges_recursive(): + if isinstance(parent, SDFGState): + self.approximation_dict[edge] = copy.deepcopy(edge.data) + if not isinstance(self.approximation_dict[edge].subset, + subsets.SubsetUnion) and self.approximation_dict[edge].subset: + self.approximation_dict[edge].subset = subsets.SubsetUnion([ + self.approximation_dict[edge].subset + ]) + if not isinstance(self.approximation_dict[edge].dst_subset, + subsets.SubsetUnion) and self.approximation_dict[edge].dst_subset: + self.approximation_dict[edge].dst_subset = subsets.SubsetUnion([ + self.approximation_dict[edge].dst_subset + ]) + if not isinstance(self.approximation_dict[edge].src_subset, + subsets.SubsetUnion) and self.approximation_dict[edge].src_subset: + self.approximation_dict[edge].src_subset = subsets.SubsetUnion([ + self.approximation_dict[edge].src_subset + ]) + + self._underapproximate_writes_sdfg(sdfg) + + # Replace None with empty SubsetUnion in each Memlet + for entry in self.approximation_dict.values(): + if entry.subset is None: + entry.subset = subsets.SubsetUnion([]) + + result[sdfg.cfg_id].approximation = self.approximation_dict + result[sdfg.cfg_id].loop_approximation = self.loop_write_dict + result[sdfg.cfg_id].loops = self.loop_dict + + return result def _underapproximate_writes_sdfg(self, sdfg: SDFG): """ Underapproximates write-sets of loops, maps and nested SDFGs in the given SDFG. """ from dace.transformation.helpers import split_interstate_edges + from dace.transformation.passes.analysis import loop_analysis split_interstate_edges(sdfg) loops = self._find_for_loops(sdfg) - loop_dict.update(loops) + self.loop_dict.update(loops) + + for region in sdfg.all_control_flow_regions(): + if isinstance(region, LoopRegion): + start = loop_analysis.get_init_assignment(region) + stop = loop_analysis.get_loop_end(region) + stride = loop_analysis.get_loop_stride(region) + for state in region.all_states(): + self.ranges_per_state[state][region.loop_variable] = subsets.Range([(start, stop, stride)]) - for state in sdfg.nodes(): - self._underapproximate_writes_state(sdfg, state) + for state in region.all_states(): + self._underapproximate_writes_state(sdfg, state) self._underapproximate_writes_loops(loops, sdfg) @@ -792,8 +819,8 @@ def _find_for_loops(self, """ # We import here to avoid cyclic imports. - from dace.transformation.interstate.loop_detection import find_for_loop from dace.sdfg import utils as sdutils + from dace.transformation.interstate.loop_detection import find_for_loop # dictionary mapping loop headers to beginstate, loopstates, looprange identified_loops = {} @@ -885,13 +912,12 @@ def _find_for_loops(self, sources=[begin], condition=lambda _, child: child != guard) - if itvar not in ranges_per_state[begin]: + if itvar not in self.ranges_per_state[begin]: for loop_state in loop_states: - ranges_per_state[loop_state][itervar] = subsets.Range([ - rng]) + self.ranges_per_state[loop_state][itervar] = subsets.Range([rng]) loop_state_list.append(loop_state) - ranges_per_state[guard][itervar] = subsets.Range([rng]) + self.ranges_per_state[guard][itervar] = subsets.Range([rng]) identified_loops[guard] = (begin, last_loop_state, loop_state_list, itvar, subsets.Range([rng])) @@ -934,8 +960,11 @@ def _underapproximate_writes_state(self, sdfg: SDFG, state: SDFGState): # approximation_dict # First, propagate nested SDFGs in a bottom-up fashion + dnodes: Set[nodes.AccessNode] = set() for node in state.nodes(): - if isinstance(node, nodes.NestedSDFG): + if isinstance(node, AccessNode): + dnodes.add(node) + elif isinstance(node, nodes.NestedSDFG): self._find_live_iteration_variables(node, sdfg, state) # Propagate memlets inside the nested SDFG. @@ -947,6 +976,15 @@ def _underapproximate_writes_state(self, sdfg: SDFG, state: SDFGState): # Process scopes from the leaves upwards self._underapproximate_writes_scope(sdfg, state, state.scope_leaves()) + # Make sure any scalar writes are also added if they have not been processed yet. + for dn in dnodes: + desc = sdfg.data(dn.data) + if isinstance(desc, data.Scalar) or (isinstance(desc, data.Array) and desc.total_size == 1): + for iedge in state.in_edges(dn): + if not iedge in self.approximation_dict: + self.approximation_dict[iedge] = copy.deepcopy(iedge.data) + self.approximation_dict[iedge]._edge = iedge + def _find_live_iteration_variables(self, nsdfg: nodes.NestedSDFG, sdfg: SDFG, @@ -963,15 +1001,14 @@ def symbol_map(mapping, symbol): return None map_iteration_variables = _collect_iteration_variables(state, nsdfg) - sdfg_iteration_variables = iteration_variables[ - sdfg] if sdfg in iteration_variables else set() - state_iteration_variables = ranges_per_state[state].keys() + sdfg_iteration_variables = self.iteration_variables[sdfg] if sdfg in self.iteration_variables else set() + state_iteration_variables = self.ranges_per_state[state].keys() iteration_variables_local = (map_iteration_variables | sdfg_iteration_variables | state_iteration_variables) mapped_iteration_variables = set( map(lambda x: symbol_map(nsdfg.symbol_mapping, x), iteration_variables_local)) if mapped_iteration_variables: - iteration_variables[nsdfg.sdfg] = mapped_iteration_variables + self.iteration_variables[nsdfg.sdfg] = mapped_iteration_variables def _underapproximate_writes_nested_sdfg( self, @@ -1025,12 +1062,11 @@ def _init_border_memlet(template_memlet: Memlet, # Collect all memlets belonging to this access node memlets = [] for edge in edges: - inside_memlet = approximation_dict[edge] + inside_memlet = self.approximation_dict[edge] memlets.append(inside_memlet) # initialize border memlet if it does not exist already if border_memlet is None: - border_memlet = _init_border_memlet( - inside_memlet, node.label) + border_memlet = _init_border_memlet(inside_memlet, node.label) # Given all of this access nodes' memlets union all the subsets to one SubsetUnion if len(memlets) > 0: @@ -1042,18 +1078,16 @@ def _init_border_memlet(template_memlet: Memlet, border_memlet.subset, subset) # collect the memlets for each loop in the NSDFG - if state in loop_write_dict: - for node_label, loop_memlet in loop_write_dict[state].items(): + if state in self.loop_write_dict: + for node_label, loop_memlet in self.loop_write_dict[state].items(): if node_label not in border_memlets: continue border_memlet = border_memlets[node_label] # initialize border memlet if it does not exist already if border_memlet is None: - border_memlet = _init_border_memlet( - loop_memlet, node_label) + border_memlet = _init_border_memlet(loop_memlet, node_label) # compute the union of the ranges to merge the subsets. - border_memlet.subset = _merge_subsets( - border_memlet.subset, loop_memlet.subset) + border_memlet.subset = _merge_subsets(border_memlet.subset, loop_memlet.subset) # Make sure any potential NSDFG symbol mapping is correctly reversed # when propagating out. @@ -1068,17 +1102,16 @@ def _init_border_memlet(template_memlet: Memlet, # Propagate the inside 'border' memlets outside the SDFG by # offsetting, and unsqueezing if necessary. for edge in parent_state.out_edges(nsdfg_node): - out_memlet = approximation_dict[edge] + out_memlet = self.approximation_dict[edge] if edge.src_conn in border_memlets: internal_memlet = border_memlets[edge.src_conn] if internal_memlet is None: out_memlet.subset = None out_memlet.dst_subset = None - approximation_dict[edge] = out_memlet + self.approximation_dict[edge] = out_memlet continue - out_memlet = _unsqueeze_memlet_subsetunion(internal_memlet, out_memlet, parent_sdfg, - nsdfg_node) - approximation_dict[edge] = out_memlet + out_memlet = _unsqueeze_memlet_subsetunion(internal_memlet, out_memlet, parent_sdfg, nsdfg_node) + self.approximation_dict[edge] = out_memlet def _underapproximate_writes_loop(self, sdfg: SDFG, @@ -1099,9 +1132,7 @@ def _underapproximate_writes_loop(self, propagate_memlet_loop will be called recursively on the outermost loopheaders """ - def _init_border_memlet(template_memlet: Memlet, - node_label: str - ): + def _init_border_memlet(template_memlet: Memlet, node_label: str): ''' Creates a Memlet with the same data as the template_memlet, stores it in the border_memlets dictionary and returns it. @@ -1111,8 +1142,7 @@ def _init_border_memlet(template_memlet: Memlet, border_memlets[node_label] = border_memlet return border_memlet - def filter_subsets(itvar: str, itrange: subsets.Range, - memlet: Memlet) -> List[subsets.Subset]: + def filter_subsets(itvar: str, itrange: subsets.Range, memlet: Memlet) -> List[subsets.Subset]: # helper method that filters out subsets that do not depend on the iteration variable # if the iteration range is symbolic @@ -1134,7 +1164,7 @@ def filter_subsets(itvar: str, itrange: subsets.Range, if rng.num_elements() == 0: return # make sure there is no break out of the loop - dominators = cfg.all_dominators(sdfg) + dominators = cfg_analysis.all_dominators(sdfg) if any(begin not in dominators[s] and not begin is s for s in loop_states): return border_memlets = defaultdict(None) @@ -1159,7 +1189,7 @@ def filter_subsets(itvar: str, itrange: subsets.Range, # collect all the subsets of the incoming memlets for the current access node for edge in edges: - inside_memlet = copy.copy(approximation_dict[edge]) + inside_memlet = copy.copy(self.approximation_dict[edge]) # filter out subsets that could become empty depending on assignments # of symbols filtered_subsets = filter_subsets( @@ -1177,35 +1207,27 @@ def filter_subsets(itvar: str, itrange: subsets.Range, self._underapproximate_writes_loop_subset(sdfg, memlets, border_memlet, sdfg.arrays[node.label], itvar, rng) - if state not in loop_write_dict: + if state not in self.loop_write_dict: continue # propagate the border memlets of nested loop - for node_label, other_border_memlet in loop_write_dict[state].items(): + for node_label, other_border_memlet in self.loop_write_dict[state].items(): # filter out subsets that could become empty depending on symbol assignments - filtered_subsets = filter_subsets( - itvar, rng, other_border_memlet) + filtered_subsets = filter_subsets(itvar, rng, other_border_memlet) if not filtered_subsets: continue - other_border_memlet.subset = subsets.SubsetUnion( - filtered_subsets) + other_border_memlet.subset = subsets.SubsetUnion(filtered_subsets) border_memlet = border_memlets.get(node_label) if border_memlet is None: - border_memlet = _init_border_memlet( - other_border_memlet, node_label) + border_memlet = _init_border_memlet(other_border_memlet, node_label) self._underapproximate_writes_loop_subset(sdfg, [other_border_memlet], border_memlet, sdfg.arrays[node_label], itvar, rng) - loop_write_dict[loop_header] = border_memlets + self.loop_write_dict[loop_header] = border_memlets - def _underapproximate_writes_loop_subset(self, - sdfg: dace.SDFG, - memlets: List[Memlet], - dst_memlet: Memlet, - arr: dace.data.Array, - itvar: str, - rng: subsets.Subset, + def _underapproximate_writes_loop_subset(self, sdfg: dace.SDFG, memlets: List[Memlet], dst_memlet: Memlet, + arr: dace.data.Array, itvar: str, rng: subsets.Subset, loop_nest_itvars: Union[Set[str], None] = None): """ Helper function that takes a list of (border) memlets, propagates them out of a @@ -1223,16 +1245,11 @@ def _underapproximate_writes_loop_subset(self, if len(memlets) > 0: params = [itvar] # get all the other iteration variables surrounding this memlet - surrounding_itvars = iteration_variables[sdfg] if sdfg in iteration_variables else set( - ) + surrounding_itvars = self.iteration_variables[sdfg] if sdfg in self.iteration_variables else set() if loop_nest_itvars: surrounding_itvars |= loop_nest_itvars - subset = self._underapproximate_subsets(memlets, - arr, - params, - rng, - use_dst=True, + subset = self._underapproximate_subsets(memlets, arr, params, rng, use_dst=True, surrounding_itvars=surrounding_itvars).subset if subset is None or len(subset.subset_list) == 0: @@ -1240,9 +1257,7 @@ def _underapproximate_writes_loop_subset(self, # compute the union of the ranges to merge the subsets. dst_memlet.subset = _merge_subsets(dst_memlet.subset, subset) - def _underapproximate_writes_scope(self, - sdfg: SDFG, - state: SDFGState, + def _underapproximate_writes_scope(self, sdfg: SDFG, state: SDFGState, scopes: Union[scope.ScopeTree, List[scope.ScopeTree]]): """ Propagate memlets from the given scopes outwards. @@ -1253,8 +1268,7 @@ def _underapproximate_writes_scope(self, """ # for each map scope find the iteration variables of surrounding maps - surrounding_map_vars: Dict[scope.ScopeTree, - Set[str]] = _collect_itvars_scope(scopes) + surrounding_map_vars: Dict[scope.ScopeTree, Set[str]] = _collect_itvars_scope(scopes) if isinstance(scopes, scope.ScopeTree): scopes_to_process = [scopes] else: @@ -1272,8 +1286,7 @@ def _underapproximate_writes_scope(self, sdfg, state, surrounding_map_vars) - self._underapproximate_writes_node( - state, scope_node.exit, surrounding_iteration_variables) + self._underapproximate_writes_node(state, scope_node.exit, surrounding_iteration_variables) # Add parent to next frontier next_scopes.add(scope_node.parent) scopes_to_process = next_scopes @@ -1286,9 +1299,8 @@ def _collect_iteration_variables_scope_node(self, surrounding_map_vars: Dict[scope.ScopeTree, Set[str]]) -> Set[str]: map_iteration_variables = surrounding_map_vars[ scope_node] if scope_node in surrounding_map_vars else set() - sdfg_iteration_variables = iteration_variables[ - sdfg] if sdfg in iteration_variables else set() - loop_iteration_variables = ranges_per_state[state].keys() + sdfg_iteration_variables = self.iteration_variables[sdfg] if sdfg in self.iteration_variables else set() + loop_iteration_variables = self.ranges_per_state[state].keys() surrounding_iteration_variables = (map_iteration_variables | sdfg_iteration_variables | loop_iteration_variables) @@ -1308,12 +1320,8 @@ def _underapproximate_writes_node(self, :param surrounding_itvars: Iteration variables that surround the map scope """ if isinstance(node, nodes.EntryNode): - internal_edges = [ - e for e in dfg_state.out_edges(node) if e.src_conn and e.src_conn.startswith('OUT_') - ] - external_edges = [ - e for e in dfg_state.in_edges(node) if e.dst_conn and e.dst_conn.startswith('IN_') - ] + internal_edges = [e for e in dfg_state.out_edges(node) if e.src_conn and e.src_conn.startswith('OUT_')] + external_edges = [e for e in dfg_state.in_edges(node) if e.dst_conn and e.dst_conn.startswith('IN_')] def geticonn(e): return e.src_conn[4:] @@ -1323,12 +1331,8 @@ def geteconn(e): use_dst = False else: - internal_edges = [ - e for e in dfg_state.in_edges(node) if e.dst_conn and e.dst_conn.startswith('IN_') - ] - external_edges = [ - e for e in dfg_state.out_edges(node) if e.src_conn and e.src_conn.startswith('OUT_') - ] + internal_edges = [e for e in dfg_state.in_edges(node) if e.dst_conn and e.dst_conn.startswith('IN_')] + external_edges = [e for e in dfg_state.out_edges(node) if e.src_conn and e.src_conn.startswith('OUT_')] def geticonn(e): return e.dst_conn[3:] @@ -1339,21 +1343,17 @@ def geteconn(e): use_dst = True for edge in external_edges: - if approximation_dict[edge].is_empty(): + if self.approximation_dict[edge].is_empty(): new_memlet = Memlet() else: internal_edge = next( e for e in internal_edges if geticonn(e) == geteconn(edge)) - aligned_memlet = self._align_memlet( - dfg_state, internal_edge, dst=use_dst) - new_memlet = self._underapproximate_memlets(dfg_state, - aligned_memlet, - node, - True, - connector=geteconn( - edge), + aligned_memlet = self._align_memlet(dfg_state, internal_edge, dst=use_dst) + new_memlet = self._underapproximate_memlets(dfg_state, aligned_memlet, node, True, + connector=geteconn(edge), surrounding_itvars=surrounding_itvars) - approximation_dict[edge] = new_memlet + new_memlet._edge = edge + self.approximation_dict[edge] = new_memlet def _align_memlet(self, state: SDFGState, @@ -1373,16 +1373,16 @@ def _align_memlet(self, is_src = edge.data._is_data_src # Memlet is already aligned if is_src is None or (is_src and not dst) or (not is_src and dst): - res = approximation_dict[edge] + res = self.approximation_dict[edge] return res # Data<->Code memlets always have one data container mpath = state.memlet_path(edge) if not isinstance(mpath[0].src, AccessNode) or not isinstance(mpath[-1].dst, AccessNode): - return approximation_dict[edge] + return self.approximation_dict[edge] # Otherwise, find other data container - result = copy.deepcopy(approximation_dict[edge]) + result = copy.deepcopy(self.approximation_dict[edge]) if dst: node = mpath[-1].dst else: @@ -1390,8 +1390,8 @@ def _align_memlet(self, # Fix memlet fields result.data = node.data - result.subset = approximation_dict[edge].other_subset - result.other_subset = approximation_dict[edge].subset + result.subset = self.approximation_dict[edge].other_subset + result.other_subset = self.approximation_dict[edge].subset result._is_data_src = not is_src return result @@ -1448,9 +1448,9 @@ def _underapproximate_memlets(self, # and union their subsets if union_inner_edges: aggdata = [ - approximation_dict[e] + self.approximation_dict[e] for e in neighboring_edges - if approximation_dict[e].data == memlet.data and approximation_dict[e] != memlet + if self.approximation_dict[e].data == memlet.data and self.approximation_dict[e] != memlet ] else: aggdata = [] @@ -1459,8 +1459,7 @@ def _underapproximate_memlets(self, if arr is None: if memlet.data not in sdfg.arrays: - raise KeyError('Data descriptor (Array, Stream) "%s" not defined in SDFG.' % - memlet.data) + raise KeyError('Data descriptor (Array, Stream) "%s" not defined in SDFG.' % memlet.data) # FIXME: A memlet alone (without an edge) cannot figure out whether it is data<->data or data<->code # so this test cannot be used diff --git a/dace/sdfg/propagation.py b/dace/sdfg/propagation.py index 1c038dd2e4..f62bb6eb58 100644 --- a/dace/sdfg/propagation.py +++ b/dace/sdfg/propagation.py @@ -4,21 +4,22 @@ from internal memory accesses and scope ranges). """ -from collections import deque import copy -from dace.symbolic import issymbolic, pystr_to_symbolic, simplify -import itertools import functools +import itertools +import warnings +from collections import deque +from typing import List, Set + import sympy -from sympy import ceiling, Symbol +from sympy import Symbol, ceiling from sympy.concrete.summations import Sum -import warnings -import networkx as nx -from dace import registry, subsets, symbolic, dtypes, data +from dace import data, dtypes, registry, subsets, symbolic from dace.memlet import Memlet -from dace.sdfg import nodes, graph as gr -from typing import List, Set +from dace.sdfg import graph as gr +from dace.sdfg import nodes +from dace.symbolic import issymbolic, pystr_to_symbolic, simplify @registry.make_registry @@ -61,17 +62,17 @@ def can_be_applied(self, expressions, variable_context, node_range, orig_edges): for rb, re, rs in node_range]) for dim in range(data_dims): - dexprs = [] for expr in expressions: - if isinstance(expr[dim], symbolic.SymExpr): - dexprs.append(expr[dim].approx) - elif isinstance(expr[dim], tuple): - dexprs.append((expr[dim][0].approx if isinstance(expr[dim][0], symbolic.SymExpr) else expr[dim][0], - expr[dim][1].approx if isinstance(expr[dim][1], symbolic.SymExpr) else expr[dim][1], - expr[dim][2].approx if isinstance(expr[dim][2], symbolic.SymExpr) else expr[dim][2])) + expr_dim = expr[dim] + if isinstance(expr_dim, symbolic.SymExpr): + dexprs.append(expr_dim.approx) + elif isinstance(expr_dim, tuple): + dexprs.append((expr_dim[0].approx if isinstance(expr_dim[0], symbolic.SymExpr) else expr_dim[0], + expr_dim[1].approx if isinstance(expr_dim[1], symbolic.SymExpr) else expr_dim[1], + expr_dim[2].approx if isinstance(expr_dim[2], symbolic.SymExpr) else expr_dim[2])) else: - dexprs.append(expr[dim]) + dexprs.append(expr_dim) for pattern_class in SeparableMemletPattern.extensions().keys(): smpattern = pattern_class() @@ -93,15 +94,16 @@ def propagate(self, array, expressions, node_range): dexprs = [] for expr in expressions: - if isinstance(expr[i], symbolic.SymExpr): - dexprs.append(expr[i].approx) - elif isinstance(expr[i], tuple): - dexprs.append((expr[i][0].approx if isinstance(expr[i][0], symbolic.SymExpr) else expr[i][0], - expr[i][1].approx if isinstance(expr[i][1], symbolic.SymExpr) else expr[i][1], - expr[i][2].approx if isinstance(expr[i][2], symbolic.SymExpr) else expr[i][2], + expr_i = expr[i] + if isinstance(expr_i, symbolic.SymExpr): + dexprs.append(expr_i.approx) + elif isinstance(expr_i, tuple): + dexprs.append((expr_i[0].approx if isinstance(expr_i[0], symbolic.SymExpr) else expr_i[0], + expr_i[1].approx if isinstance(expr_i[1], symbolic.SymExpr) else expr_i[1], + expr_i[2].approx if isinstance(expr_i[2], symbolic.SymExpr) else expr_i[2], expr.tile_sizes[i])) else: - dexprs.append(expr[i]) + dexprs.append(expr_i) result[i] = smpattern.propagate(array, dexprs, overapprox_range) @@ -569,8 +571,8 @@ def _annotate_loop_ranges(sdfg, unannotated_cycle_states): """ # We import here to avoid cyclic imports. - from dace.transformation.interstate.loop_detection import find_for_loop from dace.sdfg import utils as sdutils + from dace.transformation.interstate.loop_detection import find_for_loop condition_edges = {} @@ -739,8 +741,8 @@ def propagate_states(sdfg, concretize_dynamic_unbounded=False) -> None: # We import here to avoid cyclic imports. from dace.sdfg import InterstateEdge - from dace.transformation.helpers import split_interstate_edges from dace.sdfg.analysis import cfg + from dace.transformation.helpers import split_interstate_edges # Reset the state edge annotations (which may have changed due to transformations) reset_state_annotations(sdfg) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 8d443e6beb..2ae6109b31 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -2987,35 +2987,52 @@ class LoopRegion(ControlFlowRegion): inverted = Property(dtype=bool, default=False, desc='If True, the loop condition is checked after the first iteration.') + update_before_condition = Property(dtype=bool, + default=True, + desc='If False, the loop condition is checked before the update statement is' + + ' executed. This only applies to inverted loops, turning them from a typical ' + + 'do-while style into a while(true) with a break before the update (at the end ' + + 'of an iteration) if the condition no longer holds.') loop_variable = Property(dtype=str, default='', desc='The loop variable, if given') def __init__(self, label: str, - condition_expr: Optional[str] = None, + condition_expr: Optional[Union[str, CodeBlock]] = None, loop_var: Optional[str] = None, - initialize_expr: Optional[str] = None, - update_expr: Optional[str] = None, + initialize_expr: Optional[Union[str, CodeBlock]] = None, + update_expr: Optional[Union[str, CodeBlock]] = None, inverted: bool = False, - sdfg: Optional['SDFG'] = None): + sdfg: Optional['SDFG'] = None, + update_before_condition = True): super(LoopRegion, self).__init__(label, sdfg) if initialize_expr is not None: - self.init_statement = CodeBlock(initialize_expr) + if isinstance(initialize_expr, CodeBlock): + self.init_statement = initialize_expr + else: + self.init_statement = CodeBlock(initialize_expr) else: self.init_statement = None if condition_expr: - self.loop_condition = CodeBlock(condition_expr) + if isinstance(condition_expr, CodeBlock): + self.loop_condition = condition_expr + else: + self.loop_condition = CodeBlock(condition_expr) else: self.loop_condition = CodeBlock('True') if update_expr is not None: - self.update_statement = CodeBlock(update_expr) + if isinstance(update_expr, CodeBlock): + self.update_statement = update_expr + else: + self.update_statement = CodeBlock(update_expr) else: self.update_statement = None self.loop_variable = loop_var or '' self.inverted = inverted + self.update_before_condition = update_before_condition def inline(self) -> Tuple[bool, Any]: """ @@ -3234,7 +3251,12 @@ def __repr__(self) -> str: @property def branches(self) -> List[Tuple[Optional[CodeBlock], ControlFlowRegion]]: return self._branches - + + def add_branch(self, condition: Optional[CodeBlock], branch: ControlFlowRegion): + self._branches.append([condition, branch]) + branch.parent_graph = self.parent_graph + branch.sdfg = self.sdfg + def nodes(self) -> List['ControlFlowBlock']: return [node for _, node in self._branches if node is not None] diff --git a/dace/transformation/helpers.py b/dace/transformation/helpers.py index 74a3d2ee12..6ca4602079 100644 --- a/dace/transformation/helpers.py +++ b/dace/transformation/helpers.py @@ -379,7 +379,7 @@ def nest_state_subgraph(sdfg: SDFG, SDFG. :raise ValueError: The subgraph is contained in more than one scope. """ - if state.parent != sdfg: + if state.sdfg != sdfg: raise KeyError('State does not belong to given SDFG') if subgraph is not state and subgraph.graph is not state: raise KeyError('Subgraph does not belong to given state') @@ -433,7 +433,7 @@ def nest_state_subgraph(sdfg: SDFG, # top-level graph) data_in_subgraph = set(n.data for n in subgraph.nodes() if isinstance(n, nodes.AccessNode)) # Find other occurrences in SDFG - other_nodes = set(n.data for s in sdfg.nodes() for n in s.nodes() + other_nodes = set(n.data for s in sdfg.states() for n in s.nodes() if isinstance(n, nodes.AccessNode) and n not in subgraph.nodes()) subgraph_transients = set() for data in data_in_subgraph: diff --git a/dace/transformation/interstate/loop_detection.py b/dace/transformation/interstate/loop_detection.py index 93c2f6ea1c..8081447132 100644 --- a/dace/transformation/interstate/loop_detection.py +++ b/dace/transformation/interstate/loop_detection.py @@ -1,9 +1,9 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Loop detection transformation """ import sympy as sp import networkx as nx -from typing import AnyStr, Optional, Tuple, List, Set +from typing import AnyStr, Iterable, Optional, Tuple, List, Set from dace import sdfg as sd, symbolic from dace.sdfg import graph as gr, utils as sdutil, InterstateEdge @@ -29,6 +29,9 @@ class DetectLoop(transformation.PatternTransformation): # Available for rotated and self loops entry_state = transformation.PatternNode(sd.SDFGState) + # Available for explicit-latch rotated loops + loop_break = transformation.PatternNode(sd.SDFGState) + @classmethod def expressions(cls): # Case 1: Loop with one state @@ -69,7 +72,46 @@ def expressions(cls): ssdfg.add_edge(cls.loop_begin, cls.loop_begin, sd.InterstateEdge()) ssdfg.add_edge(cls.loop_begin, cls.exit_state, sd.InterstateEdge()) - return [sdfg, msdfg, rsdfg, rmsdfg, ssdfg] + # Case 6: Rotated multi-state loop with explicit exiting and latch states + mlrmsdfg = gr.OrderedDiGraph() + mlrmsdfg.add_nodes_from([cls.entry_state, cls.loop_break, cls.loop_latch, cls.loop_begin, cls.exit_state]) + mlrmsdfg.add_edge(cls.entry_state, cls.loop_begin, sd.InterstateEdge()) + mlrmsdfg.add_edge(cls.loop_latch, cls.loop_begin, sd.InterstateEdge()) + mlrmsdfg.add_edge(cls.loop_break, cls.exit_state, sd.InterstateEdge()) + mlrmsdfg.add_edge(cls.loop_break, cls.loop_latch, sd.InterstateEdge()) + + # Case 7: Rotated single-state loop with explicit exiting and latch states + mlrsdfg = gr.OrderedDiGraph() + mlrsdfg.add_nodes_from([cls.entry_state, cls.loop_latch, cls.loop_begin, cls.exit_state]) + mlrsdfg.add_edge(cls.entry_state, cls.loop_begin, sd.InterstateEdge()) + mlrsdfg.add_edge(cls.loop_latch, cls.loop_begin, sd.InterstateEdge()) + mlrsdfg.add_edge(cls.loop_begin, cls.exit_state, sd.InterstateEdge()) + mlrsdfg.add_edge(cls.loop_begin, cls.loop_latch, sd.InterstateEdge()) + + # Case 8: Guarded rotated multi-state loop with explicit exiting and latch states (modification of case 6) + gmlrmsdfg = gr.OrderedDiGraph() + gmlrmsdfg.add_nodes_from([cls.entry_state, cls.loop_break, cls.loop_latch, cls.loop_begin, cls.exit_state]) + gmlrmsdfg.add_edge(cls.entry_state, cls.loop_begin, sd.InterstateEdge()) + gmlrmsdfg.add_edge(cls.loop_latch, cls.loop_begin, sd.InterstateEdge()) + gmlrmsdfg.add_edge(cls.loop_begin, cls.loop_break, sd.InterstateEdge()) + gmlrmsdfg.add_edge(cls.loop_break, cls.exit_state, sd.InterstateEdge()) + gmlrmsdfg.add_edge(cls.loop_break, cls.loop_latch, sd.InterstateEdge()) + + return [sdfg, msdfg, rsdfg, rmsdfg, ssdfg, mlrmsdfg, mlrsdfg, gmlrmsdfg] + + @property + def inverted(self) -> bool: + """ + Whether the loop matched a pattern of an inverted (do-while style) loop. + """ + return self.expr_index in (2, 3, 5, 6, 7) + + @property + def first_loop_block(self) -> ControlFlowBlock: + """ + The first control flow block executed in each loop iteration. + """ + return self.loop_guard if self.expr_index <= 1 else self.loop_begin def can_be_applied(self, graph: ControlFlowRegion, @@ -77,19 +119,26 @@ def can_be_applied(self, sdfg: sd.SDFG, permissive: bool = False) -> bool: if expr_index == 0: - return self.detect_loop(graph, False) is not None + return self.detect_loop(graph, multistate_loop=False, accept_missing_itvar=permissive) is not None elif expr_index == 1: - return self.detect_loop(graph, True) is not None + return self.detect_loop(graph, multistate_loop=True, accept_missing_itvar=permissive) is not None elif expr_index == 2: - return self.detect_rotated_loop(graph, False) is not None + return self.detect_rotated_loop(graph, multistate_loop=False, accept_missing_itvar=permissive) is not None elif expr_index == 3: - return self.detect_rotated_loop(graph, True) is not None + return self.detect_rotated_loop(graph, multistate_loop=True, accept_missing_itvar=permissive) is not None elif expr_index == 4: - return self.detect_self_loop(graph) is not None + return self.detect_self_loop(graph, accept_missing_itvar=permissive) is not None + elif expr_index in (5, 7): + return self.detect_rotated_loop(graph, multistate_loop=True, accept_missing_itvar=permissive, + separate_latch=True) is not None + elif expr_index == 6: + return self.detect_rotated_loop(graph, multistate_loop=False, accept_missing_itvar=permissive, + separate_latch=True) is not None raise ValueError(f'Invalid expression index {expr_index}') - def detect_loop(self, graph: ControlFlowRegion, multistate_loop: bool) -> Optional[str]: + def detect_loop(self, graph: ControlFlowRegion, multistate_loop: bool, + accept_missing_itvar: bool = False) -> Optional[str]: """ Detects a loop of the form: @@ -159,13 +208,19 @@ def detect_loop(self, graph: ControlFlowRegion, multistate_loop: bool) -> Option # The backedge must reassign the iteration variable itvar &= backedge.data.assignments.keys() if len(itvar) != 1: - # Either no consistent iteration variable found, or too many - # consistent iteration variables found - return None + if not accept_missing_itvar: + # Either no consistent iteration variable found, or too many consistent iteration variables found + return None + else: + if len(itvar) == 0: + return '' + else: + return None return next(iter(itvar)) - def detect_rotated_loop(self, graph: ControlFlowRegion, multistate_loop: bool) -> Optional[str]: + def detect_rotated_loop(self, graph: ControlFlowRegion, multistate_loop: bool, + accept_missing_itvar: bool = False, separate_latch: bool = False) -> Optional[str]: """ Detects a loop of the form: @@ -181,6 +236,9 @@ def detect_rotated_loop(self, graph: ControlFlowRegion, multistate_loop: bool) - :return: The loop variable or ``None`` if not detected. """ latch = self.loop_latch + ltest = self.loop_latch + if separate_latch: + ltest = self.loop_break if multistate_loop else self.loop_begin begin = self.loop_begin # A for-loop start has at least two incoming edges (init and increment) @@ -188,18 +246,14 @@ def detect_rotated_loop(self, graph: ControlFlowRegion, multistate_loop: bool) - if len(begin_inedges) < 2: return None # A for-loop latch only has two outgoing edges (loop condition and exit-loop) - latch_outedges = graph.out_edges(latch) + latch_outedges = graph.out_edges(ltest) if len(latch_outedges) != 2: return None - # All incoming edges to the start of the loop must set the same variable - itvar = None - for iedge in begin_inedges: - if itvar is None: - itvar = set(iedge.data.assignments.keys()) - else: - itvar &= iedge.data.assignments.keys() - if itvar is None: + # A for-loop latch can further only have one incoming edge (the increment edge). A while-loop, i.e., a loop + # with no explicit iteration variable, may have more than that. + latch_inedges = graph.in_edges(latch) + if not accept_missing_itvar and len(latch_inedges) != 1: return None # Outgoing edges must be a negation of each other @@ -208,8 +262,13 @@ def detect_rotated_loop(self, graph: ControlFlowRegion, multistate_loop: bool) - # All nodes inside loop must be dominated by loop start dominators = nx.dominance.immediate_dominators(graph.nx, graph.start_block) - loop_nodes = list(sdutil.dfs_conditional(graph, sources=[begin], condition=lambda _, child: child != latch)) - loop_nodes += [latch] + if begin is ltest: + loop_nodes = [begin] + else: + loop_nodes = list(sdutil.dfs_conditional(graph, sources=[begin], condition=lambda _, child: child != ltest)) + loop_nodes.append(latch) + if ltest is not latch and ltest is not begin: + loop_nodes.append(ltest) backedge = None for node in loop_nodes: for e in graph.out_edges(node): @@ -231,16 +290,9 @@ def detect_rotated_loop(self, graph: ControlFlowRegion, multistate_loop: bool) - if backedge is None: return None - # The backedge must reassign the iteration variable - itvar &= backedge.data.assignments.keys() - if len(itvar) != 1: - # Either no consistent iteration variable found, or too many - # consistent iteration variables found - return None + return rotated_loop_find_itvar(begin_inedges, latch_inedges, backedge, ltest, accept_missing_itvar)[0] - return next(iter(itvar)) - - def detect_self_loop(self, graph: ControlFlowRegion) -> Optional[str]: + def detect_self_loop(self, graph: ControlFlowRegion, accept_missing_itvar: bool = False) -> Optional[str]: """ Detects a loop of the form: @@ -288,9 +340,14 @@ def detect_self_loop(self, graph: ControlFlowRegion) -> Optional[str]: # The backedge must reassign the iteration variable itvar &= backedge.data.assignments.keys() if len(itvar) != 1: - # Either no consistent iteration variable found, or too many - # consistent iteration variables found - return None + if not accept_missing_itvar: + # Either no consistent iteration variable found, or too many consistent iteration variables found + return None + else: + if len(itvar) == 0: + return '' + else: + return None return next(iter(itvar)) @@ -310,9 +367,10 @@ def loop_information( if self.expr_index <= 1: guard = self.loop_guard return find_for_loop(guard.parent_graph, guard, entry, itervar) - elif self.expr_index in (2, 3): + elif self.expr_index in (2, 3, 5, 6, 7): latch = self.loop_latch - return find_rotated_for_loop(latch.parent_graph, latch, entry, itervar) + return find_rotated_for_loop(latch.parent_graph, latch, entry, itervar, + separate_latch=(self.expr_index in (5, 6, 7))) elif self.expr_index == 4: return find_rotated_for_loop(entry.parent_graph, entry, entry, itervar) @@ -334,6 +392,14 @@ def loop_body(self) -> List[ControlFlowBlock]: return loop_nodes elif self.expr_index == 4: return [begin] + elif self.expr_index in (5, 7): + ltest = self.loop_break + latch = self.loop_latch + loop_nodes = list(sdutil.dfs_conditional(graph, sources=[begin], condition=lambda _, child: child != ltest)) + loop_nodes += [ltest, latch] + return loop_nodes + elif self.expr_index == 6: + return [begin, self.loop_latch] return [] @@ -343,8 +409,10 @@ def loop_meta_states(self) -> List[ControlFlowBlock]: """ if self.expr_index in (0, 1): return [self.loop_guard] - if self.expr_index in (2, 3): + if self.expr_index in (2, 3, 6): return [self.loop_latch] + if self.expr_index in (5, 7): + return [self.loop_break, self.loop_latch] return [] def loop_init_edge(self) -> gr.Edge[InterstateEdge]: @@ -357,7 +425,7 @@ def loop_init_edge(self) -> gr.Edge[InterstateEdge]: guard = self.loop_guard body = self.loop_body() return next(e for e in graph.in_edges(guard) if e.src not in body) - elif self.expr_index in (2, 3): + elif self.expr_index in (2, 3, 5, 6, 7): latch = self.loop_latch return next(e for e in graph.in_edges(begin) if e.src is not latch) elif self.expr_index == 4: @@ -377,9 +445,12 @@ def loop_exit_edge(self) -> gr.Edge[InterstateEdge]: elif self.expr_index in (2, 3): latch = self.loop_latch return graph.edges_between(latch, exitstate)[0] - elif self.expr_index == 4: + elif self.expr_index in (4, 6): begin = self.loop_begin return graph.edges_between(begin, exitstate)[0] + elif self.expr_index in (5, 7): + ltest = self.loop_break + return graph.edges_between(ltest, exitstate)[0] raise ValueError(f'Invalid expression index {self.expr_index}') @@ -398,6 +469,10 @@ def loop_condition_edge(self) -> gr.Edge[InterstateEdge]: elif self.expr_index == 4: begin = self.loop_begin return graph.edges_between(begin, begin)[0] + elif self.expr_index in (5, 6, 7): + latch = self.loop_latch + ltest = self.loop_break if self.expr_index in (5, 7) else self.loop_begin + return graph.edges_between(ltest, latch)[0] raise ValueError(f'Invalid expression index {self.expr_index}') @@ -411,15 +486,93 @@ def loop_increment_edge(self) -> gr.Edge[InterstateEdge]: guard = self.loop_guard body = self.loop_body() return next(e for e in graph.in_edges(guard) if e.src in body) - elif self.expr_index in (2, 3): - body = self.loop_body() - return next(e for e in graph.in_edges(begin) if e.src in body) + elif self.expr_index in (2, 3, 5, 6, 7): + _, step_edge = rotated_loop_find_itvar(graph.in_edges(begin), graph.in_edges(self.loop_latch), + graph.edges_between(self.loop_latch, begin)[0], self.loop_latch) + return step_edge elif self.expr_index == 4: return graph.edges_between(begin, begin)[0] raise ValueError(f'Invalid expression index {self.expr_index}') +def rotated_loop_find_itvar(begin_inedges: List[gr.Edge[InterstateEdge]], + latch_inedges: List[gr.Edge[InterstateEdge]], + backedge: gr.Edge[InterstateEdge], latch: ControlFlowBlock, + accept_missing_itvar: bool = False) -> Tuple[Optional[str], + Optional[gr.Edge[InterstateEdge]]]: + # The iteration variable must be assigned (initialized) on all edges leading into the beginning block, which + # are not the backedge. Gather all variabes for which that holds - they are all candidates for the iteration + # variable (Phase 1). Said iteration variable must then be incremented: + # EITHER: On the backedge, in which case the increment is only executed if the loop does not exit. This + # corresponds to a while(true) loop that checks the condition at the end of the loop body and breaks + # if it does not hold before incrementing. (Scenario 1) + # OR: On the edge(s) leading into the latch, in which case the increment is executed BEFORE the condition is + # checked - which corresponds to a do-while loop. (Scenario 2) + # For either case, the iteration variable may only be incremented on one of these places. Filter the candidates + # down to each variable for which this condition holds (Phase 2). If there is exactly one candidate remaining, + # that is the iteration variable. Otherwise it cannot be determined. + + # Phase 1: Gather iteration variable candidates. + itvar_candidates = None + for e in begin_inedges: + if e is backedge: + continue + if itvar_candidates is None: + itvar_candidates = set(e.data.assignments.keys()) + else: + itvar_candidates &= set(e.data.assignments.keys()) + + # Phase 2: Filter down the candidates according to incrementation edges. + step_edge = None + filtered_candidates = set() + backedge_incremented = set(backedge.data.assignments.keys()) + latch_incremented = None + if backedge.src is not backedge.dst: + # If this is a self loop, there are no edges going into the latch to be considered. The only incoming edges are + # from outside the loop. + for e in latch_inedges: + if e is backedge: + continue + if latch_incremented is None: + latch_incremented = set(e.data.assignments.keys()) + else: + latch_incremented &= set(e.data.assignments.keys()) + if latch_incremented is None: + latch_incremented = set() + for cand in itvar_candidates: + if cand in backedge_incremented: + # Scenario 1. + + # Note, only allow this scenario if the backedge leads directly from the latch to the entry, i.e., there is + # no intermediate block on the backedge path. + if backedge.src is not latch: + continue + + if cand not in latch_incremented: + filtered_candidates.add(cand) + elif cand in latch_incremented: + # Scenario 2. + if cand not in backedge_incremented: + filtered_candidates.add(cand) + if len(filtered_candidates) != 1: + if not accept_missing_itvar: + # Either no consistent iteration variable found, or too many consistent iteration variables found + return None, None + else: + if len(filtered_candidates) == 0: + return '', None + else: + return None, None + else: + itvar = next(iter(filtered_candidates)) + if itvar in backedge_incremented: + step_edge = backedge + elif len(latch_inedges) == 1: + step_edge = latch_inedges[0] + return itvar, step_edge + + def find_for_loop( graph: ControlFlowRegion, guard: sd.SDFGState, @@ -520,6 +673,10 @@ def find_for_loop( match = condition.match(itersym >= a) if match: end = match[a] + if end is None: + match = condition.match(sp.Ne(itersym + stride, a)) + if match: + end = match[a] - stride if end is None: # No match found return None @@ -531,14 +688,14 @@ def find_rotated_for_loop( graph: ControlFlowRegion, latch: sd.SDFGState, entry: sd.SDFGState, - itervar: Optional[str] = None + itervar: Optional[str] = None, + separate_latch: bool = False, ) -> Optional[Tuple[AnyStr, Tuple[symbolic.SymbolicType, symbolic.SymbolicType, symbolic.SymbolicType], Tuple[ List[sd.SDFGState], sd.SDFGState]]]: """ Finds rotated loop range from state machine. - :param latch: State from which the outgoing edges detect whether to exit - the loop or not. + :param latch: State from which the outgoing edges detect whether to reenter the loop or not. :param entry: First state in the loop body. :param itervar: An optional field that overrides the analyzed iteration variable. :return: (iteration variable, (start, end, stride), @@ -547,20 +704,19 @@ def find_rotated_for_loop( """ # Extract state transition edge information entry_inedges = graph.in_edges(entry) - condition_edge = graph.edges_between(latch, entry)[0] - - # All incoming edges to the loop entry must set the same variable + if separate_latch: + condition_edge = graph.in_edges(latch)[0] + backedge = graph.edges_between(latch, entry)[0] + else: + condition_edge = graph.edges_between(latch, entry)[0] + backedge = condition_edge + latch_inedges = graph.in_edges(latch) + + self_loop = latch is entry + step_edge = None if itervar is None: - itervars = None - for iedge in entry_inedges: - if itervars is None: - itervars = set(iedge.data.assignments.keys()) - else: - itervars &= iedge.data.assignments.keys() - if itervars and len(itervars) == 1: - itervar = next(iter(itervars)) - else: - # Ambiguous or no iteration variable + itervar, step_edge = rotated_loop_find_itvar(entry_inedges, latch_inedges, backedge, latch) + if itervar is None: return None condition = condition_edge.data.condition_sympy() @@ -570,18 +726,12 @@ def find_rotated_for_loop( # have one assignment. init_edges = [] init_assignment = None - step_edge = None itersym = symbolic.symbol(itervar) for iedge in entry_inedges: + if iedge is condition_edge: + continue assignment = iedge.data.assignments[itervar] - if itersym in symbolic.pystr_to_symbolic(assignment).free_symbols: - if step_edge is None: - step_edge = iedge - else: - # More than one edge with the iteration variable as a free - # symbol, which is not legal. Invalid for loop. - return None - else: + if itersym not in symbolic.pystr_to_symbolic(assignment).free_symbols: if init_assignment is None: init_assignment = assignment init_edges.append(iedge) @@ -591,10 +741,16 @@ def find_rotated_for_loop( return None else: init_edges.append(iedge) - if step_edge is None or len(init_edges) == 0 or init_assignment is None: + if len(init_edges) == 0 or init_assignment is None: # Less than two assignment variations, can't be a valid for loop. return None + if self_loop: + step_edge = condition_edge + else: + if step_edge is None: + return None + # Get the init expression and the stride. start = symbolic.pystr_to_symbolic(init_assignment) stride = (symbolic.pystr_to_symbolic(step_edge.data.assignments[itervar]) - itersym) @@ -626,6 +782,10 @@ def find_rotated_for_loop( match = condition.match(itersym >= a) if match: end = match[a] + if end is None: + match = condition.match(sp.Ne(itersym + stride, a)) + if match: + end = match[a] - stride if end is None: # No match found return None diff --git a/dace/transformation/interstate/loop_lifting.py b/dace/transformation/interstate/loop_lifting.py new file mode 100644 index 0000000000..072c2519ed --- /dev/null +++ b/dace/transformation/interstate/loop_lifting.py @@ -0,0 +1,99 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. + +from dace import properties +from dace.sdfg.sdfg import SDFG, InterstateEdge +from dace.sdfg.state import ControlFlowRegion, LoopRegion +from dace.transformation import transformation +from dace.transformation.interstate.loop_detection import DetectLoop + + +@properties.make_properties +@transformation.experimental_cfg_block_compatible +class LoopLifting(DetectLoop, transformation.MultiStateTransformation): + + def can_be_applied(self, graph: transformation.ControlFlowRegion, expr_index: int, sdfg: transformation.SDFG, + permissive: bool = False) -> bool: + # Check loop detection with permissive = True, which allows loops where no iteration variable could be detected. + # We want this to detect while loops. + if not super().can_be_applied(graph, expr_index, sdfg, permissive=True): + return False + + # Check that there's a condition edge, that's the only requirement to lift it into loop. + cond_edge = self.loop_condition_edge() + if not cond_edge or cond_edge.data.condition is None: + return False + return True + + def apply(self, graph: ControlFlowRegion, sdfg: SDFG): + first_state = self.first_loop_block + after = self.exit_state + + loop_info = self.loop_information() + + body = self.loop_body() + meta = self.loop_meta_states() + full_body = set(body) + full_body.update(meta) + cond_edge = self.loop_condition_edge() + incr_edge = self.loop_increment_edge() + inverted = self.inverted + init_edge = self.loop_init_edge() + exit_edge = self.loop_exit_edge() + + label = 'loop_' + first_state.label + if loop_info is None: + itvar = None + init_expr = None + incr_expr = None + else: + incr_expr = f'{loop_info[0]} = {incr_edge.data.assignments[loop_info[0]]}' + init_expr = f'{loop_info[0]} = {init_edge.data.assignments[loop_info[0]]}' + itvar = loop_info[0] + + left_over_assignments = {} + for k in init_edge.data.assignments.keys(): + if k != itvar: + left_over_assignments[k] = init_edge.data.assignments[k] + left_over_incr_assignments = {} + if incr_edge is not None: + for k in incr_edge.data.assignments.keys(): + if k != itvar: + left_over_incr_assignments[k] = incr_edge.data.assignments[k] + + if inverted and incr_edge is cond_edge: + update_before_condition = False + else: + update_before_condition = True + + loop = LoopRegion(label, condition_expr=cond_edge.data.condition, loop_var=itvar, initialize_expr=init_expr, + update_expr=incr_expr, inverted=inverted, sdfg=sdfg, + update_before_condition=update_before_condition) + + graph.add_node(loop) + graph.add_edge(init_edge.src, loop, + InterstateEdge(condition=init_edge.data.condition, assignments=left_over_assignments)) + graph.add_edge(loop, after, InterstateEdge(assignments=exit_edge.data.assignments)) + + loop.add_node(first_state, is_start_block=True) + added = set() + for e in graph.all_edges(*full_body): + if e.src in full_body and e.dst in full_body: + if not e in added: + added.add(e) + if e is incr_edge: + if left_over_incr_assignments != {}: + dst = loop.add_state(label + '_tail') if not inverted else e.dst + loop.add_edge(e.src, dst, InterstateEdge(assignments=left_over_incr_assignments)) + elif e is cond_edge: + if not inverted: + e.data.condition = properties.CodeBlock('1') + loop.add_edge(e.src, e.dst, e.data) + else: + loop.add_edge(e.src, e.dst, e.data) + + # Remove old loop. + for n in full_body: + graph.remove_node(n) + + sdfg.root_sdfg.using_experimental_blocks = True + sdfg.reset_cfg_list() diff --git a/dace/transformation/pass_pipeline.py b/dace/transformation/pass_pipeline.py index 494f9c39ae..9a8154df90 100644 --- a/dace/transformation/pass_pipeline.py +++ b/dace/transformation/pass_pipeline.py @@ -29,7 +29,8 @@ class Modifies(Flag): Memlets = auto() #: Memlets' existence, contents, or properties were modified Nodes = AccessNodes | Scopes | Tasklets | NestedSDFGs #: Modification of any dataflow node (contained in an SDFG state) was made Edges = InterstateEdges | Memlets #: Any edge (memlet or inter-state) was modified - Everything = Descriptors | Symbols | States | InterstateEdges | Nodes | Memlets #: Modification to arbitrary parts of SDFGs (nodes, edges, or properties) + CFG = States | InterstateEdges #: A CFG (any level) was modified (connectivity or number of control flow blocks, but not their contents) + Everything = Descriptors | Symbols | CFG | Nodes | Memlets #: Modification to arbitrary parts of SDFGs (nodes, edges, or properties) @properties.make_properties diff --git a/dace/transformation/passes/analysis/__init__.py b/dace/transformation/passes/analysis/__init__.py new file mode 100644 index 0000000000..5bc1f6e3f3 --- /dev/null +++ b/dace/transformation/passes/analysis/__init__.py @@ -0,0 +1 @@ +from .analysis import * diff --git a/dace/transformation/passes/analysis.py b/dace/transformation/passes/analysis/analysis.py similarity index 81% rename from dace/transformation/passes/analysis.py rename to dace/transformation/passes/analysis/analysis.py index c8bb0b7a9c..095319f807 100644 --- a/dace/transformation/passes/analysis.py +++ b/dace/transformation/passes/analysis/analysis.py @@ -1,7 +1,8 @@ -# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. from collections import defaultdict -from dace.transformation import pass_pipeline as ppl +from dace.sdfg.state import ControlFlowBlock, ControlFlowRegion, LoopRegion +from dace.transformation import pass_pipeline as ppl, transformation from dace import SDFG, SDFGState, properties, InterstateEdge, Memlet, data as dt, symbolic from dace.sdfg.graph import Edge from dace.sdfg import nodes as nd @@ -16,6 +17,7 @@ @properties.make_properties +@transformation.experimental_cfg_block_compatible class StateReachability(ppl.Pass): """ Evaluates state reachability (which other states can be executed after each state). @@ -28,25 +30,106 @@ def modifies(self) -> ppl.Modifies: def should_reapply(self, modified: ppl.Modifies) -> bool: # If anything was modified, reapply - return modified & ppl.Modifies.States + return modified & ppl.Modifies.CFG + + def depends_on(self): + return {ControlFlowBlockReachability} - def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[SDFGState, Set[SDFGState]]]: + def apply_pass(self, top_sdfg: SDFG, pipeline_res: Dict) -> Dict[int, Dict[SDFGState, Set[SDFGState]]]: """ :return: A dictionary mapping each state to its other reachable states. """ + # Ensure control flow block reachability is run if not run within a pipeline. + if pipeline_res is None or not ControlFlowBlockReachability.__name__ in pipeline_res: + cf_block_reach_dict = ControlFlowBlockReachability().apply_pass(top_sdfg, {}) + else: + cf_block_reach_dict = pipeline_res[ControlFlowBlockReachability.__name__] reachable: Dict[int, Dict[SDFGState, Set[SDFGState]]] = {} for sdfg in top_sdfg.all_sdfgs_recursive(): - result: Dict[SDFGState, Set[SDFGState]] = {} + result: Dict[SDFGState, Set[SDFGState]] = defaultdict(set) + for state in sdfg.states(): + for reached in cf_block_reach_dict[state.parent_graph.cfg_id][state]: + if isinstance(reached, SDFGState): + result[state].add(reached) + reachable[sdfg.cfg_id] = result + return reachable + + +@properties.make_properties +@transformation.experimental_cfg_block_compatible +class ControlFlowBlockReachability(ppl.Pass): + """ + Evaluates control flow block reachability (which control flow block can be executed after each control flow block) + """ + + CATEGORY: str = 'Analysis' + + contain_to_single_level = properties.Property(dtype=bool, default=False) + + def __init__(self, contain_to_single_level=False) -> None: + super().__init__() + + self.contain_to_single_level = contain_to_single_level + def modifies(self) -> ppl.Modifies: + return ppl.Modifies.Nothing + + def should_reapply(self, modified: ppl.Modifies) -> bool: + return modified & ppl.Modifies.CFG + + def _region_closure(self, region: ControlFlowRegion, + block_reach: Dict[int, Dict[ControlFlowBlock, Set[ControlFlowBlock]]]) -> Set[SDFGState]: + closure: Set[SDFGState] = set() + if isinstance(region, LoopRegion): + # Any point inside the loop may reach any other point inside the loop again. + # TODO(later): This is an overapproximation. A branch terminating in a break is excluded from this. + closure.update(region.all_control_flow_blocks()) + + # Add all states that this region can reach in its parent graph to the closure. + for reached_block in block_reach[region.parent_graph.cfg_id][region]: + if isinstance(reached_block, ControlFlowRegion): + closure.update(reached_block.all_control_flow_blocks()) + closure.add(reached_block) + + # Walk up the parent tree. + pivot = region.parent_graph + while pivot and not isinstance(pivot, SDFG): + closure.update(self._region_closure(pivot, block_reach)) + pivot = pivot.parent_graph + return closure + + def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[ControlFlowBlock, Set[ControlFlowBlock]]]: + """ + :return: For each control flow region, a dictionary mapping each control flow block to its other reachable + control flow blocks in the same region. + """ + single_level_reachable: Dict[int, Dict[ControlFlowBlock, Set[ControlFlowBlock]]] = defaultdict( + lambda: defaultdict(set) + ) + for cfg in top_sdfg.all_control_flow_regions(recursive=True): # In networkx this is currently implemented naively for directed graphs. # The implementation below is faster # tc: nx.DiGraph = nx.transitive_closure(sdfg.nx) + for n, v in reachable_nodes(cfg.nx): + single_level_reachable[cfg.cfg_id][n] = set(v) + if isinstance(cfg, LoopRegion): + single_level_reachable[cfg.cfg_id][n].update(cfg.nodes()) - for n, v in reachable_nodes(sdfg.nx): - result[n] = set(v) - - reachable[sdfg.cfg_id] = result + if self.contain_to_single_level: + return single_level_reachable + reachable: Dict[int, Dict[ControlFlowBlock, Set[ControlFlowBlock]]] = {} + for sdfg in top_sdfg.all_sdfgs_recursive(): + for cfg in sdfg.all_control_flow_regions(): + result: Dict[ControlFlowBlock, Set[ControlFlowBlock]] = defaultdict(set) + for block in cfg.nodes(): + for reached in single_level_reachable[block.parent_graph.cfg_id][block]: + if isinstance(reached, ControlFlowRegion): + result[block].update(reached.all_control_flow_blocks()) + result[block].add(reached) + if block.parent_graph is not sdfg: + result[block].update(self._region_closure(block.parent_graph, single_level_reachable)) + reachable[cfg.cfg_id] = result return reachable @@ -99,6 +182,7 @@ def reachable_nodes(G): @properties.make_properties +@transformation.experimental_cfg_block_compatible class SymbolAccessSets(ppl.Pass): """ Evaluates symbol access sets (which symbols are read/written in each state or interstate edge). @@ -116,25 +200,27 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[Union[SDFGState, Edge[InterstateEdge]], Tuple[Set[str], Set[str]]]]: """ - :return: A dictionary mapping each state to a tuple of its (read, written) data descriptors. + :return: A dictionary mapping each state and interstate edge to a tuple of its (read, written) symbols. """ - top_result: Dict[int, Dict[SDFGState, Tuple[Set[str], Set[str]]]] = {} + top_result: Dict[int, Dict[Union[SDFGState, Edge[InterstateEdge]], Tuple[Set[str], Set[str]]]] = {} for sdfg in top_sdfg.all_sdfgs_recursive(): - adesc = set(sdfg.arrays.keys()) - result: Dict[SDFGState, Tuple[Set[str], Set[str]]] = {} - for state in sdfg.nodes(): - readset = state.free_symbols - # No symbols may be written to inside states. - result[state] = (readset, set()) - for oedge in sdfg.out_edges(state): - edge_readset = oedge.data.read_symbols() - adesc - edge_writeset = set(oedge.data.assignments.keys()) - result[oedge] = (edge_readset, edge_writeset) - top_result[sdfg.cfg_id] = result + for cfg in sdfg.all_control_flow_regions(): + adesc = set(sdfg.arrays.keys()) + result: Dict[SDFGState, Tuple[Set[str], Set[str]]] = {} + for block in cfg.nodes(): + if isinstance(block, SDFGState): + # No symbols may be written to inside states. + result[block] = (block.free_symbols, set()) + for oedge in cfg.out_edges(block): + edge_readset = oedge.data.read_symbols() - adesc + edge_writeset = set(oedge.data.assignments.keys()) + result[oedge] = (edge_readset, edge_writeset) + top_result[cfg.cfg_id] = result return top_result @properties.make_properties +@transformation.experimental_cfg_block_compatible class AccessSets(ppl.Pass): """ Evaluates memory access sets (which arrays/data descriptors are read/written in each state). @@ -179,6 +265,7 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[SDFGState, Tuple[Set[s @properties.make_properties +@transformation.experimental_cfg_block_compatible class FindAccessStates(ppl.Pass): """ For each data descriptor, creates a set of states in which access nodes of that data are used. @@ -201,13 +288,13 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[str, Set[SDFGState]]]: for sdfg in top_sdfg.all_sdfgs_recursive(): result: Dict[str, Set[SDFGState]] = defaultdict(set) - for state in sdfg.nodes(): + for state in sdfg.states(): for anode in state.data_nodes(): result[anode.data].add(state) # Edges that read from arrays add to both ends' access sets anames = sdfg.arrays.keys() - for e in sdfg.edges(): + for e in sdfg.all_interstate_edges(): fsyms = e.data.free_symbols & anames for access in fsyms: result[access].update({e.src, e.dst}) @@ -217,6 +304,7 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[str, Set[SDFGState]]]: @properties.make_properties +@transformation.experimental_cfg_block_compatible class FindAccessNodes(ppl.Pass): """ For each data descriptor, creates a dictionary mapping states to all read and write access nodes with the given @@ -242,7 +330,7 @@ def apply_pass(self, top_sdfg: SDFG, for sdfg in top_sdfg.all_sdfgs_recursive(): result: Dict[str, Dict[SDFGState, Tuple[Set[nd.AccessNode], Set[nd.AccessNode]]]] = defaultdict( lambda: defaultdict(lambda: [set(), set()])) - for state in sdfg.nodes(): + for state in sdfg.states(): for anode in state.data_nodes(): if state.in_degree(anode) > 0: result[anode.data][state][1].add(anode) @@ -508,6 +596,7 @@ def apply_pass(self, top_sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[i @properties.make_properties +@transformation.experimental_cfg_block_compatible class AccessRanges(ppl.Pass): """ For each data descriptor, finds all memlets used to access it (read/write ranges). @@ -544,6 +633,7 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[str, Set[Memlet]]]: @properties.make_properties +@transformation.experimental_cfg_block_compatible class FindReferenceSources(ppl.Pass): """ For each Reference data descriptor, finds all memlets used to set it. If a Tasklet was used @@ -586,6 +676,7 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[str, Set[Union[Memlet, @properties.make_properties +@transformation.experimental_cfg_block_compatible class DeriveSDFGConstraints(ppl.Pass): CATEGORY: str = 'Analysis' diff --git a/dace/transformation/passes/analysis/loop_analysis.py b/dace/transformation/passes/analysis/loop_analysis.py new file mode 100644 index 0000000000..3d15f73c73 --- /dev/null +++ b/dace/transformation/passes/analysis/loop_analysis.py @@ -0,0 +1,116 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +""" +Various analyses concerning LopoRegions, and utility functions to get information about LoopRegions for other passes. +""" + +import ast +from typing import Any, Dict, Optional +from dace.frontend.python import astutils + +import sympy + +from dace import symbolic +from dace.sdfg.state import LoopRegion + + +class FindAssignment(ast.NodeVisitor): + + assignments: Dict[str, str] + multiple: bool + + def __init__(self): + self.assignments = {} + self.multiple = False + + def visit_Assign(self, node: ast.Assign) -> Any: + for tgt in node.targets: + if isinstance(tgt, ast.Name): + if tgt.id in self.assignments: + self.multiple = True + self.assignments[tgt.id] = astutils.unparse(node.value) + return self.generic_visit(node) + + +def get_loop_end(loop: LoopRegion) -> Optional[symbolic.SymbolicType]: + """ + Parse a loop region to identify the end value of the iteration variable under normal loop termination (no break). + """ + end: Optional[symbolic.SymbolicType] = None + a = sympy.Wild('a') + condition = symbolic.pystr_to_symbolic(loop.loop_condition.as_string) + itersym = symbolic.pystr_to_symbolic(loop.loop_variable) + match = condition.match(itersym < a) + if match: + end = match[a] - 1 + if end is None: + match = condition.match(itersym <= a) + if match: + end = match[a] + if end is None: + match = condition.match(itersym > a) + if match: + end = match[a] + 1 + if end is None: + match = condition.match(itersym >= a) + if match: + end = match[a] + return end + + +def get_init_assignment(loop: LoopRegion) -> Optional[symbolic.SymbolicType]: + """ + Parse a loop region's init statement to identify the exact init assignment expression. + """ + init_stmt = loop.init_statement + if init_stmt is None: + return None + + init_codes_list = init_stmt.code if isinstance(init_stmt.code, list) else [init_stmt.code] + assignments: Dict[str, str] = {} + for code in init_codes_list: + visitor = FindAssignment() + visitor.visit(code) + if visitor.multiple: + return None + for assign in visitor.assignments: + if assign in assignments: + return None + assignments[assign] = visitor.assignments[assign] + + if loop.loop_variable in assignments: + return symbolic.pystr_to_symbolic(assignments[loop.loop_variable]) + + return None + + +def get_update_assignment(loop: LoopRegion) -> Optional[symbolic.SymbolicType]: + """ + Parse a loop region's update statement to identify the exact update assignment expression. + """ + update_stmt = loop.update_statement + if update_stmt is None: + return None + + update_codes_list = update_stmt.code if isinstance(update_stmt.code, list) else [update_stmt.code] + assignments: Dict[str, str] = {} + for code in update_codes_list: + visitor = FindAssignment() + visitor.visit(code) + if visitor.multiple: + return None + for assign in visitor.assignments: + if assign in assignments: + return None + assignments[assign] = visitor.assignments[assign] + + if loop.loop_variable in assignments: + return symbolic.pystr_to_symbolic(assignments[loop.loop_variable]) + + return None + + +def get_loop_stride(loop: LoopRegion) -> Optional[symbolic.SymbolicType]: + update_assignment = get_update_assignment(loop) + if update_assignment: + return update_assignment - symbolic.pystr_to_symbolic(loop.loop_variable) + return None diff --git a/dace/transformation/passes/simplification/control_flow_raising.py b/dace/transformation/passes/simplification/control_flow_raising.py new file mode 100644 index 0000000000..abe305f12c --- /dev/null +++ b/dace/transformation/passes/simplification/control_flow_raising.py @@ -0,0 +1,96 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. + +from typing import Optional, Tuple +import networkx as nx +from dace import properties +from dace.sdfg.analysis import cfg as cfg_analysis +from dace.sdfg.sdfg import SDFG, InterstateEdge +from dace.sdfg.state import ConditionalBlock, ControlFlowBlock, ControlFlowRegion +from dace.sdfg.utils import dfs_conditional +from dace.transformation import pass_pipeline as ppl, transformation +from dace.transformation.interstate.loop_lifting import LoopLifting + + +@properties.make_properties +@transformation.experimental_cfg_block_compatible +class ControlFlowRaising(ppl.Pass): + """ + Raises all detectable control flow that can be expressed with native SDFG structures, such as loops and branching. + """ + + CATEGORY: str = 'Simplification' + + def modifies(self) -> ppl.Modifies: + return ppl.Modifies.CFG + + def should_reapply(self, modified: ppl.Modifies) -> bool: + return modified & ppl.Modifies.CFG + + def _lift_conditionals(self, sdfg: SDFG) -> int: + cfgs = list(sdfg.all_control_flow_regions()) + n_cond_regions_pre = len([x for x in sdfg.all_control_flow_blocks() if isinstance(x, ConditionalBlock)]) + + for region in cfgs: + sinks = region.sink_nodes() + dummy_exit = region.add_state('__DACE_DUMMY') + for s in sinks: + region.add_edge(s, dummy_exit, InterstateEdge()) + idom = nx.immediate_dominators(region.nx, region.start_block) + alldoms = cfg_analysis.all_dominators(region, idom) + branch_merges = cfg_analysis.branch_merges(region, idom, alldoms) + + for block in region.nodes(): + graph = block.parent_graph + oedges = graph.out_edges(block) + if len(oedges) > 1 and block in branch_merges: + merge_block = branch_merges[block] + + # Construct the branching block. + conditional = ConditionalBlock('conditional_' + block.label, sdfg, graph) + graph.add_node(conditional) + # Connect it. + graph.add_edge(block, conditional, InterstateEdge()) + + # Populate branches. + for i, oe in enumerate(oedges): + branch_name = 'branch_' + str(i) + '_' + block.label + branch = ControlFlowRegion(branch_name, sdfg) + conditional.add_branch(oe.data.condition, branch) + if oe.dst is merge_block: + # Empty branch. + continue + + branch_nodes = set(dfs_conditional(graph, [oe.dst], lambda _, x: x is not merge_block)) + branch_start = branch.add_state(branch_name + '_start', is_start_block=True) + branch.add_nodes_from(branch_nodes) + branch_end = branch.add_state(branch_name + '_end') + branch.add_edge(branch_start, oe.dst, InterstateEdge(assignments=oe.data.assignments)) + added = set() + for e in graph.all_edges(*branch_nodes): + if not (e in added): + added.add(e) + if e is oe: + continue + elif e.dst is merge_block: + branch.add_edge(e.src, branch_end, e.data) + else: + branch.add_edge(e.src, e.dst, e.data) + graph.remove_nodes_from(branch_nodes) + + # Connect to the end of the branch / what happens after. + if merge_block is not dummy_exit: + graph.add_edge(conditional, merge_block, InterstateEdge()) + region.remove_node(dummy_exit) + + n_cond_regions_post = len([x for x in sdfg.all_control_flow_blocks() if isinstance(x, ConditionalBlock)]) + return n_cond_regions_post - n_cond_regions_pre + + def apply_pass(self, top_sdfg: SDFG, _) -> Optional[Tuple[int, int]]: + lifted_loops = 0 + lifted_branches = 0 + for sdfg in top_sdfg.all_sdfgs_recursive(): + lifted_loops += sdfg.apply_transformations_repeated([LoopLifting], validate_all=False, validate=False) + lifted_branches += self._lift_conditionals(sdfg) + if lifted_branches == 0 and lifted_loops == 0: + return None + return lifted_loops, lifted_branches diff --git a/dace/transformation/subgraph/expansion.py b/dace/transformation/subgraph/expansion.py index db1e9b59ab..aa182e8c80 100644 --- a/dace/transformation/subgraph/expansion.py +++ b/dace/transformation/subgraph/expansion.py @@ -1,26 +1,21 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ This module contains classes that implement the expansion transformation. """ -from dace import dtypes, registry, symbolic, subsets +from dace import dtypes, symbolic, subsets from dace.sdfg import nodes -from dace.memlet import Memlet from dace.sdfg import replace, SDFG, dynamic_map_inputs from dace.sdfg.graph import SubgraphView from dace.transformation import transformation from dace.properties import make_properties, Property -from dace.sdfg.propagation import propagate_memlets_sdfg from dace.transformation.subgraph import helpers from collections import defaultdict from copy import deepcopy as dcpy -from typing import List, Union import itertools -import dace.libraries.standard as stdlib import warnings -import sys def offset_map(state, map_entry): diff --git a/dace/transformation/subgraph/helpers.py b/dace/transformation/subgraph/helpers.py index b2af49c879..0ea1903522 100644 --- a/dace/transformation/subgraph/helpers.py +++ b/dace/transformation/subgraph/helpers.py @@ -1,20 +1,11 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Subgraph Transformation Helper API """ -from dace import dtypes, registry, symbolic, subsets -from dace.sdfg import nodes, utils -from dace.memlet import Memlet -from dace.sdfg import replace, SDFG, SDFGState -from dace.properties import make_properties, Property -from dace.sdfg.propagation import propagate_memlets_sdfg +from dace import subsets +from dace.sdfg import nodes from dace.sdfg.graph import SubgraphView -from collections import defaultdict import copy -from typing import List, Union, Dict, Tuple, Set - -import dace.libraries.standard as stdlib - -import itertools +from typing import List, Dict, Set # **************** # Helper functions diff --git a/tests/passes/simplification/control_flow_raising_test.py b/tests/passes/simplification/control_flow_raising_test.py new file mode 100644 index 0000000000..53e01df12f --- /dev/null +++ b/tests/passes/simplification/control_flow_raising_test.py @@ -0,0 +1,98 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. + +import dace +import numpy as np +from dace.sdfg.state import ConditionalBlock +from dace.transformation.pass_pipeline import FixedPointPipeline, Pipeline +from dace.transformation.passes.simplification.control_flow_raising import ControlFlowRaising + + +def test_dataflow_if_check(): + + @dace.program + def dataflow_if_check(A: dace.int32[10], i: dace.int64): + if A[i] < 10: + return 0 + elif A[i] == 10: + return 10 + return 100 + + sdfg = dataflow_if_check.to_sdfg() + + assert not any(isinstance(x, ConditionalBlock) for x in sdfg.nodes()) + + ppl = FixedPointPipeline([ControlFlowRaising()]) + ppl.__experimental_cfg_block_compatible__ = True + ppl.apply_pass(sdfg, {}) + + assert any(isinstance(x, ConditionalBlock) for x in sdfg.nodes()) + + A = np.zeros((10,), np.int32) + A[4] = 10 + A[5] = 100 + assert sdfg(A, 0)[0] == 0 + assert sdfg(A, 4)[0] == 10 + assert sdfg(A, 5)[0] == 100 + assert sdfg(A, 6)[0] == 0 + + +def test_nested_if_chain(): + + @dace.program + def nested_if_chain(i: dace.int64): + if i < 2: + return 0 + else: + if i < 4: + return 1 + else: + if i < 6: + return 2 + else: + if i < 8: + return 3 + else: + return 4 + + sdfg = nested_if_chain.to_sdfg() + + assert not any(isinstance(x, ConditionalBlock) for x in sdfg.nodes()) + + assert nested_if_chain(0)[0] == 0 + assert nested_if_chain(2)[0] == 1 + assert nested_if_chain(4)[0] == 2 + assert nested_if_chain(7)[0] == 3 + assert nested_if_chain(15)[0] == 4 + + +def test_elif_chain(): + + @dace.program + def elif_chain(i: dace.int64): + if i < 2: + return 0 + elif i < 4: + return 1 + elif i < 6: + return 2 + elif i < 8: + return 3 + else: + return 4 + + elif_chain.use_experimental_cfg_blocks = True + sdfg = elif_chain.to_sdfg() + + assert any(isinstance(x, ConditionalBlock) for x in sdfg.nodes()) + + assert elif_chain(0)[0] == 0 + assert elif_chain(2)[0] == 1 + assert elif_chain(4)[0] == 2 + assert elif_chain(7)[0] == 3 + assert elif_chain(15)[0] == 4 + + +if __name__ == '__main__': + test_dataflow_if_check() + test_nested_if_chain() + test_elif_chain() diff --git a/tests/passes/writeset_underapproximation_test.py b/tests/passes/writeset_underapproximation_test.py index 7d5272d80a..96df87b5e7 100644 --- a/tests/passes/writeset_underapproximation_test.py +++ b/tests/passes/writeset_underapproximation_test.py @@ -1,7 +1,8 @@ # Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +from typing import Dict import dace -from dace.sdfg.analysis.writeset_underapproximation import UnderapproximateWrites +from dace.sdfg.analysis.writeset_underapproximation import UnderapproximateWrites, UnderapproximateWritesDict from dace.subsets import Range from dace.transformation.pass_pipeline import Pipeline @@ -9,8 +10,6 @@ M = dace.symbol("M") K = dace.symbol("K") -pipeline = Pipeline([UnderapproximateWrites()]) - def test_2D_map_overwrites_2D_array(): """ @@ -33,9 +32,10 @@ def test_2D_map_overwrites_2D_array(): output_nodes={'B': a1}, external_edges=True) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - result = results['approximation'] + result = results[sdfg.cfg_id].approximation edge = map_state.in_edges(a1)[0] result_subset_list = result[edge].subset.subset_list result_subset = result_subset_list[0] @@ -65,9 +65,10 @@ def test_2D_map_added_indices(): output_nodes={"B": a1}, external_edges=True) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - result = results["approximation"] + result = results[sdfg.cfg_id].approximation edge = map_state.in_edges(a1)[0] assert (len(result[edge].subset.subset_list) == 0) @@ -94,9 +95,10 @@ def test_2D_map_multiplied_indices(): output_nodes={"B": a1}, external_edges=True) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - result = results["approximation"] + result = results[sdfg.cfg_id].approximation edge = map_state.in_edges(a1)[0] assert (len(result[edge].subset.subset_list) == 0) @@ -121,9 +123,10 @@ def test_1D_map_one_index_multiple_dims(): output_nodes={"B": a1}, external_edges=True) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - result = results["approximation"] + result = results[sdfg.cfg_id].approximation edge = map_state.in_edges(a1)[0] assert (len(result[edge].subset.subset_list) == 0) @@ -146,9 +149,10 @@ def test_1D_map_one_index_squared(): output_nodes={"B": a1}, external_edges=True) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - result = results["approximation"] + result = results[sdfg.cfg_id].approximation edge = map_state.in_edges(a1)[0] assert (len(result[edge].subset.subset_list) == 0) @@ -185,9 +189,10 @@ def test_map_tree_full_write(): inner_edge_1 = map_state.add_edge(inner_map_exit_1, "OUT_B", map_exit, "IN_B", dace.Memlet(data="B")) outer_edge = map_state.add_edge(map_exit, "OUT_B", a1, None, dace.Memlet(data="B")) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - result = results["approximation"] + result = results[sdfg.cfg_id].approximation expected_subset_outer_edge = Range.from_string("0:M, 0:N") expected_subset_inner_edge = Range.from_string("0:M, _i") result_inner_edge_0 = result[inner_edge_0].subset.subset_list[0] @@ -230,9 +235,10 @@ def test_map_tree_no_write_multiple_indices(): inner_edge_1 = map_state.add_edge(inner_map_exit_1, "OUT_B", map_exit, "IN_B", dace.Memlet(data="B")) outer_edge = map_state.add_edge(map_exit, "OUT_B", a1, None, dace.Memlet(data="B")) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - result = results["approximation"] + result = results[sdfg.cfg_id].approximation result_inner_edge_0 = result[inner_edge_0].subset.subset_list result_inner_edge_1 = result[inner_edge_1].subset.subset_list result_outer_edge = result[outer_edge].subset.subset_list @@ -273,9 +279,10 @@ def test_map_tree_multiple_indices_per_dimension(): inner_edge_1 = map_state.add_edge(inner_map_exit_1, "OUT_B", map_exit, "IN_B", dace.Memlet(data="B")) outer_edge = map_state.add_edge(map_exit, "OUT_B", a1, None, dace.Memlet(data="B")) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - result = results["approximation"] + result = results[sdfg.cfg_id].approximation expected_subset_outer_edge = Range.from_string("0:M, 0:N") expected_subset_inner_edge_1 = Range.from_string("0:M, _i") result_inner_edge_1 = result[inner_edge_1].subset.subset_list[0] @@ -300,11 +307,12 @@ def loop(A: dace.float64[N, M]): sdfg = loop.to_sdfg(simplify=True) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] nsdfg = sdfg.cfg_list[1].parent_nsdfg_node map_state = sdfg.states()[0] - result = results["approximation"] + result = results[sdfg.cfg_id].approximation edge = map_state.out_edges(nsdfg)[0] assert (len(result[edge].subset.subset_list) == 0) @@ -323,11 +331,12 @@ def loop(A: dace.float64[N, M]): sdfg = loop.to_sdfg(simplify=True) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] map_state = sdfg.states()[0] edge = map_state.in_edges(map_state.data_nodes()[0])[0] - result = results["approximation"] + result = results[sdfg.cfg_id].approximation expected_subset = Range.from_string("0:N, 0:M") assert (str(result[edge].subset.subset_list[0]) == str(expected_subset)) @@ -357,9 +366,10 @@ def test_map_in_loop(): output_nodes={"B": a1}, external_edges=True) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - result = results["loop_approximation"] + result = results[sdfg.cfg_id].loop_approximation expected_subset = Range.from_string("0:N, 0:M") assert (str(result[guard]["B"].subset.subset_list[0]) == str(expected_subset)) @@ -390,9 +400,10 @@ def test_map_in_loop_multiplied_indices_first_dimension(): output_nodes={"B": a1}, external_edges=True) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - result = results["loop_approximation"] + result = results[sdfg.cfg_id].loop_approximation assert (guard not in result.keys() or len(result[guard]) == 0) @@ -421,9 +432,10 @@ def test_map_in_loop_multiplied_indices_second_dimension(): output_nodes={"B": a1}, external_edges=True) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - result = results["loop_approximation"] + result = results[sdfg.cfg_id].loop_approximation assert (guard not in result.keys() or len(result[guard]) == 0) @@ -444,8 +456,9 @@ def nested_loop(A: dace.float64[M, N]): sdfg = nested_loop.to_sdfg(simplify=True) + pipeline = Pipeline([UnderapproximateWrites()]) result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - write_approx = result["approximation"] + write_approx = result[sdfg.cfg_id].approximation # find write set accessnode = None write_set = None @@ -478,9 +491,10 @@ def nested_loop(A: dace.float64[M, N]): sdfg = nested_loop.to_sdfg(simplify=True) + pipeline = Pipeline([UnderapproximateWrites()]) result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - write_approx = result["approximation"] + write_approx = result[sdfg.cfg_id].approximation # find write set accessnode = None write_set = None @@ -510,15 +524,16 @@ def nested_loop(A: dace.float64[M, N]): sdfg = nested_loop.to_sdfg(simplify=True) + pipeline = Pipeline([UnderapproximateWrites()]) result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - write_approx = result["approximation"] + write_approx = result[sdfg.cfg_id].approximation # find write set accessnode = None write_set = None - for node, _ in sdfg.all_nodes_recursive(): + for node, parent in sdfg.all_nodes_recursive(): if isinstance(node, dace.nodes.AccessNode): - if node.data == "A": + if node.data == "A" and parent.out_degree(node) == 0: accessnode = node for edge, memlet in write_approx.items(): if edge.dst is accessnode: @@ -531,6 +546,7 @@ def test_nested_sdfg_in_map_branches(): Nested SDFG that overwrites second dimension of array conditionally. --> should approximate write-set of map as empty """ + # No, should be approximated precisely - at least certainly with CF regions..? @dace.program def nested_loop(A: dace.float64[M, N]): @@ -542,15 +558,16 @@ def nested_loop(A: dace.float64[M, N]): sdfg = nested_loop.to_sdfg(simplify=True) - result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] + pipeline = Pipeline([UnderapproximateWrites()]) + result: Dict[int, UnderapproximateWritesDict] = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - write_approx = result["approximation"] + write_approx = result[sdfg.cfg_id].approximation # find write set accessnode = None write_set = None - for node, _ in sdfg.all_nodes_recursive(): + for node, parent in sdfg.all_nodes_recursive(): if isinstance(node, dace.nodes.AccessNode): - if node.data == "A": + if node.data == "A" and parent.out_degree(node) == 0: accessnode = node for edge, memlet in write_approx.items(): if edge.dst is accessnode: @@ -574,9 +591,10 @@ def test_simple_loop_overwrite(): loop_tasklet = loop_body.add_tasklet("overwrite", {}, {"a"}, "a = 0") loop_body.add_edge(loop_tasklet, "a", a0, None, dace.Memlet("A[i]")) - result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__]["loop_approximation"] + pipeline = Pipeline([UnderapproximateWrites()]) + result: UnderapproximateWritesDict = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__][sdfg.cfg_id] - assert (str(result[guard]["A"].subset) == str(Range.from_array(sdfg.arrays["A"]))) + assert (str(result.loop_approximation[guard]["A"].subset) == str(Range.from_array(sdfg.arrays["A"]))) def test_loop_2D_overwrite(): @@ -598,7 +616,8 @@ def test_loop_2D_overwrite(): loop_tasklet = loop_body.add_tasklet("overwrite", {}, {"a"}, "a = 0") loop_body.add_edge(loop_tasklet, "a", a0, None, dace.Memlet("A[j,i]")) - result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__]["loop_approximation"] + pipeline = Pipeline([UnderapproximateWrites()]) + result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__][sdfg.cfg_id].loop_approximation assert (str(result[guard1]["A"].subset) == str(Range.from_array(sdfg.arrays["A"]))) assert (str(result[guard2]["A"].subset) == "j, 0:N") @@ -629,7 +648,8 @@ def test_loop_2D_propagation_gap_symbolic(): loop_tasklet = loop_body.add_tasklet("overwrite", {}, {"a"}, "a = 0") loop_body.add_edge(loop_tasklet, "a", a0, None, dace.Memlet("A[j,i]")) - result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__]["loop_approximation"] + pipeline = Pipeline([UnderapproximateWrites()]) + result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__][sdfg.cfg_id].loop_approximation assert ("A" not in result[guard1].keys()) assert ("A" not in result[guard2].keys()) @@ -657,7 +677,8 @@ def test_2_loops_overwrite(): loop_tasklet_2 = loop_body_2.add_tasklet("overwrite", {}, {"a"}, "a = 0") loop_body_2.add_edge(loop_tasklet_2, "a", a1, None, dace.Memlet("A[i]")) - result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__]["loop_approximation"] + pipeline = Pipeline([UnderapproximateWrites()]) + result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__][sdfg.cfg_id].loop_approximation assert (str(result[guard_1]["A"].subset) == str(Range.from_array(sdfg.arrays["A"]))) assert (str(result[guard_2]["A"].subset) == str(Range.from_array(sdfg.arrays["A"]))) @@ -687,7 +708,8 @@ def test_loop_2D_overwrite_propagation_gap_non_empty(): loop_tasklet = loop_body.add_tasklet("overwrite", {}, {"a"}, "a = 0") loop_body.add_edge(loop_tasklet, "a", a0, None, dace.Memlet("A[j,i]")) - result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__]["loop_approximation"] + pipeline = Pipeline([UnderapproximateWrites()]) + result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__][sdfg.cfg_id].loop_approximation assert (str(result[guard1]["A"].subset) == str(Range.from_array(sdfg.arrays["A"]))) assert (str(result[guard2]["A"].subset) == "j, 0:N") @@ -717,7 +739,8 @@ def test_loop_nest_multiplied_indices(): loop_tasklet = loop_body.add_tasklet("overwrite", {}, {"a"}, "a = 0") loop_body.add_edge(loop_tasklet, "a", a0, None, dace.Memlet("A[i,i*j]")) - result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__]["loop_approximation"] + pipeline = Pipeline([UnderapproximateWrites()]) + result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__][sdfg.cfg_id].loop_approximation assert (guard1 not in result.keys() or "A" not in result[guard1].keys()) assert (guard2 not in result.keys() or "A" not in result[guard2].keys()) @@ -748,7 +771,8 @@ def test_loop_nest_empty_nested_loop(): loop_tasklet = loop_body.add_tasklet("overwrite", {}, {"a"}, "a = 0") loop_body.add_edge(loop_tasklet, "a", a0, None, dace.Memlet("A[j,i]")) - result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__]["loop_approximation"] + pipeline = Pipeline([UnderapproximateWrites()]) + result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__][sdfg.cfg_id].loop_approximation assert (guard1 not in result.keys() or "A" not in result[guard1].keys()) assert (guard2 not in result.keys() or "A" not in result[guard2].keys()) @@ -779,7 +803,8 @@ def test_loop_nest_inner_loop_conditional(): loop_tasklet = loop_body.add_tasklet("overwrite", {}, {"a"}, "a = 0") loop_body.add_edge(loop_tasklet, "a", a0, None, dace.Memlet("A[k]")) - result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__]["loop_approximation"] + pipeline = Pipeline([UnderapproximateWrites()]) + result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__][sdfg.cfg_id].loop_approximation assert (guard1 not in result.keys() or "A" not in result[guard1].keys()) assert (guard2 in result.keys() and "A" in result[guard2].keys() and str(result[guard2]['A'].subset) == "0:N") @@ -799,9 +824,10 @@ def nested_loop(A: dace.float64[M, N]): sdfg = nested_loop.to_sdfg(simplify=True) + pipeline = Pipeline([UnderapproximateWrites()]) result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - write_approx = result["approximation"] + write_approx = result[sdfg.cfg_id].approximation write_set = None accessnode = None for node, _ in sdfg.all_nodes_recursive(): @@ -828,10 +854,11 @@ def nested_loop(A: dace.float64[M, N]): sdfg = nested_loop.to_sdfg(simplify=True) + pipeline = Pipeline([UnderapproximateWrites()]) result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] # find write set - write_approx = result["approximation"] + write_approx = result[sdfg.cfg_id].approximation accessnode = None write_set = None for node, _ in sdfg.all_nodes_recursive(): @@ -864,9 +891,10 @@ def test_loop_break(): loop_tasklet = loop_body_1.add_tasklet("overwrite", {}, {"a"}, "a = 0") loop_body_1.add_edge(loop_tasklet, "a", a0, None, dace.Memlet("A[i]")) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - result = results["loop_approximation"] + result = results[sdfg.cfg_id].loop_approximation assert (guard3 not in result.keys() or "A" not in result[guard3].keys()) diff --git a/tests/sdfg/conditional_region_test.py b/tests/sdfg/conditional_region_test.py index 4e4eda3f44..0be40f43d3 100644 --- a/tests/sdfg/conditional_region_test.py +++ b/tests/sdfg/conditional_region_test.py @@ -10,20 +10,20 @@ def test_cond_region_if(): sdfg = dace.SDFG('regular_if') - sdfg.add_array("A", (1,), dace.float32) - sdfg.add_symbol("i", dace.int32) + sdfg.add_array('A', (1,), dace.float32) + sdfg.add_symbol('i', dace.int32) state0 = sdfg.add_state('state0', is_start_block=True) - if1 = ConditionalBlock("if1") + if1 = ConditionalBlock('if1') sdfg.add_node(if1) sdfg.add_edge(state0, if1, InterstateEdge()) - if_body = ControlFlowRegion("if_body", sdfg=sdfg) - if1.branches.append((CodeBlock("i == 1"), if_body)) + if_body = ControlFlowRegion('if_body', sdfg=sdfg) + if1.add_branch(CodeBlock('i == 1'), if_body) - state1 = if_body.add_state("state1", is_start_block=True) + state1 = if_body.add_state('state1', is_start_block=True) acc_a = state1.add_access('A') - t1 = state1.add_tasklet("t1", None, {"a"}, "a = 100") + t1 = state1.add_tasklet('t1', None, {'a'}, 'a = 100') state1.add_edge(t1, 'a', acc_a, None, dace.Memlet('A[0]')) assert sdfg.is_valid() @@ -36,14 +36,14 @@ def test_cond_region_if(): assert A[0] == 1 def test_serialization(): - sdfg = SDFG("test_serialization") - cond_region = ConditionalBlock("cond_region") + sdfg = SDFG('test_serialization') + cond_region = ConditionalBlock('cond_region') sdfg.add_node(cond_region, is_start_block=True) - sdfg.add_symbol("i", dace.int32) + sdfg.add_symbol('i', dace.int32) for j in range(10): - cfg = ControlFlowRegion(f"cfg_{j}", sdfg) - cond_region.branches.append((CodeBlock(f"i == {j}"), cfg)) + cfg = ControlFlowRegion(f'cfg_{j}', sdfg) + cond_region.add_branch(CodeBlock(f'i == {j}'), cfg) assert sdfg.is_valid() @@ -52,32 +52,32 @@ def test_serialization(): new_cond_region: ConditionalBlock = new_sdfg.nodes()[0] for j in range(10): condition, cfg = new_cond_region.branches[j] - assert condition == CodeBlock(f"i == {j}") - assert cfg.label == f"cfg_{j}" + assert condition == CodeBlock(f'i == {j}') + assert cfg.label == f'cfg_{j}' def test_if_else(): sdfg = dace.SDFG('regular_if_else') - sdfg.add_array("A", (1,), dace.float32) - sdfg.add_symbol("i", dace.int32) + sdfg.add_array('A', (1,), dace.float32) + sdfg.add_symbol('i', dace.int32) state0 = sdfg.add_state('state0', is_start_block=True) - if1 = ConditionalBlock("if1") + if1 = ConditionalBlock('if1') sdfg.add_node(if1) sdfg.add_edge(state0, if1, InterstateEdge()) - if_body = ControlFlowRegion("if_body", sdfg=sdfg) - state1 = if_body.add_state("state1", is_start_block=True) + if_body = ControlFlowRegion('if_body', sdfg=sdfg) + state1 = if_body.add_state('state1', is_start_block=True) acc_a = state1.add_access('A') - t1 = state1.add_tasklet("t1", None, {"a"}, "a = 100") + t1 = state1.add_tasklet('t1', None, {'a'}, 'a = 100') state1.add_edge(t1, 'a', acc_a, None, dace.Memlet('A[0]')) - if1.branches.append((CodeBlock("i == 1"), if_body)) + if1.add_branch(CodeBlock('i == 1'), if_body) - else_body = ControlFlowRegion("else_body", sdfg=sdfg) - state2 = else_body.add_state("state1", is_start_block=True) + else_body = ControlFlowRegion('else_body', sdfg=sdfg) + state2 = else_body.add_state('state1', is_start_block=True) acc_a2 = state2.add_access('A') - t2 = state2.add_tasklet("t2", None, {"a"}, "a = 200") + t2 = state2.add_tasklet('t2', None, {'a'}, 'a = 200') state2.add_edge(t2, 'a', acc_a2, None, dace.Memlet('A[0]')) - if1.branches.append((CodeBlock("i == 0"), else_body)) + if1.add_branch(CodeBlock('i == 0'), else_body) assert sdfg.is_valid() A = np.ones((1,), dtype=np.float32) diff --git a/tests/sdfg/loop_region_test.py b/tests/sdfg/loop_region_test.py index 6aca54f40c..dedafb67ba 100644 --- a/tests/sdfg/loop_region_test.py +++ b/tests/sdfg/loop_region_test.py @@ -86,6 +86,27 @@ def _make_do_for_loop() -> SDFG: return sdfg +def _make_do_for_inverted_cond_loop() -> SDFG: + sdfg = dace.SDFG('do_for_inverted_cond') + sdfg.using_experimental_blocks = True + sdfg.add_symbol('i', dace.int32) + sdfg.add_array('A', [10], dace.float32) + state0 = sdfg.add_state('state0', is_start_block=True) + loop1 = LoopRegion(label='loop1', condition_expr='i < 8', loop_var='i', initialize_expr='i = 0', + update_expr='i = i + 1', inverted=True, update_before_condition=False) + sdfg.add_node(loop1) + state1 = loop1.add_state('state1', is_start_block=True) + acc_a = state1.add_access('A') + t1 = state1.add_tasklet('t1', None, {'a'}, 'a = i') + state1.add_edge(t1, 'a', acc_a, None, dace.Memlet('A[i]')) + state2 = loop1.add_state('state2') + loop1.add_edge(state1, state2, dace.InterstateEdge()) + state3 = sdfg.add_state('state3') + sdfg.add_edge(state0, loop1, dace.InterstateEdge()) + sdfg.add_edge(loop1, state3, dace.InterstateEdge()) + return sdfg + + def _make_triple_nested_for_loop() -> SDFG: sdfg = dace.SDFG('gemm') sdfg.using_experimental_blocks = True @@ -177,6 +198,19 @@ def test_loop_do_for(): assert np.allclose(a_validation, a_test) +def test_loop_do_for_inverted_condition(): + sdfg = _make_do_for_inverted_cond_loop() + + assert sdfg.is_valid() + + a_validation = np.zeros([10], dtype=np.float32) + a_test = np.zeros([10], dtype=np.float32) + sdfg(A=a_test) + for i in range(9): + a_validation[i] = i + assert np.allclose(a_validation, a_test) + + def test_loop_triple_nested_for(): sdfg = _make_triple_nested_for_loop() @@ -249,6 +283,21 @@ def test_loop_to_stree_do_for(): f'{tn.INDENTATION}while (i < 10)') +def test_loop_to_stree_do_for_inverted_cond(): + sdfg = _make_do_for_inverted_cond_loop() + + assert sdfg.is_valid() + + stree = s2t.as_schedule_tree(sdfg) + + assert stree.as_string() == (f'{tn.INDENTATION}i = 0\n' + + f'{tn.INDENTATION}while True:\n' + + f'{2 * tn.INDENTATION}A[i] = tasklet()\n' + + f'{2 * tn.INDENTATION}if (not (i < 8)):\n' + + f'{3 * tn.INDENTATION}break\n' + + f'{2 * tn.INDENTATION}i = (i + 1)\n') + + def test_loop_to_stree_triple_nested_for(): sdfg = _make_triple_nested_for_loop() @@ -267,9 +316,11 @@ def test_loop_to_stree_triple_nested_for(): test_loop_regular_while() test_loop_do_while() test_loop_do_for() + test_loop_do_for_inverted_condition() test_loop_triple_nested_for() test_loop_to_stree_regular_for() test_loop_to_stree_regular_while() test_loop_to_stree_do_while() test_loop_to_stree_do_for() + test_loop_to_stree_do_for_inverted_cond() test_loop_to_stree_triple_nested_for() diff --git a/tests/transformations/interstate/loop_lifting_test.py b/tests/transformations/interstate/loop_lifting_test.py new file mode 100644 index 0000000000..20f244621c --- /dev/null +++ b/tests/transformations/interstate/loop_lifting_test.py @@ -0,0 +1,217 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +""" Tests loop raising trainsformations. """ + +import numpy as np +import pytest +import dace +from dace.memlet import Memlet +from dace.sdfg.sdfg import SDFG, InterstateEdge +from dace.sdfg.state import LoopRegion +from dace.transformation.interstate.loop_lifting import LoopLifting + + +def test_lift_regular_for_loop(): + sdfg = SDFG('regular_for') + N = dace.symbol('N') + sdfg.add_symbol('i', dace.int32) + sdfg.add_symbol('j', dace.int32) + sdfg.add_symbol('k', dace.int32) + sdfg.add_array('A', (N,), dace.int32) + start_state = sdfg.add_state('start', is_start_block=True) + init_state = sdfg.add_state('init') + guard_state = sdfg.add_state('guard') + main_state = sdfg.add_state('loop_state') + loop_exit = sdfg.add_state('exit') + final_state = sdfg.add_state('final') + sdfg.add_edge(start_state, init_state, InterstateEdge(assignments={'j': 0})) + sdfg.add_edge(init_state, guard_state, InterstateEdge(assignments={'i': 0, 'k': 0})) + sdfg.add_edge(guard_state, main_state, InterstateEdge(condition='i < N')) + sdfg.add_edge(main_state, guard_state, InterstateEdge(assignments={'i': 'i + 2', 'j': 'j + 1'})) + sdfg.add_edge(guard_state, loop_exit, InterstateEdge(condition='i >= N', assignments={'k': 2})) + sdfg.add_edge(loop_exit, final_state, InterstateEdge()) + a_access = main_state.add_access('A') + w_tasklet = main_state.add_tasklet('t1', {}, {'out'}, 'out = 1') + main_state.add_edge(w_tasklet, 'out', a_access, None, Memlet('A[i]')) + a_access_2 = loop_exit.add_access('A') + w_tasklet_2 = loop_exit.add_tasklet('t1', {}, {'out'}, 'out = k') + loop_exit.add_edge(w_tasklet_2, 'out', a_access_2, None, Memlet('A[1]')) + a_access_3 = final_state.add_access('A') + w_tasklet_3 = final_state.add_tasklet('t1', {}, {'out'}, 'out = j') + final_state.add_edge(w_tasklet_3, 'out', a_access_3, None, Memlet('A[3]')) + + N = 30 + A = np.zeros((N,)).astype(np.int32) + A_valid = np.zeros((N,)).astype(np.int32) + sdfg(A=A_valid, N=N) + sdfg.apply_transformations_repeated([LoopLifting]) + + assert sdfg.using_experimental_blocks == True + assert any(isinstance(x, LoopRegion) for x in sdfg.nodes()) + + sdfg(A=A, N=N) + + assert np.allclose(A_valid, A) + + +@pytest.mark.parametrize('increment_before_condition', (True, False)) +def test_lift_loop_llvm_canonical(increment_before_condition): + addendum = '_incr_before_cond' if increment_before_condition else '' + sdfg = dace.SDFG('llvm_canonical' + addendum) + N = dace.symbol('N') + sdfg.add_symbol('i', dace.int32) + sdfg.add_symbol('j', dace.int32) + sdfg.add_symbol('k', dace.int32) + sdfg.add_array('A', (N,), dace.int32) + + entry = sdfg.add_state('entry', is_start_block=True) + guard = sdfg.add_state('guard') + preheader = sdfg.add_state('preheader') + body = sdfg.add_state('body') + latch = sdfg.add_state('latch') + loopexit = sdfg.add_state('loopexit') + exitstate = sdfg.add_state('exitstate') + + sdfg.add_edge(entry, guard, InterstateEdge(assignments={'j': 0})) + sdfg.add_edge(guard, exitstate, InterstateEdge(condition='N <= 0')) + sdfg.add_edge(guard, preheader, InterstateEdge(condition='N > 0')) + sdfg.add_edge(preheader, body, InterstateEdge(assignments={'i': 0, 'k': 0})) + if increment_before_condition: + sdfg.add_edge(body, latch, InterstateEdge(assignments={'i': 'i + 2', 'j': 'j + 1'})) + sdfg.add_edge(latch, body, InterstateEdge(condition='i < N')) + sdfg.add_edge(latch, loopexit, InterstateEdge(condition='i >= N', assignments={'k': 2})) + else: + sdfg.add_edge(body, latch, InterstateEdge(assignments={'j': 'j + 1'})) + sdfg.add_edge(latch, body, InterstateEdge(condition='i < N - 2', assignments={'i': 'i + 2'})) + sdfg.add_edge(latch, loopexit, InterstateEdge(condition='i >= N - 2', assignments={'k': 2})) + sdfg.add_edge(loopexit, exitstate, InterstateEdge()) + + a_access = body.add_access('A') + w_tasklet = body.add_tasklet('t1', {}, {'out'}, 'out = 1') + body.add_edge(w_tasklet, 'out', a_access, None, Memlet('A[i]')) + a_access_2 = loopexit.add_access('A') + w_tasklet_2 = loopexit.add_tasklet('t1', {}, {'out'}, 'out = k') + loopexit.add_edge(w_tasklet_2, 'out', a_access_2, None, Memlet('A[1]')) + a_access_3 = exitstate.add_access('A') + w_tasklet_3 = exitstate.add_tasklet('t1', {}, {'out'}, 'out = j') + exitstate.add_edge(w_tasklet_3, 'out', a_access_3, None, Memlet('A[3]')) + + N = 30 + A = np.zeros((N,)).astype(np.int32) + A_valid = np.zeros((N,)).astype(np.int32) + sdfg(A=A_valid, N=N) + sdfg.apply_transformations_repeated([LoopLifting]) + + assert sdfg.using_experimental_blocks == True + assert any(isinstance(x, LoopRegion) for x in sdfg.nodes()) + + sdfg(A=A, N=N) + + assert np.allclose(A_valid, A) + + +def test_lift_loop_llvm_canonical_while(): + sdfg = dace.SDFG('llvm_canonical_while') + N = dace.symbol('N') + sdfg.add_symbol('j', dace.int32) + sdfg.add_symbol('k', dace.int32) + sdfg.add_array('A', (N,), dace.int32) + sdfg.add_scalar('i', dace.int32, transient=True) + + entry = sdfg.add_state('entry', is_start_block=True) + guard = sdfg.add_state('guard') + preheader = sdfg.add_state('preheader') + body = sdfg.add_state('body') + latch = sdfg.add_state('latch') + loopexit = sdfg.add_state('loopexit') + exitstate = sdfg.add_state('exitstate') + + sdfg.add_edge(entry, guard, InterstateEdge(assignments={'j': 0})) + sdfg.add_edge(guard, exitstate, InterstateEdge(condition='N <= 0')) + sdfg.add_edge(guard, preheader, InterstateEdge(condition='N > 0')) + sdfg.add_edge(preheader, body, InterstateEdge(assignments={'k': 0})) + sdfg.add_edge(body, latch, InterstateEdge(assignments={'j': 'j + 1'})) + sdfg.add_edge(latch, body, InterstateEdge(condition='i < N - 2')) + sdfg.add_edge(latch, loopexit, InterstateEdge(condition='i >= N - 2', assignments={'k': 2})) + sdfg.add_edge(loopexit, exitstate, InterstateEdge()) + + i_init_write = entry.add_access('i') + iw_init_tasklet = entry.add_tasklet('ti', {}, {'out'}, 'out = 0') + entry.add_edge(iw_init_tasklet, 'out', i_init_write, None, Memlet('i[0]')) + a_access = body.add_access('A') + w_tasklet = body.add_tasklet('t1', {}, {'out'}, 'out = 1') + body.add_edge(w_tasklet, 'out', a_access, None, Memlet('A[i]')) + i_read = body.add_access('i') + i_write = body.add_access('i') + iw_tasklet = body.add_tasklet('t2', {'in1'}, {'out'}, 'out = in1 + 2') + body.add_edge(i_read, None, iw_tasklet, 'in1', Memlet('i[0]')) + body.add_edge(iw_tasklet, 'out', i_write, None, Memlet('i[0]')) + a_access_2 = loopexit.add_access('A') + w_tasklet_2 = loopexit.add_tasklet('t1', {}, {'out'}, 'out = k') + loopexit.add_edge(w_tasklet_2, 'out', a_access_2, None, Memlet('A[1]')) + a_access_3 = exitstate.add_access('A') + w_tasklet_3 = exitstate.add_tasklet('t1', {}, {'out'}, 'out = j') + exitstate.add_edge(w_tasklet_3, 'out', a_access_3, None, Memlet('A[3]')) + + N = 30 + A = np.zeros((N,)).astype(np.int32) + A_valid = np.zeros((N,)).astype(np.int32) + sdfg(A=A_valid, N=N) + sdfg.apply_transformations_repeated([LoopLifting]) + + assert sdfg.using_experimental_blocks == True + assert any(isinstance(x, LoopRegion) for x in sdfg.nodes()) + + sdfg(A=A, N=N) + + assert np.allclose(A_valid, A) + + +def test_do_while(): + sdfg = SDFG('regular_for') + N = dace.symbol('N') + sdfg.add_symbol('i', dace.int32) + sdfg.add_symbol('j', dace.int32) + sdfg.add_symbol('k', dace.int32) + sdfg.add_array('A', (N,), dace.int32) + start_state = sdfg.add_state('start', is_start_block=True) + init_state = sdfg.add_state('init') + guard_state = sdfg.add_state('guard') + main_state = sdfg.add_state('loop_state') + loop_exit = sdfg.add_state('exit') + final_state = sdfg.add_state('final') + sdfg.add_edge(start_state, init_state, InterstateEdge(assignments={'j': 0})) + sdfg.add_edge(init_state, main_state, InterstateEdge(assignments={'i': 0, 'k': 0})) + sdfg.add_edge(main_state, guard_state, InterstateEdge(assignments={'i': 'i + 2', 'j': 'j + 1'})) + sdfg.add_edge(guard_state, main_state, InterstateEdge(condition='i < N')) + sdfg.add_edge(guard_state, loop_exit, InterstateEdge(condition='i >= N', assignments={'k': 2})) + sdfg.add_edge(loop_exit, final_state, InterstateEdge()) + a_access = main_state.add_access('A') + w_tasklet = main_state.add_tasklet('t1', {}, {'out'}, 'out = 1') + main_state.add_edge(w_tasklet, 'out', a_access, None, Memlet('A[i]')) + a_access_2 = loop_exit.add_access('A') + w_tasklet_2 = loop_exit.add_tasklet('t1', {}, {'out'}, 'out = k') + loop_exit.add_edge(w_tasklet_2, 'out', a_access_2, None, Memlet('A[1]')) + a_access_3 = final_state.add_access('A') + w_tasklet_3 = final_state.add_tasklet('t1', {}, {'out'}, 'out = j') + final_state.add_edge(w_tasklet_3, 'out', a_access_3, None, Memlet('A[3]')) + + N = 30 + A = np.zeros((N,)).astype(np.int32) + A_valid = np.zeros((N,)).astype(np.int32) + sdfg(A=A_valid, N=N) + sdfg.apply_transformations_repeated([LoopLifting]) + + assert sdfg.using_experimental_blocks == True + assert any(isinstance(x, LoopRegion) for x in sdfg.nodes()) + + sdfg(A=A, N=N) + + assert np.allclose(A_valid, A) + + +if __name__ == '__main__': + test_lift_regular_for_loop() + test_lift_loop_llvm_canonical(True) + test_lift_loop_llvm_canonical(False) + test_lift_loop_llvm_canonical_while() + test_do_while() diff --git a/tests/transformations/loop_detection_test.py b/tests/transformations/loop_detection_test.py index 5469f45762..323a27787a 100644 --- a/tests/transformations/loop_detection_test.py +++ b/tests/transformations/loop_detection_test.py @@ -27,7 +27,8 @@ def tester(a: dace.float64[20]): assert rng == (1, 19, 1) -def test_loop_rotated(): +@pytest.mark.parametrize('increment_before_condition', (True, False)) +def test_loop_rotated(increment_before_condition): sdfg = dace.SDFG('tester') sdfg.add_symbol('N', dace.int32) @@ -37,8 +38,12 @@ def test_loop_rotated(): exitstate = sdfg.add_state('exitstate') sdfg.add_edge(entry, body, dace.InterstateEdge(assignments=dict(i=0))) - sdfg.add_edge(body, latch, dace.InterstateEdge()) - sdfg.add_edge(latch, body, dace.InterstateEdge('i < N', assignments=dict(i='i + 2'))) + if increment_before_condition: + sdfg.add_edge(body, latch, dace.InterstateEdge(assignments=dict(i='i + 2'))) + sdfg.add_edge(latch, body, dace.InterstateEdge('i < N')) + else: + sdfg.add_edge(body, latch, dace.InterstateEdge()) + sdfg.add_edge(latch, body, dace.InterstateEdge('i < N', assignments=dict(i='i + 2'))) sdfg.add_edge(latch, exitstate, dace.InterstateEdge('i >= N')) xform = CountLoops() @@ -48,8 +53,9 @@ def test_loop_rotated(): assert rng == (0, dace.symbol('N') - 1, 2) -@pytest.mark.skip('Extra incrementation states should not be supported by loop detection') def test_loop_rotated_extra_increment(): + # Extra incrementation states (i.e., something more than a single edge between the latch and the body) should not + # be allowed and consequently not be detected as loops. sdfg = dace.SDFG('tester') sdfg.add_symbol('N', dace.int32) @@ -60,15 +66,13 @@ def test_loop_rotated_extra_increment(): exitstate = sdfg.add_state('exitstate') sdfg.add_edge(entry, body, dace.InterstateEdge(assignments=dict(i=0))) + sdfg.add_edge(body, latch, dace.InterstateEdge()) sdfg.add_edge(latch, increment, dace.InterstateEdge('i < N')) sdfg.add_edge(increment, body, dace.InterstateEdge(assignments=dict(i='i + 1'))) sdfg.add_edge(latch, exitstate, dace.InterstateEdge('i >= N')) xform = CountLoops() - assert sdfg.apply_transformations(xform) == 1 - itvar, rng, _ = xform.loop_information() - assert itvar == 'i' - assert rng == (0, dace.symbol('N') - 1, 1) + assert sdfg.apply_transformations(xform) == 0 def test_self_loop(): @@ -91,7 +95,8 @@ def test_self_loop(): assert rng == (2, dace.symbol('N') - 1, 3) -def test_loop_llvm_canonical(): +@pytest.mark.parametrize('increment_before_condition', (True, False)) +def test_loop_llvm_canonical(increment_before_condition): sdfg = dace.SDFG('tester') sdfg.add_symbol('N', dace.int32) @@ -106,8 +111,12 @@ def test_loop_llvm_canonical(): sdfg.add_edge(guard, exitstate, dace.InterstateEdge('N <= 0')) sdfg.add_edge(guard, preheader, dace.InterstateEdge('N > 0')) sdfg.add_edge(preheader, body, dace.InterstateEdge(assignments=dict(i=0))) - sdfg.add_edge(body, latch, dace.InterstateEdge()) - sdfg.add_edge(latch, body, dace.InterstateEdge('i < N', assignments=dict(i='i + 1'))) + if increment_before_condition: + sdfg.add_edge(body, latch, dace.InterstateEdge(assignments=dict(i='i + 1'))) + sdfg.add_edge(latch, body, dace.InterstateEdge('i < N')) + else: + sdfg.add_edge(body, latch, dace.InterstateEdge()) + sdfg.add_edge(latch, body, dace.InterstateEdge('i < N', assignments=dict(i='i + 1'))) sdfg.add_edge(latch, loopexit, dace.InterstateEdge('i >= N')) sdfg.add_edge(loopexit, exitstate, dace.InterstateEdge()) @@ -118,9 +127,10 @@ def test_loop_llvm_canonical(): assert rng == (0, dace.symbol('N') - 1, 1) -@pytest.mark.skip('Extra incrementation states should not be supported by loop detection') @pytest.mark.parametrize('with_bounds_check', (False, True)) def test_loop_llvm_canonical_with_extras(with_bounds_check): + # Extra incrementation states (i.e., something more than a single edge between the latch and the body) should not + # be allowed and consequently not be detected as loops. sdfg = dace.SDFG('tester') sdfg.add_symbol('N', dace.int32) @@ -148,17 +158,16 @@ def test_loop_llvm_canonical_with_extras(with_bounds_check): sdfg.add_edge(loopexit, exitstate, dace.InterstateEdge()) xform = CountLoops() - assert sdfg.apply_transformations(xform) == 1 - itvar, rng, _ = xform.loop_information() - assert itvar == 'i' - assert rng == (0, dace.symbol('N') - 1, 1) + assert sdfg.apply_transformations(xform) == 0 if __name__ == '__main__': test_pyloop() - test_loop_rotated() - # test_loop_rotated_extra_increment() + test_loop_rotated(True) + test_loop_rotated(False) + test_loop_rotated_extra_increment() test_self_loop() - test_loop_llvm_canonical() - # test_loop_llvm_canonical_with_extras(False) - # test_loop_llvm_canonical_with_extras(True) + test_loop_llvm_canonical(True) + test_loop_llvm_canonical(False) + test_loop_llvm_canonical_with_extras(False) + test_loop_llvm_canonical_with_extras(True)