diff --git a/brainpy/_src/math/object_transform/jit.py b/brainpy/_src/math/object_transform/jit.py index 394a3dd3..6c729e1d 100644 --- a/brainpy/_src/math/object_transform/jit.py +++ b/brainpy/_src/math/object_transform/jit.py @@ -146,9 +146,9 @@ def _get_transform(self, *args, **kwargs): with VariableStack() as self._dyn_vars: rets = eval_shape(self.fun, *args, + **kwargs, static_argnums=self._static_argnums, - static_argnames=self._static_argnames, - **kwargs) + static_argnames=self._static_argnames,) # in_shardings if self._in_shardings is None: in_shardings = None @@ -174,18 +174,18 @@ def _get_transform(self, *args, **kwargs): _dyn_vars_sharing = get_shardings(self._dyn_vars.subset_by_not_instance(RandomState)) out_shardings = (_dyn_vars_sharing,) + out_shardings - # jit - self._transform = jax.jit( - self._transform_function, - static_argnums=jax.tree_util.tree_map(lambda a: a + 1, self._static_argnums), - static_argnames=self._static_argnames, - donate_argnums=self._donate_argnums, - inline=self._inline, - keep_unused=self._keep_unused, - abstracted_axes=self._abstracted_axes, - in_shardings=in_shardings, - out_shardings=out_shardings, - ) + # jit + self._transform = jax.jit( + self._transform_function, + static_argnums=jax.tree_util.tree_map(lambda a: a + 1, self._static_argnums), + static_argnames=self._static_argnames, + donate_argnums=self._donate_argnums, + inline=self._inline, + keep_unused=self._keep_unused, + abstracted_axes=self._abstracted_axes, + in_shardings=in_shardings, + out_shardings=out_shardings, + ) return rets def __call__(self, *args, **kwargs):