Skip to content

Commit

Permalink
Add gb.core.utils.ensure_type to make it easier to work with expres…
Browse files Browse the repository at this point in the history
  • Loading branch information
eriknw authored Oct 29, 2022
1 parent 90adb6c commit 55da5ff
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 6 deletions.
6 changes: 3 additions & 3 deletions graphblas/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _expect_type_message(
return x, None
elif output_type(x) in types:
if config.get("autocompute"):
return x._get_value(), None
return x.new(), None
extra_message = f"{extra_message}\n\n" if extra_message else ""
extra_message += (
"Hint: use `graphblas.config.set(autocompute=True)` to automatically "
Expand All @@ -73,7 +73,7 @@ def _expect_type_message(
return x, None
elif output_type(x) is types:
if config.get("autocompute"):
return x._get_value(), None
return x.new(), None
extra_message = f"{extra_message}\n\n" if extra_message else ""
extra_message += (
"Hint: use `graphblas.config.set(autocompute=True)` to automatically "
Expand Down Expand Up @@ -410,7 +410,7 @@ def _update(self, expr, mask=None, accum=None, replace=False, input_mask=None):

if type(self) is not expr.output_type:
if expr.output_type._is_scalar and config.get("autocompute"):
self._update(expr._get_value(), mask, accum, replace, input_mask)
self._update(expr.new(), mask, accum, replace, input_mask)
return
from .scalar import Scalar

Expand Down
2 changes: 1 addition & 1 deletion graphblas/core/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ def format_scalar(scalar, expr=None):

def get_expr_result(expr, html=False):
try:
val = expr._get_value()
val = expr.new()
except OutOfMemory:
arg_string = "Result is too large to compute!"
if html:
Expand Down
6 changes: 4 additions & 2 deletions graphblas/core/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3094,8 +3094,9 @@ def get_semiring(monoid, binaryop, name=None):
See Also
--------
Semiring.register_anonymous
Semiring.register_new
semiring.register_anonymous
semiring.register_new
semiring.from_string
"""
monoid, opclass = find_opclass(monoid)
switched = False
Expand Down Expand Up @@ -3210,6 +3211,7 @@ def get_semiring(monoid, binaryop, name=None):
monoid.register_anonymous = Monoid.register_anonymous
semiring.register_new = Semiring.register_new
semiring.register_anonymous = Semiring.register_anonymous
semiring.get_semiring = get_semiring

select._binary_to_select.update(
{
Expand Down
23 changes: 23 additions & 0 deletions graphblas/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,29 @@ def normalize_chunks(chunks, shape):
return chunksizes


def ensure_type(x, types):
"""Try to ensure `x` is one of the given types, computing if necessary.
`types` must be a type or a tuple of types as used in `isinstance`.
For example, if `types` is a Vector, then a Vector input will be returned,
and a `VectorExpression` input will be computed and returned as a Vector.
TypeError will be raised if the input is not or can't be converted to types.
This function ignores `graphblas.config["autocompute"]`; it always computes
if the return type will match `types`.
"""
if isinstance(x, types):
return x
elif isinstance(types, tuple):
if output_type(x) in types:
return x.new()
elif output_type(x) is types:
return x.new()
raise TypeError(f"{type(x).__name__!r} object is not of type {types}")


class class_property:
__slots__ = "classval", "member_property"

Expand Down
18 changes: 18 additions & 0 deletions graphblas/tests/test_resolving.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from graphblas import binary, dtypes, replace, unary
from graphblas.core.expr import Updater
from graphblas.core.utils import ensure_type

from graphblas import Matrix, Scalar, Vector # isort:skip (for dask-graphblas)

Expand Down Expand Up @@ -246,3 +247,20 @@ def test_py_indices():
# All together now!
idx = v[0 : v.size : 1].resolved_indexes.py_indices
assert idx == slice(None)


def test_ensure_type():
v = Vector.from_values([0, 1, 2], [1, 2, 3])
A = Matrix.from_values([0, 1, 2], [2, 0, 1], [0, 2, 3])
assert ensure_type(v, Vector) is v
assert ensure_type(A, Matrix) is A
assert ensure_type(v, (Matrix, Vector)) is v
assert ensure_type(A, (Matrix, Vector)) is A
with pytest.raises(TypeError):
ensure_type(A, Vector)
assert ensure_type(v + 1, Vector).isequal((v + 1).new())
assert ensure_type(v + 1, (Vector,)).isequal((v + 1).new())
assert ensure_type(A.mxm(A), Matrix).isequal(A.mxm(A).new())
with pytest.raises(TypeError):
ensure_type(A.mxm(A), (Vector,))
ensure_type(4, int) # why not

0 comments on commit 55da5ff

Please sign in to comment.