From 53d1971a69aac72efb4b43c8c620c81632dcf501 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Sun, 15 Dec 2024 15:04:13 +0800 Subject: [PATCH] fix pickling of Quantity --- README.md | 1 + brainunit/_base.py | 20 ++++---- brainunit/_base_test.py | 21 +++++++++ brainunit/sparse/_coo.py | 78 +++++++++++++++++++++++++------ brainunit/sparse/_csr.py | 99 +++++++++++++++++++++++++++++++--------- 5 files changed, 175 insertions(+), 44 deletions(-) diff --git a/README.md b/README.md index b1ee3cf..06fb05a 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,7 @@ PyPI version Continuous Integration + PyPI Downloads

diff --git a/brainunit/_base.py b/brainunit/_base.py index 140ddab..ea4ba90 100644 --- a/brainunit/_base.py +++ b/brainunit/_base.py @@ -3296,16 +3296,16 @@ def __round__(self, ndigits: int = None) -> 'Quantity': # self = self.factorless() return Quantity(self.mantissa.__round__(ndigits), unit=self.unit) - def __reduce__(self): - """ - Method used by Pickle object serialization. - - Returns - ------- - tuple - The tuple of the class and the arguments required to reconstruct the object. - """ - return array_with_unit, (self.mantissa, self.unit, None) + # def __reduce__(self): + # """ + # Method used by Pickle object serialization. + # + # Returns + # ------- + # tuple + # The tuple of the class and the arguments required to reconstruct the object. + # """ + # return array_with_unit, (self.mantissa, self.unit, None) # ----------------------- # # NumPy methods # diff --git a/brainunit/_base_test.py b/brainunit/_base_test.py index a569c3e..9864244 100644 --- a/brainunit/_base_test.py +++ b/brainunit/_base_test.py @@ -15,6 +15,7 @@ import os +import tempfile os.environ['JAX_TRACEBACK_FILTERING'] = 'off' import itertools @@ -47,6 +48,7 @@ ) from brainunit._unit_common import * from brainunit._unit_shortcuts import kHz, ms, mV, nS +import pickle class TestDimension(unittest.TestCase): @@ -1452,6 +1454,25 @@ def d_function2(true_result): d_function2(1) +def test_pickle(): + tmpdir = tempfile.gettempdir() + filename = os.path.join(tmpdir, "test.pkl") + a = 3 * mV + with open(filename, "wb") as f: + # pickle.dump(a, f) + # pickle.dump(u.mV, f) + pickle.dump(a, f) + + with open(filename, "rb") as f: + b = pickle.load(f) + print(b) + + + + + + + def test_str_repr(): """ diff --git a/brainunit/sparse/_coo.py b/brainunit/sparse/_coo.py index 88d912c..05a1fe6 100644 --- a/brainunit/sparse/_coo.py +++ b/brainunit/sparse/_coo.py @@ -106,9 +106,14 @@ def _sort_indices(self) -> COO: # TODO(jakevdp): would be benefit from lowering this to cusparse sort_rows utility? if self._rows_sorted: return self - row, col, data = lax.sort((self.row, self.col, self.data), num_keys=2) + data, unit = split_mantissa_unit(self.data) + row, col, data = lax.sort((self.row, self.col, data), num_keys=2) return self.__class__( - (data, row, col), + ( + maybe_decimal(Quantity(data, unit=unit)), + row, + col + ), shape=self.shape, rows_sorted=True ) @@ -159,7 +164,12 @@ def _eye( k = _const(idx, k) row = lax.sub(idx, lax.cond(k >= 0, lambda: zero, lambda: k)) col = lax.add(idx, lax.cond(k <= 0, lambda: zero, lambda: k)) - return cls((data, row, col), shape=(N, M), rows_sorted=True, cols_sorted=True) + return cls( + (data, row, col), + shape=(N, M), + rows_sorted=True, + cols_sorted=True + ) def with_data(self, data: jax.Array | Quantity) -> COO: assert data.shape == self.data.shape @@ -173,8 +183,12 @@ def todense(self) -> jax.Array: def transpose(self, axes: Tuple[int, ...] | None = None) -> COO: if axes is not None: raise NotImplementedError("axes argument to transpose()") - return COO((self.data, self.col, self.row), shape=self.shape[::-1], - rows_sorted=self._cols_sorted, cols_sorted=self._rows_sorted) + return COO( + (self.data, self.col, self.row), + shape=self.shape[::-1], + rows_sorted=self._cols_sorted, + cols_sorted=self._rows_sorted + ) def tree_flatten(self) -> Tuple[ Tuple[jax.Array | Quantity,], dict[str, Any] @@ -225,7 +239,11 @@ def _binary_op(self, other, op): if isinstance(other, COO): if id(self.row) == id(other.row) and id(self.col) == id(other.col): return COO( - (op(self.data, other.data), self.row, self.col), + ( + op(self.data, other.data), + self.row, + self.col + ), shape=self.shape, rows_sorted=self._rows_sorted, cols_sorted=self._cols_sorted @@ -236,7 +254,11 @@ def _binary_op(self, other, op): other = asarray(other) if other.size == 1: return COO( - (op(self.data, other), self.row, self.col), + ( + op(self.data, other), + self.row, + self.col + ), shape=self.shape, rows_sorted=self._rows_sorted, cols_sorted=self._cols_sorted @@ -244,7 +266,11 @@ def _binary_op(self, other, op): elif other.ndim == 2 and other.shape == self.shape: other = other[self.row, self.col] return COO( - (op(self.data, other), self.row, self.col), + ( + op(self.data, other), + self.row, + self.col + ), shape=self.shape, rows_sorted=self._rows_sorted, cols_sorted=self._cols_sorted @@ -256,7 +282,11 @@ def _binary_rop(self, other, op): if isinstance(other, COO): if id(self.row) == id(other.row) and id(self.col) == id(other.col): return COO( - (op(other.data, self.data), self.row, self.col), + ( + op(other.data, self.data), + self.row, + self.col + ), shape=self.shape, rows_sorted=self._rows_sorted, cols_sorted=self._cols_sorted @@ -267,7 +297,11 @@ def _binary_rop(self, other, op): other = asarray(other) if other.size == 1: return COO( - (op(other, self.data), self.row, self.col), + ( + op(other, self.data), + self.row, + self.col + ), shape=self.shape, rows_sorted=self._rows_sorted, cols_sorted=self._cols_sorted @@ -275,7 +309,11 @@ def _binary_rop(self, other, op): elif other.ndim == 2 and other.shape == self.shape: other = other[self.row, self.col] return COO( - (op(other, self.data), self.row, self.col), + ( + op(other, self.data), + self.row, + self.col + ), shape=self.shape, rows_sorted=self._rows_sorted, cols_sorted=self._cols_sorted @@ -326,7 +364,14 @@ def __matmul__( raise NotImplementedError("matmul between two sparse objects.") other = asarray(other) data, other = promote_dtypes(self.data, other) - self_promoted = COO((data, self.row, self.col), **self._info._asdict()) + self_promoted = COO( + ( + data, + self.row, + self.col + ), + **self._info._asdict() + ) if other.ndim == 1: return coo_matvec(self_promoted, other) elif other.ndim == 2: @@ -342,7 +387,14 @@ def __rmatmul__( raise NotImplementedError("matmul between two sparse objects.") other = asarray(other) data, other = promote_dtypes(self.data, other) - self_promoted = COO((data, self.row, self.col), **self._info._asdict()) + self_promoted = COO( + ( + data, + self.row, + self.col + ), + **self._info._asdict() + ) if other.ndim == 1: return coo_matvec(self_promoted, other, transpose=True) elif other.ndim == 2: diff --git a/brainunit/sparse/_csr.py b/brainunit/sparse/_csr.py index 612ba1b..94316e1 100644 --- a/brainunit/sparse/_csr.py +++ b/brainunit/sparse/_csr.py @@ -123,7 +123,9 @@ def _binary_op(self, other, op): if isinstance(other, CSR): if id(other.indices) == id(self.indices) and id(other.indptr) == id(self.indptr): return CSR( - (op(self.data, other.data), self.indices, self.indptr), + (op(self.data, other.data), + self.indices, + self.indptr), shape=self.shape ) if isinstance(other, JAXSparse): @@ -139,7 +141,9 @@ def _binary_op(self, other, op): rows, cols = _csr_to_coo(self.indices, self.indptr) other = other[rows, cols] return CSR( - (op(self.data, other), self.indices, self.indptr), + (op(self.data, other), + self.indices, + self.indptr), shape=self.shape ) else: @@ -149,7 +153,9 @@ def _binary_rop(self, other, op): if isinstance(other, CSR): if id(other.indices) == id(self.indices) and id(other.indptr) == id(self.indptr): return CSR( - (op(other.data, self.data), self.indices, self.indptr), + (op(other.data, self.data), + self.indices, + self.indptr), shape=self.shape ) if isinstance(other, JAXSparse): @@ -158,14 +164,18 @@ def _binary_rop(self, other, op): other = asarray(other) if other.size == 1: return CSR( - (op(other, self.data), self.indices, self.indptr), + (op(other, self.data), + self.indices, + self.indptr), shape=self.shape ) elif other.ndim == 2 and other.shape == self.shape: rows, cols = _csr_to_coo(self.indices, self.indptr) other = other[rows, cols] return CSR( - (op(other, self.data), self.indices, self.indptr), + (op(other, self.data), + self.indices, + self.indptr), shape=self.shape ) else: @@ -213,9 +223,21 @@ def __matmul__(self, other): other = asarray(other) data, other = promote_dtypes(self.data, other) if other.ndim == 1: - return _csr_matvec(data, self.indices, self.indptr, other, shape=self.shape) + return _csr_matvec( + data, + self.indices, + self.indptr, + other, + shape=self.shape + ) elif other.ndim == 2: - return _csr_matmat(data, self.indices, self.indptr, other, shape=self.shape) + return _csr_matmat( + data, + self.indices, + self.indptr, + other, + shape=self.shape + ) else: raise NotImplementedError(f"matmul with object of shape {other.shape}") @@ -225,10 +247,24 @@ def __rmatmul__(self, other): other = asarray(other) data, other = promote_dtypes(self.data, other) if other.ndim == 1: - return _csr_matvec(data, self.indices, self.indptr, other, shape=self.shape, transpose=True) + return _csr_matvec( + data, + self.indices, + self.indptr, + other, + shape=self.shape, + transpose=True + ) elif other.ndim == 2: other = other.T - r = _csr_matmat(data, self.indices, self.indptr, other, shape=self.shape, transpose=True) + r = _csr_matmat( + data, + self.indices, + self.indptr, + other, + shape=self.shape, + transpose=True + ) return r.T else: raise NotImplementedError(f"matmul with object of shape {other.shape}") @@ -258,8 +294,6 @@ class CSC(SparseMatrix): nse = property(lambda self: self.data.size) dtype = property(lambda self: self.data.dtype) - __array_priority__ = 2000 - def __init__(self, args, *, shape): self.data, self.indices, self.indptr = map(asarray, args) super().__init__(args, shape=shape) @@ -311,7 +345,9 @@ def _binary_op(self, other, op): if isinstance(other, CSC): if id(other.indices) == id(self.indices) and id(other.indptr) == id(self.indptr): return CSC( - (op(self.data, other.data), self.indices, self.indptr), + (op(self.data, other.data), + self.indices, + self.indptr), shape=self.shape ) if isinstance(other, JAXSparse): @@ -320,14 +356,18 @@ def _binary_op(self, other, op): other = asarray(other) if other.size == 1: return CSC( - (op(self.data, other), self.indices, self.indptr), + (op(self.data, other), + self.indices, + self.indptr), shape=self.shape ) elif other.ndim == 2 and other.shape == self.shape: cols, rows = _csr_to_coo(self.indices, self.indptr) other = other[rows, cols] return CSC( - (op(self.data, other), self.indices, self.indptr), + (op(self.data, other), + self.indices, + self.indptr), shape=self.shape ) else: @@ -337,7 +377,9 @@ def _binary_rop(self, other, op): if isinstance(other, CSC): if id(other.indices) == id(self.indices) and id(other.indptr) == id(self.indptr): return CSC( - (op(other.data, self.data), self.indices, self.indptr), + (op(other.data, self.data), + self.indices, + self.indptr), shape=self.shape ) if isinstance(other, JAXSparse): @@ -346,14 +388,18 @@ def _binary_rop(self, other, op): other = asarray(other) if other.size == 1: return CSC( - (op(other, self.data), self.indices, self.indptr), + (op(other, self.data), + self.indices, + self.indptr), shape=self.shape ) elif other.ndim == 2 and other.shape == self.shape: cols, rows = _csr_to_coo(self.indices, self.indptr) other = other[rows, cols] return CSC( - (op(other, self.data), self.indices, self.indptr), + (op(other, self.data), + self.indices, + self.indptr), shape=self.shape ) else: @@ -427,12 +473,23 @@ def __rmatmul__(self, other): other = asarray(other) data, other = promote_dtypes(self.data, other) if other.ndim == 1: - return _csr_matvec(data, self.indices, self.indptr, other, - shape=self.shape[::-1], transpose=False) + return _csr_matvec( + data, + self.indices, + self.indptr, + other, + shape=self.shape[::-1], + transpose=False + ) elif other.ndim == 2: other = other.T - r = _csr_matmat(data, self.indices, self.indptr, other, - shape=self.shape[::-1], transpose=False) + r = _csr_matmat( + data, + self.indices, + self.indptr, other, + shape=self.shape[::-1], + transpose=False + ) return r.T else: raise NotImplementedError(f"matmul with object of shape {other.shape}")