Skip to content

Commit

Permalink
WCRToAugAssign (#1098)
Browse files Browse the repository at this point in the history
Introduces a transformation that converts WCR to an augmented
assignment.

---------

Co-authored-by: Philipp Schaad <[email protected]>
  • Loading branch information
alexnick83 and phschaad authored Oct 3, 2024
1 parent 74a31cb commit 51871a7
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 7 deletions.
2 changes: 1 addition & 1 deletion dace/transformation/dataflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
RedundantArrayCopying3)
from .merge_arrays import InMergeArrays, OutMergeArrays, MergeSourceSinkArrays
from .prune_connectors import PruneConnectors, PruneSymbols
from .wcr_conversion import AugAssignToWCR
from .wcr_conversion import AugAssignToWCR, WCRToAugAssign
from .tasklet_fusion import TaskletFusion
from .trivial_tasklet_elimination import TrivialTaskletElimination

Expand Down
91 changes: 85 additions & 6 deletions dace/transformation/dataflow/wcr_conversion.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved.
""" Transformations to convert subgraphs to write-conflict resolutions. """
import ast
import copy
import re
import copy
from dace import registry, nodes, dtypes, Memlet
from dace.transformation import transformation, helpers as xfh
from dace.sdfg import graph as gr, utils as sdutil
from dace import SDFG, SDFGState
from dace.sdfg.state import StateSubgraphView
from dace import nodes, dtypes, Memlet
from dace.frontend.python import astutils
from dace.transformation import transformation
from dace.sdfg import utils as sdutil
from dace import Memlet, SDFG, SDFGState
from dace.transformation import helpers
from dace.sdfg.propagation import propagate_memlets_state

Expand Down Expand Up @@ -268,3 +269,81 @@ def apply(self, state: SDFGState, sdfg: SDFG):
outedge.data.wcr = f'lambda a,b: a {op} b'
# At this point we are leading to an access node again and can
# traverse further up


class WCRToAugAssign(transformation.SingleStateTransformation):
"""
Converts a tasklet with a write-conflict resolution to an augmented assignment subgraph (e.g., "a = a + b").
"""
tasklet = transformation.PatternNode(nodes.Tasklet)
output = transformation.PatternNode(nodes.AccessNode)
map_exit = transformation.PatternNode(nodes.MapExit)

_EXPRESSIONS = ['+', '-', '*', '^', '%'] #, '/']
_EXPR_MAP = {'-': ('+', '-({expr})'), '/': ('*', '((decltype({expr}))1)/({expr})')}
_PYOP_MAP = {ast.Add: '+', ast.Sub: '-', ast.Mult: '*', ast.BitXor: '^', ast.Mod: '%', ast.Div: '/'}

@classmethod
def expressions(cls):
return [
sdutil.node_path_graph(cls.tasklet, cls.output),
sdutil.node_path_graph(cls.tasklet, cls.map_exit, cls.output)
]

def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
if expr_index == 0:
edges = graph.edges_between(self.tasklet, self.output)
else:
edges = graph.edges_between(self.tasklet, self.map_exit)
if len(edges) != 1:
return False
if edges[0].data.wcr is None:
return False

# If the access subset on the WCR edge is overapproximated (i.e., the access may be dynamic), we do not support
# swapping to an augmented assignment pattern with this transformation.
if edges[0].data.subset.num_elements() > edges[0].data.volume or edges[0].data.dynamic is True:
return False

return True

def apply(self, state: SDFGState, sdfg: SDFG):
if self.expr_index == 0:
edge = state.edges_between(self.tasklet, self.output)[0]
wcr = ast.parse(edge.data.wcr).body[0].value.body
if isinstance(wcr, ast.BinOp):
wcr.left.id = '__in1'
wcr.right.id = '__in2'
code = astutils.unparse(wcr)
else:
raise NotImplementedError
edge.data.wcr = None
in_access = state.add_access(self.output.data)
new_tasklet = state.add_tasklet('augassign', {'__in1', '__in2'}, {'__out'}, f"__out = {code}")
scal_name, scal_desc = sdfg.add_scalar('tmp', sdfg.arrays[self.output.data].dtype, transient=True,
find_new_name=True)
state.add_edge(self.tasklet, edge.src_conn, new_tasklet, '__in1', Memlet.from_array(scal_name, scal_desc))
state.add_edge(in_access, None, new_tasklet, '__in2', copy.deepcopy(edge.data))
state.add_edge(new_tasklet, '__out', self.output, edge.dst_conn, edge.data)
state.remove_edge(edge)
else:
edge = state.edges_between(self.tasklet, self.map_exit)[0]
map_entry = state.entry_node(self.map_exit)
wcr = ast.parse(edge.data.wcr).body[0].value.body
if isinstance(wcr, ast.BinOp):
wcr.left.id = '__in1'
wcr.right.id = '__in2'
code = astutils.unparse(wcr)
else:
raise NotImplementedError
for e in state.memlet_path(edge):
e.data.wcr = None
in_access = state.add_access(self.output.data)
new_tasklet = state.add_tasklet('augassign', {'__in1', '__in2'}, {'__out'}, f"__out = {code}")
scal_name, scal_desc = sdfg.add_scalar('tmp', sdfg.arrays[self.output.data].dtype, transient=True,
find_new_name=True)
state.add_edge(self.tasklet, edge.src_conn, new_tasklet, '__in1', Memlet.from_array(scal_name, scal_desc))
state.add_memlet_path(in_access, map_entry, new_tasklet, memlet=copy.deepcopy(edge.data), dst_conn='__in2')
state.add_edge(new_tasklet, '__out', self.map_exit, edge.dst_conn, edge.data)
state.remove_edge(edge)

45 changes: 45 additions & 0 deletions tests/transformations/wcr_to_augassign_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved.
""" Tests WCRToAugAssign. """

import dace
import numpy as np
from dace.transformation.dataflow import WCRToAugAssign


def test_tasklet():

@dace.program
def test():
a = np.zeros((10,))
for i in dace.map[1:9]:
a[i-1] += 1
return a

sdfg = test.to_sdfg(simplify=False)
sdfg.apply_transformations(WCRToAugAssign)

val = sdfg()
ref = test.f()
assert(np.allclose(val, ref))


def test_mapped_tasklet():

@dace.program
def test():
a = np.zeros((10,))
for i in dace.map[1:9]:
a[i-1] += 1
return a

sdfg = test.to_sdfg(simplify=True)
sdfg.apply_transformations(WCRToAugAssign)

val = sdfg()
ref = test.f()
assert(np.allclose(val, ref))


if __name__ == '__main__':
test_tasklet()
test_mapped_tasklet()

0 comments on commit 51871a7

Please sign in to comment.