Skip to content

Commit

Permalink
support binary operation between two sparse matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Dec 14, 2024
1 parent f106941 commit 4ba6918
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 10 deletions.
38 changes: 32 additions & 6 deletions brainunit/sparse/_coo.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,37 +222,63 @@ def __pos__(self):
)

def _binary_op(self, other, op):
if isinstance(other, COO):
if id(self.row) == id(other.row) and id(self.col) == id(other.col):
return COO(
(op(self.data, other.data), self.row, self.col),
shape=self.shape,
rows_sorted=self._rows_sorted,
cols_sorted=self._cols_sorted
)
if isinstance(other, JAXSparse):
raise NotImplementedError("mul between two sparse objects.")
raise NotImplementedError(f"binary operation {op} between two sparse objects.")

other = asarray(other)
if other.size == 1:
return COO(
(op(self.data, other), self.row, self.col),
shape=self.shape
shape=self.shape,
rows_sorted=self._rows_sorted,
cols_sorted=self._cols_sorted
)
elif other.ndim == 2 and other.shape == self.shape:
other = other[self.row, self.col]
return COO(
(op(self.data, other), self.row, self.col),
shape=self.shape
shape=self.shape,
rows_sorted=self._rows_sorted,
cols_sorted=self._cols_sorted
)
else:
raise NotImplementedError(f"mul with object of shape {other.shape}")

def _binary_rop(self, other, op):
if isinstance(other, COO):
if id(self.row) == id(other.row) and id(self.col) == id(other.col):
return COO(
(op(other.data, self.data), self.row, self.col),
shape=self.shape,
rows_sorted=self._rows_sorted,
cols_sorted=self._cols_sorted
)
if isinstance(other, JAXSparse):
raise NotImplementedError("mul between two sparse objects.")
raise NotImplementedError(f"binary operation {op} between two sparse objects.")

other = asarray(other)
if other.size == 1:
return COO(
(op(other, self.data), self.row, self.col),
shape=self.shape
shape=self.shape,
rows_sorted=self._rows_sorted,
cols_sorted=self._cols_sorted
)
elif other.ndim == 2 and other.shape == self.shape:
other = other[self.row, self.col]
return COO(
(op(other, self.data), self.row, self.col),
shape=self.shape
shape=self.shape,
rows_sorted=self._rows_sorted,
cols_sorted=self._cols_sorted
)
else:
raise NotImplementedError(f"mul with object of shape {other.shape}")
Expand Down
3 changes: 3 additions & 0 deletions brainunit/sparse/_coo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,9 @@ def f(sp, x):

grads = jax.grad(f)(sp, xs)

sp = sp + grads * 1e-3
sp = sp + 1e-3 * grads

def test_jit(self):
@jax.jit
def f(sp, x):
Expand Down
36 changes: 32 additions & 4 deletions brainunit/sparse/_csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,15 @@ def __pos__(self):
return CSR((self.data.__pos__(), self.indices, self.indptr), shape=self.shape)

def _binary_op(self, other, op):
if isinstance(other, CSR):
if id(other.indices) == id(self.indices) and id(other.indptr) == id(self.indptr):
return CSR(
(op(self.data, other.data), self.indices, self.indptr),
shape=self.shape
)
if isinstance(other, JAXSparse):
raise NotImplementedError("mul between two sparse objects.")
raise NotImplementedError(f"binary operation {op} between two sparse objects.")

other = asarray(other)
if other.size == 1:
return CSR(
Expand All @@ -139,8 +146,15 @@ def _binary_op(self, other, op):
raise NotImplementedError(f"mul with object of shape {other.shape}")

def _binary_rop(self, other, op):
if isinstance(other, CSR):
if id(other.indices) == id(self.indices) and id(other.indptr) == id(self.indptr):
return CSR(
(op(other.data, self.data), self.indices, self.indptr),
shape=self.shape
)
if isinstance(other, JAXSparse):
raise NotImplementedError("mul between two sparse objects.")
raise NotImplementedError(f"binary operation {op} between two sparse objects.")

other = asarray(other)
if other.size == 1:
return CSR(
Expand Down Expand Up @@ -294,8 +308,15 @@ def __pos__(self):
return CSC((self.data.__pos__(), self.indices, self.indptr), shape=self.shape)

def _binary_op(self, other, op):
if isinstance(other, CSC):
if id(other.indices) == id(self.indices) and id(other.indptr) == id(self.indptr):
return CSC(
(op(self.data, other.data), self.indices, self.indptr),
shape=self.shape
)
if isinstance(other, JAXSparse):
raise NotImplementedError("mul between two sparse objects.")
raise NotImplementedError(f"binary operation {op} between two sparse objects.")

other = asarray(other)
if other.size == 1:
return CSC(
Expand All @@ -313,8 +334,15 @@ def _binary_op(self, other, op):
raise NotImplementedError(f"mul with object of shape {other.shape}")

def _binary_rop(self, other, op):
if isinstance(other, CSC):
if id(other.indices) == id(self.indices) and id(other.indptr) == id(self.indptr):
return CSC(
(op(other.data, self.data), self.indices, self.indptr),
shape=self.shape
)
if isinstance(other, JAXSparse):
raise NotImplementedError("mul between two sparse objects.")
raise NotImplementedError(f"binary operation {op} between two sparse objects.")

other = asarray(other)
if other.size == 1:
return CSC(
Expand Down
6 changes: 6 additions & 0 deletions brainunit/sparse/_csr_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,9 @@ def f(csr, x):

grads = jax.grad(f)(csr, xs)

csr = csr + grads * 1e-3
csr = csr + 1e-3 * grads

def test_jit(self):
@jax.jit
def f(csr, x):
Expand Down Expand Up @@ -596,6 +599,9 @@ def f(csc, x):

grads = jax.grad(f)(csc, xs)

csc = csc + grads * 1e-3
csc = csc + 1e-3 * grads

def test_jit(self):

@jax.jit
Expand Down

0 comments on commit 4ba6918

Please sign in to comment.