Skip to content

Commit

Permalink
fix pickling of Quantity
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Dec 15, 2024
1 parent 3485721 commit 53d1971
Show file tree
Hide file tree
Showing 5 changed files with 175 additions and 44 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
</a>
<a href="https://badge.fury.io/py/brainunit"><img alt="PyPI version" src="https://badge.fury.io/py/brainunit.svg"></a>
<a href="https://github.com/chaobrain/brainunit/actions/workflows/CI.yml"><img alt="Continuous Integration" src="https://github.com/chaobrain/brainunit/actions/workflows/CI.yml/badge.svg"></a>
<a href="https://pepy.tech/projects/brainunit"><img src="https://static.pepy.tech/badge/brainunit" alt="PyPI Downloads"></a>
</p>


Expand Down
20 changes: 10 additions & 10 deletions brainunit/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3296,16 +3296,16 @@ def __round__(self, ndigits: int = None) -> 'Quantity':
# self = self.factorless()
return Quantity(self.mantissa.__round__(ndigits), unit=self.unit)

def __reduce__(self):
"""
Method used by Pickle object serialization.
Returns
-------
tuple
The tuple of the class and the arguments required to reconstruct the object.
"""
return array_with_unit, (self.mantissa, self.unit, None)
# def __reduce__(self):
# """
# Method used by Pickle object serialization.
#
# Returns
# -------
# tuple
# The tuple of the class and the arguments required to reconstruct the object.
# """
# return array_with_unit, (self.mantissa, self.unit, None)

# ----------------------- #
# NumPy methods #
Expand Down
21 changes: 21 additions & 0 deletions brainunit/_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@


import os
import tempfile

os.environ['JAX_TRACEBACK_FILTERING'] = 'off'
import itertools
Expand Down Expand Up @@ -47,6 +48,7 @@
)
from brainunit._unit_common import *
from brainunit._unit_shortcuts import kHz, ms, mV, nS
import pickle


class TestDimension(unittest.TestCase):
Expand Down Expand Up @@ -1452,6 +1454,25 @@ def d_function2(true_result):
d_function2(1)


def test_pickle():
tmpdir = tempfile.gettempdir()
filename = os.path.join(tmpdir, "test.pkl")
a = 3 * mV
with open(filename, "wb") as f:
# pickle.dump(a, f)
# pickle.dump(u.mV, f)
pickle.dump(a, f)

with open(filename, "rb") as f:
b = pickle.load(f)
print(b)








def test_str_repr():
"""
Expand Down
78 changes: 65 additions & 13 deletions brainunit/sparse/_coo.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,14 @@ def _sort_indices(self) -> COO:
# TODO(jakevdp): would be benefit from lowering this to cusparse sort_rows utility?
if self._rows_sorted:
return self
row, col, data = lax.sort((self.row, self.col, self.data), num_keys=2)
data, unit = split_mantissa_unit(self.data)
row, col, data = lax.sort((self.row, self.col, data), num_keys=2)
return self.__class__(
(data, row, col),
(
maybe_decimal(Quantity(data, unit=unit)),
row,
col
),
shape=self.shape,
rows_sorted=True
)
Expand Down Expand Up @@ -159,7 +164,12 @@ def _eye(
k = _const(idx, k)
row = lax.sub(idx, lax.cond(k >= 0, lambda: zero, lambda: k))
col = lax.add(idx, lax.cond(k <= 0, lambda: zero, lambda: k))
return cls((data, row, col), shape=(N, M), rows_sorted=True, cols_sorted=True)
return cls(
(data, row, col),
shape=(N, M),
rows_sorted=True,
cols_sorted=True
)

def with_data(self, data: jax.Array | Quantity) -> COO:
assert data.shape == self.data.shape
Expand All @@ -173,8 +183,12 @@ def todense(self) -> jax.Array:
def transpose(self, axes: Tuple[int, ...] | None = None) -> COO:
if axes is not None:
raise NotImplementedError("axes argument to transpose()")
return COO((self.data, self.col, self.row), shape=self.shape[::-1],
rows_sorted=self._cols_sorted, cols_sorted=self._rows_sorted)
return COO(
(self.data, self.col, self.row),
shape=self.shape[::-1],
rows_sorted=self._cols_sorted,
cols_sorted=self._rows_sorted
)

def tree_flatten(self) -> Tuple[
Tuple[jax.Array | Quantity,], dict[str, Any]
Expand Down Expand Up @@ -225,7 +239,11 @@ def _binary_op(self, other, op):
if isinstance(other, COO):
if id(self.row) == id(other.row) and id(self.col) == id(other.col):
return COO(
(op(self.data, other.data), self.row, self.col),
(
op(self.data, other.data),
self.row,
self.col
),
shape=self.shape,
rows_sorted=self._rows_sorted,
cols_sorted=self._cols_sorted
Expand All @@ -236,15 +254,23 @@ def _binary_op(self, other, op):
other = asarray(other)
if other.size == 1:
return COO(
(op(self.data, other), self.row, self.col),
(
op(self.data, other),
self.row,
self.col
),
shape=self.shape,
rows_sorted=self._rows_sorted,
cols_sorted=self._cols_sorted
)
elif other.ndim == 2 and other.shape == self.shape:
other = other[self.row, self.col]
return COO(
(op(self.data, other), self.row, self.col),
(
op(self.data, other),
self.row,
self.col
),
shape=self.shape,
rows_sorted=self._rows_sorted,
cols_sorted=self._cols_sorted
Expand All @@ -256,7 +282,11 @@ def _binary_rop(self, other, op):
if isinstance(other, COO):
if id(self.row) == id(other.row) and id(self.col) == id(other.col):
return COO(
(op(other.data, self.data), self.row, self.col),
(
op(other.data, self.data),
self.row,
self.col
),
shape=self.shape,
rows_sorted=self._rows_sorted,
cols_sorted=self._cols_sorted
Expand All @@ -267,15 +297,23 @@ def _binary_rop(self, other, op):
other = asarray(other)
if other.size == 1:
return COO(
(op(other, self.data), self.row, self.col),
(
op(other, self.data),
self.row,
self.col
),
shape=self.shape,
rows_sorted=self._rows_sorted,
cols_sorted=self._cols_sorted
)
elif other.ndim == 2 and other.shape == self.shape:
other = other[self.row, self.col]
return COO(
(op(other, self.data), self.row, self.col),
(
op(other, self.data),
self.row,
self.col
),
shape=self.shape,
rows_sorted=self._rows_sorted,
cols_sorted=self._cols_sorted
Expand Down Expand Up @@ -326,7 +364,14 @@ def __matmul__(
raise NotImplementedError("matmul between two sparse objects.")
other = asarray(other)
data, other = promote_dtypes(self.data, other)
self_promoted = COO((data, self.row, self.col), **self._info._asdict())
self_promoted = COO(
(
data,
self.row,
self.col
),
**self._info._asdict()
)
if other.ndim == 1:
return coo_matvec(self_promoted, other)
elif other.ndim == 2:
Expand All @@ -342,7 +387,14 @@ def __rmatmul__(
raise NotImplementedError("matmul between two sparse objects.")
other = asarray(other)
data, other = promote_dtypes(self.data, other)
self_promoted = COO((data, self.row, self.col), **self._info._asdict())
self_promoted = COO(
(
data,
self.row,
self.col
),
**self._info._asdict()
)
if other.ndim == 1:
return coo_matvec(self_promoted, other, transpose=True)
elif other.ndim == 2:
Expand Down
Loading

0 comments on commit 53d1971

Please sign in to comment.