diff --git a/graphblas/dtypes.py b/graphblas/dtypes.py index 33dd033a1..fd66ba3b2 100644 --- a/graphblas/dtypes.py +++ b/graphblas/dtypes.py @@ -91,6 +91,12 @@ def register_anonymous(dtype, name=None): if isinstance(dtype, dict): # Allow dtypes such as `{'x': int, 'y': float}` for convenience dtype = _np.dtype([(key, lookup_dtype(val).np_type) for key, val in dtype.items()]) + elif isinstance(dtype, str) and "[" in dtype and dtype.endswith("]"): + # Allow dtypes such as `"INT64[3, 4]"` for convenience + base_dtype, shape = dtype.split("[", 1) + base_dtype = lookup_dtype(base_dtype) + shape = _np.lib.format.safe_eval(f"[{shape}") + dtype = _np.dtype((base_dtype.np_type, shape)) else: raise if dtype in _registry: diff --git a/graphblas/tests/test_vector.py b/graphblas/tests/test_vector.py index 73b898ef5..8e010fe0b 100644 --- a/graphblas/tests/test_vector.py +++ b/graphblas/tests/test_vector.py @@ -1972,6 +1972,8 @@ def test_udt(): # arrays as dtypes! np_dtype = np.dtype("(3,)uint16") udt2 = dtypes.register_anonymous(np_dtype, "has_subdtype") + udt2alt = dtypes.register_anonymous("UINT16[3]") + assert udt2 == udt2alt s = Scalar(udt2) s.value = [0, 0, 0]