Skip to content

Commit

Permalink
Analysis passes for access range analysis (#1484)
Browse files Browse the repository at this point in the history
Adds two analysis passes to help with analyzing data access sets: access
ranges and Reference sources. To enable constructing sets of memlets,
this PR also reintroduces data descriptor names to memlet hashes.
  • Loading branch information
tbennun authored Dec 18, 2023
1 parent 09d37e9 commit bf56e4d
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 10 deletions.
11 changes: 4 additions & 7 deletions dace/memlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,7 @@ def to_json(self):
attrs['is_data_src'] = self._is_data_src

# Fill in legacy (DEPRECATED) values for backwards compatibility
attrs['num_accesses'] = \
str(self.volume) if not self.dynamic else -1
attrs['num_accesses'] = str(self.volume) if not self.dynamic else -1

return {"type": "Memlet", "attributes": attrs}

Expand Down Expand Up @@ -421,13 +420,11 @@ def from_array(dataname, datadesc, wcr=None):
return Memlet.simple(dataname, rng, wcr_str=wcr)

def __hash__(self):
return hash((self.volume, self.src_subset, self.dst_subset, str(self.wcr)))
return hash((self.data, self.volume, self.src_subset, self.dst_subset, str(self.wcr)))

def __eq__(self, other):
return all([
self.volume == other.volume, self.src_subset == other.src_subset, self.dst_subset == other.dst_subset,
self.wcr == other.wcr
])
return all((self.data == other.data, self.volume == other.volume, self.src_subset == other.src_subset,
self.dst_subset == other.dst_subset, self.wcr == other.wcr))

def replace(self, repl_dict):
"""
Expand Down
80 changes: 79 additions & 1 deletion dace/transformation/passes/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from collections import defaultdict
from dace.transformation import pass_pipeline as ppl
from dace import SDFG, SDFGState, properties, InterstateEdge
from dace import SDFG, SDFGState, properties, InterstateEdge, Memlet, data as dt
from dace.sdfg.graph import Edge
from dace.sdfg import nodes as nd
from dace.sdfg.analysis import cfg
Expand Down Expand Up @@ -505,3 +505,81 @@ def apply_pass(self, top_sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[i
del result[desc][write]
top_result[sdfg.sdfg_id] = result
return top_result


@properties.make_properties
class AccessRanges(ppl.Pass):
"""
For each data descriptor, finds all memlets used to access it (read/write ranges).
"""

CATEGORY: str = 'Analysis'

def modifies(self) -> ppl.Modifies:
return ppl.Modifies.Nothing

def should_reapply(self, modified: ppl.Modifies) -> bool:
return modified & ppl.Modifies.Memlets

def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[str, Set[Memlet]]]:
"""
:return: A dictionary mapping each data descriptor name to a set of memlets.
"""
top_result: Dict[int, Dict[str, Set[Memlet]]] = dict()

for sdfg in top_sdfg.all_sdfgs_recursive():
result: Dict[str, Set[Memlet]] = defaultdict(set)
for state in sdfg.states():
for anode in state.data_nodes():
for e in state.all_edges(anode):
if e.dst is anode and e.dst_conn == 'set': # Skip reference sets
continue
if e.data.is_empty(): # Skip empty memlets
continue
# Find (hopefully propagated) root memlet
e = state.memlet_tree(e).root().edge
result[anode.data].add(e.data)
top_result[sdfg.sdfg_id] = result
return top_result


@properties.make_properties
class FindReferenceSources(ppl.Pass):
"""
For each Reference data descriptor, finds all memlets used to set it. If a Tasklet was used
to set the reference, the Tasklet is given as a source.
"""

CATEGORY: str = 'Analysis'

def modifies(self) -> ppl.Modifies:
return ppl.Modifies.Nothing

def should_reapply(self, modified: ppl.Modifies) -> bool:
return modified & ppl.Modifies.Memlets

def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[str, Set[Union[Memlet, nd.CodeNode]]]]:
"""
:return: A dictionary mapping each data descriptor name to a set of memlets.
"""
top_result: Dict[int, Dict[str, Set[Union[Memlet, nd.CodeNode]]]] = dict()

for sdfg in top_sdfg.all_sdfgs_recursive():
result: Dict[str, Set[Memlet]] = defaultdict(set)
reference_descs = set(k for k, v in sdfg.arrays.items() if isinstance(v, dt.Reference))
for state in sdfg.states():
for anode in state.data_nodes():
if anode.data not in reference_descs:
continue
for e in state.in_edges(anode):
if e.dst_conn != 'set':
continue
true_src = state.memlet_path(e)[0].src
if isinstance(true_src, nd.CodeNode):
# Code -> Reference
result[anode.data].add(true_src)
else:
# Array -> Reference
result[anode.data].add(e.data)
top_result[sdfg.sdfg_id] = result
return top_result
61 changes: 61 additions & 0 deletions tests/passes/access_ranges_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
""" Tests the AccessRanges analysis pass. """
import dace
from dace.transformation.passes.analysis import AccessRanges
import numpy as np

N = dace.symbol('N')


def test_simple():

@dace.program
def tester(A: dace.float64[N, N], B: dace.float64[20, 20]):
for i, j in dace.map[0:20, 0:N]:
A[i, j] = 1

sdfg = tester.to_sdfg(simplify=True)
ranges = AccessRanges().apply_pass(sdfg, {})
assert len(ranges) == 1 # Only one SDFG
ranges = ranges[0]
assert len(ranges) == 1 # Only one array is accessed

# Construct write memlet
memlet = dace.Memlet('A[0:20, 0:N]')
memlet._is_data_src = False

assert ranges['A'] == {memlet}


def test_simple_ranges():

@dace.program
def tester(A: dace.float64[N, N], B: dace.float64[20, 20]):
A[:, :] = 0
A[1:21, 1:21] = B
A[0, 0] += 1

sdfg = tester.to_sdfg(simplify=True)
ranges = AccessRanges().apply_pass(sdfg, {})
assert len(ranges) == 1 # Only one SDFG
ranges = ranges[0]
assert len(ranges) == 2 # Two arrays are accessed

assert len(ranges['B']) == 1
assert next(iter(ranges['B'])).src_subset == dace.subsets.Range([(0, 19, 1), (0, 19, 1)])

# Construct read/write memlets
memlet1 = dace.Memlet('A[0:N, 0:N]')
memlet1._is_data_src = False
memlet2 = dace.Memlet('A[1:21, 1:21] -> 0:20, 0:20')
memlet2._is_data_src = False
memlet3 = dace.Memlet('A[0, 0]')
memlet4 = dace.Memlet('A[0, 0]')
memlet4._is_data_src = False

assert ranges['A'] == {memlet1, memlet2, memlet3, memlet4}


if __name__ == '__main__':
test_simple()
test_simple_ranges()
21 changes: 19 additions & 2 deletions tests/sdfg/reference_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved.
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
""" Tests the use of Reference data descriptors. """
import dace
from dace.transformation.passes.analysis import FindReferenceSources
import numpy as np


def test_reference_branch():
def _create_branch_sdfg():
sdfg = dace.SDFG('refbranch')
sdfg.add_array('A', [20], dace.float64)
sdfg.add_array('B', [20], dace.float64)
Expand All @@ -29,6 +30,11 @@ def test_reference_branch():
r = finish.add_read('ref')
w = finish.add_write('out')
finish.add_nedge(r, w, dace.Memlet('ref'))
return sdfg


def test_reference_branch():
sdfg = _create_branch_sdfg()

A = np.random.rand(20)
B = np.random.rand(20)
Expand All @@ -41,5 +47,16 @@ def test_reference_branch():
assert np.allclose(out, A)


def test_reference_sources_pass():
sdfg = _create_branch_sdfg()
sources = FindReferenceSources().apply_pass(sdfg, {})
assert len(sources) == 1 # There is only one SDFG
sources = sources[0]
assert len(sources) == 1 and 'ref' in sources # There is one reference
sources = sources['ref']
assert sources == {dace.Memlet('A[0:20]', volume=1), dace.Memlet('B[0:20]', volume=1)}


if __name__ == '__main__':
test_reference_branch()
test_reference_sources_pass()

0 comments on commit bf56e4d

Please sign in to comment.