Skip to content

Commit

Permalink
upgrade optimizers
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Mar 10, 2024
1 parent 9ca2d6c commit 799c41f
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 16 deletions.
1 change: 1 addition & 0 deletions brainpy/_src/optimizers/brentq.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def jax_brentq(fun):
# else:
# rtol = 1.5 * jnp.finfo(jnp.float32).eps

@jax.jit
def x(a, b, args=(), xtol=2e-14, maxiter=200):
# Convert to float
xpre = a * 1.0
Expand Down
39 changes: 23 additions & 16 deletions brainpy/_src/optimizers/scipy_minimize.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import jax
import jax.numpy as jnp
import numpy as np
import scipy.optimize as soptimize
from jax import grad, jit
from jax.flatten_util import ravel_pytree

import brainpy._src.math as bm
import brainpy.math as bm
from brainpy import errors

soptimize = None


def scipy_minimize_with_jax(
fun, x0,
Expand Down Expand Up @@ -128,9 +129,13 @@ def scipy_minimize_with_jax(
See `scipy.optimize.OptimizeResult` for a description of other attributes.
"""
global soptimize
if soptimize is None:
raise errors.PackageMissingError(f'"scipy" must be installed when user want to use '
f'function: {scipy_minimize_with_jax}')
try:
import scipy.optimize as soptimize
except ImportError:
raise errors.PackageMissingError(f'"scipy" must be installed when user want to use '
f'function: {scipy_minimize_with_jax}')

# Use tree flatten and unflatten to convert params x0 from PyTrees to flat arrays
x0_flat, unravel = ravel_pytree(x0)
Expand All @@ -144,7 +149,7 @@ def fun_wrapper(x_flat, *args):
return float(r)

# Wrap the gradient in a similar manner
jac = jit(grad(fun))
jac = jax.jit(jax.grad(fun))

def jac_wrapper(x_flat, *args):
x = unravel(x_flat)
Expand All @@ -158,16 +163,18 @@ def callback_wrapper(x_flat, *args):
return callback(x, *args)

# Minimize with scipy
results = soptimize.minimize(fun_wrapper,
x0_flat,
args=args,
method=method,
jac=jac_wrapper,
callback=callback_wrapper,
bounds=bounds,
constraints=constraints,
tol=tol,
options=options)
results = soptimize.minimize(
fun_wrapper,
x0_flat,
args=args,
method=method,
jac=jac_wrapper,
callback=callback_wrapper,
bounds=bounds,
constraints=constraints,
tol=tol,
options=options
)

# pack the output back into a PyTree
results["x"] = unravel(results["x"])
Expand Down

0 comments on commit 799c41f

Please sign in to comment.