diff --git a/brainpy/_src/integrators/runner.py b/brainpy/_src/integrators/runner.py index 11dd42f5..dae638e1 100644 --- a/brainpy/_src/integrators/runner.py +++ b/brainpy/_src/integrators/runner.py @@ -9,7 +9,6 @@ import jax.numpy as jnp import numpy as np import tqdm.auto -from jax.experimental.host_callback import id_tap from jax.tree_util import tree_flatten from brainpy import math as bm @@ -245,7 +244,7 @@ def _step_fun_integrator(self, static_args, dyn_args, t, i): # progress bar if self.progress_bar: - id_tap(lambda *args: self._pbar.update(), ()) + jax.pure_callback(lambda *args: self._pbar.update(), ()) # return of function monitors shared = dict(t=t + self.dt, dt=self.dt, i=i) diff --git a/brainpy/_src/math/object_transform/controls.py b/brainpy/_src/math/object_transform/controls.py index 3edeb08e..126ca15c 100644 --- a/brainpy/_src/math/object_transform/controls.py +++ b/brainpy/_src/math/object_transform/controls.py @@ -7,7 +7,6 @@ import jax import jax.numpy as jnp from jax.errors import UnexpectedTracerError -from jax.experimental.host_callback import id_tap from jax.tree_util import tree_flatten, tree_unflatten from tqdm.auto import tqdm @@ -421,14 +420,14 @@ def call(pred, x=None): def _warp(f): @functools.wraps(f) def new_f(*args, **kwargs): - return jax.tree_map(_as_jax_array_, f(*args, **kwargs), is_leaf=lambda a: isinstance(a, Array)) + return jax.tree.map(_as_jax_array_, f(*args, **kwargs), is_leaf=lambda a: isinstance(a, Array)) return new_f def _warp_data(data): def new_f(*args, **kwargs): - return jax.tree_map(_as_jax_array_, data, is_leaf=lambda a: isinstance(a, Array)) + return jax.tree.map(_as_jax_array_, data, is_leaf=lambda a: isinstance(a, Array)) return new_f @@ -727,7 +726,7 @@ def fun2scan(carry, x): dyn_vars[k]._value = carry[k] results = body_fun(*x, **unroll_kwargs) if progress_bar: - id_tap(lambda *arg: bar.update(), ()) + jax.pure_callback(lambda *arg: bar.update(), ()) return dyn_vars.dict_data(), results if remat: @@ -916,15 +915,15 @@ def fun2scan(carry, x): dyn_vars[k]._value = dyn_vars_data[k] carry, results = body_fun(carry, x) if progress_bar: - id_tap(lambda *arg: bar.update(), ()) - carry = jax.tree_map(_as_jax_array_, carry, is_leaf=lambda a: isinstance(a, Array)) + jax.pure_callback(lambda *arg: bar.update(), ()) + carry = jax.tree.map(_as_jax_array_, carry, is_leaf=lambda a: isinstance(a, Array)) return (dyn_vars.dict_data(), carry), results if remat: fun2scan = jax.checkpoint(fun2scan) def call(init, operands): - init = jax.tree_map(_as_jax_array_, init, is_leaf=lambda a: isinstance(a, Array)) + init = jax.tree.map(_as_jax_array_, init, is_leaf=lambda a: isinstance(a, Array)) return jax.lax.scan(f=fun2scan, init=(dyn_vars.dict_data(), init), xs=operands, diff --git a/brainpy/_src/math/random.py b/brainpy/_src/math/random.py index 9ae012bc..2b647683 100644 --- a/brainpy/_src/math/random.py +++ b/brainpy/_src/math/random.py @@ -10,7 +10,6 @@ import numpy as np from jax import lax, jit, vmap, numpy as jnp, random as jr, core, dtypes from jax._src.array import ArrayImpl -from jax.experimental.host_callback import call from jax.tree_util import register_pytree_node_class from brainpy.check import jit_error_checking, jit_error_checking_no_args @@ -1233,9 +1232,9 @@ def zipf(self, a, size: Optional[Union[int, Sequence[int]]] = None, key: Optiona if size is None: size = jnp.shape(a) dtype = jax.dtypes.canonicalize_dtype(jnp.int_) - r = call(lambda x: np.random.zipf(x, size).astype(dtype), - a, - result_shape=jax.ShapeDtypeStruct(size, dtype)) + r = jax.pure_callback(lambda x: np.random.zipf(x, size).astype(dtype), + jax.ShapeDtypeStruct(size, dtype), + a) return _return(r) def power(self, a, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): @@ -1244,9 +1243,9 @@ def power(self, a, size: Optional[Union[int, Sequence[int]]] = None, key: Option size = jnp.shape(a) size = _size2shape(size) dtype = jax.dtypes.canonicalize_dtype(jnp.float_) - r = call(lambda a: np.random.power(a=a, size=size).astype(dtype), - a, - result_shape=jax.ShapeDtypeStruct(size, dtype)) + r = jax.pure_callback(lambda a: np.random.power(a=a, size=size).astype(dtype), + jax.ShapeDtypeStruct(size, dtype), + a) return _return(r) def f(self, dfnum, dfden, size: Optional[Union[int, Sequence[int]]] = None, @@ -1260,11 +1259,11 @@ def f(self, dfnum, dfden, size: Optional[Union[int, Sequence[int]]] = None, size = _size2shape(size) d = {'dfnum': dfnum, 'dfden': dfden} dtype = jax.dtypes.canonicalize_dtype(jnp.float_) - r = call(lambda x: np.random.f(dfnum=x['dfnum'], - dfden=x['dfden'], - size=size).astype(dtype), - d, - result_shape=jax.ShapeDtypeStruct(size, dtype)) + r = jax.pure_callback(lambda x: np.random.f(dfnum=x['dfnum'], + dfden=x['dfden'], + size=size).astype(dtype), + jax.ShapeDtypeStruct(size, dtype), + d) return _return(r) def hypergeometric(self, ngood, nbad, nsample, size: Optional[Union[int, Sequence[int]]] = None, @@ -1280,12 +1279,12 @@ def hypergeometric(self, ngood, nbad, nsample, size: Optional[Union[int, Sequenc size = _size2shape(size) dtype = jax.dtypes.canonicalize_dtype(jnp.int_) d = {'ngood': ngood, 'nbad': nbad, 'nsample': nsample} - r = call(lambda d: np.random.hypergeometric(ngood=d['ngood'], - nbad=d['nbad'], - nsample=d['nsample'], - size=size).astype(dtype), - d, - result_shape=jax.ShapeDtypeStruct(size, dtype)) + r = jax.pure_callback(lambda d: np.random.hypergeometric(ngood=d['ngood'], + nbad=d['nbad'], + nsample=d['nsample'], + size=size).astype(dtype), + jax.ShapeDtypeStruct(size, dtype), + d) return _return(r) def logseries(self, p, size: Optional[Union[int, Sequence[int]]] = None, @@ -1295,9 +1294,9 @@ def logseries(self, p, size: Optional[Union[int, Sequence[int]]] = None, size = jnp.shape(p) size = _size2shape(size) dtype = jax.dtypes.canonicalize_dtype(jnp.int_) - r = call(lambda p: np.random.logseries(p=p, size=size).astype(dtype), - p, - result_shape=jax.ShapeDtypeStruct(size, dtype)) + r = jax.pure_callback(lambda p: np.random.logseries(p=p, size=size).astype(dtype), + jax.ShapeDtypeStruct(size, dtype), + p) return _return(r) def noncentral_f(self, dfnum, dfden, nonc, size: Optional[Union[int, Sequence[int]]] = None, @@ -1312,11 +1311,12 @@ def noncentral_f(self, dfnum, dfden, nonc, size: Optional[Union[int, Sequence[in size = _size2shape(size) d = {'dfnum': dfnum, 'dfden': dfden, 'nonc': nonc} dtype = jax.dtypes.canonicalize_dtype(jnp.float_) - r = call(lambda x: np.random.noncentral_f(dfnum=x['dfnum'], - dfden=x['dfden'], - nonc=x['nonc'], - size=size).astype(dtype), - d, result_shape=jax.ShapeDtypeStruct(size, dtype)) + r = jax.pure_callback(lambda x: np.random.noncentral_f(dfnum=x['dfnum'], + dfden=x['dfden'], + nonc=x['nonc'], + size=size).astype(dtype), + jax.ShapeDtypeStruct(size, dtype), + d) return _return(r) # PyTorch compatibility # diff --git a/brainpy/_src/runners.py b/brainpy/_src/runners.py index 980ef998..80609608 100644 --- a/brainpy/_src/runners.py +++ b/brainpy/_src/runners.py @@ -11,7 +11,6 @@ import jax.numpy as jnp import numpy as np import tqdm.auto -from jax.experimental.host_callback import id_tap from jax.tree_util import tree_map, tree_flatten from brainpy import math as bm, tools @@ -632,12 +631,17 @@ def _step_func_predict(self, i, *x, shared_args=None): # finally if self.progress_bar: - id_tap(lambda *arg: self._pbar.update(), ()) + jax.pure_callback(lambda: self._pbar.update(), ()) # share.clear_shargs() clear_input(self.target) if self._memory_efficient: - id_tap(self._step_mon_on_cpu, mon) + mon_shape_dtype = jax.ShapeDtypeStruct(mon.shape, mon.dtype) + result = jax.pure_callback( + self._step_mon_on_cpu, + mon_shape_dtype, + mon, + ) return out, None else: return out, mon diff --git a/brainpy/_src/train/offline.py b/brainpy/_src/train/offline.py index 2bfa419d..e801a29e 100644 --- a/brainpy/_src/train/offline.py +++ b/brainpy/_src/train/offline.py @@ -2,9 +2,9 @@ from typing import Dict, Sequence, Union, Callable, Any +import jax import numpy as np import tqdm.auto -from jax.experimental.host_callback import id_tap import brainpy.math as bm from brainpy import tools @@ -219,7 +219,7 @@ def _fun_train(self, targets = target_data[node.name] node.offline_fit(targets, fit_record) if self.progress_bar: - id_tap(lambda *args: self._pbar.update(), ()) + jax.pure_callback(lambda *args: self._pbar.update(), ()) def _step_func_monitor(self): res = dict() diff --git a/brainpy/_src/train/online.py b/brainpy/_src/train/online.py index d80764f2..862db8df 100644 --- a/brainpy/_src/train/online.py +++ b/brainpy/_src/train/online.py @@ -2,9 +2,9 @@ import functools from typing import Dict, Sequence, Union, Callable +import jax import numpy as np import tqdm.auto -from jax.experimental.host_callback import id_tap from jax.tree_util import tree_map from brainpy import math as bm, tools @@ -252,7 +252,7 @@ def _step_func_fit(self, i, xs: Sequence, ys: Dict, shared_args=None): # finally if self.progress_bar: - id_tap(lambda *arg: self._pbar.update(), ()) + jax.pure_callback(lambda *arg: self._pbar.update(), ()) return out, monitors def _check_interface(self): diff --git a/brainpy/check.py b/brainpy/check.py index fafc0551..1f809d84 100644 --- a/brainpy/check.py +++ b/brainpy/check.py @@ -7,7 +7,6 @@ import numpy as np import numpy as onp from jax import numpy as jnp -from jax.experimental.host_callback import id_tap from jax.lax import cond conn = None @@ -570,7 +569,11 @@ def is_all_objs(targets: Any, out_as: str = 'tuple'): def _err_jit_true_branch(err_fun, x): - id_tap(err_fun, x) + if isinstance(x, (tuple, list)): + x_shape_dtype = tuple(jax.ShapeDtypeStruct(arr.shape, arr.dtype) for arr in x) + else: + x_shape_dtype = jax.ShapeDtypeStruct(x.shape, x.dtype) + jax.pure_callback(err_fun, x_shape_dtype, x) return @@ -629,6 +632,6 @@ def true_err_fun(arg, transforms): raise err cond(remove_vmap(as_jax(pred)), - lambda: id_tap(true_err_fun, None), + lambda: jax.pure_callback(true_err_fun, None), lambda: None)