From fa6e787f0d1723bf7f3f2bd963bc6ab9fef4b505 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Sat, 2 Mar 2024 23:56:06 +0800 Subject: [PATCH] enable to set `numpy_func_return = True/False` --- brainpy/_src/_delay.py | 2 +- brainpy/_src/math/compat_numpy.py | 28 +-- brainpy/_src/math/compat_tensorflow.py | 56 +++--- brainpy/_src/math/defaults.py | 21 ++- brainpy/_src/math/environment.py | 25 ++- brainpy/_src/math/ndarray.py | 21 +-- brainpy/_src/math/others.py | 4 +- brainpy/_src/math/surrogate/_compt.py | 247 ------------------------- 8 files changed, 92 insertions(+), 312 deletions(-) delete mode 100644 brainpy/_src/math/surrogate/_compt.py diff --git a/brainpy/_src/_delay.py b/brainpy/_src/_delay.py index a646fd159..bac73e012 100644 --- a/brainpy/_src/_delay.py +++ b/brainpy/_src/_delay.py @@ -144,7 +144,7 @@ def register_entry( delay_type = 'homo' else: delay_type = 'heter' - delay_step = bm.Array(delay_step) + delay_step = delay_step elif callable(delay_step): delay_step = delay_step(self.delay_target_shape) delay_type = 'heter' diff --git a/brainpy/_src/math/compat_numpy.py b/brainpy/_src/math/compat_numpy.py index 213185df1..0eb391458 100644 --- a/brainpy/_src/math/compat_numpy.py +++ b/brainpy/_src/math/compat_numpy.py @@ -103,6 +103,10 @@ _max = max +def _return(a): + return Array(a) + + def fill_diagonal(a, val, inplace=True): if a.ndim < 2: raise ValueError(f'Only support tensor has dimension >= 2, but got {a.shape}') @@ -120,30 +124,30 @@ def fill_diagonal(a, val, inplace=True): def zeros(shape, dtype=None): - return Array(jnp.zeros(shape, dtype=dtype)) + return _return(jnp.zeros(shape, dtype=dtype)) def ones(shape, dtype=None): - return Array(jnp.ones(shape, dtype=dtype)) + return _return(jnp.ones(shape, dtype=dtype)) def empty(shape, dtype=None): - return Array(jnp.zeros(shape, dtype=dtype)) + return _return(jnp.zeros(shape, dtype=dtype)) def zeros_like(a, dtype=None, shape=None): a = _as_jax_array_(a) - return Array(jnp.zeros_like(a, dtype=dtype, shape=shape)) + return _return(jnp.zeros_like(a, dtype=dtype, shape=shape)) def ones_like(a, dtype=None, shape=None): a = _as_jax_array_(a) - return Array(jnp.ones_like(a, dtype=dtype, shape=shape)) + return _return(jnp.ones_like(a, dtype=dtype, shape=shape)) def empty_like(a, dtype=None, shape=None): a = _as_jax_array_(a) - return Array(jnp.zeros_like(a, dtype=dtype, shape=shape)) + return _return(jnp.zeros_like(a, dtype=dtype, shape=shape)) def array(a, dtype=None, copy=True, order="K", ndmin=0) -> Array: @@ -155,7 +159,7 @@ def array(a, dtype=None, copy=True, order="K", ndmin=0) -> Array: leaves = [_as_jax_array_(l) for l in leaves] a = tree_unflatten(tree, leaves) res = jnp.array(a, dtype=dtype, copy=copy, order=order, ndmin=ndmin) - return Array(res) + return _return(res) def asarray(a, dtype=None, order=None): @@ -167,13 +171,13 @@ def asarray(a, dtype=None, order=None): leaves = [_as_jax_array_(l) for l in leaves] arrays = tree_unflatten(tree, leaves) res = jnp.asarray(a=arrays, dtype=dtype, order=order) - return Array(res) + return _return(res) def arange(*args, **kwargs): args = [_as_jax_array_(a) for a in args] kwargs = {k: _as_jax_array_(v) for k, v in kwargs.items()} - return Array(jnp.arange(*args, **kwargs)) + return _return(jnp.arange(*args, **kwargs)) def linspace(*args, **kwargs): @@ -181,15 +185,15 @@ def linspace(*args, **kwargs): kwargs = {k: _as_jax_array_(v) for k, v in kwargs.items()} res = jnp.linspace(*args, **kwargs) if isinstance(res, tuple): - return Array(res[0]), res[1] + return _return(res[0]), res[1] else: - return Array(res) + return _return(res) def logspace(*args, **kwargs): args = [_as_jax_array_(a) for a in args] kwargs = {k: _as_jax_array_(v) for k, v in kwargs.items()} - return Array(jnp.logspace(*args, **kwargs)) + return _return(jnp.logspace(*args, **kwargs)) def asanyarray(a, dtype=None, order=None): diff --git a/brainpy/_src/math/compat_tensorflow.py b/brainpy/_src/math/compat_tensorflow.py index 7e9168cfa..e9e87e24c 100644 --- a/brainpy/_src/math/compat_tensorflow.py +++ b/brainpy/_src/math/compat_tensorflow.py @@ -259,13 +259,13 @@ def segment_sum(data: Union[Array, jnp.ndarray], An array with shape :code:`(num_segments,) + data.shape[1:]` representing the segment sums. """ - return Array(jax.ops.segment_sum(as_jax(data), - as_jax(segment_ids), - num_segments, - indices_are_sorted, - unique_indices, - bucket_size, - mode)) + return _return(jax.ops.segment_sum(as_jax(data), + as_jax(segment_ids), + num_segments, + indices_are_sorted, + unique_indices, + bucket_size, + mode)) def segment_prod(data: Union[Array, jnp.ndarray], @@ -311,13 +311,13 @@ def segment_prod(data: Union[Array, jnp.ndarray], An array with shape :code:`(num_segments,) + data.shape[1:]` representing the segment sums. """ - return Array(jax.ops.segment_prod(as_jax(data), - as_jax(segment_ids), - num_segments, - indices_are_sorted, - unique_indices, - bucket_size, - mode)) + return _return(jax.ops.segment_prod(as_jax(data), + as_jax(segment_ids), + num_segments, + indices_are_sorted, + unique_indices, + bucket_size, + mode)) def segment_max(data: Union[Array, jnp.ndarray], @@ -363,13 +363,13 @@ def segment_max(data: Union[Array, jnp.ndarray], An array with shape :code:`(num_segments,) + data.shape[1:]` representing the segment sums. """ - return Array(jax.ops.segment_max(as_jax(data), - as_jax(segment_ids), - num_segments, - indices_are_sorted, - unique_indices, - bucket_size, - mode)) + return _return(jax.ops.segment_max(as_jax(data), + as_jax(segment_ids), + num_segments, + indices_are_sorted, + unique_indices, + bucket_size, + mode)) def segment_min(data: Union[Array, jnp.ndarray], @@ -415,13 +415,13 @@ def segment_min(data: Union[Array, jnp.ndarray], An array with shape :code:`(num_segments,) + data.shape[1:]` representing the segment sums. """ - return Array(jax.ops.segment_min(as_jax(data), - as_jax(segment_ids), - num_segments, - indices_are_sorted, - unique_indices, - bucket_size, - mode)) + return _return(jax.ops.segment_min(as_jax(data), + as_jax(segment_ids), + num_segments, + indices_are_sorted, + unique_indices, + bucket_size, + mode)) def cast(x, dtype): diff --git a/brainpy/_src/math/defaults.py b/brainpy/_src/math/defaults.py index 9f3c50454..eab8b9b66 100644 --- a/brainpy/_src/math/defaults.py +++ b/brainpy/_src/math/defaults.py @@ -12,32 +12,37 @@ # Default computation mode. mode = NonBatchingMode() -# '''Default computation mode.''' +# Default computation mode. membrane_scaling = IdScaling() -# '''Default time step.''' +# Default time step. dt = 0.1 -# '''Default bool data type.''' +# Default bool data type. bool_ = jnp.bool_ -# '''Default integer data type.''' +# Default integer data type. int_ = jnp.int64 if config.read('jax_enable_x64') else jnp.int32 -# '''Default float data type.''' +# Default float data type. float_ = jnp.float64 if config.read('jax_enable_x64') else jnp.float32 -# '''Default complex data type.''' +# Default complex data type. complex_ = jnp.complex128 if config.read('jax_enable_x64') else jnp.complex64 # register brainpy object as pytree bp_object_as_pytree = False + +# default return array type +numpy_func_return = 'bp_array' # 'bp_array','jax_array' + + if ti is not None: - # '''Default integer data type in Taichi.''' + # Default integer data type in Taichi. ti_int = ti.int64 if config.read('jax_enable_x64') else ti.int32 - # '''Default float data type in Taichi.''' + # Default float data type in Taichi. ti_float = ti.float64 if config.read('jax_enable_x64') else ti.float32 else: diff --git a/brainpy/_src/math/environment.py b/brainpy/_src/math/environment.py index d49e70f51..1948f4a7c 100644 --- a/brainpy/_src/math/environment.py +++ b/brainpy/_src/math/environment.py @@ -169,6 +169,7 @@ def __init__( int_: type = None, bool_: type = None, bp_object_as_pytree: bool = None, + numpy_func_return: bool = None, ) -> None: super().__init__() @@ -208,6 +209,10 @@ def __init__( assert isinstance(bp_object_as_pytree, bool), '"bp_object_as_pytree" must be a bool.' self.old_bp_object_as_pytree = defaults.bp_object_as_pytree + if numpy_func_return is not None: + assert isinstance(numpy_func_return, bool), '"numpy_func_return" must be a bool.' + self.old_numpy_func_return = defaults.numpy_func_return + self.dt = dt self.mode = mode self.membrane_scaling = membrane_scaling @@ -217,6 +222,7 @@ def __init__( self.int_ = int_ self.bool_ = bool_ self.bp_object_as_pytree = bp_object_as_pytree + self.numpy_func_return = numpy_func_return def __enter__(self) -> 'environment': if self.dt is not None: set_dt(self.dt) @@ -228,6 +234,7 @@ def __enter__(self) -> 'environment': if self.complex_ is not None: set_complex(self.complex_) if self.bool_ is not None: set_bool(self.bool_) if self.bp_object_as_pytree is not None: defaults.__dict__['bp_object_as_pytree'] = self.bp_object_as_pytree + if self.numpy_func_return is not None: defaults.__dict__['numpy_func_return'] = self.numpy_func_return return self def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: @@ -240,6 +247,7 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: if self.complex_ is not None: set_complex(self.old_complex) if self.bool_ is not None: set_bool(self.old_bool) if self.bp_object_as_pytree is not None: defaults.__dict__['bp_object_as_pytree'] = self.old_bp_object_as_pytree + if self.numpy_func_return is not None: defaults.__dict__['numpy_func_return'] = self.old_numpy_func_return def clone(self): return self.__class__(dt=self.dt, @@ -250,7 +258,8 @@ def clone(self): complex_=self.complex_, float_=self.float_, int_=self.int_, - bp_object_as_pytree=self.bp_object_as_pytree) + bp_object_as_pytree=self.bp_object_as_pytree, + numpy_func_return=self.numpy_func_return) def __eq__(self, other): return id(self) == id(other) @@ -279,6 +288,7 @@ def __init__( batch_size: int = 1, membrane_scaling: scales.Scaling = None, bp_object_as_pytree: bool = None, + numpy_func_return: bool = None, ): super().__init__(dt=dt, x64=x64, @@ -288,7 +298,8 @@ def __init__( bool_=bool_, membrane_scaling=membrane_scaling, mode=modes.TrainingMode(batch_size), - bp_object_as_pytree=bp_object_as_pytree) + bp_object_as_pytree=bp_object_as_pytree, + numpy_func_return=numpy_func_return) class batching_environment(environment): @@ -315,6 +326,7 @@ def __init__( batch_size: int = 1, membrane_scaling: scales.Scaling = None, bp_object_as_pytree: bool = None, + numpy_func_return: bool = None, ): super().__init__(dt=dt, x64=x64, @@ -324,7 +336,8 @@ def __init__( bool_=bool_, mode=modes.BatchingMode(batch_size), membrane_scaling=membrane_scaling, - bp_object_as_pytree=bp_object_as_pytree) + bp_object_as_pytree=bp_object_as_pytree, + numpy_func_return=numpy_func_return) def set( @@ -337,6 +350,7 @@ def set( int_: type = None, bool_: type = None, bp_object_as_pytree: bool = None, + numpy_func_return: bool = None, ): """Set the default computation environment. @@ -360,6 +374,8 @@ def set( The bool data type. bp_object_as_pytree: bool Whether to register brainpy object as pytree. + numpy_func_return: bool + Whether to return brainpy array in all numpy functions. """ if dt is not None: assert isinstance(dt, float), '"dt" must a float.' @@ -396,6 +412,9 @@ def set( if bp_object_as_pytree is not None: defaults.__dict__['bp_object_as_pytree'] = bp_object_as_pytree + if numpy_func_return is not None: + defaults.__dict__['numpy_func_return'] = numpy_func_return + set_environment = set diff --git a/brainpy/_src/math/ndarray.py b/brainpy/_src/math/ndarray.py index cf2b2343d..791c8d9fe 100644 --- a/brainpy/_src/math/ndarray.py +++ b/brainpy/_src/math/ndarray.py @@ -10,6 +10,7 @@ from jax.tree_util import register_pytree_node_class from brainpy.errors import MathError +from . import defaults bm = None @@ -41,8 +42,8 @@ def _check_input_array(array): def _return(a): - if isinstance(a, jax.Array) and a.ndim > 0: - return Array(a) + if defaults.numpy_func_return == 'bp_array' and isinstance(a, jax.Array) and a.ndim > 0: + return Array(a) return a @@ -1087,7 +1088,7 @@ def unsqueeze(self, dim: int) -> 'Array': See :func:`brainpy.math.unsqueeze` """ - return Array(jnp.expand_dims(self.value, dim)) + return _return(jnp.expand_dims(self.value, dim)) def expand_dims(self, axis: Union[int, Sequence[int]]) -> 'Array': """ @@ -1119,7 +1120,7 @@ def expand_dims(self, axis: Union[int, Sequence[int]]) -> 'Array': self.expand_dims(axis)==self.expand_dims(axis[0]).expand_dims(axis[1])... expand_dims(axis[len(axis)-1]) """ - return Array(jnp.expand_dims(self.value, axis)) + return _return(jnp.expand_dims(self.value, axis)) def expand_as(self, array: Union['Array', jax.Array, np.ndarray]) -> 'Array': """ @@ -1136,9 +1137,7 @@ def expand_as(self, array: Union['Array', jax.Array, np.ndarray]) -> 'Array': typically not contiguous. Furthermore, more than one element of a expanded array may refer to a single memory location. """ - if not isinstance(array, Array): - array = Array(array) - return Array(jnp.broadcast_to(self.value, array.value.shape)) + return _return(jnp.broadcast_to(self.value, array)) def pow(self, index: int): return _return(self.value ** index) @@ -1228,7 +1227,7 @@ def absolute_(self): return self.abs_() def mul(self, value): - return Array(self.value * value) + return _return(self.value * value) def mul_(self, value): """ @@ -1404,7 +1403,7 @@ def clip_(self, return self def clone(self) -> 'Array': - return Array(self.value.copy()) + return _return(self.value.copy()) def copy_(self, src: Union['Array', jax.Array, np.ndarray]) -> 'Array': self.value = jnp.copy(_as_jax_array_(src)) @@ -1423,7 +1422,7 @@ def cov_with( fweights = _as_jax_array_(fweights) aweights = _as_jax_array_(aweights) r = jnp.cov(self.value, y, rowvar, bias, fweights, aweights) - return Array(r) + return _return(r) def expand(self, *sizes) -> 'Array': """ @@ -1459,7 +1458,7 @@ def expand(self, *sizes) -> 'Array': raise ValueError( f'The expanded size of the tensor ({sizes_list[base + i]}) must match the existing size ({v}) at non-singleton ' f'dimension {i}. Target sizes: {sizes}. Tensor sizes: {self.shape}') - return Array(jnp.broadcast_to(self.value, sizes_list)) + return _return(jnp.broadcast_to(self.value, sizes_list)) def tree_flatten(self): return (self.value,), None diff --git a/brainpy/_src/math/others.py b/brainpy/_src/math/others.py index 94aeebb16..59588d3b9 100644 --- a/brainpy/_src/math/others.py +++ b/brainpy/_src/math/others.py @@ -11,7 +11,7 @@ from .compat_numpy import fill_diagonal from .environment import get_dt, get_int from .interoperability import as_jax -from .ndarray import Array +from .ndarray import Array, _return __all__ = [ 'shared_args_over_time', @@ -79,7 +79,7 @@ def remove_diag(arr): """ if arr.ndim != 2: raise ValueError(f'Only support 2D matrix, while we got a {arr.ndim}D array.') - eyes = Array(jnp.ones(arr.shape, dtype=bool)) + eyes = _return(jnp.ones(arr.shape, dtype=bool)) fill_diagonal(eyes, False) return jnp.reshape(arr[eyes.value], (arr.shape[0], arr.shape[1] - 1)) diff --git a/brainpy/_src/math/surrogate/_compt.py b/brainpy/_src/math/surrogate/_compt.py deleted file mode 100644 index 67b7d5158..000000000 --- a/brainpy/_src/math/surrogate/_compt.py +++ /dev/null @@ -1,247 +0,0 @@ -# -*- coding: utf-8 -*- - -import warnings - -from jax import custom_gradient, numpy as jnp - -from brainpy._src.math.compat_numpy import asarray -from brainpy._src.math.interoperability import as_jax -from brainpy._src.math.environment import get_float -from brainpy._src.math.ndarray import Array - -__all__ = [ - 'spike_with_sigmoid_grad', - 'spike_with_linear_grad', - 'spike_with_gaussian_grad', - 'spike_with_mg_grad', - - 'spike2_with_sigmoid_grad', - 'spike2_with_linear_grad', -] - - -def _consistent_type(target, compare): - return as_jax(target) if not isinstance(compare, Array) else asarray(target) - - -@custom_gradient -def spike_with_sigmoid_grad(x: Array, scale: float = 100.): - """Spike function with the sigmoid surrogate gradient. - - .. deprecated:: 2.3.1 - Please use ``brainpy.math.surrogate.sigmoid_grad()`` instead. - Will be removed after version 2.4.0. - - Parameters - ---------- - x: Array - The input data. - scale: float - The scaling factor. - """ - warnings.warn('Use `brainpy.math.surrogate.inv_square_grad()` instead.', UserWarning) - - x = as_jax(x) - z = jnp.asarray(x >= 0, dtype=get_float()) - - def grad(dE_dz): - dE_dz = as_jax(dE_dz) - dE_dx = dE_dz / (scale * jnp.abs(x) + 1.0) ** 2 - if scale is None: - return (_consistent_type(dE_dx, x),) - else: - dscale = jnp.zeros_like(scale) - return (dE_dx, dscale) - - return z, grad - - -@custom_gradient -def spike2_with_sigmoid_grad(x_new: Array, x_old: Array, scale: float = None): - """Spike function with the sigmoid surrogate gradient. - - .. deprecated:: 2.3.1 - Please use ``brainpy.math.surrogate.inv_square_grad2()`` instead. - Will be removed after version 2.4.0. - - Parameters - ---------- - x_new: Array - The input data. - x_old: Array - The input data. - scale: optional, float - The scaling factor. - """ - warnings.warn('Use `brainpy.math.surrogate.inv_square_grad2()` instead.', UserWarning) - - x_new_comp = x_new >= 0 - x_old_comp = x_old < 0 - z = jnp.asarray(jnp.logical_and(x_new_comp, x_old_comp), dtype=get_float()) - - def grad(dE_dz): - _scale = 100. if scale is None else scale - dx_new = (dE_dz / (_scale * jnp.abs(x_new) + 1.0) ** 2) * jnp.asarray(x_old_comp, dtype=get_float()) - dx_old = -(dE_dz / (_scale * jnp.abs(x_old) + 1.0) ** 2) * jnp.asarray(x_new_comp, dtype=get_float()) - if scale is None: - return (_consistent_type(dx_new, x_new), - _consistent_type(dx_old, x_old)) - else: - dscale = jnp.zeros_like(_scale) - return (_consistent_type(dx_new, x_new), - _consistent_type(dx_old, x_old), - _consistent_type(dscale, scale)) - - return z, grad - - -@custom_gradient -def spike_with_linear_grad(x: Array, scale: float = None): - """Spike function with the relu surrogate gradient. - - .. deprecated:: 2.3.1 - Please use ``brainpy.math.surrogate.relu_grad()`` instead. - Will be removed after version 2.4.0. - - Parameters - ---------- - x: Array - The input data. - scale: float - The scaling factor. - """ - - warnings.warn('Use `brainpy.math.surrogate.relu_grad()` instead.', UserWarning) - - z = jnp.asarray(x >= 0., dtype=get_float()) - - def grad(dE_dz): - _scale = 0.3 if scale is None else scale - dE_dx = dE_dz * jnp.maximum(1 - jnp.abs(x), 0) * _scale - if scale is None: - return (_consistent_type(dE_dx, x),) - else: - dscale = jnp.zeros_like(_scale) - return (_consistent_type(dE_dx, x), _consistent_type(dscale, _scale)) - - return z, grad - - -@custom_gradient -def spike2_with_linear_grad(x_new: Array, x_old: Array, scale: float = 10.): - """Spike function with the linear surrogate gradient. - - .. deprecated:: 2.3.1 - Please use ``brainpy.math.surrogate.relu_grad2()`` instead. - Will be removed after version 2.4.0. - - Parameters - ---------- - x_new: Array - The input data. - x_old: Array - The input data. - scale: float - The scaling factor. - """ - warnings.warn('Use `brainpy.math.surrogate.relu_grad2()` instead.', UserWarning) - - x_new_comp = x_new >= 0 - x_old_comp = x_old < 0 - z = jnp.asarray(jnp.logical_and(x_new_comp, x_old_comp), dtype=get_float()) - - def grad(dE_dz): - _scale = 0.3 if scale is None else scale - dx_new = (dE_dz * jnp.maximum(1 - jnp.abs(x_new), 0) * _scale) * jnp.asarray(x_old_comp, dtype=get_float()) - dx_old = -(dE_dz * jnp.maximum(1 - jnp.abs(x_old), 0) * _scale) * jnp.asarray(x_new_comp, dtype=get_float()) - if scale is None: - return (_consistent_type(dx_new, x_new), - _consistent_type(dx_old, x_old)) - else: - dscale = jnp.zeros_like(_scale) - return (_consistent_type(dx_new, x_new), - _consistent_type(dx_old, x_old), - _consistent_type(dscale, scale)) - - return z, grad - - -def _gaussian(x, mu, sigma): - return jnp.exp(-((x - mu) ** 2) / (2 * sigma ** 2)) / jnp.sqrt(2 * jnp.pi) / sigma - - -@custom_gradient -def spike_with_gaussian_grad(x, sigma=None, scale=None): - """Spike function with the Gaussian surrogate gradient. - - .. deprecated:: 2.3.1 - Please use ``brainpy.math.surrogate.gaussian_grad()`` instead. - Will be removed after version 2.4.0. - - """ - - warnings.warn('Use `brainpy.math.surrogate.gaussian_grad()` instead.', UserWarning) - - z = jnp.asarray(x >= 0., dtype=get_float()) - - def grad(dE_dz): - _scale = 0.5 if scale is None else scale - _sigma = 0.5 if sigma is None else sigma - dE_dx = dE_dz * _gaussian(x, 0., _sigma) * _scale - returns = (_consistent_type(dE_dx, x),) - if sigma is not None: - returns += (_consistent_type(jnp.zeros_like(_sigma), sigma),) - if scale is not None: - returns += (_consistent_type(jnp.zeros_like(_scale), scale),) - return returns - - return z, grad - - -@custom_gradient -def spike_with_mg_grad(x, h=None, s=None, sigma=None, scale=None): - """Spike function with the multi-Gaussian surrogate gradient. - - .. deprecated:: 2.3.1 - Please use ``brainpy.math.surrogate.multi_sigmoid_grad()`` instead. - Will be removed after version 2.4.0. - - Parameters - ---------- - x: ndarray - The variable to judge spike. - h: float - The hyper-parameters of approximate function - s: float - The hyper-parameters of approximate function - sigma: float - The gaussian sigma. - scale: float - The gradient scale. - """ - - warnings.warn('Use `brainpy.math.surrogate.multi_sigmoid_grad()` instead.', UserWarning) - - z = jnp.asarray(x >= 0., dtype=get_float()) - - def grad(dE_dz): - _sigma = 0.5 if sigma is None else sigma - _scale = 0.5 if scale is None else scale - _s = 6.0 if s is None else s - _h = 0.15 if h is None else h - dE_dx = dE_dz * (_gaussian(x, mu=0., sigma=_sigma) * (1. + _h) - - _gaussian(x, mu=_sigma, sigma=_s * _sigma) * _h - - _gaussian(x, mu=-_sigma, sigma=_s * _sigma) * _h) * _scale - returns = (_consistent_type(dE_dx, x),) - if h is not None: - returns += (_consistent_type(jnp.zeros_like(_h), h),) - if s is not None: - returns += (_consistent_type(jnp.zeros_like(_s), s),) - if sigma is not None: - returns += (_consistent_type(jnp.zeros_like(_sigma), sigma),) - if scale is not None: - returns += (_consistent_type(jnp.zeros_like(_scale), scale),) - return returns - - return z, grad -