Skip to content

Commit

Permalink
Fix infix expression _value and _expr usage (python-graphblas#418)
Browse files Browse the repository at this point in the history
Previously, `infixexpr._value` was sometimes `MatrixExpression` and sometimes `Matrix`. Scary!
  • Loading branch information
eriknw authored Mar 28, 2023
1 parent 0c1b102 commit 41fb71a
Show file tree
Hide file tree
Showing 8 changed files with 55 additions and 24 deletions.
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ repos:
- id: mixed-line-ending
- id: trailing-whitespace
- repo: https://github.com/abravalheri/validate-pyproject
rev: v0.12.1
rev: v0.12.2
hooks:
- id: validate-pyproject
name: Validate pyproject.toml
Expand All @@ -47,7 +47,7 @@ repos:
- id: black
- id: black-jupyter
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.257
rev: v0.0.259
hooks:
- id: ruff
args: [--fix-only]
Expand Down Expand Up @@ -75,7 +75,7 @@ repos:
additional_dependencies: [tomli]
files: ^(graphblas|docs)/
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.257
rev: v0.0.259
hooks:
- id: ruff
- repo: https://github.com/sphinx-contrib/sphinx-lint
Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ dependencies:
# - snakeviz
# - sphinx-lint
# - sympy
# - tuna
# - twine
# - vim
# - yesqa
Expand Down
31 changes: 20 additions & 11 deletions graphblas/core/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,33 +478,34 @@ def __bool__(self):


class InfixExprBase:
__slots__ = "left", "right", "_value", "__weakref__"
__slots__ = "left", "right", "_expr", "__weakref__"
_is_scalar = False

def __init__(self, left, right):
self.left = left
self.right = right
self._value = None
self._expr = None

def new(self, dtype=None, *, mask=None, name=None, **opts):
if (
mask is None
and self._value is not None
and (dtype is None or self._value.dtype == dtype)
and self._expr is not None
and self._expr._value is not None
and (dtype is None or self._expr._value.dtype == dtype)
):
rv = self._value
rv = self._expr._value
if name is not None:
rv.name = name
self._value = None
self._expr._value = None
return rv
expr = self._to_expr()
return expr.new(dtype, mask=mask, name=name, **opts)

def _to_expr(self):
if self._value is None:
if self._expr is None:
# Rely on the default operator for `x @ y`
self._value = getattr(self.left, self.method_name)(self.right)
return self._value
self._expr = getattr(self.left, self.method_name)(self.right)
return self._expr

def _get_value(self, attr=None, default=None):
expr = self._to_expr()
Expand Down Expand Up @@ -536,10 +537,18 @@ def __repr__(self):

@property
def dtype(self):
if self._value is not None:
return self._value.dtype
return self._to_expr().dtype

@property
def _value(self):
if self._expr is None:
return None
return self._expr._value

@_value.setter
def _value(self, val):
self._to_expr()._value = val


# Mistakes
utils._output_types[AmbiguousAssignOrExtract] = AmbiguousAssignOrExtract
Expand Down
1 change: 1 addition & 0 deletions graphblas/core/formatting.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# This file imports pandas, so it should only be imported when formatting
import numpy as np

from .. import backend, config, monoid, unary
Expand Down
16 changes: 8 additions & 8 deletions graphblas/core/infix.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@


def _ewise_add_to_expr(self):
if self._value is not None:
return self._value
if self._expr is not None:
return self._expr
if self.left.dtype == BOOL and self.right.dtype == BOOL:
self._value = self.left.ewise_add(self.right, lor)
return self._value
self._expr = self.left.ewise_add(self.right, lor)
return self._expr
raise TypeError(
"Bad dtypes for `x | y`! Automatic computation of `x | y` infix expressions is only valid "
f"for BOOL dtypes. The argument dtypes are {self.left.dtype} and {self.right.dtype}.\n\n"
Expand All @@ -30,11 +30,11 @@ def _ewise_add_to_expr(self):


def _ewise_mult_to_expr(self):
if self._value is not None:
return self._value
if self._expr is not None:
return self._expr
if self.left.dtype == BOOL and self.right.dtype == BOOL:
self._value = self.left.ewise_mult(self.right, land)
return self._value
self._expr = self.left.ewise_mult(self.right, land)
return self._expr
raise TypeError(
"Bad dtypes for `x & y`! Automatic computation of `x & y` infix expressions is only valid "
f"for BOOL dtypes. The argument dtypes are {self.left.dtype} and {self.right.dtype}.\n\n"
Expand Down
3 changes: 2 additions & 1 deletion graphblas/core/recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from ..dtypes import DataType
from . import base, lib
from .base import _recorder
from .formatting import CSS_STYLE
from .mask import Mask
from .matrix import TransposedMatrix
from .operator import TypedOpBase
Expand Down Expand Up @@ -103,6 +102,8 @@ def is_recording(self):
return self._token is not None and _recorder.get(base._prev_recorder) is self

def _repr_base_(self):
from .formatting import CSS_STYLE

status = (
'<div style="'
"height: 12px; "
Expand Down
18 changes: 18 additions & 0 deletions graphblas/tests/test_infix.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,3 +342,21 @@ def test_inplace_infix(s1, v1, v2, A1, A2):
expr @= A
with pytest.raises(TypeError, match="not supported"):
s1 @= v1


@autocompute
def test_infix_expr_value_types():
"""Test bug where `infix_expr._value` was used as MatrixExpression or Matrix"""
from graphblas.core.matrix import MatrixExpression

A = Matrix(int, 3, 3)
A << 1
expr = A @ A.T
assert expr._expr is None
assert expr._value is None
assert type(expr._get_value()) is Matrix
assert type(expr._expr) is MatrixExpression
assert type(expr.new()) is Matrix
assert expr._expr is not None
assert expr._value is None
assert type(expr.new()) is Matrix
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,8 @@ ignore = [
"graphblas/core/ss/matrix.py" = ["NPY002"] # numba doesn't support rng generator yet
"graphblas/core/ss/vector.py" = ["NPY002"] # numba doesn't support rng generator yet
"graphblas/ss/_core.py" = ["N999"] # We want _core.py to be underscopre
"graphblas/tests/*py" = ["S101", "T201", "D103", "D100", "SIM300"] # Allow assert, print, no docstring, and yoda
# Allow assert, pickle, RNG, print, no docstring, and yoda in tests
"graphblas/tests/*py" = ["S101", "S301", "S311", "T201", "D103", "D100", "SIM300"]
"graphblas/tests/test_formatting.py" = ["E501"] # Allow long lines
"graphblas/**/__init__.py" = ["F401"] # Allow unused imports (w/o defining `__all__`)
"scripts/*.py" = ["INP001"] # Not a package
Expand Down

0 comments on commit 41fb71a

Please sign in to comment.