Skip to content

Commit

Permalink
Merge pull request #53 from invrs-io/back
Browse files Browse the repository at this point in the history
support jax >=0.4.31
  • Loading branch information
mfschubert authored Jan 15, 2025
2 parents f090995 + 90c60bb commit 33db47f
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 13 deletions.
5 changes: 1 addition & 4 deletions .bumpversion.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "v0.10.2"
current_version = "v0.10.3"
commit = true
commit_args = "--no-verify"
tag = true
Expand All @@ -15,8 +15,5 @@ message = "Version updated from {current_version} to {new_version}"
[[tool.bumpversion.files]]
filename = "pyproject.toml"

[[tool.bumpversion.files]]
filename = "README.md"

[[tool.bumpversion.files]]
filename = "src/invrs_opt/__init__.py"
26 changes: 25 additions & 1 deletion .github/workflows/build-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ jobs:
- name: darglint docstring validation
run: darglint src --strictness=short --ignore-raise=ValueError

tests:
tests-jax-laatest:
runs-on: ubuntu-latest
timeout-minutes: 5
steps:
Expand All @@ -74,3 +74,27 @@ jobs:
- name: Run Python tests
run: |
pytest --cov=invrs_opt tests
tests-jax-0_4_31:
runs-on: ubuntu-latest
timeout-minutes: 5
steps:
- name: Checkout repository
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.10"
cache: "pip"
cache-dependency-path: pyproject.toml

- name: Setup environment
run: |
python -m pip install --upgrade pip
pip install ".[tests,dev]"
pip install --upgrade "jax==0.4.31"
- name: Run Python tests
run: |
pytest --cov=invrs_opt tests
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# invrs-opt - Optimization algorithms for inverse design
`v0.10.2`
![Continuous integration](https://github.com/invrs-io/opt/actions/workflows/build-ci.yml/badge.svg)
![PyPI version](https://img.shields.io/pypi/v/invrs-opt)

## Overview

Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]

name = "invrs_opt"
version = "v0.10.2"
version = "v0.10.3"
description = "Algorithms for inverse design"
keywords = ["topology", "optimization", "jax", "inverse design"]
readme = "README.md"
Expand All @@ -16,7 +16,7 @@ maintainers = [
]

dependencies = [
"jax >= 0.4.35",
"jax >= 0.4.31",
"jaxlib",
"numpy",
"requests",
Expand Down
2 changes: 1 addition & 1 deletion src/invrs_opt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Copyright (c) 2023 The INVRS-IO authors.
"""

__version__ = "v0.10.2"
__version__ = "v0.10.3"
__author__ = "Martin F. Schubert <[email protected]>"

from invrs_opt import parameterization as parameterization
Expand Down
13 changes: 9 additions & 4 deletions src/invrs_opt/optimizers/lbfgsb.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
"""

import dataclasses
import functools
from packaging import version
from typing import Any, Dict, Optional, Sequence, Tuple, Union

import jax
Expand Down Expand Up @@ -56,6 +58,11 @@

FORTRAN_INT = scipy_lbfgsb.types.intvar.dtype

if version.Version(jax.__version__) > version.Version("0.4.31"):
callback_sequential = functools.partial(jax.pure_callback, vmap_method="sequential")
else:
callback_sequential = functools.partial(jax.pure_callback, vectorized=False)


def lbfgsb(
*,
Expand Down Expand Up @@ -316,11 +323,10 @@ def _init_state_pure(latent_params: PyTree) -> Tuple[PyTree, NumpyLbfgsbDict]:

latent_params = _init_latents(params)
metadata, latents = param_base.partition_density_metadata(latent_params)
latents, jax_lbfgsb_state = jax.pure_callback(
latents, jax_lbfgsb_state = callback_sequential(
_init_state_pure,
_example_state(latents, maxcor),
latents,
vmap_method="sequential",
)
latent_params = param_base.combine_density_metadata(metadata, latents)
return (
Expand Down Expand Up @@ -401,13 +407,12 @@ def _constraint_loss_latents(latents: PyTree) -> jnp.ndarray:
latents_grad
) # type: ignore[no-untyped-call]

flat_latent_updates, jax_lbfgsb_state = jax.pure_callback(
flat_latent_updates, jax_lbfgsb_state = callback_sequential(
_update_pure,
(flat_latents_grad, jax_lbfgsb_state),
flat_latents_grad,
value,
jax_lbfgsb_state,
vmap_method="sequential",
)
latent_updates = unflatten_fn(flat_latent_updates)
latent_params = _apply_updates(
Expand Down

0 comments on commit 33db47f

Please sign in to comment.