Skip to content

Commit

Permalink
Made the SDFGState.add_mapped_tasklet() more convenient (#1655)
Browse files Browse the repository at this point in the history
Before if the user wanted to supply in and output nodes he had to
present a `dict` that maps the data name to the access node. However,
because of the rules of a valid SDFG the key of that `dict` was always
the same as the data the access node this information is redundant. Thus
this commit allows to only pass the access nodes.
  • Loading branch information
philip-paul-mueller authored Sep 15, 2024
1 parent c8e2704 commit 95c65be
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 2 deletions.
13 changes: 11 additions & 2 deletions dace/sdfg/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1736,8 +1736,12 @@ def add_mapped_tasklet(self,
language=dtypes.Language.Python,
debuginfo=None,
external_edges=False,
input_nodes: Optional[Dict[str, nd.AccessNode]] = None,
output_nodes: Optional[Dict[str, nd.AccessNode]] = None,
input_nodes: Optional[Union[Dict[str, nd.AccessNode],
List[nd.AccessNode],
Set[nd.AccessNode]]] = None,
output_nodes: Optional[Union[Dict[str, nd.AccessNode],
List[nd.AccessNode],
Set[nd.AccessNode]]] = None,
propagate=True) -> Tuple[nd.Tasklet, nd.MapEntry, nd.MapExit]:
""" Convenience function that adds a map entry, tasklet, map exit,
and the respective edges to external arrays.
Expand Down Expand Up @@ -1777,6 +1781,11 @@ def add_mapped_tasklet(self,
tinputs = {k: None for k, v in inputs.items()}
toutputs = {k: None for k, v in outputs.items()}

if isinstance(input_nodes, (list, set)):
input_nodes = {input_node.data: input_node for input_node in input_nodes}
if isinstance(output_nodes, (list, set)):
output_nodes = {output_node.data: output_node for output_node in output_nodes}

tasklet = nd.Tasklet(
name,
tinputs,
Expand Down
24 changes: 24 additions & 0 deletions tests/sdfg/state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,31 @@ def double_loop(arr: dace.float32[N]):
sdfg.validate()


def test_add_mapped_tasklet():
sdfg = dace.SDFG("test_add_mapped_tasklet")
state = sdfg.add_state(is_start_block=True)

for name in "AB":
sdfg.add_array(name, (10, 10), dace.float64)
A, B = (state.add_access(name) for name in "AB")

tsklt, me, mx = state.add_mapped_tasklet(
"test_map",
map_ranges={"i": "0:10", "j": "0:10"},
inputs={"__in": dace.Memlet("A[i, j]")},
code="__out = math.sin(__in)",
outputs={"__out": dace.Memlet("B[j, i]")},
external_edges=True,
output_nodes=[B],
input_nodes={A},
)
sdfg.validate()
assert all(out_edge.dst is B for out_edge in state.out_edges(mx))
assert all(in_edge.src is A for in_edge in state.in_edges(me))


if __name__ == '__main__':
test_read_write_set()
test_read_write_set_y_formation()
test_deepcopy_state()
test_add_mapped_tasklet()

0 comments on commit 95c65be

Please sign in to comment.