diff --git a/dace/dtypes.py b/dace/dtypes.py index 465e73b2b1..d7076dc987 100644 --- a/dace/dtypes.py +++ b/dace/dtypes.py @@ -389,7 +389,6 @@ def __init__(self, wrapped_type, typename=None): # Convert python basic types if isinstance(wrapped_type, str): try: - if wrapped_type == "bool": wrapped_type = numpy.bool_ else: diff --git a/dace/sdfg/nodes.py b/dace/sdfg/nodes.py index eb073f4319..1880fbb7e1 100644 --- a/dace/sdfg/nodes.py +++ b/dace/sdfg/nodes.py @@ -144,6 +144,43 @@ def add_out_connector(self, connector_name: str, dtype: dtypes.typeclass = None, self.out_connectors = connectors return True + def _add_scope_connectors( + self, + connector_name: str, + dtype: Optional[dtypes.typeclass] = None, + force: bool = False, + ) -> None: + """ Adds input and output connector names to `self` in one step. + + The function will add an input connector with name `'IN_' + connector_name` + and an output connector with name `'OUT_' + connector_name`. + The function is a shorthand for calling `add_in_connector()` and `add_out_connector()`. + + :param connector_name: The base name of the new connectors. + :param dtype: The type of the connectors, or `None` for auto-detect. + :param force: Add connector even if input or output connector of that name already exists. + :return: True if the operation is successful, otherwise False. + """ + in_connector_name = "IN_" + connector_name + out_connector_name = "OUT_" + connector_name + if not force: + if in_connector_name in self.in_connectors or in_connector_name in self.out_connectors: + return False + if out_connector_name in self.in_connectors or out_connector_name in self.out_connectors: + return False + # We force unconditionally because we have performed the tests above. + self.add_in_connector( + connector_name=in_connector_name, + dtype=dtype, + force=True, + ) + self.add_out_connector( + connector_name=out_connector_name, + dtype=dtype, + force=True, + ) + return True + def remove_in_connector(self, connector_name: str): """ Removes an input connector from the node. @@ -741,6 +778,9 @@ class EntryNode(Node): def validate(self, sdfg, state): self.map.validate(sdfg, state, self) + add_scope_connectors = Node._add_scope_connectors + + # ------------------------------------------------------------------------------ @@ -752,6 +792,8 @@ class ExitNode(Node): def validate(self, sdfg, state): self.map.validate(sdfg, state, self) + add_scope_connectors = Node._add_scope_connectors + # ------------------------------------------------------------------------------ diff --git a/tests/sdfg/nodes_test.py b/tests/sdfg/nodes_test.py new file mode 100644 index 0000000000..baf6d4765d --- /dev/null +++ b/tests/sdfg/nodes_test.py @@ -0,0 +1,35 @@ +import dace + +def test_add_scope_connectors(): + sdfg = dace.SDFG("add_scope_connectors_sdfg") + state = sdfg.add_state(is_start_block=True) + me: dace.nodes.MapEntry + mx: dace.nodes.MapExit + me, mx = state.add_map("test_map", ndrange={"__i0": "0:10"}) + assert all( + len(mn.in_connectors) == 0 and len(mn.out_connectors) == 0 + for mn in [me, mx] + ) + me.add_in_connector("IN_T", dtype=dace.float64) + assert len(me.in_connectors) == 1 and me.in_connectors["IN_T"] is dace.float64 and len(me.out_connectors) == 0 + assert len(mx.in_connectors) == 0 and len(mx.out_connectors) == 0 + + # Because there is already an `IN_T` this call will fail. + assert not me.add_scope_connectors("T") + assert len(me.in_connectors) == 1 and me.in_connectors["IN_T"] is dace.float64 and len(me.out_connectors) == 0 + assert len(mx.in_connectors) == 0 and len(mx.out_connectors) == 0 + + # Now it will work, because we specify force, however, the current type for `IN_T` will be overridden. + assert me.add_scope_connectors("T", force=True) + assert len(me.in_connectors) == 1 and me.in_connectors["IN_T"].type is None + assert len(me.out_connectors) == 1 and me.out_connectors["OUT_T"].type is None + assert len(mx.in_connectors) == 0 and len(mx.out_connectors) == 0 + + # Now tries to the full adding. + assert mx.add_scope_connectors("B", dtype=dace.int64) + assert len(mx.in_connectors) == 1 and mx.in_connectors["IN_B"] is dace.int64 + assert len(mx.out_connectors) == 1 and mx.out_connectors["OUT_B"] is dace.int64 + + +if __name__ == "__main__": + test_add_scope_connectors()