diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bc599f7eb..d429ee6d5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,6 +27,10 @@ repos: hooks: - id: pyupgrade args: [--py38-plus] + - repo: https://github.com/MarcoGorelli/auto-walrus + rev: v0.2.1 + hooks: + - id: auto-walrus - repo: https://github.com/psf/black rev: 22.6.0 hooks: diff --git a/graphblas/base.py b/graphblas/base.py index 1ebf820ad..5bcf434ec 100644 --- a/graphblas/base.py +++ b/graphblas/base.py @@ -17,8 +17,7 @@ def record_raw(text): - rec = _recorder.get(_prev_recorder) - if rec is not None: + if (rec := _recorder.get(_prev_recorder)) is not None: rec.record_raw(text) @@ -160,8 +159,7 @@ def _expect_op_message( def _expect_op(self, op, values, **kwargs): - message = _expect_op_message(self, op, values, **kwargs) - if message is not None: + if (message := _expect_op_message(self, op, values, **kwargs)) is not None: raise TypeError(message) from None diff --git a/graphblas/io.py b/graphblas/io.py index 077a959c7..ed5b3a505 100644 --- a/graphblas/io.py +++ b/graphblas/io.py @@ -210,8 +210,7 @@ def from_awkward(A, *, name=None): Vector or Matrix """ params = A.layout.parameters - missing = {"format", "shape"} - params.keys() - if missing: + if missing := {"format", "shape"} - params.keys(): raise ValueError(f"Missing parameters: {missing}") format = params["format"] shape = params["shape"] diff --git a/graphblas/operator.py b/graphblas/operator.py index 2e28b4262..ad44ee9fa 100644 --- a/graphblas/operator.py +++ b/graphblas/operator.py @@ -527,8 +527,7 @@ def __reduce__(self): def _deserialize(name, func, anonymous): if anonymous: return UnaryOp.register_anonymous(func, name, parameterized=True) - rv = UnaryOp._find(name) - if rv is not None: + if (rv := UnaryOp._find(name)) is not None: return rv return UnaryOp.register_new(name, func, parameterized=True) @@ -559,8 +558,7 @@ def __reduce__(self): def _deserialize(name, func, anonymous): if anonymous: return IndexUnaryOp.register_anonymous(func, name, parameterized=True) - rv = IndexUnaryOp._find(name) - if rv is not None: + if (rv := IndexUnaryOp._find(name)) is not None: return rv return IndexUnaryOp.register_new(name, func, parameterized=True) @@ -591,8 +589,7 @@ def __reduce__(self): def _deserialize(name, func, anonymous): if anonymous: return SelectOp.register_anonymous(func, name, parameterized=True) - rv = SelectOp._find(name) - if rv is not None: + if (rv := SelectOp._find(name)) is not None: return rv return SelectOp.register_new(name, func, parameterized=True) @@ -658,8 +655,7 @@ def __reduce__(self): def _deserialize(name, func, anonymous): if anonymous: return BinaryOp.register_anonymous(func, name, parameterized=True) - rv = BinaryOp._find(name) - if rv is not None: + if (rv := BinaryOp._find(name)) is not None: return rv return BinaryOp.register_new(name, func, parameterized=True) @@ -711,8 +707,7 @@ def __reduce__(self): def _deserialize(name, binaryop, identity, anonymous): if anonymous: return Monoid.register_anonymous(binaryop, identity, name) - rv = Monoid._find(name) - if rv is not None: + if (rv := Monoid._find(name)) is not None: return rv return Monoid.register_new(name, binaryop, identity) @@ -768,8 +763,7 @@ def __reduce__(self): def _deserialize(name, monoid, binaryop, anonymous): if anonymous: return Semiring.register_anonymous(monoid, binaryop, name) - rv = Semiring._find(name) - if rv is not None: + if (rv := Semiring._find(name)) is not None: return rv return Semiring.register_new(name, monoid, binaryop) @@ -952,8 +946,7 @@ def _initialize(cls, include_in_ops=True): @classmethod def _deserialize(cls, name, *args): - rv = cls._find(name) - if rv is not None: + if (rv := cls._find(name)) is not None: return rv # Should we verify this is what the user expects? return cls.register_new(name, *args) @@ -1239,8 +1232,7 @@ def __reduce__(self): if hasattr(self.orig_func, "_parameterized_info"): return (_deserialize_parameterized, self.orig_func._parameterized_info) return (self.register_anonymous, (self.orig_func, self.name)) - name = f"unary.{self.name}" - if name in _STANDARD_OPERATOR_NAMES: + if (name := f"unary.{self.name}") in _STANDARD_OPERATOR_NAMES: return name return (self._deserialize, (self.name, self.orig_func)) @@ -1522,8 +1514,7 @@ def __reduce__(self): if hasattr(self.orig_func, "_parameterized_info"): return (_deserialize_parameterized, self.orig_func._parameterized_info) return (self.register_anonymous, (self.orig_func, self.name)) - name = f"indexunary.{self.name}" - if name in _STANDARD_OPERATOR_NAMES: + if (name := f"indexunary.{self.name}") in _STANDARD_OPERATOR_NAMES: return name return (self._deserialize, (self.name, self.orig_func)) @@ -1646,8 +1637,7 @@ def __reduce__(self): if hasattr(self.orig_func, "_parameterized_info"): return (_deserialize_parameterized, self.orig_func._parameterized_info) return (self.register_anonymous, (self.orig_func, self.name)) - name = f"select.{self.name}" - if name in _STANDARD_OPERATOR_NAMES: + if (name := f"select.{self.name}") in _STANDARD_OPERATOR_NAMES: return name return (self._deserialize, (self.name, self.orig_func)) @@ -2357,8 +2347,7 @@ def __reduce__(self): if hasattr(self.orig_func, "_parameterized_info"): return (_deserialize_parameterized, self.orig_func._parameterized_info) return (self.register_anonymous, (self.orig_func, self.name)) - name = f"binary.{self.name}" - if name in _STANDARD_OPERATOR_NAMES: + if (name := f"binary.{self.name}") in _STANDARD_OPERATOR_NAMES: return name return (self._deserialize, (self.name, self.orig_func)) @@ -2543,8 +2532,7 @@ def __init__(self, name, binaryop=None, identity=None, *, anonymous=False): def __reduce__(self): if self._anonymous: return (self.register_anonymous, (self._binaryop, self._identity, self.name)) - name = f"monoid.{self.name}" - if name in _STANDARD_OPERATOR_NAMES: + if (name := f"monoid.{self.name}") in _STANDARD_OPERATOR_NAMES: return name return (self._deserialize, (self.name, self._binaryop, self._identity)) @@ -2963,8 +2951,7 @@ def __init__(self, name, monoid=None, binaryop=None, *, anonymous=False): def __reduce__(self): if self._anonymous: return (self.register_anonymous, (self._monoid, self._binaryop, self.name)) - name = f"semiring.{self.name}" - if name in _STANDARD_OPERATOR_NAMES: + if (name := f"semiring.{self.name}") in _STANDARD_OPERATOR_NAMES: return name return (self._deserialize, (self.name, self._monoid, self._binaryop)) diff --git a/versioneer.py b/versioneer.py index aff09869d..9c7cf8552 100644 --- a/versioneer.py +++ b/versioneer.py @@ -1466,8 +1466,7 @@ def get_versions(verbose=False): except NotThisMethod: pass - from_vcs_f = handlers.get("pieces_from_vcs") - if from_vcs_f: + if from_vcs_f := handlers.get("pieces_from_vcs"): try: pieces = from_vcs_f(cfg.tag_prefix, root, verbose) ver = render(pieces, cfg.style)