Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates for OO transforma and surrogate functions #519

Merged
merged 6 commits into from
Oct 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/CI-models.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: [ "3.8", "3.9", "3.10", "3.11"]
python-version: [ "3.9", "3.10", "3.11"]

steps:
- uses: actions/checkout@v4
Expand Down Expand Up @@ -69,7 +69,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]
python-version: [ "3.9", "3.10", "3.11"]

steps:
- uses: actions/checkout@v4
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: [ "3.8", "3.9", "3.10", "3.11"]
python-version: [ "3.9", "3.10", "3.11"]

steps:
- uses: actions/checkout@v4
Expand Down Expand Up @@ -90,7 +90,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]
python-version: ["3.9", "3.10", "3.11"]

steps:
- uses: actions/checkout@v4
Expand Down
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,21 @@ Then, you can run the image with the following command:
$ docker run -it --platform linux/amd64 brainpy/brainpy:latest
```


### Using BrainPy with Binder

We provide a Binder environment for BrainPy. You can use the following button to launch the environment:

[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/brainpy/BrainPy-binder/main)


## Ecosystem

- **[BrainPy](https://github.com/brainpy/BrainPy)**: The solution for the general-purpose brain dynamics programming.
- **[brainpy-examples](https://github.com/brainpy/examples)**: Comprehensive examples of BrainPy computation.
- **[brainpy-datasets](https://github.com/brainpy/datasets)**: Neuromorphic and Cognitive Datasets for Brain Dynamics Modeling.
- [第一届神经计算建模与编程培训班](https://github.com/brainpy/1st-neural-modeling-and-programming-course)


## Citing

Expand Down
3 changes: 3 additions & 0 deletions brainpy/_src/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
#


from . import brainpylib_check

# data structure
from .ndarray import *
from .delayvars import *
Expand Down Expand Up @@ -59,3 +61,4 @@
from .modes import *
from .environment import *

del brainpylib_check
29 changes: 29 additions & 0 deletions brainpy/_src/math/brainpylib_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from jax.lib import xla_client

# Register the CPU XLA custom calls
try:
import brainpylib
from brainpylib import cpu_ops

for _name, _value in cpu_ops.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="cpu")
except ImportError:
cpu_ops = None
brainpylib = None

# Register the GPU XLA custom calls
try:
from brainpylib import gpu_ops

for _name, _value in gpu_ops.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="gpu")
except ImportError:
gpu_ops = None

# check brainpy and brainpylib version consistency
_minimal_brainpylib_version = '0.1.10'
if brainpylib is not None:
if brainpylib.__version__ < _minimal_brainpylib_version:
raise SystemError(f'This version of brainpy needs brainpylib >= {_minimal_brainpylib_version}.')
if hasattr(brainpylib, 'check_brainpy_version'):
brainpylib.check_brainpy_version()
10 changes: 5 additions & 5 deletions brainpy/_src/math/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1510,11 +1510,11 @@ def cpu(self):
# dtype exchanging #
# ---------------- #

def bool(self): return jnp.asarray(self.value, dtypt=jnp.bool_)
def int(self): return jnp.asarray(self.value, dtypt=jnp.int32)
def long(self): return jnp.asarray(self.value, dtypt=jnp.int64)
def half(self): return jnp.asarray(self.value, dtypt=jnp.float16)
def float(self): return jnp.asarray(self.value, dtypt=jnp.float32)
def bool(self): return jnp.asarray(self.value, dtype=jnp.bool_)
def int(self): return jnp.asarray(self.value, dtype=jnp.int32)
def long(self): return jnp.asarray(self.value, dtype=jnp.int64)
def half(self): return jnp.asarray(self.value, dtype=jnp.float16)
def float(self): return jnp.asarray(self.value, dtype=jnp.float32)
def double(self): return jnp.asarray(self.value, dtype=jnp.float64)


Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/math/object_transform/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from brainpy import tools, check
from brainpy._src.math.ndarray import Array, _as_jax_array_
from ._tools import (
from .tools import (
dynvar_deprecation,
node_deprecation,
get_stack_cache,
Expand Down
3 changes: 1 addition & 2 deletions brainpy/_src/math/object_transform/controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from brainpy import errors, tools
from brainpy._src.math.interoperability import as_jax
from brainpy._src.math.ndarray import (Array, )
from ._tools import (
from .tools import (
evaluate_dyn_vars,
evaluate_dyn_vars_with_cache,
dynvar_deprecation,
Expand All @@ -31,7 +31,6 @@
VariableStack,
new_transform,
current_transform_number,
transform_stack,
)

__all__ = [
Expand Down
10 changes: 5 additions & 5 deletions brainpy/_src/math/object_transform/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
from jax.sharding import Sharding

from brainpy import tools, check
from ._tools import (dynvar_deprecation,
node_deprecation,
evaluate_dyn_vars_with_cache,
evaluate_dyn_vars,
_partial_fun)
from .tools import (dynvar_deprecation,
node_deprecation,
evaluate_dyn_vars_with_cache,
evaluate_dyn_vars,
_partial_fun)
from .base import BrainPyObject, ObjectTransform
from .naming import get_stack_cache, cache_stack
from ..ndarray import Array
Expand Down
27 changes: 27 additions & 0 deletions brainpy/_src/math/object_transform/tests/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,33 @@ def update(self, x):
program.update(1.)
self.assertTrue(bm.allclose(new_b + 1., program.b))

def test_class_jit2(self):
class SomeProgram(bp.BrainPyObject):
def __init__(self):
super(SomeProgram, self).__init__()
self.a = bm.zeros(2)
self.b = bm.Variable(bm.ones(2))

self.call1 = bm.jit(self.call, static_argnums=0)
self.call2 = bm.jit(self.call, static_argnames=['fit'])

def call(self, fit=True):
a = bm.random.uniform(size=2)
if fit:
a = a.at[0].set(1.)
self.b += a
return self.b

bm.random.seed(123)
program = SomeProgram()
new_b1 = program.call1(True)
new_b2 = program.call2(fit=False)
print()
print(new_b1, )
print(new_b2, )
with self.assertRaises(jax.errors.TracerBoolConversionError):
new_b3 = program.call2(False)

def test_class_jit1_with_disable(self):
class SomeProgram(bp.BrainPyObject):
def __init__(self):
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/math/object_transform/tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import brainpy.math as bm
import jax
import unittest
from brainpy._src.math.object_transform._tools import evaluate_dyn_vars_with_cache
from brainpy._src.math.object_transform.tools import evaluate_dyn_vars_with_cache


class TestTool(unittest.TestCase):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import warnings
from functools import wraps
from typing import Sequence, Tuple, Any
from typing import Sequence, Tuple, Any, Callable

import jax

from brainpy._src.math.object_transform.naming import (cache_stack,
get_stack_cache)
from brainpy._src.math.object_transform.variables import VariableStack, current_transform_number
from brainpy._src.math.object_transform.variables import VariableStack

fun_in_eval_shape = []


class Empty(object):
Expand All @@ -16,11 +18,13 @@ class Empty(object):
empty = Empty()


def _partial_fun(fun,
args: tuple,
kwargs: dict,
static_argnums: Sequence[int] = (),
static_argnames: Sequence[str] = ()):
def _partial_fun(
fun: Callable,
args: tuple,
kwargs: dict,
static_argnums: Sequence[int] = (),
static_argnames: Sequence[str] = ()
):
static_args, dyn_args = [], []
for i, arg in enumerate(args):
if i in static_argnums:
Expand Down Expand Up @@ -79,7 +83,6 @@ def abstract(x):
def evaluate_dyn_vars(
f,
*args,
transform: str = None,
static_argnums: Sequence[int] = (),
static_argnames: Sequence[str] = (),
use_eval_shape: bool = True,
Expand Down Expand Up @@ -119,11 +122,52 @@ def evaluate_dyn_vars_with_cache(

with jax.ensure_compile_time_eval():
with VariableStack() as stack:
rets = jax.eval_shape(f2, *args, **kwargs)
rets = eval_shape(f2, *args, **kwargs)
cache_stack(f, stack) # cache
del args, kwargs, f2
if with_return:
return stack, rets
else:
return stack
return stack


def eval_shape(
fun: Callable,
*args,
static_argnums: Sequence[int] = (),
static_argnames: Sequence[str] = (),
**kwargs
):
"""Compute the shape/dtype of ``fun`` without any FLOPs.

Args:
fun: The callable function.
*args:
**kwargs:
static_argnums: The static argument indices.
static_argnames: The static argument names.

Returns:
The variable stack and the functional returns.
"""
# reorganize the function
if len(static_argnums) or len(static_argnames):
f2, args, kwargs = _partial_fun(fun, args, kwargs,
static_argnums=static_argnums,
static_argnames=static_argnames)
else:
f2, args, kwargs = fun, args, kwargs

# evaluate the function
fun_in_eval_shape.append(fun)
try:
with jax.ensure_compile_time_eval():
with VariableStack() as stack:
if len(fun_in_eval_shape) > 1:
returns = fun(*args, **kwargs)
else:
returns = jax.eval_shape(fun, *args, **kwargs)
finally:
fun_in_eval_shape.pop()
return stack, returns
Loading