From bde2c1e6fd286062a7b0c3dae09eb20b608a8797 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Sun, 10 Mar 2024 10:57:46 +0800 Subject: [PATCH] move optimization methods in `brainpy.analysis` into `brainpy.optimizers` --- .../_src/analysis/lowdim/lowdim_analyzer.py | 18 +-- brainpy/_src/analysis/utils/__init__.py | 1 - brainpy/_src/analysis/utils/optimization.py | 110 ------------------ brainpy/_src/optimizers/brentq.py | 36 ++++++ 4 files changed, 45 insertions(+), 120 deletions(-) delete mode 100644 brainpy/_src/analysis/utils/optimization.py diff --git a/brainpy/_src/analysis/lowdim/lowdim_analyzer.py b/brainpy/_src/analysis/lowdim/lowdim_analyzer.py index c24f6d591..c75f74e5f 100644 --- a/brainpy/_src/analysis/lowdim/lowdim_analyzer.py +++ b/brainpy/_src/analysis/lowdim/lowdim_analyzer.py @@ -14,7 +14,7 @@ from brainpy._src.analysis import constants as C, utils from brainpy._src.analysis.base import DSAnalyzer from brainpy._src.math.object_transform.base import Collector -from brainpy._src.optimizers.brentq import jax_brentq, ECONVERGED +from brainpy._src.optimizers.brentq import jax_brentq, ECONVERGED, brentq_candidates, brentq_roots pyplot = None @@ -733,15 +733,15 @@ def _get_fx_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {_j} ...") if coords == self.x_var + '-' + self.y_var: - x0s, x1s, vps = utils.brentq_candidates(self.F_vmap_fx, *((xs, ys) + Ps)) - x_values_in_fx, out_args = utils.brentq_roots2(vmap_brentq_f1, x0s, x1s, *vps) + x0s, x1s, vps = brentq_candidates(self.F_vmap_fx, *((xs, ys) + Ps)) + x_values_in_fx, out_args = brentq_roots(vmap_brentq_f1, x0s, x1s, *vps) y_values_in_fx = out_args[0] p_values_in_fx = out_args[1:] x_values_in_fx, y_values_in_fx, p_values_in_fx = \ self._fp_filter(x_values_in_fx, y_values_in_fx, p_values_in_fx, fp_aux_filter) elif coords == self.y_var + '-' + self.x_var: - x0s, x1s, vps = utils.brentq_candidates(vmap_f2, *((ys, xs) + Ps)) - y_values_in_fx, out_args = utils.brentq_roots2(vmap_brentq_f2, x0s, x1s, *vps) + x0s, x1s, vps = brentq_candidates(vmap_f2, *((ys, xs) + Ps)) + y_values_in_fx, out_args = brentq_roots(vmap_brentq_f2, x0s, x1s, *vps) x_values_in_fx = out_args[0] p_values_in_fx = out_args[1:] x_values_in_fx, y_values_in_fx, p_values_in_fx = \ @@ -825,15 +825,15 @@ def _get_fy_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux for j, Ps in enumerate(par_seg): if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...") if coords == self.x_var + '-' + self.y_var: - starts, ends, vps = utils.brentq_candidates(self.F_vmap_fy, *((xs, ys) + Ps)) - x_values_in_fy, out_args = utils.brentq_roots2(vmap_brentq_f1, starts, ends, *vps) + starts, ends, vps = brentq_candidates(self.F_vmap_fy, *((xs, ys) + Ps)) + x_values_in_fy, out_args = brentq_roots(vmap_brentq_f1, starts, ends, *vps) y_values_in_fy = out_args[0] p_values_in_fy = out_args[1:] x_values_in_fy, y_values_in_fy, p_values_in_fy = \ self._fp_filter(x_values_in_fy, y_values_in_fy, p_values_in_fy, fp_aux_filter) elif coords == self.y_var + '-' + self.x_var: - starts, ends, vps = utils.brentq_candidates(vmap_f2, *((ys, xs) + Ps)) - y_values_in_fy, out_args = utils.brentq_roots2(vmap_brentq_f2, starts, ends, *vps) + starts, ends, vps = brentq_candidates(vmap_f2, *((ys, xs) + Ps)) + y_values_in_fy, out_args = brentq_roots(vmap_brentq_f2, starts, ends, *vps) x_values_in_fy = out_args[0] p_values_in_fy = out_args[1:] x_values_in_fy, y_values_in_fy, p_values_in_fy = \ diff --git a/brainpy/_src/analysis/utils/__init__.py b/brainpy/_src/analysis/utils/__init__.py index be8715821..52b4ceb5b 100644 --- a/brainpy/_src/analysis/utils/__init__.py +++ b/brainpy/_src/analysis/utils/__init__.py @@ -3,7 +3,6 @@ from .function import * from .measurement import * from .model import * -from .optimization import * from .others import * from .outputs import * from .visualization import * diff --git a/brainpy/_src/analysis/utils/optimization.py b/brainpy/_src/analysis/utils/optimization.py deleted file mode 100644 index 28d1cb10f..000000000 --- a/brainpy/_src/analysis/utils/optimization.py +++ /dev/null @@ -1,110 +0,0 @@ -# -*- coding: utf-8 -*- - - -import jax.lax -import jax.numpy as jnp -from jax import vmap - -import brainpy._src.math as bm -from brainpy._src.optimizers.brentq import jax_brentq, ECONVERGED -from . import f_without_jaxarray_return - -__all__ = [ - 'get_brentq_candidates', - 'brentq_candidates', - 'brentq_roots', - 'brentq_roots2', - 'roots_of_1d_by_x', - 'roots_of_1d_by_xy', -] - - -def get_brentq_candidates(f, xs, ys): - f = f_without_jaxarray_return(f) - xs = bm.as_jax(xs) - ys = bm.as_jax(ys) - Y, X = jnp.meshgrid(ys, xs) - vals = f(X, Y) - signs = jnp.sign(vals) - x_ids, y_ids = jnp.where(signs[:-1] * signs[1:] <= 0) - starts = xs[x_ids] - ends = xs[x_ids + 1] - args = ys[y_ids] - return starts, ends, args - - -def brentq_candidates(vmap_f, *values, args=()): - # change the position of meshgrid values - values = tuple((v.value if isinstance(v, bm.Array) else v) for v in values) - xs = values[0] - mesh_values = jnp.meshgrid(*values) - if jnp.ndim(mesh_values[0]) > 1: - mesh_values = tuple(jnp.moveaxis(m, 0, 1) for m in mesh_values) - mesh_values = tuple(m.flatten() for m in mesh_values) - # function outputs - signs = jnp.sign(vmap_f(*(mesh_values + args))) - # compute the selected values - signs = signs.reshape((xs.shape[0], -1)) - par_len = signs.shape[1] - signs1 = signs.at[-1].set(1) # discard the final row - signs2 = jnp.vstack((signs[1:], signs[:1])).at[-1].set(1) # discard the first row - ids = jnp.where((signs1 * signs2).flatten() <= 0)[0] - x_starts = mesh_values[0][ids] - x_ends = mesh_values[0][ids + par_len] - other_vals = tuple(v[ids] for v in mesh_values[1:]) - return x_starts, x_ends, other_vals - - -def brentq_roots(f, starts, ends, *vmap_args, args=()): - in_axes = (0, 0, tuple([0] * len(vmap_args)) + tuple([None] * len(args))) - vmap_f_opt = jax.jit(vmap(jax_brentq(f_without_jaxarray_return(f)), in_axes=in_axes)) - all_args = vmap_args + args - if len(all_args): - res = vmap_f_opt(starts, ends, all_args) - else: - res = vmap_f_opt(starts, ends, ) - valid_idx = jnp.where(res['status'] == ECONVERGED)[0] - roots = res['root'][valid_idx] - vmap_args = tuple(a[valid_idx] for a in vmap_args) - return roots, vmap_args - - -def brentq_roots2(vmap_f, starts, ends, *vmap_args, args=()): - all_args = vmap_args + args - res = vmap_f(starts, ends, all_args) - valid_idx = jnp.where(res['status'] == ECONVERGED)[0] - roots = res['root'][valid_idx] - vmap_args = tuple(a[valid_idx] for a in vmap_args) - return roots, vmap_args - -def roots_of_1d_by_x(f, candidates, args=()): - """Find the roots of the given function by numerical methods. - """ - f = f_without_jaxarray_return(f) - candidates = candidates.value if isinstance(candidates, bm.Array) else candidates - args = tuple(a.value if isinstance(candidates, bm.Array) else a for a in args) - vals = f(candidates, *args) - signs = jnp.sign(vals) - zero_sign_idx = jnp.where(signs == 0)[0] - fps = candidates[zero_sign_idx] - candidate_ids = jnp.where(signs[:-1] * signs[1:] < 0)[0] - if len(candidate_ids) <= 0: - return fps - starts = candidates[candidate_ids] - ends = candidates[candidate_ids + 1] - f_opt = jax.jit(vmap(jax_brentq(f), in_axes=(0, 0, None))) - res = f_opt(starts, ends, args) - valid_idx = jnp.where(res['status'] == ECONVERGED)[0] - fps2 = res['root'][valid_idx] - return jnp.concatenate([fps, fps2]) - - -def roots_of_1d_by_xy(f, starts, ends, args): - f_opt = jax.jit(vmap(jax_brentq(f_without_jaxarray_return(f)))) - res = f_opt(starts, ends, (args,)) - valid_idx = jnp.where(res['status'] == ECONVERGED)[0] - xs = res['root'][valid_idx] - ys = args[valid_idx] - return xs, ys - - diff --git a/brainpy/_src/optimizers/brentq.py b/brainpy/_src/optimizers/brentq.py index 22113114d..000c4739a 100644 --- a/brainpy/_src/optimizers/brentq.py +++ b/brainpy/_src/optimizers/brentq.py @@ -313,3 +313,39 @@ def find_root_of_1d_numpy(f, f_points, args=(), tol=1e-8): idx += 1 return roots + + + +def brentq_candidates(vmap_f, *values, args=()): + # change the position of meshgrid values + values = tuple((v.value if isinstance(v, bm.Array) else v) for v in values) + xs = values[0] + mesh_values = jnp.meshgrid(*values) + if jnp.ndim(mesh_values[0]) > 1: + mesh_values = tuple(jnp.moveaxis(m, 0, 1) for m in mesh_values) + mesh_values = tuple(m.flatten() for m in mesh_values) + # function outputs + signs = jnp.sign(vmap_f(*(mesh_values + args))) + # compute the selected values + signs = signs.reshape((xs.shape[0], -1)) + par_len = signs.shape[1] + signs1 = signs.at[-1].set(1) # discard the final row + signs2 = jnp.vstack((signs[1:], signs[:1])).at[-1].set(1) # discard the first row + ids = jnp.where((signs1 * signs2).flatten() <= 0)[0] + x_starts = mesh_values[0][ids] + x_ends = mesh_values[0][ids + par_len] + other_vals = tuple(v[ids] for v in mesh_values[1:]) + return x_starts, x_ends, other_vals + + + + + +def brentq_roots(vmap_f, starts, ends, *vmap_args, args=()): + all_args = vmap_args + args + res = vmap_f(starts, ends, all_args) + valid_idx = jnp.where(res['status'] == ECONVERGED)[0] + roots = res['root'][valid_idx] + vmap_args = tuple(a[valid_idx] for a in vmap_args) + return roots, vmap_args +