diff --git a/dace/transformation/dataflow/__init__.py b/dace/transformation/dataflow/__init__.py index 4ed7fd6283..6fa274f041 100644 --- a/dace/transformation/dataflow/__init__.py +++ b/dace/transformation/dataflow/__init__.py @@ -38,7 +38,7 @@ RedundantArrayCopying3) from .merge_arrays import InMergeArrays, OutMergeArrays, MergeSourceSinkArrays from .prune_connectors import PruneConnectors, PruneSymbols -from .wcr_conversion import AugAssignToWCR +from .wcr_conversion import AugAssignToWCR, WCRToAugAssign from .tasklet_fusion import TaskletFusion from .trivial_tasklet_elimination import TrivialTaskletElimination diff --git a/dace/transformation/dataflow/wcr_conversion.py b/dace/transformation/dataflow/wcr_conversion.py index 60da5d3939..443f7734c8 100644 --- a/dace/transformation/dataflow/wcr_conversion.py +++ b/dace/transformation/dataflow/wcr_conversion.py @@ -1,13 +1,14 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Transformations to convert subgraphs to write-conflict resolutions. """ import ast +import copy import re import copy -from dace import registry, nodes, dtypes, Memlet -from dace.transformation import transformation, helpers as xfh -from dace.sdfg import graph as gr, utils as sdutil -from dace import SDFG, SDFGState -from dace.sdfg.state import StateSubgraphView +from dace import nodes, dtypes, Memlet +from dace.frontend.python import astutils +from dace.transformation import transformation +from dace.sdfg import utils as sdutil +from dace import Memlet, SDFG, SDFGState from dace.transformation import helpers from dace.sdfg.propagation import propagate_memlets_state @@ -268,3 +269,81 @@ def apply(self, state: SDFGState, sdfg: SDFG): outedge.data.wcr = f'lambda a,b: a {op} b' # At this point we are leading to an access node again and can # traverse further up + + +class WCRToAugAssign(transformation.SingleStateTransformation): + """ + Converts a tasklet with a write-conflict resolution to an augmented assignment subgraph (e.g., "a = a + b"). + """ + tasklet = transformation.PatternNode(nodes.Tasklet) + output = transformation.PatternNode(nodes.AccessNode) + map_exit = transformation.PatternNode(nodes.MapExit) + + _EXPRESSIONS = ['+', '-', '*', '^', '%'] #, '/'] + _EXPR_MAP = {'-': ('+', '-({expr})'), '/': ('*', '((decltype({expr}))1)/({expr})')} + _PYOP_MAP = {ast.Add: '+', ast.Sub: '-', ast.Mult: '*', ast.BitXor: '^', ast.Mod: '%', ast.Div: '/'} + + @classmethod + def expressions(cls): + return [ + sdutil.node_path_graph(cls.tasklet, cls.output), + sdutil.node_path_graph(cls.tasklet, cls.map_exit, cls.output) + ] + + def can_be_applied(self, graph, expr_index, sdfg, permissive=False): + if expr_index == 0: + edges = graph.edges_between(self.tasklet, self.output) + else: + edges = graph.edges_between(self.tasklet, self.map_exit) + if len(edges) != 1: + return False + if edges[0].data.wcr is None: + return False + + # If the access subset on the WCR edge is overapproximated (i.e., the access may be dynamic), we do not support + # swapping to an augmented assignment pattern with this transformation. + if edges[0].data.subset.num_elements() > edges[0].data.volume or edges[0].data.dynamic is True: + return False + + return True + + def apply(self, state: SDFGState, sdfg: SDFG): + if self.expr_index == 0: + edge = state.edges_between(self.tasklet, self.output)[0] + wcr = ast.parse(edge.data.wcr).body[0].value.body + if isinstance(wcr, ast.BinOp): + wcr.left.id = '__in1' + wcr.right.id = '__in2' + code = astutils.unparse(wcr) + else: + raise NotImplementedError + edge.data.wcr = None + in_access = state.add_access(self.output.data) + new_tasklet = state.add_tasklet('augassign', {'__in1', '__in2'}, {'__out'}, f"__out = {code}") + scal_name, scal_desc = sdfg.add_scalar('tmp', sdfg.arrays[self.output.data].dtype, transient=True, + find_new_name=True) + state.add_edge(self.tasklet, edge.src_conn, new_tasklet, '__in1', Memlet.from_array(scal_name, scal_desc)) + state.add_edge(in_access, None, new_tasklet, '__in2', copy.deepcopy(edge.data)) + state.add_edge(new_tasklet, '__out', self.output, edge.dst_conn, edge.data) + state.remove_edge(edge) + else: + edge = state.edges_between(self.tasklet, self.map_exit)[0] + map_entry = state.entry_node(self.map_exit) + wcr = ast.parse(edge.data.wcr).body[0].value.body + if isinstance(wcr, ast.BinOp): + wcr.left.id = '__in1' + wcr.right.id = '__in2' + code = astutils.unparse(wcr) + else: + raise NotImplementedError + for e in state.memlet_path(edge): + e.data.wcr = None + in_access = state.add_access(self.output.data) + new_tasklet = state.add_tasklet('augassign', {'__in1', '__in2'}, {'__out'}, f"__out = {code}") + scal_name, scal_desc = sdfg.add_scalar('tmp', sdfg.arrays[self.output.data].dtype, transient=True, + find_new_name=True) + state.add_edge(self.tasklet, edge.src_conn, new_tasklet, '__in1', Memlet.from_array(scal_name, scal_desc)) + state.add_memlet_path(in_access, map_entry, new_tasklet, memlet=copy.deepcopy(edge.data), dst_conn='__in2') + state.add_edge(new_tasklet, '__out', self.map_exit, edge.dst_conn, edge.data) + state.remove_edge(edge) + \ No newline at end of file diff --git a/tests/transformations/wcr_to_augassign_test.py b/tests/transformations/wcr_to_augassign_test.py new file mode 100644 index 0000000000..111ef135eb --- /dev/null +++ b/tests/transformations/wcr_to_augassign_test.py @@ -0,0 +1,45 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +""" Tests WCRToAugAssign. """ + +import dace +import numpy as np +from dace.transformation.dataflow import WCRToAugAssign + + +def test_tasklet(): + + @dace.program + def test(): + a = np.zeros((10,)) + for i in dace.map[1:9]: + a[i-1] += 1 + return a + + sdfg = test.to_sdfg(simplify=False) + sdfg.apply_transformations(WCRToAugAssign) + + val = sdfg() + ref = test.f() + assert(np.allclose(val, ref)) + + +def test_mapped_tasklet(): + + @dace.program + def test(): + a = np.zeros((10,)) + for i in dace.map[1:9]: + a[i-1] += 1 + return a + + sdfg = test.to_sdfg(simplify=True) + sdfg.apply_transformations(WCRToAugAssign) + + val = sdfg() + ref = test.f() + assert(np.allclose(val, ref)) + + +if __name__ == '__main__': + test_tasklet() + test_mapped_tasklet()