From bdecb25064b103bdc481cb895e106f2bb7ae12f7 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 20 Oct 2023 13:27:54 +0000 Subject: [PATCH 01/19] Bump urllib3 from 2.0.6 to 2.0.7 (#1400) Bumps [urllib3](https://github.com/urllib3/urllib3) from 2.0.6 to 2.0.7. - [Release notes](https://github.com/urllib3/urllib3/releases) - [Changelog](https://github.com/urllib3/urllib3/blob/main/CHANGES.rst) - [Commits](https://github.com/urllib3/urllib3/compare/2.0.6...2.0.7) --- updated-dependencies: - dependency-name: urllib3 dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: alexnick83 <31545860+alexnick83@users.noreply.github.com> --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 996449dbef..5f804e1b4c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,7 +20,7 @@ PyYAML==6.0 requests==2.31.0 six==1.16.0 sympy==1.9 -urllib3==2.0.6 +urllib3==2.0.7 websockets==11.0.3 Werkzeug==2.3.5 zipp==3.15.0 From af62440be7bd3a2756279288042b4c67de0b3411 Mon Sep 17 00:00:00 2001 From: Marcin Copik Date: Fri, 20 Oct 2023 23:11:44 +0200 Subject: [PATCH 02/19] Bugfixes and extended testing for Fortran SUM (#1390) * Fix incorrect generation of sum to loop code for Fortran frontend * Support passing array with no bounds in Fortran sum() * Add test case for Foftran sum * Fix bug in offset normalization and support Fortran SUM for arrays with offsets * Expand tests for array2loop in Fortran * Add more tests covering 2D sum in Fortran * Support Fortran sum for arrays without explicit dimension access declaration * Add more tests for Fortran sum over 2D arrays --------- Co-authored-by: acalotoiu <61420859+acalotoiu@users.noreply.github.com> --- dace/frontend/fortran/ast_transforms.py | 40 +++++- tests/fortran/array_to_loop_offset.py | 104 ++++++++++++++ tests/fortran/sum_to_loop_offset.py | 176 ++++++++++++++++++++++++ 3 files changed, 313 insertions(+), 7 deletions(-) create mode 100644 tests/fortran/sum_to_loop_offset.py diff --git a/dace/frontend/fortran/ast_transforms.py b/dace/frontend/fortran/ast_transforms.py index e2a7246aed..32744c5120 100644 --- a/dace/frontend/fortran/ast_transforms.py +++ b/dace/frontend/fortran/ast_transforms.py @@ -268,7 +268,7 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No ast_internal_classes.Var_Decl_Node( name="tmp_call_" + str(temp), type=res[i].type, - sizes=None, + sizes=None ) ])) newbody.append( @@ -284,7 +284,7 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No ast_internal_classes.Var_Decl_Node( name="tmp_call_" + str(temp), type=res[i].type, - sizes=None, + sizes=None ) ])) newbody.append( @@ -458,7 +458,11 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No if self.normalize_offsets: # Find the offset of a variable to which we are assigning - var_name = child.lval.name.name + var_name = "" + if isinstance(j, ast_internal_classes.Name_Node): + var_name = j.name + else: + var_name = j.name.name variable = self.scope_vars.get_var(child.parent, var_name) offset = variable.offsets[idx] @@ -737,8 +741,7 @@ def par_Decl_Range_Finder(node: ast_internal_classes.Array_Subscript_Node, count: int, newbody: list, scope_vars: ScopeVarsDeclarations, - declaration=True, - is_sum_to_loop=False): + declaration=True): """ Helper function for the transformation of array operations and sums to loops :param node: The AST to be transformed @@ -753,6 +756,7 @@ def par_Decl_Range_Finder(node: ast_internal_classes.Array_Subscript_Node, currentindex = 0 indices = [] + offsets = scope_vars.get_var(node.parent, node.name.name).offsets for idx, i in enumerate(node.indices): @@ -926,14 +930,36 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No current = child.lval val = child.rval - rvals = [i for i in mywalk(val) if isinstance(i, ast_internal_classes.Array_Subscript_Node)] + + rvals = [] + for i in mywalk(val): + if isinstance(i, ast_internal_classes.Call_Expr_Node) and i.name.name == '__dace_sum': + + for arg in i.args: + + # supports syntax SUM(arr) + if isinstance(arg, ast_internal_classes.Name_Node): + array_node = ast_internal_classes.Array_Subscript_Node(parent=arg.parent) + array_node.name = arg + + # If we access SUM(arr) where arr has many dimensions, + # We need to create a ParDecl_Node for each dimension + dims = len(self.scope_vars.get_var(node.parent, arg.name).sizes) + array_node.indices = [ast_internal_classes.ParDecl_Node(type='ALL')] * dims + + rvals.append(array_node) + + # supports syntax SUM(arr(:)) + if isinstance(arg, ast_internal_classes.Array_Subscript_Node): + rvals.append(arg) + if len(rvals) != 1: raise NotImplementedError("Only one array can be summed") val = rvals[0] rangeposrval = [] rangesrval = [] - par_Decl_Range_Finder(val, rangesrval, rangeposrval, self.count, newbody, self.scope_vars, False, True) + par_Decl_Range_Finder(val, rangesrval, rangeposrval, self.count, newbody, self.scope_vars, True) range_index = 0 body = ast_internal_classes.BinOp_Node(lval=current, diff --git a/tests/fortran/array_to_loop_offset.py b/tests/fortran/array_to_loop_offset.py index 43d01d9b6b..5042859f8c 100644 --- a/tests/fortran/array_to_loop_offset.py +++ b/tests/fortran/array_to_loop_offset.py @@ -112,8 +112,112 @@ def test_fortran_frontend_arr2loop_2d_offset(): for j in range(7,10): assert a[i-1, j-1] == i * 2 +def test_fortran_frontend_arr2loop_2d_offset2(): + """ + Tests that the generated array map correctly handles offsets. + """ + test_string = """ + PROGRAM index_offset_test + implicit none + double precision, dimension(5,7:9) :: d + CALL index_test_function(d) + end + + SUBROUTINE index_test_function(d) + double precision, dimension(5,7:9) :: d + + d(:,:) = 43 + + END SUBROUTINE index_test_function + """ + + # Now test to verify it executes correctly with no offset normalization + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + assert len(sdfg.data('d').shape) == 2 + assert sdfg.data('d').shape[0] == 5 + assert sdfg.data('d').shape[1] == 3 + + a = np.full([5,9], 42, order="F", dtype=np.float64) + sdfg(d=a) + for i in range(1,6): + for j in range(7,10): + assert a[i-1, j-1] == 43 + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test", True) + sdfg.simplify(verbose=True) + sdfg.compile() + + a = np.full([5,3], 42, order="F", dtype=np.float64) + sdfg(d=a) + for i in range(0,5): + for j in range(0,3): + assert a[i, j] == 43 + +def test_fortran_frontend_arr2loop_2d_offset3(): + """ + Tests that the generated array map correctly handles offsets. + """ + test_string = """ + PROGRAM index_offset_test + implicit none + double precision, dimension(5,7:9) :: d + CALL index_test_function(d) + end + + SUBROUTINE index_test_function(d) + double precision, dimension(5,7:9) :: d + + d(2:4, 7:8) = 43 + + END SUBROUTINE index_test_function + """ + + # Now test to verify it executes correctly with no offset normalization + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + assert len(sdfg.data('d').shape) == 2 + assert sdfg.data('d').shape[0] == 5 + assert sdfg.data('d').shape[1] == 3 + + a = np.full([5,9], 42, order="F", dtype=np.float64) + sdfg(d=a) + for i in range(2,4): + for j in range(7,9): + assert a[i-1, j-1] == 43 + for j in range(9,10): + assert a[i-1, j-1] == 42 + + for i in [1, 5]: + for j in range(7,10): + assert a[i-1, j-1] == 42 + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test", True) + sdfg.simplify(verbose=True) + sdfg.compile() + + a = np.full([5,3], 42, order="F", dtype=np.float64) + sdfg(d=a) + for i in range(1,4): + for j in range(0,2): + assert a[i, j] == 43 + for j in range(2,3): + assert a[i, j] == 42 + + for i in [0, 4]: + for j in range(0,3): + assert a[i, j] == 42 + if __name__ == "__main__": test_fortran_frontend_arr2loop_1d_offset() test_fortran_frontend_arr2loop_2d_offset() + test_fortran_frontend_arr2loop_2d_offset2() + test_fortran_frontend_arr2loop_2d_offset3() test_fortran_frontend_arr2loop_without_offset() diff --git a/tests/fortran/sum_to_loop_offset.py b/tests/fortran/sum_to_loop_offset.py new file mode 100644 index 0000000000..e933589e0f --- /dev/null +++ b/tests/fortran/sum_to_loop_offset.py @@ -0,0 +1,176 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. + +import numpy as np + +from dace.frontend.fortran import ast_transforms, fortran_parser + +def test_fortran_frontend_sum2loop_1d_without_offset(): + """ + Tests that the generated array map correctly handles offsets. + """ + test_string = """ + PROGRAM index_offset_test + implicit none + double precision, dimension(7) :: d + double precision, dimension(3) :: res + CALL index_test_function(d, res) + end + + SUBROUTINE index_test_function(d, res) + double precision, dimension(7) :: d + double precision, dimension(3) :: res + + res(1) = SUM(d(:)) + res(2) = SUM(d) + res(3) = SUM(d(2:6)) + + END SUBROUTINE index_test_function + """ + + # Now test to verify it executes correctly with no offset normalization + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 7 + d = np.full([size], 0, order="F", dtype=np.float64) + for i in range(size): + d[i] = i + 1 + res = np.full([3], 42, order="F", dtype=np.float64) + sdfg(d=d, res=res) + assert res[0] == (1 + size) * size / 2 + assert res[1] == (1 + size) * size / 2 + assert res[2] == (2 + size - 1) * (size - 2)/ 2 + +def test_fortran_frontend_sum2loop_1d_offset(): + """ + Tests that the generated array map correctly handles offsets. + """ + test_string = """ + PROGRAM index_offset_test + implicit none + double precision, dimension(2:6) :: d + double precision, dimension(3) :: res + CALL index_test_function(d,res) + end + + SUBROUTINE index_test_function(d, res) + double precision, dimension(2:6) :: d + double precision, dimension(3) :: res + + res(1) = SUM(d) + res(2) = SUM(d(:)) + res(3) = SUM(d(3:5)) + + END SUBROUTINE index_test_function + """ + + # Now test to verify it executes correctly with no offset normalization + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test", True) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 5 + d = np.full([size], 0, order="F", dtype=np.float64) + for i in range(size): + d[i] = i + 1 + res = np.full([3], 42, order="F", dtype=np.float64) + sdfg(d=d, res=res) + assert res[0] == (1 + size) * size / 2 + assert res[1] == (1 + size) * size / 2 + assert res[2] == (2 + size - 1) * (size - 2) / 2 + +def test_fortran_frontend_arr2loop_2d(): + """ + Tests that the generated array map correctly handles offsets. + """ + test_string = """ + PROGRAM index_offset_test + implicit none + double precision, dimension(5,3) :: d + double precision, dimension(4) :: res + CALL index_test_function(d,res) + end + + SUBROUTINE index_test_function(d, res) + double precision, dimension(5,3) :: d + double precision, dimension(4) :: res + + res(1) = SUM(d) + res(2) = SUM(d(:,:)) + res(3) = SUM(d(2:4, 2)) + res(4) = SUM(d(2:4, 2:3)) + + END SUBROUTINE index_test_function + """ + + # Now test to verify it executes correctly with no offset normalization + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test", True) + sdfg.simplify(verbose=True) + sdfg.compile() + + sizes = [5, 3] + d = np.full(sizes, 42, order="F", dtype=np.float64) + cnt = 0 + for i in range(sizes[0]): + for j in range(sizes[1]): + d[i, j] = cnt + cnt += 1 + res = np.full([4], 42, order="F", dtype=np.float64) + sdfg(d=d, res=res) + assert res[0] == 105 + assert res[1] == 105 + assert res[2] == 21 + assert res[3] == 45 + +def test_fortran_frontend_arr2loop_2d_offset(): + """ + Tests that the generated array map correctly handles offsets. + """ + test_string = """ + PROGRAM index_offset_test + implicit none + double precision, dimension(2:6,7:10) :: d + double precision, dimension(3) :: res + CALL index_test_function(d,res) + end + + SUBROUTINE index_test_function(d, res) + double precision, dimension(2:6,7:10) :: d + double precision, dimension(3) :: res + + res(1) = SUM(d) + res(2) = SUM(d(:,:)) + res(3) = SUM(d(3:5, 8:9)) + + END SUBROUTINE index_test_function + """ + + # Now test to verify it executes correctly with no offset normalization + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test", True) + sdfg.simplify(verbose=True) + sdfg.compile() + + sizes = [5, 4] + d = np.full(sizes, 42, order="F", dtype=np.float64) + cnt = 0 + for i in range(sizes[0]): + for j in range(sizes[1]): + d[i, j] = cnt + cnt += 1 + res = np.full([3], 42, order="F", dtype=np.float64) + sdfg(d=d, res=res) + assert res[0] == 190 + assert res[1] == 190 + assert res[2] == 57 + +if __name__ == "__main__": + + test_fortran_frontend_sum2loop_1d_without_offset() + test_fortran_frontend_sum2loop_1d_offset() + test_fortran_frontend_arr2loop_2d() + test_fortran_frontend_arr2loop_2d_offset() From 66913220ea600492db59cf8e536271b36c1554bd Mon Sep 17 00:00:00 2001 From: alexnick83 <31545860+alexnick83@users.noreply.github.com> Date: Sat, 21 Oct 2023 11:22:06 +0200 Subject: [PATCH 03/19] Option for utilizing GPU global memory (#1405) * Added option to change storage of non-transient data to GPU global memory. * Fixed typos. --- dace/transformation/auto/auto_optimize.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/dace/transformation/auto/auto_optimize.py b/dace/transformation/auto/auto_optimize.py index 54dbc8d4ac..644df59e5c 100644 --- a/dace/transformation/auto/auto_optimize.py +++ b/dace/transformation/auto/auto_optimize.py @@ -515,11 +515,29 @@ def make_transients_persistent(sdfg: SDFG, return result +def apply_gpu_storage(sdfg: SDFG) -> None: + """ Changes the storage of the SDFG's input and output data to GPU global memory. """ + + written_scalars = set() + for state in sdfg.nodes(): + for node in state.data_nodes(): + desc = node.desc(sdfg) + if isinstance(desc, dt.Scalar) and not desc.transient and state.in_degree(node) > 0: + written_scalars.add(node.data) + + for name, desc in sdfg.arrays.items(): + if not desc.transient and desc.storage == dtypes.StorageType.Default: + if isinstance(desc, dt.Scalar) and not name in written_scalars: + continue + desc.storage = dtypes.StorageType.GPU_Global + + def auto_optimize(sdfg: SDFG, device: dtypes.DeviceType, validate: bool = True, validate_all: bool = False, - symbols: Dict[str, int] = None) -> SDFG: + symbols: Dict[str, int] = None, + use_gpu_storage: bool = False) -> SDFG: """ Runs a basic sequence of transformations to optimize a given SDFG to decent performance. In particular, performs the following: @@ -539,6 +557,7 @@ def auto_optimize(sdfg: SDFG, have been applied. :param validate_all: If True, validates the SDFG after every step. :param symbols: Optional dict that maps symbols (str/symbolic) to int/float + :param use_gpu_storage: If True, changes the storage of non-transient data to GPU global memory. :return: The optimized SDFG. :note: Operates in-place on the given SDFG. :note: This function is still experimental and may harm correctness in @@ -565,6 +584,8 @@ def auto_optimize(sdfg: SDFG, # Apply GPU transformations and set library node implementations if device == dtypes.DeviceType.GPU: + if use_gpu_storage: + apply_gpu_storage(sdfg) sdfg.apply_gpu_transformations() sdfg.simplify() From 0f731d6c60fdbc26fa3963c6a4c7c58a24afeb9a Mon Sep 17 00:00:00 2001 From: Jan Kleine Date: Thu, 26 Oct 2023 18:25:58 +0200 Subject: [PATCH 04/19] Add tensor storage format abstraction (#1392) * Add tensor storage format abstraction Format abstraction is based on [https://doi.org/10.1145/3276493]. * Fix type signature from OrderedDict to Dict * Fix typos sefl and Singelton * Remove OrderedDict in favor of Dict * Replace |= with .update() for backwards compatibility * Fix serialization issues --- dace/data.py | 697 +++++++++++++++++++++++++++++++++ tests/sdfg/data/tensor_test.py | 131 +++++++ 2 files changed, 828 insertions(+) create mode 100644 tests/sdfg/data/tensor_test.py diff --git a/dace/data.py b/dace/data.py index 0a9858458b..199e7dabd4 100644 --- a/dace/data.py +++ b/dace/data.py @@ -1,8 +1,10 @@ # Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +import aenum import copy as cp import ctypes import functools +from abc import ABC, abstractmethod from collections import OrderedDict from numbers import Number from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union @@ -482,6 +484,701 @@ def __getitem__(self, s): if isinstance(s, list) or isinstance(s, tuple): return StructArray(self, tuple(s)) return StructArray(self, (s, )) + + +class TensorIterationTypes(aenum.AutoNumberEnum): + """ + Types of tensor iteration capabilities. + + Value (Coordinate Value Iteration) allows to directly iterate over + coordinates such as when using the Dense index type. + + Position (Coordinate Position Iteratation) iterates over coordinate + positions, at which the actual coordinates lie. This is for example the case + with a compressed index, in which the pos array enables one to iterate over + the positions in the crd array that hold the actual coordinates. + """ + Value = () + Position = () + + +class TensorAssemblyType(aenum.AutoNumberEnum): + """ + Types of possible assembly strategies for the individual indices. + + NoAssembly: Assembly is not possible as such. + + Insert: index allows inserting elements at random (e.g. Dense) + + Append: index allows appending to a list of existing coordinates. Depending + on append order, this affects whether the index is ordered or not. This + could be changed by sorting the index after assembly + """ + NoAssembly = () + Insert = () + Append = () + + +class TensorIndex(ABC): + """ + Abstract base class for tensor index implementations. + """ + + @property + @abstractmethod + def iteration_type(self) -> TensorIterationTypes: + """ + Iteration capability supported by this index. + + See TensorIterationTypes for reference. + """ + pass + + @property + @abstractmethod + def locate(self) -> bool: + """ + True if the index supports locate (aka random access), False otw. + """ + pass + + @property + @abstractmethod + def assembly(self) -> TensorAssemblyType: + """ + What assembly type is supported by the index. + + See TensorAssemblyType for reference. + """ + pass + + @property + @abstractmethod + def full(self) -> bool: + """ + True if the level is full, False otw. + + A level is considered full if it encompasses all valid coordinates along + the corresponding tensor dimension. + """ + pass + + @property + @abstractmethod + def ordered(self) -> bool: + """ + True if the level is ordered, False otw. + + A level is ordered when all coordinates that share the same ancestor are + ordered by increasing value (e.g. in typical CSR). + """ + pass + + @property + @abstractmethod + def unique(self) -> bool: + """ + True if coordinate in the level are unique, False otw. + + A level is considered unique if no collection of coordinates that share + the same ancestor contains duplicates. In CSR this is True, in COO it is + not. + """ + pass + + @property + @abstractmethod + def branchless(self) -> bool: + """ + True if the level doesn't branch, false otw. + + A level is considered branchless if no coordinate has a sibling (another + coordinate with same ancestor) and all coordinates in parent level have + a child. In other words if there is a bijection between the coordinates + in this level and the parent level. An example of the is the Singleton + index level in the COO format. + """ + pass + + @property + @abstractmethod + def compact(self) -> bool: + """ + True if the level is compact, false otw. + + A level is compact if no two coordinates are separated by an unlabled + node that does not encode a coordinate. An example of a compact level + can be found in CSR, while the DIA formats range and offset levels are + not compact (they have entries that would coorespond to entries outside + the tensors index range, e.g. column -1). + """ + pass + + @abstractmethod + def fields(self, lvl: int, dummy_symbol: symbolic.SymExpr) -> Dict[str, Data]: + """ + Generates the fields needed for the index. + + :returns: a Dict of fields that need to be present in the struct + """ + pass + + + def to_json(self): + attrs = serialize.all_properties_to_json(self) + + retdict = {"type": type(self).__name__, "attributes": attrs} + + return retdict + + + @classmethod + def from_json(cls, json_obj, context=None): + + # Selecting proper subclass + if json_obj['type'] == "TensorIndexDense": + self = TensorIndexDense.__new__(TensorIndexDense) + elif json_obj['type'] == "TensorIndexCompressed": + self = TensorIndexCompressed.__new__(TensorIndexCompressed) + elif json_obj['type'] == "TensorIndexSingleton": + self = TensorIndexSingleton.__new__(TensorIndexSingleton) + elif json_obj['type'] == "TensorIndexRange": + self = TensorIndexRange.__new__(TensorIndexRange) + elif json_obj['type'] == "TensorIndexOffset": + self = TensorIndexOffset.__new__(TensorIndexOffset) + else: + raise TypeError(f"Invalid data type, got: {json_obj['type']}") + + serialize.set_properties_from_json(self, json_obj['attributes'], context=context) + + return self + + +@make_properties +class TensorIndexDense(TensorIndex): + """ + Dense tensor index. + + Levels of this type encode the the coordinate in the interval [0, N), where + N is the size of the corresponding dimension. This level doesn't need any + index structure beyond the corresponding dimension size. + """ + + _ordered = Property(dtype=bool, default=False) + _unique = Property(dtype=bool) + + @property + def iteration_type(self) -> TensorIterationTypes: + return TensorIterationTypes.Value + + @property + def locate(self) -> bool: + return True + + @property + def assembly(self) -> TensorAssemblyType: + return TensorAssemblyType.Insert + + @property + def full(self) -> bool: + return True + + @property + def ordered(self) -> bool: + return self._ordered + + @property + def unique(self) -> bool: + return self._unique + + @property + def branchless(self) -> bool: + return False + + @property + def compact(self) -> bool: + return True + + def __init__(self, ordered: bool = True, unique: bool = True): + self._ordered = ordered + self._unique = unique + + def fields(self, lvl: int, dummy_symbol: symbolic.SymExpr) -> Dict[str, Data]: + return {} + + def __repr__(self) -> str: + s = "Dense" + + non_defaults = [] + if not self._ordered: + non_defaults.append("¬O") + if not self._unique: + non_defaults.append("¬U") + + if len(non_defaults) > 0: + s += f"({','.join(non_defaults)})" + + return s + + +@make_properties +class TensorIndexCompressed(TensorIndex): + """ + Tensor level that stores coordinates in segmented array. + + Levels of this type are compressed using a segented array. The pos array + holds the start and end positions of the segment in the crd (coordinate) + array that holds the child coordinates corresponding the parent. + """ + + _full = Property(dtype=bool, default=False) + _ordered = Property(dtype=bool, default=False) + _unique = Property(dtype=bool, default=False) + + @property + def iteration_type(self) -> TensorIterationTypes: + return TensorIterationTypes.Position + + @property + def locate(self) -> bool: + return False + + @property + def assembly(self) -> TensorAssemblyType: + return TensorAssemblyType.Append + + @property + def full(self) -> bool: + return self._full + + @property + def ordered(self) -> bool: + return self._ordered + + @property + def unique(self) -> bool: + return self._unique + + @property + def branchless(self) -> bool: + return False + + @property + def compact(self) -> bool: + return True + + def __init__(self, + full: bool = False, + ordered: bool = True, + unique: bool = True): + self._full = full + self._ordered = ordered + self._unique = unique + + def fields(self, lvl: int, dummy_symbol: symbolic.SymExpr) -> Dict[str, Data]: + return { + f"idx{lvl}_pos": dtypes.int32[dummy_symbol], # TODO (later) choose better length + f"idx{lvl}_crd": dtypes.int32[dummy_symbol], # TODO (later) choose better length + } + + def __repr__(self) -> str: + s = "Compressed" + + non_defaults = [] + if self._full: + non_defaults.append("F") + if not self._ordered: + non_defaults.append("¬O") + if not self._unique: + non_defaults.append("¬U") + + if len(non_defaults) > 0: + s += f"({','.join(non_defaults)})" + + return s + + +@make_properties +class TensorIndexSingleton(TensorIndex): + """ + Tensor index that encodes a single coordinate per parent coordinate. + + Levels of this type hold exactly one coordinate for every coordinate in the + parent level. An example can be seen in the COO format, where every + coordinate but the first is encoded in this manner. + """ + + _full = Property(dtype=bool, default=False) + _ordered = Property(dtype=bool, default=False) + _unique = Property(dtype=bool, default=False) + + @property + def iteration_type(self) -> TensorIterationTypes: + return TensorIterationTypes.Position + + @property + def locate(self) -> bool: + return False + + @property + def assembly(self) -> TensorAssemblyType: + return TensorAssemblyType.Append + + @property + def full(self) -> bool: + return self._full + + @property + def ordered(self) -> bool: + return self._ordered + + @property + def unique(self) -> bool: + return self._unique + + @property + def branchless(self) -> bool: + return True + + @property + def compact(self) -> bool: + return True + + def __init__(self, + full: bool = False, + ordered: bool = True, + unique: bool = True): + self._full = full + self._ordered = ordered + self._unique = unique + + def fields(self, lvl: int, dummy_symbol: symbolic.SymExpr) -> Dict[str, Data]: + return { + f"idx{lvl}_crd": dtypes.int32[dummy_symbol], # TODO (later) choose better length + } + + def __repr__(self) -> str: + s = "Singleton" + + non_defaults = [] + if self._full: + non_defaults.append("F") + if not self._ordered: + non_defaults.append("¬O") + if not self._unique: + non_defaults.append("¬U") + + if len(non_defaults) > 0: + s += f"({','.join(non_defaults)})" + + return s + + +@make_properties +class TensorIndexRange(TensorIndex): + """ + Tensor index that encodes a interval of coordinates for every parent. + + The interval is computed from an offset for each parent together with the + tensor dimension size of this level (M) and the parent level (N) parents + corresponding tensor. Given the parent coordinate i, the level encodes the + range of coordinates between max(0, -offset[i]) and min(N, M - offset[i]). + """ + + _ordered = Property(dtype=bool, default=False) + _unique = Property(dtype=bool, default=False) + + @property + def iteration_type(self) -> TensorIterationTypes: + return TensorIterationTypes.Value + + @property + def locate(self) -> bool: + return False + + @property + def assembly(self) -> TensorAssemblyType: + return TensorAssemblyType.NoAssembly + + @property + def full(self) -> bool: + return False + + @property + def ordered(self) -> bool: + return self._ordered + + @property + def unique(self) -> bool: + return self._unique + + @property + def branchless(self) -> bool: + return False + + @property + def compact(self) -> bool: + return False + + def __init__(self, ordered: bool = True, unique: bool = True): + self._ordered = ordered + self._unique = unique + + def fields(self, lvl: int, dummy_symbol: symbolic.SymExpr) -> Dict[str, Data]: + return { + f"idx{lvl}_offset": dtypes.int32[dummy_symbol], # TODO (later) choose better length + } + + def __repr__(self) -> str: + s = "Range" + + non_defaults = [] + if not self._ordered: + non_defaults.append("¬O") + if not self._unique: + non_defaults.append("¬U") + + if len(non_defaults) > 0: + s += f"({','.join(non_defaults)})" + + return s + + +@make_properties +class TensorIndexOffset(TensorIndex): + """ + Tensor index that encodes the next coordinates as offset from parent. + + Given a parent coordinate i and an offset index k, the level encodes the + coordinate j = i + offset[k]. + """ + + _ordered = Property(dtype=bool, default=False) + _unique = Property(dtype=bool, default=False) + + @property + def iteration_type(self) -> TensorIterationTypes: + return TensorIterationTypes.Position + + @property + def locate(self) -> bool: + return False + + @property + def assembly(self) -> TensorAssemblyType: + return TensorAssemblyType.NoAssembly + + @property + def full(self) -> bool: + return False + + @property + def ordered(self) -> bool: + return self._ordered + + @property + def unique(self) -> bool: + return self._unique + + @property + def branchless(self) -> bool: + return True + + @property + def compact(self) -> bool: + return False + + def __init__(self, ordered: bool = True, unique: bool = True): + self._ordered = ordered + self._unique = unique + + def fields(self, lvl: int, dummy_symbol: symbolic.SymExpr) -> Dict[str, Data]: + return { + f"idx{lvl}_offset": dtypes.int32[dummy_symbol], # TODO (later) choose better length + } + + def __repr__(self) -> str: + s = "Offset" + + non_defaults = [] + if not self._ordered: + non_defaults.append("¬O") + if not self._unique: + non_defaults.append("¬U") + + if len(non_defaults) > 0: + s += f"({','.join(non_defaults)})" + + return s + + +@make_properties +class Tensor(Structure): + """ + Abstraction for Tensor storage format. + + This abstraction is based on [https://doi.org/10.1145/3276493]. + """ + + value_dtype = TypeClassProperty(default=dtypes.int32, choices=dtypes.Typeclasses) + tensor_shape = ShapeProperty(default=[]) + indices = ListProperty(element_type=TensorIndex) + index_ordering = ListProperty(element_type=symbolic.SymExpr) + value_count = SymbolicProperty(default=0) + + def __init__( + self, + value_dtype: dtypes.Typeclasses, + tensor_shape, + indices: List[Tuple[TensorIndex, Union[int, symbolic.SymExpr]]], + value_count: symbolic.SymExpr, + name: str, + transient: bool = False, + storage: dtypes.StorageType = dtypes.StorageType.Default, + location: Dict[str, str] = None, + lifetime: dtypes.AllocationLifetime = dtypes.AllocationLifetime.Scope, + debuginfo: dtypes.DebugInfo = None): + """ + Constructor for Tensor storage format. + + Below are examples of common matrix storage formats: + + .. code-block:: python + + M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz')) + + csr = dace.data.Tensor( + dace.float32, + (M, N), + [(dace.data.Dense(), 0), (dace.data.Compressed(), 1)], + nnz, + "CSR_Matrix", + ) + + csc = dace.data.Tensor( + dace.float32, + (M, N), + [(dace.data.Dense(), 1), (dace.data.Compressed(), 0)], + nnz, + "CSC_Matrix", + ) + + coo = dace.data.Tensor( + dace.float32, + (M, N), + [ + (dace.data.Compressed(unique=False), 0), + (dace.data.Singleton(), 1), + ], + nnz, + "CSC_Matrix", + ) + + num_diags = dace.symbol('num_diags') # number of diagonals stored + + diag = dace.data.Tensor( + dace.float32, + (M, N), + [ + (dace.data.Dense(), num_diags), + (dace.data.Range(), 0), + (dace.data.Offset(), 1), + ], + nnz, + "DIA_Matrix", + ) + + Below you can find examples of common 3rd order tensor storage formats: + + .. code-block:: python + + I, J, K, nnz = (dace.symbol(s) for s in ('I', 'J', 'K', 'nnz')) + + coo = dace.data.Tensor( + dace.float32, + (I, J, K), + [ + (dace.data.Compressed(unique=False), 0), + (dace.data.Singleton(unique=False), 1), + (dace.data.Singleton(), 2), + ], + nnz, + "COO_3D_Tensor", + ) + + csf = dace.data.Tensor( + dace.float32, + (I, J, K), + [ + (dace.data.Compressed(), 0), + (dace.data.Compressed(), 1), + (dace.data.Compressed(), 2), + ], + nnz, + "CSF_3D_Tensor", + ) + + :param value_type: data type of the explicitly stored values. + :param tensor_shape: logical shape of tensor (#rows, #cols, etc...) + :param indices: + a list of tuples, each tuple represents a level in the tensor + storage hirachy, specifying the levels tensor index type, and the + corresponding dimension this level encodes (as index of the + tensor_shape tuple above). The order of the dimensions may differ + from the logical shape of the tensor, e.g. as seen in the CSC + format. If an index's dimension is unrelated to the tensor shape + (e.g. in diagonal format where the first index's dimension is the + number of diagonals stored), a symbol can be specified instead. + :param value_count: number of explicitly stored values. + :param name: name of resulting struct. + :param others: See Structure class for remaining arguments + """ + + self.value_dtype = value_dtype + self.tensor_shape = tensor_shape + self.value_count = value_count + + indices, index_ordering = zip(*indices) + self.indices, self.index_ordering = list(indices), list(index_ordering) + + num_dims = len(tensor_shape) + dimension_order = [idx for idx in self.index_ordering if isinstance(idx, int)] + + # all tensor dimensions must occure exactly once in indices + if not sorted(dimension_order) == list(range(num_dims)): + raise TypeError(( + f"All tensor dimensions must be refferenced exactly once in " + f"tensor indices. (referenced dimensions: {dimension_order}; " + f"tensor dimensions: {list(range(num_dims))})" + )) + + # assembling permanent and index specific fields + fields = dict( + order=Scalar(dtypes.int32), + dim_sizes=dtypes.int32[num_dims], + value_count=value_count, + values=dtypes.float32[value_count], + ) + + for (lvl, index) in enumerate(indices): + fields.update(index.fields(lvl, value_count)) + + super(Tensor, self).__init__(fields, name, transient, storage, location, + lifetime, debuginfo) + + def __repr__(self): + return f"{self.name} (dtype: {self.value_dtype}, shape: {list(self.tensor_shape)}, indices: {self.indices})" + + @staticmethod + def from_json(json_obj, context=None): + if json_obj['type'] != 'Tensor': + raise TypeError("Invalid data type") + + # Create dummy object + tensor = Tensor.__new__(Tensor) + serialize.set_properties_from_json(tensor, json_obj, context=context) + + return tensor @make_properties diff --git a/tests/sdfg/data/tensor_test.py b/tests/sdfg/data/tensor_test.py new file mode 100644 index 0000000000..06d3363a8b --- /dev/null +++ b/tests/sdfg/data/tensor_test.py @@ -0,0 +1,131 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +import dace +import numpy as np +import pytest + +from scipy import sparse + + +def test_read_csr_tensor(): + + M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz')) + csr_obj = dace.data.Tensor( + dace.float32, + (M, N), + [(dace.data.TensorIndexDense(), 0), (dace.data.TensorIndexCompressed(), 1)], + nnz, + "CSR_Tensor") + + sdfg = dace.SDFG('tensor_csr_to_dense') + + sdfg.add_datadesc('A', csr_obj) + sdfg.add_array('B', [M, N], dace.float32) + + sdfg.add_view('vindptr', csr_obj.members['idx1_pos'].shape, csr_obj.members['idx1_pos'].dtype) + sdfg.add_view('vindices', csr_obj.members['idx1_crd'].shape, csr_obj.members['idx1_crd'].dtype) + sdfg.add_view('vdata', csr_obj.members['values'].shape, csr_obj.members['values'].dtype) + + state = sdfg.add_state() + + A = state.add_access('A') + B = state.add_access('B') + + indptr = state.add_access('vindptr') + indices = state.add_access('vindices') + data = state.add_access('vdata') + + state.add_edge(A, None, indptr, 'views', dace.Memlet.from_array('A.idx1_pos', csr_obj.members['idx1_pos'])) + state.add_edge(A, None, indices, 'views', dace.Memlet.from_array('A.idx1_crd', csr_obj.members['idx1_crd'])) + state.add_edge(A, None, data, 'views', dace.Memlet.from_array('A.values', csr_obj.members['values'])) + + ime, imx = state.add_map('i', dict(i='0:M')) + jme, jmx = state.add_map('idx', dict(idx='start:stop')) + jme.add_in_connector('start') + jme.add_in_connector('stop') + t = state.add_tasklet('indirection', {'j', '__val'}, {'__out'}, '__out[i, j] = __val') + + state.add_memlet_path(indptr, ime, jme, memlet=dace.Memlet(data='vindptr', subset='i'), dst_conn='start') + state.add_memlet_path(indptr, ime, jme, memlet=dace.Memlet(data='vindptr', subset='i+1'), dst_conn='stop') + state.add_memlet_path(indices, ime, jme, t, memlet=dace.Memlet(data='vindices', subset='idx'), dst_conn='j') + state.add_memlet_path(data, ime, jme, t, memlet=dace.Memlet(data='vdata', subset='idx'), dst_conn='__val') + state.add_memlet_path(t, jmx, imx, B, memlet=dace.Memlet(data='B', subset='0:M, 0:N', volume=1), src_conn='__out') + + func = sdfg.compile() + + rng = np.random.default_rng(42) + A = sparse.random(20, 20, density=0.1, format='csr', dtype=np.float32, random_state=rng) + B = np.zeros((20, 20), dtype=np.float32) + + inpA = csr_obj.dtype._typeclass.as_ctypes()(idx1_pos=A.indptr.__array_interface__['data'][0], + idx1_crd=A.indices.__array_interface__['data'][0], + values=A.data.__array_interface__['data'][0]) + + func(A=inpA, B=B, M=A.shape[0], N=A.shape[1], nnz=A.nnz) + ref = A.toarray() + + sdfg.save("./tensor.json") + + assert np.allclose(B, ref) + + +def test_csr_fields(): + + M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz')) + + csr = dace.data.Tensor( + dace.float32, + (M, N), + [(dace.data.TensorIndexDense(), 0), (dace.data.TensorIndexCompressed(), 1)], + nnz, + "CSR_Matrix", + ) + + expected_fields = ["idx1_pos", "idx1_crd"] + assert all(key in csr.members.keys() for key in expected_fields) + + +def test_dia_fields(): + + M, N, nnz, num_diags = (dace.symbol(s) for s in ('M', 'N', 'nnz', 'num_diags')) + + diag = dace.data.Tensor( + dace.float32, + (M, N), + [ + (dace.data.TensorIndexDense(), num_diags), + (dace.data.TensorIndexRange(), 0), + (dace.data.TensorIndexOffset(), 1), + ], + nnz, + "DIA_Matrix", + ) + + expected_fields = ["idx1_offset", "idx2_offset"] + assert all(key in diag.members.keys() for key in expected_fields) + + +def test_coo_fields(): + + I, J, K, nnz = (dace.symbol(s) for s in ('I', 'J', 'K', 'nnz')) + + coo = dace.data.Tensor( + dace.float32, + (I, J, K), + [ + (dace.data.TensorIndexCompressed(unique=False), 0), + (dace.data.TensorIndexSingleton(unique=False), 1), + (dace.data.TensorIndexSingleton(), 2), + ], + nnz, + "COO_3D_Tensor", + ) + + expected_fields = ["idx0_pos", "idx0_crd", "idx1_crd", "idx2_crd"] + assert all(key in coo.members.keys() for key in expected_fields) + + +if __name__ == "__main__": + test_read_csr_tensor() + test_csr_fields() + test_dia_fields() + test_coo_fields() From 3ddd2cccf54e3812c08c3762cd3c4257d312b7e2 Mon Sep 17 00:00:00 2001 From: Jan Kleine Date: Mon, 30 Oct 2023 14:17:30 +0100 Subject: [PATCH 05/19] Remove eroneous file creation (#1411) --- tests/sdfg/data/tensor_test.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/sdfg/data/tensor_test.py b/tests/sdfg/data/tensor_test.py index 06d3363a8b..3057539f70 100644 --- a/tests/sdfg/data/tensor_test.py +++ b/tests/sdfg/data/tensor_test.py @@ -63,8 +63,6 @@ def test_read_csr_tensor(): func(A=inpA, B=B, M=A.shape[0], N=A.shape[1], nnz=A.nnz) ref = A.toarray() - sdfg.save("./tensor.json") - assert np.allclose(B, ref) From 69d4f3d05aa84c77a44add95a19f23320e27c909 Mon Sep 17 00:00:00 2001 From: matteonussbauemer Date: Mon, 30 Oct 2023 21:59:38 +0100 Subject: [PATCH 06/19] create new branch that only contains changes to subsets.py and tests --- dace/subsets.py | 268 ++++++++++++++++++++++++---- tests/subset_covers_precise_test.py | 161 +++++++++++++++++ 2 files changed, 399 insertions(+), 30 deletions(-) create mode 100644 tests/subset_covers_precise_test.py diff --git a/dace/subsets.py b/dace/subsets.py index f8b66a565d..f2a2072343 100644 --- a/dace/subsets.py +++ b/dace/subsets.py @@ -10,21 +10,52 @@ from dace.config import Config +def nng(expr): + # When dealing with set sizes, assume symbols are non-negative + try: + # TODO: Fix in symbol definition, not here + for sym in list(expr.free_symbols): + expr = expr.subs({sym: sp.Symbol(sym.name, nonnegative=True)}) + return expr + except AttributeError: # No free_symbols in expr + return expr + +def bounding_box_cover_exact(subset_a, subset_b) -> bool: + return all([(symbolic.simplify_ext(nng(rb)) <= symbolic.simplify_ext(nng(orb))) == True + and (symbolic.simplify_ext(nng(re)) >= symbolic.simplify_ext(nng(ore))) == True + for rb, re, orb, ore in zip(subset_a.min_element(), subset_a.max_element(), + subset_b.min_element(), subset_b.max_element())]) + +def bounding_box_symbolic_positive(subset_a, subset_b, approximation = False)-> bool: + min_elements_a = subset_a.min_element_approx() if approximation else subset_a.min_element() + max_elements_a = subset_a.max_element_approx() if approximation else subset_a.max_element() + min_elements_b = subset_b.min_element_approx() if approximation else subset_b.min_element() + max_elements_b = subset_b.max_element_approx() if approximation else subset_b.max_element() + + for rb, re, orb, ore in zip(min_elements_a, max_elements_a, + min_elements_b, max_elements_b): + # NOTE: We first test for equality, which always returns True or False. If the equality test returns + # False, then we test for less-equal and greater-equal, which may return an expression, leading to + # TypeError. This is a workaround for the case where two expressions are the same or equal and + # SymPy confirms this but fails to return True when testing less-equal and greater-equal. + + # lower bound: first check whether symbolic positive condition applies + if not (len(rb.free_symbols) == 0 and len(orb.free_symbols) == 1): + if not (symbolic.simplify_ext(nng(rb)) == symbolic.simplify_ext(nng(orb)) or + symbolic.simplify_ext(nng(rb)) <= symbolic.simplify_ext(nng(orb))): + return False + # upper bound: first check whether symbolic positive condition applies + if not (len(re.free_symbols) == 1 and len(ore.free_symbols) == 0): + if not (symbolic.simplify_ext(nng(re)) == symbolic.simplify_ext(nng(ore)) or + symbolic.simplify_ext(nng(re)) >= symbolic.simplify_ext(nng(ore))): + return False + return True + class Subset(object): """ Defines a subset of a data descriptor. """ def covers(self, other): """ Returns True if this subset covers (using a bounding box) another subset. """ - def nng(expr): - # When dealing with set sizes, assume symbols are non-negative - try: - # TODO: Fix in symbol definition, not here - for sym in list(expr.free_symbols): - expr = expr.subs({sym: sp.Symbol(sym.name, nonnegative=True)}) - return expr - except AttributeError: # No free_symbols in expr - return expr - symbolic_positive = Config.get('optimizer', 'symbolic_positive') if not symbolic_positive: @@ -38,28 +69,65 @@ def nng(expr): else: try: - for rb, re, orb, ore in zip(self.min_element_approx(), self.max_element_approx(), - other.min_element_approx(), other.max_element_approx()): - # NOTE: We first test for equality, which always returns True or False. If the equality test returns - # False, then we test for less-equal and greater-equal, which may return an expression, leading to - # TypeError. This is a workaround for the case where two expressions are the same or equal and - # SymPy confirms this but fails to return True when testing less-equal and greater-equal. - - # lower bound: first check whether symbolic positive condition applies - if not (len(rb.free_symbols) == 0 and len(orb.free_symbols) == 1): - if not (symbolic.simplify_ext(nng(rb)) == symbolic.simplify_ext(nng(orb)) or - symbolic.simplify_ext(nng(rb)) <= symbolic.simplify_ext(nng(orb))): - return False - - # upper bound: first check whether symbolic positive condition applies - if not (len(re.free_symbols) == 1 and len(ore.free_symbols) == 0): - if not (symbolic.simplify_ext(nng(re)) == symbolic.simplify_ext(nng(ore)) or - symbolic.simplify_ext(nng(re)) >= symbolic.simplify_ext(nng(ore))): - return False + if not bounding_box_symbolic_positive(self, other, True): + return False except TypeError: return False return True + + def covers_precise(self, other): + """ Returns True if self contains all the elements in other. """ + + # If self does not cover other with a bounding box union, return false. + symbolic_positive = Config.get('optimizer', 'symbolic_positive') + try: + bounding_box_cover = bounding_box_cover_exact(self, other) if symbolic_positive else bounding_box_symbolic_positive(self, other) + if not bounding_box_cover: + return False + except TypeError: + return False + + try: + # if self is an index no further distinction is needed + if isinstance(self, Indices): + return True + + elif isinstance(self, Range): + # other is an index so we need to check if the step of self is such that other is covered + # self.start % self.step == other.index % self.step + if isinstance(other, Indices): + try: + return all( + [(symbolic.simplify_ext(nng(start)) % symbolic.simplify_ext(nng(step)) == + symbolic.simplify_ext(nng(i)) % symbolic.simplify_ext(nng(step))) == True + for (start, _, step), i in zip(self.ranges, other.indices)]) + except: + return False + if isinstance(other, Range): + # other is a range so in every dimension self.step has to divide other.step and + # self.start % self.step = other.start % other.step + try: + self_steps = [r[2] for r in self.ranges] + other_steps = [r[2] for r in other.ranges] + for start, step, ostart, ostep in zip(self.min_element(), self_steps, other.min_element(), + other_steps): + if not (ostep % step == 0 and + ((symbolic.simplify_ext(nng(start)) == symbolic.simplify_ext(nng(ostart))) or + (symbolic.simplify_ext(nng(start)) % symbolic.simplify_ext( + nng(step)) == symbolic.simplify_ext(nng(ostart)) % symbolic.simplify_ext( + nng(ostep))) == True)): + return False + except: + return False + return True + # unknown type + else: + raise TypeError + + except TypeError: + return False + def __repr__(self): return '%s (%s)' % (type(self).__name__, self.__str__()) @@ -973,6 +1041,111 @@ def intersection(self, other: 'Indices'): return self return None +class Subsetlist(Subset): + """ + Wrapper subset type that stores multiple Subsets in a list. + """ + + def __init__(self, subset): + self.subset_list: list[Subset] = [] + if isinstance(subset, Subsetlist): + self.subset_list = subset.subset_list + elif isinstance(subset, list): + for subset in subset: + if not subset: + break + if isinstance(subset, (Range, Indices)): + self.subset_list.append(subset) + else: + raise NotImplementedError + elif isinstance(subset, (Range, Indices)): + self.subset_list = [subset] + + def covers(self, other): + """ + Returns True if this Subsetlist covers another subset (using a bounding box). + If other is another SubsetList then self and other will + only return true if self is other. If other is a different type of subset + true is returned when one of the subsets in self is equal to other. + """ + + if isinstance(other, Subsetlist): + for subset in self.subset_list: + # check if ther is a subset in self that covers every subset in other + if all(subset.covers(s) for s in other.subset_list): + return True + # return False if that's not the case for any of the subsets in self + return False + else: + return any(s.covers(other) for s in self.subset_list) + + def covers_precise(self, other): + """ + Returns True if this Subsetlist covers another + subset. If other is another SubsetList then self and other will + only return true if self is other. If other is a different type of subset + true is returned when one of the subsets in self is equal to other + """ + + if isinstance(other, Subsetlist): + for subset in self.subset_list: + # check if ther is a subset in self that covers every subset in other + if all(subset.covers_precise(s) for s in other.subset_list): + return True + # return False if that's not the case for any of the subsets in self + return False + else: + return any(s.covers_precise(other) for s in self.subset_list) + + def __str__(self): + string = '' + for subset in self.subset_list: + if not string == '': + string += " " + string += subset.__str__() + return string + + def dims(self): + if not self.subset_list: + return 0 + return next(iter(self.subset_list)).dims() + + def union(self, other: Subset): + """In place union of self with another Subset""" + try: + if isinstance(other, Subsetlist): + self.subset_list += other.subset_list + elif isinstance(other, Indices) or isinstance(other, Range): + self.subset_list.append(other) + else: + raise TypeError + except TypeError: # cannot determine truth value of Relational + return None + + @property + def free_symbols(self) -> Set[str]: + result = set() + for subset in self.subset_list: + result |= subset.free_symbols + return result + + def replace(self, repl_dict): + for subset in self.subset_list: + subset.replace(repl_dict) + + def num_elements(self): + # TODO: write something more meaningful here + min = 0 + for subset in self.subset_list: + try: + if subset.num_elements() < min or min ==0: + min = subset.num_elements() + except: + continue + + return min + + def _union_special_cases(arb: symbolic.SymbolicType, brb: symbolic.SymbolicType, are: symbolic.SymbolicType, bre: symbolic.SymbolicType): @@ -1038,6 +1211,8 @@ def bounding_box_union(subset_a: Subset, subset_b: Subset) -> Range: return Range(result) + + def union(subset_a: Subset, subset_b: Subset) -> Subset: """ Compute the union of two Subset objects. If the subsets are not of the same type, degenerates to bounding-box @@ -1056,6 +1231,9 @@ def union(subset_a: Subset, subset_b: Subset) -> Subset: return subset_b elif subset_a is None and subset_b is None: raise TypeError('Both subsets cannot be None') + elif isinstance(subset_a, Subsetlist) or isinstance( + subset_b, Subsetlist): + return list_union(subset_a, subset_b) elif type(subset_a) != type(subset_b): return bounding_box_union(subset_a, subset_b) elif isinstance(subset_a, Indices): @@ -1066,13 +1244,43 @@ def union(subset_a: Subset, subset_b: Subset) -> Subset: # TODO(later): More involved Strided-Tiled Range union return bounding_box_union(subset_a, subset_b) else: - warnings.warn('Unrecognized Subset type %s in union, degenerating to' - ' bounding box' % type(subset_a).__name__) + warnings.warn( + 'Unrecognized Subset type %s in union, degenerating to' + ' bounding box' % type(subset_a).__name__) return bounding_box_union(subset_a, subset_b) except TypeError: # cannot determine truth value of Relational return None +def list_union(subset_a: Subset, subset_b: Subset) -> Subset: + """ + Returns the union of two Subset lists. + + :param subset_a: The first subset. + :param subset_b: The second subset. + :return: A Subsetlist object that contains all elements of subset_a and subset_b. + """ + # TODO(later): Merge subsets in both lists if possible + try: + if subset_a is not None and subset_b is None: + return subset_a + elif subset_b is not None and subset_a is None: + return subset_b + elif subset_a is None and subset_b is None: + raise TypeError('Both subsets cannot be None') + elif type(subset_a) != type(subset_b): + if isinstance(subset_b, Subsetlist): + return Subsetlist(subset_b.subset_list.append(subset_a)) + else: + return Subsetlist(subset_a.subset_list.append(subset_b)) + elif isinstance(subset_a, Subsetlist): + return Subsetlist(subset_a.subset_list + subset_b.subset_list) + else: + return Subsetlist([subset_a, subset_b]) + + except TypeError: + return None + def intersects(subset_a: Subset, subset_b: Subset) -> Union[bool, None]: """ Returns True if two subsets intersect, False if they do not, or diff --git a/tests/subset_covers_precise_test.py b/tests/subset_covers_precise_test.py new file mode 100644 index 0000000000..793926ab1c --- /dev/null +++ b/tests/subset_covers_precise_test.py @@ -0,0 +1,161 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. + +import pytest +import dace +from dace.subsets import Indices, Subset, Range +from dace.config import Config + + +def test_integer_overlap_no_cover(): + # two overlapping subsets, neither of them covering the other + subset1 = Range.from_string("0:10:1") + subset2 = Range.from_string("5:11:1") + + assert (subset1.covers_precise(subset2) is False) + assert (subset2.covers_precise(subset1) is False) + + subset1 = Range.from_string("0:10:1, 3:8:1") + subset2 = Range.from_string("5:11:1, 2:9:1") + assert (subset1.covers_precise(subset2) is False) + assert (subset2.covers_precise(subset1) is False) + + +def test_integer_bounding_box_cover_coprime_step(): + # bb of subset1 covers bb of subset2 but step sizes of the subsets are coprime + subset1 = Range.from_string("0:10:3") + subset2 = Range.from_string("0:10:2") + + assert (subset1.covers_precise(subset2) is False) + assert (subset2.covers_precise(subset1) is False) + + subset1 = Range.from_string("0:10:3, 5:10:2") + subset2 = Range.from_string("0:10:2, 5:10:4") + assert (subset1.covers_precise(subset2) is False) + assert (subset2.covers_precise(subset1) is False) + + subset1 = Range.from_string("0:10:3, 6:10:2") + subset2 = Range.from_string("0:10:2, 5:10:4") + assert (subset1.covers_precise(subset2) is False) + assert (subset2.covers_precise(subset1) is False) + + +def test_integer_same_step_different_start(): + subset1 = Range.from_string("0:10:3") + subset2 = Range.from_string("1:10:3") + + assert (subset1.covers_precise(subset2) is False) + + +def test_integer_bounding_box_symbolic_step(): + + subset1 = Range.from_string("0:20:s") + subset2 = Range.from_string("0:10:s") + subset3 = Range.from_string("0:10:2 * s") + + assert (subset1.covers_precise(subset2)) + assert (subset1.covers_precise(subset3)) + assert (subset3.covers_precise(subset1) is False) + assert (subset3.covers_precise(subset2) is False) + + subset1 = Range.from_string("0:20:s, 30:50:k") + subset2 = Range.from_string("0:10:s, 40:50:k") + assert (subset1.covers_precise(subset2) is False) + + +def test_symbolic_boundaries(): + + subset1 = Range.from_string("N:M:1") + subset2 = Range.from_string("N:M:2") + assert (subset1.covers_precise(subset2)) + assert (subset2.covers_precise(subset1) is False) + + subset1 = Range.from_string("N + 1:M:1") + subset2 = Range.from_string("N:M:2") + assert (subset1.covers_precise(subset2) is False) + assert (subset2.covers_precise(subset1) is False) + + subset1 = Range.from_string("-N:M:1") + subset2 = Range.from_string("N:M:2") + assert (subset1.covers_precise(subset2) is False) + assert (subset2.covers_precise(subset1) is False) + + +def test_symbolic_boundaries_not_symbolic_positive(): + Config.set('optimizer', 'symbolic_positive', value=False) + + subset1 = Range.from_string("N:M:1") + subset2 = Range.from_string("N:M:2") + assert (subset1.covers_precise(subset2)) + assert (subset2.covers_precise(subset1) is False) + + subset1 = Range.from_string("N + 1:M:1") + subset2 = Range.from_string("N:M:2") + assert (subset1.covers_precise(subset2) is False) + assert (subset2.covers_precise(subset1) is False) + + subset1 = Range.from_string("-N:M:1") + subset2 = Range.from_string("N:M:2") + assert (subset1.covers_precise(subset2) is False) + assert (subset2.covers_precise(subset1) is False) + + +def test_range_indices(): + subset1 = Indices.from_string('0') + subset2 = Range.from_string('0:2:1') + assert (subset2.covers_precise(subset1)) + assert (subset1.covers_precise(subset2) is False) + subset1 = Indices.from_string('0') + subset2 = Range.from_string('0:1:1') + assert (subset2.covers_precise(subset1)) + assert (subset1.covers_precise(subset2)) + subset1 = Indices.from_string('0, 1') + subset2 = Range.from_string('0:2:1, 2:4:1') + assert (subset2.covers_precise(subset1) is False) + assert (subset1.covers_precise(subset2) is False) + +def test_index_index(): + subset1 = Indices.from_string('1') + subset2 = Indices.from_string('1') + assert (subset2.covers_precise(subset1)) + assert (subset1.covers_precise(subset2)) + subset1 = Indices.from_string('1') + subset2 = Indices.from_string('2') + assert (subset2.covers_precise(subset1) is False) + assert (subset1.covers_precise(subset2) is False) + subset1 = Indices.from_string('1, 2') + subset2 = Indices.from_string('1, 2') + assert (subset2.covers_precise(subset1)) + assert (subset1.covers_precise(subset2)) + subset1 = Indices.from_string('2, 1') + subset2 = Indices.from_string('1, 2') + assert (subset2.covers_precise(subset1) is False) + assert (subset1.covers_precise(subset2) is False) + subset1 = Indices.from_string('i') + subset2 = Indices.from_string('j') + assert (subset2.covers_precise(subset1) is False) + assert (subset1.covers_precise(subset2) is False) + subset1 = Indices.from_string('i') + subset2 = Indices.from_string('i') + assert (subset2.covers_precise(subset1)) + assert (subset1.covers_precise(subset2)) + subset1 = Indices.from_string('i, j') + subset2 = Indices.from_string('i, k') + assert (subset2.covers_precise(subset1) is False) + assert (subset1.covers_precise(subset2) is False) + subset1 = Indices.from_string('i, j') + subset2 = Indices.from_string('i, j') + assert (subset2.covers_precise(subset1)) + assert (subset1.covers_precise(subset2)) + + + + +if __name__ == "__main__": + test_integer_overlap_no_cover() + test_integer_bounding_box_cover_coprime_step() + test_integer_same_step_different_start() + test_integer_bounding_box_symbolic_step() + test_symbolic_boundaries() + test_symbolic_boundaries_not_symbolic_positive() + test_range_indices() + test_index_index() From c1935b6b8995404240b82ebd548d958d6ab68502 Mon Sep 17 00:00:00 2001 From: matteonussbauemer Date: Mon, 30 Oct 2023 22:15:43 +0100 Subject: [PATCH 07/19] formatting --- tests/subset_covers_precise_test.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/tests/subset_covers_precise_test.py b/tests/subset_covers_precise_test.py index 793926ab1c..644cfa20ee 100644 --- a/tests/subset_covers_precise_test.py +++ b/tests/subset_covers_precise_test.py @@ -7,7 +7,9 @@ def test_integer_overlap_no_cover(): - # two overlapping subsets, neither of them covering the other + """ + two overlapping subsets, neither of them covering the other + """ subset1 = Range.from_string("0:10:1") subset2 = Range.from_string("5:11:1") @@ -21,7 +23,9 @@ def test_integer_overlap_no_cover(): def test_integer_bounding_box_cover_coprime_step(): - # bb of subset1 covers bb of subset2 but step sizes of the subsets are coprime + """ + boundingbox of subset1 covers bb of subset2 but step sizes of the subsets are coprime + """ subset1 = Range.from_string("0:10:3") subset2 = Range.from_string("0:10:2") @@ -47,7 +51,6 @@ def test_integer_same_step_different_start(): def test_integer_bounding_box_symbolic_step(): - subset1 = Range.from_string("0:20:s") subset2 = Range.from_string("0:10:s") subset3 = Range.from_string("0:10:2 * s") @@ -63,7 +66,6 @@ def test_integer_bounding_box_symbolic_step(): def test_symbolic_boundaries(): - subset1 = Range.from_string("N:M:1") subset2 = Range.from_string("N:M:2") assert (subset1.covers_precise(subset2)) @@ -104,10 +106,12 @@ def test_range_indices(): subset2 = Range.from_string('0:2:1') assert (subset2.covers_precise(subset1)) assert (subset1.covers_precise(subset2) is False) + subset1 = Indices.from_string('0') subset2 = Range.from_string('0:1:1') assert (subset2.covers_precise(subset1)) assert (subset1.covers_precise(subset2)) + subset1 = Indices.from_string('0, 1') subset2 = Range.from_string('0:2:1, 2:4:1') assert (subset2.covers_precise(subset1) is False) @@ -118,30 +122,37 @@ def test_index_index(): subset2 = Indices.from_string('1') assert (subset2.covers_precise(subset1)) assert (subset1.covers_precise(subset2)) + subset1 = Indices.from_string('1') subset2 = Indices.from_string('2') assert (subset2.covers_precise(subset1) is False) assert (subset1.covers_precise(subset2) is False) + subset1 = Indices.from_string('1, 2') subset2 = Indices.from_string('1, 2') assert (subset2.covers_precise(subset1)) assert (subset1.covers_precise(subset2)) + subset1 = Indices.from_string('2, 1') subset2 = Indices.from_string('1, 2') assert (subset2.covers_precise(subset1) is False) assert (subset1.covers_precise(subset2) is False) + subset1 = Indices.from_string('i') subset2 = Indices.from_string('j') assert (subset2.covers_precise(subset1) is False) assert (subset1.covers_precise(subset2) is False) + subset1 = Indices.from_string('i') subset2 = Indices.from_string('i') assert (subset2.covers_precise(subset1)) assert (subset1.covers_precise(subset2)) + subset1 = Indices.from_string('i, j') subset2 = Indices.from_string('i, k') assert (subset2.covers_precise(subset1) is False) assert (subset1.covers_precise(subset2) is False) + subset1 = Indices.from_string('i, j') subset2 = Indices.from_string('i, j') assert (subset2.covers_precise(subset1)) From ecbca2d990272fce01e8c3d1fdc67a8d2984f462 Mon Sep 17 00:00:00 2001 From: matteonussbauemer Date: Tue, 31 Oct 2023 12:29:08 +0100 Subject: [PATCH 08/19] rename Subsetlist to SubsetUnion (cherry picked from commit e75e782b86fa6476af84ec59d878624e79369a18) --- dace/properties.py | 2 +- dace/subsets.py | 26 +++++++++++++------------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/dace/properties.py b/dace/properties.py index 44f8b4fbcc..e02a54ad1f 100644 --- a/dace/properties.py +++ b/dace/properties.py @@ -1153,7 +1153,7 @@ def allow_none(self): def __set__(self, obj, val): if isinstance(val, str): val = self.from_string(val) - if (val is not None and not isinstance(val, sbs.Range) and not isinstance(val, sbs.Indices)): + if (val is not None and not isinstance(val, sbs.Range) and not isinstance(val, sbs.Indices) and not isinstance(val, sbs.SubsetUnion)): raise TypeError("Subset property must be either Range or Indices: got {}".format(type(val).__name__)) super(SubsetProperty, self).__set__(obj, val) diff --git a/dace/subsets.py b/dace/subsets.py index f2a2072343..f53520c5aa 100644 --- a/dace/subsets.py +++ b/dace/subsets.py @@ -1041,14 +1041,14 @@ def intersection(self, other: 'Indices'): return self return None -class Subsetlist(Subset): +class SubsetUnion(Subset): """ Wrapper subset type that stores multiple Subsets in a list. """ def __init__(self, subset): self.subset_list: list[Subset] = [] - if isinstance(subset, Subsetlist): + if isinstance(subset, SubsetUnion): self.subset_list = subset.subset_list elif isinstance(subset, list): for subset in subset: @@ -1069,7 +1069,7 @@ def covers(self, other): true is returned when one of the subsets in self is equal to other. """ - if isinstance(other, Subsetlist): + if isinstance(other, SubsetUnion): for subset in self.subset_list: # check if ther is a subset in self that covers every subset in other if all(subset.covers(s) for s in other.subset_list): @@ -1087,7 +1087,7 @@ def covers_precise(self, other): true is returned when one of the subsets in self is equal to other """ - if isinstance(other, Subsetlist): + if isinstance(other, SubsetUnion): for subset in self.subset_list: # check if ther is a subset in self that covers every subset in other if all(subset.covers_precise(s) for s in other.subset_list): @@ -1113,7 +1113,7 @@ def dims(self): def union(self, other: Subset): """In place union of self with another Subset""" try: - if isinstance(other, Subsetlist): + if isinstance(other, SubsetUnion): self.subset_list += other.subset_list elif isinstance(other, Indices) or isinstance(other, Range): self.subset_list.append(other) @@ -1231,8 +1231,8 @@ def union(subset_a: Subset, subset_b: Subset) -> Subset: return subset_b elif subset_a is None and subset_b is None: raise TypeError('Both subsets cannot be None') - elif isinstance(subset_a, Subsetlist) or isinstance( - subset_b, Subsetlist): + elif isinstance(subset_a, SubsetUnion) or isinstance( + subset_b, SubsetUnion): return list_union(subset_a, subset_b) elif type(subset_a) != type(subset_b): return bounding_box_union(subset_a, subset_b) @@ -1269,14 +1269,14 @@ def list_union(subset_a: Subset, subset_b: Subset) -> Subset: elif subset_a is None and subset_b is None: raise TypeError('Both subsets cannot be None') elif type(subset_a) != type(subset_b): - if isinstance(subset_b, Subsetlist): - return Subsetlist(subset_b.subset_list.append(subset_a)) + if isinstance(subset_b, SubsetUnion): + return SubsetUnion(subset_b.subset_list.append(subset_a)) else: - return Subsetlist(subset_a.subset_list.append(subset_b)) - elif isinstance(subset_a, Subsetlist): - return Subsetlist(subset_a.subset_list + subset_b.subset_list) + return SubsetUnion(subset_a.subset_list.append(subset_b)) + elif isinstance(subset_a, SubsetUnion): + return SubsetUnion(subset_a.subset_list + subset_b.subset_list) else: - return Subsetlist([subset_a, subset_b]) + return SubsetUnion([subset_a, subset_b]) except TypeError: return None From 3c10cb126ee68b2558f636a73aedf71f4863a125 Mon Sep 17 00:00:00 2001 From: matteonussbauemer Date: Tue, 31 Oct 2023 12:32:07 +0100 Subject: [PATCH 09/19] rename occurences in comments (cherry picked from commit 1301c3a4ae6d4e634e0bdc38f94bfcf1ff677c88) --- dace/subsets.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/dace/subsets.py b/dace/subsets.py index f53520c5aa..068b330a07 100644 --- a/dace/subsets.py +++ b/dace/subsets.py @@ -1063,8 +1063,8 @@ def __init__(self, subset): def covers(self, other): """ - Returns True if this Subsetlist covers another subset (using a bounding box). - If other is another SubsetList then self and other will + Returns True if this SubsetUnion covers another subset (using a bounding box). + If other is another SubsetUnion then self and other will only return true if self is other. If other is a different type of subset true is returned when one of the subsets in self is equal to other. """ @@ -1081,8 +1081,8 @@ def covers(self, other): def covers_precise(self, other): """ - Returns True if this Subsetlist covers another - subset. If other is another SubsetList then self and other will + Returns True if this SubsetUnion covers another + subset. If other is another SubsetUnion then self and other will only return true if self is other. If other is a different type of subset true is returned when one of the subsets in self is equal to other """ @@ -1258,7 +1258,7 @@ def list_union(subset_a: Subset, subset_b: Subset) -> Subset: :param subset_a: The first subset. :param subset_b: The second subset. - :return: A Subsetlist object that contains all elements of subset_a and subset_b. + :return: A SubsetUnion object that contains all elements of subset_a and subset_b. """ # TODO(later): Merge subsets in both lists if possible try: From 6965b96aed93445d15f7a45720a27b3f310bc053 Mon Sep 17 00:00:00 2001 From: matteonussbauemer Date: Tue, 31 Oct 2023 18:57:09 +0100 Subject: [PATCH 10/19] upgrade sympy to 1.12 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index a0ac2e2d49..cd5189437e 100644 --- a/setup.py +++ b/setup.py @@ -73,7 +73,7 @@ }, include_package_data=True, install_requires=[ - 'numpy', 'networkx >= 2.5', 'astunparse', 'sympy<=1.9', 'pyyaml', 'ply', 'websockets', 'requests', 'flask', + 'numpy', 'networkx >= 2.5', 'astunparse', 'sympy>=1.12', 'pyyaml', 'ply', 'websockets', 'requests', 'flask', 'fparser >= 0.1.3', 'aenum >= 3.1', 'dataclasses; python_version < "3.7"', 'dill', 'pyreadline;platform_system=="Windows"', 'typing-compat; python_version < "3.8"' ] + cmake_requires, From 9ff33a709b4d90d515b69975802debabc6a9d1ff Mon Sep 17 00:00:00 2001 From: Christos Kotsalos Date: Wed, 1 Nov 2023 19:50:05 +0100 Subject: [PATCH 11/19] Fix for VS Code debug console: view opens sdfg in VS Code and not in browser (#1419) * Fix for VS Code debug console: view opens sdfg in VS Code and not in browser * Fix for VS Code debug console: view opens sdfg in VS Code and not in browser --------- Co-authored-by: Christos Kotsalos --- dace/cli/sdfv.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/dace/cli/sdfv.py b/dace/cli/sdfv.py index 3be8e1ca45..c0ff3da36d 100644 --- a/dace/cli/sdfv.py +++ b/dace/cli/sdfv.py @@ -36,7 +36,11 @@ def view(sdfg: dace.SDFG, filename: Optional[Union[str, int]] = None): """ # If vscode is open, try to open it inside vscode if filename is None: - if 'VSCODE_IPC_HOOK_CLI' in os.environ or 'VSCODE_GIT_IPC_HANDLE' in os.environ: + if ( + 'VSCODE_IPC_HOOK' in os.environ + or 'VSCODE_IPC_HOOK_CLI' in os.environ + or 'VSCODE_GIT_IPC_HANDLE' in os.environ + ): filename = tempfile.mktemp(suffix='.sdfg') sdfg.save(filename) os.system(f'code {filename}') From bd7a82b9a1f46e139b0a350afa1e40ee71c56c3f Mon Sep 17 00:00:00 2001 From: matteonussbauemer Date: Wed, 1 Nov 2023 21:33:09 +0100 Subject: [PATCH 12/19] Annotate tests with expected outcome --- tests/subset_covers_precise_test.py | 73 ++++++++++++++++++++++------- 1 file changed, 55 insertions(+), 18 deletions(-) diff --git a/tests/subset_covers_precise_test.py b/tests/subset_covers_precise_test.py index 644cfa20ee..185932ab53 100644 --- a/tests/subset_covers_precise_test.py +++ b/tests/subset_covers_precise_test.py @@ -1,14 +1,16 @@ # Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. import pytest + import dace -from dace.subsets import Indices, Subset, Range from dace.config import Config +from dace.subsets import Indices, Range -def test_integer_overlap_no_cover(): +def test_integer_overlap_same_step_no_cover(): """ - two overlapping subsets, neither of them covering the other + Tests ranges with overlapping bounding boxes neither of them covering the other. + The ranges have the same step size. Covers_precise should return false. """ subset1 = Range.from_string("0:10:1") subset2 = Range.from_string("5:11:1") @@ -16,15 +18,16 @@ def test_integer_overlap_no_cover(): assert (subset1.covers_precise(subset2) is False) assert (subset2.covers_precise(subset1) is False) - subset1 = Range.from_string("0:10:1, 3:8:1") - subset2 = Range.from_string("5:11:1, 2:9:1") + subset1 = Range.from_string("0:10:2") + subset2 = Range.from_string("2:11:1") assert (subset1.covers_precise(subset2) is False) assert (subset2.covers_precise(subset1) is False) def test_integer_bounding_box_cover_coprime_step(): """ - boundingbox of subset1 covers bb of subset2 but step sizes of the subsets are coprime + Tests ranges where the boundingbox of subset1 covers the boundingbox of subset2 but + step sizes of the subsets are coprime so subset1 does not cover subset2. """ subset1 = Range.from_string("0:10:3") subset2 = Range.from_string("0:10:2") @@ -44,6 +47,11 @@ def test_integer_bounding_box_cover_coprime_step(): def test_integer_same_step_different_start(): + """ + Tests range where the bounding box of subset1 covers the bounding box of subset2 + but since subset2 starts at an offset that is not a multiple subset1's stepsize it + is not contained in subset1. + """ subset1 = Range.from_string("0:10:3") subset2 = Range.from_string("1:10:3") @@ -51,6 +59,14 @@ def test_integer_same_step_different_start(): def test_integer_bounding_box_symbolic_step(): + """ + Tests ranges where the step is symbolic but the start and end are not. + For 2 subsets s1 and s2 where s1's start is equal to s2's start and both subsets' step + sizes are symbolic s1.covers_precise(s2) should only return true iff s2's step size is + a multiple of s1's step size. + For 2 subsets s1 and s2 where s1's start is not equal to s2's start and both subsets' step + sizes are symbolic, s1.covers_precise(s2) should return false. + """ subset1 = Range.from_string("0:20:s") subset2 = Range.from_string("0:10:s") subset3 = Range.from_string("0:10:2 * s") @@ -60,12 +76,17 @@ def test_integer_bounding_box_symbolic_step(): assert (subset3.covers_precise(subset1) is False) assert (subset3.covers_precise(subset2) is False) - subset1 = Range.from_string("0:20:s, 30:50:k") - subset2 = Range.from_string("0:10:s, 40:50:k") + subset1 = Range.from_string("30:50:k") + subset2 = Range.from_string("40:50:k") assert (subset1.covers_precise(subset2) is False) -def test_symbolic_boundaries(): +def test_ranges_symbolic_boundaries(): + """ + Tests where the boundaries of ranges are symbolic. + The function subset1.covers_precise(subset2) should return true only when the + start, end, and step size of subset1 are multiples of those in subset2 + """ subset1 = Range.from_string("N:M:1") subset2 = Range.from_string("N:M:2") assert (subset1.covers_precise(subset2)) @@ -83,6 +104,9 @@ def test_symbolic_boundaries(): def test_symbolic_boundaries_not_symbolic_positive(): + """ + Tests from test_symbolic_boundaries with symbolic_positive flag deactivated. + """ Config.set('optimizer', 'symbolic_positive', value=False) subset1 = Range.from_string("N:M:1") @@ -102,22 +126,35 @@ def test_symbolic_boundaries_not_symbolic_positive(): def test_range_indices(): - subset1 = Indices.from_string('0') + """ + Tests the handling of indices covering ranges and vice versa. + Given a range r and indices i: + If r's bounding box covers i r.covers_precise(i) should return true iff + i is covered by the step of r. + i.covers_precise(r) should only return true iff r.start == r.end == i. + If i is not in r's bounding box i.covers_precise(r) and r.covers_precise(i) + should return false + """ + subset1 = Indices.from_string('1') subset2 = Range.from_string('0:2:1') assert (subset2.covers_precise(subset1)) assert (subset1.covers_precise(subset2) is False) - subset1 = Indices.from_string('0') - subset2 = Range.from_string('0:1:1') - assert (subset2.covers_precise(subset1)) - assert (subset1.covers_precise(subset2)) + subset1 = Indices.from_string('3') + subset2 = Range.from_string('0:4:2') + assert (subset2.covers_precise(subset1) is False) + assert (subset2.covers_precise(subset1) is False) - subset1 = Indices.from_string('0, 1') - subset2 = Range.from_string('0:2:1, 2:4:1') + subset1 = Indices.from_string('3') + subset2 = Range.from_string('0:2:1') assert (subset2.covers_precise(subset1) is False) assert (subset1.covers_precise(subset2) is False) def test_index_index(): + """ + Tests the handling of indices covering indices. + Given two indices i1 and i2 i1.covers_precise should only return true iff i1 = i2 + """ subset1 = Indices.from_string('1') subset2 = Indices.from_string('1') assert (subset2.covers_precise(subset1)) @@ -162,11 +199,11 @@ def test_index_index(): if __name__ == "__main__": - test_integer_overlap_no_cover() + test_integer_overlap_same_step_no_cover() test_integer_bounding_box_cover_coprime_step() test_integer_same_step_different_start() test_integer_bounding_box_symbolic_step() - test_symbolic_boundaries() + test_ranges_symbolic_boundaries() test_symbolic_boundaries_not_symbolic_positive() test_range_indices() test_index_index() From 1e83e7112d6d79a96bd3637df5ff9d2b0563c6a8 Mon Sep 17 00:00:00 2001 From: matteonussbauemer Date: Wed, 1 Nov 2023 21:51:28 +0100 Subject: [PATCH 13/19] Upgrade sympy version in requirements.txt --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 5f804e1b4c..12c50a2eb5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,7 +19,7 @@ ply==3.11 PyYAML==6.0 requests==2.31.0 six==1.16.0 -sympy==1.9 +sympy==1.12 urllib3==2.0.7 websockets==11.0.3 Werkzeug==2.3.5 From ca75b88c96b2c97592515ca437a9749e1cea080d Mon Sep 17 00:00:00 2001 From: matteonussbauemer Date: Wed, 1 Nov 2023 23:18:50 +0100 Subject: [PATCH 14/19] fix config in test --- tests/subset_covers_precise_test.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/subset_covers_precise_test.py b/tests/subset_covers_precise_test.py index 185932ab53..8c688ea6c1 100644 --- a/tests/subset_covers_precise_test.py +++ b/tests/subset_covers_precise_test.py @@ -107,6 +107,7 @@ def test_symbolic_boundaries_not_symbolic_positive(): """ Tests from test_symbolic_boundaries with symbolic_positive flag deactivated. """ + symbolic_positive = Config.get('optimizer', 'symbolic_positive') Config.set('optimizer', 'symbolic_positive', value=False) subset1 = Range.from_string("N:M:1") @@ -124,6 +125,8 @@ def test_symbolic_boundaries_not_symbolic_positive(): assert (subset1.covers_precise(subset2) is False) assert (subset2.covers_precise(subset1) is False) + Config.set('optimizer', 'symbolic_positive', value=symbolic_positive) + def test_range_indices(): """ From d947bf87b1120f7612af5264a9fc690605920e50 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Thu, 2 Nov 2023 17:03:51 +0100 Subject: [PATCH 15/19] Hierarchical Control Flow / Control Flow Regions (#1404) * Adds just the framework for integral loops * Fix duplicate collapsed property on states * Fix inorrect parent class initialization * Add deprecation warning to is_start_state kwarg * Symbols and start block fixes * More symbol fixes * label and state list fix * Remove loop scope for now * Renaming * revert to traditional nodes-based iteration for now * Update docs * Add test for deprecation * Improve iteration function names * Remove obsolete property * Improve type naming * Remove obsolete scope_subgraph method --- dace/codegen/instrumentation/papi.py | 13 +- .../analysis/schedule_tree/sdfg_to_tree.py | 2 +- dace/sdfg/nodes.py | 5 +- dace/sdfg/replace.py | 29 +- dace/sdfg/sdfg.py | 261 +----- dace/sdfg/state.py | 824 +++++++++++++++--- dace/sdfg/utils.py | 56 +- .../dataflow/double_buffering.py | 6 +- dace/transformation/interstate/loop_unroll.py | 3 +- .../interstate/multistate_inline.py | 2 +- doc/sdfg/images/elements.svg | 592 +++++++++++-- doc/sdfg/ir.rst | 21 +- requirements.txt | 4 +- .../sdfg/nested_control_flow_regions_test.py | 18 + tests/sdfg_validate_names_test.py | 2 +- 15 files changed, 1331 insertions(+), 507 deletions(-) create mode 100644 tests/sdfg/nested_control_flow_regions_test.py diff --git a/dace/codegen/instrumentation/papi.py b/dace/codegen/instrumentation/papi.py index c0d3b657a1..4885611408 100644 --- a/dace/codegen/instrumentation/papi.py +++ b/dace/codegen/instrumentation/papi.py @@ -12,7 +12,7 @@ from dace.sdfg.graph import SubgraphView from dace.memlet import Memlet from dace.sdfg import scope_contains_scope -from dace.sdfg.state import StateGraphView +from dace.sdfg.state import DataflowGraphView import sympy as sp import os @@ -392,7 +392,7 @@ def should_instrument_entry(map_entry: EntryNode) -> bool: return cond @staticmethod - def has_surrounding_perfcounters(node, dfg: StateGraphView): + def has_surrounding_perfcounters(node, dfg: DataflowGraphView): """ Returns true if there is a possibility that this node is part of a section that is profiled. """ parent = dfg.entry_node(node) @@ -605,7 +605,7 @@ def get_memlet_byte_size(sdfg: dace.SDFG, memlet: Memlet): return memlet.volume * memdata.dtype.bytes @staticmethod - def get_out_memlet_costs(sdfg: dace.SDFG, state_id: int, node: nodes.Node, dfg: StateGraphView): + def get_out_memlet_costs(sdfg: dace.SDFG, state_id: int, node: nodes.Node, dfg: DataflowGraphView): scope_dict = sdfg.node(state_id).scope_dict() out_costs = 0 @@ -636,7 +636,10 @@ def get_out_memlet_costs(sdfg: dace.SDFG, state_id: int, node: nodes.Node, dfg: return out_costs @staticmethod - def get_tasklet_byte_accesses(tasklet: nodes.CodeNode, dfg: StateGraphView, sdfg: dace.SDFG, state_id: int) -> str: + def get_tasklet_byte_accesses(tasklet: nodes.CodeNode, + dfg: DataflowGraphView, + sdfg: dace.SDFG, + state_id: int) -> str: """ Get the amount of bytes processed by `tasklet`. The formula is sum(inedges * size) + sum(outedges * size) """ in_accum = [] @@ -693,7 +696,7 @@ def get_memory_input_size(node, sdfg, state_id) -> str: return sym2cpp(input_size) @staticmethod - def accumulate_byte_movement(outermost_node, node, dfg: StateGraphView, sdfg, state_id): + def accumulate_byte_movement(outermost_node, node, dfg: DataflowGraphView, sdfg, state_id): itvars = dict() # initialize an empty dict diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index 917f748cb8..084d46f47d 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -275,7 +275,7 @@ def remove_name_collisions(sdfg: SDFG): # Rename duplicate states for state in nsdfg.nodes(): if state.label in state_names_seen: - state.set_label(data.find_new_name(state.label, state_names_seen)) + state.label = data.find_new_name(state.label, state_names_seen) state_names_seen.add(state.label) replacements: Dict[str, str] = {} diff --git a/dace/sdfg/nodes.py b/dace/sdfg/nodes.py index 32369a19a3..a28e9fce38 100644 --- a/dace/sdfg/nodes.py +++ b/dace/sdfg/nodes.py @@ -262,9 +262,8 @@ def label(self): def __label__(self, sdfg, state): return self.data - def desc(self, sdfg): - from dace.sdfg import SDFGState, ScopeSubgraphView - if isinstance(sdfg, (SDFGState, ScopeSubgraphView)): + def desc(self, sdfg: Union['dace.sdfg.SDFG', 'dace.sdfg.SDFGState', 'dace.sdfg.ScopeSubgraphView']): + if isinstance(sdfg, (dace.sdfg.SDFGState, dace.sdfg.ScopeSubgraphView)): sdfg = sdfg.parent return sdfg.arrays[self.data] diff --git a/dace/sdfg/replace.py b/dace/sdfg/replace.py index 4b36fad4fe..a2c7b9a43c 100644 --- a/dace/sdfg/replace.py +++ b/dace/sdfg/replace.py @@ -175,17 +175,18 @@ def replace_datadesc_names(sdfg, repl: Dict[str, str]): sdfg.constants_prop[repl[aname]] = sdfg.constants_prop[aname] del sdfg.constants_prop[aname] - # Replace in interstate edges - for e in sdfg.edges(): - e.data.replace_dict(repl, replace_keys=False) - - for state in sdfg.nodes(): - # Replace in access nodes - for node in state.data_nodes(): - if node.data in repl: - node.data = repl[node.data] - - # Replace in memlets - for edge in state.edges(): - if edge.data.data in repl: - edge.data.data = repl[edge.data.data] + for cf in sdfg.all_control_flow_regions(): + # Replace in interstate edges + for e in cf.edges(): + e.data.replace_dict(repl, replace_keys=False) + + for block in cf.nodes(): + if isinstance(block, dace.SDFGState): + # Replace in access nodes + for node in block.data_nodes(): + if node.data in repl: + node.data = repl[node.data] + # Replace in memlets + for edge in block.edges(): + if edge.data.data in repl: + edge.data.data = repl[edge.data.data] diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index a85e773337..fdf8835c7e 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -30,7 +30,7 @@ from dace.frontend.python import astutils, wrappers from dace.sdfg import nodes as nd from dace.sdfg.graph import OrderedDiGraph, Edge, SubgraphView -from dace.sdfg.state import SDFGState +from dace.sdfg.state import SDFGState, ControlFlowRegion from dace.sdfg.propagation import propagate_memlets_sdfg from dace.distr_types import ProcessGrid, SubArray, RedistrArray from dace.dtypes import validate_name @@ -402,7 +402,7 @@ def label(self): @make_properties -class SDFG(OrderedDiGraph[SDFGState, InterstateEdge]): +class SDFG(ControlFlowRegion): """ The main intermediate representation of code in DaCe. A Stateful DataFlow multiGraph (SDFG) is a directed graph of directed @@ -499,8 +499,6 @@ def __init__(self, self._parent_sdfg = None self._parent_nsdfg_node = None self._sdfg_list = [self] - self._start_state: Optional[int] = None - self._cached_start_state: Optional[SDFGState] = None self._arrays = NestedDict() # type: Dict[str, dt.Array] self._labels: Set[str] = set() self.global_code = {'frame': CodeBlock("", dtypes.Language.CPP)} @@ -531,14 +529,14 @@ def __deepcopy__(self, memo): memo[id(self)] = result for k, v in self.__dict__.items(): # Skip derivative attributes - if k in ('_cached_start_state', '_edges', '_nodes', '_parent', '_parent_sdfg', '_parent_nsdfg_node', + if k in ('_cached_start_block', '_edges', '_nodes', '_parent', '_parent_sdfg', '_parent_nsdfg_node', '_sdfg_list', '_transformation_hist'): continue setattr(result, k, copy.deepcopy(v, memo)) # Copy edges and nodes result._edges = copy.deepcopy(self._edges, memo) result._nodes = copy.deepcopy(self._nodes, memo) - result._cached_start_state = copy.deepcopy(self._cached_start_state, memo) + result._cached_start_block = copy.deepcopy(self._cached_start_block, memo) # Copy parent attributes for k in ('_parent', '_parent_sdfg', '_parent_nsdfg_node'): if id(getattr(self, k)) in memo: @@ -583,7 +581,7 @@ def to_json(self, hash=False): tmp['attributes']['constants_prop'] = json.loads(dace.serialize.dumps(tmp['attributes']['constants_prop'])) tmp['sdfg_list_id'] = int(self.sdfg_id) - tmp['start_state'] = self._start_state + tmp['start_state'] = self._start_block tmp['attributes']['name'] = self.name if hash: @@ -627,7 +625,7 @@ def from_json(cls, json_obj, context_info=None): ret.add_edge(nodelist[int(e.src)], nodelist[int(e.dst)], e.data) if 'start_state' in json_obj: - ret._start_state = json_obj['start_state'] + ret._start_block = json_obj['start_state'] return ret @@ -753,14 +751,7 @@ def replace_dict(self, for array in self.arrays.values(): replace_properties_dict(array, repldict, symrepl) - if replace_in_graph: - # Replace in inter-state edges - for edge in self.edges(): - edge.data.replace_dict(repldict, replace_keys=replace_keys) - - # Replace in states - for state in self.nodes(): - state.replace_dict(repldict, symrepl) + super().replace_dict(repldict, symrepl, replace_in_graph, replace_keys) def add_symbol(self, name, stype): """ Adds a symbol to the SDFG. @@ -787,34 +778,11 @@ def remove_symbol(self, name): @property def start_state(self): - """ Returns the starting state of this SDFG. """ - if self._cached_start_state is not None: - return self._cached_start_state - - source_nodes = self.source_nodes() - if len(source_nodes) == 1: - self._cached_start_state = source_nodes[0] - return source_nodes[0] - # If starting state is ambiguous (i.e., loop to initial state or more - # than one possible start state), allow manually overriding start state - if self._start_state is not None: - self._cached_start_state = self.node(self._start_state) - return self._cached_start_state - raise ValueError('Ambiguous or undefined starting state for SDFG, ' - 'please use "is_start_state=True" when adding the ' - 'starting state with "add_state"') + return self.start_block @start_state.setter def start_state(self, state_id): - """ Manually sets the starting state of this SDFG. - - :param state_id: The node ID (use `node_id(state)`) of the - state to set. - """ - if state_id < 0 or state_id >= self.number_of_nodes(): - raise ValueError("Invalid state ID") - self._start_state = state_id - self._cached_start_state = self.node(state_id) + self.start_block = state_id def set_global_code(self, cpp_code: str, location: str = 'frame'): """ @@ -1127,7 +1095,7 @@ def remove_data(self, name, validate=True): # Verify that there are no access nodes that use this data if validate: - for state in self.nodes(): + for state in self.states(): for node in state.nodes(): if isinstance(node, nd.AccessNode) and node.data == name: raise ValueError(f"Cannot remove data descriptor " @@ -1243,75 +1211,14 @@ def parent_sdfg(self, value): def parent_nsdfg_node(self, value): self._parent_nsdfg_node = value - def add_node(self, node, is_start_state=False): - """ Adds a new node to the SDFG. Must be an SDFGState or a subclass - thereof. - - :param node: The node to add. - :param is_start_state: If True, sets this node as the starting - state. - """ - if not isinstance(node, SDFGState): - raise TypeError("Expected SDFGState, got " + str(type(node))) - super(SDFG, self).add_node(node) - self._cached_start_state = None - if is_start_state is True: - self.start_state = len(self.nodes()) - 1 - self._cached_start_state = node - def remove_node(self, node: SDFGState): - if node is self._cached_start_state: - self._cached_start_state = None + if node is self._cached_start_block: + self._cached_start_block = None return super().remove_node(node) - def add_edge(self, u, v, edge): - """ Adds a new edge to the SDFG. Must be an InterstateEdge or a - subclass thereof. - - :param u: Source node. - :param v: Destination node. - :param edge: The edge to add. - """ - if not isinstance(u, SDFGState): - raise TypeError("Expected SDFGState, got: {}".format(type(u).__name__)) - if not isinstance(v, SDFGState): - raise TypeError("Expected SDFGState, got: {}".format(type(v).__name__)) - if not isinstance(edge, InterstateEdge): - raise TypeError("Expected InterstateEdge, got: {}".format(type(edge).__name__)) - if v is self._cached_start_state: - self._cached_start_state = None - return super(SDFG, self).add_edge(u, v, edge) - def states(self): - """ Alias that returns the nodes (states) in this SDFG. """ - return self.nodes() - - def all_nodes_recursive(self) -> Iterator[Tuple[nd.Node, Union['SDFG', 'SDFGState']]]: - """ Iterate over all nodes in this SDFG, including states, nodes in - states, and recursive states and nodes within nested SDFGs, - returning tuples on the form (node, parent), where the parent is - either the SDFG (for states) or a DFG (nodes). """ - for node in self.nodes(): - yield node, self - yield from node.all_nodes_recursive() - - def all_sdfgs_recursive(self): - """ Iterate over this and all nested SDFGs. """ - yield self - for state in self.nodes(): - for node in state.nodes(): - if isinstance(node, nd.NestedSDFG): - yield from node.sdfg.all_sdfgs_recursive() - - def all_edges_recursive(self): - """ Iterate over all edges in this SDFG, including state edges, - inter-state edges, and recursively edges within nested SDFGs, - returning tuples on the form (edge, parent), where the parent is - either the SDFG (for states) or a DFG (nodes). """ - for e in self.edges(): - yield e, self - for node in self.nodes(): - yield from node.all_edges_recursive() + """ Returns the states in this SDFG, recursing into state scope blocks. """ + return list(self.all_states()) def arrays_recursive(self): """ Iterate over all arrays in this SDFG, including arrays within @@ -1323,19 +1230,15 @@ def arrays_recursive(self): if isinstance(node, nd.NestedSDFG): yield from node.sdfg.arrays_recursive() - def used_symbols(self, all_symbols: bool, keep_defined_in_mapping: bool=False) -> Set[str]: - """ - Returns a set of symbol names that are used by the SDFG, but not - defined within it. This property is used to determine the symbolic - parameters of the SDFG. - - :param all_symbols: If False, only returns the set of symbols that will be used - in the generated code and are needed as arguments. - :param keep_defined_in_mapping: If True, symbols defined in inter-state edges that are in the symbol mapping - will be removed from the set of defined symbols. - """ - defined_syms = set() - free_syms = set() + def _used_symbols_internal(self, + all_symbols: bool, + defined_syms: Optional[Set]=None, + free_syms: Optional[Set]=None, + used_before_assignment: Optional[Set]=None, + keep_defined_in_mapping: bool=False) -> Tuple[Set[str], Set[str], Set[str]]: + defined_syms = set() if defined_syms is None else defined_syms + free_syms = set() if free_syms is None else free_syms + used_before_assignment = set() if used_before_assignment is None else used_before_assignment # Exclude data descriptor names and constants for name in self.arrays.keys(): @@ -1349,54 +1252,10 @@ def used_symbols(self, all_symbols: bool, keep_defined_in_mapping: bool=False) - for code in self.exit_code.values(): free_syms |= symbolic.symbols_in_code(code.as_string, self.symbols.keys()) - # Add free state symbols - used_before_assignment = set() - - try: - ordered_states = self.topological_sort(self.start_state) - except ValueError: # Failsafe (e.g., for invalid or empty SDFGs) - ordered_states = self.nodes() - - for state in ordered_states: - state_fsyms = state.used_symbols(all_symbols) - free_syms |= state_fsyms - - # Add free inter-state symbols - for e in self.out_edges(state): - # NOTE: First we get the true InterstateEdge free symbols, then we compute the newly defined symbols by - # subracting the (true) free symbols from the edge's assignment keys. This way we can correctly - # compute the symbols that are used before being assigned. - efsyms = e.data.used_symbols(all_symbols) - defined_syms |= set(e.data.assignments.keys()) - (efsyms | state_fsyms) - used_before_assignment.update(efsyms - defined_syms) - free_syms |= efsyms - - # Remove symbols that were used before they were assigned - defined_syms -= used_before_assignment - - # Remove from defined symbols those that are in the symbol mapping - if self.parent_nsdfg_node is not None and keep_defined_in_mapping: - defined_syms -= set(self.parent_nsdfg_node.symbol_mapping.keys()) - - # Add the set of SDFG symbol parameters - # If all_symbols is False, those symbols would only be added in the case of non-Python tasklets - if all_symbols: - free_syms |= set(self.symbols.keys()) - - # Subtract symbols defined in inter-state edges and constants - return free_syms - defined_syms - - @property - def free_symbols(self) -> Set[str]: - """ - Returns a set of symbol names that are used by the SDFG, but not - defined within it. This property is used to determine the symbolic - parameters of the SDFG and verify that ``SDFG.symbols`` is complete. - - :note: Assumes that the graph is valid (i.e., without undefined or - overlapping symbols). - """ - return self.used_symbols(all_symbols=True) + return super()._used_symbols_internal( + all_symbols=all_symbols, keep_defined_in_mapping=keep_defined_in_mapping, + defined_syms=defined_syms, free_syms=free_syms, used_before_assignment=used_before_assignment + ) def get_all_toplevel_symbols(self) -> Set[str]: """ @@ -1608,16 +1467,16 @@ def shared_transients(self, check_toplevel=True) -> List[str]: shared = [] # If a transient is present in an inter-state edge, it is shared - for interstate_edge in self.edges(): + for interstate_edge in self.all_interstate_edges(): for sym in interstate_edge.data.free_symbols: if sym in self.arrays and self.arrays[sym].transient: seen[sym] = interstate_edge shared.append(sym) # If transient is accessed in more than one state, it is shared - for state in self.nodes(): - for node in state.nodes(): - if isinstance(node, nd.AccessNode) and node.desc(self).transient: + for state in self.states(): + for node in state.data_nodes(): + if node.desc(self).transient: if (check_toplevel and node.desc(self).toplevel) or (node.data in seen and seen[node.data] != state): shared.append(node.data) @@ -1706,62 +1565,6 @@ def from_file(filename: str) -> 'SDFG': # Dynamic SDFG creation API ############################## - def add_state(self, label=None, is_start_state=False) -> 'SDFGState': - """ Adds a new SDFG state to this graph and returns it. - - :param label: State label. - :param is_start_state: If True, resets SDFG starting state to this - state. - :return: A new SDFGState object. - """ - if self._labels is None or len(self._labels) != self.number_of_nodes(): - self._labels = set(s.label for s in self.nodes()) - label = label or 'state' - existing_labels = self._labels - label = dt.find_new_name(label, existing_labels) - state = SDFGState(label, self) - self._labels.add(label) - - self.add_node(state, is_start_state=is_start_state) - return state - - def add_state_before(self, state: 'SDFGState', label=None, is_start_state=False) -> 'SDFGState': - """ Adds a new SDFG state before an existing state, reconnecting - predecessors to it instead. - - :param state: The state to prepend the new state before. - :param label: State label. - :param is_start_state: If True, resets SDFG starting state to this - state. - :return: A new SDFGState object. - """ - new_state = self.add_state(label, is_start_state) - # Reconnect - for e in self.in_edges(state): - self.remove_edge(e) - self.add_edge(e.src, new_state, e.data) - # Add unconditional connection between the new state and the current - self.add_edge(new_state, state, InterstateEdge()) - return new_state - - def add_state_after(self, state: 'SDFGState', label=None, is_start_state=False) -> 'SDFGState': - """ Adds a new SDFG state after an existing state, reconnecting - it to the successors instead. - - :param state: The state to append the new state after. - :param label: State label. - :param is_start_state: If True, resets SDFG starting state to this - state. - :return: A new SDFGState object. - """ - new_state = self.add_state(label, is_start_state) - # Reconnect - for e in self.out_edges(state): - self.remove_edge(e) - self.add_edge(new_state, e.dst, e.data) - # Add unconditional connection between the current and the new state - self.add_edge(state, new_state, InterstateEdge()) - return new_state def _find_new_name(self, name: str): """ Tries to find a new name by adding an underscore and a number. """ @@ -2482,7 +2285,7 @@ def __call__(self, *args, **kwargs): def fill_scope_connectors(self): """ Fills missing scope connectors (i.e., "IN_#"/"OUT_#" on entry/exit nodes) according to data on the memlets. """ - for state in self.nodes(): + for state in self.states(): state.fill_scope_connectors() def predecessor_state_transitions(self, state): diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 1ff8fe4cf1..097365fbc3 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -2,6 +2,7 @@ """ Contains classes of a single SDFG state and dataflow subgraphs. """ import ast +import abc import collections import copy import inspect @@ -19,7 +20,7 @@ from dace.properties import (CodeBlock, DictProperty, EnumProperty, Property, SubsetProperty, SymbolicProperty, CodeProperty, make_properties) from dace.sdfg import nodes as nd -from dace.sdfg.graph import MultiConnectorEdge, OrderedMultiDiConnectorGraph, SubgraphView +from dace.sdfg.graph import MultiConnectorEdge, OrderedMultiDiConnectorGraph, SubgraphView, OrderedDiGraph, Edge from dace.sdfg.propagation import propagate_memlet from dace.sdfg.validation import validate_state from dace.subsets import Range, Subset @@ -28,6 +29,11 @@ import dace.sdfg.scope +NodeT = Union[nd.Node, 'ControlFlowBlock'] +EdgeT = Union[MultiConnectorEdge[mm.Memlet], Edge['dace.sdfg.InterstateEdge']] +GraphT = Union['ControlFlowRegion', 'SDFGState'] + + def _getdebuginfo(old_dinfo=None) -> dtypes.DebugInfo: """ Returns a DebugInfo object for the position that called this function. @@ -66,13 +72,248 @@ def _make_iterators(ndrange): return params, map_range -class StateGraphView(object): +class BlockGraphView(object): """ - Read-only view interface of an SDFG state, containing methods for memlet - tracking, traversal, subgraph creation, queries, and replacements. - ``SDFGState`` and ``StateSubgraphView`` inherit from this class to share + Read-only view interface of an SDFG control flow block, containing methods for memlet tracking, traversal, subgraph + creation, queries, and replacements. ``ControlFlowBlock`` and ``StateSubgraphView`` inherit from this class to share methods. """ + + + ################################################################### + # Typing overrides + + @overload + def nodes(self) -> List[NodeT]: + ... + + @overload + def edges(self) -> List[EdgeT]: + ... + + @overload + def in_degree(self, node: NodeT) -> int: + ... + + @overload + def out_degree(self, node: NodeT) -> int: + ... + + ################################################################### + # Traversal methods + + @abc.abstractmethod + def all_nodes_recursive(self) -> Iterator[Tuple[NodeT, GraphT]]: + """ + Iterate over all nodes in this graph or subgraph. + This includes control flow blocks, nodes in those blocks, and recursive control flow blocks and nodes within + nested SDFGs. It returns tuples of the form (node, parent), where the node is either a dataflow node, in which + case the parent is an SDFG state, or a control flow block, in which case the parent is a control flow graph + (i.e., an SDFG or a scope block). + """ + raise NotImplementedError() + + @abc.abstractmethod + def all_edges_recursive(self) -> Iterator[Tuple[EdgeT, GraphT]]: + """ + Iterate over all edges in this graph or subgraph. + This includes dataflow edges, inter-state edges, and recursive edges within nested SDFGs. It returns tuples of + the form (edge, parent), where the edge is either a dataflow edge, in which case the parent is an SDFG state, or + an inter-stte edge, in which case the parent is a control flow graph (i.e., an SDFG or a scope block). + """ + raise NotImplementedError() + + @abc.abstractmethod + def data_nodes(self) -> List[nd.AccessNode]: + """ + Returns all data nodes (i.e., AccessNodes, arrays) present in this graph or subgraph. + Note: This does not recurse into nested SDFGs. + """ + raise NotImplementedError() + + @abc.abstractmethod + def entry_node(self, node: nd.Node) -> nd.EntryNode: + """ Returns the entry node that wraps the current node, or None if it is top-level in a state. """ + raise NotImplementedError() + + @abc.abstractmethod + def exit_node(self, entry_node: nd.EntryNode) -> nd.ExitNode: + """ Returns the exit node leaving the context opened by the given entry node. """ + raise NotImplementedError() + + ################################################################### + # Memlet-tracking methods + + @abc.abstractmethod + def memlet_path(self, edge: MultiConnectorEdge[mm.Memlet]) -> List[MultiConnectorEdge[mm.Memlet]]: + """ + Given one edge, returns a list of edges representing a path between its source and sink nodes. + Used for memlet tracking. + + :note: Behavior is undefined when there is more than one path involving this edge. + :param edge: An edge within a state (memlet). + :return: A list of edges from a source node to a destination node. + """ + raise NotImplementedError() + + @abc.abstractmethod + def memlet_tree(self, edge: MultiConnectorEdge) -> mm.MemletTree: + """ + Given one edge, returns a tree of edges between its node source(s) and sink(s). + Used for memlet tracking. + + :param edge: An edge within a state (memlet). + :return: A tree of edges whose root is the source/sink node (depending on direction) and associated children + edges. + """ + raise NotImplementedError() + + @abc.abstractmethod + def in_edges_by_connector(self, node: nd.Node, connector: AnyStr) -> Iterable[MultiConnectorEdge[mm.Memlet]]: + """ + Returns a generator over edges entering the given connector of the given node. + + :param node: Destination node of edges. + :param connector: Destination connector of edges. + """ + raise NotImplementedError() + + @abc.abstractmethod + def out_edges_by_connector(self, node: nd.Node, connector: AnyStr) -> Iterable[MultiConnectorEdge[mm.Memlet]]: + """ + Returns a generator over edges exiting the given connector of the given node. + + :param node: Source node of edges. + :param connector: Source connector of edges. + """ + raise NotImplementedError() + + @abc.abstractmethod + def edges_by_connector(self, node: nd.Node, connector: AnyStr) -> Iterable[MultiConnectorEdge[mm.Memlet]]: + """ + Returns a generator over edges entering or exiting the given connector of the given node. + + :param node: Source/destination node of edges. + :param connector: Source/destination connector of edges. + """ + raise NotImplementedError() + + ################################################################### + # Query, subgraph, and replacement methods + + @abc.abstractmethod + def used_symbols(self, all_symbols: bool, keep_defined_in_mapping: bool=False) -> Set[str]: + """ + Returns a set of symbol names that are used in the graph. + + :param all_symbols: If False, only returns symbols that are needed as arguments (only used in generated code). + :param keep_defined_in_mapping: If True, symbols defined in inter-state edges that are in the symbol mapping + will be removed from the set of defined symbols. + """ + raise NotImplementedError() + + @property + def free_symbols(self) -> Set[str]: + """ + Returns a set of symbol names that are used, but not defined, in this graph view. + In the case of an SDFG, this property is used to determine the symbolic parameters of the SDFG and + verify that ``SDFG.symbols`` is complete. + + :note: Assumes that the graph is valid (i.e., without undefined or overlapping symbols). + """ + return self.used_symbols(all_symbols=True) + + @abc.abstractmethod + def read_and_write_sets(self) -> Tuple[Set[AnyStr], Set[AnyStr]]: + """ + Determines what data is read and written in this graph. + Does not include reads to subsets of containers that have previously been written within the same state. + + :return: A two-tuple of sets of things denoting ({data read}, {data written}). + """ + raise NotImplementedError() + + @abc.abstractmethod + def unordered_arglist(self, + defined_syms=None, + shared_transients=None) -> Tuple[Dict[str, dt.Data], Dict[str, dt.Data]]: + raise NotImplementedError() + + def arglist(self, defined_syms=None, shared_transients=None) -> Dict[str, dt.Data]: + """ + Returns an ordered dictionary of arguments (names and types) required to invoke this subgraph. + + The arguments differ from SDFG.arglist, but follow the same order, + namely: , . + + Data arguments contain: + * All used non-transient data containers in the subgraph + * All used transient data containers that were allocated outside. + This includes data from memlets, transients shared across multiple states, and transients that could not + be allocated within the subgraph (due to their ``AllocationLifetime`` or according to the + ``dtypes.can_allocate`` function). + + Scalar arguments contain: + * Free symbols in this state/subgraph. + * All transient and non-transient scalar data containers used in this subgraph. + + This structure will create a sorted list of pointers followed by a sorted list of PoDs and structs. + + :return: An ordered dictionary of (name, data descriptor type) of all the arguments, sorted as defined here. + """ + data_args, scalar_args = self.unordered_arglist(defined_syms, shared_transients) + + # Fill up ordered dictionary + result = collections.OrderedDict() + for k, v in itertools.chain(sorted(data_args.items()), sorted(scalar_args.items())): + result[k] = v + + return result + + def signature_arglist(self, with_types=True, for_call=False): + """ Returns a list of arguments necessary to call this state or subgraph, formatted as a list of C definitions. + + :param with_types: If True, includes argument types in the result. + :param for_call: If True, returns arguments that can be used when calling the SDFG. + :return: A list of strings. For example: `['float *A', 'int b']`. + """ + return [v.as_arg(name=k, with_types=with_types, for_call=for_call) for k, v in self.arglist().items()] + + @abc.abstractmethod + def top_level_transients(self) -> Set[str]: + """Iterate over top-level transients of this graph.""" + raise NotImplementedError() + + @abc.abstractmethod + def all_transients(self) -> List[str]: + """Iterate over all transients in this graph.""" + raise NotImplementedError() + + @abc.abstractmethod + def replace(self, name: str, new_name: str): + """ + Finds and replaces all occurrences of a symbol or array in this graph. + + :param name: Name to find. + :param new_name: Name to replace. + """ + raise NotImplementedError() + + @abc.abstractmethod + def replace_dict(self, + repl: Dict[str, str], + symrepl: Optional[Dict[symbolic.SymbolicType, symbolic.SymbolicType]] = None): + """ + Finds and replaces all occurrences of a set of symbols or arrays in this graph. + + :param repl: Mapping from names to replacements. + :param symrepl: Optional symbolic version of ``repl``. + """ + raise NotImplementedError() + + +@make_properties +class DataflowGraphView(BlockGraphView, abc.ABC): def __init__(self, *args, **kwargs): self._clear_scopedict_cache() @@ -91,29 +332,29 @@ def edges(self) -> List[MultiConnectorEdge[mm.Memlet]]: ################################################################### # Traversal methods - def all_nodes_recursive(self): + def all_nodes_recursive(self) -> Iterator[Tuple[NodeT, GraphT]]: for node in self.nodes(): yield node, self if isinstance(node, nd.NestedSDFG): yield from node.sdfg.all_nodes_recursive() - def all_edges_recursive(self): + def all_edges_recursive(self) -> Iterator[Tuple[EdgeT, GraphT]]: for e in self.edges(): yield e, self for node in self.nodes(): if isinstance(node, nd.NestedSDFG): yield from node.sdfg.all_edges_recursive() - def data_nodes(self): + def data_nodes(self) -> List[nd.AccessNode]: """ Returns all data_nodes (arrays) present in this state. """ return [n for n in self.nodes() if isinstance(n, nd.AccessNode)] - def entry_node(self, node: nd.Node) -> nd.EntryNode: + def entry_node(self, node: nd.Node) -> Optional[nd.EntryNode]: """ Returns the entry node that wraps the current node, or None if it is top-level in a state. """ return self.scope_dict()[node] - def exit_node(self, entry_node: nd.EntryNode) -> nd.ExitNode: + def exit_node(self, entry_node: nd.EntryNode) -> Optional[nd.ExitNode]: """ Returns the exit node leaving the context opened by the given entry node. """ node_to_children = self.scope_children() @@ -152,7 +393,7 @@ def memlet_path(self, edge: MultiConnectorEdge[mm.Memlet]) -> List[MultiConnecto result.insert(0, next_edge) curedge = next_edge - # Prepend outgoing edges until reaching the sink node + # Append outgoing edges until reaching the sink node curedge = edge while not isinstance(curedge.dst, (nd.CodeNode, nd.AccessNode)): # Trace through scope entry using IN_# -> OUT_# @@ -168,13 +409,6 @@ def memlet_path(self, edge: MultiConnectorEdge[mm.Memlet]) -> List[MultiConnecto return result def memlet_tree(self, edge: MultiConnectorEdge) -> mm.MemletTree: - """ Given one edge, returns a tree of edges between its node source(s) - and sink(s). Used for memlet tracking. - - :param edge: An edge within this state. - :return: A tree of edges whose root is the source/sink node - (depending on direction) and associated children edges. - """ propagate_forward = False propagate_backward = False if ((isinstance(edge.src, nd.EntryNode) and edge.src_conn is not None) or @@ -246,30 +480,12 @@ def traverse(node): return traverse(tree_root) def in_edges_by_connector(self, node: nd.Node, connector: AnyStr) -> Iterable[MultiConnectorEdge[mm.Memlet]]: - """ Returns a generator over edges entering the given connector of the - given node. - - :param node: Destination node of edges. - :param connector: Destination connector of edges. - """ return (e for e in self.in_edges(node) if e.dst_conn == connector) def out_edges_by_connector(self, node: nd.Node, connector: AnyStr) -> Iterable[MultiConnectorEdge[mm.Memlet]]: - """ Returns a generator over edges exiting the given connector of the - given node. - - :param node: Source node of edges. - :param connector: Source connector of edges. - """ return (e for e in self.out_edges(node) if e.src_conn == connector) def edges_by_connector(self, node: nd.Node, connector: AnyStr) -> Iterable[MultiConnectorEdge[mm.Memlet]]: - """ Returns a generator over edges entering or exiting the given - connector of the given node. - - :param node: Source/destination node of edges. - :param connector: Source/destination connector of edges. - """ return itertools.chain(self.in_edges_by_connector(node, connector), self.out_edges_by_connector(node, connector)) @@ -297,8 +513,6 @@ def scope_tree(self) -> 'dace.sdfg.scope.ScopeTree': result = {} - sdfg_symbols = self.parent.symbols.keys() - # Get scopes for node, scopenodes in sdc.items(): if node is None: @@ -325,15 +539,7 @@ def scope_leaves(self) -> List['dace.sdfg.scope.ScopeTree']: self._scope_leaves_cached = [scope for scope in st.values() if len(scope.children) == 0] return copy.copy(self._scope_leaves_cached) - def scope_dict(self, return_ids: bool = False, validate: bool = True) -> Dict[nd.Node, Optional[nd.Node]]: - """ Returns a dictionary that maps each SDFG node to its parent entry - node, or to None if the node is not in any scope. - - :param return_ids: Return node ID numbers instead of node objects. - :param validate: Ensure that the graph is not malformed when - computing dictionary. - :return: The mapping from a node to its parent scope entry node. - """ + def scope_dict(self, return_ids: bool = False, validate: bool = True) -> Dict[nd.Node, Union['SDFGState', nd.Node]]: from dace.sdfg.scope import _scope_dict_inner, _scope_dict_to_ids result = None result = copy.copy(self._scope_dict_toparent_cached) @@ -367,16 +573,7 @@ def scope_dict(self, return_ids: bool = False, validate: bool = True) -> Dict[nd def scope_children(self, return_ids: bool = False, - validate: bool = True) -> Dict[Optional[nd.EntryNode], List[nd.Node]]: - """ Returns a dictionary that maps each SDFG entry node to its children, - not including the children of children entry nodes. The key `None` - contains a list of top-level nodes (i.e., not in any scope). - - :param return_ids: Return node ID numbers instead of node objects. - :param validate: Ensure that the graph is not malformed when - computing dictionary. - :return: The mapping from a node to a list of children nodes. - """ + validate: bool = True) -> Dict[Union[nd.Node, 'SDFGState'], List[nd.Node]]: from dace.sdfg.scope import _scope_dict_inner, _scope_dict_to_ids result = None if self._scope_dict_tochildren_cached is not None: @@ -419,13 +616,7 @@ def is_leaf_memlet(self, e): return False return True - def used_symbols(self, all_symbols: bool) -> Set[str]: - """ - Returns a set of symbol names that are used in the state. - - :param all_symbols: If False, only returns the set of symbols that will be used - in the generated code and are needed as arguments. - """ + def used_symbols(self, all_symbols: bool, keep_defined_in_mapping: bool=False) -> Set[str]: state = self.graph if isinstance(self, SubgraphView) else self sdfg = state.parent new_symbols = set() @@ -579,33 +770,9 @@ def read_and_write_sets(self) -> Tuple[Set[AnyStr], Set[AnyStr]]: read_set, write_set = self._read_and_write_sets() return set(read_set.keys()), set(write_set.keys()) - def arglist(self, defined_syms=None, shared_transients=None) -> Dict[str, dt.Data]: - """ - Returns an ordered dictionary of arguments (names and types) required - to invoke this SDFG state or subgraph thereof. - - The arguments differ from SDFG.arglist, but follow the same order, - namely: , . - - Data arguments contain: - * All used non-transient data containers in the subgraph - * All used transient data containers that were allocated outside. - This includes data from memlets, transients shared across multiple - states, and transients that could not be allocated within the - subgraph (due to their ``AllocationLifetime`` or according to the - ``dtypes.can_allocate`` function). - - Scalar arguments contain: - * Free symbols in this state/subgraph. - * All transient and non-transient scalar data containers used in - this subgraph. - - This structure will create a sorted list of pointers followed by a - sorted list of PoDs and structs. - - :return: An ordered dictionary of (name, data descriptor type) of all - the arguments, sorted as defined here. - """ + def unordered_arglist(self, + defined_syms=None, + shared_transients=None) -> Tuple[Dict[str, dt.Data], Dict[str, dt.Data]]: sdfg: 'dace.sdfg.SDFG' = self.parent shared_transients = shared_transients or sdfg.shared_transients() sdict = self.scope_dict() @@ -699,12 +866,7 @@ def arglist(self, defined_syms=None, shared_transients=None) -> Dict[str, dt.Dat if not str(k).startswith('__dace') and str(k) not in sdfg.constants }) - # Fill up ordered dictionary - result = collections.OrderedDict() - for k, v in itertools.chain(sorted(data_args.items()), sorted(scalar_args.items())): - result[k] = v - - return result + return data_args, scalar_args def signature_arglist(self, with_types=True, for_call=False): """ Returns a list of arguments necessary to call this state or @@ -749,22 +911,212 @@ def replace(self, name: str, new_name: str): def replace_dict(self, repl: Dict[str, str], symrepl: Optional[Dict[symbolic.SymbolicType, symbolic.SymbolicType]] = None): - """ Finds and replaces all occurrences of a set of symbols or arrays in this state. - - :param repl: Mapping from names to replacements. - :param symrepl: Optional symbolic version of ``repl``. - """ from dace.sdfg.replace import replace_dict replace_dict(self, repl, symrepl) @make_properties -class SDFGState(OrderedMultiDiConnectorGraph[nd.Node, mm.Memlet], StateGraphView): +class ControlGraphView(BlockGraphView, abc.ABC): + + ################################################################### + # Typing overrides + + @overload + def nodes(self) -> List['ControlFlowBlock']: + ... + + @overload + def edges(self) -> List[Edge['dace.sdfg.InterstateEdge']]: + ... + + ################################################################### + # Traversal methods + + def all_nodes_recursive(self) -> Iterator[Tuple[NodeT, GraphT]]: + for node in self.nodes(): + yield node, self + yield from node.all_nodes_recursive() + + def all_edges_recursive(self) -> Iterator[Tuple[EdgeT, GraphT]]: + for e in self.edges(): + yield e, self + for node in self.nodes(): + yield from node.all_edges_recursive() + + def data_nodes(self) -> List[nd.AccessNode]: + data_nodes = [] + for node in self.nodes(): + data_nodes.extend(node.data_nodes()) + return data_nodes + + def entry_node(self, node: nd.Node) -> Optional[nd.EntryNode]: + for block in self.nodes(): + if node in block.nodes(): + return block.exit_node(node) + return None + + def exit_node(self, entry_node: nd.EntryNode) -> Optional[nd.ExitNode]: + for block in self.nodes(): + if entry_node in block.nodes(): + return block.exit_node(entry_node) + return None + + ################################################################### + # Memlet-tracking methods + + def memlet_path(self, edge: MultiConnectorEdge[mm.Memlet]) -> List[MultiConnectorEdge[mm.Memlet]]: + for block in self.nodes(): + if edge in block.edges(): + return block.memlet_path(edge) + return [] + + def memlet_tree(self, edge: MultiConnectorEdge) -> mm.MemletTree: + for block in self.nodes(): + if edge in block.edges(): + return block.memlet_tree(edge) + return mm.MemletTree(edge) + + def in_edges_by_connector(self, node: nd.Node, connector: AnyStr) -> Iterable[MultiConnectorEdge[mm.Memlet]]: + for block in self.nodes(): + if node in block.nodes(): + return block.in_edges_by_connector(node, connector) + return [] + + def out_edges_by_connector(self, node: nd.Node, connector: AnyStr) -> Iterable[MultiConnectorEdge[mm.Memlet]]: + for block in self.nodes(): + if node in block.nodes(): + return block.out_edges_by_connector(node, connector) + return [] + + def edges_by_connector(self, node: nd.Node, connector: AnyStr) -> Iterable[MultiConnectorEdge[mm.Memlet]]: + for block in self.nodes(): + if node in block.nodes(): + return block.edges_by_connector(node, connector) + + ################################################################### + # Query, subgraph, and replacement methods + + @abc.abstractmethod + def _used_symbols_internal(self, + all_symbols: bool, + defined_syms: Optional[Set] = None, + free_syms: Optional[Set] = None, + used_before_assignment: Optional[Set] = None, + keep_defined_in_mapping: bool = False) -> Tuple[Set[str], Set[str], Set[str]]: + raise NotImplementedError() + + def used_symbols(self, all_symbols: bool, keep_defined_in_mapping: bool=False) -> Set[str]: + return self._used_symbols_internal(all_symbols, keep_defined_in_mapping=keep_defined_in_mapping)[0] + + def read_and_write_sets(self) -> Tuple[Set[AnyStr], Set[AnyStr]]: + read_set = set() + write_set = set() + for block in self.nodes(): + for edge in self.in_edges(block): + read_set |= edge.data.free_symbols & self.sdfg.arrays.keys() + rs, ws = block.read_and_write_sets() + read_set.update(rs) + write_set.update(ws) + return read_set, write_set + + def unordered_arglist(self, + defined_syms=None, + shared_transients=None) -> Tuple[Dict[str, dt.Data], Dict[str, dt.Data]]: + data_args = {} + scalar_args = {} + for block in self.nodes(): + n_data_args, n_scalar_args = block.unordered_arglist(defined_syms, shared_transients) + data_args.update(n_data_args) + scalar_args.update(n_scalar_args) + return data_args, scalar_args + + def top_level_transients(self) -> Set[str]: + res = set() + for block in self.nodes(): + res.update(block.top_level_transients()) + return res + + def all_transients(self) -> List[str]: + res = [] + for block in self.nodes(): + res.extend(block.all_transients()) + return dtypes.deduplicate(res) + + def replace(self, name: str, new_name: str): + for n in self.nodes(): + n.replace(name, new_name) + + def replace_dict(self, + repl: Dict[str, str], + symrepl: Optional[Dict[symbolic.SymbolicType, symbolic.SymbolicType]] = None, + replace_in_graph: bool = True, replace_keys: bool = False): + symrepl = symrepl or { + symbolic.symbol(k): symbolic.pystr_to_symbolic(v) if isinstance(k, str) else v + for k, v in repl.items() + } + + if replace_in_graph: + # Replace in inter-state edges + for edge in self.edges(): + edge.data.replace_dict(repl, replace_keys=replace_keys) + + # Replace in states + for state in self.nodes(): + state.replace_dict(repl, symrepl) + +@make_properties +class ControlFlowBlock(BlockGraphView, abc.ABC): + + is_collapsed = Property(dtype=bool, desc='Show this block as collapsed', default=False) + + _label: str + + def __init__(self, label: str=''): + super(ControlFlowBlock, self).__init__() + self._label = label + self._default_lineinfo = None + self.is_collapsed = False + + def set_default_lineinfo(self, lineinfo: dace.dtypes.DebugInfo): + """ + Sets the default source line information to be lineinfo, or None to + revert to default mode. + """ + self._default_lineinfo = lineinfo + + def to_json(self, parent=None): + tmp = { + 'type': self.__class__.__name__, + 'collapsed': self.is_collapsed, + 'label': self._label, + 'id': parent.node_id(self) if parent is not None else None, + } + return tmp + + def __str__(self): + return self._label + + def __repr__(self) -> str: + return f'ControlFlowBlock ({self.label})' + + @property + def label(self) -> str: + return self._label + + @label.setter + def label(self, label: str): + self._label = label + + @property + def name(self) -> str: + return self._label + + +@make_properties +class SDFGState(OrderedMultiDiConnectorGraph[nd.Node, mm.Memlet], ControlFlowBlock, DataflowGraphView): """ An acyclic dataflow multigraph in an SDFG, corresponding to a single state in the SDFG state machine. """ - is_collapsed = Property(dtype=bool, desc="Show this node/scope/state as collapsed", default=False) - nosync = Property(dtype=bool, default=False, desc="Do not synchronize at the end of the state") instrument = EnumProperty(dtype=dtypes.InstrumentationType, @@ -803,13 +1155,14 @@ def __init__(self, label=None, sdfg=None, debuginfo=None, location=None): :param debuginfo: Source code locator for debugging. """ from dace.sdfg.sdfg import SDFG # Avoid import loop + OrderedMultiDiConnectorGraph.__init__(self) + ControlFlowBlock.__init__(self, label) super(SDFGState, self).__init__() self._label = label self._parent: SDFG = sdfg self._graph = self # Allowing MemletTrackingView mixin to work self._clear_scopedict_cache() self._debuginfo = debuginfo - self.is_collapsed = False self.nosync = False self.location = location if location is not None else {} self._default_lineinfo = None @@ -839,33 +1192,12 @@ def parent(self): def parent(self, value): self._parent = value - def __str__(self): - return self._label - - @property - def label(self): - return self._label - - @property - def name(self): - return self._label - - def set_label(self, label): - self._label = label - def is_empty(self): return self.number_of_nodes() == 0 def validate(self) -> None: validate_state(self) - def set_default_lineinfo(self, lineinfo: dtypes.DebugInfo): - """ - Sets the default source line information to be lineinfo, or None to - revert to default mode. - """ - self._default_lineinfo = lineinfo - def nodes(self) -> List[nd.Node]: # Added for type hints return super().nodes() @@ -1981,8 +2313,244 @@ def fill_scope_connectors(self): node.add_in_connector(edge.dst_conn) -class StateSubgraphView(SubgraphView, StateGraphView): +class StateSubgraphView(SubgraphView, DataflowGraphView): """ A read-only subgraph view of an SDFG state. """ def __init__(self, graph, subgraph_nodes): super().__init__(graph, subgraph_nodes) + + +@make_properties +class ControlFlowRegion(OrderedDiGraph[ControlFlowBlock, 'dace.sdfg.InterstateEdge'], ControlGraphView, + ControlFlowBlock): + + def __init__(self, + label: str=''): + OrderedDiGraph.__init__(self) + ControlGraphView.__init__(self) + ControlFlowBlock.__init__(self, label) + + self._labels: Set[str] = set() + self._start_block: Optional[int] = None + self._cached_start_block: Optional[ControlFlowBlock] = None + + def add_edge(self, src: ControlFlowBlock, dst: ControlFlowBlock, data: 'dace.sdfg.InterstateEdge'): + """ Adds a new edge to the graph. Must be an InterstateEdge or a subclass thereof. + + :param u: Source node. + :param v: Destination node. + :param edge: The edge to add. + """ + if not isinstance(src, ControlFlowBlock): + raise TypeError('Expected ControlFlowBlock, got ' + str(type(src))) + if not isinstance(dst, ControlFlowBlock): + raise TypeError('Expected ControlFlowBlock, got ' + str(type(dst))) + if not isinstance(data, dace.sdfg.InterstateEdge): + raise TypeError('Expected InterstateEdge, got ' + str(type(data))) + if dst is self._cached_start_block: + self._cached_start_block = None + return super().add_edge(src, dst, data) + + def add_node(self, node, is_start_block=False, *, is_start_state: bool=None): + if not isinstance(node, ControlFlowBlock): + raise TypeError('Expected ControlFlowBlock, got ' + str(type(node))) + super().add_node(node) + self._cached_start_block = None + start_block = is_start_block + if is_start_state is not None: + warnings.warn('is_start_state is deprecated, use is_start_block instead', DeprecationWarning) + start_block = is_start_state + + if start_block: + self.start_block = len(self.nodes()) - 1 + self._cached_start_block = node + + def add_state(self, label=None, is_start_block=False, *, is_start_state: bool=None) -> SDFGState: + if self._labels is None or len(self._labels) != self.number_of_nodes(): + self._labels = set(s.label for s in self.nodes()) + label = label or 'state' + existing_labels = self._labels + label = dt.find_new_name(label, existing_labels) + state = SDFGState(label) + state.parent = self + self._labels.add(label) + start_block = is_start_block + if is_start_state is not None: + warnings.warn('is_start_state is deprecated, use is_start_block instead', DeprecationWarning) + start_block = is_start_state + self.add_node(state, is_start_block=start_block) + return state + + def add_state_before(self, state: SDFGState, label=None, is_start_state=False) -> SDFGState: + """ Adds a new SDFG state before an existing state, reconnecting predecessors to it instead. + + :param state: The state to prepend the new state before. + :param label: State label. + :param is_start_state: If True, resets scope block starting state to this state. + :return: A new SDFGState object. + """ + new_state = self.add_state(label, is_start_state) + # Reconnect + for e in self.in_edges(state): + self.remove_edge(e) + self.add_edge(e.src, new_state, e.data) + # Add unconditional connection between the new state and the current + self.add_edge(new_state, state, dace.sdfg.InterstateEdge()) + return new_state + + def add_state_after(self, state: SDFGState, label=None, is_start_state=False) -> SDFGState: + """ Adds a new SDFG state after an existing state, reconnecting it to the successors instead. + + :param state: The state to append the new state after. + :param label: State label. + :param is_start_state: If True, resets SDFG starting state to this state. + :return: A new SDFGState object. + """ + new_state = self.add_state(label, is_start_state) + # Reconnect + for e in self.out_edges(state): + self.remove_edge(e) + self.add_edge(new_state, e.dst, e.data) + # Add unconditional connection between the current and the new state + self.add_edge(state, new_state, dace.sdfg.InterstateEdge()) + return new_state + + @abc.abstractmethod + def _used_symbols_internal(self, + all_symbols: bool, + defined_syms: Optional[Set] = None, + free_syms: Optional[Set] = None, + used_before_assignment: Optional[Set] = None, + keep_defined_in_mapping: bool = False) -> Tuple[Set[str], Set[str], Set[str]]: + defined_syms = set() if defined_syms is None else defined_syms + free_syms = set() if free_syms is None else free_syms + used_before_assignment = set() if used_before_assignment is None else used_before_assignment + + try: + ordered_blocks = self.topological_sort(self.start_block) + except ValueError: # Failsafe (e.g., for invalid or empty SDFGs) + ordered_blocks = self.nodes() + + for block in ordered_blocks: + state_symbols = set() + if isinstance(block, ControlFlowRegion): + b_free_syms, b_defined_syms, b_used_before_syms = block._used_symbols_internal(all_symbols) + free_syms |= b_free_syms + defined_syms |= b_defined_syms + used_before_assignment |= b_used_before_syms + state_symbols = b_free_syms + else: + state_symbols = block.used_symbols(all_symbols) + free_syms |= state_symbols + + # Add free inter-state symbols + for e in self.out_edges(block): + # NOTE: First we get the true InterstateEdge free symbols, then we compute the newly defined symbols by + # subracting the (true) free symbols from the edge's assignment keys. This way we can correctly + # compute the symbols that are used before being assigned. + efsyms = e.data.used_symbols(all_symbols) + defined_syms |= set(e.data.assignments.keys()) - (efsyms | state_symbols) + used_before_assignment.update(efsyms - defined_syms) + free_syms |= efsyms + + # Remove symbols that were used before they were assigned. + defined_syms -= used_before_assignment + + if isinstance(self, dace.SDFG): + # Remove from defined symbols those that are in the symbol mapping + if self.parent_nsdfg_node is not None and keep_defined_in_mapping: + defined_syms -= set(self.parent_nsdfg_node.symbol_mapping.keys()) + + # Add the set of SDFG symbol parameters + # If all_symbols is False, those symbols would only be added in the case of non-Python tasklets + if all_symbols: + free_syms |= set(self.symbols.keys()) + + # Subtract symbols defined in inter-state edges and constants from the list of free symbols. + free_syms -= defined_syms + + return free_syms, defined_syms, used_before_assignment + + def to_json(self, parent=None): + graph_json = OrderedDiGraph.to_json(self) + block_json = ControlFlowBlock.to_json(self, parent) + graph_json.update(block_json) + return graph_json + + ################################################################### + # Traversal methods + + def all_control_flow_regions(self, recursive=False) -> Iterator['ControlFlowRegion']: + """ Iterate over this and all nested control flow regions. """ + yield self + for block in self.nodes(): + if isinstance(block, SDFGState) and recursive: + for node in block.nodes(): + if isinstance(node, nd.NestedSDFG): + yield from node.sdfg.all_control_flow_regions(recursive=recursive) + elif isinstance(block, ControlFlowRegion): + yield from block.all_control_flow_regions(recursive=recursive) + + def all_sdfgs_recursive(self) -> Iterator['dace.SDFG']: + """ Iterate over this and all nested SDFGs. """ + for cfg in self.all_control_flow_regions(recursive=True): + if isinstance(cfg, dace.SDFG): + yield cfg + + def all_states(self) -> Iterator[SDFGState]: + """ Iterate over all states in this control flow graph. """ + for block in self.nodes(): + if isinstance(block, SDFGState): + yield block + elif isinstance(block, ControlFlowRegion): + yield from block.all_states() + + def all_control_flow_blocks(self, recursive=False) -> Iterator[ControlFlowBlock]: + """ Iterate over all control flow blocks in this control flow graph. """ + for cfg in self.all_control_flow_regions(recursive=recursive): + for block in cfg.nodes(): + yield block + + def all_interstate_edges(self, recursive=False) -> Iterator[Edge['dace.sdfg.InterstateEdge']]: + """ Iterate over all interstate edges in this control flow graph. """ + for cfg in self.all_control_flow_regions(recursive=recursive): + for edge in cfg.edges(): + yield edge + + ################################################################### + # Getters & setters, overrides + + def __str__(self): + return ControlFlowBlock.__str__(self) + + def __repr__(self) -> str: + return f'{self.__class__.__name__} ({self.label})' + + @property + def start_block(self): + """ Returns the starting block of this ControlFlowGraph. """ + if self._cached_start_block is not None: + return self._cached_start_block + + source_nodes = self.source_nodes() + if len(source_nodes) == 1: + self._cached_start_block = source_nodes[0] + return source_nodes[0] + # If the starting block is ambiguous allow manual override. + if self._start_block is not None: + self._cached_start_block = self.node(self._start_block) + return self._cached_start_block + raise ValueError('Ambiguous or undefined starting block for ControlFlowGraph, ' + 'please use "is_start_block=True" when adding the ' + 'starting block with "add_state" or "add_node"') + + @start_block.setter + def start_block(self, block_id): + """ Manually sets the starting block of this ControlFlowGraph. + + :param block_id: The node ID (use `node_id(block)`) of the block to set. + """ + if block_id < 0 or block_id >= self.number_of_nodes(): + raise ValueError('Invalid state ID') + self._start_block = block_id + self._cached_start_block = self.node(block_id) diff --git a/dace/sdfg/utils.py b/dace/sdfg/utils.py index 1078414161..621f8a9e16 100644 --- a/dace/sdfg/utils.py +++ b/dace/sdfg/utils.py @@ -668,7 +668,7 @@ def consolidate_edges(sdfg: SDFG, starting_scope=None) -> int: from dace.sdfg.propagation import propagate_memlets_scope total_consolidated = 0 - for state in sdfg.nodes(): + for state in sdfg.states(): # Start bottom-up if starting_scope and starting_scope.entry not in state.nodes(): continue @@ -1206,8 +1206,8 @@ def fuse_states(sdfg: SDFG, permissive: bool = False, progress: bool = None) -> counter = 0 if progress is True or progress is None: fusible_states = 0 - for sd in sdfg.all_sdfgs_recursive(): - fusible_states += sd.number_of_edges() + for cfg in sdfg.all_control_flow_regions(): + fusible_states += cfg.number_of_edges() if progress is True: pbar = tqdm(total=fusible_states, desc='Fusing states') @@ -1217,30 +1217,32 @@ def fuse_states(sdfg: SDFG, permissive: bool = False, progress: bool = None) -> for sd in sdfg.all_sdfgs_recursive(): id = sd.sdfg_id - while True: - edges = list(sd.nx.edges) - applied = 0 - skip_nodes = set() - for u, v in edges: - if (progress is None and tqdm is not None and (time.time() - start) > 5): - progress = True - pbar = tqdm(total=fusible_states, desc='Fusing states', initial=counter) - - if u in skip_nodes or v in skip_nodes: - continue - candidate = {StateFusion.first_state: u, StateFusion.second_state: v} - sf = StateFusion() - sf.setup_match(sd, id, -1, candidate, 0, override=True) - if sf.can_be_applied(sd, 0, sd, permissive=permissive): - sf.apply(sd, sd) - applied += 1 - counter += 1 - if progress: - pbar.update(1) - skip_nodes.add(u) - skip_nodes.add(v) - if applied == 0: - break + for cfg in sd.all_control_flow_regions(): + while True: + edges = list(cfg.nx.edges) + applied = 0 + skip_nodes = set() + for u, v in edges: + if (progress is None and tqdm is not None and (time.time() - start) > 5): + progress = True + pbar = tqdm(total=fusible_states, desc='Fusing states', initial=counter) + + if (u in skip_nodes or v in skip_nodes or not isinstance(v, SDFGState) or + not isinstance(u, SDFGState)): + continue + candidate = {StateFusion.first_state: u, StateFusion.second_state: v} + sf = StateFusion() + sf.setup_match(cfg, id, -1, candidate, 0, override=True) + if sf.can_be_applied(cfg, 0, sd, permissive=permissive): + sf.apply(cfg, sd) + applied += 1 + counter += 1 + if progress: + pbar.update(1) + skip_nodes.add(u) + skip_nodes.add(v) + if applied == 0: + break if progress: pbar.close() return counter diff --git a/dace/transformation/dataflow/double_buffering.py b/dace/transformation/dataflow/double_buffering.py index 8ff70a6355..6efe6543ca 100644 --- a/dace/transformation/dataflow/double_buffering.py +++ b/dace/transformation/dataflow/double_buffering.py @@ -128,7 +128,7 @@ def apply(self, graph: sd.SDFGState, sdfg: sd.SDFG): ############################## # Add initial reads to initial nested state initial_state: sd.SDFGState = nsdfg_node.sdfg.start_state - initial_state.set_label('%s_init' % map_entry.map.label) + initial_state.label = '%s_init' % map_entry.map.label for edge in edges_to_replace: initial_state.add_node(edge.src) rnode = edge.src @@ -152,7 +152,7 @@ def apply(self, graph: sd.SDFGState, sdfg: sd.SDFG): # Add the main state's contents to the last state, modifying # memlets appropriately. final_state: sd.SDFGState = nsdfg_node.sdfg.sink_nodes()[0] - final_state.set_label('%s_final_computation' % map_entry.map.label) + final_state.label = '%s_final_computation' % map_entry.map.label dup_nstate = copy.deepcopy(nstate) final_state.add_nodes_from(dup_nstate.nodes()) for e in dup_nstate.edges(): @@ -183,7 +183,7 @@ def apply(self, graph: sd.SDFGState, sdfg: sd.SDFG): nstate.add_edge(rnode, edge.src_conn, wnode, edge.dst_conn, new_memlet) - nstate.set_label('%s_double_buffered' % map_entry.map.label) + nstate.label = '%s_double_buffered' % map_entry.map.label # Divide by loop stride new_expr = symbolic.pystr_to_symbolic('((%s / %s) + 1) %% 2' % (map_param, map_rstride)) sd.replace(nstate, '__dace_db_param', new_expr) diff --git a/dace/transformation/interstate/loop_unroll.py b/dace/transformation/interstate/loop_unroll.py index 47d438a2fc..b1dbfdd5c9 100644 --- a/dace/transformation/interstate/loop_unroll.py +++ b/dace/transformation/interstate/loop_unroll.py @@ -116,8 +116,7 @@ def instantiate_loop( # Replace iterate with value in each state for state in new_states: - state.set_label(state.label + '_' + itervar + '_' + - (state_suffix if state_suffix is not None else str(value))) + state.label = state.label + '_' + itervar + '_' + (state_suffix if state_suffix is not None else str(value)) state.replace(itervar, value) # Add subgraph to original SDFG diff --git a/dace/transformation/interstate/multistate_inline.py b/dace/transformation/interstate/multistate_inline.py index 74dd51a483..4d560ab70a 100644 --- a/dace/transformation/interstate/multistate_inline.py +++ b/dace/transformation/interstate/multistate_inline.py @@ -334,7 +334,7 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): if nstate.label in statenames: newname = data.find_new_name(nstate.label, statenames) statenames.add(newname) - nstate.set_label(newname) + nstate.label = newname ####################################################### # Add nested SDFG states into top-level SDFG diff --git a/doc/sdfg/images/elements.svg b/doc/sdfg/images/elements.svg index 80d35e39f0..6402de8e1d 100644 --- a/doc/sdfg/images/elements.svg +++ b/doc/sdfg/images/elements.svg @@ -1,90 +1,506 @@ - + - - - -Access Nodes - -T -ransient -Global - -Stream - -V -iew - -Reference - -T -asklet - - - - - - - - - -Nested SDFG - -Consume - - -Map - -... - - - -Library Node - - -... - -A[0] -CR: Sum -V -olume: 1 - -B[i, j] -V -olume: 1 -Memlet -W -rite-Conflict -Resolution - -State -State -T -ransition - + + + + +Access Nodes + +T +ransient +Global + +Stream + +V +iew + +Reference + +T +asklet + + + + + + + + + +Nested SDFG + +Consume + + +Map + +... + + + +Library Node + + +... + +A[0] +CR: Sum +V +olume: 1 + +B[i, j] +V +olume: 1 +Memlet +W +rite-Conflict +Resolution + +State +State +T +ransition +Control FlowRegion diff --git a/doc/sdfg/ir.rst b/doc/sdfg/ir.rst index 3c651fab19..f7bbb0ff79 100644 --- a/doc/sdfg/ir.rst +++ b/doc/sdfg/ir.rst @@ -29,7 +29,7 @@ Some of the main differences between SDFGs and other representations are: The Language ------------ -In a nutshell, an SDFG is a state machine of acyclic dataflow multigraphs. Here is an example graph: +In a nutshell, an SDFG is a hierarchical state machine of acyclic dataflow multigraphs. Here is an example graph: .. raw:: html @@ -43,7 +43,7 @@ In a nutshell, an SDFG is a state machine of acyclic dataflow multigraphs. Here The cyan rectangles are called **states** and together they form a state machine, executing the code from the starting state and following the blue edge that matches the conditions. In each state, an acyclic multigraph controls execution -through dataflow. There are four elements in the above state: +through dataflow. There are four elements in the above states: * **Access nodes** (ovals) that give access to data containers * **Memlets** (edges/dotted arrows) that represent units of data movement @@ -58,7 +58,14 @@ The state machine shown in the example is a for-loop (``for _ in range(5)``). Th the guard state controls the loop, and at the end the result is copied to the special ``__return`` data container, which designates the return value of the function. -There are other kinds of elements in an SDFG, as detailed below. +The state machine is analogous to a control flow graph, where states represent basic blocks. Multiple such basic blocks, +such as with the described loop, can be put together to form a **control flow region**. This allows them to be +represented with a single graph node in the SDFG's state machine, which is useful for optimization and analysis. +The SDFG itself can be thought of as one big control flow region. This means that control flow regions are directed +graphs, where nodes are states or other control flow regions, and edges are state transitions. + +In addition to the elements seen in the example above, there are other kinds of elements in an SDFG, which are detailed +below. .. _sdfg-lang: @@ -142,6 +149,12 @@ new value, and specifies how the update is performed. In the summation example, end of each state there is an implicit synchronization point, so it will not finish executing until all the last nodes have been reached (this assumption can be removed in extreme cases, see :class:`~dace.sdfg.state.SDFGState.nosync`). +**Control Flow Region**: Forms a directed graph of states and other control flow regions, where edges are state +transitions. This allows representing complex control flow in a single graph node, which is useful for analysis and +optimization. The SDFG itself is a control flow region, which means that control flow regions are recursive / +hierarchical. Similar to the SDFG, each control flow region has a unique starting state, which is the entry point to +the region and is executed first. + **State Transition**: Transitions, internally referred to as *inter-state edges*, specify how execution proceeds after the end of a State. Inter-state edges optionally contain a symbolic *condition* that is checked at the end of the preceding state. If any of the conditions are true, execution will continue to the destination of this edge (the @@ -783,5 +796,7 @@ file uses the :func:`~dace.sdfg.sdfg.SDFG.from_file` static method. For example, The ``compress`` argument can be used to save a smaller (``gzip`` compressed) file. It can keep the same extension, but it is customary to use ``.sdfg.gz`` or ``.sdfgz`` to let others know it is compressed. +It is recommended to use this option for large SDFGs, as it not only saves space, but also speeds up loading and +editing of the SDFG in visualization tools and the VSCode extension. diff --git a/requirements.txt b/requirements.txt index 5f804e1b4c..27560949fb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,9 +14,9 @@ Jinja2==3.1.2 MarkupSafe==2.1.3 mpmath==1.3.0 networkx==3.1 -numpy==1.24.3 +numpy==1.26.1 ply==3.11 -PyYAML==6.0 +PyYAML==6.0.1 requests==2.31.0 six==1.16.0 sympy==1.9 diff --git a/tests/sdfg/nested_control_flow_regions_test.py b/tests/sdfg/nested_control_flow_regions_test.py new file mode 100644 index 0000000000..f29c093dad --- /dev/null +++ b/tests/sdfg/nested_control_flow_regions_test.py @@ -0,0 +1,18 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +import pytest + +import dace + + +def test_is_start_state_deprecation(): + sdfg = dace.SDFG('deprecation_test') + with pytest.deprecated_call(): + sdfg.add_state('state1', is_start_state=True) + sdfg2 = dace.SDFG('deprecation_test2') + state = dace.SDFGState('state2') + with pytest.deprecated_call(): + sdfg2.add_node(state, is_start_state=True) + + +if __name__ == '__main__': + test_is_start_state_deprecation() diff --git a/tests/sdfg_validate_names_test.py b/tests/sdfg_validate_names_test.py index dad79c8950..1650a4e4b1 100644 --- a/tests/sdfg_validate_names_test.py +++ b/tests/sdfg_validate_names_test.py @@ -28,7 +28,7 @@ def test_state_duplication(self): sdfg = dace.SDFG('ok') s1 = sdfg.add_state('also_ok') s2 = sdfg.add_state('also_ok') - s2.set_label('also_ok') + s2.label = 'also_ok' sdfg.add_edge(s1, s2, dace.InterstateEdge()) sdfg.validate() self.fail('Failed to detect duplicate state') From dff301c3d28c4cb3d0a6ba6c017bce22f941f6f6 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 2 Nov 2023 17:25:33 +0100 Subject: [PATCH 16/19] Bump werkzeug from 2.3.5 to 3.0.1 (#1409) Bumps [werkzeug](https://github.com/pallets/werkzeug) from 2.3.5 to 3.0.1. - [Release notes](https://github.com/pallets/werkzeug/releases) - [Changelog](https://github.com/pallets/werkzeug/blob/main/CHANGES.rst) - [Commits](https://github.com/pallets/werkzeug/compare/2.3.5...3.0.1) --- updated-dependencies: - dependency-name: werkzeug dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 27560949fb..266b3368c8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,5 +22,5 @@ six==1.16.0 sympy==1.9 urllib3==2.0.7 websockets==11.0.3 -Werkzeug==2.3.5 +Werkzeug==3.0.1 zipp==3.15.0 From ab11b20a66e720b5250ab46580787b20c87418e2 Mon Sep 17 00:00:00 2001 From: matteonussbauemer Date: Thu, 2 Nov 2023 18:43:32 +0100 Subject: [PATCH 17/19] set sympy version back to 1.9 --- requirements.txt | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 12c50a2eb5..5f804e1b4c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,7 +19,7 @@ ply==3.11 PyYAML==6.0 requests==2.31.0 six==1.16.0 -sympy==1.12 +sympy==1.9 urllib3==2.0.7 websockets==11.0.3 Werkzeug==2.3.5 diff --git a/setup.py b/setup.py index cd5189437e..a0ac2e2d49 100644 --- a/setup.py +++ b/setup.py @@ -73,7 +73,7 @@ }, include_package_data=True, install_requires=[ - 'numpy', 'networkx >= 2.5', 'astunparse', 'sympy>=1.12', 'pyyaml', 'ply', 'websockets', 'requests', 'flask', + 'numpy', 'networkx >= 2.5', 'astunparse', 'sympy<=1.9', 'pyyaml', 'ply', 'websockets', 'requests', 'flask', 'fparser >= 0.1.3', 'aenum >= 3.1', 'dataclasses; python_version < "3.7"', 'dill', 'pyreadline;platform_system=="Windows"', 'typing-compat; python_version < "3.8"' ] + cmake_requires, From 5389a3136605f5ad59a0bd610eaca75906e1069c Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Thu, 2 Nov 2023 18:00:21 -0700 Subject: [PATCH 18/19] GPU code generation: User-specified block/thread/warp location (#1358) * Remove persistent GPU kernel write scope heuristics * Allow CUDA device-level tasklets to have user-specified block/thread/warp specialization * Logic fixes for CPU dispatch in GPU code generator --- dace/codegen/targets/cuda.py | 90 ++++++++++++++++++++++++++++++++---- tests/cuda_block_test.py | 38 +++++++++++++++ 2 files changed, 120 insertions(+), 8 deletions(-) diff --git a/dace/codegen/targets/cuda.py b/dace/codegen/targets/cuda.py index a465d2bbc0..fb8ae90187 100644 --- a/dace/codegen/targets/cuda.py +++ b/dace/codegen/targets/cuda.py @@ -445,7 +445,7 @@ def node_dispatch_predicate(self, sdfg, state, node): if hasattr(node, 'schedule'): # NOTE: Works on nodes and scopes if node.schedule in dtypes.GPU_SCHEDULES: return True - if isinstance(node, nodes.NestedSDFG) and CUDACodeGen._in_device_code: + if CUDACodeGen._in_device_code: return True return False @@ -1324,11 +1324,11 @@ def generate_devicelevel_state(self, sdfg, state, function_stream, callsite_stre if write_scope == 'grid': callsite_stream.write("if (blockIdx.x == 0 " - "&& threadIdx.x == 0) " - "{ // sub-graph begin", sdfg, state.node_id) + "&& threadIdx.x == 0) " + "{ // sub-graph begin", sdfg, state.node_id) elif write_scope == 'block': callsite_stream.write("if (threadIdx.x == 0) " - "{ // sub-graph begin", sdfg, state.node_id) + "{ // sub-graph begin", sdfg, state.node_id) else: callsite_stream.write("{ // subgraph begin", sdfg, state.node_id) else: @@ -2519,15 +2519,17 @@ def generate_devicelevel_scope(self, sdfg, dfg_scope, state_id, function_stream, def generate_node(self, sdfg, dfg, state_id, node, function_stream, callsite_stream): if self.node_dispatch_predicate(sdfg, dfg, node): # Dynamically obtain node generator according to class name - gen = getattr(self, '_generate_' + type(node).__name__) - gen(sdfg, dfg, state_id, node, function_stream, callsite_stream) - return + gen = getattr(self, '_generate_' + type(node).__name__, False) + if gen is not False: # Not every node type has a code generator here + gen(sdfg, dfg, state_id, node, function_stream, callsite_stream) + return if not CUDACodeGen._in_device_code: self._cpu_codegen.generate_node(sdfg, dfg, state_id, node, function_stream, callsite_stream) return - self._locals.clear_scope(self._code_state.indentation + 1) + if isinstance(node, nodes.ExitNode): + self._locals.clear_scope(self._code_state.indentation + 1) if CUDACodeGen._in_device_code and isinstance(node, nodes.MapExit): return # skip @@ -2591,6 +2593,78 @@ def _generate_MapExit(self, sdfg, dfg, state_id, node, function_stream, callsite self._cpu_codegen._generate_MapExit(sdfg, dfg, state_id, node, function_stream, callsite_stream) + def _get_thread_id(self) -> str: + result = 'threadIdx.x' + if self._block_dims[1] != 1: + result += f' + ({sym2cpp(self._block_dims[0])}) * threadIdx.y' + if self._block_dims[2] != 1: + result += f' + ({sym2cpp(self._block_dims[0] * self._block_dims[1])}) * threadIdx.z' + return result + + def _get_warp_id(self) -> str: + return f'(({self._get_thread_id()}) / warpSize)' + + def _get_block_id(self) -> str: + result = 'blockIdx.x' + if self._block_dims[1] != 1: + result += f' + gridDim.x * blockIdx.y' + if self._block_dims[2] != 1: + result += f' + gridDim.x * gridDim.y * blockIdx.z' + return result + + def _generate_condition_from_location(self, name: str, index_expr: str, node: nodes.Tasklet, + callsite_stream: CodeIOStream) -> str: + if name not in node.location: + return 0 + + location: Union[int, str, subsets.Range] = node.location[name] + if isinstance(location, str) and ':' in location: + location = subsets.Range.from_string(location) + elif symbolic.issymbolic(location): + location = sym2cpp(location) + + if isinstance(location, subsets.Range): + # Range of indices + if len(location) != 1: + raise ValueError(f'Only one-dimensional ranges are allowed for {name} specialization, {location} given') + begin, end, stride = location[0] + rb, re, rs = sym2cpp(begin), sym2cpp(end), sym2cpp(stride) + cond = '' + cond += f'(({index_expr}) >= {rb}) && (({index_expr}) <= {re})' + if stride != 1: + cond += f' && ((({index_expr}) - {rb}) % {rs} == 0)' + + callsite_stream.write(f'if ({cond}) {{') + else: + # Single-element + callsite_stream.write(f'if (({index_expr}) == {location}) {{') + + return 1 + + def _generate_Tasklet(self, sdfg: SDFG, dfg, state_id: int, node: nodes.Tasklet, function_stream: CodeIOStream, + callsite_stream: CodeIOStream): + generated_preamble_scopes = 0 + if self._in_device_code: + # If location dictionary prescribes that the code should run on a certain group of threads/blocks, + # add condition + generated_preamble_scopes += self._generate_condition_from_location('gpu_thread', self._get_thread_id(), + node, callsite_stream) + generated_preamble_scopes += self._generate_condition_from_location('gpu_warp', self._get_warp_id(), node, + callsite_stream) + generated_preamble_scopes += self._generate_condition_from_location('gpu_block', self._get_block_id(), node, + callsite_stream) + + # Call standard tasklet generation + old_codegen = self._cpu_codegen.calling_codegen + self._cpu_codegen.calling_codegen = self + self._cpu_codegen._generate_Tasklet(sdfg, dfg, state_id, node, function_stream, callsite_stream) + self._cpu_codegen.calling_codegen = old_codegen + + if generated_preamble_scopes > 0: + # Generate appropriate postamble + for i in range(generated_preamble_scopes): + callsite_stream.write('}', sdfg, state_id, node) + def make_ptr_vector_cast(self, *args, **kwargs): return cpp.make_ptr_vector_cast(*args, **kwargs) diff --git a/tests/cuda_block_test.py b/tests/cuda_block_test.py index f77e80673f..676785e0e5 100644 --- a/tests/cuda_block_test.py +++ b/tests/cuda_block_test.py @@ -10,8 +10,10 @@ @dace.program(dace.float64[N], dace.float64[N]) def cudahello(V, Vout): + @dace.mapscope(_[0:N:32]) def multiplication(i): + @dace.map(_[0:32]) def mult_block(bi): in_V << V[i + bi] @@ -55,6 +57,7 @@ def test_gpu(): @pytest.mark.gpu def test_different_block_sizes_nesting(): + @dace.program def nested(V: dace.float64[34], v1: dace.float64[1]): with dace.tasklet: @@ -105,6 +108,7 @@ def diffblocks(V: dace.float64[130], v1: dace.float64[4], v2: dace.float64[128]) @pytest.mark.gpu def test_custom_block_size_onemap(): + @dace.program def tester(A: dace.float64[400, 300]): for i, j in dace.map[0:400, 0:300]: @@ -132,6 +136,7 @@ def tester(A: dace.float64[400, 300]): @pytest.mark.gpu def test_custom_block_size_twomaps(): + @dace.program def tester(A: dace.float64[400, 300, 2, 32]): for i, j in dace.map[0:400, 0:300]: @@ -154,9 +159,42 @@ def tester(A: dace.float64[400, 300, 2, 32]): sdfg.compile() +@pytest.mark.gpu +def test_block_thread_specialization(): + + @dace.program + def tester(A: dace.float64[200]): + for i in dace.map[0:200:32]: + for bi in dace.map[0:32]: + with dace.tasklet: + a >> A[i + bi] + a = 1 + with dace.tasklet: # Tasklet to be specialized + a >> A[i + bi] + a = 2 + + sdfg = tester.to_sdfg() + sdfg.apply_gpu_transformations(sequential_innermaps=False) + tasklet = next(n for n, _ in sdfg.all_nodes_recursive() + if isinstance(n, dace.nodes.Tasklet) and '2' in n.code.as_string) + tasklet.location['gpu_thread'] = dace.subsets.Range.from_string('2:9:3') + tasklet.location['gpu_block'] = 1 + + code = sdfg.generate_code()[1].clean_code # Get GPU code (second file) + assert '>= 2' in code and '<= 8' in code + assert ' == 1' in code + + a = np.random.rand(200) + ref = np.ones_like(a) + ref[32:64][2:9:3] = 2 + sdfg(a) + assert np.allclose(a, ref) + + if __name__ == "__main__": test_cpu() test_gpu() test_different_block_sizes_nesting() test_custom_block_size_onemap() test_custom_block_size_twomaps() + test_block_thread_specialization() From 9430e874dadcf77e45a03887e66fd6da4a9cc4b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lukas=20Tr=C3=BCmper?= Date: Fri, 3 Nov 2023 09:07:31 +0100 Subject: [PATCH 19/19] AugAssignToWCR: Support for more cases and increased test coverage (#1359) --- .../transformation/dataflow/wcr_conversion.py | 152 ++++++----- tests/transformations/wcr_conversion_test.py | 247 ++++++++++++++++++ 2 files changed, 332 insertions(+), 67 deletions(-) create mode 100644 tests/transformations/wcr_conversion_test.py diff --git a/dace/transformation/dataflow/wcr_conversion.py b/dace/transformation/dataflow/wcr_conversion.py index e95674adc1..7f4fbc654d 100644 --- a/dace/transformation/dataflow/wcr_conversion.py +++ b/dace/transformation/dataflow/wcr_conversion.py @@ -2,10 +2,14 @@ """ Transformations to convert subgraphs to write-conflict resolutions. """ import ast import re -from dace import registry, nodes, dtypes +import copy +from dace import registry, nodes, dtypes, Memlet from dace.transformation import transformation, helpers as xfh from dace.sdfg import graph as gr, utils as sdutil from dace import SDFG, SDFGState +from dace.sdfg.state import StateSubgraphView +from dace.transformation import helpers +from dace.sdfg.propagation import propagate_memlets_state class AugAssignToWCR(transformation.SingleStateTransformation): @@ -20,6 +24,7 @@ class AugAssignToWCR(transformation.SingleStateTransformation): map_exit = transformation.PatternNode(nodes.MapExit) _EXPRESSIONS = ['+', '-', '*', '^', '%'] #, '/'] + _FUNCTIONS = ['min', 'max'] _EXPR_MAP = {'-': ('+', '-({expr})'), '/': ('*', '((decltype({expr}))1)/({expr})')} _PYOP_MAP = {ast.Add: '+', ast.Sub: '-', ast.Mult: '*', ast.BitXor: '^', ast.Mod: '%', ast.Div: '/'} @@ -27,6 +32,7 @@ class AugAssignToWCR(transformation.SingleStateTransformation): def expressions(cls): return [ sdutil.node_path_graph(cls.input, cls.tasklet, cls.output), + sdutil.node_path_graph(cls.input, cls.map_entry, cls.tasklet, cls.map_exit, cls.output) ] def can_be_applied(self, graph, expr_index, sdfg, permissive=False): @@ -38,7 +44,6 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): # Free tasklet if expr_index == 0: - # Only free tasklets supported for now if graph.entry_node(tasklet) is not None: return False @@ -49,8 +54,6 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): # Make sure augmented assignment can be fissioned as necessary if any(not isinstance(e.src, nodes.AccessNode) for e in graph.in_edges(tasklet)): return False - if graph.in_degree(inarr) > 0 and graph.out_degree(outarr) > 0: - return False outedge = graph.edges_between(tasklet, outarr)[0] else: # Free map @@ -65,12 +68,10 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): if len(graph.edges_between(tasklet, mx)) > 1: return False - # Currently no fission is supported + # Make sure augmented assignment can be fissioned as necessary if any(e.src is not me and not isinstance(e.src, nodes.AccessNode) for e in graph.in_edges(me) + graph.in_edges(tasklet)): return False - if graph.in_degree(inarr) > 0: - return False outedge = graph.edges_between(tasklet, mx)[0] @@ -78,6 +79,7 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): outconn = outedge.src_conn ops = '[%s]' % ''.join(re.escape(o) for o in AugAssignToWCR._EXPRESSIONS) + funcs = '|'.join(re.escape(o) for o in AugAssignToWCR._FUNCTIONS) if tasklet.language is dtypes.Language.Python: # Match a single assignment with a binary operation as RHS @@ -108,18 +110,33 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): # Try to match a single C assignment that can be converted to WCR inconn = edge.dst_conn lhs = r'^\s*%s\s*=\s*%s\s*%s.*;$' % (re.escape(outconn), re.escape(inconn), ops) - rhs = r'^\s*%s\s*=\s*.*%s\s*%s;$' % (re.escape(outconn), ops, re.escape(inconn)) - if re.match(lhs, cstr) is None: - continue + # rhs: a = (...) op b + rhs = r'^\s*%s\s*=\s*\(.*\)\s*%s\s*%s;$' % (re.escape(outconn), ops, re.escape(inconn)) + func_lhs = r'^\s*%s\s*=\s*(%s)\(\s*%s\s*,.*\)\s*;$' % (re.escape(outconn), funcs, re.escape(inconn)) + func_rhs = r'^\s*%s\s*=\s*(%s)\(.*,\s*%s\s*\)\s*;$' % (re.escape(outconn), funcs, re.escape(inconn)) + if re.match(lhs, cstr) is None and re.match(rhs, cstr) is None: + if re.match(func_lhs, cstr) is None and re.match(func_rhs, cstr) is None: + inconns = list(self.tasklet.in_connectors) + if len(inconns) != 2: + continue + + # Special case: a = op b + other_inconn = inconns[0] if inconns[0] != inconn else inconns[1] + rhs2 = r'^\s*%s\s*=\s*%s\s*%s\s*%s;$' % (re.escape(outconn), re.escape(other_inconn), ops, + re.escape(inconn)) + if re.match(rhs2, cstr) is None: + continue + # Same memlet if edge.data.subset != outedge.data.subset: continue # If in map, only match if the subset is independent of any # map indices (otherwise no conflict) - if (expr_index == 1 and len(outedge.data.subset.free_symbols - & set(me.map.params)) == len(me.map.params)): - continue + if expr_index == 1: + if not permissive and len(outedge.data.subset.free_symbols & set(me.map.params)) == len( + me.map.params): + continue return True else: @@ -132,50 +149,22 @@ def apply(self, state: SDFGState, sdfg: SDFG): input: nodes.AccessNode = self.input tasklet: nodes.Tasklet = self.tasklet output: nodes.AccessNode = self.output + if self.expr_index == 1: + me = self.map_entry + mx = self.map_exit # If state fission is necessary to keep semantics, do it first - if (self.expr_index == 0 and state.in_degree(input) > 0 and state.out_degree(output) == 0): - newstate = sdfg.add_state_after(state) - newstate.add_node(tasklet) - new_input, new_output = None, None - - # Keep old edges for after we remove tasklet from the original state - in_edges = list(state.in_edges(tasklet)) - out_edges = list(state.out_edges(tasklet)) - - for e in in_edges: - r = newstate.add_read(e.src.data) - newstate.add_edge(r, e.src_conn, e.dst, e.dst_conn, e.data) - if e.src is input: - new_input = r - for e in out_edges: - w = newstate.add_write(e.dst.data) - newstate.add_edge(e.src, e.src_conn, w, e.dst_conn, e.data) - if e.dst is output: - new_output = w - - # Remove tasklet and resulting isolated nodes - state.remove_node(tasklet) - for e in in_edges: - if state.degree(e.src) == 0: - state.remove_node(e.src) - for e in out_edges: - if state.degree(e.dst) == 0: - state.remove_node(e.dst) - - # Reset state and nodes for rest of transformation - input = new_input - output = new_output - state = newstate - # End of state fission + if state.in_degree(input) > 0: + subgraph_nodes = set([e.src for e in state.bfs_edges(input, reverse=True)]) + subgraph_nodes.add(input) + + subgraph = StateSubgraphView(state, subgraph_nodes) + helpers.state_fission(sdfg, subgraph) if self.expr_index == 0: inedges = state.edges_between(input, tasklet) outedge = state.edges_between(tasklet, output)[0] else: - me = self.map_entry - mx = self.map_exit - inedges = state.edges_between(me, tasklet) outedge = state.edges_between(tasklet, mx)[0] @@ -183,6 +172,7 @@ def apply(self, state: SDFGState, sdfg: SDFG): outconn = outedge.src_conn ops = '[%s]' % ''.join(re.escape(o) for o in AugAssignToWCR._EXPRESSIONS) + funcs = '|'.join(re.escape(o) for o in AugAssignToWCR._FUNCTIONS) # Change tasklet code if tasklet.language is dtypes.Language.Python: @@ -206,13 +196,40 @@ def apply(self, state: SDFGState, sdfg: SDFG): inconn = edge.dst_conn match = re.match(r'^\s*%s\s*=\s*%s\s*(%s)(.*);$' % (re.escape(outconn), re.escape(inconn), ops), cstr) if match is None: - # match = re.match( - # r'^\s*%s\s*=\s*(.*)\s*(%s)\s*%s;$' % - # (re.escape(outconn), ops, re.escape(inconn)), cstr) - # if match is None: - continue - # op = match.group(2) - # expr = match.group(1) + match = re.match( + r'^\s*%s\s*=\s*\((.*)\)\s*(%s)\s*%s;$' % (re.escape(outconn), ops, re.escape(inconn)), cstr) + if match is None: + func_rhs = r'^\s*%s\s*=\s*(%s)\((.*),\s*%s\s*\)\s*;$' % (re.escape(outconn), funcs, + re.escape(inconn)) + match = re.match(func_rhs, cstr) + if match is None: + func_lhs = r'^\s*%s\s*=\s*(%s)\(\s*%s\s*,(.*)\)\s*;$' % (re.escape(outconn), funcs, + re.escape(inconn)) + match = re.match(func_lhs, cstr) + if match is None: + inconns = list(self.tasklet.in_connectors) + if len(inconns) != 2: + continue + + # Special case: a = op b + other_inconn = inconns[0] if inconns[0] != inconn else inconns[1] + rhs2 = r'^\s*%s\s*=\s*(%s)\s*(%s)\s*%s;$' % ( + re.escape(outconn), re.escape(other_inconn), ops, re.escape(inconn)) + match = re.match(rhs2, cstr) + if match is None: + continue + else: + op = match.group(2) + expr = match.group(1) + else: + op = match.group(1) + expr = match.group(2) + else: + op = match.group(1) + expr = match.group(2) + else: + op = match.group(2) + expr = match.group(1) else: op = match.group(1) expr = match.group(2) @@ -232,16 +249,14 @@ def apply(self, state: SDFGState, sdfg: SDFG): raise NotImplementedError # Change output edge - outedge.data.wcr = f'lambda a,b: a {op} b' - - if self.expr_index == 0: - # Remove input node and connector - state.remove_edge_and_connectors(inedge) - if state.degree(input) == 0: - state.remove_node(input) + if op in AugAssignToWCR._FUNCTIONS: + outedge.data.wcr = f'lambda a,b: {op}(a, b)' else: - # Remove input edge and dst connector, but not necessarily src - state.remove_memlet_path(inedge) + outedge.data.wcr = f'lambda a,b: a {op} b' + + # Remove input node and connector + state.remove_memlet_path(inedge) + propagate_memlets_state(sdfg, state) # If outedge leads to non-transient, and this is a nested SDFG, # propagate outwards @@ -252,6 +267,9 @@ def apply(self, state: SDFGState, sdfg: SDFG): sd = sd.parent_sdfg outedge = next(iter(nstate.out_edges_by_connector(nsdfg, outedge.data.data))) for outedge in nstate.memlet_path(outedge): - outedge.data.wcr = f'lambda a,b: a {op} b' + if op in AugAssignToWCR._FUNCTIONS: + outedge.data.wcr = f'lambda a,b: {op}(a, b)' + else: + outedge.data.wcr = f'lambda a,b: a {op} b' # At this point we are leading to an access node again and can # traverse further up diff --git a/tests/transformations/wcr_conversion_test.py b/tests/transformations/wcr_conversion_test.py new file mode 100644 index 0000000000..091b2a9db8 --- /dev/null +++ b/tests/transformations/wcr_conversion_test.py @@ -0,0 +1,247 @@ +import dace + +from dace.transformation.dataflow import AugAssignToWCR + + +def test_aug_assign_tasklet_lhs(): + + @dace.program + def sdfg_aug_assign_tasklet_lhs(A: dace.float64[32], B: dace.float64[32]): + for i in range(32): + with dace.tasklet: + a << A[i] + k << B[i] + b >> A[i] + b = a + k + + sdfg = sdfg_aug_assign_tasklet_lhs.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_tasklet_lhs_brackets(): + + @dace.program + def sdfg_aug_assign_tasklet_lhs_brackets(A: dace.float64[32], B: dace.float64[32]): + for i in range(32): + with dace.tasklet: + a << A[i] + k << B[i] + b >> A[i] + b = a + (k + 1) + + sdfg = sdfg_aug_assign_tasklet_lhs_brackets.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_tasklet_rhs(): + + @dace.program + def sdfg_aug_assign_tasklet_rhs(A: dace.float64[32], B: dace.float64[32]): + for i in range(32): + with dace.tasklet: + a << A[i] + k << B[i] + b >> A[i] + b = k + a + + sdfg = sdfg_aug_assign_tasklet_rhs.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_tasklet_rhs_brackets(): + + @dace.program + def sdfg_aug_assign_tasklet_rhs_brackets(A: dace.float64[32], B: dace.float64[32]): + for i in range(32): + with dace.tasklet: + a << A[i] + k << B[i] + b >> A[i] + b = (k + 1) + a + + sdfg = sdfg_aug_assign_tasklet_rhs_brackets.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_tasklet_lhs_cpp(): + + @dace.program + def sdfg_aug_assign_tasklet_lhs_cpp(A: dace.float64[32], B: dace.float64[32]): + for i in range(32): + with dace.tasklet(language=dace.Language.CPP): + a << A[i] + k << B[i] + b >> A[i] + """ + b = a + k; + """ + + sdfg = sdfg_aug_assign_tasklet_lhs_cpp.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_tasklet_lhs_brackets_cpp(): + + @dace.program + def sdfg_aug_assign_tasklet_lhs_brackets_cpp(A: dace.float64[32], B: dace.float64[32]): + for i in range(32): + with dace.tasklet(language=dace.Language.CPP): + a << A[i] + k << B[i] + b >> A[i] + """ + b = a + (k + 1); + """ + + sdfg = sdfg_aug_assign_tasklet_lhs_brackets_cpp.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_tasklet_rhs_brackets_cpp(): + + @dace.program + def sdfg_aug_assign_tasklet_rhs_brackets_cpp(A: dace.float64[32], B: dace.float64[32]): + for i in range(32): + with dace.tasklet(language=dace.Language.CPP): + a << A[i] + k << B[i] + b >> A[i] + """ + b = (k + 1) + a; + """ + + sdfg = sdfg_aug_assign_tasklet_rhs_brackets_cpp.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_tasklet_func_lhs_cpp(): + + @dace.program + def sdfg_aug_assign_tasklet_func_lhs_cpp(A: dace.float64[32], B: dace.float64[32]): + for i in range(32): + with dace.tasklet(language=dace.Language.CPP): + a << A[i] + c << B[i] + b >> A[i] + """ + b = min(a, c); + """ + + sdfg = sdfg_aug_assign_tasklet_func_lhs_cpp.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_tasklet_func_rhs_cpp(): + + @dace.program + def sdfg_aug_assign_tasklet_func_rhs_cpp(A: dace.float64[32], B: dace.float64[32]): + for i in range(32): + with dace.tasklet(language=dace.Language.CPP): + a << A[i] + c << B[i] + b >> A[i] + """ + b = min(c, a); + """ + + sdfg = sdfg_aug_assign_tasklet_func_rhs_cpp.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_free_map(): + + @dace.program + def sdfg_aug_assign_free_map(A: dace.float64[32], B: dace.float64[32]): + for i in dace.map[0:32]: + with dace.tasklet(language=dace.Language.CPP): + a << A[0] + k << B[i] + b >> A[0] + """ + b = k * a; + """ + + sdfg = sdfg_aug_assign_free_map.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_state_fission_map(): + + @dace.program + def sdfg_aug_assign_state_fission(A: dace.float64[32], B: dace.float64[32]): + for i in dace.map[0:32]: + with dace.tasklet: + a << B[i] + b >> A[i] + b = a + + for i in dace.map[0:32]: + with dace.tasklet: + a << A[0] + b >> A[0] + b = a * 2 + + for i in dace.map[0:32]: + with dace.tasklet: + a << A[0] + b >> A[0] + b = a * 2 + + sdfg = sdfg_aug_assign_state_fission.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 2 + + +def test_free_map_permissive(): + + @dace.program + def sdfg_free_map_permissive(A: dace.float64[32], B: dace.float64[32]): + for i in dace.map[0:32]: + with dace.tasklet(language=dace.Language.CPP): + a << A[i] + k << B[i] + b >> A[i] + """ + b = k * a; + """ + + sdfg = sdfg_free_map_permissive.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR, permissive=False) + assert applied == 0 + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR, permissive=True) + assert applied == 1