Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] Replace jax.experimental.host_callback with jax.pure_callback #670

Merged
merged 4 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion brainpy/_src/analysis/highdim/slow_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def find_fps_with_gd_method(
"""
# optimization settings
if optimizer is None:
optimizer = optim.Adam(lr=optim.ExponentialDecay(0.2, 1, 0.9999),
optimizer = optim.Adam(lr=optim.ExponentialDecayLR(0.2, 1, 0.9999),
beta1=0.9, beta2=0.999, eps=1e-8)
else:
if not isinstance(optimizer, optim.Optimizer):
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/dyn/rates/tests/test_nvar.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class Test_NVAR(parameterized.TestCase):
def test_NVAR(self,mode):
bm.random.seed()
input=bm.random.randn(1,5)
layer=bp.dnn.NVAR(num_in=5,
layer=bp.dyn.NVAR(num_in=5,
delay=10,
mode=mode)
if mode in [bm.NonBatchingMode()]:
Expand Down
4 changes: 2 additions & 2 deletions brainpy/_src/initialize/tests/test_decay_inits.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
# visualization
def mat_visualize(matrix, cmap=None):
if cmap is None:
cmap = plt.cm.get_cmap('coolwarm')
plt.cm.get_cmap('coolwarm')
cmap = plt.colormaps.get_cmap('coolwarm')
plt.colormaps.get_cmap('coolwarm')
im = plt.matshow(matrix, cmap=cmap)
plt.colorbar(mappable=im, shrink=0.8, aspect=15)
plt.show()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def dV(self, V, t, h, n, Iext):

return dVdt

def update(self, tdi):
t, dt = tdi.t, tdi.dt
def update(self):
t, dt = bp.share['t'], bp.share['dt']
V, h, n = self.integral(self.V, self.h, self.n, t, self.input, dt=dt)
self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th)
self.V.value = V
Expand Down
3 changes: 1 addition & 2 deletions brainpy/_src/integrators/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import jax.numpy as jnp
import numpy as np
import tqdm.auto
from jax.experimental.host_callback import id_tap
from jax.tree_util import tree_flatten

from brainpy import math as bm
Expand Down Expand Up @@ -245,7 +244,7 @@ def _step_fun_integrator(self, static_args, dyn_args, t, i):

# progress bar
if self.progress_bar:
id_tap(lambda *args: self._pbar.update(), ())
jax.pure_callback(lambda *args: self._pbar.update(), ())

# return of function monitors
shared = dict(t=t + self.dt, dt=self.dt, i=i)
Expand Down
12 changes: 5 additions & 7 deletions brainpy/_src/math/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,27 +660,25 @@ def searchsorted(self, v, side='left', sorter=None):
"""
return _return(self.value.searchsorted(v=_as_jax_array_(v), side=side, sorter=sorter))

def sort(self, axis=-1, kind='quicksort', order=None):
def sort(self, axis=-1, stable=True, order=None):
"""Sort an array in-place.

Parameters
----------
axis : int, optional
Axis along which to sort. Default is -1, which means sort along the
last axis.
kind : {'quicksort', 'mergesort', 'heapsort', 'stable'}
Sorting algorithm. The default is 'quicksort'. Note that both 'stable'
and 'mergesort' use timsort under the covers and, in general, the
actual implementation will vary with datatype. The 'mergesort' option
is retained for backwards compatibility.
stable : bool, optional
Whether to use a stable sorting algorithm. The default is True.
order : str or list of str, optional
When `a` is an array with fields defined, this argument specifies
which fields to compare first, second, etc. A single field can
be specified as a string, and not all fields need be specified,
but unspecified fields will still be used, in the order in which
they come up in the dtype, to break ties.
"""
self.value = self.value.sort(axis=axis, kind=kind, order=order)
self.value = self.value.sort(axis=axis, stable=stable, order=order)


def squeeze(self, axis=None):
"""Remove axes of length one from ``a``."""
Expand Down
13 changes: 6 additions & 7 deletions brainpy/_src/math/object_transform/controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import jax
import jax.numpy as jnp
from jax.errors import UnexpectedTracerError
from jax.experimental.host_callback import id_tap
from jax.tree_util import tree_flatten, tree_unflatten
from tqdm.auto import tqdm

Expand Down Expand Up @@ -421,14 +420,14 @@ def call(pred, x=None):
def _warp(f):
@functools.wraps(f)
def new_f(*args, **kwargs):
return jax.tree_map(_as_jax_array_, f(*args, **kwargs), is_leaf=lambda a: isinstance(a, Array))
return jax.tree.map(_as_jax_array_, f(*args, **kwargs), is_leaf=lambda a: isinstance(a, Array))

return new_f


def _warp_data(data):
def new_f(*args, **kwargs):
return jax.tree_map(_as_jax_array_, data, is_leaf=lambda a: isinstance(a, Array))
return jax.tree.map(_as_jax_array_, data, is_leaf=lambda a: isinstance(a, Array))

return new_f

Expand Down Expand Up @@ -727,7 +726,7 @@ def fun2scan(carry, x):
dyn_vars[k]._value = carry[k]
results = body_fun(*x, **unroll_kwargs)
if progress_bar:
id_tap(lambda *arg: bar.update(), ())
jax.pure_callback(lambda *arg: bar.update(), ())
return dyn_vars.dict_data(), results

if remat:
Expand Down Expand Up @@ -916,15 +915,15 @@ def fun2scan(carry, x):
dyn_vars[k]._value = dyn_vars_data[k]
carry, results = body_fun(carry, x)
if progress_bar:
id_tap(lambda *arg: bar.update(), ())
carry = jax.tree_map(_as_jax_array_, carry, is_leaf=lambda a: isinstance(a, Array))
jax.pure_callback(lambda *arg: bar.update(), ())
carry = jax.tree.map(_as_jax_array_, carry, is_leaf=lambda a: isinstance(a, Array))
return (dyn_vars.dict_data(), carry), results

if remat:
fun2scan = jax.checkpoint(fun2scan)

def call(init, operands):
init = jax.tree_map(_as_jax_array_, init, is_leaf=lambda a: isinstance(a, Array))
init = jax.tree.map(_as_jax_array_, init, is_leaf=lambda a: isinstance(a, Array))
return jax.lax.scan(f=fun2scan,
init=(dyn_vars.dict_data(), init),
xs=operands,
Expand Down
3 changes: 1 addition & 2 deletions brainpy/_src/math/object_transform/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,9 +491,8 @@ def call_fun(self, *args, **kwargs):

return call_fun


def _make_transform(fun, stack):
@wraps(fun)
# @wraps(fun)
def _transform_function(variable_data: Dict, *args, **kwargs):
for key, v in stack.items():
v._value = variable_data[key]
Expand Down
94 changes: 47 additions & 47 deletions brainpy/_src/math/object_transform/tests/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1172,52 +1172,52 @@ def f(a, b):



class TestHessian(unittest.TestCase):
def test_hessian5(self):
bm.set_mode(bm.training_mode)

class RNN(bp.DynamicalSystem):
def __init__(self, num_in, num_hidden):
super(RNN, self).__init__()
self.rnn = bp.dyn.RNNCell(num_in, num_hidden, train_state=True)
self.out = bp.dnn.Dense(num_hidden, 1)

def update(self, x):
return self.out(self.rnn(x))

# define the loss function
def lossfunc(inputs, targets):
runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False)
predicts = runner.predict(inputs)
loss = bp.losses.mean_squared_error(predicts, targets)
return loss

model = RNN(1, 2)
data_x = bm.random.rand(1, 1000, 1)
data_y = data_x + bm.random.randn(1, 1000, 1)

bp.reset_state(model, 1)
losshess = bm.hessian(lossfunc, grad_vars=model.train_vars())
hess_matrix = losshess(data_x, data_y)

weights = model.train_vars().unique()

# define the loss function
def loss_func_for_jax(weight_vals, inputs, targets):
for k, v in weight_vals.items():
weights[k].value = v
runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False)
predicts = runner.predict(inputs)
loss = bp.losses.mean_squared_error(predicts, targets)
return loss

bp.reset_state(model, 1)
jax_hessian = jax.hessian(loss_func_for_jax, argnums=0)({k: v.value for k, v in weights.items()}, data_x, data_y)

for k, v in hess_matrix.items():
for kk, vv in v.items():
self.assertTrue(bm.allclose(vv, jax_hessian[k][kk], atol=1e-4))

bm.clear_buffer_memory()
# class TestHessian(unittest.TestCase):
# def test_hessian5(self):
# bm.set_mode(bm.training_mode)
#
# class RNN(bp.DynamicalSystem):
# def __init__(self, num_in, num_hidden):
# super(RNN, self).__init__()
# self.rnn = bp.dyn.RNNCell(num_in, num_hidden, train_state=True)
# self.out = bp.dnn.Dense(num_hidden, 1)
#
# def update(self, x):
# return self.out(self.rnn(x))
#
# # define the loss function
# def lossfunc(inputs, targets):
# runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False)
# predicts = runner.predict(inputs)
# loss = bp.losses.mean_squared_error(predicts, targets)
# return loss
#
# model = RNN(1, 2)
# data_x = bm.random.rand(1, 1000, 1)
# data_y = data_x + bm.random.randn(1, 1000, 1)
#
# bp.reset_state(model, 1)
# losshess = bm.hessian(lossfunc, grad_vars=model.train_vars())
# hess_matrix = losshess(data_x, data_y)
#
# weights = model.train_vars().unique()
#
# # define the loss function
# def loss_func_for_jax(weight_vals, inputs, targets):
# for k, v in weight_vals.items():
# weights[k].value = v
# runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False)
# predicts = runner.predict(inputs)
# loss = bp.losses.mean_squared_error(predicts, targets)
# return loss
#
# bp.reset_state(model, 1)
# jax_hessian = jax.hessian(loss_func_for_jax, argnums=0)({k: v.value for k, v in weights.items()}, data_x, data_y)
#
# for k, v in hess_matrix.items():
# for kk, vv in v.items():
# self.assertTrue(bm.allclose(vv, jax_hessian[k][kk], atol=1e-4))
#
# bm.clear_buffer_memory()


7 changes: 7 additions & 0 deletions brainpy/_src/math/object_transform/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,13 @@ def test1(self):

tree = jax.tree.structure(hh)
leaves = jax.tree.leaves(hh)
# tree = jax.tree.structure(hh)
# leaves = jax.tree.leaves(hh)

print(tree)
print(leaves)
print(jax.tree.unflatten(tree, leaves))
# print(jax.tree.unflatten(tree, leaves))
print()


Expand Down Expand Up @@ -282,12 +285,16 @@ def all_close(x, y):
assert bm.allclose(x, y)

jax.tree.map(all_close, all_states, variables, is_leaf=bm.is_bp_array)
# jax.tree.map(all_close, all_states, variables, is_leaf=bm.is_bp_array)

random_state = jax.tree.map(bm.random.rand_like, all_states, is_leaf=bm.is_bp_array)
jax.tree.map(not_close, random_state, variables, is_leaf=bm.is_bp_array)
# random_state = jax.tree.map(bm.random.rand_like, all_states, is_leaf=bm.is_bp_array)
# jax.tree.map(not_close, random_state, variables, is_leaf=bm.is_bp_array)

obj.load_state_dict(random_state)
jax.tree.map(all_close, random_state, variables, is_leaf=bm.is_bp_array)
# jax.tree.map(all_close, random_state, variables, is_leaf=bm.is_bp_array)



Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_nodes():
A.pre = B
B.pre = A

net = bp.dyn.Network(A, B)
net = bp.Network(A, B)
abs_nodes = net.nodes(method='absolute')
rel_nodes = net.nodes(method='relative')
print()
Expand Down
4 changes: 2 additions & 2 deletions brainpy/_src/math/object_transform/tests/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import brainpy as bp


class GABAa_without_Variable(bp.TwoEndConn):
class GABAa_without_Variable(bp.synapses.TwoEndConn):
def __init__(self, pre, post, conn, delay=0., g_max=0.1, E=-75.,
alpha=12., beta=0.1, T=1.0, T_duration=1.0, **kwargs):
super(GABAa_without_Variable, self).__init__(pre=pre, post=post, **kwargs)
Expand Down Expand Up @@ -192,7 +192,7 @@ def test_neu_nodes_1():
assert len(neu.nodes(method='relative', include_self=False)) == 1


class GABAa_with_Variable(bp.TwoEndConn):
class GABAa_with_Variable(bp.synapses.TwoEndConn):
def __init__(self, pre, post, conn, delay=0., g_max=0.1, E=-75.,
alpha=12., beta=0.1, T=1.0, T_duration=1.0, **kwargs):
super(GABAa_with_Variable, self).__init__(pre=pre, post=post, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/math/object_transform/tests/test_controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def f1():
branches=[f1,
lambda: 2, lambda: 3,
lambda: 4, lambda: 5],
dyn_vars=var_a,
# dyn_vars=var_a,
show_code=True)

self.assertTrue(f(11) == 1)
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/math/object_transform/tests/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ class MyObj:
def __init__(self):
self.a = bm.Variable(bm.ones(2))

@bm.cls_jit(static_argnums=1)
@bm.cls_jit(static_argnums=0)
def f(self, b, c):
self.a.value *= b
self.a.value /= c
Expand Down
Loading
Loading