Skip to content

Commit

Permalink
Applied Edoardo's comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
philip-paul-mueller committed Jul 31, 2024
1 parent 03f4b1a commit fd2366f
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def gt_auto_optimize(
Something along the line "Fuse if operational intensity goes up, but
not if we have too much internal space (register pressure).
- Create a custom array elimination pass that honors rule 1.
- Check if a pipeline could be used to speed up some computations.
"""
device = dace.DeviceType.GPU if gpu else dace.DeviceType.CPU

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class KBlocking(transformation.SingleStateTransformation):
dimension, that is commonly called "k", but identified with `block_dim`.
All dimensions except `k` are unaffected by this transformation. In the outer
Map the will replace the `k` range, currently `k = 0:N`, with
Map will be replace the `k` range, currently `k = 0:N`, with
`__coarse_k = 0:N:B`, where `N` is the original size of the range and `B`
is the block size, passed as `blocking_size`. The transformation also handles the
case if `N % B != 0`.
Expand Down Expand Up @@ -231,7 +231,7 @@ def apply(
# of the node in one go.
relocated_nodes.add(edge_dst)

# In order to be useful we have to temporary store the data the
# In order to be useful we have to temporarily store the data the
# independent node generates
assert graph.out_degree(edge_dst) == 1 # TODO(phimuell): Lift
if isinstance(edge_dst, nodes.AccessNode):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class BaseMapPromoter(transformation.SingleStateTransformation):
The transformation operates on two Maps, first the "source map". This map
describes the Map that should be used as template. The second one is "map to
promote". After the transformation the "map to promote" will have the same
map parameter than the "source map" has.
map parameter as the "source map" has.
In order to properly work, the parameters of "source map" must be a strict
superset of the ones of "map to promote". Furthermore, this transformation
Expand All @@ -52,6 +52,8 @@ class BaseMapPromoter(transformation.SingleStateTransformation):
promote_vertical: If `True` promote vertical dimensions; `True` by default.
promote_local: If `True` promote local dimensions; `True` by default.
promote_horizontal: If `True` promote horizontal dimensions; `False` by default.
promote_all: Do not impose any restriction on what to promote. The only
reasonable value is `True` or `None`.
Note:
This ignores tiling.
Expand Down Expand Up @@ -311,9 +313,9 @@ class SerialMapPromoter(BaseMapPromoter):
def expressions(cls) -> Any:
"""Get the match expressions.
The function generates two different match expression. The first match
describes the case where the top map must be promoted, while the second
case is the second/lower map must be promoted.
The function generates two match expressions. The first match describes
the case where the top map must be promoted, while the second case is
the second/lower map must be promoted.
"""
return [
dace.sdfg.utils.node_path_graph(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ class SerialMapFusion(map_fusion_helper.MapFusionHelper):
Things that are improved, compared to the native DaCe implementation:
- Nested Maps.
- Temporary arrays and the correct propagation of their Memelts.
- Temporary arrays and the correct propagation of their Memlets.
- Top Maps that have multiple outputs.
Conceptually this transformation removes the exit of the first or upper map
and the entry of the lower or second map and then rewriting the connections
appropriate.
and the entry of the lower or second map and then rewrites the connections
appropriately.
This transformation assumes that an SDFG obeys the structure that is outlined
[here](https://hackmd.io/klvzLnzMR6GZBWtRU8HbDg#Requirements-on-SDFG). For that
Expand Down Expand Up @@ -70,9 +70,9 @@ def expressions(cls) -> Any:
The transformation matches the exit node of the top Map that is connected to
an access node that again is connected to the entry node of the second Map.
An important note is, that the transformation operates not just on these nodes,
but more or less anything that has an outgoing connection of the first Map,
and is connected to the second map.
An important note is, that the transformation operates not just on the
matched nodes, but more or less on anything that has an incoming connection
from the first Map or an outgoing connection to the second Map entry.
"""
return [dace.sdfg.utils.node_path_graph(cls.map_exit1, cls.access_node, cls.map_entry2)]

Expand All @@ -86,7 +86,7 @@ def can_be_applied(
"""Tests if the matched Maps can be merged.
The two Maps are mergeable iff:
- The `can_be_fused()` of the base succeed, which checks some basic constrains.
- The `can_be_fused()` of the base succeed, which checks some basic constraints.
- The decomposition exists and at least one of the intermediate sets
is not empty.
"""
Expand All @@ -102,7 +102,8 @@ def can_be_applied(
return False

# Two maps can be serially fused if the node decomposition exists and
# there at least one of the intermediate output sets is not empty.
# at least one of the intermediate output sets is not empty. The state
# of the pure outputs is irrelevant for serial map fusion.
output_partition = self.partition_first_outputs(
state=graph,
sdfg=sdfg,
Expand All @@ -111,7 +112,8 @@ def can_be_applied(
)
if output_partition is None:
return False
if not (output_partition[1] or output_partition[2]):
_, exclusive_outputs, shared_outputs = output_partition
if not (exclusive_outputs or shared_outputs):
return False
return True

Expand Down Expand Up @@ -216,12 +218,12 @@ def handle_intermediate_set(
state: The state in which the map is processed.
sdfg: The SDFG that should be optimized.
map_exit_1: The exit of the first/top map.
map_entry_1: The entry of the second map.
map_entry_2: The entry of the second map.
map_exit_2: The exit of the second map.
is_exclusive_set: If `True` `intermediate_outputs` is the exclusive set.
Notes:
Before the transformation the `state` does not be to be valid and
Before the transformation the `state` does not have to be valid and
after this function has run the state is (most likely) invalid.
Todo:
Expand Down Expand Up @@ -303,7 +305,7 @@ def handle_intermediate_set(
)
new_inter_node: nodes.AccessNode = state.add_access(new_inter_name)

# New we will reroute the output Memlet, thus it will no longer going
# New we will reroute the output Memlet, thus it will no longer pass
# through the Map exit but through the newly created intermediate.
# we will delete the previous edge later.
pre_exit_memlet: dace.Memlet = pre_exit_edge.data
Expand All @@ -314,11 +316,11 @@ def handle_intermediate_set(
assert pre_exit_memlet.data == inter_name
new_pre_exit_memlet.data = new_inter_name

# Now we have to fix the subset of the Memlet.
# Before the subset of the Memlet dependent on the Map variables,
# Now we have to modify the subset of the Memlet.
# Before the subset of the Memlet was dependent on the Map variables,
# however, this is no longer the case, as we removed them. This change
# has to be reflected in the Memlet.
# NOTE: Assert above ensures that the bellow is correct.
# NOTE: Assert above ensures that the below is correct.
new_pre_exit_memlet.replace(memlet_repl)
if is_scalar:
new_pre_exit_memlet.subset = "0"
Expand Down Expand Up @@ -350,14 +352,14 @@ def handle_intermediate_set(
producer_edge.data.replace(memlet_repl)
if is_scalar:
producer_edge.data.dst_subset = "0"
else:
if producer_edge.data.dst_subset is not None:
producer_edge.data.dst_subset.pop(squeezed_dims)
elif producer_edge.data.dst_subset is not None:
producer_edge.data.dst_subset.pop(squeezed_dims)

# Now after we have handled the input of the new intermediate node,
# we must handle its output. For this we have to "inject" the temporary
# in the second map. We do this by finding the input connectors on the
# map entry, such that we know where we have to reroute inside the Map.
# we must handle its output. For this we have to "inject" the newly
# created intermediate into the second map. We do this by finding
# the input connectors on the map entry, such that we know where we
# have to reroute inside the Map.
# NOTE: Assumes that map (if connected is the direct neighbour).
conn_names: set[str] = set()
for inter_node_out_edge in state.out_edges(inter_node):
Expand Down Expand Up @@ -386,7 +388,7 @@ def handle_intermediate_set(
# Memlet and the correctness of the code below.
new_inner_memlet = copy.deepcopy(inner_edge.data)
new_inner_memlet.replace(memlet_repl)
new_inner_memlet.data = new_inter_name # Because of the assert above, this will not chenge the direction.
new_inner_memlet.data = new_inter_name # Because of the assert above, this will not change the direction.

# Now remove the old edge, that started the second map entry.
# Also add the new edge that started at the new intermediate.
Expand All @@ -402,9 +404,8 @@ def handle_intermediate_set(
# Now we do subset modification to ensure that nothing failed.
if is_scalar:
new_inner_memlet.src_subset = "0"
else:
if new_inner_memlet.src_subset is not None:
new_inner_memlet.src_subset.pop(squeezed_dims)
elif new_inner_memlet.src_subset is not None:
new_inner_memlet.src_subset.pop(squeezed_dims)

# Now clean the Memlets of that tree to use the new intermediate node.
for consumer_tree in state.memlet_tree(new_inner_edge).traverse_children():
Expand All @@ -413,9 +414,8 @@ def handle_intermediate_set(
consumer_edge.data.data = new_inter_name
if is_scalar:
consumer_edge.data.src_subset = "0"
else:
if consumer_edge.data.subset is not None:
consumer_edge.data.subset.pop(squeezed_dims)
elif consumer_edge.data.subset is not None:
consumer_edge.data.subset.pop(squeezed_dims)

# The edge that leaves the second map entry was already deleted.
# We will now delete the edges that brought the data.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,19 @@

"""Common functionality for the transformations/optimization pipeline."""

from typing import Iterable
from typing import Iterable, Union

import dace
from dace.sdfg import graph as dace_graph, nodes
from dace.sdfg import graph as dace_graph, nodes as dace_nodes


def is_nested_sdfg(sdfg: dace.SDFG) -> bool:
"""Tests if `sdfg` is a neseted sdfg."""
def is_nested_sdfg(
sdfg: Union[dace.SDFG, dace.SDFGState, dace_nodes.NestedSDFG],
) -> bool:
"""Tests if `sdfg` is a NestedSDFG."""
if isinstance(sdfg, dace.SDFGState):
sdfg = sdfg.parent
if isinstance(sdfg, dace.nodes.NestedSDFG):
if isinstance(sdfg, dace_nodes.NestedSDFG):
return True
elif isinstance(sdfg, dace.SDFG):
if sdfg.parent_nsdfg_node is not None:
Expand All @@ -36,17 +38,17 @@ def is_nested_sdfg(sdfg: dace.SDFG) -> bool:

def all_nodes_between(
graph: dace.SDFG | dace.SDFGState,
begin: nodes.Node,
end: nodes.Node,
begin: dace_nodes.Node,
end: dace_nodes.Node,
reverse: bool = False,
) -> set[nodes.Node] | None:
) -> set[dace_nodes.Node] | None:
"""Find all nodes that are reachable from `begin` but bound by `end`.
Essentially the function starts a DFS at `begin`, which is never part of the
returned set, if at a node an edge is found that lead to `end`, the function
will ignore this edge. However, it will find every node that is reachable
from `begin` that is reachable by a path that does not visit `end`.
In case `end` is never found the function will return `None`.
Essentially the function starts a DFS at `begin`. If an edge is found that lead
to `end`, this edge is ignored. It will thus found any node that is reachable
from `begin` by a path that does not involve `end`. The returned set will
never contain `end` nor `begin`. In case `end` is never found the function
will return `None`.
If `reverse` is set to `True` the function will start exploring at `end` and
follows the outgoing edges, i.e. the meaning of `end` and `begin` are swapped.
Expand All @@ -58,25 +60,24 @@ def all_nodes_between(
reverse: Perform a backward DFS.
Notes:
- The returned set will never contain the node `begin`.
- The returned set will also contain the nodes of path that starts at
`begin` and ends at a node that is not `end`.
"""

def next_nodes(node: nodes.Node) -> Iterable[nodes.Node]:
def next_nodes(node: dace_nodes.Node) -> Iterable[dace_nodes.Node]:
if reverse:
return (edge.src for edge in graph.in_edges(node))
return (edge.dst for edge in graph.out_edges(node))

if reverse:
begin, end = end, begin

to_visit: list[nodes.Node] = [begin]
seen: set[nodes.Node] = set()
to_visit: list[dace_nodes.Node] = [begin]
seen: set[dace_nodes.Node] = set()
found_end: bool = False

while len(to_visit) > 0:
n: nodes.Node = to_visit.pop()
n: dace_nodes.Node = to_visit.pop()
if n == end:
found_end = True
continue
Expand All @@ -94,10 +95,10 @@ def next_nodes(node: nodes.Node) -> Iterable[nodes.Node]:

def find_downstream_consumers(
state: dace.SDFGState,
begin: nodes.Node | dace_graph.MultiConnectorEdge[dace.Memlet],
begin: dace_nodes.Node | dace_graph.MultiConnectorEdge[dace.Memlet],
only_tasklets: bool = False,
reverse: bool = False,
) -> set[tuple[nodes.Node, dace_graph.MultiConnectorEdge[dace.Memlet]]]:
) -> set[tuple[dace_nodes.Node, dace_graph.MultiConnectorEdge[dace.Memlet]]]:
"""Find all downstream connectors of `begin`.
A consumer, in for this function, is any node that is neither an entry nor
Expand All @@ -123,25 +124,25 @@ def find_downstream_consumers(
else:
to_visit = list(state.out_edges(begin))
seen: set[dace_graph.MultiConnectorEdge[dace.Memlet]] = set()
found: set[tuple[nodes.Node, dace_graph.MultiConnectorEdge[dace.Memlet]]] = set()
found: set[tuple[dace_nodes.Node, dace_graph.MultiConnectorEdge[dace.Memlet]]] = set()

while len(to_visit) != 0:
curr_edge: dace_graph.MultiConnectorEdge[dace.Memlet] = to_visit.pop()
next_node: nodes.Node = curr_edge.src if reverse else curr_edge.dst
next_node: dace_nodes.Node = curr_edge.src if reverse else curr_edge.dst

if curr_edge in seen:
continue
seen.add(curr_edge)

if isinstance(next_node, (nodes.MapEntry, nodes.MapExit)):
if isinstance(next_node, (dace_nodes.MapEntry, dace_nodes.MapExit)):
if reverse:
target_conn = curr_edge.src_conn[4:]
new_edges = state.in_edges_by_connector(curr_edge.src, "IN_" + target_conn)
else:
# In forward mode a Map entry could also mean the definition of a
# dynamic map range.
if (not curr_edge.dst_conn.startswith("IN_")) and isinstance(
next_node, nodes.MapEntry
next_node, dace_nodes.MapEntry
):
# This edge defines a dynamic map range, which is a consumer
if not only_tasklets:
Expand All @@ -152,7 +153,7 @@ def find_downstream_consumers(
to_visit.extend(new_edges)
del new_edges
else:
if only_tasklets and (not isinstance(next_node, nodes.Tasklet)):
if only_tasklets and (not isinstance(next_node, dace_nodes.Tasklet)):
continue
found.add((next_node, curr_edge))

Expand All @@ -161,9 +162,9 @@ def find_downstream_consumers(

def find_upstream_producers(
state: dace.SDFGState,
begin: nodes.Node | dace_graph.MultiConnectorEdge[dace.Memlet],
begin: dace_nodes.Node | dace_graph.MultiConnectorEdge[dace.Memlet],
only_tasklets: bool = False,
) -> set[tuple[nodes.Node, dace_graph.MultiConnectorEdge[dace.Memlet]]]:
) -> set[tuple[dace_nodes.Node, dace_graph.MultiConnectorEdge[dace.Memlet]]]:
"""Same as `find_downstream_consumers()` but with `reverse` set to `True`."""
return find_downstream_consumers(
state=state,
Expand Down

0 comments on commit fd2366f

Please sign in to comment.