Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support binary operation between two sparse matrix #82

Merged
merged 1 commit into from
Dec 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 32 additions & 6 deletions brainunit/sparse/_coo.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,37 +222,63 @@ def __pos__(self):
)

def _binary_op(self, other, op):
if isinstance(other, COO):
if id(self.row) == id(other.row) and id(self.col) == id(other.col):
return COO(
(op(self.data, other.data), self.row, self.col),
shape=self.shape,
rows_sorted=self._rows_sorted,
cols_sorted=self._cols_sorted
)
if isinstance(other, JAXSparse):
raise NotImplementedError("mul between two sparse objects.")
raise NotImplementedError(f"binary operation {op} between two sparse objects.")

other = asarray(other)
if other.size == 1:
return COO(
(op(self.data, other), self.row, self.col),
shape=self.shape
shape=self.shape,
rows_sorted=self._rows_sorted,
cols_sorted=self._cols_sorted
)
elif other.ndim == 2 and other.shape == self.shape:
other = other[self.row, self.col]
return COO(
(op(self.data, other), self.row, self.col),
shape=self.shape
shape=self.shape,
rows_sorted=self._rows_sorted,
cols_sorted=self._cols_sorted
)
else:
raise NotImplementedError(f"mul with object of shape {other.shape}")

def _binary_rop(self, other, op):
if isinstance(other, COO):
if id(self.row) == id(other.row) and id(self.col) == id(other.col):
return COO(
(op(other.data, self.data), self.row, self.col),
shape=self.shape,
rows_sorted=self._rows_sorted,
cols_sorted=self._cols_sorted
)
if isinstance(other, JAXSparse):
raise NotImplementedError("mul between two sparse objects.")
raise NotImplementedError(f"binary operation {op} between two sparse objects.")

other = asarray(other)
if other.size == 1:
return COO(
(op(other, self.data), self.row, self.col),
shape=self.shape
shape=self.shape,
rows_sorted=self._rows_sorted,
cols_sorted=self._cols_sorted
)
elif other.ndim == 2 and other.shape == self.shape:
other = other[self.row, self.col]
return COO(
(op(other, self.data), self.row, self.col),
shape=self.shape
shape=self.shape,
rows_sorted=self._rows_sorted,
cols_sorted=self._cols_sorted
)
else:
raise NotImplementedError(f"mul with object of shape {other.shape}")
Expand Down
3 changes: 3 additions & 0 deletions brainunit/sparse/_coo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,9 @@ def f(sp, x):

grads = jax.grad(f)(sp, xs)

sp = sp + grads * 1e-3
sp = sp + 1e-3 * grads

def test_jit(self):
@jax.jit
def f(sp, x):
Expand Down
36 changes: 32 additions & 4 deletions brainunit/sparse/_csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,15 @@ def __pos__(self):
return CSR((self.data.__pos__(), self.indices, self.indptr), shape=self.shape)

def _binary_op(self, other, op):
if isinstance(other, CSR):
if id(other.indices) == id(self.indices) and id(other.indptr) == id(self.indptr):
return CSR(
(op(self.data, other.data), self.indices, self.indptr),
shape=self.shape
)
if isinstance(other, JAXSparse):
raise NotImplementedError("mul between two sparse objects.")
raise NotImplementedError(f"binary operation {op} between two sparse objects.")

other = asarray(other)
if other.size == 1:
return CSR(
Expand All @@ -139,8 +146,15 @@ def _binary_op(self, other, op):
raise NotImplementedError(f"mul with object of shape {other.shape}")

def _binary_rop(self, other, op):
if isinstance(other, CSR):
if id(other.indices) == id(self.indices) and id(other.indptr) == id(self.indptr):
return CSR(
(op(other.data, self.data), self.indices, self.indptr),
shape=self.shape
)
if isinstance(other, JAXSparse):
raise NotImplementedError("mul between two sparse objects.")
raise NotImplementedError(f"binary operation {op} between two sparse objects.")

other = asarray(other)
if other.size == 1:
return CSR(
Expand Down Expand Up @@ -294,8 +308,15 @@ def __pos__(self):
return CSC((self.data.__pos__(), self.indices, self.indptr), shape=self.shape)

def _binary_op(self, other, op):
if isinstance(other, CSC):
if id(other.indices) == id(self.indices) and id(other.indptr) == id(self.indptr):
return CSC(
(op(self.data, other.data), self.indices, self.indptr),
shape=self.shape
)
if isinstance(other, JAXSparse):
raise NotImplementedError("mul between two sparse objects.")
raise NotImplementedError(f"binary operation {op} between two sparse objects.")

other = asarray(other)
if other.size == 1:
return CSC(
Expand All @@ -313,8 +334,15 @@ def _binary_op(self, other, op):
raise NotImplementedError(f"mul with object of shape {other.shape}")

def _binary_rop(self, other, op):
if isinstance(other, CSC):
if id(other.indices) == id(self.indices) and id(other.indptr) == id(self.indptr):
return CSC(
(op(other.data, self.data), self.indices, self.indptr),
shape=self.shape
)
if isinstance(other, JAXSparse):
raise NotImplementedError("mul between two sparse objects.")
raise NotImplementedError(f"binary operation {op} between two sparse objects.")

other = asarray(other)
if other.size == 1:
return CSC(
Expand Down
6 changes: 6 additions & 0 deletions brainunit/sparse/_csr_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,9 @@ def f(csr, x):

grads = jax.grad(f)(csr, xs)

csr = csr + grads * 1e-3
csr = csr + 1e-3 * grads

def test_jit(self):
@jax.jit
def f(csr, x):
Expand Down Expand Up @@ -596,6 +599,9 @@ def f(csc, x):

grads = jax.grad(f)(csc, xs)

csc = csc + grads * 1e-3
csc = csc + 1e-3 * grads

def test_jit(self):

@jax.jit
Expand Down
Loading