From 799c41f9772e0342b6d00cc3a7bb11ad24c55420 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Sun, 10 Mar 2024 10:53:09 +0800 Subject: [PATCH] upgrade optimizers --- brainpy/_src/optimizers/brentq.py | 1 + brainpy/_src/optimizers/scipy_minimize.py | 39 +++++++++++++---------- 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/brainpy/_src/optimizers/brentq.py b/brainpy/_src/optimizers/brentq.py index 4e281ec3..22113114 100644 --- a/brainpy/_src/optimizers/brentq.py +++ b/brainpy/_src/optimizers/brentq.py @@ -37,6 +37,7 @@ def jax_brentq(fun): # else: # rtol = 1.5 * jnp.finfo(jnp.float32).eps + @jax.jit def x(a, b, args=(), xtol=2e-14, maxiter=200): # Convert to float xpre = a * 1.0 diff --git a/brainpy/_src/optimizers/scipy_minimize.py b/brainpy/_src/optimizers/scipy_minimize.py index 95a8844f..bfd810f1 100644 --- a/brainpy/_src/optimizers/scipy_minimize.py +++ b/brainpy/_src/optimizers/scipy_minimize.py @@ -1,12 +1,13 @@ +import jax import jax.numpy as jnp import numpy as np -import scipy.optimize as soptimize -from jax import grad, jit from jax.flatten_util import ravel_pytree -import brainpy._src.math as bm +import brainpy.math as bm from brainpy import errors +soptimize = None + def scipy_minimize_with_jax( fun, x0, @@ -128,9 +129,13 @@ def scipy_minimize_with_jax( See `scipy.optimize.OptimizeResult` for a description of other attributes. """ + global soptimize if soptimize is None: - raise errors.PackageMissingError(f'"scipy" must be installed when user want to use ' - f'function: {scipy_minimize_with_jax}') + try: + import scipy.optimize as soptimize + except ImportError: + raise errors.PackageMissingError(f'"scipy" must be installed when user want to use ' + f'function: {scipy_minimize_with_jax}') # Use tree flatten and unflatten to convert params x0 from PyTrees to flat arrays x0_flat, unravel = ravel_pytree(x0) @@ -144,7 +149,7 @@ def fun_wrapper(x_flat, *args): return float(r) # Wrap the gradient in a similar manner - jac = jit(grad(fun)) + jac = jax.jit(jax.grad(fun)) def jac_wrapper(x_flat, *args): x = unravel(x_flat) @@ -158,16 +163,18 @@ def callback_wrapper(x_flat, *args): return callback(x, *args) # Minimize with scipy - results = soptimize.minimize(fun_wrapper, - x0_flat, - args=args, - method=method, - jac=jac_wrapper, - callback=callback_wrapper, - bounds=bounds, - constraints=constraints, - tol=tol, - options=options) + results = soptimize.minimize( + fun_wrapper, + x0_flat, + args=args, + method=method, + jac=jac_wrapper, + callback=callback_wrapper, + bounds=bounds, + constraints=constraints, + tol=tol, + options=options + ) # pack the output back into a PyTree results["x"] = unravel(results["x"])