diff --git a/brainunit/sparse/__init__.py b/brainunit/sparse/__init__.py index 8e51511..4266be7 100644 --- a/brainunit/sparse/__init__.py +++ b/brainunit/sparse/__init__.py @@ -13,11 +13,12 @@ # limitations under the License. # ============================================================================== - -from .coo import COO, coo_todense, coo_fromdense -from .csr import CSR, CSC, csr_todense, csr_fromdense, csc_fromdense, csc_todense +from ._coo import COO, coo_todense, coo_fromdense +from ._csr import CSR, CSC, csr_todense, csr_fromdense, csc_fromdense, csc_todense +from .._sparse_base import SparseMatrix __all__ = [ + "SparseMatrix", "CSR", "CSC", "csr_todense", "csr_fromdense", "csc_todense", "csc_fromdense", diff --git a/brainunit/sparse/_block_csr.py b/brainunit/sparse/_block_csr.py new file mode 100644 index 0000000..3c3c4d1 --- /dev/null +++ b/brainunit/sparse/_block_csr.py @@ -0,0 +1,240 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import annotations + +import functools +from typing import Tuple + +import jax +import jax.numpy as jnp +import numpy as np +from jax.experimental import pallas as pl + +import brainunit as u +from brainunit._base import Quantity +from brainunit._sparse_base import SparseMatrix + +__all__ = [ + 'BlockCSR', +] + + +@jax.tree_util.register_pytree_node_class +class BlockCSR(SparseMatrix): + """ + Unit-aware Block-CSR sparse matrix. + """ + data: jax.Array | Quantity # float32[n_blocks, *block_size] + indptr: jax.Array # int32[n_block_rows + 1] + indices: jax.Array # int32[n_blocks] + shape: tuple[int, int] # (n_block_rows * block_size[0], n_block_cols * block_size[1]) + + ndim: int = property(lambda self: len(self.shape)) + num_blocks = property(lambda self: self.data.shape[0]) + block_size = property(lambda self: self.data.shape[1:]) + dtype = property(lambda self: self.data.dtype) + + def __init__(self, args, *, shape: Tuple[int, int]): + blocks, indptr, indices = args + self.data = blocks + self.indptr = indptr + self.indices = indices + super().__init__(args, shape=shape) + + def tree_flatten(self): + return (self.data,), (self.indptr, self.indices, self.shape,) + + @classmethod + def tree_unflatten(cls, data, xs): + blocks, = xs + indptr, indices, shape = data + return BlockCSR((blocks, indptr, indices), shape=shape) + + def _validate(self): + _nblocks, n, m = self.data.shape + nrows = self.indptr.shape[0] - 1 + assert self.indices.shape[0] == _nblocks + assert len(self.shape) == 2 + assert self.shape[0] == n * nrows + assert self.shape[1] % m == 0 + + @jax.jit + def todense(self) -> jax.Array: + self._validate() + return _sdd_todense(self) + + @classmethod + def fromdense(cls, dense: jax.Array, *, block_size) -> 'BlockCSR': + raise NotImplementedError + + def __matmul__(self, other) -> jax.Array: + self._validate() + return sdd_matmul(self, other) + + +@jax.jit +def _sdd_todense(mat: BlockCSR) -> jax.Array: + _, n, m = mat.data.shape + nrows = mat.shape[0] // n + unit = u.get_unit(mat.data) + blocks = u.get_mantissa(mat.data) + + def i_body(i_row, out): # each row + def j_body(x): # each block in the row + i_block, val = x + i_col = mat.indices[i_block] + val = jax.lax.dynamic_update_slice(val, blocks[i_block], (i_row * n, i_col * m)) + return i_block + 1, val + + return jax.lax.while_loop( + lambda x: x[0] < mat.indptr[i_row + 1], + j_body, + (mat.indptr[i_row], out) + )[1] + + dense = jax.lax.fori_loop(0, nrows, i_body, jnp.zeros(mat.shape, mat.dtype)) + return u.maybe_decimal(u.Quantity(dense, unit=unit)) + + +def _check_shape_consistency(x, y): + assert isinstance(y, jax.Array), f"Only support jax.Array. But got unsupported type {type(y)}" + assert x.ndim == y.ndim == 2 + assert x.shape[1] == y.shape[0], f"Dimension mismatch: {x.shape} @ {y.shape}" + + +def _sdd_kernel( + x_ref, # [n_blocks, bm, bn] + indices_ref, # [n_block] + indptr_ref, # [n_rows + 1] + y_ref, # [n, k] + o_ref, # [m, k] + *, + bm: int, + bn: int, + bk: int, +): + i_m = pl.program_id(axis=0) + i_k = pl.program_id(axis=1) + i_start = indptr_ref[i_m] + i_end = indptr_ref[i_m + 1] + + def body(x): + val, i_block = x + i_x_col = indices_ref[i_block] + block = pl.load(x_ref, (i_block, pl.dslice(None), pl.dslice(None))) # [bm, bn] + chunk = pl.load(y_ref, (pl.dslice(i_x_col * bn, bn), pl.dslice(i_k * bk, bk))) # [bn, bk] + return val + jnp.dot(block, chunk).astype(o_ref.dtype), i_block + 1 + + acc = jax.lax.while_loop( + lambda x: x[1] < i_end, + body, + (jnp.zeros([bm, bk], dtype=o_ref.dtype), i_start) + )[0] + pl.store(o_ref, (pl.dslice(bm * i_m, bm), pl.dslice(bk * i_k, bk)), acc) # [bm, bk] + + +@functools.partial(jax.jit, static_argnames=["debug", 'interpret', 'block_size']) +def sdd_matmul( + mat1: BlockCSR, + mat2: jax.Array, + *, + debug: bool = False, + interpret: bool = False, + block_size: int = 256, +) -> jax.Array: + _check_shape_consistency(mat1, mat2) + + # shape and dtype + m, n, k = mat1.shape[0], mat1.shape[1], mat2.shape[1] + _, bm, bn = mat1.data.shape + dtype = jnp.result_type(mat1.dtype, mat2.dtype) + + # kernel + fn = pl.pallas_call( + functools.partial(_sdd_kernel, bm=bm, bn=bn, bk=block_size), + out_shape=jax.ShapeDtypeStruct(shape=(m, k), dtype=dtype), + grid=(pl.cdiv(m, bm), pl.cdiv(k, block_size)), + debug=debug, + interpret=interpret + ) + + # call + unita = u.get_unit(mat1.data) + unitb = u.get_unit(mat2) + blocks = u.get_mantissa(mat1.data) + r = fn(blocks, mat1.indices, mat1.indptr, u.get_mantissa(mat2)) + return u.maybe_decimal(u.Quantity(r, unit=unita * unitb)) + + +@jax.jit +def native_sdd_matmul( + mat1: BlockCSR, + mat2: jax.Array, +): + _check_shape_consistency(mat1, mat2) + + dtype = jnp.result_type(mat1.dtype, mat2.dtype) + _, n, m = mat1.data.shape + + nrows = mat1.shape[0] // n + + def i_body(i): # each row + def k_body(x): + i_block, val = x + i_col = mat1.indices[i_block] + chunk = jax.lax.dynamic_slice(mat2, [i_col * m, 0], (m, mat2.shape[1])) # [m, mat2.shape[1]] + block = blocks[i_block] + return i_block + 1, val + block.dot(chunk) + + acc = jax.lax.while_loop( + lambda x: x[0] < mat1.indptr[i + 1], + k_body, + (mat1.indptr[i], jnp.zeros((n, mat2.shape[1]), dtype=jnp.float32)) + )[1] + return acc.astype(dtype) + + unita = u.get_unit(mat1.data) + unitb = u.get_unit(mat2) + blocks = u.get_mantissa(mat1.data) + mat2 = u.get_mantissa(mat2) + + out = jax.vmap(i_body)(jnp.arange(nrows)).reshape((mat1.shape[0], mat2.shape[1])) + return u.maybe_decimal(u.Quantity(out, unit=unita * unitb)) + + +def sample_sparse_matrix( + m, n, bm, bn, *, + sparse_prob=0.2, + dtype=jnp.float32 +) -> BlockCSR: + num_rows = m // bm # number of rows in the Block-ELL matrix + num_cols = n // bn # number of columns in the Block-ELL matrix + blocks_per_row = np.random.binomial(num_cols, sparse_prob, + size=[num_rows]) # [n_rows], number of data in each row + num_blocks = blocks_per_row.sum() + blocks = np.random.randn(num_blocks, bm, bn).astype(dtype) # [n_blocks, bm, bk], block values + + # [n_rows + 1], row pointers + indptr = np.zeros(num_rows + 1, dtype=np.int32) # [n_rows + 1], row pointers + indptr[1:] = np.cumsum(blocks_per_row) + + # [n_block], block indices + indices = [] + for i in range(num_rows): + indices.extend(np.random.choice(num_cols, blocks_per_row[i], replace=False)) + indices = jnp.array(indices) # [n_rows, max_num_blocks_per_row, 2], block indices + + return BlockCSR((jnp.asarray(blocks), jnp.asarray(indptr), jnp.asarray(indices)), shape=(m, n)) diff --git a/brainunit/sparse/_block_csr_benchmark.py b/brainunit/sparse/_block_csr_benchmark.py new file mode 100644 index 0000000..93eb3ab --- /dev/null +++ b/brainunit/sparse/_block_csr_benchmark.py @@ -0,0 +1,99 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import annotations + +import functools +import timeit + +import brainstate as bst +import jax +import jax.numpy as jnp + +from brainunit.sparse._block_csr import sample_sparse_matrix, sdd_matmul, native_sdd_matmul + + +def main(dtype=jnp.float16, sparse_prob=0.2): + bst.random.seed(1234) + # data + m, k, n = 4096, 4096, 4096 + bm, bn, bk = 32, 32, 256 + print(f"Matrix Shape: {m} x {k} x {n}, dtype: {dtype}, sparse_prob: {sparse_prob}") + + x = sample_sparse_matrix(m, k, bm, bn, sparse_prob=sparse_prob, dtype=dtype) + x_dense = x.todense() + y = bst.random.randn(k, n, dtype=dtype) + + # operations + interpret = jax.devices()[0].platform == "cpu" + # sdd_matmul(x, y, debug=False, block_size=bk, interpret=interpret).block_until_ready() + native_matmul = jax.jit(native_sdd_matmul) + pl_matmul = jax.jit(functools.partial(sdd_matmul, block_size=bk, interpret=interpret)) + dense_matmul = jax.jit(jnp.matmul) + native_grad = jax.jit(jax.grad(native_sdd_matmul, argnums=(0, 1))) + pl_grad = jax.jit(jax.grad(functools.partial(sdd_matmul, block_size=bk, interpret=interpret), argnums=(0, 1))) + dense_grad = jax.jit(jax.grad(jnp.matmul, argnums=(0, 1))) + + # compilation + out_pl = pl_matmul(x, y) + out_hlo = native_matmul(x, y) + out_ref = dense_matmul(x_dense, y) + # out_pl = pl_grad(x, y) + # out_hlo = native_grad(x, y) + # out_ref = dense_grad(x_dense, y) + + # print(jnp.max(jnp.abs(out_pl - out_ref))) + # print(jnp.max(jnp.abs(out_pl - out_ref) / jnp.abs(out_pl))) + # np.testing.assert_allclose(out_pl, out_ref, atol=0.04, rtol=0.04) + # np.testing.assert_allclose(out_hlo, out_ref, atol=0.04, rtol=0.04) + + n_trial1, n_trial2 = (10, 2) if interpret else (1000, 20) + duration = timeit.timeit(lambda: dense_matmul(x_dense, y).block_until_ready(), number=n_trial1) + s1_forward = duration / n_trial1 * 1000 + print(f"Dense Matmul, forward: {s1_forward:.2f}ms") + # duration = timeit.timeit(lambda: jax.block_until_ready(dense_grad(x, y)), number=n_trial1) + # s1_backward = duration / n_trial1 * 1000 + # print(f"Dense Matmul, backward: {s1_backward:.2f}ms") + + duration = timeit.timeit(lambda: pl_matmul(x, y).block_until_ready(), number=n_trial1) + s2_forward = duration / n_trial1 * 1000 + print(f"Pallas Blocksparse Matmul, forward: {s2_forward:.2f}ms") + # duration = timeit.timeit(lambda: jax.block_until_ready(pl_grad(x, y)), number=n_trial1) + # s2_backward = duration / n_trial1 * 1000 + # print(f"Pallas Blocksparse Matmul, backward: {s2_backward:.2f}ms") + + duration = timeit.timeit(lambda: native_matmul(x, y).block_until_ready(), number=n_trial2) + s3_forward = duration / n_trial2 * 1000 + print(f"HLO Blocksparse Matmul, forward: {s3_forward:.2f}ms") + # duration = timeit.timeit(lambda: jax.block_until_ready(native_grad(x, y)), number=n_trial2) + # s3_backward = duration / n_trial2 * 1000 + # print(f"HLO Blocksparse Matmul, backward: {s3_backward:.2f}ms") + + print(f"Forward speedup: {s1_forward / s2_forward:.2f}x (Dense vs. Pallas), " + f"{s3_forward / s2_forward:.2f}x (HLO vs. Pallas)") + # print(f"Backward speedup: {s1_backward / s2_backward:.2f}x (Dense vs. Pallas), " + # f"{s3_backward / s2_backward:.2f}x (HLO vs. Pallas)") + print() + + +if __name__ == "__main__": + main(jnp.float32, 0.3) + main(jnp.float32, 0.2) + main(jnp.float32, 0.1) + main(jnp.float32, 0.05) + main(jnp.float16, 0.3) + main(jnp.float16, 0.2) + main(jnp.float16, 0.1) + main(jnp.float16, 0.05) diff --git a/brainunit/sparse/_block_ell.py b/brainunit/sparse/_block_ell.py new file mode 100644 index 0000000..37c20b3 --- /dev/null +++ b/brainunit/sparse/_block_ell.py @@ -0,0 +1,251 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import annotations + +import functools + +import jax +import jax.numpy as jnp +import numpy as np +from jax.experimental import pallas as pl + +from brainunit._base import Quantity +from brainunit._sparse_base import SparseMatrix + +__all__ = [ + 'BlockELL', +] + + +@jax.tree_util.register_pytree_node_class +class BlockELL(SparseMatrix): + """ + Unit-aware Block-ELL sparse matrix. + """ + data: jax.Array | Quantity # float32[n_blocks, *block_size] + blocks_per_row: jax.Array # int32[n_rows] + indices: jax.Array # int32[n_rows, max_num_blocks_per_row, 2] + shape: tuple[int, int] # (n_rows * block_size[0], n_cols * block_size[1]) + + ndim: int = property(lambda self: len(self.shape)) + num_blocks = property(lambda self: self.data.shape[0]) + block_size = property(lambda self: self.data.shape[1:]) + dtype = property(lambda self: self.data.dtype) + + def __init__(self, args, *, shape): + blocks, blocks_per_row, indices = args + self.data = blocks + self.blocks_per_row = blocks_per_row + self.indices = indices + super().__init__(args, shape=shape) + + def tree_flatten(self): + return (self.data,), (self.blocks_per_row, self.indices, self.shape,) + + @classmethod + def tree_unflatten(cls, data, xs): + blocks, = xs + blocks_per_row, indices, shape, = data + return BlockELL((blocks, blocks_per_row, indices), shape=shape) + + def _validate(self): + _nblocks, n, m = self.data.shape + nrows = self.blocks_per_row.shape[0] + assert self.indices.shape[0] == nrows + assert len(self.shape) == 2 + assert self.shape[0] == n * nrows + assert self.shape[1] % m == 0 + + @jax.jit + def todense(self) -> jax.Array: + self._validate() + return _sdd_todense(self) + + @classmethod + def fromdense(cls, dense: jax.Array, *, block_size) -> 'BlockCSR': + nrows, ncols = dense.shape + n, m = block_size + assert nrows % n == 0 + assert ncols % m == 0 + nrows //= n + ncols //= m + + blocks = [] + blocks_per_row = [] + indices = [] + for i in range(nrows): + row_blocks = [] + row_indices = [] + for j in range(ncols): + block = dense[i * n:(i + 1) * n, j * m:(j + 1) * m] + if not jnp.all(block == 0): + row_blocks.append(block) + row_indices.append([j, len(row_blocks) - 1]) + blocks_per_row.append(len(row_blocks)) + blocks.extend(row_blocks) + indices.append(row_indices) + + return cls( + (jnp.asarray(blocks), jnp.asarray(blocks_per_row), jnp.asarray(indices)), + shape=dense.shape + ) + + def __matmul__(self, other) -> jax.Array: + self._validate() + return sdd_matmul(self, other) + + +@jax.jit +def _sdd_todense(mat: BlockELL) -> jax.Array: + _, n, m = mat.data.shape + nrows = mat.shape[0] // n + out = jnp.zeros(mat.shape, mat.dtype) + + def i_body(i, val1): # each row + def j_body(j, val2): # each block in the row + i_col, i_block = mat.indices[i, j] + val2 = jax.lax.dynamic_update_slice(val2, mat.data[i_block], (i * n, i_col * m)) + return val2 + + return jax.lax.fori_loop(0, mat.blocks_per_row[i], j_body, val1) + + return jax.lax.fori_loop(0, nrows, i_body, out) + + +def _check_shape_consistency(x, y): + assert isinstance(y, jax.Array), f"Only support jax.Array. But got unsupported type {type(y)}" + assert x.ndim == y.ndim == 2 + assert x.shape[1] == y.shape[0], f"Dimension mismatch: {x.shape} @ {y.shape}" + + +def _sdd_kernel( + x_ref, # [n_blocks, bm, bn] + indices_ref, # [n_rows, max_num_blocks_per_row, 2] + blocks_per_row_ref, # [n_rows] + y_ref, # [n, k] + o_ref, # [m, k] + *, + bm: int, + bn: int, + bk: int, +): + i_m = pl.program_id(axis=0) + i_k = pl.program_id(axis=1) + n_block_this_row = blocks_per_row_ref[i_m] + + def body(k, val): + i_x_col = indices_ref[i_m, k, 0] + i_block = indices_ref[i_m, k, 1] + # block = x_ref[i_block, ...] # [bm, bn] + # chunk = y_ref[i_x_col * bn:(i_x_col + 1) * bn, i_k * bk:(i_k + 1) * bk] # [bn, bk] + block = pl.load(x_ref, (i_block, pl.dslice(None), pl.dslice(None))) # [bm, bn] + chunk = pl.load(y_ref, (pl.dslice(i_x_col * bn, bn), pl.dslice(i_k * bk, bk))) # [bn, bk] + return val + jnp.dot(block, chunk).astype(o_ref.dtype) + + acc = jax.lax.fori_loop(0, n_block_this_row, body, jnp.zeros([bm, bk], dtype=o_ref.dtype)) + pl.store(o_ref, (pl.dslice(bm * i_m, bm), pl.dslice(bk * i_k, bk)), acc) # [bm, bk] + # o_ref[i_m * bm:(i_m + 1) * bm, i_k * bk:(i_k + 1) * bk] = acc + + +@functools.partial(jax.jit, static_argnames=["debug", 'interpret', 'block_size']) +def sdd_matmul( + mat1: BlockELL, + mat2: jax.Array, + *, + debug: bool = False, + interpret: bool = False, + block_size: int = 256, +) -> jax.Array: + _check_shape_consistency(mat1, mat2) + + # shape and dtype + m, n, k = mat1.shape[0], mat1.shape[1], mat2.shape[1] + _, bm, bn = mat1.data.shape + dtype = jnp.result_type(mat1.dtype, mat2.dtype) + + # kernel + fn = pl.pallas_call( + functools.partial(_sdd_kernel, bm=bm, bn=bn, bk=block_size), + out_shape=jax.ShapeDtypeStruct(shape=(m, k), dtype=dtype), + grid=(pl.cdiv(m, bm), pl.cdiv(k, block_size)), + debug=debug, + interpret=interpret + ) + + # call + return fn(mat1.data, mat1.indices, mat1.blocks_per_row, mat2) + + +@jax.jit +def native_sdd_matmul( + mat1: BlockELL, + mat2: jax.Array, +): + _check_shape_consistency(mat1, mat2) + + dtype = jnp.result_type(mat1.dtype, mat2.dtype) + _, n, m = mat1.data.shape + nrows = mat1.shape[0] // n + + def i_body(i): # each row + num_blocks_in_row = mat1.blocks_per_row[i] + + def k_body(k, val): + i_col, i_block = mat1.indices[i, k] + chunk = jax.lax.dynamic_slice(mat2, [i_col * m, 0], (m, mat2.shape[1])) # [m, mat2.shape[1]] + block = mat1.data[i_block] + return val + block.dot(chunk) + + acc = jax.lax.fori_loop(0, num_blocks_in_row, k_body, jnp.zeros((n, mat2.shape[1]), dtype=jnp.float32)) + return acc.astype(dtype) + + out = jax.vmap(i_body)(jnp.arange(nrows)) + return out.reshape((mat1.shape[0], mat2.shape[1])) + + +def sample_sparse_matrix( + m, n, bm, bn, *, + sparse_prob=0.2, + dtype=jnp.float32 +) -> BlockELL: + num_rows = m // bm # number of rows in the Block-ELL matrix + num_cols = n // bn # number of columns in the Block-ELL matrix + blocks_per_row = np.random.binomial(num_cols, sparse_prob, + size=[num_rows]) # [n_rows], number of data in each row + num_blocks = blocks_per_row.sum() + blocks = np.random.randn(num_blocks, bm, bn).astype(dtype) # [n_blocks, bm, bk], block values + + indices = [] + block_index = 0 + max_num_blocks = blocks_per_row.max(axis=0) + for i in range(num_rows): + row = [] + num_blocks_in_row = blocks_per_row[i] + block_indices = np.sort(np.random.permutation(np.arange(num_cols))[:max_num_blocks]) + for j, b in zip(range(max_num_blocks), block_indices): + if j < num_blocks_in_row: + index = [b, block_index] + block_index += 1 + else: + index = [0, 0] + row.append(index) + indices.append(row) + indices = jnp.array(indices) # [n_rows, max_num_blocks_per_row, 2], block indices + + return BlockELL( + (jnp.asarray(blocks), jnp.asarray(blocks_per_row), jnp.asarray(indices)), + shape=(m, n) + ) diff --git a/brainunit/sparse/_block_ell_benchmark.py b/brainunit/sparse/_block_ell_benchmark.py new file mode 100644 index 0000000..8558bfc --- /dev/null +++ b/brainunit/sparse/_block_ell_benchmark.py @@ -0,0 +1,99 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import annotations + +import functools +import timeit + +import brainstate as bst +import jax +import jax.numpy as jnp + +from brainunit.sparse._block_ell import sample_sparse_matrix, sdd_matmul, native_sdd_matmul + + +def main(dtype=jnp.float16, sparse_prob=0.2): + bst.random.seed(1234) + # data + m, k, n = 4096, 4096, 4096 + bm, bn, bk = 32, 32, 256 + print(f"Matrix Shape: {m} x {k} x {n}, dtype: {dtype}, sparse_prob: {sparse_prob}") + + x = sample_sparse_matrix(m, k, bm, bn, sparse_prob=sparse_prob, dtype=dtype) + x_dense = x.todense() + y = bst.random.randn(k, n, dtype=dtype) + + # operations + interpret = jax.devices()[0].platform == "cpu" + # sdd_matmul(x, y, debug=False, block_size=bk, interpret=interpret).block_until_ready() + native_matmul = jax.jit(native_sdd_matmul) + pl_matmul = jax.jit(functools.partial(sdd_matmul, block_size=bk, interpret=interpret)) + dense_matmul = jax.jit(jnp.matmul) + native_grad = jax.jit(jax.grad(native_sdd_matmul, argnums=(0, 1))) + pl_grad = jax.jit(jax.grad(functools.partial(sdd_matmul, block_size=bk, interpret=interpret), argnums=(0, 1))) + dense_grad = jax.jit(jax.grad(jnp.matmul, argnums=(0, 1))) + + # compilation + out_pl = pl_matmul(x, y) + out_hlo = native_matmul(x, y) + out_ref = dense_matmul(x_dense, y) + # out_pl = pl_grad(x, y) + # out_hlo = native_grad(x, y) + # out_ref = dense_grad(x_dense, y) + + # print(jnp.max(jnp.abs(out_pl - out_ref))) + # print(jnp.max(jnp.abs(out_pl - out_ref) / jnp.abs(out_pl))) + # np.testing.assert_allclose(out_pl, out_ref, atol=0.04, rtol=0.04) + # np.testing.assert_allclose(out_hlo, out_ref, atol=0.04, rtol=0.04) + + n_trial1, n_trial2 = (10, 2) if interpret else (1000, 20) + duration = timeit.timeit(lambda: dense_matmul(x_dense, y).block_until_ready(), number=n_trial1) + s1_forward = duration / n_trial1 * 1000 + print(f"Dense Matmul, forward: {s1_forward:.2f}ms") + # duration = timeit.timeit(lambda: jax.block_until_ready(dense_grad(x, y)), number=n_trial1) + # s1_backward = duration / n_trial1 * 1000 + # print(f"Dense Matmul, backward: {s1_backward:.2f}ms") + + duration = timeit.timeit(lambda: pl_matmul(x, y).block_until_ready(), number=n_trial1) + s2_forward = duration / n_trial1 * 1000 + print(f"Pallas Blocksparse Matmul, forward: {s2_forward:.2f}ms") + # duration = timeit.timeit(lambda: jax.block_until_ready(pl_grad(x, y)), number=n_trial1) + # s2_backward = duration / n_trial1 * 1000 + # print(f"Pallas Blocksparse Matmul, backward: {s2_backward:.2f}ms") + + duration = timeit.timeit(lambda: native_matmul(x, y).block_until_ready(), number=n_trial2) + s3_forward = duration / n_trial2 * 1000 + print(f"HLO Blocksparse Matmul, forward: {s3_forward:.2f}ms") + # duration = timeit.timeit(lambda: jax.block_until_ready(native_grad(x, y)), number=n_trial2) + # s3_backward = duration / n_trial2 * 1000 + # print(f"HLO Blocksparse Matmul, backward: {s3_backward:.2f}ms") + + print(f"Forward speedup: {s1_forward / s2_forward:.2f}x (Dense vs. Pallas), " + f"{s3_forward / s2_forward:.2f}x (HLO vs. Pallas)") + # print(f"Backward speedup: {s1_backward / s2_backward:.2f}x (Dense vs. Pallas), " + # f"{s3_backward / s2_backward:.2f}x (HLO vs. Pallas)") + print() + + +if __name__ == "__main__": + main(jnp.float32, 0.3) + main(jnp.float32, 0.2) + main(jnp.float32, 0.1) + main(jnp.float32, 0.05) + main(jnp.float16, 0.3) + main(jnp.float16, 0.2) + main(jnp.float16, 0.1) + main(jnp.float16, 0.05) diff --git a/brainunit/sparse/coo.py b/brainunit/sparse/_coo.py similarity index 98% rename from brainunit/sparse/coo.py rename to brainunit/sparse/_coo.py index a088a42..be1d0aa 100644 --- a/brainunit/sparse/coo.py +++ b/brainunit/sparse/_coo.py @@ -177,19 +177,24 @@ def transpose(self, axes: Tuple[int, ...] | None = None) -> COO: rows_sorted=self._cols_sorted, cols_sorted=self._rows_sorted) def tree_flatten(self) -> Tuple[ - Tuple[jax.Array | Quantity, jax.Array, jax.Array], dict[str, Any] + Tuple[jax.Array | Quantity,], dict[str, Any] ]: - return (self.data, self.row, self.col), self._info._asdict() + aux = self._info._asdict() + aux['row'] = self.row + aux['col'] = self.col + return (self.data,), aux @classmethod def tree_unflatten(cls, aux_data, children): obj = object.__new__(cls) - obj.data, obj.row, obj.col = children - if aux_data.keys() != {'shape', 'rows_sorted', 'cols_sorted'}: + obj.data, = children + if aux_data.keys() != {'shape', 'rows_sorted', 'cols_sorted', 'row', 'col'}: raise ValueError(f"COO.tree_unflatten: invalid {aux_data=}") obj.shape = aux_data['shape'] obj._rows_sorted = aux_data['rows_sorted'] obj._cols_sorted = aux_data['cols_sorted'] + obj.row = aux_data['row'] + obj.col = aux_data['col'] return obj def __abs__(self): diff --git a/brainunit/sparse/coo_test.py b/brainunit/sparse/_coo_test.py similarity index 85% rename from brainunit/sparse/coo_test.py rename to brainunit/sparse/_coo_test.py index e771616..0a4beaa 100644 --- a/brainunit/sparse/coo_test.py +++ b/brainunit/sparse/_coo_test.py @@ -19,6 +19,7 @@ import unittest import brainstate as bst +import jax import brainunit as u @@ -268,3 +269,51 @@ def test_mod(self): data2 % coo1.data ) ) + + def test_grad(self): + for ux in [ + u.ms, + u.UNITLESS, + u.mV, + ]: + data1 = bst.random.randn(10, 20) * ux + sp = u.sparse.COO.fromdense(data1) + + def f(data, x): + return u.get_mantissa((sp.with_data(data) @ x).sum()) + + xs = bst.random.randn(20) + + grads = jax.grad(f)(sp.data, xs) + + def test_grad2(self): + for ux in [ + u.ms, + u.UNITLESS, + u.mV, + ]: + data1 = bst.random.randn(10, 20) * ux + sp = u.sparse.CSR.fromdense(data1) + + def f(sp, x): + return u.get_mantissa((sp @ x).sum()) + + xs = bst.random.randn(20) + + grads = jax.grad(f)(sp, xs) + + def test_jit(self): + @jax.jit + def f(sp, x): + return sp @ x + + for ux in [ + u.ms, + u.UNITLESS, + u.mV, + ]: + data1 = bst.random.randn(10, 20) * ux + sp = u.sparse.CSR.fromdense(data1) + + xs = bst.random.randn(20) + ys = f(sp, xs) diff --git a/brainunit/sparse/csr.py b/brainunit/sparse/_csr.py similarity index 97% rename from brainunit/sparse/csr.py rename to brainunit/sparse/_csr.py index d40a53f..2257bcd 100644 --- a/brainunit/sparse/csr.py +++ b/brainunit/sparse/_csr.py @@ -220,13 +220,13 @@ def __rmatmul__(self, other): raise NotImplementedError(f"matmul with object of shape {other.shape}") def tree_flatten(self): - return (self.data, self.indices, self.indptr), {"shape": self.shape} + return (self.data,), {"shape": self.shape, "indices": self.indices, "indptr": self.indptr} @classmethod def tree_unflatten(cls, aux_data, children): obj = object.__new__(cls) - obj.data, obj.indices, obj.indptr = children - if aux_data.keys() != {'shape'}: + obj.data, = children + if aux_data.keys() != {'shape', 'indices', 'indptr'}: raise ValueError(f"CSR.tree_unflatten: invalid {aux_data=}") obj.__dict__.update(**aux_data) return obj @@ -410,14 +410,14 @@ def __rmatmul__(self, other): raise NotImplementedError(f"matmul with object of shape {other.shape}") def tree_flatten(self): - return (self.data, self.indices, self.indptr), {"shape": self.shape} + return (self.data,), {"shape": self.shape, "indices": self.indices, "indptr": self.indptr} @classmethod def tree_unflatten(cls, aux_data, children): obj = object.__new__(cls) - obj.data, obj.indices, obj.indptr = children - if aux_data.keys() != {'shape'}: - raise ValueError(f"CSC.tree_unflatten: invalid {aux_data=}") + obj.data, = children + if aux_data.keys() != {'shape', 'indices', 'indptr'}: + raise ValueError(f"CSR.tree_unflatten: invalid {aux_data=}") obj.__dict__.update(**aux_data) return obj diff --git a/brainunit/sparse/csr_test.py b/brainunit/sparse/_csr_test.py similarity index 90% rename from brainunit/sparse/csr_test.py rename to brainunit/sparse/_csr_test.py index 09f11ab..981a4cf 100644 --- a/brainunit/sparse/csr_test.py +++ b/brainunit/sparse/_csr_test.py @@ -285,6 +285,38 @@ def f(csr_data, x): grads = jax.grad(f)(csr.data, xs) + def test_grad2(self): + for ux in [ + u.ms, + u.UNITLESS, + u.mV, + ]: + data1 = bst.random.randn(10, 20) * ux + csr = u.sparse.CSR.fromdense(data1) + + def f(csr, x): + return u.get_mantissa((csr @ x).sum()) + + xs = bst.random.randn(20) + + grads = jax.grad(f)(csr, xs) + + def test_jit(self): + @jax.jit + def f(csr, x): + return csr @ x + + for ux in [ + u.ms, + u.UNITLESS, + u.mV, + ]: + data1 = bst.random.randn(10, 20) * ux + csr = u.sparse.CSR.fromdense(data1) + + xs = bst.random.randn(20) + ys = f(csr, xs) + class TestCSC(unittest.TestCase): def test_matvec(self): @@ -547,3 +579,36 @@ def f(data, x): xs = bst.random.randn(20) grads = jax.grad(f)(csc.data, xs) + + def test_grad2(self): + for ux in [ + u.ms, + u.UNITLESS, + u.mV, + ]: + data1 = bst.random.randn(10, 20) * ux + csc = u.sparse.CSC.fromdense(data1) + + def f(csc, x): + return u.get_mantissa((csc @ x).sum()) + + xs = bst.random.randn(20) + + grads = jax.grad(f)(csc, xs) + + def test_jit(self): + + @jax.jit + def f(csc, x): + return csc @ x + + for ux in [ + u.ms, + u.UNITLESS, + u.mV, + ]: + data1 = bst.random.randn(10, 20) * ux + csc = u.sparse.CSC.fromdense(data1) + + xs = bst.random.randn(20) + ys = f(csc, xs)