From 509ee0ff7c6088d8a89c68d05791af58aca8c1d7 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Wed, 20 Dec 2023 23:27:13 -0800 Subject: [PATCH] Reference-to-View pass and comprehensive reference test suite (#1485) Implements a reference-to-view pass (converting references to views if they are only set to one particular subset). Also improves the simplify pipeline in the presence of Reference data descriptors and adds multiple tests that use references. --- dace/codegen/targets/cpp.py | 3 + dace/codegen/targets/cpu.py | 39 +- .../analysis/schedule_tree/sdfg_to_tree.py | 8 +- dace/sdfg/analysis/schedule_tree/treenodes.py | 7 +- dace/sdfg/validation.py | 21 +- .../passes/dead_dataflow_elimination.py | 9 + .../passes/reference_reduction.py | 247 ++++++++ dace/transformation/passes/simplify.py | 17 +- tests/sdfg/reference_test.py | 543 ++++++++++++++++++ 9 files changed, 864 insertions(+), 30 deletions(-) create mode 100644 dace/transformation/passes/reference_reduction.py diff --git a/dace/codegen/targets/cpp.py b/dace/codegen/targets/cpp.py index 9687fb1783..f3f1424297 100644 --- a/dace/codegen/targets/cpp.py +++ b/dace/codegen/targets/cpp.py @@ -1256,6 +1256,9 @@ def visit_Name(self, node: ast.Name): except KeyError: defined_type = None if (self.allow_casts and isinstance(dtype, dtypes.pointer) and memlet.subset.num_elements() == 1): + # Special case for pointer to pointer assignment + if memlet.data in self.sdfg.arrays and self.sdfg.arrays[memlet.data].dtype == dtype: + return self.generic_visit(node) return ast.parse(f"{name}[0]").body[0].value elif (self.allow_casts and (defined_type in (DefinedType.Stream, DefinedType.StreamArray)) and memlet.dynamic): diff --git a/dace/codegen/targets/cpu.py b/dace/codegen/targets/cpu.py index 7ed8a48cd7..e2497cdb77 100644 --- a/dace/codegen/targets/cpu.py +++ b/dace/codegen/targets/cpu.py @@ -717,10 +717,8 @@ def _emit_copy( state_dfg = sdfg.nodes()[state_id] - copy_shape, src_strides, dst_strides, src_expr, dst_expr = \ - cpp.memlet_copy_to_absolute_strides( - self._dispatcher, sdfg, state_dfg, edge, src_node, dst_node, - self._packed_types) + copy_shape, src_strides, dst_strides, src_expr, dst_expr = cpp.memlet_copy_to_absolute_strides( + self._dispatcher, sdfg, state_dfg, edge, src_node, dst_node, self._packed_types) # Which numbers to include in the variable argument part dynshape, dynsrc, dyndst = 1, 1, 1 @@ -904,7 +902,8 @@ def process_out_memlets(self, _, uconn, v, _, memlet = edge if skip_wcr and memlet.wcr is not None: continue - dst_node = dfg.memlet_path(edge)[-1].dst + dst_edge = dfg.memlet_path(edge)[-1] + dst_node = dst_edge.dst # Target is neither a data nor a tasklet node if isinstance(node, nodes.AccessNode) and (not isinstance(dst_node, nodes.AccessNode) @@ -952,9 +951,12 @@ def process_out_memlets(self, conntype = node.out_connectors[uconn] is_scalar = not isinstance(conntype, dtypes.pointer) + if isinstance(conntype, dtypes.pointer) and sdfg.arrays[memlet.data].dtype == conntype: + is_scalar = True # Pointer to pointer assignment is_stream = isinstance(sdfg.arrays[memlet.data], data.Stream) + is_refset = isinstance(sdfg.arrays[memlet.data], data.Reference) and dst_edge.dst_conn == 'set' - if is_scalar and not memlet.dynamic and not is_stream: + if (is_scalar and not memlet.dynamic and not is_stream) or is_refset: out_local_name = " __" + uconn in_local_name = uconn if not locals_defined: @@ -987,6 +989,9 @@ def process_out_memlets(self, if defined_type == DefinedType.Scalar: mname = cpp.ptr(memlet.data, desc, sdfg, self._frame) write_expr = f"{mname} = {in_local_name};" + elif defined_type == DefinedType.Pointer and is_refset: + mname = cpp.ptr(memlet.data, desc, sdfg, self._frame) + write_expr = f"{mname} = {in_local_name};" elif (defined_type == DefinedType.ArrayInterface and not isinstance(desc, data.View)): # Special case: No need to write anything between # array interfaces going out @@ -1473,15 +1478,21 @@ def define_out_memlet(self, sdfg, state_dfg, state_id, src_node, dst_node, edge, cdtype = src_node.out_connectors[edge.src_conn] if isinstance(sdfg.arrays[edge.data.data], data.Stream): pass - elif isinstance(cdtype, dtypes.pointer): - # If pointer, also point to output + elif isinstance(cdtype, dtypes.pointer): # If pointer, also point to output desc = sdfg.arrays[edge.data.data] - ptrname = cpp.ptr(edge.data.data, desc, sdfg, self._frame) - is_global = desc.lifetime in (dtypes.AllocationLifetime.Global, dtypes.AllocationLifetime.Persistent, - dtypes.AllocationLifetime.External) - defined_type, _ = self._dispatcher.defined_vars.get(ptrname, is_global=is_global) - base_ptr = cpp.cpp_ptr_expr(sdfg, edge.data, defined_type, codegen=self._frame) - callsite_stream.write(f'{cdtype.ctype} {edge.src_conn} = {base_ptr};', sdfg, state_id, src_node) + + # If reference set, do not emit initial assignment + is_refset = isinstance(desc, data.Reference) and state_dfg.memlet_path(edge)[-1].dst_conn == 'set' + + if not is_refset and not isinstance(desc.dtype, dtypes.pointer): + ptrname = cpp.ptr(edge.data.data, desc, sdfg, self._frame) + is_global = desc.lifetime in (dtypes.AllocationLifetime.Global, dtypes.AllocationLifetime.Persistent, + dtypes.AllocationLifetime.External) + defined_type, _ = self._dispatcher.defined_vars.get(ptrname, is_global=is_global) + base_ptr = cpp.cpp_ptr_expr(sdfg, edge.data, defined_type, codegen=self._frame) + callsite_stream.write(f'{cdtype.ctype} {edge.src_conn} = {base_ptr};', sdfg, state_id, src_node) + else: + callsite_stream.write(f'{cdtype.as_arg(edge.src_conn)};', sdfg, state_id, src_node) else: callsite_stream.write(f'{cdtype.ctype} {edge.src_conn};', sdfg, state_id, src_node) diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index 78b2280902..51871e6512 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -424,9 +424,15 @@ def prepare_schedule_tree_edges(state: SDFGState) -> Dict[gr.MultiConnectorEdge[ # 2. Check for reference sets if isinstance(e.dst, dace.nodes.AccessNode) and e.dst_conn == 'set': assert isinstance(e.dst.desc(sdfg), dace.data.Reference) + + # Determine source + if isinstance(mtree.root().edge.src, dace.nodes.CodeNode): + src_desc = mtree.root().edge.src + else: + src_desc = sdfg.arrays[e.data.data] result[e] = tn.RefSetNode(target=e.dst.data, memlet=e.data, - src_desc=sdfg.arrays[e.data.data], + src_desc=src_desc, ref_desc=sdfg.arrays[e.dst.data]) scope = state.entry_node(e.dst if mtree.downwards else e.src) scope_to_edges[scope].append(e) diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index 99918cd2a4..5d3d2a6fa8 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -36,8 +36,7 @@ class ScheduleTreeScope(ScheduleTreeNode): containers: Optional[Dict[str, data.Data]] = field(default_factory=dict, init=False) symbols: Optional[Dict[str, symbol]] = field(default_factory=dict, init=False) - def __init__(self, - children: Optional[List['ScheduleTreeNode']] = None): + def __init__(self, children: Optional[List['ScheduleTreeNode']] = None): self.children = children or [] if self.children: for child in children: @@ -350,10 +349,12 @@ class RefSetNode(ScheduleTreeNode): """ target: str memlet: Memlet - src_desc: data.Data + src_desc: Union[data.Data, nodes.CodeNode] ref_desc: data.Data def as_string(self, indent: int = 0): + if isinstance(self.src_desc, nodes.CodeNode): + return indent * INDENTATION + f'{self.target} = refset from {type(self.src_desc).__name__.lower()}' return indent * INDENTATION + f'{self.target} = refset to {self.memlet}' diff --git a/dace/sdfg/validation.py b/dace/sdfg/validation.py index a3914494c3..9feda8259c 100644 --- a/dace/sdfg/validation.py +++ b/dace/sdfg/validation.py @@ -100,7 +100,7 @@ def validate_control_flow_region(sdfg: 'dace.sdfg.SDFG', in_default_scope = False if sdfg.parent_nsdfg_node is not None: if is_in_scope(sdfg.parent_sdfg, sdfg.parent, sdfg.parent_nsdfg_node, - [dtypes.ScheduleType.Default]): + [dtypes.ScheduleType.Default]): in_default_scope = True if in_default_scope is False: eid = region.edge_id(edge) @@ -126,9 +126,9 @@ def validate_control_flow_region(sdfg: 'dace.sdfg.SDFG', if start_block not in visited: if isinstance(start_block, SDFGState): validate_state(start_block, region.node_id(start_block), sdfg, symbols, initialized_transients, references, - **context) + **context) else: - validate_control_flow_region(sdfg, start_block, initialized_transients, symbols, references, **context) + validate_control_flow_region(sdfg, start_block, initialized_transients, symbols, references, **context) # Validate all inter-state edges (including self-loops not found by DFS) for eid, edge in enumerate(region.edges()): @@ -162,7 +162,7 @@ def validate_control_flow_region(sdfg: 'dace.sdfg.SDFG', in_default_scope = False if sdfg.parent_nsdfg_node is not None: if is_in_scope(sdfg.parent_sdfg, sdfg.parent, sdfg.parent_nsdfg_node, - [dtypes.ScheduleType.Default]): + [dtypes.ScheduleType.Default]): in_default_scope = True if in_default_scope is False: raise InvalidSDFGInterstateEdgeError( @@ -453,9 +453,16 @@ def validate_state(state: 'dace.sdfg.SDFGState', # Find uninitialized transients if node.data not in initialized_transients: - if (arr.transient and state.in_degree(node) == 0 and state.out_degree(node) > 0 - # Streams do not need to be initialized - and not isinstance(arr, dt.Stream)): + if isinstance(arr, dt.Reference): # References are considered more conservatively + if any(e.dst_conn == 'set' for e in state.in_edges(node)): + initialized_transients.add(node.data) + else: + raise InvalidSDFGNodeError( + 'Reference data descriptor was used before it was set. Set ' + 'it with an incoming memlet to the "set" connector', sdfg, state_id, nid) + elif (arr.transient and state.in_degree(node) == 0 and state.out_degree(node) > 0 + # Streams do not need to be initialized + and not isinstance(arr, dt.Stream)): if node.setzero == False: warnings.warn('WARNING: Use of uninitialized transient "%s" in state %s' % (node.data, state.label)) diff --git a/dace/transformation/passes/dead_dataflow_elimination.py b/dace/transformation/passes/dead_dataflow_elimination.py index aeaf1cdbd1..d9131385d6 100644 --- a/dace/transformation/passes/dead_dataflow_elimination.py +++ b/dace/transformation/passes/dead_dataflow_elimination.py @@ -231,6 +231,10 @@ def _is_node_dead(self, node: nodes.Node, sdfg: SDFG, state: SDFGState, dead_nod # Check incoming edges for e in state.in_edges(node): + # A reference set should not be removed + if e.dst_conn == 'set': + return False + for l in state.memlet_tree(e).leaves(): # If data is connected to a side-effect tasklet/library node, cannot remove if _has_side_effects(l.src, sdfg): @@ -245,6 +249,11 @@ def _is_node_dead(self, node: nodes.Node, sdfg: SDFG, state: SDFGState, dead_nod if isinstance(desc, data.Stream) and node.data in access_set[0]: return False + # If it is a reference, it may point to other data containers, + # be conservative for now + if isinstance(desc, data.Reference): + return False + # Any other case can be marked as dead return True diff --git a/dace/transformation/passes/reference_reduction.py b/dace/transformation/passes/reference_reduction.py new file mode 100644 index 0000000000..99bd2cea24 --- /dev/null +++ b/dace/transformation/passes/reference_reduction.py @@ -0,0 +1,247 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +from collections import defaultdict +import copy +from typing import Any, Dict, List, Optional, Set, Tuple + +from dace import SDFG, SDFGState, data, properties, Memlet +from dace.sdfg import nodes +from dace.sdfg.analysis import cfg +from dace.transformation import pass_pipeline as ppl +from dace.transformation.passes import analysis as ap + + +@properties.make_properties +class ReferenceToView(ppl.Pass): + """ + Replaces Reference data descriptors that are only set to one source with views. + """ + + CATEGORY: str = 'Simplification' + + def modifies(self) -> ppl.Modifies: + return ppl.Modifies.Descriptors | ppl.Modifies.AccessNodes + + def should_reapply(self, modified: ppl.Modifies) -> bool: + return modified & ppl.Modifies.AccessNodes + + def depends_on(self): + return {ap.StateReachability, ap.FindAccessStates, ap.FindReferenceSources} + + def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[Set[str]]: + """ + Removes redundant arrays and access nodes. + + :param sdfg: The SDFG to modify. + :param pipeline_results: If in the context of a ``Pipeline``, a dictionary that is populated with prior Pass + results as ``{Pass subclass name: returned object from pass}``. If not run in a + pipeline, an empty dictionary is expected. + :return: A set of removed data descriptor names, or None if nothing changed. + """ + reachable: Dict[SDFGState, Set[SDFGState]] = pipeline_results[ap.StateReachability.__name__][sdfg.sdfg_id] + access_states: Dict[str, Set[SDFGState]] = pipeline_results[ap.FindAccessStates.__name__][sdfg.sdfg_id] + reference_sources: Dict[str, Set[Memlet]] = pipeline_results[ap.FindReferenceSources.__name__][sdfg.sdfg_id] + + # Early exit if no references exist + if not reference_sources: + return None + + # Filter out multi-source references and tasklet-set references + candidates = set(k for k, v in reference_sources.items() + if len(v) == 1 and not isinstance(next(iter(v)), nodes.CodeNode)) + + refsets = self.find_refsets(candidates, access_states) + + result: Set[str] = self.find_candidates(sdfg, reference_sources, refsets, access_states, reachable) + if not result: + return None + + # Remove reference set edges and eliminate orphaned access nodes + self.remove_refsets(result, refsets) + + # Reconnect reference uses as views + self.reconnect_views(sdfg, result, access_states, reference_sources) + + # Modify data descriptor from Reference to View + self.change_ref_descriptors_to_views(sdfg, result) + + return result or None + + def report(self, pass_retval: Set[str]) -> str: + return f'Converted {len(pass_retval)} references to views: {pass_retval}.' + + def find_refsets(self, candidates: Set[str], + access_states: Dict[str, Set[SDFGState]]) -> Dict[str, List[Tuple[SDFGState, nodes.AccessNode]]]: + """ + Returns a dictionary of reference name to a list of tuples of (state, access node) + where the reference is set via a memlet. + """ + result: Dict[str, List[Tuple[SDFGState, nodes.AccessNode]]] = defaultdict(list) + all_states_to_consider: Set[SDFGState] = set() + for candidate in candidates: + all_states_to_consider.update(access_states[candidate]) + + for state in all_states_to_consider: + # Loop over all states that use the references once + for node in state.data_nodes(): + if node.data not in candidates: + continue + for _ in state.in_edges_by_connector(node, 'set'): + result[node.data].append((state, node)) + break + + return result + + def find_candidates( + self, + sdfg: SDFG, + reference_sources: Dict[str, Set[Memlet]], + refsets: Dict[str, List[Tuple[SDFGState, nodes.AccessNode]]], + access_states: Dict[str, Set[SDFGState]], + reachable_states: Dict[SDFGState, Set[SDFGState]], + ) -> Set[str]: + """ + Returns a set of candidates for conversion to views. + """ + result = set(refsets.keys()) + if not result: # Early return + return result + + # If memlet does not depend on any symbol, it can be kept. Otherwise, + # it may depend on a (free) symbol. There are multiple options: + # * If dependent on scope symbol (e.g., map parameter) - remove from candidates + # * If dependent on symbol defined in inter-state edges - make sure it is not changed between set and uses + # * If dependent on a free symbol - also make sure it is not changed between set and uses + for cand in list(result): # Copying the set to a list allows us to iterate over it while removing elements + source = next(iter(reference_sources[cand])) + fsyms = source.subset.free_symbols + if not fsyms: + continue + + for state, node in refsets[cand]: + # Check if any of the symbols is a scope symbol + entry = state.entry_node(node) + while entry is not None: + if fsyms & entry.new_symbols(sdfg, state, {}): + result.remove(cand) + break + entry = state.entry_node(entry) + if cand not in result: + break + + # Otherwise, they are only inter-state or free symbols. Test all paths to uses in different states + # NOTE: This is an expensive check! + for other_state in access_states[cand]: + # Filter self and unreachable states + if other_state is state or other_state not in reachable_states[state]: + continue + for path in sdfg.all_simple_paths(state, other_state, as_edges=True): + for e in path: + # The symbol was modified/reassigned in one of the paths, skip + if fsyms & e.data.assignments.keys(): + result.remove(cand) + break + if cand not in result: + break + if cand not in result: + break + + return result + + def remove_refsets( + self, + candidates: Set[str], + all_refsets: Dict[str, List[Tuple[SDFGState, nodes.AccessNode]]], + ): + for ref, refsets in all_refsets.items(): + if ref not in candidates: + continue + for state, node in refsets: + # Loop over all states that use the reference and remove reference + # set memlets, reconnecting the remaining surrounding nodes so as + # to not break scopes + edges_to_add = [] + edges_to_remove = set() + nodes_to_remove = set() + affected_nodes = set() + for e in state.in_edges_by_connector(node, 'set'): + # This is a reference set edge. Consider scope and neighbors and remove set + edges_to_remove.add(e) + affected_nodes.add(e.src) + affected_nodes.add(e.dst) + + # If source node does not have any other neighbors, it can be removed + if all(ee is e or ee.data.is_empty() for ee in state.all_edges(e.src)): + nodes_to_remove.add(e.src) + # If set reference does not have any other neighbors, it can be removed + if all(ee is e or ee.data.is_empty() for ee in state.all_edges(node)): + nodes_to_remove.add(node) + + # If in a scope, ensure reference node will not be disconnected + scope = state.entry_node(node) + if scope is not None and node not in nodes_to_remove: + edges_to_add.append((scope, None, node, None, Memlet())) + + # Modify the state graph as necessary + for e in edges_to_remove: + state.remove_edge_and_connectors(e) + for n in nodes_to_remove: + state.remove_node(n) + for e in edges_to_add: + state.add_edge(*e) + for n in affected_nodes: # Orphaned nodes + if n in nodes_to_remove: + continue + if state.degree(n) == 0: + state.remove_node(n) + + def reconnect_views(self, sdfg: SDFG, candidates: Set[str], access_states: Dict[str, Set[SDFGState]], + reference_sources: Dict[str, Set[Memlet]]): + all_states_to_consider: Set[SDFGState] = set() + for cand in candidates: + all_states_to_consider.update(access_states[cand]) + + # For each instance of the access node, connect the original data container to the view + for state in all_states_to_consider: + for node in state.data_nodes(): + if node.data not in candidates: + continue + refsource = next(iter(reference_sources[node.data])) + + needs_pred_view = any(not e.data.is_empty() for e in state.in_edges(node)) + needs_succ_view = any(not e.data.is_empty() for e in state.out_edges(node)) + if needs_pred_view: + self._create_view(refsource, state, node, predecessor=True) + if needs_succ_view: + self._create_view(refsource, state, node, predecessor=False) + + # Replace node's data container with the reference source + node.data = refsource.data + + def _create_view(self, refsource: Memlet, state: SDFGState, node: nodes.AccessNode, predecessor: bool): + """ + Creates a view access node and redirects all the edges appropriately. + """ + edges = state.in_edges if predecessor else state.out_edges + view = state.add_access(node.data) + src = (lambda e: e.src) if predecessor else (lambda _: view) + dst = (lambda _: view) if predecessor else (lambda e: e.dst) + + # Redirect edges to view + for e in edges(node): + state.remove_edge(e) + state.add_edge(src(e), e.src_conn, dst(e), e.dst_conn, e.data) + + # Use "views" connector to disambiguate potential corner cases + if predecessor: + view.add_out_connector('views') + state.add_edge(view, 'views', node, None, copy.deepcopy(refsource)) + else: + view.add_in_connector('views') + state.add_edge(node, None, view, 'views', copy.deepcopy(refsource)) + + def change_ref_descriptors_to_views(self, sdfg: SDFG, names: Set[str]): + # A slightly hacky way to replace a reference class with a view. + # Since both classes have the same superclass, and all the fields + # are the same, this is safe to perform. + for name in names: + sdfg.arrays[name].__class__ = data.View diff --git a/dace/transformation/passes/simplify.py b/dace/transformation/passes/simplify.py index 0a2539457a..1778470b14 100644 --- a/dace/transformation/passes/simplify.py +++ b/dace/transformation/passes/simplify.py @@ -14,6 +14,7 @@ from dace.transformation.passes.optional_arrays import OptionalArrayInference from dace.transformation.passes.scalar_to_symbol import ScalarToSymbolPromotion from dace.transformation.passes.prune_symbols import RemoveUnusedSymbols +from dace.transformation.passes.reference_reduction import ReferenceToView SIMPLIFY_PASSES = [ InlineSDFGs, @@ -24,12 +25,18 @@ DeadDataflowElimination, DeadStateElimination, RemoveUnusedSymbols, + ReferenceToView, ArrayElimination, ConsolidateEdges, ] _nonrecursive_passes = [ - ScalarToSymbolPromotion, DeadDataflowElimination, DeadStateElimination, ArrayElimination, ConsolidateEdges + ScalarToSymbolPromotion, + DeadDataflowElimination, + DeadStateElimination, + ArrayElimination, + ConsolidateEdges, + ReferenceToView, ] @@ -42,11 +49,11 @@ class SimplifyPass(ppl.FixedPointPipeline): CATEGORY: str = 'Simplification' - validate = properties.Property(dtype=bool, default=False, desc='Whether to validate the SDFG at the end of the pipeline.') + validate = properties.Property(dtype=bool, + default=False, + desc='Whether to validate the SDFG at the end of the pipeline.') validate_all = properties.Property(dtype=bool, default=False, desc='Whether to validate the SDFG after each pass.') - skip = properties.SetProperty(element_type=str, - default=set(), - desc='Set of pass names to skip.') + skip = properties.SetProperty(element_type=str, default=set(), desc='Set of pass names to skip.') verbose = properties.Property(dtype=bool, default=False, desc='Whether to print reports after every pass.') def __init__(self, diff --git a/tests/sdfg/reference_test.py b/tests/sdfg/reference_test.py index f1e605e315..066bd80a7f 100644 --- a/tests/sdfg/reference_test.py +++ b/tests/sdfg/reference_test.py @@ -1,11 +1,30 @@ # Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. """ Tests the use of Reference data descriptors. """ import dace +from dace.sdfg import validation +from dace.transformation.pass_pipeline import Pipeline from dace.transformation.passes.analysis import FindReferenceSources +from dace.transformation.passes.reference_reduction import ReferenceToView import numpy as np +import pytest + + +def test_unset_reference(): + sdfg = dace.SDFG('tester') + sdfg.add_reference('ref', [20], dace.float64) + state = sdfg.add_state() + t = state.add_tasklet('doit', {'a'}, {'b'}, 'b = a + 1') + state.add_edge(state.add_read('ref'), None, t, 'a', dace.Memlet('ref[0]')) + state.add_edge(t, 'b', state.add_write('ref'), None, dace.Memlet('ref[1]')) + + with pytest.raises(validation.InvalidSDFGNodeError): + sdfg.validate() def _create_branch_sdfg(): + """ + An SDFG in which a reference is set conditionally. + """ sdfg = dace.SDFG('refbranch') sdfg.add_array('A', [20], dace.float64) sdfg.add_array('B', [20], dace.float64) @@ -33,6 +52,268 @@ def _create_branch_sdfg(): return sdfg +def _create_tasklet_assignment_sdfg(): + """ + An SDFG in which a reference is set by a tasklet. + """ + sdfg = dace.SDFG('refta') + sdfg.add_array('A', [20], dace.float64) + sdfg.add_array('B', [19], dace.float64) + sdfg.add_reference('ref', [19], dace.float64) + + state = sdfg.add_state() + t = state.add_tasklet('ptrset', {'a': dace.pointer(dace.float64)}, {'o'}, 'o = a + 1') + state.add_edge(state.add_read('A'), None, t, 'a', dace.Memlet('A')) + ref = state.add_access('ref') + state.add_edge(t, 'o', ref, 'set', dace.Memlet('ref')) + t2 = state.add_tasklet('addone', {'a'}, {'o'}, 'o = a + 1') + state.add_edge(ref, None, t2, 'a', dace.Memlet('ref[0]')) + state.add_edge(t2, 'o', state.add_write('B'), None, dace.Memlet('B[0]')) + return sdfg + + +def _create_twostate_sdfg(): + """ + An SDFG in which a reference set happens on another state. + """ + sdfg = dace.SDFG('reftest') + sdfg.add_array('A', [20], dace.float64) + sdfg.add_reference('ref', [10], dace.float64) + + setstate = sdfg.add_state() + computestate = sdfg.add_state_after(setstate) + + setstate.add_edge(setstate.add_read('A'), None, setstate.add_write('ref'), 'set', dace.Memlet('A[10:20]')) + + # Read from A[10], write to A[11] + t = computestate.add_tasklet('addone', {'a'}, {'b'}, 'b = a + 1') + computestate.add_edge(computestate.add_write('ref'), None, t, 'a', dace.Memlet('ref[0]')) + computestate.add_edge(t, 'b', computestate.add_write('ref'), None, dace.Memlet('ref[1]')) + return sdfg + + +def _create_multisubset_sdfg(): + """ + A Jacobi-2d style SDFG to test multi-dimensional subsets and the use of an empty memlet + as a dependency edge. + """ + sdfg = dace.SDFG('reftest') + sdfg.add_array('A', [22, 22], dace.float64) + sdfg.add_array('B', [20, 20], dace.float64) + sdfg.add_reference('ref1', [22], dace.float64) + sdfg.add_reference('ref2', [22], dace.float64, strides=[22]) + sdfg.add_reference('ref3', [22], dace.float64, strides=[22]) + sdfg.add_reference('ref4', [22], dace.float64, strides=[22]) + sdfg.add_reference('refB', [20], dace.float64) + + state = sdfg.add_state() + + # Access nodes + a = state.add_read('A') + b = state.add_read('B') + r1 = state.add_access('ref1') + r2 = state.add_access('ref2') + r3 = state.add_access('ref3') + r4 = state.add_access('ref4') + rbset = state.add_access('refB') + rbwrite = state.add_write('refB') + + # Add reference sets + state.add_edge(a, None, r1, 'set', dace.Memlet('A[5, 0:22]')) + state.add_edge(a, None, r2, 'set', dace.Memlet('A[0:22, 5]')) + state.add_edge(a, None, r3, 'set', dace.Memlet('A[0:22, 4]')) + state.add_edge(a, None, r4, 'set', dace.Memlet('A[0:22, 3]')) + state.add_edge(b, None, rbset, 'set', dace.Memlet('B[4, 0:20]')) + + # Add tasklet + t = state.add_tasklet('stencil', {'a', 'b', 'c', 'd'}, {'o'}, 'o = 0.25 * (a + b + c + d)') + + # Connect tasklet + state.add_nedge(rbset, t, dace.Memlet()) # Happens-before edge + state.add_edge(r1, None, t, 'a', dace.Memlet('ref1[4]')) # (5,4) + state.add_edge(r2, None, t, 'b', dace.Memlet('ref2[4]')) # (4,5) + state.add_edge(r3, None, t, 'c', dace.Memlet('ref3[3]')) # (3,4) + state.add_edge(r4, None, t, 'd', dace.Memlet('ref4[4]')) # (4,3) + state.add_edge(t, 'o', rbwrite, None, dace.Memlet('refB[4]')) + + return sdfg + + +def _create_scoped_sdfg(): + """ + An SDFG in which a reference is used inside a scope. + """ + sdfg = dace.SDFG('reftest') + sdfg.add_array('A', [20, 20], dace.float64) + sdfg.add_array('B', [20, 20], dace.float64) + sdfg.add_reference('ref', [2], dace.float64, strides=[20]) + + istate = sdfg.add_state() + state = sdfg.add_state_after(istate) + + istate.add_edge(istate.add_read('A'), None, istate.add_write('ref'), 'set', dace.Memlet('A[2:4, 3]')) + + me, mx = state.add_map('mapit', dict(i='0:2')) + ref = state.add_access('ref') + inp = state.add_read('B') + t = state.add_tasklet('doit', {'r'}, {'w'}, 'w = r + 1') + out = state.add_write('A') + state.add_memlet_path(inp, me, ref, memlet=dace.Memlet('B[1, i] -> i')) + state.add_edge(ref, None, t, 'r', dace.Memlet('ref[i]')) + state.add_edge_pair(mx, t, out, internal_connector='w', internal_memlet=dace.Memlet('A[10, i]')) + + return sdfg + + +def _create_scoped_empty_memlet_sdfg(): + """ + An SDFG in which a reference is used inside a scope with no inputs. + """ + sdfg = dace.SDFG('reftest') + sdfg.add_array('A', [20, 20], dace.float64) + sdfg.add_array('B', [20, 20], dace.float64) + sdfg.add_reference('ref', [2], dace.float64, strides=[20]) + + istate = sdfg.add_state() + state = sdfg.add_state_after(istate) + + istate.add_edge(istate.add_read('A'), None, istate.add_write('ref'), 'set', dace.Memlet('A[2:4, 3]')) + + me, mx = state.add_map('mapit', dict(i='0:2')) + ref = state.add_access('ref') + t = state.add_tasklet('doit', {'r'}, {'w'}, 'w = r + 1') + out = state.add_write('B') + state.add_edge(me, None, ref, None, memlet=dace.Memlet()) + state.add_edge(ref, None, t, 'r', dace.Memlet('ref[i]')) + state.add_edge_pair(mx, t, out, internal_connector='w', internal_memlet=dace.Memlet('B[10, i]')) + + return sdfg + + +def _create_neighbor_sdfg(): + """ + An SDFG where a reference has both predecessors and successors. + """ + sdfg = dace.SDFG('reftest') + sdfg.add_array('A', [20, 20], dace.float64) + sdfg.add_array('B', [2, 2], dace.float64) + sdfg.add_reference('ref', [2, 2], dace.float64, strides=[20, 1]) + + istate = sdfg.add_state() + state = sdfg.add_state_after(istate) + + istate.add_edge(istate.add_read('A'), None, istate.add_write('ref'), 'set', dace.Memlet('A[2:4, 3:5]')) + + b = state.add_read('B') + ref1 = state.add_access('ref') + ref2 = state.add_write('ref') + state.add_mapped_tasklet('addtwo', + dict(i='0:2', j='0:2'), + dict(r=dace.Memlet('B[i, j]')), + 'w = r + 2', + dict(w=dace.Memlet('ref[i, j]')), + external_edges=True, + input_nodes=dict(B=b), + output_nodes=dict(ref=ref1)) + state.add_mapped_tasklet('sum', + dict(i='0:2'), + dict(r=dace.Memlet('ref[0, i]')), + 'w = r', + dict(w=dace.Memlet('ref[1, 0]', wcr='lambda a,b: a+b')), + external_edges=True, + input_nodes=dict(ref=ref1), + output_nodes=dict(ref=ref2)) + state.add_mapped_tasklet('addone', + dict(i='1:2'), + dict(r=dace.Memlet('ref[i - 1, i - 1]')), + 'w = r + 1', + dict(w=dace.Memlet('ref[i, i]')), + external_edges=True, + input_nodes=dict(ref=ref1), + output_nodes=dict(ref=ref2)) + return sdfg + + +def _create_loop_nonfree_symbols_sdfg(): + """ + An SDFG where a reference is set inside a loop and used outside. + """ + sdfg = dace.SDFG('reftest') + sdfg.add_array('A', [20], dace.float64) + sdfg.add_reference('ref', [1], dace.float64) + + # Create state machine + istate = sdfg.add_state() + state = sdfg.add_state() + after = sdfg.add_state() + sdfg.add_loop(istate, state, after, 'i', '0', 'i < 20', 'i + 1') + + # Reference set inside loop + state.add_edge(state.add_read('A'), None, state.add_write('ref'), 'set', dace.Memlet('A[i] -> 0')) + + # Use outisde loop + t = after.add_tasklet('setone', {}, {'out'}, 'out = 1') + after.add_edge(t, 'out', after.add_write('ref'), None, dace.Memlet('ref[0]')) + + return sdfg + + +def _create_loop_reference_internal_use(): + """ + An SDFG where a reference is set and used inside a loop. + """ + sdfg = dace.SDFG('reftest') + sdfg.add_array('A', [20], dace.float64) + sdfg.add_reference('ref', [1], dace.float64) + + # Create state machine + istate = sdfg.add_state() + state = sdfg.add_state() + after = sdfg.add_state() + sdfg.add_edge(state, after, dace.InterstateEdge()) + sdfg.add_loop(istate, state, None, 'i', '0', 'i < 20', 'i + 1', loop_end_state=after) + + # Reference set inside loop + state.add_edge(state.add_read('A'), None, state.add_write('ref'), 'set', dace.Memlet('A[i]')) + + # Use inside loop + t = after.add_tasklet('setone', {}, {'out'}, 'out = 1') + after.add_edge(t, 'out', after.add_write('ref'), None, dace.Memlet('ref[0]')) + + return sdfg + + +def _create_loop_reference_nonfree_internal_use(): + """ + An SDFG where a reference is set inside one loop and used in another, with + the same symbol name. + """ + sdfg = dace.SDFG('reftest') + sdfg.add_array('A', [20], dace.float64) + sdfg.add_reference('ref', [1], dace.float64) + + # Create state machine + istate = sdfg.add_state() + between_loops = sdfg.add_state() + + # First loop + state1 = sdfg.add_state() + sdfg.add_loop(istate, state1, between_loops, 'i', '0', 'i < 20', 'i + 1') + + # Second loop + state2 = sdfg.add_state() + sdfg.add_loop(between_loops, state2, None, 'i', '0', 'i < 20', 'i + 1') + + # Reference set inside first loop + state1.add_edge(state1.add_read('A'), None, state1.add_write('ref'), 'set', dace.Memlet('A[i]')) + + # Use inside second loop + t = state2.add_tasklet('setone', {}, {'out'}, 'out = 1') + state2.add_edge(t, 'out', state2.add_write('ref'), None, dace.Memlet('ref[0]')) + + return sdfg + + def test_reference_branch(): sdfg = _create_branch_sdfg() @@ -46,6 +327,10 @@ def test_reference_branch(): sdfg(A=A, B=B, out=out, i=1) assert np.allclose(out, A) + # Test reference-to-view - should fail to apply + result = Pipeline([ReferenceToView()]).apply_pass(sdfg, {}) + assert 'ReferenceToView' not in result or not result['ReferenceToView'] + def test_reference_sources_pass(): sdfg = _create_branch_sdfg() @@ -57,6 +342,264 @@ def test_reference_sources_pass(): assert sources == {dace.Memlet('A[0:20]', volume=1), dace.Memlet('B[0:20]', volume=1)} +def test_reference_tasklet_assignment(): + sdfg = _create_tasklet_assignment_sdfg() + + A = np.random.rand(20) + B = np.random.rand(19) + ref = np.copy(B) + ref[0] = A[1] + 1 + + sdfg(A=A, B=B) + assert np.allclose(ref, B) + + +def test_reference_tasklet_assignment_analysis(): + sdfg = _create_tasklet_assignment_sdfg() + sources = FindReferenceSources().apply_pass(sdfg, {}) + assert len(sources) == 1 # There is only one SDFG + sources = sources[0] + assert len(sources) == 1 and 'ref' in sources # There is one reference + sources = sources['ref'] + assert sources == { + next(n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, dace.nodes.Tasklet) and n.label == 'ptrset') + } + + +def test_reference_tasklet_assignment_stree(): + from dace.sdfg.analysis.schedule_tree import sdfg_to_tree as s2t, treenodes as tn + sdfg = _create_tasklet_assignment_sdfg() + stree = s2t.as_schedule_tree(sdfg) + assert [type(n) for n in stree.children] == [tn.TaskletNode, tn.RefSetNode, tn.TaskletNode] + + +def test_reference_tasklet_assignment_reftoview(): + sdfg = _create_tasklet_assignment_sdfg() + + # Test reference-to-view - should fail to apply + result = Pipeline([ReferenceToView()]).apply_pass(sdfg, {}) + assert 'ReferenceToView' not in result or not result['ReferenceToView'] + + +@pytest.mark.parametrize('reftoview', (False, True)) +def test_twostate(reftoview): + sdfg = _create_twostate_sdfg() + + # Test sources + sources = FindReferenceSources().apply_pass(sdfg, {}) + assert len(sources) == 1 # There is only one SDFG + sources = sources[0] + assert len(sources) == 1 and 'ref' in sources # There is one reference + sources = sources['ref'] + assert sources == {dace.Memlet('A[10:20]')} + + if reftoview: + sdfg.simplify() + assert not any(isinstance(v, dace.data.Reference) for v in sdfg.arrays.values()) + + # Test correctness + A = np.random.rand(20) + ref = np.copy(A) + ref[11] = ref[10] + 1 + sdfg(A=A) + assert np.allclose(A, ref) + + +@pytest.mark.parametrize('reftoview', (False, True)) +def test_multisubset(reftoview): + sdfg = _create_multisubset_sdfg() + + # Test sources + sources = FindReferenceSources().apply_pass(sdfg, {}) + assert len(sources) == 1 # There is only one SDFG + sources = sources[0] + assert len(sources) == 5 + assert sources['ref1'] == {dace.Memlet('A[5, 0:22]')} + assert sources['ref2'] == {dace.Memlet('A[0:22, 5]')} + assert sources['refB'] == {dace.Memlet('B[4, 0:20]')} + + if reftoview: + sdfg.simplify() + assert not any(isinstance(v, dace.data.Reference) for v in sdfg.arrays.values()) + + # Test correctness + A = np.random.rand(22, 22) + B = np.random.rand(20, 20) + ref = np.copy(B) + ref[4, 4] = 0.25 * (A[5, 4] + A[4, 5] + A[3, 4] + A[4, 3]) + sdfg(A=A, B=B) + assert np.allclose(B, ref) + + +@pytest.mark.parametrize('reftoview', (False, True)) +def test_scoped(reftoview): + sdfg = _create_scoped_sdfg() + + # Test sources + sources = FindReferenceSources().apply_pass(sdfg, {}) + assert len(sources) == 1 # There is only one SDFG + sources = sources[0] + assert len(sources) == 1 + assert sources['ref'] == {dace.Memlet('A[2:4, 3]')} + + if reftoview: + sdfg.simplify() + assert not any(isinstance(v, dace.data.Reference) for v in sdfg.arrays.values()) + + # Test correctness + A = np.random.rand(20, 20) + B = np.random.rand(20, 20) + ref = np.copy(A) + + ref[2:4, 3] = B[1, 0:2] + ref[10, 0:2] = ref[2:4, 3] + 1 + + sdfg(A=A, B=B) + assert np.allclose(A, ref) + + +@pytest.mark.parametrize('reftoview', (False, True)) +def test_scoped_empty_memlet(reftoview): + sdfg = _create_scoped_empty_memlet_sdfg() + + # Test sources + sources = FindReferenceSources().apply_pass(sdfg, {}) + assert len(sources) == 1 # There is only one SDFG + sources = sources[0] + assert len(sources) == 1 + assert sources['ref'] == {dace.Memlet('A[2:4, 3]')} + + if reftoview: + sdfg.simplify() + assert not any(isinstance(v, dace.data.Reference) for v in sdfg.arrays.values()) + + # Test correctness + A = np.random.rand(20, 20) + B = np.random.rand(20, 20) + ref = np.copy(B) + ref[10, 0:2] = A[2:4, 3] + 1 + + sdfg(A=A, B=B) + assert np.allclose(B, ref) + + +@pytest.mark.parametrize('reftoview', (False, True)) +def test_reference_neighbors(reftoview): + sdfg = _create_neighbor_sdfg() + + # Test sources + sources = FindReferenceSources().apply_pass(sdfg, {}) + assert len(sources) == 1 # There is only one SDFG + sources = sources[0] + assert len(sources) == 1 + assert sources['ref'] == {dace.Memlet('A[2:4, 3:5]')} + + if reftoview: + sdfg.simplify() + assert not any(isinstance(v, dace.data.Reference) for v in sdfg.arrays.values()) + + # Test correctness + A = np.random.rand(20, 20) + B = np.random.rand(2, 2) + ref = np.copy(A) + ref[2:4, 3:5] = B + 2 + ref[3, 3] += np.sum(ref[2, 3:5]) + ref[3, 4] = ref[2, 3] + 1 + + sdfg(A=A, B=B) + assert np.allclose(A, ref) + + +def test_reference_loop_nonfree(): + sdfg = _create_loop_nonfree_symbols_sdfg() + + # Test sources + sources = FindReferenceSources().apply_pass(sdfg, {}) + assert len(sources) == 1 # There is only one SDFG + sources = sources[0] + assert len(sources) == 1 + assert sources['ref'] == {dace.Memlet('A[i] -> 0')} + + # Test loop-to-map - should fail to apply + from dace.transformation.interstate import LoopToMap + assert sdfg.apply_transformations(LoopToMap) == 0 + + # Test reference-to-view - should fail to apply + result = Pipeline([ReferenceToView()]).apply_pass(sdfg, {}) + assert 'ReferenceToView' not in result or not result['ReferenceToView'] + + # Test correctness + A = np.random.rand(20) + ref = np.copy(A) + ref[-1] = 1 + sdfg(A=A) + assert np.allclose(ref, A) + + +@pytest.mark.parametrize('reftoview', (False, True)) +def test_reference_loop_internal_use(reftoview): + sdfg = _create_loop_reference_internal_use() + + # Test sources + sources = FindReferenceSources().apply_pass(sdfg, {}) + assert len(sources) == 1 # There is only one SDFG + sources = sources[0] + assert len(sources) == 1 + assert sources['ref'] == {dace.Memlet('A[i]')} + + if reftoview: + sdfg.simplify() + assert not any(isinstance(v, dace.data.Reference) for v in sdfg.arrays.values()) + + # Test correctness + A = np.random.rand(20) + ref = np.copy(A) + ref[:] = 1 + sdfg(A=A) + assert np.allclose(ref, A) + + +def test_reference_loop_nonfree_internal_use(): + sdfg = _create_loop_reference_nonfree_internal_use() + + # Test sources + sources = FindReferenceSources().apply_pass(sdfg, {}) + assert len(sources) == 1 # There is only one SDFG + sources = sources[0] + assert len(sources) == 1 + assert sources['ref'] == {dace.Memlet('A[i]')} + + # Test reference-to-view - should fail to apply + result = Pipeline([ReferenceToView()]).apply_pass(sdfg, {}) + assert 'ReferenceToView' not in result or not result['ReferenceToView'] + + # Test correctness + A = np.random.rand(20) + ref = np.copy(A) + ref[-1] = 1 + sdfg(A=A) + assert np.allclose(ref, A) + + if __name__ == '__main__': + test_unset_reference() test_reference_branch() test_reference_sources_pass() + test_reference_tasklet_assignment() + test_reference_tasklet_assignment_analysis() + test_reference_tasklet_assignment_stree() + test_reference_tasklet_assignment_reftoview() + test_twostate(False) + test_twostate(True) + test_multisubset(False) + test_multisubset(True) + test_scoped(False) + test_scoped(True) + test_scoped_empty_memlet(False) + test_scoped_empty_memlet(True) + test_reference_neighbors(False) + test_reference_neighbors(True) + test_reference_loop_nonfree() + test_reference_loop_internal_use(False) + test_reference_loop_internal_use(True) + test_reference_loop_nonfree_internal_use()