Skip to content

Commit

Permalink
Added a test that (currently fails) that shows that we have to also p…
Browse files Browse the repository at this point in the history
…ropagate to views.
  • Loading branch information
philip-paul-mueller committed Dec 20, 2024
1 parent cc9801b commit 9b36334
Showing 1 changed file with 103 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -539,3 +539,106 @@ def ref(a1, b1):
ref(**ref_args)
sdfg_level1(**res_args)
assert np.allclose(ref_args["b1"], res_args["b1"])


def _make_strides_propagation_view_nsdfg() -> dace.SDFG:
sdfg = dace.SDFG(util.unique_name("strides_propagation_view_nsdfg"))
state = sdfg.add_state(is_start_block=True)

sdfg.add_array(
"a2",
shape=(10, 10),
dtype=dace.float64,
transient=False,
)
sdfg.add_view(
"v2",
shape=(10,),
dtype=dace.float64,
)
sdfg.add_array(
"b2",
shape=(10,),
dtype=dace.float64,
transient=False,
)
a2, b2, v2 = (state.add_access(name) for name in ["a2", "b2", "v2"])

state.add_edge(a2, None, v2, "view", dace.Memlet("a2[2, 0:10]"))
state.add_mapped_tasklet(
"comp",
map_ranges={"__i0": "0:10"},
inputs={"__in1": dace.Memlet("v2[__i0]")},
code="__out = __in1 + 10.",
outputs={"__out": dace.Memlet("b2[__i0]")},
input_nodes={v2},
output_nodes={b2},
external_edges=True,
)
sdfg.validate()
return sdfg


def _make_strides_propagation_view_sdfg() -> tuple[dace.SDFG, dace_nodes.NestedSDFG]:
sdfg_level1 = dace.SDFG(util.unique_name("strides_propagation_view_sdfg"))
state = sdfg_level1.add_state(is_start_block=True)

sdfg_level1.add_array(
"a1",
shape=(10, 10),
dtype=dace.float64,
transient=False,
)
sdfg_level1.add_array(
"b1",
shape=(10,),
dtype=dace.float64,
transient=False,
)

sdfg_level2 = _make_strides_propagation_view_nsdfg()
nsdfg_level2 = state.add_nested_sdfg(
sdfg=sdfg_level2,
parent=sdfg_level1,
inputs={"a2"},
outputs={"b2"},
symbol_mapping={},
)

state.add_edge(state.add_access("a1"), None, nsdfg_level2, "a2", dace.Memlet("a1[0:10, 0:10]"))
state.add_edge(nsdfg_level2, "b2", state.add_access("b1"), None, dace.Memlet("b1[0:10]"))
sdfg_level1.validate()
return sdfg_level1, nsdfg_level2


def test_strides_propagation_view():
sdfg_level1, nsdfg_level2 = _make_strides_propagation_view_sdfg()

def ref(a1, b1):
v2 = a1[2, :]
for i in range(10):
b1[i] = v2[i] + 10.0

# The FORTRAN order simulates GPU execution.
res_args = {
"a1": np.array(np.random.rand(10, 10), order="F", dtype=np.float64, copy=True),
"b1": np.array(np.random.rand(10), order="F", dtype=np.float64, copy=True),
}
ref_args = copy.deepcopy(res_args)

# We have to temporary change the arguments into transients, otherwise they
# would not be processed.
for desc in sdfg_level1.arrays.values():
desc.transient = True
gtx_transformations.gt_change_transient_strides(
sdfg=sdfg_level1,
gpu=True,
)
for desc in sdfg_level1.arrays.values():
desc.transient = False

ref(**ref_args)
sdfg_level1(**res_args)

for name in ref_args:
assert np.allclose(ref_args[name], res_args[name]), f"Failed in argument {name}"

0 comments on commit 9b36334

Please sign in to comment.