diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 6dca3d186e..e8a8161747 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -1736,8 +1736,12 @@ def add_mapped_tasklet(self, language=dtypes.Language.Python, debuginfo=None, external_edges=False, - input_nodes: Optional[Dict[str, nd.AccessNode]] = None, - output_nodes: Optional[Dict[str, nd.AccessNode]] = None, + input_nodes: Optional[Union[Dict[str, nd.AccessNode], + List[nd.AccessNode], + Set[nd.AccessNode]]] = None, + output_nodes: Optional[Union[Dict[str, nd.AccessNode], + List[nd.AccessNode], + Set[nd.AccessNode]]] = None, propagate=True) -> Tuple[nd.Tasklet, nd.MapEntry, nd.MapExit]: """ Convenience function that adds a map entry, tasklet, map exit, and the respective edges to external arrays. @@ -1777,6 +1781,11 @@ def add_mapped_tasklet(self, tinputs = {k: None for k, v in inputs.items()} toutputs = {k: None for k, v in outputs.items()} + if isinstance(input_nodes, (list, set)): + input_nodes = {input_node.data: input_node for input_node in input_nodes} + if isinstance(output_nodes, (list, set)): + output_nodes = {output_node.data: output_node for output_node in output_nodes} + tasklet = nd.Tasklet( name, tinputs, diff --git a/tests/sdfg/state_test.py b/tests/sdfg/state_test.py index eb4e97ba66..7ba43ac4c0 100644 --- a/tests/sdfg/state_test.py +++ b/tests/sdfg/state_test.py @@ -58,7 +58,31 @@ def double_loop(arr: dace.float32[N]): sdfg.validate() +def test_add_mapped_tasklet(): + sdfg = dace.SDFG("test_add_mapped_tasklet") + state = sdfg.add_state(is_start_block=True) + + for name in "AB": + sdfg.add_array(name, (10, 10), dace.float64) + A, B = (state.add_access(name) for name in "AB") + + tsklt, me, mx = state.add_mapped_tasklet( + "test_map", + map_ranges={"i": "0:10", "j": "0:10"}, + inputs={"__in": dace.Memlet("A[i, j]")}, + code="__out = math.sin(__in)", + outputs={"__out": dace.Memlet("B[j, i]")}, + external_edges=True, + output_nodes=[B], + input_nodes={A}, + ) + sdfg.validate() + assert all(out_edge.dst is B for out_edge in state.out_edges(mx)) + assert all(in_edge.src is A for in_edge in state.in_edges(me)) + + if __name__ == '__main__': test_read_write_set() test_read_write_set_y_formation() test_deepcopy_state() + test_add_mapped_tasklet()