Skip to content

Commit

Permalink
auto-walrus (python-graphblas#287)
Browse files Browse the repository at this point in the history
* auto-walrus

* update

* auto-walrus v0.1.9

* and now auto-walrus v0.2.1
  • Loading branch information
eriknw authored Oct 19, 2022
1 parent d94aec4 commit 5fc34d7
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 34 deletions.
4 changes: 4 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ repos:
hooks:
- id: pyupgrade
args: [--py38-plus]
- repo: https://github.com/MarcoGorelli/auto-walrus
rev: v0.2.1
hooks:
- id: auto-walrus
- repo: https://github.com/psf/black
rev: 22.6.0
hooks:
Expand Down
6 changes: 2 additions & 4 deletions graphblas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@


def record_raw(text):
rec = _recorder.get(_prev_recorder)
if rec is not None:
if (rec := _recorder.get(_prev_recorder)) is not None:
rec.record_raw(text)


Expand Down Expand Up @@ -160,8 +159,7 @@ def _expect_op_message(


def _expect_op(self, op, values, **kwargs):
message = _expect_op_message(self, op, values, **kwargs)
if message is not None:
if (message := _expect_op_message(self, op, values, **kwargs)) is not None:
raise TypeError(message) from None


Expand Down
3 changes: 1 addition & 2 deletions graphblas/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,7 @@ def from_awkward(A, *, name=None):
Vector or Matrix
"""
params = A.layout.parameters
missing = {"format", "shape"} - params.keys()
if missing:
if missing := {"format", "shape"} - params.keys():
raise ValueError(f"Missing parameters: {missing}")
format = params["format"]
shape = params["shape"]
Expand Down
39 changes: 13 additions & 26 deletions graphblas/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,8 +527,7 @@ def __reduce__(self):
def _deserialize(name, func, anonymous):
if anonymous:
return UnaryOp.register_anonymous(func, name, parameterized=True)
rv = UnaryOp._find(name)
if rv is not None:
if (rv := UnaryOp._find(name)) is not None:
return rv
return UnaryOp.register_new(name, func, parameterized=True)

Expand Down Expand Up @@ -559,8 +558,7 @@ def __reduce__(self):
def _deserialize(name, func, anonymous):
if anonymous:
return IndexUnaryOp.register_anonymous(func, name, parameterized=True)
rv = IndexUnaryOp._find(name)
if rv is not None:
if (rv := IndexUnaryOp._find(name)) is not None:
return rv
return IndexUnaryOp.register_new(name, func, parameterized=True)

Expand Down Expand Up @@ -591,8 +589,7 @@ def __reduce__(self):
def _deserialize(name, func, anonymous):
if anonymous:
return SelectOp.register_anonymous(func, name, parameterized=True)
rv = SelectOp._find(name)
if rv is not None:
if (rv := SelectOp._find(name)) is not None:
return rv
return SelectOp.register_new(name, func, parameterized=True)

Expand Down Expand Up @@ -658,8 +655,7 @@ def __reduce__(self):
def _deserialize(name, func, anonymous):
if anonymous:
return BinaryOp.register_anonymous(func, name, parameterized=True)
rv = BinaryOp._find(name)
if rv is not None:
if (rv := BinaryOp._find(name)) is not None:
return rv
return BinaryOp.register_new(name, func, parameterized=True)

Expand Down Expand Up @@ -711,8 +707,7 @@ def __reduce__(self):
def _deserialize(name, binaryop, identity, anonymous):
if anonymous:
return Monoid.register_anonymous(binaryop, identity, name)
rv = Monoid._find(name)
if rv is not None:
if (rv := Monoid._find(name)) is not None:
return rv
return Monoid.register_new(name, binaryop, identity)

Expand Down Expand Up @@ -768,8 +763,7 @@ def __reduce__(self):
def _deserialize(name, monoid, binaryop, anonymous):
if anonymous:
return Semiring.register_anonymous(monoid, binaryop, name)
rv = Semiring._find(name)
if rv is not None:
if (rv := Semiring._find(name)) is not None:
return rv
return Semiring.register_new(name, monoid, binaryop)

Expand Down Expand Up @@ -952,8 +946,7 @@ def _initialize(cls, include_in_ops=True):

@classmethod
def _deserialize(cls, name, *args):
rv = cls._find(name)
if rv is not None:
if (rv := cls._find(name)) is not None:
return rv # Should we verify this is what the user expects?
return cls.register_new(name, *args)

Expand Down Expand Up @@ -1239,8 +1232,7 @@ def __reduce__(self):
if hasattr(self.orig_func, "_parameterized_info"):
return (_deserialize_parameterized, self.orig_func._parameterized_info)
return (self.register_anonymous, (self.orig_func, self.name))
name = f"unary.{self.name}"
if name in _STANDARD_OPERATOR_NAMES:
if (name := f"unary.{self.name}") in _STANDARD_OPERATOR_NAMES:
return name
return (self._deserialize, (self.name, self.orig_func))

Expand Down Expand Up @@ -1522,8 +1514,7 @@ def __reduce__(self):
if hasattr(self.orig_func, "_parameterized_info"):
return (_deserialize_parameterized, self.orig_func._parameterized_info)
return (self.register_anonymous, (self.orig_func, self.name))
name = f"indexunary.{self.name}"
if name in _STANDARD_OPERATOR_NAMES:
if (name := f"indexunary.{self.name}") in _STANDARD_OPERATOR_NAMES:
return name
return (self._deserialize, (self.name, self.orig_func))

Expand Down Expand Up @@ -1646,8 +1637,7 @@ def __reduce__(self):
if hasattr(self.orig_func, "_parameterized_info"):
return (_deserialize_parameterized, self.orig_func._parameterized_info)
return (self.register_anonymous, (self.orig_func, self.name))
name = f"select.{self.name}"
if name in _STANDARD_OPERATOR_NAMES:
if (name := f"select.{self.name}") in _STANDARD_OPERATOR_NAMES:
return name
return (self._deserialize, (self.name, self.orig_func))

Expand Down Expand Up @@ -2357,8 +2347,7 @@ def __reduce__(self):
if hasattr(self.orig_func, "_parameterized_info"):
return (_deserialize_parameterized, self.orig_func._parameterized_info)
return (self.register_anonymous, (self.orig_func, self.name))
name = f"binary.{self.name}"
if name in _STANDARD_OPERATOR_NAMES:
if (name := f"binary.{self.name}") in _STANDARD_OPERATOR_NAMES:
return name
return (self._deserialize, (self.name, self.orig_func))

Expand Down Expand Up @@ -2543,8 +2532,7 @@ def __init__(self, name, binaryop=None, identity=None, *, anonymous=False):
def __reduce__(self):
if self._anonymous:
return (self.register_anonymous, (self._binaryop, self._identity, self.name))
name = f"monoid.{self.name}"
if name in _STANDARD_OPERATOR_NAMES:
if (name := f"monoid.{self.name}") in _STANDARD_OPERATOR_NAMES:
return name
return (self._deserialize, (self.name, self._binaryop, self._identity))

Expand Down Expand Up @@ -2963,8 +2951,7 @@ def __init__(self, name, monoid=None, binaryop=None, *, anonymous=False):
def __reduce__(self):
if self._anonymous:
return (self.register_anonymous, (self._monoid, self._binaryop, self.name))
name = f"semiring.{self.name}"
if name in _STANDARD_OPERATOR_NAMES:
if (name := f"semiring.{self.name}") in _STANDARD_OPERATOR_NAMES:
return name
return (self._deserialize, (self.name, self._monoid, self._binaryop))

Expand Down
3 changes: 1 addition & 2 deletions versioneer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1466,8 +1466,7 @@ def get_versions(verbose=False):
except NotThisMethod:
pass

from_vcs_f = handlers.get("pieces_from_vcs")
if from_vcs_f:
if from_vcs_f := handlers.get("pieces_from_vcs"):
try:
pieces = from_vcs_f(cfg.tag_prefix, root, verbose)
ver = render(pieces, cfg.style)
Expand Down

0 comments on commit 5fc34d7

Please sign in to comment.