From 9d9cd0123e22637eb4045889570c563cbc34326a Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Thu, 22 Feb 2024 10:20:24 +0800 Subject: [PATCH] updates --- brainpy/_src/math/object_transform/jit.py | 46 +++++++++++------------ 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/brainpy/_src/math/object_transform/jit.py b/brainpy/_src/math/object_transform/jit.py index 6c729e1d..73eab2f9 100644 --- a/brainpy/_src/math/object_transform/jit.py +++ b/brainpy/_src/math/object_transform/jit.py @@ -148,31 +148,31 @@ def _get_transform(self, *args, **kwargs): *args, **kwargs, static_argnums=self._static_argnums, - static_argnames=self._static_argnames,) - # in_shardings - if self._in_shardings is None: - in_shardings = None - else: - if isinstance(self._in_shardings, (tuple, list)): - in_shardings = tuple(self._in_shardings) + static_argnames=self._static_argnames) + # in_shardings + if self._in_shardings is None: + in_shardings = None else: - in_shardings = (self._in_shardings,) - _dyn_vars_sharing = get_shardings(self._dyn_vars) - in_shardings = (_dyn_vars_sharing,) + in_shardings - - # out_shardings - if self._out_shardings is None: - out_shardings = None - else: - if isinstance(self._out_shardings, (tuple, list)): - out_shardings = tuple(self._out_shardings) + if isinstance(self._in_shardings, (tuple, list)): + in_shardings = tuple(self._in_shardings) + else: + in_shardings = (self._in_shardings,) + _dyn_vars_sharing = get_shardings(self._dyn_vars) + in_shardings = (_dyn_vars_sharing,) + in_shardings + + # out_shardings + if self._out_shardings is None: + out_shardings = None else: - out_shardings = (self._out_shardings,) - global RandomState - if RandomState is None: - from brainpy.math.random import RandomState - _dyn_vars_sharing = get_shardings(self._dyn_vars.subset_by_not_instance(RandomState)) - out_shardings = (_dyn_vars_sharing,) + out_shardings + if isinstance(self._out_shardings, (tuple, list)): + out_shardings = tuple(self._out_shardings) + else: + out_shardings = (self._out_shardings,) + global RandomState + if RandomState is None: + from brainpy.math.random import RandomState + _dyn_vars_sharing = get_shardings(self._dyn_vars.subset_by_not_instance(RandomState)) + out_shardings = (_dyn_vars_sharing,) + out_shardings # jit self._transform = jax.jit(