Skip to content

Commit

Permalink
Replace jax.experimental.host_callback with jax.pure_callback
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed May 6, 2024
1 parent 4b5b61f commit 3cf58ce
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 15 deletions.
3 changes: 2 additions & 1 deletion brainpy/_src/integrators/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,8 @@ def _step_fun_integrator(self, static_args, dyn_args, t, i):

# progress bar
if self.progress_bar:
id_tap(lambda *args: self._pbar.update(), ())
jax.pure_callback(lambda *args: self._pbar.update(), ())
# id_tap(lambda *args: self._pbar.update(), ())

# return of function monitors
shared = dict(t=t + self.dt, dt=self.dt, i=i)
Expand Down
6 changes: 4 additions & 2 deletions brainpy/_src/math/object_transform/controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,8 @@ def fun2scan(carry, x):
dyn_vars[k]._value = carry[k]
results = body_fun(*x, **unroll_kwargs)
if progress_bar:
id_tap(lambda *arg: bar.update(), ())
jax.pure_callback(lambda *arg: bar.update(), ())
# id_tap(lambda *arg: bar.update(), ())
return dyn_vars.dict_data(), results

if remat:
Expand Down Expand Up @@ -916,7 +917,8 @@ def fun2scan(carry, x):
dyn_vars[k]._value = dyn_vars_data[k]
carry, results = body_fun(carry, x)
if progress_bar:
id_tap(lambda *arg: bar.update(), ())
jax.pure_callback(lambda *arg: bar.update(), ())
# id_tap(lambda *arg: bar.update(), ())
carry = jax.tree_map(_as_jax_array_, carry, is_leaf=lambda a: isinstance(a, Array))
return (dyn_vars.dict_data(), carry), results

Expand Down
11 changes: 9 additions & 2 deletions brainpy/_src/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,12 +632,19 @@ def _step_func_predict(self, i, *x, shared_args=None):

# finally
if self.progress_bar:
id_tap(lambda *arg: self._pbar.update(), ())
jax.pure_callback(lambda: self._pbar.update(), ())
# id_tap(lambda *arg: self._pbar.update(), ())
# share.clear_shargs()
clear_input(self.target)

if self._memory_efficient:
id_tap(self._step_mon_on_cpu, mon)
mon_shape_dtype = jax.ShapeDtypeStruct(mon.shape, mon.dtype)
result = jax.pure_callback(
self._step_mon_on_cpu,
mon_shape_dtype,
mon,
)
# id_tap(self._step_mon_on_cpu, mon)
return out, None
else:
return out, mon
Expand Down
4 changes: 3 additions & 1 deletion brainpy/_src/train/offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import Dict, Sequence, Union, Callable, Any

import jax
import numpy as np
import tqdm.auto
from jax.experimental.host_callback import id_tap
Expand Down Expand Up @@ -219,7 +220,8 @@ def _fun_train(self,
targets = target_data[node.name]
node.offline_fit(targets, fit_record)
if self.progress_bar:
id_tap(lambda *args: self._pbar.update(), ())
jax.pure_callback(lambda *args: self._pbar.update(), ())
# id_tap(lambda *args: self._pbar.update(), ())

def _step_func_monitor(self):
res = dict()
Expand Down
5 changes: 3 additions & 2 deletions brainpy/_src/train/online.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import functools
from typing import Dict, Sequence, Union, Callable

import jax
import numpy as np
import tqdm.auto
from jax.experimental.host_callback import id_tap
Expand All @@ -23,7 +24,6 @@
'ForceTrainer',
]


class OnlineTrainer(DSTrainer):
"""Online trainer for models with recurrent dynamics.
Expand Down Expand Up @@ -252,7 +252,8 @@ def _step_func_fit(self, i, xs: Sequence, ys: Dict, shared_args=None):

# finally
if self.progress_bar:
id_tap(lambda *arg: self._pbar.update(), ())
jax.pure_callback(lambda *arg: self._pbar.update(), ())
# id_tap(lambda *arg: self._pbar.update(), ())
return out, monitors

def _check_interface(self):
Expand Down
19 changes: 12 additions & 7 deletions brainpy/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,12 +351,12 @@ def is_float(
if min_bound is not None:
jit_error_checking_no_args(value < min_bound,
ValueError(f"{name} must be a float bigger than {min_bound}, "
f"while we got {value}"))
f"while we got {value}"))

if max_bound is not None:
jit_error_checking_no_args(value > max_bound,
ValueError(f"{name} must be a float smaller than {max_bound}, "
f"while we got {value}"))
f"while we got {value}"))
return value


Expand Down Expand Up @@ -389,11 +389,11 @@ def is_integer(value: int, name=None, min_bound=None, max_bound=None, allow_none
if min_bound is not None:
jit_error_checking_no_args(jnp.any(value < min_bound),
ValueError(f"{name} must be an int bigger than {min_bound}, "
f"while we got {value}"))
f"while we got {value}"))
if max_bound is not None:
jit_error_checking_no_args(jnp.any(value > max_bound),
ValueError(f"{name} must be an int smaller than {max_bound}, "
f"while we got {value}"))
f"while we got {value}"))
return value


Expand Down Expand Up @@ -570,7 +570,12 @@ def is_all_objs(targets: Any, out_as: str = 'tuple'):


def _err_jit_true_branch(err_fun, x):
id_tap(err_fun, x)
if isinstance(x, (tuple, list)):
x_shape_dtype = tuple(jax.ShapeDtypeStruct(arr.shape, arr.dtype) for arr in x)
else:
x_shape_dtype = jax.ShapeDtypeStruct(x.shape, x.dtype)
jax.pure_callback(err_fun, x_shape_dtype, x)
# id_tap(err_fun, x)
return


Expand Down Expand Up @@ -629,6 +634,6 @@ def true_err_fun(arg, transforms):
raise err

cond(remove_vmap(as_jax(pred)),
lambda: id_tap(true_err_fun, None),
# lambda: id_tap(true_err_fun, None),
lambda: jax.pure_callback(true_err_fun, None),
lambda: None)

0 comments on commit 3cf58ce

Please sign in to comment.