Skip to content

Commit

Permalink
Updated and fixed the MapExpansion transformation. (#1532)
Browse files Browse the repository at this point in the history
Before the transformation ignored the tiling parameter of the range,
this was fixed. The transformation is now also able to limit the
expansion. Thus it is possible to only expand the first k dimensions,
the remaining dimensions will be kept in a multi dimensional map.

This is a feature that will be need in GT4Py.
  • Loading branch information
philip-paul-mueller authored Feb 27, 2024
1 parent 608aa80 commit ba1587e
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 14 deletions.
87 changes: 74 additions & 13 deletions dace/transformation/dataflow/map_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import copy
import dace
from dace import dtypes, subsets, symbolic
from dace.properties import EnumProperty, make_properties
from dace.properties import EnumProperty, make_properties, Property
from dace.sdfg import nodes
from dace.sdfg import utils as sdutil
from dace.sdfg.graph import OrderedMultiDiConnectorGraph
Expand All @@ -18,8 +18,9 @@
class MapExpansion(pm.SingleStateTransformation):
""" Implements the map-expansion pattern.
Map-expansion takes an N-dimensional map and expands it to N
unidimensional maps.
Map-expansion takes an N-dimensional map and expands it.
It will generate the k nested unidimensional map and a (N-k)-dimensional inner most map.
If k is not specified all maps are expanded.
New edges abide by the following rules:
1. If there are no edges coming from the outside, use empty memlets
Expand All @@ -33,6 +34,11 @@ class MapExpansion(pm.SingleStateTransformation):
dtype=dtypes.ScheduleType,
default=dtypes.ScheduleType.Sequential,
allow_none=True)
expansion_limit = Property(desc="How many unidimensional maps will be creaed, known as k. "
"If None, the default no limit is in place.",
dtype=int,
allow_none=True,
default=None)

@classmethod
def expressions(cls):
Expand All @@ -43,22 +49,77 @@ def can_be_applied(self, graph: dace.SDFGState, expr_index: int, sdfg: dace.SDFG
# includes an N-dimensional map, with N greater than one.
return self.map_entry.map.get_param_num() > 1

def generate_new_maps(self,
current_map: nodes.Map):
if self.expansion_limit is None:
full_expand = True
elif isinstance(self.expansion_limit, int):
full_expand = False
if self.expansion_limit <= 0: # These are invalid, so we make a full expansion
full_expand = True
elif (self.map_entry.map.get_param_num() - self.expansion_limit) <= 1:
full_expand = True
else:
raise TypeError(f"Does not know how to handle type {type(self.expansion_limit).__name__}")

inner_schedule = self.inner_schedule or current_map.schedule
if full_expand:
new_maps = [
nodes.Map(
current_map.label + '_' + str(param), [param],
subsets.Range([param_range]),
schedule=inner_schedule if dim != 0 else current_map.schedule)
for dim, param, param_range in zip(range(len(current_map.params)), current_map.params, current_map.range)
]
for i, new_map in enumerate(new_maps):
new_map.range.tile_sizes[0] = current_map.range.tile_sizes[i]

else:
k = self.expansion_limit
new_maps: list[nodes.Map] = []

# Unidimensional maps
for dim in range(0, k):
dim_param = current_map.params[dim]
dim_range = current_map.range.ranges[dim]
dim_tile = current_map.range.tile_sizes[dim]
new_maps.append(
nodes.Map(
current_map.label + '_' + str(dim_param),
[dim_param],
subsets.Range([dim_range]),
schedule=inner_schedule if dim != 0 else current_map.schedule ))
new_maps[-1].range.tile_sizes[0] = dim_tile

# Multidimensional maps
mdim_params = current_map.params[k:]
mdim_ranges = current_map.range.ranges[k:]
mdim_tiles = current_map.range.tile_sizes[k:]
new_maps.append(
nodes.Map(
current_map.label, # The original name
mdim_params,
mdim_ranges,
schedule=inner_schedule ))
new_maps[-1].range.tile_sizes = mdim_tiles
return new_maps

def apply(self, graph: dace.SDFGState, sdfg: dace.SDFG):
# Extract the map and its entry and exit nodes.
map_entry = self.map_entry
map_exit = graph.exit_node(map_entry)
current_map = map_entry.map

# Create new maps
inner_schedule = self.inner_schedule or current_map.schedule
new_maps = [
nodes.Map(current_map.label + '_' + str(param), [param],
subsets.Range([param_range]),
schedule=inner_schedule)
for param, param_range in zip(current_map.params[1:], current_map.range[1:])
]
current_map.params = [current_map.params[0]]
current_map.range = subsets.Range([current_map.range[0]])
# Generate the new maps that we should use.
new_maps = self.generate_new_maps(current_map)

if not new_maps: # No changes should be made -> noops
return

# Reuse the map that is already existing for the first one.
current_map.params = new_maps[0].params
current_map.range = new_maps[0].range
new_maps.pop(0)

# Create new map entries and exits
entries = [nodes.MapEntry(new_map) for new_map in new_maps]
Expand Down
53 changes: 52 additions & 1 deletion tests/transformations/map_expansion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def toexpand(B: dace.float64[4, 4]):
continue

# (Fast) MapExpansion should not add memlet paths for each memlet to a tasklet
if sdfg.start_state.entry_node(node) is None:
if state.entry_node(node) is None:
assert state.in_degree(node) == 0
assert state.out_degree(node) == 1
assert len(node.out_connectors) == 0
Expand Down Expand Up @@ -113,7 +113,58 @@ def mymap(i: _[0:20], j: _[rng[0]:rng[1]], k: _[0:5]):
print('Difference:', diff2)
assert (diff <= 1e-5) and (diff2 <= 1e-5)

def test_expand_with_limits():
@dace.program
def expansion(A: dace.float32[20, 30, 5]):
@dace.map
def mymap(i: _[0:20], j: _[0:30], k: _[0:5]):
a << A[i, j, k]
b >> A[i, j, k]
b = a * 2

A = np.random.rand(20, 30, 5).astype(np.float32)
expected = A.copy()
expected *= 2

sdfg = expansion.to_sdfg()
sdfg.simplify()
sdfg(A=A)
diff = np.linalg.norm(A - expected)
print('Difference (before transformation):', diff)

sdfg.apply_transformations(MapExpansion, options=dict(expansion_limit=1))

map_entries = set()
state = sdfg.start_state
for node in state.nodes():
if not isinstance(node, dace.nodes.MapEntry):
continue

if state.entry_node(node) is None:
assert state.in_degree(node) == 1
assert state.out_degree(node) == 1
assert len(node.out_connectors) == 1
assert len(node.map.range.ranges) == 1
assert node.map.range.ranges[0][1] - node.map.range.ranges[0][0] + 1 == 20
else:
assert state.in_degree(node) == 1
assert state.out_degree(node) == 1
assert len(node.out_connectors) == 1
assert len(node.map.range.ranges) == 2
assert list(map(lambda x: x[1] - x[0] + 1, node.map.range.ranges)) == [30, 5]

map_entries.add(node)

sdfg(A=A)
expected *= 2
diff2 = np.linalg.norm(A - expected)
print('Difference:', diff2)
assert (diff <= 1e-5) and (diff2 <= 1e-5)
assert len(map_entries) == 2


if __name__ == '__main__':
test_expand_with_inputs()
test_expand_without_inputs()
test_expand_without_dynamic_inputs()
test_expand_with_limits()

0 comments on commit ba1587e

Please sign in to comment.