diff --git a/brainpy/_src/analysis/lowdim/lowdim_analyzer.py b/brainpy/_src/analysis/lowdim/lowdim_analyzer.py index f186659e9..c75f74e5f 100644 --- a/brainpy/_src/analysis/lowdim/lowdim_analyzer.py +++ b/brainpy/_src/analysis/lowdim/lowdim_analyzer.py @@ -3,17 +3,18 @@ import warnings from functools import partial -import numpy as np import jax +import numpy as np from jax import numpy as jnp from jax import vmap from jax.scipy.optimize import minimize -from brainpy import errors, tools import brainpy._src.math as bm -from brainpy._src.math.object_transform.base import Collector +from brainpy import errors, tools from brainpy._src.analysis import constants as C, utils from brainpy._src.analysis.base import DSAnalyzer +from brainpy._src.math.object_transform.base import Collector +from brainpy._src.optimizers.brentq import jax_brentq, ECONVERGED, brentq_candidates, brentq_roots pyplot = None @@ -316,7 +317,9 @@ def F_vmap_fp_aux(self): def F_fixed_point_opt(self): if C.F_fixed_point_opt not in self.analyzed_results: def f(start_and_end, *args): - return utils.jax_brentq(self.F_fx)(start_and_end[0], start_and_end[1], args) + return jax_brentq(utils.f_without_jaxarray_return(self.F_fx))( + start_and_end[0], start_and_end[1], args + ) self.analyzed_results[C.F_fixed_point_opt] = f return self.analyzed_results[C.F_fixed_point_opt] @@ -387,7 +390,7 @@ def _get_fixed_points(self, candidates, *args, num_seg=None, tol_aux=1e-7, loss_ # optimize the fixed points res = self.F_vmap_fp_opt(X, *args) losses = self.F_vmap_fp_aux(res['root'], *args) - valid_or_not = jnp.logical_and(res['status'] == utils.ECONVERGED, losses <= tol_aux) + valid_or_not = jnp.logical_and(res['status'] == ECONVERGED, losses <= tol_aux) ids = np.asarray(jnp.where(valid_or_not)[0]) fps = np.asarray(res['root'])[ids] args = tuple(a[ids] for a in args) @@ -569,10 +572,14 @@ def F_fixed_point_opt(self): if self._can_convert_to_one_eq(): if self.convert_type() == C.x_by_y: def f(start_and_end, *args): - return utils.jax_brentq(self.F_y_convert[1])(start_and_end[0], start_and_end[1], args) + return jax_brentq(utils.f_without_jaxarray_return(self.F_y_convert[1]))( + start_and_end[0], start_and_end[1], args + ) else: def f(start_and_end, *args): - return utils.jax_brentq(self.F_x_convert[1])(start_and_end[0], start_and_end[1], args) + return jax_brentq(utils.f_without_jaxarray_return(self.F_x_convert[1]))( + start_and_end[0], start_and_end[1], args + ) self.analyzed_results[C.F_fixed_point_opt] = f else: @@ -718,23 +725,23 @@ def _get_fx_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux # auxiliary functions f2 = lambda y, x, *pars: self.F_fx(x, y, *pars) vmap_f2 = jax.jit(vmap(f2), device=self.jit_device) - vmap_brentq_f2 = jax.jit(vmap(utils.jax_brentq(f2)), device=self.jit_device) - vmap_brentq_f1 = jax.jit(vmap(utils.jax_brentq(self.F_fx)), device=self.jit_device) + vmap_brentq_f2 = jax.jit(vmap(jax_brentq(utils.f_without_jaxarray_return(f2))), device=self.jit_device) + vmap_brentq_f1 = jax.jit(vmap(jax_brentq(utils.f_without_jaxarray_return(self.F_fx))), device=self.jit_device) # num segments for _j, Ps in enumerate(par_seg): if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {_j} ...") if coords == self.x_var + '-' + self.y_var: - x0s, x1s, vps = utils.brentq_candidates(self.F_vmap_fx, *((xs, ys) + Ps)) - x_values_in_fx, out_args = utils.brentq_roots2(vmap_brentq_f1, x0s, x1s, *vps) + x0s, x1s, vps = brentq_candidates(self.F_vmap_fx, *((xs, ys) + Ps)) + x_values_in_fx, out_args = brentq_roots(vmap_brentq_f1, x0s, x1s, *vps) y_values_in_fx = out_args[0] p_values_in_fx = out_args[1:] x_values_in_fx, y_values_in_fx, p_values_in_fx = \ self._fp_filter(x_values_in_fx, y_values_in_fx, p_values_in_fx, fp_aux_filter) elif coords == self.y_var + '-' + self.x_var: - x0s, x1s, vps = utils.brentq_candidates(vmap_f2, *((ys, xs) + Ps)) - y_values_in_fx, out_args = utils.brentq_roots2(vmap_brentq_f2, x0s, x1s, *vps) + x0s, x1s, vps = brentq_candidates(vmap_f2, *((ys, xs) + Ps)) + y_values_in_fx, out_args = brentq_roots(vmap_brentq_f2, x0s, x1s, *vps) x_values_in_fx = out_args[0] p_values_in_fx = out_args[1:] x_values_in_fx, y_values_in_fx, p_values_in_fx = \ @@ -812,21 +819,21 @@ def _get_fy_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux # auxiliary functions f2 = lambda y, x, *pars: self.F_fy(x, y, *pars) vmap_f2 = jax.jit(vmap(f2), device=self.jit_device) - vmap_brentq_f2 = jax.jit(vmap(utils.jax_brentq(f2)), device=self.jit_device) - vmap_brentq_f1 = jax.jit(vmap(utils.jax_brentq(self.F_fy)), device=self.jit_device) + vmap_brentq_f2 = jax.jit(vmap(jax_brentq(utils.f_without_jaxarray_return(f2))), device=self.jit_device) + vmap_brentq_f1 = jax.jit(vmap(jax_brentq(utils.f_without_jaxarray_return(self.F_fy))), device=self.jit_device) for j, Ps in enumerate(par_seg): if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...") if coords == self.x_var + '-' + self.y_var: - starts, ends, vps = utils.brentq_candidates(self.F_vmap_fy, *((xs, ys) + Ps)) - x_values_in_fy, out_args = utils.brentq_roots2(vmap_brentq_f1, starts, ends, *vps) + starts, ends, vps = brentq_candidates(self.F_vmap_fy, *((xs, ys) + Ps)) + x_values_in_fy, out_args = brentq_roots(vmap_brentq_f1, starts, ends, *vps) y_values_in_fy = out_args[0] p_values_in_fy = out_args[1:] x_values_in_fy, y_values_in_fy, p_values_in_fy = \ self._fp_filter(x_values_in_fy, y_values_in_fy, p_values_in_fy, fp_aux_filter) elif coords == self.y_var + '-' + self.x_var: - starts, ends, vps = utils.brentq_candidates(vmap_f2, *((ys, xs) + Ps)) - y_values_in_fy, out_args = utils.brentq_roots2(vmap_brentq_f2, starts, ends, *vps) + starts, ends, vps = brentq_candidates(vmap_f2, *((ys, xs) + Ps)) + y_values_in_fy, out_args = brentq_roots(vmap_brentq_f2, starts, ends, *vps) x_values_in_fy = out_args[0] p_values_in_fy = out_args[1:] x_values_in_fy, y_values_in_fy, p_values_in_fy = \ diff --git a/brainpy/_src/analysis/utils/__init__.py b/brainpy/_src/analysis/utils/__init__.py index be8715821..52b4ceb5b 100644 --- a/brainpy/_src/analysis/utils/__init__.py +++ b/brainpy/_src/analysis/utils/__init__.py @@ -3,7 +3,6 @@ from .function import * from .measurement import * from .model import * -from .optimization import * from .others import * from .outputs import * from .visualization import * diff --git a/brainpy/_src/analysis/utils/optimization.py b/brainpy/_src/analysis/utils/optimization.py deleted file mode 100644 index 270852327..000000000 --- a/brainpy/_src/analysis/utils/optimization.py +++ /dev/null @@ -1,591 +0,0 @@ -# -*- coding: utf-8 -*- - - -import jax.lax -import jax.numpy as jnp -import numpy as np -from jax import grad, jit, vmap -from jax.flatten_util import ravel_pytree - -from brainpy import errors -import brainpy._src.math as bm -from . import f_without_jaxarray_return - -try: - import scipy.optimize as soptimize -except (ModuleNotFoundError, ImportError): - soptimize = None - -__all__ = [ - 'ECONVERGED', 'ECONVERR', - - 'jax_brentq', - 'get_brentq_candidates', - 'brentq_candidates', - 'brentq_roots', - 'brentq_roots2', - 'scipy_minimize_with_jax', - 'roots_of_1d_by_x', - 'roots_of_1d_by_xy', -] - -ECONVERGED = 0 -ECONVERR = -1 - - -def _logical_or(a, b): - a = a.value if isinstance(a, bm.Array) else a - b = b.value if isinstance(b, bm.Array) else b - return jnp.logical_or(a, b) - - -def _logical_and(a, b): - a = a.value if isinstance(a, bm.Array) else a - b = b.value if isinstance(b, bm.Array) else b - return jnp.logical_and(a, b) - - -def _where(p, a, b): - p = p.value if isinstance(p, bm.Array) else p - a = a.value if isinstance(a, bm.Array) else a - b = b.value if isinstance(b, bm.Array) else b - return jnp.where(p, a, b) - - -def jax_brentq(fun): - f = f_without_jaxarray_return(fun) - assert jax.config.read('jax_enable_x64'), ('Brentq optimization need x64 support. ' - 'Please enable x64 with "brainpy.math.enable_x64()"') - rtol = 4 * jnp.finfo(jnp.float64).eps - - # if jax.config.read('jax_enable_x64'): - # rtol = 4 * jnp.finfo(jnp.float64).eps - # else: - # rtol = 1.5 * jnp.finfo(jnp.float32).eps - - def x(a, b, args=(), xtol=2e-14, maxiter=200): - # Convert to float - xpre = a * 1.0 - xcur = b * 1.0 - - # Conditional checks for intervals in methods involving bisection - fpre = f(xpre, *args) - fcur = f(xcur, *args) - - # Root found at either end of [a,b] - root = _where(fpre == 0, xpre, 0.) - status = _where(fpre == 0, ECONVERGED, ECONVERR) - root = _where(fcur == 0, xcur, root) - status = _where(fcur == 0, ECONVERGED, status) - - # Check for sign error and early termination - # Perform Brent's method - def _f1(x): - x['xblk'] = x['xpre'] - x['fblk'] = x['fpre'] - x['spre'] = x['xcur'] - x['xpre'] - x['scur'] = x['xcur'] - x['xpre'] - return x - - def _f2(x): - x['xpre'] = x['xcur'] - x['xcur'] = x['xblk'] - x['xblk'] = x['xpre'] - x['fpre'] = x['fcur'] - x['fcur'] = x['fblk'] - x['fblk'] = x['fpre'] - return x - - def _f5(x): - x['stry'] = -x['fcur'] * (x['xcur'] - x['xpre']) / (x['fcur'] - x['fpre']) - return x - - def _f6(x): - x['dpre'] = (x['fpre'] - x['fcur']) / (x['xpre'] - x['xcur']) - dblk = (x['fblk'] - x['fcur']) / (x['xblk'] - x['xcur']) - _tmp = dblk * x['dpre'] * (x['fblk'] - x['fpre']) - x['stry'] = -x['fcur'] * (x['fblk'] * dblk - x['fpre'] * x['dpre']) / _tmp - return x - - def _f3(x): - x = jax.lax.cond(x['xpre'] == x['xblk'], _f5, _f6, x) - k = jnp.min(jnp.array([abs(x['spre']), 3 * abs(x['sbis']) - x['delta']])) - j = 2 * abs(x['stry']) < k - x['spre'] = _where(j, x['scur'], x['sbis']) - x['scur'] = _where(j, x['stry'], x['sbis']) - return x - - def _f4(x): # bisect - x['spre'] = x['sbis'] - x['scur'] = x['sbis'] - return x - - def body_fun(x): - x['itr'] += 1 - x = jax.lax.cond(x['fpre'] * x['fcur'] < 0, _f1, lambda a: a, x) - x = jax.lax.cond(abs(x['fblk']) < abs(x['fcur']), _f2, lambda a: a, x) - x['delta'] = (xtol + rtol * abs(x['xcur'])) / 2 - x['sbis'] = (x['xblk'] - x['xcur']) / 2 - # Root found - j = _logical_or(x['fcur'] == 0, abs(x['sbis']) < x['delta']) - x['status'] = _where(j, ECONVERGED, x['status']) - x['root'] = _where(j, x['xcur'], x['root']) - x = jax.lax.cond(_logical_and(abs(x['spre']) > x['delta'], abs(x['fcur']) < abs(x['fpre'])), - _f3, _f4, x) - x['xpre'] = x['xcur'] - x['fpre'] = x['fcur'] - x['xcur'] += _where(abs(x['scur']) > x['delta'], - x['scur'], _where(x['sbis'] > 0, x['delta'], -x['delta'])) - x['fcur'] = f(x['xcur'], *args) - x['funcalls'] += 1 - return x - - def cond_fun(R): - return jnp.logical_and(R['status'] != ECONVERGED, R['itr'] <= maxiter) - - R = dict(root=root, status=status, xpre=xpre, xcur=xcur, fpre=fpre, fcur=fcur, - itr=0, funcalls=2, xblk=xpre, fblk=fpre, - sbis=(xpre - xcur) / 2, - delta=(xtol + rtol * abs(xcur)) / 2, - stry=-fcur * (xcur - xpre) / (fcur - fpre), - scur=xcur - xpre, spre=xcur - xpre, - dpre=(fpre - fcur) / (xpre - xcur)) - R = jax.lax.cond(status == ECONVERGED, - lambda x: x, - lambda x: jax.lax.while_loop(cond_fun, body_fun, x), - R) - return dict(root=R['root'], funcalls=R['funcalls'], itr=R['itr'], status=R['status']) - - return x - - -def get_brentq_candidates(f, xs, ys): - f = f_without_jaxarray_return(f) - xs = bm.as_jax(xs) - ys = bm.as_jax(ys) - Y, X = jnp.meshgrid(ys, xs) - vals = f(X, Y) - signs = jnp.sign(vals) - x_ids, y_ids = jnp.where(signs[:-1] * signs[1:] <= 0) - starts = xs[x_ids] - ends = xs[x_ids + 1] - args = ys[y_ids] - return starts, ends, args - - -def brentq_candidates(vmap_f, *values, args=()): - # change the position of meshgrid values - values = tuple((v.value if isinstance(v, bm.Array) else v) for v in values) - xs = values[0] - mesh_values = jnp.meshgrid(*values) - if jnp.ndim(mesh_values[0]) > 1: - mesh_values = tuple(jnp.moveaxis(m, 0, 1) for m in mesh_values) - mesh_values = tuple(m.flatten() for m in mesh_values) - # function outputs - signs = jnp.sign(vmap_f(*(mesh_values + args))) - # compute the selected values - signs = signs.reshape((xs.shape[0], -1)) - par_len = signs.shape[1] - signs1 = signs.at[-1].set(1) # discard the final row - signs2 = jnp.vstack((signs[1:], signs[:1])).at[-1].set(1) # discard the first row - ids = jnp.where((signs1 * signs2).flatten() <= 0)[0] - x_starts = mesh_values[0][ids] - x_ends = mesh_values[0][ids + par_len] - other_vals = tuple(v[ids] for v in mesh_values[1:]) - return x_starts, x_ends, other_vals - - -def brentq_roots(f, starts, ends, *vmap_args, args=()): - in_axes = (0, 0, tuple([0] * len(vmap_args)) + tuple([None] * len(args))) - vmap_f_opt = jax.jit(vmap(jax_brentq(f), in_axes=in_axes)) - all_args = vmap_args + args - if len(all_args): - res = vmap_f_opt(starts, ends, all_args) - else: - res = vmap_f_opt(starts, ends, ) - valid_idx = jnp.where(res['status'] == ECONVERGED)[0] - roots = res['root'][valid_idx] - vmap_args = tuple(a[valid_idx] for a in vmap_args) - return roots, vmap_args - - -def brentq_roots2(vmap_f, starts, ends, *vmap_args, args=()): - all_args = vmap_args + args - res = vmap_f(starts, ends, all_args) - valid_idx = jnp.where(res['status'] == ECONVERGED)[0] - roots = res['root'][valid_idx] - vmap_args = tuple(a[valid_idx] for a in vmap_args) - return roots, vmap_args - - -def scipy_minimize_with_jax(fun, x0, - method=None, - args=(), - bounds=None, - constraints=(), - tol=None, - callback=None, - options=None): - """ - A simple wrapper for scipy.optimize.minimize using JAX. - - Parameters - ---------- - fun: function - The objective function to be minimized, written in JAX code - so that it is automatically differentiable. It is of type, - ```fun: x, *args -> float``` where `x` is a PyTree and args - is a tuple of the fixed parameters needed to completely specify the function. - - x0: jnp.ndarray - Initial guess represented as a JAX PyTree. - - args: tuple, optional. - Extra arguments passed to the objective function - and its derivative. Must consist of valid JAX types; e.g. the leaves - of the PyTree must be floats. - - method : str or callable, optional - Type of solver. Should be one of - - 'Nelder-Mead' :ref:`(see here) ` - - 'Powell' :ref:`(see here) ` - - 'CG' :ref:`(see here) ` - - 'BFGS' :ref:`(see here) ` - - 'Newton-CG' :ref:`(see here) ` - - 'L-BFGS-B' :ref:`(see here) ` - - 'TNC' :ref:`(see here) ` - - 'COBYLA' :ref:`(see here) ` - - 'SLSQP' :ref:`(see here) ` - - 'trust-constr':ref:`(see here) ` - - 'dogleg' :ref:`(see here) ` - - 'trust-ncg' :ref:`(see here) ` - - 'trust-exact' :ref:`(see here) ` - - 'trust-krylov' :ref:`(see here) ` - - custom - a callable object (added in version 0.14.0), - see below for description. - If not given, chosen to be one of ``BFGS``, ``L-BFGS-B``, ``SLSQP``, - depending on if the problem has constraints or bounds. - - bounds : sequence or `Bounds`, optional - Bounds on variables for L-BFGS-B, TNC, SLSQP, Powell, and - trust-constr methods. There are two ways to specify the bounds: - 1. Instance of `Bounds` class. - 2. Sequence of ``(min, max)`` pairs for each element in `x`. None - is used to specify no bound. - Note that in order to use `bounds` you will need to manually flatten - them in the same order as your inputs `x0`. - - constraints : {Constraint, dict} or List of {Constraint, dict}, optional - Constraints definition (only for COBYLA, SLSQP and trust-constr). - Constraints for 'trust-constr' are defined as a single object or a - list of objects specifying constraints to the optimization problem. - Available constraints are: - - `LinearConstraint` - - `NonlinearConstraint` - Constraints for COBYLA, SLSQP are defined as a list of dictionaries. - Each dictionary with fields: - type : str - Constraint type: 'eq' for equality, 'ineq' for inequality. - fun : callable - The function defining the constraint. - jac : callable, optional - The Jacobian of `fun` (only for SLSQP). - args : sequence, optional - Extra arguments to be passed to the function and Jacobian. - Equality constraint means that the constraint function result is to - be zero whereas inequality means that it is to be non-negative. - Note that COBYLA only supports inequality constraints. - - Note that in order to use `constraints` you will need to manually flatten - them in the same order as your inputs `x0`. - - tol : float, optional - Tolerance for termination. For detailed control, use solver-specific - options. - - options : dict, optional - A dictionary of solver options. All methods accept the following - generic options: - maxiter : int - Maximum number of iterations to perform. Depending on the - method each iteration may use several function evaluations. - disp : bool - Set to True to print convergence messages. - For method-specific options, see :func:`show_options()`. - - callback : callable, optional - Called after each iteration. For 'trust-constr' it is a callable with - the signature: - ``callback(xk, OptimizeResult state) -> bool`` - where ``xk`` is the current parameter vector represented as a PyTree, - and ``state`` is an `OptimizeResult` object, with the same fields - as the ones from the return. If callback returns True the algorithm - execution is terminated. - - For all the other methods, the signature is: - ```callback(xk)``` - where `xk` is the current parameter vector, represented as a PyTree. - - Returns - ------- - res : The optimization result represented as a ``OptimizeResult`` object. - Important attributes are: - ``x``: the solution array, represented as a JAX PyTree - ``success``: a Boolean flag indicating if the optimizer exited successfully - ``message``: describes the cause of the termination. - See `scipy.optimize.OptimizeResult` for a description of other attributes. - - """ - if soptimize is None: - raise errors.PackageMissingError(f'"scipy" must be installed when user want to use ' - f'function: {scipy_minimize_with_jax}') - - # Use tree flatten and unflatten to convert params x0 from PyTrees to flat arrays - x0_flat, unravel = ravel_pytree(x0) - - # Wrap the objective function to consume flat _original_ - # numpy arrays and produce scalar outputs. - def fun_wrapper(x_flat, *args): - x = unravel(x_flat) - r = fun(x, *args) - r = r.value if isinstance(r, bm.Array) else r - return float(r) - - # Wrap the gradient in a similar manner - jac = jit(grad(fun)) - - def jac_wrapper(x_flat, *args): - x = unravel(x_flat) - g_flat, _ = ravel_pytree(jac(x, *args)) - return np.array(g_flat) - - # Wrap the callback to consume a pytree - def callback_wrapper(x_flat, *args): - if callback is not None: - x = unravel(x_flat) - return callback(x, *args) - - # Minimize with scipy - results = soptimize.minimize(fun_wrapper, - x0_flat, - args=args, - method=method, - jac=jac_wrapper, - callback=callback_wrapper, - bounds=bounds, - constraints=constraints, - tol=tol, - options=options) - - # pack the output back into a PyTree - results["x"] = unravel(results["x"]) - return results - - -def roots_of_1d_by_x(f, candidates, args=()): - """Find the roots of the given function by numerical methods. - """ - f = f_without_jaxarray_return(f) - candidates = candidates.value if isinstance(candidates, bm.Array) else candidates - args = tuple(a.value if isinstance(candidates, bm.Array) else a for a in args) - vals = f(candidates, *args) - signs = jnp.sign(vals) - zero_sign_idx = jnp.where(signs == 0)[0] - fps = candidates[zero_sign_idx] - candidate_ids = jnp.where(signs[:-1] * signs[1:] < 0)[0] - if len(candidate_ids) <= 0: - return fps - starts = candidates[candidate_ids] - ends = candidates[candidate_ids + 1] - f_opt = jax.jit(vmap(jax_brentq(f), in_axes=(0, 0, None))) - res = f_opt(starts, ends, args) - valid_idx = jnp.where(res['status'] == ECONVERGED)[0] - fps2 = res['root'][valid_idx] - return jnp.concatenate([fps, fps2]) - - -def roots_of_1d_by_xy(f, starts, ends, args): - f = f_without_jaxarray_return(f) - f_opt = jax.jit(vmap(jax_brentq(f))) - res = f_opt(starts, ends, (args,)) - valid_idx = jnp.where(res['status'] == ECONVERGED)[0] - xs = res['root'][valid_idx] - ys = args[valid_idx] - return xs, ys - - -# @tools.numba_jit -def numpy_brentq(f, a, b, args=(), xtol=2e-14, maxiter=200, rtol=4 * np.finfo(float).eps): - """ - Find a root of a function in a bracketing interval using Brent's method - adapted from Scipy's brentq. - - Uses the classic Brent's method to find a zero of the function `f` on - the sign changing interval [a , b]. - - Parameters - ---------- - f : callable - Python function returning a number. `f` must be continuous. - a : number - One end of the bracketing interval [a,b]. - b : number - The other end of the bracketing interval [a,b]. - args : tuple, optional(default=()) - Extra arguments to be used in the function call. - xtol : number, optional(default=2e-12) - The computed root ``x0`` will satisfy ``np.allclose(x, x0, - atol=xtol, rtol=rtol)``, where ``x`` is the exact root. The - parameter must be nonnegative. - rtol : number, optional(default=4*np.finfo(float).eps) - The computed root ``x0`` will satisfy ``np.allclose(x, x0, - atol=xtol, rtol=rtol)``, where ``x`` is the exact root. - maxiter : number, optional(default=100) - Maximum number of iterations. - """ - if xtol <= 0: - raise ValueError("xtol is too small (<= 0)") - if maxiter < 1: - raise ValueError("maxiter must be greater than 0") - - # Convert to float - xpre = a * 1.0 - xcur = b * 1.0 - - # Conditional checks for intervals in methods involving bisection - fpre = f(xpre, *args) - fcur = f(xcur, *args) - funcalls = 2 - - if fpre * fcur > 0: - raise ValueError("f(a) and f(b) must have different signs") - root = 0.0 - status = ECONVERR - - # Root found at either end of [a,b] - if fpre == 0: - root = xpre - status = ECONVERGED - if fcur == 0: - root = xcur - status = ECONVERGED - - root, status = root, status - - # Check for sign error and early termination - if status == ECONVERGED: - itr = 0 - else: - # Perform Brent's method - for itr in range(maxiter): - if fpre * fcur < 0: - xblk = xpre - fblk = fpre - spre = scur = xcur - xpre - if abs(fblk) < abs(fcur): - xpre = xcur - xcur = xblk - xblk = xpre - - fpre = fcur - fcur = fblk - fblk = fpre - - delta = (xtol + rtol * abs(xcur)) / 2 - sbis = (xblk - xcur) / 2 - - # Root found - if fcur == 0 or abs(sbis) < delta: - status = ECONVERGED - root = xcur - itr += 1 - break - - if abs(spre) > delta and abs(fcur) < abs(fpre): - if xpre == xblk: - # interpolate - stry = -fcur * (xcur - xpre) / (fcur - fpre) - else: - # extrapolate - dpre = (fpre - fcur) / (xpre - xcur) - dblk = (fblk - fcur) / (xblk - xcur) - stry = -fcur * (fblk * dblk - fpre * dpre) / \ - (dblk * dpre * (fblk - fpre)) - - if 2 * abs(stry) < min(abs(spre), 3 * abs(sbis) - delta): - # good short step - spre = scur - scur = stry - else: - # bisect - spre = sbis - scur = sbis - else: - # bisect - spre = sbis - scur = sbis - - xpre = xcur - fpre = fcur - if abs(scur) > delta: - xcur += scur - else: - xcur += (delta if sbis > 0 else -delta) - fcur = f(xcur, *args) - funcalls += 1 - - if status == ECONVERR: - raise RuntimeError("Failed to converge") - - # x, funcalls, iterations = root, funcalls, itr - return root, funcalls, itr - - -# @tools.numba_jit -def find_root_of_1d_numpy(f, f_points, args=(), tol=1e-8): - """Find the roots of the given function by numerical methods. - - Parameters - ---------- - f : callable - The function. - f_points : np.ndarray, list, tuple - The value points. - - Returns - ------- - roots : list - The roots. - """ - vals = f(f_points, *args) - fs_len = len(f_points) - signs = np.sign(vals) - - roots = [] - sign_l = signs[0] - point_l = f_points[0] - idx = 1 - while idx < fs_len and sign_l == 0.: - roots.append(f_points[idx - 1]) - sign_l = signs[idx] - idx += 1 - while idx < fs_len: - sign_r = signs[idx] - point_r = f_points[idx] - if sign_r == 0.: - roots.append(point_r) - if idx + 1 < fs_len: - sign_l = sign_r - point_l = point_r - else: - break - idx += 1 - else: - if not np.isnan(sign_r) and sign_l != sign_r: - root, funcalls, itr = numpy_brentq(f, point_l, point_r, args) - if abs(f(root, *args)) < tol: roots.append(root) - sign_l = sign_r - point_l = point_r - idx += 1 - - return roots diff --git a/brainpy/_src/optimizers/__init__.py b/brainpy/_src/optimizers/__init__.py index ed3b22c6b..21cbe7c6e 100644 --- a/brainpy/_src/optimizers/__init__.py +++ b/brainpy/_src/optimizers/__init__.py @@ -1,4 +1,4 @@ # -*- coding: utf-8 -*- -from .optimizer import * -from .scheduler import * +from .sgd_optimizer import * +from .sgd_scheduler import * diff --git a/brainpy/_src/optimizers/base.py b/brainpy/_src/optimizers/base.py new file mode 100644 index 000000000..26849ad1d --- /dev/null +++ b/brainpy/_src/optimizers/base.py @@ -0,0 +1,36 @@ +import abc + +from tqdm.auto import tqdm + +__all__ = ['Optimizer'] + + +class Optimizer(metaclass=abc.ABCMeta): + """ + Optimizer class created as a base for optimization initialization and + performance with different libraries. To be used with modelfitting + Fitter. + """ + + @abc.abstractmethod + def initialize(self, *args, **kwargs): + """ + Initialize the instrumentation for the optimization, based on + parameters, creates bounds for variables and attaches them to the + optimizer + """ + pass + + @abc.abstractmethod + def one_trial(self, *args, **kwargs): + """ + Returns the requested number of samples of parameter sets + """ + pass + + def minimize(self, n_iter): + results = [] + for i in tqdm(range(n_iter)): + r = self.one_trial(choice_best=i + 1 == n_iter) + results.append(r) + return results[-1] diff --git a/brainpy/_src/optimizers/brentq.py b/brainpy/_src/optimizers/brentq.py new file mode 100644 index 000000000..000c4739a --- /dev/null +++ b/brainpy/_src/optimizers/brentq.py @@ -0,0 +1,351 @@ +import jax.lax +import jax.numpy as jnp +import numpy as np + +import brainpy._src.math as bm + +ECONVERGED = 0 +ECONVERR = -1 + + +def _logical_or(a, b): + a = a.value if isinstance(a, bm.Array) else a + b = b.value if isinstance(b, bm.Array) else b + return jnp.logical_or(a, b) + + +def _logical_and(a, b): + a = a.value if isinstance(a, bm.Array) else a + b = b.value if isinstance(b, bm.Array) else b + return jnp.logical_and(a, b) + + +def _where(p, a, b): + p = p.value if isinstance(p, bm.Array) else p + a = a.value if isinstance(a, bm.Array) else a + b = b.value if isinstance(b, bm.Array) else b + return jnp.where(p, a, b) + + +def jax_brentq(fun): + assert jax.config.read('jax_enable_x64'), ('Brentq optimization need x64 support. ' + 'Please enable x64 with "brainpy.math.enable_x64()"') + rtol = 4 * jnp.finfo(jnp.float64).eps + + # if jax.config.read('jax_enable_x64'): + # rtol = 4 * jnp.finfo(jnp.float64).eps + # else: + # rtol = 1.5 * jnp.finfo(jnp.float32).eps + + @jax.jit + def x(a, b, args=(), xtol=2e-14, maxiter=200): + # Convert to float + xpre = a * 1.0 + xcur = b * 1.0 + + # Conditional checks for intervals in methods involving bisection + fpre = fun(xpre, *args) + fcur = fun(xcur, *args) + + # Root found at either end of [a,b] + root = _where(fpre == 0, xpre, 0.) + status = _where(fpre == 0, ECONVERGED, ECONVERR) + root = _where(fcur == 0, xcur, root) + status = _where(fcur == 0, ECONVERGED, status) + + # Check for sign error and early termination + # Perform Brent's method + def _f1(x): + x['xblk'] = x['xpre'] + x['fblk'] = x['fpre'] + x['spre'] = x['xcur'] - x['xpre'] + x['scur'] = x['xcur'] - x['xpre'] + return x + + def _f2(x): + x['xpre'] = x['xcur'] + x['xcur'] = x['xblk'] + x['xblk'] = x['xpre'] + x['fpre'] = x['fcur'] + x['fcur'] = x['fblk'] + x['fblk'] = x['fpre'] + return x + + def _f5(x): + x['stry'] = -x['fcur'] * (x['xcur'] - x['xpre']) / (x['fcur'] - x['fpre']) + return x + + def _f6(x): + x['dpre'] = (x['fpre'] - x['fcur']) / (x['xpre'] - x['xcur']) + dblk = (x['fblk'] - x['fcur']) / (x['xblk'] - x['xcur']) + _tmp = dblk * x['dpre'] * (x['fblk'] - x['fpre']) + x['stry'] = -x['fcur'] * (x['fblk'] * dblk - x['fpre'] * x['dpre']) / _tmp + return x + + def _f3(x): + x = jax.lax.cond(x['xpre'] == x['xblk'], _f5, _f6, x) + k = jnp.min(jnp.array([abs(x['spre']), 3 * abs(x['sbis']) - x['delta']])) + j = 2 * abs(x['stry']) < k + x['spre'] = _where(j, x['scur'], x['sbis']) + x['scur'] = _where(j, x['stry'], x['sbis']) + return x + + def _f4(x): # bisect + x['spre'] = x['sbis'] + x['scur'] = x['sbis'] + return x + + def body_fun(x): + x['itr'] += 1 + x = jax.lax.cond(x['fpre'] * x['fcur'] < 0, _f1, lambda a: a, x) + x = jax.lax.cond(abs(x['fblk']) < abs(x['fcur']), _f2, lambda a: a, x) + x['delta'] = (xtol + rtol * abs(x['xcur'])) / 2 + x['sbis'] = (x['xblk'] - x['xcur']) / 2 + # Root found + j = _logical_or(x['fcur'] == 0, abs(x['sbis']) < x['delta']) + x['status'] = _where(j, ECONVERGED, x['status']) + x['root'] = _where(j, x['xcur'], x['root']) + x = jax.lax.cond(_logical_and(abs(x['spre']) > x['delta'], abs(x['fcur']) < abs(x['fpre'])), + _f3, _f4, x) + x['xpre'] = x['xcur'] + x['fpre'] = x['fcur'] + x['xcur'] += _where(abs(x['scur']) > x['delta'], + x['scur'], _where(x['sbis'] > 0, x['delta'], -x['delta'])) + x['fcur'] = fun(x['xcur'], *args) + x['funcalls'] += 1 + return x + + def cond_fun(R): + return jnp.logical_and(R['status'] != ECONVERGED, R['itr'] <= maxiter) + + R = dict(root=root, status=status, xpre=xpre, xcur=xcur, fpre=fpre, fcur=fcur, + itr=0, funcalls=2, xblk=xpre, fblk=fpre, + sbis=(xpre - xcur) / 2, + delta=(xtol + rtol * abs(xcur)) / 2, + stry=-fcur * (xcur - xpre) / (fcur - fpre), + scur=xcur - xpre, spre=xcur - xpre, + dpre=(fpre - fcur) / (xpre - xcur)) + R = jax.lax.cond( + status == ECONVERGED, + lambda x: x, + lambda x: jax.lax.while_loop(cond_fun, body_fun, x), + R + ) + return dict(root=R['root'], funcalls=R['funcalls'], itr=R['itr'], status=R['status']) + + return x + + +# @tools.numba_jit +def numpy_brentq( + f, a, b, args=(), xtol=2e-14, maxiter=200, rtol=4 * np.finfo(float).eps +): + """ + Find a root of a function in a bracketing interval using Brent's method + adapted from Scipy's brentq. + + Uses the classic Brent's method to find a zero of the function `f` on + the sign changing interval [a , b]. + + Parameters + ---------- + f : callable + Python function returning a number. `f` must be continuous. + a : number + One end of the bracketing interval [a,b]. + b : number + The other end of the bracketing interval [a,b]. + args : tuple, optional(default=()) + Extra arguments to be used in the function call. + xtol : number, optional(default=2e-12) + The computed root ``x0`` will satisfy ``np.allclose(x, x0, + atol=xtol, rtol=rtol)``, where ``x`` is the exact root. The + parameter must be nonnegative. + rtol : number, optional(default=4*np.finfo(float).eps) + The computed root ``x0`` will satisfy ``np.allclose(x, x0, + atol=xtol, rtol=rtol)``, where ``x`` is the exact root. + maxiter : number, optional(default=100) + Maximum number of iterations. + """ + if xtol <= 0: + raise ValueError("xtol is too small (<= 0)") + if maxiter < 1: + raise ValueError("maxiter must be greater than 0") + + # Convert to float + xpre = a * 1.0 + xcur = b * 1.0 + + # Conditional checks for intervals in methods involving bisection + fpre = f(xpre, *args) + fcur = f(xcur, *args) + funcalls = 2 + + if fpre * fcur > 0: + raise ValueError("f(a) and f(b) must have different signs") + root = 0.0 + status = ECONVERR + + # Root found at either end of [a,b] + if fpre == 0: + root = xpre + status = ECONVERGED + if fcur == 0: + root = xcur + status = ECONVERGED + + root, status = root, status + + # Check for sign error and early termination + if status == ECONVERGED: + itr = 0 + else: + # Perform Brent's method + for itr in range(maxiter): + if fpre * fcur < 0: + xblk = xpre + fblk = fpre + spre = scur = xcur - xpre + if abs(fblk) < abs(fcur): + xpre = xcur + xcur = xblk + xblk = xpre + + fpre = fcur + fcur = fblk + fblk = fpre + + delta = (xtol + rtol * abs(xcur)) / 2 + sbis = (xblk - xcur) / 2 + + # Root found + if fcur == 0 or abs(sbis) < delta: + status = ECONVERGED + root = xcur + itr += 1 + break + + if abs(spre) > delta and abs(fcur) < abs(fpre): + if xpre == xblk: + # interpolate + stry = -fcur * (xcur - xpre) / (fcur - fpre) + else: + # extrapolate + dpre = (fpre - fcur) / (xpre - xcur) + dblk = (fblk - fcur) / (xblk - xcur) + stry = -fcur * (fblk * dblk - fpre * dpre) / \ + (dblk * dpre * (fblk - fpre)) + + if 2 * abs(stry) < min(abs(spre), 3 * abs(sbis) - delta): + # good short step + spre = scur + scur = stry + else: + # bisect + spre = sbis + scur = sbis + else: + # bisect + spre = sbis + scur = sbis + + xpre = xcur + fpre = fcur + if abs(scur) > delta: + xcur += scur + else: + xcur += (delta if sbis > 0 else -delta) + fcur = f(xcur, *args) + funcalls += 1 + + if status == ECONVERR: + raise RuntimeError("Failed to converge") + + # x, funcalls, iterations = root, funcalls, itr + return root, funcalls, itr + + +# @tools.numba_jit +def find_root_of_1d_numpy(f, f_points, args=(), tol=1e-8): + """Find the roots of the given function by numerical methods. + + Parameters + ---------- + f : callable + The function. + f_points : np.ndarray, list, tuple + The value points. + + Returns + ------- + roots : list + The roots. + """ + vals = f(f_points, *args) + fs_len = len(f_points) + signs = np.sign(vals) + + roots = [] + sign_l = signs[0] + point_l = f_points[0] + idx = 1 + while idx < fs_len and sign_l == 0.: + roots.append(f_points[idx - 1]) + sign_l = signs[idx] + idx += 1 + while idx < fs_len: + sign_r = signs[idx] + point_r = f_points[idx] + if sign_r == 0.: + roots.append(point_r) + if idx + 1 < fs_len: + sign_l = sign_r + point_l = point_r + else: + break + idx += 1 + else: + if not np.isnan(sign_r) and sign_l != sign_r: + root, funcalls, itr = numpy_brentq(f, point_l, point_r, args) + if abs(f(root, *args)) < tol: roots.append(root) + sign_l = sign_r + point_l = point_r + idx += 1 + + return roots + + + +def brentq_candidates(vmap_f, *values, args=()): + # change the position of meshgrid values + values = tuple((v.value if isinstance(v, bm.Array) else v) for v in values) + xs = values[0] + mesh_values = jnp.meshgrid(*values) + if jnp.ndim(mesh_values[0]) > 1: + mesh_values = tuple(jnp.moveaxis(m, 0, 1) for m in mesh_values) + mesh_values = tuple(m.flatten() for m in mesh_values) + # function outputs + signs = jnp.sign(vmap_f(*(mesh_values + args))) + # compute the selected values + signs = signs.reshape((xs.shape[0], -1)) + par_len = signs.shape[1] + signs1 = signs.at[-1].set(1) # discard the final row + signs2 = jnp.vstack((signs[1:], signs[:1])).at[-1].set(1) # discard the first row + ids = jnp.where((signs1 * signs2).flatten() <= 0)[0] + x_starts = mesh_values[0][ids] + x_ends = mesh_values[0][ids + par_len] + other_vals = tuple(v[ids] for v in mesh_values[1:]) + return x_starts, x_ends, other_vals + + + + + +def brentq_roots(vmap_f, starts, ends, *vmap_args, args=()): + all_args = vmap_args + args + res = vmap_f(starts, ends, all_args) + valid_idx = jnp.where(res['status'] == ECONVERGED)[0] + roots = res['root'][valid_idx] + vmap_args = tuple(a[valid_idx] for a in vmap_args) + return roots, vmap_args + diff --git a/brainpy/_src/optimizers/nevergrad_optimizer.py b/brainpy/_src/optimizers/nevergrad_optimizer.py new file mode 100644 index 000000000..e69de29bb diff --git a/brainpy/_src/optimizers/scipy_minimize.py b/brainpy/_src/optimizers/scipy_minimize.py new file mode 100644 index 000000000..bfd810f1a --- /dev/null +++ b/brainpy/_src/optimizers/scipy_minimize.py @@ -0,0 +1,181 @@ +import jax +import jax.numpy as jnp +import numpy as np +from jax.flatten_util import ravel_pytree + +import brainpy.math as bm +from brainpy import errors + +soptimize = None + + +def scipy_minimize_with_jax( + fun, x0, + method=None, + args=(), + bounds=None, + constraints=(), + tol=None, + callback=None, + options=None +): + """ + A simple wrapper for scipy.optimize.minimize using JAX. + + Parameters + ---------- + fun: function + The objective function to be minimized, written in JAX code + so that it is automatically differentiable. It is of type, + ```fun: x, *args -> float``` where `x` is a PyTree and args + is a tuple of the fixed parameters needed to completely specify the function. + + x0: jnp.ndarray + Initial guess represented as a JAX PyTree. + + args: tuple, optional. + Extra arguments passed to the objective function + and its derivative. Must consist of valid JAX types; e.g. the leaves + of the PyTree must be floats. + + method : str or callable, optional + Type of solver. Should be one of + - 'Nelder-Mead' :ref:`(see here) ` + - 'Powell' :ref:`(see here) ` + - 'CG' :ref:`(see here) ` + - 'BFGS' :ref:`(see here) ` + - 'Newton-CG' :ref:`(see here) ` + - 'L-BFGS-B' :ref:`(see here) ` + - 'TNC' :ref:`(see here) ` + - 'COBYLA' :ref:`(see here) ` + - 'SLSQP' :ref:`(see here) ` + - 'trust-constr':ref:`(see here) ` + - 'dogleg' :ref:`(see here) ` + - 'trust-ncg' :ref:`(see here) ` + - 'trust-exact' :ref:`(see here) ` + - 'trust-krylov' :ref:`(see here) ` + - custom - a callable object (added in version 0.14.0), + see below for description. + If not given, chosen to be one of ``BFGS``, ``L-BFGS-B``, ``SLSQP``, + depending on if the problem has constraints or bounds. + + bounds : sequence or `Bounds`, optional + Bounds on variables for L-BFGS-B, TNC, SLSQP, Powell, and + trust-constr methods. There are two ways to specify the bounds: + 1. Instance of `Bounds` class. + 2. Sequence of ``(min, max)`` pairs for each element in `x`. None + is used to specify no bound. + Note that in order to use `bounds` you will need to manually flatten + them in the same order as your inputs `x0`. + + constraints : {Constraint, dict} or List of {Constraint, dict}, optional + Constraints definition (only for COBYLA, SLSQP and trust-constr). + Constraints for 'trust-constr' are defined as a single object or a + list of objects specifying constraints to the optimization problem. + Available constraints are: + - `LinearConstraint` + - `NonlinearConstraint` + Constraints for COBYLA, SLSQP are defined as a list of dictionaries. + Each dictionary with fields: + type : str + Constraint type: 'eq' for equality, 'ineq' for inequality. + fun : callable + The function defining the constraint. + jac : callable, optional + The Jacobian of `fun` (only for SLSQP). + args : sequence, optional + Extra arguments to be passed to the function and Jacobian. + Equality constraint means that the constraint function result is to + be zero whereas inequality means that it is to be non-negative. + Note that COBYLA only supports inequality constraints. + + Note that in order to use `constraints` you will need to manually flatten + them in the same order as your inputs `x0`. + + tol : float, optional + Tolerance for termination. For detailed control, use solver-specific + options. + + options : dict, optional + A dictionary of solver options. All methods accept the following + generic options: + maxiter : int + Maximum number of iterations to perform. Depending on the + method each iteration may use several function evaluations. + disp : bool + Set to True to print convergence messages. + For method-specific options, see :func:`show_options()`. + + callback : callable, optional + Called after each iteration. For 'trust-constr' it is a callable with + the signature: + ``callback(xk, OptimizeResult state) -> bool`` + where ``xk`` is the current parameter vector represented as a PyTree, + and ``state`` is an `OptimizeResult` object, with the same fields + as the ones from the return. If callback returns True the algorithm + execution is terminated. + + For all the other methods, the signature is: + ```callback(xk)``` + where `xk` is the current parameter vector, represented as a PyTree. + + Returns + ------- + res : The optimization result represented as a ``OptimizeResult`` object. + Important attributes are: + ``x``: the solution array, represented as a JAX PyTree + ``success``: a Boolean flag indicating if the optimizer exited successfully + ``message``: describes the cause of the termination. + See `scipy.optimize.OptimizeResult` for a description of other attributes. + + """ + global soptimize + if soptimize is None: + try: + import scipy.optimize as soptimize + except ImportError: + raise errors.PackageMissingError(f'"scipy" must be installed when user want to use ' + f'function: {scipy_minimize_with_jax}') + + # Use tree flatten and unflatten to convert params x0 from PyTrees to flat arrays + x0_flat, unravel = ravel_pytree(x0) + + # Wrap the objective function to consume flat _original_ + # numpy arrays and produce scalar outputs. + def fun_wrapper(x_flat, *args): + x = unravel(x_flat) + r = fun(x, *args) + r = r.value if isinstance(r, bm.Array) else r + return float(r) + + # Wrap the gradient in a similar manner + jac = jax.jit(jax.grad(fun)) + + def jac_wrapper(x_flat, *args): + x = unravel(x_flat) + g_flat, _ = ravel_pytree(jac(x, *args)) + return np.array(g_flat) + + # Wrap the callback to consume a pytree + def callback_wrapper(x_flat, *args): + if callback is not None: + x = unravel(x_flat) + return callback(x, *args) + + # Minimize with scipy + results = soptimize.minimize( + fun_wrapper, + x0_flat, + args=args, + method=method, + jac=jac_wrapper, + callback=callback_wrapper, + bounds=bounds, + constraints=constraints, + tol=tol, + options=options + ) + + # pack the output back into a PyTree + results["x"] = unravel(results["x"]) + return results diff --git a/brainpy/_src/optimizers/optimizer.py b/brainpy/_src/optimizers/sgd_optimizer.py similarity index 98% rename from brainpy/_src/optimizers/optimizer.py rename to brainpy/_src/optimizers/sgd_optimizer.py index c2aec25a0..536b97195 100644 --- a/brainpy/_src/optimizers/optimizer.py +++ b/brainpy/_src/optimizers/sgd_optimizer.py @@ -5,13 +5,11 @@ import jax.numpy as jnp from jax.lax import cond -import brainpy as bp import brainpy.math as bm from brainpy import check -from brainpy._src.math.object_transform.base import BrainPyObject, ArrayCollector from brainpy.errors import MathError -from .scheduler import make_schedule, Scheduler +from .sgd_scheduler import make_schedule, Scheduler __all__ = [ 'Optimizer', @@ -28,7 +26,7 @@ ] -class Optimizer(BrainPyObject): +class Optimizer(bm.BrainPyObject): """Base Optimizer Class. Parameters @@ -40,7 +38,7 @@ class Optimizer(BrainPyObject): lr: Scheduler # learning rate '''Learning rate''' - vars_to_train: ArrayCollector # variables to train + vars_to_train: bm.VarDict # variables to train '''Variables to train.''' def __init__( @@ -49,9 +47,9 @@ def __init__( train_vars: Union[Sequence[bm.Variable], Dict[str, bm.Variable]] = None, name: Optional[str] = None ): - super(Optimizer, self).__init__(name=name) + super().__init__(name=name) self.lr: Scheduler = make_schedule(lr) - self.vars_to_train = ArrayCollector() + self.vars_to_train = bm.var_dict() self.register_train_vars(train_vars) def register_vars(self, train_vars: Optional[Dict[str, bm.Variable]] = None): @@ -72,8 +70,15 @@ def __repr__(self): def update(self, grads: dict): raise NotImplementedError + def zero_grad(self): + """ + Zero the gradients of all trainable variables. + """ + for p in self.vars_to_train.values(): + p.value = jnp.zeros_like(p.value) -class CommonOpt(Optimizer): + +class _CommonOpt(Optimizer): def __init__( self, lr: Union[float, Scheduler, bm.Variable], @@ -81,14 +86,11 @@ def __init__( weight_decay: Optional[float] = None, name: Optional[str] = None ): - super(Optimizer, self).__init__(name=name) - self.lr: Scheduler = make_schedule(lr) - self.vars_to_train = ArrayCollector() - self.register_train_vars(train_vars) + super().__init__(name=name, lr=lr, train_vars=train_vars) self.weight_decay = check.is_float(weight_decay, min_bound=0., max_bound=1., allow_none=True) -class SGD(CommonOpt): +class SGD(_CommonOpt): r"""Stochastic gradient descent optimizer. SGD performs a parameter update for training examples :math:`x` and label @@ -138,7 +140,7 @@ def update(self, grads: dict): self.lr.step_call() -class Momentum(CommonOpt): +class Momentum(_CommonOpt): r"""Momentum optimizer. Momentum [1]_ is a method that helps accelerate SGD in the relevant direction @@ -209,7 +211,7 @@ def update(self, grads: dict): self.lr.step_call() -class MomentumNesterov(CommonOpt): +class MomentumNesterov(_CommonOpt): r"""Nesterov accelerated gradient optimizer [2]_. .. math:: @@ -273,7 +275,7 @@ def update(self, grads: dict): self.lr.step_call() -class Adagrad(CommonOpt): +class Adagrad(_CommonOpt): r"""Optimizer that implements the Adagrad algorithm. Adagrad [3]_ is an optimizer with parameter-specific learning rates, which are @@ -345,7 +347,7 @@ def __repr__(self): return f"{self.__class__.__name__}(lr={self.lr}, epsilon={self.epsilon})" -class Adadelta(CommonOpt): +class Adadelta(_CommonOpt): r"""Optimizer that implements the Adadelta algorithm. Adadelta [4]_ optimization is a stochastic gradient descent method that is based @@ -437,7 +439,7 @@ def __repr__(self): f"epsilon={self.epsilon}, rho={self.rho})") -class RMSProp(CommonOpt): +class RMSProp(_CommonOpt): r"""Optimizer that implements the RMSprop algorithm. RMSprop [5]_ and Adadelta have both been developed independently around the same time @@ -513,7 +515,7 @@ def __repr__(self): f"epsilon={self.epsilon}, rho={self.rho})") -class Adam(CommonOpt): +class Adam(_CommonOpt): """Optimizer that implements the Adam algorithm. Adam [6]_ - a stochastic gradient descent method (SGD) that computes @@ -598,7 +600,7 @@ def update(self, grads: dict): self.lr.step_call() -class LARS(CommonOpt): +class LARS(_CommonOpt): r"""Layer-wise adaptive rate scaling (LARS) optimizer [1]_. Layer-wise Adaptive Rate Scaling, or LARS, is a large batch @@ -678,7 +680,7 @@ def update(self, grads: dict): self.lr.step_call() -class Adan(CommonOpt): +class Adan(_CommonOpt): r"""Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models [1]_. .. math:: @@ -817,7 +819,7 @@ def update(self, grads: dict): self.lr.step_call() -class AdamW(CommonOpt): +class AdamW(_CommonOpt): r"""Adam with weight decay regularization [1]_. AdamW uses weight decay to regularize learning towards small weights, as @@ -977,7 +979,7 @@ def update(self, grads: dict): self.lr.step_call() -class SM3(CommonOpt): +class SM3(_CommonOpt): """SM3 algorithm [1]_. The 'Square-root of Minima of Sums of Maxima of Squared-gradients Method' diff --git a/brainpy/_src/optimizers/scheduler.py b/brainpy/_src/optimizers/sgd_scheduler.py similarity index 99% rename from brainpy/_src/optimizers/scheduler.py rename to brainpy/_src/optimizers/sgd_scheduler.py index b27398dae..d2f47da8e 100644 --- a/brainpy/_src/optimizers/scheduler.py +++ b/brainpy/_src/optimizers/sgd_scheduler.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- import warnings -from functools import partial from typing import Sequence, Union import jax @@ -8,7 +7,6 @@ import brainpy.math as bm from brainpy import check -from brainpy._src.math.object_transform.base import BrainPyObject from brainpy.errors import MathError @@ -25,7 +23,7 @@ def make_schedule(scalar_or_schedule): raise TypeError(type(scalar_or_schedule)) -class Scheduler(BrainPyObject): +class Scheduler(bm.BrainPyObject): """The learning rate scheduler.""" def __init__(self, lr: Union[float, bm.Variable], last_epoch: int = -1): diff --git a/brainpy/_src/optimizers/skopt_bayesian.py b/brainpy/_src/optimizers/skopt_bayesian.py new file mode 100644 index 000000000..4482d0624 --- /dev/null +++ b/brainpy/_src/optimizers/skopt_bayesian.py @@ -0,0 +1,98 @@ +from typing import Callable, Optional, Sequence + +import numpy as np + +from .base import Optimizer + +__all__ = ['SkoptBayesOptimizer'] + + +class SkoptBayesOptimizer(Optimizer): + """ + SkoptOptimizer instance creates all the tools necessary for the user + to use it with scikit-optimize library. + + Parameters + ---------- + parameter_names: list[str] + Parameters to be used as instruments. + bounds : list + List with appropiate bounds for each parameter. + method : `str`, optional + The optimization method. Possibilities: "GP", "RF", "ET", "GBRT" or + sklearn regressor, default="GP" + n_calls: int + Number of calls to ``func``. Defaults to 100. + n_jobs: int + The number of jobs to run in parallel for ``fit``. If -1, then the + number of jobs is set to the number of cores. + + """ + + def __init__( + self, + loss_fun: Callable, + n_sample: int, + bounds: Optional[Sequence] = None, + method: str = 'GP', + **kwds + ): + super().__init__() + + try: + from sklearn.base import RegressorMixin # noqa + except (ImportError, ModuleNotFoundError): + raise ImportError("scikit-learn must be installed to use this class") + + # loss function + assert callable(loss_fun), "'loss_fun' must be a callable function" + self.loss_fun = loss_fun + + # method + if not (method.upper() in ["GP", "RF", "ET", "GBRT"] or isinstance(method, RegressorMixin)): + raise AssertionError(f"Provided method: {method} is not an skopt optimization or a regressor") + self.method = method + + # population size + assert n_sample > 0, "'n_sample' must be a positive integer" + self.n_sample = n_sample + + # bounds + if bounds is None: + bounds = () + self.bounds = bounds + + # others + self.kwds = kwds + + def initialize(self): + try: + from skopt.optimizer import Optimizer # noqa + from skopt.space import Real # noqa + except (ImportError, ModuleNotFoundError): + raise ImportError("scikit-optimize must be installed to use this class") + self.tested_parameters = [] + self.errors = [] + instruments = [] + for bound in self.bounds: + instrumentation = Real(*np.asarray(bound), transform='normalize') + instruments.append(instrumentation) + self.optim = Optimizer(dimensions=instruments, base_estimator=self.method, **self.kwds) + + def one_trial(self, choice_best: bool = False): + # draw parameters + parameters = self.optim.ask(n_points=self.n_sample) + self.tested_parameters.extend(parameters) + + # errors + errors = self.loss_fun(*np.asarray(parameters).T) + errors = np.asarray(errors).tolist() + self.errors.extend(errors) + + # tell + self.optim.tell(parameters, errors) + + if choice_best: + xi = self.optim.Xi + yii = np.array(self.optim.yi) + return xi[yii.argmin()] diff --git a/brainpy/_src/optimizers/tests/test_scheduler.py b/brainpy/_src/optimizers/tests/test_scheduler.py index dbdda0eda..82f2fd1e7 100644 --- a/brainpy/_src/optimizers/tests/test_scheduler.py +++ b/brainpy/_src/optimizers/tests/test_scheduler.py @@ -7,7 +7,7 @@ from absl.testing import parameterized import brainpy.math as bm -from brainpy._src.optimizers import scheduler +from brainpy._src.optimizers import sgd_scheduler show = False @@ -18,8 +18,8 @@ class TestMultiStepLR(parameterized.TestCase): ) def test2(self, last_epoch): bm.random.seed() - scheduler1 = scheduler.MultiStepLR(0.1, [10, 20], gamma=0.1, last_epoch=last_epoch) - scheduler2 = scheduler.MultiStepLR(0.1, [10, 20], gamma=0.1, last_epoch=last_epoch) + scheduler1 = sgd_scheduler.MultiStepLR(0.1, [10, 20], gamma=0.1, last_epoch=last_epoch) + scheduler2 = sgd_scheduler.MultiStepLR(0.1, [10, 20], gamma=0.1, last_epoch=last_epoch) for i in range(1, 25): lr1 = scheduler1(i + last_epoch) @@ -38,8 +38,8 @@ class TestStepLR(parameterized.TestCase): ) def test1(self, last_epoch): bm.random.seed() - scheduler1 = scheduler.StepLR(0.1, 10, gamma=0.1, last_epoch=last_epoch) - scheduler2 = scheduler.StepLR(0.1, 10, gamma=0.1, last_epoch=last_epoch) + scheduler1 = sgd_scheduler.StepLR(0.1, 10, gamma=0.1, last_epoch=last_epoch) + scheduler2 = sgd_scheduler.StepLR(0.1, 10, gamma=0.1, last_epoch=last_epoch) for i in range(1, 25): lr1 = scheduler1(i + last_epoch) lr2 = scheduler2() @@ -54,7 +54,7 @@ def test1(self): bm.random.seed() max_epoch = 50 iters = 200 - sch = scheduler.CosineAnnealingLR(0.1, T_max=5, eta_min=0, last_epoch=-1) + sch = sgd_scheduler.CosineAnnealingLR(0.1, T_max=5, eta_min=0, last_epoch=-1) all_lr1 = [[], []] all_lr2 = [[], []] for epoch in range(max_epoch): @@ -81,11 +81,11 @@ def test1(self): bm.random.seed() max_epoch = 50 iters = 200 - sch = scheduler.CosineAnnealingWarmRestarts(0.1, - iters, - T_0=5, - T_mult=1, - last_call=-1) + sch = sgd_scheduler.CosineAnnealingWarmRestarts(0.1, + iters, + T_0=5, + T_mult=1, + last_call=-1) all_lr1 = [] all_lr2 = [] for epoch in range(max_epoch): diff --git a/brainpy/optim.py b/brainpy/optim.py index de66e3700..66419ddd4 100644 --- a/brainpy/optim.py +++ b/brainpy/optim.py @@ -5,10 +5,10 @@ # ---------- # -from brainpy._src.optimizers.optimizer import ( +from brainpy._src.optimizers.sgd_optimizer import ( Optimizer as Optimizer, ) -from brainpy._src.optimizers.optimizer import ( +from brainpy._src.optimizers.sgd_optimizer import ( SGD as SGD, Momentum as Momentum, MomentumNesterov as MomentumNesterov, @@ -26,11 +26,11 @@ # ---------- # -from brainpy._src.optimizers.scheduler import ( +from brainpy._src.optimizers.sgd_scheduler import ( make_schedule as make_schedule, Scheduler as Scheduler, ) -from brainpy._src.optimizers.scheduler import ( +from brainpy._src.optimizers.sgd_scheduler import ( Constant as Constant, ExponentialDecay as ExponentialDecay, InverseTimeDecay as InverseTimeDecay, @@ -41,7 +41,7 @@ InverseTimeDecayLR as InverseTimeDecayLR, ExponentialDecayLR as ExponentialDecayLR ) -from brainpy._src.optimizers.scheduler import ( +from brainpy._src.optimizers.sgd_scheduler import ( StepLR as StepLR, MultiStepLR as MultiStepLR, ExponentialLR as ExponentialLR,