Skip to content

Commit

Permalink
Maps With Zero Parameters (#1649)
Browse files Browse the repository at this point in the history
Before a map without any parameter was considered not invalid, it would
pass validation, but most likly compilation would fail (except it is a
serial map).
This PR adds:
- Disallows such maps.
- Fixes a small bug in the constructor of the `Map` object.
- It updates `TrivialMapElimination` such that it correctly handles the
case if it has dynamic map ranges.
- It removes the `TrivialMapRangeElimination` transformation as it is
redundant and contained a bug.

---------

Co-authored-by: Tal Ben-Nun <[email protected]>
  • Loading branch information
philip-paul-mueller and tbennun authored Sep 15, 2024
1 parent 95c65be commit d31dd7b
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 147 deletions.
9 changes: 7 additions & 2 deletions dace/sdfg/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,7 +932,7 @@ def __init__(self,
self.label = label
self.schedule = schedule
self.unroll = unroll
self.collapse = 1
self.collapse = collapse
self.params = params
self.range = ndrange
self.debuginfo = debuginfo
Expand All @@ -948,7 +948,12 @@ def __repr__(self):

def validate(self, sdfg, state, node):
if not dtypes.validate_name(self.label):
raise NameError('Invalid map name "%s"' % self.label)
raise NameError(f'Invalid map name "{self.label}"')
if self.get_param_num() == 0:
raise ValueError('There must be at least one parameter in a map.')
if self.get_param_num() != self.range.dims():
raise ValueError(f'There are {self.get_param_num()} parameters but the range'
f' has {self.range.dims()} dimensions.')

def get_param_num(self):
""" Returns the number of map dimension parameters/symbols. """
Expand Down
1 change: 0 additions & 1 deletion dace/transformation/dataflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from .map_fission import MapFission
from .map_unroll import MapUnroll
from .trivial_map_elimination import TrivialMapElimination
from .trivial_map_range_elimination import TrivialMapRangeElimination
from .otf_map_fusion import OTFMapFusion

# Data movement
Expand Down
106 changes: 69 additions & 37 deletions dace/transformation/dataflow/trivial_map_elimination.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
""" Contains classes that implement the trivial-map-elimination transformation. """

import dace
from dace.sdfg import nodes
from dace.sdfg import utils as sdutil
from dace.transformation import transformation
Expand All @@ -10,12 +11,17 @@

@make_properties
class TrivialMapElimination(transformation.SingleStateTransformation):
""" Implements the Trivial-Map Elimination pattern.
"""Implements the Trivial-Map Elimination pattern.
Trivial-Map Elimination removes all dimensions containing only one
element from a map. If this applies to all ranges the map is removed.
Example: Map[i=0:I,j=7] -> Map[i=0:I]
Example: Map[i=0 ,j=7] -> nothing
Trivial-Map Elimination removes all dimensions containing only one
element from a map. If this applies to all ranges the map is removed.
Example: Map[i=0:I,j=7] -> Map[i=0:I]
Example: Map[i=0 ,j=7] -> nothing
There are some special cases:
- GPU maps are ignored as they are syntactically needed.
- If all map ranges are trivial and the map has dynamic map ranges,
the map is not removed, and one map parameter is retained.
"""

map_entry = transformation.PatternNode(nodes.MapEntry)
Expand All @@ -26,52 +32,78 @@ def expressions(cls):

def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
map_entry = self.map_entry
return any(r[0] == r[1] for r in map_entry.map.range)

if map_entry.map.schedule in (dace.dtypes.GPU_SCHEDULES + [dace.ScheduleType.GPU_Default]):
return False
if not any(r[0] == r[1] for r in map_entry.map.range):
return False
if (map_entry.map.get_param_num()) == 1 and (
any(not e.dst_conn.startswith("IN_") for e in graph.in_edges(map_entry) if not e.data.is_empty())
):
# There is only one map parameter and there are dynamic map ranges, this can not be resolved.
return False
return True

def apply(self, graph, sdfg):
map_entry = self.map_entry
map_exit = graph.exit_node(map_entry)

remaining_ranges = []
remaining_params = []
scope = graph.scope_subgraph(map_entry)
for map_param, ranges in zip(map_entry.map.params, map_entry.map.range.ranges):
map_from, map_to, _ = ranges
if map_from == map_to:
# Replace the map index variable with the value it obtained
scope = graph.scope_subgraph(map_entry)
scope.replace(map_param, map_from)
else:
remaining_ranges.append(ranges)
remaining_params.append(map_param)

map_entry.map.range.ranges = remaining_ranges
map_entry.map.range = remaining_ranges
map_entry.map.params = remaining_params

if len(remaining_ranges) == 0:
# Redirect map entry's out edges
write_only_map = True
for edge in graph.out_edges(map_entry):
path = graph.memlet_path(edge)
index = path.index(edge)

if not edge.data.is_empty():
# Add an edge directly from the previous source connector to the destination
graph.add_edge(path[index - 1].src, path[index - 1].src_conn, edge.dst, edge.dst_conn, edge.data)
write_only_map = False

# Redirect map exit's in edges.
for edge in graph.in_edges(map_exit):
path = graph.memlet_path(edge)
index = path.index(edge)

# Add an edge directly from the source to the next destination connector
if len(path) > index + 1:
graph.add_edge(edge.src, edge.src_conn, path[index + 1].dst, path[index + 1].dst_conn, edge.data)
if write_only_map:
outer_exit = path[index+1].dst
outer_entry = graph.entry_node(outer_exit)
if outer_entry is not None:
graph.add_edge(outer_entry, None, edge.src, None, Memlet())

# Remove map
graph.remove_nodes_from([map_entry, map_exit])
if len(remaining_params) != 0:
# There are still some dimensions left, so no need to remove the map
pass

elif any(not e.dst_conn.startswith("IN_") for e in graph.in_edges(map_entry) if not e.data.is_empty()):
# The map has dynamic map ranges, thus we can not remove the map.
# Instead we add one dimension back to keep the SDFG valid.
map_entry.map.params = [map_param]
map_entry.map.range = [ranges]

else:
# The map is empty and there are no dynamic map ranges.
self.remove_empty_map(graph, sdfg)

def remove_empty_map(self, graph, sdfg):
map_entry = self.map_entry
map_exit = graph.exit_node(map_entry)

# Redirect map entry's out edges
write_only_map = True
for edge in graph.out_edges(map_entry):
if edge.data.is_empty():
continue
# Add an edge directly from the previous source connector to the destination
path = graph.memlet_path(edge)
index = path.index(edge)
graph.add_edge(path[index - 1].src, path[index - 1].src_conn, edge.dst, edge.dst_conn, edge.data)
write_only_map = False

# Redirect map exit's in edges.
for edge in graph.in_edges(map_exit):
path = graph.memlet_path(edge)
index = path.index(edge)

# Add an edge directly from the source to the next destination connector
if len(path) > index + 1:
graph.add_edge(edge.src, edge.src_conn, path[index + 1].dst, path[index + 1].dst_conn, edge.data)
if write_only_map:
outer_exit = path[index+1].dst
outer_entry = graph.entry_node(outer_exit)
if outer_entry is not None:
graph.add_edge(outer_entry, None, edge.src, None, Memlet())

# Remove map
graph.remove_nodes_from([map_entry, map_exit])
48 changes: 0 additions & 48 deletions dace/transformation/dataflow/trivial_map_range_elimination.py

This file was deleted.

67 changes: 66 additions & 1 deletion tests/trivial_map_elimination_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,37 @@ def trivial_map_init_sdfg():
return sdfg


def trivial_map_with_dynamic_map_range_sdfg():
sdfg = dace.SDFG("trivial_map_with_dynamic_map_range")
state = sdfg.add_state("state1", is_start_block=True)

for name in "ABC":
sdfg.add_scalar(name, dtype=dace.float32, transient=False)
A, B, C = (state.add_access(name) for name in "ABC")

_, me, _ = state.add_mapped_tasklet(
name="MAP",
map_ranges=[("__i", "0:1"), ("__j", "10:11")],
inputs={"__in": dace.Memlet("A[0]")},
input_nodes={"A": A},
code="__out = __in + 1",
outputs={"__out": dace.Memlet("B[0]")},
output_nodes={"B": B},
external_edges=True,
)
state.add_edge(
C,
None,
me,
"dynamic_variable",
dace.Memlet("C[0]"),
)
me.add_in_connector("dynamic_variable")
sdfg.validate()

return sdfg


def trivial_map_pseudo_init_sdfg():
sdfg = dace.SDFG('trivial_map_range_expanded')
sdfg.add_array('A', [5, 1], dace.float64)
Expand Down Expand Up @@ -160,7 +191,6 @@ def test_can_be_applied(self):

count = graph.apply_transformations(TrivialMapElimination, validate=False, validate_all=False)
graph.validate()
#graph.view()

self.assertGreater(count, 0)

Expand Down Expand Up @@ -188,5 +218,40 @@ def test_reconnects_edges(self):
self.assertEqual(len(state.out_edges(map_entries[0])), 1)


class TrivialMapEliminationWithDynamicMapRangesTest(unittest.TestCase):
"""
Tests the case where the map has trivial ranges and dynamic map ranges.
"""

def test_can_be_applied(self):
graph = trivial_map_with_dynamic_map_range_sdfg()

count = graph.apply_transformations(TrivialMapElimination)
graph.validate()

self.assertEqual(count, 1)


def test_removes_map(self):
graph = trivial_map_with_dynamic_map_range_sdfg()

graph.apply_transformations(TrivialMapElimination)

state = graph.nodes()[0]
map_entries = [n for n in state.nodes() if isinstance(n, dace.sdfg.nodes.MapEntry)]
self.assertEqual(len(map_entries), 1)
self.assertEqual(state.in_degree(map_entries[0]), 2)
self.assertTrue(any(e.dst_conn.startswith("IN_") for e in state.in_edges(map_entries[0])))
self.assertTrue(any(not e.dst_conn.startswith("IN_") for e in state.in_edges(map_entries[0])))

def test_not_remove_dynamic_map_range(self):
graph = trivial_map_with_dynamic_map_range_sdfg()

count1 = graph.apply_transformations(TrivialMapElimination)
self.assertEqual(count1, 1)

count2 = graph.apply_transformations(TrivialMapElimination)
self.assertEqual(count2, 0)

if __name__ == '__main__':
unittest.main()
58 changes: 0 additions & 58 deletions tests/trivial_map_range_elimination_test.py

This file was deleted.

0 comments on commit d31dd7b

Please sign in to comment.