Skip to content

Commit

Permalink
move optimization methods in brainpy.analysis into `brainpy.optimiz…
Browse files Browse the repository at this point in the history
…ers`
  • Loading branch information
chaoming0625 committed Mar 10, 2024
1 parent 799c41f commit bde2c1e
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 120 deletions.
18 changes: 9 additions & 9 deletions brainpy/_src/analysis/lowdim/lowdim_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from brainpy._src.analysis import constants as C, utils
from brainpy._src.analysis.base import DSAnalyzer
from brainpy._src.math.object_transform.base import Collector
from brainpy._src.optimizers.brentq import jax_brentq, ECONVERGED
from brainpy._src.optimizers.brentq import jax_brentq, ECONVERGED, brentq_candidates, brentq_roots

pyplot = None

Expand Down Expand Up @@ -733,15 +733,15 @@ def _get_fx_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux
if len(par_seg.arg_id_segments[0]) > 1:
utils.output(f"{C.prefix}segment {_j} ...")
if coords == self.x_var + '-' + self.y_var:
x0s, x1s, vps = utils.brentq_candidates(self.F_vmap_fx, *((xs, ys) + Ps))
x_values_in_fx, out_args = utils.brentq_roots2(vmap_brentq_f1, x0s, x1s, *vps)
x0s, x1s, vps = brentq_candidates(self.F_vmap_fx, *((xs, ys) + Ps))
x_values_in_fx, out_args = brentq_roots(vmap_brentq_f1, x0s, x1s, *vps)
y_values_in_fx = out_args[0]
p_values_in_fx = out_args[1:]
x_values_in_fx, y_values_in_fx, p_values_in_fx = \
self._fp_filter(x_values_in_fx, y_values_in_fx, p_values_in_fx, fp_aux_filter)
elif coords == self.y_var + '-' + self.x_var:
x0s, x1s, vps = utils.brentq_candidates(vmap_f2, *((ys, xs) + Ps))
y_values_in_fx, out_args = utils.brentq_roots2(vmap_brentq_f2, x0s, x1s, *vps)
x0s, x1s, vps = brentq_candidates(vmap_f2, *((ys, xs) + Ps))
y_values_in_fx, out_args = brentq_roots(vmap_brentq_f2, x0s, x1s, *vps)
x_values_in_fx = out_args[0]
p_values_in_fx = out_args[1:]
x_values_in_fx, y_values_in_fx, p_values_in_fx = \
Expand Down Expand Up @@ -825,15 +825,15 @@ def _get_fy_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux
for j, Ps in enumerate(par_seg):
if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...")
if coords == self.x_var + '-' + self.y_var:
starts, ends, vps = utils.brentq_candidates(self.F_vmap_fy, *((xs, ys) + Ps))
x_values_in_fy, out_args = utils.brentq_roots2(vmap_brentq_f1, starts, ends, *vps)
starts, ends, vps = brentq_candidates(self.F_vmap_fy, *((xs, ys) + Ps))
x_values_in_fy, out_args = brentq_roots(vmap_brentq_f1, starts, ends, *vps)
y_values_in_fy = out_args[0]
p_values_in_fy = out_args[1:]
x_values_in_fy, y_values_in_fy, p_values_in_fy = \
self._fp_filter(x_values_in_fy, y_values_in_fy, p_values_in_fy, fp_aux_filter)
elif coords == self.y_var + '-' + self.x_var:
starts, ends, vps = utils.brentq_candidates(vmap_f2, *((ys, xs) + Ps))
y_values_in_fy, out_args = utils.brentq_roots2(vmap_brentq_f2, starts, ends, *vps)
starts, ends, vps = brentq_candidates(vmap_f2, *((ys, xs) + Ps))
y_values_in_fy, out_args = brentq_roots(vmap_brentq_f2, starts, ends, *vps)
x_values_in_fy = out_args[0]
p_values_in_fy = out_args[1:]
x_values_in_fy, y_values_in_fy, p_values_in_fy = \
Expand Down
1 change: 0 additions & 1 deletion brainpy/_src/analysis/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from .function import *
from .measurement import *
from .model import *
from .optimization import *
from .others import *
from .outputs import *
from .visualization import *
110 changes: 0 additions & 110 deletions brainpy/_src/analysis/utils/optimization.py

This file was deleted.

36 changes: 36 additions & 0 deletions brainpy/_src/optimizers/brentq.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,3 +313,39 @@ def find_root_of_1d_numpy(f, f_points, args=(), tol=1e-8):
idx += 1

return roots



def brentq_candidates(vmap_f, *values, args=()):
# change the position of meshgrid values
values = tuple((v.value if isinstance(v, bm.Array) else v) for v in values)
xs = values[0]
mesh_values = jnp.meshgrid(*values)
if jnp.ndim(mesh_values[0]) > 1:
mesh_values = tuple(jnp.moveaxis(m, 0, 1) for m in mesh_values)
mesh_values = tuple(m.flatten() for m in mesh_values)
# function outputs
signs = jnp.sign(vmap_f(*(mesh_values + args)))
# compute the selected values
signs = signs.reshape((xs.shape[0], -1))
par_len = signs.shape[1]
signs1 = signs.at[-1].set(1) # discard the final row
signs2 = jnp.vstack((signs[1:], signs[:1])).at[-1].set(1) # discard the first row
ids = jnp.where((signs1 * signs2).flatten() <= 0)[0]
x_starts = mesh_values[0][ids]
x_ends = mesh_values[0][ids + par_len]
other_vals = tuple(v[ids] for v in mesh_values[1:])
return x_starts, x_ends, other_vals





def brentq_roots(vmap_f, starts, ends, *vmap_args, args=()):
all_args = vmap_args + args
res = vmap_f(starts, ends, all_args)
valid_idx = jnp.where(res['status'] == ECONVERGED)[0]
roots = res['root'][valid_idx]
vmap_args = tuple(a[valid_idx] for a in vmap_args)
return roots, vmap_args

0 comments on commit bde2c1e

Please sign in to comment.