diff --git a/.github/workflows/CI-models.yml b/.github/workflows/CI-models.yml index 523c5c2a7..b07008e78 100644 --- a/.github/workflows/CI-models.yml +++ b/.github/workflows/CI-models.yml @@ -22,7 +22,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: [ "3.8", "3.9", "3.10", "3.11"] + python-version: [ "3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v4 @@ -69,7 +69,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: [ "3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v4 diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 8f81ae5a5..1ccb482a5 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -24,7 +24,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: [ "3.8", "3.9", "3.10", "3.11"] + python-version: [ "3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v4 @@ -90,7 +90,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v4 diff --git a/README.md b/README.md index fd527ef0b..4dc098b67 100644 --- a/README.md +++ b/README.md @@ -64,17 +64,21 @@ Then, you can run the image with the following command: $ docker run -it --platform linux/amd64 brainpy/brainpy:latest ``` + ### Using BrainPy with Binder We provide a Binder environment for BrainPy. You can use the following button to launch the environment: [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/brainpy/BrainPy-binder/main) + ## Ecosystem - **[BrainPy](https://github.com/brainpy/BrainPy)**: The solution for the general-purpose brain dynamics programming. - **[brainpy-examples](https://github.com/brainpy/examples)**: Comprehensive examples of BrainPy computation. - **[brainpy-datasets](https://github.com/brainpy/datasets)**: Neuromorphic and Cognitive Datasets for Brain Dynamics Modeling. +- [第一届神经计算建模与编程培训班](https://github.com/brainpy/1st-neural-modeling-and-programming-course) + ## Citing diff --git a/brainpy/_src/math/__init__.py b/brainpy/_src/math/__init__.py index 5ca00da42..cdeed51ec 100644 --- a/brainpy/_src/math/__init__.py +++ b/brainpy/_src/math/__init__.py @@ -30,6 +30,8 @@ # +from . import brainpylib_check + # data structure from .ndarray import * from .delayvars import * @@ -59,3 +61,4 @@ from .modes import * from .environment import * +del brainpylib_check diff --git a/brainpy/_src/math/brainpylib_check.py b/brainpy/_src/math/brainpylib_check.py new file mode 100644 index 000000000..95e029471 --- /dev/null +++ b/brainpy/_src/math/brainpylib_check.py @@ -0,0 +1,29 @@ +from jax.lib import xla_client + +# Register the CPU XLA custom calls +try: + import brainpylib + from brainpylib import cpu_ops + + for _name, _value in cpu_ops.registrations().items(): + xla_client.register_custom_call_target(_name, _value, platform="cpu") +except ImportError: + cpu_ops = None + brainpylib = None + +# Register the GPU XLA custom calls +try: + from brainpylib import gpu_ops + + for _name, _value in gpu_ops.registrations().items(): + xla_client.register_custom_call_target(_name, _value, platform="gpu") +except ImportError: + gpu_ops = None + +# check brainpy and brainpylib version consistency +_minimal_brainpylib_version = '0.1.10' +if brainpylib is not None: + if brainpylib.__version__ < _minimal_brainpylib_version: + raise SystemError(f'This version of brainpy needs brainpylib >= {_minimal_brainpylib_version}.') + if hasattr(brainpylib, 'check_brainpy_version'): + brainpylib.check_brainpy_version() diff --git a/brainpy/_src/math/ndarray.py b/brainpy/_src/math/ndarray.py index cb5e739e4..0c9bf8f54 100644 --- a/brainpy/_src/math/ndarray.py +++ b/brainpy/_src/math/ndarray.py @@ -1510,11 +1510,11 @@ def cpu(self): # dtype exchanging # # ---------------- # - def bool(self): return jnp.asarray(self.value, dtypt=jnp.bool_) - def int(self): return jnp.asarray(self.value, dtypt=jnp.int32) - def long(self): return jnp.asarray(self.value, dtypt=jnp.int64) - def half(self): return jnp.asarray(self.value, dtypt=jnp.float16) - def float(self): return jnp.asarray(self.value, dtypt=jnp.float32) + def bool(self): return jnp.asarray(self.value, dtype=jnp.bool_) + def int(self): return jnp.asarray(self.value, dtype=jnp.int32) + def long(self): return jnp.asarray(self.value, dtype=jnp.int64) + def half(self): return jnp.asarray(self.value, dtype=jnp.float16) + def float(self): return jnp.asarray(self.value, dtype=jnp.float32) def double(self): return jnp.asarray(self.value, dtype=jnp.float64) diff --git a/brainpy/_src/math/object_transform/autograd.py b/brainpy/_src/math/object_transform/autograd.py index f8dd1d8f8..299ed4202 100644 --- a/brainpy/_src/math/object_transform/autograd.py +++ b/brainpy/_src/math/object_transform/autograd.py @@ -19,7 +19,7 @@ from brainpy import tools, check from brainpy._src.math.ndarray import Array, _as_jax_array_ -from ._tools import ( +from .tools import ( dynvar_deprecation, node_deprecation, get_stack_cache, diff --git a/brainpy/_src/math/object_transform/controls.py b/brainpy/_src/math/object_transform/controls.py index 61c7b7f0d..39032da84 100644 --- a/brainpy/_src/math/object_transform/controls.py +++ b/brainpy/_src/math/object_transform/controls.py @@ -13,7 +13,7 @@ from brainpy import errors, tools from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import (Array, ) -from ._tools import ( +from .tools import ( evaluate_dyn_vars, evaluate_dyn_vars_with_cache, dynvar_deprecation, @@ -31,7 +31,6 @@ VariableStack, new_transform, current_transform_number, - transform_stack, ) __all__ = [ diff --git a/brainpy/_src/math/object_transform/jit.py b/brainpy/_src/math/object_transform/jit.py index f8d2ad5db..7bb36f4e2 100644 --- a/brainpy/_src/math/object_transform/jit.py +++ b/brainpy/_src/math/object_transform/jit.py @@ -14,11 +14,11 @@ from jax.sharding import Sharding from brainpy import tools, check -from ._tools import (dynvar_deprecation, - node_deprecation, - evaluate_dyn_vars_with_cache, - evaluate_dyn_vars, - _partial_fun) +from .tools import (dynvar_deprecation, + node_deprecation, + evaluate_dyn_vars_with_cache, + evaluate_dyn_vars, + _partial_fun) from .base import BrainPyObject, ObjectTransform from .naming import get_stack_cache, cache_stack from ..ndarray import Array diff --git a/brainpy/_src/math/object_transform/tests/test_jit.py b/brainpy/_src/math/object_transform/tests/test_jit.py index 4d66f3e7d..d52903d43 100644 --- a/brainpy/_src/math/object_transform/tests/test_jit.py +++ b/brainpy/_src/math/object_transform/tests/test_jit.py @@ -97,6 +97,33 @@ def update(self, x): program.update(1.) self.assertTrue(bm.allclose(new_b + 1., program.b)) + def test_class_jit2(self): + class SomeProgram(bp.BrainPyObject): + def __init__(self): + super(SomeProgram, self).__init__() + self.a = bm.zeros(2) + self.b = bm.Variable(bm.ones(2)) + + self.call1 = bm.jit(self.call, static_argnums=0) + self.call2 = bm.jit(self.call, static_argnames=['fit']) + + def call(self, fit=True): + a = bm.random.uniform(size=2) + if fit: + a = a.at[0].set(1.) + self.b += a + return self.b + + bm.random.seed(123) + program = SomeProgram() + new_b1 = program.call1(True) + new_b2 = program.call2(fit=False) + print() + print(new_b1, ) + print(new_b2, ) + with self.assertRaises(jax.errors.TracerBoolConversionError): + new_b3 = program.call2(False) + def test_class_jit1_with_disable(self): class SomeProgram(bp.BrainPyObject): def __init__(self): diff --git a/brainpy/_src/math/object_transform/tests/test_tools.py b/brainpy/_src/math/object_transform/tests/test_tools.py index 22357c0b2..fa57ee80d 100644 --- a/brainpy/_src/math/object_transform/tests/test_tools.py +++ b/brainpy/_src/math/object_transform/tests/test_tools.py @@ -2,7 +2,7 @@ import brainpy.math as bm import jax import unittest -from brainpy._src.math.object_transform._tools import evaluate_dyn_vars_with_cache +from brainpy._src.math.object_transform.tools import evaluate_dyn_vars_with_cache class TestTool(unittest.TestCase): diff --git a/brainpy/_src/math/object_transform/_tools.py b/brainpy/_src/math/object_transform/tools.py similarity index 70% rename from brainpy/_src/math/object_transform/_tools.py rename to brainpy/_src/math/object_transform/tools.py index 6e126f093..7b519590a 100644 --- a/brainpy/_src/math/object_transform/_tools.py +++ b/brainpy/_src/math/object_transform/tools.py @@ -1,12 +1,14 @@ import warnings from functools import wraps -from typing import Sequence, Tuple, Any +from typing import Sequence, Tuple, Any, Callable import jax from brainpy._src.math.object_transform.naming import (cache_stack, get_stack_cache) -from brainpy._src.math.object_transform.variables import VariableStack, current_transform_number +from brainpy._src.math.object_transform.variables import VariableStack + +fun_in_eval_shape = [] class Empty(object): @@ -16,11 +18,13 @@ class Empty(object): empty = Empty() -def _partial_fun(fun, - args: tuple, - kwargs: dict, - static_argnums: Sequence[int] = (), - static_argnames: Sequence[str] = ()): +def _partial_fun( + fun: Callable, + args: tuple, + kwargs: dict, + static_argnums: Sequence[int] = (), + static_argnames: Sequence[str] = () +): static_args, dyn_args = [], [] for i, arg in enumerate(args): if i in static_argnums: @@ -79,7 +83,6 @@ def abstract(x): def evaluate_dyn_vars( f, *args, - transform: str = None, static_argnums: Sequence[int] = (), static_argnames: Sequence[str] = (), use_eval_shape: bool = True, @@ -119,7 +122,7 @@ def evaluate_dyn_vars_with_cache( with jax.ensure_compile_time_eval(): with VariableStack() as stack: - rets = jax.eval_shape(f2, *args, **kwargs) + rets = eval_shape(f2, *args, **kwargs) cache_stack(f, stack) # cache del args, kwargs, f2 if with_return: @@ -127,3 +130,44 @@ def evaluate_dyn_vars_with_cache( else: return stack return stack + + +def eval_shape( + fun: Callable, + *args, + static_argnums: Sequence[int] = (), + static_argnames: Sequence[str] = (), + **kwargs +): + """Compute the shape/dtype of ``fun`` without any FLOPs. + + Args: + fun: The callable function. + *args: + **kwargs: + static_argnums: The static argument indices. + static_argnames: The static argument names. + + Returns: + The variable stack and the functional returns. + """ + # reorganize the function + if len(static_argnums) or len(static_argnames): + f2, args, kwargs = _partial_fun(fun, args, kwargs, + static_argnums=static_argnums, + static_argnames=static_argnames) + else: + f2, args, kwargs = fun, args, kwargs + + # evaluate the function + fun_in_eval_shape.append(fun) + try: + with jax.ensure_compile_time_eval(): + with VariableStack() as stack: + if len(fun_in_eval_shape) > 1: + returns = fun(*args, **kwargs) + else: + returns = jax.eval_shape(fun, *args, **kwargs) + finally: + fun_in_eval_shape.pop() + return stack, returns diff --git a/brainpy/_src/math/surrogate/_one_input.py b/brainpy/_src/math/surrogate/_one_input.py index 382bfdda3..007d216ed 100644 --- a/brainpy/_src/math/surrogate/_one_input.py +++ b/brainpy/_src/math/surrogate/_one_input.py @@ -78,7 +78,7 @@ def surrogate_fun(self, x): return sci.special.expit(x) def surrogate_grad(self, dz, x): - sgax = sci.special.expit(x * self.alpha) + sgax = sci.special.expit(as_jax(x) * self.alpha) dx = as_jax(dz) * (1. - sgax) * sgax * self.alpha return dx @@ -159,6 +159,7 @@ def __init__(self, alpha: float = 1., forward_use_surrogate=False): self.alpha = alpha def surrogate_fun(self, x): + x = as_jax(x) z = jnp.where(x < -1 / self.alpha, 0., jnp.where(x > 1 / self.alpha, @@ -167,6 +168,7 @@ def surrogate_fun(self, x): return z def surrogate_grad(self, dz, x): + x = as_jax(x) dx = jnp.where(jnp.abs(x) > 1 / self.alpha, 0., dz * (-(self.alpha * x) ** 2 + self.alpha)) return dx @@ -263,10 +265,12 @@ def __init__(self, alpha: float = 1., forward_use_surrogate=False): self.alpha = alpha def surrogate_grad(self, dz, x): + x = as_jax(x) dx = (self.alpha / 2) * jnp.exp(-self.alpha * jnp.abs(x)) return dx * as_jax(dz) def surrogate_fun(self, x): + x = as_jax(x) return jnp.where(x < 0, jnp.exp(self.alpha * x) / 2, 1 - jnp.exp(-self.alpha * x) / 2) def __repr__(self): @@ -352,10 +356,12 @@ def __init__(self, alpha=1., forward_use_surrogate=False): self.alpha = alpha def surrogate_grad(self, dz, x): + x = as_jax(x) dx = self.alpha * 0.5 / (1 + jnp.abs(self.alpha * x)) ** 2 return dx * as_jax(dz) def surrogate_fun(self, x): + x = as_jax(x) return x / (2 / self.alpha + 2 * jnp.abs(x)) + 0.5 def __repr__(self): @@ -436,10 +442,12 @@ def __init__(self, alpha=1., forward_use_surrogate=False): self.alpha = alpha def surrogate_grad(self, dz, x): + x = as_jax(x) dx = self.alpha * 0.5 / (1 + (jnp.pi / 2 * self.alpha * x) ** 2) return dx * as_jax(dz) def surrogate_fun(self, x): + x = as_jax(x) return jnp.arctan2(jnp.pi / 2 * self.alpha * x) / jnp.pi + 0.5 def __repr__(self): @@ -519,10 +527,12 @@ def __init__(self, alpha=1., forward_use_surrogate=False): self.alpha = alpha def surrogate_grad(self, dz, x): + x = as_jax(x) dx = as_jax(dz) / (1 / self.alpha + jnp.abs(x)) return dx def surrogate_fun(self, x): + x = as_jax(x) return jnp.where(x < 0, -1., 1.) * jnp.log(jnp.abs(self.alpha * x) + 1) def __repr__(self): @@ -615,10 +625,12 @@ def __init__(self, alpha=1., forward_use_surrogate=False): self.alpha = alpha def surrogate_grad(self, dz, x): + x = as_jax(x) dx = (self.alpha / jnp.sqrt(jnp.pi)) * jnp.exp(-jnp.power(self.alpha, 2) * x * x) return dx * as_jax(dz) def surrogate_fun(self, x): + x = as_jax(x) return sci.special.erf(-self.alpha * x) * 0.5 def __repr__(self): @@ -709,6 +721,7 @@ def __init__(self, c=0.01, w=1., forward_use_surrogate=False): self.w = w def surrogate_fun(self, x): + x = as_jax(x) z = jnp.where(x < -self.w, self.c * x + self.c * self.w, jnp.where(x > self.w, @@ -717,6 +730,7 @@ def surrogate_fun(self, x): return z def surrogate_grad(self, dz, x): + x = as_jax(x) dx = jnp.where(jnp.abs(x) > self.w, self.c, 1 / self.w) return dx * as_jax(dz) @@ -822,6 +836,7 @@ def __init__(self, n=2, t_period=8., forward_use_surrogate=False): self.t_period = t_period def surrogate_grad(self, dz, x): + x = as_jax(x) w = jnp.pi * 2. / self.t_period dx = jnp.cos(w * x) for i in range(2, self.n): @@ -830,6 +845,7 @@ def surrogate_grad(self, dz, x): return dx * as_jax(dz) def surrogate_fun(self, x): + x = as_jax(x) w = jnp.pi * 2. / self.t_period ret = jnp.sin(w * x) for i in range(2, self.n): @@ -919,12 +935,14 @@ def __init__(self, alpha=4., beta=1., epsilon=1e-8, forward_use_surrogate=False) self.epsilon = epsilon def surrogate_fun(self, x): + x = as_jax(x) z = jnp.where(x < 0., sci.special.expit(x * self.alpha), self.beta * jnp.log(jnp.abs((x + 1.)) + self.epsilon) + 0.5) return z def surrogate_grad(self, dz, x): + x = as_jax(x) sg = sci.special.expit(self.alpha * x) dx = jnp.where(x < 0., self.alpha * sg * (1. - sg), self.beta / (x + 1.)) return dx * as_jax(dz) @@ -1023,10 +1041,12 @@ def __init__(self, alpha=2., forward_use_surrogate=False): self.alpha = alpha def surrogate_grad(self, dz, x): + x = as_jax(x) dx = jnp.power(1 + 2 / (self.alpha + 1) * jnp.abs(x), -self.alpha) return dx * as_jax(dz) def surrogate_fun(self, x): + x = as_jax(x) z = jnp.where(x < 0., 0.5 * jnp.power(1 - 2 / (self.alpha - 1) * jnp.abs(x), 1 - self.alpha), 1. - 0.5 * jnp.power(1 + 2 / (self.alpha - 1) * jnp.abs(x), 1 - self.alpha)) @@ -1117,9 +1137,11 @@ def __init__(self, alpha=0.1, beta=1., forward_use_surrogate=False): self.beta = beta def surrogate_fun(self, x): + x = as_jax(x) return jnp.where(x < 0., self.alpha * x, self.beta * x) def surrogate_grad(self, dz, x): + x = as_jax(x) dx = jnp.where(x < 0., self.alpha, self.beta) return dx * as_jax(dz) @@ -1209,6 +1231,7 @@ def __init__(self, alpha=0., forward_use_surrogate=False): self.alpha = alpha def surrogate_fun(self, x): + x = as_jax(x) z = jnp.where(x > 1, jnp.log(x), jnp.where(x > 0, @@ -1217,6 +1240,7 @@ def surrogate_fun(self, x): return z def surrogate_grad(self, dz, x): + x = as_jax(x) dx = jnp.where(x > 1, 1 / x, jnp.where(x > 0, @@ -1314,6 +1338,7 @@ def __init__(self, alpha=0.3, width=1.): self.width = width def surrogate_grad(self, dz, x): + x = as_jax(x) dx = jnp.maximum(self.alpha * self.width - jnp.abs(x) * self.alpha, 0) return dx * as_jax(dz) @@ -1393,6 +1418,7 @@ def __init__(self, sigma=0.5, alpha=0.5): self.alpha = alpha def surrogate_grad(self, dz, x): + x = as_jax(x) dx = jnp.exp(-(x ** 2) / 2 * jnp.power(self.sigma, 2)) / (jnp.sqrt(2 * jnp.pi) * self.sigma) return self.alpha * dx * as_jax(dz) @@ -1473,6 +1499,7 @@ def __init__(self, h=0.15, s=6.0, sigma=0.5, scale=0.5): self.scale = scale def surrogate_grad(self, dz, x): + x = as_jax(x) g1 = jnp.exp(-x ** 2 / (2 * jnp.power(self.sigma, 2))) / (jnp.sqrt(2 * jnp.pi) * self.sigma) g2 = jnp.exp(-(x - self.sigma) ** 2 / (2 * jnp.power(self.s * self.sigma, 2)) ) / (jnp.sqrt(2 * jnp.pi) * self.s * self.sigma) @@ -1564,6 +1591,7 @@ def __init__(self, alpha=100.): self.alpha = alpha def surrogate_grad(self, dz, x): + x = as_jax(x) dx = as_jax(dz) / (self.alpha * jnp.abs(x) + 1.0) ** 2 return dx @@ -1634,6 +1662,7 @@ def __init__(self, alpha=1.): self.alpha = alpha def surrogate_grad(self, dz, x): + x = as_jax(x) dx = as_jax(dz) * jnp.exp(-self.alpha * jnp.abs(x)) return dx diff --git a/brainpy/_src/tools/package.py b/brainpy/_src/tools/package.py index 870b88129..4c83bdd51 100644 --- a/brainpy/_src/tools/package.py +++ b/brainpy/_src/tools/package.py @@ -24,8 +24,6 @@ ] -_minimal_brainpylib_version = '0.1.10' - def import_numba(): if numba is None: @@ -38,8 +36,6 @@ def import_brainpylib(): if brainpylib is None: raise ModuleNotFoundError('brainpylib is needed. Please install brainpylib through:\n' '> pip install brainpylib\n\n') - if brainpylib.__version__ < _minimal_brainpylib_version: - raise SystemError(f'This version of brainpy needs brainpylib >= {_minimal_brainpylib_version}.') return brainpylib diff --git a/brainpy/math/ndarray.py b/brainpy/math/ndarray.py index 6d679d111..fee2b264e 100644 --- a/brainpy/math/ndarray.py +++ b/brainpy/math/ndarray.py @@ -5,4 +5,5 @@ Array as Tensor, ndarray as ndarray, JaxArray as JaxArray, + ShardedArray as ShardedArray, ) diff --git a/brainpy/math/oo_transform.py b/brainpy/math/oo_transform.py index 94ab09a9d..0b012f869 100644 --- a/brainpy/math/oo_transform.py +++ b/brainpy/math/oo_transform.py @@ -1,20 +1,26 @@ # -*- coding: utf-8 -*- -from brainpy._src.math.object_transform.base import (BrainPyObject as BrainPyObject, - FunAsObject as FunAsObject) +from brainpy._src.math.object_transform.base import ( + BrainPyObject as BrainPyObject, + FunAsObject as FunAsObject +) from brainpy._src.math.object_transform.function import (Partial as Partial) -from brainpy._src.math.object_transform.base import (NodeList as NodeList, - NodeDict as NodeDict, - node_dict as node_dict, - node_list as node_list, ) -from brainpy._src.math.object_transform.variables import (Variable as Variable, - Parameter as Parameter, - TrainVar as TrainVar, - VariableView as VariableView, - VarList as VarList, - VarDict as VarDict, - var_list as var_list, - var_dict as var_dict, ) +from brainpy._src.math.object_transform.base import ( + NodeList as NodeList, + NodeDict as NodeDict, + node_dict as node_dict, + node_list as node_list, +) +from brainpy._src.math.object_transform.variables import ( + Variable as Variable, + Parameter as Parameter, + TrainVar as TrainVar, + VariableView as VariableView, + VarList as VarList, + VarDict as VarDict, + var_list as var_list, + var_dict as var_dict, +) from brainpy._src.math.object_transform.autograd import ( grad as grad, @@ -46,3 +52,8 @@ to_object as to_object, function as function, ) + +from brainpy._src.math.object_transform.tools import ( + eval_shape as eval_shape, +) + diff --git a/brainpy/math/surrogate.py b/brainpy/math/surrogate.py index 7fb4a05c5..3f3daa2b7 100644 --- a/brainpy/math/surrogate.py +++ b/brainpy/math/surrogate.py @@ -1,10 +1,6 @@ # -*- coding: utf-8 -*- -# from brainpy._src.math.surrogate._utils import ( -# vjp_custom as vjp_custom -# ) - from brainpy._src.math.surrogate.base import ( Surrogate ) diff --git a/docs/apis/brainpy.math.oo_transform.rst b/docs/apis/brainpy.math.oo_transform.rst index 2d279a4ee..5ee94c615 100644 --- a/docs/apis/brainpy.math.oo_transform.rst +++ b/docs/apis/brainpy.math.oo_transform.rst @@ -1,16 +1,19 @@ Object-oriented Transformations =============================== +.. currentmodule:: brainpy.math +.. automodule:: brainpy.math + + .. contents:: :local: :depth: 1 + Objects and Variables --------------------- -.. currentmodule:: brainpy.math -.. automodule:: brainpy.math .. autosummary:: :toctree: generated/ @@ -34,11 +37,10 @@ Objects and Variables var_dict + Object-oriented Transformations ------------------------------- -.. currentmodule:: brainpy.math -.. automodule:: brainpy.math .. autosummary:: :toctree: generated/ @@ -61,4 +63,17 @@ Object-oriented Transformations jit cls_jit to_object - function \ No newline at end of file + function + + +Helpers for Object-oriented Transformations +------------------------------------------- + + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + eval_shape + diff --git a/docs/apis/brainpy.math.rst b/docs/apis/brainpy.math.rst index b108840a4..4e9426700 100644 --- a/docs/apis/brainpy.math.rst +++ b/docs/apis/brainpy.math.rst @@ -12,11 +12,21 @@ General Mathematical Operators -Array Interoperability ----------------------- +BrainPy Array +------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + Array + ShardedArray -.. currentmodule:: brainpy.math -.. automodule:: brainpy.math + + +Array Interoperability to JAX +----------------------------- .. autosummary:: :toctree: generated/ @@ -25,17 +35,38 @@ Array Interoperability as_device_array as_jax + + + + +Array Interoperability to NumPy +------------------------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + as_ndarray as_numpy + + + +Array Interoperability to BrainPy +--------------------------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + as_variable + asarray Activation Functions -------------------- -.. currentmodule:: brainpy.math -.. automodule:: brainpy.math - .. autosummary:: :toctree: generated/ :nosignatures: diff --git a/docs/tutorial_advanced/operator_custom_with_numba.ipynb b/docs/tutorial_advanced/operator_custom_with_numba.ipynb index 84d4deb79..215d41418 100644 --- a/docs/tutorial_advanced/operator_custom_with_numba.ipynb +++ b/docs/tutorial_advanced/operator_custom_with_numba.ipynb @@ -9,6 +9,15 @@ "# Operator Customization with Numba" ] }, + { + "cell_type": "markdown", + "source": [ + "## English version" + ], + "metadata": { + "collapsed": false + } + }, { "cell_type": "markdown", "source": [ @@ -48,7 +57,7 @@ { "cell_type": "markdown", "source": [ - "## ``brainpy.math.CustomOpByNumba``\n", + "### ``brainpy.math.CustomOpByNumba``\n", "\n", "``brainpy.math.CustomOpByNumba`` is also called ``brainpy.math.XLACustomOp``.\n", "\n", @@ -122,7 +131,7 @@ { "cell_type": "markdown", "source": [ - "## Return multiple values ``multiple_returns=True``\n", + "### Return multiple values ``multiple_returns=True``\n", "\n", "If the result of our computation needs to return multiple arrays, then we need to use ``multiple_returns=True`` in our use of registering the operator. In this case, ``outs`` will be a list containing multiple arrays, not an array.\n", "\n", @@ -155,7 +164,7 @@ { "cell_type": "markdown", "source": [ - "## Non-Tracer parameters\n", + "### Non-Tracer parameters\n", "\n", "In the ``eval_shape`` function, all arguments are abstract information (containing only the shape and type) if they are arguments that can be traced by ``jax.jit``. However, if we infer the output data type requires additional information beyond the input parameter information, then we need to define non-Tracer parameters.\n", "\n", @@ -206,7 +215,7 @@ { "cell_type": "markdown", "source": [ - "## Example: A sparse operator\n", + "### Example: A sparse operator\n", "\n", "To illustrate the effectiveness of this approach, we define in this an event-driven sparse computation operator." ], @@ -282,6 +291,15 @@ } } }, + { + "cell_type": "markdown", + "source": [ + "## 中文版" + ], + "metadata": { + "collapsed": false + } + }, { "cell_type": "markdown", "source": [ @@ -321,7 +339,7 @@ { "cell_type": "markdown", "source": [ - "## ``brainpy.math.CustomOpByNumba``接口\n", + "### ``brainpy.math.CustomOpByNumba``接口\n", "\n", "``brainpy.math.CustomOpByNumba`` 也叫做``brainpy.math.XLACustomOp``。\n", "\n", @@ -396,7 +414,7 @@ { "cell_type": "markdown", "source": [ - "## 返回多个值 ``multiple_returns=True``\n", + "### 返回多个值 ``multiple_returns=True``\n", "\n", "如果我们的计算结果需要返回多个数组,那么,我们在注册算子的使用需要使用``multiple_returns=True``。此时,``outs``将会是一个包含多个数组的列表,而不是一个数组。\n", "\n", @@ -428,7 +446,7 @@ { "cell_type": "markdown", "source": [ - "## 非Tracer参数\n", + "### 非Tracer参数\n", "\n", "在``eval_shape``函数中推断数据类型时,如果所有参数都是可以被``jax.jit``追踪的参数,那么所有参数都是抽象信息(只包含形状和类型)。如果有时推断输出数据类型时还需要除输入参数信息以外的额外信息,此时我们需要定义非Tracer参数。\n", "\n", @@ -479,7 +497,7 @@ { "cell_type": "markdown", "source": [ - "## 示例:一个稀疏算子\n", + "### 示例:一个稀疏算子\n", "\n", "为了说明这种方法的有效性,我们在这个定义一个事件驱动的稀疏计算算子。" ],