Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Feb 22, 2024
1 parent 007bae6 commit 9d9cd01
Showing 1 changed file with 23 additions and 23 deletions.
46 changes: 23 additions & 23 deletions brainpy/_src/math/object_transform/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,31 +148,31 @@ def _get_transform(self, *args, **kwargs):
*args,
**kwargs,
static_argnums=self._static_argnums,
static_argnames=self._static_argnames,)
# in_shardings
if self._in_shardings is None:
in_shardings = None
else:
if isinstance(self._in_shardings, (tuple, list)):
in_shardings = tuple(self._in_shardings)
static_argnames=self._static_argnames)
# in_shardings
if self._in_shardings is None:
in_shardings = None
else:
in_shardings = (self._in_shardings,)
_dyn_vars_sharing = get_shardings(self._dyn_vars)
in_shardings = (_dyn_vars_sharing,) + in_shardings

# out_shardings
if self._out_shardings is None:
out_shardings = None
else:
if isinstance(self._out_shardings, (tuple, list)):
out_shardings = tuple(self._out_shardings)
if isinstance(self._in_shardings, (tuple, list)):
in_shardings = tuple(self._in_shardings)
else:
in_shardings = (self._in_shardings,)
_dyn_vars_sharing = get_shardings(self._dyn_vars)
in_shardings = (_dyn_vars_sharing,) + in_shardings

# out_shardings
if self._out_shardings is None:
out_shardings = None
else:
out_shardings = (self._out_shardings,)
global RandomState
if RandomState is None:
from brainpy.math.random import RandomState
_dyn_vars_sharing = get_shardings(self._dyn_vars.subset_by_not_instance(RandomState))
out_shardings = (_dyn_vars_sharing,) + out_shardings
if isinstance(self._out_shardings, (tuple, list)):
out_shardings = tuple(self._out_shardings)
else:
out_shardings = (self._out_shardings,)
global RandomState
if RandomState is None:
from brainpy.math.random import RandomState
_dyn_vars_sharing = get_shardings(self._dyn_vars.subset_by_not_instance(RandomState))
out_shardings = (_dyn_vars_sharing,) + out_shardings

# jit
self._transform = jax.jit(
Expand Down

0 comments on commit 9d9cd01

Please sign in to comment.