Skip to content

Commit

Permalink
Updates for OO transformations and surrogate functions (#519)
Browse files Browse the repository at this point in the history
Updates for OO transforma and surrogate functions
  • Loading branch information
ztqakita authored Oct 14, 2023
2 parents dfb705e + fd7d0ad commit 1d24470
Show file tree
Hide file tree
Showing 20 changed files with 273 additions and 70 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/CI-models.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: [ "3.8", "3.9", "3.10", "3.11"]
python-version: [ "3.9", "3.10", "3.11"]

steps:
- uses: actions/checkout@v4
Expand Down Expand Up @@ -69,7 +69,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]
python-version: [ "3.9", "3.10", "3.11"]

steps:
- uses: actions/checkout@v4
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: [ "3.8", "3.9", "3.10", "3.11"]
python-version: [ "3.9", "3.10", "3.11"]

steps:
- uses: actions/checkout@v4
Expand Down Expand Up @@ -90,7 +90,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]
python-version: ["3.9", "3.10", "3.11"]

steps:
- uses: actions/checkout@v4
Expand Down
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,21 @@ Then, you can run the image with the following command:
$ docker run -it --platform linux/amd64 brainpy/brainpy:latest
```


### Using BrainPy with Binder

We provide a Binder environment for BrainPy. You can use the following button to launch the environment:

[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/brainpy/BrainPy-binder/main)


## Ecosystem

- **[BrainPy](https://github.com/brainpy/BrainPy)**: The solution for the general-purpose brain dynamics programming.
- **[brainpy-examples](https://github.com/brainpy/examples)**: Comprehensive examples of BrainPy computation.
- **[brainpy-datasets](https://github.com/brainpy/datasets)**: Neuromorphic and Cognitive Datasets for Brain Dynamics Modeling.
- [第一届神经计算建模与编程培训班](https://github.com/brainpy/1st-neural-modeling-and-programming-course)


## Citing

Expand Down
3 changes: 3 additions & 0 deletions brainpy/_src/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
#


from . import brainpylib_check

# data structure
from .ndarray import *
from .delayvars import *
Expand Down Expand Up @@ -59,3 +61,4 @@
from .modes import *
from .environment import *

del brainpylib_check
29 changes: 29 additions & 0 deletions brainpy/_src/math/brainpylib_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from jax.lib import xla_client

# Register the CPU XLA custom calls
try:
import brainpylib
from brainpylib import cpu_ops

for _name, _value in cpu_ops.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="cpu")
except ImportError:
cpu_ops = None
brainpylib = None

# Register the GPU XLA custom calls
try:
from brainpylib import gpu_ops

for _name, _value in gpu_ops.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="gpu")
except ImportError:
gpu_ops = None

# check brainpy and brainpylib version consistency
_minimal_brainpylib_version = '0.1.10'
if brainpylib is not None:
if brainpylib.__version__ < _minimal_brainpylib_version:
raise SystemError(f'This version of brainpy needs brainpylib >= {_minimal_brainpylib_version}.')
if hasattr(brainpylib, 'check_brainpy_version'):
brainpylib.check_brainpy_version()
10 changes: 5 additions & 5 deletions brainpy/_src/math/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1510,11 +1510,11 @@ def cpu(self):
# dtype exchanging #
# ---------------- #

def bool(self): return jnp.asarray(self.value, dtypt=jnp.bool_)
def int(self): return jnp.asarray(self.value, dtypt=jnp.int32)
def long(self): return jnp.asarray(self.value, dtypt=jnp.int64)
def half(self): return jnp.asarray(self.value, dtypt=jnp.float16)
def float(self): return jnp.asarray(self.value, dtypt=jnp.float32)
def bool(self): return jnp.asarray(self.value, dtype=jnp.bool_)
def int(self): return jnp.asarray(self.value, dtype=jnp.int32)
def long(self): return jnp.asarray(self.value, dtype=jnp.int64)
def half(self): return jnp.asarray(self.value, dtype=jnp.float16)
def float(self): return jnp.asarray(self.value, dtype=jnp.float32)
def double(self): return jnp.asarray(self.value, dtype=jnp.float64)


Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/math/object_transform/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from brainpy import tools, check
from brainpy._src.math.ndarray import Array, _as_jax_array_
from ._tools import (
from .tools import (
dynvar_deprecation,
node_deprecation,
get_stack_cache,
Expand Down
3 changes: 1 addition & 2 deletions brainpy/_src/math/object_transform/controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from brainpy import errors, tools
from brainpy._src.math.interoperability import as_jax
from brainpy._src.math.ndarray import (Array, )
from ._tools import (
from .tools import (
evaluate_dyn_vars,
evaluate_dyn_vars_with_cache,
dynvar_deprecation,
Expand All @@ -31,7 +31,6 @@
VariableStack,
new_transform,
current_transform_number,
transform_stack,
)

__all__ = [
Expand Down
10 changes: 5 additions & 5 deletions brainpy/_src/math/object_transform/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
from jax.sharding import Sharding

from brainpy import tools, check
from ._tools import (dynvar_deprecation,
node_deprecation,
evaluate_dyn_vars_with_cache,
evaluate_dyn_vars,
_partial_fun)
from .tools import (dynvar_deprecation,
node_deprecation,
evaluate_dyn_vars_with_cache,
evaluate_dyn_vars,
_partial_fun)
from .base import BrainPyObject, ObjectTransform
from .naming import get_stack_cache, cache_stack
from ..ndarray import Array
Expand Down
27 changes: 27 additions & 0 deletions brainpy/_src/math/object_transform/tests/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,33 @@ def update(self, x):
program.update(1.)
self.assertTrue(bm.allclose(new_b + 1., program.b))

def test_class_jit2(self):
class SomeProgram(bp.BrainPyObject):
def __init__(self):
super(SomeProgram, self).__init__()
self.a = bm.zeros(2)
self.b = bm.Variable(bm.ones(2))

self.call1 = bm.jit(self.call, static_argnums=0)
self.call2 = bm.jit(self.call, static_argnames=['fit'])

def call(self, fit=True):
a = bm.random.uniform(size=2)
if fit:
a = a.at[0].set(1.)
self.b += a
return self.b

bm.random.seed(123)
program = SomeProgram()
new_b1 = program.call1(True)
new_b2 = program.call2(fit=False)
print()
print(new_b1, )
print(new_b2, )
with self.assertRaises(jax.errors.TracerBoolConversionError):
new_b3 = program.call2(False)

def test_class_jit1_with_disable(self):
class SomeProgram(bp.BrainPyObject):
def __init__(self):
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/math/object_transform/tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import brainpy.math as bm
import jax
import unittest
from brainpy._src.math.object_transform._tools import evaluate_dyn_vars_with_cache
from brainpy._src.math.object_transform.tools import evaluate_dyn_vars_with_cache


class TestTool(unittest.TestCase):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import warnings
from functools import wraps
from typing import Sequence, Tuple, Any
from typing import Sequence, Tuple, Any, Callable

import jax

from brainpy._src.math.object_transform.naming import (cache_stack,
get_stack_cache)
from brainpy._src.math.object_transform.variables import VariableStack, current_transform_number
from brainpy._src.math.object_transform.variables import VariableStack

fun_in_eval_shape = []


class Empty(object):
Expand All @@ -16,11 +18,13 @@ class Empty(object):
empty = Empty()


def _partial_fun(fun,
args: tuple,
kwargs: dict,
static_argnums: Sequence[int] = (),
static_argnames: Sequence[str] = ()):
def _partial_fun(
fun: Callable,
args: tuple,
kwargs: dict,
static_argnums: Sequence[int] = (),
static_argnames: Sequence[str] = ()
):
static_args, dyn_args = [], []
for i, arg in enumerate(args):
if i in static_argnums:
Expand Down Expand Up @@ -79,7 +83,6 @@ def abstract(x):
def evaluate_dyn_vars(
f,
*args,
transform: str = None,
static_argnums: Sequence[int] = (),
static_argnames: Sequence[str] = (),
use_eval_shape: bool = True,
Expand Down Expand Up @@ -119,11 +122,52 @@ def evaluate_dyn_vars_with_cache(

with jax.ensure_compile_time_eval():
with VariableStack() as stack:
rets = jax.eval_shape(f2, *args, **kwargs)
rets = eval_shape(f2, *args, **kwargs)
cache_stack(f, stack) # cache
del args, kwargs, f2
if with_return:
return stack, rets
else:
return stack
return stack


def eval_shape(
fun: Callable,
*args,
static_argnums: Sequence[int] = (),
static_argnames: Sequence[str] = (),
**kwargs
):
"""Compute the shape/dtype of ``fun`` without any FLOPs.
Args:
fun: The callable function.
*args:
**kwargs:
static_argnums: The static argument indices.
static_argnames: The static argument names.
Returns:
The variable stack and the functional returns.
"""
# reorganize the function
if len(static_argnums) or len(static_argnames):
f2, args, kwargs = _partial_fun(fun, args, kwargs,
static_argnums=static_argnums,
static_argnames=static_argnames)
else:
f2, args, kwargs = fun, args, kwargs

# evaluate the function
fun_in_eval_shape.append(fun)
try:
with jax.ensure_compile_time_eval():
with VariableStack() as stack:
if len(fun_in_eval_shape) > 1:
returns = fun(*args, **kwargs)
else:
returns = jax.eval_shape(fun, *args, **kwargs)
finally:
fun_in_eval_shape.pop()
return stack, returns
Loading

0 comments on commit 1d24470

Please sign in to comment.