Skip to content

Commit

Permalink
Merge pull request #55 from invrs-io/next_scipy
Browse files Browse the repository at this point in the history
support scipy >= 1.15.0
  • Loading branch information
mfschubert authored Jan 16, 2025
2 parents fe3761c + d0919ba commit 9cda27e
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 108 deletions.
2 changes: 1 addition & 1 deletion .bumpversion.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "v0.10.4"
current_version = "v0.10.5"
commit = true
commit_args = "--no-verify"
tag = true
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/build-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ jobs:
- name: darglint docstring validation
run: darglint src --strictness=short --ignore-raise=ValueError

tests-jax-laatest:
tests-jax-latest:
runs-on: ubuntu-latest
timeout-minutes: 5
steps:
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]

name = "invrs_opt"
version = "v0.10.4"
version = "v0.10.5"
description = "Algorithms for inverse design"
keywords = ["topology", "optimization", "jax", "inverse design"]
readme = "README.md"
Expand All @@ -21,7 +21,7 @@ dependencies = [
"numpy",
"requests",
"optax",
"scipy < 1.15.0",
"scipy >= 1.15.0",
"totypes",
"types-requests",
]
Expand Down
2 changes: 1 addition & 1 deletion src/invrs_opt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Copyright (c) 2023 The INVRS-IO authors.
"""

__version__ = "v0.10.4"
__version__ = "v0.10.5"
__author__ = "Martin F. Schubert <[email protected]>"

from invrs_opt import parameterization as parameterization
Expand Down
178 changes: 76 additions & 102 deletions src/invrs_opt/optimizers/lbfgsb.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@
(True, False): 3, # Only the lower bound is `None`.
}

FORTRAN_INT = scipy_lbfgsb.types.intvar.dtype

if version.Version(jax.__version__) > version.Version("0.4.31"):
callback_sequential = functools.partial(jax.pure_callback, vmap_method="sequential")
else:
Expand Down Expand Up @@ -638,19 +636,19 @@ def _example_state(params: PyTree, maxcor: int) -> PyTree:
x=jnp.zeros(n, dtype=float),
converged=jnp.asarray(False),
_maxcor=jnp.zeros((), dtype=int),
_line_search_max_steps=jnp.zeros((), dtype=int),
_lower_bound=jnp.zeros(n, dtype=float),
_upper_bound=jnp.zeros(n, dtype=float),
_bound_type=jnp.zeros(n, dtype=jnp.int32),
_ftol=jnp.zeros((), dtype=float),
_gtol=jnp.zeros((), dtype=float),
_wa=jnp.ones(_wa_size(n=n, maxcor=maxcor), dtype=float),
_iwa=jnp.ones(n * 3, dtype=jnp.int32), # Fortran int
_task=jnp.zeros(59, dtype=int),
_csave=jnp.zeros(59, dtype=int),
_lsave=jnp.zeros(4, dtype=jnp.int32), # Fortran int
_isave=jnp.zeros(44, dtype=jnp.int32), # Fortran int
_iwa=jnp.ones(n * 3, dtype=jnp.int32),
_task=jnp.zeros(2, dtype=jnp.int32),
_lsave=jnp.zeros(4, dtype=jnp.int32),
_isave=jnp.zeros(44, dtype=jnp.int32),
_dsave=jnp.zeros(29, dtype=float),
_lower_bound=jnp.zeros(n, dtype=float),
_upper_bound=jnp.zeros(n, dtype=float),
_bound_type=jnp.zeros(n, dtype=int),
_ln_task=jnp.zeros(2, dtype=jnp.int32),
_line_search_max_steps=jnp.zeros((), dtype=int),
)
return float_params, example_jax_lbfgsb_state

Expand Down Expand Up @@ -691,56 +689,57 @@ class ScipyLbfgsbState:

x: NDArray
converged: NDArray
# Private attributes correspond to internal variables in the `scipy.optimize.
# lbfgsb._minimize_lbfgsb` function.
# Private attributes correspond to internal variables in the
# `scipy.optimize.lbfgsb._minimize_lbfgsb` function.
_maxcor: int
_line_search_max_steps: int
_lower_bound: NDArray
_upper_bound: NDArray
_bound_type: NDArray
_ftol: NDArray
_gtol: NDArray
_wa: NDArray
_iwa: NDArray
_task: NDArray
_csave: NDArray
_lsave: NDArray
_isave: NDArray
_dsave: NDArray
_lower_bound: NDArray
_upper_bound: NDArray
_bound_type: NDArray
_line_search_max_steps: int
_ln_task: NDArray

def __post_init__(self) -> None:
"""Validates the datatypes for all state attributes."""
_validate_array_dtype(self.x, onp.float64)
_validate_array_dtype(self._wa, onp.float64)
_validate_array_dtype(self._iwa, FORTRAN_INT)
_validate_array_dtype(self._task, "S60")
_validate_array_dtype(self._csave, "S60")
_validate_array_dtype(self._lsave, FORTRAN_INT)
_validate_array_dtype(self._isave, FORTRAN_INT)
_validate_array_dtype(self._dsave, onp.float64)
_validate_array_dtype(self._lower_bound, onp.float64)
_validate_array_dtype(self._upper_bound, onp.float64)
_validate_array_dtype(self._bound_type, int)
_validate_array_dtype("x", self.x, onp.float64)
_validate_array_dtype("_lower_bound", self._lower_bound, onp.float64)
_validate_array_dtype("_upper_bound", self._upper_bound, onp.float64)
_validate_array_dtype("_ftol", self._ftol, onp.float64)
_validate_array_dtype("_gtol", self._gtol, onp.float64)
_validate_array_dtype("_wa", self._wa, onp.float64)
_validate_array_dtype("_iwa", self._iwa, onp.int32)
_validate_array_dtype("_task", self._task, onp.int32)
_validate_array_dtype("_lsave", self._lsave, onp.int32)
_validate_array_dtype("_isave", self._isave, onp.int32)
_validate_array_dtype("_dsave", self._dsave, onp.float64)
_validate_array_dtype("_ln_task", self._ln_task, onp.int32)

def to_dict(self) -> NumpyLbfgsbDict:
"""Generates a dictionary of jax arrays defining the state."""
return dict(
x=onp.asarray(self.x),
converged=onp.asarray(self.converged),
_maxcor=onp.asarray(self._maxcor),
_line_search_max_steps=onp.asarray(self._line_search_max_steps),
_lower_bound=onp.asarray(self._lower_bound),
_upper_bound=onp.asarray(self._upper_bound),
_bound_type=onp.asarray(self._bound_type),
_ftol=onp.asarray(self._ftol),
_gtol=onp.asarray(self._gtol),
_wa=onp.asarray(self._wa),
_iwa=onp.asarray(self._iwa),
_task=_array_from_s60_str(self._task),
_csave=_array_from_s60_str(self._csave),
_task=onp.asarray(self._task),
_lsave=onp.asarray(self._lsave),
_isave=onp.asarray(self._isave),
_dsave=onp.asarray(self._dsave),
_lower_bound=onp.asarray(self._lower_bound),
_upper_bound=onp.asarray(self._upper_bound),
_bound_type=onp.asarray(self._bound_type),
_line_search_max_steps=onp.asarray(self._line_search_max_steps),
_ln_task=onp.asarray(self._ln_task),
)

@classmethod
Expand All @@ -750,19 +749,19 @@ def from_jax(cls, state_dict: JaxLbfgsbDict) -> "ScipyLbfgsbState":
x=onp.array(state_dict["x"], dtype=onp.float64),
converged=onp.asarray(state_dict["converged"], dtype=bool),
_maxcor=int(state_dict["_maxcor"]),
_line_search_max_steps=int(state_dict["_line_search_max_steps"]),
_lower_bound=onp.asarray(state_dict["_lower_bound"], dtype=onp.float64),
_upper_bound=onp.asarray(state_dict["_upper_bound"], dtype=onp.float64),
_bound_type=onp.asarray(state_dict["_bound_type"], dtype=onp.int32),
_ftol=onp.asarray(state_dict["_ftol"], dtype=onp.float64),
_gtol=onp.asarray(state_dict["_gtol"], dtype=onp.float64),
_wa=onp.array(state_dict["_wa"], onp.float64),
_iwa=onp.array(state_dict["_iwa"], dtype=FORTRAN_INT),
_task=_s60_str_from_array(onp.asarray(state_dict["_task"])),
_csave=_s60_str_from_array(onp.asarray(state_dict["_csave"])),
_lsave=onp.array(state_dict["_lsave"], dtype=FORTRAN_INT),
_isave=onp.array(state_dict["_isave"], dtype=FORTRAN_INT),
_iwa=onp.array(state_dict["_iwa"], dtype=onp.int32),
_task=onp.asarray(state_dict["_task"], dtype=onp.int32),
_lsave=onp.array(state_dict["_lsave"], dtype=onp.int32),
_isave=onp.array(state_dict["_isave"], dtype=onp.int32),
_dsave=onp.array(state_dict["_dsave"], dtype=onp.float64),
_lower_bound=onp.asarray(state_dict["_lower_bound"], dtype=onp.float64),
_upper_bound=onp.asarray(state_dict["_upper_bound"], dtype=onp.float64),
_bound_type=onp.asarray(state_dict["_bound_type"], dtype=int),
_line_search_max_steps=int(state_dict["_line_search_max_steps"]),
_ln_task=onp.asarray(state_dict["_ln_task"], onp.int32),
)

@classmethod
Expand Down Expand Up @@ -792,7 +791,6 @@ def init(
Returns:
The `ScipyLbfgsbState`.
"""
x0 = onp.asarray(x0)
if x0.ndim > 1:
raise ValueError(f"`x0` must be rank-1 but got shape {x0.shape}.")
lower_bound = onp.asarray(lower_bound)
Expand All @@ -810,8 +808,6 @@ def init(
lower_bound_array, upper_bound_array, bound_type = _configure_bounds(
lower_bound, upper_bound
)
task = onp.zeros(1, "S60")
task[:] = TASK_START

# See initialization of internal variables in the `lbfgsb._minimize_lbfgsb`
# function.
Expand All @@ -820,31 +816,27 @@ def init(
x=onp.array(x0, onp.float64),
converged=onp.asarray(False),
_maxcor=maxcor,
_line_search_max_steps=line_search_max_steps,
_lower_bound=lower_bound_array,
_upper_bound=upper_bound_array,
_bound_type=bound_type,
_ftol=onp.asarray(ftol, onp.float64),
_gtol=onp.asarray(gtol, onp.float64),
_wa=onp.zeros(wa_size, onp.float64),
_iwa=onp.zeros(3 * n, FORTRAN_INT),
_task=task,
_csave=onp.zeros(1, "S60"),
_lsave=onp.zeros(4, FORTRAN_INT),
_isave=onp.zeros(44, FORTRAN_INT),
_iwa=onp.zeros(3 * n, onp.int32),
_task=onp.zeros(2, onp.int32),
_lsave=onp.zeros(4, onp.int32),
_isave=onp.zeros(44, onp.int32),
_dsave=onp.zeros(29, onp.float64),
_lower_bound=lower_bound_array,
_upper_bound=upper_bound_array,
_bound_type=bound_type,
_line_search_max_steps=line_search_max_steps,
_ln_task=onp.zeros(2, onp.int32),
)
# The initial state requires an update with zero value and gradient. This
# is because the initial task is "START", which does not actually require
# value and gradient evaluation.
state.update(onp.zeros(x0.shape, onp.float64), onp.zeros((), onp.float64))
return state

def update(
self,
grad: NDArray,
value: NDArray,
) -> None:
def update(self, grad: NDArray, value: NDArray) -> None:
"""Performs an in-place update of the `ScipyLbfgsbState` if not converged.
Args:
Expand All @@ -866,29 +858,27 @@ def update(
# again, advancing past such "dummy" steps.
for _ in range(3):
scipy_lbfgsb.setulb(
m=self._maxcor,
x=self.x,
l=self._lower_bound,
u=self._upper_bound,
nbd=self._bound_type,
f=value,
g=grad,
factr=self._ftol / onp.finfo(float).eps,
pgtol=self._gtol,
wa=self._wa,
iwa=self._iwa,
task=self._task,
iprint=UPDATE_IPRINT,
csave=self._csave,
lsave=self._lsave,
isave=self._isave,
dsave=self._dsave,
maxls=self._line_search_max_steps,
self._maxcor, # m
self.x, # x
self._lower_bound, # low_bnd
self._upper_bound, # upper_bnd
self._bound_type, # nbnd
value, # f
grad, # g
self._ftol / onp.finfo(float).eps, # factr
self._gtol, # pgtol
self._wa, # wa
self._iwa, # iwa
self._task, # task
self._lsave, # lsave
self._isave, # isave
self._dsave, # dsave
self._line_search_max_steps, # maxls
self._ln_task, # ln_task
)
task_str = self._task.tobytes()
if task_str.startswith(TASK_CONVERGED):
if self._task[0] == 4:
self.converged = onp.asarray(True)
if task_str.startswith(TASK_FG):
if self._task[0] == 3:
break


Expand All @@ -897,12 +887,12 @@ def _wa_size(n: int, maxcor: int) -> int:
return 2 * maxcor * n + 5 * n + 11 * maxcor**2 + 8 * maxcor


def _validate_array_dtype(x: NDArray, dtype: Union[type, str]) -> None:
def _validate_array_dtype(name: str, x: NDArray, dtype: type) -> None:
"""Validates that `x` is an array with the specified `dtype`."""
if not isinstance(x, onp.ndarray):
raise ValueError(f"`x` must be an `onp.ndarray` but got {type(x)}")
raise ValueError(f"`{name}` must be an `onp.ndarray` but got {type(x)}")
if x.dtype != dtype:
raise ValueError(f"`x` must have dtype {dtype} but got {x.dtype}")
raise ValueError(f"`{name}` must have dtype {dtype} but got {x.dtype}")


def _configure_bounds(
Expand All @@ -919,21 +909,5 @@ def _configure_bounds(
return (
onp.asarray(lower_bound_array, onp.float64),
onp.asarray(upper_bound_array, onp.float64),
onp.asarray(bound_type),
)


def _array_from_s60_str(s60_str: NDArray) -> NDArray:
"""Return a jax array for a numpy s60 string."""
assert s60_str.shape == (1,)
chars = [int(o) for o in s60_str[0]]
chars.extend([32] * (59 - len(chars)))
return onp.asarray(chars, dtype=int)


def _s60_str_from_array(array: NDArray) -> NDArray:
"""Return a numpy s60 string for a jax array."""
return onp.asarray(
[b"".join(int(i).to_bytes(length=1, byteorder="big") for i in array)],
dtype="S60",
onp.asarray(bound_type, onp.int32),
)
2 changes: 1 addition & 1 deletion tests/optimizers/test_lbfgsb.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ def test_converter_numpy_pytree(self):
onp.testing.assert_array_equal(a, b)


class ScipyLbfgsStateTest(unittest.TestCase):
class ScipyLbfgsbStateTest(unittest.TestCase):
def test_x0_shape_validation(self):
with self.assertRaisesRegex(ValueError, "`x0` must be rank-1 but got"):
lbfgsb.ScipyLbfgsbState.init(
Expand Down

0 comments on commit 9cda27e

Please sign in to comment.