Skip to content

Commit

Permalink
This are the changes Edoardo implemented to fix some issues in the op…
Browse files Browse the repository at this point in the history
…timization pipeline when confronted with scans.
  • Loading branch information
edopao authored and philip-paul-mueller committed Dec 17, 2024
1 parent 06b398a commit d9218b6
Show file tree
Hide file tree
Showing 5 changed files with 233 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@
gt_simplify,
gt_substitute_compiletime_symbols,
)
from .strides import gt_change_transient_strides
from .strides import (
gt_change_transient_strides,
gt_map_strides_to_dst_nested_sdfg,
gt_map_strides_to_src_nested_sdfg,
)
from .util import gt_find_constant_arguments, gt_make_transients_persistent


Expand All @@ -59,6 +63,8 @@
"gt_gpu_transformation",
"gt_inline_nested_sdfg",
"gt_make_transients_persistent",
"gt_map_strides_to_dst_nested_sdfg",
"gt_map_strides_to_src_nested_sdfg",
"gt_reduce_distributed_buffering",
"gt_set_gpu_blocksize",
"gt_set_iteration_order",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def gt_gpu_transformation(

if try_removing_trivial_maps:
# In DaCe a Tasklet, outside of a Map, can not write into an _array_ that is on
# GPU. `sdfg.appyl_gpu_transformations()` will wrap such Tasklets in a Map. So
# GPU. `sdfg.apply_gpu_transformations()` will wrap such Tasklets in a Map. So
# we might end up with lots of these trivial Maps, each requiring a separate
# kernel launch. To prevent this we will combine these trivial maps, if
# possible, with their downstream maps.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -950,7 +950,7 @@ def _perform_pointwise_test(

def apply(
self,
graph: dace.SDFGState | dace.SDFG,
graph: dace.SDFGState,
sdfg: dace.SDFG,
) -> None:
# Removal
Expand All @@ -971,6 +971,9 @@ def apply(
tmp_out_subset = dace_subsets.Range.from_array(tmp_desc)
assert glob_in_subset is not None

# Recursively visit the nested SDFGs for mapping from inner to outer strides on the vertical dimension
gtx_transformations.gt_map_strides_to_src_nested_sdfg(sdfg, graph, map_to_tmp_edge, glob_ac)

# We now remove the `tmp` node, and create a new connection between
# the global node and the map exit.
new_map_to_glob_edge = graph.add_edge(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

import functools
from typing import Iterable

import dace
from dace import data as dace_data

Expand Down Expand Up @@ -64,6 +67,13 @@ def _gt_change_transient_strides_non_recursive_impl(
# we simply have to reverse the order.
new_stride_order = list(range(ndim))
desc.set_strides_from_layout(*new_stride_order)
for state in sdfg.states():
for data_node in state.data_nodes():
if data_node.data == top_level_transient:
for in_edge in state.in_edges(data_node):
gt_map_strides_to_src_nested_sdfg(sdfg, state, in_edge, data_node)
for out_edge in state.out_edges(data_node):
gt_map_strides_to_dst_nested_sdfg(sdfg, state, out_edge, data_node)


def _find_toplevel_transients(
Expand Down Expand Up @@ -97,3 +107,125 @@ def _find_toplevel_transients(
continue
top_level_transients.add(data)
return top_level_transients


def gt_map_strides_to_dst_nested_sdfg(
sdfg: dace.SDFG,
state: dace.SDFGState,
edge: dace.sdfg.graph.Edge,
outer_node: dace.nodes.AccessNode,
) -> None:
"""Propagates the strides of the given data node to the nested SDFGs on the edge destination.
This function will recursively visit the nested SDFGs connected to the given
data node and apply mapping from inner to outer strides.
Args:
sdfg: The SDFG to process.
state: The state where the data node is used.
edge: The edge that reads from the data node, the nested SDFG is expected as the destination.
outer_node: The data node whose strides should be propagated.
"""
if isinstance(edge.dst, dace.nodes.MapEntry):
# Find the destinaion of the edge entering the map entry node
map_entry_out_conn = edge.dst_conn.replace("IN_", "OUT_")
for edge_from_map_entry in state.out_edges_by_connector(edge.dst, map_entry_out_conn):
gt_map_strides_to_dst_nested_sdfg(sdfg, state, edge_from_map_entry, outer_node)
return

if not isinstance(edge.dst, dace.nodes.NestedSDFG):
return

outer_strides = outer_node.desc(sdfg).strides
_gt_map_strides_to_nested_sdfg(edge.dst, edge.dst_conn, edge.data, outer_strides)

for inner_state in edge.dst.sdfg.states():
for inner_node in inner_state.data_nodes():
if inner_node.data == edge.dst:
for inner_edge in inner_state.out_edges(inner_node):
gt_map_strides_to_dst_nested_sdfg(sdfg, state, inner_edge, inner_node)


def gt_map_strides_to_src_nested_sdfg(
sdfg: dace.SDFG,
state: dace.SDFGState,
edge: dace.sdfg.graph.Edge,
outer_node: dace.nodes.AccessNode,
) -> None:
"""Propagates the strides of the given data node to the nested SDFGs on the edge source.
This function will recursively visit the nested SDFGs connected to the given
data node and apply mapping from inner to outer strides.
Args:
sdfg: The SDFG to process.
state: The state where the data node is used.
edge: The edge that writes to the data node, the nested SDFG is expected as the source.
outer_node: The data node whose strides should be propagated.
"""
if isinstance(edge.src, dace.nodes.MapExit):
# Find the source of the edge entering the map exit node
map_exit_in_conn = edge.src_conn.replace("OUT_", "IN_")
for edge_to_map_exit in state.in_edges_by_connector(edge.src, map_exit_in_conn):
gt_map_strides_to_src_nested_sdfg(sdfg, state, edge_to_map_exit, outer_node)
return

if not isinstance(edge.src, dace.nodes.NestedSDFG):
return

if isinstance(edge.src.sdfg.data(edge.src_conn), dace.data.Scalar):
return # no strides to propagate

outer_strides = outer_node.desc(sdfg).strides
_gt_map_strides_to_nested_sdfg(edge.src, edge.src_conn, edge.data, outer_strides)

for inner_state in edge.src.sdfg.states():
for inner_node in inner_state.data_nodes():
if inner_node.data == edge.src_conn:
for inner_edge in inner_state.in_edges(inner_node):
gt_map_strides_to_src_nested_sdfg(sdfg, state, inner_edge, inner_node)


def _gt_map_strides_to_nested_sdfg(
nsdfg_node: dace.nodes.NestedSDFG,
inner_data: str,
edge_data: dace.Memlet,
outer_strides: Iterable[int | dace.symbolic.SymExpr],
) -> None:
# We need to propagate the strides inside the nested SDFG on the global arrays
new_strides = tuple(
stride
for stride, to_map_size in zip(
outer_strides,
edge_data.subset.size(),
strict=True,
)
if to_map_size != 1
)
inner_desc = nsdfg_node.sdfg.arrays[inner_data]
assert not inner_desc.transient

if isinstance(inner_desc, dace.data.Scalar):
assert len(new_strides) == 0
return

assert isinstance(inner_desc, dace.data.Array)
if all(isinstance(inner_stride, dace.symbol) for inner_stride in inner_desc.strides):
for inner_stride, outer_stride in zip(inner_desc.strides, new_strides, strict=True):
nsdfg_node.symbol_mapping[inner_stride.name] = outer_stride
else:
inner_desc.set_shape(inner_desc.shape, new_strides)

new_strides_symbols: list[dace.symbol] = functools.reduce(
lambda acc, itm: (acc + list(itm.free_symbols)) # type: ignore[union-attr]
if dace.symbolic.issymbolic(itm)
else acc,
new_strides,
[],
)
new_strides_free_symbols = {
sym for sym in new_strides_symbols if sym.name not in nsdfg_node.sdfg.symbols
}
for sym in new_strides_free_symbols:
nsdfg_node.sdfg.add_symbol(sym.name, sym.dtype)
nsdfg_node.symbol_mapping[sym.name] = sym
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@
import dace


def _make_test_data(names: list[str]) -> dict[str, np.ndarray]:
return {name: np.array(np.random.rand(10), dtype=np.float64, copy=True) for name in names}


def _make_test_sdfg(
output_name: str = "G",
input_name: str = "G",
Expand Down Expand Up @@ -262,3 +258,92 @@ def test_map_buffer_elimination_not_apply():
validate_all=True,
)
assert count == 0


def test_map_buffer_elimination_with_nested_sdfgs():
"""
After removing a transient connected to a nested SDFG node, ensure that the strides
are propagated to the arrays in nested SDFG.
"""

stride1, stride2, stride3 = [dace.symbol(f"stride{i}", dace.int32) for i in range(3)]

# top-level sdfg
sdfg = dace.SDFG(util.unique_name("map_buffer"))
inp, inp_desc = sdfg.add_array("__inp", (10,), dace.float64)
out, out_desc = sdfg.add_array(
"__out", (10, 10, 10), dace.float64, strides=(stride1, stride2, stride3)
)
tmp, _ = sdfg.add_temp_transient_like(out_desc)
state = sdfg.add_state()
tmp_node = state.add_access(tmp)

nsdfg1 = dace.SDFG(util.unique_name("map_buffer"))
inp1, inp1_desc = nsdfg1.add_array("__inp", (10,), dace.float64)
out1, out1_desc = nsdfg1.add_array("__out", (10, 10), dace.float64)
tmp1, _ = nsdfg1.add_temp_transient_like(out1_desc)
state1 = nsdfg1.add_state()
tmp1_node = state1.add_access(tmp1)

nsdfg2 = dace.SDFG(util.unique_name("map_buffer"))
inp2, _ = nsdfg2.add_array("__inp", (10,), dace.float64)
out2, out2_desc = nsdfg2.add_array("__out", (10,), dace.float64)
tmp2, _ = nsdfg2.add_temp_transient_like(out2_desc)
state2 = nsdfg2.add_state()
tmp2_node = state2.add_access(tmp2)

state2.add_mapped_tasklet(
"broadcast2",
map_ranges={"__i": "0:10"},
code="__oval = __ival + 1.0",
inputs={
"__ival": dace.Memlet(f"{inp2}[__i]"),
},
outputs={
"__oval": dace.Memlet(f"{tmp2}[__i]"),
},
output_nodes={tmp2_node},
external_edges=True,
)
state2.add_nedge(tmp2_node, state2.add_access(out2), dace.Memlet.from_array(out2, out2_desc))

nsdfg2_node = state1.add_nested_sdfg(nsdfg2, nsdfg1, inputs={"__inp"}, outputs={"__out"})
me1, mx1 = state1.add_map("broadcast1", ndrange={"__i": "0:10"})
state1.add_memlet_path(
state1.add_access(inp1),
me1,
nsdfg2_node,
dst_conn="__inp",
memlet=dace.Memlet.from_array(inp1, inp1_desc),
)
state1.add_memlet_path(
nsdfg2_node, mx1, tmp1_node, src_conn="__out", memlet=dace.Memlet(f"{tmp1}[__i, 0:10]")
)
state1.add_nedge(tmp1_node, state1.add_access(out1), dace.Memlet.from_array(out1, out1_desc))

nsdfg1_node = state.add_nested_sdfg(nsdfg1, sdfg, inputs={"__inp"}, outputs={"__out"})
me, mx = state.add_map("broadcast", ndrange={"__i": "0:10"})
state.add_memlet_path(
state.add_access(inp),
me,
nsdfg1_node,
dst_conn="__inp",
memlet=dace.Memlet.from_array(inp, inp_desc),
)
state.add_memlet_path(
nsdfg1_node, mx, tmp_node, src_conn="__out", memlet=dace.Memlet(f"{tmp}[__i, 0:10, 0:10]")
)
state.add_nedge(tmp_node, state.add_access(out), dace.Memlet.from_array(out, out_desc))

sdfg.validate()

count = sdfg.apply_transformations_repeated(
gtx_transformations.GT4PyMapBufferElimination(
assume_pointwise=False,
),
validate=True,
validate_all=True,
)
assert count == 3
assert out1_desc.strides == out_desc.strides[1:]
assert out2_desc.strides == out_desc.strides[2:]

0 comments on commit d9218b6

Please sign in to comment.