From edeed6c314c01a21b0da7faaf37f1e27a718c903 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Sat, 14 Dec 2024 14:38:40 +0800 Subject: [PATCH] support unit-aware sparse computation: `CSR`, `CSC`, `COO` (#80) * support unit-aware sparse computation * add documentation * update sparse doc * fix typing * remove deprecated `jnp.round_` --- brainunit/__init__.py | 3 +- brainunit/_base.py | 38 ++ brainunit/_sparse_base.py | 97 +++ brainunit/math/_fun_keep_unit.py | 37 +- brainunit/math/_fun_remove_unit.py | 24 +- brainunit/sparse/__init__.py | 25 + brainunit/sparse/coo.py | 516 +++++++++++++++ brainunit/sparse/coo_test.py | 270 ++++++++ brainunit/sparse/csr.py | 601 ++++++++++++++++++ brainunit/sparse/csr_test.py | 549 ++++++++++++++++ ...{constants.rst => brainunit.constants.rst} | 0 docs/apis/brainunit.sparse.rst | 33 + docs/index.rst | 3 +- 13 files changed, 2184 insertions(+), 12 deletions(-) create mode 100644 brainunit/_sparse_base.py create mode 100644 brainunit/sparse/__init__.py create mode 100644 brainunit/sparse/coo.py create mode 100644 brainunit/sparse/coo_test.py create mode 100644 brainunit/sparse/csr.py create mode 100644 brainunit/sparse/csr_test.py rename docs/apis/{constants.rst => brainunit.constants.rst} (100%) create mode 100644 docs/apis/brainunit.sparse.rst diff --git a/brainunit/__init__.py b/brainunit/__init__.py index 729187d..81e48e6 100644 --- a/brainunit/__init__.py +++ b/brainunit/__init__.py @@ -22,6 +22,7 @@ from . import lax from . import linalg from . import math +from . import sparse from ._base import * from ._base import __all__ as _base_all from ._celsius import * @@ -35,7 +36,7 @@ from .constants import __all__ as _constants_all __all__ = ( - ['math', 'linalg', 'autograd', 'fft', 'constants'] + + ['math', 'linalg', 'autograd', 'fft', 'constants', 'sparse'] + _common_all + _std_units_all + _constants_all + diff --git a/brainunit/_base.py b/brainunit/_base.py index 1f2bd4b..140ddab 100644 --- a/brainunit/_base.py +++ b/brainunit/_base.py @@ -29,6 +29,7 @@ from jax.tree_util import register_pytree_node_class from ._misc import set_module_as +from ._sparse_base import SparseMatrix __all__ = [ # three base objects @@ -50,6 +51,7 @@ 'get_mantissa', 'get_magnitude', 'display_in_unit', + 'split_mantissa_unit', 'maybe_decimal', # functions for checking @@ -717,6 +719,26 @@ def get_mantissa(obj): get_magnitude = get_mantissa +def split_mantissa_unit(obj): + """ + Split a Quantity into its mantissa and unit. + + Parameters + ---------- + obj : `object` + The object to check. + + Returns + ------- + mantissa : `float` or `array_like` + The mantissa of the `obj`. + unit : Unit + The physical unit of the `obj`. + """ + obj = _to_quantity(obj) + return obj.mantissa, obj.unit + + @set_module_as('brainunit') def have_same_dim(obj1, obj2) -> bool: """Test if two values have the same dimensions. @@ -3033,6 +3055,8 @@ def _binary_operation( return r def __add__(self, oc): + if isinstance(oc, SparseMatrix): + return oc.__radd__(self) return self._binary_operation(oc, operator.add, fail_for_mismatch=True, operator_str="+") def __radd__(self, oc): @@ -3043,6 +3067,8 @@ def __iadd__(self, oc): return self._binary_operation(oc, operator.add, fail_for_mismatch=True, operator_str="+=", inplace=True) def __sub__(self, oc): + if isinstance(oc, SparseMatrix): + return oc.__rsub__(self) return self._binary_operation(oc, operator.sub, fail_for_mismatch=True, operator_str="-") def __rsub__(self, oc): @@ -3053,6 +3079,8 @@ def __isub__(self, oc): return self._binary_operation(oc, operator.sub, fail_for_mismatch=True, operator_str="-=", inplace=True) def __mul__(self, oc): + if isinstance(oc, SparseMatrix): + return oc.__rmul__(self) r = self._binary_operation(oc, operator.mul, operator.mul) return maybe_decimal(r) @@ -3065,6 +3093,8 @@ def __imul__(self, oc): def __div__(self, oc): # self / oc + if isinstance(oc, SparseMatrix): + return oc.__rdiv__(self) r = self._binary_operation(oc, operator.truediv, operator.truediv) return maybe_decimal(r) @@ -3073,6 +3103,8 @@ def __idiv__(self, oc): def __truediv__(self, oc): # self / oc + if isinstance(oc, SparseMatrix): + return oc.__rtruediv__(self) return self.__div__(oc) def __rdiv__(self, oc): @@ -3092,6 +3124,8 @@ def __itruediv__(self, oc): def __floordiv__(self, oc): # self // oc + if isinstance(oc, SparseMatrix): + return oc.__rfloordiv__(self) r = self._binary_operation(oc, operator.floordiv, operator.truediv) return maybe_decimal(r) @@ -3108,6 +3142,8 @@ def __ifloordiv__(self, oc): def __mod__(self, oc): # self % oc + if isinstance(oc, SparseMatrix): + return oc.__rmod__(self) r = self._binary_operation(oc, operator.mod, lambda ua, ub: ua, fail_for_mismatch=True, operator_str=r"%") return maybe_decimal(r) @@ -3127,6 +3163,8 @@ def __rdivmod__(self, oc): return self.__rfloordiv__(oc), self.__rmod__(oc) def __matmul__(self, oc): + if isinstance(oc, SparseMatrix): + return oc.__rmatmul__(self) r = self._binary_operation(oc, operator.matmul, operator.mul, operator_str="@") return maybe_decimal(r) diff --git a/brainunit/_sparse_base.py b/brainunit/_sparse_base.py new file mode 100644 index 0000000..91b9562 --- /dev/null +++ b/brainunit/_sparse_base.py @@ -0,0 +1,97 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import annotations + +from typing import Sequence + +from jax.experimental.sparse import JAXSparse + +__all__ = [ + "SparseMatrix" +] + + +class SparseMatrix(JAXSparse): + + # Not abstract methods because not all sparse classes implement them + + def with_data(self, data): + raise NotImplementedError(f"{self.__class__}.assign_data") + + def sum(self, axis: int | Sequence[int] = None): + if axis is not None: + raise NotImplementedError("CSR.sum with axis is not implemented.") + return self.data.sum() + + def __abs__(self): + raise NotImplementedError(f"{self.__class__}.__abs__") + + def __neg__(self): + raise NotImplementedError(f"{self.__class__}.__neg__") + + def __pos__(self): + raise NotImplementedError(f"{self.__class__}.__pos__") + + def __matmul__(self, other): + raise NotImplementedError(f"{self.__class__}.__matmul__") + + def __rmatmul__(self, other): + raise NotImplementedError(f"{self.__class__}.__rmatmul__") + + def __mul__(self, other): + raise NotImplementedError(f"{self.__class__}.__mul__") + + def __rmul__(self, other): + raise NotImplementedError(f"{self.__class__}.__rmul__") + + def __add__(self, other): + raise NotImplementedError(f"{self.__class__}.__add__") + + def __radd__(self, other): + raise NotImplementedError(f"{self.__class__}.__radd__") + + def __sub__(self, other): + raise NotImplementedError(f"{self.__class__}.__sub__") + + def __rsub__(self, other): + raise NotImplementedError(f"{self.__class__}.__rsub__") + + def __div__(self, other): + raise NotImplementedError(f"{self.__class__}.__div__") + + def __rdiv__(self, other): + raise NotImplementedError(f"{self.__class__}.__rdiv__") + + def __truediv__(self, other): + raise NotImplementedError(f"{self.__class__}.__truediv__") + + def __rtruediv__(self, other): + raise NotImplementedError(f"{self.__class__}.__rtruediv__") + + def __floordiv__(self, other): + raise NotImplementedError(f"{self.__class__}.__floordiv__") + + def __rfloordiv__(self, other): + raise NotImplementedError(f"{self.__class__}.__rfloordiv__") + + def __mod__(self, other): + raise NotImplementedError(f"{self.__class__}.__mod__") + + def __rmod__(self, other): + raise NotImplementedError(f"{self.__class__}.__rmod__") + + def __getitem__(self, item): + raise NotImplementedError(f"{self.__class__}.__getitem__") diff --git a/brainunit/math/_fun_keep_unit.py b/brainunit/math/_fun_keep_unit.py index b51a446..4af911d 100644 --- a/brainunit/math/_fun_keep_unit.py +++ b/brainunit/math/_fun_keep_unit.py @@ -20,6 +20,7 @@ import jax import jax.numpy as jnp +from jax._src.numpy.util import promote_dtypes as _promote_dtypes import numpy as np from ._fun_array_creation import asarray @@ -63,7 +64,7 @@ # math funcs keep unit (binary) 'fmod', 'mod', 'copysign', 'remainder', 'maximum', 'minimum', 'fmax', 'fmin', 'lcm', 'gcd', 'trace', - 'add', 'subtract', 'nextafter', + 'add', 'subtract', 'nextafter', 'promote_dtypes', # math funcs keep unit 'interp', 'clip', 'histogram', @@ -492,6 +493,27 @@ def broadcast_arrays( return _broadcast_fun(jnp.broadcast_arrays, *args) +@set_module_as('brainunit.math') +def promote_dtypes( + *args: Union[Quantity, jax.typing.ArrayLike] +) -> Union[Quantity | jax.Array | Sequence[jax.Array | Quantity]]: + """ + Promote the data types of the inputs to a common type. + + Parameters + ---------- + `*args` : array_likes + The arrays to promote. + + Returns + ------- + promoted : list of arrays + These arrays have the same shape as the input arrays, with the + data type of the most precise input. + """ + return _broadcast_fun(_promote_dtypes, *args) + + @set_module_as('brainunit.math') def broadcast_to( array: Union[Quantity, jax.typing.ArrayLike], @@ -3371,13 +3393,12 @@ def round_( ------- out : jax.Array """ - return _fun_keep_unit_unary(jnp.round_, x) + return _fun_keep_unit_unary(jnp.round, x) @set_module_as('brainunit.math') -def around( +def round( x: Union[Quantity, jax.typing.ArrayLike], - decimals: int = 0, ) -> jax.Array | Quantity: """ Round an array to the nearest integer. @@ -3386,18 +3407,16 @@ def around( ---------- x : array_like, Quantity Input array. - decimals : int, optional - Number of decimal places to round to (default is 0). Returns ------- out : jax.Array """ - return _fun_keep_unit_unary(jnp.around, x, decimals=decimals) + return _fun_keep_unit_unary(jnp.round, x) @set_module_as('brainunit.math') -def round( +def around( x: Union[Quantity, jax.typing.ArrayLike], decimals: int = 0, ) -> jax.Array | Quantity: @@ -3415,7 +3434,7 @@ def round( ------- out : jax.Array """ - return _fun_keep_unit_unary(jnp.round, x, decimals=decimals) + return _fun_keep_unit_unary(jnp.around, x, decimals=decimals) @set_module_as('brainunit.math') diff --git a/brainunit/math/_fun_remove_unit.py b/brainunit/math/_fun_remove_unit.py index 5fe35b6..773dc97 100644 --- a/brainunit/math/_fun_remove_unit.py +++ b/brainunit/math/_fun_remove_unit.py @@ -24,7 +24,7 @@ __all__ = [ # math funcs remove unit (unary) - 'iscomplexobj', 'heaviside', 'signbit', 'sign', 'bincount', 'digitize', + 'iscomplexobj', 'heaviside', 'signbit', 'sign', 'bincount', 'digitize', 'get_promote_dtypes', # logic funcs (unary) 'all', 'any', 'logical_not', @@ -43,6 +43,28 @@ # math funcs remove unit (unary) # ------------------------------ + +@set_module_as('brainunit.math') +def get_promote_dtypes( + *args: Union[Quantity, jax.typing.ArrayLike] +) -> Union[Quantity | jax.Array | Sequence[jax.Array | Quantity]]: + """ + Promote the data types of the inputs to a common type. + + Parameters + ---------- + `*args` : array_likes + The arrays to promote. + + Returns + ------- + promoted : list of arrays + These arrays have the same shape as the input arrays, with the + data type of the most precise input. + """ + return jnp.promote_types(*jax.tree.leaves(args)) + + def _fun_remove_unit_unary(func, x, *args, **kwargs): if isinstance(x, Quantity): # x = x.factorless() diff --git a/brainunit/sparse/__init__.py b/brainunit/sparse/__init__.py new file mode 100644 index 0000000..8e51511 --- /dev/null +++ b/brainunit/sparse/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +from .coo import COO, coo_todense, coo_fromdense +from .csr import CSR, CSC, csr_todense, csr_fromdense, csc_fromdense, csc_todense + +__all__ = [ + "CSR", "CSC", + "csr_todense", "csr_fromdense", + "csc_todense", "csc_fromdense", + "COO", "coo_todense", "coo_fromdense" +] diff --git a/brainunit/sparse/coo.py b/brainunit/sparse/coo.py new file mode 100644 index 0000000..a088a42 --- /dev/null +++ b/brainunit/sparse/coo.py @@ -0,0 +1,516 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +from __future__ import annotations + +import operator +from typing import Any, Tuple, Sequence, NamedTuple + +import jax +import jax.numpy as jnp +import numpy as np +from jax import lax +from jax import tree_util +from jax._src.lax.lax import _const +from jax.experimental.sparse import JAXSparse, coo_todense_p, coo_fromdense_p, coo_matmat_p, coo_matvec_p + +from brainunit._base import Quantity, split_mantissa_unit, maybe_decimal, get_mantissa, get_unit +from brainunit._sparse_base import SparseMatrix +from brainunit.math._fun_array_creation import asarray +from brainunit.math._fun_keep_unit import promote_dtypes + +__all__ = [ + 'COO', 'coo_todense', 'coo_fromdense', +] + +Dtype = Any +Shape = tuple[int, ...] + + +class COOInfo(NamedTuple): + shape: Shape + rows_sorted: bool = False + cols_sorted: bool = False + + +@tree_util.register_pytree_node_class +class COO(SparseMatrix): + """Experimental COO matrix implemented in JAX. + + Note: this class has minimal compatibility with JAX transforms such as + grad and autodiff, and offers very little functionality. In general you + should prefer :class:`jax.experimental.sparse.BCOO`. + + Additionally, there are known failures in the case that `nse` is larger + than the true number of nonzeros in the represented matrix. This situation + is better handled in BCOO. + """ + data: jax.Array + row: jax.Array + col: jax.Array + shape: tuple[int, int] + nse = property(lambda self: self.data.size) + dtype = property(lambda self: self.data.dtype) + _info = property( + lambda self: COOInfo( + shape=self.shape, + rows_sorted=self._rows_sorted, + cols_sorted=self._cols_sorted) + ) + _bufs = property(lambda self: (self.data, self.row, self.col)) + _rows_sorted: bool + _cols_sorted: bool + + def __init__( + self, + args: Tuple[jax.Array | Quantity, jax.Array, jax.Array], + *, + shape: Shape, + rows_sorted: bool = False, + cols_sorted: bool = False + ): + self.data, self.row, self.col = map(asarray, args) + self._rows_sorted = rows_sorted + self._cols_sorted = cols_sorted + super().__init__(args, shape=shape) + + @classmethod + def fromdense( + cls, + mat: jax.Array, + *, + nse: int | None = None, + index_dtype: jax.typing.DTypeLike = np.int32 + ) -> COO: + return coo_fromdense(mat, nse=nse, index_dtype=index_dtype) + + def _sort_indices(self) -> COO: + """Return a copy of the COO matrix with sorted indices. + + The matrix is sorted by row indices and column indices per row. + If self._rows_sorted is True, this returns ``self`` without a copy. + """ + # TODO(jakevdp): would be benefit from lowering this to cusparse sort_rows utility? + if self._rows_sorted: + return self + row, col, data = lax.sort((self.row, self.col, self.data), num_keys=2) + return self.__class__( + (data, row, col), + shape=self.shape, + rows_sorted=True + ) + + @classmethod + def _empty( + cls, + shape: Sequence[int], + *, + dtype: jax.typing.DTypeLike | None = None, + index_dtype: jax.typing.DTypeLike = 'int32' + ) -> COO: + """Create an empty COO instance. Public method is sparse.empty().""" + shape = tuple(shape) + if len(shape) != 2: + raise ValueError(f"COO must have ndim=2; got {shape=}") + data = jnp.empty(0, dtype) + row = col = jnp.empty(0, index_dtype) + return cls( + (data, row, col), + shape=shape, + rows_sorted=True, + cols_sorted=True + ) + + @classmethod + def _eye( + cls, + N: int, + M: int, + k: int, + *, + dtype: jax.typing.DTypeLike | None = None, + index_dtype: jax.typing.DTypeLike = 'int32' + ) -> COO: + if k > 0: + diag_size = min(N, M - k) + else: + diag_size = min(N + k, M) + + if diag_size <= 0: + # if k is out of range, return an empty matrix. + return cls._empty((N, M), dtype=dtype, index_dtype=index_dtype) + + data = jnp.ones(diag_size, dtype=dtype) + idx = jnp.arange(diag_size, dtype=index_dtype) + zero = _const(idx, 0) + k = _const(idx, k) + row = lax.sub(idx, lax.cond(k >= 0, lambda: zero, lambda: k)) + col = lax.add(idx, lax.cond(k <= 0, lambda: zero, lambda: k)) + return cls((data, row, col), shape=(N, M), rows_sorted=True, cols_sorted=True) + + def with_data(self, data: jax.Array | Quantity) -> COO: + assert data.shape == self.data.shape + assert data.dtype == self.data.dtype + assert get_unit(data) == get_unit(self.data) + return COO((data, self.row, self.col), shape=self.shape) + + def todense(self) -> jax.Array: + return coo_todense(self) + + def transpose(self, axes: Tuple[int, ...] | None = None) -> COO: + if axes is not None: + raise NotImplementedError("axes argument to transpose()") + return COO((self.data, self.col, self.row), shape=self.shape[::-1], + rows_sorted=self._cols_sorted, cols_sorted=self._rows_sorted) + + def tree_flatten(self) -> Tuple[ + Tuple[jax.Array | Quantity, jax.Array, jax.Array], dict[str, Any] + ]: + return (self.data, self.row, self.col), self._info._asdict() + + @classmethod + def tree_unflatten(cls, aux_data, children): + obj = object.__new__(cls) + obj.data, obj.row, obj.col = children + if aux_data.keys() != {'shape', 'rows_sorted', 'cols_sorted'}: + raise ValueError(f"COO.tree_unflatten: invalid {aux_data=}") + obj.shape = aux_data['shape'] + obj._rows_sorted = aux_data['rows_sorted'] + obj._cols_sorted = aux_data['cols_sorted'] + return obj + + def __abs__(self): + return COO( + (self.data.__abs__(), self.row, self.col), + shape=self.shape, + rows_sorted=self._rows_sorted, + cols_sorted=self._cols_sorted + ) + + def __neg__(self): + return COO( + (-self.data, self.row, self.col), + shape=self.shape, + rows_sorted=self._rows_sorted, + cols_sorted=self._cols_sorted + ) + + def __pos__(self): + return COO( + (self.data.__pos__(), self.row, self.col), + shape=self.shape, + rows_sorted=self._rows_sorted, + cols_sorted=self._cols_sorted + ) + + def _binary_op(self, other, op): + if isinstance(other, JAXSparse): + raise NotImplementedError("mul between two sparse objects.") + other = asarray(other) + if other.size == 1: + return COO( + (op(self.data, other), self.row, self.col), + shape=self.shape + ) + elif other.ndim == 2 and other.shape == self.shape: + other = other[self.row, self.col] + return COO( + (op(self.data, other), self.row, self.col), + shape=self.shape + ) + else: + raise NotImplementedError(f"mul with object of shape {other.shape}") + + def _binary_rop(self, other, op): + if isinstance(other, JAXSparse): + raise NotImplementedError("mul between two sparse objects.") + other = asarray(other) + if other.size == 1: + return COO( + (op(other, self.data), self.row, self.col), + shape=self.shape + ) + elif other.ndim == 2 and other.shape == self.shape: + other = other[self.row, self.col] + return COO( + (op(other, self.data), self.row, self.col), + shape=self.shape + ) + else: + raise NotImplementedError(f"mul with object of shape {other.shape}") + + def __mul__(self, other: jax.Array | Quantity) -> COO: + return self._binary_op(other, operator.mul) + + def __rmul__(self, other: jax.Array | Quantity) -> COO: + return self._binary_rop(other, operator.mul) + + def __div__(self, other: jax.Array | Quantity) -> COO: + return self._binary_op(other, operator.truediv) + + def __rdiv__(self, other: jax.Array | Quantity) -> COO: + return self._binary_rop(other, operator.truediv) + + def __truediv__(self, other) -> COO: + return self.__div__(other) + + def __rtruediv__(self, other) -> COO: + return self.__rdiv__(other) + + def __add__(self, other) -> COO: + return self._binary_op(other, operator.add) + + def __radd__(self, other) -> COO: + return self._binary_rop(other, operator.add) + + def __sub__(self, other) -> COO: + return self._binary_op(other, operator.sub) + + def __rsub__(self, other) -> COO: + return self._binary_rop(other, operator.sub) + + def __mod__(self, other) -> COO: + return self._binary_op(other, operator.mod) + + def __rmod__(self, other) -> COO: + return self._binary_rop(other, operator.mod) + + def __matmul__( + self, other: jax.typing.ArrayLike + ) -> jax.Array | Quantity: + if isinstance(other, JAXSparse): + raise NotImplementedError("matmul between two sparse objects.") + other = asarray(other) + data, other = promote_dtypes(self.data, other) + self_promoted = COO((data, self.row, self.col), **self._info._asdict()) + if other.ndim == 1: + return coo_matvec(self_promoted, other) + elif other.ndim == 2: + return coo_matmat(self_promoted, other) + else: + raise NotImplementedError(f"matmul with object of shape {other.shape}") + + def __rmatmul__( + self, + other: jax.typing.ArrayLike + ) -> jax.Array | Quantity: + if isinstance(other, JAXSparse): + raise NotImplementedError("matmul between two sparse objects.") + other = asarray(other) + data, other = promote_dtypes(self.data, other) + self_promoted = COO((data, self.row, self.col), **self._info._asdict()) + if other.ndim == 1: + return coo_matvec(self_promoted, other, transpose=True) + elif other.ndim == 2: + other = other.T + return coo_matmat(self_promoted, other, transpose=True).T + else: + raise NotImplementedError(f"matmul with object of shape {other.shape}") + + +def coo_todense(mat: COO) -> jax.Array | Quantity: + """Convert a COO-format sparse matrix to a dense matrix. + + Args: + mat : COO matrix + Returns: + mat_dense: dense version of ``mat`` + """ + return _coo_todense(mat.data, mat.row, mat.col, spinfo=mat._info) + + +def coo_fromdense( + mat: jax.Array | Quantity, + *, + nse: int | None = None, + index_dtype: jax.typing.DTypeLike = jnp.int32 +) -> COO: + """Create a COO-format sparse matrix from a dense matrix. + + Args: + mat : array to be converted to COO. + nse : number of specified entries in ``mat``. If not specified, + it will be computed from the input matrix. + index_dtype : dtype of sparse indices + + Returns: + mat_coo : COO representation of the matrix. + """ + if nse is None: + nse = int((get_mantissa(mat) != 0).sum()) + nse_int = jax.core.concrete_or_error(operator.index, nse, "coo_fromdense nse argument") + return COO( + _coo_fromdense(mat, nse=nse_int, index_dtype=index_dtype), + shape=mat.shape, + rows_sorted=True + ) + + +def _coo_todense( + data: jax.Array | Quantity, + row: jax.Array, + col: jax.Array, + *, + spinfo: COOInfo +) -> jax.Array | Quantity: + """Convert CSR-format sparse matrix to a dense matrix. + + Args: + data : array of shape ``(nse,)``. + row : array of shape ``(nse,)`` + col : array of shape ``(nse,)`` and dtype ``row.dtype`` + spinfo : COOInfo object containing matrix metadata + + Returns: + mat : array with specified shape and dtype matching ``data`` + """ + data, unit = split_mantissa_unit(data) + r = coo_todense_p.bind(data, row, col, spinfo=spinfo) + return maybe_decimal(r * unit) + + +def _coo_fromdense( + mat: jax.Array | Quantity, + *, + nse: int, + index_dtype: jax.typing.DTypeLike = jnp.int32 +) -> Tuple[jax.Array | Quantity, jax.Array, jax.Array]: + """Create COO-format sparse matrix from a dense matrix. + + Args: + mat : array to be converted to COO. + nse : number of specified entries in ``mat`` + index_dtype : dtype of sparse indices + + Returns: + data : array of shape ``(nse,)`` and dtype ``mat.dtype`` + row : array of shape ``(nse,)`` and dtype ``index_dtype`` + col : array of shape ``(nse,)`` and dtype ``index_dtype`` + """ + mat = asarray(mat) + mat, unit = split_mantissa_unit(mat) + nse = jax.core.concrete_or_error(operator.index, nse, "nse argument of coo_fromdense()") + r = coo_fromdense_p.bind(mat, nse=nse, index_dtype=index_dtype) + if unit.is_unitless: + return r + return r[0] * unit, r[1], r[2] + + +def coo_matvec( + mat: COO, + v: jax.Array | Quantity, + transpose: bool = False +) -> jax.Array | Quantity: + """Product of COO sparse matrix and a dense vector. + + Args: + mat : COO matrix + v : one-dimensional array of size ``(shape[0] if transpose else shape[1],)`` and + dtype ``mat.dtype`` + transpose : boolean specifying whether to transpose the sparse matrix + before computing. + + Returns: + y : array of shape ``(mat.shape[1] if transpose else mat.shape[0],)`` representing + the matrix vector product. + """ + data, row, col = mat._bufs + return _coo_matvec(data, row, col, v, spinfo=mat._info, transpose=transpose) + + +def _coo_matvec( + data: jax.Array | Quantity, + row: jax.Array, + col: jax.Array, + v: jax.Array | Quantity, + *, + spinfo: COOInfo, + transpose: bool = False +) -> jax.Array | Quantity: + """Product of COO sparse matrix and a dense vector. + + Args: + data : array of shape ``(nse,)``. + row : array of shape ``(nse,)`` + col : array of shape ``(nse,)`` and dtype ``row.dtype`` + v : array of shape ``(shape[0] if transpose else shape[1],)`` and + dtype ``data.dtype`` + shape : length-2 tuple representing the matrix shape + transpose : boolean specifying whether to transpose the sparse matrix + before computing. + + Returns: + y : array of shape ``(shape[1] if transpose else shape[0],)`` representing + the matrix vector product. + """ + data, unita = split_mantissa_unit(data) + v, unitv = split_mantissa_unit(v) + r = coo_matvec_p.bind(data, row, col, v, spinfo=spinfo, transpose=transpose) + return maybe_decimal(r * unita * unitv) + + +def coo_matmat( + mat: COO, + B: jax.Array | Quantity, + *, + transpose: bool = False +) -> jax.Array | Quantity: + """Product of COO sparse matrix and a dense matrix. + + Args: + mat : COO matrix + B : array of shape ``(mat.shape[0] if transpose else mat.shape[1], cols)`` and + dtype ``mat.dtype`` + transpose : boolean specifying whether to transpose the sparse matrix + before computing. + + Returns: + C : array of shape ``(mat.shape[1] if transpose else mat.shape[0], cols)`` + representing the matrix vector product. + """ + data, row, col = mat._bufs + return _coo_matmat(data, row, col, B, spinfo=mat._info, transpose=transpose) + + +def _coo_matmat( + data: jax.Array | Quantity, + row: jax.Array, + col: jax.Array, + B: jax.Array | Quantity, + *, + spinfo: COOInfo, + transpose: bool = False +) -> jax.Array: + """Product of COO sparse matrix and a dense matrix. + + Args: + data : array of shape ``(nse,)``. + row : array of shape ``(nse,)`` + col : array of shape ``(nse,)`` and dtype ``row.dtype`` + B : array of shape ``(shape[0] if transpose else shape[1], cols)`` and + dtype ``data.dtype`` + shape : length-2 tuple representing the matrix shape + transpose : boolean specifying whether to transpose the sparse matrix + before computing. + + Returns: + C : array of shape ``(shape[1] if transpose else shape[0], cols)`` + representing the matrix vector product. + """ + data, unita = split_mantissa_unit(data) + B, unitb = split_mantissa_unit(B) + res = coo_matmat_p.bind(data, row, col, B, spinfo=spinfo, transpose=transpose) + return maybe_decimal(res * unita * unitb) diff --git a/brainunit/sparse/coo_test.py b/brainunit/sparse/coo_test.py new file mode 100644 index 0000000..e771616 --- /dev/null +++ b/brainunit/sparse/coo_test.py @@ -0,0 +1,270 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +from __future__ import annotations + +import unittest + +import brainstate as bst + +import brainunit as u + + +class TestCOO(unittest.TestCase): + def test_matvec(self): + for ux, uy in [ + (u.ms, u.mV), + (u.UNITLESS, u.UNITLESS), + (u.mV, u.UNITLESS), + (u.UNITLESS, u.mV), + ]: + data = bst.random.rand(10, 20) + data = data * (data < 0.3) * ux + + coo = u.sparse.COO.fromdense(data) + + x = bst.random.random((10,)) * uy + self.assertTrue( + u.math.allclose( + x @ data, + x @ coo + ) + ) + + x = bst.random.random((20,)) * uy + self.assertTrue( + u.math.allclose( + data @ x, + coo @ x + ) + ) + + def test_matmul(self): + for ux, uy in [ + (u.ms, u.mV), + (u.UNITLESS, u.UNITLESS), + (u.mV, u.UNITLESS), + (u.UNITLESS, u.mV), + ]: + data = bst.random.rand(10, 20) + data = data * (data < 0.3) * ux + coo = u.sparse.COO.fromdense(data) + + data2 = bst.random.rand(20, 30) * uy + + self.assertTrue( + u.math.allclose( + data @ data2, + coo @ data2 + ) + ) + + data2 = bst.random.rand(30, 10) * uy + self.assertTrue( + u.math.allclose( + data2 @ data, + data2 @ coo + ) + ) + + def test_pos(self): + for ux in [ + u.ms, + u.UNITLESS, + u.mV, + ]: + data = bst.random.rand(10, 20) + data = data * (data < 0.3) * ux + + coo = u.sparse.COO.fromdense(data) + + self.assertTrue( + u.math.allclose( + coo.__pos__().data, + coo.data + ) + ) + + def test_neg(self): + for ux in [ + u.ms, + u.UNITLESS, + u.mV, + ]: + data = bst.random.rand(10, 20) + data = data * (data < 0.3) * ux + + coo = u.sparse.COO.fromdense(data) + + self.assertTrue( + u.math.allclose( + (-coo).data, + -coo.data + ) + ) + + def test_abs(self): + for ux in [ + u.ms, + u.UNITLESS, + u.mV, + ]: + data = bst.random.rand(10, 20) + data = data * (data < 0.3) * ux + + coo = u.sparse.COO.fromdense(data) + + self.assertTrue( + u.math.allclose( + abs(coo).data, + abs(coo.data) + ) + ) + + def test_add(self): + for ux in [ + u.ms, + u.UNITLESS, + u.mV, + ]: + data1 = bst.random.rand(10, 20) + data1 = data1 * (data1 < 0.3) * ux + + data2 = 2. * ux + + coo1 = u.sparse.COO.fromdense(data1) + + self.assertTrue( + u.math.allclose( + (coo1 + data2).data, + coo1.data + data2 + ) + ) + + self.assertTrue( + u.math.allclose( + (data2 + coo1).data, + data2 + coo1.data + ) + ) + + def test_sub(self): + for ux in [ + u.ms, + u.UNITLESS, + u.mV, + ]: + data1 = bst.random.rand(10, 20) + data1 = data1 * (data1 < 0.3) * ux + + data2 = 2. * ux + + coo1 = u.sparse.COO.fromdense(data1) + + self.assertTrue( + u.math.allclose( + (coo1 - data2).data, + coo1.data - data2 + ) + ) + + self.assertTrue( + u.math.allclose( + (data2 - coo1).data, + data2 - coo1.data + ) + ) + + def test_mul(self): + for ux in [ + u.ms, + u.UNITLESS, + u.mV, + ]: + data1 = bst.random.rand(10, 20) + data1 = data1 * (data1 < 0.3) * ux + + data2 = 2. * ux + + coo1 = u.sparse.COO.fromdense(data1) + + self.assertTrue( + u.math.allclose( + (coo1 * data2).data, + coo1.data * data2 + ) + ) + + self.assertTrue( + u.math.allclose( + (data2 * coo1).data, + data2 * coo1.data + ) + ) + + def test_div(self): + for ux in [ + u.ms, + u.UNITLESS, + u.mV, + ]: + data1 = bst.random.rand(10, 20) + data1 = data1 * (data1 < 0.3) * ux + + data2 = 2. * u.ohm + + coo1 = u.sparse.COO.fromdense(data1) + + self.assertTrue( + u.math.allclose( + (coo1 / data2).data, + coo1.data / data2 + ) + ) + + self.assertTrue( + u.math.allclose( + (data2 / coo1).data, + data2 / coo1.data + ) + ) + + def test_mod(self): + for ux in [ + u.ms, + u.UNITLESS, + u.mV, + ]: + data1 = bst.random.rand(10, 20) + data1 = data1 * (data1 < 0.3) * ux + + data2 = 2. * ux + + coo1 = u.sparse.COO.fromdense(data1) + + self.assertTrue( + u.math.allclose( + (coo1 % data2).data, + coo1.data % data2 + ) + ) + + self.assertTrue( + u.math.allclose( + (data2 % coo1).data, + data2 % coo1.data + ) + ) diff --git a/brainunit/sparse/csr.py b/brainunit/sparse/csr.py new file mode 100644 index 0000000..d40a53f --- /dev/null +++ b/brainunit/sparse/csr.py @@ -0,0 +1,601 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import annotations + +import operator +from typing import Tuple, Union + +import jax +import jax.numpy as jnp +import numpy as np +from jax import tree_util +from jax._src.lax.lax import _const +from jax.experimental.sparse import ( + JAXSparse, csr_fromdense_p, csr_todense_p, csr_matvec_p, csr_matmat_p +) + +from brainunit._base import Quantity, split_mantissa_unit, maybe_decimal, get_mantissa, get_unit +from brainunit._sparse_base import SparseMatrix +from brainunit.math._fun_array_creation import asarray +from brainunit.math._fun_keep_unit import promote_dtypes + +__all__ = [ + 'CSR', 'CSC', + 'csr_fromdense', 'csr_todense', + 'csc_fromdense', 'csc_todense', +] + +Shape = tuple[int, ...] + + +@tree_util.register_pytree_node_class +class CSR(SparseMatrix): + """ + Unit-aware CSR matrix. + """ + data: jax.Array | Quantity + indices: jax.Array + indptr: jax.Array + shape: tuple[int, int] + nse = property(lambda self: self.data.size) + dtype = property(lambda self: self.data.dtype) + _bufs = property(lambda self: (self.data, self.indices, self.indptr)) + + def __init__(self, args, *, shape): + self.data, self.indices, self.indptr = map(asarray, args) + super().__init__(args, shape=shape) + + @classmethod + def fromdense(cls, mat, *, nse=None, index_dtype=np.int32): + if nse is None: + nse = (get_mantissa(mat) != 0).sum() + return csr_fromdense(mat, nse=nse, index_dtype=index_dtype) + + @classmethod + def _empty(cls, shape, *, dtype=None, index_dtype='int32'): + """Create an empty CSR instance. Public method is sparse.empty().""" + shape = tuple(shape) + if len(shape) != 2: + raise ValueError(f"CSR must have ndim=2; got {shape=}") + data = jnp.empty(0, dtype) + indices = jnp.empty(0, index_dtype) + indptr = jnp.zeros(shape[0] + 1, index_dtype) + return cls((data, indices, indptr), shape=shape) + + @classmethod + def _eye(cls, N, M, k, *, dtype=None, index_dtype='int32'): + if k > 0: + diag_size = min(N, M - k) + else: + diag_size = min(N + k, M) + + if diag_size <= 0: + # if k is out of range, return an empty matrix. + return cls._empty((N, M), dtype=dtype, index_dtype=index_dtype) + + data = jnp.ones(diag_size, dtype=dtype) + idx = jnp.arange(diag_size, dtype=index_dtype) + zero = _const(idx, 0) + k = _const(idx, k) + col = jax.lax.add(idx, jax.lax.cond(k <= 0, lambda: zero, lambda: k)) + indices = col.astype(index_dtype) + # TODO(jakevdp): this can be done more efficiently. + row = jax.lax.sub(idx, jax.lax.cond(k >= 0, lambda: zero, lambda: k)) + indptr = jnp.zeros(N + 1, dtype=index_dtype).at[1:].set( + jnp.cumsum(jnp.bincount(row, length=N).astype(index_dtype))) + return cls((data, indices, indptr), shape=(N, M)) + + def with_data(self, data: jax.Array | Quantity) -> CSR: + assert data.shape == self.data.shape + assert data.dtype == self.data.dtype + assert get_unit(data) == get_unit(self.data) + return CSR((data, self.indices, self.indptr), shape=self.shape) + + def todense(self): + return csr_todense(self) + + def transpose(self, axes=None): + assert axes is None + return CSC((self.data, self.indices, self.indptr), shape=self.shape[::-1]) + + def __abs__(self): + return CSR((abs(self.data), self.indices, self.indptr), shape=self.shape) + + def __neg__(self): + return CSR((-self.data, self.indices, self.indptr), shape=self.shape) + + def __pos__(self): + return CSR((self.data.__pos__(), self.indices, self.indptr), shape=self.shape) + + def _binary_op(self, other, op): + if isinstance(other, JAXSparse): + raise NotImplementedError("mul between two sparse objects.") + other = asarray(other) + if other.size == 1: + return CSR( + (op(self.data, other), self.indices, self.indptr), + shape=self.shape + ) + elif other.ndim == 2 and other.shape == self.shape: + rows, cols = _csr_to_coo(self.indices, self.indptr) + other = other[rows, cols] + return CSR( + (op(self.data, other), self.indices, self.indptr), + shape=self.shape + ) + else: + raise NotImplementedError(f"mul with object of shape {other.shape}") + + def _binary_rop(self, other, op): + if isinstance(other, JAXSparse): + raise NotImplementedError("mul between two sparse objects.") + other = asarray(other) + if other.size == 1: + return CSR( + (op(other, self.data), self.indices, self.indptr), + shape=self.shape + ) + elif other.ndim == 2 and other.shape == self.shape: + rows, cols = _csr_to_coo(self.indices, self.indptr) + other = other[rows, cols] + return CSR( + (op(other, self.data), self.indices, self.indptr), + shape=self.shape + ) + else: + raise NotImplementedError(f"mul with object of shape {other.shape}") + + def __mul__(self, other: jax.Array | Quantity) -> CSR: + return self._binary_op(other, operator.mul) + + def __rmul__(self, other: jax.Array | Quantity) -> CSR: + return self._binary_rop(other, operator.mul) + + def __div__(self, other: jax.Array | Quantity) -> CSR: + return self._binary_op(other, operator.truediv) + + def __rdiv__(self, other: jax.Array | Quantity) -> CSR: + return self._binary_rop(other, operator.truediv) + + def __truediv__(self, other) -> CSR: + return self.__div__(other) + + def __rtruediv__(self, other) -> CSR: + return self.__rdiv__(other) + + def __add__(self, other) -> CSR: + return self._binary_op(other, operator.add) + + def __radd__(self, other) -> CSR: + return self._binary_rop(other, operator.add) + + def __sub__(self, other) -> CSR: + return self._binary_op(other, operator.sub) + + def __rsub__(self, other) -> CSR: + return self._binary_rop(other, operator.sub) + + def __mod__(self, other) -> CSR: + return self._binary_op(other, operator.mod) + + def __rmod__(self, other) -> CSR: + return self._binary_rop(other, operator.mod) + + def __matmul__(self, other): + if isinstance(other, JAXSparse): + raise NotImplementedError("matmul between two sparse objects.") + other = asarray(other) + data, other = promote_dtypes(self.data, other) + if other.ndim == 1: + return _csr_matvec(data, self.indices, self.indptr, other, shape=self.shape) + elif other.ndim == 2: + return _csr_matmat(data, self.indices, self.indptr, other, shape=self.shape) + else: + raise NotImplementedError(f"matmul with object of shape {other.shape}") + + def __rmatmul__(self, other): + if isinstance(other, JAXSparse): + raise NotImplementedError("matmul between two sparse objects.") + other = asarray(other) + data, other = promote_dtypes(self.data, other) + if other.ndim == 1: + return _csr_matvec(data, self.indices, self.indptr, other, shape=self.shape, transpose=True) + elif other.ndim == 2: + other = other.T + r = _csr_matmat(data, self.indices, self.indptr, other, shape=self.shape, transpose=True) + return r.T + else: + raise NotImplementedError(f"matmul with object of shape {other.shape}") + + def tree_flatten(self): + return (self.data, self.indices, self.indptr), {"shape": self.shape} + + @classmethod + def tree_unflatten(cls, aux_data, children): + obj = object.__new__(cls) + obj.data, obj.indices, obj.indptr = children + if aux_data.keys() != {'shape'}: + raise ValueError(f"CSR.tree_unflatten: invalid {aux_data=}") + obj.__dict__.update(**aux_data) + return obj + + +@tree_util.register_pytree_node_class +class CSC(SparseMatrix): + """ + Unit-aware CSC matrix. + """ + data: jax.Array + indices: jax.Array + indptr: jax.Array + shape: tuple[int, int] + nse = property(lambda self: self.data.size) + dtype = property(lambda self: self.data.dtype) + + __array_priority__ = 2000 + + def __init__(self, args, *, shape): + self.data, self.indices, self.indptr = map(asarray, args) + super().__init__(args, shape=shape) + + @classmethod + def fromdense(cls, mat, *, nse=None, index_dtype=np.int32): + if nse is None: + nse = (get_mantissa(mat) != 0).sum() + return csr_fromdense(mat.T, nse=nse, index_dtype=index_dtype).T + + @classmethod + def _empty(cls, shape, *, dtype=None, index_dtype='int32'): + """Create an empty CSC instance. Public method is sparse.empty().""" + shape = tuple(shape) + if len(shape) != 2: + raise ValueError(f"CSC must have ndim=2; got {shape=}") + data = jnp.empty(0, dtype) + indices = jnp.empty(0, index_dtype) + indptr = jnp.zeros(shape[1] + 1, index_dtype) + return cls((data, indices, indptr), shape=shape) + + @classmethod + def _eye(cls, N, M, k, *, dtype=None, index_dtype='int32'): + return CSR._eye(M, N, -k, dtype=dtype, index_dtype=index_dtype).T + + def with_data(self, data: jax.Array | Quantity) -> CSC: + assert data.shape == self.data.shape + assert data.dtype == self.data.dtype + assert get_unit(data) == get_unit(self.data) + return CSC((data, self.indices, self.indptr), shape=self.shape) + + def todense(self): + return csr_todense(self.T).T + + def transpose(self, axes=None): + assert axes is None + return CSR((self.data, self.indices, self.indptr), shape=self.shape[::-1]) + + def __abs__(self): + return CSC((abs(self.data), self.indices, self.indptr), shape=self.shape) + + def __neg__(self): + return CSC((-self.data, self.indices, self.indptr), shape=self.shape) + + def __pos__(self): + return CSC((self.data.__pos__(), self.indices, self.indptr), shape=self.shape) + + def _binary_op(self, other, op): + if isinstance(other, JAXSparse): + raise NotImplementedError("mul between two sparse objects.") + other = asarray(other) + if other.size == 1: + return CSC( + (op(self.data, other), self.indices, self.indptr), + shape=self.shape + ) + elif other.ndim == 2 and other.shape == self.shape: + cols, rows = _csr_to_coo(self.indices, self.indptr) + other = other[rows, cols] + return CSC( + (op(self.data, other), self.indices, self.indptr), + shape=self.shape + ) + else: + raise NotImplementedError(f"mul with object of shape {other.shape}") + + def _binary_rop(self, other, op): + if isinstance(other, JAXSparse): + raise NotImplementedError("mul between two sparse objects.") + other = asarray(other) + if other.size == 1: + return CSC( + (op(other, self.data), self.indices, self.indptr), + shape=self.shape + ) + elif other.ndim == 2 and other.shape == self.shape: + cols, rows = _csr_to_coo(self.indices, self.indptr) + other = other[rows, cols] + return CSC( + (op(other, self.data), self.indices, self.indptr), + shape=self.shape + ) + else: + raise NotImplementedError(f"mul with object of shape {other.shape}") + + def __mul__(self, other: jax.Array | Quantity) -> CSC: + return self._binary_op(other, operator.mul) + + def __rmul__(self, other: jax.Array | Quantity) -> CSC: + return self._binary_rop(other, operator.mul) + + def __div__(self, other: jax.Array | Quantity) -> CSC: + return self._binary_op(other, operator.truediv) + + def __rdiv__(self, other: jax.Array | Quantity) -> CSC: + return self._binary_rop(other, operator.truediv) + + def __truediv__(self, other) -> CSC: + return self.__div__(other) + + def __rtruediv__(self, other) -> CSC: + return self.__rdiv__(other) + + def __add__(self, other) -> CSC: + return self._binary_op(other, operator.add) + + def __radd__(self, other) -> CSC: + return self._binary_rop(other, operator.add) + + def __sub__(self, other) -> CSC: + return self._binary_op(other, operator.sub) + + def __rsub__(self, other) -> CSC: + return self._binary_rop(other, operator.sub) + + def __mod__(self, other) -> CSC: + return self._binary_op(other, operator.mod) + + def __rmod__(self, other) -> CSC: + return self._binary_rop(other, operator.mod) + + def __matmul__(self, other): + if isinstance(other, JAXSparse): + raise NotImplementedError("matmul between two sparse objects.") + other = asarray(other) + data, other = promote_dtypes(self.data, other) + if other.ndim == 1: + return _csr_matvec( + data, + self.indices, + self.indptr, + other, + shape=self.shape[::-1], + transpose=True + ) + elif other.ndim == 2: + return _csr_matmat( + data, + self.indices, + self.indptr, + other, + shape=self.shape[::-1], + transpose=True + ) + else: + raise NotImplementedError(f"matmul with object of shape {other.shape}") + + def __rmatmul__(self, other): + if isinstance(other, JAXSparse): + raise NotImplementedError("matmul between two sparse objects.") + other = asarray(other) + data, other = promote_dtypes(self.data, other) + if other.ndim == 1: + return _csr_matvec(data, self.indices, self.indptr, other, + shape=self.shape[::-1], transpose=False) + elif other.ndim == 2: + other = other.T + r = _csr_matmat(data, self.indices, self.indptr, other, + shape=self.shape[::-1], transpose=False) + return r.T + else: + raise NotImplementedError(f"matmul with object of shape {other.shape}") + + def tree_flatten(self): + return (self.data, self.indices, self.indptr), {"shape": self.shape} + + @classmethod + def tree_unflatten(cls, aux_data, children): + obj = object.__new__(cls) + obj.data, obj.indices, obj.indptr = children + if aux_data.keys() != {'shape'}: + raise ValueError(f"CSC.tree_unflatten: invalid {aux_data=}") + obj.__dict__.update(**aux_data) + return obj + + +Data = Union[jax.Array, Quantity] +Indices = jax.Array +Indptr = jax.Array + + +def csr_fromdense( + mat: jax.Array | Quantity, + *, nse: int | None = None, + index_dtype: jax.typing.DTypeLike = np.int32 +) -> CSR: + """Create a CSR-format sparse matrix from a dense matrix. + + Args: + mat : array to be converted to CSR. + nse : number of specified entries in ``mat``. If not specified, + it will be computed from the input matrix. + index_dtype : dtype of sparse indices + + Returns: + mat_coo : CSR representation of the matrix. + """ + if nse is None: + nse = int((get_mantissa(mat) != 0).sum()) + nse_int = jax.core.concrete_or_error(operator.index, nse, "coo_fromdense nse argument") + return CSR(_csr_fromdense(mat, nse=nse_int, index_dtype=index_dtype), shape=mat.shape) + + +def csr_todense(mat: CSR) -> jax.Array | Quantity: + """Convert a CSR-format sparse matrix to a dense matrix. + + Args: + mat : CSR matrix + Returns: + mat_dense: dense version of ``mat`` + """ + assert isinstance(mat, CSR), f"Expected CSR, got {type(mat)}" + return _csr_todense(mat.data, mat.indices, mat.indptr, shape=mat.shape) + + +def csc_todense(mat: CSC) -> jax.Array | Quantity: + """Convert a CSR-format sparse matrix to a dense matrix. + + Args: + mat : CSR matrix + Returns: + mat_dense: dense version of ``mat`` + """ + assert isinstance(mat, CSC), f"Expected CSC, got {type(mat)}" + return mat.todense() + + +def csc_fromdense( + mat: jax.Array | Quantity, + *, + nse: int | None = None, + index_dtype: jax.typing.DTypeLike = np.int32 +) -> CSC: + assert nse is None, "nse argument is not supported for CSC" + return CSC.fromdense(mat, nse=nse, index_dtype=index_dtype) + + +def _csr_fromdense( + mat: jax.Array | Quantity, + *, + nse: int, + index_dtype: jax.typing.DTypeLike = np.int32 +) -> Tuple[Data, Indices, Indptr]: + """Create CSR-format sparse matrix from a dense matrix. + + Args: + mat : array to be converted to CSR. + nse : number of specified entries in ``mat`` + index_dtype : dtype of sparse indices + + Returns: + data : array of shape ``(nse,)`` and dtype ``mat.dtype``. + indices : array of shape ``(nse,)`` and dtype ``index_dtype`` + indptr : array of shape ``(mat.shape[0] + 1,)`` and dtype ``index_dtype`` + """ + mat = asarray(mat) + mat, unit = split_mantissa_unit(mat) + nse = jax.core.concrete_or_error(operator.index, nse, "nse argument of csr_fromdense()") + r = csr_fromdense_p.bind(mat, nse=nse, index_dtype=np.dtype(index_dtype)) + if unit.is_unitless: + return r + else: + return maybe_decimal(r[0] * unit), r[1], r[2] + + +def _csr_todense( + data: jax.Array | Quantity, + indices: jax.Array, + indptr: jax.Array, *, + shape: Shape +) -> jax.Array: + """Convert CSR-format sparse matrix to a dense matrix. + + Args: + data : array of shape ``(nse,)``. + indices : array of shape ``(nse,)`` + indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype`` + shape : length-2 tuple representing the matrix shape + + Returns: + mat : array with specified shape and dtype matching ``data`` + """ + data, unit = split_mantissa_unit(data) + mat = csr_todense_p.bind(data, indices, indptr, shape=shape) + return maybe_decimal(mat * unit) + + +def _csr_matvec( + data: jax.Array | Quantity, + indices: jax.Array, + indptr: jax.Array, + v: jax.Array | Quantity, + *, + shape: Shape, + transpose: bool = False +) -> jax.Array | Quantity: + """Product of CSR sparse matrix and a dense vector. + + Args: + data : array of shape ``(nse,)``. + indices : array of shape ``(nse,)`` + indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype`` + v : array of shape ``(shape[0] if transpose else shape[1],)`` + and dtype ``data.dtype`` + shape : length-2 tuple representing the matrix shape + transpose : boolean specifying whether to transpose the sparse matrix + before computing. + + Returns: + y : array of shape ``(shape[1] if transpose else shape[0],)`` representing + the matrix vector product. + """ + data, unitd = split_mantissa_unit(data) + v, unitv = split_mantissa_unit(v) + res = csr_matvec_p.bind(data, indices, indptr, v, shape=shape, transpose=transpose) + return maybe_decimal(res * unitd * unitv) + + +def _csr_matmat( + data: jax.Array | Quantity, + indices: jax.Array, + indptr: jax.Array, + B: jax.Array | Quantity, + *, + shape: Shape, + transpose: bool = False +) -> jax.Array | Quantity: + """Product of CSR sparse matrix and a dense matrix. + + Args: + data : array of shape ``(nse,)``. + indices : array of shape ``(nse,)`` + indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype`` + B : array of shape ``(shape[0] if transpose else shape[1], cols)`` and + dtype ``data.dtype`` + shape : length-2 tuple representing the matrix shape + transpose : boolean specifying whether to transpose the sparse matrix + before computing. + + Returns: + C : array of shape ``(shape[1] if transpose else shape[0], cols)`` + representing the matrix-matrix product. + """ + data, unitd = split_mantissa_unit(data) + B, unitb = split_mantissa_unit(B) + res = csr_matmat_p.bind(data, indices, indptr, B, shape=shape, transpose=transpose) + return maybe_decimal(res * unitd * unitb) + + +@jax.jit +def _csr_to_coo(indices: jax.Array, indptr: jax.Array) -> Tuple[jax.Array, jax.Array]: + """Given CSR (indices, indptr) return COO (row, col)""" + return jnp.cumsum(jnp.zeros_like(indices).at[indptr].add(1)) - 1, indices diff --git a/brainunit/sparse/csr_test.py b/brainunit/sparse/csr_test.py new file mode 100644 index 0000000..09f11ab --- /dev/null +++ b/brainunit/sparse/csr_test.py @@ -0,0 +1,549 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import annotations + +import unittest + +import brainstate as bst +import jax + +import brainunit as u + + +class TestCSR(unittest.TestCase): + def test_matvec(self): + for ux, uy in [ + (u.ms, u.mV), + (u.UNITLESS, u.UNITLESS), + (u.mV, u.UNITLESS), + (u.UNITLESS, u.mV), + ]: + data = bst.random.rand(10, 20) + data = data * (data < 0.3) * ux + + csr = u.sparse.CSR.fromdense(data) + + x = bst.random.random((10,)) * uy + self.assertTrue( + u.math.allclose( + x @ data, + x @ csr + ) + ) + + x = bst.random.random((20,)) * uy + self.assertTrue( + u.math.allclose( + data @ x, + csr @ x + ) + ) + + def test_matmul(self): + for ux, uy in [ + (u.ms, u.mV), + (u.UNITLESS, u.UNITLESS), + (u.mV, u.UNITLESS), + (u.UNITLESS, u.mV), + ]: + data = bst.random.rand(10, 20) + data = data * (data < 0.3) * ux + csr = u.sparse.CSR.fromdense(data) + + data2 = bst.random.rand(20, 30) * uy + + self.assertTrue( + u.math.allclose( + data @ data2, + csr @ data2 + ) + ) + + data2 = bst.random.rand(30, 10) * uy + self.assertTrue( + u.math.allclose( + data2 @ data, + data2 @ csr + ) + ) + + def test_pos(self): + for ux in [ + u.ms, + u.UNITLESS, + u.mV, + ]: + data = bst.random.rand(10, 20) + data = data * (data < 0.3) * ux + + csr = u.sparse.CSR.fromdense(data) + + self.assertTrue( + u.math.allclose( + csr.__pos__().data, + csr.data + ) + ) + + def test_neg(self): + for ux in [ + u.ms, + u.UNITLESS, + u.mV, + ]: + data = bst.random.rand(10, 20) + data = data * (data < 0.3) * ux + + csr = u.sparse.CSR.fromdense(data) + + self.assertTrue( + u.math.allclose( + (-csr).data, + -csr.data + ) + ) + + def test_abs(self): + for ux in [ + u.ms, + u.UNITLESS, + u.mV, + ]: + data = bst.random.rand(10, 20) + data = data * (data < 0.3) * ux + + csr = u.sparse.CSR.fromdense(data) + + self.assertTrue( + u.math.allclose( + abs(csr).data, + abs(csr.data) + ) + ) + + def test_add(self): + for ux in [ + u.ms, + u.UNITLESS, + u.mV, + ]: + data1 = bst.random.rand(10, 20) + data1 = data1 * (data1 < 0.3) * ux + + data2 = 2. * ux + + csr1 = u.sparse.CSR.fromdense(data1) + + self.assertTrue( + u.math.allclose( + (csr1 + data2).data, + csr1.data + data2 + ) + ) + + self.assertTrue( + u.math.allclose( + (data2 + csr1).data, + data2 + csr1.data + ) + ) + + def test_sub(self): + for ux in [ + u.ms, + u.UNITLESS, + u.mV, + ]: + data1 = bst.random.rand(10, 20) + data1 = data1 * (data1 < 0.3) * ux + + data2 = 2. * ux + + csr1 = u.sparse.CSR.fromdense(data1) + + self.assertTrue( + u.math.allclose( + (csr1 - data2).data, + csr1.data - data2 + ) + ) + + self.assertTrue( + u.math.allclose( + (data2 - csr1).data, + data2 - csr1.data + ) + ) + + def test_mul(self): + for ux in [ + u.ms, + u.UNITLESS, + u.mV, + ]: + data1 = bst.random.rand(10, 20) + data1 = data1 * (data1 < 0.3) * ux + + data2 = 2. * ux + + csr1 = u.sparse.CSR.fromdense(data1) + + self.assertTrue( + u.math.allclose( + (csr1 * data2).data, + csr1.data * data2 + ) + ) + + self.assertTrue( + u.math.allclose( + (data2 * csr1).data, + data2 * csr1.data + ) + ) + + def test_div(self): + for ux in [ + u.ms, + u.UNITLESS, + u.mV, + ]: + data1 = bst.random.rand(10, 20) + data1 = data1 * (data1 < 0.3) * ux + + data2 = 2. * u.ohm + + csr1 = u.sparse.CSR.fromdense(data1) + + self.assertTrue( + u.math.allclose( + (csr1 / data2).data, + csr1.data / data2 + ) + ) + + self.assertTrue( + u.math.allclose( + (data2 / csr1).data, + data2 / csr1.data + ) + ) + + def test_mod(self): + for ux in [ + u.ms, + u.UNITLESS, + u.mV, + ]: + data1 = bst.random.rand(10, 20) + data1 = data1 * (data1 < 0.3) * ux + + data2 = 2. * ux + + csr1 = u.sparse.CSR.fromdense(data1) + + self.assertTrue( + u.math.allclose( + (csr1 % data2).data, + csr1.data % data2 + ) + ) + + self.assertTrue( + u.math.allclose( + (data2 % csr1).data, + data2 % csr1.data + ) + ) + + def test_grad(self): + for ux in [ + u.ms, + u.UNITLESS, + u.mV, + ]: + data1 = bst.random.randn(10, 20) * ux + csr = u.sparse.CSR.fromdense(data1) + + def f(csr_data, x): + return u.get_mantissa((csr.with_data(csr_data) @ x).sum()) + + xs = bst.random.randn(20) + + grads = jax.grad(f)(csr.data, xs) + + +class TestCSC(unittest.TestCase): + def test_matvec(self): + for ux, uy in [ + (u.ms, u.mV), + (u.UNITLESS, u.UNITLESS), + (u.mV, u.UNITLESS), + (u.UNITLESS, u.mV), + ]: + data = bst.random.rand(10, 20) + data = data * (data < 0.3) * ux + + csc = u.sparse.CSC.fromdense(data) + + x = bst.random.random((20,)) * uy + self.assertTrue( + u.math.allclose( + data @ x, + csc @ x + ) + ) + + x = bst.random.random((10,)) * uy + self.assertTrue( + u.math.allclose( + x @ data, + x @ csc + ) + ) + + def test_matmul(self): + for ux, uy in [ + (u.ms, u.mV), + (u.UNITLESS, u.UNITLESS), + (u.mV, u.UNITLESS), + (u.UNITLESS, u.mV), + ]: + data = bst.random.rand(10, 20) + data = data * (data < 0.3) * ux + csr = u.sparse.CSC.fromdense(data) + + data2 = bst.random.rand(20, 30) * uy + + self.assertTrue( + u.math.allclose( + data @ data2, + csr @ data2 + ) + ) + + data2 = bst.random.rand(30, 10) * uy + self.assertTrue( + u.math.allclose( + data2 @ data, + data2 @ csr + ) + ) + + def test_pos(self): + for ux in [ + u.ms, + u.UNITLESS, + u.mV, + ]: + data = bst.random.rand(10, 20) + data = data * (data < 0.3) * ux + + csc = u.sparse.CSC.fromdense(data) + + self.assertTrue( + u.math.allclose( + csc.__pos__().data, + csc.data + ) + ) + + def test_neg(self): + for ux in [ + u.ms, + u.UNITLESS, + u.mV, + ]: + data = bst.random.rand(10, 20) + data = data * (data < 0.3) * ux + + csc = u.sparse.CSC.fromdense(data) + + self.assertTrue( + u.math.allclose( + (-csc).data, + -csc.data + ) + ) + + def test_abs(self): + for ux in [ + u.ms, + u.UNITLESS, + u.mV, + ]: + data = bst.random.rand(10, 20) + data = data * (data < 0.3) * ux + + csc = u.sparse.CSC.fromdense(data) + + self.assertTrue( + u.math.allclose( + abs(csc).data, + abs(csc.data) + ) + ) + + def test_add(self): + for ux in [ + u.ms, + u.UNITLESS, + u.mV, + ]: + data1 = bst.random.rand(10, 20) + data1 = data1 * (data1 < 0.3) * ux + + data2 = 2. * ux + + csc1 = u.sparse.CSC.fromdense(data1) + + self.assertTrue( + u.math.allclose( + (csc1 + data2).data, + csc1.data + data2 + ) + ) + + self.assertTrue( + u.math.allclose( + (data2 + csc1).data, + data2 + csc1.data + ) + ) + + def test_sub(self): + for ux in [ + u.ms, + u.UNITLESS, + u.mV, + ]: + data1 = bst.random.rand(10, 20) + data1 = data1 * (data1 < 0.3) * ux + + data2 = 2. * ux + + csc1 = u.sparse.CSC.fromdense(data1) + + self.assertTrue( + u.math.allclose( + (csc1 - data2).data, + csc1.data - data2 + ) + ) + + self.assertTrue( + u.math.allclose( + (data2 - csc1).data, + data2 - csc1.data + ) + ) + + def test_mul(self): + for ux in [ + u.ms, + u.UNITLESS, + u.mV, + ]: + data1 = bst.random.rand(10, 20) + data1 = data1 * (data1 < 0.3) * ux + + data2 = 2. * ux + + csc1 = u.sparse.CSC.fromdense(data1) + + self.assertTrue( + u.math.allclose( + (csc1 * data2).data, + csc1.data * data2 + ) + ) + + self.assertTrue( + u.math.allclose( + (data2 * csc1).data, + data2 * csc1.data + ) + ) + + def test_div(self): + for ux in [ + u.ms, + u.UNITLESS, + u.mV, + ]: + data1 = bst.random.rand(10, 20) + data1 = data1 * (data1 < 0.3) * ux + + data2 = 2. * u.ohm + + csc1 = u.sparse.CSC.fromdense(data1) + + self.assertTrue( + u.math.allclose( + (csc1 / data2).data, + csc1.data / data2 + ) + ) + + self.assertTrue( + u.math.allclose( + (data2 / csc1).data, + data2 / csc1.data + ) + ) + + def test_mod(self): + for ux in [ + u.ms, + u.UNITLESS, + u.mV, + ]: + data1 = bst.random.rand(10, 20) + data1 = data1 * (data1 < 0.3) * ux + + data2 = 2. * ux + + csc1 = u.sparse.CSC.fromdense(data1) + + self.assertTrue( + u.math.allclose( + (csc1 % data2).data, + csc1.data % data2 + ) + ) + + self.assertTrue( + u.math.allclose( + (data2 % csc1).data, + data2 % csc1.data + ) + ) + + def test_grad(self): + for ux in [ + u.ms, + u.UNITLESS, + u.mV, + ]: + data1 = bst.random.randn(10, 20) * ux + csc = u.sparse.CSC.fromdense(data1) + + def f(data, x): + return u.get_mantissa((csc.with_data(data) @ x).sum()) + + xs = bst.random.randn(20) + + grads = jax.grad(f)(csc.data, xs) diff --git a/docs/apis/constants.rst b/docs/apis/brainunit.constants.rst similarity index 100% rename from docs/apis/constants.rst rename to docs/apis/brainunit.constants.rst diff --git a/docs/apis/brainunit.sparse.rst b/docs/apis/brainunit.sparse.rst new file mode 100644 index 0000000..d7a122d --- /dev/null +++ b/docs/apis/brainunit.sparse.rst @@ -0,0 +1,33 @@ +``brainunit.sparse`` module +============================= + +.. currentmodule:: brainunit.sparse +.. automodule:: brainunit.sparse + + +Sparse Data Structures +---------------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + CSR + CSC + COO + + +Sparse Data Operations +---------------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + csr_todense + csr_fromdense + csc_todense + csc_fromdense + coo_todense + coo_fromdense + diff --git a/docs/index.rst b/docs/index.rst index 841debc..e6b0555 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -192,7 +192,8 @@ We are building the `brain dynamics programming ecosystem