Skip to content

Commit

Permalink
Allow easier array UDTs, such as "INT64[3, 3]" (python-graphblas#300)
Browse files Browse the repository at this point in the history
  • Loading branch information
eriknw authored Oct 26, 2022
1 parent 5852ff3 commit 20f3dfb
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 0 deletions.
6 changes: 6 additions & 0 deletions graphblas/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ def register_anonymous(dtype, name=None):
if isinstance(dtype, dict):
# Allow dtypes such as `{'x': int, 'y': float}` for convenience
dtype = _np.dtype([(key, lookup_dtype(val).np_type) for key, val in dtype.items()])
elif isinstance(dtype, str) and "[" in dtype and dtype.endswith("]"):
# Allow dtypes such as `"INT64[3, 4]"` for convenience
base_dtype, shape = dtype.split("[", 1)
base_dtype = lookup_dtype(base_dtype)
shape = _np.lib.format.safe_eval(f"[{shape}")
dtype = _np.dtype((base_dtype.np_type, shape))
else:
raise
if dtype in _registry:
Expand Down
2 changes: 2 additions & 0 deletions graphblas/tests/test_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1972,6 +1972,8 @@ def test_udt():
# arrays as dtypes!
np_dtype = np.dtype("(3,)uint16")
udt2 = dtypes.register_anonymous(np_dtype, "has_subdtype")
udt2alt = dtypes.register_anonymous("UINT16[3]")
assert udt2 == udt2alt
s = Scalar(udt2)
s.value = [0, 0, 0]

Expand Down

0 comments on commit 20f3dfb

Please sign in to comment.