Skip to content

Commit

Permalink
support unit-aware sparse computation: CSR, CSC, COO (#80)
Browse files Browse the repository at this point in the history
* support unit-aware sparse computation

* add documentation

* update sparse doc

* fix typing

* remove deprecated `jnp.round_`
  • Loading branch information
chaoming0625 authored Dec 14, 2024
1 parent 0109799 commit edeed6c
Show file tree
Hide file tree
Showing 13 changed files with 2,184 additions and 12 deletions.
3 changes: 2 additions & 1 deletion brainunit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from . import lax
from . import linalg
from . import math
from . import sparse
from ._base import *
from ._base import __all__ as _base_all
from ._celsius import *
Expand All @@ -35,7 +36,7 @@
from .constants import __all__ as _constants_all

__all__ = (
['math', 'linalg', 'autograd', 'fft', 'constants'] +
['math', 'linalg', 'autograd', 'fft', 'constants', 'sparse'] +
_common_all +
_std_units_all +
_constants_all +
Expand Down
38 changes: 38 additions & 0 deletions brainunit/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from jax.tree_util import register_pytree_node_class

from ._misc import set_module_as
from ._sparse_base import SparseMatrix

__all__ = [
# three base objects
Expand All @@ -50,6 +51,7 @@
'get_mantissa',
'get_magnitude',
'display_in_unit',
'split_mantissa_unit',
'maybe_decimal',

# functions for checking
Expand Down Expand Up @@ -717,6 +719,26 @@ def get_mantissa(obj):
get_magnitude = get_mantissa


def split_mantissa_unit(obj):
"""
Split a Quantity into its mantissa and unit.
Parameters
----------
obj : `object`
The object to check.
Returns
-------
mantissa : `float` or `array_like`
The mantissa of the `obj`.
unit : Unit
The physical unit of the `obj`.
"""
obj = _to_quantity(obj)
return obj.mantissa, obj.unit


@set_module_as('brainunit')
def have_same_dim(obj1, obj2) -> bool:
"""Test if two values have the same dimensions.
Expand Down Expand Up @@ -3033,6 +3055,8 @@ def _binary_operation(
return r

def __add__(self, oc):
if isinstance(oc, SparseMatrix):
return oc.__radd__(self)
return self._binary_operation(oc, operator.add, fail_for_mismatch=True, operator_str="+")

def __radd__(self, oc):
Expand All @@ -3043,6 +3067,8 @@ def __iadd__(self, oc):
return self._binary_operation(oc, operator.add, fail_for_mismatch=True, operator_str="+=", inplace=True)

def __sub__(self, oc):
if isinstance(oc, SparseMatrix):
return oc.__rsub__(self)
return self._binary_operation(oc, operator.sub, fail_for_mismatch=True, operator_str="-")

def __rsub__(self, oc):
Expand All @@ -3053,6 +3079,8 @@ def __isub__(self, oc):
return self._binary_operation(oc, operator.sub, fail_for_mismatch=True, operator_str="-=", inplace=True)

def __mul__(self, oc):
if isinstance(oc, SparseMatrix):
return oc.__rmul__(self)
r = self._binary_operation(oc, operator.mul, operator.mul)
return maybe_decimal(r)

Expand All @@ -3065,6 +3093,8 @@ def __imul__(self, oc):

def __div__(self, oc):
# self / oc
if isinstance(oc, SparseMatrix):
return oc.__rdiv__(self)
r = self._binary_operation(oc, operator.truediv, operator.truediv)
return maybe_decimal(r)

Expand All @@ -3073,6 +3103,8 @@ def __idiv__(self, oc):

def __truediv__(self, oc):
# self / oc
if isinstance(oc, SparseMatrix):
return oc.__rtruediv__(self)
return self.__div__(oc)

def __rdiv__(self, oc):
Expand All @@ -3092,6 +3124,8 @@ def __itruediv__(self, oc):

def __floordiv__(self, oc):
# self // oc
if isinstance(oc, SparseMatrix):
return oc.__rfloordiv__(self)
r = self._binary_operation(oc, operator.floordiv, operator.truediv)
return maybe_decimal(r)

Expand All @@ -3108,6 +3142,8 @@ def __ifloordiv__(self, oc):

def __mod__(self, oc):
# self % oc
if isinstance(oc, SparseMatrix):
return oc.__rmod__(self)
r = self._binary_operation(oc, operator.mod, lambda ua, ub: ua, fail_for_mismatch=True, operator_str=r"%")
return maybe_decimal(r)

Expand All @@ -3127,6 +3163,8 @@ def __rdivmod__(self, oc):
return self.__rfloordiv__(oc), self.__rmod__(oc)

def __matmul__(self, oc):
if isinstance(oc, SparseMatrix):
return oc.__rmatmul__(self)
r = self._binary_operation(oc, operator.matmul, operator.mul, operator_str="@")
return maybe_decimal(r)

Expand Down
97 changes: 97 additions & 0 deletions brainunit/_sparse_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import annotations

from typing import Sequence

from jax.experimental.sparse import JAXSparse

__all__ = [
"SparseMatrix"
]


class SparseMatrix(JAXSparse):

# Not abstract methods because not all sparse classes implement them

def with_data(self, data):
raise NotImplementedError(f"{self.__class__}.assign_data")

def sum(self, axis: int | Sequence[int] = None):
if axis is not None:
raise NotImplementedError("CSR.sum with axis is not implemented.")
return self.data.sum()

def __abs__(self):
raise NotImplementedError(f"{self.__class__}.__abs__")

def __neg__(self):
raise NotImplementedError(f"{self.__class__}.__neg__")

def __pos__(self):
raise NotImplementedError(f"{self.__class__}.__pos__")

def __matmul__(self, other):
raise NotImplementedError(f"{self.__class__}.__matmul__")

def __rmatmul__(self, other):
raise NotImplementedError(f"{self.__class__}.__rmatmul__")

def __mul__(self, other):
raise NotImplementedError(f"{self.__class__}.__mul__")

def __rmul__(self, other):
raise NotImplementedError(f"{self.__class__}.__rmul__")

def __add__(self, other):
raise NotImplementedError(f"{self.__class__}.__add__")

def __radd__(self, other):
raise NotImplementedError(f"{self.__class__}.__radd__")

def __sub__(self, other):
raise NotImplementedError(f"{self.__class__}.__sub__")

def __rsub__(self, other):
raise NotImplementedError(f"{self.__class__}.__rsub__")

def __div__(self, other):
raise NotImplementedError(f"{self.__class__}.__div__")

def __rdiv__(self, other):
raise NotImplementedError(f"{self.__class__}.__rdiv__")

def __truediv__(self, other):
raise NotImplementedError(f"{self.__class__}.__truediv__")

def __rtruediv__(self, other):
raise NotImplementedError(f"{self.__class__}.__rtruediv__")

def __floordiv__(self, other):
raise NotImplementedError(f"{self.__class__}.__floordiv__")

def __rfloordiv__(self, other):
raise NotImplementedError(f"{self.__class__}.__rfloordiv__")

def __mod__(self, other):
raise NotImplementedError(f"{self.__class__}.__mod__")

def __rmod__(self, other):
raise NotImplementedError(f"{self.__class__}.__rmod__")

def __getitem__(self, item):
raise NotImplementedError(f"{self.__class__}.__getitem__")
37 changes: 28 additions & 9 deletions brainunit/math/_fun_keep_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import jax
import jax.numpy as jnp
from jax._src.numpy.util import promote_dtypes as _promote_dtypes
import numpy as np

from ._fun_array_creation import asarray
Expand Down Expand Up @@ -63,7 +64,7 @@
# math funcs keep unit (binary)
'fmod', 'mod', 'copysign', 'remainder',
'maximum', 'minimum', 'fmax', 'fmin', 'lcm', 'gcd', 'trace',
'add', 'subtract', 'nextafter',
'add', 'subtract', 'nextafter', 'promote_dtypes',

# math funcs keep unit
'interp', 'clip', 'histogram',
Expand Down Expand Up @@ -492,6 +493,27 @@ def broadcast_arrays(
return _broadcast_fun(jnp.broadcast_arrays, *args)


@set_module_as('brainunit.math')
def promote_dtypes(
*args: Union[Quantity, jax.typing.ArrayLike]
) -> Union[Quantity | jax.Array | Sequence[jax.Array | Quantity]]:
"""
Promote the data types of the inputs to a common type.
Parameters
----------
`*args` : array_likes
The arrays to promote.
Returns
-------
promoted : list of arrays
These arrays have the same shape as the input arrays, with the
data type of the most precise input.
"""
return _broadcast_fun(_promote_dtypes, *args)


@set_module_as('brainunit.math')
def broadcast_to(
array: Union[Quantity, jax.typing.ArrayLike],
Expand Down Expand Up @@ -3371,13 +3393,12 @@ def round_(
-------
out : jax.Array
"""
return _fun_keep_unit_unary(jnp.round_, x)
return _fun_keep_unit_unary(jnp.round, x)


@set_module_as('brainunit.math')
def around(
def round(
x: Union[Quantity, jax.typing.ArrayLike],
decimals: int = 0,
) -> jax.Array | Quantity:
"""
Round an array to the nearest integer.
Expand All @@ -3386,18 +3407,16 @@ def around(
----------
x : array_like, Quantity
Input array.
decimals : int, optional
Number of decimal places to round to (default is 0).
Returns
-------
out : jax.Array
"""
return _fun_keep_unit_unary(jnp.around, x, decimals=decimals)
return _fun_keep_unit_unary(jnp.round, x)


@set_module_as('brainunit.math')
def round(
def around(
x: Union[Quantity, jax.typing.ArrayLike],
decimals: int = 0,
) -> jax.Array | Quantity:
Expand All @@ -3415,7 +3434,7 @@ def round(
-------
out : jax.Array
"""
return _fun_keep_unit_unary(jnp.round, x, decimals=decimals)
return _fun_keep_unit_unary(jnp.around, x, decimals=decimals)


@set_module_as('brainunit.math')
Expand Down
24 changes: 23 additions & 1 deletion brainunit/math/_fun_remove_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

__all__ = [
# math funcs remove unit (unary)
'iscomplexobj', 'heaviside', 'signbit', 'sign', 'bincount', 'digitize',
'iscomplexobj', 'heaviside', 'signbit', 'sign', 'bincount', 'digitize', 'get_promote_dtypes',

# logic funcs (unary)
'all', 'any', 'logical_not',
Expand All @@ -43,6 +43,28 @@
# math funcs remove unit (unary)
# ------------------------------


@set_module_as('brainunit.math')
def get_promote_dtypes(
*args: Union[Quantity, jax.typing.ArrayLike]
) -> Union[Quantity | jax.Array | Sequence[jax.Array | Quantity]]:
"""
Promote the data types of the inputs to a common type.
Parameters
----------
`*args` : array_likes
The arrays to promote.
Returns
-------
promoted : list of arrays
These arrays have the same shape as the input arrays, with the
data type of the most precise input.
"""
return jnp.promote_types(*jax.tree.leaves(args))


def _fun_remove_unit_unary(func, x, *args, **kwargs):
if isinstance(x, Quantity):
# x = x.factorless()
Expand Down
25 changes: 25 additions & 0 deletions brainunit/sparse/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================


from .coo import COO, coo_todense, coo_fromdense
from .csr import CSR, CSC, csr_todense, csr_fromdense, csc_fromdense, csc_todense

__all__ = [
"CSR", "CSC",
"csr_todense", "csr_fromdense",
"csc_todense", "csc_fromdense",
"COO", "coo_todense", "coo_fromdense"
]
Loading

0 comments on commit edeed6c

Please sign in to comment.