diff --git a/brainpy/_src/integrators/ode/exponential.py b/brainpy/_src/integrators/ode/exponential.py index e44e324e7..931d140c6 100644 --- a/brainpy/_src/integrators/ode/exponential.py +++ b/brainpy/_src/integrators/ode/exponential.py @@ -105,26 +105,137 @@ .. [2] Hochbruck, M., & Ostermann, A. (2010). Exponential integrators. Acta Numerica, 19, 209-286. """ -from functools import wraps +from functools import wraps, partial + +from jax.tree_util import tree_map + from brainpy import errors from brainpy._src import math as bm from brainpy._src.integrators import constants as C, utils, joint_eq from brainpy._src.integrators.ode.base import ODEIntegrator +from brainpy._src.math.object_transform.autograd import vjp_for_exp_euler, tree_as_jax from .generic import register_ode_integrator - __all__ = [ - 'ExponentialEuler', + 'ExponentialEuler', 'ExponentialEulerMultiODE', ] -class ExponentialEuler(ODEIntegrator): - """Exponential Euler method using automatic differentiation. +class _ExponentialEulerFamily(ODEIntegrator): + r"""Exponential Euler method using automatic differentiation. + + Parameters + ---------- + f : function, joint_eq.JointEq + The derivative function. + var_type : optional, str + The variable type. + dt : optional, float + The default numerical integration step. + name : optional, str + The integrator name. + """ - This method uses `brainpy.math.vector_grad <../../math/generated/brainpy.math.autograd.vector_grad.html>`_ + def __init__( + self, + f, + var_type=None, + dt=None, + name=None, + show_code=False, + state_delays=None, + neutral_delays=None + ): + super().__init__(f=f, + var_type=var_type, + dt=dt, + name=name, + show_code=show_code, + state_delays=state_delays, + neutral_delays=neutral_delays) + + if var_type == C.SYSTEM_VAR: + raise NotImplementedError(f'{self.__class__.__name__} does not support {C.SYSTEM_VAR}, ' + f'because the auto-differentiation ') + + # build the integrator + self.code_lines = [] + self.code_scope = {} + self.integral = self.build() + + def build(self): + parses = self._build_integrator(self.f) + all_vps = self.variables + self.parameters + + @wraps(self.f) + def integral_func(*args, **kwargs): + # format arguments + params_in = bm.Collector() + for i, arg in enumerate(args): + params_in[all_vps[i]] = arg + params_in.update(kwargs) + if C.DT not in params_in: + params_in[C.DT] = self.dt + + # call integrals + results = [] + for i, parse in enumerate(parses): + f_integral, vars_, pars_ = parse + vps = vars_ + pars_ + [C.DT] + r = f_integral(params_in[vps[0]], **{arg: params_in[arg] for arg in vps[1:] if arg in params_in}) + results.append(r) + return results if len(self.variables) > 1 else results[0] + + return integral_func + + def _build_integrator(self, eq): + if isinstance(eq, joint_eq.JointEq): + results = [] + for sub_eq in eq.eqs: + results.extend(self._build_integrator(sub_eq)) + return results + else: + vars, pars, _ = utils.get_args(eq) + + # checking + if len(vars) != 1: + raise errors.DiffEqError(C.multi_vars_msg.format(cls=self.__class__.__name__, + vars=str(vars), + eq=str(eq))) + return [(partial(self._make_integral, eq), vars, pars), ] + + def _make_integral(self, eq, v, *args, **kwargs): + raise NotImplementedError + + +class ExponentialEuler(_ExponentialEulerFamily): + r"""Exponential Euler method using automatic differentiation. + + This method uses `brainpy.math.vector_grad <../../apis/generated/brainpy.math.vector_grad.html>`_ to automatically infer the linear part of the given function. Therefore, it has minimal constraints on your derivative function. Arbitrary complex functions can be numerically integrated with this method. + The simplest exponential Rosenbrock method is the exponential + Rosenbrock–Euler scheme, which has order 2. + + For an ODE equation of the form + + .. math:: + + u^{\prime}=f(u), \quad u(0)=u_{0} + + its schema is given by + + .. math:: + + u_{n+1}= u_{n}+h \varphi(hL) f (u_{n}) + + where :math:`L=f^{\prime}(u_{n})` and :math:`\varphi(z)=\frac{e^{z}-1}{z}`. + + For a linear ODE system: :math:`u^{\prime} = Ay + B`, + the above equation is equal to :math:`u_{n+1}= u_{n}e^{hA}-B/A(1-e^{hA})`, + which is the exact solution for this ODE system. + Examples -------- @@ -187,10 +298,10 @@ class ExponentialEuler(ODEIntegrator): >>> >>> return dVdt >>> - >>> def update(self, tdi): - >>> h = self.int_h(self.h, tdi.t, self.V, dt=tdi.dt) - >>> n = self.int_n(self.n, tdi.t, self.V, dt=tdi.dt) - >>> V = self.int_V(self.V, tdi.t, self.h, self.n, self.input, dt=tdi.dt) + >>> def update(self): + >>> h = self.int_h(self.h, None, self.V) + >>> n = self.int_n(self.n, None, self.V) + >>> V = self.int_V(self.V, None, self.h, self.n, self.input.value) >>> self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th) >>> self.V.value = V >>> self.h.value = h @@ -256,11 +367,10 @@ class ExponentialEuler(ODEIntegrator): >>> IK = self.gK * n ** 4 * (V - self.EK) >>> IL = self.gL * (V - self.EL) >>> dVdt = (- INa - IK - IL + Iext) / self.C - >>> >>> return dVdt >>> - >>> def update(self, tdi): - >>> h, n, V = self.integral(self.h, self.n, self.V, tdi.t, self.input, dt=tdi.dt) + >>> def update(self): + >>> h, n, V = self.integral(self.h, self.n, self.V, None, self.input.value) >>> self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th) >>> self.V.value = V >>> self.h.value = h @@ -283,6 +393,163 @@ class ExponentialEuler(ODEIntegrator): The integrator name. """ + def _make_integral(self, eq, v, *args, **kwargs): + dt = kwargs.pop(C.DT, self.dt) + linear, dy = bm.vector_grad(eq, argnums=0, return_value=True)(v, *args, **kwargs) + phi = bm.exprel(dt * linear) + return v + dt * phi * dy + + +register_ode_integrator('exponential_euler', ExponentialEuler) +register_ode_integrator('exp_euler', ExponentialEuler) +register_ode_integrator('exp_euler_auto', ExponentialEuler) +register_ode_integrator('exp_auto', ExponentialEuler) + + +class ExponentialRosenbrock32(_ExponentialEulerFamily): + """A class of third-order exponential Rosenbrock methods was derived in Hochbruck et al. (2009), named as ``exprb32``. + + The method is given by + + .. math:: + + u_{n2} = u_n + \Delta t \phi_1(\Delta t L) f(u_n) \\ + u_{n+1} = u_n + \Delta t \phi_1(\Delta t L) f(u_n) + \Delta t 2 \phi_3(\Delta t L) (N_n(u_{n2}) - N_n(u_n)) \\ + N_n(u_n) = f(u_n) - f'(u_n) u_n \\ + \phi_1 = \frac{e^z - 1}{z} \\ + \phi_3 = \frac{e^z - 1 - z - \frac{1}{2} z^2}{z^3} + + """ + + def _make_integral(self, eq, u_n, *args, **kwargs): + dt = kwargs.pop(C.DT, self.dt) + linear, dy = bm.vector_grad(eq, argnums=0, return_value=True)(u_n, *args, **kwargs) + u_n2 = u_n + dt * bm.exprel(dt * linear) * dy + N_u_n = dy - linear * u_n + + linear2, dy2 = bm.vector_grad(eq, argnums=0, return_value=True)(u_n2, *args, **kwargs) + N_u_n2 = dy2 - linear2 * u_n2 + + D_n2 = N_u_n2 - N_u_n + phi3 = self._psi3(dt * linear) + return u_n2 + dt * 2 * phi3 * D_n2 + + def _psi3(self, x): + origin = (bm.exp(x) - 1 - x - 0.5 * x * x) / x ** 3 + return bm.where(bm.abs(x) < 1e-6, 1 / 6 + x / 24, origin) + + +register_ode_integrator('exprb32', ExponentialRosenbrock32) + + +class ExponentialEulerMultiODE(ODEIntegrator): + r"""Exponential Euler method using automatic differentiation. + + The simplest exponential Rosenbrock method is the exponential + Rosenbrock–Euler scheme, which has order 2. + + For an ODE equation of the form + + .. math:: + + u^{\prime}=f(u), \quad u(0)=u_{0} + + its schema is given by + + .. math:: + + u_{n+1}= u_{n}+h \varphi(hL) f (u_{n}) + + where :math:`L=f^{\prime}(u_{n})` and :math:`\varphi(z)=\frac{e^{z}-1}{z}`. + + For a linear ODE system: :math:`u^{\prime} = Ay + B`, + the above equation is equal to :math:`u_{n+1}= u_{n}e^{hA}-B/A(1-e^{hA})`, + which is the exact solution for this ODE system. + + Examples + -------- + + Here is an example uses ``ExponentialEuler`` to implement HH neuron model. + + .. code-block:: python + + import brainpy as bp + import brainpy.math as bm + + + class HH(bp.dyn.NeuDyn): + def __init__( + self, size, ENa=55., EK=-90., EL=-65, C=1.0, gNa=35., gK=9., + gL=0.1, V_th=20., phi=5.0, name=None + ): + super(HH, self).__init__(size=size, name=name) + + # parameters + self.ENa = ENa + self.EK = EK + self.EL = EL + self.C = C + self.gNa = gNa + self.gK = gK + self.gL = gL + self.V_th = V_th + self.phi = phi + + # variables + self.V = bm.Variable(bm.ones(size) * -65.) + self.h = bm.Variable(bm.ones(size) * 0.6) + self.n = bm.Variable(bm.ones(size) * 0.32) + self.spike = bm.Variable(bm.zeros(size, dtype=bool)) + self.input = bm.Variable(bm.zeros(size)) + + # functions + self.integral = bp.odeint(self.derivative, method='exp_euler_multi') + + def derivative(self, V, h, n, t, Iext): + m_alpha = -0.1 * (V + 35) / (bm.exp(-0.1 * (V + 35)) - 1) + m_beta = 4 * bm.exp(-(V + 60) / 18) + m = m_alpha / (m_alpha + m_beta) + INa = self.gNa * m ** 3 * h * (V - self.ENa) + IK = self.gK * n ** 4 * (V - self.EK) + IL = self.gL * (V - self.EL) + dVdt = (- INa - IK - IL + Iext) / self.C + + alpha = 0.07 * bm.exp(-(V + 58) / 20) + beta = 1 / (bm.exp(-0.1 * (V + 28)) + 1) + dhdt = self.phi * (alpha * (1 - h) - beta * h) + + alpha = -0.01 * (V + 34) / (bm.exp(-0.1 * (V + 34)) - 1) + beta = 0.125 * bm.exp(-(V + 44) / 80) + dndt = self.phi * (alpha * (1 - n) - beta * n) + + return dVdt, dhdt, dndt + + def update(self): + V, h, n = self.integral(self.V.value, self.h.value, self.n.value, None, self.input.value) + self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th) + self.V.value = V + self.h.value = h + self.n.value = n + self.input[:] = 0. + + + runner = bp.DSRunner(HH(1), inputs=('input', 2.), monitors=['V'], dt=0.05) + runner.run(100.) + bp.visualize.line_plot(runner.mon.ts, runner.mon.V, legend='V', show=True) + + + Parameters + ---------- + f : function, joint_eq.JointEq + The derivative function. + var_type : optional, str + The variable type. + dt : optional, float + The default numerical integration step. + name : optional, str + The integrator name. + """ + def __init__( self, f, @@ -293,13 +560,13 @@ def __init__( state_delays=None, neutral_delays=None ): - super(ExponentialEuler, self).__init__(f=f, - var_type=var_type, - dt=dt, - name=name, - show_code=show_code, - state_delays=state_delays, - neutral_delays=neutral_delays) + super().__init__(f=f, + var_type=var_type, + dt=dt, + name=name, + show_code=show_code, + state_delays=state_delays, + neutral_delays=neutral_delays) if var_type == C.SYSTEM_VAR: raise NotImplementedError(f'{self.__class__.__name__} does not support {C.SYSTEM_VAR}, ' @@ -311,60 +578,29 @@ def __init__( self.integral = self.build() def build(self): - parses = self._build_integrator(self.f) all_vps = self.variables + self.parameters @wraps(self.f) def integral_func(*args, **kwargs): - # format arguments + # check arguments params_in = bm.Collector() for i, arg in enumerate(args): params_in[all_vps[i]] = arg params_in.update(kwargs) if C.DT not in params_in: params_in[C.DT] = self.dt + # separate variables and parameters + args = tree_as_jax(tuple([params_in[vp] for vp in self.variables])) + kwargs = tree_as_jax({vp: params_in[vp] for vp in self.parameters}) + # gradients + devs, dys, ins = vjp_for_exp_euler(partial(self.f, **kwargs))(*args) - # call integrals - results = [] - for i, parse in enumerate(parses): - f_integral, vars_, pars_ = parse - vps = vars_ + pars_ + [C.DT] - r = f_integral(params_in[vps[0]], **{arg: params_in[arg] for arg in vps[1:] if arg in params_in}) - results.append(r) - return results if len(self.variables) > 1 else results[0] + # integration + dt = kwargs.pop(C.DT, self.dt) + return tree_map(lambda is_, dy, dev: is_ + dt * bm.exprel(dt * dev) * dy, ins, dys, devs, + is_leaf=lambda a: isinstance(a, bm.Array)) return integral_func - def _build_integrator(self, eq): - if isinstance(eq, joint_eq.JointEq): - results = [] - for sub_eq in eq.eqs: - results.extend(self._build_integrator(sub_eq)) - return results - else: - vars, pars, _ = utils.get_args(eq) - - # checking - if len(vars) != 1: - raise errors.DiffEqError(C.multi_vars_msg.format(cls=self.__class__.__name__, - vars=str(vars), - eq=str(eq))) - - # gradient function - value_and_grad = bm.vector_grad(eq, argnums=0, return_value=True) - - # integration function - def integral(*args, **kwargs): - assert len(args) > 0 - dt = kwargs.pop(C.DT, self.dt) - linear, derivative = value_and_grad(*args, **kwargs) - phi = bm.exprel(dt * linear) - return args[0] + dt * phi * derivative - - return [(integral, vars, pars), ] - -register_ode_integrator('exponential_euler', ExponentialEuler) -register_ode_integrator('exp_euler', ExponentialEuler) -register_ode_integrator('exp_euler_auto', ExponentialEuler) -register_ode_integrator('exp_auto', ExponentialEuler) +register_ode_integrator('exp_euler_multi', ExponentialEulerMultiODE) diff --git a/brainpy/_src/math/object_transform/autograd.py b/brainpy/_src/math/object_transform/autograd.py index f5e091675..7135f0c28 100644 --- a/brainpy/_src/math/object_transform/autograd.py +++ b/brainpy/_src/math/object_transform/autograd.py @@ -1,37 +1,25 @@ # -*- coding: utf-8 -*- -import inspect -from functools import partial, wraps -from typing import Union, Callable, Dict, Sequence, Any, Optional +from functools import wraps +from typing import Union, Callable, Dict, Sequence, Any, Optional, Tuple import jax -import numpy as np +from jax import numpy as jnp +from jax._src.api import _vjp +from jax.api_util import argnums_partial +from jax.tree_util import (tree_flatten, tree_unflatten, tree_map, tree_structure) if jax.__version__ >= '0.4.16': from jax.extend import linear_util else: from jax import linear_util -from jax import dtypes, vmap, numpy as jnp, core -from jax._src.api import (_vjp, _jvp) -from jax.api_util import argnums_partial -from jax.interpreters import xla -from jax.tree_util import (tree_flatten, tree_unflatten, - tree_map, tree_transpose, - tree_structure) -from jax.util import safe_map - from brainpy import tools, check from brainpy._src.math.ndarray import Array, _as_jax_array_ -from .tools import (dynvar_deprecation, - node_deprecation, - get_stack_cache, - cache_stack) -from .base import (BrainPyObject, ObjectTransform) -from .variables import (Variable, - VariableStack, - current_transform_number, - new_transform) +from brainpy._src.math.compat_numpy import zeros +from .base import (ObjectTransform) +from .tools import (get_stack_cache, cache_stack) +from .variables import (Variable, VariableStack, current_transform_number, new_transform) __all__ = [ 'grad', # gradient of scalar function @@ -48,31 +36,23 @@ class GradientTransform(ObjectTransform): def __init__( self, - target: Callable, + fun: Callable, transform: Callable, - # variables and nodes - grad_vars: Any, - dyn_vars: Dict[str, Variable], - child_objs: Dict[str, Variable], - # gradient setting - argnums: Optional[Union[int, Sequence[int]]], + grad_vars: Any, + argnums: Union[int, Sequence[int]], return_value: bool, has_aux: bool, - transform_setting: Optional[Dict[str, Any]] = None, # other name: str = None, + **transform_kwargs ): super().__init__(name=name) # gradient variables - self._grad_vars, self._grad_tree = tree_flatten(grad_vars, is_leaf=lambda a: isinstance(a, Array)) - - # register variables and nodes - self.register_implicit_vars(dyn_vars, self._grad_vars) - self.register_implicit_nodes(child_objs) + self._grad_vars, self._grad_tree = tree_flatten(grad_vars, is_leaf=_isleaf) # parameters if argnums is None and len(self._grad_vars) == 0: @@ -92,68 +72,58 @@ def __init__( self._return_value = return_value self._has_aux = has_aux - # target - self.target = target + # target function + self.fun = fun - # transform - self._eval_dyn_vars = False - self._grad_transform = transform + # target transform self._dyn_vars = VariableStack() + self._eval_dyn_vars = False self._transform = None - self._grad_setting = dict() if transform_setting is None else transform_setting if self._has_aux: - self._transform = self._grad_transform( + self._transform = transform( self._f_grad_with_aux_to_transform, argnums=self._argnums, has_aux=True, - **self._grad_setting + **transform_kwargs ) else: - self._transform = self._grad_transform( + self._transform = transform( self._f_grad_without_aux_to_transform, argnums=self._argnums, has_aux=True, - **self._grad_setting + **transform_kwargs ) - def _f_grad_with_aux_to_transform(self, - grad_values: tuple, - dyn_values: dict, - *args, - **kwargs): - for k in dyn_values.keys(): - self._dyn_vars[k]._value = dyn_values[k] - for v, d in zip(self._grad_vars, grad_values): + def _f_grad_with_aux_to_transform(self, grad_vals: tuple, dyn_vals: dict, *args, **kwargs): + for k in dyn_vals.keys(): + self._dyn_vars[k]._value = dyn_vals[k] + for v, d in zip(self._grad_vars, grad_vals): v._value = d # Users should return the auxiliary data like:: # >>> # 1. example of return one data # >>> return scalar_loss, data # >>> # 2. example of return multiple data # >>> return scalar_loss, (data1, data2, ...) - outputs = self.target(*args, **kwargs) + outputs = self.fun(*args, **kwargs) # outputs: [0] is the value for gradient, # [1] is other values for return output0 = tree_map(lambda a: (a.value if isinstance(a, Array) else a), outputs[0]) return output0, (outputs, [v.value for v in self._grad_vars], self._dyn_vars.dict_data()) - def _f_grad_without_aux_to_transform(self, - grad_values: tuple, - dyn_values: dict, - *args, - **kwargs): - for k in dyn_values.keys(): - self._dyn_vars[k]._value = dyn_values[k] - for v, d in zip(self._grad_vars, grad_values): + def _f_grad_without_aux_to_transform(self, grad_vals: tuple, dyn_vals: dict, *args, **kwargs): + for k in dyn_vals.keys(): + self._dyn_vars[k]._value = dyn_vals[k] + for v, d in zip(self._grad_vars, grad_vals): v._value = d # Users should return the scalar value like this:: # >>> return scalar_loss - output = self.target(*args, **kwargs) + output = self.fun(*args, **kwargs) output0 = tree_map(lambda a: (a.value if isinstance(a, Array) else a), output) return output0, (output, [v.value for v in self._grad_vars], self._dyn_vars.dict_data()) def __repr__(self): name = self.__class__.__name__ - f = tools.repr_object(self.target) + f = tools.repr_object(self.fun) f = tools.repr_context(f, " " * (len(name) + 6)) format_ref = (f'{name}({self.name}, target={f}, \n' + f'{" " * len(name)} num_of_grad_vars={len(self._grad_vars)}, \n' @@ -201,7 +171,7 @@ def __call__(self, *args, **kwargs): return self._return(rets) elif not self._eval_dyn_vars: # evaluate dynamical variables - stack = get_stack_cache(self.target) + stack = get_stack_cache(self.fun) if stack is None: with new_transform(self): with VariableStack() as stack: @@ -220,7 +190,7 @@ def __call__(self, *args, **kwargs): *args, **kwargs ) - cache_stack(self.target, stack) + cache_stack(self.fun, stack) self._dyn_vars = stack self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars]) @@ -252,24 +222,16 @@ def _make_grad( reduce_axes: Optional[Sequence[str]] = (), has_aux: Optional[bool] = None, return_value: Optional[bool] = False, - # deprecated - dyn_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None, - child_objs: Optional[Union[BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]]] = None, ): - child_objs = check.is_all_objs(child_objs, out_as='dict') - dyn_vars = check.is_all_vars(dyn_vars, out_as='dict') - - return GradientTransform(target=func, + return GradientTransform(fun=func, transform=jax.grad, grad_vars=grad_vars, - dyn_vars=dyn_vars, - child_objs=child_objs, argnums=argnums, return_value=return_value, has_aux=False if has_aux is None else has_aux, - transform_setting=dict(holomorphic=holomorphic, - allow_int=allow_int, - reduce_axes=reduce_axes)) + holomorphic=holomorphic, + allow_int=allow_int, + reduce_axes=reduce_axes) def grad( @@ -281,10 +243,6 @@ def grad( reduce_axes: Optional[Sequence[str]] = (), has_aux: Optional[bool] = None, return_value: Optional[bool] = False, - - # deprecated - dyn_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None, - child_objs: Optional[Union[BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]]] = None, ) -> Union[Callable, GradientTransform]: """Automatic gradient computation for functions or class objects. @@ -391,19 +349,6 @@ def grad( is a named batch axis, ``grad(f, reduce_axes=('batch',))`` will create a function that computes the total gradient while ``grad(f)`` will create one that computes the per-example gradient. - dyn_vars : optional, ArrayType, sequence of ArrayType, dict - The dynamically changed variables used in ``func``. - - .. deprecated:: 2.4.0 - No longer need to provide ``dyn_vars``. This function is capable of automatically - collecting the dynamical variables used in the target ``func``. - child_objs: optional, BrainPyObject, sequnce, dict - - .. versionadded:: 2.3.1 - - .. deprecated:: 2.4.0 - No longer need to provide ``child_objs``. This function is capable of automatically - collecting the children objects used in the target ``func``. Returns ------- @@ -415,14 +360,10 @@ def grad( same shapes and types as the corresponding arguments. If ``has_aux`` is True then a pair of (gradient, auxiliary_data) is returned. """ - dynvar_deprecation(dyn_vars) - node_deprecation(child_objs) if func is None: return lambda f: _make_grad(f, grad_vars=grad_vars, - dyn_vars=dyn_vars, - child_objs=child_objs, argnums=argnums, holomorphic=holomorphic, allow_int=allow_int, @@ -432,8 +373,6 @@ def grad( else: return _make_grad(func=func, grad_vars=grad_vars, - dyn_vars=dyn_vars, - child_objs=child_objs, argnums=argnums, holomorphic=holomorphic, allow_int=allow_int, @@ -442,49 +381,62 @@ def grad( return_value=return_value) -def _unravel_array_into_pytree(pytree, axis, arr, is_leaf=None): - leaves, treedef = tree_flatten(pytree, is_leaf=is_leaf) - axis = axis % arr.ndim - shapes = [arr.shape[:axis] + np.shape(l) + arr.shape[axis + 1:] for l in leaves] - parts = jnp.split(_as_jax_array_(arr), np.cumsum(safe_map(np.size, leaves[:-1])), axis) - reshaped_parts = [x.reshape(shape) for x, shape in zip(parts, shapes)] - return tree_unflatten(treedef, reshaped_parts, ) +def _isleaf(x): + return isinstance(x, Array) -def _std_basis(pytree): - leaves, _ = tree_flatten(pytree) - ndim = sum(safe_map(np.size, leaves)) - dtype = dtypes.result_type(*leaves) - flat_basis = jax.numpy.eye(ndim, dtype=dtype) - return _unravel_array_into_pytree(pytree, 1, flat_basis) +def tree_as_jax(x): + return tree_map(_as_jax_array_, x, is_leaf=_isleaf) -def _isleaf(x): - return isinstance(x, Array) +def _warp_fun_force_aux(fun: Callable, has_aux: bool): + @wraps(fun) + def new_fun(*args, **kwargs): + if has_aux: + y, aux = fun(*args, **kwargs) + y, aux = tree_as_jax((y, aux)) + return y, (y, aux) + else: + y = fun(*args, **kwargs) + y = tree_as_jax(y) + return y, y + return new_fun -def _jacrev(fun, argnums=0, holomorphic=False, allow_int=False, has_aux=False, return_value=False): - _check_callable(fun) +def _warp_fun_force_return_jax(fun: Callable): @wraps(fun) + def new_fun(*args, **kwargs): + return tree_as_jax(fun(*args, **kwargs)) + + return new_fun + + +def _jacrev( + fun: Callable, + argnums: Union[int, Sequence[int]] = 0, + holomorphic: bool = False, + allow_int: bool = False, + has_aux: bool = False, + return_value: bool = False +): + """ + Jacobian of ``fun`` using reverse-mode autodiff. + + Compared to ``jax.jacrev``, this function supports returning value ("return_value"). + """ + fun = _warp_fun_force_aux(fun, has_aux) + fun_jac = jax.jacrev(fun, argnums=argnums, holomorphic=holomorphic, allow_int=allow_int, has_aux=True) + + @wraps(fun_jac) def jacfun(*args, **kwargs): - f = linear_util.wrap_init(fun, kwargs) - f_partial, dyn_args = argnums_partial(f, argnums, args, require_static_args_hashable=False) - tree_map(partial(_check_input_dtype_jacrev, holomorphic, allow_int), dyn_args) + args, kwargs = tree_as_jax((args, kwargs)) if has_aux: - y, pullback, aux = _vjp(f_partial, *dyn_args, has_aux=True) + jac, (y, aux) = fun_jac(*args, **kwargs) + return (jac, y, aux) if return_value else (jac, aux) else: - y, pullback = _vjp(f_partial, *dyn_args, has_aux=False) - tree_map(partial(_check_output_dtype_jacrev, holomorphic), y) - jac = vmap(pullback)(_std_basis(y)) - jac = jac[0] if isinstance(argnums, int) else jac - example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args - jac_tree = tree_map(partial(_unravel_array_into_pytree, y, 0, is_leaf=_isleaf), jac, is_leaf=_isleaf) - jac = tree_transpose(tree_structure(example_args), tree_flatten(y, is_leaf=_isleaf)[1], jac_tree) - if return_value: - return (jac, y, aux) if has_aux else (jac, y) - else: - return (jac, aux) if has_aux else jac + jac, y = fun_jac(*args, **kwargs) + return (jac, y) if return_value else jac return jacfun @@ -497,10 +449,6 @@ def jacrev( return_value: bool = False, holomorphic: bool = False, allow_int: bool = False, - - # deprecated - dyn_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None, - child_objs: Optional[Union[BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]]] = None, ) -> ObjectTransform: """Extending automatic Jacobian (reverse-mode) of ``func`` to classes. @@ -549,64 +497,49 @@ def jacrev( Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vector-space dtype (float0). Default False. - dyn_vars : optional, ArrayType, sequence of ArrayType, dict - The dynamically changed variables used in ``func``. - - .. deprecated:: 2.4.0 - No longer need to provide ``dyn_vars``. This function is capable of automatically - collecting the dynamical variables used in the target ``func``. - child_objs: optional, BrainPyObject, sequnce, dict - - .. versionadded:: 2.3.1 - - .. deprecated:: 2.4.0 - No longer need to provide ``child_objs``. This function is capable of automatically - collecting the children objects used in the target ``func``. Returns ------- fun: GradientTransform The transformed object. """ - child_objs = check.is_all_objs(child_objs, out_as='dict') - dyn_vars = check.is_all_vars(dyn_vars, out_as='dict') - - return GradientTransform(target=func, + return GradientTransform(fun=func, transform=_jacrev, grad_vars=grad_vars, - dyn_vars=dyn_vars, - child_objs=child_objs, argnums=argnums, return_value=return_value, has_aux=False if has_aux is None else has_aux, - transform_setting=dict(holomorphic=holomorphic, - allow_int=allow_int)) + holomorphic=holomorphic, + allow_int=allow_int) jacobian = jacrev -def _jacfwd(fun, argnums=0, holomorphic=False, has_aux=False, return_value=False): - _check_callable(fun) +def _jacfwd( + fun: Callable, + argnums: Union[int, Sequence[int]] = 0, + holomorphic: bool = False, + has_aux: bool = False, + return_value: bool = False +): + """ + Jacobian of ``fun`` using forward-mode autodiff. - @wraps(fun) + Compared to ``jax.jacfwd``, this function supports returning value ("return_value"). + """ + fun = _warp_fun_force_aux(fun, has_aux) + fun_jac = jax.jacfwd(fun, argnums=argnums, holomorphic=holomorphic, has_aux=True) + + @wraps(fun_jac) def jacfun(*args, **kwargs): - f = linear_util.wrap_init(fun, kwargs) - f_partial, dyn_args = argnums_partial(f, argnums, args, require_static_args_hashable=False) - tree_map(partial(_check_input_dtype_jacfwd, holomorphic), dyn_args) + args, kwargs = tree_as_jax((args, kwargs)) if has_aux: - pushfwd = partial(_jvp, f_partial, dyn_args, has_aux=True) - y, jac, aux = vmap(pushfwd, out_axes=(None, -1, None))(_std_basis(dyn_args)) - else: - pushfwd = partial(_jvp, f_partial, dyn_args) - y, jac = vmap(pushfwd, out_axes=(None, -1))(_std_basis(dyn_args)) - tree_map(partial(_check_output_dtype_jacfwd, holomorphic), y) - example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args - jac = tree_map(partial(_unravel_array_into_pytree, example_args, -1, is_leaf=_isleaf), jac, is_leaf=_isleaf) - if return_value: - return (jac, y, aux) if has_aux else (jac, y) + jac, (y, aux) = fun_jac(*args, **kwargs) + return (jac, y, aux) if return_value else (jac, aux) else: - return (jac, aux) if has_aux else jac + jac, y = fun_jac(*args, **kwargs) + return (jac, y) if return_value else jac return jacfun @@ -618,10 +551,6 @@ def jacfwd( has_aux: Optional[bool] = None, return_value: bool = False, holomorphic: bool = False, - - # deprecated - dyn_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None, - child_objs: Optional[Union[BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]]] = None, ) -> ObjectTransform: """Extending automatic Jacobian (forward-mode) of ``func`` to classes. @@ -663,37 +592,19 @@ def jacfwd( positional argument(s) to differentiate with respect to (default ``0``). holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be holomorphic. Default False. - dyn_vars : optional, ArrayType, sequence of ArrayType, dict - The dynamically changed variables used in ``func``. - - .. deprecated:: 2.4.0 - No longer need to provide ``dyn_vars``. This function is capable of automatically - collecting the dynamical variables used in the target ``func``. - child_objs: optional, BrainPyObject, sequnce, dict - - .. versionadded:: 2.3.1 - - .. deprecated:: 2.4.0 - No longer need to provide ``child_objs``. This function is capable of automatically - collecting the children objects used in the target ``func``. Returns ------- obj: GradientTransform The transformed object. """ - child_objs = check.is_all_objs(child_objs, out_as='dict') - dyn_vars = check.is_all_vars(dyn_vars, out_as='dict') - - return GradientTransform(target=func, + return GradientTransform(fun=func, transform=_jacfwd, grad_vars=grad_vars, - dyn_vars=dyn_vars, - child_objs=child_objs, argnums=argnums, return_value=return_value, has_aux=False if has_aux is None else has_aux, - transform_setting=dict(holomorphic=holomorphic)) + holomorphic=holomorphic) def hessian( @@ -702,10 +613,6 @@ def hessian( argnums: Optional[Union[int, Sequence[int]]] = None, return_value: bool = False, holomorphic=False, - - # deprecated - dyn_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None, - child_objs: Optional[Union[BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]]] = None, ) -> ObjectTransform: """Hessian of ``func`` as a dense array. @@ -724,53 +631,43 @@ def hessian( Indicates whether ``fun`` is promised to be holomorphic. Default False. return_value : bool Whether return the hessian values. - dyn_vars : optional, ArrayType, sequence of ArrayType, dict - The dynamically changed variables used in ``func``. - - .. deprecated:: 2.4.0 - No longer need to provide ``dyn_vars``. This function is capable of automatically - collecting the dynamical variables used in the target ``func``. - child_objs: optional, BrainPyObject, sequnce, dict - - .. versionadded:: 2.3.1 - - .. deprecated:: 2.4.0 - No longer need to provide ``child_objs``. This function is capable of automatically - collecting the children objects used in the target ``func``. Returns ------- obj: ObjectTransform The transformed object. """ - child_objs = check.is_all_objs(child_objs, out_as='dict') - dyn_vars = check.is_all_vars(dyn_vars, out_as='dict') return jacfwd(jacrev(func, - dyn_vars=dyn_vars, - child_objs=child_objs, grad_vars=grad_vars, argnums=argnums, holomorphic=holomorphic), - dyn_vars=dyn_vars, - child_objs=child_objs, grad_vars=grad_vars, argnums=argnums, holomorphic=holomorphic, return_value=return_value) -def functional_vector_grad(func, argnums=0, return_value=False, has_aux=False): - _check_callable(func) +def functional_vector_grad( + func: Callable, + argnums: Union[int, Sequence[int]] = 0, + return_value: bool = False, + has_aux: bool = False, + reduce_axes: Tuple = () +): + """ + Vector-Jacobian product of ``func`` using reverse-mode autodiff. + """ + func = _warp_fun_force_return_jax(func) @wraps(func) def grad_fun(*args, **kwargs): - f = linear_util.wrap_init(func, kwargs) - f_partial, dyn_args = argnums_partial(f, argnums, args, require_static_args_hashable=False) + f_partial, dyn_args = argnums_partial(linear_util.wrap_init(func, kwargs), argnums, args, + require_static_args_hashable=False) if has_aux: - y, vjp_fn, aux = _vjp(f_partial, *dyn_args, has_aux=True) + y, vjp_fn, aux = _vjp(f_partial, *dyn_args, has_aux=True, reduce_axes=reduce_axes) else: - y, vjp_fn = _vjp(f_partial, *dyn_args, has_aux=False) + y, vjp_fn = _vjp(f_partial, *dyn_args, has_aux=False, reduce_axes=reduce_axes) leaves, tree = tree_flatten(y) tangents = tree_unflatten(tree, [jnp.ones(l.shape, dtype=l.dtype) for l in leaves]) grads = vjp_fn(tangents) @@ -790,10 +687,6 @@ def vector_grad( argnums: Optional[Union[int, Sequence[int]]] = None, return_value: bool = False, has_aux: Optional[bool] = None, - - # deprecated - dyn_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None, - child_objs: Optional[Union[BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]]] = None, ) -> Union[Callable, ObjectTransform]: """Take vector-valued gradients for function ``func``. @@ -818,7 +711,6 @@ def vector_grad( - "has_aux=False" + "return_value=True" => ``((var_grads, arg_grads), loss_value)``. - "has_aux=True" + "return_value=True" => ``((var_grads, arg_grads), loss_value, aux_data)``. - Parameters ---------- func: Callable @@ -833,148 +725,50 @@ def vector_grad( Whether return the loss value. argnums: Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default ``0``). - dyn_vars : optional, ArrayType, sequence of ArrayType, dict - The dynamically changed variables used in ``func``. - - .. deprecated:: 2.4.0 - No longer need to provide ``dyn_vars``. This function is capable of automatically - collecting the dynamical variables used in the target ``func``. - child_objs: optional, BrainPyObject, sequnce, dict - - .. versionadded:: 2.3.1 - - .. deprecated:: 2.4.0 - No longer need to provide ``child_objs``. This function is capable of automatically - collecting the children objects used in the target ``func``. Returns ------- func : GradientTransform The vector gradient function. """ - child_objs = check.is_all_objs(child_objs, out_as='dict') - dyn_vars = check.is_all_vars(dyn_vars, out_as='dict') if func is None: - return lambda f: GradientTransform(target=f, + return lambda f: GradientTransform(fun=f, transform=functional_vector_grad, grad_vars=grad_vars, - dyn_vars=dyn_vars, - child_objs=child_objs, argnums=argnums, return_value=return_value, has_aux=False if has_aux is None else has_aux) else: - return GradientTransform(target=func, + return GradientTransform(fun=func, transform=functional_vector_grad, grad_vars=grad_vars, - dyn_vars=dyn_vars, - child_objs=child_objs, argnums=argnums, return_value=return_value, has_aux=False if has_aux is None else has_aux) -def _check_callable(fun): - # In Python 3.10+, the only thing stopping us from supporting staticmethods - # is that we can't take weak references to them, which the C++ JIT requires. - if isinstance(fun, staticmethod): - raise TypeError(f"staticmethod arguments are not supported, got {fun}") - if not callable(fun): - raise TypeError(f"Expected a callable value, got {fun}") - if _isgeneratorfunction(fun): - raise TypeError(f"Expected a function, got a generator function: {fun}") - - -def _isgeneratorfunction(fun): - # re-implemented here because of https://bugs.python.org/issue33261 - while inspect.ismethod(fun): - fun = fun.__func__ - while isinstance(fun, partial): - fun = fun.func - return inspect.isfunction(fun) and bool(fun.__code__.co_flags & inspect.CO_GENERATOR) +def vjp_for_exp_euler(func: Callable): + """ + Vector-Jacobian product of ``func`` using reverse-mode autodiff. + """ + func = _warp_fun_force_return_jax(func) + @wraps(func) + def grad_fun(*dyn_args): + ys, y_vjp = jax.vjp(func, *dyn_args) + tree = tree_structure(ys) + out_tangents = [] + for i in range(len(dyn_args)): + raw_tangents = tuple([jnp.ones(l.shape, dtype=l.dtype) if j == i else jnp.zeros(l.shape, dtype=l.dtype) + for j, l in enumerate(dyn_args)]) + out_tangents.append(y_vjp(tree_unflatten(tree, raw_tangents))[i]) + return tree_unflatten(tree, tuple(out_tangents)), ys, tree_unflatten(tree, dyn_args) -def _check_arg(arg): - if not (isinstance(arg, core.Tracer) or _valid_jaxtype(arg)): - raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid JAX type.") + return grad_fun -def _valid_jaxtype(arg): - try: - xla.abstractify(arg) # faster than core.get_aval - except TypeError: - return core.valid_jaxtype(arg) - else: - return True - - -def _check_output_dtype_revderiv(name, holomorphic, x): - aval = core.get_aval(x) - # if jnp.issubdtype(aval.dtype, dtypes.extended): - # raise TypeError(f"{name} with output element type {aval.dtype.name}") - if holomorphic: - if not dtypes.issubdtype(aval.dtype, np.complexfloating): - raise TypeError(f"{name} with holomorphic=True requires outputs with complex dtype, " - f"but got {aval.dtype.name}.") - elif dtypes.issubdtype(aval.dtype, np.complexfloating): - raise TypeError(f"{name} requires real-valued outputs (output dtype that is " - f"a sub-dtype of np.floating), but got {aval.dtype.name}. " - "For holomorphic differentiation, pass holomorphic=True. " - "For differentiation of non-holomorphic functions involving complex " - "outputs, use jax.vjp directly.") - elif not dtypes.issubdtype(aval.dtype, np.floating): - raise TypeError(f"{name} requires real-valued outputs (output dtype that is " - f"a sub-dtype of np.floating), but got {aval.dtype.name}. " - "For differentiation of functions with integer outputs, use " - "jax.vjp directly.") - - -def _check_input_dtype_revderiv(name, holomorphic, allow_int, x): - _check_arg(x) - aval = core.get_aval(x) - # if jnp.issubdtype(aval.dtype, dtypes.extended): - # raise TypeError(f"{name} with input element type {aval.dtype.name}") - if holomorphic: - if not dtypes.issubdtype(aval.dtype, np.complexfloating): - raise TypeError(f"{name} with holomorphic=True requires inputs with complex dtype, " - f"but got {aval.dtype.name}.") - if (dtypes.issubdtype(aval.dtype, np.integer) or - dtypes.issubdtype(aval.dtype, np.bool_)): - if not allow_int: - raise TypeError(f"{name} requires real- or complex-valued inputs (input dtype " - f"that is a sub-dtype of np.inexact), but got {aval.dtype.name}. " - "If you want to use Boolean- or integer-valued inputs, use vjp " - "or set allow_int to True.") - elif not dtypes.issubdtype(aval.dtype, np.inexact): - raise TypeError(f"{name} requires numerical-valued inputs (input dtype that is a " - f"sub-dtype of np.bool_ or np.number), but got {aval.dtype.name}.") - - -_check_output_dtype_jacrev = partial(_check_output_dtype_revderiv, "jacrev") -_check_input_dtype_jacrev = partial(_check_input_dtype_revderiv, "jacrev") - - -def _check_output_dtype_jacfwd(holomorphic, x): - aval = core.get_aval(x) - if holomorphic: - if not dtypes.issubdtype(aval.dtype, np.complexfloating): - raise TypeError("jacfwd with holomorphic=True requires outputs with complex dtype, " - f"but got {aval.dtype.name}.") - - -def _check_input_dtype_jacfwd(holomorphic: bool, x: Any) -> None: - _check_arg(x) - aval = core.get_aval(x) - # if jnp.issubdtype(aval.dtype, dtypes.extended): - # raise TypeError(f"jacfwd with input element type {aval.dtype.name}") - if holomorphic: - if not dtypes.issubdtype(aval.dtype, np.complexfloating): - raise TypeError("jacfwd with holomorphic=True requires inputs with complex " - f"dtype, but got {aval.dtype.name}.") - elif not dtypes.issubdtype(aval.dtype, np.floating): - raise TypeError("jacfwd requires real-valued inputs (input dtype that is " - f"a sub-dtype of np.floating), but got {aval.dtype.name}. " - "For holomorphic differentiation, pass holomorphic=True. " - "For differentiation of non-holomorphic functions involving " - "complex inputs or integer inputs, use jax.jvp directly.") +def _init_tangents(leaf, n_copy, index): + ret = zeros((n_copy,) + leaf.shape, dtype=leaf.dtype) + ret[index] = 1. + return ret.value diff --git a/brainpy/_src/math/others.py b/brainpy/_src/math/others.py index f3cf4f516..2f75449c1 100644 --- a/brainpy/_src/math/others.py +++ b/brainpy/_src/math/others.py @@ -86,18 +86,6 @@ def f(l): return tree_map(f, t) -def _exprel(x, threshold): - def true_f(x): - x2 = x * x - return 1. + x / 2. + x2 / 6. + x2 * x / 24.0 # + x2 * x2 / 120. - - def false_f(x): - return (jnp.exp(x) - 1) / x - - # return jax.lax.cond(jnp.abs(x) < threshold, true_f, false_f, x) - return jnp.where(jnp.abs(x) <= threshold, 1. + x / 2. + x * x / 6., (jnp.exp(x) - 1) / x) - - def exprel(x, threshold: float = None): """Relative error exponential, ``(exp(x) - 1)/x``. @@ -118,4 +106,4 @@ def exprel(x, threshold: float = None): threshold = 1e-8 else: threshold = 1e-5 - return _exprel(x, threshold) + return jnp.where(jnp.abs(x) <= threshold, 1. + x / 2. + x * x / 6., (jnp.exp(x) - 1) / x)