diff --git a/brainpy/__init__.py b/brainpy/__init__.py index c52358720..c8f834c6d 100644 --- a/brainpy/__init__.py +++ b/brainpy/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -__version__ = "2.4.6.post4" +__version__ = "2.4.6.post5" # fundamental supporting modules from brainpy import errors, check, tools diff --git a/brainpy/_src/math/ndarray.py b/brainpy/_src/math/ndarray.py index 61746c038..81ccba932 100644 --- a/brainpy/_src/math/ndarray.py +++ b/brainpy/_src/math/ndarray.py @@ -108,6 +108,15 @@ def sharding(self): def addressable_shards(self): return self._value.addressable_shards + def __check_shape_dtype(self, other_value, self_value): + if other_value.shape != self_value.shape: + raise MathError(f"The shape of the original data is {self_value.shape}, " + f"while we got {other_value.shape}.") + if other_value.dtype != self_value.dtype: + raise MathError(f"The dtype of the original data is {self_value.dtype}, " + f"while we got {other_value.dtype}.") + + @property def value(self): # return the value @@ -126,12 +135,8 @@ def value(self, value): else: value = jnp.asarray(value) # check - if value.shape != self_value.shape: - raise MathError(f"The shape of the original data is {self_value.shape}, " - f"while we got {value.shape}.") - if value.dtype != self_value.dtype: - raise MathError(f"The dtype of the original data is {self_value.dtype}, " - f"while we got {value.dtype}.") + self.__check_shape_dtype(value, self_value) + # assign self._value = value def update(self, value): @@ -1569,12 +1574,8 @@ def value(self, value): else: value = jnp.asarray(value) # check - if value.shape != self_value.shape: - raise MathError(f"The shape of the original data is {self_value.shape}, " - f"while we got {value.shape}.") - if value.dtype != self_value.dtype: - raise MathError(f"The dtype of the original data is {self_value.dtype}, " - f"while we got {value.dtype}.") + self.__check_shape_dtype(value, self_value) + # assign self._value = value diff --git a/brainpy/_src/math/object_transform/variables.py b/brainpy/_src/math/object_transform/variables.py index 5014da0bf..35eea1774 100644 --- a/brainpy/_src/math/object_transform/variables.py +++ b/brainpy/_src/math/object_transform/variables.py @@ -240,14 +240,6 @@ def current_transform_number(): return len(transform_stack) -def _stack_add_read(var: 'Variable'): - pass - - -def _stack_add_write(var: 'Variable'): - pass - - @register_pytree_node_class class Variable(Array): """The pointer to specify the dynamical variable.