From 55da5ffc7e14a827387aa7f3b77448675efbd0bb Mon Sep 17 00:00:00 2001 From: Erik Welch Date: Fri, 28 Oct 2022 22:34:08 -0500 Subject: [PATCH] Add `gb.core.utils.ensure_type` to make it easier to work with expressions (#312) --- graphblas/core/base.py | 6 +++--- graphblas/core/formatting.py | 2 +- graphblas/core/operator.py | 6 ++++-- graphblas/core/utils.py | 23 +++++++++++++++++++++++ graphblas/tests/test_resolving.py | 18 ++++++++++++++++++ 5 files changed, 49 insertions(+), 6 deletions(-) diff --git a/graphblas/core/base.py b/graphblas/core/base.py index c9e7b08a3..c33b641a2 100644 --- a/graphblas/core/base.py +++ b/graphblas/core/base.py @@ -63,7 +63,7 @@ def _expect_type_message( return x, None elif output_type(x) in types: if config.get("autocompute"): - return x._get_value(), None + return x.new(), None extra_message = f"{extra_message}\n\n" if extra_message else "" extra_message += ( "Hint: use `graphblas.config.set(autocompute=True)` to automatically " @@ -73,7 +73,7 @@ def _expect_type_message( return x, None elif output_type(x) is types: if config.get("autocompute"): - return x._get_value(), None + return x.new(), None extra_message = f"{extra_message}\n\n" if extra_message else "" extra_message += ( "Hint: use `graphblas.config.set(autocompute=True)` to automatically " @@ -410,7 +410,7 @@ def _update(self, expr, mask=None, accum=None, replace=False, input_mask=None): if type(self) is not expr.output_type: if expr.output_type._is_scalar and config.get("autocompute"): - self._update(expr._get_value(), mask, accum, replace, input_mask) + self._update(expr.new(), mask, accum, replace, input_mask) return from .scalar import Scalar diff --git a/graphblas/core/formatting.py b/graphblas/core/formatting.py index 8f024bfbd..95506dac7 100644 --- a/graphblas/core/formatting.py +++ b/graphblas/core/formatting.py @@ -505,7 +505,7 @@ def format_scalar(scalar, expr=None): def get_expr_result(expr, html=False): try: - val = expr._get_value() + val = expr.new() except OutOfMemory: arg_string = "Result is too large to compute!" if html: diff --git a/graphblas/core/operator.py b/graphblas/core/operator.py index bf6843572..c717f81e6 100644 --- a/graphblas/core/operator.py +++ b/graphblas/core/operator.py @@ -3094,8 +3094,9 @@ def get_semiring(monoid, binaryop, name=None): See Also -------- - Semiring.register_anonymous - Semiring.register_new + semiring.register_anonymous + semiring.register_new + semiring.from_string """ monoid, opclass = find_opclass(monoid) switched = False @@ -3210,6 +3211,7 @@ def get_semiring(monoid, binaryop, name=None): monoid.register_anonymous = Monoid.register_anonymous semiring.register_new = Semiring.register_new semiring.register_anonymous = Semiring.register_anonymous +semiring.get_semiring = get_semiring select._binary_to_select.update( { diff --git a/graphblas/core/utils.py b/graphblas/core/utils.py index 8a236bbfd..53b4f9897 100644 --- a/graphblas/core/utils.py +++ b/graphblas/core/utils.py @@ -214,6 +214,29 @@ def normalize_chunks(chunks, shape): return chunksizes +def ensure_type(x, types): + """Try to ensure `x` is one of the given types, computing if necessary. + + `types` must be a type or a tuple of types as used in `isinstance`. + + For example, if `types` is a Vector, then a Vector input will be returned, + and a `VectorExpression` input will be computed and returned as a Vector. + + TypeError will be raised if the input is not or can't be converted to types. + + This function ignores `graphblas.config["autocompute"]`; it always computes + if the return type will match `types`. + """ + if isinstance(x, types): + return x + elif isinstance(types, tuple): + if output_type(x) in types: + return x.new() + elif output_type(x) is types: + return x.new() + raise TypeError(f"{type(x).__name__!r} object is not of type {types}") + + class class_property: __slots__ = "classval", "member_property" diff --git a/graphblas/tests/test_resolving.py b/graphblas/tests/test_resolving.py index 9fd525c08..74fdcca26 100644 --- a/graphblas/tests/test_resolving.py +++ b/graphblas/tests/test_resolving.py @@ -3,6 +3,7 @@ from graphblas import binary, dtypes, replace, unary from graphblas.core.expr import Updater +from graphblas.core.utils import ensure_type from graphblas import Matrix, Scalar, Vector # isort:skip (for dask-graphblas) @@ -246,3 +247,20 @@ def test_py_indices(): # All together now! idx = v[0 : v.size : 1].resolved_indexes.py_indices assert idx == slice(None) + + +def test_ensure_type(): + v = Vector.from_values([0, 1, 2], [1, 2, 3]) + A = Matrix.from_values([0, 1, 2], [2, 0, 1], [0, 2, 3]) + assert ensure_type(v, Vector) is v + assert ensure_type(A, Matrix) is A + assert ensure_type(v, (Matrix, Vector)) is v + assert ensure_type(A, (Matrix, Vector)) is A + with pytest.raises(TypeError): + ensure_type(A, Vector) + assert ensure_type(v + 1, Vector).isequal((v + 1).new()) + assert ensure_type(v + 1, (Vector,)).isequal((v + 1).new()) + assert ensure_type(A.mxm(A), Matrix).isequal(A.mxm(A).new()) + with pytest.raises(TypeError): + ensure_type(A.mxm(A), (Vector,)) + ensure_type(4, int) # why not