Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Feb 21, 2024
1 parent da8807c commit 05784b9
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions brainpy/_src/math/object_transform/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,9 @@ def _get_transform(self, *args, **kwargs):
with VariableStack() as self._dyn_vars:
rets = eval_shape(self.fun,
*args,
**kwargs,
static_argnums=self._static_argnums,
static_argnames=self._static_argnames,
**kwargs)
static_argnames=self._static_argnames,)
# in_shardings
if self._in_shardings is None:
in_shardings = None
Expand All @@ -174,18 +174,18 @@ def _get_transform(self, *args, **kwargs):
_dyn_vars_sharing = get_shardings(self._dyn_vars.subset_by_not_instance(RandomState))
out_shardings = (_dyn_vars_sharing,) + out_shardings

# jit
self._transform = jax.jit(
self._transform_function,
static_argnums=jax.tree_util.tree_map(lambda a: a + 1, self._static_argnums),
static_argnames=self._static_argnames,
donate_argnums=self._donate_argnums,
inline=self._inline,
keep_unused=self._keep_unused,
abstracted_axes=self._abstracted_axes,
in_shardings=in_shardings,
out_shardings=out_shardings,
)
# jit
self._transform = jax.jit(
self._transform_function,
static_argnums=jax.tree_util.tree_map(lambda a: a + 1, self._static_argnums),
static_argnames=self._static_argnames,
donate_argnums=self._donate_argnums,
inline=self._inline,
keep_unused=self._keep_unused,
abstracted_axes=self._abstracted_axes,
in_shardings=in_shardings,
out_shardings=out_shardings,
)
return rets

def __call__(self, *args, **kwargs):
Expand Down

0 comments on commit 05784b9

Please sign in to comment.